diff --git a/.bazelrc b/.bazelrc index 7fc6ef6c94889b..505589d8ef9531 100644 --- a/.bazelrc +++ b/.bazelrc @@ -482,30 +482,36 @@ build:avx_linux --copt=-mavx build:avx_linux --host_copt=-mavx build:avx_win --copt=/arch:AVX +# TODO(belitskiy): Remove once Win2019 is gone. # Use Clang-cl compiler on Windows -build:win_clang --copt=/clang:-Weverything -build:win_clang --host_copt=/clang:-Weverything build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl build:win_clang --extra_execution_platforms=//tensorflow/tools/toolchains/win:x64_windows-clang-cl build:win_clang --host_platform=//tensorflow/tools/toolchains/win:x64_windows-clang-cl +build:win_clang --copt=/clang:-Weverything +build:win_clang --host_copt=/clang:-Weverything build:win_clang --compiler=clang-cl build:win_clang --linkopt=/FORCE:MULTIPLE build:win_clang --host_linkopt=/FORCE:MULTIPLE test:win_clang --linkopt=/FORCE:MULTIPLE test:win_clang --host_linkopt=/FORCE:MULTIPLE - -# Same config as above but for XLA, which has different toolchain paths -build:win_clang_xla --copt=/clang:-Weverything -build:win_clang_xla --host_copt=/clang:-Weverything -build:win_clang_xla --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl -build:win_clang_xla --extra_execution_platforms=//tools/toolchains/win:x64_windows-clang-cl -build:win_clang_xla --host_platform=//tools/toolchains/win:x64_windows-clang-cl -build:win_clang_xla --compiler=clang-cl -build:win_clang_xla --linkopt=/FORCE:MULTIPLE -build:win_clang_xla --host_linkopt=/FORCE:MULTIPLE -test:win_clang_xla --action_env=PATHEXT=.COM;.EXE;.BAT;.CMD;.VBS;.VBE;.JS;.JSE;.WSF;.WSH;.MSC;.PY;.PYW -test:win_clang_xla --linkopt=/FORCE:MULTIPLE -test:win_clang_xla --host_linkopt=/FORCE:MULTIPLE +test:win_clang --action_env=PATHEXT=.COM;.EXE;.BAT;.CMD;.VBS;.VBE;.JS;.JSE;.WSF;.WSH;.MSC;.PY;.PYW + +# build:windows_x86_cpu --extra_toolchains="//tensorflow/tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl" +# build:windows_x86_cpu --extra_execution_platforms="//tensorflow/tools/toolchains/win2022:windows_ltsc2022_clang" +# build:windows_x86_cpu --host_platform="//tensorflow/tools/toolchains/win2022:windows_ltsc2022_clang" +build:windows_x86_cpu --crosstool_top="//tensorflow/tools/toolchains/win2022/20241118:toolchain" +build:windows_x86_cpu --extra_toolchains="//tensorflow/tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl" +build:windows_x86_cpu --extra_execution_platforms="//tensorflow/tools/toolchains/win2022:windows_ltsc2022_clang" +build:windows_x86_cpu --host_platform="//tensorflow/tools/toolchains/win2022:windows_ltsc2022_clang" +build:windows_x86_cpu --platforms="//tensorflow/tools/toolchains/win2022:windows_ltsc2022_clang" +build:windows_x86_cpu --copt=/clang:-Weverything +build:windows_x86_cpu --host_copt=/clang:-Weverything +build:windows_x86_cpu --compiler=clang-cl +build:windows_x86_cpu --linkopt=/FORCE:MULTIPLE +build:windows_x86_cpu --host_linkopt=/FORCE:MULTIPLE +test:windows_x86_cpu --linkopt=/FORCE:MULTIPLE +test:windows_x86_cpu --host_linkopt=/FORCE:MULTIPLE +test:windows_x86_cpu --action_env=PATHEXT=.COM;.EXE;.BAT;.CMD;.VBS;.VBE;.JS;.JSE;.WSF;.WSH;.MSC;.PY;.PYW # Options to build TensorFlow 1.x or 2.x. # TODO(kanglan): Change v2's define to default behavior @@ -564,9 +570,9 @@ build:rbe_linux_cpu --crosstool_top="@local_config_cuda//crosstool:toolchain" build:rbe_linux_cpu --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64" build:rbe_linux_cpu --repo_env=CC="/usr/lib/llvm-18/bin/clang" build:rbe_linux_cpu --repo_env=TF_SYSROOT="/dt9" -build:rbe_linux_cpu --extra_execution_platforms="@sigbuild-r2.17-clang_config_platform//:platform" -build:rbe_linux_cpu --host_platform="@sigbuild-r2.17-clang_config_platform//:platform" -build:rbe_linux_cpu --platforms="@sigbuild-r2.17-clang_config_platform//:platform" +build:rbe_linux_cpu --extra_execution_platforms="@ml_build_config_platform//:platform" +build:rbe_linux_cpu --host_platform="@ml_build_config_platform//:platform" +build:rbe_linux_cpu --platforms="@ml_build_config_platform//:platform" # This is needed for all Clang17 builds but must not be present in GCC builds. build:rbe_linux_cpu --copt=-Wno-error=unused-command-line-argument # This was added in clang-16 by https://reviews.llvm.org/D133574. @@ -796,48 +802,54 @@ build:tf_public_macos_cache_push --config=tf_public_macos_cache --remote_upload_ # LIBTENSORFLOW TESTS are for building Libtensorflow archives. These are CUDA/CPU-agnostic. test:linux_libtensorflow_test --config=cuda_wheel -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_package:libtensorflow.tar.gz //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz //tensorflow/java:libtensorflow.jar //tensorflow/java:libtensorflow-src.jar //tensorflow/tools/lib_package:libtensorflow_proto.zip +build:windows_libtensorflow_build --config=cuda_wheel --config=windows_x86_cpu -- //:LICENSE //tensorflow:tensorflow.dll //tensorflow:tensorflow_dll_import_lib //tensorflow/tools/lib_package:clicenses_generate //tensorflow/java:tensorflow_jni.dll //tensorflow/tools/lib_package:jnilicenses_generate # PYTHON TESTS run a suite of Python tests intended for verifying that the Python wheel # will work properly. These are usually run Nightly or upon Release. # CPU WHEEL -test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cpu_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/tools/pip_package:import_api_packages_test +test:linux_cpu_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA WHEEL -test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/tools/pip_package:import_api_packages_test +test:linux_cuda_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_gpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL -test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/tools/pip_package:import_api_packages_test +test:linux_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS ARM64 WHEEL -test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/tools/pip_package:import_api_packages_test +test:macos_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # MACOS X86 WHEEL -test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_x86_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/tools/pip_package:import_api_packages_test +test:macos_x86_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +# WINDOWS X86 WHEEL +test:windows_x86_cpu_wheel_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test +test:windows_x86_cpu_wheel_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test +test:windows_x86_cpu_wheel_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600" +test:windows_x86_cpu_wheel_test --build_tests_only --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. # LINUX CPU PYCPP: -test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # LINUX CUDA PYCPP: -test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 -test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_gpu -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # LINUX ARM64 PYCPP # In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on @@ -848,35 +860,35 @@ test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflo # do not run them. By prefixing the configs with "build", we can run both # `bazel build` and `bazel test` commands with the same config as test configs # inherit from build. -build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? -build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/python/tools:aot_compiled_test +build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/python/tools:aot_compiled_test # CROSS-COMPILE ARM64 PYCPP build:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test # Tests that fail only when cross-compiled build:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP -test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS X86 PYCPP # These are defined as build configs so that we can run a build only job. See # the note under "ARM64 PYCPP" for more details. -build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test build:macos_x86_pycpp_test_filters --keep_going --test_lang_filters=cc,py --test_size_filters=small,medium -build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... +build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... # CROSS-COMPILE MACOS X86 PYCPP build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test # WINDOWS X86-64 CPU PYCPP build:windows_x86_cpu_pycpp_test_build_opts --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions --dynamic_mode=off build:windows_x86_cpu_pycpp_test_build_opts_debug --config=windows_x86_cpu_pycpp_test_build_opts --linkopt=/demangle:no --host_linkopt=/demangle:no --linkopt=/errorlimit:0 --host_linkopt=/errorlimit:0 -test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test -test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-tf_tosa,-oss_excluded,-gpu,-tpu,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-tf_tosa,-oss_excluded,-benchmark-test test:windows_x86_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600" test:windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_build_opts --build_tests_only test:windows_x86_cpu_pycpp_test --config=windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/... @@ -893,38 +905,15 @@ build:cross_compile_base --host_cpu=k8 build:cross_compile_base --host_crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite build:cross_compile_base --extra_execution_platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_x86_64 -# XLA related settings for cross-compiled build. Certain paths are -# different in the XLA repo. -build:cross_compile_base_xla --host_cpu=k8 -build:cross_compile_base_xla --host_crosstool_top=//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite -build:cross_compile_base_xla --extra_execution_platforms=//tools/toolchains/cross_compile/config:linux_x86_64 - build:rbe_cross_compile_base --config=rbe_base build:rbe_cross_compile_base --remote_instance_name=projects/tensorflow-testing/instances/default_instance -# XLA depends on some local Python headers that are configured as Genrule. They -# are present on the local host machine but not on the remote execution machine, -# leading to build failures. To resolve the issue, the following line is added -# to make sure all Genrule targets are excuted locally. -build:rbe_cross_compile_base_xla --config=rbe_cross_compile_base -build:rbe_cross_compile_base_xla --strategy=Genrule=standalone - -# Due to the above strategy, all Genrule commands are executed locally, but the -# following actions invoke tools (E.g `flatc`, `llvm-tblgen`, etc.) that are -# only executabe on the RBE (x86) machine, so the strategy_regexp options are -# added to override and run the actions using remote strategy. -build:rbe_cross_compile_base_xla --strategy_regexp='Generating code from table.*=remote' -build:rbe_cross_compile_base_xla --strategy_regexp='Generating flatbuffer files.*=remote' -build:rbe_cross_compile_base_xla --strategy_regexp='Executing genrule @llvm-project.*=remote' - # Test-related settings below this point # We cannot run cross-compiled tests on the remote Linux x86 VMs so we need to # force all tests to run locally on the Aarch64 host. test:rbe_cross_compile_base --strategy=TestRunner=local --build_tests_only test:rbe_cross_compile_base --verbose_failures=true --local_test_jobs=HOST_CPUS --test_output=errors -test:rbe_cross_compile_base_xla --config=rbe_cross_compile_base - # START LINUX AARCH64 CROSS-COMPILE CONFIGS build:cross_compile_linux_arm64 --config=cross_compile_base @@ -933,21 +922,11 @@ build:cross_compile_linux_arm64 --platforms=//tensorflow/tools/toolchains/cross_ build:cross_compile_linux_arm64 --cpu=aarch64 build:cross_compile_linux_arm64 --crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite -# XLA uses different paths for platforms and crosstool_top. -build:cross_compile_linux_arm64_xla --config=cross_compile_base_xla -build:cross_compile_linux_arm64_xla --platforms=//tools/toolchains/cross_compile/config:linux_aarch64 -build:cross_compile_linux_arm64_xla --crosstool_top=//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite - # RBE cross-compile configs for Linux Aarch64 build:rbe_cross_compile_linux_arm64 --config=cross_compile_linux_arm64 build:rbe_cross_compile_linux_arm64 --config=rbe_cross_compile_base test:rbe_cross_compile_linux_arm64 --config=rbe_cross_compile_base -# RBE cross-compile configs for XLA Linux Aarch64 -build:rbe_cross_compile_linux_arm64_xla --config=cross_compile_linux_arm64_xla -build:rbe_cross_compile_linux_arm64_xla --config=rbe_cross_compile_base_xla -test:rbe_cross_compile_linux_arm64_xla --config=rbe_cross_compile_base_xla - # END LINUX AARCH64 CROSS-COMPILE CONFIGS # START MACOS CROSS-COMPILE CONFIGS diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 58123a3cddd9b4..f48c37a84c7b03 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -250,7 +250,7 @@ There are two ways to run TensorFlow unit tests. bazel by doing as follows: ```bash - export flags="--config=opt -k" + export flags="--config=linux -k" ``` If the tests are to be run on the GPU: @@ -259,7 +259,7 @@ There are two ways to run TensorFlow unit tests. flag. ```bash - export flags="--config=opt --config=cuda -k" + export flags="--config=linux --config=cuda -k" ``` * For TensorFlow versions prior v.2.18.0: Add CUDA paths to @@ -267,7 +267,7 @@ There are two ways to run TensorFlow unit tests. ```bash export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" - export flags="--config=opt --config=cuda -k" + export flags="--config=linux --config=cuda -k" ``` For example, to run all tests under tensorflow/python, do: diff --git a/ci/devinfra/docker/windows/Dockerfile b/ci/devinfra/docker/windows/Dockerfile index e1a7f949d5f48b..5ce20a017134e2 100644 --- a/ci/devinfra/docker/windows/Dockerfile +++ b/ci/devinfra/docker/windows/Dockerfile @@ -42,6 +42,7 @@ RUN C:\TEMP\vs_community.exe \ --add Microsoft.VisualStudio.Workload.NativeDesktop \ --add Microsoft.VisualStudio.Component.VC.14.39.17.9.x86.64 \ --add Microsoft.VisualStudio.Component.Windows11SDK.22621 \ + --add Microsoft.VisualStudio.Component.VC.ATL \ || IF "%ERRORLEVEL%"=="3010" EXIT 0 SHELL ["powershell.exe", "-ExecutionPolicy", "Bypass", "-Command", \ @@ -152,4 +153,18 @@ RUN (New-Object Net.WebClient).DownloadFile( \ $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine') + ';C:\tools\bazel'; \ [Environment]::SetEnvironmentVariable('PATH', $env:PATH, 'Machine'); +ENV CLOUDSDK_CORE_DISABLE_PROMPTS 1 +RUN (New-Object Net.WebClient).DownloadFile('https://dl.google.com/dl/cloudsdk/channels/rapid/google-cloud-sdk.zip', 'C:\Temp\google-cloud-sdk.zip'); \ + Expand-Archive -Path 'C:\Temp\google-cloud-sdk.zip' -DestinationPath $env:ProgramFiles -Verbose:$false +RUN & \"$env:ProgramFiles\\google-cloud-sdk\\install.bat\" --path-update false +RUN $env:Path += \";$env:ProgramFiles\\google-cloud-sdk\\bin\"; \ + [Environment]::SetEnvironmentVariable('Path', $env:Path, [EnvironmentVariableTarget]::Machine); +# Re-enable prompts for interactive use. +ENV CLOUDSDK_CORE_DISABLE_PROMPTS="" + +# MSYS attempts to use non-cmd versions, which aren't meant for Windows +RUN Add-Content -Path C:\tools\msys64\.bashrc -Value 'alias gcloud=gcloud.cmd' +RUN Add-Content -Path C:\tools\msys64\.bashrc -Value 'alias gsutil=gsutil.cmd' +RUN Add-Content -Path C:\tools\msys64\.bashrc -Value 'alias bq=bq.cmd' + SHELL ["cmd.exe", "/s", "/c"] diff --git a/ci/official/any.sh b/ci/official/any.sh index dc1484b64dc9ea..4706b0212cea09 100755 --- a/ci/official/any.sh +++ b/ci/official/any.sh @@ -36,7 +36,7 @@ # export TF_ANY_EXTRA_ENV=ci/official/envs/local_rbe # ./any.sh # ... -set -euxo pipefail +set -exo pipefail cd "$(dirname "$0")/../../" # tensorflow/ # Any request that includes "nightly_upload" should just use the # local multi-cache (public read-only cache + disk cache) instead. diff --git a/ci/official/bisect.sh b/ci/official/bisect.sh index 7f18dd1460ff5b..72cd6e684a6827 100755 --- a/ci/official/bisect.sh +++ b/ci/official/bisect.sh @@ -32,7 +32,7 @@ # export TF_BISECT_BAD=a_failing_commit_sha # export TF_ANY_TARGETS="quoted list of targets, like on the command line" # export TF_ANY_MODE=test -set -euxo pipefail +set -exo pipefail cd "$(dirname "$0")/../../" # tensorflow/ export TFCI="$(echo $TFCI | sed 's/,nightly_upload/,public_cache,disk_cache/')" git bisect start "$TF_BISECT_BAD" "$TF_BISECT_GOOD" diff --git a/ci/official/containers/linux_arm64/devel.usertools/rename_and_verify_wheels.sh b/ci/official/containers/linux_arm64/devel.usertools/rename_and_verify_wheels.sh index 0b56b5f5b9f0bd..23f3b532dd5eba 100755 --- a/ci/official/containers/linux_arm64/devel.usertools/rename_and_verify_wheels.sh +++ b/ci/official/containers/linux_arm64/devel.usertools/rename_and_verify_wheels.sh @@ -17,7 +17,7 @@ # Check and rename wheels with auditwheel. Inserts the platform tags like # "manylinux_xyz" into the wheel filename. -set -euxo pipefail +set -exo pipefail for wheel in /tf/pkg/*.whl; do echo "Checking and renaming $wheel..." diff --git a/ci/official/containers/linux_arm64/devel.usertools/setup_venv_test.sh b/ci/official/containers/linux_arm64/devel.usertools/setup_venv_test.sh index db05f3d3c1dec9..4158e04bd16051 100755 --- a/ci/official/containers/linux_arm64/devel.usertools/setup_venv_test.sh +++ b/ci/official/containers/linux_arm64/devel.usertools/setup_venv_test.sh @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -set -euxo pipefail +set -exo pipefail # Run this from inside the tensorflow github directory. # Usage: setup_venv_test.sh venv_and_symlink_name "glob pattern for one wheel file" diff --git a/ci/official/containers/ml_build/Dockerfile b/ci/official/containers/ml_build/Dockerfile index fb17fd97bebdd4..9b7686a166e92c 100644 --- a/ci/official/containers/ml_build/Dockerfile +++ b/ci/official/containers/ml_build/Dockerfile @@ -1,5 +1,8 @@ ################################################################################ -FROM ubuntu:22.04@sha256:58b87898e82351c6cf9cf5b9f3c20257bb9e2dcf33af051e12ce532d7f94e3fe AS devel +ARG BASE_IMAGE=ubuntu:22.04@sha256:58b87898e82351c6cf9cf5b9f3c20257bb9e2dcf33af051e12ce532d7f94e3fe +FROM $BASE_IMAGE AS devel +# See https://docs.docker.com/reference/dockerfile/#understand-how-arg-and-from-interact +# on why we cannot reference BASE_IMAGE again unless we declare it again. ################################################################################ # Install devtoolset build dependencies @@ -20,15 +23,15 @@ RUN /build_devtoolset.sh devtoolset-9 /dt9 # Setup Python COPY setup.python.sh /setup.python.sh COPY builder.requirements.txt /builder.requirements.txt -RUN /setup.python.sh python3.9 builder.requirements.txt -RUN /setup.python.sh python3.10 builder.requirements.txt -RUN /setup.python.sh python3.11 builder.requirements.txt -RUN /setup.python.sh python3.13 builder.requirements.txt +RUN /setup.python.sh python3.9 /builder.requirements.txt +RUN /setup.python.sh python3.10 /builder.requirements.txt +RUN /setup.python.sh python3.11 /builder.requirements.txt +RUN /setup.python.sh python3.13 /builder.requirements.txt # Since we are using python3.12 as the default python version, we need to # install python3.12 last for now. # TODO(b/376338367): switch to pyenv. -RUN /setup.python.sh python3.12 builder.requirements.txt +RUN /setup.python.sh python3.12 /builder.requirements.txt # Setup links for TensorFlow to compile. # Referenced in devel.usertools/*.bazelrc. @@ -41,6 +44,9 @@ RUN ln -sf /usr/lib/python3.12 /usr/lib/tf_python # Make sure clang is on the path RUN ln -s /usr/lib/llvm-18/bin/clang /usr/bin/clang +# Link the compat driver to the location if available. +RUN if [ -e "/usr/local/cuda/compat/libcuda.so.1" ]; then ln -s /usr/local/cuda/compat/libcuda.so.1 /usr/lib/x86_64-linux-gnu/libcuda.so.1; fi + # Install various tools. # - bats: bash unit testing framework # - bazelisk: always use the correct bazel version diff --git a/ci/official/containers/ml_build/setup.python.sh b/ci/official/containers/ml_build/setup.python.sh index 831bd612c41bb0..05e955d45471d9 100755 --- a/ci/official/containers/ml_build/setup.python.sh +++ b/ci/official/containers/ml_build/setup.python.sh @@ -24,7 +24,7 @@ VERSION=$1 REQUIREMENTS=$2 # Install Python packages for this container's version -if [[ ${VERSION} == "python3.13" ]]; then +if [[ ${VERSION} == "python3.13" || ${VERSION} == "python3.12" ]]; then cat >pythons.txt < != T:\ mismatches, +# when using variables like `TFCI_OUTPUT_DIR` in `docker exec commands, +# requiring conditional path adjustments throughout the CI scripts. +# Note: This does not work for `docker cp` commands. +TFCI_OUTPUT_WIN_DOCKER_DIR='C:/drive_t' + +# Docker on Windows doesn't support the `host` networking mode, and so +# port-forwarding is required for the container to detect it's running on GCE. +export IP_ADDR=$(powershell -command "(Get-NetIPAddress -AddressFamily IPv4 -InterfaceAlias 'vEthernet (nat)').IPAddress") +netsh interface portproxy add v4tov4 listenaddress=$IP_ADDR listenport=80 connectaddress=169.254.169.254 connectport=80 +# A local firewall rule for the container is added in +# ci/official/utilities/setup_docker.sh. diff --git a/ci/official/envs/windows_x86_2022 b/ci/official/envs/windows_x86_2022 new file mode 100644 index 00000000000000..f4305982df806a --- /dev/null +++ b/ci/official/envs/windows_x86_2022 @@ -0,0 +1,49 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/tf-win2022@sha256:915cb093630432c38b028f56bd31116a5559ebbc688d427b6092d86828ae03bc" +TFCI_BAZEL_BAZELRC_ARGS="--output_user_root=C:/t" +TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=windows_x86_cpu" +TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=windows_x86_cpu +TFCI_OUTPUT_DIR=build_output +TFCI_FIND_BIN=C:/tools/msys64/usr/bin/find.exe +TFCI_LIB_SUFFIX="-cpu-windows-x86_64" +# auditwheel is not supported for Windows +TFCI_WHL_AUDIT_ENABLE=0 +TFCI_WHL_AUDIT_PLAT=0 +# Tests are extremely slow at the moment +TFCI_WHL_BAZEL_TEST_ENABLE=0 +TFCI_WHL_SIZE_LIMIT=450M +TFCI_WHL_SIZE_LIMIT_ENABLE=1 +TFCI_WHL_IMPORT_TEST_ENABLE=1 +TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS="" + +# TODO(belitskiy): Add a link to the Dockerfile comment that explains this more. +# Used to simulate a T:\ drive within the container, to a limited extent, +# via a symlink. +# Helpful since the internal CI utilizes a T:\ drive, part of which is mounted +# to the container, and would result in C:\ != T:\ mismatches, +# when using variables like `TFCI_OUTPUT_DIR` in `docker exec commands, +# requiring conditional path adjustments throughout the CI scripts. +# Note: This does not work for `docker cp` commands. +TFCI_OUTPUT_WIN_DOCKER_DIR='C:/drive_t' + +# Docker on Windows doesn't support the `host` networking mode, and so +# port-forwarding is required for the container to detect it's running on GCE. +export IP_ADDR=$(powershell -command "(Get-NetIPAddress -AddressFamily IPv4 -InterfaceAlias 'vEthernet (nat)').IPAddress") +netsh interface portproxy add v4tov4 listenaddress=$IP_ADDR listenport=80 connectaddress=169.254.169.254 connectport=80 +# A local firewall rule for the container is added in +# ci/official/utilities/setup_docker.sh. diff --git a/ci/official/libtensorflow.sh b/ci/official/libtensorflow.sh index ded7b90da421f0..331851b3c17ca6 100755 --- a/ci/official/libtensorflow.sh +++ b/ci/official/libtensorflow.sh @@ -25,10 +25,14 @@ if [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]]; then tfrun python3 tensorflow/tools/ci_build/update_version.py --nightly fi -tfrun bazel $TFCI_BAZEL_BAZELRC_ARGS test $TFCI_BAZEL_COMMON_ARGS --config=linux_libtensorflow_test -tfrun bazel $TFCI_BAZEL_BAZELRC_ARGS build $TFCI_BAZEL_COMMON_ARGS --config=linux_libtensorflow_build +if [[ $(uname -s) != MSYS_NT* ]]; then + tfrun bazel $TFCI_BAZEL_BAZELRC_ARGS test $TFCI_BAZEL_COMMON_ARGS --config=linux_libtensorflow_test + tfrun bazel $TFCI_BAZEL_BAZELRC_ARGS build $TFCI_BAZEL_COMMON_ARGS --config=linux_libtensorflow_build +else + tfrun bazel $TFCI_BAZEL_BAZELRC_ARGS build $TFCI_BAZEL_COMMON_ARGS --config=windows_libtensorflow_build +fi -tfrun ./ci/official/utilities/repack_libtensorflow.sh "$TFCI_OUTPUT_DIR" "$TFCI_LIB_SUFFIX" +tfrun bash ./ci/official/utilities/repack_libtensorflow.sh "$TFCI_OUTPUT_DIR" "$TFCI_LIB_SUFFIX" if [[ "$TFCI_ARTIFACT_STAGING_GCS_ENABLE" == 1 ]]; then # Note: -n disables overwriting previously created files. diff --git a/ci/official/pycpp.sh b/ci/official/pycpp.sh index 0f4df1a7a83d73..cf2f258c90b0c4 100755 --- a/ci/official/pycpp.sh +++ b/ci/official/pycpp.sh @@ -16,7 +16,7 @@ source "${BASH_SOURCE%/*}/utilities/setup.sh" if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then - PROFILE_JSON_PATH=$(replace_drive_letter_with_c "$TFCI_OUTPUT_DIR") + PROFILE_JSON_PATH=$(replace_drive_letter_with_prefix "$TFCI_OUTPUT_WIN_DOCKER_DIR") PROFILE_JSON_PATH="$PROFILE_JSON_PATH/profile.json.gz" else PROFILE_JSON_PATH="$TFCI_OUTPUT_DIR/profile.json.gz" @@ -29,14 +29,9 @@ if [[ "$TFCI_WHL_NUMPY_VERSION" == 1 ]]; then fi if [[ $TFCI_PYCPP_SWAP_TO_BUILD_ENABLE == 1 ]]; then - tfrun bazel build $TFCI_BAZEL_COMMON_ARGS --profile "$PROFILE_JSON_PATH" --@local_config_cuda//cuda:override_include_cuda_libs=true --@local_tsl//third_party/py:verify_manylinux=false --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test" + tfrun bazel $TFCI_BAZEL_BAZELRC_ARGS build $TFCI_BAZEL_COMMON_ARGS --profile "$PROFILE_JSON_PATH" --@local_config_cuda//cuda:override_include_cuda_libs=true --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test" else - # TODO(belitskiy): Clean this up when migrating to new VM/Docker image - if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then - tfrun bazel --output_user_root 'C:/tmp' test $TFCI_BAZEL_COMMON_ARGS --profile "$PROFILE_JSON_PATH" --@local_config_cuda//cuda:override_include_cuda_libs=true --@local_tsl//third_party/py:verify_manylinux=false --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test" - else - tfrun bazel test $TFCI_BAZEL_COMMON_ARGS --profile "$PROFILE_JSON_PATH" --@local_config_cuda//cuda:override_include_cuda_libs=true --@local_tsl//third_party/py:verify_manylinux=false --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test" - fi + tfrun bazel $TFCI_BAZEL_BAZELRC_ARGS test $TFCI_BAZEL_COMMON_ARGS --profile "$PROFILE_JSON_PATH" --@local_config_cuda//cuda:override_include_cuda_libs=true --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test" fi # Note: the profile can be viewed by visiting chrome://tracing in a Chrome browser. diff --git a/ci/official/requirements_updater/numpy1_requirements/requirements.in b/ci/official/requirements_updater/numpy1_requirements/requirements.in index 2cbb31ca920105..6daebb3f7094dd 100644 --- a/ci/official/requirements_updater/numpy1_requirements/requirements.in +++ b/ci/official/requirements_updater/numpy1_requirements/requirements.in @@ -28,6 +28,20 @@ requests >= 2.31.0 packaging==23.2 setuptools==70.0.0 jax==0.4.7 +# NVIDIA CUDA dependencies +# Note that the wheels are downloaded only when the targets in bazel command +# contain dependencies on these wheels. +nvidia-cublas-cu12 == 12.5.3.2 +nvidia-cuda-cupti-cu12 == 12.5.82 +nvidia-cuda-nvrtc-cu12 == 12.5.82 +nvidia-cuda-runtime-cu12 == 12.5.82 +nvidia-cudnn-cu12 == 9.3.0.75 +nvidia-cufft-cu12 == 11.2.3.61 +nvidia-curand-cu12 == 10.3.6.82 +nvidia-cusolver-cu12 == 11.6.3.83 +nvidia-cusparse-cu12 == 12.5.1.3 +nvidia-nccl-cu12 == 2.23.4 +nvidia-nvjitlink-cu12 == 12.5.82 # The dependencies below are needed for TF wheel testing. tensorflow-io-gcs-filesystem==0.37.1 libclang >= 13.0.0 diff --git a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_10.txt b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_10.txt index a89874be35acb9..dce8c939f26c2f 100644 --- a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_10.txt +++ b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_10.txt @@ -430,6 +430,69 @@ numpy==1.26.4 \ # opt-einsum # scipy # tb-nightly +nvidia-cublas-cu12==12.5.3.2 \ + --hash=sha256:4960f3dc5f39699acadf76fa6d94b10a2a00f2956c2c442efa299fb22b0748f3 \ + --hash=sha256:7d0191251180de606023d396b94d66f66470a0ae96d1dbb906c7656ea0f71eda \ + --hash=sha256:ca070ad70e9fa6654084575d01bd001f30cc4665e33d4bb9fc8e0f321caa034b + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.5.82 \ + --hash=sha256:4f835281cf492e2bedd153f5c3de9da8f1d775a419468305e64ce73b3b0c6dc3 \ + --hash=sha256:bde77a5feb66752ec61db2adfe47f56b941842825b4c7e2068aff27c9d107953 \ + --hash=sha256:d32c06490c6ba35c4323730820c7d0c4c126c04ed58d2f57275adb8d54b138fe + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-nvrtc-cu12==12.5.82 \ + --hash=sha256:3dbd97b0104b4bfbc3c4f8c79cd2496307c89c43c29a9f83125f1d76296ff3fd \ + --hash=sha256:5bb6a0eb01d4974bb7ca3d48bd3859472debb3c3057a5e7de2b08fbdf35eed7e \ + --hash=sha256:e5db37e990056c70953b7772dd778336ef9da0a0b5bb28f9f2a61c2e42b51d78 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-runtime-cu12==12.5.82 \ + --hash=sha256:0fd5fbca289bceb9f0690aa9858f06187b554fdeb7e2711dfd5bb3ce58900b46 \ + --hash=sha256:3e79a060e126df40fd3a068f3f787eb000fa51b251ec6cd97d09579632687115 \ + --hash=sha256:71f015dbf9df05dd71f7480132c6ebf47a6ceb2ab53d7db8e08e4b30ebb87e14 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cudnn-cu12==9.3.0.75 \ + --hash=sha256:9ad9c6929ebb5295eb4a1728024666d1c88283373e265a0c5c883e6f9d5cd76d \ + --hash=sha256:c5cf7ff3415e446adf195a5b7dd2ba56cd00c3ee78bfdc566e51698931aa4b7f \ + --hash=sha256:c819e82eed8cf564b9d37478ea4eab9e87194bb3b7f7f8098bc1f67c9b80f1b6 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cufft-cu12==11.2.3.61 \ + --hash=sha256:4a8f6f0ce93c52a50ee83422a80472b5f376054a63f38532d0eab4007e7ef28b \ + --hash=sha256:6d45b48a5ee7599e57131129cda2c58544d9b78b95064d3ec3e5c6b96e2b58cc \ + --hash=sha256:9a6e8df162585750f61983a638104a48c756aa13f9f48e19ab079b38e3c828b8 + # via -r ci/official/requirements_updater/requirements.in +nvidia-curand-cu12==10.3.6.82 \ + --hash=sha256:0631ba65231260ad832ce233ddda57e7b3b7158eabf000d78e46cbb5bd5b7aae \ + --hash=sha256:2823fb27de4e44dbb22394a6adf53aa6e1b013aca0f8c22867d1cfae58405536 \ + --hash=sha256:36aabeb5990297bbce3df324ea7c7c13c3aabb140c86d50ab3b23e4ec61672f1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusolver-cu12==11.6.3.83 \ + --hash=sha256:1b8b77d2fe8abe72bb722dafb708cceaeb81f1a03999477f20b33b34f46ab885 \ + --hash=sha256:6224732963cba312a84c78114b9a38c4ffabb2e2a6a120923ac99ba6f895c8cf \ + --hash=sha256:93cfafacde4428b71778eeb092ec615a02a3d05404da1bcf91c53e3fa1bce42b + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusparse-cu12==12.5.1.3 \ + --hash=sha256:016df8e993c437e8301e62739f01775cba988fd5253cd4c64173f8e8d2f8e752 \ + --hash=sha256:33520db374e2f5ebc976d6faa1852b98c398a57e6f71150fe59705928596ffd1 \ + --hash=sha256:7b97fd01f0a61628af99d0efd52132fccc8c18fc5c509f13802dccf0574a19c2 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.23.4 \ + --hash=sha256:aa946c8327e22ced28e7cef508a334673abc42064ec85f02d005ba1785ea4cec \ + --hash=sha256:b097258d9aab2fa9f686e33c6fe40ae57b27df60cedbd15d139701bb5509e0c1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-nvjitlink-cu12==12.5.82 \ + --hash=sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27 \ + --hash=sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697 \ + --hash=sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_11.txt b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_11.txt index 3dc9ccbb7eff80..b637200d71addd 100644 --- a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_11.txt +++ b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_11.txt @@ -430,6 +430,69 @@ numpy==1.26.4 \ # opt-einsum # scipy # tb-nightly +nvidia-cublas-cu12==12.5.3.2 \ + --hash=sha256:4960f3dc5f39699acadf76fa6d94b10a2a00f2956c2c442efa299fb22b0748f3 \ + --hash=sha256:7d0191251180de606023d396b94d66f66470a0ae96d1dbb906c7656ea0f71eda \ + --hash=sha256:ca070ad70e9fa6654084575d01bd001f30cc4665e33d4bb9fc8e0f321caa034b + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.5.82 \ + --hash=sha256:4f835281cf492e2bedd153f5c3de9da8f1d775a419468305e64ce73b3b0c6dc3 \ + --hash=sha256:bde77a5feb66752ec61db2adfe47f56b941842825b4c7e2068aff27c9d107953 \ + --hash=sha256:d32c06490c6ba35c4323730820c7d0c4c126c04ed58d2f57275adb8d54b138fe + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-nvrtc-cu12==12.5.82 \ + --hash=sha256:3dbd97b0104b4bfbc3c4f8c79cd2496307c89c43c29a9f83125f1d76296ff3fd \ + --hash=sha256:5bb6a0eb01d4974bb7ca3d48bd3859472debb3c3057a5e7de2b08fbdf35eed7e \ + --hash=sha256:e5db37e990056c70953b7772dd778336ef9da0a0b5bb28f9f2a61c2e42b51d78 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-runtime-cu12==12.5.82 \ + --hash=sha256:0fd5fbca289bceb9f0690aa9858f06187b554fdeb7e2711dfd5bb3ce58900b46 \ + --hash=sha256:3e79a060e126df40fd3a068f3f787eb000fa51b251ec6cd97d09579632687115 \ + --hash=sha256:71f015dbf9df05dd71f7480132c6ebf47a6ceb2ab53d7db8e08e4b30ebb87e14 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cudnn-cu12==9.3.0.75 \ + --hash=sha256:9ad9c6929ebb5295eb4a1728024666d1c88283373e265a0c5c883e6f9d5cd76d \ + --hash=sha256:c5cf7ff3415e446adf195a5b7dd2ba56cd00c3ee78bfdc566e51698931aa4b7f \ + --hash=sha256:c819e82eed8cf564b9d37478ea4eab9e87194bb3b7f7f8098bc1f67c9b80f1b6 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cufft-cu12==11.2.3.61 \ + --hash=sha256:4a8f6f0ce93c52a50ee83422a80472b5f376054a63f38532d0eab4007e7ef28b \ + --hash=sha256:6d45b48a5ee7599e57131129cda2c58544d9b78b95064d3ec3e5c6b96e2b58cc \ + --hash=sha256:9a6e8df162585750f61983a638104a48c756aa13f9f48e19ab079b38e3c828b8 + # via -r ci/official/requirements_updater/requirements.in +nvidia-curand-cu12==10.3.6.82 \ + --hash=sha256:0631ba65231260ad832ce233ddda57e7b3b7158eabf000d78e46cbb5bd5b7aae \ + --hash=sha256:2823fb27de4e44dbb22394a6adf53aa6e1b013aca0f8c22867d1cfae58405536 \ + --hash=sha256:36aabeb5990297bbce3df324ea7c7c13c3aabb140c86d50ab3b23e4ec61672f1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusolver-cu12==11.6.3.83 \ + --hash=sha256:1b8b77d2fe8abe72bb722dafb708cceaeb81f1a03999477f20b33b34f46ab885 \ + --hash=sha256:6224732963cba312a84c78114b9a38c4ffabb2e2a6a120923ac99ba6f895c8cf \ + --hash=sha256:93cfafacde4428b71778eeb092ec615a02a3d05404da1bcf91c53e3fa1bce42b + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusparse-cu12==12.5.1.3 \ + --hash=sha256:016df8e993c437e8301e62739f01775cba988fd5253cd4c64173f8e8d2f8e752 \ + --hash=sha256:33520db374e2f5ebc976d6faa1852b98c398a57e6f71150fe59705928596ffd1 \ + --hash=sha256:7b97fd01f0a61628af99d0efd52132fccc8c18fc5c509f13802dccf0574a19c2 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.23.4 \ + --hash=sha256:aa946c8327e22ced28e7cef508a334673abc42064ec85f02d005ba1785ea4cec \ + --hash=sha256:b097258d9aab2fa9f686e33c6fe40ae57b27df60cedbd15d139701bb5509e0c1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-nvjitlink-cu12==12.5.82 \ + --hash=sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27 \ + --hash=sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697 \ + --hash=sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_12.txt b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_12.txt index 2ea408a671a827..a5ab8820abfcbb 100644 --- a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_12.txt +++ b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_12.txt @@ -430,6 +430,69 @@ numpy==1.26.4 \ # opt-einsum # scipy # tb-nightly +nvidia-cublas-cu12==12.5.3.2 \ + --hash=sha256:4960f3dc5f39699acadf76fa6d94b10a2a00f2956c2c442efa299fb22b0748f3 \ + --hash=sha256:7d0191251180de606023d396b94d66f66470a0ae96d1dbb906c7656ea0f71eda \ + --hash=sha256:ca070ad70e9fa6654084575d01bd001f30cc4665e33d4bb9fc8e0f321caa034b + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.5.82 \ + --hash=sha256:4f835281cf492e2bedd153f5c3de9da8f1d775a419468305e64ce73b3b0c6dc3 \ + --hash=sha256:bde77a5feb66752ec61db2adfe47f56b941842825b4c7e2068aff27c9d107953 \ + --hash=sha256:d32c06490c6ba35c4323730820c7d0c4c126c04ed58d2f57275adb8d54b138fe + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-nvrtc-cu12==12.5.82 \ + --hash=sha256:3dbd97b0104b4bfbc3c4f8c79cd2496307c89c43c29a9f83125f1d76296ff3fd \ + --hash=sha256:5bb6a0eb01d4974bb7ca3d48bd3859472debb3c3057a5e7de2b08fbdf35eed7e \ + --hash=sha256:e5db37e990056c70953b7772dd778336ef9da0a0b5bb28f9f2a61c2e42b51d78 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-runtime-cu12==12.5.82 \ + --hash=sha256:0fd5fbca289bceb9f0690aa9858f06187b554fdeb7e2711dfd5bb3ce58900b46 \ + --hash=sha256:3e79a060e126df40fd3a068f3f787eb000fa51b251ec6cd97d09579632687115 \ + --hash=sha256:71f015dbf9df05dd71f7480132c6ebf47a6ceb2ab53d7db8e08e4b30ebb87e14 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cudnn-cu12==9.3.0.75 \ + --hash=sha256:9ad9c6929ebb5295eb4a1728024666d1c88283373e265a0c5c883e6f9d5cd76d \ + --hash=sha256:c5cf7ff3415e446adf195a5b7dd2ba56cd00c3ee78bfdc566e51698931aa4b7f \ + --hash=sha256:c819e82eed8cf564b9d37478ea4eab9e87194bb3b7f7f8098bc1f67c9b80f1b6 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cufft-cu12==11.2.3.61 \ + --hash=sha256:4a8f6f0ce93c52a50ee83422a80472b5f376054a63f38532d0eab4007e7ef28b \ + --hash=sha256:6d45b48a5ee7599e57131129cda2c58544d9b78b95064d3ec3e5c6b96e2b58cc \ + --hash=sha256:9a6e8df162585750f61983a638104a48c756aa13f9f48e19ab079b38e3c828b8 + # via -r ci/official/requirements_updater/requirements.in +nvidia-curand-cu12==10.3.6.82 \ + --hash=sha256:0631ba65231260ad832ce233ddda57e7b3b7158eabf000d78e46cbb5bd5b7aae \ + --hash=sha256:2823fb27de4e44dbb22394a6adf53aa6e1b013aca0f8c22867d1cfae58405536 \ + --hash=sha256:36aabeb5990297bbce3df324ea7c7c13c3aabb140c86d50ab3b23e4ec61672f1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusolver-cu12==11.6.3.83 \ + --hash=sha256:1b8b77d2fe8abe72bb722dafb708cceaeb81f1a03999477f20b33b34f46ab885 \ + --hash=sha256:6224732963cba312a84c78114b9a38c4ffabb2e2a6a120923ac99ba6f895c8cf \ + --hash=sha256:93cfafacde4428b71778eeb092ec615a02a3d05404da1bcf91c53e3fa1bce42b + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusparse-cu12==12.5.1.3 \ + --hash=sha256:016df8e993c437e8301e62739f01775cba988fd5253cd4c64173f8e8d2f8e752 \ + --hash=sha256:33520db374e2f5ebc976d6faa1852b98c398a57e6f71150fe59705928596ffd1 \ + --hash=sha256:7b97fd01f0a61628af99d0efd52132fccc8c18fc5c509f13802dccf0574a19c2 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.23.4 \ + --hash=sha256:aa946c8327e22ced28e7cef508a334673abc42064ec85f02d005ba1785ea4cec \ + --hash=sha256:b097258d9aab2fa9f686e33c6fe40ae57b27df60cedbd15d139701bb5509e0c1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-nvjitlink-cu12==12.5.82 \ + --hash=sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27 \ + --hash=sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697 \ + --hash=sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_9.txt b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_9.txt index d520f09659073c..3ebea86d0a62e1 100644 --- a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_9.txt +++ b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_9.txt @@ -434,6 +434,69 @@ numpy==1.26.4 \ # opt-einsum # scipy # tb-nightly +nvidia-cublas-cu12==12.5.3.2 \ + --hash=sha256:4960f3dc5f39699acadf76fa6d94b10a2a00f2956c2c442efa299fb22b0748f3 \ + --hash=sha256:7d0191251180de606023d396b94d66f66470a0ae96d1dbb906c7656ea0f71eda \ + --hash=sha256:ca070ad70e9fa6654084575d01bd001f30cc4665e33d4bb9fc8e0f321caa034b + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.5.82 \ + --hash=sha256:4f835281cf492e2bedd153f5c3de9da8f1d775a419468305e64ce73b3b0c6dc3 \ + --hash=sha256:bde77a5feb66752ec61db2adfe47f56b941842825b4c7e2068aff27c9d107953 \ + --hash=sha256:d32c06490c6ba35c4323730820c7d0c4c126c04ed58d2f57275adb8d54b138fe + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-nvrtc-cu12==12.5.82 \ + --hash=sha256:3dbd97b0104b4bfbc3c4f8c79cd2496307c89c43c29a9f83125f1d76296ff3fd \ + --hash=sha256:5bb6a0eb01d4974bb7ca3d48bd3859472debb3c3057a5e7de2b08fbdf35eed7e \ + --hash=sha256:e5db37e990056c70953b7772dd778336ef9da0a0b5bb28f9f2a61c2e42b51d78 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-runtime-cu12==12.5.82 \ + --hash=sha256:0fd5fbca289bceb9f0690aa9858f06187b554fdeb7e2711dfd5bb3ce58900b46 \ + --hash=sha256:3e79a060e126df40fd3a068f3f787eb000fa51b251ec6cd97d09579632687115 \ + --hash=sha256:71f015dbf9df05dd71f7480132c6ebf47a6ceb2ab53d7db8e08e4b30ebb87e14 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cudnn-cu12==9.3.0.75 \ + --hash=sha256:9ad9c6929ebb5295eb4a1728024666d1c88283373e265a0c5c883e6f9d5cd76d \ + --hash=sha256:c5cf7ff3415e446adf195a5b7dd2ba56cd00c3ee78bfdc566e51698931aa4b7f \ + --hash=sha256:c819e82eed8cf564b9d37478ea4eab9e87194bb3b7f7f8098bc1f67c9b80f1b6 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cufft-cu12==11.2.3.61 \ + --hash=sha256:4a8f6f0ce93c52a50ee83422a80472b5f376054a63f38532d0eab4007e7ef28b \ + --hash=sha256:6d45b48a5ee7599e57131129cda2c58544d9b78b95064d3ec3e5c6b96e2b58cc \ + --hash=sha256:9a6e8df162585750f61983a638104a48c756aa13f9f48e19ab079b38e3c828b8 + # via -r ci/official/requirements_updater/requirements.in +nvidia-curand-cu12==10.3.6.82 \ + --hash=sha256:0631ba65231260ad832ce233ddda57e7b3b7158eabf000d78e46cbb5bd5b7aae \ + --hash=sha256:2823fb27de4e44dbb22394a6adf53aa6e1b013aca0f8c22867d1cfae58405536 \ + --hash=sha256:36aabeb5990297bbce3df324ea7c7c13c3aabb140c86d50ab3b23e4ec61672f1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusolver-cu12==11.6.3.83 \ + --hash=sha256:1b8b77d2fe8abe72bb722dafb708cceaeb81f1a03999477f20b33b34f46ab885 \ + --hash=sha256:6224732963cba312a84c78114b9a38c4ffabb2e2a6a120923ac99ba6f895c8cf \ + --hash=sha256:93cfafacde4428b71778eeb092ec615a02a3d05404da1bcf91c53e3fa1bce42b + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusparse-cu12==12.5.1.3 \ + --hash=sha256:016df8e993c437e8301e62739f01775cba988fd5253cd4c64173f8e8d2f8e752 \ + --hash=sha256:33520db374e2f5ebc976d6faa1852b98c398a57e6f71150fe59705928596ffd1 \ + --hash=sha256:7b97fd01f0a61628af99d0efd52132fccc8c18fc5c509f13802dccf0574a19c2 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.23.4 \ + --hash=sha256:aa946c8327e22ced28e7cef508a334673abc42064ec85f02d005ba1785ea4cec \ + --hash=sha256:b097258d9aab2fa9f686e33c6fe40ae57b27df60cedbd15d139701bb5509e0c1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-nvjitlink-cu12==12.5.82 \ + --hash=sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27 \ + --hash=sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697 \ + --hash=sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/ci/official/requirements_updater/requirements.in b/ci/official/requirements_updater/requirements.in index a1738d6008c7a9..4832983df6ce74 100644 --- a/ci/official/requirements_updater/requirements.in +++ b/ci/official/requirements_updater/requirements.in @@ -28,6 +28,20 @@ requests >= 2.31.0 packaging==23.2 setuptools==70.0.0 jax==0.4.7 +# NVIDIA CUDA dependencies +# Note that the wheels are downloaded only when the targets in bazel command +# contain dependencies on these wheels. +nvidia-cublas-cu12 == 12.5.3.2 +nvidia-cuda-cupti-cu12 == 12.5.82 +nvidia-cuda-nvrtc-cu12 == 12.5.82 +nvidia-cuda-runtime-cu12 == 12.5.82 +nvidia-cudnn-cu12 == 9.3.0.75 +nvidia-cufft-cu12 == 11.2.3.61 +nvidia-curand-cu12 == 10.3.6.82 +nvidia-cusolver-cu12 == 11.6.3.83 +nvidia-cusparse-cu12 == 12.5.1.3 +nvidia-nccl-cu12 == 2.23.4 +nvidia-nvjitlink-cu12 == 12.5.82 # The dependencies below are needed for TF wheel testing. tensorflow-io-gcs-filesystem==0.37.1 libclang >= 13.0.0 diff --git a/ci/official/utilities/cleanup_summary.sh b/ci/official/utilities/cleanup_summary.sh index 6b6fdfaa855106..1cb89f017104ea 100755 --- a/ci/official/utilities/cleanup_summary.sh +++ b/ci/official/utilities/cleanup_summary.sh @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -set -euxo pipefail +set -exo pipefail function resultstore_extract_fallback { # In case the main script fails somehow. diff --git a/ci/official/utilities/code_check_full.bats b/ci/official/utilities/code_check_full.bats index f681a78b2461e3..5adc64e62f7f62 100644 --- a/ci/official/utilities/code_check_full.bats +++ b/ci/official/utilities/code_check_full.bats @@ -304,7 +304,7 @@ EOF # anything with a Windows-only toolchain, and bazel errors if trying to build # that directory. @test "bazel nobuild passes on all of TF except TF Lite and win toolchains" { - bazel build --experimental_cc_shared_library --nobuild --keep_going -- //tensorflow/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/win/... -//tensorflow/tools/toolchains/win_1803/... + bazel build --experimental_cc_shared_library --nobuild --keep_going -- //tensorflow/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/win/... -//tensorflow/tools/toolchains/win_1803/... -//tensorflow/tools/toolchains/win2022/... } @test "API compatibility test passes, ensuring no unexpected changes to the TF API" { @@ -316,7 +316,7 @@ EOF # See b/279852433 (internal). # TODO(b/279852433) Replace deps(//tensorflow/...) with deps(//...) @test "Verify that it's possible to query every TensorFlow target without BUILD errors" { - bazel query "deps(//tensorflow/... -//tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test)" > /dev/null + bazel query "deps(//tensorflow/... -attr(tags, 'manual', //tensorflow/...))" > /dev/null } teardown_file() { diff --git a/ci/official/utilities/rename_and_verify_wheels.sh b/ci/official/utilities/rename_and_verify_wheels.sh index 2111be61b802cc..34389f79264f12 100755 --- a/ci/official/utilities/rename_and_verify_wheels.sh +++ b/ci/official/utilities/rename_and_verify_wheels.sh @@ -19,7 +19,7 @@ # This script is aware of TFCI_ variables, so it doesn't need any arguments. # Puts new wheel through auditwheel to rename and verify it, deletes the old # one, checks the filesize, and then ensures the new wheel is installable. -set -euxo pipefail +set -exo pipefail cd "$TFCI_OUTPUT_DIR" @@ -46,7 +46,7 @@ fi # Check if size is too big. TFCI_WHL_SIZE_LIMIT is in find's format, which can be # 'k' for kilobytes, 'M' for megabytes, or 'G' for gigabytes, and the + to indicate # "anything greater than" is added by the script. -if [[ "$TFCI_WHL_SIZE_LIMIT_ENABLE" == "1" ]] && [[ -n "$(find . -iname "*.whl" -size "+$TFCI_WHL_SIZE_LIMIT")" ]]; then +if [[ "$TFCI_WHL_SIZE_LIMIT_ENABLE" == "1" ]] && [[ -n "$("$TFCI_FIND_BIN" . -iname "*.whl" -size "+$TFCI_WHL_SIZE_LIMIT")" ]]; then echo "Error: Generated wheel is too big! Limit is $TFCI_WHL_SIZE_LIMIT" echo '(search for TFCI_WHL_SIZE_LIMIT to change it)' ls -sh *.whl @@ -54,9 +54,18 @@ if [[ "$TFCI_WHL_SIZE_LIMIT_ENABLE" == "1" ]] && [[ -n "$(find . -iname "*.whl" fi # Quick install checks -venv=$(mktemp -d) -"python${TFCI_PYTHON_VERSION}" -m venv "$venv" -python="$venv/bin/python3" +venv_dir=$(mktemp -d) +if [[ $(uname -s) != MSYS_NT* ]]; then + "python${TFCI_PYTHON_VERSION}" -m venv "$venv_dir" + python="$venv_dir/bin/python3" +else + # When using the Linux-like path, venv creation quietly fails, which is + # why it's converted here. + venv_dir=$(cygpath -m $venv_dir) + "/c/python${TFCI_PYTHON_VERSION}/python.exe" -m venv "$venv_dir" + python="$venv_dir/Scripts/python.exe" +fi + # TODO(b/366266944) Remove the check after tf docker image upgrade for NumPy 2 # and numpy 1 support is dropped b/361369076. if [[ "$TFCI_WHL_NUMPY_VERSION" == 1 ]]; then diff --git a/ci/official/utilities/repack_libtensorflow.sh b/ci/official/utilities/repack_libtensorflow.sh index 0f549bf0975d73..5dc6f6c60f5a25 100755 --- a/ci/official/utilities/repack_libtensorflow.sh +++ b/ci/official/utilities/repack_libtensorflow.sh @@ -54,11 +54,94 @@ function cp_normalized_srcjar() { cp "${tmp_dir}/new.jar" "${dest_jar}" rm -rf "${tmp_dir}" } + DIR=$1 -TARBALL_SUFFIX=$2 mkdir -p "$DIR" -cp bazel-bin/tensorflow/tools/lib_package/libtensorflow.tar.gz "${DIR}/libtensorflow${TARBALL_SUFFIX}.tar.gz" -cp bazel-bin/tensorflow/tools/lib_package/libtensorflow_jni.tar.gz "${DIR}/libtensorflow_jni${TARBALL_SUFFIX}.tar.gz" -cp bazel-bin/tensorflow/java/libtensorflow.jar "${DIR}" -cp_normalized_srcjar bazel-bin/tensorflow/java/libtensorflow-src.jar "${DIR}/libtensorflow-src.jar" -cp bazel-bin/tensorflow/tools/lib_package/libtensorflow_proto.zip "${DIR}" +TARBALL_SUFFIX=$2 + +if [[ $(uname -s) != MSYS_NT* ]]; then + cp bazel-bin/tensorflow/tools/lib_package/libtensorflow.tar.gz "${DIR}/libtensorflow${TARBALL_SUFFIX}.tar.gz" + cp bazel-bin/tensorflow/tools/lib_package/libtensorflow_jni.tar.gz "${DIR}/libtensorflow_jni${TARBALL_SUFFIX}.tar.gz" + cp bazel-bin/tensorflow/java/libtensorflow.jar "${DIR}" + cp_normalized_srcjar bazel-bin/tensorflow/java/libtensorflow-src.jar "${DIR}/libtensorflow-src.jar" + cp bazel-bin/tensorflow/tools/lib_package/libtensorflow_proto.zip "${DIR}" +else + LIB_PKG="$1/lib_package" + mkdir -p ${LIB_PKG} + + # Zip up the .dll and the LICENSE for the JNI library. + cp bazel-bin/tensorflow/java/tensorflow_jni.dll ${LIB_PKG}/tensorflow_jni.dll + zip -j ${LIB_PKG}/libtensorflow_jni-cpu-windows-$(uname -m).zip \ + ${LIB_PKG}/tensorflow_jni.dll \ + bazel-bin/tensorflow/tools/lib_package/include/tensorflow/THIRD_PARTY_TF_JNI_LICENSES \ + LICENSE + rm -f ${LIB_PKG}/tensorflow_jni.dll + + # Zip up the .dll, LICENSE and include files for the C library. + mkdir -p ${LIB_PKG}/include/tensorflow/c + mkdir -p ${LIB_PKG}/include/tensorflow/c/eager + mkdir -p ${LIB_PKG}/include/tensorflow/core/platform + mkdir -p ${LIB_PKG}/include/xla/tsl/c + mkdir -p ${LIB_PKG}/include/tsl/platform + mkdir -p ${LIB_PKG}/lib + cp bazel-bin/tensorflow/tensorflow.dll ${LIB_PKG}/lib/tensorflow.dll + cp bazel-bin/tensorflow/tensorflow.lib ${LIB_PKG}/lib/tensorflow.lib + cp tensorflow/c/c_api.h \ + tensorflow/c/tf_attrtype.h \ + tensorflow/c/tf_buffer.h \ + tensorflow/c/tf_datatype.h \ + tensorflow/c/tf_status.h \ + tensorflow/c/tf_tensor.h \ + tensorflow/c/tf_tensor_helper.h \ + tensorflow/c/tf_tstring.h \ + tensorflow/c/tf_file_statistics.h \ + tensorflow/c/tensor_interface.h \ + tensorflow/c/c_api_macros.h \ + tensorflow/c/c_api_experimental.h \ + ${LIB_PKG}/include/tensorflow/c + cp tensorflow/c/eager/c_api.h \ + tensorflow/c/eager/c_api_experimental.h \ + tensorflow/c/eager/dlpack.h \ + ${LIB_PKG}/include/tensorflow/c/eager + cp tensorflow/core/platform/ctstring.h \ + tensorflow/core/platform/ctstring_internal.h \ + ${LIB_PKG}/include/tensorflow/core/platform + cp third_party/xla/xla/tsl/c/tsl_status.h ${LIB_PKG}/include/xla/tsl/c + cp third_party/xla/third_party/tsl/tsl/platform/ctstring.h \ + third_party/xla/third_party/tsl/tsl/platform/ctstring_internal.h \ + ${LIB_PKG}/include/tsl/platform + cp LICENSE ${LIB_PKG}/LICENSE + cp bazel-bin/tensorflow/tools/lib_package/THIRD_PARTY_TF_C_LICENSES ${LIB_PKG}/ + cd ${LIB_PKG} + zip libtensorflow-cpu-windows-$(uname -m).zip \ + lib/tensorflow.dll \ + lib/tensorflow.lib \ + include/tensorflow/c/eager/c_api.h \ + include/tensorflow/c/eager/c_api_experimental.h \ + include/tensorflow/c/eager/dlpack.h \ + include/tensorflow/c/c_api.h \ + include/tensorflow/c/tf_attrtype.h \ + include/tensorflow/c/tf_buffer.h \ + include/tensorflow/c/tf_datatype.h \ + include/tensorflow/c/tf_status.h \ + include/tensorflow/c/tf_tensor.h \ + include/tensorflow/c/tf_tensor_helper.h \ + include/tensorflow/c/tf_tstring.h \ + include/tensorflow/c/tf_file_statistics.h \ + include/tensorflow/c/tensor_interface.h \ + include/tensorflow/c/c_api_macros.h \ + include/tensorflow/c/c_api_experimental.h \ + include/tensorflow/core/platform/ctstring.h \ + include/tensorflow/core/platform/ctstring_internal.h \ + include/xla/tsl/c/tsl_status.h \ + include/tsl/platform/ctstring.h \ + include/tsl/platform/ctstring_internal.h \ + LICENSE \ + THIRD_PARTY_TF_C_LICENSES + rm -rf lib include + + cd .. + tar -zcvf windows_cpu_libtensorflow_binaries.tar.gz $LIB_PKG + rm -rf $LIB_PKG + +fi diff --git a/ci/official/utilities/setup.sh b/ci/official/utilities/setup.sh index bca1c781802046..829fdbdc34f911 100755 --- a/ci/official/utilities/setup.sh +++ b/ci/official/utilities/setup.sh @@ -29,7 +29,7 @@ # -o history: record shell history # -o allexport: export all functions and variables to be available to subscripts # (affects 'source $TFCI') -set -euxo pipefail -o history -o allexport +set -exo pipefail -o history -o allexport # Set TFCI_GIT_DIR, the root directory for all commands, to two directories # above the location of this file (setup.sh). We could also use "git rev-parse @@ -81,6 +81,7 @@ else source "$FROM_ENV" rm "$FROM_ENV" fi + set +u fi # If building installer wheels, set the required environment variables that are @@ -118,7 +119,7 @@ exec > >(tee "$TFCI_OUTPUT_DIR/script.log") 2>&1 # functionality instead. tfrun() { "$@"; } -if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then +if [[ $(uname -s) = MSYS_NT* ]]; then source ./ci/official/utilities/windows.sh echo 'Converting MSYS Linux-like paths to Windows paths (for Docker, Python, etc.)' source <(python ./ci/official/utilities/convert_msys_paths_to_win_paths.py --whitelist-prefix TFCI_) diff --git a/ci/official/utilities/setup_docker.sh b/ci/official/utilities/setup_docker.sh index 61db7c2e124d0a..d928272d5ae1a3 100755 --- a/ci/official/utilities/setup_docker.sh +++ b/ci/official/utilities/setup_docker.sh @@ -38,15 +38,17 @@ if ! docker container inspect tf >/dev/null 2>&1 ; then env_file=$(mktemp) env | grep ^TFCI_ > "$env_file" + if [[ $(uname -s) == MSYS_NT* ]]; then + is_windows=true + else + is_windows=false + fi + WORKING_DIR="$TFCI_GIT_DIR" - if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then + if [[ "$is_windows" == true ]]; then env_file=$(cygpath -m $env_file) - # Host dirs can only be mapped to an existing drive inside the container, so - # T:\ is replaced with C:\. - _TFCI_OUTPUT_DIR_WIN=$(replace_drive_letter_with_c "$TFCI_OUTPUT_DIR") - sed -iE 's|^TFCI_OUTPUT_DIR=.*|TFCI_OUTPUT_DIR='"$_TFCI_OUTPUT_DIR_WIN"'|g' $env_file - WORKING_DIR=$(replace_drive_letter_with_c "$TFCI_GIT_DIR") - echo "GCE_METADATA_HOST=$IP_ADDR" > $env_file + WORKING_DIR=$(replace_drive_letter_with_prefix "$TFCI_GIT_DIR" "$TFCI_OUTPUT_WIN_DOCKER_DIR") + echo "GCE_METADATA_HOST=$IP_ADDR" >> $env_file fi docker run $TFCI_DOCKER_ARGS --name tf -w "$WORKING_DIR" -itd --rm \ @@ -55,7 +57,7 @@ if ! docker container inspect tf >/dev/null 2>&1 ; then "$TFCI_DOCKER_IMAGE" \ bash - if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then + if [[ "$is_windows" == true ]]; then # Allow requests from the container. # Additional setup is contained in ci/official/envs/rbe. CONTAINER_IP_ADDR=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' tf) diff --git a/ci/official/utilities/windows.sh b/ci/official/utilities/windows.sh index 1ab2d89ef327f6..00c564d0363ba5 100644 --- a/ci/official/utilities/windows.sh +++ b/ci/official/utilities/windows.sh @@ -19,8 +19,15 @@ # Docker on Windows has difficulty using volumes other than C:\, when it comes # to setting up up volume mappings. -# Thus, the drive letter is replaced with C:\, in case it's -# something else (ex. T:), which is frequently the case inside Kokoro jobs. -function replace_drive_letter_with_c () { - sed -E "s|^[a-zA-Z]:|C:|g" <<< $1 +# Thus, the drive letter is replaced with the passed prefix. +# If no prefix is passed, by default, it's replaced with C:\, in case it's +# something else (ex. T:), which is a volume used in internal CI. +function replace_drive_letter_with_prefix () { + local path_prefix + if [[ -z "$2" ]]; then + path_prefix="C:" + else + path_prefix="$2" + fi + sed -E "s|^[a-zA-Z]:|${path_prefix}|g" <<< "$1" } diff --git a/ci/official/wheel.sh b/ci/official/wheel.sh index 11017934a009f7..b51c7ece243309 100755 --- a/ci/official/wheel.sh +++ b/ci/official/wheel.sh @@ -33,13 +33,12 @@ if [[ "$TFCI_WHL_NUMPY_VERSION" == 1 ]]; then cp ./ci/official/requirements_updater/numpy1_requirements/*.txt . fi -# TODO(ybaturina): add --@local_tsl//third_party/py:verify_manylinux=true when -# hermetic CC toolchain is ready. -tfrun bazel build $TFCI_BAZEL_COMMON_ARGS --config=cuda_wheel //tensorflow/tools/pip_package:wheel $TFCI_BUILD_PIP_PACKAGE_ARGS -tfrun find ./bazel-bin/tensorflow/tools/pip_package -iname "*.whl" -exec cp {} $TFCI_OUTPUT_DIR \; +tfrun bazel $TFCI_BAZEL_BAZELRC_ARGS build $TFCI_BAZEL_COMMON_ARGS --config=cuda_wheel //tensorflow/tools/pip_package:wheel $TFCI_BUILD_PIP_PACKAGE_ARGS + +tfrun "$TFCI_FIND_BIN" ./bazel-bin/tensorflow/tools/pip_package -iname "*.whl" -exec cp {} $TFCI_OUTPUT_DIR \; tfrun mkdir ./dist tfrun cp $TFCI_OUTPUT_DIR/*.whl ./dist -tfrun ./ci/official/utilities/rename_and_verify_wheels.sh +tfrun bash ./ci/official/utilities/rename_and_verify_wheels.sh if [[ "$TFCI_ARTIFACT_STAGING_GCS_ENABLE" == 1 ]]; then # Note: -n disables overwriting previously created files. @@ -47,5 +46,5 @@ if [[ "$TFCI_ARTIFACT_STAGING_GCS_ENABLE" == 1 ]]; then fi if [[ "$TFCI_WHL_BAZEL_TEST_ENABLE" == 1 ]]; then - tfrun bazel test $TFCI_BAZEL_COMMON_ARGS $TFCI_BUILD_PIP_PACKAGE_ARGS --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_wheel_test" + tfrun bazel $TFCI_BAZEL_BAZELRC_ARGS test $TFCI_BAZEL_COMMON_ARGS $TFCI_BUILD_PIP_PACKAGE_ARGS --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_wheel_test" fi diff --git a/requirements_lock_3_10.txt b/requirements_lock_3_10.txt index 4f0dd8497c8979..b298293f6c1cd9 100644 --- a/requirements_lock_3_10.txt +++ b/requirements_lock_3_10.txt @@ -447,6 +447,69 @@ numpy==2.1.1 \ # opt-einsum # scipy # tb-nightly +nvidia-cublas-cu12==12.5.3.2 \ + --hash=sha256:4960f3dc5f39699acadf76fa6d94b10a2a00f2956c2c442efa299fb22b0748f3 \ + --hash=sha256:7d0191251180de606023d396b94d66f66470a0ae96d1dbb906c7656ea0f71eda \ + --hash=sha256:ca070ad70e9fa6654084575d01bd001f30cc4665e33d4bb9fc8e0f321caa034b + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.5.82 \ + --hash=sha256:4f835281cf492e2bedd153f5c3de9da8f1d775a419468305e64ce73b3b0c6dc3 \ + --hash=sha256:bde77a5feb66752ec61db2adfe47f56b941842825b4c7e2068aff27c9d107953 \ + --hash=sha256:d32c06490c6ba35c4323730820c7d0c4c126c04ed58d2f57275adb8d54b138fe + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-nvrtc-cu12==12.5.82 \ + --hash=sha256:3dbd97b0104b4bfbc3c4f8c79cd2496307c89c43c29a9f83125f1d76296ff3fd \ + --hash=sha256:5bb6a0eb01d4974bb7ca3d48bd3859472debb3c3057a5e7de2b08fbdf35eed7e \ + --hash=sha256:e5db37e990056c70953b7772dd778336ef9da0a0b5bb28f9f2a61c2e42b51d78 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-runtime-cu12==12.5.82 \ + --hash=sha256:0fd5fbca289bceb9f0690aa9858f06187b554fdeb7e2711dfd5bb3ce58900b46 \ + --hash=sha256:3e79a060e126df40fd3a068f3f787eb000fa51b251ec6cd97d09579632687115 \ + --hash=sha256:71f015dbf9df05dd71f7480132c6ebf47a6ceb2ab53d7db8e08e4b30ebb87e14 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cudnn-cu12==9.3.0.75 \ + --hash=sha256:9ad9c6929ebb5295eb4a1728024666d1c88283373e265a0c5c883e6f9d5cd76d \ + --hash=sha256:c5cf7ff3415e446adf195a5b7dd2ba56cd00c3ee78bfdc566e51698931aa4b7f \ + --hash=sha256:c819e82eed8cf564b9d37478ea4eab9e87194bb3b7f7f8098bc1f67c9b80f1b6 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cufft-cu12==11.2.3.61 \ + --hash=sha256:4a8f6f0ce93c52a50ee83422a80472b5f376054a63f38532d0eab4007e7ef28b \ + --hash=sha256:6d45b48a5ee7599e57131129cda2c58544d9b78b95064d3ec3e5c6b96e2b58cc \ + --hash=sha256:9a6e8df162585750f61983a638104a48c756aa13f9f48e19ab079b38e3c828b8 + # via -r ci/official/requirements_updater/requirements.in +nvidia-curand-cu12==10.3.6.82 \ + --hash=sha256:0631ba65231260ad832ce233ddda57e7b3b7158eabf000d78e46cbb5bd5b7aae \ + --hash=sha256:2823fb27de4e44dbb22394a6adf53aa6e1b013aca0f8c22867d1cfae58405536 \ + --hash=sha256:36aabeb5990297bbce3df324ea7c7c13c3aabb140c86d50ab3b23e4ec61672f1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusolver-cu12==11.6.3.83 \ + --hash=sha256:1b8b77d2fe8abe72bb722dafb708cceaeb81f1a03999477f20b33b34f46ab885 \ + --hash=sha256:6224732963cba312a84c78114b9a38c4ffabb2e2a6a120923ac99ba6f895c8cf \ + --hash=sha256:93cfafacde4428b71778eeb092ec615a02a3d05404da1bcf91c53e3fa1bce42b + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusparse-cu12==12.5.1.3 \ + --hash=sha256:016df8e993c437e8301e62739f01775cba988fd5253cd4c64173f8e8d2f8e752 \ + --hash=sha256:33520db374e2f5ebc976d6faa1852b98c398a57e6f71150fe59705928596ffd1 \ + --hash=sha256:7b97fd01f0a61628af99d0efd52132fccc8c18fc5c509f13802dccf0574a19c2 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.23.4 \ + --hash=sha256:aa946c8327e22ced28e7cef508a334673abc42064ec85f02d005ba1785ea4cec \ + --hash=sha256:b097258d9aab2fa9f686e33c6fe40ae57b27df60cedbd15d139701bb5509e0c1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-nvjitlink-cu12==12.5.82 \ + --hash=sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27 \ + --hash=sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697 \ + --hash=sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/requirements_lock_3_11.txt b/requirements_lock_3_11.txt index 8c922a4ab3c9ae..c667c4e63dd595 100644 --- a/requirements_lock_3_11.txt +++ b/requirements_lock_3_11.txt @@ -447,6 +447,69 @@ numpy==2.1.1 \ # opt-einsum # scipy # tb-nightly +nvidia-cublas-cu12==12.5.3.2 \ + --hash=sha256:4960f3dc5f39699acadf76fa6d94b10a2a00f2956c2c442efa299fb22b0748f3 \ + --hash=sha256:7d0191251180de606023d396b94d66f66470a0ae96d1dbb906c7656ea0f71eda \ + --hash=sha256:ca070ad70e9fa6654084575d01bd001f30cc4665e33d4bb9fc8e0f321caa034b + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.5.82 \ + --hash=sha256:4f835281cf492e2bedd153f5c3de9da8f1d775a419468305e64ce73b3b0c6dc3 \ + --hash=sha256:bde77a5feb66752ec61db2adfe47f56b941842825b4c7e2068aff27c9d107953 \ + --hash=sha256:d32c06490c6ba35c4323730820c7d0c4c126c04ed58d2f57275adb8d54b138fe + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-nvrtc-cu12==12.5.82 \ + --hash=sha256:3dbd97b0104b4bfbc3c4f8c79cd2496307c89c43c29a9f83125f1d76296ff3fd \ + --hash=sha256:5bb6a0eb01d4974bb7ca3d48bd3859472debb3c3057a5e7de2b08fbdf35eed7e \ + --hash=sha256:e5db37e990056c70953b7772dd778336ef9da0a0b5bb28f9f2a61c2e42b51d78 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-runtime-cu12==12.5.82 \ + --hash=sha256:0fd5fbca289bceb9f0690aa9858f06187b554fdeb7e2711dfd5bb3ce58900b46 \ + --hash=sha256:3e79a060e126df40fd3a068f3f787eb000fa51b251ec6cd97d09579632687115 \ + --hash=sha256:71f015dbf9df05dd71f7480132c6ebf47a6ceb2ab53d7db8e08e4b30ebb87e14 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cudnn-cu12==9.3.0.75 \ + --hash=sha256:9ad9c6929ebb5295eb4a1728024666d1c88283373e265a0c5c883e6f9d5cd76d \ + --hash=sha256:c5cf7ff3415e446adf195a5b7dd2ba56cd00c3ee78bfdc566e51698931aa4b7f \ + --hash=sha256:c819e82eed8cf564b9d37478ea4eab9e87194bb3b7f7f8098bc1f67c9b80f1b6 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cufft-cu12==11.2.3.61 \ + --hash=sha256:4a8f6f0ce93c52a50ee83422a80472b5f376054a63f38532d0eab4007e7ef28b \ + --hash=sha256:6d45b48a5ee7599e57131129cda2c58544d9b78b95064d3ec3e5c6b96e2b58cc \ + --hash=sha256:9a6e8df162585750f61983a638104a48c756aa13f9f48e19ab079b38e3c828b8 + # via -r ci/official/requirements_updater/requirements.in +nvidia-curand-cu12==10.3.6.82 \ + --hash=sha256:0631ba65231260ad832ce233ddda57e7b3b7158eabf000d78e46cbb5bd5b7aae \ + --hash=sha256:2823fb27de4e44dbb22394a6adf53aa6e1b013aca0f8c22867d1cfae58405536 \ + --hash=sha256:36aabeb5990297bbce3df324ea7c7c13c3aabb140c86d50ab3b23e4ec61672f1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusolver-cu12==11.6.3.83 \ + --hash=sha256:1b8b77d2fe8abe72bb722dafb708cceaeb81f1a03999477f20b33b34f46ab885 \ + --hash=sha256:6224732963cba312a84c78114b9a38c4ffabb2e2a6a120923ac99ba6f895c8cf \ + --hash=sha256:93cfafacde4428b71778eeb092ec615a02a3d05404da1bcf91c53e3fa1bce42b + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusparse-cu12==12.5.1.3 \ + --hash=sha256:016df8e993c437e8301e62739f01775cba988fd5253cd4c64173f8e8d2f8e752 \ + --hash=sha256:33520db374e2f5ebc976d6faa1852b98c398a57e6f71150fe59705928596ffd1 \ + --hash=sha256:7b97fd01f0a61628af99d0efd52132fccc8c18fc5c509f13802dccf0574a19c2 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.23.4 \ + --hash=sha256:aa946c8327e22ced28e7cef508a334673abc42064ec85f02d005ba1785ea4cec \ + --hash=sha256:b097258d9aab2fa9f686e33c6fe40ae57b27df60cedbd15d139701bb5509e0c1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-nvjitlink-cu12==12.5.82 \ + --hash=sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27 \ + --hash=sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697 \ + --hash=sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/requirements_lock_3_12.txt b/requirements_lock_3_12.txt index 8f971ecfe5cc67..9ae1aa3c7418f3 100644 --- a/requirements_lock_3_12.txt +++ b/requirements_lock_3_12.txt @@ -447,6 +447,69 @@ numpy==2.1.1 \ # opt-einsum # scipy # tb-nightly +nvidia-cublas-cu12==12.5.3.2 \ + --hash=sha256:4960f3dc5f39699acadf76fa6d94b10a2a00f2956c2c442efa299fb22b0748f3 \ + --hash=sha256:7d0191251180de606023d396b94d66f66470a0ae96d1dbb906c7656ea0f71eda \ + --hash=sha256:ca070ad70e9fa6654084575d01bd001f30cc4665e33d4bb9fc8e0f321caa034b + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.5.82 \ + --hash=sha256:4f835281cf492e2bedd153f5c3de9da8f1d775a419468305e64ce73b3b0c6dc3 \ + --hash=sha256:bde77a5feb66752ec61db2adfe47f56b941842825b4c7e2068aff27c9d107953 \ + --hash=sha256:d32c06490c6ba35c4323730820c7d0c4c126c04ed58d2f57275adb8d54b138fe + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-nvrtc-cu12==12.5.82 \ + --hash=sha256:3dbd97b0104b4bfbc3c4f8c79cd2496307c89c43c29a9f83125f1d76296ff3fd \ + --hash=sha256:5bb6a0eb01d4974bb7ca3d48bd3859472debb3c3057a5e7de2b08fbdf35eed7e \ + --hash=sha256:e5db37e990056c70953b7772dd778336ef9da0a0b5bb28f9f2a61c2e42b51d78 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-runtime-cu12==12.5.82 \ + --hash=sha256:0fd5fbca289bceb9f0690aa9858f06187b554fdeb7e2711dfd5bb3ce58900b46 \ + --hash=sha256:3e79a060e126df40fd3a068f3f787eb000fa51b251ec6cd97d09579632687115 \ + --hash=sha256:71f015dbf9df05dd71f7480132c6ebf47a6ceb2ab53d7db8e08e4b30ebb87e14 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cudnn-cu12==9.3.0.75 \ + --hash=sha256:9ad9c6929ebb5295eb4a1728024666d1c88283373e265a0c5c883e6f9d5cd76d \ + --hash=sha256:c5cf7ff3415e446adf195a5b7dd2ba56cd00c3ee78bfdc566e51698931aa4b7f \ + --hash=sha256:c819e82eed8cf564b9d37478ea4eab9e87194bb3b7f7f8098bc1f67c9b80f1b6 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cufft-cu12==11.2.3.61 \ + --hash=sha256:4a8f6f0ce93c52a50ee83422a80472b5f376054a63f38532d0eab4007e7ef28b \ + --hash=sha256:6d45b48a5ee7599e57131129cda2c58544d9b78b95064d3ec3e5c6b96e2b58cc \ + --hash=sha256:9a6e8df162585750f61983a638104a48c756aa13f9f48e19ab079b38e3c828b8 + # via -r ci/official/requirements_updater/requirements.in +nvidia-curand-cu12==10.3.6.82 \ + --hash=sha256:0631ba65231260ad832ce233ddda57e7b3b7158eabf000d78e46cbb5bd5b7aae \ + --hash=sha256:2823fb27de4e44dbb22394a6adf53aa6e1b013aca0f8c22867d1cfae58405536 \ + --hash=sha256:36aabeb5990297bbce3df324ea7c7c13c3aabb140c86d50ab3b23e4ec61672f1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusolver-cu12==11.6.3.83 \ + --hash=sha256:1b8b77d2fe8abe72bb722dafb708cceaeb81f1a03999477f20b33b34f46ab885 \ + --hash=sha256:6224732963cba312a84c78114b9a38c4ffabb2e2a6a120923ac99ba6f895c8cf \ + --hash=sha256:93cfafacde4428b71778eeb092ec615a02a3d05404da1bcf91c53e3fa1bce42b + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusparse-cu12==12.5.1.3 \ + --hash=sha256:016df8e993c437e8301e62739f01775cba988fd5253cd4c64173f8e8d2f8e752 \ + --hash=sha256:33520db374e2f5ebc976d6faa1852b98c398a57e6f71150fe59705928596ffd1 \ + --hash=sha256:7b97fd01f0a61628af99d0efd52132fccc8c18fc5c509f13802dccf0574a19c2 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.23.4 \ + --hash=sha256:aa946c8327e22ced28e7cef508a334673abc42064ec85f02d005ba1785ea4cec \ + --hash=sha256:b097258d9aab2fa9f686e33c6fe40ae57b27df60cedbd15d139701bb5509e0c1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-nvjitlink-cu12==12.5.82 \ + --hash=sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27 \ + --hash=sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697 \ + --hash=sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/requirements_lock_3_9.txt b/requirements_lock_3_9.txt index 1c5ecd23b50bd1..6187d7cca59e2b 100644 --- a/requirements_lock_3_9.txt +++ b/requirements_lock_3_9.txt @@ -443,6 +443,69 @@ numpy==2.0.2 \ # opt-einsum # scipy # tb-nightly +nvidia-cublas-cu12==12.5.3.2 \ + --hash=sha256:4960f3dc5f39699acadf76fa6d94b10a2a00f2956c2c442efa299fb22b0748f3 \ + --hash=sha256:7d0191251180de606023d396b94d66f66470a0ae96d1dbb906c7656ea0f71eda \ + --hash=sha256:ca070ad70e9fa6654084575d01bd001f30cc4665e33d4bb9fc8e0f321caa034b + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.5.82 \ + --hash=sha256:4f835281cf492e2bedd153f5c3de9da8f1d775a419468305e64ce73b3b0c6dc3 \ + --hash=sha256:bde77a5feb66752ec61db2adfe47f56b941842825b4c7e2068aff27c9d107953 \ + --hash=sha256:d32c06490c6ba35c4323730820c7d0c4c126c04ed58d2f57275adb8d54b138fe + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-nvrtc-cu12==12.5.82 \ + --hash=sha256:3dbd97b0104b4bfbc3c4f8c79cd2496307c89c43c29a9f83125f1d76296ff3fd \ + --hash=sha256:5bb6a0eb01d4974bb7ca3d48bd3859472debb3c3057a5e7de2b08fbdf35eed7e \ + --hash=sha256:e5db37e990056c70953b7772dd778336ef9da0a0b5bb28f9f2a61c2e42b51d78 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-runtime-cu12==12.5.82 \ + --hash=sha256:0fd5fbca289bceb9f0690aa9858f06187b554fdeb7e2711dfd5bb3ce58900b46 \ + --hash=sha256:3e79a060e126df40fd3a068f3f787eb000fa51b251ec6cd97d09579632687115 \ + --hash=sha256:71f015dbf9df05dd71f7480132c6ebf47a6ceb2ab53d7db8e08e4b30ebb87e14 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cudnn-cu12==9.3.0.75 \ + --hash=sha256:9ad9c6929ebb5295eb4a1728024666d1c88283373e265a0c5c883e6f9d5cd76d \ + --hash=sha256:c5cf7ff3415e446adf195a5b7dd2ba56cd00c3ee78bfdc566e51698931aa4b7f \ + --hash=sha256:c819e82eed8cf564b9d37478ea4eab9e87194bb3b7f7f8098bc1f67c9b80f1b6 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cufft-cu12==11.2.3.61 \ + --hash=sha256:4a8f6f0ce93c52a50ee83422a80472b5f376054a63f38532d0eab4007e7ef28b \ + --hash=sha256:6d45b48a5ee7599e57131129cda2c58544d9b78b95064d3ec3e5c6b96e2b58cc \ + --hash=sha256:9a6e8df162585750f61983a638104a48c756aa13f9f48e19ab079b38e3c828b8 + # via -r ci/official/requirements_updater/requirements.in +nvidia-curand-cu12==10.3.6.82 \ + --hash=sha256:0631ba65231260ad832ce233ddda57e7b3b7158eabf000d78e46cbb5bd5b7aae \ + --hash=sha256:2823fb27de4e44dbb22394a6adf53aa6e1b013aca0f8c22867d1cfae58405536 \ + --hash=sha256:36aabeb5990297bbce3df324ea7c7c13c3aabb140c86d50ab3b23e4ec61672f1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusolver-cu12==11.6.3.83 \ + --hash=sha256:1b8b77d2fe8abe72bb722dafb708cceaeb81f1a03999477f20b33b34f46ab885 \ + --hash=sha256:6224732963cba312a84c78114b9a38c4ffabb2e2a6a120923ac99ba6f895c8cf \ + --hash=sha256:93cfafacde4428b71778eeb092ec615a02a3d05404da1bcf91c53e3fa1bce42b + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusparse-cu12==12.5.1.3 \ + --hash=sha256:016df8e993c437e8301e62739f01775cba988fd5253cd4c64173f8e8d2f8e752 \ + --hash=sha256:33520db374e2f5ebc976d6faa1852b98c398a57e6f71150fe59705928596ffd1 \ + --hash=sha256:7b97fd01f0a61628af99d0efd52132fccc8c18fc5c509f13802dccf0574a19c2 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.23.4 \ + --hash=sha256:aa946c8327e22ced28e7cef508a334673abc42064ec85f02d005ba1785ea4cec \ + --hash=sha256:b097258d9aab2fa9f686e33c6fe40ae57b27df60cedbd15d139701bb5509e0c1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-nvjitlink-cu12==12.5.82 \ + --hash=sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27 \ + --hash=sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697 \ + --hash=sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 3aa8e6469e6a72..682490aa5ec884 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -90,7 +90,6 @@ PACKAGE_STATIC_DEPS = [ "@com_googlesource_code_re2//:__subpackages__", "@compute_library//:__subpackages__", "@curl//:__subpackages__", - "@double_conversion//:__subpackages__", "@eigen_archive//:__subpackages__", "@farmhash_archive//:__subpackages__", "@farmhash_gpu_archive//:__subpackages__", diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index bbdb42167319ca..b27ced84840280 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -354,9 +354,9 @@ tf_cuda_library( ], deps = [ ":c_api_macros_hdrs", - "@local_tsl//tsl/platform:status", "@local_xla//xla/tsl/c:tsl_status", "@local_xla//xla/tsl/c:tsl_status_internal", + "@local_xla//xla/tsl/platform:status", ] + select({ "//tensorflow:android": [ "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 08c5de71906e31..c4828432584347 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -832,7 +832,7 @@ void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) { void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name, const void* value, size_t length) { - tensorflow::StringPiece s(static_cast(value), length); + absl::string_view s(static_cast(value), length); desc->node_builder.Attr(attr_name, s); } @@ -846,7 +846,7 @@ void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name, lengths[i]); } } else { - std::vector v; + std::vector v; v.reserve(num_values); for (int i = 0; i < num_values; ++i) { v.emplace_back(static_cast(values[i]), lengths[i]); diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 48cb17b190f334..4361ea41feadd8 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -64,7 +64,7 @@ absl::Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); namespace { -static void ExpectHasSubstr(StringPiece s, StringPiece expected) { +static void ExpectHasSubstr(absl::string_view s, absl::string_view expected) { EXPECT_TRUE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 8019b01edeca77..be1e88384bdb00 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -987,7 +987,7 @@ REGISTER_KERNEL_BUILDER( Name("TestCommUnavailable").Device(tensorflow::DEVICE_DEFAULT), TestUnavailableErrorOp); -string FunctionWithErrorOp(const tensorflow::StringPiece op_name) { +string FunctionWithErrorOp(const absl::string_view op_name) { const std::string& func_str = " signature {" " name: 'FunctionWith__OP_NAME__'" diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index 23ebb99839c46b..a277766f9be280 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -187,7 +187,7 @@ class GraphOperation : public TracingOperation { absl::Status SetAttrString(const char* attr_name, const char* data, size_t length) override { - tensorflow::StringPiece s(data, length); + absl::string_view s(data, length); op_->node_builder.Attr(attr_name, s); return absl::OkStatus(); } @@ -251,7 +251,7 @@ class GraphOperation : public TracingOperation { lengths[i]); } } else { - std::vector v; + std::vector v; v.reserve(num_values); for (int i = 0; i < num_values; ++i) { v.emplace_back(static_cast(values[i]), lengths[i]); diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc index 2fa9f90726896a..93140659df13d4 100644 --- a/tensorflow/c/eager/gradients.cc +++ b/tensorflow/c/eager/gradients.cc @@ -324,7 +324,7 @@ absl::Status AddInputList(AbstractOperation* op_, absl::Status SetAttrString(AbstractOperation* op_, const char* attr_name, const char* data, size_t length, ForwardOperation* forward_op_) { - forward_op_->attrs.Set(attr_name, StringPiece(data, length)); + forward_op_->attrs.Set(attr_name, absl::string_view(data, length)); return op_->SetAttrString(attr_name, data, length); } absl::Status SetAttrInt(AbstractOperation* op_, const char* attr_name, @@ -390,9 +390,9 @@ absl::Status SetAttrTensor(AbstractOperation* op_, const char* attr_name, absl::Status SetAttrStringList(AbstractOperation* op_, const char* attr_name, const void* const* values, const size_t* lengths, int num_values, ForwardOperation* forward_op_) { - std::vector v(num_values); + std::vector v(num_values); for (int i = 0; i < num_values; ++i) { - v[i] = StringPiece(static_cast(values[i]), lengths[i]); + v[i] = absl::string_view(static_cast(values[i]), lengths[i]); } forward_op_->attrs.Set(attr_name, v); return op_->SetAttrStringList(attr_name, values, lengths, num_values); diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD index ee9aec47da4c0c..54d4bb30c6f888 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD @@ -79,11 +79,12 @@ tf_cc_test( deps = [ ":ram_file_block_cache", "//tensorflow/c:tf_status_internal", - "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core/platform:blocking_counter", "//tensorflow/core/platform/cloud:now_seconds_env", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc index 17ab386f271f04..4ad4a8ea1868f3 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -25,13 +24,15 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/synchronization/blocking_counter.h" +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_internal.h" #include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/cloud/now_seconds_env.h" -#include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -511,11 +512,19 @@ TEST(RamFileBlockCacheTest, ParallelReads) { // concurrently (at which point it will respond with success to all callers), // or 10 seconds have elapsed (at which point it will respond with an error). const int callers = 4; - BlockingCounter counter(callers); - auto fetcher = [&counter](const string& filename, size_t offset, size_t n, - char* buffer, TF_Status* status) -> int64_t { - counter.DecrementCount(); - if (!counter.WaitFor(std::chrono::seconds(10))) { + absl::BlockingCounter counter(callers); + absl::Notification notification; + auto fetcher = [&counter, ¬ification]( + const string& filename, size_t offset, size_t n, + char* buffer, TF_Status* status) -> int64_t { + if (counter.DecrementCount()) { + notification.Notify(); + // This call to `Wait()` is not expected to block. Calling `Wait()` here + // allows us to satisfy `BlockingCounter`'s requirement: "When `Wait()` + // returns, it is legal to destroy the `BlockingCounter`.". + counter.Wait(); + } + if (!notification.WaitForNotificationWithTimeout(absl::Seconds(10))) { // This avoids having the test time out, which is harder to debug. TF_SetStatus(status, TF_FAILED_PRECONDITION, "desired concurrency not reached"); @@ -549,7 +558,7 @@ TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) { // Concurrent reads to the same file blocks should be de-duplicated. const size_t block_size = 16; int num_requests = 0; - Notification notification; + absl::Notification notification; auto fetcher = [&num_requests, ¬ification, block_size]( const string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { diff --git a/tensorflow/c/experimental/gradients/tape/BUILD b/tensorflow/c/experimental/gradients/tape/BUILD index 3097c31e289fd6..6dfd0fffa6e83c 100644 --- a/tensorflow/c/experimental/gradients/tape/BUILD +++ b/tensorflow/c/experimental/gradients/tape/BUILD @@ -22,6 +22,7 @@ cc_library( "//tensorflow/c/eager:gradients_internal", "//tensorflow/core:portable_gif_internal", "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", ], ) @@ -75,6 +76,7 @@ cc_library( "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/c/experimental/gradients/tape/tape_context.cc b/tensorflow/c/experimental/gradients/tape/tape_context.cc index 5285b6a088e5b0..bdf080733f9bd9 100644 --- a/tensorflow/c/experimental/gradients/tape/tape_context.cc +++ b/tensorflow/c/experimental/gradients/tape/tape_context.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/gradients/tape/tape_context.h" +#include "absl/status/status.h" #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_function.h" #include "tensorflow/c/eager/gradients.h" @@ -40,10 +41,10 @@ TapeContext::~TapeContext() { TapeOperation* TapeContext::CreateOperation() { return new TapeOperation(parent_ctx_->CreateOperation(), tape_, registry_); } -Status TapeContext::RegisterFunction(AbstractFunction* f) { +absl::Status TapeContext::RegisterFunction(AbstractFunction* f) { return parent_ctx_->RegisterFunction(f); } -Status TapeContext::RemoveFunction(const string& func) { +absl::Status TapeContext::RemoveFunction(const string& func) { return parent_ctx_->RemoveFunction(func); } diff --git a/tensorflow/c/experimental/gradients/tape/tape_context.h b/tensorflow/c/experimental/gradients/tape/tape_context.h index a7588362325fc1..f92c35f27f4235 100644 --- a/tensorflow/c/experimental/gradients/tape/tape_context.h +++ b/tensorflow/c/experimental/gradients/tape/tape_context.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_ #define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_ +#include "absl/status/status.h" #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_function.h" #include "tensorflow/c/eager/gradients.h" @@ -29,8 +30,8 @@ class TapeContext : public AbstractContext { explicit TapeContext(AbstractContext*, Tape*, const GradientRegistry&); void Release() override; TapeOperation* CreateOperation() override; - Status RegisterFunction(AbstractFunction*) override; - Status RemoveFunction(const string& func) override; + absl::Status RegisterFunction(AbstractFunction*) override; + absl::Status RemoveFunction(const string& func) override; // For LLVM style RTTI. static bool classof(const AbstractContext* ptr) { return ptr->getKind() == kTape; diff --git a/tensorflow/c/experimental/gradients/tape/tape_operation.cc b/tensorflow/c/experimental/gradients/tape/tape_operation.cc index 5bd3daa4037fbe..f0cba24b9f87c8 100644 --- a/tensorflow/c/experimental/gradients/tape/tape_operation.cc +++ b/tensorflow/c/experimental/gradients/tape/tape_operation.cc @@ -14,6 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/gradients/tape/tape_operation.h" +#include +#include +#include +#include + #include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_operation.h" @@ -50,7 +55,7 @@ TapeOperation::~TapeOperation() { // TODO(b/172003047): Consider making AbstractOperation RefCounted. // parent_op->Unref(); } -Status TapeOperation::Reset(const char* op, const char* raw_device_name) { +absl::Status TapeOperation::Reset(const char* op, const char* raw_device_name) { forward_op_.op_name = op; forward_op_.attrs.Reset(op); forward_op_.inputs.clear(); @@ -61,15 +66,15 @@ const string& TapeOperation::Name() const { return parent_op_->Name(); } const string& TapeOperation::DeviceName() const { return parent_op_->DeviceName(); } -Status TapeOperation::SetDeviceName(const char* name) { +absl::Status TapeOperation::SetDeviceName(const char* name) { return parent_op_->SetDeviceName(name); } -Status TapeOperation::AddInput(AbstractTensorHandle* input) { +absl::Status TapeOperation::AddInput(AbstractTensorHandle* input) { TF_RETURN_IF_ERROR(parent_op_->AddInput(input)); forward_op_.inputs.push_back(input); return absl::OkStatus(); } -Status TapeOperation::AddInputList( +absl::Status TapeOperation::AddInputList( absl::Span inputs) { TF_RETURN_IF_ERROR(parent_op_->AddInputList(inputs)); for (auto input : inputs) { @@ -77,29 +82,30 @@ Status TapeOperation::AddInputList( } return absl::OkStatus(); } -Status TapeOperation::SetAttrString(const char* attr_name, const char* data, - size_t length) { - forward_op_.attrs.Set(attr_name, StringPiece(data, length)); +absl::Status TapeOperation::SetAttrString(const char* attr_name, + const char* data, size_t length) { + forward_op_.attrs.Set(attr_name, absl::string_view(data, length)); return parent_op_->SetAttrString(attr_name, data, length); } -Status TapeOperation::SetAttrInt(const char* attr_name, int64_t value) { +absl::Status TapeOperation::SetAttrInt(const char* attr_name, int64_t value) { forward_op_.attrs.Set(attr_name, static_cast(value)); return parent_op_->SetAttrInt(attr_name, value); } -Status TapeOperation::SetAttrFloat(const char* attr_name, float value) { +absl::Status TapeOperation::SetAttrFloat(const char* attr_name, float value) { forward_op_.attrs.Set(attr_name, value); return parent_op_->SetAttrFloat(attr_name, value); } -Status TapeOperation::SetAttrBool(const char* attr_name, bool value) { +absl::Status TapeOperation::SetAttrBool(const char* attr_name, bool value) { forward_op_.attrs.Set(attr_name, value); return parent_op_->SetAttrBool(attr_name, value); } -Status TapeOperation::SetAttrType(const char* attr_name, DataType value) { +absl::Status TapeOperation::SetAttrType(const char* attr_name, DataType value) { forward_op_.attrs.Set(attr_name, value); return parent_op_->SetAttrType(attr_name, value); } -Status TapeOperation::SetAttrShape(const char* attr_name, const int64_t* dims, - const int num_dims) { +absl::Status TapeOperation::SetAttrShape(const char* attr_name, + const int64_t* dims, + const int num_dims) { if (num_dims > TensorShape::MaxDimensions()) { return errors::InvalidArgument("Value specified for `", attr_name, "` has ", num_dims, @@ -118,54 +124,59 @@ Status TapeOperation::SetAttrShape(const char* attr_name, const int64_t* dims, forward_op_.attrs.Set(attr_name, proto); return parent_op_->SetAttrShape(attr_name, dims, num_dims); } -Status TapeOperation::SetAttrFunction(const char* attr_name, - const AbstractOperation* value) { +absl::Status TapeOperation::SetAttrFunction(const char* attr_name, + const AbstractOperation* value) { return tensorflow::errors::Unimplemented( "SetAttrFunction has not been implemented yet."); } -Status TapeOperation::SetAttrFunctionName(const char* attr_name, - const char* value, size_t length) { +absl::Status TapeOperation::SetAttrFunctionName(const char* attr_name, + const char* value, + size_t length) { return tensorflow::errors::Unimplemented( "SetAttrFunctionName has not been implemented " "yet."); } -Status TapeOperation::SetAttrTensor(const char* attr_name, - AbstractTensorInterface* tensor) { +absl::Status TapeOperation::SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) { return tensorflow::errors::Unimplemented( "SetAttrTensor has not been implemented yet."); } -Status TapeOperation::SetAttrStringList(const char* attr_name, - const void* const* values, - const size_t* lengths, int num_values) { - std::vector v(num_values); +absl::Status TapeOperation::SetAttrStringList(const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values) { + std::vector v(num_values); for (int i = 0; i < num_values; ++i) { - v[i] = StringPiece(static_cast(values[i]), lengths[i]); + v[i] = absl::string_view(static_cast(values[i]), lengths[i]); } forward_op_.attrs.Set(attr_name, v); return parent_op_->SetAttrStringList(attr_name, values, lengths, num_values); } -Status TapeOperation::SetAttrFloatList(const char* attr_name, - const float* values, int num_values) { +absl::Status TapeOperation::SetAttrFloatList(const char* attr_name, + const float* values, + int num_values) { forward_op_.attrs.Set(attr_name, gtl::ArraySlice(values, num_values)); return parent_op_->SetAttrFloatList(attr_name, values, num_values); } -Status TapeOperation::SetAttrIntList(const char* attr_name, - const int64_t* values, int num_values) { +absl::Status TapeOperation::SetAttrIntList(const char* attr_name, + const int64_t* values, + int num_values) { forward_op_.attrs.Set( attr_name, gtl::ArraySlice( reinterpret_cast(values), num_values)); return parent_op_->SetAttrIntList(attr_name, values, num_values); } -Status TapeOperation::SetAttrTypeList(const char* attr_name, - const DataType* values, int num_values) { +absl::Status TapeOperation::SetAttrTypeList(const char* attr_name, + const DataType* values, + int num_values) { forward_op_.attrs.Set(attr_name, gtl::ArraySlice(values, num_values)); return parent_op_->SetAttrTypeList(attr_name, values, num_values); } -Status TapeOperation::SetAttrBoolList(const char* attr_name, - const unsigned char* values, - int num_values) { +absl::Status TapeOperation::SetAttrBoolList(const char* attr_name, + const unsigned char* values, + int num_values) { std::unique_ptr b(new bool[num_values]); for (int i = 0; i < num_values; ++i) { b[i] = values[i]; @@ -174,9 +185,10 @@ Status TapeOperation::SetAttrBoolList(const char* attr_name, gtl::ArraySlice(b.get(), num_values)); return parent_op_->SetAttrBoolList(attr_name, values, num_values); } -Status TapeOperation::SetAttrShapeList(const char* attr_name, - const int64_t** dims, - const int* num_dims, int num_values) { +absl::Status TapeOperation::SetAttrShapeList(const char* attr_name, + const int64_t** dims, + const int* num_dims, + int num_values) { std::unique_ptr proto(new TensorShapeProto[num_values]); for (int i = 0; i < num_values; ++i) { const auto num_dims_i = num_dims[i]; @@ -201,15 +213,15 @@ Status TapeOperation::SetAttrShapeList(const char* attr_name, attr_name, gtl::ArraySlice(proto.get(), num_values)); return parent_op_->SetAttrShapeList(attr_name, dims, num_dims, num_values); } -Status TapeOperation::SetAttrFunctionList( +absl::Status TapeOperation::SetAttrFunctionList( const char* attr_name, absl::Span values) { return tensorflow::errors::Unimplemented( "SetAttrFunctionList has not been " "implemented yet."); } AbstractOperation* TapeOperation::GetBackingOperation() { return parent_op_; } -Status TapeOperation::Execute(absl::Span retvals, - int* num_retvals) { +absl::Status TapeOperation::Execute(absl::Span retvals, + int* num_retvals) { TF_RETURN_IF_ERROR(parent_op_->Execute(retvals, num_retvals)); for (int i = 0; i < *num_retvals; i++) { // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs. diff --git a/tensorflow/c/experimental/gradients/tape/tape_operation.h b/tensorflow/c/experimental/gradients/tape/tape_operation.h index 2ab67394988cf9..8f447440768912 100644 --- a/tensorflow/c/experimental/gradients/tape/tape_operation.h +++ b/tensorflow/c/experimental/gradients/tape/tape_operation.h @@ -15,6 +15,10 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_ #define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_ +#include +#include + +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_operation.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" @@ -30,41 +34,45 @@ class TapeOperation : public AbstractOperation { public: explicit TapeOperation(AbstractOperation*, Tape*, const GradientRegistry&); void Release() override; - Status Reset(const char* op, const char* raw_device_name) override; + absl::Status Reset(const char* op, const char* raw_device_name) override; const string& Name() const override; const string& DeviceName() const override; - Status SetDeviceName(const char* name) override; - Status AddInput(AbstractTensorHandle* input) override; - Status AddInputList(absl::Span inputs) override; - Status Execute(absl::Span retvals, - int* num_retvals) override; - Status SetAttrString(const char* attr_name, const char* data, - size_t length) override; - Status SetAttrInt(const char* attr_name, int64_t value) override; - Status SetAttrFloat(const char* attr_name, float value) override; - Status SetAttrBool(const char* attr_name, bool value) override; - Status SetAttrType(const char* attr_name, DataType value) override; - Status SetAttrShape(const char* attr_name, const int64_t* dims, - const int num_dims) override; - Status SetAttrFunction(const char* attr_name, - const AbstractOperation* value) override; - Status SetAttrFunctionName(const char* attr_name, const char* value, + absl::Status SetDeviceName(const char* name) override; + absl::Status AddInput(AbstractTensorHandle* input) override; + absl::Status AddInputList( + absl::Span inputs) override; + absl::Status Execute(absl::Span retvals, + int* num_retvals) override; + absl::Status SetAttrString(const char* attr_name, const char* data, size_t length) override; - Status SetAttrTensor(const char* attr_name, - AbstractTensorInterface* tensor) override; - Status SetAttrStringList(const char* attr_name, const void* const* values, - const size_t* lengths, int num_values) override; - Status SetAttrFloatList(const char* attr_name, const float* values, - int num_values) override; - Status SetAttrIntList(const char* attr_name, const int64_t* values, - int num_values) override; - Status SetAttrTypeList(const char* attr_name, const DataType* values, - int num_values) override; - Status SetAttrBoolList(const char* attr_name, const unsigned char* values, - int num_values) override; - Status SetAttrShapeList(const char* attr_name, const int64_t** dims, - const int* num_dims, int num_values) override; - Status SetAttrFunctionList( + absl::Status SetAttrInt(const char* attr_name, int64_t value) override; + absl::Status SetAttrFloat(const char* attr_name, float value) override; + absl::Status SetAttrBool(const char* attr_name, bool value) override; + absl::Status SetAttrType(const char* attr_name, DataType value) override; + absl::Status SetAttrShape(const char* attr_name, const int64_t* dims, + const int num_dims) override; + absl::Status SetAttrFunction(const char* attr_name, + const AbstractOperation* value) override; + absl::Status SetAttrFunctionName(const char* attr_name, const char* value, + size_t length) override; + absl::Status SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) override; + absl::Status SetAttrStringList(const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values) override; + absl::Status SetAttrFloatList(const char* attr_name, const float* values, + int num_values) override; + absl::Status SetAttrIntList(const char* attr_name, const int64_t* values, + int num_values) override; + absl::Status SetAttrTypeList(const char* attr_name, const DataType* values, + int num_values) override; + absl::Status SetAttrBoolList(const char* attr_name, + const unsigned char* values, + int num_values) override; + absl::Status SetAttrShapeList(const char* attr_name, const int64_t** dims, + const int* num_dims, int num_values) override; + absl::Status SetAttrFunctionList( const char* attr_name, absl::Span values) override; AbstractOperation* GetBackingOperation(); diff --git a/tensorflow/c/experimental/next_pluggable_device/BUILD b/tensorflow/c/experimental/next_pluggable_device/BUILD index 2df0776922e2df..fdbcef8cee3582 100644 --- a/tensorflow/c/experimental/next_pluggable_device/BUILD +++ b/tensorflow/c/experimental/next_pluggable_device/BUILD @@ -22,6 +22,7 @@ cc_library( "//tensorflow/compiler/jit:variable_info", "//tensorflow/compiler/jit:variable_info_util", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/next_pluggable_device:plugin_resource", "//tensorflow/core/platform:refcount", "//tensorflow/core/platform:status", @@ -84,14 +85,17 @@ tf_cc_test( deps = [ ":tensor_pjrt_buffer_util", "//tensorflow/core:framework_types_hdr", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/tfrt/common:async_value_tensor", "//tensorflow/core/tfrt/common:pjrt_util", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", + "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/pjrt:pjrt_api", "@local_xla//xla/pjrt:pjrt_c_api_client", "@local_xla//xla/pjrt/c:pjrt_c_api_cpu", diff --git a/tensorflow/c/experimental/next_pluggable_device/c_api.cc b/tensorflow/c/experimental/next_pluggable_device/c_api.cc index 595775abb26d84..fdb8a9e7f47794 100644 --- a/tensorflow/c/experimental/next_pluggable_device/c_api.cc +++ b/tensorflow/c/experimental/next_pluggable_device/c_api.cc @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/tfrt/common/pjrt_util.h" diff --git a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.h b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.h index c2b1051f75c39e..c2378b68109fc9 100644 --- a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.h +++ b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_NEXT_PLUGGABLE_DEVICE_TENSOR_PJRT_BUFFER_UTIL_H_ #define TENSORFLOW_C_EXPERIMENTAL_NEXT_PLUGGABLE_DEVICE_TENSOR_PJRT_BUFFER_UTIL_H_ +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/pjrt_c_api_client.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc index 84edac2bc4e825..3c1d1e760a0755 100644 --- a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc +++ b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include #include #include "absl/log/check.h" +#include "absl/strings/str_cat.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_cpu.h" #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" @@ -33,7 +35,9 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/protobuf/error_codes.pb.h" +#include "xla/xla_data.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/tfrt/common/async_value_tensor.h" #include "tensorflow/core/tfrt/common/pjrt_util.h" #include "tsl/platform/casts.h" diff --git a/tensorflow/c/experimental/ops/BUILD b/tensorflow/c/experimental/ops/BUILD index 88c1b6dccee0d3..9920fb114a62d2 100644 --- a/tensorflow/c/experimental/ops/BUILD +++ b/tensorflow/c/experimental/ops/BUILD @@ -49,6 +49,7 @@ cc_library( "//tensorflow/c/eager:tracing_utils", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", ], @@ -71,6 +72,7 @@ cc_library( "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:tracing_utils", "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", ], @@ -93,6 +95,7 @@ cc_library( "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:tracing_utils", "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", ], @@ -118,6 +121,7 @@ cc_library( "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", ], diff --git a/tensorflow/c/experimental/ops/gen/common/BUILD b/tensorflow/c/experimental/ops/gen/common/BUILD index 1782722cac7f72..447c6a2a480be7 100644 --- a/tensorflow/c/experimental/ops/gen/common/BUILD +++ b/tensorflow/c/experimental/ops/gen/common/BUILD @@ -25,6 +25,8 @@ cc_library( "//tensorflow/core:op_gen_lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:str_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/tensorflow/c/experimental/ops/gen/common/case_format.cc b/tensorflow/c/experimental/ops/gen/common/case_format.cc index d23f7b75149c8f..82acc32f623fd8 100644 --- a/tensorflow/c/experimental/ops/gen/common/case_format.cc +++ b/tensorflow/c/experimental/ops/gen/common/case_format.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/common/case_format.h" +#include + #include "absl/strings/ascii.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/c/experimental/ops/gen/common/controller.cc b/tensorflow/c/experimental/ops/gen/common/controller.cc index cafb57c0919403..16908012f296bb 100644 --- a/tensorflow/c/experimental/ops/gen/common/controller.cc +++ b/tensorflow/c/experimental/ops/gen/common/controller.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/common/controller.h" +#include + +#include "absl/log/check.h" #include "absl/strings/substitute.h" #include "tensorflow/c/experimental/ops/gen/common/path_config.h" #include "tensorflow/c/experimental/ops/gen/common/source_code.h" diff --git a/tensorflow/c/experimental/ops/gen/common/controller.h b/tensorflow/c/experimental/ops/gen/common/controller.h index a86779eedb598f..e152efeb6d8f9f 100644 --- a/tensorflow/c/experimental/ops/gen/common/controller.h +++ b/tensorflow/c/experimental/ops/gen/common/controller.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_CONTROLLER_H_ #define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_CONTROLLER_H_ +#include + #include "tensorflow/c/experimental/ops/gen/common/path_config.h" #include "tensorflow/c/experimental/ops/gen/common/source_code.h" #include "tensorflow/c/experimental/ops/gen/model/op_spec.h" diff --git a/tensorflow/c/experimental/ops/gen/common/path_config.cc b/tensorflow/c/experimental/ops/gen/common/path_config.cc index b8f84d5f31f4d3..2ec57d67c9d6f7 100644 --- a/tensorflow/c/experimental/ops/gen/common/path_config.cc +++ b/tensorflow/c/experimental/ops/gen/common/path_config.cc @@ -14,7 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/common/path_config.h" -#include +#include +#include #include "absl/strings/str_join.h" #include "tensorflow/core/lib/strings/str_util.h" diff --git a/tensorflow/c/experimental/ops/gen/common/path_config.h b/tensorflow/c/experimental/ops/gen/common/path_config.h index 7d76f7c987a376..ce29063be5f682 100644 --- a/tensorflow/c/experimental/ops/gen/common/path_config.h +++ b/tensorflow/c/experimental/ops/gen/common/path_config.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_PATH_CONFIG_H_ #define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_PATH_CONFIG_H_ +#include + #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/common/source_code.cc b/tensorflow/c/experimental/ops/gen/common/source_code.cc index 5868b20dc7e5d2..2b7bce6a263184 100644 --- a/tensorflow/c/experimental/ops/gen/common/source_code.cc +++ b/tensorflow/c/experimental/ops/gen/common/source_code.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/common/source_code.h" +#include "absl/log/log.h" #include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -48,7 +49,7 @@ void SourceCode::IncreaseIndent() { current_indent_++; } void SourceCode::DecreaseIndent() { current_indent_--; } void SourceCode::ValidateAndAddLine(int indent, const string& raw_line) { - StringPiece line(raw_line); + absl::string_view line(raw_line); bool had_trailing_newline = absl::ConsumeSuffix(&line, "\n"); if (absl::StrContains(line, '\n')) { diff --git a/tensorflow/c/experimental/ops/gen/common/source_code.h b/tensorflow/c/experimental/ops/gen/common/source_code.h index 471b63f1f6a902..df1aa90acf7b8c 100644 --- a/tensorflow/c/experimental/ops/gen/common/source_code.h +++ b/tensorflow/c/experimental/ops/gen/common/source_code.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_SOURCE_CODE_H_ #define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_SOURCE_CODE_H_ +#include + #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/common/view_util.cc b/tensorflow/c/experimental/ops/gen/common/view_util.cc index 7c8717067b08fe..388aa0646db82b 100644 --- a/tensorflow/c/experimental/ops/gen/common/view_util.cc +++ b/tensorflow/c/experimental/ops/gen/common/view_util.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/common/view_util.h" +#include + #include "absl/strings/str_join.h" #include "absl/strings/substitute.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/c/experimental/ops/gen/common/view_util.h b/tensorflow/c/experimental/ops/gen/common/view_util.h index 4fff7189acbf2c..7ab437a90e4fd8 100644 --- a/tensorflow/c/experimental/ops/gen/common/view_util.h +++ b/tensorflow/c/experimental/ops/gen/common/view_util.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_VIEW_UTIL_H_ #define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_VIEW_UTIL_H_ +#include + #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD b/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD index ba3fe1575c781a..5403d1bf46d9a9 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD @@ -26,6 +26,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h index 8adf390561c442..fa7571d98a1214 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_CPP_CONFIG_H_ #define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_CPP_CONFIG_H_ +#include + #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.cc index 71132cfc3bf8b2..c274d00d816019 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.h" +#include + #include "tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc index 7a4275b532eda7..1a685cac0c405c 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h" +#include + #include "tensorflow/c/experimental/ops/gen/common/case_format.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc index c58e67782dfc34..c459d239ca699f 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h" #include +#include #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc index 4764fe799523ae..a9efb94335c0a6 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "absl/log/log.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" @@ -39,7 +40,7 @@ Renderer& Renderer::CodeLine(const string& text) { } Renderer& Renderer::CodeLines(const string& text) { - StringPiece trimmed_text(text); + absl::string_view trimmed_text(text); str_util::RemoveWhitespaceContext(&trimmed_text); for (const string& line : str_util::Split(trimmed_text, '\n')) { context_.code.AddLineWithoutIndent(line); diff --git a/tensorflow/c/experimental/ops/gen/cpp/views/BUILD b/tensorflow/c/experimental/ops/gen/cpp/views/BUILD index fd8194d584d32b..1790ddc8d86978 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/views/BUILD +++ b/tensorflow/c/experimental/ops/gen/cpp/views/BUILD @@ -22,6 +22,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/tensorflow/c/experimental/ops/gen/cpp/views/op_view.cc b/tensorflow/c/experimental/ops/gen/cpp/views/op_view.cc index f47851ddbd404e..eeb300271abdae 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/views/op_view.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/views/op_view.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include "absl/log/check.h" #include "tensorflow/c/experimental/ops/gen/common/view_util.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/arg_view.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/attr_view.h" diff --git a/tensorflow/c/experimental/ops/io_ops.cc b/tensorflow/c/experimental/ops/io_ops.cc index 920d82cf1be3ec..7c5be2c67e7476 100644 --- a/tensorflow/c/experimental/ops/io_ops.cc +++ b/tensorflow/c/experimental/ops/io_ops.cc @@ -17,6 +17,9 @@ limitations under the License. #include "tensorflow/c/experimental/ops/io_ops.h" +#include + +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_operation.h" diff --git a/tensorflow/c/experimental/ops/io_ops.h b/tensorflow/c/experimental/ops/io_ops.h index ceccddad5ea188..939c853616d10a 100644 --- a/tensorflow/c/experimental/ops/io_ops.h +++ b/tensorflow/c/experimental/ops/io_ops.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_IO_OPS_H_ #define TENSORFLOW_C_EXPERIMENTAL_OPS_IO_OPS_H_ +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" diff --git a/tensorflow/c/experimental/ops/math_ops.cc b/tensorflow/c/experimental/ops/math_ops.cc index 2a2ea0f26534b9..cd1c6e3a2209ca 100644 --- a/tensorflow/c/experimental/ops/math_ops.cc +++ b/tensorflow/c/experimental/ops/math_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/c/experimental/ops/math_ops.h" +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_operation.h" diff --git a/tensorflow/c/experimental/ops/math_ops.h b/tensorflow/c/experimental/ops/math_ops.h index c7cde54acad483..c33c89fd00ff9a 100644 --- a/tensorflow/c/experimental/ops/math_ops.h +++ b/tensorflow/c/experimental/ops/math_ops.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_ #define TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_ +#include "absl/status/status.h" #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/core/platform/status.h" diff --git a/tensorflow/c/experimental/ops/nn_ops.cc b/tensorflow/c/experimental/ops/nn_ops.cc index 6be53fb7fe0bf5..c7e9589f053ec9 100644 --- a/tensorflow/c/experimental/ops/nn_ops.cc +++ b/tensorflow/c/experimental/ops/nn_ops.cc @@ -17,6 +17,9 @@ limitations under the License. #include "tensorflow/c/experimental/ops/nn_ops.h" +#include + +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_operation.h" diff --git a/tensorflow/c/experimental/ops/nn_ops.h b/tensorflow/c/experimental/ops/nn_ops.h index 204ed13a3ba9fd..0006267f627113 100644 --- a/tensorflow/c/experimental/ops/nn_ops.h +++ b/tensorflow/c/experimental/ops/nn_ops.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_ #define TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_ +#include "absl/status/status.h" #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/core/platform/status.h" diff --git a/tensorflow/c/experimental/ops/resource_variable_ops.cc b/tensorflow/c/experimental/ops/resource_variable_ops.cc index 68304ebff5bbbe..042ef809886313 100644 --- a/tensorflow/c/experimental/ops/resource_variable_ops.cc +++ b/tensorflow/c/experimental/ops/resource_variable_ops.cc @@ -17,6 +17,10 @@ limitations under the License. #include "tensorflow/c/experimental/ops/resource_variable_ops.h" +#include +#include + +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_operation.h" diff --git a/tensorflow/c/experimental/ops/resource_variable_ops.h b/tensorflow/c/experimental/ops/resource_variable_ops.h index 5ba2b8fdd5656d..02b42bf4caa706 100644 --- a/tensorflow/c/experimental/ops/resource_variable_ops.h +++ b/tensorflow/c/experimental/ops/resource_variable_ops.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_RESOURCE_VARIABLE_OPS_H_ #define TENSORFLOW_C_EXPERIMENTAL_OPS_RESOURCE_VARIABLE_OPS_H_ +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" diff --git a/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc b/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc index d179d0de6b7d09..c2bf61d785e6b2 100644 --- a/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc +++ b/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc @@ -24,7 +24,7 @@ limitations under the License. namespace tensorflow { namespace { -SavedObjectGraph ParseSavedObjectGraph(StringPiece text_proto) { +SavedObjectGraph ParseSavedObjectGraph(absl::string_view text_proto) { SavedObjectGraph value; CHECK(tensorflow::protobuf::TextFormat::ParseFromString(string(text_proto), &value)); diff --git a/tensorflow/c/experimental/saved_model/core/ops/BUILD b/tensorflow/c/experimental/saved_model/core/ops/BUILD index be8856e1055017..4214f76cee1cee 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/BUILD +++ b/tensorflow/c/experimental/saved_model/core/ops/BUILD @@ -32,6 +32,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", ], ) @@ -53,6 +54,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", ], ) @@ -79,6 +81,7 @@ tf_cc_test( "//tensorflow/core/common_runtime:core_cpu_lib", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:core", + "@com_google_absl//absl/status", ], ) @@ -100,5 +103,6 @@ tf_cc_test( "//tensorflow/core/common_runtime:core_cpu_lib", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:core", + "@com_google_absl//absl/status", ], ) diff --git a/tensorflow/c/experimental/saved_model/core/ops/restore_ops.cc b/tensorflow/c/experimental/saved_model/core/ops/restore_ops.cc index 0db50bd6faa32b..30b6adde2df81b 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/restore_ops.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/restore_ops.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h" +#include +#include + +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_context.h" diff --git a/tensorflow/c/experimental/saved_model/core/ops/restore_ops.h b/tensorflow/c/experimental/saved_model/core/ops/restore_ops.h index f559978b5de345..5a0ec2bce5fe1e 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/restore_ops.h +++ b/tensorflow/c/experimental/saved_model/core/ops/restore_ops.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/status/status.h" #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc index f3bb9e93d24486..1d55dabcc9ab87 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h" +#include +#include + +#include "absl/status/status.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/test_utils.h" #include "tensorflow/c/tensor_interface.h" @@ -31,7 +35,7 @@ limitations under the License. namespace tensorflow { namespace { -std::string CheckpointPrefix(StringPiece saved_model_dir) { +std::string CheckpointPrefix(absl::string_view saved_model_dir) { return io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", saved_model_dir, kSavedModelVariablesDirectory, kSavedModelVariablesFilename); diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc index 9d10241ad21bb7..2804456f4f4ecb 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h" +#include +#include + +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_context.h" diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h index d16bd3b2557345..ee01935b6ebf0d 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H_ +#include "absl/status/status.h" #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc index 04f9441e89ec57..bbff929015dd6a 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/status/status.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/test_utils.h" #include "tensorflow/c/tensor_interface.h" diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index f2647901a81c76..5dd21f10d817c8 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -47,6 +47,7 @@ cc_library( "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status", ], ) @@ -65,6 +66,8 @@ cc_library( "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", ], ) @@ -119,6 +122,8 @@ cc_library( "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/core:lib", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], @@ -171,6 +176,8 @@ cc_library( "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", ], ) @@ -241,6 +248,7 @@ cc_library( "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/asset.h b/tensorflow/c/experimental/saved_model/core/revived_types/asset.h index c09a16ab61b844..4f4bff8643bb06 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/asset.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/asset.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/constant.cc b/tensorflow/c/experimental/saved_model/core/revived_types/constant.cc index 865b24ae515fd3..8d8342bb304368 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/constant.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/constant.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/constant.h b/tensorflow/c/experimental/saved_model/core/revived_types/constant.h index 2558fa14b9efbc..0d89cf37dbf0c9 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/constant.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/constant.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/status/status.h" #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc index 70b25c7fc5739f..a50b50fef7b888 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_operation.h" diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h index ac0c67a7b6545a..810a42ec88784f 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_context.h" diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc index 5a32806980c797..2ac31f313230ac 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc @@ -54,7 +54,7 @@ using StructuredValueDictEntry = protobuf::MapPair; using NamedParamMap = - gtl::FlatMap; + gtl::FlatMap; absl::Status AssertAllCreateResourceFunctionsHaveNoCaptures( const PartiallyRevivedObjects& objects) { diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc index fde245d6830956..b5a3e5b8d5fda5 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_operation.h" diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h index fd2db397cfe688..691a591cb54a2d 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/types/optional.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc index 6d9cbe61c0c414..8c16b2ea2b7bc9 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_operation.h" diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h index eedf2aae295422..c9b98189ef174a 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_context.h" diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc b/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc index db5f0428dea65a..cdf81e69835767 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/types/optional.h" #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/variable.h b/tensorflow/c/experimental/saved_model/core/revived_types/variable.h index 0897a96b8fd363..5a9ad51ae54c42 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/variable.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/variable.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/types/optional.h" #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc index 7d9dd3f73375c3..50c9c2c6271500 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc @@ -51,11 +51,12 @@ using StructuredValueDictEntry = // Maps from a Nodedef's name to its corresponding AttrValues, for a given // Graphdef using NodeAttrMap = - gtl::FlatMap; + gtl::FlatMap; // Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary -using FunctionDefMap = gtl::FlatMap; +using FunctionDefMap = + gtl::FlatMap; // Looks up a SavedConstant's associated tensorproto from the NodeAttrMap and // returns a tensorflow::Constant. @@ -331,7 +332,7 @@ absl::Status FlattenSignature( } } -absl::optional FindNodeAtPath(StringPiece path, +absl::optional FindNodeAtPath(absl::string_view path, const SavedObjectGraph& object_graph) { const auto& nodes = object_graph.nodes(); if (nodes.empty()) { @@ -361,18 +362,21 @@ absl::optional FindNodeAtPath(StringPiece path, return node_id; } -gtl::FlatMap NodeToAttrMap( - const tensorflow::GraphDef& graphdef) { - gtl::FlatMap result; +gtl::FlatMap +NodeToAttrMap(const tensorflow::GraphDef& graphdef) { + gtl::FlatMap + result; for (const tensorflow::NodeDef& node : graphdef.node()) { result[node.name()] = &node.attr(); } return result; } -gtl::FlatMap +gtl::FlatMap FunctionNameToFunctionDefMap(const FunctionDefLibrary& library) { - gtl::FlatMap + gtl::FlatMap result; for (const FunctionDef& function_def : library.function()) { result[function_def.signature().name()] = &function_def; diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h index 6cebe518a6cfd8..9a6108dbb0c438 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h @@ -83,17 +83,18 @@ absl::Status FlattenSignature( // Find the node id in `object_graph` at location `path`. `path` must be // a dot-delimited string of object names relative to the root object. If no // object is found, returns absl::nullopt. -absl::optional FindNodeAtPath(StringPiece path, +absl::optional FindNodeAtPath(absl::string_view path, const SavedObjectGraph& object_graph); // Maps each node in `graphdef` to its corresponding Attribute Map. // Callers must ensure that `graphdef` outlives the returned map. -gtl::FlatMap NodeToAttrMap( - const tensorflow::GraphDef& graphdef); +gtl::FlatMap +NodeToAttrMap(const tensorflow::GraphDef& graphdef); // Maps the name of each FunctionDef in `library` to its corresponding // FunctionDef. Callers must ensure `library` outlives the returned map. -gtl::FlatMap +gtl::FlatMap FunctionNameToFunctionDefMap(const FunctionDefLibrary& library); // Finds the "signatures" object in the object graph, and fills a mapping of diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc index 2f8230af3f028e..66dd039650103a 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc @@ -65,8 +65,9 @@ limitations under the License. namespace tensorflow { // Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary -using FunctionDefMap = gtl::FlatMap; +using FunctionDefMap = + gtl::FlatMap; // Maps from a functiondef's name to the corresponding "TFConcreteFunction" using FlatTensorFunctionMap = diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index 987a157a5e9e6d..dee65387df04b6 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -152,6 +152,7 @@ cc_library( "//tensorflow/c/experimental/saved_model/core:tf_saved_model_api", "//tensorflow/core:lib", "//tensorflow/core/common_runtime/eager:context", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc index 7feefc4bd671e1..f07beb42fa6ec4 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/types/optional.h" #include "tensorflow/c/eager/tfe_context_internal.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" @@ -99,7 +100,7 @@ TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model, const char* function_path, TF_Status* status) { tensorflow::ConcreteFunction* result = nullptr; - tensorflow::Status get_function_status = + absl::Status get_function_status = tensorflow::unwrap(model)->GetFunction(function_path, &result); status->status.Update(get_function_status); if (!get_function_status.ok()) { @@ -113,7 +114,7 @@ TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model, const char* signature_def_key, TF_Status* status) { tensorflow::SignatureDefFunction* result = nullptr; - tensorflow::Status get_function_status = + absl::Status get_function_status = tensorflow::unwrap(model)->GetSignatureDefFunction(signature_def_key, &result); status->status.Update(get_function_status); diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc index 2aaabe180770a0..51c0d5971501fa 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc @@ -49,7 +49,7 @@ using tensorflow::tstring; constexpr char kTestData[] = "cc/saved_model/testdata"; const char* kServeTag[] = {"serve"}; -std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) { +std::string SavedModelPath(absl::string_view saved_model_dir) { return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(), kTestData, saved_model_dir); } diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD index 1cdcb0df9babf0..0e80c72eb1f24a 100644 --- a/tensorflow/c/experimental/stream_executor/BUILD +++ b/tensorflow/c/experimental/stream_executor/BUILD @@ -47,6 +47,7 @@ cc_library( "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", "@local_tsl//tsl/platform:status", "@local_xla//xla/stream_executor:device_description", "@local_xla//xla/stream_executor:executor_cache", @@ -98,11 +99,17 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/protobuf:error_codes_proto_impl_cc", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", "@local_xla//xla/stream_executor:event", "@local_xla//xla/stream_executor:platform_manager", "@local_xla//xla/stream_executor:stream", "@local_xla//xla/stream_executor:stream_executor_h", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index ff2d6146c5ead1..b19195ec208c81 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" +#include "absl/types/optional.h" #include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/c_api_macros_internal.h" #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc index 680a1d9d1db1f5..810e72aa48b436 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc @@ -14,15 +14,27 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/stream_executor/stream_executor.h" +#include +#include +#include +#include #include +#include +#include #include +#include +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" #include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/error_codes.pb.h" diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test_util.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test_util.cc index 41928bc469c104..f145e6c3376f7b 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_test_util.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test_util.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h" +#include + #include "tensorflow/c/experimental/stream_executor/stream_executor.h" namespace stream_executor { diff --git a/tensorflow/c/kernels/BUILD b/tensorflow/c/kernels/BUILD index a178b0fb66e5b7..9569eda9fb12af 100644 --- a/tensorflow/c/kernels/BUILD +++ b/tensorflow/c/kernels/BUILD @@ -47,6 +47,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/log:check", "@eigen_archive//:eigen3", ], ) @@ -61,6 +62,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/log:check", ], ) diff --git a/tensorflow/c/kernels/bitcast_op_test.cc b/tensorflow/c/kernels/bitcast_op_test.cc index 169c9b17da3a78..ef93d0d4438f96 100644 --- a/tensorflow/c/kernels/bitcast_op_test.cc +++ b/tensorflow/c/kernels/bitcast_op_test.cc @@ -44,7 +44,7 @@ class DummyDevice : public DeviceBase { void TestBitcastOp(Tensor* input_tensor, DataType out_type, TensorShape expected_shape, error::Code expected_code) { - Status status; + absl::Status status; NodeDef def; def.set_op("Bitcast"); def.set_device(DEVICE_CPU); diff --git a/tensorflow/c/kernels/histogram_summary_op.cc b/tensorflow/c/kernels/histogram_summary_op.cc index 87adad4104f222..7f34e5217c20ba 100644 --- a/tensorflow/c/kernels/histogram_summary_op.cc +++ b/tensorflow/c/kernels/histogram_summary_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/c/kernels.h" #include "tensorflow/c/tf_status.h" diff --git a/tensorflow/c/kernels/merge_summary_op.cc b/tensorflow/c/kernels/merge_summary_op.cc index 2a7ddc6e93c678..339267d094a554 100644 --- a/tensorflow/c/kernels/merge_summary_op.cc +++ b/tensorflow/c/kernels/merge_summary_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "tensorflow/c/kernels.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_tensor.h" diff --git a/tensorflow/c/kernels/summary_op_test.cc b/tensorflow/c/kernels/summary_op_test.cc index 9bb23eefe2d4bd..11a7c06c1d2e30 100644 --- a/tensorflow/c/kernels/summary_op_test.cc +++ b/tensorflow/c/kernels/summary_op_test.cc @@ -54,7 +54,7 @@ void ExpectSummaryMatches(const Summary& actual, const string& expected_str) { void TestScalarSummaryOp(Tensor* tags, Tensor* values, string expected_output, error::Code expected_code) { // Initialize node used to fetch OpKernel - Status status; + absl::Status status; NodeDef def; def.set_op("ScalarSummary"); diff --git a/tensorflow/c/kernels/tensor_shape_utils_test.cc b/tensorflow/c/kernels/tensor_shape_utils_test.cc index 783105f3ad7009..dc972a428a01d3 100644 --- a/tensorflow/c/kernels/tensor_shape_utils_test.cc +++ b/tensorflow/c/kernels/tensor_shape_utils_test.cc @@ -36,7 +36,7 @@ struct TF_TensorWrapper { void TestShapeMatch(TensorShape shape) { Tensor tensor(DT_FLOAT, shape); - Status status; + absl::Status status; TF_Tensor* tf_tensor = TF_TensorFromTensor(tensor, &status); TF_TensorWrapper tensor_wrapper = TF_TensorWrapper(tf_tensor); ASSERT_TRUE(status.ok()) << status.ToString(); diff --git a/tensorflow/c/tf_datatype.h b/tensorflow/c/tf_datatype.h index 9a9eaadc08c30d..448207bf42993d 100644 --- a/tensorflow/c/tf_datatype.h +++ b/tensorflow/c/tf_datatype.h @@ -55,10 +55,12 @@ typedef enum TF_DataType { TF_FLOAT8_E5M2 = 24, // 5 exponent bits, 2 mantissa bits. TF_FLOAT8_E4M3FN = 25, // 4 exponent bits, 3 mantissa bits, finite-only, with // 2 NaNs (0bS1111111). - // TODO - b/299182407: Leaving room for remaining float8 types. - // TF_FLOAT8_E4M3FNUZ = 26, - // TF_FLOAT8_E4M3B11FNUZ = 27, - // TF_FLOAT8_E5M2FNUZ = 28, + TF_FLOAT8_E4M3FNUZ = 26, // 4 exponent bits, 3 mantissa bits, + // finite-only,with NaN. + TF_FLOAT8_E4M3B11FNUZ = 27, // 4 exponent bits, 3 mantissa bits, 11 bits + // bias, finite-only, with NaNs. + TF_FLOAT8_E5M2FNUZ = 28, // 5 exponent bits, 2 mantissa bits, + // finite-only,with NaN. TF_INT4 = 29, TF_UINT4 = 30, } TF_DataType; diff --git a/tensorflow/cc/experimental/base/tests/tensor_test.cc b/tensorflow/cc/experimental/base/tests/tensor_test.cc index 5ce9dde5c81ec7..03f4515521c7db 100644 --- a/tensorflow/cc/experimental/base/tests/tensor_test.cc +++ b/tensorflow/cc/experimental/base/tests/tensor_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include + #include #include "absl/types/span.h" #include "tensorflow/c/tf_datatype.h" diff --git a/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc b/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc index c5751a18ce57ce..71a4c5f5e4f628 100644 --- a/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc +++ b/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include #include "absl/types/span.h" diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 031451d3d2d339..5d9bf652b7829c 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -45,8 +45,8 @@ namespace { const int kRightMargin = 79; -string GetConstructorDecl(const OpInfo& op_info, StringPiece op_name_prefix, - bool include_attr) { +string GetConstructorDecl(const OpInfo& op_info, + absl::string_view op_name_prefix, bool include_attr) { const string prefix = strings::StrCat(op_name_prefix, op_info.op_name, "("); string c_decl; for (int i = 0; i < op_info.arg_types.size(); ++i) { diff --git a/tensorflow/cc/framework/cc_op_gen_main.cc b/tensorflow/cc/framework/cc_op_gen_main.cc index c42ae6323c9763..02545e9bcecc17 100644 --- a/tensorflow/cc/framework/cc_op_gen_main.cc +++ b/tensorflow/cc/framework/cc_op_gen_main.cc @@ -61,7 +61,7 @@ int main(int argc, char* argv[]) { exit(1); } - bool include_internal = tensorflow::StringPiece("1") == argv[3]; + bool include_internal = absl::string_view("1") == argv[3]; std::vector api_def_dirs = tensorflow::str_util::Split( argv[4], ",", tensorflow::str_util::SkipEmpty()); tensorflow::cc_op::PrintAllCCOps(argv[1], argv[2], include_internal, diff --git a/tensorflow/cc/framework/cc_op_gen_test.cc b/tensorflow/cc/framework/cc_op_gen_test.cc index 71521b71e88928..846291cbb2b54d 100644 --- a/tensorflow/cc/framework/cc_op_gen_test.cc +++ b/tensorflow/cc/framework/cc_op_gen_test.cc @@ -61,12 +61,12 @@ op { } )"; -void ExpectHasSubstr(StringPiece s, StringPiece expected) { +void ExpectHasSubstr(absl::string_view s, absl::string_view expected) { EXPECT_TRUE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } -void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) { +void ExpectDoesNotHaveSubstr(absl::string_view s, absl::string_view expected) { EXPECT_FALSE(absl::StrContains(s, expected)) << "'" << s << "' contains '" << expected << "'"; } diff --git a/tensorflow/cc/framework/cc_op_gen_util.cc b/tensorflow/cc/framework/cc_op_gen_util.cc index 23280b6bdc4736..3503fd83053d67 100644 --- a/tensorflow/cc/framework/cc_op_gen_util.cc +++ b/tensorflow/cc/framework/cc_op_gen_util.cc @@ -59,7 +59,7 @@ absl::StatusOr LoadOpsAndApiDefs( return api_def_map; } -string GetPath(StringPiece dot_h_fname) { +string GetPath(absl::string_view dot_h_fname) { auto pos = dot_h_fname.find("/bin/"); string result(dot_h_fname); if (pos != string::npos) { @@ -82,14 +82,14 @@ string GetPath(StringPiece dot_h_fname) { return result; } -string GetFilename(StringPiece path) { +string GetFilename(absl::string_view path) { size_t slash_pos = path.rfind('/'); if (slash_pos == path.npos) slash_pos = -1; size_t dot_pos = path.rfind('.'); return string(path.substr(slash_pos + 1, dot_pos - (slash_pos + 1))); } -string ToGuard(StringPiece path) { +string ToGuard(absl::string_view path) { string guard; guard.reserve(path.size() + 1); // + 1 -> trailing _ for (const char c : path) { @@ -105,7 +105,7 @@ string ToGuard(StringPiece path) { return guard; } -string ToTitle(StringPiece name) { +string ToTitle(absl::string_view name) { string title(name); for (int i = 0; i < title.size(); ++i) { if (title[i] == '_') title[i] = ' '; @@ -114,7 +114,7 @@ string ToTitle(StringPiece name) { return title; } -string MakeComment(StringPiece text, StringPiece indent) { +string MakeComment(absl::string_view text, absl::string_view indent) { string ret; while (!text.empty()) { int last_non_space = -1; @@ -134,7 +134,7 @@ string MakeComment(StringPiece text, StringPiece indent) { return ret; } -string PrintString(StringPiece str) { +string PrintString(absl::string_view str) { return strings::StrCat("\"", absl::CEscape(str), "\""); } @@ -280,7 +280,7 @@ bool IsEmptyList(const AttrValue::ListValue& list) { list.shape_size() == 0 && list.tensor_size() == 0; } -string ToCamelCase(StringPiece str) { +string ToCamelCase(absl::string_view str) { string result; const char joiner = '_'; size_t i = 0; @@ -301,7 +301,7 @@ string ToCamelCase(StringPiece str) { return result; } -string SeparateNamespaces(StringPiece str) { +string SeparateNamespaces(absl::string_view str) { string result; const char joiner = '_'; size_t i = 0; @@ -316,27 +316,26 @@ string SeparateNamespaces(StringPiece str) { return result; } -std::pair AttrTypeName(StringPiece attr_type) { - static const auto* attr_type_map = - new std::unordered_map, - StringPieceHasher>{ - {"string", {"StringPiece", false}}, - {"list(string)", {"gtl::ArraySlice<::tensorflow::tstring>", true}}, - {"int", {"int64", false}}, - {"list(int)", {"gtl::ArraySlice", true}}, - {"float", {"float", false}}, - {"list(float)", {"gtl::ArraySlice", true}}, - {"bool", {"bool", false}}, - {"list(bool)", {"gtl::ArraySlice", true}}, - {"type", {"DataType", false}}, - {"list(type)", {"DataTypeSlice", true}}, - {"shape", {"PartialTensorShape", false}}, - {"list(shape)", {"gtl::ArraySlice", true}}, - {"tensor", {"TensorProto", true}}, - {"list(tensor)", {"gtl::ArraySlice", true}}, - {"func", {"NameAttrList", true}}, - {"list(func)", {"gtl::ArraySlice", true}}, - }; +std::pair AttrTypeName(absl::string_view attr_type) { + static const auto* attr_type_map = new std::unordered_map< + absl::string_view, std::pair, StringPieceHasher>{ + {"string", {"StringPiece", false}}, + {"list(string)", {"gtl::ArraySlice<::tensorflow::tstring>", true}}, + {"int", {"int64", false}}, + {"list(int)", {"gtl::ArraySlice", true}}, + {"float", {"float", false}}, + {"list(float)", {"gtl::ArraySlice", true}}, + {"bool", {"bool", false}}, + {"list(bool)", {"gtl::ArraySlice", true}}, + {"type", {"DataType", false}}, + {"list(type)", {"DataTypeSlice", true}}, + {"shape", {"PartialTensorShape", false}}, + {"list(shape)", {"gtl::ArraySlice", true}}, + {"tensor", {"TensorProto", true}}, + {"list(tensor)", {"gtl::ArraySlice", true}}, + {"func", {"NameAttrList", true}}, + {"list(func)", {"gtl::ArraySlice", true}}, + }; auto entry = attr_type_map->find(attr_type); if (entry == attr_type_map->end()) { @@ -346,17 +345,14 @@ std::pair AttrTypeName(StringPiece attr_type) { return entry->second; } -StringPiece ListElementTypeName(StringPiece attr_type) { - static const auto* attr_list_type_map = - new absl::flat_hash_map{ - {"list(string)", "string"}, - {"list(int)", "int"}, - {"list(float)", "float"}, - {"list(bool)", "bool"}, - {"list(type)", "DataType"}, - {"list(shape)", "PartialTensorShape"}, - {"list(tensor)", "TensorProto"}, - }; +absl::string_view ListElementTypeName(absl::string_view attr_type) { + static const auto* attr_list_type_map = new absl::flat_hash_map< + absl::string_view, absl::string_view, StringPieceHasher>{ + {"list(string)", "string"}, {"list(int)", "int"}, + {"list(float)", "float"}, {"list(bool)", "bool"}, + {"list(type)", "DataType"}, {"list(shape)", "PartialTensorShape"}, + {"list(tensor)", "TensorProto"}, + }; auto entry = attr_list_type_map->find(attr_type); if (entry == attr_list_type_map->end()) { @@ -366,10 +362,11 @@ StringPiece ListElementTypeName(StringPiece attr_type) { return entry->second; } -bool IsCPPKeyword(StringPiece name) { - static const absl::flat_hash_set* +bool IsCPPKeyword(absl::string_view name) { + static const absl::flat_hash_set* // Keywords obtained from http://en.cppreference.com/w/cpp/keyword - kCPPReserved = new absl::flat_hash_set{ + kCPPReserved = new absl::flat_hash_set{ "alignas", "alignof", "and", @@ -477,7 +474,7 @@ bool IsCPPKeyword(StringPiece name) { return kCPPReserved->count(name) > 0; } -string AvoidCPPKeywords(StringPiece name) { +string AvoidCPPKeywords(absl::string_view name) { if (IsCPPKeyword(name)) { return strings::StrCat(name, "_"); } @@ -558,7 +555,7 @@ OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, arg_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to())); // TODO(keveman): Include input type information. - StringPiece description = api_def_arg.description(); + absl::string_view description = api_def_arg.description(); if (!description.empty()) { ConsumeEquals(&description); strings::StrAppend(&comment, "* ", diff --git a/tensorflow/cc/framework/cc_op_gen_util.h b/tensorflow/cc/framework/cc_op_gen_util.h index 128c3ca6877288..4e3272c7e38c0d 100644 --- a/tensorflow/cc/framework/cc_op_gen_util.h +++ b/tensorflow/cc/framework/cc_op_gen_util.h @@ -40,30 +40,30 @@ absl::StatusOr LoadOpsAndApiDefs( // Converts: // bazel-out/.../(bin|genfiles)/(external/YYY/)?XX // to: XX. -string GetPath(StringPiece dot_h_fname); +string GetPath(absl::string_view dot_h_fname); // Converts: some/path/to/file.xx // to: file // (note that suffix is removed) -string GetFilename(StringPiece path); +string GetFilename(absl::string_view path); // Converts: // cc/ops/gen_foo_ops.h // to: // CC_OPS_GEN_FOO_OPS_H_ -string ToGuard(StringPiece path); +string ToGuard(absl::string_view path); // Converts: some_name_xyz // to: Some Name Xyz -string ToTitle(StringPiece name); +string ToTitle(absl::string_view name); // Change: Into: // ABC /// ABC // /// // DEF /// DEF -string MakeComment(StringPiece text, StringPiece indent); +string MakeComment(absl::string_view text, absl::string_view indent); -string PrintString(StringPiece str); +string PrintString(absl::string_view str); string PrintTensorShape(const TensorShapeProto& shape_proto); @@ -81,25 +81,25 @@ string PrintTensor(const TensorProto& tensor_proto); string PrintTensorProto(const TensorProto& proto); -string PrintAttrValue(StringPiece, const AttrValue& attr_value); +string PrintAttrValue(absl::string_view, const AttrValue& attr_value); bool IsEmptyList(const AttrValue::ListValue& list); -string ToCamelCase(StringPiece str); +string ToCamelCase(absl::string_view str); -string SeparateNamespaces(StringPiece str); +string SeparateNamespaces(absl::string_view str); // Returns a pair. The string is the C++ type name to be used for // attr_type when defining an object of that type. The bool is a flag to // indicate whether to treat the type as const when accepting the C++ type as an // argument to a function. -std::pair AttrTypeName(StringPiece attr_type); +std::pair AttrTypeName(absl::string_view attr_type); -StringPiece ListElementTypeName(StringPiece attr_type); +absl::string_view ListElementTypeName(absl::string_view attr_type); -bool IsCPPKeyword(StringPiece name); +bool IsCPPKeyword(absl::string_view name); -string AvoidCPPKeywords(StringPiece name); +string AvoidCPPKeywords(absl::string_view name); void InferArgAttributes(const OpDef::ArgDef& arg, std::unordered_map* inferred_attrs); @@ -123,7 +123,7 @@ struct OpInfo { const std::vector& aliases); OpInfo(const OpDef& graph_op_def, const ApiDef& api_def); string GetOpAttrStruct() const; - string GetConstructorDecl(StringPiece op_name_prefix, + string GetConstructorDecl(absl::string_view op_name_prefix, bool include_attr) const; string op_name; diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc index cacc15ca32d28f..3d114ff18e95c5 100644 --- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc +++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc @@ -15,13 +15,10 @@ limitations under the License. #include "tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h" -#include #include -#include #include #include #include -#include #include #include @@ -47,7 +44,8 @@ namespace { string DefaultValue(OpDef_AttrDef attr) { static const auto* attr_default_value_map = - new absl::flat_hash_map{ + new absl::flat_hash_map{ {"int", "0"}, {"string", "\"\""}, {"list(int)", "{ 0, 1 }"}, diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc index 667d566a4fa5c1..f521e11fea3c6a 100644 --- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc +++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc @@ -15,8 +15,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include +#include #include -#include #include #include "absl/status/status.h" diff --git a/tensorflow/cc/framework/grad_op_registry.cc b/tensorflow/cc/framework/grad_op_registry.cc index 26628759277889..d95b05ee24d1b1 100644 --- a/tensorflow/cc/framework/grad_op_registry.cc +++ b/tensorflow/cc/framework/grad_op_registry.cc @@ -29,7 +29,7 @@ bool GradOpRegistry::Register(const string& op, GradFunc func) { return true; } -Status GradOpRegistry::Lookup(const string& op, GradFunc* func) const { +absl::Status GradOpRegistry::Lookup(const string& op, GradFunc* func) const { auto iter = registry_.find(op); if (iter == registry_.end()) { const string error_msg = diff --git a/tensorflow/cc/framework/grad_op_registry.h b/tensorflow/cc/framework/grad_op_registry.h index 951144cf8ce43a..b08478443d78dc 100644 --- a/tensorflow/cc/framework/grad_op_registry.h +++ b/tensorflow/cc/framework/grad_op_registry.h @@ -29,9 +29,9 @@ namespace ops { /// GradFunc is the signature for all gradient functions in GradOpRegistry. /// Implementations should add operations to compute the gradient outputs of /// 'op' (returned in 'grad_outputs') using 'scope' and 'grad_inputs'. -typedef Status (*GradFunc)(const Scope& scope, const Operation& op, - const std::vector& grad_inputs, - std::vector* grad_outputs); +typedef absl::Status (*GradFunc)(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs); /// GradOpRegistry maintains a static registry of gradient functions. /// Gradient functions are indexed in the registry by the forward op name (i.e. @@ -47,7 +47,7 @@ class GradOpRegistry { /// Note that 'func' can be null for ops that have registered no-gradient with /// the registry. /// Returns error status otherwise. - Status Lookup(const string& op, GradFunc* func) const; + absl::Status Lookup(const string& op, GradFunc* func) const; /// Returns a pointer to the global gradient function registry. static GradOpRegistry* Global(); diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc index 90f104bc24b129..039b36f54ace40 100644 --- a/tensorflow/cc/framework/gradient_checker.cc +++ b/tensorflow/cc/framework/gradient_checker.cc @@ -104,7 +104,7 @@ SET_JACOBIAN_STRIDE(complex64, 2); SET_JACOBIAN_STRIDE(complex128, 2); template -Status ComputeTheoreticalJacobianTranspose( +absl::Status ComputeTheoreticalJacobianTranspose( const Scope& scope, const OutputList& xs, const std::vector& x_shapes, const std::vector& x_datas, const OutputList& ys, @@ -186,9 +186,9 @@ Status ComputeTheoreticalJacobianTranspose( return absl::OkStatus(); } -Status EvaluateGraph(ClientSession* session, const OutputList& xs, - const OutputList& ys, std::vector* x_datas, - std::vector* y_datas) { +absl::Status EvaluateGraph(ClientSession* session, const OutputList& xs, + const OutputList& ys, std::vector* x_datas, + std::vector* y_datas) { // Create the feed list. ClientSession::FeedType feed_list; for (int i = 0; i < x_datas->size(); i++) { @@ -212,13 +212,11 @@ Status EvaluateGraph(ClientSession* session, const OutputList& xs, } template -Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs, - const std::vector& x_shapes, - const OutputList& ys, - const std::vector& y_shapes, - const JAC_T delta, - std::vector* x_datas, - std::vector* jacobian_ts) { +absl::Status ComputeNumericJacobianTranspose( + const Scope& scope, const OutputList& xs, + const std::vector& x_shapes, const OutputList& ys, + const std::vector& y_shapes, const JAC_T delta, + std::vector* x_datas, std::vector* jacobian_ts) { size_t y_num = y_shapes.size(); size_t x_num = x_shapes.size(); // x_stride and y_stride are used to calculate the correct jacobian row and @@ -332,12 +330,11 @@ void InitJacobians(const OutputList& xs, } template -Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs, - const std::vector& x_shapes, - const OutputList& ys, - const std::vector& y_shapes, - std::vector* x_datas, - JAC_T* max_error) { +absl::Status ComputeGradientErrorInternal( + const Scope& scope, const OutputList& xs, + const std::vector& x_shapes, const OutputList& ys, + const std::vector& y_shapes, std::vector* x_datas, + JAC_T* max_error) { // Initialize theoretical Jacobians to zeros. std::vector jacobian_ts; InitJacobians(xs, x_shapes, y_shapes, &jacobian_ts); @@ -378,11 +375,11 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs, } // namespace template -Status ComputeGradientError(const Scope& scope, const OutputList& xs, - const std::vector& x_shapes, - const OutputList& ys, - const std::vector& y_shapes, - JAC_T* max_error) { +absl::Status ComputeGradientError(const Scope& scope, const OutputList& xs, + const std::vector& x_shapes, + const OutputList& ys, + const std::vector& y_shapes, + JAC_T* max_error) { if (xs.size() != x_shapes.size()) { return errors::InvalidArgument("xs(size ", xs.size(), ") and x_shapes(size ", x_shapes.size(), @@ -406,9 +403,10 @@ Status ComputeGradientError(const Scope& scope, const OutputList& xs, } template -Status ComputeGradientError(const Scope& scope, const Output& x, - const Tensor& x_init_value, const Output& y, - const TensorShape& y_shape, JAC_T* max_error) { +absl::Status ComputeGradientError(const Scope& scope, const Output& x, + const Tensor& x_init_value, const Output& y, + const TensorShape& y_shape, + JAC_T* max_error) { // Initialize 'x_data' from 'x_init_value'. std::vector x_datas(1, Tensor(x_init_value)); // Compute gradient error. diff --git a/tensorflow/cc/framework/gradient_checker.h b/tensorflow/cc/framework/gradient_checker.h index b8db767f77cc58..20b6545f1f51d7 100644 --- a/tensorflow/cc/framework/gradient_checker.h +++ b/tensorflow/cc/framework/gradient_checker.h @@ -48,17 +48,17 @@ namespace tensorflow { /// if y = Complex(x, x) where x is DT_FLOAT (so y is DT_COMPLEX64) /// should be template -Status ComputeGradientError(const Scope& scope, const OutputList& xs, - const std::vector& x_shapes, - const OutputList& ys, - const std::vector& y_shapes, - JAC_T* max_error); +absl::Status ComputeGradientError(const Scope& scope, const OutputList& xs, + const std::vector& x_shapes, + const OutputList& ys, + const std::vector& y_shapes, + JAC_T* max_error); /// Overload of ComputeGradientError which takes an initial value for 'x'. template -Status ComputeGradientError(const Scope& scope, const Output& x, - const Tensor& x_init_value, const Output& y, - const TensorShape& y_shape, JAC_T* max_error); +absl::Status ComputeGradientError(const Scope& scope, const Output& x, + const Tensor& x_init_value, const Output& y, + const TensorShape& y_shape, JAC_T* max_error); } // namespace tensorflow diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index 548f5c04833a2e..876a259925910c 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -58,31 +58,31 @@ class SymbolicGradientBuilder { const std::vector& grad_inputs, std::vector* grad_outputs); - Status AddGradients(); + absl::Status AddGradients(); static Output NoGradient() { return Output(nullptr, -1); } private: - Status Initialize(); + absl::Status Initialize(); // For each forward edge from `src` to `dst` in the initial/forward graph: // propagates gradients `dst_grad` backwards along the edge from `src` // to `dst` in the graph. This will add `dst_grad` to the list of pending // gradients for the node associated with `src`. - Status BackpropAlongEdge(const Output& dst_grad, const Output& src); + absl::Status BackpropAlongEdge(const Output& dst_grad, const Output& src); // Adds a node to the graph (returned in `grad`) that sums the in-bound // gradients to `src` (if there are more than one). - Status SumGradients(const Output& src, Output* grad); + absl::Status SumGradients(const Output& src, Output* grad); // Returns true if `opname` is registered in `registry_` with no gradient // function, false otherwise. bool IsPrimitiveOpWithNoGrad(const string& opname); // Call the gradient function for `op`, storing the result in `grad_outputs`. - Status CallGradFunction(const Operation& op, - const std::vector& grad_inputs, - std::vector* grad_outputs); + absl::Status CallGradFunction(const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs); // Returns a list mapping whether each node in the graph is reachable // from outputs_. Keyed by node id. @@ -93,7 +93,7 @@ class SymbolicGradientBuilder { // nodes (which are the first nodes of a loop encountered in the backwards // pass) are passed to this function rather than processed normally. // `summed_grads` is the sum of `exit_node`s gradients. - Status ProcessWhileLoop(Node* exit_node, const Output& summed_grads); + absl::Status ProcessWhileLoop(Node* exit_node, const Output& summed_grads); // Gets the set of node ids at which to stop backprop. These are all elements // of `outputs_` that do not get transitively consumed by other `outputs_`. @@ -153,8 +153,8 @@ SymbolicGradientBuilder::SymbolicGradientBuilder( grad_inputs_(grad_inputs), grad_outputs_(grad_outputs) {} -Status SymbolicGradientBuilder::BackpropAlongEdge(const Output& dst_grad, - const Output& src) { +absl::Status SymbolicGradientBuilder::BackpropAlongEdge(const Output& dst_grad, + const Output& src) { if (src.node() == nullptr) { return errors::Internal("Attempted to backprop along an invalid edge."); } @@ -251,7 +251,7 @@ std::unordered_set SymbolicGradientBuilder::GetStopBackpropNodes( return stop_backprop_nodes; } -Status SymbolicGradientBuilder::Initialize() { +absl::Status SymbolicGradientBuilder::Initialize() { if (outputs_.size() != grad_inputs_.size()) { return errors::InvalidArgument( "Must specify a gradient input for each output."); @@ -344,7 +344,8 @@ Status SymbolicGradientBuilder::Initialize() { return absl::OkStatus(); } -Status SymbolicGradientBuilder::SumGradients(const Output& src, Output* grad) { +absl::Status SymbolicGradientBuilder::SumGradients(const Output& src, + Output* grad) { auto iter = backprops_.find(src); if (iter == backprops_.end()) { return errors::Internal("Unable to find backprop list for node.id ", @@ -377,11 +378,11 @@ Status SymbolicGradientBuilder::SumGradients(const Output& src, Output* grad) { bool SymbolicGradientBuilder::IsPrimitiveOpWithNoGrad(const string& opname) { ops::GradFunc grad_fn; - Status s = registry_->Lookup(opname, &grad_fn); + absl::Status s = registry_->Lookup(opname, &grad_fn); return s.ok() && (grad_fn == nullptr); } -Status SymbolicGradientBuilder::CallGradFunction( +absl::Status SymbolicGradientBuilder::CallGradFunction( const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { ops::GradFunc grad_fn; @@ -391,8 +392,8 @@ Status SymbolicGradientBuilder::CallGradFunction( return absl::OkStatus(); } -Status SymbolicGradientBuilder::ProcessWhileLoop(Node* exit_node, - const Output& summed_grads) { +absl::Status SymbolicGradientBuilder::ProcessWhileLoop( + Node* exit_node, const Output& summed_grads) { // TODO(skyewm): detect second-order gradient and return bad status // TODO(skyewm): handle (or at least detect) nested while loops @@ -439,7 +440,7 @@ Status SymbolicGradientBuilder::ProcessWhileLoop(Node* exit_node, return absl::OkStatus(); } -Status SymbolicGradientBuilder::AddGradients() { +absl::Status SymbolicGradientBuilder::AddGradients() { // Initialize backprops. TF_RETURN_IF_ERROR(Initialize()); @@ -559,20 +560,20 @@ Status SymbolicGradientBuilder::AddGradients() { } // namespace -Status AddSymbolicGradients(const Scope& scope, - const std::vector& outputs, - const std::vector& inputs, - const std::vector& grad_inputs, - std::vector* grad_outputs) { +absl::Status AddSymbolicGradients(const Scope& scope, + const std::vector& outputs, + const std::vector& inputs, + const std::vector& grad_inputs, + std::vector* grad_outputs) { SymbolicGradientBuilder builder(scope, ops::GradOpRegistry::Global(), outputs, inputs, grad_inputs, grad_outputs); return builder.AddGradients(); } -Status AddSymbolicGradients(const Scope& scope, - const std::vector& outputs, - const std::vector& inputs, - std::vector* grad_outputs) { +absl::Status AddSymbolicGradients(const Scope& scope, + const std::vector& outputs, + const std::vector& inputs, + std::vector* grad_outputs) { std::vector grad_inputs; grad_inputs.reserve(outputs.size()); for (const Output& output : outputs) { diff --git a/tensorflow/cc/framework/gradients.h b/tensorflow/cc/framework/gradients.h index d404bd34c4a3d8..c79269fde3a7b3 100644 --- a/tensorflow/cc/framework/gradients.h +++ b/tensorflow/cc/framework/gradients.h @@ -29,18 +29,18 @@ namespace tensorflow { /// derivatives of some loss function 'L' w.r.t 'outputs'), adds gradient nodes /// to the graph associated with 'scope', which compute (and return in /// 'grad_outputs') the symbolic partial derivatives of 'L' w.r.t 'inputs'. -Status AddSymbolicGradients(const Scope& scope, - const std::vector& outputs, - const std::vector& inputs, - const std::vector& grad_inputs, - std::vector* grad_outputs); +absl::Status AddSymbolicGradients(const Scope& scope, + const std::vector& outputs, + const std::vector& inputs, + const std::vector& grad_inputs, + std::vector* grad_outputs); // Same as above, but uses 'OnesLike' for all shapes in // 'outputs' as grad_inputs. -Status AddSymbolicGradients(const Scope& scope, - const std::vector& outputs, - const std::vector& inputs, - std::vector* grad_outputs); +absl::Status AddSymbolicGradients(const Scope& scope, + const std::vector& outputs, + const std::vector& inputs, + std::vector* grad_outputs); /// Returns a sentinel Output that represents 'no gradient' (i.e. no gradient /// flows along some graph edge during backpropagation). diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc index 2256d795422ca3..d0f8217a8d62f0 100644 --- a/tensorflow/cc/framework/gradients_test.cc +++ b/tensorflow/cc/framework/gradients_test.cc @@ -456,7 +456,7 @@ TEST_F(GradientsTest, UnreachableInput) { // / \ / \ // z y x std::vector grad_outputs; - Status status = + absl::Status status = AddSymbolicGradients(scope_test_, {m1}, {z}, {dm1}, &grad_outputs); EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); EXPECT_EQ(status.message(), diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h index 7bbb3b2bcb5236..e856e311ceb3ee 100644 --- a/tensorflow/cc/framework/ops.h +++ b/tensorflow/cc/framework/ops.h @@ -196,7 +196,7 @@ class Input { return tensor_proto; } - Status status; + absl::Status status; Tensor tensor; }; @@ -243,11 +243,11 @@ class Input { std::string node_name() const { return node_name_; } int32 index() const { return node_name_.empty() ? output_.index() : index_; } DataType data_type() const { return data_type_; } - Status status() const { return status_; } + absl::Status status() const { return status_; } const Tensor& tensor() const { return tensor_; } private: - Status status_; + absl::Status status_; Output output_ = Output(Operation(nullptr), 0); Tensor tensor_; const std::string node_name_ = ""; diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 0c972612089918..7cc8687ebbf8ac 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -41,7 +41,7 @@ const char kScopeSeparator[] = "/"; const char kSuffixSeparator[] = "_"; } // namespace -Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map, +Scope::Impl::Impl(Graph* graph, absl::Status* status, NameMap* name_map, ShapeRefiner* refiner, bool disable_shape_inference) : graph_(graph), status_(status), @@ -52,7 +52,7 @@ Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map, disable_shape_inference_(disable_shape_inference) {} Scope::Impl::Impl(const std::shared_ptr& graph, - const std::shared_ptr& status, + const std::shared_ptr& status, const std::shared_ptr& name_map, const std::shared_ptr& refiner) : graph_(graph), @@ -67,7 +67,7 @@ Scope Scope::NewRootScope() { Graph* graph = new Graph(OpRegistry::Global()); ShapeRefiner* refiner = new ShapeRefiner(graph->versions(), graph->op_registry()); - return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner, + return Scope(new Impl(graph, new absl::Status, new Impl::NameMap, refiner, /* disable_shape_inference */ false)); } @@ -75,7 +75,7 @@ Scope Scope::DisabledShapeInferenceScope() { Graph* graph = new Graph(OpRegistry::Global()); ShapeRefiner* refiner = new ShapeRefiner(graph->versions(), graph->op_registry()); - return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner, + return Scope(new Impl(graph, new absl::Status, new Impl::NameMap, refiner, /* disable_shape_inference */ true)); } @@ -274,7 +274,7 @@ std::unordered_set Scope::Impl::GetColocationConstraints( std::vector node_constraints; if (TryGetNodeAttr(attrs, kColocationAttrName, &node_constraints)) { for (const string& entry : node_constraints) { - StringPiece s(entry); + absl::string_view s(entry); if (absl::ConsumePrefix(&s, kColocationGroupPrefix)) { current_constraints.emplace(s); } @@ -293,20 +293,20 @@ std::shared_ptr Scope::graph_as_shared_ptr() const { return impl()->graph_; } -Status Scope::status() const { return *impl()->status_; } +absl::Status Scope::status() const { return *impl()->status_; } const std::vector& Scope::control_deps() const { return impl()->control_deps_; } -void Scope::UpdateStatus(const Status& s) const { +void Scope::UpdateStatus(const absl::Status& s) const { impl()->status_->Update(s); if (impl()->exit_on_error_ && !ok()) { LOG(FATAL) << *impl()->status_; } } -Status Scope::ToGraphDef(GraphDef* gdef, bool include_debug_info) const { +absl::Status Scope::ToGraphDef(GraphDef* gdef, bool include_debug_info) const { if (!ok()) { return *impl()->status_; } @@ -314,7 +314,7 @@ Status Scope::ToGraphDef(GraphDef* gdef, bool include_debug_info) const { return absl::OkStatus(); } -Status Scope::ToGraph(Graph* g, GraphConstructorOptions opts) const { +absl::Status Scope::ToGraph(Graph* g, GraphConstructorOptions opts) const { if (ok()) { GraphDef graph_def; graph()->ToGraphDef(&graph_def); @@ -498,7 +498,7 @@ CompositeOpScopes Scope::GetCompositeOpScopes( } } -Status Scope::DoShapeInference(Node* node) const { +absl::Status Scope::DoShapeInference(Node* node) const { if (impl_->disable_shape_inference_) return absl::OkStatus(); return impl_->refiner_->AddNode(node); } @@ -506,7 +506,8 @@ Status Scope::DoShapeInference(Node* node) const { class InternalScope { public: // NewScope doesn't take ownership of the inputs. - static Scope NewScope(Graph* graph, Status* status, ShapeRefiner* refiner) { + static Scope NewScope(Graph* graph, absl::Status* status, + ShapeRefiner* refiner) { Scope::Impl::NameMap* name_map = new Scope::Impl::NameMap; for (const Node* node : graph->nodes()) { const string& name = node->name(); @@ -521,19 +522,20 @@ class InternalScope { // since the caller owns them and doesn't want the scope to destroy them. return Scope(new Scope::Impl( std::shared_ptr(graph, [](Graph*) {}), - std::shared_ptr(status, [](Status*) {}), + std::shared_ptr(status, [](absl::Status*) {}), std::shared_ptr(name_map), std::shared_ptr(refiner, [](ShapeRefiner*) {}))); } }; -Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner) { +Scope NewInternalScope(Graph* graph, absl::Status* status, + ShapeRefiner* refiner) { return InternalScope::NewScope(graph, status, refiner); } -Status CreateOutputWithScope(string op_name, - absl::Span inputs, - const Scope& scope, Output* output) { +absl::Status CreateOutputWithScope(string op_name, + absl::Span inputs, + const Scope& scope, Output* output) { TF_RETURN_IF_ERROR(scope.status()); const auto unique_name = scope.GetUniqueNameForOp(op_name); auto builder = ::tensorflow::NodeBuilder(unique_name, op_name); diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index 0b0f6871e7f27c..9b8896e4ad6ee9 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -176,7 +176,7 @@ class Scope { /// Note: The status object is shared between all children of this scope. /// If the resulting status is not OkStatus() and exit_on_error_ is set on /// this scope, this function exits by calling LOG(FATAL). - void UpdateStatus(const Status& s) const; + void UpdateStatus(const absl::Status& s) const; // START_SKIP_DOXYGEN @@ -196,14 +196,15 @@ class Scope { // TODO(skyewm): Graph is not part of public API std::shared_ptr graph_as_shared_ptr() const; - Status status() const; + absl::Status status() const; /// If status() is ok, convert the Graph object stored in this scope /// to a GraphDef proto and return an ok Status. Otherwise, return the error /// status as is without performing GraphDef conversion. If /// `include_debug_info` is true, populate the `debug_info` field of the /// GraphDef from stack traces in this Graph. - Status ToGraphDef(GraphDef* gdef, bool include_debug_info = false) const; + absl::Status ToGraphDef(GraphDef* gdef, + bool include_debug_info = false) const; // START_SKIP_DOXYGEN @@ -214,14 +215,14 @@ class Scope { // Graph->GraphDef->Graph. This cleans up the graph (e.g. adds // edges from the source and to the sink node, resolves back edges // by name), and makes sure the resulting graph is valid. - Status ToGraph( + absl::Status ToGraph( Graph* g, GraphConstructorOptions opts = GraphConstructorOptions{}) const; // Calls AddNode() using this scope's ShapeRefiner. This exists in the public // API to prevent custom op wrappers from needing access to shape_refiner.h or // scope_internal.h. // TODO(skyewm): remove this from public API - Status DoShapeInference(Node* node) const; + absl::Status DoShapeInference(Node* node) const; // Creates a new root scope that causes all DoShapeInference() calls to return // OkStatus() (on the returned scope and any subscopes). Used for testing. @@ -259,9 +260,9 @@ struct CompositeOpScopes { // Creates a node of the given operation, with the given inputs, and assigns the // result to output. This does not support the ability to add additional // attributes. -Status CreateOutputWithScope(string op_name, - absl::Span inputs, - const Scope& scope, Output* output); +absl::Status CreateOutputWithScope(string op_name, + absl::Span inputs, + const Scope& scope, Output* output); /// @} } // namespace tensorflow diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index 586165ee4eb2b8..0cf6af6812c27a 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -34,7 +34,8 @@ class ShapeRefiner; // bindings) to create a Scope and access C++ functionality (i.e. gradients). // // Shape inference is disabled if `refiner` is nullptr. -Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner); +Scope NewInternalScope(Graph* graph, absl::Status* status, + ShapeRefiner* refiner); class Scope::Impl { public: @@ -46,7 +47,7 @@ class Scope::Impl { typedef std::unordered_map NameMap; Impl(const std::shared_ptr& graph, - const std::shared_ptr& status, + const std::shared_ptr& status, const std::shared_ptr& name_map, const std::shared_ptr& refiner); @@ -70,8 +71,8 @@ class Scope::Impl { enum class XlaCluster; }; - Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner, - bool disable_shape_inference); + Impl(Graph* graph, absl::Status* status, NameMap* name_map, + ShapeRefiner* refiner, bool disable_shape_inference); Impl(const Scope& other, Tags::ScopeName, const string& name, bool copy_names); Impl(const Scope& other, Tags::OpName, const string& name, @@ -101,7 +102,7 @@ class Scope::Impl { // Scope::NewRootScope function, which creates a new graph, a new status and // the name maps. std::shared_ptr graph_ = nullptr; - std::shared_ptr status_ = nullptr; + std::shared_ptr status_ = nullptr; std::shared_ptr name_map_ = nullptr; std::shared_ptr refiner_ = nullptr; diff --git a/tensorflow/cc/framework/while_gradients.cc b/tensorflow/cc/framework/while_gradients.cc index 9f966994ea2066..107f82a605be97 100644 --- a/tensorflow/cc/framework/while_gradients.cc +++ b/tensorflow/cc/framework/while_gradients.cc @@ -56,8 +56,8 @@ string BackPropFrameName(const string& forward_frame_name) { // Creates a loop that counts the number of iterations performed by the // while loop associated with `while_ctx`. The returned output yields the // iteration count. -Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope, - Output* count) { +absl::Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope, + Output* count) { // Create while loop: // i = 0 // while forward loop predicate is true: @@ -95,9 +95,10 @@ Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope, // boolean predicate indicating if the loop is still executing. This is used to // drive the gradient computation for the while loop associated with // `while_ctx`. -Status AddBackPropLoopCounter(WhileContext* while_ctx, const Output& loop_count, - const Scope& scope, - Output* backprop_execution_pred) { +absl::Status AddBackPropLoopCounter(WhileContext* while_ctx, + const Output& loop_count, + const Scope& scope, + Output* backprop_execution_pred) { // Create while loop: // n = loop_count // while n > 0: @@ -135,11 +136,11 @@ Status AddBackPropLoopCounter(WhileContext* while_ctx, const Output& loop_count, // the predicate to use for the backprop loop (see AddBackPropLoopCounter()). // The partial derivatives w.r.t. the loop inputs, i.e. the input loop vars, are // returned in `grad_outputs`. -Status AddWhileGradientLoop(WhileContext* while_ctx, - const std::vector& grad_inputs, - const Output& backprop_execution_pred, - const Scope& parent_scope, - std::vector* grad_outputs) { +absl::Status AddWhileGradientLoop(WhileContext* while_ctx, + const std::vector& grad_inputs, + const Output& backprop_execution_pred, + const Scope& parent_scope, + std::vector* grad_outputs) { DCHECK_EQ(grad_inputs.size(), while_ctx->body_outputs().size()); DCHECK_EQ(while_ctx->body_inputs().size(), while_ctx->body_outputs().size()); @@ -178,9 +179,9 @@ Status AddWhileGradientLoop(WhileContext* while_ctx, } // namespace -Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope, - const std::vector& grad_inputs, - std::vector* grad_outputs) { +absl::Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope, + const std::vector& grad_inputs, + std::vector* grad_outputs) { Output forward_loop_count; TF_RETURN_IF_ERROR(AddForwardLoopCounter( while_ctx, scope.NewSubScope("ForwardLoopCounter"), &forward_loop_count)); diff --git a/tensorflow/cc/framework/while_gradients.h b/tensorflow/cc/framework/while_gradients.h index 6d33d49dbb3d9e..1f31de15ebab6f 100644 --- a/tensorflow/cc/framework/while_gradients.h +++ b/tensorflow/cc/framework/while_gradients.h @@ -33,9 +33,9 @@ namespace tensorflow { // `grad_inputs` and `grad_outputs` are both in loop-variable order, as defined // by the original inputs to BuildWhileLoop(). // TODO(skyewm): maybe comment on NoGradient once it's supported -Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope, - const std::vector& grad_inputs, - std::vector* grad_outputs); +absl::Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope, + const std::vector& grad_inputs, + std::vector* grad_outputs); } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index c70616158f2a11..c8f4db108d4589 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -397,7 +397,7 @@ REGISTER_GRADIENT_OP("FractionalMaxPool", FractionalMaxPoolGradHelper); // Templated constructor for FusedBatchNormGrad[..]::Attrs. template -T FusedBatchNormGradAttrs(float epsilon, StringPiece data_format, +T FusedBatchNormGradAttrs(float epsilon, absl::string_view data_format, bool is_training) { T result; result.epsilon_ = epsilon; @@ -409,7 +409,7 @@ T FusedBatchNormGradAttrs(float epsilon, StringPiece data_format, using BatchNormGradFn = std::function& reserve_spaces, float epsilon, - StringPiece data_format, bool is_training, + absl::string_view data_format, bool is_training, std::vector* grad_outputs)>; absl::Status BaseFusedBatchNormGrad(const Scope& scope, const Operation& op, @@ -465,7 +465,7 @@ absl::Status BaseFusedBatchNormGrad(const Scope& scope, const Operation& op, grad_y = Transpose(scope, grad_y, {0, 2, 3, 4, 1}); } - StringPiece target_data_format; + absl::string_view target_data_format; if (data_format == "NCHW" || data_format == "NHWC") { target_data_format = "NHWC"; } else { @@ -491,7 +491,7 @@ absl::Status FusedBatchNormV3Grad(const Scope& scope, const Operation& op, scope, op, grad_inputs, [](const Scope& scope, Output x, Output grad_y, Output scale, const std::vector& reserve_spaces, float epsilon, - StringPiece data_format, bool is_training, + absl::string_view data_format, bool is_training, std::vector* grad_outputs) { FusedBatchNormGradV3 grad( scope, grad_y, x, scale, reserve_spaces[0], reserve_spaces[1], diff --git a/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc index c5489afd9633d8..ac85bd728cb7e4 100644 --- a/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc +++ b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc @@ -37,7 +37,7 @@ using tensorflow::experimental::cc::Status; constexpr char kTestData[] = "cc/saved_model/testdata"; -std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) { +std::string SavedModelPath(absl::string_view saved_model_dir) { return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(), kTestData, saved_model_dir); } diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index a1f4170adafc29..0031cffb820cbd 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -148,7 +148,7 @@ Tensor CreateStringTensor(const string& value) { return tensor; } -void AddAssetsTensorsToInputs(const StringPiece export_dir, +void AddAssetsTensorsToInputs(const absl::string_view export_dir, const std::vector& asset_file_defs, std::vector>* inputs) { if (asset_file_defs.empty()) { @@ -229,8 +229,8 @@ absl::Status RunInitOp(const RunOptions& run_options, const string& export_dir, } absl::Status RunRestore(const RunOptions& run_options, const string& export_dir, - const StringPiece restore_op_name, - const StringPiece variable_filename_const_op_name, + const absl::string_view restore_op_name, + const absl::string_view variable_filename_const_op_name, const std::vector& asset_file_defs, Session* session) { LOG(INFO) << "Restoring SavedModel bundle."; diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD index c51a9f639abb38..d1ce9afb542fb9 100644 --- a/tensorflow/cc/tools/BUILD +++ b/tensorflow/cc/tools/BUILD @@ -49,6 +49,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", "@local_tsl//tsl/platform:errors", ], ) diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc index c23f9161a448fd..3fab4536106b72 100644 --- a/tensorflow/cc/tools/freeze_saved_model.cc +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -15,8 +15,11 @@ limitations under the License. #include "tensorflow/cc/tools/freeze_saved_model.h" -#include +#include #include +#include +#include +#include #include "absl/log/log.h" #include "absl/status/status.h" @@ -127,7 +130,7 @@ void GetReachableNodesAndVariables( } // Gets a map from variable name to variable value. -Status GetVariableNameToTensorMap( +absl::Status GetVariableNameToTensorMap( Session* session, const std::unordered_map& name_to_node_map, std::unordered_set variable_names_set, @@ -220,9 +223,9 @@ StatusOr GetHandleNameIfNeedsToFreeze( } // Freezes the subgraph of all nodes needed by `outputs`. -Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, - const std::unordered_set& outputs, - GraphDef* frozen_graph_def) { +absl::Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, + const std::unordered_set& outputs, + GraphDef* frozen_graph_def) { GraphDef graph_def = saved_model_bundle.meta_graph_def.graph_def(); // Copy versions and library as-is from original graph. *frozen_graph_def->mutable_versions() = graph_def.versions(); @@ -282,10 +285,10 @@ Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, } // namespace -Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle, - GraphDef* frozen_graph_def, - std::unordered_set* inputs, - std::unordered_set* outputs) { +absl::Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle, + GraphDef* frozen_graph_def, + std::unordered_set* inputs, + std::unordered_set* outputs) { GetSignatureDefsInputsAndOutputs(saved_model_bundle, inputs, outputs); TF_RETURN_IF_ERROR( FreezeGraphDef(saved_model_bundle, *outputs, frozen_graph_def)); diff --git a/tensorflow/cc/tools/freeze_saved_model.h b/tensorflow/cc/tools/freeze_saved_model.h index 284a038278fa13..8a35bafe069924 100644 --- a/tensorflow/cc/tools/freeze_saved_model.h +++ b/tensorflow/cc/tools/freeze_saved_model.h @@ -34,10 +34,10 @@ namespace tensorflow { // in the SavedModelBundle. // WARNING: Only the variable checkpoints will be reflected in the frozen // graph_def. All saved_model assets will be ignored. -Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle, - GraphDef* frozen_graph_def, - std::unordered_set* inputs, - std::unordered_set* outputs); +absl::Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle, + GraphDef* frozen_graph_def, + std::unordered_set* inputs, + std::unordered_set* outputs); } // namespace tensorflow diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc index a64aab9e0bb5f5..6fd6fff1836d14 100644 --- a/tensorflow/cc/tools/freeze_saved_model_test.cc +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -15,6 +15,12 @@ limitations under the License. #include "tensorflow/cc/tools/freeze_saved_model.h" +#include +#include +#include +#include + +#include "absl/status/status.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/array_ops.h" @@ -70,7 +76,7 @@ class FreezeTest : public ::testing::Test { // Adds an initialized session to `saved_model_bundle` using `graph_def` and // initializing with `init_node`. - Status InitializeSavedModelBundleSession( + absl::Status InitializeSavedModelBundleSession( const GraphDef& graph_def, const string& init_node, SavedModelBundle* saved_model_bundle) { SessionOptions session_options; @@ -86,9 +92,9 @@ class FreezeTest : public ::testing::Test { // Adds `graph_def` to `saved_model_bundle` and initializes a session with // `init_node`. - Status AddGraphDefToSavedModelBundle(const GraphDef& graph_def, - const string& init_node, - SavedModelBundle* saved_model_bundle) { + absl::Status AddGraphDefToSavedModelBundle( + const GraphDef& graph_def, const string& init_node, + SavedModelBundle* saved_model_bundle) { MetaGraphDef* meta_graph_def = &saved_model_bundle->meta_graph_def; *meta_graph_def->mutable_graph_def() = graph_def; return InitializeSavedModelBundleSession(graph_def, init_node, @@ -97,7 +103,7 @@ class FreezeTest : public ::testing::Test { // Adds `graph_def` and `outputs` as the GraphDef and SignatureDef in // `saved_model_bundle` and initializes a session with `init_node`. - Status AddGraphDefWithOutputsToSavedModelBundle( + absl::Status AddGraphDefWithOutputsToSavedModelBundle( const GraphDef& graph_def, const std::unordered_set& outputs, const string& init_node, SavedModelBundle* saved_model_bundle) { SignatureDef signature_def = diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 4666ddd5db9ed6..baf2cef3e80b45 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -51,7 +51,7 @@ bool IsAlpha(char c) { bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); } // Convert an XLA type into a C++ type. -Status XLATypeToCpp(xla::PrimitiveType type, string* str) { +absl::Status XLATypeToCpp(xla::PrimitiveType type, string* str) { switch (type) { case xla::PRED: *str = "bool"; @@ -127,8 +127,9 @@ std::vector ExtractTempBufferInfos( // Add (from,to) rewrite pairs based on the given shape. These rewrite pairs // are used to generate methods for args and results. -Status AddRewritesForShape(int i, const xla::Shape& shape, - std::vector>* rewrites) { +absl::Status AddRewritesForShape( + int i, const xla::Shape& shape, + std::vector>* rewrites) { string type; TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type)); std::vector dim_vars; @@ -171,9 +172,10 @@ string RewriteWithName(const string& name, string code, } // Generate methods for args (inputs). -Status GenArgMethods(const tf2xla::Config& config, - const xla::ProgramShapeProto& ps, - const CompileResult& compile_result, string* methods) { +absl::Status GenArgMethods(const tf2xla::Config& config, + const xla::ProgramShapeProto& ps, + const CompileResult& compile_result, + string* methods) { const int num_args = ps.parameters_size(); // feed_size() + variable_size() is the maximum number of args as an // implementation may not create an argument for an unused variable. @@ -220,8 +222,9 @@ Status GenArgMethods(const tf2xla::Config& config, } // Generate methods for results (outputs). -Status GenResultMethods(const tf2xla::Config& config, - const xla::ProgramShapeProto& ps, string* methods) { +absl::Status GenResultMethods(const tf2xla::Config& config, + const xla::ProgramShapeProto& ps, + string* methods) { if (ps.result().element_type() != xla::TUPLE) { // The XlaCompiler we use to build the xla computation always generates a // tuple result, and we rely on this to simplify code generation. @@ -274,8 +277,9 @@ Status GenResultMethods(const tf2xla::Config& config, } // Generate methods for variables. -Status GenVariableMethods(const tf2xla::Config& config, - const xla::ProgramShapeProto& ps, string* methods) { +absl::Status GenVariableMethods(const tf2xla::Config& config, + const xla::ProgramShapeProto& ps, + string* methods) { const int num_args = ps.parameters_size(); for (int i = config.feed_size(); i < num_args; ++i) { std::vector> rewrites; @@ -315,7 +319,7 @@ Status GenVariableMethods(const tf2xla::Config& config, } // Generate shape infos for args (inputs). -Status GenArgShapeInfos(const xla::ProgramShapeProto& ps, string* infos) { +absl::Status GenArgShapeInfos(const xla::ProgramShapeProto& ps, string* infos) { for (int i = 0; i < ps.parameters_size(); ++i) { const xla::ShapeProto& shape = ps.parameters(i); if (shape.element_type() == xla::TUPLE) { @@ -352,7 +356,8 @@ Status GenArgShapeInfos(const xla::ProgramShapeProto& ps, string* infos) { } // Generate shape infos for results. -Status GenResultShapeInfos(const xla::ProgramShapeProto& ps, string* infos) { +absl::Status GenResultShapeInfos(const xla::ProgramShapeProto& ps, + string* infos) { if (ps.result().element_type() != xla::TUPLE) { return absl::InternalError("codegen requires the XLA result to be a tuple"); } @@ -417,7 +422,7 @@ string GenNameToIndexCode(const T& entries, bool generate) { return code; } -Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { +absl::Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { for (const tf2xla::Feed& feed : config.feed()) { if (!feed.name().empty()) { TF_RETURN_IF_ERROR(ValidateCppIdent(feed.name(), "feed name")); @@ -462,7 +467,7 @@ std::vector BufferInfosToCppExpression( return buffer_infos_as_strings; } -Status CheckEqual(size_t a, size_t b, absl::string_view error_msg) { +absl::Status CheckEqual(size_t a, size_t b, absl::string_view error_msg) { if (a != b) { return absl::InternalError( absl::StrCat(error_msg, ". Expected ", a, ", got ", b, ".")); @@ -471,9 +476,11 @@ Status CheckEqual(size_t a, size_t b, absl::string_view error_msg) { } } // namespace -Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, - const CompileResult& compile_result, - const MetadataResult& metadata_result, string* header) { +absl::Status GenerateHeader(const CodegenOpts& opts, + const tf2xla::Config& config, + const CompileResult& compile_result, + const MetadataResult& metadata_result, + string* header) { TF_RETURN_IF_ERROR(ValidateConfig(config)); TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config)); const int64_t result_index = compile_result.aot->result_buffer_index(); @@ -858,9 +865,9 @@ static string CreateUniqueIdentifier(const CodegenOpts& opts, return result; } -Status GenerateMetadata(const CodegenOpts& opts, - const CompileResult& compile_result, - MetadataResult* metadata_result) { +absl::Status GenerateMetadata(const CodegenOpts& opts, + const CompileResult& compile_result, + MetadataResult* metadata_result) { std::unique_ptr program_shape; if (opts.gen_program_shape) { @@ -904,8 +911,8 @@ Status GenerateMetadata(const CodegenOpts& opts, return absl::OkStatus(); } -Status ParseCppClass(const string& cpp_class, string* class_name, - std::vector* namespaces) { +absl::Status ParseCppClass(const string& cpp_class, string* class_name, + std::vector* namespaces) { class_name->clear(); namespaces->clear(); if (cpp_class.empty()) { @@ -930,7 +937,7 @@ Status ParseCppClass(const string& cpp_class, string* class_name, return absl::OkStatus(); } -Status ValidateCppIdent(absl::string_view ident, absl::string_view msg) { +absl::Status ValidateCppIdent(absl::string_view ident, absl::string_view msg) { if (ident.empty()) { return errors::InvalidArgument("empty identifier: ", msg); } diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index a0caceaf4c6af0..993196b114da6b 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -76,9 +76,9 @@ struct MetadataResult { // Generates a metadata object file according to `opts` and `compile_result`. // The generated object file is returned via `metadata_result`. -Status GenerateMetadata(const CodegenOpts& opts, - const CompileResult& compile_result, - MetadataResult* metadata_result); +absl::Status GenerateMetadata(const CodegenOpts& opts, + const CompileResult& compile_result, + MetadataResult* metadata_result); // GenerateHeader uses the meta-information from compile_result to generate a // C++ header giving access to the function in the generated object file. The @@ -86,20 +86,22 @@ Status GenerateMetadata(const CodegenOpts& opts, // // metadata_result is an instance of MetadataResult obtained by a previous // invocation to GenerateMetadata. -Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, - const CompileResult& compile_result, - const MetadataResult& metadata_result, string* header); +absl::Status GenerateHeader(const CodegenOpts& opts, + const tf2xla::Config& config, + const CompileResult& compile_result, + const MetadataResult& metadata_result, + string* header); // ParseCppClass parses `cpp_class` into its `class_name` and `namespaces` // components. The syntax is [[::],...]. This // mirrors the C++ syntax for referring to a class, where multiple namespaces // may precede the class name, separated by double-colons. -Status ParseCppClass(const string& cpp_class, string* class_name, - std::vector* namespaces); +absl::Status ParseCppClass(const string& cpp_class, string* class_name, + std::vector* namespaces); // ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is // appended to error messages. -Status ValidateCppIdent(absl::string_view ident, absl::string_view msg); +absl::Status ValidateCppIdent(absl::string_view ident, absl::string_view msg); } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 7880ba7e235026..7056d85590143f 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -39,7 +39,7 @@ namespace { using ::xla::cpu_function_runtime::BufferInfo; -void ExpectErrorContains(const Status& status, absl::string_view str) { +void ExpectErrorContains(const absl::Status& status, absl::string_view str) { EXPECT_NE(absl::OkStatus(), status); EXPECT_TRUE(absl::StrContains(status.message(), str)) << "expected error: " << status.message() << " to contain: " << str; diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 9dee02eb8e2548..0074d61baa7373 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -59,10 +59,10 @@ bool RegisterQuantizeFn(const QuantizeXlaFn& fn) { namespace { // Compiles the XLA computation into executable code. -Status CompileXla(xla::CompileOnlyClient* client, - const xla::XlaComputation& computation, - const xla::cpu::CpuAotCompilationOptions& aot_opts, - CompileResult* compile_result) { +absl::Status CompileXla(xla::CompileOnlyClient* client, + const xla::XlaComputation& computation, + const xla::cpu::CpuAotCompilationOptions& aot_opts, + CompileResult* compile_result) { // Retrieves arg and result layouts from the computation. // TODO(toddw): Should we let the user choose the major/minor ordering? absl::StatusOr> pshape_or = @@ -105,8 +105,9 @@ Status CompileXla(xla::CompileOnlyClient* client, } // namespace -Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config, - const MainFlags& flags, CompileResult* compile_result) { +absl::Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config, + const MainFlags& flags, + CompileResult* compile_result) { // Converts the graph into an XLA computation, and compiles the // computation. // TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client? @@ -170,7 +171,8 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config, return CompileXla(client, computation, aot_opts, compile_result); } -static Status ReadProtoFile(const string& fname, protobuf::Message* proto) { +static absl::Status ReadProtoFile(const string& fname, + protobuf::Message* proto) { if (absl::EndsWith(fname, ".pbtxt")) { return ReadTextProto(Env::Default(), fname, proto); } else { @@ -243,7 +245,7 @@ static std::string InterpolateErrorMessage(std::string message) { return message; } -Status Main(const MainFlags& flags) { +absl::Status Main(const MainFlags& flags) { absl::call_once(targets_init, &InitializeTargets); // Process config. @@ -270,7 +272,7 @@ Status Main(const MainFlags& flags) { TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def)); CompileResult compile_result; - Status status = + absl::Status status = CompileGraph(std::move(graph_def), config, flags, &compile_result); if (!status.ok()) { return errors::CreateWithUpdatedMessage( diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h index 0acb39fda98a75..9d3ff78af89a92 100644 --- a/tensorflow/compiler/aot/compile.h +++ b/tensorflow/compiler/aot/compile.h @@ -43,11 +43,12 @@ struct CompileResult { // that performs the graph operations. // // The XLA compilation options are specified in the flags. -Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config, - const MainFlags& flags, CompileResult* compile_result); +absl::Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config, + const MainFlags& flags, + CompileResult* compile_result); // The full compilation method, for reuse in a library setting. -Status Main(const MainFlags& flags); +absl::Status Main(const MainFlags& flags); } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/quantize.h b/tensorflow/compiler/aot/quantize.h index e2412749290e77..62f03808798779 100644 --- a/tensorflow/compiler/aot/quantize.h +++ b/tensorflow/compiler/aot/quantize.h @@ -28,8 +28,8 @@ limitations under the License. namespace tensorflow { namespace tfcompile { -using QuantizeXlaFn = std::function; +using QuantizeXlaFn = std::function; // Set the static quantization function to the `fn` if it hasn't been set. // Return false if the static function has been set. diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 82fdb603138136..99c8541c55488c 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -212,7 +212,6 @@ def _tf_library( ] + freeze_saver_srcs, outs = [freeze_file], cmd = ( - "PYWRAP_TARGET='//tensorflow/python:_pywrap_tensorflow' " + "CUDA_VISIBLE_DEVICES='' " + "$(location " + "//tensorflow/python/tools:freeze_graph)" + diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index b6b70a6f04d0f5..a2a00afe47f0fc 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -76,7 +76,7 @@ int main(int argc, char** argv) { tensorflow::port::InitMain(usage.c_str(), &argc, &argv); QCHECK(argc == 1) << "\nERROR: This command does not take any arguments " "other than flags. See --help.\n\n"; - tensorflow::Status status = tensorflow::tfcompile::Main(flags); + absl::Status status = tensorflow::tfcompile::Main(flags); if (status.code() == absl::StatusCode::kInvalidArgument) { std::cerr << "INVALID ARGUMENTS: " << status.message() << "\n\n"; return 1; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 6022f4b2618e02..acc62243f90488 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -164,6 +164,7 @@ cc_library( "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@local_xla//xla/stream_executor:platform_manager", diff --git a/tensorflow/compiler/jit/cluster_scoping_pass.cc b/tensorflow/compiler/jit/cluster_scoping_pass.cc index e4efb8922089c6..e70be48f0b7341 100644 --- a/tensorflow/compiler/jit/cluster_scoping_pass.cc +++ b/tensorflow/compiler/jit/cluster_scoping_pass.cc @@ -60,7 +60,7 @@ std::optional GetXlaInternalScope(Node* node) { return std::nullopt; } -void SetXlaInternalScope(Node* node, StringPiece scope) { +void SetXlaInternalScope(Node* node, absl::string_view scope) { node->AddAttr(kXlaInternalScopeAttr, scope); } diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 1adbac0e5e187a..0e59bf0c19d93e 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -146,8 +146,8 @@ absl::Status RewriteSubgraph( bool a_is_resource = (a->output_type(0) == DT_RESOURCE); bool b_is_resource = (b->output_type(0) == DT_RESOURCE); // Uses the name as a tiebreaker so the output is deterministic. - StringPiece a_name(a->name()); - StringPiece b_name(b->name()); + absl::string_view a_name(a->name()); + absl::string_view b_name(b->name()); return std::tie(a_is_resource, a_name) < std::tie(b_is_resource, b_name); }); diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 35f3c82eca3b27..d07a21035a7844 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -60,6 +60,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@local_xla//xla/pjrt:pjrt_client", "@local_xla//xla/tsl/concurrency:async_value", ], diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 86cb79d981ee85..f50c5a7f610d41 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -29,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/device_compilation_profiler.h" #include "tensorflow/compiler/jit/device_compiler.h" @@ -205,7 +205,7 @@ XlaComputationLaunchContext GetLaunchContext( return launch_context; } -absl::Status GetTaskName(const std::string_view device_name, +absl::Status GetTaskName(const absl::string_view device_name, std::string* task_name) { string ignored; if (!DeviceNameUtils::SplitDeviceName(device_name, task_name, &ignored)) { diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index dcc661e4f73cf5..462c5c446b28c7 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -587,7 +587,7 @@ absl::Status XlaDevice::RefreshStatus() { XlaDeviceOpRegistrations* RegisterXlaDeviceKernels( const char* device, const char* jit_device, OpKernel* (*factory)(OpKernelConstruction*), - StringPiece kernel_class_name) { + absl::string_view kernel_class_name) { XlaOpRegistry::RegisterCompilationKernels(); XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations; for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels( diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index cbaa97dc15e1c0..877d208d2ad220 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -308,7 +308,8 @@ struct XlaDeviceOpRegistrations { XlaDeviceOpRegistrations* RegisterXlaDeviceKernels( const char* device, const char* jit_device, - OpKernel* (*factory)(OpKernelConstruction*), StringPiece kernel_class_name); + OpKernel* (*factory)(OpKernelConstruction*), + absl::string_view kernel_class_name); XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device); diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 64f98698ccd951..f5ecde6aba2149 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_init.h" #include "xla/stream_executor/platform_manager.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { @@ -157,11 +158,29 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaGpuTypes = { - {DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, - DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, - DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, - DT_BFLOAT16, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN, DT_INT4, DT_UINT4}}; +constexpr std::array kAllXlaGpuTypes = {{DT_UINT8, + DT_QUINT8, + DT_UINT16, + DT_INT8, + DT_QINT8, + DT_INT16, + DT_INT32, + DT_QINT32, + DT_INT64, + DT_HALF, + DT_FLOAT, + DT_DOUBLE, + DT_COMPLEX64, + DT_COMPLEX128, + DT_BOOL, + DT_BFLOAT16, + DT_FLOAT8_E5M2, + DT_FLOAT8_E4M3FN, + DT_FLOAT8_E4M3FNUZ, + DT_FLOAT8_E4M3B11FNUZ, + DT_FLOAT8_E5M2FNUZ, + DT_INT4, + DT_UINT4}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes); diff --git a/tensorflow/compiler/jit/xla_host_recv_device_context.cc b/tensorflow/compiler/jit/xla_host_recv_device_context.cc index 479abe923e0fb8..27cb1c67e4293f 100644 --- a/tensorflow/compiler/jit/xla_host_recv_device_context.cc +++ b/tensorflow/compiler/jit/xla_host_recv_device_context.cc @@ -20,7 +20,7 @@ limitations under the License. namespace tensorflow { void XlaHostRecvDeviceContext::CopyDeviceTensorToCPU( - const Tensor* device_tensor, StringPiece tensor_name, Device* device, + const Tensor* device_tensor, absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { DataType dtype = EncodePrimitiveTypeAsDataType(shape_.element_type()).value(); TensorShape tensor_shape; diff --git a/tensorflow/compiler/jit/xla_host_recv_device_context.h b/tensorflow/compiler/jit/xla_host_recv_device_context.h index 028fd4efd68091..d6dfc6f1906e0c 100644 --- a/tensorflow/compiler/jit/xla_host_recv_device_context.h +++ b/tensorflow/compiler/jit/xla_host_recv_device_context.h @@ -66,7 +66,7 @@ class XlaHostRecvDeviceContext : public DeviceContext { // Copies `device_memory_base_` with `shape_` into `cpu_tensor`. // `device_tensor` is unused. void CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, Device* device, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, diff --git a/tensorflow/compiler/jit/xla_host_send_device_context.h b/tensorflow/compiler/jit/xla_host_send_device_context.h index f4e4e9a2535341..52ca612570a2c7 100644 --- a/tensorflow/compiler/jit/xla_host_send_device_context.h +++ b/tensorflow/compiler/jit/xla_host_send_device_context.h @@ -64,7 +64,7 @@ class XlaHostSendDeviceContext : public DeviceContext { bool sync_dst_compute) const override; void CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, Device* device, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override { done(errors::Internal("host->device copy not implemented.")); } diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 90def2e63c7029..217dfdf7b5b5a5 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -65,14 +65,10 @@ cc_library( "//tensorflow/compiler/mlir/tf2xla/internal/passes:mlir_to_graph_passes", "//tensorflow/compiler/mlir/tf2xla/transforms:tf_xla_passes", "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", - "//tensorflow/compiler/mlir/tosa:tf_passes", - "//tensorflow/compiler/mlir/tosa:tf_tfl_passes", - "//tensorflow/compiler/mlir/tosa:tfl_passes", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", - "@local_xla//xla/mlir/framework/ir:xla_framework", "@local_xla//xla/mlir/framework/transforms:passes", "@local_xla//xla/mlir_hlo:all_passes", ], @@ -101,8 +97,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", - "//tensorflow/compiler/mlir/tosa:tf_passes", - "//tensorflow/compiler/mlir/tosa:tfl_passes", ], ) @@ -126,7 +120,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:device_util", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", - "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy", "//tensorflow/compiler/mlir/tf2xla/api/v2:graph_to_tf_executor", @@ -241,6 +234,7 @@ tf_cc_binary( "//tensorflow/core:lib", "//tensorflow/core:tensorflow", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index cc62d279af1a93..cb4e736d4e6c83 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -93,6 +93,7 @@ td_library( "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_optimize_td_files", "@llvm-project//mlir:ArithOpsTdFiles", "@llvm-project//mlir:FuncTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", ], ) @@ -1298,6 +1299,8 @@ cc_library( "transforms/default_quant_params.cc", "transforms/generated_post_quantize.inc", "transforms/generated_quantize.inc", + "transforms/lower_quant_annotations_helper.cc", + "transforms/lower_quant_annotations_pass.cc", "transforms/modify_io_nodes.cc", "transforms/optimize_op_order.cc", "transforms/post_quantize.cc", @@ -1309,6 +1312,7 @@ cc_library( "utils/generated_op_quant_spec_getters.inc", ], hdrs = [ + "transforms/lower_quant_annotations_helper.h", "transforms/passes.h", "transforms/prepare_quantize_helper.h", ], @@ -1330,6 +1334,7 @@ cc_library( "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", + "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -1338,6 +1343,7 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/lite/converter_flags.proto b/tensorflow/compiler/mlir/lite/converter_flags.proto index 155bf748095e1e..5b6b9e2ca752a6 100644 --- a/tensorflow/compiler/mlir/lite/converter_flags.proto +++ b/tensorflow/compiler/mlir/lite/converter_flags.proto @@ -41,7 +41,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 67. +// Next ID to use: 68. message ConverterFlags { // Input file format optional FileFormat input_format = 1; @@ -380,4 +380,9 @@ message ConverterFlags { // When set to true, debug metadata will be generated and attached to // serialized TFLite flatbuffer. optional bool serialize_debug_metadata = 66 [default = false]; + + // When set, adheres to the QDQ annotations added by the framework when + // possible rather than quantizing any op that is possible to quantize. + // WARNING: Experimental interface, subject to change. + optional bool strict_qdq_mode = 67 [default = false]; } diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index 1941fbd7e63105..6869783209e2fa 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -15,7 +15,6 @@ limitations under the License. #include -#include #include #include #include diff --git a/tensorflow/compiler/mlir/lite/core/c/BUILD b/tensorflow/compiler/mlir/lite/core/c/BUILD index 6448a3a8f8638b..55e349ce6cab86 100644 --- a/tensorflow/compiler/mlir/lite/core/c/BUILD +++ b/tensorflow/compiler/mlir/lite/core/c/BUILD @@ -49,9 +49,7 @@ cc_library( ], compatible_with = get_compatible_with_portable(), copts = tflite_copts(), - visibility = [ - "//tensorflow/lite/ios:__subpackages__", - ], + visibility = ["//visibility:public"], ) # LINT.IfChange(common) diff --git a/tensorflow/compiler/mlir/lite/core/model_builder_base.h b/tensorflow/compiler/mlir/lite/core/model_builder_base.h index 002b745ce8fd02..e7892cc06ae266 100644 --- a/tensorflow/compiler/mlir/lite/core/model_builder_base.h +++ b/tensorflow/compiler/mlir/lite/core/model_builder_base.h @@ -386,20 +386,9 @@ class FlatBufferModelBase { size_t allocation_size = std::min(allocation->bytes(), static_cast(FLATBUFFERS_MAX_BUFFER_SIZE - 1)); - flatbuffers::Verifier::Options options; - // TODO(b/366118885): Remove after the root cause of the crash on Windows - // is found. -#if defined(_WIN32) - options.assert = true; -#if defined(FLATBUFFER_VERIFIER_HAS_CHECK_BUFFER_ALIGNMENT) - // `check_buf_alignment` is not supported in all implementations of - // `flatbuffers::Verifier`. - options.check_buf_alignment = true; -#endif -#endif flatbuffers::Verifier base_verifier( reinterpret_cast(allocation->base()), allocation_size, - options); + flatbuffers::Verifier::Options()); if (!VerifyModelBuffer(base_verifier)) { TF_LITE_REPORT_ERROR(error_reporter, "The model is not a valid Flatbuffer buffer"); diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/cpu_hardware.cc b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/cpu_hardware.cc index 902c59c9b69eb3..bd51fdd9c12b80 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/cpu_hardware.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/cpu_hardware.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.cc b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.cc index 38b0f5b18b8737..19cd2e081a7d1e 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h" +#include +#include #include #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h index 1e6c65333f8e10..149c2076a6154a 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_HARDWARES_GPU_HARDWARE_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_HARDWARES_GPU_HARDWARE_H_ +#include + #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/nnapi_hardware.cc b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/nnapi_hardware.cc index 094eda7f31cf08..4f6a7f834ee692 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/nnapi_hardware.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/nnapi_hardware.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/nnapi_hardware.h" +#include #include #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/simple_hardware.cc b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/simple_hardware.cc index 8fb602123bbd31..8fd92ba66ff16d 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/simple_hardware.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/simple_hardware.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/simple_hardware.h" +#include + #include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h" diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/simple_hardware.h b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/simple_hardware.h index 39e8b5c4f143d2..ca3715448e8a77 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/simple_hardware.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/simple_hardware.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_HARDWARES_SIMPLE_HARDWARE_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_HARDWARES_SIMPLE_HARDWARE_H_ +#include + #include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h" diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD index 0399fa00caf3e1..40b1a1ac905e51 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD @@ -57,7 +57,6 @@ pybind_extension( "@compute_library//:__subpackages__", "@cpuinfo//:__subpackages__", "@curl//:__subpackages__", - "@double_conversion//:__subpackages__", "@eigen_archive//:__subpackages__", "@farmhash_archive//:__subpackages__", "@farmhash_gpu_archive//:__subpackages__", diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.cc index 81bf1477ff0077..306fed5f74d104 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.cc @@ -211,7 +211,7 @@ void OptimizeQuantizedOpToFloat(func::FuncOp func, MLIRContext* context) { patterns .add( context); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } } // namespace tac diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_gpu.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_gpu.cc index c3b4f811ffa0d1..8af57f268c838d 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_gpu.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_gpu.cc @@ -64,7 +64,7 @@ void DeviceTransformGPUPass::runOnOperation() { auto func = getOperation(); auto* ctx = &getContext(); RewritePatternSet patterns = GetHardwareRewritePatternsGPU(ctx); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } } // namespace diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_nnapi.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_nnapi.cc index 4336c660191f9e..e9bdf1f82ffd3b 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_nnapi.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_nnapi.cc @@ -63,7 +63,7 @@ void DeviceTransformNNAPIPass::runOnOperation() { auto* ctx = &getContext(); NNAPIHardware nnapi_hardware; RewritePatternSet patterns = nnapi_hardware.GetTransformations(ctx); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } } // namespace diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc index 7ccf26d3baca41..fd4852b34ed3cf 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc @@ -215,7 +215,7 @@ void AlternativeSubgraphPass::Optimize(func::FuncOp func, const std::string& hardware) { auto* ctx = &getContext(); RewritePatternSet patterns = GetHardwareRewritePatterns(ctx, hardware); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } // Get the alternative view of the func for the given device_inference_type. diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 73e7986b1a6a74..721b787c3e5c9b 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -29,6 +30,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -40,7 +42,6 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/status/status.h" -#include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -69,6 +70,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/DialectResourceBlobManager.h" // from @llvm-project // IWYU pragma: keep #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project @@ -121,7 +123,6 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/tstring.h" #include "tsl/platform/fingerprint.h" -#include "tsl/platform/status.h" #include "tsl/platform/tstring.h" using absl::StatusOr; @@ -737,7 +738,10 @@ class Translator { // Append constant and custom op buffers at the end of the flatbuffer and // calculate the offsets - void AppendBufferData(absl::Cord& result); + void AppendBufferData(std::string& result, int64_t offset); + + // Utility function to return the size of the buffer data. + int64_t GetBufferDataSize(); // Update constant & custom op buffer offsets // Return false if fail to update offset @@ -850,8 +854,19 @@ class Translator { // Maps buffer data to corresponding buffer index // in the idx map, the value is a pair of offset and size absl::flat_hash_map> buffer_idx_map_; - absl::flat_hash_map buffer_data_map_; + // Maps buffer index to buffer data. Prefer string_view to avoid one extra + // copy. As it is, this data will be copied at least once to the flatbuffer. + // We need to find a way to avoid this copy. + absl::flat_hash_map buffer_data_map_; bool buffer_data_exported_ = false; + // strings, buffers, and Tensors that need to be deleted after the flatbuffer + // is built. We're currently using these to hold the data that are created + // from DenseResourceElementsAttr or DenseElementsAttr and hold constant data. + std::vector> tf_tensors_to_delete_; + std::vector>> + string_buffers_to_delete_; + std::vector>> + packed_int4_buffers_to_delete_; // Maps custom options data to corresponding node // Key is set to be the list of input tensor indices and list of output tensor @@ -1001,24 +1016,33 @@ std::optional> Translator::BuildBuffer( for (mlir::APInt v : attr.getValues()) { data.emplace_back(static_cast(*(v.getRawData()))); } - auto packed_buffer = tflite::PackInt4ValuesDensely(data); + auto packed_buffer = std::make_unique>( + tflite::PackInt4ValuesDensely(data)); if (use_buffer_offset_) { buffer_data_map_[index] = - std::string(packed_buffer.begin(), packed_buffer.end()); + absl::string_view(reinterpret_cast(packed_buffer->data()), + packed_buffer->size()); + packed_int4_buffers_to_delete_.emplace_back(std::move(packed_buffer)); return tflite::CreateBuffer(builder_, 0, 1, 1); } else { - if (IsModelBiggerThan2GB(packed_buffer.size())) { + if (IsModelBiggerThan2GB(packed_buffer->size())) { require_use_buffer_offset_ = true; return empty_buffer_; } auto buffer_data = - builder_.CreateVector(packed_buffer.data(), packed_buffer.size()); + builder_.CreateVector(packed_buffer->data(), packed_buffer->size()); return tflite::CreateBuffer(builder_, buffer_data); } } - tensorflow::Tensor tensor; - auto status = tensorflow::ConvertToTensor(attr, &tensor); + auto tensor = std::make_unique(); + auto status = tensorflow::ConvertToTensor(attr, tensor.get()); + // Reset the attribute after copying it to a tensorflow::Tensor because the + // attribute is not needed anymore. + if (auto dense_resource_attr = + dyn_cast(attr)) { + dense_resource_attr.getRawHandle().getResource()->setBlob({}); + } if (!status.ok()) { inst->emitError( Twine("failed to convert value attribute to tensor with error: " + @@ -1028,9 +1052,9 @@ std::optional> Translator::BuildBuffer( // TensorFlow and TensorFlow Lite use different string encoding formats. // Convert to TensorFlow Lite format is it's a constant string tensor. - if (tensor.dtype() == tensorflow::DT_STRING) { + if (tensor->dtype() == tensorflow::DT_STRING) { ::mlir::TFL::SimpleDynamicBuffer dynamic_buffer; - auto flat = tensor.flat<::tensorflow::tstring>(); + auto flat = tensor->flat<::tensorflow::tstring>(); for (int i = 0; i < flat.size(); ++i) { const auto& str = flat(i); if (!dynamic_buffer.AddString(str.c_str(), str.length())) { @@ -1043,10 +1067,11 @@ std::optional> Translator::BuildBuffer( char* tensor_buffer; int bytes = dynamic_buffer.WriteToBuffer(&tensor_buffer); if (use_buffer_offset_) { - std::vector buffer_data(tensor_buffer, tensor_buffer + bytes); - free(tensor_buffer); - buffer_data_map_[index] = - std::string(buffer_data.begin(), buffer_data.end()); + // Avoid creating std::vector and std::string + buffer_data_map_[index] = absl::string_view(tensor_buffer, bytes); + string_buffers_to_delete_.push_back( + std::unique_ptr>(tensor_buffer, + free)); return tflite::CreateBuffer(builder_, 0, 1, 1); } else { if (IsModelBiggerThan2GB(bytes)) { @@ -1060,9 +1085,10 @@ std::optional> Translator::BuildBuffer( } } - absl::string_view tensor_data = tensor.tensor_data(); + absl::string_view tensor_data = std::move(tensor->tensor_data()); if (use_buffer_offset_) { - buffer_data_map_[index] = std::string(tensor_data); + buffer_data_map_[index] = std::move(tensor_data); + tf_tensors_to_delete_.push_back(std::move(tensor)); return tflite::CreateBuffer(builder_, 0, 1, 1); } else { if (IsModelBiggerThan2GB(tensor_data.size())) { @@ -1072,6 +1098,10 @@ std::optional> Translator::BuildBuffer( auto buffer_data = builder_.CreateVector( reinterpret_cast(tensor_data.data()), tensor_data.size()); + // Delete the tensor as the call to CreateVector copies the + // data. We need a better design for this so that we don't have to + // delete the tensor based on the implementation details. + tensor.reset(); return tflite::CreateBuffer(builder_, buffer_data); } } @@ -4068,90 +4098,168 @@ std::optional Translator::TranslateInternal() { tflite::UpdateOpVersion(builder_.GetBufferPointer()); tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer()); - absl::Cord result; + std::string result_string; + int64_t final_result_size = builder_.GetSize(); + + // If we need to use buffer offset, we need to add the buffer data size to the + // final result size. This is because the buffer data size is not included in + // the flatbuffer size. + if (use_buffer_offset_) { + final_result_size += GetBufferDataSize(); + } + result_string.reserve(final_result_size); + + int64_t offset = 0; auto fbs = absl::string_view( reinterpret_cast(builder_.GetBufferPointer()), builder_.GetSize()); - result.Append(fbs); + result_string.replace(offset, fbs.size(), fbs); // Return serialized string for the built FlatBuffer. if (use_buffer_offset_) { + offset += fbs.size(); // Pad to be 16 bytes aligned { - std::string pad(kFbAlignment - result.size() % kFbAlignment, '\0'); - result.Append(std::move(pad)); + std::string pad(kFbAlignment - offset % kFbAlignment, '\0'); + size_t pad_size = pad.size(); + result_string.replace(offset, pad_size, std::move(pad)); + offset += pad_size; } - AppendBufferData(result); - std::string result_str = std::string(std::move(result)); - auto mutable_model = tflite::GetMutableModel(result_str.data()); + AppendBufferData(result_string, offset); + auto mutable_model = tflite::GetMutableModel(result_string.data()); bool ret = UpdateBufferOffsets(mutable_model); if (!ret) { return std::nullopt; } - return result_str; + return result_string; } - return std::string(result); + + // Free all the buffers/tensors, etc. that were created but were kept around + // to copy into the flatbuffer. + for (auto& packed_int4_buffer : packed_int4_buffers_to_delete_) { + packed_int4_buffer.reset(); + } + packed_int4_buffers_to_delete_.clear(); + + for (auto& str_buffer : string_buffers_to_delete_) { + str_buffer.reset(); + } + string_buffers_to_delete_.clear(); + + for (auto& tensor : tf_tensors_to_delete_) { + auto tensor_ptr = tensor.release(); + delete tensor_ptr; + } + tf_tensors_to_delete_.clear(); + + return std::move(result_string); } -void Translator::AppendBufferData(absl::Cord& result) { +int64_t Translator::GetBufferDataSize() { + int64_t final_size = 0; + // 1. FlatBuffer Size, which will be included prior to the buffer data. + + // 2. Alignment Padding for FlatBuffer (if needed) + if (use_buffer_offset_) { + final_size += 16; + } + + // 3. Buffer Data Size (with deduplication) + absl::flat_hash_set unique_buffer_hashes; + for (const auto& [_, buffer] : buffer_data_map_) { + uint64_t hash = tsl::Fingerprint64(buffer); + if (unique_buffer_hashes.insert(hash).second) { // Unique buffer + final_size += buffer.size(); + final_size += 16; // Alignment + } + } + + // 4. Additional Padding for XNNPack + final_size += 16; // Assuming 16 bytes of padding + + // 5. Custom Op Data Size + for (const auto& [_, custom_data] : custom_op_data_map_) { + final_size += 16; // Alignment + if (custom_option_alignment_.has_value()) { + final_size += custom_option_alignment_.value() - + final_size % custom_option_alignment_.value(); + } + final_size += custom_data.size(); + } + + // 6. Final Alignment Padding + final_size += 16; + + return final_size; +} + +void Translator::AppendBufferData(std::string& result, int64_t offset) { std::unordered_map> hashcode_to_pos; // Buffer data should be exported only once. assert(!buffer_data_exported_); - auto it = buffer_data_map_.begin(); - while (it != buffer_data_map_.end()) { - std::string buffer = it->second; - int64_t index = it->first; - int64_t offset = result.size(); + for (const auto& [index, buffer] : buffer_data_map_) { int64_t size = buffer.size(); uint64_t hash = tsl::Fingerprint64(buffer); if (hashcode_to_pos.find(hash) == hashcode_to_pos.end()) { hashcode_to_pos[hash] = std::make_pair(offset, size); buffer_idx_map_[index] = std::make_pair(offset, size); - result.Append(std::move(buffer)); + result.replace(offset, size, std::move(buffer)); + offset += size; // Pad to be 16 bytes aligned. { - std::string pad(kFbAlignment - result.size() % kFbAlignment, '\0'); - result.Append(std::move(pad)); + std::string pad(kFbAlignment - offset % kFbAlignment, '\0'); + size_t pad_size = pad.size(); + result.replace(offset, pad_size, std::move(pad)); + offset += pad_size; } } else { // only update offset/index. buffer_idx_map_[index] = hashcode_to_pos[hash]; } - buffer_data_map_.erase(it); - it = buffer_data_map_.begin(); buffer_data_exported_ = true; } - // pad 16 bytes for the last buffer for XNNPack - result.Append("\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"); + { + // pad 16 bytes for the last buffer for XNNPack + std::string pad(16, '\0'); + size_t pad_size = pad.size(); + result.replace(offset, pad_size, std::move(pad)); + offset += pad_size; + } // pad to be 16 bytes aligned { - std::string pad(kFbAlignment - result.size() % kFbAlignment, '\0'); - result.Append(std::move(pad)); + std::string pad(kFbAlignment - offset % kFbAlignment, '\0'); + size_t pad_size = pad.size(); + result.replace(offset, pad_size, std::move(pad)); + offset += pad_size; } for (auto& it : custom_op_data_map_) { { - std::string pad(kFbAlignment - result.size() % kFbAlignment, '\0'); - result.Append(std::move(pad)); + std::string pad(kFbAlignment - offset % kFbAlignment, '\0'); + size_t pad_size = pad.size(); + result.replace(offset, pad_size, std::move(pad)); + offset += pad_size; } if (custom_option_alignment_.has_value()) { { auto alignment = custom_option_alignment_.value(); - std::string pad(alignment - result.size() % alignment, '\0'); - result.Append(std::move(pad)); + std::string pad(alignment - offset % alignment, '\0'); + size_t pad_size = pad.size(); + result.replace(offset, pad_size, std::move(pad)); + offset += pad_size; } } auto buffer = std::string(it.second.begin(), it.second.end()); - int64_t offset = result.size(); int64_t size = it.second.size(); custom_op_idx_map_[it.first] = std::make_pair(offset, size); - result.Append(std::move(buffer)); + result.replace(offset, size, std::move(buffer)); + offset += size; } // pad to be 16 bytes aligned { - std::string pad(kFbAlignment - result.size() % kFbAlignment, '\0'); - result.Append(std::move(pad)); + std::string pad(kFbAlignment - offset % kFbAlignment, '\0'); + result.replace(offset, pad.size(), std::move(pad)); } } diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h index 014142b131e8eb..f0afe15f8d5657 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h @@ -137,7 +137,7 @@ llvm::MinMax OperandNumbersMinMax(llvm::StringRef op_name); // `custom_code` is used to identify CustomOp. // `custom_options` are opaque attribute used to store infomations for this // custom op. -tensorflow::Status CustomOptionsToAttributes( +absl::Status CustomOptionsToAttributes( const std::string &custom_code, const std::vector &custom_options, mlir::Builder builder, // NOLINTNEXTLINE diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 23db5dd0b41a49..953390e699e1e9 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -413,6 +413,24 @@ struct RemoveOptionalZeroBias : public OpRewritePattern { } }; +struct SetAsymmetricQuantizeInput : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(FullyConnectedOp op, + PatternRewriter& rewriter) const override { + if (op.getAsymmetricQuantizeInputs() == std::nullopt || + op.getAsymmetricQuantizeInputs() == false) { + auto new_op = rewriter.create( + op.getLoc(), op.getOutput().getType(), op.getInput(), op.getFilter(), + op.getBias(), op.getFusedActivationFunction(), op.getWeightsFormat(), + op.getKeepNumDims(), rewriter.getBoolAttr(true)); + rewriter.replaceOp(op, new_op.getOutput()); + return success(); + } + return failure(); + } +}; + // Return true if the given Add operation has the CPU kernel supported shapes. bool VerifyAddOpShapeConstraints(AddOp op) { auto element_type = getElementTypeOrSelf(op.getOutput().getType()); @@ -1624,6 +1642,7 @@ LogicalResult FullyConnectedOp::fold(FoldAdaptor adaptor, void FullyConnectedOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.add>(context); + results.add(context); } int64_t FullyConnectedOp::GetArithmeticCount(Operation* op) { diff --git a/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc b/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc index f16ade3c0066c9..8b238abe0e3162 100644 --- a/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc +++ b/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc @@ -14,9 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" -#include #include -#include #include #include #include diff --git a/tensorflow/compiler/mlir/lite/metrics/types_util.cc b/tensorflow/compiler/mlir/lite/metrics/types_util.cc index d13df105fcf322..f6707b71cb2d82 100644 --- a/tensorflow/compiler/mlir/lite/metrics/types_util.cc +++ b/tensorflow/compiler/mlir/lite/metrics/types_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/lite/metrics/types_util.h" +#include #include #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/lite/metrics/types_util.h b/tensorflow/compiler/mlir/lite/metrics/types_util.h index aa85396aed4012..7fe31a38e24b56 100644 --- a/tensorflow/compiler/mlir/lite/metrics/types_util.h +++ b/tensorflow/compiler/mlir/lite/metrics/types_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_METRICS_TYPES_UTIL_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_METRICS_TYPES_UTIL_H_ +#include #include #include diff --git a/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.pyi b/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.pyi index 989d4f1dbe56fb..7557dee725f4c6 100644 --- a/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.pyi +++ b/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.pyi @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -def Convert(model_flags_proto_txt_raw: object, converter_flags_proto_txt_raw: object, input_contents_txt_raw: object, extended_return: bool = ..., debug_info_txt_raw: object = ..., quantization_py_function_library = ...) -> object: ... +def Convert(model_flags_proto_txt_raw: object, converter_flags_proto_txt_raw: object, input_contents_txt_raw: object, extended_return: bool = ..., debug_info_txt_raw: object = ..., quantization_py_function_library=...) -> object: ... def ExperimentalMlirQuantizeModel(input_contents_txt_raw: object, disable_per_channel: bool = ..., fully_quantize: bool = ..., inference_type: int = ..., input_data_type: int = ..., output_data_type: int = ..., enable_numeric_verify: bool = ..., enable_whole_model_verify: bool = ..., op_blocklist: object = ..., node_blocklist: object = ..., enable_variable_quantization: bool = ..., disable_per_channel_for_dense_layers: bool = ..., debug_options_proto_txt_raw: object = ...) -> object: ... def ExperimentalMlirSparsifyModel(input_contents_txt_raw: object) -> object: ... def FlatBufferToMlir(arg0: str, arg1: bool) -> str: ... diff --git a/tensorflow/compiler/mlir/lite/python/converter_python_api.cc b/tensorflow/compiler/mlir/lite/python/converter_python_api.cc index c7059d721a062f..2f6e98623cff3c 100644 --- a/tensorflow/compiler/mlir/lite/python/converter_python_api.cc +++ b/tensorflow/compiler/mlir/lite/python/converter_python_api.cc @@ -136,7 +136,7 @@ PyObject* Convert(PyObject* model_flags_proto_txt_raw, } std::string output_file_contents_txt; - tensorflow::Status status; + absl::Status status; // Convert model. if (model_flags.use_hlo_import() && model_flags.has_saved_model_dir()) { @@ -387,8 +387,7 @@ PyObject* RegisterCustomOpdefs(PyObject* list) { // Register extra opdefs to TensorFlow global op registry. tensorflow::OpRegistry::Global()->Register( - [opdef]( - tensorflow::OpRegistrationData* op_reg_data) -> tensorflow::Status { + [opdef](tensorflow::OpRegistrationData* op_reg_data) -> absl::Status { *op_reg_data = tensorflow::OpRegistrationData(opdef); return absl::OkStatus(); }); diff --git a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h index 94657a52b1436f..9008560f24ed2c 100644 --- a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h @@ -28,10 +28,9 @@ namespace tensorflow { // Converts the given Jax model to a TF Lite FlatBuffer // string according to the given model flags, converter flags and tags. Returns // error status if it fails to convert the input. -Status ConvertJaxToTFLiteFlatBuffer(const std::string& input, - const tflite::ModelFlags& model_flags, - tflite::ConverterFlags& converter_flags, - string* result); +absl::Status ConvertJaxToTFLiteFlatBuffer( + const std::string& input, const tflite::ModelFlags& model_flags, + tflite::ConverterFlags& converter_flags, string* result); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index f43139021b74af..3959901428c3d5 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -56,7 +56,7 @@ namespace tensorflow { using tensorflow::quantization::PyFunctionLibrary; -Status HandleInputOutputArraysWithModule( +absl::Status HandleInputOutputArraysWithModule( const tflite::ModelFlags& model_flags, mlir::OwningOpRef* module) { mlir::func::FuncOp entry_function = nullptr; @@ -132,7 +132,7 @@ Status HandleInputOutputArraysWithModule( return absl::OkStatus(); } -Status ConvertSavedModelToTFLiteFlatBuffer( +absl::Status ConvertSavedModelToTFLiteFlatBuffer( const tflite::ModelFlags& model_flags, tflite::ConverterFlags& converter_flags, std::string* result, const PyFunctionLibrary* quantization_py_function_lib) { @@ -218,6 +218,8 @@ Status ConvertSavedModelToTFLiteFlatBuffer( pass_config.canonicalizing_inf_as_min_max_float = converter_flags.canonicalizing_inf_as_min_max_float(); + pass_config.quant_specs.strict_qdq_mode = converter_flags.strict_qdq_mode(); + if (converter_flags.qdq_conversion_mode() == "STATIC") { pass_config.quant_specs.qdq_conversion_mode = mlir::quant::QDQConversionMode::kQDQStatic; diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h index 39a97a93ea82a7..9280104763849f 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h @@ -29,7 +29,7 @@ namespace tensorflow { // Converts the given saved_model(either v1 or v2) to a TF Lite FlatBuffer // string according to the given model flags, converter flags and tags. Returns // error status if it fails to convert the input. -Status ConvertSavedModelToTFLiteFlatBuffer( +absl::Status ConvertSavedModelToTFLiteFlatBuffer( const tflite::ModelFlags& model_flags, tflite::ConverterFlags& converter_flags, string* result, const quantization::PyFunctionLibrary* quantization_py_function_lib); diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index 20c63dc41f016b..de1e33f01cfbea 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -37,11 +37,12 @@ namespace tensorflow { namespace internal { // Register all custom ops including user specified custom ops. -Status RegisterAllCustomOps(const tflite::ConverterFlags& converter_flags); +absl::Status RegisterAllCustomOps( + const tflite::ConverterFlags& converter_flags); // Populate quantization specs (or not) given user specified ranges for each // input arrays. -Status PopulateQuantizationSpecs( +absl::Status PopulateQuantizationSpecs( const tflite::ModelFlags& model_flags, tflite::ConverterFlags& converter_flags, mlir::quant::QuantizationSpecs* quant_specs, @@ -52,7 +53,7 @@ Status PopulateQuantizationSpecs( // Convert imported MLIR file to TfLite flatbuffer. // This will also run relevant passes as well. -Status ConvertMLIRToTFLiteFlatBuffer( +absl::Status ConvertMLIRToTFLiteFlatBuffer( const tflite::ModelFlags& model_flags, tflite::ConverterFlags& converter_flags, std::unique_ptr&& context, diff --git a/tensorflow/compiler/mlir/lite/quantization/device_target.h b/tensorflow/compiler/mlir/lite/quantization/device_target.h index 2fc4f248dbc0d0..01072c50677821 100644 --- a/tensorflow/compiler/mlir/lite/quantization/device_target.h +++ b/tensorflow/compiler/mlir/lite/quantization/device_target.h @@ -160,7 +160,7 @@ class DeviceTarget { // Adds the kernel spec with the scale constraint type for the kernel. LogicalResult RegisterKernel(llvm::StringRef kernel, const KernelSpecs::Signature& signature, - const ScaleConstraintType constraint); + ScaleConstraintType constraint); // Adds the kernel with the name. Retrun an existing one if it has been // added before. diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc index 22d5fb6743f4ef..4e67350bdaee84 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc @@ -113,7 +113,7 @@ void ConvertConstPass::runOnOperation() { auto func = getOperation(); auto *context = &getContext(); patterns.add(context); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } std::unique_ptr> diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc index d38b9e39423c8a..ddede29c0d7ed6 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc @@ -145,7 +145,7 @@ void ConvertSimulatedQuantPass::runOnOperation() { auto *ctx = func.getContext(); patterns.add( ctx, &hadFailure); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); if (hadFailure) signalPassFailure(); } diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 366cbc254bdf7a..2a7d16e2512aab 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -214,6 +214,7 @@ tf_cc_test( "//tensorflow/compiler/mlir/lite/schema:schema_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@flatbuffers", diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc index 008a454b851705..f274a84470d71b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/status/status.h" #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.h b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h index a1f40f867787e0..2b33e1e65b5837 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_context.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h @@ -130,7 +130,7 @@ class QuantizeContext { // ops, which have the parameters propagated to, are collected by `new_items`, // so they can be added to the working queue. `changed` is set to true if // there are any new elements being added to `new_items`. - LogicalResult PropagateQuantParams(Operation *op, const QuantParams params, + LogicalResult PropagateQuantParams(Operation *op, QuantParams params, AdjacentOperations *new_items, bool *changed); diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc index c30c91ae8180dd..1d3abaa9570a2a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc @@ -296,7 +296,7 @@ void FallbackToFlexOps::runOnOperation() { // Convert binary ops to BiasAdd ops if possible. RewritePatternSet patterns(ctx); populateWithGenerated(patterns); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); // Convert unsupported ops to Flex ops. auto tf_dialect = ctx->getLoadedDialect(); diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc index 12856137123f63..a60ac436b56b63 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc @@ -176,7 +176,7 @@ void LegalizeTFToQuant::runOnOperation() { auto func = getOperation(); auto *ctx = func.getContext(); patterns.add(ctx); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } } // namespace diff --git a/tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h b/tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h index b9756dd7517548..ebf9219f4a249b 100644 --- a/tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h +++ b/tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h @@ -20,8 +20,7 @@ limitations under the License. namespace tflite { -int8_t ConvertBuiltinCodeToDeprecatedBuiltinCode( - const BuiltinOperator builtin_code); +int8_t ConvertBuiltinCodeToDeprecatedBuiltinCode(BuiltinOperator builtin_code); // The following methods are for backward compatibility for the early version // three, which does not have an extended builtin code. diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 4773611054df6f..153ce23dae7a67 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -508,6 +508,7 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/tensorflow", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.cc index 668fe06515812e..8c3881a954e01b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.cc @@ -44,7 +44,7 @@ class SHLOSimplifyPass : public impl::SHLOSimplifyPassBase { RewritePatternSet patterns(&getContext()); populateWithGenerated(patterns); PopulateFolderPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) { + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc index bdba7dc58a379f..2bcf63a9313422 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc @@ -198,9 +198,9 @@ absl::StatusOr> ImportSavedModelOrMLIR( saved_model_bundle); } -tensorflow::Status ExportModule(mlir::ModuleOp module, - const std::string& output_filename, - bool elide_large_elements_attrs) { +absl::Status ExportModule(mlir::ModuleOp module, + const std::string& output_filename, + bool elide_large_elements_attrs) { std::string error_msg; auto output = mlir::openOutputFile(output_filename, &error_msg); if (output == nullptr) { @@ -227,8 +227,8 @@ tensorflow::Status ExportModule(mlir::ModuleOp module, return absl::OkStatus(); } -tensorflow::Status ConvertTFToStableHLO( - ModuleOp tf_module, const PassPipelineCLParser& pass_pipeline) { +absl::Status ConvertTFToStableHLO(ModuleOp tf_module, + const PassPipelineCLParser& pass_pipeline) { PassManager pm(tf_module.getContext()); if (failed(applyPassManagerCLOptions(pm))) { return tensorflow::errors::Aborted( @@ -273,7 +273,7 @@ tensorflow::Status ConvertTFToStableHLO( return absl::OkStatus(); } -tensorflow::Status RunConverter(const PassPipelineCLParser& pass_pipeline) { +absl::Status RunConverter(const PassPipelineCLParser& pass_pipeline) { DialectRegistry registry; registerAllDialects(registry); RegisterAllTensorFlowDialects(registry); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir index b15175f6602547..60f94c69014604 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir @@ -258,6 +258,24 @@ func.func private @XlaCallModule_odml.upsample_bilinear2d.impl_21_0(%arg0: tenso // CHECK: return %[[VAL_6]] : tensor<1x64x32x32xf32> // CHECK: } +func.func private @XlaCallModule_tfl.gelu.impl_0(%arg0: tensor<1x4x4x1xf32>) -> (tensor<1x4x4x1xf32>) +func.func @jax_gelu_approx(%arg0: tensor<1x4x4x1xf32>) -> (tensor<1x4x4x1xf32>) { + %2 = mhlo.composite "tfl.gelu" %arg0 {composite_attributes = {approximate = true}, decomposition = @XlaCallModule_tfl.gelu.impl_0} : (tensor<1x4x4x1xf32>) -> tensor<1x4x4x1xf32> + return %2 : tensor<1x4x4x1xf32> +} + +// CHECK-LABEL: jax_gelu_approx +// CHECK: %0 = "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<1x4x4x1xf32>) -> tensor<1x4x4x1xf32> + +func.func private @XlaCallModule_tfl.gelu.impl_1(%arg0: tensor<1x4x4x1xf32>) -> (tensor<1x4x4x1xf32>) +func.func @jax_gelu(%arg0: tensor<1x4x4x1xf32>) -> (tensor<1x4x4x1xf32>) { + %2 = mhlo.composite "tfl.gelu" %arg0 {composite_attributes = {approximate = false}, decomposition = @XlaCallModule_tfl.gelu.impl_1} : (tensor<1x4x4x1xf32>) -> tensor<1x4x4x1xf32> + return %2 : tensor<1x4x4x1xf32> +} + +// CHECK-LABEL: jax_gelu +// CHECK: %0 = "tfl.gelu"(%arg0) <{approximate = false}> : (tensor<1x4x4x1xf32>) -> tensor<1x4x4x1xf32> + func.func private @gelu_decomp_1(%arg0: tensor<5x10xf32>) -> tensor<5x10xf32> func.func @gelu_aten(%arg0: tensor<5x10xf32>) -> (tensor<*xf32>) { %0 = mhlo.composite "aten.gelu.default" %arg0 {composite_attributes = {approximate = "none"}, decomposition = @gelu_decomp_1} : (tensor<5x10xf32>) -> tensor<5x10xf32> @@ -409,3 +427,28 @@ func.func private @XlaCallModule_odml.embedding_lookup.impl_0(%arg0: tensor<1xi3 // CHECK: return %[[VAL_1]] : tensor<1x2048xf32> // CHECK: } + +func.func @random_uniform(%arg0: tensor<3xi32>) -> tensor<1x2x3xf32> { + %0 = mhlo.composite "odml.random_uniform" %arg0 {composite_attributes = {seed = 0 : i64, seed2 = 1: i64}, decomposition = @XlaCallModule_odml.random_uniform.impl_0} : (tensor<3xi32>) -> tensor<1x2x3xf32> + return %0 : tensor<1x2x3xf32> +} +func.func private @XlaCallModule_odml.random_uniform.impl_0(%arg0: tensor<3xi32>) -> tensor<1x2x3xf32> { + %0 = mhlo.constant dense<1.000000e+00> : tensor<1x2x3xf32> + return %0 : tensor<1x2x3xf32> +} +// CHECK-LABEL func.func @random_uniform +// CHECK: %0 = "tfl.random_uniform"(%arg0) <{seed = 0 : i64, seed2 = 1 : i64}> : (tensor<3xi32>) -> tensor<1x2x3xf32> +// CHECK: return %0 : tensor<1x2x3xf32> + + +func.func @random_standard_normal(%arg0: tensor<3xi32>) -> tensor<1x2x3xf32> { + %0 = mhlo.composite "odml.random_standard_normal" %arg0 {composite_attributes = {seed = 0 : i64, seed2 = 1: i64}, decomposition = @XlaCallModule_odml.random_standard_normal.impl_0} : (tensor<3xi32>) -> tensor<1x2x3xf32> + return %0 : tensor<1x2x3xf32> +} +func.func private @XlaCallModule_odml.random_standard_normal.impl_0(%arg0: tensor<3xi32>) -> tensor<1x2x3xf32> { + %0 = mhlo.constant dense<1.000000e+00> : tensor<1x2x3xf32> + return %0 : tensor<1x2x3xf32> +} +// CHECK-LABEL func.func @random_standard_normal +// CHECK: %0 = "tfl.random_standard_normal"(%arg0) <{seed = 0 : i64, seed2 = 1 : i64}> : (tensor<3xi32>) -> tensor<1x2x3xf32> +// CHECK: return %0 : tensor<1x2x3xf32> \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir index 3fddc46c755361..c55a93fb8f6dfe 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir @@ -3751,27 +3751,106 @@ func.func @convert_gather_offset(%arg0: tensor<1x20xi32>, %arg1: tensor<1x1xi32> func.return %0 : tensor<1x1xi32> } -// CHECK-LABEL: func @convert_gather_trivial_batching_dims( -// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x128xf32>, -// CHECK-SAME: %[[ARG_1:.*]]: tensor<1x128x1xi32>) -// CHECK: %[[VAL_0:.*]] = arith.constant dense<128> : tensor<1xi64> -// CHECK: %[[VAL_1:.*]] = "tf.Reshape"(%[[ARG_0]], %[[VAL_0]]) : {{.*}} -> tensor<128xf32> -// CHECK: %[[VAL_2:.*]] = "tf.GatherNd"(%[[VAL_1]], %[[ARG_1]]) <{bad_indices_policy = ""}> : {{.*}} -> tensor<1x128xf32> -// CHECK: return %[[VAL_2]] -// CHECK: } -func.func @convert_gather_trivial_batching_dims(%arg0: tensor<1x128xf32>, %arg1: tensor<1x128x1xi32>) -> tensor<1x128xf32> { +// CHECK-LABEL: func @convert_gather_batching_dims( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<2x3x128xf32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<3x2x128x1xi32>) +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<[6, 128]> : tensor<2xi64> +// CHECK: %[[VAL_0:.*]] = "tf.Reshape"(%[[ARG_0]], %[[CST]]) : (tensor<2x3x128xf32>, tensor<2xi64>) -> tensor<6x128xf32> +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<[1, 0, 2, 3]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %[[VAL_1:.*]] = "tf.Transpose"(%[[ARG_1]], %[[CST_0]]) : (tensor<3x2x128x1xi32>, tensor<4xi64>) -> tensor<2x3x128x1xi32> +// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[6, 128, 1]> : tensor<3xi64> +// CHECK: %[[VAL_2:.*]] = "tf.Reshape"(%[[VAL_1]], %[[CST_1]]) : (tensor<2x3x128x1xi32>, tensor<3xi64>) -> tensor<6x128x1xi32> +// CHECK-DAG: %[[CST_2:.*]] = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK-DAG: %[[CST_3:.*]] = "tf.Const"() <{value = dense<6> : tensor}> : () -> tensor +// CHECK-DAG: %[[CST_4:.*]] = "tf.Const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tf.Range"(%[[CST_2]], %[[CST_3]], %[[CST_4]]) : (tensor, tensor, tensor) -> tensor<6xi32> +// CHECK-DAG: %[[CST_5:.*]] = "tf.Const"() <{value = dense<[6, 1, 1]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK: %[[VAL_4:.*]] = "tf.Reshape"(%[[VAL_3]], %[[CST_5]]) : (tensor<6xi32>, tensor<3xi64>) -> tensor<6x1x1xi32> +// CHECK-DAG: %[[CST_6:.*]] = "tf.Const"() <{value = dense<[6, 128, 1]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK: %[[VAL_5:.*]] = "tf.BroadcastTo"(%[[VAL_4]], %[[CST_6]]) : (tensor<6x1x1xi32>, tensor<3xi64>) -> tensor<6x128x1xi32> +// CHECK-DAG: %[[CST_7:.*]] = "tf.Const"() <{value = dense<2> : tensor}> : () -> tensor +// CHECK: %[[VAL_6:.*]] = "tf.ConcatV2"(%[[VAL_5]], %[[VAL_2]], %[[CST_7]]) : (tensor<6x128x1xi32>, tensor<6x128x1xi32>, tensor) -> tensor<6x128x2xi32> +// CHECK: %[[VAL_7:.*]] = "tf.GatherNd"(%[[VAL_0]], %[[VAL_6]]) <{bad_indices_policy = ""}> : {{.*}} -> tensor<6x128xf32> +// CHECK-DAG: %[[CST_8:.*]] = arith.constant dense<[2, 3, 128]> : tensor<3xi64> +// CHECK: %[[VAL_8:.*]] = "tf.Reshape"(%[[VAL_7]], %[[CST_8]]) : (tensor<6x128xf32>, tensor<3xi64>) -> tensor<2x3x128xf32> +// CHECK-DAG: %[[CST_9:.*]] = "tf.Const"() <{value = dense<[1, 0, 2]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK: %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_8]], %[[CST_9]]) : (tensor<2x3x128xf32>, tensor<3xi64>) -> tensor<3x2x128xf32> +// CHECK: return %[[VAL_9]] +// CHECK: } +func.func @convert_gather_batching_dims(%arg0: tensor<2x3x128xf32>, %arg1: tensor<3x2x128x1xi32>) -> tensor<3x2x128xf32> { %0 = "mhlo.gather"(%arg0, %arg1) { dimension_numbers = #mhlo.gather< - index_vector_dim = 2, - start_index_map = [1], - operand_batching_dims = [0], - start_indices_batching_dims = [0], - collapsed_slice_dims = [1], + index_vector_dim = 3, + start_index_map = [2], + operand_batching_dims = [0, 1], + start_indices_batching_dims = [1, 0], + collapsed_slice_dims = [2], >, indices_are_sorted = false, - slice_sizes = dense<1> : tensor<2xi64> - } : (tensor<1x128xf32>, tensor<1x128x1xi32>) -> tensor<1x128xf32> - func.return %0 : tensor<1x128xf32> + slice_sizes = dense<1> : tensor<3xi64> + } : (tensor<2x3x128xf32>, tensor<3x2x128x1xi32>) -> tensor<3x2x128xf32> + func.return %0 : tensor<3x2x128xf32> +} + +// CHECK-LABEL: func @convert_gather_non_collapsed_index_dim( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<10x5xi32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<2x1xi32>) -> tensor<2x1x5xi32> { +// CHECK: %[[VAL_0:.*]] = "tf.GatherNd"(%[[ARG_0]], %[[ARG_1]]) <{bad_indices_policy = ""}> : (tensor<10x5xi32>, tensor<2x1xi32>) -> tensor<2x5xi32> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<[2, 1, 5]> : tensor<3xi64> +// CHECK: %[[VAL_1:.*]] = "tf.Reshape"(%[[VAL_0]], %[[CST]]) : (tensor<2x5xi32>, tensor<3xi64>) -> tensor<2x1x5xi32> +// CHECK: return %[[VAL_1]] : tensor<2x1x5xi32> +// CHECK: } +func.func @convert_gather_non_collapsed_index_dim(%arg0: tensor<10x5xi32>, %arg1: tensor<2x1xi32>) -> tensor<2x1x5xi32> { + %0 = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = #mhlo.gather< + index_vector_dim = 1, + offset_dims = [1, 2], + start_index_map = [0], + >, + indices_are_sorted = false, + slice_sizes = dense<[1, 5]> : tensor<2xi64> + } : (tensor<10x5xi32>, tensor<2x1xi32>) -> tensor<2x1x5xi32> + func.return %0 : tensor<2x1x5xi32> +} + +// CHECK-LABEL: func @convert_gather_indexed_dimension_slice( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x5x6xi32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<2x2xi32>) -> tensor<2x1x5x6xi32> { +// CHECK-DAG: %[[CST:.*]] = "tf.Const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK: %[[VAL_0:.*]] = "tf.Transpose"(%[[ARG_0]], %[[CST]]) : (tensor<4x5x6xi32>, tensor<3xi64>) -> tensor<4x6x5xi32> +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<[2, 1, 2]> : tensor<3xi64> +// CHECK: %[[VAL_1:.*]] = "tf.Reshape"(%[[ARG_1]], %[[CST_0]]) : (tensor<2x2xi32>, tensor<3xi64>) -> tensor<2x1x2xi32> +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK-DAG: %[[CST_2:.*]] = "tf.Const"() <{value = dense<6> : tensor}> : () -> tensor +// CHECK-DAG: %[[CST_3:.*]] = "tf.Const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Range"(%[[CST_1]], %[[CST_2]], %[[CST_3]]) : (tensor, tensor, tensor) -> tensor<6xi32> +// CHECK-DAG: %[[CST_4:.*]] = "tf.Const"() <{value = dense<[1, 6, 1]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_2]], %[[CST_4]]) : (tensor<6xi32>, tensor<3xi64>) -> tensor<1x6x1xi32> +// CHECK-DAG: %[[CST_5:.*]] = "tf.Const"() <{value = dense<[1, 6, 1]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK: %[[VAL_4:.*]] = "tf.BroadcastTo"(%[[VAL_3]], %[[CST_5]]) : (tensor<1x6x1xi32>, tensor<3xi64>) -> tensor<1x6x1xi32> +// CHECK-DAG: %[[CST_6:.*]] = arith.constant dense<0> : tensor +// CHECK-DAG: %[[CST_7:.*]] = arith.constant +// CHECK-SAME{LITERAL: dense<[[0, 0], [0, 0], [1, 0]]> : tensor<3x2xi64> +// CHECK: %[[VAL_5:.*]] = "tf.PadV2"(%[[VAL_4]], %[[CST_7]], %[[CST_6]]) : (tensor<1x6x1xi32>, tensor<3x2xi64>, tensor) -> tensor<1x6x2xi32> +// CHECK: %[[VAL_6:.*]] = "tf.Add"(%[[VAL_1]], %[[VAL_5]]) : (tensor<2x1x2xi32>, tensor<1x6x2xi32>) -> tensor<2x6x2xi32> +// CHECK: %[[VAL_7:.*]] = "tf.GatherNd"(%[[VAL_0]], %[[VAL_6]]) <{bad_indices_policy = ""}> : (tensor<4x6x5xi32>, tensor<2x6x2xi32>) -> tensor<2x6x5xi32> +// CHECK-DAG: %[[CST_8:.*]] = arith.constant dense<[2, 1, 6, 5]> : tensor<4xi64> +// CHECK: %[[VAL_8:.*]] = "tf.Reshape"(%[[VAL_7]], %[[CST_8]]) : (tensor<2x6x5xi32>, tensor<4xi64>) -> tensor<2x1x6x5xi32> +// CHECK-DAG: %[[CST_9:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_8]], %[[CST_9]]) : (tensor<2x1x6x5xi32>, tensor<4xi64>) -> tensor<2x1x5x6xi32> +// CHECK: return %[[VAL_9]] : tensor<2x1x5x6xi32> +// CHECK: } +func.func @convert_gather_indexed_dimension_slice(%arg0: tensor<4x5x6xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x1x5x6xi32> { + %0 = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = #mhlo.gather< + index_vector_dim = 1, + offset_dims = [1, 2, 3], + start_index_map = [0, 2], + >, + indices_are_sorted = false, + slice_sizes = dense<[1, 5, 6]> : tensor<3xi64> + } : (tensor<4x5x6xi32>, tensor<2x2xi32>) -> tensor<2x1x5x6xi32> + func.return %0 : tensor<2x1x5x6xi32> } // CHECK-LABEL: func @convert_gather_to_slice_batch_size_1( diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir index 4529b34448077c..83e9d9b3062187 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir @@ -1756,6 +1756,117 @@ func.func @gather_offset(%arg0: tensor<1x20xi32>, %arg1: tensor<1x1xi32>) -> ten // ----- + +// CHECK-LABEL: gather_batching_dims +func.func @gather_batching_dims(%arg0: tensor<2x3x128xf32>, %arg1: tensor<3x2x128x1xi32>) -> tensor<3x2x128xf32> { + %0 = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = #mhlo.gather< + index_vector_dim = 3, + start_index_map = [2], + operand_batching_dims = [0, 1], + start_indices_batching_dims = [1, 0], + collapsed_slice_dims = [2], + >, + indices_are_sorted = false, + slice_sizes = dense<1> : tensor<3xi64> + } : (tensor<2x3x128xf32>, tensor<3x2x128x1xi32>) -> tensor<3x2x128xf32> + func.return %0 : tensor<3x2x128xf32> +} + +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<[6, 128]> : tensor<2xi64> +// CHECK: %[[VAL_0:.*]] = "tfl.cast"(%[[CST]]) : (tensor<2xi64>) -> tensor<2xi32> +// CHECK: %[[VAL_1:.*]] = "tfl.reshape"(%arg0, %[[VAL_0]]) : (tensor<2x3x128xf32>, tensor<2xi32>) -> tensor<6x128xf32> +// CHECK-DAG: %[[VAL_2:.*]] = "tfl.pseudo_const"() <{value = dense<[1, 0, 2, 3]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %[[VAL_3:.*]] = "tfl.cast"(%[[VAL_2]]) : (tensor<4xi64>) -> tensor<4xi32> +// CHECK: %[[VAL_4:.*]] = "tfl.transpose"(%arg1, %[[VAL_3]]) : (tensor<3x2x128x1xi32>, tensor<4xi32>) -> tensor<2x3x128x1xi32> +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<[6, 128, 1]> : tensor<3xi64> +// CHECK: %[[VAL_5:.*]] = "tfl.cast"(%[[CST_0]]) : (tensor<3xi64>) -> tensor<3xi32> +// CHECK: %[[VAL_6:.*]] = "tfl.reshape"(%[[VAL_4]], %[[VAL_5]]) : (tensor<2x3x128x1xi32>, tensor<3xi32>) -> tensor<6x128x1xi32> +// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<0> : tensor +// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<6> : tensor +// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<1> : tensor +// CHECK: %[[VAL_7:.*]] = "tfl.range"(%[[CST_1]], %[[CST_2]], %[[CST_3]]) : (tensor, tensor, tensor) -> tensor<6xi32> +// CHECK-DAG: %[[CST_4:.*]] = arith.constant dense<[6, 1, 1]> : tensor<3xi64> +// CHECK: %[[VAL_8:.*]] = "tfl.cast"(%[[CST_4]]) : (tensor<3xi64>) -> tensor<3xi32> +// CHECK: %[[VAL_9:.*]] = "tfl.reshape"(%[[VAL_7]], %[[VAL_8]]) : (tensor<6xi32>, tensor<3xi32>) -> tensor<6x1x1xi32> +// CHECK-DAG: %[[CST_5:.*]] = arith.constant dense<[6, 128, 1]> : tensor<3xi64> +// CHECK: %[[VAL_10:.*]] = "tfl.broadcast_to"(%[[VAL_9]], %[[CST_5]]) : (tensor<6x1x1xi32>, tensor<3xi64>) -> tensor<6x128x1xi32> +// CHECK: %[[VAL_11:.*]] = "tfl.concatenation"(%[[VAL_10]], %[[VAL_6]]) <{axis = 2 : i32, fused_activation_function = "NONE"}> : (tensor<6x128x1xi32>, tensor<6x128x1xi32>) -> tensor<6x128x2xi32> +// CHECK: %[[VAL_12:.*]] = "tfl.gather_nd"(%[[VAL_1]], %[[VAL_11]]) : (tensor<6x128xf32>, tensor<6x128x2xi32>) -> tensor<6x128xf32> +// CHECK-DAG: %[[CST_6:.*]] = arith.constant dense<[2, 3, 128]> : tensor<3xi64> +// CHECK: %[[VAL_13:.*]] = "tfl.cast"(%[[CST_6]]) : (tensor<3xi64>) -> tensor<3xi32> +// CHECK: %[[VAL_14:.*]] = "tfl.reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<6x128xf32>, tensor<3xi32>) -> tensor<2x3x128xf32> +// CHECK: %[[VAL_15:.*]] = "tfl.pseudo_const"() <{value = dense<[1, 0, 2]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK: %[[VAL_16:.*]] = "tfl.cast"(%[[VAL_15]]) : (tensor<3xi64>) -> tensor<3xi32> +// CHECK: %[[VAL_17:.*]] = "tfl.transpose"(%[[VAL_14]], %[[VAL_16]]) : (tensor<2x3x128xf32>, tensor<3xi32>) -> tensor<3x2x128xf32> + +// ----- + +// CHECK-LABEL: convert_gather_non_collapsed_index_dim +func.func @convert_gather_non_collapsed_index_dim(%arg0: tensor<10x5xi32>, %arg1: tensor<2x1xi32>) -> tensor<2x1x5xi32> { + %0 = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = #mhlo.gather< + index_vector_dim = 1, + offset_dims = [1, 2], + start_index_map = [0], + >, + indices_are_sorted = false, + slice_sizes = dense<[1, 5]> : tensor<2xi64> + } : (tensor<10x5xi32>, tensor<2x1xi32>) -> tensor<2x1x5xi32> + func.return %0 : tensor<2x1x5xi32> +} + +// CHECK: %[[VAL_0:.*]] = "tfl.gather_nd"(%arg0, %arg1) : (tensor<10x5xi32>, tensor<2x1xi32>) -> tensor<2x5xi32 +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<[2, 1, 5]> : tensor<3xi64> +// CHECK: %[[VAL_1:.*]] = "tfl.cast"(%[[CST]]) : (tensor<3xi64>) -> tensor<3xi32> +// CHECK: %[[VAL_2:.*]] = "tfl.reshape"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x5xi32>, tensor<3xi32>) -> tensor<2x1x5xi32> + +// ----- + +// CHECK-LABEL: convert_gather_indexed_dimension_slice +func.func @convert_gather_indexed_dimension_slice(%arg0: tensor<4x5x6xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x1x5x6xi32> { + %0 = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = #mhlo.gather< + index_vector_dim = 1, + offset_dims = [1, 2, 3], + start_index_map = [0, 2], + >, + indices_are_sorted = false, + slice_sizes = dense<[1, 5, 6]> : tensor<3xi64> + } : (tensor<4x5x6xi32>, tensor<2x2xi32>) -> tensor<2x1x5x6xi32> + func.return %0 : tensor<2x1x5x6xi32> +} + +// CHECK: %[[VAL_0:.*]] = "tfl.pseudo_const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK: %[[VAL_1:.*]] = "tfl.cast"(%[[VAL_0]]) : (tensor<3xi64>) -> tensor<3xi32> +// CHECK: %[[VAL_2:.*]] = "tfl.transpose"(%arg0, %[[VAL_1]]) : (tensor<4x5x6xi32>, tensor<3xi32>) -> tensor<4x6x5xi32> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<[2, 1, 2]> : tensor<3xi64> +// CHECK: %[[VAL_3:.*]] = "tfl.cast"(%[[CST]]) : (tensor<3xi64>) -> tensor<3xi32> +// CHECK: %[[VAL_4:.*]] = "tfl.reshape"(%arg1, %[[VAL_3]]) : (tensor<2x2xi32>, tensor<3xi32>) -> tensor<2x1x2xi32> +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<0> : tensor +// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<6> : tensor +// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<1> : tensor +// CHECK: %[[VAL_5:.*]] = "tfl.range"(%[[CST_0]], %[[CST_1]], %[[CST_2]]) : (tensor, tensor, tensor) -> tensor<6xi32> +// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[1, 6, 1]> : tensor<3xi64> +// CHECK: %[[VAL_6:.*]] = "tfl.cast"(%[[CST_3]]) : (tensor<3xi64>) -> tensor<3xi32> +// CHECK: %[[VAL_7:.*]] = "tfl.reshape"(%[[VAL_5]], %[[VAL_6]]) : (tensor<6xi32>, tensor<3xi32>) -> tensor<1x6x1xi32> +// CHECK-DAG: %[[CST_4:.*]] = arith.constant dense<[1, 6, 1]> : tensor<3xi64> +// CHECK: %[[VAL_8:.*]] = "tfl.broadcast_to"(%[[VAL_7]], %[[CST_4]]) : (tensor<1x6x1xi32>, tensor<3xi64>) -> tensor<1x6x1xi32> +// CHECK-DAG: %[[CST_5:.*]] = arith.constant dense<0> : tensor +// CHECK-DAG: %[[CST_6:.*]] = arith.constant +// CHECK-SAME{LITERAL}: dense<[[0, 0], [0, 0], [1, 0]]> : tensor<3x2xi64> +// CHECK: %[[VAL_9:.*]] = "tfl.pad"(%[[VAL_8]], %[[CST_6]]) : (tensor<1x6x1xi32>, tensor<3x2xi64>) -> tensor<1x6x2xi32> +// CHECK: %[[VAL_10:.*]] = tfl.add(%[[VAL_4]], %[[VAL_9]]) <{fused_activation_function = "NONE"}> : (tensor<2x1x2xi32>, tensor<1x6x2xi32>) -> tensor<2x6x2xi32> +// CHECK: %[[VAL_11:.*]] = "tfl.gather_nd"(%[[VAL_2]], %[[VAL_10]]) : (tensor<4x6x5xi32>, tensor<2x6x2xi32>) -> tensor<2x6x5xi32> +// CHECK-DAG: %[[CST_7:.*]] = arith.constant dense<[2, 1, 6, 5]> : tensor<4xi64> +// CHECK: %[[VAL_12:.*]] = "tfl.cast"(%[[CST_7]]) : (tensor<4xi64>) -> tensor<4xi32> +// CHECK: %[[VAL_13:.*]] = "tfl.reshape"(%[[VAL_11]], %[[VAL_12]]) : (tensor<2x6x5xi32>, tensor<4xi32>) -> tensor<2x1x6x5xi32> +// CHECK: %[[VAL_14:.*]] = "tfl.pseudo_const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %[[VAL_15:.*]] = "tfl.cast"(%[[VAL_14]]) : (tensor<4xi64>) -> tensor<4xi32> +// CHECK: %[[VAL_16:.*]] = "tfl.transpose"(%[[VAL_13]], %[[VAL_15]]) : (tensor<2x1x6x5xi32>, tensor<4xi32>) -> tensor<2x1x5x6xi32> + +// ----- + // CHECK-LABEL: gather_to_slice_batch_size_1 func.func @gather_to_slice_batch_size_1(%arg0: tensor<1x2944xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x1504xi32> { %0 = "mhlo.gather"(%arg0, %arg1) { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td index a8323cddc31037..7fe70321a1dd45 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td @@ -126,6 +126,13 @@ def LegalizeCompositeApproximateAtenGELU : Pat< (TFL_GeluOp $inputs, ConstBoolAttrTrue), [(IsStrCompositeAttribute<"approximate", "tanh"> $attrs)]>; +def LegalizeCompositeGELU : Pat< + (MHLO_CompositeOp:$composite + (variadic $inputs), + ConstantStrAttr, $attrs, $_, $_), + (TFL_GeluOp $inputs, + (GetCompositeAttributeAs<"approximate", "BoolAttr"> $attrs))>; + def LegalizeCompositeOdmlEmbeddingLookup : Pat< (MHLO_CompositeOp:$composite (variadic $indices, $table), @@ -143,3 +150,19 @@ def LegalizeCompositeOdmlEmbeddingLookupDynamicShaped : Pat< [(HasRank<1> $indices), (I32ElementsVal $indices), (HasRankAtLeast<2> $table)]>; + +def LegalizeCompositeOdmlRandomUniform : Pat< + (MHLO_CompositeOp:$composite + (variadic $shape), + ConstantStrAttr, $attrs, $_, $_), + (TFL_RandomUniformOp $shape, + (GetCompositeAttributeAs<"seed", "IntegerAttr"> $attrs), + (GetCompositeAttributeAs<"seed2", "IntegerAttr"> $attrs))>; + +def LegalizeCompositeOdmlRandomStandardNormal : Pat< + (MHLO_CompositeOp:$composite + (variadic $shape), + ConstantStrAttr, $attrs, $_, $_), + (TFL_RandomStandardNormalOp $shape, + (GetCompositeAttributeAs<"seed", "IntegerAttr"> $attrs), + (GetCompositeAttributeAs<"seed2", "IntegerAttr"> $attrs))>; \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc index ae5a7439390c74..67763345add880 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc @@ -285,7 +285,7 @@ class ConvertNdConvOp : public OpConversionPattern { int64_t output_size; int64_t pad_low_int64; int64_t pad_high_int64; - tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerbose( + absl::Status status = tensorflow::GetWindowedOutputSizeVerbose( mlir::cast(conv_op.getLhs().getType()) .getDimSize(input_spatial_dim[i]), mlir::cast(conv_op.getRhs().getType()) @@ -2695,22 +2695,49 @@ bool SameTypeOrDefaultCompare(mhlo::ComparisonTypeAttr comparison_type_attr, return false; } +// Tries to convert an mhlo::GatherOp into a TF::GatherNdOp (or TF::SliceOp). +// +// Consider the following example: +// operand_shape = [B1, I1, O1, B2, I2, O2] +// operand_batching_dims = [0, 3] +// +// start_indices_shape = [B2, B3, B1, 2] +// start_indices_batching_dims = [3, 0] +// index_vector_dim = 3 +// start_index_map = [4, 1] +// +// offset_dims: [2, 4] +// slice_sizes = [1, 1, O1, 1, 1, O2] +// collapsed_slice_dims = [1, 4] +// result_shape = [B2, B3, O1, B3, O2] +// +// To implement this with a tf.GatherNd, we canonicalize the operand s.t. the +// operand batching dimensions are flattened into the leading dimensions, +// followed by the indexed dimensions in order: +// canonical_operand_shape = [B1 * B2, I2, I1, O1, O2] +// +// We canonicalize the start indices so the start indices batching dimensions +// are flattened (in order) into a leading dimension. In addition, we add iota +// indices to appropriately offset into the flattened operand batching +// dimension: +// canonical_start_indices_shape = [B1 * B2, B3, 3] +// (index_vector_dim is expanded to included indices for the operand +// batching dimensions) +// +// The result of tf.GatherNd(canonical_operand, canonical_start_indices) has the +// following shape: +// canonical_result_shape = [B1 * B2, B3, O1, O2] +// +// The canonical result is unflattened and transpose as needed to get back to +// the original result shape. class ConvertGatherOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - // Helper params for representing the transpose params for the "canonicalized" - // output to the real output. - struct TransposeParams { - std::vector permutation; - // The following are the "canonicalized" output shape with offset dims. - std::vector canonicalized_output_shape; - std::vector canonicalized_offset_dims; - }; - LogicalResult matchAndRewrite( mhlo::GatherOp gather_op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { + // First see if we can convert the gather to a tf.Slice. if (succeeded(ConvertGatherOpToSlice(gather_op, rewriter))) { return success(); } @@ -2729,6 +2756,39 @@ class ConvertGatherOp : public OpConversionPattern { return failure(); } + llvm::ArrayRef operand_batching_dims = + gather_op.getDimensionNumbers().getOperandBatchingDims(); + llvm::ArrayRef start_indices_batching_dims = + gather_op.getDimensionNumbers().getStartIndicesBatchingDims(); + llvm::ArrayRef start_index_map = + gather_op.getDimensionNumbers().getStartIndexMap(); + llvm::ArrayRef collapsed_slice_dims = + gather_op.getDimensionNumbers().getCollapsedSliceDims(); + if (!start_indices_type.hasStaticShape() || !result_type.hasStaticShape()) { + // Dynamic dimensions aren't supported in certain cases that require + // reshaping the indices or result. + if (!start_indices_batching_dims.empty()) { + gather_op.emitOpError() + << "Dynamic shaped start indices aren't supported when there are " + "batching dimensions."; + } + + // Verify that start_index_map and collapsed_slice_dims contains the same + // values. + if (start_index_map.size() != collapsed_slice_dims.size()) { + return rewriter.notifyMatchFailure( + gather_op, + "different size for start index map and collapsed slice dims"); + } + for (auto c : collapsed_slice_dims) { + if (llvm::count(start_index_map, c) == 0) { + return rewriter.notifyMatchFailure( + gather_op, + "collapsed slice dim isn't present in start index map"); + } + } + } + // Normalize start_indices so index_vector_dim == start_indices.rank() - 1. int64_t index_vector_dim = gather_op.getDimensionNumbers().getIndexVectorDim(); @@ -2737,151 +2797,78 @@ class ConvertGatherOp : public OpConversionPattern { rewriter))) { return failure(); } + start_indices_type = mlir::cast(start_indices.getType()); - // Verify that start_index_map and collapsed_slice_dims contains the same - // values. - auto start_index_map = gather_op.getDimensionNumbers().getStartIndexMap(); - auto collapsed_slice_dims = - gather_op.getDimensionNumbers().getCollapsedSliceDims(); - if (start_index_map.size() != collapsed_slice_dims.size()) { - return rewriter.notifyMatchFailure( - gather_op, - "different size for start index map and collapsed slice dims"); - } - for (auto c : collapsed_slice_dims) { - if (llvm::count(start_index_map, c) == 0) { - return rewriter.notifyMatchFailure( - gather_op, "collapsed slice dim isn't present in start index map"); - } - } - - // Verify that slice_sizes is 1 for the indexed dimensions and the full - // shape for the rest of the dimensions. + // Verify that slice_sizes is 1 for the batching dimensions and the full + // shape for non-indexed dimensions. auto slice_sizes = gather_op.getSliceSizes(); - int64_t index = 0; + llvm::SmallVector slice_sizes_vector; + slice_sizes_vector.reserve(slice_sizes.size()); for (int64_t s : slice_sizes.getValues()) { - if (llvm::count(start_index_map, index)) { + slice_sizes_vector.push_back(s); + } + for (int i = 0; i < slice_sizes_vector.size(); ++i) { + int s = slice_sizes_vector[i]; + if (llvm::count(start_indices_batching_dims, i)) { if (s != 1) { return rewriter.notifyMatchFailure(gather_op, "unsupported slice sizes"); } - } else { - if (s != operand_type.getShape()[index]) { + } else if (llvm::count(start_index_map, i) == 0) { + if (s != operand_type.getShape()[i]) { return rewriter.notifyMatchFailure(gather_op, "unsupported slice sizes"); } } - ++index; } - // Verify that offset_dims are the tailing dimensions in the output tensor. - auto offset_dims = gather_op.getDimensionNumbers().getOffsetDims(); - SmallVector offset_dims_vector(offset_dims.begin(), - offset_dims.end()); - const TransposeParams& transpose_params = - CanonicalizeOffset(/*result_type=*/result_type, - /*original_offset_dims=*/offset_dims_vector); - - int64_t offset = start_indices_type.getRank() - 1; - for (int64_t o : transpose_params.canonicalized_offset_dims) { - if (o != offset) { - return rewriter.notifyMatchFailure(gather_op, - "unsupported offset dims"); - } - ++offset; - } + // Canonicalize the operand and start indices. + auto canonical_operand = + CanonicalizeOperand(gather_op, operand, operand_type, + operand_batching_dims, start_index_map, rewriter); + auto canonical_operand_type = + mlir::cast(canonical_operand.getType()); - // Verify that operand_batching_dims and start_indices_batching_dims are - // leading dimensions of the operand and start_indices, respectively, and - // that all batching dimensions are trivial. - llvm::ArrayRef operand_batching_dims = - gather_op.getDimensionNumbers().getOperandBatchingDims(); - llvm::ArrayRef start_indices_batching_dims = - gather_op.getDimensionNumbers().getStartIndicesBatchingDims(); - if (operand_batching_dims.size() != start_indices_batching_dims.size()) { - return rewriter.notifyMatchFailure( - gather_op, - "different size for operand and start_indices batching dims"); - } - for (int64_t i = 0; i < operand_batching_dims.size(); ++i) { - if (operand_batching_dims[i] != i || - start_indices_batching_dims[i] != i || - operand_type.getShape()[i] != 1 || - start_indices_type.getShape()[i] != 1) { - return rewriter.notifyMatchFailure(gather_op, - "unsupported batching dims"); - } - } - const int64_t num_batch_dims = operand_batching_dims.size(); - - // Transpose the operand to handle non-iota start index map, such that - // the start index dimensions are in order and follow the batching - // dimensions. - llvm::SmallVector transpose_dimensions; - llvm::SmallVector transpose_shape; - for (int64_t i = 0; i < num_batch_dims; ++i) { - transpose_dimensions.push_back(i); - transpose_shape.push_back(operand_type.getShape()[i]); - } - for (int64_t s : start_index_map) { - transpose_dimensions.push_back(s); - transpose_shape.push_back(operand_type.getShape()[s]); - } - for (int64_t i = num_batch_dims, e = operand_type.getRank(); i < e; ++i) { - if (llvm::count(start_index_map, i) == 0) { - transpose_dimensions.push_back(i); - transpose_shape.push_back(operand_type.getShape()[i]); - } - } - operand_type = - RankedTensorType::get(transpose_shape, operand_type.getElementType()); - operand = rewriter.create( - gather_op.getLoc(), operand_type, operand, - rewriter.getI64TensorAttr(transpose_dimensions)); - - // Reshape away the batching dimensions (trivial) from the operand. - operand_type = RankedTensorType::get( - operand_type.getShape().drop_front(num_batch_dims), - operand_type.getElementType()); - operand = rewriter.create(gather_op->getLoc(), - operand_type, operand); - - // Check whether we need to append a transpose op after the gather nd. - bool need_transpose_after = false; - for (int i = 0; i < transpose_params.permutation.size(); ++i) { - if (i != transpose_params.permutation[i]) { - need_transpose_after = true; - break; - } - } - - auto tf_gather_nd_result_type = - RankedTensorType::get(transpose_params.canonicalized_output_shape, - result_type.getElementType()); + auto canonical_start_indices = + CanonicalizeStartIndices(gather_op, start_indices, start_indices_type, + start_indices_batching_dims, start_index_map, + slice_sizes_vector, rewriter); + auto canonical_start_indices_type = + mlir::cast(canonical_start_indices.getType()); TF::CastOp cast_op = nullptr; - if (start_indices_type.getElementType().isUnsignedInteger(32)) { + if (canonical_start_indices_type.getElementType().isUnsignedInteger(32)) { cast_op = rewriter.create( gather_op->getLoc(), - RankedTensorType::get(start_indices_type.getShape(), + RankedTensorType::get(canonical_start_indices_type.getShape(), rewriter.getI64Type()), - start_indices); + canonical_start_indices); } - auto tf_gather_nd_op = rewriter.create( - gather_op->getLoc(), tf_gather_nd_result_type, operand, - cast_op ? cast_op.getResult() : start_indices); - - if (!need_transpose_after) { - rewriter.replaceOp(gather_op, tf_gather_nd_op->getOpResults()); - return success(); + llvm::SmallVector canonical_result_shape; + for (int64_t i = 0; i < canonical_start_indices_type.getRank() - 1; ++i) { + canonical_result_shape.push_back( + canonical_start_indices_type.getDimSize(i)); } + for (int64_t i = canonical_start_indices_type.getDimSize( + canonical_start_indices_type.getRank() - 1); + i < canonical_operand_type.getRank(); ++i) { + canonical_result_shape.push_back(canonical_operand_type.getDimSize(i)); + } + + auto canonical_result_type = RankedTensorType::get( + canonical_result_shape, result_type.getElementType()); + auto canonical_result = rewriter.create( + gather_op->getLoc(), canonical_result_type, canonical_operand, + cast_op ? cast_op.getResult() : canonical_start_indices); - // Insert the transpose op after the gather_nd. - rewriter.replaceOpWithNewOp( - gather_op, result_type, tf_gather_nd_op, - rewriter.getI64TensorAttr(transpose_params.permutation)); + auto offset_dims = gather_op.getDimensionNumbers().getOffsetDims(); + auto final_result = UncanonicalizeResult( + gather_op, canonical_result, canonical_result_type, result_type, + offset_dims, operand_batching_dims, start_indices_batching_dims, + start_index_map, slice_sizes_vector, collapsed_slice_dims, rewriter); + rewriter.replaceOp(gather_op, final_result); return success(); } @@ -3037,75 +3024,303 @@ class ConvertGatherOp : public OpConversionPattern { } private: - // Canonicalize the offset dims to make sure the offset dims are the trailing - // dimensions of the output tensor. - // We will also return the permutation for (the transpose op). - // However, it's not guaranteed the canonicalized offset dims can make it - // always legalizable to tf. - TransposeParams CanonicalizeOffset( - ShapedType result_type, ArrayRef original_offset_dims) const { - TransposeParams transpose_params; - int output_rank = result_type.getRank(); - // The canonicalized offset should be the trailing of the output rank. - for (int start = output_rank - original_offset_dims.size(); - start < output_rank; ++start) { - transpose_params.canonicalized_offset_dims.push_back(start); - } - + // Transform the canonicalized result produced by tf.GatherNd with the + // canonicalized operand and start indices back into the original result. + // The canonicalized result will have the start indices batching dimensions + // flattened as leading dimension, and the offset dimensions as trailing + // dimensions. To transform back, we: + // - Unflatten the start indices batching dimensions. + // - Introduce trivial index dimensions that aren't in `collapsed_slice_dims`. + // - Transpose dimensions back based on `offset_dims` and + // `start_indices_batching_dims`. + Value UncanonicalizeResult(mhlo::GatherOp gather_op, Value canonical_result, + ShapedType canonical_result_type, + ShapedType original_result_type, + ArrayRef offset_dims, + ArrayRef operand_batching_dims, + ArrayRef start_indices_batching_dims, + ArrayRef start_index_map, + ArrayRef slice_sizes, + ArrayRef collapsed_slice_dims, + ConversionPatternRewriter& rewriter) const { // For those dims NOT inside the original_offset_dims are considered "batch // dims". std::vector batch_dims; // Offset dims are guaranteed to be sorted. int offset_index = 0; - for (int64_t i = 0; i < output_rank; ++i) { - if (offset_index >= original_offset_dims.size() || - original_offset_dims[offset_index] != i) { + for (int64_t i = 0; i < original_result_type.getRank(); ++i) { + if (offset_index >= offset_dims.size() || + offset_dims[offset_index] != i) { batch_dims.push_back(i); } else { ++offset_index; } } - // Populate the trnaspose permutation params from a "canonicalized" output - // to the real output. - // The canonicalized layout would be batch_dims followed by sliced_dims. - // The current layout is essentially a transpose after the canonicalized - // layout. - // Take the following as an example: - // If we have the: - // original_offset_dims like [1, 2, 4] - // batch_dims like [0, 3] - // It's like performing transpose on a "canonicalized" - // [batch_dims, sliced_dims]: [B1, B2, O1, O2, O3] - // into the current layout: [B1, O1, O2, B2, O3] - // where the permutation is [0, 2, 3, 1, 4] - int batch_idx = 0; - int offset_idx = 0; - int batch_dim_size = batch_dims.size(); - for (int i = 0; i < output_rank; ++i) { - if (batch_idx >= batch_dims.size()) { - transpose_params.permutation.push_back(batch_dim_size + offset_idx); - ++offset_idx; - } else if (offset_idx < original_offset_dims.size() && - original_offset_dims[offset_idx] < batch_dims[batch_idx]) { - transpose_params.permutation.push_back(batch_dim_size + offset_idx); - ++offset_idx; + // Determine the canonical shape after unflattening the start indices + // batching dimensions (if they exist) and introducing any trivial index + // dimensions that weren't collapsed. Also compute the permutation to + // transform the original shape to the unflattened canonical shape. + llvm::SmallVector permutation_to_canonical; + llvm::SmallVector unflattened_shape; + for (int64_t i : start_indices_batching_dims) { + int64_t dim = batch_dims[i]; + permutation_to_canonical.push_back(dim); + unflattened_shape.push_back(original_result_type.getDimSize(dim)); + } + for (int64_t i = 0; i < batch_dims.size(); ++i) { + if (llvm::count(start_indices_batching_dims, i) == 0) { + int64_t dim = batch_dims[i]; + permutation_to_canonical.push_back(dim); + unflattened_shape.push_back(original_result_type.getDimSize(dim)); + } + } + // The remaining dimensions are the offset dims. We expect non-collapsed + // indexed dimensions first, followed by the rest of the operand dimensions. + llvm::SmallVector operand_dim_to_offset_dim_map(slice_sizes.size(), + -1); + int offset_dim_index = 0; + llvm::SmallVector remaining_operand_dims; + for (int64_t operand_dim = 0; operand_dim < slice_sizes.size(); + ++operand_dim) { + if (llvm::count(collapsed_slice_dims, operand_dim) || + llvm::count(operand_batching_dims, operand_dim)) { + continue; } else { - transpose_params.permutation.push_back(batch_idx++); + if (llvm::count(start_index_map, operand_dim) == 0) { + remaining_operand_dims.push_back(operand_dim); + } + operand_dim_to_offset_dim_map[operand_dim] = + offset_dims[offset_dim_index++]; + } + } + for (int64_t s : start_index_map) { + if (llvm::count(collapsed_slice_dims, s) == 0) { + int64_t dim = operand_dim_to_offset_dim_map[s]; + permutation_to_canonical.push_back(dim); + unflattened_shape.push_back(original_result_type.getDimSize(dim)); + } + } + for (int64_t operand_dim : remaining_operand_dims) { + int64_t dim = operand_dim_to_offset_dim_map[operand_dim]; + permutation_to_canonical.push_back(dim); + unflattened_shape.push_back(original_result_type.getDimSize(dim)); + } + + // Reshape the result to unflatten the batching dimensions and add back any + // non-collapsed indexed dimensions. The caller should ensure that a + // reshape is not needed if the result has dynamic dimensions. + if (canonical_result_type.hasStaticShape()) { + auto unflattened_result_type = RankedTensorType::get( + unflattened_shape, original_result_type.getElementType()); + canonical_result = rewriter.create( + gather_op.getLoc(), unflattened_result_type, canonical_result); + } + // Transpose back to the original result shape. + return rewriter.create( + gather_op.getLoc(), original_result_type, canonical_result, + rewriter.getI64TensorAttr( + GetInversePermutationArray(permutation_to_canonical))); + } + + // Canonicalize `operand` to handle operand batching dimensions and non-iota + // start index map, so it can be used by tf.GatherNd: + // - Transpose so that the leading dimensions are the operand batching + // dimensions followed by the indexed dimensions (in order). + // - Flatten the batching dimensions. + Value CanonicalizeOperand(mhlo::GatherOp gather_op, Value operand, + ShapedType operand_type, + ArrayRef operand_batching_dims, + ArrayRef start_index_map, + ConversionPatternRewriter& rewriter) const { + int batch_size = 1; + llvm::SmallVector permutation; + llvm::SmallVector transposed_shape; + llvm::SmallVector flattened_shape; + // First add the batching dimensions. + for (int64_t batch_dim : operand_batching_dims) { + permutation.push_back(batch_dim); + transposed_shape.push_back(operand_type.getDimSize(batch_dim)); + batch_size *= operand_type.getDimSize(batch_dim); + } + if (!operand_batching_dims.empty()) { + flattened_shape.push_back(batch_size); + } + // Add the indexed dimensions. + for (int64_t s : start_index_map) { + permutation.push_back(s); + transposed_shape.push_back(operand_type.getDimSize(s)); + flattened_shape.push_back(operand_type.getDimSize(s)); + } + // Finally, add the remaining dimensions. + for (int64_t i = 0; i < operand_type.getRank(); i++) { + if (llvm::count(operand_batching_dims, i) == 0 && + llvm::count(start_index_map, i) == 0) { + permutation.push_back(i); + transposed_shape.push_back(operand_type.getDimSize(i)); + flattened_shape.push_back(operand_type.getDimSize(i)); } } - // Finally, let's find out what are the "canonicalized" output shape looks - // like. - for (auto dim : batch_dims) { - transpose_params.canonicalized_output_shape.push_back( - result_type.getDimSize(dim)); + // Transpose the dimensions and flatten the batching dimensions. + RankedTensorType transposed_type = + RankedTensorType::get(transposed_shape, operand_type.getElementType()); + auto transposed_operand = rewriter.create( + gather_op.getLoc(), transposed_type, operand, + rewriter.getI64TensorAttr(permutation)); + auto flattened_type = + RankedTensorType::get(flattened_shape, operand_type.getElementType()); + auto flattened_operand = rewriter.create( + gather_op.getLoc(), flattened_type, transposed_operand); + return flattened_operand; + } + + // Canonicalize `start_indices` to handle start indices batching dimensions so + // it can be used by tf.GatherNd: + // - Transpose so that the batching dimensions are the leading dimensions. + // - Flatten the batching dimensions if they exist. + // - For each indexed dimension with non-trivial slicing, introduce a new + // dimension, and broadcast and add iota values to the indices. + // - Add iota index values for the operand batching dimensions. + Value CanonicalizeStartIndices(mhlo::GatherOp gather_op, Value start_indices, + ShapedType start_indices_type, + ArrayRef start_indices_batching_dims, + ArrayRef start_index_map, + ArrayRef slice_sizes, + ConversionPatternRewriter& rewriter) const { + int batch_size = 1; + llvm::SmallVector permutation; + llvm::SmallVector transposed_shape; + llvm::SmallVector reshaped_shape; + + // First add the batching dimensions. + for (int64_t batch_dim : start_indices_batching_dims) { + permutation.push_back(batch_dim); + transposed_shape.push_back(start_indices_type.getDimSize(batch_dim)); + batch_size *= start_indices_type.getDimSize(batch_dim); + } + if (!start_indices_batching_dims.empty()) { + reshaped_shape.push_back(batch_size); + } + + // Add remaining dimensions before the final index vector dim. + for (int64_t dim = 0; dim < start_indices_type.getRank() - 1; dim++) { + if (llvm::count(start_indices_batching_dims, dim) == 0) { + permutation.push_back(dim); + transposed_shape.push_back(start_indices_type.getDimSize(dim)); + reshaped_shape.push_back(start_indices_type.getDimSize(dim)); + } } - for (auto dim : original_offset_dims) { - transpose_params.canonicalized_output_shape.push_back( - result_type.getDimSize(dim)); + + // Introduce new dimensions associated with each indexed operand dimension + // that is taking a non-trivial slice. We will broadcast and add iota values + // after reshaping. See comment below for more details. + int64_t first_non_trivial_sliced_dim = reshaped_shape.size(); + for (int64_t operand_dim : start_index_map) { + if (slice_sizes[operand_dim] > 1) { + reshaped_shape.push_back(1); + } } - return transpose_params; + + // Add the index vector dimension. + int64_t index_vector_size = + start_indices_type.getDimSize(start_indices_type.getRank() - 1); + permutation.push_back(permutation.size()); + transposed_shape.push_back(index_vector_size); + reshaped_shape.push_back(index_vector_size); + + // Transpose the dimensions and flatten the batching dimensions. + auto transposed_start_indices = rewriter.create( + gather_op.getLoc(), + RankedTensorType::get(transposed_shape, + start_indices_type.getElementType()), + start_indices, rewriter.getI64TensorAttr(permutation)); + start_indices = rewriter.create( + gather_op.getLoc(), + RankedTensorType::get(reshaped_shape, + start_indices_type.getElementType()), + transposed_start_indices); + + // Because tf.GatherNd does not support non-trivial slicing on indexed + // dimensions, we introduce new dimensions in start_indices and broadcast + // and add iota values to the indices. For example: + // + // operand_shape = [10, 10, 10] + // start_indices_original_shape = [1, 3] + // start_index_map = [0, 1, 2] + // slice_sizes = [1, 5, 1] + // + // We then transform the start indices by broadcasting the shape to + // [1, 5, 3], and adding the iota tensor with the following values: + // + // [[[ 0 0 0 ] + // [ 0 1 0 ] + // [ 0 2 0 ] + // [ 0 3 0 ] + // [ 0 4 0 ]]] + // + // This allows us to take trivial slices when indexing into operand + // dimension 1. + llvm::SmallVector start_indices_shape = reshaped_shape; + int64_t non_trivial_sliced_dim = first_non_trivial_sliced_dim; + for (int i = 0; i < start_index_map.size(); ++i) { + int64_t operand_dim = start_index_map[i]; + if (slice_sizes[operand_dim] == 1) { + continue; + } + // Create iota values along the sliced dimension. + llvm::SmallVector offsets_shape(start_indices_shape.size(), 1); + offsets_shape[non_trivial_sliced_dim] = slice_sizes[operand_dim]; + start_indices_shape[non_trivial_sliced_dim] = slice_sizes[operand_dim]; + auto offsets = rewriter.create( + gather_op.getLoc(), + RankedTensorType::get(offsets_shape, + start_indices_type.getElementType()), + rewriter.getI64IntegerAttr(non_trivial_sliced_dim)); + non_trivial_sliced_dim++; + + // Pad with 0s on the other operand dimensions. + Value zero = rewriter.create( + gather_op.getLoc(), rewriter.getZeroAttr(RankedTensorType::get( + {}, start_indices_type.getElementType()))); + int rank = offsets_shape.size(); + llvm::SmallVector padding_low(rank, 0); + llvm::SmallVector padding_high(rank, 0); + llvm::SmallVector padding_interior(rank, 0); + padding_low.back() = i; + padding_high.back() = start_indices_shape.back() - i - 1; + auto padded_offsets = rewriter.create( + gather_op.getLoc(), offsets, zero, + GetI64ElementsAttr(padding_low, &rewriter), + GetI64ElementsAttr(padding_high, &rewriter), + GetI64ElementsAttr(padding_interior, &rewriter)); + + // Add the padded offsets to the start indices (with broadcasting). + start_indices = rewriter.create(gather_op.getLoc(), + start_indices, padded_offsets); + } + + if (!start_indices_batching_dims.empty()) { + // Concat iota values for indexing into the batching dimensions of the + // operand. + llvm::SmallVector offsets_shape = start_indices_shape; + offsets_shape.back() = 1; + auto offsets = rewriter.create( + gather_op.getLoc(), + RankedTensorType::get(offsets_shape, + start_indices_type.getElementType()), + rewriter.getI64IntegerAttr(0)); + + start_indices_shape.back()++; + start_indices = rewriter.create( + gather_op.getLoc(), + RankedTensorType::get(start_indices_shape, + start_indices_type.getElementType()), + ValueRange{offsets, start_indices}, + rewriter.getI32IntegerAttr(start_indices_shape.size() - 1)); + } + + return start_indices; } }; @@ -3660,6 +3875,10 @@ void LegalizeHloToTf::runOnOperation() { void PopulateLegalizeHloToTfPatterns(RewritePatternSet* patterns, MLIRContext* context) { + // Add mhlo::GatherOp canonicalization patterns first before the complicated + // ConvertGatherOp legalization pattern. + mhlo::GatherOp::getCanonicalizationPatterns(*patterns, context); + patterns ->add permutation; - // The following are the "canonicalized" output shape with offset dims. - std::vector canonicalized_output_shape; - std::vector canonicalized_offset_dims; -}; +namespace { + +DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, + Builder* builder) { + RankedTensorType ty = RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, values); +} -// Canonicalize the offset dims to make sure the offset dims are the -// trailing -// dimensions of the output tensor. -// We will also return the permutation for (the transpose op). -// However, it's not guaranteed the canonicalized offset dims can make it -// always legalizable to tf. -TransposeParams CanonicalizeOffset(ShapedType result_type, - ArrayRef original_offset_dims) { - TransposeParams transpose_params; - int output_rank = result_type.getRank(); - // The canonicalized offset should be the trailing of the output rank. - for (int start = output_rank - original_offset_dims.size(); - start < output_rank; ++start) { - transpose_params.canonicalized_offset_dims.push_back(start); - } - - // For those dims NOT inside the original_offset_dims are considered - // "batch +// Transform the canonicalized result produced by tf.GatherNd with the +// canonicalized operand and start indices back into the original result. +// The canonicalized result will have the start indices batching dimensions +// flattened as leading dimension, and the offset dimensions as trailing +// dimensions. To transform back, we: +// - Unflatten the start indices batching dimensions. +// - Introduce trivial index dimensions that aren't in `collapsed_slice_dims`. +// - Transpose dimensions back based on `offset_dims` and +// `start_indices_batching_dims`. +Value UncanonicalizeResult(mhlo::GatherOp gather_op, Value canonical_result, + ShapedType canonical_result_type, + ShapedType original_result_type, + ArrayRef offset_dims, + ArrayRef operand_batching_dims, + ArrayRef start_indices_batching_dims, + ArrayRef start_index_map, + ArrayRef slice_sizes, + ArrayRef collapsed_slice_dims, + ConversionPatternRewriter& rewriter) { + // For those dims NOT inside the original_offset_dims are considered "batch // dims". std::vector batch_dims; // Offset dims are guaranteed to be sorted. int offset_index = 0; - for (int64_t i = 0; i < output_rank; ++i) { - if (offset_index >= original_offset_dims.size() || - original_offset_dims[offset_index] != i) { + for (int64_t i = 0; i < original_result_type.getRank(); ++i) { + if (offset_index >= offset_dims.size() || offset_dims[offset_index] != i) { batch_dims.push_back(i); } else { ++offset_index; } } - // Populate the trnaspose permutation params from a "canonicalized" - // output - // to the real output. - // The canonicalized layout would be batch_dims followed by sliced_dims. - // The current layout is essentially a transpose after the canonicalized - // layout. - // Take the following as an example: - // If we have the: - // original_offset_dims like [1, 2, 4] - // batch_dims like [0, 3] - // It's like performing transpose on a "canonicalized" - // [batch_dims, sliced_dims]: [B1, B2, O1, O2, O3] - // into the current layout: [B1, O1, O2, B2, O3] - // where the permutation is [0, 2, 3, 1, 4] - int batch_idx = 0; - int offset_idx = 0; - int batch_dim_size = batch_dims.size(); - for (int i = 0; i < output_rank; ++i) { - if (batch_idx >= batch_dims.size()) { - transpose_params.permutation.push_back(batch_dim_size + offset_idx); - ++offset_idx; - } else if (offset_idx < original_offset_dims.size() && - original_offset_dims[offset_idx] < batch_dims[batch_idx]) { - transpose_params.permutation.push_back(batch_dim_size + offset_idx); - ++offset_idx; + // Determine the canonical shape after unflattening the start indices + // batching dimensions (if they exist) and introducing any trivial index + // dimensions that weren't collapsed. Also compute the permutation to + // transform the original shape to the unflattened canonical shape. + llvm::SmallVector permutation_to_canonical; + llvm::SmallVector unflattened_shape; + for (int64_t i : start_indices_batching_dims) { + int64_t dim = batch_dims[i]; + permutation_to_canonical.push_back(dim); + unflattened_shape.push_back(original_result_type.getDimSize(dim)); + } + for (int64_t i = 0; i < batch_dims.size(); ++i) { + if (llvm::count(start_indices_batching_dims, i) == 0) { + int64_t dim = batch_dims[i]; + permutation_to_canonical.push_back(dim); + unflattened_shape.push_back(original_result_type.getDimSize(dim)); + } + } + // The remaining dimensions are the offset dims. We expect non-collapsed + // indexed dimensions first, followed by the rest of the operand dimensions. + llvm::SmallVector operand_dim_to_offset_dim_map(slice_sizes.size(), + -1); + int offset_dim_index = 0; + llvm::SmallVector remaining_operand_dims; + for (int64_t operand_dim = 0; operand_dim < slice_sizes.size(); + ++operand_dim) { + if (llvm::count(collapsed_slice_dims, operand_dim) || + llvm::count(operand_batching_dims, operand_dim)) { + continue; } else { - transpose_params.permutation.push_back(batch_idx++); + if (llvm::count(start_index_map, operand_dim) == 0) { + remaining_operand_dims.push_back(operand_dim); + } + operand_dim_to_offset_dim_map[operand_dim] = + offset_dims[offset_dim_index++]; + } + } + for (int64_t s : start_index_map) { + if (llvm::count(collapsed_slice_dims, s) == 0) { + int64_t dim = operand_dim_to_offset_dim_map[s]; + permutation_to_canonical.push_back(dim); + unflattened_shape.push_back(original_result_type.getDimSize(dim)); } } + for (int64_t operand_dim : remaining_operand_dims) { + int64_t dim = operand_dim_to_offset_dim_map[operand_dim]; + permutation_to_canonical.push_back(dim); + unflattened_shape.push_back(original_result_type.getDimSize(dim)); + } - // Finally, let's find out what are the "canonicalized" output shape - // looks - // like. - for (auto dim : batch_dims) { - transpose_params.canonicalized_output_shape.push_back( - result_type.getDimSize(dim)); + // Reshape the result to unflatten the batching dimensions and add back any + // non-collapsed indexed dimensions. The caller should ensure that a + // reshape is not needed if the result has dynamic dimensions. + if (canonical_result_type.hasStaticShape()) { + auto unflattened_result_type = RankedTensorType::get( + unflattened_shape, original_result_type.getElementType()); + canonical_result = rewriter.create( + gather_op.getLoc(), unflattened_result_type, canonical_result); + } + // Transpose back to the original result shape. + return rewriter.create( + gather_op.getLoc(), original_result_type, canonical_result, + rewriter.getI64TensorAttr( + GetInversePermutationArray(permutation_to_canonical))); +} + +// Canonicalize `operand` to handle operand batching dimensions and non-iota +// start index map, so it can be used by tf.GatherNd: +// - Transpose so that the leading dimensions are the operand batching +// dimensions followed by the indexed dimensions (in order). +// - Flatten the batching dimensions. +Value CanonicalizeOperand(mhlo::GatherOp gather_op, Value operand, + ShapedType operand_type, + ArrayRef operand_batching_dims, + ArrayRef start_index_map, + ConversionPatternRewriter& rewriter) { + int batch_size = 1; + llvm::SmallVector permutation; + llvm::SmallVector transposed_shape; + llvm::SmallVector flattened_shape; + // First add the batching dimensions. + for (int64_t batch_dim : operand_batching_dims) { + permutation.push_back(batch_dim); + transposed_shape.push_back(operand_type.getDimSize(batch_dim)); + batch_size *= operand_type.getDimSize(batch_dim); + } + if (!operand_batching_dims.empty()) { + flattened_shape.push_back(batch_size); } - for (auto dim : original_offset_dims) { - transpose_params.canonicalized_output_shape.push_back( - result_type.getDimSize(dim)); + // Add the indexed dimensions. + for (int64_t s : start_index_map) { + permutation.push_back(s); + transposed_shape.push_back(operand_type.getDimSize(s)); + flattened_shape.push_back(operand_type.getDimSize(s)); } - return transpose_params; + // Finally, add the remaining dimensions. + for (int64_t i = 0; i < operand_type.getRank(); i++) { + if (llvm::count(operand_batching_dims, i) == 0 && + llvm::count(start_index_map, i) == 0) { + permutation.push_back(i); + transposed_shape.push_back(operand_type.getDimSize(i)); + flattened_shape.push_back(operand_type.getDimSize(i)); + } + } + + // Transpose the dimensions and flatten the batching dimensions. + RankedTensorType transposed_type = + RankedTensorType::get(transposed_shape, operand_type.getElementType()); + auto transposed_operand = rewriter.create( + gather_op.getLoc(), transposed_type, operand, + rewriter.getI64TensorAttr(permutation)); + auto flattened_type = + RankedTensorType::get(flattened_shape, operand_type.getElementType()); + auto flattened_operand = rewriter.create( + gather_op.getLoc(), flattened_type, transposed_operand); + return flattened_operand; } +// Canonicalize `start_indices` to handle start indices batching dimensions so +// it can be used by tf.GatherNd: +// - Transpose so that the batching dimensions are the leading dimensions. +// - Flatten the batching dimensions if they exist. +// - For each indexed dimension with non-trivial slicing, introduce a new +// dimension, and broadcast and add iota values to the indices. +// - Add iota index values for the operand batching dimensions. +Value CanonicalizeStartIndices(mhlo::GatherOp gather_op, Value start_indices, + ShapedType start_indices_type, + ArrayRef start_indices_batching_dims, + ArrayRef start_index_map, + ArrayRef slice_sizes, + ConversionPatternRewriter& rewriter) { + int batch_size = 1; + llvm::SmallVector permutation; + llvm::SmallVector transposed_shape; + llvm::SmallVector reshaped_shape; + + // First add the batching dimensions. + for (int64_t batch_dim : start_indices_batching_dims) { + permutation.push_back(batch_dim); + transposed_shape.push_back(start_indices_type.getDimSize(batch_dim)); + batch_size *= start_indices_type.getDimSize(batch_dim); + } + if (!start_indices_batching_dims.empty()) { + reshaped_shape.push_back(batch_size); + } + + // Add remaining dimensions before the final index vector dim. + for (int64_t dim = 0; dim < start_indices_type.getRank() - 1; dim++) { + if (llvm::count(start_indices_batching_dims, dim) == 0) { + permutation.push_back(dim); + transposed_shape.push_back(start_indices_type.getDimSize(dim)); + reshaped_shape.push_back(start_indices_type.getDimSize(dim)); + } + } + + // Introduce new dimensions associated with each indexed operand dimension + // that is taking a non-trivial slice. We will broadcast and add iota values + // after reshaping. See comment below for more details. + int64_t first_non_trivial_sliced_dim = reshaped_shape.size(); + for (int64_t operand_dim : start_index_map) { + if (slice_sizes[operand_dim] > 1) { + reshaped_shape.push_back(1); + } + } + + // Add the index vector dimension. + int64_t index_vector_size = + start_indices_type.getDimSize(start_indices_type.getRank() - 1); + permutation.push_back(permutation.size()); + transposed_shape.push_back(index_vector_size); + reshaped_shape.push_back(index_vector_size); + + // Transpose the dimensions and flatten the batching dimensions. + auto transposed_start_indices = rewriter.create( + gather_op.getLoc(), + RankedTensorType::get(transposed_shape, + start_indices_type.getElementType()), + start_indices, rewriter.getI64TensorAttr(permutation)); + start_indices = rewriter.create( + gather_op.getLoc(), + RankedTensorType::get(reshaped_shape, + start_indices_type.getElementType()), + transposed_start_indices); + + // Because tf.GatherNd does not support non-trivial slicing on indexed + // dimensions, we introduce new dimensions in start_indices and broadcast + // and add iota values to the indices. For example: + // + // operand_shape = [10, 10, 10] + // start_indices_original_shape = [1, 3] + // start_index_map = [0, 1, 2] + // slice_sizes = [1, 5, 1] + // + // We then transform the start indices by broadcasting the shape to + // [1, 5, 3], and adding the iota tensor with the following values: + // + // [[[ 0 0 0 ] + // [ 0 1 0 ] + // [ 0 2 0 ] + // [ 0 3 0 ] + // [ 0 4 0 ]]] + // + // This allows us to take trivial slices when indexing into operand + // dimension 1. + llvm::SmallVector start_indices_shape = reshaped_shape; + int64_t non_trivial_sliced_dim = first_non_trivial_sliced_dim; + for (int i = 0; i < start_index_map.size(); ++i) { + int64_t operand_dim = start_index_map[i]; + if (slice_sizes[operand_dim] == 1) { + continue; + } + // Create iota values along the sliced dimension. + llvm::SmallVector offsets_shape(start_indices_shape.size(), 1); + offsets_shape[non_trivial_sliced_dim] = slice_sizes[operand_dim]; + start_indices_shape[non_trivial_sliced_dim] = slice_sizes[operand_dim]; + auto offsets = rewriter.create( + gather_op.getLoc(), + RankedTensorType::get(offsets_shape, + start_indices_type.getElementType()), + rewriter.getI64IntegerAttr(non_trivial_sliced_dim)); + non_trivial_sliced_dim++; + + // Pad with 0s on the other operand dimensions. + Value zero = rewriter.create( + gather_op.getLoc(), rewriter.getZeroAttr(RankedTensorType::get( + {}, start_indices_type.getElementType()))); + int rank = offsets_shape.size(); + llvm::SmallVector padding_low(rank, 0); + llvm::SmallVector padding_high(rank, 0); + llvm::SmallVector padding_interior(rank, 0); + padding_low.back() = i; + padding_high.back() = start_indices_shape.back() - i - 1; + auto padded_offsets = rewriter.create( + gather_op.getLoc(), offsets, zero, + GetI64ElementsAttr(padding_low, &rewriter), + GetI64ElementsAttr(padding_high, &rewriter), + GetI64ElementsAttr(padding_interior, &rewriter)); + + // Add the padded offsets to the start indices (with broadcasting). + start_indices = rewriter.create( + gather_op.getLoc(), start_indices, padded_offsets, + /*fused_activation_function=*/ + mlir::StringAttr::get(rewriter.getContext(), "NONE")); + } + + if (!start_indices_batching_dims.empty()) { + // Concat iota values for indexing into the batching dimensions of the + // operand. + llvm::SmallVector offsets_shape = start_indices_shape; + offsets_shape.back() = 1; + auto offsets = rewriter.create( + gather_op.getLoc(), + RankedTensorType::get(offsets_shape, + start_indices_type.getElementType()), + rewriter.getI64IntegerAttr(0)); + + start_indices_shape.back()++; + start_indices = rewriter.create( + gather_op.getLoc(), + RankedTensorType::get(start_indices_shape, + start_indices_type.getElementType()), + ValueRange{offsets, start_indices}, + rewriter.getI32IntegerAttr(start_indices_shape.size() - 1)); + } + + return start_indices; +} +} // namespace + +// Tries to convert an mhlo::GatherOp into a TFL::GatherNdOp. +// +// Consider the following example: +// operand_shape = [B1, I1, O1, B2, I2, O2] +// operand_batching_dims = [0, 3] +// +// start_indices_shape = [B2, B3, B1, 2] +// start_indices_batching_dims = [3, 0] +// index_vector_dim = 3 +// start_index_map = [4, 1] +// +// offset_dims: [2, 4] +// slice_sizes = [1, 1, O1, 1, 1, O2] +// collapsed_slice_dims = [1, 4] +// result_shape = [B2, B3, O1, B3, O2] +// +// To implement this with a tfl.GatherNd, we canonicalize the operand s.t. the +// operand batching dimensions are flattened into the leading dimensions, +// followed by the indexed dimensions in order: +// canonical_operand_shape = [B1 * B2, I2, I1, O1, O2] +// +// We canonicalize the start indices so the start indices batching dimensions +// are flattened (in order) into a leading dimension. In addition, we add iota +// indices to appropriately offset into the flattened operand batching +// dimension: +// canonical_start_indices_shape = [B1 * B2, B3, 3] +// (index_vector_dim is expanded to included indices for the operand +// batching dimensions) +// +// The result of tf.GatherNd(canonical_operand, canonical_start_indices) has the +// following shape: +// canonical_result_shape = [B1 * B2, B3, O1, O2] +// +// The canonical result is unflattened and transpose as needed to get back to +// the original result shape. class LegalizeGatherToGatherND : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -300,6 +563,38 @@ LogicalResult LegalizeGatherToGatherND::matchAndRewrite( return failure(); } + llvm::ArrayRef operand_batching_dims = + gather_op.getDimensionNumbers().getOperandBatchingDims(); + llvm::ArrayRef start_indices_batching_dims = + gather_op.getDimensionNumbers().getStartIndicesBatchingDims(); + llvm::ArrayRef start_index_map = + gather_op.getDimensionNumbers().getStartIndexMap(); + llvm::ArrayRef collapsed_slice_dims = + gather_op.getDimensionNumbers().getCollapsedSliceDims(); + if (!start_indices_type.hasStaticShape() || !result_type.hasStaticShape()) { + // Dynamic dimensions aren't supported in certain cases that require + // reshaping the indices or result. + if (!start_indices_batching_dims.empty()) { + gather_op.emitOpError() + << "Dynamic shaped start indices aren't supported when there are " + "batching dimensions."; + } + + // Verify that start_index_map and collapsed_slice_dims contains the same + // values. + if (start_index_map.size() != collapsed_slice_dims.size()) { + return rewriter.notifyMatchFailure( + gather_op, + "different size for start index map and collapsed slice dims"); + } + for (auto c : collapsed_slice_dims) { + if (llvm::count(start_index_map, c) == 0) { + return rewriter.notifyMatchFailure( + gather_op, "collapsed slice dim isn't present in start index map"); + } + } + } + // Normalize start_indices so index_vector_dim == start_indices.rank() - 1. int64_t index_vector_dim = gather_op.getDimensionNumbers().getIndexVectorDim(); @@ -307,118 +602,86 @@ LogicalResult LegalizeGatherToGatherND::matchAndRewrite( index_vector_dim, rewriter))) { return failure(); } + start_indices_type = mlir::cast(start_indices.getType()); - // Verify that start_index_map and collapsed_slice_dims contains the same - // values. - auto start_index_map = gather_op.getDimensionNumbers().getStartIndexMap(); - auto collapsed_slice_dims = - gather_op.getDimensionNumbers().getCollapsedSliceDims(); - if (start_index_map.size() != collapsed_slice_dims.size()) { - return rewriter.notifyMatchFailure( - gather_op, - "different size for start index map and collapsed slice dims"); - } - for (auto c : collapsed_slice_dims) { - if (llvm::count(start_index_map, c) == 0) { - return rewriter.notifyMatchFailure( - gather_op, "collapsed slice dim isn't present in start index map"); - } - } - - // Verify that slice_sizes is 1 for the indexed dimensions and the full - // shape for the rest of the dimensions. + // Verify that slice_sizes is 1 for the batching dimensions and the full + // shape for non-indexed dimensions. auto slice_sizes = gather_op.getSliceSizes(); - int64_t index = 0; + llvm::SmallVector slice_sizes_vector; + slice_sizes_vector.reserve(slice_sizes.size()); for (int64_t s : slice_sizes.getValues()) { - if (llvm::count(start_index_map, index)) { + slice_sizes_vector.push_back(s); + } + for (int i = 0; i < slice_sizes_vector.size(); ++i) { + int s = slice_sizes_vector[i]; + if (llvm::count(start_indices_batching_dims, i)) { if (s != 1) { return rewriter.notifyMatchFailure(gather_op, "unsupported slice sizes"); } - } else { - if (s != operand_type.getShape()[index]) { + } else if (llvm::count(start_index_map, i) == 0) { + if (s != operand_type.getShape()[i]) { return rewriter.notifyMatchFailure(gather_op, "unsupported slice sizes"); } } - ++index; - } - - // Verify that offset_dims are the tailing dimensions in the output tensor. - auto offset_dims = gather_op.getDimensionNumbers().getOffsetDims(); - SmallVector offset_dims_vector(offset_dims.begin(), - offset_dims.end()); - const TransposeParams& transpose_params = - CanonicalizeOffset(/*result_type=*/result_type, - /*original_offset_dims=*/offset_dims_vector); - - int64_t offset = start_indices_type.getRank() - 1; - for (int64_t o : transpose_params.canonicalized_offset_dims) { - if (o != offset) { - return rewriter.notifyMatchFailure(gather_op, "unsupported offset dims"); - } - ++offset; - } - - // Transpose the operand to handle non-iota start index map. - llvm::SmallVector transpose_dimensions; - llvm::SmallVector transpose_shape; - for (auto s : start_index_map) { - transpose_dimensions.push_back(s); - transpose_shape.push_back(operand_type.getShape()[s]); - } - for (int64_t i = 0, e = operand_type.getRank(); i < e; ++i) { - if (llvm::count(start_index_map, i) == 0) { - transpose_dimensions.push_back(i); - transpose_shape.push_back(operand_type.getShape()[i]); - } } - operand_type = - RankedTensorType::get(transpose_shape, operand_type.getElementType()); - operand = rewriter.create( - gather_op.getLoc(), operand_type, operand, - rewriter.getI64TensorAttr(transpose_dimensions)); - - // Check whether we need to append a transpose op after the gather nd. - bool need_transpose_after = false; - for (int i = 0; i < transpose_params.permutation.size(); ++i) { - if (i != transpose_params.permutation[i]) { - need_transpose_after = true; - break; - } - } - - auto tf_gather_nd_result_type = - RankedTensorType::get(transpose_params.canonicalized_output_shape, - result_type.getElementType()); - if (start_indices_type.getElementType().isUnsignedInteger(32)) { - start_indices = rewriter.create( + // Canonicalize the operand and start indices. + auto canonical_operand = + CanonicalizeOperand(gather_op, operand, operand_type, + operand_batching_dims, start_index_map, rewriter); + auto canonical_operand_type = + mlir::cast(canonical_operand.getType()); + + auto canonical_start_indices = CanonicalizeStartIndices( + gather_op, start_indices, start_indices_type, start_indices_batching_dims, + start_index_map, slice_sizes_vector, rewriter); + auto canonical_start_indices_type = + mlir::cast(canonical_start_indices.getType()); + + TFL::CastOp cast_op = nullptr; + if (canonical_start_indices_type.getElementType().isUnsignedInteger(32)) { + cast_op = rewriter.create( gather_op->getLoc(), - RankedTensorType::get(start_indices_type.getShape(), + RankedTensorType::get(canonical_start_indices_type.getShape(), rewriter.getI64Type()), - start_indices); + canonical_start_indices); } - auto tf_gather_nd_op = rewriter.create( - gather_op->getLoc(), tf_gather_nd_result_type, operand, start_indices); - - if (!need_transpose_after) { - rewriter.replaceOp(gather_op, tf_gather_nd_op->getOpResults()); - return success(); + llvm::SmallVector canonical_result_shape; + for (int64_t i = 0; i < canonical_start_indices_type.getRank() - 1; ++i) { + canonical_result_shape.push_back( + canonical_start_indices_type.getDimSize(i)); + } + for (int64_t i = canonical_start_indices_type.getDimSize( + canonical_start_indices_type.getRank() - 1); + i < canonical_operand_type.getRank(); ++i) { + canonical_result_shape.push_back(canonical_operand_type.getDimSize(i)); } - // Insert the transpose op after the gather_nd. - rewriter.replaceOpWithNewOp( - gather_op, result_type, tf_gather_nd_op, - rewriter.getI64TensorAttr(transpose_params.permutation)); + auto canonical_result_type = RankedTensorType::get( + canonical_result_shape, result_type.getElementType()); + auto canonical_result = rewriter.create( + gather_op->getLoc(), canonical_result_type, canonical_operand, + cast_op ? cast_op.getResult() : canonical_start_indices); + + auto offset_dims = gather_op.getDimensionNumbers().getOffsetDims(); + auto final_result = UncanonicalizeResult( + gather_op, canonical_result, canonical_result_type, result_type, + offset_dims, operand_batching_dims, start_indices_batching_dims, + start_index_map, slice_sizes_vector, collapsed_slice_dims, rewriter); + rewriter.replaceOp(gather_op, final_result); return success(); } void PopulateGatherPatterns(MLIRContext* ctx, RewritePatternSet& patterns, ConversionTarget& target) { - patterns.add(ctx); + // Prefer `LegalizeGatherToSlice` for the cases it handles, since it produces + // simpler IR. + patterns.add(ctx, /*benefit=*/2); + patterns.add(ctx); target.addDynamicallyLegalOp(IsGatherLegal); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc index e628cfded7fe6b..e1f1681a3d7ae1 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc @@ -12,7 +12,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td index 185216448a15ed..322fcc44ed4a9f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td @@ -542,14 +542,14 @@ def ArgTypesMatchCallee : Constraint< foreach callOp = [TF_PartitionedCallOp, TF_StatefulPartitionedCallOp] in { def : Pat<(callOp:$op $args, FlatSymbolRefAttr:$f, $config, $config_proto, $executor_type), - (CallOp $f, $args), + (CallOp $f, $args, ConstantAttr), [(ArgTypesMatchCallee $op, $args, $f)]>; } // The extra attr on this op is _disable_call_shape_inference, which we ignore // in the bridge. def : Pat<(TF_LegacyCallOp:$op $args, FlatSymbolRefAttr:$f, $attr), - (CallOp $f, $args), + (CallOp $f, $args, ConstantAttr), [(ArgTypesMatchCallee $op, $args, $f)]>; //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc index 5a63a339e460b9..9f931e1bc4bfdf 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc @@ -15,12 +15,13 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" +#include #include #include -#include #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -49,8 +50,9 @@ limitations under the License. namespace mlir { namespace odml { -static constexpr std::string_view kStablehloModuleDefaultEntryFuncName = "main"; -static constexpr std::string_view kStablehloFuncNamePrefix = "XlaCallModule"; +static constexpr absl::string_view kStablehloModuleDefaultEntryFuncName = + "main"; +static constexpr absl::string_view kStablehloFuncNamePrefix = "XlaCallModule"; static constexpr char kShardingAttr[] = "mhlo.sharding"; static constexpr char kShardingName[] = "Sharding"; @@ -103,21 +105,6 @@ bool ContainsPlatformIndexArg(TF::XlaCallModuleOp xla_call_module_op) { return xla_call_module_op.getPlatforms().size() > 1; } -// Removes the platform index argument from the function. It is equivalent to -// removing the first argument from `func_op` (see the comments at -// `ContainsPlatformIndexArg`). This function assumes that `func_op` is a valid -// function deserialized from XlaCallModule op. -void RemovePlatformIndexArg(MLIRContext *ctx, func::FuncOp func_op) { - // If there are multiple platforms, the first argument is reserved for - // passing the platform index. - FunctionType function_type = func_op.getFunctionType(); - ArrayRef new_input_types = - function_type.getInputs().take_back(func_op.getNumArguments() - 1); - func_op.setFunctionType( - FunctionType::get(ctx, new_input_types, function_type.getResults())); - func_op.getBody().eraseArgument(0); -} - } // namespace class ConvertTFXlaCallModuleOp : public OpRewritePattern { @@ -181,12 +168,20 @@ class ConvertTFXlaCallModuleOp : public OpRewritePattern { } // When the `XlaCallModuleOp`'s callee accepts a platform index argument, - // remove it. This is because when converted to `CallOp` there will be a - // mismatch btw. the number of arguments passed and number of parameters - // accepted (the platform index argument is an extra argument that is not - // expressed by the operands of XlaCallModuleOp). + // add a dummy platform index argument in order to match the number of + // the arguments of the callee function. + // + // This is because `XlaCallModuleOp` doesn't explicitly take it as an + // operand. See: + // https://github.com/tensorflow/tensorflow/blob/eba24f41ba9d661d2f58a515921720cf90708cd4/tensorflow/compiler/tf2xla/ops/xla_ops.cc#L1376-L1385 + + SmallVector call_op_operands(op.getOperands()); if (ContainsPlatformIndexArg(op)) { - RemovePlatformIndexArg(getContext(), main_fn); + Value dummy_const = rewriter.create( + op.getLoc(), + DenseIntElementsAttr::get( + RankedTensorType::get({}, rewriter.getIntegerType(32)), {0})); + call_op_operands.insert(call_op_operands.begin(), dummy_const); } // The stablehlo module main function's input tensor types might be @@ -195,8 +190,9 @@ class ConvertTFXlaCallModuleOp : public OpRewritePattern { // argument type is tensor<1x2f32>. SmallVector casted_operands; casted_operands.reserve(main_fn.getNumArguments()); + assert(call_op_operands.size() == main_fn.getNumArguments()); for (const auto &operand_and_type : - zip(op.getOperands(), main_fn.getFunctionType().getInputs())) { + zip(call_op_operands, main_fn.getFunctionType().getInputs())) { Value operand = std::get<0>(operand_and_type); Type expected_type = std::get<1>(operand_and_type); if (operand.getType() != expected_type) { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc index f52ca0a40553c5..7ff1ce6cc29df0 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc @@ -15,20 +15,15 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h" -#include #include #include #include #include -#include "absl/container/flat_hash_map.h" #include "absl/log/log.h" #include "absl/strings/ascii.h" -#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h index acc3ca0e7923b1..8d57016bc7cf3b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "mlir/Pass/Pass.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc index 70f62c3e0b582e..d0e6fb4b3e9a77 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc index 81c6fc47473d43..23b2ccdc83a6bc 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h" +#include #include -#include #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.cc index 06754ea72b580c..249a1018e091f4 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.h" +#include #include -#include #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc index e38cad1d4c7edc..a3b2b47ac9f76a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" #include -#include #include #include "llvm/ADT/StringRef.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc index f5d756d971610e..d12b4f75a8211e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h index 13ff4c4767721d..fc7c2316655df9 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ +#include + #include "llvm/ADT/ArrayRef.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir index ab374046bbfd2e..46b92f06ceb409 100644 --- a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir @@ -186,7 +186,7 @@ func.func @WhileCanonicalizeBug(%arg0: tensor, %arg1: tensor) -> tenso // result. Canonicalize will think it can remove both slot#0 and slot#1 and do // so without replacing all operands, and in assert builds it will fail an // assert failure ( op->use_empty() && "expected 'op' to have no uses") -// CHECK-LABEL: WhileCanonicalizeBug1 +// CHECK-LABEL: @WhileCanonicalizeBug1 func.func @WhileCanonicalizeBug1(%arg0: tensor, %arg1: tensor) -> tensor { %0:2 = "tfl.while"(%arg0, %arg1) ({ ^bb0(%carg0: tensor, %carg1: tensor): @@ -242,6 +242,17 @@ func.func @RemoveFcZeroBias(%arg0: tensor<1x37xf32>, %arg1: tensor<40x37xf32>) - func.return %1 : tensor<1x40xf32> } +// CHECK-LABEL: forceAsymmetricQuantizeInput +func.func @forceAsymmetricQuantizeInput(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> { + %cst0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %cst1 = arith.constant dense<2.0> : tensor<2xf32> + + %0 = "tfl.fully_connected"(%arg0, %cst0, %cst1) {asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + func.return %0 : tensor<4x2xf32> + // CHECK %0 = "tfl.fully_connected"(%arg0, %cst0, %cst1) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + // CHECK return %0 +} + // CHECK-LABEL: RemoveLstmQuantZeroBias func.func @RemoveLstmQuantZeroBias( %arg0: tensor<1x528xf32>, @@ -373,6 +384,7 @@ func.func @OptimizeTranposeWithRank7orMoreEffectiveRank4(%arg0: tensor<56x8x56x1 // CHECK: return %2 } +// CHECK-LABEL: @ConstPadToI32 func.func @ConstPadToI32(%arg0: tensor<15600xf32>) -> tensor<15602xf32> { %0 = "tfl.pseudo_const"() {value = dense<1> : tensor<1x2xi64>} : () -> tensor<1x2xi64> %1 = "tfl.pad"(%arg0, %0) : (tensor<15600xf32>, tensor<1x2xi64>) -> tensor<15602xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index 487c9311e42e04..b758e0567d2cea 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -1132,7 +1132,7 @@ func.func @ConstantFoldFullyConnectedSmall() -> tensor<3xf32> { %cst_weights = arith.constant dense<[[5.0, 7.0], [11.0, 13.0], [17.0, 19.0]]> : tensor<3x2xf32> %cst_bias = arith.constant dense<[23.0, 29.0, 31.0]> : tensor<3xf32> - %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2xf32>, tensor<3x2xf32>, tensor<3xf32>) -> tensor<3xf32> + %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2xf32>, tensor<3x2xf32>, tensor<3xf32>) -> tensor<3xf32> func.return %0 : tensor<3xf32> // [54, 90, 122] @@ -1146,7 +1146,7 @@ func.func @ConstantFoldFullyConnectedLarge() -> tensor<1024xf32> { %cst_weights = arith.constant dense<2.0> : tensor<1024x512xf32> %cst_bias = arith.constant dense<4.0> : tensor<1024xf32> - %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<1024xf32> + %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<1024xf32> func.return %0 : tensor<1024xf32> @@ -1161,7 +1161,7 @@ func.func @ConstantFoldFullyConnectedNoBias() -> tensor<1024xf32> { %cst_weights = arith.constant dense<2.0> : tensor<1024x512xf32> %cst_bias = "tfl.no_value"() {value = unit} : () -> none - %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<512xf32>, tensor<1024x512xf32>, none) -> tensor<1024xf32> + %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<512xf32>, tensor<1024x512xf32>, none) -> tensor<1024xf32> func.return %0 : tensor<1024xf32> @@ -1176,13 +1176,13 @@ func.func @NoFoldFullyConnectedNonFloat() -> tensor<1024xf32> { %cst_weights = arith.constant dense<2> : tensor<1024x512xi8> %cst_bias = arith.constant dense<4.0> : tensor<1024xf32> - %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<512xf32>, tensor<1024x512xi8>, tensor<1024xf32>) -> tensor<1024xf32> + %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<512xf32>, tensor<1024x512xi8>, tensor<1024xf32>) -> tensor<1024xf32> func.return %0 : tensor<1024xf32> // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<512xf32> // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<2> : tensor<1024x512xi8> // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<4.000000e+00> : tensor<1024xf32> - // CHECK: %[[VAL:.*]] = "tfl.fully_connected"(%[[CST]], %[[CST_0]], %[[CST_1]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<512xf32>, tensor<1024x512xi8>, tensor<1024xf32>) -> tensor<1024xf32> + // CHECK: %[[VAL:.*]] = "tfl.fully_connected"(%[[CST]], %[[CST_0]], %[[CST_1]]) <{asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<512xf32>, tensor<1024x512xi8>, tensor<1024xf32>) -> tensor<1024xf32> // CHECK: return %[[VAL]] : tensor<1024xf32> } @@ -1192,13 +1192,13 @@ func.func @NoFoldFullyConnectedHighRank() -> tensor<2x1024xf32> { %cst_weights = arith.constant dense<2.0> : tensor<1024x512xf32> %cst_bias = arith.constant dense<4.0> : tensor<1024xf32> - %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<2x1024xf32> + %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<2x1024xf32> func.return %0 : tensor<2x1024xf32> // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<2x512xf32> // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<2.000000e+00> : tensor<1024x512xf32> // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<4.000000e+00> : tensor<1024xf32> - // CHECK: %[[VAL:.*]] = "tfl.fully_connected"(%[[CST]], %[[CST_0]], %[[CST_1]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<2x512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<2x1024xf32> + // CHECK: %[[VAL:.*]] = "tfl.fully_connected"(%[[CST]], %[[CST_0]], %[[CST_1]]) <{asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<2x512xf32>, tensor<1024x512xf32>, tensor<1024xf32>) -> tensor<2x1024xf32> // CHECK: return %[[VAL]] : tensor<2x1024xf32> } @@ -1208,7 +1208,7 @@ func.func @ConstantFoldFullyConnectedCheckPrecision() -> tensor<1xf32> { %cst_weights = arith.constant dense<[[1.0, 1.0e38, 1.0, -1.0e38]]> : tensor<1x4xf32> %cst_bias = arith.constant dense<0.0> : tensor<1xf32> - %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4xf32>, tensor<1x4xf32>, tensor<1xf32>) -> tensor<1xf32> + %0 = "tfl.fully_connected" (%cst_input, %cst_weights, %cst_bias) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4xf32>, tensor<1x4xf32>, tensor<1xf32>) -> tensor<1xf32> func.return %0 : tensor<1xf32> // CHECK: %[[CST:.*]] = arith.constant dense<2.000000e+00> : tensor<1xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt index 293fe283ee2685..a4bd43d9c01651 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt @@ -83,8 +83,8 @@ versions { # CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<0> : tensor # CHECK: %[[VAL_7:.*]]:2 = "tfl.split"(%[[VAL_6]], %[[VAL_0]]) <{num_splits = 2 : i32}> : (tensor, tensor<2x5x3xf32>) -> (tensor<1x5x3xf32>, tensor<1x5x3xf32>) # CHECK: %[[VAL_9:.*]] = "tfl.transpose"(%[[VAL_1]], %[[VAL_2]]) : (tensor<3x7xf32>, tensor<2xi32>) -> tensor<7x3xf32> -# CHECK: %[[VAL_10:.*]] = "tfl.fully_connected"(%[[VAL_7]]#0, %[[VAL_9]], %[[VAL_3]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<1x5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> -# CHECK: %[[VAL_11:.*]] = "tfl.fully_connected"(%[[VAL_7]]#1, %[[VAL_9]], %[[VAL_3]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<1x5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> +# CHECK: %[[VAL_10:.*]] = "tfl.fully_connected"(%[[VAL_7]]#0, %[[VAL_9]], %[[VAL_3]]) <{asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<1x5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> +# CHECK: %[[VAL_11:.*]] = "tfl.fully_connected"(%[[VAL_7]]#1, %[[VAL_9]], %[[VAL_3]]) <{asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<1x5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> # CHECK: %[[VAL_12:.*]] = "tfl.pack"(%[[VAL_10]], %[[VAL_11]]) <{axis = 0 : i32, values_count = 2 : i32}> : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<2x5x7xf32> # CHECK: return %[[VAL_12]] : tensor<2x5x7xf32> # CHECK: } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/dense_constants.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/dense_constants.mlir new file mode 100644 index 00000000000000..2d0d83c7d2aa55 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/dense_constants.mlir @@ -0,0 +1,55 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s +// Ensure constants roundtrip exactly + +func.func @f32() -> tensor<4xf32> { + // CHECK-LABEL: @f32 + // CHECK: value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> + %0 = "tfl.pseudo_const"() { value = dense_resource : tensor<4xf32> } : () -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +func.func @i8() -> tensor<4xi8> { + // CHECK-LABEL: @i8 + // CHECK: value = dense<[1, 2, 3, 4]> : tensor<4xi8> + %0 = "tfl.pseudo_const" () { value = dense_resource : tensor<4xi8> } : () -> tensor<4xi8> + func.return %0 : tensor<4xi8> +} + +func.func @i16() -> tensor<4xi16> { + // CHECK-LABEL: @i16 + // CHECK: value = dense<[1, 2, 3, 258]> : tensor<4xi16> + %0 = "tfl.pseudo_const" () { value = dense_resource : tensor<4xi16> } : () -> tensor<4xi16> + func.return %0 : tensor<4xi16> +} + +func.func @i32() -> tensor<4xi32> { + // CHECK-LABEL: @i32 + // CHECK: value = dense<[1, 2, 3, 16909060]> : tensor<4xi32> + // Check bytes come back in the right order + %0 = "tfl.pseudo_const" () { value = dense_resource : tensor<4xi32> } : () -> tensor<4xi32> + func.return %0 : tensor<4xi32> +} + +func.func @uint8() -> tensor<4xui8> { + // CHECK-LABEL: @uint8 + // CHECK: value = dense<[222, 173, 190, 239]> : tensor<4xui8> + %0 = "tfl.pseudo_const"() {value = dense_resource : tensor<4xui8>} : () -> tensor<4xui8> + func.return %0 : tensor<4xui8> +} + +// Identity function to make the exporter happy +func.func @main(%arg0: tensor<4xi8>) -> tensor<4xi8> { + func.return %arg0 : tensor<4xi8> +} + +{-# + dialect_resources: { + builtin: { + dense_elements_f32: "0x400000000000803F000000400000404000008040", + dense_elements_i16: "0x400000000100020003000201", + dense_elements_i32: "0x4000000001000000020000000300000004030201", + dense_elements_i8: "0x4000000001020304", + dense_elements_i8_1: "0x40000000DEADBEEF" + } + } +#-} diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/dense_constants_offset.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/dense_constants_offset.mlir new file mode 100644 index 00000000000000..b2fe9a8a463101 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/dense_constants_offset.mlir @@ -0,0 +1,55 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer --use-buffer-offset %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s +// Ensure constants roundtrip exactly + +func.func @f32() -> tensor<4xf32> { + // CHECK-LABEL: @f32 + // CHECK: value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> + %0 = "tfl.pseudo_const"() { value = dense_resource : tensor<4xf32> } : () -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +func.func @i8() -> tensor<4xi8> { + // CHECK-LABEL: @i8 + // CHECK: value = dense<[1, 2, 3, 4]> : tensor<4xi8> + %0 = "tfl.pseudo_const" () { value = dense_resource : tensor<4xi8> } : () -> tensor<4xi8> + func.return %0 : tensor<4xi8> +} + +func.func @i16() -> tensor<4xi16> { + // CHECK-LABEL: @i16 + // CHECK: value = dense<[1, 2, 3, 258]> : tensor<4xi16> + %0 = "tfl.pseudo_const" () { value = dense_resource : tensor<4xi16> } : () -> tensor<4xi16> + func.return %0 : tensor<4xi16> +} + +func.func @i32() -> tensor<4xi32> { + // CHECK-LABEL: @i32 + // CHECK: value = dense<[1, 2, 3, 16909060]> : tensor<4xi32> + // Check bytes come back in the right order + %0 = "tfl.pseudo_const" () { value = dense_resource : tensor<4xi32> } : () -> tensor<4xi32> + func.return %0 : tensor<4xi32> +} + +func.func @uint8() -> tensor<4xui8> { + // CHECK-LABEL: @uint8 + // CHECK: value = dense<[222, 173, 190, 239]> : tensor<4xui8> + %0 = "tfl.pseudo_const"() {value = dense_resource : tensor<4xui8>} : () -> tensor<4xui8> + func.return %0 : tensor<4xui8> +} + +// Identity function to make the exporter happy +func.func @main(%arg0: tensor<4xi8>) -> tensor<4xi8> { + func.return %arg0 : tensor<4xi8> +} + +{-# + dialect_resources: { + builtin: { + dense_elements_f32: "0x400000000000803F000000400000404000008040", + dense_elements_i16: "0x400000000100020003000201", + dense_elements_i32: "0x4000000001000000020000000300000004030201", + dense_elements_i8: "0x4000000001020304", + dense_elements_i8_1: "0x40000000DEADBEEF" + } + } +#-} diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index fbee322810a6eb..b991f62ff0aeb0 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -875,7 +875,7 @@ func.func @convert_bmm_rhs_transpose_into_fc(%arg0: tensor<8x256xf32>, %arg1: te // CHECK: return %2 : tensor<8x256xf32> // FOLD: %0 = "tfl.no_value"() <{value}> : () -> none - // FOLD: %1 = "tfl.fully_connected"(%arg0, %arg1, %0) <{asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<8x256xf32>, tensor<256x256xf32>, none) -> tensor<8x256xf32> + // FOLD: %1 = "tfl.fully_connected"(%arg0, %arg1, %0) <{asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<8x256xf32>, tensor<256x256xf32>, none) -> tensor<8x256xf32> // FOLD: return %1 : tensor<8x256xf32> } @@ -1218,7 +1218,7 @@ func.func @MoveReshapeAfterFullyConnected(%arg0: tensor<4x4x10xf32>)->(tensor<16 // FOLD: %[[BIAS:.*]] = "tfl.no_value"() <{value}> : () -> none // FOLD: %[[SHAPE:.*]] = arith.constant dense<[16, 10]> : tensor<2xi32> // FOLD: %[[INPUT:.*]] = "tfl.reshape"(%arg0, %[[SHAPE]]) : (tensor<4x4x10xf32>, tensor<2xi32>) -> tensor<16x10xf32> - // FOLD: %[[RESULT:.*]] = "tfl.fully_connected"(%[[INPUT]], %[[FILTER]], %[[BIAS]]) <{asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<16x10xf32>, tensor<20x10xf32>, none) -> tensor<16x20xf32> + // FOLD: %[[RESULT:.*]] = "tfl.fully_connected"(%[[INPUT]], %[[FILTER]], %[[BIAS]]) <{asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<16x10xf32>, tensor<20x10xf32>, none) -> tensor<16x20xf32> // FOLD: return %[[RESULT]] : tensor<16x20xf32> } @@ -1272,7 +1272,7 @@ func.func @fuse_fc_and_lhs_reshape(%arg0: tensor<1x128x14336xf32>) -> tensor<128 //FOLD: %cst = arith.constant dense<9.000000e+00> : tensor<1792x14336xf32> //FOLD: %0 = "tfl.no_value"() <{value}> : () -> none - //FOLD: %1 = "tfl.fully_connected"(%arg0, %cst, %0) <{asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<1x128x14336xf32>, tensor<1792x14336xf32>, none) -> tensor<128x1792xf32> + //FOLD: %1 = "tfl.fully_connected"(%arg0, %cst, %0) <{asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<1x128x14336xf32>, tensor<1792x14336xf32>, none) -> tensor<128x1792xf32> //FOLD: return %1 : tensor<128x1792xf32> } @@ -1314,7 +1314,7 @@ func.func @FuseFullyConnectedReshapeAddConstWithActivation(%arg0: tensor<40x37xf // CHECK: return %[[rs2]] // FOLD: %[[cst:.*]] = arith.constant dense<5.000000e+00> : tensor<40x40xf32> - // FOLD: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) <{fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}> + // FOLD: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) <{asymmetric_quantize_inputs = true, fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}> // FOLD: return %[[fc]] } @@ -3959,6 +3959,7 @@ func.func @fuseSigmoid(%arg0: tensor<10xf32>) -> tensor<10xf32> { %3 = tfl.div %cst, %2 {fused_activation_function = "NONE"} : tensor<10xf32> return %3 : tensor<10xf32> } + // CHECK-LABEL: func @fuseElu func.func @fuseElu(%arg0: tensor<10xf32>) -> tensor<10xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "args_tf_0", outputs = "Identity_1"}} { // CHECK: "tfl.elu" @@ -3984,6 +3985,7 @@ func.func @fuseHardSwishJAX(%arg0: tensor<10xf32>) -> tensor<10xf32> attributes %4 = tfl.mul %arg0, %3 {fused_activation_function = "NONE"} : tensor<10xf32> return %4 : tensor<10xf32> } + // CHECK-LABEL: func @fuseLeakyRelu func.func @fuseLeakyRelu(%arg0: tensor<10xf32>) -> tensor<10xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "args_tf_0", outputs = "Identity_1"}} { // CHECK: "tfl.leaky_relu" @@ -4488,3 +4490,13 @@ func.func @reorder_gather_cast(%arg0: tensor<2x3x5xi8>, %arg1: tensor<2x7xi32>) // CHECK: %0 = "tfl.gather"(%arg0, %arg1) <{axis = 1 : i32, batch_dims = 1 : i32}> : (tensor<2x3x5xi8>, tensor<2x7xi32>) -> tensor<2x7x5xi8> // CHECK: %1 = "tfl.cast"(%0) : (tensor<2x7x5xi8>) -> tensor<2x7x5xf32> + +// CHECK-LABEL: @RealDivWithConstDivisor +func.func @RealDivWithConstDivisor(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + %cst = arith.constant dense<5.000000e+00> : tensor + %1 = tfl.div(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2x3xf32>, tensor) -> tensor<2x3xf32> + func.return %1 : tensor<2x3xf32> + // CHECK: %cst = arith.constant dense<2.000000e-01> : tensor + // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2x3xf32>, tensor) -> tensor<2x3xf32> + // CHECK: return %0 : tensor<2x3xf32> +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 18cafc2f8f094f..f36d65c358e0a6 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -61,6 +61,28 @@ namespace { constexpr mlir::StringRef kTFLiteDataLayout = "NHWC"; } // namespace +void AddStrictQDQQuantizationPasses(const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager& pass_manager) { + mlir::quant::QuantizationSpecs updated_quant_specs; + updated_quant_specs = pass_config.quant_specs; + // TODO(majiddadashi): setting QDQCOnversionMode to static to enable per-axis + // propagation of parameters for transpose in the prepare quantize pass. The + // flag likely should become an enum value of QDQConversionMode. + updated_quant_specs.qdq_conversion_mode = + mlir::quant::QDQConversionMode::kQDQStatic; + pass_manager.addNestedPass( + mlir::TFL::CreatePrepareQuantizePass(updated_quant_specs)); + + pass_manager.addNestedPass( + mlir::TFL::CreateQuantizePass(pass_config.quant_specs)); + pass_manager.addNestedPass( + mlir::TFL::CreatePostQuantizePass(true)); + + // So that quantized clipping activations get fused into preceding ops. + pass_manager.addNestedPass( + mlir::TFL::CreateOptimizePass()); +} + void AddQuantizationPasses(const mlir::TFL::PassConfig& pass_config, mlir::OpPassManager& pass_manager) { const mlir::quant::QuantizationSpecs& quant_specs = pass_config.quant_specs; @@ -558,6 +580,13 @@ void AddPostVariableFreezingTFToTFLConversionPasses( pass_manager->addPass(mlir::TFL::CreateLegalizeVariablesPass()); pass_manager->addPass(mlir::TFL::CreateLegalizeHashTablesPass()); + if (pass_config.quant_specs.strict_qdq_mode) { + pass_manager->addPass(mlir::TFL::CreateLowerQuantAnnotationsPass()); + + // To remove the quant annotation decompositions. + pass_manager->addPass(mlir::createSymbolDCEPass()); + } + // Run TFL optimization passes set multiple times as op fusion and // reordering in later passes may enable further optimizations with earlier // passes. @@ -576,19 +605,24 @@ void AddPostVariableFreezingTFToTFLConversionPasses( mlir::createCanonicalizerPass()); pass_manager->addNestedPass(mlir::createCSEPass()); - // Run quantization after all the floating point model conversion is - // completed. Add either full integer quantization or dynamic range - // quantization passes based on quant_specs. - if (pass_config.quant_specs.RunPropagationAndRewriteQuantizationPasses() || - pass_config.quant_specs.qdq_conversion_mode != - mlir::quant::QDQConversionMode::kQDQNone) { - AddQuantizationPasses(pass_config, *pass_manager); - // Remove unnecessary QDQs while handling QAT models. - pass_manager->addNestedPass( - mlir::TFL::CreatePostQuantizeRemoveQDQPass()); - } else if (pass_config.quant_specs - .RunAndRewriteDynamicRangeQuantizationPasses()) { - AddDynamicRangeQuantizationPasses(pass_config, *pass_manager); + if (pass_config.quant_specs.strict_qdq_mode) { + AddStrictQDQQuantizationPasses(pass_config, *pass_manager); + } else { + // Run quantization after all the floating point model conversion is + // completed. Add either full integer quantization or dynamic range + // quantization passes based on quant_specs. + if (pass_config.quant_specs + .RunPropagationAndRewriteQuantizationPasses() || + pass_config.quant_specs.qdq_conversion_mode != + mlir::quant::QDQConversionMode::kQDQNone) { + AddQuantizationPasses(pass_config, *pass_manager); + // Remove unnecessary QDQs while handling QAT models. + pass_manager->addNestedPass( + mlir::TFL::CreatePostQuantizeRemoveQDQPass()); + } else if (pass_config.quant_specs + .RunAndRewriteDynamicRangeQuantizationPasses()) { + AddDynamicRangeQuantizationPasses(pass_config, *pass_manager); + } } pass_manager->addPass(mlir::createCanonicalizerPass()); diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 1910d51105cc20..afd8f440684b6b 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -608,7 +608,7 @@ absl::Status ConvertTFExecutorToTFLOrFlatbuffer( return status_handler->Combine(status); } } else { - *result = translated_result; + *result = std::move(translated_result); } if (mlir::failed(module->verifyInvariants())) { diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index ba5a2e7c2b5e0b..ec8569a14c7920 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -76,7 +76,7 @@ absl::StatusOr> ImportSavedModel( // * `session` pointer may provided, it will be used to freeze resource // variables. If the `saved_model_dir` directory path is provided, then the // `tf_saved_model.asset` ops will be freezed. -Status ConvertTFExecutorToTFLOrFlatbuffer( +absl::Status ConvertTFExecutorToTFLOrFlatbuffer( std::unique_ptr&& context, mlir::OwningOpRef module, tflite::ConverterFlags& converter_flags, diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_helper.cc b/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_helper.cc new file mode 100644 index 00000000000000..0ce6cc9d8a9fe7 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_helper.cc @@ -0,0 +1,174 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_helper.h" + +#include +#include +#include + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo + +namespace mlir::TFL { + +LogicalResult FillCompositeParams(stablehlo::CompositeOp op, + SmallVector& scales, + SmallVector& zero_points, + int& num_bits, bool& is_signed) { + auto scale_attr = llvm::dyn_cast_or_null( + op.getCompositeAttributes().get("scale")); + if (scale_attr == nullptr) { + return failure(); + } + for (auto float_attr : scale_attr.getValues()) { + scales.push_back(float_attr.getValueAsDouble()); + } + + auto zero_point_attr = llvm::dyn_cast_or_null( + op.getCompositeAttributes().get("zero_point")); + if (zero_point_attr == nullptr) { + for (int i = 0; i < scales.size(); ++i) { + zero_points.push_back(0); + } + } else { + for (int64_t zp : zero_point_attr.getValues()) { + zero_points.push_back(zp); + } + } + + auto dtype_attr = llvm::dyn_cast_or_null( + op.getCompositeAttributes().get("dtype")); + if (dtype_attr == nullptr) { + return failure(); + } + std::string dtype = dtype_attr.getValue().str(); + if (dtype == "i8") { + num_bits = 8; + is_signed = true; + } else { + // TODO(majiddadashi) currently only tested with i8. + return failure(); + } + return success(); +} + +LogicalResult GetStorageParams(unsigned num_bits, bool narrow_range, + bool is_signed, MLIRContext* ctx, + Type& storage_type, int64_t& qmin, + int64_t& qmax) { + if (num_bits <= 4) { + storage_type = IntegerType::get(ctx, 4); + if (is_signed) { + qmin = -8; + qmax = 7; + } else { + qmin = 0; + qmax = 15; + } + } else if (num_bits <= 8) { + storage_type = IntegerType::get(ctx, 8); + if (is_signed) { + qmin = -128; + qmax = 127; + } else { + qmin = 0; + qmax = 255; + } + } else if (num_bits <= 16) { + storage_type = IntegerType::get(ctx, 16); + if (is_signed) { + qmin = -32768; + qmax = 32767; + } else { + qmin = 0; + qmax = 65535; + } + } else if (num_bits <= 32) { + storage_type = IntegerType::get(ctx, 32); + if (is_signed) { + qmin = std::numeric_limits::min(); + qmax = std::numeric_limits::max(); + } else { + qmin = std::numeric_limits::min(); + qmax = std::numeric_limits::max(); + } + } else { + return failure(); + } + + // Handle narrow_range. + if (narrow_range) { + qmin += 1; + } + return success(); +} + +Type GetPerTensorQuantizedTensorType(Builder& builder, double scale, + int64_t zero_point, Type expressed_type, + int num_bits, Location loc, + bool narrow_range, bool is_signed) { + unsigned flags = is_signed ? quant::QuantizationFlags::Signed : 0; + MLIRContext* ctx = builder.getContext(); + Type storage_type; + int64_t qmin; + int64_t qmax; + if (failed(GetStorageParams(num_bits, narrow_range, is_signed, ctx, + storage_type, qmin, qmax))) { + return (emitError(loc, "unsupported FakeQuant number of bits: ") + << num_bits, + nullptr); + } + + return quant::UniformQuantizedType::getChecked( + loc, flags, storage_type, expressed_type, scale, zero_point, qmin, qmax); +} + +Type GetPerAxisQuantizedTensorType(Builder& builder, + SmallVector scales, + SmallVector zero_points, + int32_t quantized_dimension, + Type expressed_type, int num_bits, + Location loc, bool narrow_range, + bool is_signed) { + unsigned flags = is_signed ? quant::QuantizationFlags::Signed : 0; + + MLIRContext* ctx = builder.getContext(); + Type storage_type; + int64_t qmin; + int64_t qmax; + if (failed(GetStorageParams(num_bits, narrow_range, is_signed, ctx, + storage_type, qmin, qmax))) { + return (emitError(loc, "unsupported FakeQuant number of bits: ") + << num_bits, + nullptr); + } + + return quant::UniformQuantizedPerAxisType::getChecked( + loc, flags, storage_type, expressed_type, scales, zero_points, + quantized_dimension, qmin, qmax); +} + +} // namespace mlir::TFL diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_helper.h b/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_helper.h new file mode 100644 index 00000000000000..85fffcf2ba07a5 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_helper.h @@ -0,0 +1,55 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_LOWER_QUANT_ANNOTATIONS_HELPER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_LOWER_QUANT_ANNOTATIONS_HELPER_H_ + +#include + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo + +namespace mlir::TFL { + +LogicalResult FillCompositeParams(stablehlo::CompositeOp op, + SmallVector& scales, + SmallVector& zero_points, + int& num_bits, bool& is_signed); + +LogicalResult GetStorageParams(unsigned num_bits, bool narrow_range, + bool is_signed, MLIRContext* ctx, + Type& storage_type, int64_t& qmin, + int64_t& qmax); + +Type GetPerTensorQuantizedTensorType(Builder& builder, double scale, + int64_t zero_point, Type expressed_type, + int num_bits, Location loc, + bool narrow_range, bool is_signed); + +Type GetPerAxisQuantizedTensorType(Builder& builder, + SmallVector scales, + SmallVector zero_points, + int32_t quantized_dimension, + Type expressed_type, int num_bits, + Location loc, bool narrow_range, + bool is_signed); + +} // namespace mlir::TFL +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_LOWER_QUANT_ANNOTATIONS_HELPER_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_pass.cc b/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_pass.cc new file mode 100644 index 00000000000000..d27e22f460e6c1 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_pass.cc @@ -0,0 +1,160 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This transformation pass applies quantization on TFLite dialect. + +#include +#include +#include + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_helper.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/utils/utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" + +namespace mlir { +namespace TFL { +namespace { + +#define GEN_PASS_DEF_LOWERQUANTANNOTATIONSPASS +#include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc" + +class RewriteFakeQuantCompositeOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + public: + explicit RewriteFakeQuantCompositeOp(MLIRContext* context) + : OpRewritePattern(context) { + setHasBoundedRewriteRecursion(); + } + + LogicalResult matchAndRewrite(stablehlo::CompositeOp op, + PatternRewriter& rewriter) const final { + if (op.getName() != "quant.fake_quant") { + return failure(); + } + + SmallVector scales; + SmallVector zero_points; + int num_bits; + bool is_signed; + + if (failed(FillCompositeParams(op, scales, zero_points, num_bits, + is_signed))) { + return failure(); + } + + ShapedType input_shaped_type = cast(op.getOperand(0).getType()); + Type input_element_type = input_shaped_type.getElementType(); + Type quantized_element_type; + if (scales.size() == 1) { + quantized_element_type = GetPerTensorQuantizedTensorType( + rewriter, scales[0], zero_points[0], + /*expressed_type=*/input_element_type, num_bits, op->getLoc(), + /*narrow_range=*/false, is_signed); + } else { + int32_t quantized_dimension; + if (auto quantized_dimension_attr = llvm::dyn_cast_or_null( + op.getCompositeAttributes().get("quantization_dimension"))) { + quantized_dimension = + quantized_dimension_attr.getValue().getSExtValue(); + } else { + return failure(); + } + quantized_element_type = GetPerAxisQuantizedTensorType( + rewriter, scales, zero_points, quantized_dimension, + /*expressed_type=*/input_element_type, num_bits, op->getLoc(), + /*narrow_range=*/false, is_signed); + } + RankedTensorType intermediate_type = RankedTensorType::get( + input_shaped_type.getShape(), quantized_element_type); + TFL::QuantizeOp tfl_quantize_op = rewriter.create( + op.getLoc(), intermediate_type, + /*input=*/op.getOperand(0), + /*qtype=*/TypeAttr::get(intermediate_type)); + + Type output_type = op.getType(0); + TFL::DequantizeOp tfl_dequantize_op = rewriter.create( + op.getLoc(), output_type, /*input=*/tfl_quantize_op); + + rewriter.replaceAllOpUsesWith(op, tfl_dequantize_op.getOutput()); + rewriter.eraseOp(op); + + return success(); + } +}; + +struct LowerQuantAnnotationsPass + : public impl::LowerQuantAnnotationsPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerQuantAnnotationsPass) + + void runOnOperation() override; +}; + +void LowerQuantAnnotationsPass::runOnOperation() { + MLIRContext& ctx = getContext(); + + RewritePatternSet patterns(&ctx); + patterns.add(&ctx); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + + // Declare all the MHLO ops as legal except for the quantization composites we + // want to lower. + target.addDynamicallyLegalDialect( + [](Operation* op) { + auto mhlo_op = dyn_cast_or_null(op); + if (!mhlo_op) { + return true; + } + return mhlo_op.getName() != "quant.fake_quant"; + }); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + getOperation().emitError("Composite lowering pass failed."); + signalPassFailure(); + } +} +} // namespace +std::unique_ptr> CreateLowerQuantAnnotationsPass() { + return std::make_unique(); +} +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 28426956451ff7..4f665472b140a7 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -23,6 +23,7 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "tensorflow/compiler/mlir/lite/utils/utils.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td" +include "mlir/IR/CommonAttrConstraints.td" // Checks if the param passed is a F32 ElementsAttr. def F32ElementsAttr : ElementsAttrBase< @@ -354,6 +355,23 @@ def MatchHardSwishPattern5 : Pat< (FloatValueEquals<"6"> $cst_6), ]>; +def MatchHardSwishPattern6 : Pat< + (TFL_MulOp + $arg, + (TFL_MulOp + (TFL_AddOp + $arg, + (Arith_ConstantOp F32ElementsAttr:$cst_3), + TFL_AF_Relu6), + (Arith_ConstantOp F32ElementsAttr:$cst_one_sixth), + TFL_AF_None), + TFL_AF_None), + (TFL_HardSwishOp $arg), + [ + (FloatValueEquals<"3"> $cst_3), + (FloatValueEquals<"0.166666672"> $cst_one_sixth), + ]>; + // Constraint that the attribute value is less than 'n' class ConstDoubleValueLessThan : Constraint< CPred<"$0.isa() && " @@ -1913,3 +1931,14 @@ def ReorderGatherAndCast : Pat< (TFL_GatherOp (TFL_CastOp:$cast $params), $indices, $axis, $batch_dims), (TFL_CastOp (TFL_GatherOp $params, $indices, $axis, $batch_dims)), [(HasOneUse $cast)]>; + +// Replace division by a constant with a multiplication by a reciprocal of that +// constant. Floating point division can be ~10x more expensive than a +// multiplication. +def RealDivWithF32ConstDivisor : Pat< + (TFL_DivOp:$src $arg0, (Arith_ConstantOp FloatElementsAttr<32>:$value), $activation), + (TFL_MulOp:$dest1 $arg0, + (TFL_DivOp (Arith_ConstantOp + (GetScalarOfType<1> (Arith_ConstantOp $value))), + (Arith_ConstantOp $value), TFL_AF_None), + $activation)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 50356378be9ab0..4d8ecccaa5f3f7 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -116,6 +116,8 @@ std::unique_ptr> CreateQuantizePass( std::unique_ptr> CreateDefaultQuantizePass(); +std::unique_ptr> CreateLowerQuantAnnotationsPass(); + // Overloading of CreateQuantizePass which takes only necessary flags to reduce // the binary size. std::unique_ptr> CreateQuantizePass( diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.td b/tensorflow/compiler/mlir/lite/transforms/passes.td index 06bee2f85638d3..8ea13964fe64b8 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/transforms/passes.td @@ -340,6 +340,17 @@ def QuantizePass : Pass<"tfl-quantize", "mlir::func::FuncOp"> { ]; } +def LowerQuantAnnotationsPass : Pass<"tfl-lower-quant-annotations", "mlir::ModuleOp"> { + let summary = "Lowers the quantization annotations marked by composites to the TFLite dialect."; + let constructor = "CreateLowerQuantAnnotationsPass()"; + let dependentDialects = [ + "TFL::TensorFlowLiteDialect", + "mlir::quant::QuantDialect", + "TF::TensorFlowDialect", + "stablehlo::StablehloDialect" + ]; +} + def QuantizeVariablesPass : Pass<"tfl-quantize-variables", "mlir::ModuleOp"> { let summary = "Quantize variables"; let constructor = "CreatePrepareQuantizeVariablesPass()"; diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 867eecff15818f..a2e58c81c54d9a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include "llvm/Support/Casting.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -234,9 +236,11 @@ struct FoldTransposeOp : public OpRewritePattern { DenseIntElementsAttr perm_tensor; if (!matchPattern(op.getPerm(), m_Constant(&perm_tensor))) return failure(); - if (!mlir::isa( - (getElementTypeOrSelf(op.getOutput().getType())))) + auto output_element_type = getElementTypeOrSelf(op.getOutput().getType()); + if (!mlir::isa(output_element_type) && + !mlir::isa(output_element_type)) { return failure(); + } ElementsAttr input_tensor = qconst_op.getValue(); @@ -265,10 +269,19 @@ struct FoldTransposeOp : public OpRewritePattern { /*output_axis=*/0, &input_indices, &new_values); auto result_type = RankedTensorType::get(output_shape, output_type.getElementType()); - auto values_type = RankedTensorType::get( - output_shape, - mlir::cast(output_type.getElementType()) - .getStorageType()); + RankedTensorType values_type; + if (mlir::isa(output_element_type)) { + values_type = RankedTensorType::get( + output_shape, + mlir::cast(output_type.getElementType()) + .getStorageType()); + } else { + values_type = RankedTensorType::get( + output_shape, mlir::cast( + output_type.getElementType()) + .getStorageType()); + } + rewriter.replaceOpWithNewOp( op, TypeAttr::get(result_type), DenseIntElementsAttr::get(values_type, new_values)); diff --git a/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/legalize_tensorlist_pass.cc b/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/legalize_tensorlist_pass.cc index 942f62a2725c97..0fe96f4b0b71cc 100644 --- a/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/legalize_tensorlist_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/legalize_tensorlist_pass.cc @@ -258,7 +258,7 @@ void LegalizeTensorListPass::runOnOperation() { patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); - (void)applyPatternsAndFoldGreedily(module, std::move(patterns)); + (void)applyPatternsGreedily(module, std::move(patterns)); } } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc index d3ca3179b2b818..d1b15341c51bb7 100644 --- a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "Eigen/Core" // from @eigen_archive #include "llvm/ADT/APInt.h" diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc index 76b00825628b2e..30b79d91a7a900 100644 --- a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc @@ -76,8 +76,8 @@ absl::StatusOr CreateTypedAttr(ShapedType shaped_type, int value) { return mlir::TF::TensorProtoAttr::get(shaped_type, mangled); } else { - return tensorflow::Status(absl::StatusCode::kInvalidArgument, - "Unsupported type"); + return absl::Status(absl::StatusCode::kInvalidArgument, + "Unsupported type"); } } else if (auto itype = mlir::dyn_cast(element_type)) { if (element_type.isSignedInteger()) { @@ -99,8 +99,8 @@ absl::StatusOr CreateTypedAttr(ShapedType shaped_type, int value) { static_cast(value)); break; default: - return tensorflow::Status(absl::StatusCode::kInvalidArgument, - "Unsupported type"); + return absl::Status(absl::StatusCode::kInvalidArgument, + "Unsupported type"); } } else { switch (itype.getWidth()) { @@ -121,13 +121,12 @@ absl::StatusOr CreateTypedAttr(ShapedType shaped_type, int value) { static_cast(value)); break; default: - return tensorflow::Status(absl::StatusCode::kInvalidArgument, - "Unsupported type"); + return absl::Status(absl::StatusCode::kInvalidArgument, + "Unsupported type"); } } } else { - return tensorflow::Status(absl::StatusCode::kInvalidArgument, - "Unsupported type"); + return absl::Status(absl::StatusCode::kInvalidArgument, "Unsupported type"); } } diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc index 940d30c9c7929b..5ab34b85cb1601 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h" #include -#include +#include #include #include "llvm/ADT/ArrayRef.h" diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h index f2266f8920669a..8d9a5ab17f5095 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h @@ -19,6 +19,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_ +#include + #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc index f85ea68d621ef6..504c10861f7b1e 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h" +#include #include -#include #include #include diff --git a/tensorflow/compiler/mlir/lite/utils/nms_utils.cc b/tensorflow/compiler/mlir/lite/utils/nms_utils.cc index 6677f57c6fdd0d..211336de124075 100644 --- a/tensorflow/compiler/mlir/lite/utils/nms_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/nms_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/nms_utils.h" +#include #include #include "flatbuffers/flexbuffers.h" // from @flatbuffers diff --git a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc index 650c372e42b2b4..e94819afa3612f 100644 --- a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h" +#include #include #include -#include #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/utils/region_isolation.h b/tensorflow/compiler/mlir/lite/utils/region_isolation.h index 06a1776ae86104..b32b2df210f962 100644 --- a/tensorflow/compiler/mlir/lite/utils/region_isolation.h +++ b/tensorflow/compiler/mlir/lite/utils/region_isolation.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_REGION_ISOLATION_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_REGION_ISOLATION_H_ +#include + #include "llvm/ADT/SetVector.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc index 24314630e65154..fa191c6c69d984 100644 --- a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc @@ -15,7 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h" -#include +#include +#include #include #include diff --git a/tensorflow/compiler/mlir/lite/utils/utils.h b/tensorflow/compiler/mlir/lite/utils/utils.h index c74fe250638eff..53f6a038678d1e 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.h +++ b/tensorflow/compiler/mlir/lite/utils/utils.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_UTILS_H_ #include +#include #include #include #include @@ -24,11 +25,13 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project @@ -375,6 +378,30 @@ inline bool OperandsBroadcastToOutputType(Type a, Type b, return broadcasted_type != Type() && broadcasted_type == expected_output; } +// Returns int, float or complex DenseElementsAttr with scalar shape with the +// given element type and the integer value. +template +DenseElementsAttr GetScalarOfType(Type ty, T raw_value) { + RankedTensorType scalar_ty = RankedTensorType::get({}, ty); + if (auto float_ty = mlir::dyn_cast(ty)) { + FloatAttr attr = FloatAttr::get(float_ty, raw_value); + return DenseElementsAttr::get(scalar_ty, attr); + } else if (auto int_ty = mlir::dyn_cast(ty)) { + IntegerAttr attr = IntegerAttr::get(int_ty, raw_value); + return DenseElementsAttr::get(scalar_ty, attr); + } else if (auto complex_ty = mlir::dyn_cast(ty)) { + Type complex_element_ty = complex_ty.getElementType(); + if (complex_element_ty.isF32()) { + return DenseElementsAttr::get( + scalar_ty, static_cast>(raw_value)); + } else if (complex_element_ty.isF64()) { + return DenseElementsAttr::get( + scalar_ty, static_cast>(raw_value)); + } + } + llvm_unreachable("unsupported type"); +} + } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/utils.td b/tensorflow/compiler/mlir/lite/utils/utils.td index f0dd36366d9808..fb7baadc6fc85d 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.td +++ b/tensorflow/compiler/mlir/lite/utils/utils.td @@ -218,4 +218,9 @@ def IsNoneType : Constraint()">>; def ConstantLikePred : CPred<"::mlir::matchPattern($0, ::mlir::m_Constant())">; def IsConstantLike : Constraint; -def NotConstantLike : Constraint>; \ No newline at end of file +def NotConstantLike : Constraint>; + +// Here, the element type can be any integer or float type. But, note that only +// 32 bit integers are supported for the values. +class GetScalarOfType : NativeCodeCall< + "GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">; diff --git a/tensorflow/compiler/mlir/lite/utils/validators.cc b/tensorflow/compiler/mlir/lite/utils/validators.cc index 536762c3e44292..f824f22ffb72e7 100644 --- a/tensorflow/compiler/mlir/lite/utils/validators.cc +++ b/tensorflow/compiler/mlir/lite/utils/validators.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include +#include #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/utils/validators.h b/tensorflow/compiler/mlir/lite/utils/validators.h index 86306ab5a454ce..be24f40fc2ec01 100644 --- a/tensorflow/compiler/mlir/lite/utils/validators.h +++ b/tensorflow/compiler/mlir/lite/utils/validators.h @@ -19,6 +19,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_ +#include + #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 6e9af113ac669e..0f463a5996cb51 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -84,6 +84,14 @@ auto* mlir_function_pass_graph_conversion_count = monitoring::Counter<1>::New( "optimization pass", /* metric field */ "status"); +auto* mlir_v1_compat_graph_conversion_count = monitoring::Counter<1>::New( + /* metric name */ + "/tensorflow/core/mlir_v1_compat_graph_conversion_count", + /* metric description */ + "Track success/failure of Graph to MLIR conversions in MLIR V1 compat " + "optimization pass", + /* metric field */ "status"); + // The status metric field is used to record success/failure of mlir // function/graph optimization passes. constexpr char kSuccess[] = "kSuccess"; @@ -155,7 +163,7 @@ static void RegisterDialects(mlir::DialectRegistry& registry) { // clang-format on } -Status MlirFunctionOptimizationPass::Run( +absl::Status MlirFunctionOptimizationPass::Run( const std::string& function_name, const DeviceSet& device_set, const ConfigProto& config_proto, const FunctionOptimizationPass::FunctionOptions& function_options, @@ -239,7 +247,9 @@ Status MlirFunctionOptimizationPass::Run( {kTfMlirCategory, "convert_graph_to_mlir"}); auto module_ref_status = tensorflow::tf2xla::v2::ConvertGraphToTfExecutor( - **graph, debug_info, *flib_def, import_config, &context); + **graph, debug_info, *flib_def, import_config, &context, + /*tf_name_to_mlir_name*/ nullptr, config_proto, + tensorflow::TF2XLABridgeVersion::kNominal); mlir_function_pass_graph_conversion_count ->GetCell(absl::StatusCodeToString(module_ref_status.status().code())) ->IncrementBy(1); @@ -277,7 +287,7 @@ Status MlirFunctionOptimizationPass::Run( *module_ref, llvm::StringRef(), nullptr); } - Status pass_status = absl::OkStatus(); + absl::Status pass_status = absl::OkStatus(); auto pass_state = per_pass_state[per_pass_state_index++]; if (pass_state == MlirOptimizationPassState::Enabled) { VLOG(2) << "Run MLIR graph optimization pass: " << StringRefToView(name); @@ -361,7 +371,7 @@ Status MlirFunctionOptimizationPass::Run( timings.Reset({kTfMlirCategory, "convert_mlir_to_graph"}); // Some or all passes are enabled. Convert MLIR module and return back // resulted graph. - Status status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( + absl::Status status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( *module_ref, export_config, graph, flib_def, &control_ret_nodes); if (!status.ok()) { errors::AppendToMessage(&status, @@ -387,7 +397,7 @@ MlirV1CompatOptimizationPassRegistry::Global() { return *global; } -Status MlirV1CompatGraphOptimizationPass::Run( +absl::Status MlirV1CompatGraphOptimizationPass::Run( const GraphOptimizationPassOptions& options) { // Skip MLIR V1 optimization pass if it is not enabled in compiling // SavedModel. @@ -429,7 +439,12 @@ Status MlirV1CompatGraphOptimizationPass::Run( import_config.restrict_functionalization_to_compiled_nodes = true; auto module_ref_status = tensorflow::tf2xla::v2::ConvertGraphToTfExecutor( - **options.graph, debug_info, *options.flib_def, import_config, &context); + **options.graph, debug_info, *options.flib_def, import_config, &context, + /*tf_name_to_mlir_name*/ nullptr, options.session_options->config, + tensorflow::TF2XLABridgeVersion::kV1Compat); + mlir_v1_compat_graph_conversion_count + ->GetCell(absl::StatusCodeToString(module_ref_status.status().code())) + ->IncrementBy(1); if (!module_ref_status.ok()) { if (pass_state == MlirOptimizationPassState::Enabled) { return module_ref_status.status(); @@ -452,7 +467,7 @@ Status MlirV1CompatGraphOptimizationPass::Run( if (VLOG_IS_ON(1)) { DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name)); } - Status pass_status = pass->Run(options, *module_ref); + absl::Status pass_status = pass->Run(options, *module_ref); bool is_module_updated = !mlir::OperationEquivalence::isEquivalentTo( module_ref_clone, *module_ref, diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h index 5c463e32aef718..1e817d0ae3386d 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h @@ -76,10 +76,10 @@ class MlirOptimizationPass { const Graph& graph, const FunctionLibraryDefinition& function_library) const = 0; - virtual Status Run(const std::string& function_name, - const ConfigProto& config_proto, mlir::ModuleOp module, - const Graph& graph, - const FunctionLibraryDefinition& function_library) = 0; + virtual absl::Status Run( + const std::string& function_name, const ConfigProto& config_proto, + mlir::ModuleOp module, const Graph& graph, + const FunctionLibraryDefinition& function_library) = 0; }; class MlirOptimizationPassRegistry { @@ -129,12 +129,13 @@ class MlirFunctionOptimizationPass : public FunctionOptimizationPass { : registry_(registry) {} // Executes all of the underlying registered MlirOptimizationPasses. - Status Run(const std::string& function_name, const DeviceSet& device_set, - const ConfigProto& config_proto, - const FunctionOptimizationPass::FunctionOptions& function_options, - std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, - std::vector* control_ret_node_names, - bool* control_rets_updated) override; + absl::Status Run( + const std::string& function_name, const DeviceSet& device_set, + const ConfigProto& config_proto, + const FunctionOptimizationPass::FunctionOptions& function_options, + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, + std::vector* control_ret_node_names, + bool* control_rets_updated) override; private: const MlirOptimizationPassRegistry* registry_; @@ -162,8 +163,8 @@ class MlirV1CompatOptimizationPass { const Graph& graph, const FunctionLibraryDefinition& function_library) const = 0; - virtual Status Run(const GraphOptimizationPassOptions& options, - mlir::ModuleOp module) = 0; + virtual absl::Status Run(const GraphOptimizationPassOptions& options, + mlir::ModuleOp module) = 0; }; class MlirV1CompatOptimizationPassRegistry { @@ -195,7 +196,7 @@ class MlirV1CompatGraphOptimizationPass : public GraphOptimizationPass { &MlirV1CompatOptimizationPassRegistry::Global()) : registry_(registry) {} - Status Run(const GraphOptimizationPassOptions& options) override; + absl::Status Run(const GraphOptimizationPassOptions& options) override; private: const MlirV1CompatOptimizationPassRegistry* registry_; diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc index 64e230f448f3fe..c00eba34a93ce7 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" +#include #include #include #include @@ -22,6 +23,7 @@ limitations under the License. #include #include +#include #include "absl/status/status.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" @@ -66,7 +68,7 @@ class MockMlirOptimizationPass : public MlirOptimizationPass { const Graph& graph, const FunctionLibraryDefinition& function_library), (const, override)); - MOCK_METHOD(Status, Run, + MOCK_METHOD(absl::Status, Run, (const std::string& function_name, const ConfigProto& config_proto, mlir::ModuleOp module, const Graph& graph, @@ -82,7 +84,7 @@ class MockMlirV1CompatOptimizationPass : public MlirV1CompatOptimizationPass { const Graph& graph, const FunctionLibraryDefinition& function_library), (const, override)); - MOCK_METHOD(Status, Run, + MOCK_METHOD(absl::Status, Run, (const GraphOptimizationPassOptions& options, mlir::ModuleOp module), (override)); @@ -90,7 +92,8 @@ class MockMlirV1CompatOptimizationPass : public MlirV1CompatOptimizationPass { class ModifyMlirModulePass : public MlirOptimizationPass { public: - explicit ModifyMlirModulePass(Status run_status) : run_status_(run_status) {} + explicit ModifyMlirModulePass(absl::Status run_status) + : run_status_(run_status) {} MOCK_METHOD(llvm::StringRef, name, (), (const, override)); MOCK_METHOD(MlirOptimizationPassState, GetPassState, (const DeviceSet* device_set, const ConfigProto& config_proto, @@ -100,9 +103,10 @@ class ModifyMlirModulePass : public MlirOptimizationPass { // Just modify MLIR module so that we can check whether original TF graph // has changed or not. - Status Run(const std::string& function_name, const ConfigProto& config_proto, - mlir::ModuleOp module, const Graph& graph, - const FunctionLibraryDefinition& function_library) override { + absl::Status Run(const std::string& function_name, + const ConfigProto& config_proto, mlir::ModuleOp module, + const Graph& graph, + const FunctionLibraryDefinition& function_library) override { mlir::Builder b(module.getContext()); auto producer = b.getNamedAttr("producer", b.getI32IntegerAttr(0)); auto min_consumer = b.getNamedAttr("min_consumer", b.getI32IntegerAttr(0)); @@ -116,7 +120,7 @@ class ModifyMlirModulePass : public MlirOptimizationPass { return run_status_; } - Status run_status_; + absl::Status run_status_; }; FunctionDef XTimesTwo() { @@ -140,7 +144,7 @@ FunctionDef XTimesTwo() { class MlirGraphOptimizationPassTest : public Test { public: - void Init(Status pass_run_result, + void Init(absl::Status pass_run_result, const std::vector& pass_states) { graph_ = std::make_unique(OpRegistry::Global()); @@ -162,7 +166,7 @@ class MlirGraphOptimizationPassTest : public Test { } void AddModuleModificationPass(MlirOptimizationPassState pass_state, - Status run_status) { + absl::Status run_status) { // Add FallbackEnabled pass that modifies the graph. auto optimization_pass = std::make_unique>(run_status); @@ -231,7 +235,7 @@ class MlirGraphOptimizationPassTest : public Test { }; TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsNoFallback) { - Init(Status(absl::StatusCode::kAborted, "aborted"), + Init(absl::Status(absl::StatusCode::kAborted, "aborted"), {MlirOptimizationPassState::Enabled}); GraphDef original_graph_def; @@ -241,13 +245,13 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsNoFallback) { function_optimization_pass_.Run( "test_func", device_set_, config_proto_, function_options_, &graph_, flib_.get(), &control_ret_node_names_, &control_rets_updated_), - Status(absl::StatusCode::kAborted, "aborted")); + absl::Status(absl::StatusCode::kAborted, "aborted")); verifyGraph(original_graph_def); verifyCounters(); } TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsDisabledFallback) { - Init(Status(absl::StatusCode::kAborted, "aborted"), + Init(absl::Status(absl::StatusCode::kAborted, "aborted"), {MlirOptimizationPassState::Disabled, MlirOptimizationPassState::FallbackEnabled}); @@ -261,8 +265,9 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsDisabledFallback) { GraphDef original_graph_def; graph_->ToGraphDef(&original_graph_def); - AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled, - Status(absl::StatusCode::kAborted, "aborted")); + AddModuleModificationPass( + MlirOptimizationPassState::FallbackEnabled, + absl::Status(absl::StatusCode::kAborted, "aborted")); EXPECT_EQ( function_optimization_pass_.Run( @@ -329,7 +334,7 @@ TEST(MlirV1CompatOptimizationPassRegistry, RegisterMultiplePassesFails) { class MlirGraphOptimizationV1PassTest : public Test { public: - void Init(Status pass_run_result, + void Init(absl::Status pass_run_result, const std::vector& pass_states) { graph_ = std::make_unique(OpRegistry::Global()); MlirV1CompatOptimizationPassRegistry::Global().ClearPass(); @@ -381,6 +386,7 @@ class MlirGraphOptimizationV1PassTest : public Test { pass_result_expected_[MlirOptimizationPassState::FallbackEnabled] [false]); EXPECT_EQ(mlir_function_pass_graph_conversion_count_.Read(kOk), 0); + EXPECT_EQ(mlir_v1_compat_graph_conversion_count_.Read(kOk), 1); } void TearDown() override { @@ -413,6 +419,11 @@ class MlirGraphOptimizationV1PassTest : public Test { monitoring::testing::CellReader( /* metric name */ "/tensorflow/core/mlir_function_pass_graph_conversion_count"); + monitoring::testing::CellReader + mlir_v1_compat_graph_conversion_count_ = + monitoring::testing::CellReader( + /* metric name */ + "/tensorflow/core/mlir_v1_compat_graph_conversion_count"); }; TEST_F(MlirGraphOptimizationV1PassTest, OptimizationPassDoesNotFailFallback) { diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.h b/tensorflow/compiler/mlir/op_or_arg_name_mapper.h index d8ff9cebc0108a..f8c596fff79f61 100644 --- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.h +++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_OP_OR_ARG_NAME_MAPPER_H_ #define TENSORFLOW_COMPILER_MLIR_OP_OR_ARG_NAME_MAPPER_H_ +#include #include #include "absl/strings/string_view.h" diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD index 8fcdc2db1a7e51..24b5ef8cc8b85e 100644 --- a/tensorflow/compiler/mlir/python/BUILD +++ b/tensorflow/compiler/mlir/python/BUILD @@ -25,6 +25,7 @@ cc_library( srcs = ["mlir.cc"], hdrs = ["mlir.h"], deps = [ + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "//tensorflow/cc/saved_model:bundle_v2", "//tensorflow/cc/saved_model:loader", @@ -35,7 +36,6 @@ cc_library( "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:BytecodeWriter", "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:TosaDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", @@ -61,15 +61,10 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", - "//tensorflow/compiler/mlir/tosa:passes_header", - "//tensorflow/compiler/mlir/tosa:tf_passes", - "//tensorflow/compiler/mlir/tosa:tf_tfl_passes", - "//tensorflow/compiler/mlir/tosa:tfl_passes", "//tensorflow/core:framework", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tflite_portable_logging", - "//tensorflow/core/common_runtime:core_cpu_base_no_ops", "//tensorflow/core/common_runtime/eager:context", # (yongtang) The graph_optimization_pass_registration needs to be part # of a shared object that will be loaded whenever `import tensorflow` diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index 24db5d87c008b7..49260f26e6abf4 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -31,7 +31,6 @@ limitations under the License. #include "mlir/Bytecode/BytecodeWriter.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project -#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/IR/AsmState.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project @@ -62,10 +61,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/mlprogram_util.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" -#include "tensorflow/compiler/mlir/tosa/tf_passes.h" -#include "tensorflow/compiler/mlir/tosa/tf_tfl_passes.h" -#include "tensorflow/compiler/mlir/tosa/tfl_passes.h" -#include "tensorflow/compiler/mlir/tosa/transforms/passes.h" #include "xla/mlir/framework/transforms/passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "tensorflow/core/common_runtime/eager/context.h" @@ -97,10 +92,6 @@ static void RegisterPasses() { mlir::mhlo::registerTfXlaPasses(); mlir::mhlo::registerLegalizeTFPass(); mlir::quant::stablehlo::registerBridgePasses(); - mlir::tosa::registerLegalizeTosaPasses(); - mlir::tosa::registerTFtoTOSALegalizationPipeline(); - mlir::tosa::registerTFLtoTOSALegalizationPipeline(); - mlir::tosa::registerTFTFLtoTOSALegalizationPipeline(); mlir::tf_saved_model::registerTensorFlowSavedModelPasses(); mlir::xla_framework::registerXlaFrameworkPasses(); tensorflow::RegisterMlProgramPasses(); @@ -197,7 +188,15 @@ std::string ImportFunction(const std::string& functiondef_proto, mlir::DialectRegistry registry; mlir::func::registerAllExtensions(registry); mlir::MLIRContext context(registry); - auto module = ConvertFunctionToMlir(fbody.get(), flib_def, &context); + + tensorflow::GraphImportConfig specs; + specs.graph_func_name = fbody->record->fdef().signature().name(); + specs.enable_shape_inference = false; + specs.graph_as_function = true; + for (const auto* control_ret_node : fbody->control_ret_nodes) + specs.control_outputs.push_back(control_ret_node->name()); + auto module = tensorflow::tf2xla::v2::ConvertGraphToTfExecutor( + *fbody->graph, {}, flib_def, specs, &context); if (!module.ok()) { tsl::Set_TF_Status_from_Status(status, module.status()); return "// error"; @@ -423,64 +422,4 @@ void ExperimentalWriteBytecode(const std::string& filename, } } -void ExperimentalTFLiteToTosaBytecode( - const std::string& flatbuffer_file, const std::string& tosa_bytecode_file, - bool use_external_constant, - const std::vector& ordered_input_arrays, - const std::vector& ordered_output_arrays, TF_Status* status) { - mlir::DialectRegistry registry; - mlir::RegisterAllTensorFlowDialects(registry); - registry.insert(); - mlir::MLIRContext context(registry); - mlir::OwningOpRef module; - mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); - { - mlir::Location loc = mlir::UnknownLoc::get(&context); - std::string error; - std::unique_ptr buffer = - mlir::openInputFile(flatbuffer_file, &error); - if (buffer == nullptr) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, - ("Unable to load input file " + error).c_str()); - return; - } - - auto buffer_view = - std::string_view(buffer->getBufferStart(), buffer->getBufferSize()); - module = tflite::FlatBufferToMlir( - buffer_view, &context, loc, use_external_constant, ordered_input_arrays, - ordered_output_arrays); - mlir::PassManager pm(&context, module.get()->getName().getStringRef(), - mlir::PassManager::Nesting::Implicit); - mlir::tosa::TOSATFLLegalizationPipelineOptions opts; - // This flow is specific to compilation backend, so set to true. - opts.target_compilation_backend = true; - // Temporary work-around for https://github.com/openxla/iree/issues/8974 - opts.dequantize_tfl_softmax = true; - createTFLtoTOSALegalizationPipeline(pm, opts); - if (failed(pm.run(*module))) { - tsl::Set_TF_Status_from_Status(status, - diagnostic_handler.ConsumeStatus()); - return; - } - } - mlir::FallbackAsmResourceMap fallback_resource_map; - mlir::BytecodeWriterConfig writer_config(fallback_resource_map); - // TODO(jpienaar): Make this an option to the call. - writer_config.setDesiredBytecodeVersion(1); - std::string error; - std::unique_ptr outputFile = - mlir::openOutputFile(tosa_bytecode_file, &error); - if (!error.empty()) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, - ("Unable to create output file" + error).c_str()); - return; - } - outputFile->keep(); - if (failed(mlir::writeBytecodeToFile(*module, outputFile->os(), - writer_config))) { - tsl::Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus()); - } -} - } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/python/mlir.h b/tensorflow/compiler/mlir/python/mlir.h index a17f4f2843e470..99a17ca1ef2fc1 100644 --- a/tensorflow/compiler/mlir/python/mlir.h +++ b/tensorflow/compiler/mlir/python/mlir.h @@ -109,16 +109,6 @@ std::string ExperimentalRunPassPipeline(const std::string &mlir_txt, void ExperimentalWriteBytecode(const std::string &filename, const std::string &mlir_txt, TF_Status *status); -// Loads a TFLite flatbuffer, convert to TOSA for backend compilation and -// produce an MLIR bytecode file as output. -// TODO(jpienaar): Refactor this when we use more implicit module passing -// between calls to avoid serialization overhead. -void ExperimentalTFLiteToTosaBytecode( - const std::string &flatbuffer_file, const std::string &tosa_bytecode_file, - bool use_external_constant, - const std::vector &ordered_input_arrays, - const std::vector &ordered_output_arrays, TF_Status *status); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_ diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD index eafd86653603ec..6908a1d2d53058 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD @@ -42,6 +42,7 @@ tf_python_pybind_extension( pytype_srcs = [ "filecheck_wrapper.pyi", ], + starlark_only = True, visibility = ["//visibility:public"], deps = [ "//tensorflow/python/lib/core:pybind11_lib", diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc index ded07c7254e51b..86c019e689466f 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc index bdc0931e250bc7..58a3c9452edb07 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc @@ -15,6 +15,8 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project +#include + #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc index 8c82fc9bc12b42..e597ae85eeaaaa 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "llvm/FileCheck/FileCheck.h" #include "llvm/Support/SourceMgr.h" #include "pybind11/pybind11.h" // from @pybind11 diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc index 24e2f739529b81..60e980a6df201b 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" +#include + #include "llvm/Support/SourceMgr.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.pyi b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.pyi index b3d75ba9a3ff9e..0961d12bdefb09 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.pyi +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.pyi @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== -from typing import Any - from typing import overload class Attribute: @@ -22,7 +20,7 @@ class Attribute: class Block: def __init__(self, *args, **kwargs) -> None: ... - def addArgument(self, *args, **kwargs) -> Any: ... + def addArgument(self, *args, **kwargs): ... def end(self) -> Block_Iterator: ... def new(self) -> Block: ... @@ -85,11 +83,11 @@ class OpBuilder: def __init__(self, arg0) -> None: ... @overload def __init__(self, arg0: Block, arg1: Block_Iterator) -> None: ... - def create(self, *args, **kwargs) -> Any: ... + def create(self, *args, **kwargs): ... def getContext(self) -> MLIRContext: ... def getUnknownLoc(self) -> Location: ... def restoreInsertionPoint(self, arg0) -> None: ... - def saveInsertionPoint(self, *args, **kwargs) -> Any: ... + def saveInsertionPoint(self, *args, **kwargs): ... def setInsertionPoint(self, arg0: Block, arg1: Block_Iterator) -> None: ... class OpBuilder_InsertionPoint: @@ -119,8 +117,8 @@ class RankedTensorType(Type): class Region: def __init__(self, *args, **kwargs) -> None: ... def add_block(self) -> None: ... - def back(self, *args, **kwargs) -> Any: ... - def front(self, *args, **kwargs) -> Any: ... + def back(self, *args, **kwargs): ... + def front(self, *args, **kwargs): ... def push_back(self, arg0) -> None: ... def size(self) -> int: ... @@ -189,7 +187,7 @@ class UnrankedTensorType(Type): class Value: def __init__(self, *args, **kwargs) -> None: ... - def getType(self, *args, **kwargs) -> Any: ... + def getType(self, *args, **kwargs): ... def preloadTensorFlowDialects(arg0) -> None: ... def verify(arg0: str) -> bool: ... diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc index 049de333516d18..4e1ab6796e1cc8 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include + #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc index 46111c07ef6b3a..775ef48ffed3c0 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h index 9acff4230669e3..cb9dac201a0a96 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h @@ -215,6 +215,10 @@ struct QuantizationSpecs { // If other than kQDQNone, the model is a floating point graph with QDQ ops // to be eliminated and fused into quantized kernels. QDQConversionMode qdq_conversion_mode = QDQConversionMode::kQDQNone; + + // When set, adheres to the QDQ annotations added by the framework when + // possible rather than quantizing any op that is possible to quantize. + bool strict_qdq_mode = false; }; // Parses the command line flag strings to the CustomOpMap specification. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc index 180df43a62a249..c63a7158e5a93b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc @@ -58,7 +58,7 @@ class TestEnvBrokenFileSystem : public tsl::Env { tsl::string GetRunfilesDir() override { return tsl::string("dummy_path"); } - int32_t GetCurrentThreadId() override { return 0; } + int64_t GetCurrentThreadId() override { return 0; } tsl::Thread* StartThread(const tsl::ThreadOptions& thread_options, const tsl::string& name, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h index 8f1e4236e09823..9918b144a11fe3 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h @@ -17,8 +17,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_SAVED_MODEL_IMPORT_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_SAVED_MODEL_IMPORT_H_ +#include #include #include +#include #include #include "absl/base/attributes.h" diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/optimize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/optimize.cc index dabf1d06a6e447..f641ea64cf0154 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/optimize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/optimize.cc @@ -45,7 +45,7 @@ void OptimizeIntGraph::runOnOperation() { RewritePatternSet patterns(&getContext()); populateWithGenerated(patterns); auto func = getOperation(); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc index 686204030c1fdc..0f4d2074e420f3 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc @@ -281,7 +281,7 @@ void DeferActivationTransposePass::runOnOperation() { patterns.add(&ctx); - if (failed(applyPatternsAndFoldGreedily(func_op, std::move(patterns)))) { + if (failed(applyPatternsGreedily(func_op, std::move(patterns)))) { func_op->emitWarning() << "Failed to converge patterns: " << getArgument(); } } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc index 06e38c3935c417..24f5ab6a10fb64 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc @@ -189,7 +189,7 @@ void FoldConstantTransposePass::runOnOperation() { RewritePatternSet patterns(&ctx); patterns.add(&ctx); - if (failed(applyPatternsAndFoldGreedily(func_op, std::move(patterns)))) { + if (failed(applyPatternsGreedily(func_op, std::move(patterns)))) { func_op.emitError("Failed to fold constant->transpose pattern."); signalPassFailure(); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc index 415496445f7f13..a9e13695fbdab0 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc @@ -240,7 +240,7 @@ void InsertWeightParamPass::runOnOperation() { patterns.add(context); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc index 2020fea5ea7146..cfe19f6af774f2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc @@ -210,7 +210,7 @@ void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() { // Iterate over the sorted list of functions to keep order deterministic. for (func::FuncOp func : GetSortedFunctions(module_op)) { - if (failed(applyPatternsAndFoldGreedily(func, frozen_patterns))) { + if (failed(applyPatternsGreedily(func, frozen_patterns))) { func.emitError() << "quant-stablehlo-lift-quantizable-spots-as-functions failed."; signalPassFailure(); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc index 24e148949215e8..293b4a19c6eb2c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc @@ -136,7 +136,7 @@ void MergeFusionWithDequantizePass::runOnOperation() { MLIRContext* ctx = module_op.getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx); - if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { + if (failed(applyPatternsGreedily(module_op, std::move(patterns)))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc index 9e64756ddbf2a6..39546b33778242 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc @@ -179,7 +179,7 @@ void NchwConvolutionToNhwcPass::runOnOperation() { RewritePatternSet patterns(&ctx); patterns.add(&ctx); - if (failed(applyPatternsAndFoldGreedily(func_op, std::move(patterns)))) { + if (failed(applyPatternsGreedily(func_op, std::move(patterns)))) { func_op.emitError() << "Failed to run NchwConvolutionToNhwcPass."; signalPassFailure(); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/optimize_graph.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/optimize_graph.cc index 8c4837673b2754..47ec6ab15fbb51 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/optimize_graph.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/optimize_graph.cc @@ -46,7 +46,7 @@ void OptimizeGraphPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateWithGenerated(patterns); auto func = getOperation(); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/post_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/post_quantize.cc index 4052988230b108..167aad9da31492 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/post_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/post_quantize.cc @@ -140,7 +140,7 @@ void PostQuantizePass::runOnOperation() { // TODO: b/307463853 - Consider splitting passes for each pattern set. patterns.add, RemoveVolatileQdqPattern>(ctx); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { signalPassFailure(); } @@ -148,7 +148,7 @@ void PostQuantizePass::runOnOperation() { patterns_2 .add(ctx); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns_2)))) { + if (failed(applyPatternsGreedily(func, std::move(patterns_2)))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc index 824d24065e239b..7e5e0a9cd83dfa 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc @@ -162,7 +162,7 @@ void PrepareQuantizePass::runOnOperation() { // deal with the arith::ConstantOp instances. patterns.add(ctx); patterns.add(ctx); - if (failed(applyPatternsAndFoldGreedily(func_op, std::move(patterns)))) { + if (failed(applyPatternsGreedily(func_op, std::move(patterns)))) { signalPassFailure(); } @@ -180,7 +180,7 @@ void PrepareQuantizePass::runOnOperation() { patterns_2 .add( ctx); - if (failed(applyPatternsAndFoldGreedily(func_op, std::move(patterns_2)))) { + if (failed(applyPatternsGreedily(func_op, std::move(patterns_2)))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc index e1b8812530f110..91d37dbe5d3d1c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc @@ -101,7 +101,7 @@ void QuantizePass::runOnOperation() { // Quantize all quantizable ops, including ops that are not compute-heavy. PopulateAllQuantizablePatterns(ctx, patterns); - if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { + if (failed(applyPatternsGreedily(module_op, std::move(patterns)))) { // There are cases where no rewrites happen even if a pattern matches, // causing this to result in a convergence failure. Consider this as a // best-effort. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc index e0469cc8d14032..e339f0089248aa 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc @@ -231,7 +231,7 @@ void QuantizeWeightPass::runOnOperation() { FrozenRewritePatternSet frozen_patterns(std::move(patterns)); - if (failed(applyPatternsAndFoldGreedily(func, frozen_patterns))) { + if (failed(applyPatternsGreedily(func, frozen_patterns))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.cc index 5380b53b8ea0d0..675020271bc00e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.cc @@ -50,7 +50,7 @@ void RemoveShardingCustomCallPass::runOnOperation() { populateWithGenerated(patterns); FrozenRewritePatternSet frozen_patterns(std::move(patterns)); - if (failed(applyPatternsAndFoldGreedily(func_op, frozen_patterns))) { + if (failed(applyPatternsGreedily(func_op, frozen_patterns))) { func_op.emitWarning() << "Failed to converge " << RemoveShardingCustomCallPass::getArgumentName(); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/unfuse_mhlo_batch_norm.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/unfuse_mhlo_batch_norm.cc index 13fb470454ea3b..51f9858fd26f3c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/unfuse_mhlo_batch_norm.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/unfuse_mhlo_batch_norm.cc @@ -50,8 +50,7 @@ void UnfuseMhloBatchNormPass::runOnOperation() { RewritePatternSet patterns(ctx); mhlo::populateUnfuseBatchNormPatterns(ctx, &patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/xla_call_module_to_call.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/xla_call_module_to_call.cc index 123244db3b7dbb..6078237c53e2ac 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/xla_call_module_to_call.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/xla_call_module_to_call.cc @@ -74,7 +74,7 @@ void XlaCallModuleToCallPass::runOnOperation() { MLIRContext* ctx = module_op.getContext(); RewritePatternSet patterns(&getContext()); patterns.add(ctx); - if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { + if (failed(applyPatternsGreedily(module_op, std::move(patterns)))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD index f55abc54056257..ddae8b2a8dac04 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD @@ -152,6 +152,7 @@ tf_python_pybind_extension( name = "pywrap_quantization", srcs = ["pywrap_quantization.cc"], pytype_srcs = ["pywrap_quantization.pyi"], + starlark_only = True, visibility = [ "//tensorflow/python:__pkg__", ], diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type_test.cc index 45fb47565ea9e3..4d95f799029225 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type_test.cc @@ -36,6 +36,10 @@ TEST(IsLargeFloatTypeTest, scalars) { auto context = CreateContext(); EXPECT_FALSE(IsLargeFloatType(Float8E4M3FNType::get(context.get()))); + EXPECT_FALSE(IsLargeFloatType(Float8E4M3FNUZType::get(context.get()))); + EXPECT_FALSE(IsLargeFloatType(Float8E4M3B11FNUZType::get(context.get()))); + EXPECT_FALSE(IsLargeFloatType(Float8E5M2FNUZType::get(context.get()))); + EXPECT_FALSE(IsLargeFloatType(Float8E5M2Type::get(context.get()))); EXPECT_FALSE(IsLargeFloatType(Float16Type::get(context.get()))); EXPECT_FALSE(IsLargeFloatType(BFloat16Type::get(context.get()))); EXPECT_TRUE(IsLargeFloatType(Float32Type::get(context.get()))); @@ -54,6 +58,14 @@ TEST(IsLargeFloatTypeTest, tensors) { RankedTensorType::get({2, 2}, Float8E4M3FNType::get(context.get())))); EXPECT_FALSE(IsLargeFloatType( RankedTensorType::get({2, 2}, Float16Type::get(context.get())))); + EXPECT_FALSE(IsLargeFloatType( + RankedTensorType::get({2, 2}, Float8E4M3FNUZType::get(context.get())))); + EXPECT_FALSE(IsLargeFloatType(RankedTensorType::get( + {2, 2}, Float8E4M3B11FNUZType::get(context.get())))); + EXPECT_FALSE(IsLargeFloatType( + RankedTensorType::get({2, 2}, Float8E5M2FNUZType::get(context.get())))); + EXPECT_FALSE(IsLargeFloatType( + RankedTensorType::get({2, 2}, Float8E5M2Type::get(context.get())))); EXPECT_FALSE(IsLargeFloatType( RankedTensorType::get({2, 2}, BFloat16Type::get(context.get())))); EXPECT_TRUE(IsLargeFloatType( @@ -76,6 +88,14 @@ TEST(ToBfloat16TypeTest, scalars) { EXPECT_EQ(ToBfloat16Type(Float8E4M3FNType::get(context.get())), Float8E4M3FNType::get(context.get())); + EXPECT_EQ(ToBfloat16Type(Float8E4M3FNUZType::get(context.get())), + Float8E4M3FNUZType::get(context.get())); + EXPECT_EQ(ToBfloat16Type(Float8E4M3B11FNUZType::get(context.get())), + Float8E4M3B11FNUZType::get(context.get())); + EXPECT_EQ(ToBfloat16Type(Float8E5M2FNUZType::get(context.get())), + Float8E5M2FNUZType::get(context.get())); + EXPECT_EQ(ToBfloat16Type(Float8E5M2Type::get(context.get())), + Float8E5M2Type::get(context.get())); EXPECT_EQ(ToBfloat16Type(Float16Type::get(context.get())), Float16Type::get(context.get())); EXPECT_EQ(ToBfloat16Type(BFloat16Type::get(context.get())), @@ -102,6 +122,21 @@ TEST(ToBfloat16TypeTest, tensors) { ToBfloat16Type( RankedTensorType::get({2, 2}, Float8E4M3FNType::get(context.get()))), RankedTensorType::get({2, 2}, Float8E4M3FNType::get(context.get()))); + EXPECT_EQ( + ToBfloat16Type(RankedTensorType::get( + {2, 2}, Float8E4M3FNUZType::get(context.get()))), + RankedTensorType::get({2, 2}, Float8E4M3FNUZType::get(context.get()))); + EXPECT_EQ( + ToBfloat16Type(RankedTensorType::get( + {2, 2}, Float8E4M3B11FNUZType::get(context.get()))), + RankedTensorType::get({2, 2}, Float8E4M3B11FNUZType::get(context.get()))); + EXPECT_EQ( + ToBfloat16Type(RankedTensorType::get( + {2, 2}, Float8E5M2FNUZType::get(context.get()))), + RankedTensorType::get({2, 2}, Float8E5M2FNUZType::get(context.get()))); + EXPECT_EQ(ToBfloat16Type(RankedTensorType::get( + {2, 2}, Float8E5M2Type::get(context.get()))), + RankedTensorType::get({2, 2}, Float8E5M2Type::get(context.get()))); EXPECT_EQ(ToBfloat16Type( RankedTensorType::get({2, 2}, Float16Type::get(context.get()))), RankedTensorType::get({2, 2}, Float16Type::get(context.get()))); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD index 218e229828211a..d02a11fe8992dd 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD @@ -196,6 +196,7 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold_test.cc index aaaf088b507e07..6a86c88c46e5be 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold_test.cc @@ -155,8 +155,7 @@ TEST_F(ConstantFoldingTest, FoldDepthwiseConvWeight) { RewritePatternSet patterns(ctx_.get()); patterns.add(ctx_.get()); - EXPECT_TRUE( - succeeded(applyPatternsAndFoldGreedily(test_func, std::move(patterns)))); + EXPECT_TRUE(succeeded(applyPatternsGreedily(test_func, std::move(patterns)))); auto depthwise_conv_op = FindOperationOfType(test_func); @@ -188,8 +187,7 @@ TEST_F(ConstantFoldingTest, DepthwiseConvWeightNotFoldable) { RewritePatternSet patterns(ctx_.get()); patterns.add(ctx_.get()); - EXPECT_TRUE( - succeeded(applyPatternsAndFoldGreedily(test_func, std::move(patterns)))); + EXPECT_TRUE(succeeded(applyPatternsGreedily(test_func, std::move(patterns)))); auto depthwise_conv_op = FindOperationOfType(test_func); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.cc index 8ba632b66ae0f3..8deda7c6138303 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include #include -#include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/Support/Casting.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -32,12 +32,12 @@ namespace quant { namespace { // Prefix and suffix to the QuantizationUnit string representation. -constexpr std::string_view kQuantizationUnitPrefix = "QuantizationUnit("; -constexpr std::string_view kQuantizationUnitSuffix = ")"; +constexpr absl::string_view kQuantizationUnitPrefix = "QuantizationUnit("; +constexpr absl::string_view kQuantizationUnitSuffix = ")"; // Concatenates node name and func name with a "@" separator. -std::string ConcatNodeAndFuncName(std::string_view node_name, - std::string_view func_name) { +std::string ConcatNodeAndFuncName(absl::string_view node_name, + absl::string_view func_name) { return absl::StrCat(node_name, "@", func_name); } diff --git a/tensorflow/compiler/mlir/stablehlo/BUILD b/tensorflow/compiler/mlir/stablehlo/BUILD index 0425d7d4300f96..3c293b74b2624b 100644 --- a/tensorflow/compiler/mlir/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/stablehlo/BUILD @@ -43,7 +43,7 @@ tsl_pybind_extension( "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:IR", "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", - "@pybind11", + "@nanobind", "@stablehlo//:stablehlo_capi", ], ) diff --git a/tensorflow/compiler/mlir/stablehlo/stablehlo.cc b/tensorflow/compiler/mlir/stablehlo/stablehlo.cc index af8f69b1298805..60185f3d53257b 100644 --- a/tensorflow/compiler/mlir/stablehlo/stablehlo.cc +++ b/tensorflow/compiler/mlir/stablehlo/stablehlo.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "pybind11/pybind11.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind #include "stablehlo/integrations/python/StablehloApi.h" // from @stablehlo namespace mlir { namespace stablehlo { -PYBIND11_MODULE(stablehlo_extension, m) { mlir::stablehlo::AddPortableApi(m); } +NB_MODULE(stablehlo_extension, m) { mlir::stablehlo::AddPortableApi(m); } } // namespace stablehlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 254fa7abbd9405..9e7466048c1663 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -891,13 +891,13 @@ cc_library( ":dynamic_shape_utils", ":mangling_util", ":tensorflow_attributes", - ":tensorflow_types", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/base", - "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -950,6 +950,7 @@ cc_library( "//tensorflow/core/platform:status", "//tensorflow/core/util:managed_stack_trace", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@local_xla//xla/mlir/utils:error_util", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index a5bb0051cc8fe4..127210340114a5 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -327,6 +327,9 @@ def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">; def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">; def TF_Float8E4M3FNRef : TF_TensorFlowType<"Float8E4M3FNRef", "float8e4m3fnref">; def TF_Float8E5M2Ref : TF_TensorFlowType<"Float8E5M2Ref", "float8e5m2ref">; +def TF_Float8E4M3FNUZRef : TF_TensorFlowType<"Float8E4M3FNUZRef", "float8e4m3fnuzref">; +def TF_Float8E4M3B11FNUZRef : TF_TensorFlowType<"Float8E4M3B11FNUZRef", "float8e4m3b11fnuzref">; +def TF_Float8E5M2FNUZRef : TF_TensorFlowType<"Float8E5M2FNUZRef", "float8e5m2fnuzref">; // Complex reference types def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">; @@ -443,6 +446,9 @@ def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">; def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">; def TF_Float8E4M3FN : AnyTypeOf<[F8E4M3FN, TF_Float8E4M3FNRef], "float8e4m3fn">; def TF_Float8E5M2 : AnyTypeOf<[F8E5M2, TF_Float8E5M2Ref], "float8e5m2">; +def TF_Float8E4M3FNUZ : AnyTypeOf<[F8E4M3FNUZ, TF_Float8E4M3FNUZRef], "float8e4m3fnuz">; +def TF_Float8E4M3B11FNUZ : AnyTypeOf<[F8E4M3B11FNUZ, TF_Float8E4M3B11FNUZRef], "float8e4m3b11fnuz">; +def TF_Float8E5M2FNUZ : AnyTypeOf<[F8E5M2FNUZ, TF_Float8E5M2FNUZRef], "float8e5m2fnuz">; def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">; @@ -460,6 +466,9 @@ def TF_Float64Tensor : TensorOf<[TF_Float64]>; def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>; def TF_Float8E4M3FNTensor : TensorOf<[TF_Float8E4M3FN]>; def TF_Float8E5M2Tensor : TensorOf<[TF_Float8E5M2]>; +def TF_Float8E4M3FNUZTensor : TensorOf<[TF_Float8E4M3FNUZ]>; +def TF_Float8E4M3B11FNUZTensor : TensorOf<[TF_Float8E4M3B11FNUZ]>; +def TF_Float8E5M2FNUZTensor : TensorOf<[TF_Float8E5M2FNUZ]>; //===----------------------------------------------------------------------===// // Complex types (including corresponding reference types) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 373586ae837a3f..f432b6b1f612f8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -260,7 +260,7 @@ void *TensorFlowDialect::getRegisteredInterfaceForOp( // Only use fallback interface for known not-stateful ops. const tensorflow::OpRegistrationData *op_reg_data = nullptr; - tensorflow::Status s = tensorflow::OpRegistry::Global()->LookUp( + absl::Status s = tensorflow::OpRegistry::Global()->LookUp( opName.stripDialect().str(), &op_reg_data); return (s.ok() && !op_reg_data->op_def.is_stateful()) ? fallback_effect_op_interface_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 5ad1642d2f064f..008767bda6cebd 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -796,7 +796,7 @@ void GetOutputShapeForBroadcastGradientArgs(ArrayRef bcasted_shape, } // namespace // Verifies that, -// * Broadcast compatability for input shapes. +// * Broadcast compatibility for input shapes. // * Output shape dimension matches the expected dimension size for input // shapes. LogicalResult BroadcastGradientArgsOp::verify() { @@ -1635,15 +1635,13 @@ LogicalResult ConcatOffsetOp::fold(FoldAdaptor adaptor, if (concat_dim >= num_dims || concat_dim < 0) return failure(); // Check all elements besides at concat_dim match across all shape tensors. - SmallVector shape0; - shape0.reserve(num_dims); - for (int32_t dim : shapes.front().getValues()) shape0.push_back(dim); + DenseIntElementsAttr shape0 = shapes.front(); for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1)) { for (const auto& dims_and_idx : llvm::enumerate(llvm::zip(shape0, shape))) { if (dims_and_idx.index() == concat_dim) continue; - if (std::get<0>(dims_and_idx.value()) != + if (std::get<0>(dims_and_idx.value()).getSExtValue() != std::get<1>(dims_and_idx.value()).getSExtValue()) return failure(); } @@ -1651,14 +1649,25 @@ LogicalResult ConcatOffsetOp::fold(FoldAdaptor adaptor, // Compute an exclusive cumulative sum of elements at concat_dim. results.reserve(shapes.size()); - SmallVector cumulative_sum(num_dims, 0); - RankedTensorType offset_type = tensorflow::GetTypeFromTFTensorShape( - {num_dims}, IntegerType::get(getContext(), 32)); - for (DenseIntElementsAttr shape : shapes) { - results.push_back(DenseIntElementsAttr::get(offset_type, cumulative_sum)); - cumulative_sum[concat_dim] += shape.getValues()[concat_dim]; + if (getShapeType().isInteger(32)) { + SmallVector cumulative_sum(num_dims, 0); + RankedTensorType offset_type = tensorflow::GetTypeFromTFTensorShape( + {num_dims}, IntegerType::get(getContext(), 32)); + for (DenseIntElementsAttr shape : shapes) { + results.push_back(DenseIntElementsAttr::get(offset_type, cumulative_sum)); + cumulative_sum[concat_dim] += shape.getValues()[concat_dim]; + } + } else if (getShapeType().isInteger(64)) { + SmallVector cumulative_sum(num_dims, 0); + RankedTensorType offset_type = tensorflow::GetTypeFromTFTensorShape( + {num_dims}, IntegerType::get(getContext(), 64)); + for (DenseIntElementsAttr shape : shapes) { + results.push_back(DenseIntElementsAttr::get(offset_type, cumulative_sum)); + cumulative_sum[concat_dim] += shape.getValues()[concat_dim]; + } + } else { + return failure(); } - return success(); } @@ -1944,7 +1953,7 @@ static LogicalResult inferConvReturnTypeComponents( // Skip if input or filter size is dynamic. if (input_ty.isDynamicDim(dim) || filter_ty.isDynamicDim(i)) continue; // Calculate the expected_output_size. - tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerbose( + absl::Status status = tensorflow::GetWindowedOutputSizeVerbose( input_ty.getDimSize(dim), filter_ty.getDimSize(i), get_int(dilations[dim]), stride, padding, &expected_output_size, &pad_low, &pad_high); @@ -2278,7 +2287,7 @@ class DivNoNanOrMulNoNanConstantY : public OpRewritePattern { // TF::ConstOp, i.e., if `y` is defined by an op and it is the tf.Const op. // In that case, `yDefOp` stores this tf.Const op. // Note that if `y` is a block argument, `y.getDefiningOp()` will return - // null, which will get propogated by dyn_cast_or_null to `yDefOp`. + // null, which will get propagated by dyn_cast_or_null to `yDefOp`. // Further, if `y` is defined by an op other than tf.Const, // `y.getDefiningOp()` will not return null but dyn_cast_or_null will. if (auto yDefOp = dyn_cast_or_null(y.getDefiningOp())) { @@ -2630,7 +2639,8 @@ namespace { // Flips the incompatible_shape_error attribute to true if the shapes are known // to be compatible. template -static LogicalResult flipComatibleShapeError(Ty op, PatternRewriter& rewriter) { +static LogicalResult flipCompatibleShapeError(Ty op, + PatternRewriter& rewriter) { if (op.getIncompatibleShapeError()) { return rewriter.notifyMatchFailure(op, "the attribute is already true"); } @@ -2663,12 +2673,12 @@ static LogicalResult flipComatibleShapeError(Ty op, PatternRewriter& rewriter) { void EqualOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(flipComatibleShapeError); + results.add(flipCompatibleShapeError); } void NotEqualOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(flipComatibleShapeError); + results.add(flipCompatibleShapeError); } //===----------------------------------------------------------------------===// @@ -2861,9 +2871,6 @@ OpFoldResult FillOp::fold(FoldAdaptor adaptor) { // FusedBatchNormGradOp //===----------------------------------------------------------------------===// -// TODO(b/150954845): Add benchmarks to verify that layout preference didn't -// change in the latest GPU generations. - LogicalResult FusedBatchNormGradV3Op::UpdateDataFormat(StringRef data_format) { return ::mlir::TF::UpdateDataFormat(data_format, this); } @@ -2923,7 +2930,7 @@ LogicalResult FusedBatchNormOp::verify() { template static LogicalResult InferenceFoldOperandsPermutation( ArrayRef permutation, Op* op) { - // FusedBatchNorm in training mode is a layout sentitive operation, and should + // FusedBatchNorm in training mode is a layout sensitive operation, and should // have already assigned an optimal data format. if (op->getIsTraining()) return failure(); return ::mlir::TF::FoldOperandsPermutation(permutation, op); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc index 247a85804d899e..1b6ef2f7112b80 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc @@ -90,7 +90,7 @@ LogicalResult _XlaHostComputeMlirOp::verify() { if (host_module.empty()) return success(); mlir::OwningOpRef module_for_func; - tensorflow::Status status = tensorflow::DeserializeMlirModule( + absl::Status status = tensorflow::DeserializeMlirModule( host_module.str(), op->getContext(), &module_for_func); if (!status.ok()) { return op.emitError() diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def index 17daa6afdcaf4b..2ec55558acbaaf 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def @@ -68,6 +68,9 @@ HANDLE_TF_REF_TYPE(HalfRef, HALF_REF, "halfref") HANDLE_TF_REF_TYPE(ResourceRef, RESOURCE_REF, "resourceref") HANDLE_TF_REF_TYPE(Float8E4M3FNRef, FLOAT8_E4M3FN_REF, "float8e4m3fnref") HANDLE_TF_REF_TYPE(Float8E5M2Ref, FLOAT8_E5M2_REF, "float8e5m2ref") +HANDLE_TF_REF_TYPE(Float8E4M3FNUZRef, FLOAT8_E4M3FNUZ_REF, "float8e4m3fnuzref") +HANDLE_TF_REF_TYPE(Float8E4M3B11FNUZRef, FLOAT8_E4M3B11FNUZ_REF, "float8e4m3b11fnuzref") +HANDLE_TF_REF_TYPE(Float8E5M2FNUZRef, FLOAT8_E5M2FNUZ_REF, "float8e5m2fnuzref") #ifndef HANDLE_LAST_TF_TYPE #define HANDLE_LAST_TF_TYPE(class, enumerant, name) \ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir index 37e8e118ca4347..5d01be5bcc6757 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir @@ -1,3 +1,4 @@ +// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-test=import-variables-as-dense-resources=true -split-input-file %s | FileCheck --check-prefix=CheckWithDense %s --dump-input=fail // RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-test -split-input-file %s | FileCheck %s --dump-input=fail module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} { @@ -15,11 +16,23 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} } // CHECK: "tf_saved_model.global_tensor"() // CHECK: sym_name = "dense/kernel" + // CHECK: value = dense<0.000000e+00> // CHECK: "tf_saved_model.global_tensor"() // CHECK: sym_name = "dense/bias" + // CHECK: value = dense<0.000000e+00> // CHECK: func @serving_default( // CHECK: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, // CHECK: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) + + // CheckWithDense: "tf_saved_model.global_tensor"() + // CheckWithDense: sym_name = "dense/kernel" + // CheckWithDense: value = dense_resource + // CheckWithDense: "tf_saved_model.global_tensor"() + // CheckWithDense: sym_name = "dense/bias" + // CheckWithDense: value = dense_resource + // CheckWithDense: func @serving_default( + // CheckWithDense: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, + // CheckWithDense: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) } // ----- @@ -49,8 +62,10 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} } // CHECK: "tf_saved_model.global_tensor"() // CHECK: sym_name = "dense/kernel" + // CHECK: value = dense<0.000000e+00> // CHECK: "tf_saved_model.global_tensor"() // CHECK: sym_name = "dense/bias" + // CHECK: value = dense<0.000000e+00> // CHECK: func @f( // CHECK: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, // CHECK: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) @@ -58,6 +73,20 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} // CHECK: func @f2( // CHECK: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, // CHECK: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) + + // CheckWithDense: "tf_saved_model.global_tensor"() + // CheckWithDense: sym_name = "dense/kernel" + // CheckWithDense: value = dense_resource + // CheckWithDense: "tf_saved_model.global_tensor"() + // CheckWithDense: sym_name = "dense/bias" + // CheckWithDense: value = dense_resource + // CheckWithDense: func @f( + // CheckWithDense: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, + // CheckWithDense: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) + + // CheckWithDense: func @f2( + // CheckWithDense: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, + // CheckWithDense: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) } // ----- @@ -75,9 +104,21 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} } // CHECK: "tf_saved_model.global_tensor"() // CHECK: sym_name = "dense/kernel" + // CHECK: value = dense<0.000000e+00> // CHECK: "tf_saved_model.global_tensor"() // CHECK: sym_name = "dense/bias" + // CHECK: value = dense<0.000000e+00> // CHECK: func @serving_default( // CHECK: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, // CHECK: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) + + // CheckWithDense: "tf_saved_model.global_tensor"() + // CheckWithDense: sym_name = "dense/kernel" + // CheckWithDense: value = dense_resource + // CheckWithDense: "tf_saved_model.global_tensor"() + // CheckWithDense: sym_name = "dense/bias" + // CheckWithDense: value = dense_resource + // CheckWithDense: func @serving_default( + // CheckWithDense: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, + // CheckWithDense: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index 83fdfdf95a5c61..7fdf1c8a6c1e12 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -196,6 +196,7 @@ cc_library( ":tf_pass_inc_gen", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:framework", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineAnalysis", @@ -224,8 +225,10 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", "//tensorflow/core/platform:threadpool_options", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -248,6 +251,8 @@ cc_library( "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -272,6 +277,8 @@ cc_library( "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework_internal", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -313,6 +320,9 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework_internal", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", @@ -656,6 +666,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:variant", @@ -807,6 +818,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/ir/types:Dialect", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -985,6 +997,7 @@ cc_library( "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc index c6e21cb1e03054..72697e4dd3f862 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc @@ -97,7 +97,7 @@ void BatchMatMulToEinsumPass::runOnOperation() { patterns.add, ConvertTFBatchMatMulToEinsumOp>( &getContext()); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_tf_control_flow_to_scf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_tf_control_flow_to_scf.cc index a3266f58718837..a9b3b4f6809005 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/convert_tf_control_flow_to_scf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_tf_control_flow_to_scf.cc @@ -197,7 +197,7 @@ struct ConvertTfControlFlowToScf void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateTfControlFlowToScfPatterns(&getContext(), &patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_optionals.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_optionals.cc index a5beaf06d6f349..012997de67bca3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_optionals.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_optionals.cc @@ -287,7 +287,7 @@ void DecomposeOptionalsPass::runOnOperation() { pattern_list.add(&getContext()); FrozenRewritePatternSet patterns(std::move(pattern_list)); - if (failed(applyPatternsAndFoldGreedily(module, patterns))) { + if (failed(applyPatternsGreedily(module, patterns))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc index 0c2d026abc60de..cd5ae2d2fdaa2d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc @@ -120,7 +120,7 @@ LogicalResult ApplyPatternsInClusterAndReachableFunctions( // Apply patterns to reachable functions. for (Operation* op : reachable_functions) { assert(isa(op)); - if (failed(applyPatternsAndFoldGreedily(op, patterns))) { + if (failed(applyPatternsGreedily(op, patterns))) { return op->emitError() << kBadDecompositionMessage; } } @@ -137,7 +137,7 @@ LogicalResult ApplyPatternsInClusterAndReachableFunctions( auto walk_result = func.walk([&](tf_device::ClusterOp cluster) { // Cluster ops are not isolated from above so we cannot use - // `applyPatternsAndFoldGreedily` utility. Instead we apply patterns + // `applyPatternsGreedily` utility. Instead we apply patterns // locally on each op within the cluster until convergence. if (failed(ApplyPatternsLocallyUntilConverged(cluster, patterns, max_iterations))) { @@ -162,8 +162,7 @@ struct DecomposeResourceOpsPass RewritePatternSet patterns(&getContext()); TF::PopulateDecomposeResourceOpsPatterns(&getContext(), &patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { getOperation().emitError() << kBadDecompositionMessage; signalPassFailure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index 4cdc90376c2317..f28f3f1447e3fd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -850,7 +850,7 @@ void TransformEinsumPass::runOnOperation() { auto func = getOperation(); patterns.add(&getContext()); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc index 6547b6f168c3bf..9ef0b9b89c34da 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc @@ -194,7 +194,7 @@ void BroadcastFoldPass::runOnOperation() { auto func = getOperation(); patterns.add(func.getContext()); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc index 4eb791a909022d..2327bcb3e4140c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc @@ -360,7 +360,7 @@ void FusedKernelMatcherPass::runOnOperation() { auto func = getOperation(); patterns.add(&getContext()); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc index 98e0f3b345466f..f943d0984617e7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc @@ -123,7 +123,7 @@ void GpuOpFusionPass::runOnOperation() { func::FuncOp func = getOperation(); RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc index 67bf6fa422121e..a2c4a7031ed14b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc @@ -13,7 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include +#include +#include #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Casting.h" @@ -148,7 +152,7 @@ void InitTextFileToImportPass::runOnOperation() { patterns.add( context, StringRef(saved_model_dir_)); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import_test_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import_test_pass.cc index 5b1159801a6f12..a985cdc11611b4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import_test_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import_test_pass.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include +#include +#include #include "llvm/Support/Casting.h" #include "llvm/Support/FileSystem.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc b/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc index f65a5a6af59056..ec43d331191cd5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" @@ -29,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/session_utils.h" #include "tensorflow/core/framework/resource_var.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/public/session.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init_test_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init_test_pass.cc index 623051468e2d7e..61846b557abc67 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init_test_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init_test_pass.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/test_passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/fake_session.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc index 851a87e3620b10..931e6d9295cdbd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc index e8c1d1997e195e..e9ed3b7ce8ae38 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include +#include #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc index 7a4d1bfffc19d7..fe1a8c5031b6af 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc @@ -14,14 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h" -#include -#include +#include #include -#include -#include -#include #include +#include "absl/status/statusor.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" @@ -40,6 +37,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/threadpool_options.h" @@ -66,7 +64,8 @@ constexpr char kSavedModelArgAttr[] = "tf_saved_model.bound_input"; LogicalResult LiftVariablesFromSession( ModuleOp module, Session* session, - const SmallSet& resource_names) { + const SmallSet& resource_names, + bool import_variables_as_dense_resources) { OpBuilder builder(module.getBodyRegion()); if (!session) return module.emitOpError() << "no session provided"; @@ -129,11 +128,13 @@ LogicalResult LiftVariablesFromSession( const Tensor& tensor = std::get<1>(iter); // Create tensor attribute for this variable. - absl::StatusOr tensor_attr_or = - ConvertTensor(tensor, &builder); + absl::StatusOr tensor_attr_or = ConvertTensor( + tensor, &builder, + /*convert_to_dense_resource=*/import_variables_as_dense_resources); if (!tensor_attr_or.ok()) { return module.emitOpError() - << "failed to convert tensor (name: " << name.str() << ")"; + << "failed to convert tensor (name: " << name.str() << ")- " + << tensor_attr_or.status().ToString(); } ElementsAttr tensor_attr = tensor_attr_or.value(); @@ -148,7 +149,8 @@ LogicalResult LiftVariablesFromSession( } // namespace -LogicalResult LiftVariables(ModuleOp module, Session* session) { +LogicalResult LiftVariables(ModuleOp module, Session* session, + bool import_variables_as_dense_resources) { MLIRContext* context = module.getContext(); mlir::Builder builder(context); StringAttr resource_name_id = builder.getStringAttr(kResourceNameArgAttr); @@ -177,7 +179,9 @@ LogicalResult LiftVariables(ModuleOp module, Session* session) { if (resource_names.empty()) return success(); - if (failed(LiftVariablesFromSession(module, session, resource_names))) + if (failed(LiftVariablesFromSession(module, session, resource_names, + /*import_variables_as_dense_resources=*/ + import_variables_as_dense_resources))) return failure(); // Now that we have all global tensors created, we set the corresponding diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h index e86e2f570d01d4..a0a218f67a8184 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h @@ -26,7 +26,8 @@ namespace tf_saved_model { // Creates GlobalTensorOp for each variable from function arguments and converts // them to the corresponding saved model arguments. -LogicalResult LiftVariables(ModuleOp module, ::tensorflow::Session* session); +LogicalResult LiftVariables(ModuleOp module, ::tensorflow::Session* session, + bool import_variables_as_dense_resources = false); } // namespace tf_saved_model } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.cc index d3d04fdbf4f278..7ed4f82c579e2b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.cc @@ -13,14 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/test_passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/fake_session.h" namespace mlir { -namespace tf_saved_model { +namespace tf_test { namespace { using ::tensorflow::Session; @@ -37,7 +39,9 @@ class LiftVariablesTestPass void runOnOperation() override { ModuleOp module = getOperation(); - if (failed(tf_saved_model::LiftVariables(module, session_))) + if (failed(tf_saved_model::LiftVariables( + module, session_, /*import_variables_as_dense_resources=*/ + import_variables_as_dense_resources_))) signalPassFailure(); } @@ -62,18 +66,17 @@ class LiftVariablesInvalidSessionTestPass }; } // namespace -} // namespace tf_saved_model +} // namespace tf_test namespace tf_test { std::unique_ptr> CreateLiftVariablesTestPass() { - return std::make_unique(); + return std::make_unique(); } std::unique_ptr> CreateLiftVariablesInvalidSessionTestPass() { - return std::make_unique< - tf_saved_model::LiftVariablesInvalidSessionTestPass>(); + return std::make_unique(); } } // namespace tf_test diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h deleted file mode 100644 index 0cf52f98e809e3..00000000000000 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h +++ /dev/null @@ -1,150 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_ -#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_ - -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h" -#include "tensorflow/core/common_runtime/device_mgr.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/threadpool_options.h" -#include "tensorflow/core/public/session.h" - -namespace mlir { -namespace tf_saved_model { - -using ::tensorflow::DeviceMgr; -using ::tensorflow::Session; -using ::tensorflow::Status; -using ::tensorflow::Tensor; - -// FakeSession is for testing only. -class FakeSession : public tensorflow::Session { - public: - FakeSession() {} - ~FakeSession() override = default; - - Status Create(const tensorflow::GraphDef& graph) override { - return tensorflow::errors::Unimplemented("not available"); - } - Status Extend(const tensorflow::GraphDef& graph) override { - return tensorflow::errors::Unimplemented("not available"); - } - - Status Close() override { - return tensorflow::errors::Unimplemented("not available"); - } - - Status ListDevices( - std::vector* response) override { - return tensorflow::errors::Unimplemented("not available"); - } - - Status LocalDeviceManager( - const tensorflow::DeviceMgr** deviceMgrPtr) override { - // This method returns a null device manager without making an error. - // Users of this method will be notified since it will have a fake data. - *deviceMgrPtr = nullptr; - return OkStatus(); - } - - Status Run(const std::vector>& inputs, - const std::vector& output_names, - const std::vector& target_nodes, - std::vector* outputs) override { - tensorflow::RunMetadata run_metadata; - return Run(tensorflow::RunOptions(), inputs, output_names, target_nodes, - outputs, &run_metadata); - } - - Status Run(const tensorflow::RunOptions& run_options, - const std::vector>& inputs, - const std::vector& output_names, - const std::vector& target_nodes, - std::vector* outputs, - tensorflow::RunMetadata* run_metadata) override { - return Run(run_options, inputs, output_names, target_nodes, outputs, - run_metadata, tensorflow::thread::ThreadPoolOptions()); - } - - Status Run(const tensorflow::RunOptions& run_options, - const std::vector>& inputs, - const std::vector& output_names, - const std::vector& target_nodes, - std::vector* outputs, - tensorflow::RunMetadata* run_metadata, - const tensorflow::thread::ThreadPoolOptions& thread_pool_options) - override { - for (const std::string& output_name : output_names) { - Tensor output; - if (output_name == "dense/bias") { - Tensor t = Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({50})); - t.flat().setZero(); - outputs->push_back(t); - } else if (output_name == "dense/kernel") { - Tensor t = - Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({100, 50})); - t.flat().setZero(); - outputs->push_back(t); - } else { - // Create a scalar float tensor. - Tensor t = Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({})); - t.flat()(0) = 1.0f; - outputs->push_back(t); - } - } - return OkStatus(); - } -}; - -// This pass is only available in the tf-opt binary for testing. -class LiftVariablesTestPass - : public PassWrapper> { - public: - LiftVariablesTestPass() { session_ = new FakeSession(); } - - ~LiftVariablesTestPass() override { delete session_; } - - void runOnOperation() override { - ModuleOp module = getOperation(); - if (failed(LiftVariables(module, session_))) signalPassFailure(); - } - - private: - Session* session_; -}; - -// This pass is only available in the tf-opt binary for testing. -class LiftVariablesInvalidSessionTestPass - : public PassWrapper> { - public: - void runOnOperation() override { - ModuleOp module = getOperation(); - // Pass an invalid session argument, which is a nullptr. - if (failed(LiftVariables(module, /*session=*/nullptr))) signalPassFailure(); - } -}; - -} // namespace tf_saved_model -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_quantized.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_quantized.cc index ab515860954a2e..cbd7ff56bfc053 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_quantized.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_quantized.cc @@ -16,6 +16,9 @@ limitations under the License. // Rewrites ops that require quantized inputs or outputs to ops that allow // non-quantized inputs and outputs. +#include +#include + #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project @@ -36,7 +39,7 @@ class LowerQuantizedPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); mlir::TF::PopulateLoweringQuantizedPatterns(&getContext(), &patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index da565f00b45b99..bfb0e75db1579e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -15,7 +15,14 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" +#include +#include +#include +#include +#include +#include #include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_test_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_test_pass.cc index 1f3fafcb16bdf7..e128b10af5e0d5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_test_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf_test_pass.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project @@ -39,7 +42,7 @@ struct LowerTF : public impl::TestTensorFlowLowerTFPassBase { mlir::TF::PopulateTFLoweringBeforeHLOPatterns(&getContext(), &patterns); } - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_initialized_variables.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_initialized_variables.cc index 54fef16b043e16..fc58732c3190f7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_initialized_variables.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_initialized_variables.cc @@ -14,9 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tensorflow/transforms/mark_initialized_variables.h" -#include -#include - +#include "absl/strings/str_cat.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -30,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/public/session.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_input_output_aliases.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_input_output_aliases.cc index 6020ef19d824a9..0db17e5dfc79b6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_input_output_aliases.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_input_output_aliases.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc b/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc index 709e4532c1239c..37122141afc784 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include #include #include -#include #include #include diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.cc index 0b41f4a8bdbe6c..21ae326dfe93dc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.cc @@ -15,9 +15,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.h" -#include -#include - #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/Twine.h" #include "mlir/Transforms/Passes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index 80e7cd3991c727..ccafc3719a2705 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include +#include #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -148,8 +150,7 @@ struct TensorFlowOptimizePass void runOnOperation() override { auto func = getOperation(); - if (failed(applyPatternsAndFoldGreedily(func, patterns))) - signalPassFailure(); + if (failed(applyPatternsGreedily(func, patterns))) signalPassFailure(); } FrozenRewritePatternSet patterns; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index bfed05448bd25a..fd4e631a4a7d4c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -17,7 +17,8 @@ limitations under the License. #include #include -#include +#include +#include #include "llvm/ADT/DenseMap.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/order_by_dialect.cc b/tensorflow/compiler/mlir/tensorflow/transforms/order_by_dialect.cc index 5a3f91c0d23b49..1212a960b96d23 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/order_by_dialect.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/order_by_dialect.cc @@ -13,9 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include -#include #include #include diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc index b968923089cb8f..46a9f020ed7dde 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc @@ -13,6 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include + #include "absl/container/flat_hash_set.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -166,7 +171,7 @@ LogicalResult RewriteCommunicationOps(ModuleOp module) { MLIRContext* ctx = module.getContext(); mlir::RewritePatternSet patterns(ctx); patterns.add(ctx); - if (failed(mlir::applyPatternsAndFoldGreedily(module, std::move(patterns)))) { + if (failed(mlir::applyPatternsGreedily(module, std::move(patterns)))) { return module.emitError("failed to apply tf export preparation patterns"); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index bc64c48c81a596..8de89f01748636 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include +#include + #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/PointerUnion.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc index 7c488b8992d2cb..e685f04a8336ef 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc b/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc index 18f54d6b5826d3..493725c6cdcb43 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_while_results.cc b/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_while_results.cc index b4818592ef6f50..bb66bdb39c0148 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_while_results.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_while_results.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" @@ -114,7 +115,7 @@ void RemoveUnusedWhileResultsPass::runOnOperation() { MLIRContext* context = &getContext(); RewritePatternSet patterns(context); TF::WhileRegionOp::getCanonicalizationPatterns(patterns, context); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/remove_vars_in_session_initializer.cc b/tensorflow/compiler/mlir/tensorflow/transforms/remove_vars_in_session_initializer.cc index 3a6377a3bb63e1..7d0d650b38bff8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/remove_vars_in_session_initializer.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/remove_vars_in_session_initializer.cc @@ -13,9 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include -#include #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replica_id_to_device_ordinal.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replica_id_to_device_ordinal.cc index 0294ba24d394fc..88628bf1a3c2fe 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replica_id_to_device_ordinal.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replica_id_to_device_ordinal.cc @@ -17,7 +17,6 @@ limitations under the License. // the replica id attribute. #include -#include #include "llvm/Support/Casting.h" #include "mlir/Pass/Pass.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index 52c449d227c5ee..3928faaa280398 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -16,11 +16,13 @@ limitations under the License. // This pass forms `tf_executor.island` per replica from a single // `tf_device.replicate` island. +#include #include #include #include #include +#include "absl/strings/str_cat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc index 180da6a90e81d2..17796c18242090 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include #include #include -#include #include #include "llvm/ADT/ArrayRef.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 90397e7f8237c9..c7ffc9c0dd462f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -15,8 +15,10 @@ limitations under the License. // This pass lifts resource variable operations outside of device computation. -#include +#include #include +#include +#include #include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseMap.h" @@ -30,11 +32,11 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project -#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Region.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc index faedd25114807e..deef690b4d9636 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h index daf8c04fbd9365..b8bc0a1d57cdec 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_REWRITE_UTIL_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_REWRITE_UTIL_H_ +#include + #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc b/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc index 7a93aac60fc7cb..6fb99069362162 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.h" #include +#include #include #include diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index a8df162eb9fa17..f44c7e969ce6dc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -16,9 +16,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h" #include +#include +#include #include #include -#include #include #include #include @@ -27,6 +28,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/status/status.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc index 1a84be115b355e..4c0a16857a2f18 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include #include "llvm/ADT/DenseMap.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc index 476a67b496355f..a548b88d3f7c29 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include #include -#include #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/strip_noinline_attribute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/strip_noinline_attribute.cc index 4ac965d57359e9..69c6ae88f926a8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/strip_noinline_attribute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/strip_noinline_attribute.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index b1b2733802234d..47b046d9fdaee2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -13,7 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include +#include +#include #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc index 267f32daa9f6e6..18ddf3ef909dc2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc @@ -16,6 +16,8 @@ limitations under the License. // This pass folds the tf.Identity op if the operation has the same device as // its operand. +#include + #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index a9ad31a28461f7..857e6a29d8ffb8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -13,10 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include #include -#include #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/test_cluster_ops_by_policy.cc b/tensorflow/compiler/mlir/tensorflow/transforms/test_cluster_ops_by_policy.cc index 80e6bd739066e9..54dc049d7c0d83 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/test_cluster_ops_by_policy.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/test_cluster_ops_by_policy.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/test_resource_alias_analysis.cc b/tensorflow/compiler/mlir/tensorflow/transforms/test_resource_alias_analysis.cc index 22577b4dba1aa7..064ddd0a4fdcce 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/test_resource_alias_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/test_resource_alias_analysis.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include -#include +#include #include #include "llvm/ADT/DenseMap.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc index 2ad0e6bc946b57..52d3969c894029 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc @@ -13,10 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include +#include #include -#include #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization_pass.cc index 98cc5c4a756754..3b49e6d7c360f7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization_pass.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project @@ -35,7 +36,7 @@ struct TFDataOptimization RewritePatternSet patterns(&getContext()); mlir::TF::PopulateTFDataOptimizationPatterns(&getContext(), &patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc index d233b5167451db..ff0cf6231e6213 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ // This file implements device assignment in TF dialect. +#include +#include + #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_functional_to_executor.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_functional_to_executor.cc index a6cad7fe77acee..f9be96c902a290 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_functional_to_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_functional_to_executor.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 369840c888f4a2..3fef3ff9ba2020 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -15,9 +15,16 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h" +#include +#include +#include +#include +#include #include +#include #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "llvm/Support/CommandLine.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project @@ -29,6 +36,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" @@ -75,7 +83,7 @@ void GraphOptPass::runOnOperation() { GraphExportConfig confs; auto graph = std::make_unique(flib_def); absl::flat_hash_set control_ret_nodes; - Status status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( + absl::Status status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( module_in, confs, &graph, &flib_def, &control_ret_nodes); if (!status.ok()) { mlir::emitError(mlir::UnknownLoc::get(&ctx)) << status.message(); @@ -95,7 +103,7 @@ void GraphOptPass::runOnOperation() { for (auto pass : passes_) { assert(pass != nullptr); - Status status = pass->Run(options); + absl::Status status = pass->Run(options); if (!status.ok()) { mlir::emitError(mlir::UnknownLoc::get(&ctx)) << pass->name() << ": " << status.message(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h index 340444d4a329b7..2b60139557a2e3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h @@ -16,6 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_GRAPH_OPTIMIZATION_PASS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_GRAPH_OPTIMIZATION_PASS_H_ +#include +#include +#include + #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/core/common_runtime/optimization_registry.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.cc index 963c7ed0c62084..4ecf7ccec9b2bb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables_test_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables_test_pass.cc index e8eb0a859ed1c8..162abc4e6b78e3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables_test_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables_test_pass.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/test_passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_test_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_test_passes.td index 8758a3631a96e0..df4deb30ff3a6b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_test_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_test_passes.td @@ -81,6 +81,11 @@ def LiftVariablesInvalidSessionTestPass : Pass<"tf-saved-model-lift-variables-in def LiftVariablesTestPass : Pass<"tf-saved-model-lift-variables-test", "ModuleOp"> { let summary = "Lift variables and save them as global tensors"; let constructor = "mlir::tf_test::CreateLiftVariablesTestPass()"; + + let options = [ + Option<"import_variables_as_dense_resources_", "import-variables-as-dense-resources", "bool", /*default=*/"false", + "Import variables as dense resources">, + ]; } def InitializeVariablesInSessionInitializerPass : Pass<"tf-saved-model-initialize-variables-in-session-init", "ModuleOp"> { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tfg-to-tfe.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tfg-to-tfe.cc index 68d50e54a1bce0..d68157a052887b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tfg-to-tfe.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tfg-to-tfe.cc @@ -13,7 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "absl/strings/match.h" +#include +#include +#include +#include + #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_annotate_dynamic_shape_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_annotate_dynamic_shape_inputs.cc index b4a98605a34ac2..f64019b08b1362 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_annotate_dynamic_shape_inputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_annotate_dynamic_shape_inputs.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_cleanup_attributes.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_cleanup_attributes.cc index bf9e1f4647a0d4..5a5d4677b1f63d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_cleanup_attributes.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_cleanup_attributes.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc index e2b9c62ee8e6bc..ccdf4a53ffc465 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include #include "llvm/ADT/DenseMap.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc index 2281658efc5ed1..ae5710c3d74cea 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_host_computation_expansion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_host_computation_expansion.cc index b2a3b81f63a1a9..332512d00ba9be 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_host_computation_expansion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_host_computation_expansion.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Attributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_identity_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_identity_pruning.cc index 17326d160368a4..03025d77675810 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_identity_pruning.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_identity_pruning.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_parallel_execute_sink_resource_write.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_parallel_execute_sink_resource_write.cc index 9ef8cda3d6f92a..bb4c951065f771 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_parallel_execute_sink_resource_write.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_parallel_execute_sink_resource_write.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_partitioned_op_conversion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_partitioned_op_conversion.cc index 08165fb1435ff2..180fd8eaaed75e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_partitioned_op_conversion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_partitioned_op_conversion.cc @@ -10,11 +10,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include #include -#include -#include #include #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_reorder_replicate_and_partitioned_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_reorder_replicate_and_partitioned_inputs.cc index be4f986bf1ff26..559c625167c7ae 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_reorder_replicate_and_partitioned_inputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_reorder_replicate_and_partitioned_inputs.cc @@ -10,7 +10,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include #include #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc index 086fab19ac98a1..fdacf313d30240 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc index ef16273e9eea45..eb11dbd722bf74 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc @@ -14,8 +14,11 @@ limitations under the License. ==============================================================================*/ #include -#include +#include +#include #include +#include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc index ef6f03e0be355f..7bd71f3e48078a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Casting.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc index ff8ac1ad7cacd1..03618d23464b0a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc @@ -15,11 +15,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h" -#include #include +#include #include +#include -#include "absl/memory/memory.h" +#include "absl/container/inlined_vector.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" @@ -73,7 +74,7 @@ void UnrollBatchMatMulPass::runOnOperation() { RewritePatternSet patterns(&getContext()); auto func = getOperation(); PopulateUnrollTfBatchMatMul(&getContext(), patterns); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc b/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc index 9f36e838206804..63e255748c41ae 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc @@ -14,12 +14,11 @@ limitations under the License. ==============================================================================*/ #include +#include #include -#include -#include #include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/verify_suitable_for_graph_export_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/verify_suitable_for_graph_export_pass.cc index 623cf8af3d6ea9..80057322280230 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/verify_suitable_for_graph_export_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/verify_suitable_for_graph_export_pass.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc index c80005e6ae3cb2..b6c930f5d08d1a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_inline_device_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_inline_device_ops.cc index f9318637fe9562..a55470bb8391d2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_inline_device_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_inline_device_ops.cc @@ -15,6 +15,10 @@ limitations under the License. // This pass remove Cluster ops by inlining Cluster ops. +#include +#include +#include + #include "llvm/ADT/SmallVector.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc index 8ce264b47b57d4..3ccf4d9554330b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc @@ -16,7 +16,7 @@ limitations under the License. // This transformation pass converts stateful and stateless partitioned calls // with _xla_compile_device_type attribute to XLA launch ops. -#include +#include #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_validate_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_validate_inputs.cc index 9267607e7e342a..95250846eb0801 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_validate_inputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_validate_inputs.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/BUILD b/tensorflow/compiler/mlir/tensorflow/translate/BUILD index 2cb4cb5fdffc61..fd8f7a1970ba14 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/translate/BUILD @@ -47,7 +47,10 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", @@ -77,6 +80,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:DerivedAttributeOpInterface", @@ -99,6 +104,8 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@local_xla//xla:status_macros", @@ -135,8 +142,10 @@ cc_library( "//tensorflow/core/grappler/utils:transitive_fanin", "//tensorflow/core/util/tensor_bundle:byteswaptensor", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", @@ -154,7 +163,10 @@ cc_library( deps = [ "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc index f1ce5038e23d01..e8c92e4e6c5f4e 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc @@ -15,10 +15,15 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include +#include #include #include +#include #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/StringRef.h" @@ -31,8 +36,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" #include "tensorflow/compiler/mlir/utils/string_container_utils.h" #include "xla/status_macros.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -46,8 +53,8 @@ template () .begin())>::value>::type> -Status SetTypeAttribute(absl::string_view name, ContainerT types, - AttrValueMap* values) { +absl::Status SetTypeAttribute(absl::string_view name, ContainerT types, + AttrValueMap* values) { AttrValue value; auto& type_list = *value.mutable_list(); for (auto type : types) { @@ -93,7 +100,7 @@ void SetShapeAttribute(absl::string_view name, ContainerT shapes, // Collects all the unregistered attributes for an TF dialect operation. // Attributes "name" and "device" are not included because they are not part // of an TF op attributes. -Status GetUnregisteredAttrs( +absl::Status GetUnregisteredAttrs( mlir::Operation* inst, const tensorflow::OpRegistrationData* op_reg_data, absl::flat_hash_set* attrs_to_ignore) { if (!op_reg_data) { @@ -166,10 +173,11 @@ absl::StatusOr> GetAttributesToIgnore( // Populates all derived attributes of a MLIR operation in a proto // map. -Status PopulateDerivedAttributes(mlir::Operation* inst, llvm::StringRef name, - mlir::DictionaryAttr derived_attrs, - bool ignore_unregistered_attrs, - AttrValueMap* attributes) { +absl::Status PopulateDerivedAttributes(mlir::Operation* inst, + llvm::StringRef name, + mlir::DictionaryAttr derived_attrs, + bool ignore_unregistered_attrs, + AttrValueMap* attributes) { if (derived_attrs) { TF_RETURN_WITH_CONTEXT_IF_ERROR( ConvertAttributes(derived_attrs.getValue(), /*attrs_to_ignore=*/{}, @@ -219,7 +227,7 @@ void RemoveIdentityCast(NodeDef* node_def) { } // namespace -Status GetAttrValuesFromOperation( +absl::Status GetAttrValuesFromOperation( mlir::Operation* inst, llvm::StringRef name, const tensorflow::OpRegistrationData* op_reg_data, bool ignore_unregistered_attrs, AttrValueMap* attributes) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h index f15e741b247340..47bc42e096dd2a 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_TF_DIALECT_OP_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_TF_DIALECT_OP_H_ +#include + +#include "absl/status/statusor.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" @@ -29,7 +32,7 @@ namespace tensorflow { // Extracts the attributes of a MLIR operation and populates the converted // attributes in a proto map. -Status GetAttrValuesFromOperation( +absl::Status GetAttrValuesFromOperation( mlir::Operation* inst, llvm::StringRef name, const tensorflow::OpRegistrationData* op_reg_data, bool ignore_unregistered_attrs, AttrValueMap* attributes); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 194abe76611d7e..ea9a997532cd85 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -32,8 +32,10 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" -#include "absl/strings/match.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" @@ -196,7 +198,8 @@ class NameUniquifier : public OpOrArgNameMapper { // the GraphDef. // - Replacing LegacyFedInput nodes with Placeholder nodes if // convert_legacy_fed_inputs option is enabled. -Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) { +absl::Status PreprocessGraphDef(const GraphImportConfig* specs, + GraphDef* graph_def) { for (auto& node_def : *graph_def->mutable_node()) { const tensorflow::OpRegistrationData* op_reg_data = tensorflow::OpRegistry::Global()->LookUp(node_def.op()); @@ -209,9 +212,6 @@ Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) { return absl::OkStatus(); } - - - // Determines the names used to reference objects in the SavedObjectGraph. class ObjectNames { public: @@ -431,7 +431,7 @@ const TensorProto* ExtractConstTensorFromGraph(const GraphDef& graph_def, const TrackableObjectGraph::TrackableObject::SerializedTensor* FindSerializedTensorInTrackable( const TrackableObjectGraph::TrackableObject& trackable_object, - StringPiece name) { + absl::string_view name) { for (const auto& maybe_serialized_tensor : trackable_object.attributes()) { if (maybe_serialized_tensor.name() == name) { return &maybe_serialized_tensor; @@ -440,8 +440,8 @@ FindSerializedTensorInTrackable( return nullptr; } -Status DiagnoseMultipleConcreteFunctions(const SavedObjectGraph& object_graph, - const ObjectNames& object_names) { +absl::Status DiagnoseMultipleConcreteFunctions( + const SavedObjectGraph& object_graph, const ObjectNames& object_names) { for (int node_id = 0; node_id < object_graph.nodes_size(); node_id++) { const SavedObject& object = object_graph.nodes(node_id); if (object_names.GetExportedNames(node_id).empty()) { @@ -750,7 +750,7 @@ void SortSavedModelModule(mlir::ModuleOp module) { } } -Status CreateSavedModelIR( +absl::Status CreateSavedModelIR( const ObjectNames& object_names, mlir::ModuleOp module, const SavedObjectGraph& object_graph, const std::unordered_map& tf_name_to_mlir_name, @@ -923,7 +923,11 @@ Status CreateSavedModelIR( saved_model->variable_reader()->Lookup(checkpoint_key, &value), "Could not read checkpoint key from variables bundle: ", checkpoint_key); - TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensor(value, &builder)); + TF_ASSIGN_OR_RETURN( + auto value_attr, + ConvertTensor(value, &builder, + /*convert_to_dense_resource=*/ + import_options.import_variables_as_dense_resources)); // A variable can have a partially known type, such as // tensor, even if the initializer is a specific static // shape. @@ -1191,8 +1195,8 @@ class SavedModelSignatureDefImporterLite { // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function // for each signature. absl::StatusOr> ConvertSignatures(); - Status ConvertSignature(const std::string& sig_def_key, - const SignatureDef& signature_def); + absl::Status ConvertSignature(const std::string& sig_def_key, + const SignatureDef& signature_def); struct AssetInfo { std::string tensor_name; @@ -1203,9 +1207,9 @@ class SavedModelSignatureDefImporterLite { // Converts the initialization graph in the SavedModel to an MLIR function. // Attaches `tf_saved_model.initializer_type` attribute with value // `initializer_type` to the created function. - Status ConvertInitializer(const std::string& target_node_name, - const std::vector& assets, - llvm::StringRef initializer_type); + absl::Status ConvertInitializer(const std::string& target_node_name, + const std::vector& assets, + llvm::StringRef initializer_type); // Converts a graph with feeds and fetches to an MLIR function. absl::StatusOr> ConvertGraph( @@ -1217,7 +1221,7 @@ class SavedModelSignatureDefImporterLite { // Moves the functions in `sub_module` to `module_` and skips the duplicate // functions. - Status MoveConvertedFunctionsToModule( + absl::Status MoveConvertedFunctionsToModule( absl::string_view name, mlir::ModuleOp sub_module, const std::unordered_map& tf_name_to_mlir_name); @@ -1262,7 +1266,7 @@ SavedModelSignatureDefImporterLite::ConvertAssets() { return results; } -Status SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule( +absl::Status SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule( absl::string_view name, mlir::ModuleOp sub_module, const std::unordered_map& tf_name_to_mlir_name) { mlir::Builder builder(sub_module.getContext()); @@ -1306,7 +1310,7 @@ Status SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule( return absl::OkStatus(); } -Status SavedModelSignatureDefImporterLite::ConvertInitializer( +absl::Status SavedModelSignatureDefImporterLite::ConvertInitializer( const std::string& target_node_name, const std::vector& assets, llvm::StringRef initializer_type) { std::vector> inputs; @@ -1386,7 +1390,7 @@ SavedModelSignatureDefImporterLite::ConvertGraph( module_->getContext(), tf_name_to_mlir_name); } -Status SavedModelSignatureDefImporterLite::ConvertSignature( +absl::Status SavedModelSignatureDefImporterLite::ConvertSignature( const std::string& sig_def_key, const SignatureDef& signature_def) { // Create local vectors for the input and output and sort them to be // deterministic. We don't want anyone to really depend on the order, client @@ -1495,7 +1499,7 @@ SavedModelSignatureDefImporterLite::ConvertSignatures() { } absl::Mutex error_status_mu; // Needed since `error_status` is non-atomic. - tensorflow::Status error_status; + absl::Status error_status; { // Start a threadpool to convert signatures, since signature conversion can // be time consuming especially for large models. Threadpool destructor @@ -1610,7 +1614,8 @@ class SavedModelSignatureDefImporter { builder.getUnitAttr()); TF_RETURN_IF_ERROR( LiftVariables(bundle, *module, options.lift_variables, - options.include_variables_in_initializers)); + options.include_variables_in_initializers, + options.import_variables_as_dense_resources)); (*module)->removeAttr("tf_saved_model.under_construction"); return module; @@ -1623,16 +1628,18 @@ class SavedModelSignatureDefImporter { // `tf_saved_model::SessionInitializerOp`) by running the // `RemoveVariablesInSessionInitializerPass`, regardless of whether // `lift_variable_ops_to_args` is true or not. - static Status LiftVariables(const SavedModelBundle& bundle, - mlir::ModuleOp module, - bool lift_varhandle_ops_to_args, - bool include_variables_in_initializers); + static absl::Status LiftVariables(const SavedModelBundle& bundle, + mlir::ModuleOp module, + bool lift_varhandle_ops_to_args, + bool include_variables_in_initializers, + bool import_variables_as_dense_resources); }; -Status SavedModelSignatureDefImporter::LiftVariables( +absl::Status SavedModelSignatureDefImporter::LiftVariables( const SavedModelBundle& bundle, mlir::ModuleOp module, const bool lift_varhandle_ops_to_args, - const bool include_variables_in_initializers) { + const bool include_variables_in_initializers, + const bool import_variables_as_dense_resources) { mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); mlir::PassManager pm(module.getContext()); @@ -1662,8 +1669,8 @@ Status SavedModelSignatureDefImporter::LiftVariables( if (mlir::failed(pm.run(module))) return diag_handler.Combine( errors::Internal("Failed to promote var handles to args.")); - if (failed( - mlir::tf_saved_model::LiftVariables(module, bundle.GetSession()))) + if (failed(mlir::tf_saved_model::LiftVariables( + module, bundle.GetSession(), import_variables_as_dense_resources))) return diag_handler.Combine( errors::Internal("Failed to lift variables.")); } else { @@ -1706,29 +1713,6 @@ absl::StatusOr> ConvertGraphdefToMlir( graph, debug_info, graph.flib_def(), specs, context); } -absl::StatusOr> ConvertGraphToMlir( - const Graph& graph, const GraphDebugInfo& debug_info, - const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, - mlir::MLIRContext* context, - std::unordered_map* tf_name_to_mlir_name) { - return tensorflow::tf2xla::v2::ConvertGraphToTfExecutor( - graph, debug_info, flib_def, specs, context, tf_name_to_mlir_name); -} - -absl::StatusOr> ConvertFunctionToMlir( - const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def, - mlir::MLIRContext* context) { - tensorflow::GraphDebugInfo dummy_debug_info; - tensorflow::GraphImportConfig specs; - specs.graph_func_name = fbody->record->fdef().signature().name(); - specs.enable_shape_inference = false; - specs.graph_as_function = true; - for (const auto* control_ret_node : fbody->control_ret_nodes) - specs.control_outputs.push_back(control_ret_node->name()); - return ConvertGraphToMlir(*fbody->graph, dummy_debug_info, flib_def, specs, - context); -} - absl::StatusOr> ConvertSavedModelToMlir( SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span exported_names, MLIRImportOptions options) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index e24d7b140d5889..fe7684adc1f2cf 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -17,9 +17,14 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ #include +#include #include +#include "absl/base/attributes.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -33,6 +38,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" namespace tensorflow { @@ -45,24 +51,6 @@ absl::StatusOr> ConvertGraphdefToMlir( const GraphDef& graphdef, const GraphDebugInfo& debug_info, const GraphImportConfig& specs, mlir::MLIRContext* context); -// Given a Graph, returns a MLIR module containing the graph, expressed with -// tf_executor dialect. -ABSL_DEPRECATED("Use tensorflow::tf2xla::v2::ConvertGraphToTfExecutor instead.") -absl::StatusOr> ConvertGraphToMlir( - const Graph& graph, const GraphDebugInfo& debug_info, - const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, - mlir::MLIRContext* context, - std::unordered_map* tf_name_to_mlir_name = - nullptr); - -// [Experimental] -// Given a Function, returns a MLIR module containing the graph, expressed with -// tf_executor dialect. -ABSL_DEPRECATED("Use tensorflow::tf2xla::v2::ConvertGraphToTfExecutor instead.") -absl::StatusOr> ConvertFunctionToMlir( - const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def, - mlir::MLIRContext* context); - // Given a SavedModel, returns a MLIR module containing the functions, expressed // with tf_executor dialect. absl::StatusOr> ConvertSavedModelToMlir( diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h index 44262d0bd08d86..b49ed7bbfc6a35 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h @@ -49,6 +49,10 @@ struct MLIRImportOptions { // Load the model without restoring associated variables from disk. Enables // loading raw programs without checkpoints. bool allow_uninitialized_variables = false; + + // If true, variables are imported as DenseResourceElementsAttr; else, + // variables are imported as DenseElementsAttr. + bool import_variables_as_dense_resources = false; }; } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc index 8664d080cd75c6..a7eee4a191e236 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc @@ -18,15 +18,17 @@ limitations under the License. #include #include #include +#include #include #include +#include -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "xla/status_macros.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -67,14 +69,14 @@ std::string GraphImportConfig::str() const { return ss.str(); } -Status ParseOutputArrayInfo(absl::string_view array_names, - std::vector* outputs) { +absl::Status ParseOutputArrayInfo(absl::string_view array_names, + std::vector* outputs) { TF_RETURN_IF_ERROR(ParseNodeNames(array_names, *outputs)); return absl::OkStatus(); } -Status ParseOutputArrayInfo(const std::vector& output_names, - std::vector* outputs) { +absl::Status ParseOutputArrayInfo(const std::vector& output_names, + std::vector* outputs) { for (auto& output_name : output_names) { if (output_name.empty()) continue; outputs->push_back(output_name); @@ -82,10 +84,10 @@ Status ParseOutputArrayInfo(const std::vector& output_names, return absl::OkStatus(); } -Status ParseInputArrayInfo(absl::string_view array_names, - absl::string_view data_types, - absl::string_view shapes, - GraphImportConfig::InputArrays* inputs) { +absl::Status ParseInputArrayInfo(absl::string_view array_names, + absl::string_view data_types, + absl::string_view shapes, + GraphImportConfig::InputArrays* inputs) { std::vector node_names; std::vector node_dtypes; std::vector>> node_shapes; @@ -112,8 +114,8 @@ static absl::StatusOr> ParseShapeStr( return dims; } -static Status HandleSubtype(absl::string_view subtype, - ArrayInfo::SubTypeInfo* result) { +static absl::Status HandleSubtype(absl::string_view subtype, + ArrayInfo::SubTypeInfo* result) { std::vector shape_and_type = absl::StrSplit(subtype, ':'); std::vector dims; @@ -141,7 +143,7 @@ static Status HandleSubtype(absl::string_view subtype, return absl::OkStatus(); } -Status ParseInputArrayInfo( +absl::Status ParseInputArrayInfo( const std::vector& node_names, const std::vector& node_dtypes, const std::vector>>& node_shapes, @@ -217,7 +219,7 @@ Status ParseInputArrayInfo( return absl::OkStatus(); } -Status ParseNodeShapes( +absl::Status ParseNodeShapes( absl::string_view shapes_str, std::vector>>& shapes_vector) { shapes_vector.clear(); @@ -235,8 +237,8 @@ Status ParseNodeShapes( return absl::OkStatus(); } -Status ParseNodeNames(absl::string_view names_str, - std::vector& names_vector) { +absl::Status ParseNodeNames(absl::string_view names_str, + std::vector& names_vector) { names_vector = absl::StrSplit(names_str, ',', absl::SkipEmpty()); return absl::OkStatus(); } @@ -284,8 +286,8 @@ static absl::StatusOr> ParseDTypesHelper( return dtypes; } -Status ParseNodeDataTypes(absl::string_view data_types_str, - std::vector& data_type_vector) { +absl::Status ParseNodeDataTypes(absl::string_view data_types_str, + std::vector& data_type_vector) { data_type_vector.clear(); if (!data_types_str.empty()) { TF_ASSIGN_OR_RETURN(data_type_vector, ParseDTypesHelper(data_types_str)); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h index 8873b0928b028f..cf90b7edf359b2 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h @@ -18,8 +18,10 @@ limitations under the License. #include #include +#include #include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/StringMap.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -117,20 +119,20 @@ struct GraphExportConfig { // Parses the command line flag strings to the specification of nodes in // the Graph. -Status ParseOutputArrayInfo(absl::string_view array_names, - std::vector* outputs); +absl::Status ParseOutputArrayInfo(absl::string_view array_names, + std::vector* outputs); -Status ParseOutputArrayInfo(const std::vector& output_names, - std::vector* outputs); +absl::Status ParseOutputArrayInfo(const std::vector& output_names, + std::vector* outputs); // Parses the command line flag strings to the specification of nodes in // the Graph. `data_types` input string can be empty since the flag is optional. -Status ParseInputArrayInfo(absl::string_view array_names, - absl::string_view data_types, - absl::string_view shapes, - GraphImportConfig::InputArrays* inputs); +absl::Status ParseInputArrayInfo(absl::string_view array_names, + absl::string_view data_types, + absl::string_view shapes, + GraphImportConfig::InputArrays* inputs); -Status ParseInputArrayInfo( +absl::Status ParseInputArrayInfo( const std::vector& node_names, const std::vector& node_dtypes, const std::vector>>& node_shapes, @@ -139,19 +141,19 @@ Status ParseInputArrayInfo( // Parses shapes from the given string into shapes_vector which is a structured // format. // NOTE: If shapes_str is empty, shapes_vector will also be empty. -Status ParseNodeShapes( +absl::Status ParseNodeShapes( absl::string_view shapes_str, std::vector>>& shapes_vector); // Parses names from the given string into the names_vector. // NOTE: If names_str is empty, names_vector will also be empty. -Status ParseNodeNames(absl::string_view names_str, - std::vector& names_vector); +absl::Status ParseNodeNames(absl::string_view names_str, + std::vector& names_vector); // Parses data types from the given string into the data_type_vector. // NOTE: If data_types_str is empty, data_type_vector will also be empty. -Status ParseNodeDataTypes(absl::string_view data_types_str, - std::vector& data_type_vector); +absl::Status ParseNodeDataTypes(absl::string_view data_types_str, + std::vector& data_type_vector); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index 9951c58f5b6820..b64da3edc8867c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -15,13 +15,19 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" +#include +#include #include #include +#include #include -#include "absl/memory/memory.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -41,11 +47,13 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/grappler/utils/transitive_fanin.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/util/tensor_bundle/byte_swap_tensor.h" namespace tensorflow { @@ -217,7 +225,8 @@ SavedModelObjectGraphToMlirImport(absl::string_view saved_model_dir, const std::unordered_set& tags, absl::Span exported_names, mlir::MLIRContext* context, - bool unconditionally_use_set_output_shapes) { + bool unconditionally_use_set_output_shapes, + bool import_variables_as_dense_resources) { tensorflow::SavedModelV2Bundle bundle; auto load_status = tensorflow::SavedModelV2Bundle::Load( std::string(saved_model_dir.data(), saved_model_dir.length()), &bundle); @@ -231,6 +240,8 @@ SavedModelObjectGraphToMlirImport(absl::string_view saved_model_dir, options.add_default_attributes = true; options.unconditionally_use_set_output_shapes = unconditionally_use_set_output_shapes; + options.import_variables_as_dense_resources = + import_variables_as_dense_resources; auto module_or = ConvertSavedModelToMlir(&bundle, context, exported_names, options); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h index cd86b27e13550c..8d404575cbdcec 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -16,12 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_ +#include #include #include #include #include +#include "absl/base/attributes.h" #include "absl/base/macros.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -105,7 +108,8 @@ SavedModelObjectGraphToMlirImport( absl::string_view saved_model_dir, const std::unordered_set& tags, absl::Span exported_names, mlir::MLIRContext* context, - bool unconditionally_use_set_output_shapes = false); + bool unconditionally_use_set_output_shapes = false, + bool import_variables_as_dense_resources = false); // Converts a TensorFlow V1 SavedModel stored in the directory with the given // `saved_model_dir` into a MLIR module. Creates MLIR entities into the diff --git a/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc index 509bd99d8930e9..17c0bd98cc6140 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.cc @@ -15,7 +15,15 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h" +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/StringSet.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" namespace tensorflow { @@ -36,7 +44,7 @@ const llvm::StringSet<>& GetSharedNameGenerationCompatibleOps() { } // namespace -Status GenerateResourceSharedNameIfEmpty( +absl::Status GenerateResourceSharedNameIfEmpty( GraphDef& gdef, const OpRegistryInterface* default_registry) { auto is_resource_op_with_empty_shared_name = [](const NodeDef& node_def, const OpDef& op_def) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h index 33d48cb6bf8efb..31baee5514ee3a 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h @@ -27,7 +27,7 @@ class MetaGraphDef; // Generate the shared_name for resource handle ops in the graph and functions // if their shared_names are empty. Resource handle ops with empty shared_name // may have undesired semantics. -Status GenerateResourceSharedNameIfEmpty( +absl::Status GenerateResourceSharedNameIfEmpty( GraphDef& gdef, const OpRegistryInterface* default_registry); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index b9fef486428977..b0ad4e265633d8 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -15,43 +15,49 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include +#include #include -#include +#include #include #include +#include #include -#include "absl/base/casts.h" -#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/AsmState.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectResourceBlobManager.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" +#include "xla/tsl/platform/errors.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/bfloat16.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tstring.h" #include "tsl/platform/ml_dtypes.h" +#include "tsl/platform/statusor.h" namespace tensorflow { @@ -85,13 +91,120 @@ static std::string MangleTensor(const Tensor& tensor) { return mangling_util::MangleTensor(ConvertToProto(tensor)); } +template +static absl::Status CopyDataIntoBlob(mlir::AsmResourceBlob& blob, + absl::string_view raw_src_data) { + ArrayRef data = blob.getDataAs(); + llvm::MutableArrayRef raw_dest_data = + mlir::MutableArrayRef(const_cast(data.data()), + data.size()); + if (raw_src_data.size() != blob.getData().size()) { + return absl::InvalidArgumentError( + "Size mismatch between raw_src_data and blob data"); + } + // Memcpy. + std::memcpy(raw_dest_data.data(), raw_src_data.data(), raw_src_data.size()); + + return absl::OkStatus(); +} + // Converts a TensorFlow tensor into an MLIR elements attribute. -template +template absl::StatusOr ConvertFlatTensor(const Tensor& input_tensor, - ShapedType type) { - auto arr = input_tensor.flat(); - return ElementsAttr(mlir::DenseElementsAttr::get( - type, llvm::ArrayRef(arr.data(), arr.size()))); + ShapedType shaped_type, + bool convert_to_dense_resource) { + // Only convert to dense resource if the data type is integer or floating. + if (convert_to_dense_resource && DataTypeCanUseMemcpy(input_tensor.dtype()) && + (DataTypeIsInteger(input_tensor.dtype()) || + DataTypeIsFloating(input_tensor.dtype()))) { + auto element_type = shaped_type.getElementType(); + auto num_elements = shaped_type.getNumElements(); + auto bit_width = element_type.getIntOrFloatBitWidth(); + auto tensor_data = input_tensor.tensor_data(); + mlir::AsmResourceBlob blob; + + if (llvm::isa(element_type)) { + switch (bit_width) { + case 1: + blob = mlir::HeapAsmResourceBlob::allocate(num_elements, + /*align=*/64, + /*dataIsMutable=*/true); + TF_RETURN_IF_ERROR(CopyDataIntoBlob(blob, tensor_data)); + return mlir::DenseResourceElementsAttr::get( + shaped_type, "dense_elements_i1", std::move(blob)); + case 8: + blob = mlir::HeapAsmResourceBlob::allocate(num_elements, + /*align=*/64, + /*dataIsMutable=*/true); + TF_RETURN_IF_ERROR(CopyDataIntoBlob(blob, tensor_data)); + return mlir::DenseResourceElementsAttr::get( + shaped_type, "dense_elements_i8", std::move(blob)); + case 16: + blob = mlir::HeapAsmResourceBlob::allocate(2 * num_elements, + /*align=*/64, + /*dataIsMutable=*/true); + TF_RETURN_IF_ERROR(CopyDataIntoBlob(blob, tensor_data)); + return mlir::DenseResourceElementsAttr::get( + shaped_type, "dense_elements_i16", std::move(blob)); + case 32: + blob = mlir::HeapAsmResourceBlob::allocate(4 * num_elements, + /*align=*/64, + /*dataIsMutable=*/true); + TF_RETURN_IF_ERROR(CopyDataIntoBlob(blob, tensor_data)); + return mlir::DenseResourceElementsAttr::get( + shaped_type, "dense_elements_i32", std::move(blob)); + case 64: + blob = mlir::HeapAsmResourceBlob::allocate(8 * num_elements, + /*align=*/64, + /*dataIsMutable=*/true); + TF_RETURN_IF_ERROR(CopyDataIntoBlob(blob, tensor_data)); + return mlir::DenseResourceElementsAttr::get( + shaped_type, "dense_elements_i64", std::move(blob)); + default: + return absl::InvalidArgumentError("Unsupported bit width"); + } + } else if (llvm::isa(element_type)) { + mlir::AsmResourceBlob blob; + switch (bit_width) { + case 8: + blob = mlir::HeapAsmResourceBlob::allocate(num_elements, /*align=*/64, + /*dataIsMutable=*/true); + TF_RETURN_IF_ERROR(CopyDataIntoBlob(blob, tensor_data)); + return mlir::DenseResourceElementsAttr::get( + shaped_type, "dense_elements_f8", std::move(blob)); + case 16: + blob = mlir::HeapAsmResourceBlob::allocate(2 * num_elements, + /*align=*/64, + /*dataIsMutable=*/true); + TF_RETURN_IF_ERROR(CopyDataIntoBlob(blob, tensor_data)); + return mlir::DenseResourceElementsAttr::get( + shaped_type, "dense_elements_f16", std::move(blob)); + case 32: { + blob = mlir::HeapAsmResourceBlob::allocate(4 * num_elements, + /*align=*/64, + /*dataIsMutable=*/true); + TF_RETURN_IF_ERROR(CopyDataIntoBlob(blob, tensor_data)); + return mlir::DenseResourceElementsAttr::get( + shaped_type, "dense_elements_f32", std::move(blob)); + } + case 64: + blob = mlir::HeapAsmResourceBlob::allocate(8 * num_elements, + /*align=*/64, + /*dataIsMutable=*/true); + TF_RETURN_IF_ERROR(CopyDataIntoBlob(blob, tensor_data)); + return mlir::DenseResourceElementsAttr::get( + shaped_type, "dense_elements_f64", std::move(blob)); + default: + return absl::InvalidArgumentError("Unsupported bit width"); + } + } else { + return absl::InvalidArgumentError("Unsupported element type"); + } + } else { + auto tensor_data = llvm::ArrayRef(input_tensor.flat().data(), + input_tensor.flat().size()); + return ElementsAttr(mlir::DenseElementsAttr::get(shaped_type, tensor_data)); + } } ElementsAttr ConvertTensorOfCustomFloatType(const Tensor& tensor, @@ -116,7 +229,8 @@ absl::StatusOr ConvertStringTensor(const Tensor& input_tensor, } absl::StatusOr ConvertTensor(const Tensor& input_tensor, - Builder* builder) { + Builder* builder, + bool convert_to_dense_resource) { const auto& input_dtype = input_tensor.dtype(); const auto& input_shape = input_tensor.shape(); Type elt_type; @@ -125,9 +239,10 @@ absl::StatusOr ConvertTensor(const Tensor& input_tensor, ConvertToMlirShape(input_shape, &shape); auto type = RankedTensorType::get(shape, elt_type); -#define CONVERT_FLAT(DTYPE, CTYPE) \ - case DTYPE: \ - return ConvertFlatTensor(input_tensor, type); +#define CONVERT_FLAT(DTYPE, CTYPE) \ + case DTYPE: \ + return ConvertFlatTensor(input_tensor, type, \ + convert_to_dense_resource); // TODO(fengliuai): customize the conversions for quantized types. switch (input_dtype) { @@ -149,6 +264,9 @@ absl::StatusOr ConvertTensor(const Tensor& input_tensor, case DT_HALF: case DT_FLOAT8_E5M2: case DT_FLOAT8_E4M3FN: + case DT_FLOAT8_E4M3FNUZ: + case DT_FLOAT8_E4M3B11FNUZ: + case DT_FLOAT8_E5M2FNUZ: return ConvertTensorOfCustomFloatType(input_tensor, type); case DT_STRING: return ConvertStringTensor(input_tensor, type); @@ -166,10 +284,10 @@ absl::StatusOr ConvertTensor(const Tensor& input_tensor, // indicate, if we're storing a splat tensor. int NumberOfMaterializedElements(const TensorProto& tensor) { if (!tensor.tensor_content().empty()) return -1; - // We don't know which element type this protocol buffer is storing, and the - // metaprogramming facilities for TensorProto are too limited to check their - // number without knowing this, so we need to manually dispatch to each - // possible member of TensorProto, depening on its dtype. + // We don't know which element type this protocol buffer is storing, and the + // metaprogramming facilities for TensorProto are too limited to check their + // number without knowing this, so we need to manually dispatch to each + // possible member of TensorProto, depening on its dtype. #define MATCH(DTYPE, FIELD) \ case DTYPE: \ return tensor.FIELD##_val().size() @@ -202,8 +320,9 @@ int NumberOfMaterializedElements(const TensorProto& tensor) { } } -absl::StatusOr ConvertTensorProto(const TensorProto& input_tensor, - Builder* builder) { +absl::StatusOr ConvertTensorProto( + const TensorProto& input_tensor, Builder* builder, + bool convert_to_dense_resource) { // If there is only one actual element in the proto, but its shape would // indicate there are more values, then this is representing a splat tensor. // We can create an MLIR Attribute more efficiently in this case. @@ -231,7 +350,7 @@ absl::StatusOr ConvertTensorProto(const TensorProto& input_tensor, Tensor t; if (!t.FromProto(input_tensor)) return InvalidArgument("Failed to parse input_tensor."); - return ConvertTensor(t, builder); + return ConvertTensor(t, builder, convert_to_dense_resource); } void ConvertToTensorShapeProto(ArrayRef shape, @@ -300,58 +419,109 @@ absl::StatusOr ConvertTensorShapeProto( // Converts an MLIR dense string elements attribute to a TensorFlow tensor // proto. -void ConvertStringElementsAttr( +absl::Status ConvertStringElementsAttr( const DenseStringElementsAttr attr, protobuf::RepeatedPtrField* output) { for (const auto& val : attr.getRawStringData()) output->Add({val.data(), val.size()}); + return absl::OkStatus(); } template -void ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr, - protobuf::RepeatedField* output) { - for (const auto& val : attr.getValues>()) { - output->Add(val.real()); - output->Add(val.imag()); +absl::Status ConvertComplexElementsAttr(const mlir::ElementsAttr elem_attr, + protobuf::RepeatedField* output) { + auto attr = llvm::dyn_cast(elem_attr); + if (!attr) + return absl::InvalidArgumentError("Unsupported elements attr found"); + + auto elementType = attr.getType().getElementType(); + if (!llvm::isa(elementType)) { + return absl::InvalidArgumentError("Complex elements attr not found"); } + + auto complex_elem_ty = + llvm::cast(elementType).getElementType(); + if (complex_elem_ty.isF32()) { + for (const auto& val : attr.getValues>()) { + output->Add(val.real().convertToFloat()); + output->Add(val.imag().convertToFloat()); + } + } else if (complex_elem_ty.isF64()) { + for (const auto& val : attr.getValues>()) { + output->Add(val.real().convertToDouble()); + output->Add(val.imag().convertToDouble()); + } + } else { + return absl::InvalidArgumentError("Unsupported complex element type"); + } + return absl::OkStatus(); } // Converts an Tensor proto attribute to a TensorFlow tensor proto. -Status ConvertTensorProtoAttr(const mlir::TF::TensorProtoAttr attr, - TensorProto* output_tensor) { +absl::Status ConvertTensorProtoAttr(const mlir::TF::TensorProtoAttr attr, + TensorProto* output_tensor) { auto mangled_tensor = attr.getValue(); absl::string_view tensor_view(mangled_tensor.data(), mangled_tensor.size()); return mangling_util::DemangleTensor(tensor_view, output_tensor); } template -void ConvertElementsAttr(const mlir::DenseElementsAttr attr, - protobuf::RepeatedField* output) { +absl::Status ConvertElementsAttr(const mlir::ElementsAttr elem_attr, + protobuf::RepeatedField* output) { + auto attr = llvm::dyn_cast(elem_attr); + if (!attr) + return absl::InvalidArgumentError("Unsupported elements attr found"); if (attr.isSplat()) { if (attr.getSplatValue() != T(0)) output->Add(attr.getSplatValue()); } else { output->Reserve(attr.getNumElements()); for (auto value : attr.getValues()) output->AddAlreadyReserved(value); } + return absl::OkStatus(); } // Converts an MLIR elements attribute and adds it to specified repeated field. template -void ConvertFloatElementsAttr(const mlir::DenseElementsAttr attr, - protobuf::RepeatedField* output, - Cord* tensor_content) { - if (attr.isSplat()) { - if (attr.getSplatValue() != T(0)) output->Add(attr.getSplatValue()); +absl::Status ConvertFloatElementsAttr(const mlir::ElementsAttr elem_attr, + protobuf::RepeatedField* output, + Cord* tensor_content) { + if (auto attr = llvm::dyn_cast(elem_attr)) { + if (attr.isSplat()) { + if (attr.getSplatValue() != T(0)) output->Add(attr.getSplatValue()); + } else { + port::CopyFromArray(tensor_content, attr.getRawData().data(), + attr.getRawData().size()); + } + } else if (auto dense_resource_ttr = + llvm::dyn_cast(elem_attr)) { + mlir::AsmResourceBlob* blob = dense_resource_ttr.getRawHandle().getBlob(); + if (blob) { + size_t dst_block_length = blob->getData().size(); + const char* raw_dst_block = blob->getData().data(); + if constexpr (std::is_same_v) { + *tensor_content = absl::string_view(raw_dst_block, dst_block_length); + } else { + *tensor_content = absl::MakeCordFromExternal( + absl::string_view(raw_dst_block, dst_block_length), + [](absl::string_view data) {}); + } + } else { + return absl::InvalidArgumentError("No blob found in dense resource"); + } } else { - port::CopyFromArray(tensor_content, attr.getRawData().data(), - attr.getRawData().size()); + return absl::InvalidArgumentError("Unsupported elements attr found"); } + return absl::OkStatus(); } // Converts an MLIR elements attribute containing half values and adds it to // specified repeated field. -void ConvertHalfElementsAttr(const mlir::DenseElementsAttr attr, - protobuf::RepeatedField* output) { +absl::Status ConvertHalfElementsAttr(const mlir::ElementsAttr elem_attr, + protobuf::RepeatedField* output) { + auto attr = llvm::dyn_cast(elem_attr); + if (!attr) + return absl::InvalidArgumentError( + "DenseResourceElementsAttr of type half found"); if (attr.isSplat()) { if (attr.getSplatValue() != Eigen::half(0)) output->Add( @@ -361,40 +531,86 @@ void ConvertHalfElementsAttr(const mlir::DenseElementsAttr attr, for (const Eigen::half value : attr.getValues()) output->AddAlreadyReserved(Eigen::numext::bit_cast(value)); } + return absl::OkStatus(); } // Converts an MLIR elements attribute containing signed int values and adds it // to specified repeated field. template -void ConvertIntElementsAttr(const mlir::DenseElementsAttr attr, - protobuf::RepeatedField* output, - Cord* tensor_content) { - if (attr.isSplat()) { - if (attr.getSplatValue() != U(0)) - output->Add(static_cast(attr.getSplatValue())); +absl::Status ConvertIntElementsAttr(const mlir::ElementsAttr elem_attr, + protobuf::RepeatedField* output, + Cord* tensor_content) { + if (auto attr = llvm::dyn_cast(elem_attr)) { + if (attr.isSplat()) { + if (attr.getSplatValue() != U(0)) + output->Add(static_cast(attr.getSplatValue())); + } else { + port::CopyFromArray(tensor_content, attr.getRawData().data(), + attr.getRawData().size()); + } + } else if (auto dense_resource_ttr = + llvm::dyn_cast(elem_attr)) { + mlir::AsmResourceBlob* blob = dense_resource_ttr.getRawHandle().getBlob(); + if (blob) { + size_t dst_block_length = blob->getData().size(); + const char* raw_dst_block = blob->getData().data(); + if constexpr (std::is_same_v) { + *tensor_content = absl::string_view(raw_dst_block, dst_block_length); + } else { + *tensor_content = absl::MakeCordFromExternal( + absl::string_view(raw_dst_block, dst_block_length), + [](absl::string_view data) {}); + } + } else { + return absl::InvalidArgumentError("No blob found in dense resource"); + } } else { - port::CopyFromArray(tensor_content, attr.getRawData().data(), - attr.getRawData().size()); + return absl::InvalidArgumentError("Unsupported elements attr found"); } + return absl::OkStatus(); } // Converts an MLIR elements attribute containing unsigned int values and adds // it to specified repeated field. template -void ConvertUIntElementsAttr(const mlir::DenseElementsAttr attr, - protobuf::RepeatedField* output, - Cord* tensor_content) { - if (attr.isSplat()) { - if (attr.getSplatValue() != U(0)) - output->Add(static_cast(attr.getSplatValue())); +absl::Status ConvertUIntElementsAttr(const mlir::ElementsAttr elem_attr, + protobuf::RepeatedField* output, + Cord* tensor_content) { + if (auto attr = llvm::dyn_cast(elem_attr)) { + if (attr.isSplat()) { + if (attr.getSplatValue() != U(0)) + output->Add(static_cast(attr.getSplatValue())); + } else { + port::CopyFromArray(tensor_content, attr.getRawData().data(), + attr.getRawData().size()); + } + } else if (auto dense_resource_ttr = + llvm::dyn_cast(elem_attr)) { + mlir::AsmResourceBlob* blob = dense_resource_ttr.getRawHandle().getBlob(); + if (blob) { + size_t dst_block_length = blob->getData().size(); + const char* raw_dst_block = blob->getData().data(); + if constexpr (std::is_same_v) { + *tensor_content = absl::string_view(raw_dst_block, dst_block_length); + } else { + *tensor_content = absl::MakeCordFromExternal( + absl::string_view(raw_dst_block, dst_block_length), + [](absl::string_view data) {}); + } + } else { + return absl::InvalidArgumentError("No blob found in dense resource"); + } } else { - port::CopyFromArray(tensor_content, attr.getRawData().data(), - attr.getRawData().size()); + return absl::InvalidArgumentError("Unsupported elements attr found"); } + return absl::OkStatus(); } -void ConvertBfloat16ElementsAttr(const mlir::DenseElementsAttr attr, - protobuf::RepeatedField* output) { +absl::Status ConvertBfloat16ElementsAttr(const mlir::ElementsAttr elem_attr, + protobuf::RepeatedField* output) { + auto attr = llvm::dyn_cast(elem_attr); + if (!attr) + return absl::InvalidArgumentError("Unsupported elements attr found"); if (attr.isSplat()) { if (attr.getSplatValue() != bfloat16(0)) output->Add( @@ -404,11 +620,15 @@ void ConvertBfloat16ElementsAttr(const mlir::DenseElementsAttr attr, for (const bfloat16 value : attr.getValues()) output->AddAlreadyReserved(Eigen::numext::bit_cast(value)); } + return absl::OkStatus(); } template -void ConvertFloat8ElementsAttr(const mlir::DenseElementsAttr attr, - std::string* output) { +absl::Status ConvertFloat8ElementsAttr(const mlir::ElementsAttr elem_attr, + std::string* output) { + auto attr = llvm::dyn_cast(elem_attr); + if (!attr) + return absl::InvalidArgumentError("Unsupported elements attr found"); if (attr.isSplat()) { if (attr.getSplatValue() != T(0)) output->push_back( @@ -418,9 +638,11 @@ void ConvertFloat8ElementsAttr(const mlir::DenseElementsAttr attr, for (const T value : attr.getValues()) output->push_back(Eigen::numext::bit_cast(value)); } + return absl::OkStatus(); } -Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { +absl::Status ConvertToTensorProto(const ElementsAttr attr, + TensorProto* output) { auto type = attr.getShapedType(); auto shape = type.getShape(); DataType output_dtype; @@ -431,101 +653,113 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { if (auto tensor_attr = mlir::dyn_cast(attr)) return ConvertTensorProtoAttr(tensor_attr, output); - auto dense_attr = mlir::dyn_cast(attr); - if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr"); - switch (output_dtype) { case DT_BOOL: - ConvertElementsAttr(dense_attr, output->mutable_bool_val()); + TF_RETURN_IF_ERROR(ConvertElementsAttr(attr, output->mutable_bool_val())); break; case DT_BFLOAT16: - ConvertBfloat16ElementsAttr(dense_attr, output->mutable_half_val()); + TF_RETURN_IF_ERROR( + ConvertBfloat16ElementsAttr(attr, output->mutable_half_val())); break; case DT_COMPLEX64: - ConvertComplexElementsAttr(dense_attr, output->mutable_scomplex_val()); + TF_RETURN_IF_ERROR( + ConvertComplexElementsAttr(attr, output->mutable_scomplex_val())); break; case DT_COMPLEX128: - ConvertComplexElementsAttr(dense_attr, output->mutable_dcomplex_val()); + TF_RETURN_IF_ERROR( + ConvertComplexElementsAttr(attr, output->mutable_dcomplex_val())); break; case DT_DOUBLE: - ConvertFloatElementsAttr(dense_attr, output->mutable_double_val(), - output->mutable_tensor_content()); + TF_RETURN_IF_ERROR( + ConvertFloatElementsAttr(attr, output->mutable_double_val(), + output->mutable_tensor_content())); break; case DT_HALF: - ConvertHalfElementsAttr(dense_attr, output->mutable_half_val()); + TF_RETURN_IF_ERROR( + ConvertHalfElementsAttr(attr, output->mutable_half_val())); break; case DT_FLOAT: - ConvertFloatElementsAttr(dense_attr, output->mutable_float_val(), - output->mutable_tensor_content()); + TF_RETURN_IF_ERROR(ConvertFloatElementsAttr( + attr, output->mutable_float_val(), output->mutable_tensor_content())); break; case DT_FLOAT8_E5M2: - ConvertFloat8ElementsAttr(dense_attr, - output->mutable_float8_val()); + TF_RETURN_IF_ERROR(ConvertFloat8ElementsAttr( + attr, output->mutable_float8_val())); break; case DT_FLOAT8_E4M3FN: - ConvertFloat8ElementsAttr( - dense_attr, output->mutable_float8_val()); + TF_RETURN_IF_ERROR(ConvertFloat8ElementsAttr( + attr, output->mutable_float8_val())); + break; + case DT_FLOAT8_E4M3FNUZ: + TF_RETURN_IF_ERROR(ConvertFloat8ElementsAttr( + attr, output->mutable_float8_val())); + break; + case DT_FLOAT8_E4M3B11FNUZ: + TF_RETURN_IF_ERROR(ConvertFloat8ElementsAttr( + attr, output->mutable_float8_val())); + break; + case DT_FLOAT8_E5M2FNUZ: + TF_RETURN_IF_ERROR(ConvertFloat8ElementsAttr( + attr, output->mutable_float8_val())); break; case tensorflow::DT_INT4: - ConvertIntElementsAttr(dense_attr, - output->mutable_int_val(), - output->mutable_tensor_content()); + TF_RETURN_IF_ERROR(ConvertIntElementsAttr( + attr, output->mutable_int_val(), output->mutable_tensor_content())); break; case tensorflow::DT_UINT4: - ConvertUIntElementsAttr( - dense_attr, output->mutable_int_val(), - output->mutable_tensor_content()); + TF_RETURN_IF_ERROR(ConvertUIntElementsAttr( + attr, output->mutable_int_val(), output->mutable_tensor_content())); break; case DT_QUINT8: case DT_INT8: - ConvertUIntElementsAttr(dense_attr, - output->mutable_int_val(), - output->mutable_tensor_content()); + TF_RETURN_IF_ERROR(ConvertIntElementsAttr( + attr, output->mutable_int_val(), output->mutable_tensor_content())); break; case DT_QUINT16: case DT_INT16: - ConvertIntElementsAttr(dense_attr, - output->mutable_int_val(), - output->mutable_tensor_content()); + TF_RETURN_IF_ERROR(ConvertIntElementsAttr( + attr, output->mutable_int_val(), output->mutable_tensor_content())); break; case DT_INT32: - ConvertIntElementsAttr(dense_attr, output->mutable_int_val(), - output->mutable_tensor_content()); + TF_RETURN_IF_ERROR(ConvertIntElementsAttr( + attr, output->mutable_int_val(), output->mutable_tensor_content())); break; case DT_INT64: - ConvertIntElementsAttr(dense_attr, output->mutable_int64_val(), - output->mutable_tensor_content()); + TF_RETURN_IF_ERROR(ConvertIntElementsAttr( + attr, output->mutable_int64_val(), output->mutable_tensor_content())); break; case DT_STRING: - ConvertStringElementsAttr(mlir::cast(dense_attr), - output->mutable_string_val()); + TF_RETURN_IF_ERROR( + ConvertStringElementsAttr(mlir::cast(attr), + output->mutable_string_val())); break; case DT_UINT8: - ConvertUIntElementsAttr(dense_attr, - output->mutable_int_val(), - output->mutable_tensor_content()); + TF_RETURN_IF_ERROR(ConvertUIntElementsAttr( + attr, output->mutable_int_val(), output->mutable_tensor_content())); break; case DT_UINT16: - ConvertUIntElementsAttr(dense_attr, - output->mutable_int_val(), - output->mutable_tensor_content()); + TF_RETURN_IF_ERROR(ConvertUIntElementsAttr( + attr, output->mutable_int_val(), output->mutable_tensor_content())); break; case DT_UINT32: - ConvertUIntElementsAttr(dense_attr, output->mutable_uint32_val(), - output->mutable_tensor_content()); + TF_RETURN_IF_ERROR( + ConvertUIntElementsAttr(attr, output->mutable_uint32_val(), + output->mutable_tensor_content())); break; case DT_UINT64: - ConvertUIntElementsAttr(dense_attr, output->mutable_uint64_val(), - output->mutable_tensor_content()); + TF_RETURN_IF_ERROR( + ConvertUIntElementsAttr(attr, output->mutable_uint64_val(), + output->mutable_tensor_content())); break; default: - return errors::Unimplemented(absl::StrCat("Unimplemented data type ", - DataTypeString(output_dtype))); + return absl::UnimplementedError(absl::StrCat( + "Unimplemented data type ", DataTypeString(output_dtype))); } return absl::OkStatus(); } -Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) { +absl::Status ConvertToTensor(const mlir::ElementsAttr attr, + Tensor* output_tensor) { TensorProto tensor_proto; TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor_proto)); if (!output_tensor->FromProto(tensor_proto)) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h index 92d6ee4bb65356..ba5cd3d81de1a1 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h @@ -32,11 +32,13 @@ using tsl::StatusOr; // Converts an TensorFlow tensor proto into an MLIR elements attribute. absl::StatusOr ConvertTensorProto( - const TensorProto& input_tensor, mlir::Builder* builder); + const TensorProto& input_tensor, mlir::Builder* builder, + bool convert_to_dense_resource = false); // Converts an TensorFlow tensor into an MLIR elements attribute. -absl::StatusOr ConvertTensor(const Tensor& input_tensor, - mlir::Builder* builder); +absl::StatusOr ConvertTensor( + const Tensor& input_tensor, mlir::Builder* builder, + bool convert_to_dense_resource = false); // Converts a shape from MLIR to a TensorFlow tensor shape proto. void ConvertToTensorShapeProto(llvm::ArrayRef shape, @@ -57,11 +59,11 @@ absl::StatusOr ConvertTensorShapeProto( const TensorShapeProto& shape, mlir::MLIRContext* context); // Converts an MLIR elements attribute to a TensorFlow tensor proto. -Status ConvertToTensorProto(mlir::ElementsAttr attr, - TensorProto* output_tensor); +absl::Status ConvertToTensorProto(mlir::ElementsAttr attr, + TensorProto* output_tensor); // Converts an MLIR elements attribute to a TensorFlow tensor. -Status ConvertToTensor(mlir::ElementsAttr attr, Tensor* output_tensor); +absl::Status ConvertToTensor(mlir::ElementsAttr attr, Tensor* output_tensor); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc index 3feed8904fab0e..c8eb131fc897d3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc @@ -112,12 +112,13 @@ class ConvertTensorTest : public ::testing::Test { protected: template void VerifyConversion(std::initializer_list values, DataType dtype, - mlir::Type expected_ty) { + mlir::Type expected_ty, + bool convert_to_dense_resource = false) { mlir::Builder b(expected_ty.getContext()); Tensor tensor(dtype, TensorShape({static_cast(values.size())})); tensor.flat().setValues(values); - auto value_or = ConvertTensor(tensor, &b); + auto value_or = ConvertTensor(tensor, &b, convert_to_dense_resource); TF_ASSERT_OK(value_or.status()); auto attr = value_or.value(); @@ -148,6 +149,15 @@ TEST_F(ConvertTensorTest, Simple) { ASSERT_NO_FATAL_FAILURE(VerifyConversion( {tsl::float8_e4m3fn{1.0}, tsl::float8_e4m3fn{-1.0}}, DT_FLOAT8_E4M3FN, mlir::FloatType::getFloat8E4M3FN(&context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {tsl::float8_e4m3fnuz{1.0}, tsl::float8_e4m3fnuz{-1.0}}, + DT_FLOAT8_E4M3FNUZ, mlir::FloatType::getFloat8E4M3FNUZ(&context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {tsl::float8_e4m3b11fnuz{1.0}, tsl::float8_e4m3b11fnuz{-1.0}}, + DT_FLOAT8_E4M3B11FNUZ, mlir::FloatType::getFloat8E4M3B11FNUZ(&context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {tsl::float8_e5m2fnuz{1.0}, tsl::float8_e5m2fnuz{-1.0}}, + DT_FLOAT8_E5M2FNUZ, mlir::FloatType::getFloat8E5M2FNUZ(&context))); ASSERT_NO_FATAL_FAILURE(VerifyConversion( {static_cast(1), static_cast(-1)}, DT_INT4, @@ -191,6 +201,73 @@ TEST_F(ConvertTensorTest, Simple) { mlir::ComplexType::get(mlir::FloatType::getF64(&context)))); } +TEST_F(ConvertTensorTest, SimpleDenseResourceElements) { + mlir::MLIRContext context; + RegisterDialects(context); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {Eigen::half(1.0)}, DT_HALF, mlir::FloatType::getF16(&context), true)); + ASSERT_NO_FATAL_FAILURE( + VerifyConversion({bfloat16(1.0), bfloat16(-1.0)}, DT_BFLOAT16, + mlir::FloatType::getBF16(&context), true)); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1.0, -1.0}, DT_FLOAT, mlir::FloatType::getF32(&context), true)); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1.0, -1.0}, DT_DOUBLE, mlir::FloatType::getF64(&context), true)); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {tsl::float8_e5m2{1.0}, tsl::float8_e5m2{-1.0}}, DT_FLOAT8_E5M2, + mlir::FloatType::getFloat8E5M2(&context), true)); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {tsl::float8_e4m3fn{1.0}, tsl::float8_e4m3fn{-1.0}}, DT_FLOAT8_E4M3FN, + mlir::FloatType::getFloat8E4M3FN(&context), true)); + + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {static_cast(1), static_cast(-1)}, DT_INT4, + mlir::IntegerType::get(&context, 4, + mlir::IntegerType::SignednessSemantics::Signed), + true)); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, -1}, DT_INT8, mlir::IntegerType::get(&context, 8), true)); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, -1}, DT_INT16, mlir::IntegerType::get(&context, 16), true)); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, -1}, DT_INT32, mlir::IntegerType::get(&context, 32), true)); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, -1}, DT_INT64, mlir::IntegerType::get(&context, 64), true)); + + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {static_cast(1), static_cast(2)}, DT_UINT4, + mlir::IntegerType::get(&context, 4, + mlir::IntegerType::SignednessSemantics::Unsigned), + true)); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, 2}, DT_UINT8, + mlir::IntegerType::get(&context, 8, + mlir::IntegerType::SignednessSemantics::Unsigned), + true)); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, 2}, DT_UINT16, + mlir::IntegerType::get(&context, 16, + mlir::IntegerType::SignednessSemantics::Unsigned), + true)); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, 2}, DT_UINT32, + mlir::IntegerType::get(&context, 32, + mlir::IntegerType::SignednessSemantics::Unsigned), + true)); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, 2}, DT_UINT64, + mlir::IntegerType::get(&context, 64, + mlir::IntegerType::SignednessSemantics::Unsigned), + true)); + + ASSERT_NO_FATAL_FAILURE(VerifyConversion>( + {{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX64, + mlir::ComplexType::get(mlir::FloatType::getF32(&context)), true)); + ASSERT_NO_FATAL_FAILURE(VerifyConversion>( + {{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX128, + mlir::ComplexType::get(mlir::FloatType::getF64(&context)))); +} + bool IsSplat(mlir::ElementsAttr attr) { return mlir::cast(attr).isSplat(); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc index e3404d613c9f83..d9caee612bca24 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc @@ -35,7 +35,7 @@ using mlir::Builder; using mlir::ShapedType; using mlir::Type; -Status ConvertDataType(DataType dtype, Builder builder, Type* type) { +absl::Status ConvertDataType(DataType dtype, Builder builder, Type* type) { switch (dtype) { case DT_HALF: *type = builder.getF16Type(); @@ -88,6 +88,15 @@ Status ConvertDataType(DataType dtype, Builder builder, Type* type) { case tensorflow::DT_FLOAT8_E5M2: *type = builder.getFloat8E5M2Type(); return absl::OkStatus(); + case tensorflow::DT_FLOAT8_E4M3FNUZ: + *type = builder.getFloat8E4M3FNUZType(); + return absl::OkStatus(); + case tensorflow::DT_FLOAT8_E4M3B11FNUZ: + *type = builder.getFloat8E4M3B11FNUZType(); + return absl::OkStatus(); + case tensorflow::DT_FLOAT8_E5M2FNUZ: + *type = builder.getFloat8E5M2FNUZType(); + return absl::OkStatus(); case DT_INT4: *type = builder.getIntegerType(4, /*isSigned=*/true); return absl::OkStatus(); @@ -106,7 +115,7 @@ Status ConvertDataType(DataType dtype, Builder builder, Type* type) { } } -Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { +absl::Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { if (type.isF16()) { *dtype = DT_HALF; return absl::OkStatus(); @@ -125,6 +134,15 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { } else if (type.isFloat8E5M2()) { *dtype = DT_FLOAT8_E5M2; return absl::OkStatus(); + } else if (type.isFloat8E4M3FNUZ()) { + *dtype = DT_FLOAT8_E4M3FNUZ; + return absl::OkStatus(); + } else if (type.isFloat8E4M3B11FNUZ()) { + *dtype = DT_FLOAT8_E4M3B11FNUZ; + return absl::OkStatus(); + } else if (type.isFloat8E5M2FNUZ()) { + *dtype = DT_FLOAT8_E5M2FNUZ; + return absl::OkStatus(); } else if (auto itype = mlir::dyn_cast(type)) { switch (itype.getWidth()) { case 1: @@ -174,7 +192,7 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { absl::StrCat("Converting ", debugString(type), " to DataType")); } -Status ConvertToDataType(Type type, DataType* dtype) { +absl::Status ConvertToDataType(Type type, DataType* dtype) { if (auto stype = mlir::dyn_cast(type)) { TF_RETURN_IF_ERROR( ConvertScalarTypeToDataType(stype.getElementType(), dtype)); @@ -192,8 +210,8 @@ void ConvertToMlirShape(const TensorShape& input_shape, } } -Status ConvertToMlirShape(const TensorShapeProto& input_shape, - llvm::SmallVectorImpl* shape) { +absl::Status ConvertToMlirShape(const TensorShapeProto& input_shape, + llvm::SmallVectorImpl* shape) { shape->reserve(input_shape.dim_size()); auto& dims = input_shape.dim(); for (auto& d : dims) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h index 3c21aa260499c1..1ce9d054b981a7 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h @@ -27,22 +27,23 @@ namespace tensorflow { using tsl::StatusOr; // Converts the TensorFlow DataType 'dtype' into an MLIR (scalar) type. -Status ConvertDataType(DataType dtype, mlir::Builder builder, mlir::Type* type); +absl::Status ConvertDataType(DataType dtype, mlir::Builder builder, + mlir::Type* type); // Converts a scalar MLIR type to a TensorFlow Datatype. -Status ConvertScalarTypeToDataType(mlir::Type type, DataType* dtype); +absl::Status ConvertScalarTypeToDataType(mlir::Type type, DataType* dtype); // Converts an MLIR type to TensorFlow DataType. If 'type' is a scalar type, it // is converted directly. If it is a shaped type, the element type is converted. -Status ConvertToDataType(mlir::Type type, DataType* dtype); +absl::Status ConvertToDataType(mlir::Type type, DataType* dtype); // Converts an TensorFlow shape to the one used in MLIR. void ConvertToMlirShape(const TensorShape& input_shape, llvm::SmallVectorImpl* shape); // Converts an TensorFlow shape proto to the one used in MLIR. -Status ConvertToMlirShape(const TensorShapeProto& input_shape, - llvm::SmallVectorImpl* shape); +absl::Status ConvertToMlirShape(const TensorShapeProto& input_shape, + llvm::SmallVectorImpl* shape); // Given a tensor shape and dtype, get the corresponding MLIR tensor type. absl::StatusOr ConvertToMlirTensorType( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc index f089ec111991e7..c3e7ae75022348 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc @@ -48,7 +48,9 @@ class FakeDevice : public Device { explicit FakeDevice(const DeviceAttributes& device_attributes) : Device(nullptr, device_attributes) {} - Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); } + absl::Status Sync() override { + return errors::Unimplemented("FakeDevice::Sync()"); + } static std::unique_ptr Make(const string& name, const string& desc = "") { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc index d705049629b765..d8d94f3dfb858f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc @@ -49,7 +49,7 @@ struct WritableFileRawStream : public llvm::raw_ostream { void write_impl(const char* ptr, size_t size) override { // If an error is encountered, null out the file. if (file) { - Status s = file->Append(StringPiece(ptr, size)); + absl::Status s = file->Append(absl::string_view(ptr, size)); if (!s.ok()) { LOG(WARNING) << "Write failed: " << s; file = nullptr; @@ -62,16 +62,17 @@ struct WritableFileRawStream : public llvm::raw_ostream { }; } // namespace -Status DumpTextualIRToFile(const MlirDumpConfig& config, const Graph& graph, - const FunctionLibraryDefinition* flib_def, - WritableFile* file) { +absl::Status DumpTextualIRToFile(const MlirDumpConfig& config, + const Graph& graph, + const FunctionLibraryDefinition* flib_def, + WritableFile* file) { WritableFileRawStream os(std::move(file)); mlir::MLIRContext context; mlir::OwningOpRef module; if (flib_def) { flib_def = &graph.flib_def(); } - auto convert = [&]() -> Status { + auto convert = [&]() -> absl::Status { mlir::StatusScopedDiagnosticHandler status_handler(&context); // TODO(jpienaar): Both the graph debug info and import config should be // specifiable. @@ -99,7 +100,7 @@ Status DumpTextualIRToFile(const MlirDumpConfig& config, const Graph& graph, void UseMlirForGraphDump(const MlirDumpConfig& config) { SetGraphDumper( [config](const Graph& graph, const FunctionLibraryDefinition* flib_def, - WritableFile* file) -> Status { + WritableFile* file) -> absl::Status { return DumpTextualIRToFile(config, graph, flib_def, file); }, /*suffix=*/".mlir"); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h index 2c400925a88cb4..ae6e0b612ae0e5 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h @@ -30,9 +30,10 @@ struct MlirDumpConfig; // Dumps 'graph_def' to a file, as textual IR. Returns the file name chosen. // // Note: This is for debugging use and is not optimized for performance. -Status DumpTextualIRToFile(const MlirDumpConfig& config, const Graph& graph, - const FunctionLibraryDefinition* flib_def, - WritableFile* file); +absl::Status DumpTextualIRToFile(const MlirDumpConfig& config, + const Graph& graph, + const FunctionLibraryDefinition* flib_def, + WritableFile* file); // Config of the textual dump. struct MlirDumpConfig { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph_test.cc index d09458f78b06c7..7e92860e5ff03e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph_test.cc @@ -41,23 +41,23 @@ class StringWritableFile : public WritableFile { public: explicit StringWritableFile(string* str) : str_(*str) {} - Status Append(StringPiece data) override { + absl::Status Append(absl::string_view data) override { absl::StrAppend(&str_, data); return absl::OkStatus(); } - Status Close() override { return absl::OkStatus(); } + absl::Status Close() override { return absl::OkStatus(); } - Status Flush() override { return absl::OkStatus(); } + absl::Status Flush() override { return absl::OkStatus(); } - Status Name(StringPiece* result) const override { + absl::Status Name(absl::string_view* result) const override { *result = "(string)"; return absl::OkStatus(); } - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } - Status Tell(int64_t* position) override { + absl::Status Tell(int64_t* position) override { return errors::Unimplemented("Stream not seekable"); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index 1270865e551d52..b970ca84b326cf 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -105,7 +105,7 @@ struct WritableFileRawStream : public llvm::raw_ostream { void write_impl(const char* ptr, size_t size) override { // Write the file if it is still valid. If the write fails, null out the // file to avoid encountering another error. - if (file && !file->Append(StringPiece(ptr, size)).ok()) { + if (file && !file->Append(absl::string_view(ptr, size)).ok()) { file = nullptr; } } @@ -150,9 +150,10 @@ struct CrashAnalysisCrashReproducerStream : public mlir::ReproducerStream { } // namespace -Status CreateFileForDumping(llvm::StringRef name, - std::unique_ptr* os, - std::string* filepath, llvm::StringRef dirname) { +absl::Status CreateFileForDumping(llvm::StringRef name, + std::unique_ptr* os, + std::string* filepath, + llvm::StringRef dirname) { std::string dir; if (!dirname.empty()) dir = std::string(dirname); @@ -160,24 +161,24 @@ Status CreateFileForDumping(llvm::StringRef name, dir = GetDumpDirFromEnvVar(); if (dir.empty()) { - return Status(absl::StatusCode::kInvalidArgument, - "(TF_DUMP_GRAPH_PREFIX not specified)"); + return absl::Status(absl::StatusCode::kInvalidArgument, + "(TF_DUMP_GRAPH_PREFIX not specified)"); } if (dir == kCrashReproducerStdErr) { *os = std::make_unique(); *filepath = llvm::formatv("(stderr; requested filename: '{0}')", name).str(); - return Status(); + return absl::Status(); } // Get a valid file path to dump with. Env* env = Env::Default(); - Status status = env->RecursivelyCreateDir(dir); + absl::Status status = env->RecursivelyCreateDir(dir); if (!status.ok()) { LOG(WARNING) << "Failed to create '" << dir << "' directory for dumping: " << status; - return Status(absl::StatusCode::kUnavailable, "(unavailable)"); + return absl::Status(absl::StatusCode::kUnavailable, "(unavailable)"); } *filepath = io::JoinPath(dir, MakeUniqueFilename(std::string(name))); @@ -186,11 +187,11 @@ Status CreateFileForDumping(llvm::StringRef name, status = env->NewWritableFile(*filepath, &file); if (!status.ok()) { LOG(WARNING) << "Failed to create file '" << filepath << "': " << status; - return Status(absl::StatusCode::kUnavailable, "(unavailable)"); + return absl::Status(absl::StatusCode::kUnavailable, "(unavailable)"); } file = std::make_unique(std::move(file)); *os = std::make_unique(std::move(file)); - return Status(); + return absl::Status(); } // Prints the pass pipeline of `pass_manager` to `os`. @@ -214,7 +215,7 @@ std::string DumpMlirOpToFile(llvm::StringRef name, mlir::Operation* op, const mlir::PassManager* pass_manager) { std::unique_ptr os; std::string filepath; - Status result = CreateFileForDumping(name, &os, &filepath, dirname); + absl::Status result = CreateFileForDumping(name, &os, &filepath, dirname); if (!result.ok()) return std::string(result.message()); LOG(INFO) << "Dumping MLIR operation '" << op->getName().getStringRef().str() @@ -248,7 +249,7 @@ std::string DumpRawStringToFile(llvm::StringRef name, llvm::StringRef content, llvm::StringRef dirname) { std::unique_ptr os; std::string filepath; - Status result = CreateFileForDumping(name, &os, &filepath, dirname); + absl::Status result = CreateFileForDumping(name, &os, &filepath, dirname); if (!result.ok()) return std::string(result.message()); (*os) << content; @@ -314,7 +315,8 @@ void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path) { // Try to open the file and generate a raw_ostream. std::unique_ptr file; - Status status = tensorflow::Env::Default()->NewWritableFile(path, &file); + absl::Status status = + tensorflow::Env::Default()->NewWritableFile(path, &file); file = std::make_unique(std::move(file)); if (!status.ok()) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h index a7760872d79315..87d53e8b476184 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h @@ -41,10 +41,10 @@ inline constexpr absl::string_view kCrashReproducerCrashAnalysis = // This will create a file name via prefixing `name` with the value of the // TF_DUMP_GRAPH_PREFIX environment variable if `dirname` is empty and // suffixing `name` with ".mlir". -Status CreateFileForDumping(llvm::StringRef name, - std::unique_ptr* os, - std::string* filepath, - llvm::StringRef dirname = ""); +absl::Status CreateFileForDumping(llvm::StringRef name, + std::unique_ptr* os, + std::string* filepath, + llvm::StringRef dirname = ""); // Dumps MLIR operation to a file and returns the file name used. // diff --git a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc index 4a19c06154b6d6..3672fa9b5fee45 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc @@ -51,7 +51,7 @@ static bool IsOk(const TF_Status* s) { return false; } -static bool IsOk(const Status& s) { +static bool IsOk(const absl::Status& s) { if (s.ok()) return true; VLOG(2) << s.message(); return false; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index 96ba0afd096a16..729fe90731ebbf 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -78,8 +78,8 @@ std::set* GlobalOpPrefixes() { } // Converts a location to the debug information for the node def. -Status ConvertLocation(mlir::Location inst_loc, llvm::StringRef node_name, - NodeDef::ExperimentalDebugInfo* debug_info) { +absl::Status ConvertLocation(mlir::Location inst_loc, llvm::StringRef node_name, + NodeDef::ExperimentalDebugInfo* debug_info) { mlir::Location unwrapped_inst_loc = GetLocationWithoutOpType(inst_loc); if (auto call_site = mlir::dyn_cast(unwrapped_inst_loc)) { @@ -109,43 +109,46 @@ Status ConvertLocation(mlir::Location inst_loc, llvm::StringRef node_name, return absl::OkStatus(); } -Status ConvertAttribute(const mlir::BoolAttr& attr, AttrValue* value) { +absl::Status ConvertAttribute(const mlir::BoolAttr& attr, AttrValue* value) { value->set_b(attr.getValue()); return absl::OkStatus(); } -Status ConvertAttribute(const mlir::IntegerAttr& attr, AttrValue* value) { +absl::Status ConvertAttribute(const mlir::IntegerAttr& attr, AttrValue* value) { value->set_i(attr.getInt()); return absl::OkStatus(); } -Status ConvertAttribute(const mlir::FloatAttr& attr, AttrValue* value) { +absl::Status ConvertAttribute(const mlir::FloatAttr& attr, AttrValue* value) { value->set_f(attr.getValueAsDouble()); return absl::OkStatus(); } -Status ConvertAttribute(const mlir::ElementsAttr& attr, AttrValue* value) { +absl::Status ConvertAttribute(const mlir::ElementsAttr& attr, + AttrValue* value) { return ConvertToTensorProto(attr, value->mutable_tensor()); } -Status ConvertAttribute(const mlir::TF::PlaceholderAttr& attr, - AttrValue* value) { +absl::Status ConvertAttribute(const mlir::TF::PlaceholderAttr& attr, + AttrValue* value) { value->set_placeholder(attr.getValue().str()); return absl::OkStatus(); } -Status ConvertAttribute(const mlir::TF::ShapeAttr& attr, AttrValue* value) { +absl::Status ConvertAttribute(const mlir::TF::ShapeAttr& attr, + AttrValue* value) { SetTensorShapeProto(attr, value->mutable_shape()); return absl::OkStatus(); } -Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) { +absl::Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, + AttrValue* value) { value->mutable_func()->set_name(attr.getValue().str()); return absl::OkStatus(); } -Status ConvertAttribute(const mlir::TF::FuncAttr& attr, bool remove_ref_type, - AttrValue* value) { +absl::Status ConvertAttribute(const mlir::TF::FuncAttr& attr, + bool remove_ref_type, AttrValue* value) { TF_RETURN_IF_ERROR(ConvertAttribute( mlir::cast(attr.getName()), value)); TF_RETURN_IF_ERROR(ConvertAttributes(attr.getAttrs().getValue(), @@ -154,7 +157,7 @@ Status ConvertAttribute(const mlir::TF::FuncAttr& attr, bool remove_ref_type, return absl::OkStatus(); } -Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) { +absl::Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) { absl::string_view attr_value(attr.getValue().data(), attr.getValue().size()); switch (mangling_util::GetMangledKind(attr_value)) { case mangling_util::MangledKind::kUnknown: { @@ -177,8 +180,8 @@ Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) { return absl::OkStatus(); } -Status ConvertAttribute(mlir::Type type, bool remove_ref_type, - AttrValue* value) { +absl::Status ConvertAttribute(mlir::Type type, bool remove_ref_type, + AttrValue* value) { DataType dtype; TF_RETURN_IF_ERROR(ConvertToDataType(type, &dtype)); if (tensorflow::IsRefType(dtype)) dtype = tensorflow::RemoveRefType(dtype); @@ -186,18 +189,18 @@ Status ConvertAttribute(mlir::Type type, bool remove_ref_type, return absl::OkStatus(); } -Status ConvertAttribute(const mlir::TypeAttr& type, bool remove_ref_type, - AttrValue* value) { +absl::Status ConvertAttribute(const mlir::TypeAttr& type, bool remove_ref_type, + AttrValue* value) { return ConvertAttribute(type.getValue(), remove_ref_type, value); } -Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) { +absl::Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) { value->clear_value(); return absl::OkStatus(); } -Status ConvertAttribute(const mlir::ArrayAttr& attr, bool remove_ref_type, - AttrValue* value) { +absl::Status ConvertAttribute(const mlir::ArrayAttr& attr, bool remove_ref_type, + AttrValue* value) { auto* list = value->mutable_list(); for (mlir::Attribute a : attr.getValue()) { if (auto attr = mlir::dyn_cast(a)) { @@ -373,7 +376,7 @@ absl::StatusOr> GetOperationNodeDef( return node_def; } -Status ConvertAttributes( +absl::Status ConvertAttributes( const llvm::ArrayRef attrs, const absl::flat_hash_set& attrs_to_ignore, bool remove_ref_type, AttrValueMap* values) { @@ -411,7 +414,7 @@ Status ConvertAttributes( name_strref, "') unimplemented"); } TF_RETURN_IF_ERROR( - llvm::TypeSwitch(attr) + llvm::TypeSwitch(attr) .Case( @@ -448,8 +451,9 @@ Status ConvertAttributes( return absl::OkStatus(); } -Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shaped_type, - AttrValueMap* values) { +absl::Status SetShapeAttribute(absl::string_view name, + mlir::ShapedType shaped_type, + AttrValueMap* values) { AttrValue value; SetTensorShapeProto(shaped_type, value.mutable_list()->add_shape()); @@ -475,7 +479,7 @@ bool IsLegacyCallInstruction(mlir::Operation* inst) { return llvm::dyn_cast(inst); } -Status AddTensorFlowOpPrefix(std::string prefix) { +absl::Status AddTensorFlowOpPrefix(std::string prefix) { GlobalOpPrefixes()->insert(prefix); return absl::OkStatus(); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h index c12c2507e1a03c..28d5df0c8c38ce 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h @@ -42,7 +42,7 @@ namespace tensorflow { using tsl::StatusOr; // Add custom op prefix for TensorFlow dialects. -Status AddTensorFlowOpPrefix(std::string); +absl::Status AddTensorFlowOpPrefix(std::string); // Maps an MLIR op name in the TensorFlow dialect or the TensorFlow control // dialect back into a TensorFlow valid op name. @@ -56,7 +56,7 @@ absl::StatusOr> GetOperationNodeDef( // Converts MLIR attributes with values to their tensorflow equivalent. // "name" and "device" attributes are ignored by default. Use attrs_to_ignore to // specify any other attributes that should be ignored. -Status ConvertAttributes( +absl::Status ConvertAttributes( llvm::ArrayRef attrs, const absl::flat_hash_set& attrs_to_ignore, bool remove_ref_type, AttrValueMap* values); @@ -79,8 +79,8 @@ void SetTensorShapeProto(ShapeContainerT shape, TensorShapeProto* proto) { // Sets shape attribute with the given name. If the attribute already exists // with a different value, returns an error. -Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shape, - AttrValueMap* values); +absl::Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shape, + AttrValueMap* values); // Returns true if the given instruction is an mlir::TF::LegacyCallOp or the // result of such an operation transformed by the diff --git a/tensorflow/compiler/mlir/tensorflow/utils/fake_session.h b/tensorflow/compiler/mlir/tensorflow/utils/fake_session.h index 213cf4e66e16bd..6ded27b0ba7218 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/fake_session.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/fake_session.h @@ -34,24 +34,24 @@ class FakeSession : public tensorflow::Session { public: FakeSession(); - ::tensorflow::Status Create(const tensorflow::GraphDef& graph) override; - ::tensorflow::Status Extend(const tensorflow::GraphDef& graph) override; + absl::Status Create(const tensorflow::GraphDef& graph) override; + absl::Status Extend(const tensorflow::GraphDef& graph) override; - ::tensorflow::Status Close() override; + absl::Status Close() override; - ::tensorflow::Status ListDevices( + absl::Status ListDevices( std::vector* response) override; - ::tensorflow::Status LocalDeviceManager( + absl::Status LocalDeviceManager( const tensorflow::DeviceMgr** deviceMgrPtr) override; - ::tensorflow::Status Run( + absl::Status Run( const std::vector>& inputs, const std::vector& output_names, const std::vector& target_nodes, std::vector<::tensorflow::Tensor>* outputs) override; - ::tensorflow::Status Run( + absl::Status Run( const tensorflow::RunOptions& run_options, const std::vector>& inputs, const std::vector& output_names, @@ -59,7 +59,7 @@ class FakeSession : public tensorflow::Session { std::vector<::tensorflow::Tensor>* outputs, tensorflow::RunMetadata* run_metadata) override; - ::tensorflow::Status Run( + absl::Status Run( const tensorflow::RunOptions& run_options, const std::vector>& inputs, const std::vector& output_names, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc index 01e0784bff3351..50306edb28b067 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc @@ -33,7 +33,8 @@ inline llvm::StringRef StringViewToRef(absl::string_view view) { } } // namespace -Status LoadProtoFromBuffer(absl::string_view input, protobuf::Message* proto) { +absl::Status LoadProtoFromBuffer(absl::string_view input, + protobuf::Message* proto) { // Attempt to parse as text. if (ParseTextProto(input, "", proto).ok()) return absl::OkStatus(); @@ -41,8 +42,8 @@ Status LoadProtoFromBuffer(absl::string_view input, protobuf::Message* proto) { return LoadProtoFromBuffer(input, static_cast(proto)); } -Status LoadProtoFromBuffer(absl::string_view input, - protobuf::MessageLite* proto) { +absl::Status LoadProtoFromBuffer(absl::string_view input, + protobuf::MessageLite* proto) { // Attempt to parse as binary. protobuf::io::ArrayInputStream binary_stream(input.data(), input.size()); if (proto->ParseFromZeroCopyStream(&binary_stream)) return absl::OkStatus(); @@ -52,7 +53,7 @@ Status LoadProtoFromBuffer(absl::string_view input, } template -Status LoadProtoFromFileImpl(absl::string_view input_filename, T* proto) { +absl::Status LoadProtoFromFileImpl(absl::string_view input_filename, T* proto) { const auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(StringViewToRef(input_filename)); if (std::error_code error = file_or_err.getError()) { @@ -67,13 +68,13 @@ Status LoadProtoFromFileImpl(absl::string_view input_filename, T* proto) { return LoadProtoFromBuffer(content, proto); } -Status LoadProtoFromFile(absl::string_view input_filename, - protobuf::Message* proto) { +absl::Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::Message* proto) { return LoadProtoFromFileImpl(input_filename, proto); } -Status LoadProtoFromFile(absl::string_view input_filename, - protobuf::MessageLite* proto) { +absl::Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::MessageLite* proto) { return LoadProtoFromFileImpl(input_filename, proto); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h index ad1531dd4496eb..8b0aaa372b5450 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h @@ -26,18 +26,19 @@ namespace tensorflow { // buffer. Returns error status of the file is not found or malformed proto. // Note that text protos can only be parsed when full protobuf::Message protos // are used, and will fail for protobuf::MessageLite protos. -Status LoadProtoFromBuffer(absl::string_view input, protobuf::Message* proto); -Status LoadProtoFromBuffer(absl::string_view input, - protobuf::MessageLite* proto); +absl::Status LoadProtoFromBuffer(absl::string_view input, + protobuf::Message* proto); +absl::Status LoadProtoFromBuffer(absl::string_view input, + protobuf::MessageLite* proto); // Reads text (.pbtext) or binary (.pb) format of a proto message from the given // file path. Returns error status of the file is not found or malformed proto. // Note that text protos can only be parsed when full protobuf::Message protos // are used, and will fail for protobuf::MessageLite protos. -Status LoadProtoFromFile(absl::string_view input_filename, - protobuf::Message* proto); -Status LoadProtoFromFile(absl::string_view input_filename, - protobuf::MessageLite* proto); +absl::Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::Message* proto); +absl::Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::MessageLite* proto); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc index 6efa412dc43dc9..79efd048815117 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc @@ -70,7 +70,7 @@ string MangleShape(const TensorShapeProto& shape) { return absl::StrCat(kTensorShapePrefix, PrintShortTextProto(shape)); } -Status DemangleShape(absl::string_view str, TensorShapeProto* proto) { +absl::Status DemangleShape(absl::string_view str, TensorShapeProto* proto) { return ParseTextProto(str, kTensorShapePrefix, proto); } @@ -78,7 +78,7 @@ string MangleTensor(const TensorProto& tensor) { return absl::StrCat(kTensorPrefix, PrintShortTextProto(tensor)); } -Status DemangleTensor(absl::string_view str, TensorProto* proto) { +absl::Status DemangleTensor(absl::string_view str, TensorProto* proto) { return ParseTextProto(str, kTensorPrefix, proto); } @@ -86,7 +86,7 @@ string MangleDataType(const DataType& dtype) { return absl::StrCat(kDataTypePrefix, DataType_Name(dtype)); } -Status DemangleDataType(absl::string_view str, DataType* proto) { +absl::Status DemangleDataType(absl::string_view str, DataType* proto) { absl::string_view pbtxt; TF_RETURN_IF_ERROR(ConsumePrefix(str, kDataTypePrefix, &pbtxt)); if (!DataType_Parse(string(pbtxt), proto)) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h b/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h index d694009a25928b..a0c14f27b5b38f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h @@ -43,17 +43,17 @@ MangledKind GetMangledKind(absl::string_view str); // Return a TensorShapeProto mangled as a string. string MangleShape(const TensorShapeProto& shape); // Demangle a string mangled with MangleShape. -Status DemangleShape(absl::string_view str, TensorShapeProto* proto); +absl::Status DemangleShape(absl::string_view str, TensorShapeProto* proto); // Return a TensorProto mangled as a string. string MangleTensor(const TensorProto& tensor); // Demangle a string mangled with MangleTensor. -Status DemangleTensor(absl::string_view str, TensorProto* proto); +absl::Status DemangleTensor(absl::string_view str, TensorProto* proto); // Return a DataType mangled as a string. string MangleDataType(const DataType& dtype); // Demangle a string mangled with MangleDataType. -Status DemangleDataType(absl::string_view str, DataType* proto); +absl::Status DemangleDataType(absl::string_view str, DataType* proto); } // namespace mangling_util } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.cc b/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.cc index 906a058d04e02e..aa2d9406e91765 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.cc @@ -34,8 +34,8 @@ class NoOpErrorCollector : public protobuf::io::ErrorCollector { }; } // namespace -Status ConsumePrefix(absl::string_view str, absl::string_view prefix, - absl::string_view* output) { +absl::Status ConsumePrefix(absl::string_view str, absl::string_view prefix, + absl::string_view* output) { if (absl::StartsWith(str, prefix)) { *output = str.substr(prefix.size()); return absl::OkStatus(); @@ -43,9 +43,9 @@ Status ConsumePrefix(absl::string_view str, absl::string_view prefix, return errors::NotFound("No prefix \"", prefix, "\" in \"", str, "\""); } -Status ParseTextProto(absl::string_view text_proto, - absl::string_view prefix_to_strip, - protobuf::Message* parsed_proto) { +absl::Status ParseTextProto(absl::string_view text_proto, + absl::string_view prefix_to_strip, + protobuf::Message* parsed_proto) { protobuf::TextFormat::Parser parser; // Don't produce errors when attempting to parse text format as it would fail // when the input is actually a binary file. diff --git a/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h b/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h index c1f1e3b111d368..fdeec88c3e054d 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h @@ -25,17 +25,17 @@ namespace tensorflow { // Sets output to the given input with `prefix` stripped, or returns an error if // the prefix doesn't exist. -Status ConsumePrefix(absl::string_view str, absl::string_view prefix, - absl::string_view* output); +absl::Status ConsumePrefix(absl::string_view str, absl::string_view prefix, + absl::string_view* output); // Strips `prefix_to_strip` from `text_proto`, parses, and returns the parsed // proto. -Status ParseTextProto(absl::string_view text_proto, - absl::string_view prefix_to_strip, - protobuf::Message* parsed_proto); -inline Status ParseTextProto(absl::string_view /* text_proto */, - absl::string_view /* prefix_to_strip */, - protobuf::MessageLite* /* parsed_proto */) { +absl::Status ParseTextProto(absl::string_view text_proto, + absl::string_view prefix_to_strip, + protobuf::Message* parsed_proto); +inline absl::Status ParseTextProto(absl::string_view /* text_proto */, + absl::string_view /* prefix_to_strip */, + protobuf::MessageLite* /* parsed_proto */) { return errors::Unavailable("Cannot parse text protos on mobile."); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc index ca250e4cab9b14..07adcb14286ece 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc @@ -36,9 +36,9 @@ std::string SerializeMlirModule(mlir::ModuleOp module_op) { return std::move(os.str()); } -Status DeserializeMlirModule(llvm::StringRef serialized_mlir_module, - mlir::MLIRContext* mlir_context, - mlir::OwningOpRef* mlir_module) { +absl::Status DeserializeMlirModule( + llvm::StringRef serialized_mlir_module, mlir::MLIRContext* mlir_context, + mlir::OwningOpRef* mlir_module) { TF_RET_CHECK(!serialized_mlir_module.empty()) << "unexpected empty serialized MLIR module string"; TF_RET_CHECK(mlir_module) << "unexpected null MLIR module pointer"; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h index 9f43603e3888f9..fc2044135ce369 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h @@ -31,9 +31,9 @@ std::string SerializeMlirModule(mlir::ModuleOp module_op); // Parses a MLIR module from `mlir_module_string` into `mlir_module` with // context `mlir_context`. -Status DeserializeMlirModule(llvm::StringRef serialized_mlir_module, - mlir::MLIRContext* mlir_context, - mlir::OwningOpRef* mlir_module); +absl::Status DeserializeMlirModule( + llvm::StringRef serialized_mlir_module, mlir::MLIRContext* mlir_context, + mlir::OwningOpRef* mlir_module); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc index 4fa429db2ee7c5..130ed731348113 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc @@ -148,7 +148,7 @@ mlir::LogicalResult PrintHloModuleText( return mlir::success(); } -Status ParseArgumentShapes( +absl::Status ParseArgumentShapes( absl::string_view input_shapes_str, llvm::SmallVectorImpl& arg_shapes) { arg_shapes.clear(); @@ -168,8 +168,8 @@ Status ParseArgumentShapes( return absl::OkStatus(); } -Status ParseDataTypes(absl::string_view data_types_str, - llvm::SmallVectorImpl& data_types) { +absl::Status ParseDataTypes(absl::string_view data_types_str, + llvm::SmallVectorImpl& data_types) { data_types.clear(); std::vector input_dtypes_vector; TF_RETURN_IF_ERROR(ParseNodeDataTypes(data_types_str, input_dtypes_vector)); @@ -191,7 +191,7 @@ Status ParseDataTypes(absl::string_view data_types_str, return absl::OkStatus(); } -Status ParseArgumentKinds( +absl::Status ParseArgumentKinds( absl::string_view input_types_str, llvm::SmallVectorImpl& argument_kinds) { argument_kinds.clear(); @@ -216,10 +216,10 @@ Status ParseArgumentKinds( return absl::OkStatus(); } -Status ParseXlaArguments(absl::string_view input_shapes_str, - absl::string_view input_dtypes_str, - absl::string_view arg_kinds_str, - llvm::SmallVectorImpl& xla_arguments) { +absl::Status ParseXlaArguments( + absl::string_view input_shapes_str, absl::string_view input_dtypes_str, + absl::string_view arg_kinds_str, + llvm::SmallVectorImpl& xla_arguments) { xla_arguments.clear(); std::vector>> input_shapes_vector; TF_RETURN_IF_ERROR( @@ -270,7 +270,7 @@ Status ParseXlaArguments(absl::string_view input_shapes_str, // Test BuildHloFromTf. BuildHloFromTf only performs part of the conversion, so // to make this test comparable to other compile tests, the test implements // the remaining parts of the conversion. -Status CompileMlirToXlaHloViaBuilder( +absl::Status CompileMlirToXlaHloViaBuilder( mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, llvm::StringRef device_type, XlaCompilationResult* compilation_result, llvm::MutableArrayRef> diff --git a/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.cc index f65494a279560f..600d9906cd46ac 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h" +#include + #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h index 1a399df89578ac..3ec239c4a33d7a 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFICATION_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFICATION_UTILS_H_ +#include + #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.cc index 5e7768c3ce0fc3..ba4d1b71a857cd 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.h" +#include +#include + #include "absl/log/log.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h index 699388de8457f9..8b87b1c29c999a 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include #include #include diff --git a/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/BUILD b/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/BUILD index 12b29efb6ab76f..f7ec0f89181245 100644 --- a/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/BUILD +++ b/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/BUILD @@ -1,3 +1,4 @@ +load("@local_xla//xla/tsl/platform:build_config_root.bzl", "if_pywrap") load( "//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable", @@ -99,11 +100,13 @@ tf_python_pybind_extension( pytype_srcs = ["pywrap_tensorflow_to_stablehlo.pyi"], # Each dependency MUST be either header-only or exclusive. deps = [ - ":pywrap_tensorflow_to_stablehlo_lib_header_only", "//third_party/python_runtime:headers", "@com_google_absl//absl/strings:string_view", "@pybind11", "@pybind11_abseil//pybind11_abseil:absl_casters", "@pybind11_abseil//pybind11_abseil:status_casters", - ], + ] + if_pywrap( + if_false = [":pywrap_tensorflow_to_stablehlo_lib_header_only"], + if_true = [":pywrap_tensorflow_to_stablehlo_lib_impl"], + ), ) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h index f36266ba3ec304..53431dfea4115b 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h @@ -68,8 +68,7 @@ absl::Status ConvertMLIRToXlaComputation( mlir::ModuleOp module_op, llvm::StringRef device_type, xla::XlaComputation* xla_computation, bool use_tuple_args, bool enable_op_fallback, bool return_tuple, - const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns = - {}, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns = {}, llvm::MutableArrayRef> custom_legalization_passes = {}, llvm::StringRef module_name = llvm::StringRef()); @@ -135,7 +134,7 @@ ABSL_DEPRECATED("Not meant to be used directly and should be a util.") absl::Status PopulateResultIOInfo( mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, bool use_tuple_args, bool use_resource_updates_for_aliases, - const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, XlaCompilationResult* compilation_result); // Runs MLIR Bridge on an MLIR module. @@ -189,7 +188,7 @@ ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") absl::StatusOr CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, llvm::StringRef device_type, bool use_tuple_args, bool enable_op_fallback, - const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, XlaCompilationResult* compilation_result, llvm::MutableArrayRef> custom_legalization_passes = {}, @@ -206,7 +205,7 @@ absl::Status CompileGraphToXlaHlo( mlir::ModuleOp module_op, llvm::ArrayRef args, llvm::StringRef device_type, bool use_tuple_args, bool enable_op_fallback, bool use_return_tuple, - const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, XlaCompilationResult* compilation_result, llvm::MutableArrayRef> custom_legalization_passes); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD index 7ff6a3992aaade..fae6faf6a91140 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD @@ -38,6 +38,7 @@ cc_library( "//tensorflow/core/tpu/kernels:tpu_compile_proto_cc", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:variant", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -338,8 +339,10 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:translate_utils", + "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/translate:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow/translate:upgrade_graph", + "//tensorflow/compiler/mlir/tf2xla/internal:graph_to_tf_executor_util", "//tensorflow/compiler/mlir/tf2xla/internal:node_order", "//tensorflow/compiler/tf2xla:functionalize_control_flow", "//tensorflow/compiler/tf2xla:functionalize_control_flow_util", diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.cc index d731a03c3219dd..7e0f1aa6f27171 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.cc @@ -82,6 +82,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util.h" #include "tensorflow/compiler/mlir/tf2xla/internal/node_order.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" @@ -118,6 +119,7 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stack_frame.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" @@ -841,7 +843,8 @@ absl::Status ImporterBase::AddNodesToShapeRefiner( // If it is the argument node, the shape handle is set explicitly, so it // can be propagated to the body nodes of the function. - if (StringPiece(node->type_string()) == FunctionLibraryDefinition::kArgOp) { + if (absl::string_view(node->type_string()) == + FunctionLibraryDefinition::kArgOp) { auto* node_context = shape_refiner_->GetContext(node); DCHECK(node_context != nullptr); if (const AttrValue* attr = node->attrs().Find("shape")) { @@ -2687,7 +2690,23 @@ absl::StatusOr> ConvertGraphToTfExecutor( const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, mlir::MLIRContext* context, - std::unordered_map* tf_name_to_mlir_name) { + std::unordered_map* tf_name_to_mlir_name, + const ConfigProto& config_proto, + tensorflow::TF2XLABridgeVersion bridge_version) { + if (bridge_version != tensorflow::TF2XLABridgeVersion::kNotBridgeUseCase) { + bool has_unsupported_features_in_mlir_bridge = + GraphHasUnsupportedFeaturesInMlirBridge( + graph, &flib_def, config_proto, + tensorflow::TF2XLABridgeVersion::kNominal, + /*single_core_inference_mode=*/false); + if (has_unsupported_features_in_mlir_bridge) { + LOG(WARNING) + << "Graph contains unsupported features in MLIR bridge. " + << "Use MLIR bridge at your own risk or disable MLIR bridge, e.g., " + << "tf.config.experimental.disable_mlir_bridge."; + } + } + // TODO(jpienaar): Remove need to const_cast. if (specs.upgrade_legacy) { NodeFilter node_filter = specs.restrict_functionalization_to_compiled_nodes diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h b/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h index 4822edd85f7d90..1af93e6b163068 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" @@ -39,7 +40,10 @@ absl::StatusOr> ConvertGraphToTfExecutor( const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, mlir::MLIRContext* context, std::unordered_map* tf_name_to_mlir_name = - nullptr); + nullptr, + const ConfigProto& config_proto = {}, + tensorflow::TF2XLABridgeVersion bridge_version = + tensorflow::TF2XLABridgeVersion::kNotBridgeUseCase); } // namespace v2 } // namespace tf2xla diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc index 9c3d3f4aa74717..5e3e0a439ade1a 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc @@ -17,12 +17,12 @@ limitations under the License. #include #include -#include #include #include #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/StringRef.h" @@ -121,7 +121,7 @@ void DumpComputationInput( } absl::Status DumpHloCompilationResult( - std::string_view name, XlaCompilationResult* compilation_result) { + absl::string_view name, XlaCompilationResult* compilation_result) { if (!VLOG_IS_ON(2) && !DEBUG_DATA_DUMPER()->ShouldDump(std::string(name), kDebugGroupMain)) { return absl::OkStatus(); diff --git a/tensorflow/compiler/mlir/tf2xla/internal/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/BUILD index cbef35b4be949e..57ffbe06ae526e 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/BUILD @@ -371,3 +371,46 @@ tf_cc_test( "@com_google_googletest//:gtest", ], ) + +cc_library( + name = "graph_to_tf_executor_util", + srcs = ["graph_to_tf_executor_util.cc"], + hdrs = ["graph_to_tf_executor_util.h"], + deps = [ + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:framework_types_hdr", + "//tensorflow/core:lib", + "//tensorflow/core/common_runtime:function_body", + "//tensorflow/core/platform:enable_tf2_utils", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + ], +) + +tf_cc_test( + name = "graph_to_tf_executor_util_test", + srcs = ["graph_to_tf_executor_util_test.cc"], + deps = [ + ":graph_to_tf_executor_util", + "//tensorflow/cc:array_ops", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:functional_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/cc:tpu_ops", + "//tensorflow/compiler/tf2xla/ops:xla_ops", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:portable_gif_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/framework:tensor_testutil", + "//tensorflow/core/platform:enable_tf2_utils", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status", + "@local_xla//xla/tsl/lib/core:status_test_util", + ], +) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util.cc b/tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util.cc new file mode 100644 index 00000000000000..1ff482ea53233d --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util.cc @@ -0,0 +1,329 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util.h" + +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/common_runtime/function_body.h" +#include "tensorflow/core/common_runtime/function_def_utils.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/enable_tf2_utils.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tsl/platform/errors.h" + +namespace tensorflow { + +namespace { +// Internal encapsulation of state for the MLIR bridge graph analyzer. Steps +// through the nodes in the graph and reachable functions, tracking whether +// each feature of interest is found. +// +// Tracks the presence of each feature of interest in the corresponding streamz +// metric. Note that the graph traversal does not terminate early so as to +// capture all of these features. +class MlirBridgeGraphAnalyzer { + public: + explicit MlirBridgeGraphAnalyzer(bool single_core_inference_mode) + : single_core_inference_mode_(single_core_inference_mode) {} + ~MlirBridgeGraphAnalyzer() = default; + // Not copyable or movable. + MlirBridgeGraphAnalyzer(const MlirBridgeGraphAnalyzer&) = delete; + MlirBridgeGraphAnalyzer& operator=(const MlirBridgeGraphAnalyzer&) = delete; + + // Analyzes whether the graph has features not guaranteed to be supported by + // the MLIR-based TF XLA bridge. + bool HasUnsupportedFeatures(const Graph& graph, + const FunctionLibraryDefinition* function_library, + std::optional config_proto, + tensorflow::TF2XLABridgeVersion bridge_version) { + // Non-ok status is considered as "unsupported" since this means something + // is wrong or unexpected with the graph itself. + invalid_graph_ = + invalid_graph_ || !AnalyzeGraphAndReachableFunctions( + graph, function_library, config_proto) + .ok(); + + // We conservatively consider the graph to be unsupported if it's not + // *known* to be TF2. That is, graphs that have kNotTracked construction + // context are considered unsupported, even though they might in fact be + // TF2 models. + auto construction_context = graph.GetConstructionContextInternal(); + bool is_tf2 = construction_context == ConstructionContext::kEagerRuntime; + auto is_tf2_execution_enabled = tensorflow::tf2_execution_enabled(); + auto has_unsupported_features = false; + auto is_v1_compat = bridge_version == TF2XLABridgeVersion::kV1Compat; + auto is_nominal_bridge = bridge_version == TF2XLABridgeVersion::kNominal; + auto is_tfrt_bridge = bridge_version == TF2XLABridgeVersion::kTFRTNominal; + is_eager_compliant_ = is_tf2_execution_enabled || is_tf2 || + is_nominal_bridge || is_tfrt_bridge; + + is_eager_compliant_ |= (is_v1_compat && contains_partitioned_call_); + + has_unsupported_features = contains_ref_type_ || invalid_graph_; + + // For non single core inference mode, checking conditions: + if (!single_core_inference_mode_) { + has_unsupported_features |= + !is_eager_compliant_ || uses_v1_control_flow_ || + HasTpuReplicatedCoreUnsupportedFeature(is_nominal_bridge, + is_v1_compat, is_tfrt_bridge); + } + + PrintGraphUnsupportedFeatures(is_tf2, is_tf2_execution_enabled, + is_v1_compat, is_tfrt_bridge, + is_nominal_bridge, has_unsupported_features); + + // Determine whether or not the graph contains unsupported features. + return has_unsupported_features; + } + + private: + static constexpr char kPartitionedCall[] = "TPUPartitionedCall"; + + bool HasTPUReplicatedCoreAttr(const Node& node) { + constexpr absl::string_view kTPUReplicatedCore = "TPU_REPLICATED_CORE"; + const std::string& device = node.requested_device(); + if (!device.empty()) { + DeviceNameUtils::ParsedName name; + if (DeviceNameUtils::ParseFullName(device, &name)) { + // The TPU_REPLICATED_CORE attrs is not relevant for single TPU core + // inference. + // TODO(b/201091475): this can be generalized to check + // num_cores_per_replica != 1, rather than being special cased for + // single core inference. + if (name.type == kTPUReplicatedCore && !single_core_inference_mode_) { + return true; + } + } + } + return false; + } + + bool HasTpuReplicatedCoreUnsupportedFeature(bool is_nominal_bridge, + bool is_v1_compat, + bool is_tfrt_bridge) { + if (!has_tpu_replicated_core_) { + return false; + } + return has_infeed_dequeue_tuple_with_tpu_replicated_core_; + } + + void PrintGraphUnsupportedFeatures(bool is_tf2, bool is_tf2_execution_enabled, + bool is_v1_compat, bool is_tfrt_bridge, + bool is_nominal_bridge, + bool has_unsupported_features) { + if (!has_unsupported_features) { + VLOG(1) << "Graph doesn't have unsupported features"; + return; + } + + LOG(INFO) + << "Graph has unsupported features: " << (is_tf2 ? "" : "not is_tf2, ") + << (is_tf2_execution_enabled ? "" : "not tf2_execution, ") + << (is_nominal_bridge ? "" : "not nominal bridge, ") + << (is_tfrt_bridge ? "" : "not tfrt bridge, ") + << (is_v1_compat && contains_partitioned_call_ + ? "contains partitioned calls at v1 compat bridge call site, " + : "") + << (contains_ref_type_ ? "contains ref variables, " : "") + << (invalid_graph_ ? "Invalid graph, " : "") + << (uses_v1_control_flow_ ? "uses control flow v1 " : "") + << ((has_tpu_replicated_core_ && + has_infeed_dequeue_tuple_with_tpu_replicated_core_) + ? "InfeedDequeueTuple op with TPU_REPLICATED_CORE attr, " + : ""); + } + + // Traverses each node in the graph and gathers information about each of the + // features. Specifically, sets the relevant class variable to true when a + // feature is found. + void AnalyzeGraphNodes(const Graph& graph) { + constexpr absl::string_view kIdentityOp = "Identity"; + constexpr absl::string_view kIdentityNOp = "IdentityN"; + constexpr absl::string_view kCastOp = "Cast"; + constexpr absl::string_view kInfeedDequeueTuple = "InfeedDequeueTuple"; + constexpr absl::string_view kOutsideCompilationAttr = + "_xla_outside_compilation"; + constexpr absl::string_view kAllowSoftPlacementAttr = + "allow_soft_placement"; + constexpr absl::string_view kManualControlDepsAttr = + "_has_manual_control_dependencies"; + + auto has_ref_type = [](const DataTypeVector& types) { + for (const DataType& dtype : types) + if (IsRefType(dtype)) return true; + return false; + }; + + for (const Node* node : graph.nodes()) { + contains_ref_type_ = + (contains_ref_type_ || has_ref_type(node->input_types()) || + has_ref_type(node->output_types())); + contains_partitioned_call_ = (contains_partitioned_call_ || + node->type_string() == kPartitionedCall); + uses_v1_control_flow_ = (uses_v1_control_flow_ || node->IsControlFlow()); + uses_outside_compilation_ = + (uses_outside_compilation_ || + node->attrs().Find(kOutsideCompilationAttr) != nullptr); + has_manual_control_deps_ = (has_manual_control_deps_ || + node->attrs().Find(kManualControlDepsAttr)); + + auto soft_placement_attr = node->attrs().Find(kAllowSoftPlacementAttr); + if (soft_placement_attr != nullptr) { + uses_outside_compilation_ = + (uses_outside_compilation_ || soft_placement_attr->b()); + } + + // TODO(b/187611527): Add support for the ops with explicit device + // assignment on the TPU_REPLICATED_CORE. + if (node->type_string() == kIdentityOp || + node->type_string() == kCastOp || + node->type_string() == kIdentityNOp) { + if (HasTPUReplicatedCoreAttr(*node)) { + has_tpu_replicated_core_ = true; + VLOG(2) << node->type_string() + << " node has TPU_REPLICATED_CORE attribute."; + } + } + if (node->type_string() == kInfeedDequeueTuple && + HasTPUReplicatedCoreAttr(*node)) { + has_infeed_dequeue_tuple_with_tpu_replicated_core_ = true; + } + } + } + + // Analyze all functions from the flib_def if there are any that belong to + // the inference graph. + void AnalyzeInferenceGraphs(const FunctionLibraryDefinition& flib_def) { + if (contains_partitioned_call_) return; + + for (const std::string& func_name : flib_def.ListFunctionNames()) { + const FunctionDef* func_def = flib_def.Find(func_name); + for (const NodeDef& node_def : func_def->node_def()) { + contains_partitioned_call_ = node_def.op() == kPartitionedCall; + if (contains_partitioned_call_) return; + } + } + } + + // Checks any reachable functions from `graph_def` in `flib_def` + // for unsupported features in the MLIR-based bridge. + // + // Returns failure in the event that the FunctionDef fails to convert to + // FunctionBody. Otherwise returns success. + absl::Status AnalyzeReachableFunctions( + const GraphDef& graph_def, const FunctionLibraryDefinition& flib_def) { + // Check the inputs and outputs of a function for reference variables. + auto signature_contains_ref_type = [](const OpDef& signature) { + for (const auto& args : {signature.input_arg(), signature.output_arg()}) { + for (const auto& arg : args) { + if (IsRefType(arg.type())) return true; + } + } + return false; + }; + + for (const std::string& func_name : + flib_def.ReachableDefinitions(graph_def).ListFunctionNames()) { + const FunctionDef* func_def = flib_def.Find(func_name); + if (func_def->has_signature()) { + contains_ref_type_ = contains_ref_type_ || + signature_contains_ref_type(func_def->signature()); + } + // Check the function body. + std::unique_ptr func_body; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *func_def, AttrSlice(&func_def->attr()), &flib_def, &func_body)); + AnalyzeGraphNodes(*func_body->graph); + } + return absl::OkStatus(); + } + + // Checks the inputted graph for any features which aren't supported in the + // MLIR-based bridge, stepping through each node in the graph as well as any + // reachable functions (inputs, outputs, and function body). + // + // Note that this analysis does not terminate early because we care about + // collecting all of these metrics. + // + // Returns failure in the event that the FunctionDef fails to convert to + // FunctionBody. Otherwise returns success. + absl::Status AnalyzeGraphAndReachableFunctions( + const Graph& graph, const FunctionLibraryDefinition* function_library, + std::optional config_proto) { + // First, check whether soft placement is enabled. This means that auto + // outside compilation may be used. + uses_outside_compilation_ = + uses_outside_compilation_ || + (config_proto.has_value() && config_proto->allow_soft_placement()); + + // Analyze each node in this graph. + AnalyzeGraphNodes(graph); + + // Then check any associated functions in the graph + // FunctionLibraryDefinition. + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + TF_RETURN_IF_ERROR(AnalyzeReachableFunctions(graph_def, graph.flib_def())); + // Analyze whether there is an inference graph, including non reachable + // from the `graph` itself. This happens when there is a sequence of + // TPUPartitionedCall()->main()->PartitionedCall() and only second part + // of the graph is processed by the MLIR bridge. + AnalyzeInferenceGraphs(graph.flib_def()); + + // Check any associated function in the graph defined in a separate + // FunctionLibraryDefinition. + if (function_library != nullptr) { + TF_RETURN_IF_ERROR( + AnalyzeReachableFunctions(graph_def, *function_library)); + AnalyzeInferenceGraphs(*function_library); + } + + return absl::OkStatus(); + } + + bool contains_partitioned_call_ = false; + bool contains_ref_type_ = false; + bool invalid_graph_ = false; + bool uses_outside_compilation_ = false; + bool uses_v1_control_flow_ = false; + bool has_manual_control_deps_ = false; + bool single_core_inference_mode_ = false; + bool is_eager_compliant_ = false; + bool has_tpu_replicated_core_ = false; + bool has_infeed_dequeue_tuple_with_tpu_replicated_core_ = false; +}; + +} // namespace + +bool GraphHasUnsupportedFeaturesInMlirBridge( + const Graph& graph, const FunctionLibraryDefinition* function_library, + std::optional config_proto, TF2XLABridgeVersion bridge_version, + bool single_core_inference_mode) { + return MlirBridgeGraphAnalyzer(single_core_inference_mode) + .HasUnsupportedFeatures(graph, function_library, config_proto, + bridge_version); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util.h b/tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util.h new file mode 100644 index 00000000000000..c08a2c39c61886 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util.h @@ -0,0 +1,64 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_GRAPH_TO_TF_EXECUTOR_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_GRAPH_TO_TF_EXECUTOR_UTIL_H_ + +#include + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// These are used for grouping the recorded stats appropriately. Specifically, +// we're considering different entrypoints to the bridge as having potentially +// interesting differences at least in the domain of accepted graphs so we want +// to separately track graph features based on these unique entrypoints. One key +// example of this distinction is for TFRT which uses the "nominal" TPU bridge +// pipeline, but may potentially allow graphs with v1 control flow. This +// separate grouping will allow us to dig into these differences granularly. +enum class TF2XLABridgeVersion { + kNominal = 0, + kV1Compat, + kTFRTNominal, + kNotBridgeUseCase, +}; + +// Analyzes whether the graph has features not guaranteed to be supported by the +// MLIR-based TF XLA bridge for phase 1. If MLIR bridge phase 1 is not used, +// then MLIR bridge phase 2 will not be used. The optional `function_library` +// can be provided if it contains function definitions not including in the +// `graph` FunctionLibraryDefinition. +// +// Conservatively, during the initial rollout, we are not supporting graphs for +// which any of the following are true: +// +// - Not known to be TF2 +// - Contains one or more reference variables +// - Contains one or more TPUPartitionedCall ops (which is a proxy for +// inference), but the graph is not v1 compat +// - Uses V1 control flow +// - Graph is invalid or otherwise encounters error during traversal +// If `single_core_inference_mode` is true, we skip some of check conditions +// because they are not applicable. +// TODO(b/241702857): remove single_core_inference_mode +bool GraphHasUnsupportedFeaturesInMlirBridge( + const Graph& graph, const FunctionLibraryDefinition* function_library, + std::optional config_proto, TF2XLABridgeVersion bridge_version, + bool single_core_inference_mode); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_GRAPH_TO_TF_EXECUTOR_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util_test.cc new file mode 100644 index 00000000000000..66eb1cf1967ba8 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util_test.cc @@ -0,0 +1,732 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util.h" + +#include +#include +#include + +#include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/control_flow_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" +#include "tensorflow/cc/ops/tpu_functional_ops.h" +#include "tensorflow/cc/ops/tpu_replication_ops.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/enable_tf2_utils.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/status.h" + +namespace tensorflow { + +namespace { + +REGISTER_OP("OneRefOutput").Output("y: Ref(float)"); + +FunctionDef XTimesTwo() { + const Tensor kTwo = test::AsScalar(2); + return FunctionDefHelper::Define( + // Name + "XTimesTwo", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}}, + {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, + {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}}, + }); +} + +FunctionDef XTimesTwoFloat() { + const Tensor kTwo = test::AsScalar(2); + return FunctionDefHelper::Define( + // Name + "XTimesTwoFloat", + // Args + {"x: float"}, + // Return values + {"y: float"}, + // Attr def + {}, + // Nodes + { + {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}}, + {{"scale"}, + "Cast", + {"two"}, + {{"SrcT", DT_INT64}, {"DstT", DT_FLOAT}}}, + {{"y"}, "Mul", {"x", "scale"}, {{"T", DT_FLOAT}}}, + }); +} + +FunctionDef XTimesTwoFloatRef() { + const Tensor kTwo = test::AsScalar(2); + return FunctionDefHelper::Define( + // Name + "XTimesTwoFloatRef", + // Args + {"x: float"}, + // Return values + {"y: float"}, + // Attr def + {}, + // Nodes + { + {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64_REF}}}, + {{"scale"}, + "Cast", + {"two"}, + {{"SrcT", DT_INT64_REF}, {"DstT", DT_FLOAT}}}, + {{"y"}, "Mul", {"x", "scale"}, {{"T", DT_FLOAT}}}, + }); +} + +Node* FromNodeDef(absl::string_view name, absl::string_view node_type, + int num_inputs, DataType dt, Graph& graph) { + auto builder = NodeDefBuilder(name, node_type); + for (int i = 0; i < num_inputs; ++i) { + builder = builder.Input(absl::StrCat("node_", i), i, dt); + } + + NodeDef node_def; + TF_CHECK_OK(builder.Finalize(&node_def)); + + absl::Status s; + Node* node = graph.AddNode(node_def, &s); + TF_CHECK_OK(s); + return node; +} + +TEST(SupportedGraphTest, SupportedGraphReturnsFalse) { + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + auto input = tensorflow::ops::Placeholder(root.WithOpName("input"), DT_UINT8); + auto depth = tensorflow::ops::Placeholder(root.WithOpName("depth"), DT_INT32); + auto on = tensorflow::ops::Placeholder(root.WithOpName("on"), DT_UINT8); + auto off = tensorflow::ops::Placeholder(root.WithOpName("off"), DT_UINT8); + tensorflow::set_tf2_execution(true); + (void)tensorflow::ops::OneHot(root.WithOpName("output"), input, depth, on, + off); + + Graph graph(OpRegistry::Global()); + graph.SetConstructionContext(ConstructionContext::kEagerRuntime); + TF_ASSERT_OK(root.ToGraph(&graph)); + + EXPECT_FALSE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kNominal, + /*single_core_inference_mode=*/false)); +} + +TEST(InvalidGraphTest, InvalidFuncBodyReturnsTrue) { + tensorflow::set_tf2_execution(true); + FunctionDefLibrary flib; + *flib.add_function() = XTimesTwo(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + Graph graph(flib_def); + graph.SetConstructionContext(ConstructionContext::kEagerRuntime); + + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + Output x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT); + NameAttrList f_name_attr; + f_name_attr.set_name("XTimesTwo"); + ops::PartitionedCall f(root.WithOpName("f"), {x}, {DT_FLOAT}, f_name_attr); + + TF_ASSERT_OK(root.ToGraph(&graph)); + // The call to XTimesTwo is invalid (missing an attribute), so we expect the + // graph to be unsupported. + EXPECT_TRUE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kNominal, + /*single_core_inference_mode=*/false)); +} + +TEST(RefVarTest, RefVariablesReturnsTrue) { + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL); + Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL); + + // Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + tensorflow::set_tf2_execution(true); + const std::vector shape_array{2, 2}; + auto shape = TensorShape(); + TF_ASSERT_OK(TensorShapeUtils::MakeShape(shape_array, &shape)); + Output value = Output( + FromNodeDef("value", "OneRefOutput", 0, DT_FLOAT_REF, *root.graph())); + + Graph graph(OpRegistry::Global()); + graph.SetConstructionContext(ConstructionContext::kEagerRuntime); + TF_ASSERT_OK(root.ToGraph(&graph)); + + EXPECT_TRUE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kNominal, + /*single_core_inference_mode=*/false)); +} + +TEST(RefVarTest, NoRefVariablesCalleeFuncReturnsFalse) { + tensorflow::set_tf2_execution(true); + FunctionDefLibrary flib; + *flib.add_function() = XTimesTwoFloat(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + Graph graph(flib_def); + graph.SetConstructionContext(ConstructionContext::kEagerRuntime); + + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + Output x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT); + NameAttrList f_name_attr; + f_name_attr.set_name("XTimesTwoFloat"); + ops::PartitionedCall f(root.WithOpName("f"), {x}, {DT_FLOAT}, f_name_attr); + + TF_ASSERT_OK(root.ToGraph(&graph)); + EXPECT_FALSE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kNominal, + /*single_core_inference_mode=*/false)); +} + +TEST(RefVarTest, RefVariablesInCalleeFunctionReturnsTrue) { + tensorflow::set_tf2_execution(true); + FunctionDefLibrary flib; + *flib.add_function() = XTimesTwoFloatRef(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + Graph graph(flib_def); + graph.SetConstructionContext(ConstructionContext::kEagerRuntime); + + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + Output x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT); + NameAttrList f_name_attr; + f_name_attr.set_name("XTimesTwoFloatRef"); + ops::PartitionedCall f(root.WithOpName("f"), {x}, {DT_FLOAT}, f_name_attr); + + TF_ASSERT_OK(root.ToGraph(&graph)); + EXPECT_TRUE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kNominal, + /*single_core_inference_mode=*/false)); +} + +TEST(RefVarTest, RefVariablesInExternalCalleeFunctionReturnsTrue) { + tensorflow::set_tf2_execution(true); + Graph graph(OpRegistry::Global()); + graph.SetConstructionContext(ConstructionContext::kEagerRuntime); + FunctionDefLibrary flib; + *flib.add_function() = XTimesTwoFloatRef(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + Output x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT); + NameAttrList f_name_attr; + f_name_attr.set_name("XTimesTwoFloatRef"); + ops::PartitionedCall f(root.WithOpName("f"), {x}, {DT_FLOAT}, f_name_attr); + + TF_ASSERT_OK(root.ToGraph(&graph)); + EXPECT_TRUE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/&flib_def, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kNominal, + /*single_core_inference_mode=*/false)); +} + +TEST(InferenceTest, ContainsInferenceNodeEagerRuntimeReturnsTrue) { + tensorflow::set_tf2_execution(true); + FunctionDefLibrary flib; + *flib.add_function() = XTimesTwoFloat(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + Graph graph(flib_def); + graph.SetConstructionContext(ConstructionContext::kEagerRuntime); + + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + Output x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT); + NameAttrList f_name_attr; + f_name_attr.set_name("XTimesTwoFloat"); + ops::TPUPartitionedCall f(root.WithOpName("f"), {x}, /*device_ordinal=*/0, + {DT_FLOAT}, f_name_attr); + + TF_ASSERT_OK(root.ToGraph(&graph)); + EXPECT_FALSE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kNominal, + /*single_core_inference_mode=*/false)); +} + +TEST(InferenceTest, ContainsInferenceNodeTFRTBridgeReturnsTrue) { + tensorflow::set_tf2_execution(true); + FunctionDefLibrary flib; + *flib.add_function() = XTimesTwoFloat(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + Graph graph(flib_def); + graph.SetConstructionContext(ConstructionContext::kEagerRuntime); + + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + Output x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT); + NameAttrList f_name_attr; + f_name_attr.set_name("XTimesTwoFloat"); + ops::TPUPartitionedCall f(root.WithOpName("f"), {x}, /*device_ordinal=*/0, + {DT_FLOAT}, f_name_attr); + + TF_ASSERT_OK(root.ToGraph(&graph)); + EXPECT_FALSE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kTFRTNominal, + /*single_core_inference_mode=*/false)); +} + +TEST(InferenceTest, ContainsInferenceNodeDirectSessionReturnsFalse) { + tensorflow::set_tf2_execution(true); + FunctionDefLibrary flib; + *flib.add_function() = XTimesTwoFloat(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + Graph graph(flib_def); + graph.SetConstructionContext(ConstructionContext::kDirectSession); + + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + Output x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT); + NameAttrList f_name_attr; + f_name_attr.set_name("XTimesTwoFloat"); + ops::TPUPartitionedCall f(root.WithOpName("f"), {x}, /*device_ordinal=*/0, + {DT_FLOAT}, f_name_attr); + + TF_ASSERT_OK(root.ToGraph(&graph)); + EXPECT_FALSE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kV1Compat, + /*single_core_inference_mode=*/false)); +} + +TEST(ControlFlowTest, ContainsV1ControlFlowReturnsTrue) { + tensorflow::set_tf2_execution(true); + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL); + Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL); + + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + + ops::Switch switch_a(root.WithOpName("switch_a"), value, cond_a); + ops::Switch switch_b(root.WithOpName("switch_b"), value, cond_b); + + Graph graph(OpRegistry::Global()); + graph.SetConstructionContext(ConstructionContext::kEagerRuntime); + TF_ASSERT_OK(root.ToGraph(&graph)); + + EXPECT_TRUE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kNominal, + /*single_core_inference_mode=*/false)); +} + +TEST(ControlFlowTest, TFRTContainsV1ControlFlowReturnsTrue) { + tensorflow::set_tf2_execution(true); + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL); + Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL); + + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + + ops::Switch switch_a(root.WithOpName("switch_a"), value, cond_a); + ops::Switch switch_b(root.WithOpName("switch_b"), value, cond_b); + + Graph graph(OpRegistry::Global()); + graph.SetConstructionContext(ConstructionContext::kEagerRuntime); + TF_ASSERT_OK(root.ToGraph(&graph)); + + EXPECT_TRUE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kTFRTNominal, + /*single_core_inference_mode=*/false)); +} + +TEST(TFVersionTest, TF1ReturnsTrue) { + tensorflow::set_tf2_execution(false); + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + auto input = tensorflow::ops::Placeholder(root.WithOpName("input"), DT_UINT8); + auto depth = tensorflow::ops::Placeholder(root.WithOpName("depth"), DT_INT32); + auto on = tensorflow::ops::Placeholder(root.WithOpName("on"), DT_UINT8); + auto off = tensorflow::ops::Placeholder(root.WithOpName("off"), DT_UINT8); + (void)tensorflow::ops::OneHot(root.WithOpName("output"), input, depth, on, + off); + + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph)); + graph.SetConstructionContext(ConstructionContext::kDirectSession); + + EXPECT_TRUE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kV1Compat, + /*single_core_inference_mode=*/false)); +} + +TEST(TFVersionTest, TF2ExecutionFalseV1CompatBridgeReturnTrue) { + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + auto input = tensorflow::ops::Placeholder(root.WithOpName("input"), DT_UINT8); + auto depth = tensorflow::ops::Placeholder(root.WithOpName("depth"), DT_INT32); + auto on = tensorflow::ops::Placeholder(root.WithOpName("on"), DT_UINT8); + auto off = tensorflow::ops::Placeholder(root.WithOpName("off"), DT_UINT8); + (void)tensorflow::ops::OneHot(root.WithOpName("output"), input, depth, on, + off); + + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph)); + tensorflow::set_tf2_execution(false); + + EXPECT_TRUE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kV1Compat, + /*single_core_inference_mode=*/false)); +} + +TEST(TFVersionTest, TF2ExecutionTrueV1CompatBridgeReturnFalse) { + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + auto input = tensorflow::ops::Placeholder(root.WithOpName("input"), DT_UINT8); + auto depth = tensorflow::ops::Placeholder(root.WithOpName("depth"), DT_INT32); + auto on = tensorflow::ops::Placeholder(root.WithOpName("on"), DT_UINT8); + auto off = tensorflow::ops::Placeholder(root.WithOpName("off"), DT_UINT8); + (void)tensorflow::ops::OneHot(root.WithOpName("output"), input, depth, on, + off); + + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph)); + tensorflow::set_tf2_execution(true); + + EXPECT_FALSE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kV1Compat, + /*single_core_inference_mode=*/false)); +} + +TEST(TFVersionTest, TF2ExecutionFalseTfrtNominalBridgeReturnFalse) { + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + auto input = tensorflow::ops::Placeholder(root.WithOpName("input"), DT_UINT8); + auto depth = tensorflow::ops::Placeholder(root.WithOpName("depth"), DT_INT32); + auto on = tensorflow::ops::Placeholder(root.WithOpName("on"), DT_UINT8); + auto off = tensorflow::ops::Placeholder(root.WithOpName("off"), DT_UINT8); + (void)tensorflow::ops::OneHot(root.WithOpName("output"), input, depth, on, + off); + + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph)); + tensorflow::set_tf2_execution(false); + + EXPECT_FALSE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kTFRTNominal, + /*single_core_inference_mode=*/false)); +} + +TEST(TFVersionTest, TF2ExecutionTrueTfrtNominalBridgeReturnFalse) { + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + auto input = tensorflow::ops::Placeholder(root.WithOpName("input"), DT_UINT8); + auto depth = tensorflow::ops::Placeholder(root.WithOpName("depth"), DT_INT32); + auto on = tensorflow::ops::Placeholder(root.WithOpName("on"), DT_UINT8); + auto off = tensorflow::ops::Placeholder(root.WithOpName("off"), DT_UINT8); + (void)tensorflow::ops::OneHot(root.WithOpName("output"), input, depth, on, + off); + + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph)); + tensorflow::set_tf2_execution(true); + + EXPECT_FALSE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kTFRTNominal, + /*single_core_inference_mode=*/false)); +} + +TEST(TFVersionTest, TF2ExecutionFalseNominalBridgeReturnsFalse) { + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + auto input = tensorflow::ops::Placeholder(root.WithOpName("input"), DT_UINT8); + + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph)); + tensorflow::set_tf2_execution(false); + + EXPECT_FALSE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kNominal, + /*single_core_inference_mode=*/false)); +} + +TEST(TFVersionTest, TF2ExecutionTrueNominalBridgeReturnsFalse) { + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + auto input = tensorflow::ops::Placeholder(root.WithOpName("input"), DT_UINT8); + + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph)); + tensorflow::set_tf2_execution(true); + + EXPECT_FALSE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kNominal, + /*single_core_inference_mode=*/false)); +} + +TEST(UnsupportedOpTest, + InfeedDequeueTupleWithTPUReplicatedCoreAttrNotSupported) { + tensorflow::set_tf2_execution(true); + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + auto input = + tensorflow::ops::Placeholder(root.WithOpName("node_0"), DT_FLOAT); + + auto node = FromNodeDef("Identity", "Identity", 1, DT_FLOAT, *root.graph()); + ASSERT_NE(node, nullptr); + node->set_requested_device("/device:TPU_REPLICATED_CORE:0"); + + // Build InfeedDequeueTuple node with TPU_REPLICATED_CORE Attr + auto builder = NodeDefBuilder("InfeedDequeueTuple", "InfeedDequeueTuple"); + builder.Attr("dtypes", DT_FLOAT); + builder.Attr("shapes", 1); + NodeDef node_def; + TF_CHECK_OK(builder.Finalize(&node_def)); + absl::Status s; + Node* node_InfeedDequeueTuple = (*root.graph()).AddNode(node_def, &s); + node_InfeedDequeueTuple->set_requested_device( + "/device:TPU_REPLICATED_CORE:0"); + TF_CHECK_OK(s); + ASSERT_NE(node_InfeedDequeueTuple, nullptr); + + Graph graph(OpRegistry::Global()); + graph.SetConstructionContext(ConstructionContext::kEagerRuntime); + TF_ASSERT_OK(root.ToGraph(&graph)); + EXPECT_TRUE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kNominal, + /*single_core_inference_mode=*/false)); + EXPECT_FALSE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kNominal, + /*single_core_inference_mode=*/true)); +} + +TEST(ManualControlDependencyTest, + TPUReplicatedCoreWithManualControlDependencyReturnsFalse) { + tensorflow::set_tf2_execution(true); + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + auto input = + tensorflow::ops::Placeholder(root.WithOpName("node_0"), DT_FLOAT); + + auto node = FromNodeDef("Identity", "Identity", 1, DT_FLOAT, *root.graph()); + ASSERT_NE(node, nullptr); + node->set_requested_device("/device:TPU_REPLICATED_CORE:0"); + + auto metadata = tensorflow::ops::TPUReplicateMetadata(root, 2); + metadata.operation.node()->AddAttr("_has_manual_control_dependencies", true); + + Graph graph(OpRegistry::Global()); + graph.SetConstructionContext(ConstructionContext::kEagerRuntime); + TF_ASSERT_OK(root.ToGraph(&graph)); + + EXPECT_FALSE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kNominal, + /*single_core_inference_mode=*/false)); + EXPECT_FALSE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kNominal, + /*single_core_inference_mode=*/true)); +} + +TEST(InferenceTest, + ContainsInferenceNodeTPUReplicatedCoreDirectSessionReturnsFalse) { + tensorflow::set_tf2_execution(true); + FunctionDefLibrary flib; + *flib.add_function() = XTimesTwoFloat(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + Graph graph(flib_def); + graph.SetConstructionContext(ConstructionContext::kDirectSession); + + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + auto input = + tensorflow::ops::Placeholder(root.WithOpName("node_0"), DT_FLOAT); + auto node = FromNodeDef("Identity", "Identity", 1, DT_FLOAT, *root.graph()); + ASSERT_NE(node, nullptr); + node->set_requested_device("/device:TPU_REPLICATED_CORE:0"); + + Output x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT); + NameAttrList f_name_attr; + f_name_attr.set_name("XTimesTwoFloat"); + ops::TPUPartitionedCall f(root.WithOpName("f"), {x}, /*device_ordinal=*/0, + {DT_FLOAT}, f_name_attr); + + TF_ASSERT_OK(root.ToGraph(&graph)); + EXPECT_FALSE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kV1Compat, + /*single_core_inference_mode=*/false)); +} + +TEST(InferenceTest, + ContainsInferenceNodeTPUReplicatedCoreEagerRuntimeReturnsTrue) { + tensorflow::set_tf2_execution(true); + FunctionDefLibrary flib; + *flib.add_function() = XTimesTwoFloat(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + Graph graph(flib_def); + graph.SetConstructionContext(ConstructionContext::kEagerRuntime); + + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + auto input = + tensorflow::ops::Placeholder(root.WithOpName("node_0"), DT_FLOAT); + auto node = FromNodeDef("Identity", "Identity", 1, DT_FLOAT, *root.graph()); + ASSERT_NE(node, nullptr); + node->set_requested_device("/device:TPU_REPLICATED_CORE:0"); + + Output x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT); + NameAttrList f_name_attr; + f_name_attr.set_name("XTimesTwoFloat"); + ops::TPUPartitionedCall f(root.WithOpName("f"), {x}, /*device_ordinal=*/0, + {DT_FLOAT}, f_name_attr); + + TF_ASSERT_OK(root.ToGraph(&graph)); + EXPECT_FALSE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kNominal, + /*single_core_inference_mode=*/false)); +} + +TEST(InferenceTest, TF2ExecutionFalseV1CompatBridgeReturnFalse) { + tensorflow::set_tf2_execution(false); + FunctionDefLibrary flib; + *flib.add_function() = XTimesTwoFloat(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + Graph graph(flib_def); + graph.SetConstructionContext(ConstructionContext::kDirectSession); + + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + auto input = + tensorflow::ops::Placeholder(root.WithOpName("node_0"), DT_FLOAT); + auto node = FromNodeDef("Identity", "Identity", 1, DT_FLOAT, *root.graph()); + ASSERT_NE(node, nullptr); + node->set_requested_device("/device:TPU_REPLICATED_CORE:0"); + + Output x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT); + NameAttrList f_name_attr; + f_name_attr.set_name("XTimesTwoFloat"); + ops::TPUPartitionedCall f(root.WithOpName("f"), {x}, /*device_ordinal=*/0, + {DT_FLOAT}, f_name_attr); + + TF_ASSERT_OK(root.ToGraph(&graph)); + EXPECT_FALSE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kV1Compat, + /*single_core_inference_mode=*/false)); +} + +TEST(InferenceTest, V1CompatBridgeVariableRefReturnTrue) { + tensorflow::set_tf2_execution(false); + FunctionDefLibrary flib; + *flib.add_function() = XTimesTwoFloat(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + Graph graph(flib_def); + graph.SetConstructionContext(ConstructionContext::kDirectSession); + + ConfigProto config = ConfigProto(); + Scope root = Scope::NewRootScope().ExitOnError(); + + auto input = + tensorflow::ops::Placeholder(root.WithOpName("node_0"), DT_FLOAT); + auto node = FromNodeDef("Identity", "Identity", 1, DT_FLOAT, *root.graph()); + ASSERT_NE(node, nullptr); + node->set_requested_device("/device:TPU_REPLICATED_CORE:0"); + + Output x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT); + NameAttrList f_name_attr; + f_name_attr.set_name("XTimesTwoFloat"); + ops::TPUPartitionedCall f(root.WithOpName("f"), {x}, /*device_ordinal=*/0, + {DT_FLOAT}, f_name_attr); + + Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL); + Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL); + + tensorflow::set_tf2_execution(true); + const std::vector shape_array{2, 2}; + auto shape = TensorShape(); + TF_ASSERT_OK(TensorShapeUtils::MakeShape(shape_array, &shape)); + Output value = Output( + FromNodeDef("value", "OneRefOutput", 0, DT_FLOAT_REF, *root.graph())); + + TF_ASSERT_OK(root.ToGraph(&graph)); + EXPECT_TRUE(GraphHasUnsupportedFeaturesInMlirBridge( + graph, /*function_library=*/nullptr, config, + /*bridge_version=*/tensorflow::TF2XLABridgeVersion::kV1Compat, + /*single_core_inference_mode=*/false)); +} + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc index 0aa5ece97722d2..959120866c722d 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc @@ -55,7 +55,6 @@ namespace internal { // enable logging. constexpr char kBridgeComponent[] = "TFXLABridge"; -using tpu::MlirToHloArgs; using tpu::ShardingAndIndex; absl::Status CompileFromMlirToXlaHlo( diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index 2b5636629143ec..1619fd08bf430f 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -121,6 +121,7 @@ cc_library( "//tensorflow/compiler/mlir:register_common_dialects", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", @@ -140,7 +141,6 @@ cc_library( ], # DEPRECATED: use v2/legalize_tf.h::LegalizeMlirToHlo instead. visibility = [ - "//tensorflow/compiler/mlir/lite/stablehlo:__pkg__", "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__", "//tensorflow/compiler/mlir/tf2xla/internal/passes:__pkg__", ], @@ -154,6 +154,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:xla_sharding_util", "//tensorflow/core:framework", "//tensorflow/core/kernels:conv_grad_shape_utils", + "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:Dialect", @@ -273,6 +274,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/util/quantization:uniform_quant_ops_params", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", @@ -340,7 +342,10 @@ cc_library( "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -354,6 +359,7 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla:xla_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/hlo/ir:hlo", @@ -379,6 +385,8 @@ tf_cc_test( "//tensorflow/core:framework", "//tensorflow/core:ops", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", @@ -456,6 +464,7 @@ tf_cc_test( "//tensorflow/core:core_cpu_base", "//tensorflow/core/framework:allocator", "//tensorflow/core/lib/monitoring:cell_reader", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", @@ -494,6 +503,7 @@ tf_cc_test( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:protos_all_cc", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/infeed_ops_xla_adjust_layout.cc b/tensorflow/compiler/mlir/tf2xla/transforms/infeed_ops_xla_adjust_layout.cc index f1e843b81f5476..f27206dad6dcb8 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/infeed_ops_xla_adjust_layout.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/infeed_ops_xla_adjust_layout.cc @@ -13,13 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include #include -#include -#include -#include "absl/types/span.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc index b7dfe80419258d..9a7ef3232105e8 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/kernel_def.pb.h" namespace mlir { namespace mhlo { diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index 6a04acf8375e42..405b325fdcbe2a 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -15,18 +15,21 @@ limitations under the License. // This file implements logic for lowering TensorFlow dialect to XLA dialect. #include -#include +#include #include #include #include +#include #include #include #include #include #include +#include #include #include +#include "absl/status/status.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc index 34df8fc9759a5c..7061aaa4a5657b 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc @@ -16,10 +16,12 @@ limitations under the License. // This file implements logic for lowering TensorFlow dialect's collective // ops (TF/XLA) to the HLO dialect. +#include +#include #include -#include #include +#include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project @@ -395,7 +397,7 @@ void LegalizeTFCollective::runOnOperation() { patterns.insert(context, &channel_id); patterns.insert(context, &channel_id); - if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) { + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc index a3cbb4ba2cd763..1c7acd41db4be9 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc @@ -17,6 +17,7 @@ limitations under the License. // ops (TF/XLA) to the HLO dialect. #include +#include #include #include #include @@ -47,6 +48,7 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/primitive_util.h" #include "xla/side_effect_util.h" +#include "xla/xla_data.pb.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td index 185216448a15ed..322fcc44ed4a9f 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td @@ -542,14 +542,14 @@ def ArgTypesMatchCallee : Constraint< foreach callOp = [TF_PartitionedCallOp, TF_StatefulPartitionedCallOp] in { def : Pat<(callOp:$op $args, FlatSymbolRefAttr:$f, $config, $config_proto, $executor_type), - (CallOp $f, $args), + (CallOp $f, $args, ConstantAttr), [(ArgTypesMatchCallee $op, $args, $f)]>; } // The extra attr on this op is _disable_call_shape_inference, which we ignore // in the bridge. def : Pat<(TF_LegacyCallOp:$op $args, FlatSymbolRefAttr:$f, $attr), - (CallOp $f, $args), + (CallOp $f, $args, ConstantAttr), [(ArgTypesMatchCallee $op, $args, $f)]>; //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc index 9057e2406fab06..d41ecd7a262a65 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc @@ -12,15 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include #include #include -#include -#include "absl/container/inlined_vector.h" -#include "absl/memory/memory.h" -#include "absl/strings/string_view.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/split_into_island_per_op_pass.cc b/tensorflow/compiler/mlir/tf2xla/transforms/split_into_island_per_op_pass.cc index 50b9f7f2adad2f..ecf3aea5f65d48 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/split_into_island_per_op_pass.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/split_into_island_per_op_pass.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/transforms/split_into_island_per_op_pass.h" -#include #include #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.cc b/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.cc index e43bcdf6d3a26e..e0e7103630fd99 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.cc @@ -15,8 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h" -#include -#include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h b/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h index 13baaba06aadb9..0ad6e9af194518 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h +++ b/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TEST_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TEST_UTILS_H_ +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc index 6f864f8eb52736..16689caaa5573f 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc @@ -14,17 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h" +#include #include #include #include -#include #include #include -#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" -#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -70,9 +72,11 @@ limitations under the License. #include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" +#include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h index dc8b0ad459d2e1..c5c417e27ba022 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h @@ -16,10 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TF2XLA_REWRITER_H_ #define TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TF2XLA_REWRITER_H_ +#include #include #include #include +#include "absl/status/statusor.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc index 2cd2f3591ba0cd..15834412165010 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc @@ -16,12 +16,11 @@ limitations under the License. #include #include -#include #include -#include #include -#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "llvm/Support/Casting.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tfxla_device_specific_transforms.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tfxla_device_specific_transforms.cc index 5531adad8501aa..a7e9726e7575a3 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tfxla_device_specific_transforms.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tfxla_device_specific_transforms.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" #include "tensorflow/compiler/tf2xla/kernels/rng_converter_utils.h" +#include "xla/xla_data.pb.h" namespace mlir { namespace mhlo { diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/utils.cc b/tensorflow/compiler/mlir/tf2xla/transforms/utils.cc index 0b186b6a22ef8d..0152cd1d1a7363 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/utils.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/utils.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/transforms/utils.h" +#include + #include "xla/mlir_hlo/utils/hlo_utils.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/utils.h b/tensorflow/compiler/mlir/tf2xla/transforms/utils.h index a4e6d323e47ab2..5dba4a4dcf894c 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/utils.h +++ b/tensorflow/compiler/mlir/tf2xla/transforms/utils.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_UTILS_H_ +#include + #include "llvm/ADT/ArrayRef.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc index a6435081820880..d99f80ff5eacd5 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc @@ -13,13 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include -#include -#include -#include #include "mlir/IR/BuiltinOps.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Error.h" #include "llvm/Support/FormatVariadic.h" diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization_test.cc index 28d00b48628185..2b1c235c10dca5 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization_test.cc @@ -13,11 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include -#include -#include #include #include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets_test.cc index f774781d376b87..635d7dc15bb72a 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.h" -#include #include #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc index aa38150e6a14c3..f5364586ec73c9 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_test.cc index 4183d181fc5611..e2bda59448be85 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_test.cc @@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index 4583fc9cd967e2..73e6e874555f42 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -33,10 +33,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h" #include "tensorflow/compiler/mlir/tf2xla/internal/passes/mlir_to_graph_passes.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" -#include "tensorflow/compiler/mlir/tosa/tf_passes.h" -#include "tensorflow/compiler/mlir/tosa/tf_tfl_passes.h" -#include "tensorflow/compiler/mlir/tosa/tfl_passes.h" -#include "tensorflow/compiler/mlir/tosa/transforms/passes.h" #include "xla/mlir/framework/transforms/passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -57,10 +53,6 @@ int main(int argc, char **argv) { mlir::quant::stablehlo::registerBridgePasses(); tensorflow::tf2xla::internal::registerTFXLABridgeClusteringPasses(); tensorflow::tf2xla::internal::registerTFXLABridgeMlirToGraphPasses(); - mlir::tosa::registerLegalizeTosaPasses(); - mlir::tosa::registerTFtoTOSALegalizationPipeline(); - mlir::tosa::registerTFLtoTOSALegalizationPipeline(); - mlir::tosa::registerTFTFLtoTOSALegalizationPipeline(); mlir::tf_test::registerTensorFlowTestPasses(); mlir::xla_framework::registerXlaFrameworkPasses(); tensorflow::RegisterConvertMlirToXlaHloPipelineWithDefaults(); diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index d23e3f346b1c5c..babd62f6b13f89 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -14,11 +14,13 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include #include "absl/strings/str_split.h" +#include "absl/types/span.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SMLoc.h" diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index 71cbd5066128b3..deed4d72b99329 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -272,6 +272,7 @@ tf_python_pybind_extension( pytype_srcs = [ "tfr_wrapper.pyi", ], + starlark_only = True, visibility = [ "//tensorflow/python:__pkg__", ], diff --git a/tensorflow/compiler/mlir/tfr/build_defs.bzl b/tensorflow/compiler/mlir/tfr/build_defs.bzl index d92bc2f625fb41..fca80aa5f63cec 100644 --- a/tensorflow/compiler/mlir/tfr/build_defs.bzl +++ b/tensorflow/compiler/mlir/tfr/build_defs.bzl @@ -49,7 +49,6 @@ def gen_op_libraries( srcs = [], outs = [name + ".inc.cc"], cmd = - "PYWRAP_TARGET='//tensorflow/python:_pywrap_tensorflow' " + "$(location %s) --output=$@ --gen_register_op=true" % gen_op_lib_exec, tools = [":" + gen_op_lib_exec], tags = tags, @@ -114,7 +113,6 @@ def gen_op_libraries( srcs = [], outs = [name + ".mlir"], cmd = - "PYWRAP_TARGET='//tensorflow/python:_pywrap_tensorflow' " + "$(location %s) --output=$@ --gen_register_op=false" % gen_tfr_lib_exec, tools = [":" + gen_tfr_lib_exec], tags = tags, diff --git a/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc b/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc index 443781b6b63a7e..9cc555b7893563 100644 --- a/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc +++ b/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc @@ -14,8 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include -#include #include "llvm/Support/raw_ostream.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfr/passes/decompose.cc b/tensorflow/compiler/mlir/tfr/passes/decompose.cc index 988dc9e612b9c3..ce5ede14edea85 100644 --- a/tensorflow/compiler/mlir/tfr/passes/decompose.cc +++ b/tensorflow/compiler/mlir/tfr/passes/decompose.cc @@ -19,13 +19,10 @@ limitations under the License. #include #include #include -#include #include #include #include -#include "absl/memory/memory.h" -#include "absl/strings/string_view.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -136,7 +133,7 @@ void DecomposeTFOpsPass::ApplyCanonicalization() { populateWithGenerated(patterns); populateCanonicalizationPatterns(func, patterns); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { diff --git a/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc index 61aa404847ee07..4f079395063a8f 100644 --- a/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc +++ b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc @@ -13,15 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include #include -#include #include #include #include -#include "absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" @@ -508,7 +504,7 @@ void RaiseToTFOpsPass::runOnOperation() { populateCanonicalizationPatterns(func, patterns); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } } // namespace diff --git a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc index 2d9f00bd390380..60f04e19cefec4 100644 --- a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc +++ b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h" #include +#include +#include #include #include diff --git a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h index 7a62c70e9d6ac5..b27b6aa9ba3d84 100644 --- a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h +++ b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_ANALYSIS_COST_ANALYSIS_H_ #define TENSORFLOW_COMPILER_MLIR_TFRT_ANALYSIS_COST_ANALYSIS_H_ +#include #include #include "absl/strings/string_view.h" diff --git a/tensorflow/compiler/mlir/tfrt/function/function.h b/tensorflow/compiler/mlir/tfrt/function/function.h index 71d046390da6ed..8d09f8cb3f51f1 100644 --- a/tensorflow/compiler/mlir/tfrt/function/function.h +++ b/tensorflow/compiler/mlir/tfrt/function/function.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.cc b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.cc index 30f6aa234a2d59..405b668dc4c25a 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.cc +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.h" +#include +#include #include #include "mlir/IR/Builders.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc index 69d0525db1fba8..b54b809dedd6dd 100644 --- a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc +++ b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc @@ -69,12 +69,12 @@ ProcessTensorSpec(mlir::TensorType type) { } // namespace -Status MapFunctionSignaturesFromTFSavedModelMLIR( +absl::Status MapFunctionSignaturesFromTFSavedModelMLIR( mlir::ModuleOp module, llvm::function_ref map_fn) { // Create bound inputs for each functions. mlir::SymbolTable symbol_table(module); - tensorflow::Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); module.walk([&symbol_table, map_fn, &status](mlir::func::FuncOp func) { // Use the exported name as the function name, and skip non-exported // functions. diff --git a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h index 091e6642650b25..087d50deec8cf6 100644 --- a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h +++ b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h @@ -71,7 +71,7 @@ struct TFRTSavedModelSignatureInfo { // Apply `map_fn` on every exported function in the module with the // corresponding signature metadata populated in TFRTSavedModelSignatureInfo for // the function. -Status MapFunctionSignaturesFromTFSavedModelMLIR( +absl::Status MapFunctionSignaturesFromTFSavedModelMLIR( mlir::ModuleOp module, llvm::function_ref map_fn); diff --git a/tensorflow/compiler/mlir/tfrt/tests/analysis/update_op_cost_in_tfrt_mlir_test.cc b/tensorflow/compiler/mlir/tfrt/tests/analysis/update_op_cost_in_tfrt_mlir_test.cc index b8a071eb35bce6..f45a612a906006 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/analysis/update_op_cost_in_tfrt_mlir_test.cc +++ b/tensorflow/compiler/mlir/tfrt/tests/analysis/update_op_cost_in_tfrt_mlir_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.h" #include -#include #include #include diff --git a/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.cc b/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.cc index 5e46782cffb93e..57734794a11792 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h" #include +#include #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h b/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h index cc1f2a04d6ea3a..be212e444d86dc 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/cross_device_transfer.cc b/tensorflow/compiler/mlir/tfrt/transforms/cross_device_transfer.cc index 790df9f6ec01d4..f9845f53fac7d3 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/cross_device_transfer.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/cross_device_transfer.cc @@ -16,6 +16,10 @@ limitations under the License. // This pass inserts corert.transfer op to make sure any argument of any op is // on the same device of the op itself. +#include +#include +#include + #include "llvm/ADT/StringMap.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfrt/transforms/deduplicate_batch_function.cc b/tensorflow/compiler/mlir/tfrt/transforms/deduplicate_batch_function.cc index 0ef6c18f0a3e5d..986b766d897859 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/deduplicate_batch_function.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/deduplicate_batch_function.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include +#include +#include #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringMap.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h b/tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h index bb0b70a457d5d1..c1c1d42a91373e 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_FALLBACK_CONVERTER_H_ #define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_FALLBACK_CONVERTER_H_ +#include + #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/fuse_tpu_compile_and_execute_ops.cc b/tensorflow/compiler/mlir/tfrt/transforms/fuse_tpu_compile_and_execute_ops.cc index 4895c4fb584b1c..77de1e0eb48669 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/fuse_tpu_compile_and_execute_ops.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/fuse_tpu_compile_and_execute_ops.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD index 0e3b89ec8a08e9..05a791bf30aba3 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD @@ -156,19 +156,17 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:fingerprint", - "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:client_library", - "@local_xla//xla/client:compile_only_client", "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "@local_xla//xla/pjrt:pjrt_compiler", "@local_xla//xla/python/ifrt", "@local_xla//xla/service:computation_placer_hdr", "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/stream_executor:platform_manager", - "@local_xla//xla/tsl/concurrency:ref_count", ], ) @@ -220,11 +218,10 @@ tf_cc_test( "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/pjrt:pjrt_client", "@local_xla//xla/pjrt:pjrt_compiler", - "@local_xla//xla/pjrt/cpu:cpu_client", "@local_xla//xla/pjrt/gpu:se_gpu_pjrt_client", + "@local_xla//xla/pjrt/plugin/xla_cpu:cpu_topology_description", "@local_xla//xla/python/ifrt", "@local_xla//xla/python/ifrt:mock", "@local_xla//xla/python/ifrt:test_util", diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc index 58543648b795f6..f7256bf0707844 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc @@ -31,12 +31,15 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.h" @@ -232,6 +235,7 @@ absl::StatusOr CompileTfToHlo(const Tf2HloArg& arg) { std::vector arg_shapes; + arg_shapes.reserve(arg.input_dtypes_and_shapes.size()); for (const auto& input : arg.input_dtypes_and_shapes) { arg_shapes.push_back(input.shape); } @@ -269,8 +273,7 @@ absl::StatusOr CompileTfToHlo(const Tf2HloArg& arg) { return result; } -absl::StatusOr TfToHloCompiler::CompileTfToHlo( - const Tf2HloArg& arg) { +absl::StatusOr TfToHloCompiler::CompileTfToHlo(Tf2HloArg& arg) { return tensorflow::ifrt_serving::CompileTfToHlo(arg); } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h index dc6fd392d1aff6..7122f26e082291 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h @@ -19,22 +19,19 @@ limitations under the License. #include #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_compilation.pb.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "xla/client/compile_only_client.h" #include "xla/python/ifrt/client.h" -#include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/topology.h" #include "xla/service/hlo.pb.h" -#include "xla/tsl/concurrency/ref_count.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" @@ -43,12 +40,16 @@ namespace ifrt_serving { struct Tf2HloArg { mlir::ModuleOp module; - absl::Span input_dtypes_and_shapes; + // `input_dtypes_and_shapes` can be mutable during Tf2HLO compilation. + std::vector input_dtypes_and_shapes; + absl::Span variable_arg_indices; absl::string_view entry_function_name; + // `compile_metadata` can be mutable during Tf2HLO compilation. tensorflow::tpu::TPUCompileMetadataProto compile_metadata; tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn; std::shared_ptr topology; absl::string_view platform_name; + bool enable_r1_optimization = true; absl::StatusOr Fingerprint() const; }; @@ -76,7 +77,7 @@ class TfToHloCompiler { // CompileTfToHlo. virtual absl::StatusOr Key(const Tf2HloArg& arg); - virtual absl::StatusOr CompileTfToHlo(const Tf2HloArg& arg); + virtual absl::StatusOr CompileTfToHlo(Tf2HloArg& arg); }; } // namespace ifrt_serving diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc index dcd2ea5c5bbbf5..24252c40ae7da9 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc @@ -37,9 +37,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" +#include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_topology_description.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/mock.h" #include "xla/python/ifrt/test_util.h" @@ -49,8 +50,6 @@ limitations under the License. #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" #include "tsl/platform/protobuf.h" -#include "tsl/platform/status_matchers.h" -#include "tsl/platform/statusor.h" namespace tensorflow { namespace ifrt_serving { @@ -58,7 +57,6 @@ namespace { using ::testing::Eq; using ::testing::HasSubstr; using ::testing::Ne; -using ::testing::status::IsOkAndHolds; using tsl::testing::StatusIs; // TODO(b/229726259): Make EqualsProto available in OSS @@ -120,17 +118,19 @@ TEST_F(Tf2HloTest, Empty) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, {})); - xla::TfrtCpuTopologyDescription cpu_topology = - xla::TfrtCpuTopologyDescription::Create( + xla::CpuTopologyDescription cpu_topology = + xla::CpuTopologyDescription::Create( xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, + /*devices=*/std::vector>{}, /*machine_attributes=*/std::vector{}); - std::shared_ptr cpu_topology_ptr = - std::make_shared(cpu_topology); + std::shared_ptr cpu_topology_ptr = + std::make_shared(cpu_topology); + std::vector variable_arg_indices; Tf2HloArg arg{ .module = mlir_module.get(), .input_dtypes_and_shapes = {}, + .variable_arg_indices = variable_arg_indices, .entry_function_name = "main", .compile_metadata = compile_metadata, .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), @@ -168,17 +168,20 @@ TEST_F(Tf2HloTest, Tuple) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::TfrtCpuTopologyDescription cpu_topology = - xla::TfrtCpuTopologyDescription::Create( + xla::CpuTopologyDescription cpu_topology = + xla::CpuTopologyDescription::Create( xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, + /*devices=*/std::vector>{}, /*machine_attributes=*/std::vector{}); - std::shared_ptr cpu_topology_ptr = - std::make_shared(cpu_topology); + std::shared_ptr cpu_topology_ptr = + std::make_shared(cpu_topology); + + std::vector variable_arg_indices; Tf2HloArg arg{ .module = mlir_module.get(), .input_dtypes_and_shapes = dtype_and_shapes, + .variable_arg_indices = variable_arg_indices, .entry_function_name = "main", .compile_metadata = compile_metadata, .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), @@ -216,17 +219,19 @@ TEST_F(Tf2HloTest, Spmd) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::TfrtCpuTopologyDescription cpu_topology = - xla::TfrtCpuTopologyDescription::Create( + xla::CpuTopologyDescription cpu_topology = + xla::CpuTopologyDescription::Create( xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, + /*devices=*/std::vector>{}, /*machine_attributes=*/std::vector{}); - std::shared_ptr cpu_topology_ptr = - std::make_shared(cpu_topology); + std::shared_ptr cpu_topology_ptr = + std::make_shared(cpu_topology); + std::vector variable_arg_indices; Tf2HloArg arg{ .module = mlir_module.get(), .input_dtypes_and_shapes = dtype_and_shapes, + .variable_arg_indices = variable_arg_indices, .entry_function_name = "main", .compile_metadata = compile_metadata, .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), @@ -302,17 +307,19 @@ TEST_F(Tf2HloTest, UsingDefaultDeviceAssignment) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::TfrtCpuTopologyDescription cpu_topology = - xla::TfrtCpuTopologyDescription::Create( + xla::CpuTopologyDescription cpu_topology = + xla::CpuTopologyDescription::Create( xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, + /*devices=*/std::vector>{}, /*machine_attributes=*/std::vector{}); - std::shared_ptr cpu_topology_ptr = - std::make_shared(cpu_topology); + std::shared_ptr cpu_topology_ptr = + std::make_shared(cpu_topology); + std::vector variable_arg_indices; Tf2HloArg arg{ .module = mlir_module.get(), .input_dtypes_and_shapes = dtype_and_shapes, + .variable_arg_indices = variable_arg_indices, .entry_function_name = "main", .compile_metadata = compile_metadata, .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), @@ -413,17 +420,19 @@ TEST_F(Tf2HloTest, XlaCallHostCallback) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::TfrtCpuTopologyDescription cpu_topology = - xla::TfrtCpuTopologyDescription::Create( + xla::CpuTopologyDescription cpu_topology = + xla::CpuTopologyDescription::Create( xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, + /*devices=*/std::vector>{}, /*machine_attributes=*/std::vector{}); - std::shared_ptr cpu_topology_ptr = - std::make_shared(cpu_topology); + std::shared_ptr cpu_topology_ptr = + std::make_shared(cpu_topology); + std::vector variable_arg_indices; Tf2HloArg arg{ .module = mlir_module.get(), .input_dtypes_and_shapes = dtype_and_shapes, + .variable_arg_indices = variable_arg_indices, .entry_function_name = "main", .compile_metadata = compile_metadata, .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), @@ -471,9 +480,11 @@ TEST_F(Tf2HloTest, GpuCompile) { GetCompileMetadata(mlir_module.get(), mock_client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); + std::vector variable_arg_indices; Tf2HloArg arg{ .module = mlir_module.get(), .input_dtypes_and_shapes = dtype_and_shapes, + .variable_arg_indices = variable_arg_indices, .entry_function_name = "main", .compile_metadata = compile_metadata, .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), @@ -519,17 +530,19 @@ TEST_F(Tf2HloTest, SameArgProduceSameKeyFingerprint) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::TfrtCpuTopologyDescription cpu_topology = - xla::TfrtCpuTopologyDescription::Create( + xla::CpuTopologyDescription cpu_topology = + xla::CpuTopologyDescription::Create( xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, + /*devices=*/std::vector>{}, /*machine_attributes=*/std::vector{}); - std::shared_ptr cpu_topology_ptr = - std::make_shared(cpu_topology); + std::shared_ptr cpu_topology_ptr = + std::make_shared(cpu_topology); + std::vector variable_arg_indices; Tf2HloArg arg0{ .module = mlir_module.get(), .input_dtypes_and_shapes = dtype_and_shapes, + .variable_arg_indices = variable_arg_indices, .entry_function_name = "main", .compile_metadata = compile_metadata, .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), @@ -540,6 +553,7 @@ TEST_F(Tf2HloTest, SameArgProduceSameKeyFingerprint) { Tf2HloArg arg1{ .module = mlir_module_clone.get(), .input_dtypes_and_shapes = dtype_and_shapes, + .variable_arg_indices = variable_arg_indices, .entry_function_name = "main", .compile_metadata = compile_metadata, .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), @@ -578,17 +592,19 @@ TEST_F(Tf2HloTest, DifferentCompileMetadataProduceDifferentKeyFingerprint) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::TfrtCpuTopologyDescription cpu_topology = - xla::TfrtCpuTopologyDescription::Create( + xla::CpuTopologyDescription cpu_topology = + xla::CpuTopologyDescription::Create( xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, + /*devices=*/std::vector>{}, /*machine_attributes=*/std::vector{}); - std::shared_ptr cpu_topology_ptr = - std::make_shared(cpu_topology); + std::shared_ptr cpu_topology_ptr = + std::make_shared(cpu_topology); + std::vector variable_arg_indices; Tf2HloArg arg0{ .module = mlir_module.get(), .input_dtypes_and_shapes = dtype_and_shapes, + .variable_arg_indices = variable_arg_indices, .entry_function_name = "main", .compile_metadata = compile_metadata, .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), @@ -600,6 +616,7 @@ TEST_F(Tf2HloTest, DifferentCompileMetadataProduceDifferentKeyFingerprint) { Tf2HloArg arg1{ .module = mlir_module_clone.get(), .input_dtypes_and_shapes = dtype_and_shapes, + .variable_arg_indices = variable_arg_indices, .entry_function_name = "main", .compile_metadata = compile_metadata, .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), diff --git a/tensorflow/compiler/mlir/tfrt/transforms/insert_tensor_copy.cc b/tensorflow/compiler/mlir/tfrt/transforms/insert_tensor_copy.cc index d6c87abeedd54a..a36ec754708f66 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/insert_tensor_copy.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/insert_tensor_copy.cc @@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc b/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc index 01ae5811b46b9a..34b37eeefe7843 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc @@ -14,8 +14,12 @@ limitations under the License. ==============================================================================*/ #include +#include +#include #include +#include #include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/merge_tf_if_ops.cc b/tensorflow/compiler/mlir/tfrt/transforms/merge_tf_if_ops.cc index f2f3b9a3f84f1e..676ac471230a07 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/merge_tf_if_ops.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/merge_tf_if_ops.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include +#include +#include #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc index c554f8a26490e6..d2947825126915 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc @@ -1051,17 +1051,9 @@ class TfToMlrtConversionPass }; type_converter_.addTargetMaterialization(future_to_tensor_materialization); + type_converter_.addSourceMaterialization(future_to_tensor_materialization); type_converter_.addArgumentMaterialization( future_to_tensor_materialization); - type_converter_.addSourceMaterialization( - [](mlir::OpBuilder &builder, mlir::Type result_type, - mlir::ValueRange inputs, - mlir::Location loc) -> mlir::Value { - return builder - .create(loc, result_type, - inputs) - .getResult(0); - }); if (use_tpu_host_allocator_for_inputs_.hasValue()) { options_.use_tpu_host_allocator_for_inputs = diff --git a/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc b/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc index 202aa9c8d2f9ec..5a1ae5a80dfd2b 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc @@ -159,7 +159,7 @@ class OptimizeTfForTfrt EliminateCommonMultinomialOps(func.getBody().front()); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily(func, patterns_))) + if (mlir::failed(mlir::applyPatternsGreedily(func, patterns_))) signalPassFailure(); } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc b/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc index 57b07c69bf2b55..222452b9e9c8a6 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc @@ -95,8 +95,8 @@ struct TfrtXlaRewritePass patterns.add(&getContext()); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (mlir::failed( + mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); return; } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index d706babc9b3662..4644e09cfe2fd7 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -119,7 +119,9 @@ tf_cc_binary( "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:lib", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Analysis", "@llvm-project//llvm:CodeGen", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc index 3bcd745c6fa86e..6c57be1081da25 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc @@ -22,7 +22,9 @@ #include #include +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "llvm/ADT/SmallString.h" #include "llvm/Analysis/TargetLibraryInfo.h" diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h b/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h index e83f286c6a2788..15d105ca23d7de 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" // from @llvm-project #include "tensorflow/core/framework/resource_base.h" #include "tensorflow/core/framework/resource_op_kernel.h" @@ -40,7 +41,7 @@ class JITCache : public tensorflow::ResourceBase { std::string DebugString() const override; ExecutionEngine* LookupOrCompile( - const std::string code, + std::string code, std::function>()> compile_callback); size_t Size(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index 2ddadfe5cf930c..88564d60422f6d 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -165,13 +165,14 @@ cc_library( "@local_xla//xla/mlir_hlo:type_conversion", "@local_xla//xla/service/gpu:gpu_asm_opts_util", "@local_xla//xla/service/gpu:target_constants", - "@local_xla//xla/service/gpu/llvm_gpu_backend", "@local_xla//xla/stream_executor:device_description", ] + if_cuda_is_configured([ "@local_tsl//tsl/platform:cuda_root_path", "@local_xla//xla/stream_executor/cuda:cuda_asm_compiler", + "@local_xla//xla/service/gpu/llvm_gpu_backend:nvptx_backend", ]) + if_rocm_is_configured([ "@local_xla//xla/stream_executor/gpu:asm_compiler", + "@local_xla//xla/service/gpu/llvm_gpu_backend:amdgpu_backend", "//tensorflow/core/platform:rocm_rocdl_path", ]), ) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc index 10f7db030e7abf..c59edef81929c4 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc @@ -21,28 +21,29 @@ limitations under the License. #include #include "llvm/Transforms/Utils/Cloning.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Target/LLVMIR/Export.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Target/LLVMIR/Export.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" #include "xla/debug_options_flags.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/gpu/gpu_asm_opts_util.h" -#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "xla/service/gpu/llvm_gpu_backend/amdgpu_backend.h" #include "xla/service/gpu/target_constants.h" #include "xla/stream_executor/device_description.h" #include "xla/xla.pb.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/path.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/statusor.h" #if GOOGLE_CUDA +#include "xla/service/gpu/llvm_gpu_backend/nvptx_backend.h" #include "xla/stream_executor/cuda/cuda_asm_compiler.h" #elif TENSORFLOW_USE_ROCM -#include "xla/stream_executor/gpu/asm_compiler.h" #include "tensorflow/core/platform/rocm_rocdl_path.h" +#include "xla/stream_executor/gpu/asm_compiler.h" #endif namespace mlir { diff --git a/tensorflow/compiler/mlir/tosa/BUILD b/tensorflow/compiler/mlir/tosa/BUILD index 930c3e904bbc86..e854098dee6ccf 100644 --- a/tensorflow/compiler/mlir/tosa/BUILD +++ b/tensorflow/compiler/mlir/tosa/BUILD @@ -16,17 +16,7 @@ package( package_group( name = "internal", packages = [ - "//tensorflow/compiler/mlir/...", - ], -) - -package_group( - name = "friends", - includes = [ - ":internal", - ], - packages = [ - "//third_party/iree/...", + "//tensorflow/compiler/mlir/tosa/...", ], ) @@ -41,6 +31,7 @@ filegroup( gentbl_cc_library( name = "tosa_passes_inc_gen", compatible_with = get_compatible_with_portable(), + tags = ["tf_tosa"], tbl_outs = [ ( [ @@ -64,6 +55,7 @@ cc_library( "transforms/passes.h.inc", ], compatible_with = get_compatible_with_portable(), + tags = ["tf_tosa"], deps = [ "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -82,6 +74,7 @@ cc_library( "transforms/legalize_utils.h", ], compatible_with = get_compatible_with_portable(), + tags = ["tf_tosa"], deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite/kernels/internal:common", @@ -90,6 +83,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", "//tensorflow/core:framework", "//tensorflow/core/kernels:conv_grad_shape_utils", + "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithUtils", @@ -111,6 +105,7 @@ cc_library( gentbl_cc_library( name = "tosa_legalize_tf_inc_gen", compatible_with = get_compatible_with_portable(), + tags = ["tf_tosa"], tbl_outs = [ ( ["-gen-rewriters"], @@ -141,7 +136,7 @@ cc_library( "transforms/passes.h", ], compatible_with = get_compatible_with_portable(), - visibility = [":friends"], + tags = ["tf_tosa"], deps = [ ":legalize_common", ":passes_header", @@ -166,6 +161,7 @@ cc_library( gentbl_cc_library( name = "tosa_legalize_tfl_inc_gen", compatible_with = get_compatible_with_portable(), + tags = ["tf_tosa"], tbl_outs = [ ( ["-gen-rewriters"], @@ -202,7 +198,7 @@ cc_library( "transforms/passes.h", ], compatible_with = get_compatible_with_portable(), - visibility = [":friends"], + tags = ["tf_tosa"], deps = [ ":legalize_common", ":passes_header", @@ -237,7 +233,7 @@ cc_library( "transforms/passes.h", ], compatible_with = get_compatible_with_portable(), - visibility = [":friends"], + tags = ["tf_tosa"], deps = [ ":legalize_common", ":passes_header", diff --git a/tensorflow/compiler/mlir/tosa/tests/BUILD b/tensorflow/compiler/mlir/tosa/tests/BUILD index a523ba82942c64..e936d924ef4abb 100644 --- a/tensorflow/compiler/mlir/tosa/tests/BUILD +++ b/tensorflow/compiler/mlir/tosa/tests/BUILD @@ -9,6 +9,7 @@ package( glob_lit_tests( name = "all_tests", data = [":test_utilities"], + default_tags = ["tf_tosa"], driver = "@llvm-project//mlir:run_lit.sh", size_override = { "tf-to-tosa-pipeline.mlir": "medium", diff --git a/tensorflow/compiler/mlir/tosa/tests/convert-tfl-uint8.mlir b/tensorflow/compiler/mlir/tosa/tests/convert-tfl-uint8.mlir index 8a1c0615f6e03c..d44a968ac0ea60 100644 --- a/tensorflow/compiler/mlir/tosa/tests/convert-tfl-uint8.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/convert-tfl-uint8.mlir @@ -1,4 +1,5 @@ // RUN: tf-opt --tosa-convert-tfl-uint8 --verify-each %s | FileCheck %s +// REQUIRES: tf_tosa // Operations for testing --tosa-convert-tfl-uint8 diff --git a/tensorflow/compiler/mlir/tosa/tests/convert_metadata.mlir b/tensorflow/compiler/mlir/tosa/tests/convert_metadata.mlir index 7fb03c7728c179..5d7c3316b19ef2 100644 --- a/tensorflow/compiler/mlir/tosa/tests/convert_metadata.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/convert_metadata.mlir @@ -1,4 +1,5 @@ // RUN: tf-opt --split-input-file --pass-pipeline='builtin.module(func.func(tosa-tflite-convert-function-metadata))' %s | FileCheck %s +// REQUIRES: tf_tosa module attributes {tfl.schema_version = 3 : i32} { // CHECK: func.func @main( diff --git a/tensorflow/compiler/mlir/tosa/tests/fuse-bias-tf.mlir b/tensorflow/compiler/mlir/tosa/tests/fuse-bias-tf.mlir index 2850e123848332..f2c6c6cbeb9624 100644 --- a/tensorflow/compiler/mlir/tosa/tests/fuse-bias-tf.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/fuse-bias-tf.mlir @@ -1,4 +1,5 @@ // RUN: tf-opt --tosa-fuse-bias-tf --verify-each %s | FileCheck %s +// REQUIRES: tf_tosa // Operations for testing --tosa-fuse-bias-tf diff --git a/tensorflow/compiler/mlir/tosa/tests/lower-complex-types.mlir b/tensorflow/compiler/mlir/tosa/tests/lower-complex-types.mlir index fe6a8ea07b163e..c9b59c2201c313 100644 --- a/tensorflow/compiler/mlir/tosa/tests/lower-complex-types.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/lower-complex-types.mlir @@ -1,4 +1,5 @@ // RUN: tf-opt --split-input-file --tosa-lower-complex-types --verify-each %s | FileCheck %s +// REQUIRES: tf_tosa // CHECK-LABEL: test_complex_input // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x4x4x2xf32> diff --git a/tensorflow/compiler/mlir/tosa/tests/multi_add.mlir b/tensorflow/compiler/mlir/tosa/tests/multi_add.mlir index c513f2ec936aee..28f3192bae2f6d 100644 --- a/tensorflow/compiler/mlir/tosa/tests/multi_add.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/multi_add.mlir @@ -1,4 +1,5 @@ // RUN: tf-opt --tfl-to-tosa-pipeline=target-compilation-backend %s | FileCheck %s +// REQUIRES: tf_tosa // CHECK: tensor<1x8x8x3xf32> {ml_program.identifier = "a"} // CHECK-SAME: tensor<1x8x8x3xf32> {ml_program.identifier = "b"} diff --git a/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir b/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir index 64fcdfc18d081f..8feb41f2631f0f 100644 --- a/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir @@ -1,4 +1,5 @@ // RUN: tf-opt --split-input-file --pass-pipeline='builtin.module(tflite-retain-call-once-funcs)' %s | FileCheck %s +// REQUIRES: tf_tosa // CHECK-LABEL: module { module { diff --git a/tensorflow/compiler/mlir/tosa/tests/strip-quant-types.mlir b/tensorflow/compiler/mlir/tosa/tests/strip-quant-types.mlir index 82b856c1ffaba9..cea7ec359b27d1 100644 --- a/tensorflow/compiler/mlir/tosa/tests/strip-quant-types.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/strip-quant-types.mlir @@ -1,4 +1,5 @@ // RUN: tf-opt --split-input-file --tosa-strip-quant-types --verify-each %s | FileCheck %s +// REQUIRES: tf_tosa // ----- diff --git a/tensorflow/compiler/mlir/tosa/tests/strip_metadata.mlir b/tensorflow/compiler/mlir/tosa/tests/strip_metadata.mlir index f2198823a6dabf..5f75b923739d90 100644 --- a/tensorflow/compiler/mlir/tosa/tests/strip_metadata.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/strip_metadata.mlir @@ -1,4 +1,5 @@ // RUN: tf-opt --pass-pipeline='builtin.module(tosa-tflite-strip-module-metadata,func.func(tosa-tflite-strip-function-metadata))' %s | FileCheck %s +// REQUIRES: tf_tosa // CHECK-LABEL: module { // CHECK-NOT: tf.schema_version diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-tfl-to-tosa-pipeline.mlir index 4e0854ccd6f5a4..7eadb79b757bd4 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-tfl-to-tosa-pipeline.mlir @@ -1,4 +1,5 @@ // RUN: tf-opt --split-input-file --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s +// REQUIRES: tf_tosa // These tests focus on TensorFlow and TensorFlow Lite hybrid lowering and focus // on tfl.custom operations that are Flex ops. diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir index 3f3b7bcc9ef7a9..d9ebc6ce5c357e 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir @@ -1,5 +1,7 @@ // RUN: tf-opt --tf-to-tosa-pipeline --verify-each %s | FileCheck %s +// REQUIRES: tf_tosa // RUN: tf-opt --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s +// REQUIRES: tf_tosa // Operations for testing tf-to-tosa-pipeline // TODO: These tests are fairly minimal. Expand the checks to be more robust. diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-dequantize_softmax.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-dequantize_softmax.mlir index f03b7b4e0dc257..936dbf7c69c630 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-dequantize_softmax.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-dequantize_softmax.mlir @@ -1,4 +1,5 @@ // RUN: tf-opt --tosa-dequantize-tfl-softmax %s | FileCheck %s +// REQUIRES: tf_tosa // ----- diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir index 47aa6d56f57c59..dae91112503c55 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir @@ -1,4 +1,5 @@ // RUN: tf-opt --pass-pipeline='builtin.module(func.func(tosa-legalize-tfl{disable-patterns=TFLConv2D,TFLSoftmax, enable-patterns=TFLFullyConnected,TFLTranspose}))' %s | FileCheck %s +// REQUIRES: tf_tosa // ----- diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index e12c0a9ae0b38e..6db3322258821a 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -1,5 +1,7 @@ // RUN: tf-opt --split-input-file --tfl-to-tosa-pipeline --verify-each %s | FileCheck %s +// REQUIRES: tf_tosa // RUN: tf-opt --split-input-file --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s +// REQUIRES: tf_tosa // Operations for testing tfl-to-tosa-pipeline diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir index 6ad25dca4b8abd..2453efb5ca90eb 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir @@ -1,5 +1,7 @@ // RUN: tf-opt --split-input-file --tfl-to-tosa-pipeline --verify-each %s | FileCheck %s +// REQUIRES: tf_tosa // RUN: tf-opt --split-input-file --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s +// REQUIRES: tf_tosa // Operations for testing tfl-to-tosa-pipeline diff --git a/tensorflow/compiler/mlir/tosa/tests/verify_fully_converted.mlir b/tensorflow/compiler/mlir/tosa/tests/verify_fully_converted.mlir index 3783c379908a13..ac918b321356e8 100644 --- a/tensorflow/compiler/mlir/tosa/tests/verify_fully_converted.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/verify_fully_converted.mlir @@ -1,4 +1,5 @@ // RUN: tf-opt %s --tosa-tflite-verify-fully-converted --split-input-file -verify-diagnostics +// REQUIRES: tf_tosa // CHECK-LABEL: func.func @main func.func @main(%arg0: tensor<2xf32>) -> (tensor<2xf32>) { diff --git a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc index cb7dd32799f8de..81c981a448d914 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc @@ -23,12 +23,9 @@ limitations under the License. // 3. insert tosa.RESCALE int8 -> uint8 if original returned tensor is uint8 // typed. -#include #include #include -#include #include -#include #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -333,7 +330,7 @@ void ConvertUint8ToInt8::runOnOperation() { // Convert uint8 const tensor. const needs to be handled specifically. patterns.add(&ctx); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); // Replace uint8 tensor in the graph and insert rescale as needed. (void)convert_graph_uint8_tensor(ctx, func); diff --git a/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc b/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc index 5c8dd934fe8117..ea17d9160698eb 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc @@ -74,8 +74,7 @@ LogicalResult TosaDequantizeTFLSoftmaxPattern::matchAndRewrite( void TosaDequantizeTFLSoftmax::runOnOperation() { RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc index ff07b9d6f91039..7a0077dd72c5ec 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc @@ -15,12 +15,7 @@ limitations under the License. // Fuse tf.Op + tf.BiasAdd and legalized to TOSA -#include -#include -#include -#include #include -#include #include #include @@ -148,7 +143,7 @@ void FuseBiasTF::runOnOperation() { // Add the generated patterns to the list. patterns.add(ctx); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h index cfe063408edea0..d368dcd8b81d6b 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H_ #define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H_ +#include #include #include "mlir/IR/PatternMatch.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc index 496a4275c0007b..a4dd0712626c63 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc @@ -15,13 +15,11 @@ limitations under the License. // Legalize TensorFlow to TOSA -#include -#include +#include #include #include #include #include -#include #include #include diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc index af984eabf0fb70..3627546feb3239 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc @@ -16,11 +16,15 @@ limitations under the License. #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h" #include +#include +#include #include #include #include +#include #include +#include "absl/status/status.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -611,7 +615,7 @@ bool getPaddingValuesFromPadType(tensorflow::Padding tf_pad, ip_size = ip_size < 0 ? f_size * dim_dilation : ip_size; int64_t op_size, pad_before_tf, pad_after_tf; // Complains if using int64_T - tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerbose( + absl::Status status = tensorflow::GetWindowedOutputSizeVerbose( ip_size, f_size, dim_dilation, dim_stride, tf_pad, &op_size, &pad_before_tf, &pad_after_tf); if (!status.ok()) return false; @@ -792,7 +796,7 @@ LogicalResult ApplyPatternsWithShapeResolution( // during pattern rewrite. GreedyRewriteConfig config; config.useTopDownTraversal = true; - if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) { + if (failed(applyPatternsGreedily(func, patterns, config))) { return failure(); } diff --git a/tensorflow/compiler/mlir/tosa/transforms/lower_complex_types.cc b/tensorflow/compiler/mlir/tosa/transforms/lower_complex_types.cc index cf8655c4d59335..3ec3a37f167d5b 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/lower_complex_types.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/lower_complex_types.cc @@ -30,7 +30,7 @@ limitations under the License. // any remaining "unrealized_conversion_cast" operations and ensures the // resulting graph is free of illegal complex tensors. -#include +#include #include #include @@ -157,7 +157,7 @@ void LowerComplexTypes::runOnOperation() { // We need to run folders post rewrite to cleanup conversion casts. RewritePatternSet emptyRewriters(ctx); - if (failed(applyPatternsAndFoldGreedily(func, std::move(emptyRewriters)))) { + if (failed(applyPatternsGreedily(func, std::move(emptyRewriters)))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/tosa/transforms/strip_quant_types.cc b/tensorflow/compiler/mlir/tosa/transforms/strip_quant_types.cc index c63173bd2e9182..cddcc8d614c8a5 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/strip_quant_types.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/strip_quant_types.cc @@ -23,12 +23,7 @@ limitations under the License. // 3. insert tosa.RESCALE int8 -> uint8 if original returned tensor is uint8 // typed. -#include -#include -#include -#include #include -#include #include #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project diff --git a/tensorflow/compiler/tests/const_test.py b/tensorflow/compiler/tests/const_test.py index bb1f3e23a7306e..423c92f2abb015 100644 --- a/tensorflow/compiler/tests/const_test.py +++ b/tensorflow/compiler/tests/const_test.py @@ -48,6 +48,9 @@ def testConst(self): dtypes.float64, dtypes.float8_e5m2, dtypes.float8_e4m3fn, + dtypes.float8_e4m3fnuz, + dtypes.float8_e4m3b11fnuz, + dtypes.float8_e5m2fnuz, } for dtype in types: with self.subTest(dtype=dtype): diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index fe42e3f3807d0a..e2874487c2fb04 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -159,10 +159,10 @@ class OpTestBuilder { // sets it to the NodeDef of the operator under test. Fills 'inputs' and // 'outputs' with the names of the input placeholder nodes and the output // identity nodes, respectively. - Status BuildGraph(const string& name_prefix, const string& device, - bool use_jit, GraphDef* graphdef, NodeDef** test_node_def, - std::vector* inputs, - std::vector* outputs) const; + absl::Status BuildGraph(const string& name_prefix, const string& device, + bool use_jit, GraphDef* graphdef, + NodeDef** test_node_def, std::vector* inputs, + std::vector* outputs) const; struct InputDescription { Tensor tensor; @@ -245,11 +245,12 @@ OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, return *this; } -Status OpTestBuilder::BuildGraph(const string& name_prefix, - const string& device, bool use_jit, - GraphDef* graphdef, NodeDef** test_node_def, - std::vector* inputs, - std::vector* outputs) const { +absl::Status OpTestBuilder::BuildGraph(const string& name_prefix, + const string& device, bool use_jit, + GraphDef* graphdef, + NodeDef** test_node_def, + std::vector* inputs, + std::vector* outputs) const { OpRegistryInterface* op_registry = OpRegistry::Global(); const OpDef* op_def; @@ -1260,7 +1261,7 @@ OpTest::WindowedSpatialDims OpTest::ChooseWindowedSpatialDims( d.output_dims.resize(num_spatial_dims); d.stride_dims.resize(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { - Status s; + absl::Status s; // Repeatedly try different filter/stride sizes until we find a valid // combination. do { @@ -1388,8 +1389,8 @@ string Str(complex64 x) { } template -Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol, - double rtol) { +absl::Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol, + double rtol) { auto Tx = x.flat(); auto Ty = y.flat(); for (int i = 0; i < Tx.size(); ++i) { @@ -1405,7 +1406,7 @@ Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol, } template -Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { +absl::Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { auto Tx = x.flat(); auto Ty = y.flat(); for (int i = 0; i < Tx.size(); ++i) { @@ -1418,7 +1419,7 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { return absl::OkStatus(); } -Status TensorsAreEqualImplBfloat16(const Tensor& x, const Tensor& y) { +absl::Status TensorsAreEqualImplBfloat16(const Tensor& x, const Tensor& y) { auto Tx = x.flat(); auto Ty = y.flat(); for (int i = 0; i < Tx.size(); ++i) { @@ -1436,8 +1437,8 @@ Status TensorsAreEqualImplBfloat16(const Tensor& x, const Tensor& y) { // close values. For floating-point tensors, the element-wise difference between // x and y must no more than atol + rtol * abs(x). For non-floating-point // tensors the values must match exactly. -Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, - double rtol) { +absl::Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, + double rtol) { if (a.dtype() != b.dtype()) { return errors::InvalidArgument(absl::StrCat( "Tensors have different types: ", DataTypeString(a.dtype()), " and ", @@ -1511,7 +1512,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( GraphDef graph; std::vector expected_inputs, test_inputs; std::vector expected_fetches, test_fetches; - Status status = builder.BuildGraph( + absl::Status status = builder.BuildGraph( absl::StrCat("test", num_tests_, "_expected"), reference_device, /*use_jit=*/false, &graph, /*test_node_def=*/nullptr, &expected_inputs, &expected_fetches); @@ -1559,7 +1560,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( std::vector expected_outputs, test_outputs; VLOG(1) << "Running expected graph"; - Status s = + absl::Status s = session_->Run(expected_feeds, expected_fetches, {}, &expected_outputs); if (!s.ok()) { VLOG(1) << "Expected graph failed with status: " << s << ". Ignoring test"; diff --git a/tensorflow/compiler/tests/special_math_test.py b/tensorflow/compiler/tests/special_math_test.py index 989b6d57845462..0a07ce0be2a0e9 100644 --- a/tensorflow/compiler/tests/special_math_test.py +++ b/tensorflow/compiler/tests/special_math_test.py @@ -102,32 +102,24 @@ def _test_range(self, low, high, dtype, rtol, atol, is_negative=False): actual = sess.run(actual) self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 24-11-05 @parameterized.parameters((np.float32, 1e-7, 0.), (np.float64, 1e-15, 0.)) def testSmallX(self, dtype, rtol, atol): self._test_range(-40., -20., dtype, rtol, atol, is_negative=False) self._test_range(-40., -20., dtype, rtol, atol, is_negative=True) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 24-11-05 @parameterized.parameters((np.float32, 2e-7, 0.), (np.float64, 1e-15, 0.)) def testGreaterThanNegativeTwentyExponent(self, dtype, rtol, atol): self._test_range(-20., -10., dtype, rtol, atol, is_negative=False) self._test_range(-20., -10., dtype, rtol, atol, is_negative=True) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 24-11-05 @parameterized.parameters((np.float32, 2e-7, 0.), (np.float64, 1e-15, 0.)) def testGreaterThanNegativeTenExponent(self, dtype, rtol, atol): self._test_range(-10., -5., dtype, rtol, atol, is_negative=False) self._test_range(-10., -5., dtype, rtol, atol, is_negative=True) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 24-11-05 @parameterized.parameters((np.float32, 2e-7, 0.), (np.float64, 1e-15, 0.)) def testGreaterThanNegativeFiveExponent(self, dtype, rtol, atol): @@ -140,8 +132,6 @@ def testXGreaterThanOneTenth(self, dtype, rtol, atol): self._test_range(-1., 0., dtype, rtol, atol, is_negative=False) self._test_range(-1., 0., dtype, rtol, atol, is_negative=True) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 24-11-05 @parameterized.parameters((np.float32, 2e-7, 0.), (np.float64, 2e-15, 0.)) def testXGreaterThanOne(self, dtype, rtol, atol): diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 93fcd52c1e641e..206b6a9953d72f 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -894,7 +894,13 @@ def testCastFp8(self): # TODO(b/271327511): Fix issue where casts to FP8 very rarely result in # NaN on Mac self.skipTest("Casts to FP8 sometimes result in NaN on Mac") - fp8_types = {dtypes.float8_e5m2, dtypes.float8_e4m3fn} + fp8_types = { + dtypes.float8_e5m2, + dtypes.float8_e4m3fn, + dtypes.float8_e4m3fnuz, + dtypes.float8_e4m3b11fnuz, + dtypes.float8_e5m2fnuz, + } other_types = { dtypes.bool, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64 diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 9c7522a860ab19..9d7a3fc7f6e767 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -1068,7 +1068,6 @@ pybind_extension( "@com_google_protobuf//:__subpackages__", "@com_googlesource_code_re2//:__subpackages__", "@curl//:__subpackages__", - "@double_conversion//:__subpackages__", "@eigen_archive//:__subpackages__", "@farmhash_archive//:__subpackages__", "@fft2d//:__subpackages__", diff --git a/tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyi b/tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyi index 1ef7abbd7d14b6..865a06e069a7d3 100644 --- a/tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyi +++ b/tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyi @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -def get_linked_tensorrt_version() -> tuple[int,int,int]: ... -def get_loaded_tensorrt_version() -> tuple[int,int,int]: ... +def get_linked_tensorrt_version() -> tuple[int, int, int]: ... +def get_loaded_tensorrt_version() -> tuple[int, int, int]: ... def get_registered_op_converters() -> list[str]: ... def is_tensorrt_enabled() -> bool: ... diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc b/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc index c9050442e02ce5..5011ea8a39fbc8 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc @@ -13,9 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include -#include #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/stl.h" // from @pybind11 diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_allocator_test.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator_test.cc index e457c64928e5df..9a592e4b5944fa 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_allocator_test.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include + #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache_test.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache_test.cc index 0aa5eb8f7d4ad0..ff216ab23767f9 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache_test.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" +#include + #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 1f7961e2dac977..673b8182a35bdf 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -352,6 +352,10 @@ cc_library( # "@local_tsl//tsl/platform:thread_annotations", # "@local_tsl//tsl/platform:tstring", # "@local_tsl//tsl/platform:types", +# "@local_xla//xla/tsl/platform:env_time", +# "@local_xla//xla/tsl/platform:logging", +# "@local_xla//xla/tsl/platform:types", +# "@local_xla//xla/tsl/platform:macros", # "@local_xla//xla/tsl/platform/default:cord", # "@local_xla//xla/tsl/platform/default:env_time", # "@local_xla//xla/tsl/platform/default:logging", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 96a293b8676046..120deca79f84d3 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -34,7 +34,8 @@ namespace tensorflow { namespace { absl::Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime, - const NodeDef& node, StringPiece func_attr_name, + const NodeDef& node, + absl::string_view func_attr_name, const FunctionBody** fbody) { NameAttrList name_attr_list; TF_RETURN_IF_ERROR(GetNodeAttr(node, func_attr_name, &name_attr_list)); @@ -47,7 +48,7 @@ absl::Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime, absl::Status GetFunctionBodies(FunctionLibraryRuntime* flib_runtime, const NodeDef& node, - StringPiece func_list_attr_name, + absl::string_view func_list_attr_name, std::vector* fbodies) { std::vector name_attr_lists; TF_RETURN_IF_ERROR(GetNodeAttr(node, func_list_attr_name, &name_attr_lists)); diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index 92a644843c5d46..ba297127eae117 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -617,7 +617,8 @@ absl::Status Conditional::ExtractBodies(Graph* graph) { std::sort( in_edges.begin(), in_edges.end(), [](const Edge* a, const Edge* b) { int a_src_output = a->src_output(), b_src_output = b->src_output(); - StringPiece a_name(a->src()->name()), b_name(b->src()->name()); + absl::string_view a_name(a->src()->name()), + b_name(b->src()->name()); return std::tie(a_src_output, a_name) < std::tie(b_src_output, b_name); }); diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 73afe1909b4d92..2c02379c36cd45 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -79,7 +79,8 @@ absl::Status CopySubgraph(const Graph& graph, const WhileLoopFrame* frame, [](const Edge* a, const Edge* b) { int a_src_output = a->src_output(), b_src_output = b->src_output(); - StringPiece a_name(a->src()->name()), b_name(b->src()->name()); + absl::string_view a_name(a->src()->name()), + b_name(b->src()->name()); return std::tie(a_src_output, a_name) < std::tie(b_src_output, b_name); }); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 79f85051363aa5..326d9ab84e4dba 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -282,14 +282,19 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/framework:bounds_check", "//tensorflow/core/kernels:conv_grad_shape_utils", "//tensorflow/core/platform:statusor", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@local_xla//xla:literal_util", "@local_xla//xla:util", + "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder/lib:arithmetic", "@local_xla//xla/hlo/builder/lib:constants", @@ -331,6 +336,7 @@ cc_library( "//tensorflow/core/common_runtime:function_body", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@local_xla//xla:literal", "@local_xla//xla/hlo/builder:value_inference", @@ -496,6 +502,8 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/types:span", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder/lib:constants", "@local_xla//xla/hlo/builder/lib:dynamic_shaped_ops", @@ -516,6 +524,7 @@ tf_kernel_library( "//tensorflow/core:array_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:logging_ops_op_lib", + "@com_google_absl//absl/log", ], alwayslink = 1, ) @@ -594,6 +603,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/ops:xla_ops", + "//tensorflow/core:protos_all_cc", "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder/lib:math", @@ -853,6 +863,9 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/types:span", "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -893,6 +906,11 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder/lib:constants", "@local_xla//xla/hlo/builder/lib:math", ], @@ -925,6 +943,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", + "@com_google_absl//absl/status", "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1141,8 +1160,12 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", "@local_xla//xla:status_macros", + "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:client_library", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder/lib:slicing", @@ -1190,6 +1213,7 @@ tf_kernel_library( "@com_google_absl//absl/algorithm:container", "@local_xla//xla:comparison_util", "@local_xla//xla:shape_util", + "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder/lib:arithmetic", "@local_xla//xla/hlo/builder/lib:comparators", @@ -1215,6 +1239,7 @@ tf_kernel_library( "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", "@local_xla//xla:shape_util", "@local_xla//xla/hlo/builder:xla_builder", ], @@ -1375,6 +1400,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "@local_xla//xla:literal_util", "@local_xla//xla/hlo/builder:xla_builder", ], @@ -1532,7 +1558,12 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@local_xla//xla:util", + "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder/lib:constants", "@local_xla//xla/hlo/builder/lib:math", @@ -1571,6 +1602,9 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", "@local_xla//xla:literal_util", "@local_xla//xla/hlo/builder:xla_builder", ], @@ -1684,6 +1718,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:span", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:client_library", @@ -1784,9 +1820,11 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@local_xla//xla:shape_util", "@local_xla//xla/client:client_library", "@local_xla//xla/hlo/builder:xla_builder", @@ -1874,6 +1912,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder/lib:arithmetic", ], @@ -1892,6 +1932,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/lib:data_format", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1967,6 +2009,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/log", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder/lib:arithmetic", @@ -1988,7 +2031,9 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "@local_xla//xla:literal_util", + "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -2005,6 +2050,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/ops:xla_ops", + "@com_google_absl//absl/status:statusor", "@local_xla//xla:status_macros", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder/lib:arithmetic", @@ -2114,6 +2160,10 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder/lib:slicing", ], @@ -2317,6 +2367,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status:statusor", "@local_xla//xla:literal_util", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder/lib:constants", @@ -2337,6 +2388,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", + "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder/lib:constants", "@local_xla//xla/hlo/builder/lib:matrix", @@ -2356,6 +2408,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", + "@com_google_absl//absl/log", "@local_xla//xla:literal", "@local_xla//xla/hlo/builder:xla_builder", ], @@ -2442,6 +2495,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:span", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", @@ -2528,6 +2583,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder/lib:constants", ], @@ -2545,7 +2602,9 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", + "@com_google_absl//absl/status:statusor", "@local_xla//xla:shape_util", + "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/hlo/builder/lib:arithmetic", "@local_xla//xla/hlo/builder/lib:comparators", @@ -2606,6 +2665,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@local_xla//xla:literal", "@local_xla//xla/hlo/builder:value_inference", @@ -2747,6 +2808,8 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/tpu:tpu_defs", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@local_xla//xla:literal_util", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", @@ -2988,6 +3051,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", @@ -3010,6 +3074,9 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -3027,8 +3094,11 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:inlined_vector", "@local_xla//xla:literal_util", "@local_xla//xla:util", + "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -3124,7 +3194,9 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "@local_xla//xla:util", + "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder/lib:constants", "@local_xla//xla/hlo/builder/lib:math", @@ -3148,6 +3220,7 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "@local_xla//xla:shape_util", "@local_xla//xla:util", + "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder/lib:constants", "@local_xla//xla/hlo/builder/lib:matrix", @@ -3189,6 +3262,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", + "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:value_inference", "@local_xla//xla/hlo/builder:xla_builder", ], @@ -3255,6 +3329,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder/lib:matrix", ], @@ -3315,6 +3390,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:span", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder/lib:arithmetic", ], @@ -3333,6 +3410,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "@local_xla//xla/hlo/builder/lib:constants", ], ) diff --git a/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc index 95cd1f1a5c1c7d..a6ddbfd3a01fef 100644 --- a/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc @@ -13,8 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -24,6 +28,8 @@ limitations under the License. #include "xla/hlo/builder/lib/math.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/approx_topk_op.cc b/tensorflow/compiler/tf2xla/kernels/approx_topk_op.cc index 19c65b653fb54e..4134356d92491b 100644 --- a/tensorflow/compiler/tf2xla/kernels/approx_topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/approx_topk_op.cc @@ -13,9 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include -#include "absl/strings/match.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "xla/hlo/builder/lib/approx_topk.h" #include "xla/hlo/builder/xla_builder.h" diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 8d764de9b406a8..0c54ed8fdc576c 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -25,6 +25,7 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/assert_op.cc b/tensorflow/compiler/tf2xla/kernels/assert_op.cc index 8a863ea978d4b6..341a48de4264a3 100644 --- a/tensorflow/compiler/tf2xla/kernels/assert_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/assert_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/log/log.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index 11cf4682e810bf..9f5139de1c4ffa 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/hlo/builder/lib/math.h" #include "xla/hlo/builder/lib/matrix.h" #include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tsl/platform/tensor_float_32_utils.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 9e4703163e0f13..0dd528e3dea173 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA implementation of BatchNorm operations. #include +#include #include #include #include @@ -29,6 +30,8 @@ limitations under the License. #include "xla/hlo/builder/lib/math.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index b84733e7d55185..a4d9d37bd1ea09 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -14,13 +14,17 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/xla_builder.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index 95d9280924a1ab..7c89720292b0a7 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -16,14 +16,17 @@ limitations under the License. // XLA-specific Ops for broadcasting used in gradient // code. +#include #include +#include "absl/container/inlined_vector.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/value_inference.h" #include "xla/literal.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/beta_op.cc b/tensorflow/compiler/tf2xla/kernels/beta_op.cc index b504493b7ddb0e..4ead9f76fcee11 100644 --- a/tensorflow/compiler/tf2xla/kernels/beta_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/beta_op.cc @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - +#include "absl/status/statusor.h" #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index d0fb98c575f73d..2bf4ab52c8b59e 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 762f5a25c5f547..e9f571d830d619 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -15,9 +15,11 @@ limitations under the License. // Native XLA implementations of simple binary Ops +#include #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -32,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc index 374f05fa918a8c..5e0bd1829f1c07 100644 --- a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include +#include "absl/status/statusor.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -25,6 +27,7 @@ limitations under the License. #include "xla/hlo/builder/lib/constants.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc index d7fc2be632cd29..975179466bf104 100644 --- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "tensorflow/compiler/tf2xla/lib/broadcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc index e3e64b14dc5302..510d5225d6f04b 100644 --- a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -20,7 +20,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.cc b/tensorflow/compiler/tf2xla/kernels/case_op.cc index ab0a26b2f9fe37..cead6d10c2a0eb 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/case_op.cc @@ -15,10 +15,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/case_op.h" +#include #include #include #include +#include "absl/log/log.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" @@ -28,6 +31,8 @@ limitations under the License. #include "xla/hlo/builder/lib/constants.h" #include "xla/hlo/builder/lib/dynamic_shaped_ops.h" #include "xla/hlo/builder/xla_builder.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index ca7d3280cff15d..1779cfcc1ced40 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -26,6 +28,7 @@ limitations under the License. #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index cf3dbfa2655f27..e8c804791299a7 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -16,7 +16,9 @@ limitations under the License. // XLA implementations of Categorical op. #include +#include +#include "absl/log/log.h" #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc index 7039fa55651a16..6b4f278c72beff 100644 --- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/status/status.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/xla_builder.h" diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index 3d515693034ae3..bed3479941ca41 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -16,9 +16,10 @@ limitations under the License. // XLA-specific Concat Ops. #include -#include #include +#include "absl/log/log.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -33,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index a1eeea070f7f7d..d2463a9974b1bb 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index a202361a90b539..8d14995a11f3aa 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -18,11 +18,15 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" #include +#include #include #include #include +#include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -34,6 +38,7 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/node_def_util.h" @@ -42,6 +47,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/conv_grad_shape_utils.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" @@ -154,10 +160,11 @@ absl::Status CheckConvAttrs(const ConvOpAttrs& attrs) { // Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes // to TensorShapes. absl::Status ConvBackpropComputeDimensionsV2XlaShapes( - StringPiece label, int num_spatial_dims, const xla::Shape& input_shape, - const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape, - absl::Span dilations, const std::vector& strides, - Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims, + absl::string_view label, int num_spatial_dims, + const xla::Shape& input_shape, const xla::Shape& filter_shape, + const xla::Shape& out_backprop_shape, absl::Span dilations, + const std::vector& strides, Padding padding, + TensorFormat data_format, ConvBackpropDimensions* dims, absl::Span explicit_paddings) { TensorShape input_tensor_shape, filter_tensor_shape, out_backprop_tensor_shape; @@ -236,10 +243,9 @@ absl::StatusOr ConvNDOpAttrs::Create(OpKernelConstruction* ctx) { return attrs; } -absl::StatusOr MakeXlaForwardConvOp(StringPiece /*type_string*/, - xla::XlaOp conv_input, - xla::XlaOp filter, - const ConvOpAttrs& attrs) { +absl::StatusOr MakeXlaForwardConvOp( + absl::string_view /*type_string*/, xla::XlaOp conv_input, xla::XlaOp filter, + const ConvOpAttrs& attrs) { TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); auto* builder = conv_input.builder(); @@ -346,8 +352,8 @@ absl::StatusOr MakeXlaForwardConvOp(StringPiece /*type_string*/, } absl::StatusOr MakeXlaBackpropInputConvOp( - StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter, - xla::XlaOp out_backprop, const ConvOpAttrs& attrs, + absl::string_view type_string, const xla::Shape& input_shape, + xla::XlaOp filter, xla::XlaOp out_backprop, const ConvOpAttrs& attrs, xla::XlaOp* input_sizes) { TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); @@ -445,7 +451,7 @@ absl::StatusOr MakeXlaBackpropInputConvOp( } absl::StatusOr MakeXlaBackpropFilterConvOp( - StringPiece type_string, xla::XlaOp activations, + absl::string_view type_string, xla::XlaOp activations, const xla::Shape& filter_shape, xla::XlaOp gradients, const ConvOpAttrs& attrs) { TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h index ff0272f43fca9f..f53f9fd047851c 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -16,11 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ +#include #include +#include "absl/status/statusor.h" #include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" @@ -73,16 +76,16 @@ struct ConvNDOpAttrs { // Creates a new XLA forward or backward convolution with the given inputs and // attributes. -absl::StatusOr MakeXlaForwardConvOp(StringPiece type_string, +absl::StatusOr MakeXlaForwardConvOp(absl::string_view type_string, xla::XlaOp conv_input, xla::XlaOp filter, const ConvOpAttrs& attrs); absl::StatusOr MakeXlaBackpropInputConvOp( - StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter, - xla::XlaOp out_backprop, const ConvOpAttrs& attrs, + absl::string_view type_string, const xla::Shape& input_shape, + xla::XlaOp filter, xla::XlaOp out_backprop, const ConvOpAttrs& attrs, xla::XlaOp* input_sizes = nullptr); absl::StatusOr MakeXlaBackpropFilterConvOp( - StringPiece type_string, xla::XlaOp activations, + absl::string_view type_string, xla::XlaOp activations, const xla::Shape& filter_shape, xla::XlaOp gradients, const ConvOpAttrs& attrs); diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 3d876be0042949..273c16f89c9df7 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc index a7753644312856..42367723a40e89 100644 --- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "tensorflow/compiler/tf2xla/xla_helpers.h" diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index 9be97745d12023..d22e6eb74039b4 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -18,13 +18,16 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ +#include #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "xla/client/client_library.h" #include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/util/bcast.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc index 62c2ab5202f7a3..226d6248bd00d8 100644 --- a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc @@ -17,11 +17,15 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/lib/slicing.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index 93ca01039dda5f..e8e2babffd529c 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -13,8 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include "absl/log/check.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/lib/data_format.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc index c8c1705a52f801..d383c7d0ab4aa3 100644 --- a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -21,6 +22,7 @@ limitations under the License. #include "xla/hlo/builder/lib/constants.h" #include "xla/hlo/builder/lib/matrix.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/xla_data.pb.h" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/device_index_op.cc b/tensorflow/compiler/tf2xla/kernels/device_index_op.cc index d2726af1a2b10f..141415bcd0d8c0 100644 --- a/tensorflow/compiler/tf2xla/kernels/device_index_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/device_index_op.cc @@ -15,7 +15,6 @@ limitations under the License. #include -#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 4edc4143f1a80a..404fa9f5e04e45 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -14,8 +14,11 @@ limitations under the License. ==============================================================================*/ #include +#include #include +#include "absl/algorithm/container.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/type_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc index e5dcff94279c08..6e577f412fb304 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -30,6 +31,7 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/ops_util.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index f903d5fd130359..075002d39eed27 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -24,6 +26,7 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 8fb19b1c1c9dae..cb7e4f6f96437e 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -16,6 +16,7 @@ limitations under the License. // XLA-specific dynamic stitch Op. #include +#include #include #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -25,10 +26,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" +#include "xla/xla_data.pb.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/einsum_op.cc b/tensorflow/compiler/tf2xla/kernels/einsum_op.cc index d48d1fe84e67c9..c3e9b61962a388 100644 --- a/tensorflow/compiler/tf2xla/kernels/einsum_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/einsum_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/hlo/builder/lib/matrix.h" #include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/empty_op.cc b/tensorflow/compiler/tf2xla/kernels/empty_op.cc index decc24126d0f10..c0befe5d20229b 100644 --- a/tensorflow/compiler/tf2xla/kernels/empty_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/empty_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Empty Op. +#include #include #include "tensorflow/compiler/tf2xla/type_util.h" @@ -23,9 +24,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/lib/constants.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc b/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc index 11256663b59e97..3859779e8b52e5 100644 --- a/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc @@ -15,8 +15,7 @@ limitations under the License. // XLA-specific ensure_shape Op. -#include - +#include "absl/log/log.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index ded81d938d2baa..4a1de78d9371b3 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/shape_util.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc b/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc index 52412ee73f9ce8..57cdfe2cba4bf4 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "xla/hlo/builder/lib/constants.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 2fa32e1112f8e1..96d3c9bf08cc68 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -14,14 +14,17 @@ limitations under the License. ==============================================================================*/ #include +#include #include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/xla_builder.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index a9673934262d1f..8fb04773aafb49 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -15,9 +15,11 @@ limitations under the License. // XLA-specific Ops for FFT. +#include #include #include +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -25,6 +27,7 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -32,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index 89824e7a3313b5..6e5a1430538365 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Fill Op. +#include #include #include "tensorflow/compiler/tf2xla/type_util.h" @@ -23,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/value_inference.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc index 96aef937421f6d..b2b1eb3343e698 100644 --- a/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc @@ -13,17 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include -#include #include +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/lib/constants.h" #include "xla/hlo/builder/lib/math.h" +#include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 08285e0bccbc18..2108db386a7956 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include #include +#include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/types/optional.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" @@ -29,8 +31,10 @@ limitations under the License. #include "xla/hlo/builder/lib/slicing.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/status_macros.h" +#include "xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc index 305557cd773faa..033144e9f308e4 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/if_while_utils.h b/tensorflow/compiler/tf2xla/kernels/if_while_utils.h index eb103954ac8683..1800e5a6fdb714 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_while_utils.h +++ b/tensorflow/compiler/tf2xla/kernels/if_while_utils.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/status/status.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/core/common_runtime/function_body.h" diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 2213074a89d42e..9d874a856b3275 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 5d8981dd5e6e3d..58811c10744131 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include #include diff --git a/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc index a3d801a1a32819..f357262a39c35b 100644 --- a/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index 357ab3e9b0783d..2922fcf969d879 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/index_ops.h" +#include + #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index 8b2e29e29ca8ec..bb90bc8397657b 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc index eeb8617a61a39e..279007c8f64b2f 100644 --- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -17,7 +17,7 @@ limitations under the License. // input. #include -#include +#include #include #include "absl/container/flat_hash_set.h" diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index b4ea95e04a43b8..a3ee768c04e186 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc index af0f84aa2e1254..48e8f976cc67bb 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_inverse_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_inverse_op.cc index 4981751c489fa7..d733a1f7293c4d 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_inverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_inverse_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/lib/matrix.h" diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_solve_op.cc index 9b5530c569dd27..e74bd516d16b13 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_solve_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/lib/matrix.h" diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index 91d2d344b07ad0..17b5ae7a70375a 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 3556900f49b670..f20c2384b5333c 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/status/statusor.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc index e41db50beeec48..82dbfb3839312c 100644 --- a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc @@ -15,6 +15,8 @@ limitations under the License. // XLA implementation of OneHot operator. +#include + #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index a096b8f2a23e02..ba4e8bbef7b136 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -15,7 +15,6 @@ limitations under the License. // XLA Pack operator. -#include #include #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index fc19df334a4a75..1758451faf469f 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index e8e6ca0beb361e..6542abbd65433f 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -15,6 +15,8 @@ limitations under the License. // XLA specific pooling ops. +#include +#include #include #include #include diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index de7247399567e3..cac9f8a68f234e 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include #include "absl/types/span.h" diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index ab83bbbe7120b3..0c6137f6254627 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -17,6 +17,7 @@ limitations under the License. // TODO(misard,phawkins): handle random number generator seeds/states correctly. // TODO(misard,phawkins): add tests. +#include #include #include "absl/log/log.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 0c7e87015f940a..9c22222489f3a9 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ +#include #include #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index d1933ff4cff27c..58e1f992b9a74a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific reduction Ops. +#include #include #include "absl/container/inlined_vector.h" diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc index e4b08184ba5c43..6bf7cfc49560e8 100644 --- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/resampler_ops.h" +#include #include #include diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index df67f3f4938356..eb78eba56c11dc 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific reshape Op. +#include #include #include "absl/log/log.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index 5637d9091dd2fc..096241532bbb35 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific reverse Op. +#include #include #include "absl/container/inlined_vector.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 17b0f35fad3b81..5cecbf37706283 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/lib/constants.h" diff --git a/tensorflow/compiler/tf2xla/kernels/roll_op.cc b/tensorflow/compiler/tf2xla/kernels/roll_op.cc index 870c3092865367..0fcc6bec56095b 100644 --- a/tensorflow/compiler/tf2xla/kernels/roll_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/roll_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "absl/strings/str_cat.h" diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 1444abda838008..c183c1d36b5a4a 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include +#include +#include #include #include diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index 29281a7696e589..694b4eb17ef298 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "absl/status/status.h" diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 73c5a9c6ed98e6..21eaac25f058ed 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "tensorflow/compiler/tf2xla/lib/scatter.h" diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index fc9e96939b2c38..85aaabe87076c2 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include #include "absl/algorithm/container.h" diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 60a4a1a5bc62d1..108bf3848aae93 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -15,6 +15,9 @@ limitations under the License. // XLA-specific sequence and range Ops. +#include +#include + #include "absl/status/statusor.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.cc b/tensorflow/compiler/tf2xla/kernels/shape_util.cc index f217bc09ec79e1..57825657b205ab 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_util.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" +#include #include #include "absl/status/status.h" diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.h b/tensorflow/compiler/tf2xla/kernels/shape_util.h index 4ec37b1fe7cfda..bfce0919a48bfa 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_util.h +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/status/status.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/status.h" diff --git a/tensorflow/compiler/tf2xla/kernels/sharding_op.cc b/tensorflow/compiler/tf2xla/kernels/sharding_op.cc index a56dd7ed74791c..eb5615056faa9d 100644 --- a/tensorflow/compiler/tf2xla/kernels/sharding_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/sharding_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include diff --git a/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc b/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc index 63bdacfb795665..122aaacd5a4203 100644 --- a/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "absl/log/check.h" diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 35c936d5fb88db..844a31f97990fc 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Slice Op. +#include #include #include "absl/container/inlined_vector.h" diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 406b79d9981846..330479bc8d4150 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -15,6 +15,8 @@ limitations under the License. // XLA-specific Ops for softmax. +#include +#include #include #include diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 858233c28c8d03..d3804afd0f00d5 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 2648c0b077e689..ac33e0877200dc 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "absl/log/check.h" diff --git a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc index f3afba664bedbe..b4d589f183108e 100644 --- a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "tensorflow/compiler/tf2xla/lib/scatter.h" diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index ebef4cd81b2687..4f7c4ae99b6b6b 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Ops for split. +#include #include #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc b/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc index 496440e9cafbf3..124e36557f1429 100644 --- a/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index d8bd987232b569..69189b6b2ad9dd 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -15,7 +15,7 @@ limitations under the License. // XLA Stack operators. -#include +#include #include #include "absl/status/status.h" diff --git a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc index 01a44c9d734448..2a090b35f6eadf 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/core/kernels/stateful_random_ops.h" -#include +#include #include #include #include diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 7e8bf8f17e893c..aa71c5c34d2e1a 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include "absl/status/status.h" diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc index 021c22f247ff9e..ce1fee91ae6a51 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc @@ -15,7 +15,8 @@ limitations under the License. #include "tensorflow/core/kernels/stateless_random_ops_v2.h" -#include +#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 2a31e5f15fe5e4..2189d6b035f3ad 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/util/strided_slice_op.h" #include +#include #include #include "absl/algorithm/container.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 25110df1c7d733..5a5fe142b72008 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -15,7 +15,7 @@ limitations under the License. // XLA TensorArray operators. -#include +#include #include #include "absl/log/check.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 76257c25a932c6..176844bf6f1289 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -15,7 +15,7 @@ limitations under the License. // XLA TensorList operators. -#include +#include #include #include diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index 37d0ae44178998..830c8b9abd49c3 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" +#include #include #include "absl/log/log.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h index a86336ce79454c..e4aeb015034463 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h @@ -16,8 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_ +#include #include +#include "absl/status/status.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index d6bf070137f226..6c39981ba5b937 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Tile Op. +#include #include #include "absl/algorithm/container.h" diff --git a/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc b/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc index fddfbb288124f0..c53c06fa09953d 100644 --- a/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include "absl/status/status.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index a8003fbb9927d5..422bef6ba3fbaa 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/status/statusor.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 3d6beb1c1a1120..039320573f4558 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -18,9 +18,12 @@ limitations under the License. // handles all transposes, while Eigen needs a restricted DoTranspose // helper. +#include +#include #include #include "absl/container/inlined_vector.h" +#include "absl/status/status.h" #include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 5eb6438f89d322..c424236303b9d4 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -15,6 +15,8 @@ limitations under the License. // Native XLA implementations of simple unary Ops +#include + #include "absl/status/statusor.h" #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/unique_op.cc b/tensorflow/compiler/tf2xla/kernels/unique_op.cc index 00d11ef7f34543..9730427dff3b5d 100644 --- a/tensorflow/compiler/tf2xla/kernels/unique_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unique_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include #include #include #include diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index 0fc6e3e317c30b..cca29f7f585907 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -15,7 +15,7 @@ limitations under the License. // XLA Unpack operator. -#include +#include #include #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc index 13ac54b85463df..a174af17ea465f 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include diff --git a/tensorflow/compiler/tf2xla/kernels/xla_custom_call_v2_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_custom_call_v2_op.cc index 3a2e8015c1037e..33eae19ff81cfb 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_custom_call_v2_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_custom_call_v2_op.cc @@ -13,9 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include -#include #include #include "absl/status/status.h" diff --git a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc index 2341a820ea921a..c4a041acff5206 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "absl/algorithm/container.h" diff --git a/tensorflow/compiler/tf2xla/lib/data_format.cc b/tensorflow/compiler/tf2xla/lib/data_format.cc index a087abc806e5d7..92664808961f63 100644 --- a/tensorflow/compiler/tf2xla/lib/data_format.cc +++ b/tensorflow/compiler/tf2xla/lib/data_format.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/data_format.h" +#include +#include + #include "absl/status/statusor.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 086684336b6de5..af347ca4949947 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/scatter.h" +#include #include -#include #include #include "absl/log/log.h" diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 550f77f0dccfb0..0f99dfac92cc19 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" +#include + #include "absl/log/log.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 24f66027dd3ce5..eaf5218847e873 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ +#include + #include "absl/types/span.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index c9906ada9c1254..9cc8787d44b6ca 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -55,7 +55,7 @@ AttrValue TypeAttrValue(DataType type) { return attr_value; } -AttrValue StringAttrValue(StringPiece str) { +AttrValue StringAttrValue(absl::string_view str) { AttrValue attr_value; SetAttrValue(str, &attr_value); return attr_value; diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index 6383d277be852d..b2d8a878cc4568 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -71,6 +71,15 @@ absl::Status DataTypeToPrimitiveType(DataType data_type, case tensorflow::DT_FLOAT8_E4M3FN: *type = xla::F8E4M3FN; return absl::OkStatus(); + case tensorflow::DT_FLOAT8_E4M3FNUZ: + *type = xla::F8E4M3FNUZ; + return absl::OkStatus(); + case tensorflow::DT_FLOAT8_E4M3B11FNUZ: + *type = xla::F8E4M3B11FNUZ; + return absl::OkStatus(); + case tensorflow::DT_FLOAT8_E5M2FNUZ: + *type = xla::F8E5M2FNUZ; + return absl::OkStatus(); case tensorflow::DT_BFLOAT16: *type = xla::BF16; return absl::OkStatus(); @@ -103,6 +112,9 @@ absl::StatusOr EncodePrimitiveTypeAsDataType( {xla::PRED, DT_BOOL}, {xla::F8E5M2, DT_FLOAT8_E5M2}, {xla::F8E4M3FN, DT_FLOAT8_E4M3FN}, + {xla::F8E4M3FNUZ, DT_FLOAT8_E4M3FNUZ}, + {xla::F8E4M3B11FNUZ, DT_FLOAT8_E4M3B11FNUZ}, + {xla::F8E5M2FNUZ, DT_FLOAT8_E5M2FNUZ}, {xla::BF16, DT_BFLOAT16}, {xla::F16, DT_HALF}, {xla::F32, DT_FLOAT}, diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 11bbbf2b928871..5eaf0fb2d42bfa 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -65,19 +65,57 @@ constexpr std::array kNumericTypes = { DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BFLOAT16}}; -constexpr std::array kCpuAllTypes = { - {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, - DT_INT8, DT_QINT8, DT_INT16, DT_INT32, DT_QINT32, - DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, - DT_COMPLEX128, DT_BOOL, DT_BFLOAT16, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN, - DT_INT4, DT_UINT4}}; - -constexpr std::array kGpuAllTypes = { - {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, - DT_INT8, DT_QINT8, DT_INT16, DT_INT32, DT_QINT32, - DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, - DT_COMPLEX128, DT_BOOL, DT_BFLOAT16, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN, - DT_INT4, DT_UINT4}}; +constexpr std::array kCpuAllTypes = {{DT_UINT8, + DT_QUINT8, + DT_UINT16, + DT_UINT32, + DT_UINT64, + DT_INT8, + DT_QINT8, + DT_INT16, + DT_INT32, + DT_QINT32, + DT_INT64, + DT_HALF, + DT_FLOAT, + DT_DOUBLE, + DT_COMPLEX64, + DT_COMPLEX128, + DT_BOOL, + DT_BFLOAT16, + DT_FLOAT8_E5M2, + DT_FLOAT8_E4M3FN, + DT_FLOAT8_E4M3FNUZ, + DT_FLOAT8_E4M3B11FNUZ, + DT_FLOAT8_E5M2FNUZ, + DT_INT4, + DT_UINT4}}; + +constexpr std::array kGpuAllTypes = {{DT_UINT8, + DT_QUINT8, + DT_UINT16, + DT_UINT32, + DT_UINT64, + DT_INT8, + DT_QINT8, + DT_INT16, + DT_INT32, + DT_QINT32, + DT_INT64, + DT_HALF, + DT_FLOAT, + DT_DOUBLE, + DT_COMPLEX64, + DT_COMPLEX128, + DT_BOOL, + DT_BFLOAT16, + DT_FLOAT8_E5M2, + DT_FLOAT8_E4M3FN, + DT_FLOAT8_E4M3FNUZ, + DT_FLOAT8_E4M3B11FNUZ, + DT_FLOAT8_E5M2FNUZ, + DT_INT4, + DT_UINT4}}; // Class that manages registrations of operators and devices for the XLA JIT. // Not thread-safe. diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 1d066b16f2b997..d85982681de263 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -166,6 +166,7 @@ package_group( name = "friends", packages = if_google([ "//learning/brain/...", + "//third_party/car/...", "//tensorflow/...", "@tf_runtime//...", "//third_party/tf_runtime_google/...", @@ -254,6 +255,7 @@ cc_library( "@local_tsl//tsl/platform:lib_proto_parsing_hdrs", ], copts = tf_copts(), + visibility = ["//visibility:public"], deps = tf_lib_proto_parsing_deps() + [ ":platform_base", "//tensorflow/core/lib/core:errors", @@ -271,7 +273,6 @@ cc_library( "//tensorflow/core/platform:tstring", "//tensorflow/core/platform:types", "@com_google_absl//absl/strings", - "@double_conversion//:double-conversion", ], ) @@ -1283,6 +1284,7 @@ cc_library( "@eigen_archive//:eigen3", "@ml_dtypes//:float8", "@ml_dtypes//:intn", + "@ml_dtypes//:mxfloat", ] + if_static([":lib_internal_impl"]), ) @@ -1311,6 +1313,7 @@ cc_library( "@eigen_archive//:eigen3", "@ml_dtypes//:float8", "@ml_dtypes//:intn", + "@ml_dtypes//:mxfloat", ], ) @@ -1455,11 +1458,11 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", - "@double_conversion//:double-conversion", "@eigen_archive//:eigen3", "@local_xla//xla/tsl/lib/math:math_util", "@ml_dtypes//:float8", "@ml_dtypes//:intn", + "@ml_dtypes//:mxfloat", "@snappy", "@zlib", ] + select({ @@ -1519,6 +1522,7 @@ alias( alias( name = "jpeg_internal", actual = "//tensorflow/core/lib/jpeg:jpeg_internal", + visibility = ["//visibility:public"], ) cc_library( diff --git a/tensorflow/core/activity_watcher/BUILD b/tensorflow/core/activity_watcher/BUILD index 31c4408420efec..159d5ef7b0b938 100644 --- a/tensorflow/core/activity_watcher/BUILD +++ b/tensorflow/core/activity_watcher/BUILD @@ -23,7 +23,7 @@ cc_library( deps = [ "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/platform:types", + "@local_xla//xla/tsl/platform:types", ] + if_not_mobile([ ":activity_watcher_impl", ]), @@ -39,7 +39,7 @@ cc_library( deps = [ "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/platform:types", + "@local_xla//xla/tsl/platform:types", ], alwayslink = True, ) @@ -52,6 +52,6 @@ cc_library( ":activity_watcher", "//tensorflow/core:framework", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:types", + "@local_xla//xla/tsl/platform:types", ], ) diff --git a/tensorflow/core/activity_watcher/activity.h b/tensorflow/core/activity_watcher/activity.h index 334a58d45190ba..eecd207a33fe27 100644 --- a/tensorflow/core/activity_watcher/activity.h +++ b/tensorflow/core/activity_watcher/activity.h @@ -21,8 +21,8 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" namespace tsl { class CoordinationServiceAgent; diff --git a/tensorflow/core/activity_watcher/activity_utils.cc b/tensorflow/core/activity_watcher/activity_utils.cc index fba695d97a53e3..b3631076c5c2d9 100644 --- a/tensorflow/core/activity_watcher/activity_utils.cc +++ b/tensorflow/core/activity_watcher/activity_utils.cc @@ -20,9 +20,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "xla/tsl/platform/types.h" #include "tensorflow/core/activity_watcher/activity.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tsl/platform/types.h" namespace tensorflow { namespace activity_watcher { diff --git a/tensorflow/core/activity_watcher/activity_utils.h b/tensorflow/core/activity_watcher/activity_utils.h index 840f04fad7d393..64958cd5e09744 100644 --- a/tensorflow/core/activity_watcher/activity_utils.h +++ b/tensorflow/core/activity_watcher/activity_utils.h @@ -17,8 +17,8 @@ limitations under the License. #include +#include "xla/tsl/platform/types.h" #include "tensorflow/core/activity_watcher/activity.h" -#include "tsl/platform/types.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index f2269ba38722e0..b21e46196adee3 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -620,6 +620,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "@local_tsl//tsl/platform:stacktrace", ], ) @@ -714,6 +715,7 @@ cc_library( "//tensorflow/core/framework:node_def_proto_cc", "//tensorflow/core/framework:tensor_proto_cc", "//tensorflow/core/platform:errors", + "@com_google_absl//absl/strings:string_view", ], ) @@ -2210,7 +2212,6 @@ cc_library( "//tensorflow:internal", # For xla_launch_util "//tensorflow/compiler/jit:__pkg__", - "//tensorflow_models:__subpackages__", ], deps = [ ":device", diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc index ea16129c33cd42..cd1f2c18d6e118 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc @@ -339,7 +339,7 @@ bool ParseRingOrder(const string& gpu_ring_order_str, TaskDeviceMap* tdm) { for (int32_t rank = 0; rank < static_cast(split_gpu_ring_order_str.size()); ++rank) { int32_t tmp; - if (strings::safe_strto32(split_gpu_ring_order_str[rank], &tmp)) { + if (absl::SimpleAtoi(split_gpu_ring_order_str[rank], &tmp)) { gpu_ranks[tmp] = rank; } else { return false; diff --git a/tensorflow/core/common_runtime/colocation_graph.cc b/tensorflow/core/common_runtime/colocation_graph.cc index a97d03ef2f8d70..4daec16026ef78 100644 --- a/tensorflow/core/common_runtime/colocation_graph.cc +++ b/tensorflow/core/common_runtime/colocation_graph.cc @@ -59,8 +59,9 @@ namespace { // We hoist the conversion from C-style string literal to StringPiece here, // so that we can avoid the many repeated calls to strlen(). -const StringPiece kColocationAttrNameStringPiece(kColocationAttrName); -const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix); +const absl::string_view kColocationAttrNameStringPiece(kColocationAttrName); +const absl::string_view kColocationGroupPrefixStringPiece( + kColocationGroupPrefix); // Using absl::StrJoin with lambda does not work in tf-lite builds. std::vector DevicesToString(const std::vector devices) { @@ -668,7 +669,7 @@ absl::Status ColocationGraph::ColocateAllNodes() { // 'string' values stored in NodeDef attribute lists, as well as StringPiece // values that refer to 'string' values from NodeDef::name(), without // performing any string allocations. - std::unordered_map + std::unordered_map colocation_group_root; for (const Node* node : graph_.op_nodes()) { @@ -685,7 +686,7 @@ absl::Status ColocationGraph::ColocateAllNodes() { if (attr_value != nullptr) { if (attr_value->has_list()) { for (const string& class_spec : attr_value->list().s()) { - StringPiece spec(class_spec); + absl::string_view spec(class_spec); if (absl::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) { TF_RETURN_IF_ERROR( ColocateNodeToGroup(&colocation_group_root, node, spec)); @@ -1071,9 +1072,9 @@ absl::Status ColocationGraph::ApplyIOColocationGroups( } absl::Status ColocationGraph::ColocateNodeToGroup( - std::unordered_map* + std::unordered_map* colocation_group_root, - const Node* node, StringPiece colocation_group) { + const Node* node, absl::string_view colocation_group) { const Node*& root_node = (*colocation_group_root)[colocation_group]; if (root_node == nullptr) { // This is the first node of the colocation group, so diff --git a/tensorflow/core/common_runtime/colocation_graph.h b/tensorflow/core/common_runtime/colocation_graph.h index 887ac205393f38..a31a2aadca83b2 100644 --- a/tensorflow/core/common_runtime/colocation_graph.h +++ b/tensorflow/core/common_runtime/colocation_graph.h @@ -333,9 +333,9 @@ class ColocationGraph { const Node& node); absl::Status ColocateNodeToGroup( - std::unordered_map* + std::unordered_map* colocation_group_root, - const Node* node, StringPiece colocation_group); + const Node* node, absl::string_view colocation_group); // Merge the (possibly disjoint) sets containing nodes "x" and // "y". Returns OK if the all nodes in the union of these sets can diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index d4b27716a217a7..481a85add4893c 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -691,7 +691,7 @@ class TestTFFileSystem : public ::tensorflow::NullFileSystem { return ::tensorflow::errors::Unimplemented( "NewReadOnlyMemoryRegionFromFile unimplemented"); } - const ::tensorflow::StringPiece sp = data_tensor_.tensor_data(); + const absl::string_view sp = data_tensor_.tensor_data(); *result = std::unique_ptr<::tensorflow::ReadOnlyMemoryRegion>( new TestReadOnlyMemoryRegion(sp.data(), sp.size())); return absl::OkStatus(); diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index dadaaf0cd61f2d..c396cd28dd085e 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -53,7 +53,7 @@ std::vector* MutableRegistry() { } void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator, - Allocator* out_allocator, StringPiece edge_name, + Allocator* out_allocator, absl::string_view edge_name, Device* dst, Tensor* output, DeviceContext* recv_dev_context, StatusCallback done, bool sync_dst_compute) { @@ -199,7 +199,8 @@ void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function, } // namespace // static -void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context, +void CopyTensor::ViaDMA(absl::string_view edge_name, + DeviceContext* send_dev_context, DeviceContext* recv_dev_context, Device* src, Device* dst, const AllocatorAttributes src_alloc_attr, const AllocatorAttributes dst_alloc_attr, @@ -338,7 +339,7 @@ REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE); } // namespace void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator, - Allocator* out_allocator, StringPiece edge_name, + Allocator* out_allocator, absl::string_view edge_name, Device* src, Tensor* output, DeviceContext* send_dev_context, StatusCallback done) { if (input->dtype() == DT_VARIANT) { diff --git a/tensorflow/core/common_runtime/copy_tensor.h b/tensorflow/core/common_runtime/copy_tensor.h index 80187bde94b4b6..0f621603f2cd7d 100644 --- a/tensorflow/core/common_runtime/copy_tensor.h +++ b/tensorflow/core/common_runtime/copy_tensor.h @@ -40,7 +40,8 @@ class CopyTensor { // the type of devices and memory in use, the copy may be performed // synchronously or asynchronously. 'done' will be invoked only // after the copy is actually complete. - static void ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context, + static void ViaDMA(absl::string_view edge_name, + DeviceContext* send_dev_context, DeviceContext* recv_dev_context, Device* src, Device* dst, const AllocatorAttributes src_alloc_attr, const AllocatorAttributes dst_alloc_attr, @@ -70,7 +71,7 @@ class CopyTensor { }; void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator, - Allocator* out_allocator, StringPiece edge_name, + Allocator* out_allocator, absl::string_view edge_name, Device* src, Tensor* output, DeviceContext* send_dev_context, StatusCallback done); diff --git a/tensorflow/core/common_runtime/cost_constants.h b/tensorflow/core/common_runtime/cost_constants.h index 4eb71edccb2470..df01bf53826e0f 100644 --- a/tensorflow/core/common_runtime/cost_constants.h +++ b/tensorflow/core/common_runtime/cost_constants.h @@ -19,6 +19,7 @@ limitations under the License. namespace tensorflow { // Types of per-request cost. +inline constexpr char kGpuCostName[] = "gpu"; inline constexpr char kTpuCostName[] = "tpu"; inline constexpr char kGcuCostName[] = "gcu"; inline constexpr char kNoOpCostName[] = "no_op"; @@ -40,6 +41,13 @@ inline constexpr char kTpuDecodeNoSmearCostName[] = "tpu_decode_no_smear"; inline constexpr char kTpuPrefillWithSmearCostName[] = "tpu_prefill_with_smear"; inline constexpr char kTpuPrefillNoSmearCostName[] = "tpu_prefill_no_smear"; inline constexpr char kTpuNonBatchingCostName[] = "tpu_non_batching"; +inline constexpr char kGpuWithSmearCostName[] = "gpu_with_smear"; +inline constexpr char kGpuNoSmearCostName[] = "gpu_no_smear"; +inline constexpr char kGpuDecodeWithSmearCostName[] = "gpu_decode_with_smear"; +inline constexpr char kGpuDecodeNoSmearCostName[] = "gpu_decode_no_smear"; +inline constexpr char kGpuPrefillWithSmearCostName[] = "gpu_prefill_with_smear"; +inline constexpr char kGpuPrefillNoSmearCostName[] = "gpu_prefill_no_smear"; +inline constexpr char kGpuNonBatchingCostName[] = "gpu_non_batching"; inline constexpr char kGcuWithSmearCostName[] = "gcu_with_smear"; inline constexpr char kGcuNoSmearCostName[] = "gcu_no_smear"; inline constexpr char kGcuNonBatchingCostName[] = "gcu_non_batching"; diff --git a/tensorflow/core/common_runtime/device/device_utils.cc b/tensorflow/core/common_runtime/device/device_utils.cc index 60ec4cd0082a67..dcbd6c192d3f93 100644 --- a/tensorflow/core/common_runtime/device/device_utils.cc +++ b/tensorflow/core/common_runtime/device/device_utils.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { namespace device_utils { -absl::Status ValidateDeviceType(StringPiece type) { +absl::Status ValidateDeviceType(absl::string_view type) { static const LazyRE2 kTfDeviceTypeRegEx = {"[A-Z][A-Z_]*"}; bool matches = RE2::FullMatch(type, *kTfDeviceTypeRegEx); if (!matches) { diff --git a/tensorflow/core/common_runtime/device/device_utils.h b/tensorflow/core/common_runtime/device/device_utils.h index 05c52e0aa92081..5447c7291d0404 100644 --- a/tensorflow/core/common_runtime/device/device_utils.h +++ b/tensorflow/core/common_runtime/device/device_utils.h @@ -33,7 +33,7 @@ namespace device_utils { // Note that lowercase "cpu" and "gpu" are currently supported only for // legacy reasons: // https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/python/framework/device_spec.py;l=46;drc=d3a378f9665d8eee827c74cb9ecbee81e4c288dd -absl::Status ValidateDeviceType(StringPiece type); +absl::Status ValidateDeviceType(absl::string_view type); } // namespace device_utils } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h index 87fe86835419c5..3e0abb149e8e9b 100644 --- a/tensorflow/core/common_runtime/device_mgr.h +++ b/tensorflow/core/common_runtime/device_mgr.h @@ -56,7 +56,7 @@ class DeviceMgr { // Assigns *device with pointer to Device of the given name. // Accepts either a full device name, or just the replica-local suffix. - virtual absl::Status LookupDevice(StringPiece name, + virtual absl::Status LookupDevice(absl::string_view name, Device** device) const = 0; // Check if the current device manager contains device with the given @@ -101,7 +101,8 @@ class DynamicDeviceMgr : public DeviceMgr { std::vector ListDevices() const override; string DebugString() const override; string DeviceMappingString() const override; - absl::Status LookupDevice(StringPiece name, Device** device) const override; + absl::Status LookupDevice(absl::string_view name, + Device** device) const override; bool ContainsDevice(int64_t device_incarnation) const override; void ClearContainers(absl::Span containers) const override; int NumDeviceType(const string& type) const override; diff --git a/tensorflow/core/common_runtime/device_propagation.h b/tensorflow/core/common_runtime/device_propagation.h index f70ac8001f262a..20f5f9164f7376 100644 --- a/tensorflow/core/common_runtime/device_propagation.h +++ b/tensorflow/core/common_runtime/device_propagation.h @@ -27,7 +27,7 @@ namespace tensorflow { namespace device_propagation { -typedef std::function DeviceFilter; +typedef std::function DeviceFilter; typedef std::function NodeFilter; } // namespace device_propagation diff --git a/tensorflow/core/common_runtime/device_propagation_test.cc b/tensorflow/core/common_runtime/device_propagation_test.cc index d38965b10c53c6..6b751d4841fafe 100644 --- a/tensorflow/core/common_runtime/device_propagation_test.cc +++ b/tensorflow/core/common_runtime/device_propagation_test.cc @@ -39,7 +39,7 @@ const char kTpu1[] = "/job:localhost/replica:0/task:0/device:TPU:1"; const char kTpu2[] = "/job:localhost/replica:0/task:0/device:TPU:2"; const char kGpu0[] = "/job:localhost/replica:0/task:0/device:GPU:0"; -bool IsTPUDevice(StringPiece device_name) { +bool IsTPUDevice(absl::string_view device_name) { return absl::StrContains(device_name, "device:TPU:"); } diff --git a/tensorflow/core/common_runtime/device_set.cc b/tensorflow/core/common_runtime/device_set.cc index 69e940398f0673..205f5c4bf1cf01 100644 --- a/tensorflow/core/common_runtime/device_set.cc +++ b/tensorflow/core/common_runtime/device_set.cc @@ -80,7 +80,7 @@ static bool DeviceTypeComparator(const DeviceType& a, const DeviceType& b) { return a_priority > b_priority; } - return StringPiece(a.type()) < StringPiece(b.type()); + return absl::string_view(a.type()) < absl::string_view(b.type()); } std::vector DeviceSet::PrioritizedDeviceTypeList() const { @@ -134,7 +134,8 @@ void DeviceSet::SortPrioritizedDeviceVector(PrioritizedDeviceVector* vector) { return a.first->IsLocal(); } - return StringPiece(a.first->name()) < StringPiece(b.first->name()); + return absl::string_view(a.first->name()) < + absl::string_view(b.first->name()); }; std::sort(vector->begin(), vector->end(), device_sort); } diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index cb741eb2f862ba..b9b97de779901a 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1343,7 +1343,7 @@ absl::Status DirectSession::CreateExecutors( if (run_state_args->is_partial_run) { ek->graph = std::move(run_state_args->graph); - std::unordered_set names; + std::unordered_set names; for (const string& input : callable_options.feed()) { TensorId id(ParseTensorName(input)); names.emplace(id.first); diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index a0ee4c32471d81..c43827eede496d 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -65,7 +65,8 @@ class DirectSession : public Session { ~DirectSession() override; typedef std::vector> NamedTensorList; - typedef std::unordered_map NameNodeMap; + typedef std::unordered_map + NameNodeMap; absl::Status Create(const GraphDef& graph) override; absl::Status Create(GraphDef&& graph) override; diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 6850a0ef0082e3..4bfa85ccc44933 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -2569,8 +2569,8 @@ void TestFeedAndFetchTensorsInDeviceMemory( << DataType_Name(dtype); TF_ASSERT_OK(session->ReleaseCallable(handle)) << DataType_Name(dtype); ASSERT_EQ(1, outputs.size()); - const StringPiece actual_data = outputs[0].tensor_data(); - const StringPiece expected_data = host_tensor.tensor_data(); + const absl::string_view actual_data = outputs[0].tensor_data(); + const absl::string_view expected_data = host_tensor.tensor_data(); EXPECT_EQ(expected_data.size(), actual_data.size()) << DataType_Name(dtype); EXPECT_EQ(0, memcmp(expected_data.data(), actual_data.data(), std::min(expected_data.size(), actual_data.size()))) diff --git a/tensorflow/core/common_runtime/dynamic_device_mgr.cc b/tensorflow/core/common_runtime/dynamic_device_mgr.cc index d1f8fd52c338d8..f3158c29c80392 100644 --- a/tensorflow/core/common_runtime/dynamic_device_mgr.cc +++ b/tensorflow/core/common_runtime/dynamic_device_mgr.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tsl/platform/stacktrace.h" namespace tensorflow { @@ -55,6 +56,8 @@ DynamicDeviceMgr::DynamicDeviceMgr(std::unique_ptr&& device) DynamicDeviceMgr::~DynamicDeviceMgr() { // Release resources ahead of destroying the device manager as the resource // destructors (e.g. ~IteratorResource) assume devices still exist. + VLOG(1) << "DynamicDeviceMgr::~DynamicDeviceMgr @@stacktrace\n " + << tsl::CurrentStackTrace(); mutex_lock l(devices_mu_); for (const auto& it : dynamic_devices_) { // TODO(tf-runtime-team): clear devices' resource mgr in devices' @@ -104,12 +107,12 @@ string DynamicDeviceMgr::DeviceMappingString() const { return out; } -absl::Status DynamicDeviceMgr::LookupDevice(StringPiece name, +absl::Status DynamicDeviceMgr::LookupDevice(absl::string_view name, Device** device) const { tf_shared_lock l(devices_mu_); auto iter = device_map_.find(string(name)); if (iter == device_map_.end()) { - std::vector device_names; + std::vector device_names; device_names.reserve(device_map_.size()); for (auto&& itr : device_map_) { device_names.push_back(itr.first); diff --git a/tensorflow/core/common_runtime/eager/attr_builder.cc b/tensorflow/core/common_runtime/eager/attr_builder.cc index 1f27eaf6d64f19..9852cce5ee3413 100644 --- a/tensorflow/core/common_runtime/eager/attr_builder.cc +++ b/tensorflow/core/common_runtime/eager/attr_builder.cc @@ -161,7 +161,7 @@ DEFINE_GET_ATTR(tensorflow::DataType, type, "type"); #undef DEFINE_GET_ATTR template <> -absl::Status AttrBuilder::Get(StringPiece attr_name, +absl::Status AttrBuilder::Get(absl::string_view attr_name, absl::InlinedVector* value) const { auto it = encoded_attrs_.find(string(attr_name)); if (it == encoded_attrs_.end()) { @@ -236,7 +236,7 @@ void AttrBuilder::FillAttrValueMapWithoutDefaults(AttrValueMap* m) const { } } -void AttrBuilder::AddAttrIfNotPresent(StringPiece attr_name, +void AttrBuilder::AddAttrIfNotPresent(absl::string_view attr_name, const AttrValue& value) { encoded_attrs_.emplace(string(attr_name), value.SerializeAsString()); } @@ -284,19 +284,19 @@ void CombineUnordered(const tensorflow::Fprint128& a, b->high64 += a.high64; } -inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, +inline tensorflow::Fprint128 CacheKeyHelper(absl::string_view s, const tensorflow::Fprint128& b) { tensorflow::Fprint128 a = tensorflow::Fingerprint128(s); return FingerprintCat128(a, b); } -inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) { +inline tensorflow::Fprint128 CacheKeyHelper(absl::string_view s, uint64 b) { return CacheKeyHelper(s, {b, b}); } } // namespace -tensorflow::Fprint128 AttrBuilder::CacheKey(const StringPiece device) { +tensorflow::Fprint128 AttrBuilder::CacheKey(const absl::string_view device) { if (!cached_cache_key_ || device != device_for_cached_cache_key_) { cached_cache_key_ = BuildCacheKeyForDevice(device); device_for_cached_cache_key_ = string(device); @@ -306,7 +306,7 @@ tensorflow::Fprint128 AttrBuilder::CacheKey(const StringPiece device) { } tensorflow::Fprint128 AttrBuilder::BuildCacheKeyForDevice( - const StringPiece device) const { + const absl::string_view device) const { tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name()); f = tsl::FingerprintCat128(f, tensorflow::Fingerprint128(device)); for (const auto& p : encoded_attrs_) { diff --git a/tensorflow/core/common_runtime/eager/attr_builder.h b/tensorflow/core/common_runtime/eager/attr_builder.h index 129841e8f90133..9dc480d8c8187a 100644 --- a/tensorflow/core/common_runtime/eager/attr_builder.h +++ b/tensorflow/core/common_runtime/eager/attr_builder.h @@ -118,7 +118,7 @@ class AttrBuilder : public AbstractOpAttrs { AttrBuilder& NumInputs(int n); template - AttrBuilder& Set(StringPiece attr_name, T&& value) { + AttrBuilder& Set(absl::string_view attr_name, T&& value) { SetAttrValue(value, &attr_tmp_); AddAttrIfNotPresent(attr_name, attr_tmp_); node_def_finalized_ = false; @@ -128,7 +128,7 @@ class AttrBuilder : public AbstractOpAttrs { size_t NumAttributes() const { return encoded_attrs_.size(); } - AttrBuilder& Set(StringPiece attr_name, const AttrValue& value) { + AttrBuilder& Set(absl::string_view attr_name, const AttrValue& value) { AddAttrIfNotPresent(attr_name, value); cached_cache_key_ = std::nullopt; return *this; @@ -139,7 +139,7 @@ class AttrBuilder : public AbstractOpAttrs { // value type in this Node. This is not an issue, because Get is used rarely // and nodes have a small number of attributes. template - absl::Status Get(StringPiece attr_name, T* value) const { + absl::Status Get(absl::string_view attr_name, T* value) const { // Common attributes are stored in AttrVecs. This Get() template // is specialized for them below. If we end up here, the type must be // among those that we store in the node_def_. @@ -150,7 +150,7 @@ class AttrBuilder : public AbstractOpAttrs { return GetNodeAttr(AttrSlice(node_def_), attr_name, value); } - tensorflow::Fprint128 CacheKey(StringPiece device); + tensorflow::Fprint128 CacheKey(absl::string_view device); // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as // well as any default attr-value pairs from the associated op_def, if there @@ -183,7 +183,7 @@ class AttrBuilder : public AbstractOpAttrs { absl::InlinedVector* type_list) const override; private: - tensorflow::Fprint128 BuildCacheKeyForDevice(StringPiece device) const; + tensorflow::Fprint128 BuildCacheKeyForDevice(absl::string_view device) const; template void SetInAttrValueMap(AttrValueMap* m, const string& attr_name, @@ -194,7 +194,7 @@ class AttrBuilder : public AbstractOpAttrs { m->insert({attr_name, value}); } - void AddAttrIfNotPresent(StringPiece attr_name, const AttrValue& value); + void AddAttrIfNotPresent(absl::string_view attr_name, const AttrValue& value); gtl::FlatMap encoded_attrs_; mutable AttrValue attr_tmp_; // For encoding @@ -210,13 +210,13 @@ class AttrBuilder : public AbstractOpAttrs { }; template <> -absl::Status AttrBuilder::Get(StringPiece attr_name, int* value) const; +absl::Status AttrBuilder::Get(absl::string_view attr_name, int* value) const; template <> -absl::Status AttrBuilder::Get(StringPiece attr_name, float* value) const; +absl::Status AttrBuilder::Get(absl::string_view attr_name, float* value) const; template <> -absl::Status AttrBuilder::Get(StringPiece attr_name, bool* value) const; +absl::Status AttrBuilder::Get(absl::string_view attr_name, bool* value) const; template <> -absl::Status AttrBuilder::Get(StringPiece attr_name, +absl::Status AttrBuilder::Get(absl::string_view attr_name, tensorflow::DataType* value) const; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index de8c208c4b9ef3..a210694b8fd3be 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -93,7 +93,7 @@ EagerContext* GetCEagerContext() { return global_c_eager_context; } namespace { -bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) { +bool ReadBoolFromEnvVar(absl::string_view env_var_name, bool default_val) { bool val; if (tensorflow::ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) { return val; @@ -1297,7 +1297,7 @@ absl::Status EagerContext::FindDeviceFromName(const char* device_name, } absl::Status EagerContext::FindCompositeDeviceFromName( - StringPiece device_name, CompositeDevice** device) const { + absl::string_view device_name, CompositeDevice** device) const { tf_shared_lock l(composite_devices_mu_); for (const auto& d : composite_devices_) { if (d.second->name() == device_name) { diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 9dac42c1921215..8440e298a95244 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -562,7 +562,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { absl::Status FindDeviceFromName(const char* device_name, Device** device) const; - absl::Status FindCompositeDeviceFromName(StringPiece device_name, + absl::Status FindCompositeDeviceFromName(absl::string_view device_name, CompositeDevice** device) const; bool IsCustomDevice(const string& device_name) override; diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc index e13ee2ffac4a0a..2fc9c6c2523a48 100644 --- a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc +++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc @@ -126,6 +126,11 @@ class XlaKeyValueStore : public xla::KeyValueStoreInterface { absl::StrCat(key_prefix_, key), timeout); } + absl::StatusOr TryGet(std::string_view key) override { + return coordination_service_agent_->TryGetKeyValue( + absl::StrCat(key_prefix_, key)); + } + absl::Status Set(std::string_view key, std::string_view value) override { return coordination_service_agent_->InsertKeyValue( absl::StrCat(key_prefix_, key), value); diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index ce4b8df85e473e..55860a66fbbdb0 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -56,7 +56,7 @@ absl::Status EagerOperation::SetAttrValue(const char* attr_name, absl::Status EagerOperation::SetAttrString(const char* attr_name, const char* data, size_t length) { - MutableAttrs()->Set(attr_name, StringPiece(data, length)); + MutableAttrs()->Set(attr_name, absl::string_view(data, length)); return absl::OkStatus(); } @@ -137,9 +137,9 @@ absl::Status EagerOperation::SetAttrStringList(const char* attr_name, const void* const* values, const size_t* lengths, int num_values) { - std::vector v(num_values); + std::vector v(num_values); for (int i = 0; i < num_values; ++i) { - v[i] = StringPiece(static_cast(values[i]), lengths[i]); + v[i] = absl::string_view(static_cast(values[i]), lengths[i]); } MutableAttrs()->Set(attr_name, v); diff --git a/tensorflow/core/common_runtime/eager/placement_utils.cc b/tensorflow/core/common_runtime/eager/placement_utils.cc index 3cbc844dddbb74..e6d547d1e9832b 100644 --- a/tensorflow/core/common_runtime/eager/placement_utils.cc +++ b/tensorflow/core/common_runtime/eager/placement_utils.cc @@ -33,7 +33,7 @@ namespace eager { // These ops are not pinnable since they generate data. It can be slower to // generate and then copy the data instead of just generating the data on the // device directly. -static bool IsPinnableOp(StringPiece op_name) { +static bool IsPinnableOp(absl::string_view op_name) { static const gtl::FlatSet* unpinnable_ops = new gtl::FlatSet({ "RandomUniform", "RandomUniformInt", @@ -62,12 +62,12 @@ static absl::Status ValidateTensorHandleRemoteDevice( "workers have been restarted."); } -bool IsColocationExempt(StringPiece op_name) { +bool IsColocationExempt(absl::string_view op_name) { const auto& exempt_ops = InputColocationExemptionRegistry::Global()->Get(); return exempt_ops.find(string(op_name)) != exempt_ops.end(); } -bool IsFunction(StringPiece op_name) { +bool IsFunction(absl::string_view op_name) { const OpDef* op_def = nullptr; absl::Status s = OpDefForOp(string(op_name), &op_def); if (!s.ok()) { @@ -81,9 +81,9 @@ bool IsFunction(StringPiece op_name) { } absl::Status MaybePinSmallOpsToCpu( - bool* result, StringPiece op_name, + bool* result, absl::string_view op_name, absl::Span args, - StringPiece cpu_device_name) { + absl::string_view cpu_device_name) { if (IsFunction(op_name) || IsColocationExempt(op_name) || !IsPinnableOp(op_name)) { *result = false; diff --git a/tensorflow/core/common_runtime/eager/placement_utils.h b/tensorflow/core/common_runtime/eager/placement_utils.h index 9064b86314aed7..fa51f1985a52f6 100644 --- a/tensorflow/core/common_runtime/eager/placement_utils.h +++ b/tensorflow/core/common_runtime/eager/placement_utils.h @@ -24,9 +24,9 @@ limitations under the License. namespace tensorflow { namespace eager { -bool IsColocationExempt(StringPiece op_name); +bool IsColocationExempt(absl::string_view op_name); -bool IsFunction(StringPiece op_name); +bool IsFunction(absl::string_view op_name); // TODO(b/154234908): Unify placement logic. @@ -34,9 +34,9 @@ bool IsFunction(StringPiece op_name); // integers (int32/int64). This can be disabled by setting the environment // variable "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false". absl::Status MaybePinSmallOpsToCpu( - bool* result, StringPiece op_name, + bool* result, absl::string_view op_name, absl::Span args, - StringPiece cpu_device_name); + absl::string_view cpu_device_name); // If a resource touching input is specified, all resource-touching ops run in // the device the resource is, regardless of anything else that has been diff --git a/tensorflow/core/common_runtime/function_body.cc b/tensorflow/core/common_runtime/function_body.cc index 60a6f41f1d8162..efd0415162f15b 100644 --- a/tensorflow/core/common_runtime/function_body.cc +++ b/tensorflow/core/common_runtime/function_body.cc @@ -52,7 +52,8 @@ FunctionBody::FunctionBody(core::RefCountPtr&& record, (*node_vec)[index] = n; } // 2. Find ControlRet nodes that must be always executed. - std::unordered_set control_ret_node_names; + std::unordered_set + control_ret_node_names; for (const auto& control_ret : this->record->fdef().control_ret()) { control_ret_node_names.insert(control_ret.second); } diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 57204f8610ceac..2fa6cb296a920f 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -74,7 +74,7 @@ absl::Status GetOpSig(const string& op, const OpDef** sig) { } void HasError(const absl::Status& s, const error::Code code, - StringPiece substr) { + absl::string_view substr) { EXPECT_EQ(s.code(), code) << s; EXPECT_TRUE(absl::StrContains(s.message(), substr)) << s << ", expected substring " << substr; diff --git a/tensorflow/core/common_runtime/function_utils.cc b/tensorflow/core/common_runtime/function_utils.cc index 06b5c4af71e3c7..53fe6154e578df 100644 --- a/tensorflow/core/common_runtime/function_utils.cc +++ b/tensorflow/core/common_runtime/function_utils.cc @@ -49,7 +49,7 @@ struct Endpoint { // The following Add* routines are used to add a few graph nodes while // functions are transformed. -static Node* AddNoOp(StringPiece name, Graph* g) { +static Node* AddNoOp(absl::string_view name, Graph* g) { NodeDef ndef; ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name))); ndef.set_op("NoOp"); @@ -59,7 +59,7 @@ static Node* AddNoOp(StringPiece name, Graph* g) { return ret; } -static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) { +static Node* AddIdentity(absl::string_view name, Graph* g, Endpoint input) { DCHECK_LT(0, input.dtype()); NodeDef ndef; ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name))); @@ -73,7 +73,7 @@ static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) { return ret; } -void DumpGraph(StringPiece label, const Graph* g) { +void DumpGraph(absl::string_view label, const Graph* g) { // TODO(zhifengc): Change Graph to record #nodes. VLOG(2) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges " << g->num_edges(); @@ -177,11 +177,12 @@ bool RemoveListArrayConverter(Graph* g) { } absl::InlinedVector identity_nodes(n->num_inputs(), nullptr); - const auto no_op = [&](StringPiece name) -> Node* { + const auto no_op = [&](absl::string_view name) -> Node* { return AddNoOp(absl::StrCat(n->name(), "/", name), g); }; - const auto identity = [&](StringPiece name, Endpoint input) -> Node* { + const auto identity = [&](absl::string_view name, + Endpoint input) -> Node* { Node* node = AddIdentity(absl::StrCat(n->name(), "/", name), g, input); node->set_requested_device(input.node->def().device()); return node; diff --git a/tensorflow/core/common_runtime/function_utils.h b/tensorflow/core/common_runtime/function_utils.h index 587274064fa768..cfbfe86936421b 100644 --- a/tensorflow/core/common_runtime/function_utils.h +++ b/tensorflow/core/common_runtime/function_utils.h @@ -38,7 +38,7 @@ string DebugString(const Graph* g); // Dump the contents of the "graph" to log files if the logging level is // sufficiently high. -void DumpGraph(StringPiece label, const Graph* g); +void DumpGraph(absl::string_view label, const Graph* g); // Convert the Graph of a function to a GraphDef. // diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD index e7bf0d973f01e6..8655b0ed822f4d 100644 --- a/tensorflow/core/common_runtime/gpu/BUILD +++ b/tensorflow/core/common_runtime/gpu/BUILD @@ -67,7 +67,7 @@ cc_library( cc_library( name = "rocm", deps = [ - "@local_xla//xla/stream_executor/rocm:rocm_rpath", + "@local_config_rocm//rocm:rocm_rpath", ], ) @@ -201,7 +201,6 @@ tf_cuda_library( "//tensorflow/core/profiler/lib:scoped_memory_debug_annotation", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", - "@local_xla//xla/stream_executor", "@local_xla//xla/stream_executor/gpu:gpu_cudamallocasync_allocator", "@local_xla//xla/stream_executor/gpu:gpu_init_impl", "@local_xla//xla/tsl/framework:device_id_utils", diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc index c5e001c216f194..fa461aed08c75c 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc @@ -69,7 +69,7 @@ std::unique_ptr CreateGPUMemAllocator(size_t) { PlatformDeviceId gpu_id(0); return absl::WrapUnique(new DeviceMemAllocator( GPUMachineManager()->ExecutorForDevice(gpu_id.value()).value(), gpu_id, - stream_executor::MemoryType::kDevice, {}, {})); + stream_executor::MemoryType::kDevice, {})); } std::unique_ptr CreateSubAllocator( diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc index de65df20e2dad4..c5251e47a8fdce 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc @@ -49,7 +49,7 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_None) { GPUDebugAllocator a( new GPUBFCAllocator(absl::WrapUnique(new DeviceMemAllocator( stream_exec, platform_device_id, - stream_executor::MemoryType::kDevice, {}, {})), + stream_executor::MemoryType::kDevice, {})), 1 << 30, "", {}), platform_device_id); @@ -79,7 +79,7 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Header) { new GPUBFCAllocator( absl::WrapUnique(new DeviceMemAllocator( stream_exec, platform_device_id, - stream_executor::MemoryType::kDevice, {}, {})), + stream_executor::MemoryType::kDevice, {})), 1 << 30, "", {}), platform_device_id); @@ -118,7 +118,7 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Footer) { new GPUBFCAllocator( absl::WrapUnique(new DeviceMemAllocator( stream_exec, platform_device_id, - stream_executor::MemoryType::kDevice, {}, {})), + stream_executor::MemoryType::kDevice, {})), 1 << 30, "", {}), platform_device_id); @@ -153,7 +153,7 @@ TEST(GPUDebugAllocatorTest, ResetToNan) { GPUNanResetAllocator a( new GPUBFCAllocator(absl::WrapUnique(new DeviceMemAllocator( stream_exec, platform_device_id, - stream_executor::MemoryType::kDevice, {}, {})), + stream_executor::MemoryType::kDevice, {})), 1 << 30, "", {}), platform_device_id); @@ -198,7 +198,7 @@ TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) { GPUNanResetAllocator a( new GPUBFCAllocator(absl::WrapUnique(new DeviceMemAllocator( stream_exec, platform_device_id, - stream_executor::MemoryType::kDevice, {}, {})), + stream_executor::MemoryType::kDevice, {})), 1 << 30, "", {}), platform_device_id); @@ -242,7 +242,7 @@ TEST(GPUDebugAllocatorTest, TracksSizes) { GPUDebugAllocator a( new GPUBFCAllocator(absl::WrapUnique(new DeviceMemAllocator( stream_exec, platform_device_id, - stream_executor::MemoryType::kDevice, {}, {})), + stream_executor::MemoryType::kDevice, {})), 1 << 30, "", {}), platform_device_id); EXPECT_EQ(true, a.TracksAllocationSizes()); @@ -254,7 +254,7 @@ TEST(GPUDebugAllocatorTest, AllocatedVsRequested) { GPUDebugAllocator a( new GPUBFCAllocator(absl::WrapUnique(new DeviceMemAllocator( stream_exec, platform_device_id, - stream_executor::MemoryType::kDevice, {}, {})), + stream_executor::MemoryType::kDevice, {})), 1 << 30, "", {}), platform_device_id); float* t1 = TypedAllocator::Allocate(&a, 1, {}); diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index f8c8a2724cf452..9b873f72f8ba5c 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -496,7 +496,7 @@ Status BaseGPUDevice::InitScratchBuffers() { } se::DeviceMemory mem( se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size)); - TF_RETURN_IF_ERROR(executor_->SynchronousMemZero( + TF_RETURN_IF_ERROR(stream_->compute->MemZero( &mem, Eigen::kGpuScratchSize + sizeof(unsigned int))); scratch_ = static_cast(scratch_buffer); } diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc index 3b2480784ab187..0ad42bb793ce5c 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc @@ -21,21 +21,21 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_device.h" -#include "xla/stream_executor/gpu/gpu_cudamallocasync_allocator.h" -#include "xla/stream_executor/gpu/gpu_init.h" -#include "xla/tests/test_macros.h" -#include "xla/tsl/framework/device_id.h" -#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/random.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" +#include "xla/stream_executor/gpu/gpu_cudamallocasync_allocator.h" +#include "xla/stream_executor/gpu/gpu_init.h" +#include "xla/tests/test_macros.h" +#include "xla/tsl/framework/device_id.h" +#include "xla/tsl/lib/core/status_test_util.h" #ifdef TF_GPU_USE_PJRT -#include "xla/pjrt/pjrt_client.h" #include "tensorflow/core/tfrt/common/pjrt_util.h" +#include "xla/pjrt/pjrt_client.h" #endif // TF_GPU_USE_PJRT #if GOOGLE_CUDA @@ -67,6 +67,15 @@ se::CudaComputeCapability GetComputeCapability() { .cuda_compute_capability(); } +bool IsRocm() { + return std::holds_alternative( + se::GPUMachineManager() + ->ExecutorForDevice(0) + .value() + ->GetDeviceDescription() + .gpu_compute_capability()); +} + void ExpectErrorMessageSubstr(const Status& s, StringPiece substr) { EXPECT_TRUE(absl::StrContains(s.ToString(), substr)) << s << ", expected substring " << substr; @@ -144,7 +153,10 @@ class GPUDeviceTest : public ::testing::Test { } }; -TEST_F(GPUDeviceTest, DISABLED_ON_GPU_ROCM(CudaMallocAsync)) { +TEST_F(GPUDeviceTest, CudaMallocAsync) { + if (IsRocm()) { + GTEST_SKIP(); + } // cudaMallocAsync supported only when cuda toolkit and driver supporting // CUDA 11.2+ #ifndef GOOGLE_CUDA diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc index 96d9ca758d67e0..c89f4fbab669c9 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc @@ -124,7 +124,7 @@ static std::unique_ptr CreateSubAllocator( executor, platform_device_id, use_unified_memory ? stream_executor::MemoryType::kUnified : stream_executor::MemoryType::kDevice, - alloc_visitors, {})); + alloc_visitors)); } Allocator* GPUProcessState::GetGPUAllocator( diff --git a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc index 0eddde84668c39..f50520e903c3ca 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc @@ -33,7 +33,7 @@ void GPUDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { GPUUtil::CopyGPUTensorToCPU(device, this, device_tensor, cpu_tensor, done); diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h index a4799bf23b1167..e7486e971a4094 100644 --- a/tensorflow/core/common_runtime/gpu_device_context.h +++ b/tensorflow/core/common_runtime/gpu_device_context.h @@ -68,9 +68,9 @@ class GPUDeviceContext : public DeviceContext { Tensor* device_tensor, StatusCallback done, bool sync_dst_compute) const override; - void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece edge_name, - Device* device, Tensor* cpu_tensor, - StatusCallback done) override; + void CopyDeviceTensorToCPU(const Tensor* device_tensor, + absl::string_view edge_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) override; void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, Tensor* output_tensor, diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc index dc8dbe5711fb2e..b83dc22641090f 100644 --- a/tensorflow/core/common_runtime/graph_constructor.cc +++ b/tensorflow/core/common_runtime/graph_constructor.cc @@ -73,7 +73,7 @@ inline bool IsNextIteration(const NodeDef& node_def) { node_def.op() == "RefNextIteration"; } -bool IsValidNodeName(StringPiece s, bool allow_internal_ops) { +bool IsValidNodeName(absl::string_view s, bool allow_internal_ops) { using ::tensorflow::strings::Scanner; Scanner scanner(s); scanner @@ -275,15 +275,15 @@ class GraphConstructor { // Returns true if `name` already exists in `g_` (either as a node name or // prefix). - bool NameExistsInGraph(StringPiece name); + bool NameExistsInGraph(absl::string_view name); // Returns true if `name` already exists in the GraphDef being imported // (either as a node name or prefix). - bool NameExistsInGraphDef(StringPiece name); + bool NameExistsInGraphDef(absl::string_view name); // Returns a unique version of `original_name`, or `original_name` if it's // already unique in the graph. - string FindUniqueName(StringPiece original_name); + string FindUniqueName(absl::string_view original_name); // Decrement pending count for users of `processed` and add the ones that now // have all of their pending inputs satisfied to `ready_`. @@ -349,13 +349,13 @@ class GraphConstructor { absl::flat_hash_map gdef_nodes_; // Prefixes already used in the GraphDef being imported. - absl::flat_hash_set gdef_prefixes_; + absl::flat_hash_set gdef_prefixes_; // Mapping from node name to the existing node in g_. - absl::flat_hash_map existing_nodes_; + absl::flat_hash_map existing_nodes_; // Prefixes already used in the graph. - absl::flat_hash_set existing_prefixes_; + absl::flat_hash_set existing_prefixes_; // Imported node names that have been uniquified. The key is the original // name, the value is the new unique name. @@ -582,7 +582,7 @@ void GraphConstructor::UpdatePendingCountAndReady(int processed, // This could be expensive but we don't expect to call it often, if at all (only // if there are multiple nodes in g_ with the same name) bool NodeNameInValues(const std::map& input_map, - const StringPiece& node_name) { + const absl::string_view& node_name) { for (auto iter = input_map.begin(); iter != input_map.end(); ++iter) { if (iter->second.first == node_name) return true; } @@ -590,17 +590,17 @@ bool NodeNameInValues(const std::map& input_map, } bool NodeNameInValues(const std::vector& control_dependencies, - const StringPiece& node_name) { + const absl::string_view& node_name) { return std::find(control_dependencies.begin(), control_dependencies.end(), node_name) != control_dependencies.end(); } // Adds any prefixes of `node_name` (not including the full name itself) to // `prefixes`. -void AddPrefixes(StringPiece node_name, - absl::flat_hash_set* prefixes) { +void AddPrefixes(absl::string_view node_name, + absl::flat_hash_set* prefixes) { size_t idx = -1; - while ((idx = node_name.find('/', idx + 1)) != StringPiece::npos) { + while ((idx = node_name.find('/', idx + 1)) != absl::string_view::npos) { prefixes->insert(node_name.substr(0, idx)); } } @@ -634,7 +634,7 @@ absl::Status GraphConstructor::EnsureNoNameCollisions() { } } } else if (!prefix_.empty()) { - StringPiece prefix_no_slash(prefix_); + absl::string_view prefix_no_slash(prefix_); prefix_no_slash.remove_suffix(1); if (!IsValidNodeName(prefix_no_slash, false)) { return errors::InvalidArgument("Imported node name prefix '", prefix_, @@ -703,7 +703,7 @@ absl::Status GraphConstructor::BuildNodeIndex() { // Validate control edges at end bool in_control_dependence = false; for (int i = 0; i < node_def.input_size(); ++i) { - StringPiece input_name = node_def.input(i); + absl::string_view input_name = node_def.input(i); if (!input_name.empty() && absl::StartsWith(input_name, "^")) { in_control_dependence = true; } else if (in_control_dependence) { @@ -742,7 +742,7 @@ absl::Status GraphConstructor::InitFromEdges() { int32_t num_control_edges = 0; bool has_loop_back_edge = false; for (int i = 0; i < node_def.input_size(); ++i) { - StringPiece input_name(node_def.input(i)); + absl::string_view input_name(node_def.input(i)); if (absl::StartsWith(input_name, "^")) { num_control_edges++; } else { @@ -758,7 +758,7 @@ absl::Status GraphConstructor::InitFromEdges() { } } for (int i = 0; i < node_def.input_size(); ++i) { - StringPiece input_name = node_def.input(i); + absl::string_view input_name = node_def.input(i); TensorId id(ParseTensorName(input_name)); if (opts_.input_map.count(id) == 0) { // If an input is not mapped, then the input should appear in the graph @@ -792,7 +792,7 @@ absl::Status GraphConstructor::ValidateColocationConstraints( const auto iter = node_def.attr().find(kColocationAttrName); if (iter == node_def.attr().end()) return absl::OkStatus(); for (const string& c : iter->second.list().s()) { - StringPiece s(c); + absl::string_view s(c); if (absl::ConsumePrefix(&s, kColocationGroupPrefix) && gdef_nodes_.find(s) == gdef_nodes_.end()) { return errors::InvalidArgument( @@ -985,7 +985,7 @@ void GraphConstructor::AddPrefixToNodeDef( // Skip remapped inputs (which already exist in g_ and are not being // imported). if (input_already_exists[i]) continue; - StringPiece input(node_def->input(i)); + absl::string_view input(node_def->input(i)); if (absl::ConsumePrefix(&input, "^")) { node_def->set_input(i, strings::StrCat("^", prefix_, input)); } else { @@ -997,7 +997,7 @@ void GraphConstructor::AddPrefixToNodeDef( auto* list = node_def->mutable_attr()->at(kColocationAttrName).mutable_list(); for (int i = 0; i < list->s_size(); ++i) { - StringPiece v(list->s(i)); + absl::string_view v(list->s(i)); if (absl::ConsumePrefix(&v, kColocationGroupPrefix)) { list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v)); } @@ -1039,7 +1039,7 @@ void GraphConstructor::UpdateUniquifiedColocationNames() { continue; bool updated = false; for (size_t i = 0; i < coloc_values.size(); ++i) { - StringPiece val(coloc_values[i]); + absl::string_view val(coloc_values[i]); if (absl::ConsumePrefix(&val, kColocationGroupPrefix)) { auto name_pair = uniquified_names_.find(string(val)); if (name_pair == uniquified_names_.end()) continue; @@ -1054,19 +1054,19 @@ void GraphConstructor::UpdateUniquifiedColocationNames() { } } -bool GraphConstructor::NameExistsInGraph(StringPiece name) { +bool GraphConstructor::NameExistsInGraph(absl::string_view name) { if (existing_nodes_.find(name) != existing_nodes_.end()) return true; if (existing_prefixes_.find(name) != existing_prefixes_.end()) return true; return false; } -bool GraphConstructor::NameExistsInGraphDef(StringPiece name) { +bool GraphConstructor::NameExistsInGraphDef(absl::string_view name) { if (gdef_nodes_.find(name) != gdef_nodes_.end()) return true; if (gdef_prefixes_.find(name) != gdef_prefixes_.end()) return true; return false; } -string GraphConstructor::FindUniqueName(StringPiece original_name) { +string GraphConstructor::FindUniqueName(absl::string_view original_name) { string name(original_name); int count = 0; // Check that any generated names don't collide with imported NodeDefs (as @@ -1441,7 +1441,7 @@ absl::Status GraphConstructor::PopulateReturnTensors() { absl::Status GraphConstructor::PopulateReturnNodes() { if (opts_.return_nodes.empty()) return absl::OkStatus(); - for (StringPiece name : opts_.return_nodes) { + for (absl::string_view name : opts_.return_nodes) { auto iter = gdef_nodes_.find(name); if (iter == gdef_nodes_.end()) { return errors::InvalidArgument("Requested return node '", name, diff --git a/tensorflow/core/common_runtime/graph_constructor_test.cc b/tensorflow/core/common_runtime/graph_constructor_test.cc index 419f09c4d17c55..91c471f0705a55 100644 --- a/tensorflow/core/common_runtime/graph_constructor_test.cc +++ b/tensorflow/core/common_runtime/graph_constructor_test.cc @@ -167,7 +167,7 @@ class GraphConstructorTest : public ::testing::Test { "value for the _class attribute. Update it and its callers"; return ""; } - StringPiece loc(value[0]); + absl::string_view loc(value[0]); return absl::ConsumePrefix(&loc, kColocationGroupPrefix) ? string(loc) : ""; } diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index a6154ff06f301f..d7a9462e387d2d 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -811,7 +811,7 @@ absl::Status GraphExecutionState::OptimizeGraph( Device* cpu_device = nullptr; for (const auto& device : device_set_->devices()) { if (device->parsed_name().id == 0 && - StringPiece(device->parsed_name().type) == "CPU" && + absl::string_view(device->parsed_name().type) == "CPU" && device->GetAllocator(AllocatorAttributes()) != nullptr) { cpu_device = device; } diff --git a/tensorflow/core/common_runtime/inline_function_utils.cc b/tensorflow/core/common_runtime/inline_function_utils.cc index c1fea615fba655..1e8a85207fa0b1 100644 --- a/tensorflow/core/common_runtime/inline_function_utils.cc +++ b/tensorflow/core/common_runtime/inline_function_utils.cc @@ -96,7 +96,7 @@ struct EndpointEq { // The following Add* routines are used to add a few graph nodes while // functions are transformed. -static Node* AddNoOp(StringPiece name, Graph* g) { +static Node* AddNoOp(absl::string_view name, Graph* g) { NodeDef ndef; ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name))); ndef.set_op("NoOp"); @@ -106,7 +106,7 @@ static Node* AddNoOp(StringPiece name, Graph* g) { return ret; } -static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) { +static Node* AddIdentity(absl::string_view name, Graph* g, Endpoint input) { DCHECK_LT(0, input.dtype()); NodeDef ndef; ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name))); @@ -506,7 +506,7 @@ absl::Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, // control nodes and inlined function inputs and outputs. // Add a NoOp node for function control inputs/outputs. - const auto no_op = [&](StringPiece name) -> Node* { + const auto no_op = [&](absl::string_view name) -> Node* { Node* node = AddNoOp(absl::StrCat(caller->name(), "/", name), g); const absl::optional device = placer->ControlNodeDevice(); if (device.has_value()) node->set_requested_device(*device); @@ -514,7 +514,7 @@ absl::Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, }; // Add an Identity node for function input. - const auto input_identity = [&](StringPiece name, Endpoint input, + const auto input_identity = [&](absl::string_view name, Endpoint input, int index) -> Node* { Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input); const absl::optional device = placer->InputNodeDevice(index); @@ -529,7 +529,7 @@ absl::Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, }; // Add an Identity node for function output. - const auto output_identity = [&](StringPiece name, Endpoint input, + const auto output_identity = [&](absl::string_view name, Endpoint input, int index) -> Node* { Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input); const absl::optional device = placer->OutputNodeDevice(index); diff --git a/tensorflow/core/common_runtime/lower_functional_ops.cc b/tensorflow/core/common_runtime/lower_functional_ops.cc index 7cf5af392d518f..49885ba8129e8e 100644 --- a/tensorflow/core/common_runtime/lower_functional_ops.cc +++ b/tensorflow/core/common_runtime/lower_functional_ops.cc @@ -90,7 +90,7 @@ const absl::flat_hash_set& DevicePropagationOpList() { return *op_list; } -bool IsPropagatableDevice(StringPiece device_string) { +bool IsPropagatableDevice(absl::string_view device_string) { DeviceNameUtils::ParsedName device; return DeviceNameUtils::ParseFullName(device_string, &device) && device.type == DEVICE_TPU; diff --git a/tensorflow/core/common_runtime/lower_functional_ops_test.cc b/tensorflow/core/common_runtime/lower_functional_ops_test.cc index 057cc4fe4c3e8c..2f16c6fef7e308 100644 --- a/tensorflow/core/common_runtime/lower_functional_ops_test.cc +++ b/tensorflow/core/common_runtime/lower_functional_ops_test.cc @@ -40,7 +40,7 @@ typedef FunctionDefHelper FDH; constexpr const char* const kLowerUsingSwitchMergeAttr = LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr; -static void AssertHasSubstr(StringPiece s, StringPiece expected) { +static void AssertHasSubstr(absl::string_view s, absl::string_view expected) { ASSERT_TRUE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } diff --git a/tensorflow/core/common_runtime/next_pluggable_device/BUILD b/tensorflow/core/common_runtime/next_pluggable_device/BUILD index 6082429fc2585f..1be79c9f233d0b 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/BUILD +++ b/tensorflow/core/common_runtime/next_pluggable_device/BUILD @@ -231,6 +231,7 @@ cc_library( "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -363,6 +364,7 @@ tf_cc_test( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:errors", diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD b/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD index 7862391ec43c6a..4b24910e748ab6 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD @@ -29,6 +29,8 @@ cc_library( deps = [ ":plugin_c_api_hdrs", "//tensorflow/core/platform:logging", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@local_tsl//tsl/platform:env", "@tf_runtime//:hostcontext_alwayslink", ], diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/example_plugin.cc b/tensorflow/core/common_runtime/next_pluggable_device/c/example_plugin.cc index 01a4e2e8c8de18..7f2b964a290f6b 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c/example_plugin.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/example_plugin.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "tensorflow/core/platform/logging.h" #include "tsl/platform/env.h" #include "tfrt/host_context/async_dispatch.h" // from @tf_runtime diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc index 7d60b2881a2ae9..a4a1ac97a7bfa6 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc @@ -102,7 +102,8 @@ class FakeDeviceManager : public DeviceMgr { } std::string DebugString() const override { return ""; } std::string DeviceMappingString() const override { return ""; } - absl::Status LookupDevice(StringPiece name, Device** device) const override { + absl::Status LookupDevice(absl::string_view name, + Device** device) const override { *device = fake_device_.get(); return absl::OkStatus(); } diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc index 805418d4d5b4ba..3958f78f570d10 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include +#include #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -145,6 +147,11 @@ class TestCoordinationClient : public CoordinationClient { StatusCallback done) override { done(absl::UnimplementedError("CancelBarrierAsync")); } + void GetAliveTasksAsync(const tsl::GetAliveTasksRequest* request, + tsl::GetAliveTasksResponse* response, + StatusCallback done) override { + done(absl::UnimplementedError("GetAliveTasksAsync")); + } void RegisterTaskAsync(tsl::CallOptions*, const tsl::RegisterTaskRequest* request, tsl::RegisterTaskResponse* response, diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.cc b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.cc index 18331eee70b4bb..f4587925e0e238 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/status/status.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/experimental/next_pluggable_device/c_api.h" diff --git a/tensorflow/core/common_runtime/optimize_cross_host_control_deps.cc b/tensorflow/core/common_runtime/optimize_cross_host_control_deps.cc index b544eb11ffb8a2..451200bbecf61d 100644 --- a/tensorflow/core/common_runtime/optimize_cross_host_control_deps.cc +++ b/tensorflow/core/common_runtime/optimize_cross_host_control_deps.cc @@ -29,7 +29,7 @@ namespace tensorflow { namespace { -absl::Status BuildNoopNode(const Node& source, StringPiece name, +absl::Status BuildNoopNode(const Node& source, absl::string_view name, const string& device, Graph* graph, Node** node) { NodeDefBuilder builder(name, "NoOp", NodeDebugInfo(source)); if (!device.empty()) { @@ -45,7 +45,7 @@ absl::Status BuildNoopNode(const Node& source, StringPiece name, return absl::OkStatus(); } -absl::Status BuildIdentityNNode(const Node& source, StringPiece name, +absl::Status BuildIdentityNNode(const Node& source, absl::string_view name, const string& device, Graph* graph, std::vector& inputs, Node** node) { @@ -65,7 +65,7 @@ absl::Status BuildIdentityNNode(const Node& source, StringPiece name, return absl::OkStatus(); } -absl::Status BuildIdentityNode(const Node& source, StringPiece name, +absl::Status BuildIdentityNode(const Node& source, absl::string_view name, const string& device, Graph* graph, std::vector& inputs, Node** node) { diff --git a/tensorflow/core/common_runtime/optimize_function_graph_utils.cc b/tensorflow/core/common_runtime/optimize_function_graph_utils.cc index d501d2a6df2a41..1cfaffc7c3699f 100644 --- a/tensorflow/core/common_runtime/optimize_function_graph_utils.cc +++ b/tensorflow/core/common_runtime/optimize_function_graph_utils.cc @@ -150,7 +150,8 @@ const string* AssignedOrRequestedDeviceName(const Node& node) { void GetColocationGroup(const Node* node, string* group) { // We hoist the conversion from C-style string literal to string here, // so that we can avoid the many repeated calls to strlen(). - static const StringPiece kColocationAttrNameStringPiece(kColocationAttrName); + static const absl::string_view kColocationAttrNameStringPiece( + kColocationAttrName); const AttrValue* attr_value = node->attrs().Find(kColocationAttrNameStringPiece); if (attr_value != nullptr && attr_value->has_list() && diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc index a138f3216ceb1c..9963c89ea97973 100644 --- a/tensorflow/core/common_runtime/placer_test.cc +++ b/tensorflow/core/common_runtime/placer_test.cc @@ -927,12 +927,12 @@ TEST_F(PlacerTest, TestAssignedGpuDeviceToCpuDevice) { "/job:a/replica:0/task:0/device:FakeGPU:0"); absl::Status s = Place(&g); - EXPECT_EQ(error::INTERNAL, s.code()) << s.ToString(); + EXPECT_EQ(error::INTERNAL, s.code()) << s; EXPECT_TRUE(absl::StrContains( s.message(), "Assigned device '/job:a/replica:0/task:0/device:FakeGPU:0' " "does not have registered OpKernel support for TestInput")) - << s.ToString(); + << s; } // Test that graphs with reference connections are correctly placed. @@ -1082,7 +1082,7 @@ TEST_F(PlacerTest, TestResourceHandlesOnDifferentDevicesFails) { } absl::Status s = Place(&g, allow_soft_placement, true); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString(); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s; if (set_assigned) { EXPECT_TRUE(absl::StrContains( s.message(), @@ -1091,7 +1091,7 @@ TEST_F(PlacerTest, TestResourceHandlesOnDifferentDevicesFails) { "colocation groups with incompatible assigned devices: " "/job:a/replica:0/task:0/device:FakeGPU:0 vs " "/job:a/replica:0/task:0/device:FakeCPU:0")) - << s.ToString(); + << s; } else { EXPECT_TRUE(absl::StrContains( s.message(), @@ -1100,7 +1100,7 @@ TEST_F(PlacerTest, TestResourceHandlesOnDifferentDevicesFails) { "colocation groups with incompatible resource devices: " "/job:a/replica:0/task:0/device:FakeGPU:0 vs " "/job:a/replica:0/task:0/device:FakeCPU:0")) - << s.ToString(); + << s; } return absl::OkStatus(); @@ -1317,7 +1317,7 @@ TEST_P(SoftPlacementPlacerTest, TestInvalidMultipleColocationGroups) { bool allow_soft_placement = GetParam(); absl::Status s = Place(&g, allow_soft_placement, true); if (allow_soft_placement) { - EXPECT_EQ(error::OK, s.code()) << s.ToString(); + EXPECT_EQ(error::OK, s.code()) << s; EXPECT_DEVICE_TYPE(g, "in", "FakeCPU"); EXPECT_DEVICE_TYPE(g, "colocated_1", "FakeCPU"); EXPECT_DEVICE_TYPE(g, "foo", "FakeGPU"); @@ -1327,7 +1327,7 @@ TEST_P(SoftPlacementPlacerTest, TestInvalidMultipleColocationGroups) { "Cannot colocate nodes {{colocation_node foo}} and " "{{colocation_node in}} because no device type supports both of those " "nodes and the other nodes colocated with them")) - << s.ToString(); + << s; } } @@ -1401,15 +1401,15 @@ TEST_P(SoftPlacementPlacerTest, bool allow_soft_placement = GetParam(); absl::Status s = Place(&g, allow_soft_placement, true); if (allow_soft_placement) { - EXPECT_EQ(error::OK, s.code()) << s.ToString(); + EXPECT_EQ(error::OK, s.code()) << s; } else { - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString(); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s; EXPECT_TRUE(absl::StrContains( s.message(), "Cannot colocate nodes {{colocation_node assign3}} and " "{{colocation_node var2}} because no device type supports both of " "those nodes and the other nodes colocated with them.")) - << s.ToString(); + << s; } } @@ -1757,12 +1757,11 @@ TEST_F(PlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) { } absl::Status s = Place(&g, false, false); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString(); - EXPECT_TRUE(absl::StrContains(s.message(), "/device:FakeCPU:0")) - << s.ToString(); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s; + EXPECT_TRUE(absl::StrContains(s.message(), "/device:FakeCPU:0")) << s; EXPECT_TRUE(absl::StrContains( s.message(), "no supported kernel for FakeCPU devices is available")) - << s.ToString(); + << s; } // Test that placement fails when a node requests an explicit device that is not @@ -1987,7 +1986,7 @@ TEST_P(SoftPlacementPlacerTest, bool allow_soft_placement = GetParam(); absl::Status s = Place(&g, allow_soft_placement, true); if (allow_soft_placement) { - EXPECT_EQ(error::OK, s.code()) << s.ToString(); + EXPECT_EQ(error::OK, s.code()) << s; EXPECT_DEVICE_TYPE(g, "a", "FakeGPU"); EXPECT_DEVICE_TYPE(g, "id1", "FakeGPU"); EXPECT_DEVICE_TYPE(g, "b", "FakeCPU"); @@ -1999,7 +1998,7 @@ TEST_P(SoftPlacementPlacerTest, "Cannot colocate nodes {{colocation_node id2}} and {{colocation_node " "id1}}: Cannot merge devices with incompatible types: " "'/device:FakeCPU:0' and '/device:FakeGPU:0'")) - << s.ToString(); + << s; } } @@ -2056,13 +2055,13 @@ TEST_F(PlacerTest, AssignedDeviceOfColocatedNodeIsRespected) { TF_ASSERT_OK(BuildGraph(graph, &g)); GetNodeByName(g, "a")->set_assigned_device_name(kFullCPU); absl::Status s = Place(&g); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString(); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s; EXPECT_TRUE( absl::StrContains(s.message(), "{{colocation_node iter}} was colocated with a " "group of nodes that required incompatible device " "'/job:a/replica:0/task:0/device:FakeCPU:0'")) - << s.ToString(); + << s; } TEST_P(SoftPlacementPlacerTest, @@ -2100,7 +2099,7 @@ TEST_P(SoftPlacementPlacerTest, absl::Status s = Place(&g, allow_soft_placement, false); if (allow_soft_placement) { - EXPECT_EQ(error::OK, s.code()) << s.ToString(); + EXPECT_EQ(error::OK, s.code()) << s; EXPECT_DEVICE_TYPE(g, "a", "FakeGPU"); EXPECT_DEVICE_TYPE(g, "id_a", "FakeGPU"); EXPECT_DEVICE_TYPE(g, "id1", "FakeGPU"); @@ -2115,7 +2114,7 @@ TEST_P(SoftPlacementPlacerTest, "id1}}: Cannot merge devices with incompatible types: " "'/job:a/replica:0/task:0/device:FakeCPU:0' and " "'/job:a/replica:0/task:0/device:FakeGPU:0'")) - << s.ToString(); + << s; } } @@ -2693,13 +2692,13 @@ TEST_F(NestedPlacerTest, ResourceConflictInvolvingPCO) { Graph g(OpRegistry::Global()); TF_EXPECT_OK(BuildGraph(graph, &g)); absl::Status s = CallOptPassesAndPlace(&g); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString(); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s; EXPECT_TRUE(absl::StrContains( s.message(), "Cannot place the graph because a reference or resource edge connects " "colocation groups with incompatible resource devices: /device:FakeCPU:0 " "vs /device:FakeGPU:0")) - << s.ToString(); + << s; } TEST_F(NestedPlacerTest, ResourceConflictInvolvingTwoPCOs) { @@ -2741,13 +2740,13 @@ TEST_F(NestedPlacerTest, ResourceConflictInvolvingTwoPCOs) { TF_EXPECT_OK(BuildGraph(graph, &g)); absl::Status s = CallOptPassesAndPlace(&g); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString(); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s; EXPECT_TRUE(absl::StrContains( s.message(), "Cannot place the graph because a reference or resource edge connects " "colocation groups with incompatible resource devices: /device:FakeCPU:0 " "vs /device:FakeGPU:0")) - << s.ToString(); + << s; } // Function that returns a resource that can be produced on CPU only. @@ -2802,12 +2801,12 @@ TEST_F(NestedPlacerTest, DeepDeviceConstraintsPropagated) { GetNodeByName(g, "id")->set_assigned_device_name(kFullGPU); absl::Status s = CallOptPassesAndPlace(&g); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString(); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s; // TODO(b/129057603): When better error messages are implemented, this should // change. EXPECT_TRUE(absl::StrContains( s.message(), "Could not satisfy explicit device specification")) - << s.ToString(); + << s; } FunctionDef NestedCPUResourceOutput() { @@ -2865,12 +2864,12 @@ TEST_F(NestedPlacerTest, NestedDeepDeviceConstraintsPropagated) { GetNodeByName(g, "id")->set_assigned_device_name(kFullGPU); absl::Status s = CallOptPassesAndPlace(&g); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString(); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s; // TODO(b/129057603): When better error messages are implemented, this should // change. EXPECT_TRUE(absl::StrContains( s.message(), "Could not satisfy explicit device specification")) - << s.ToString(); + << s; } TEST_F(NestedPlacerTest, TwoFunctionsBackToBack) { @@ -2919,13 +2918,13 @@ TEST_F(NestedPlacerTest, TwoFunctionsBackToBack) { TF_EXPECT_OK(BuildGraph(graph, &g)); absl::Status s = CallOptPassesAndPlace(&g); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString(); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s; EXPECT_TRUE(absl::StrContains( s.message(), "Cannot place the graph because a reference or resource edge connects " "colocation groups with incompatible resource devices: /device:FakeCPU:0 " "vs /device:FakeGPU:0")) - << s.ToString(); + << s; } FunctionDef NestedCallFunctionsBackToBack() { @@ -2986,13 +2985,13 @@ TEST_F(NestedPlacerTest, NestedTwoFunctionsBackToBack) { TF_EXPECT_OK(BuildGraph(graph, &g)); absl::Status s = CallOptPassesAndPlace(&g); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString(); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s; EXPECT_TRUE(absl::StrContains( s.message(), "Nodes were connected by a reference or resource connection (requiring " "them to be on the same device), but the two nodes were assigned two " "different devices")) - << s.ToString(); + << s; } FunctionDef RecursiveResourceIdentity() { @@ -3035,13 +3034,13 @@ TEST_F(NestedPlacerTest, DirectRecursion) { TF_EXPECT_OK(BuildGraph(graph, &g)); absl::Status s = CallOptPassesAndPlace(&g); - EXPECT_EQ(error::UNIMPLEMENTED, s.code()) << s.ToString(); + EXPECT_EQ(error::UNIMPLEMENTED, s.code()) << s; EXPECT_TRUE(absl::StrContains( s.message(), "Recursive function calls are not supported. Node {{node out}} inside " "the body of {{function_node RecursiveResourceIdentity}} calls function " "{{function_node RecursiveResourceIdentity}}")) - << s.ToString(); + << s; } FunctionDef RecursiveF1() { @@ -3107,14 +3106,14 @@ TEST_F(NestedPlacerTest, IndirectRecursion) { TF_EXPECT_OK(BuildGraph(graph, &g)); absl::Status s = CallOptPassesAndPlace(&g); - EXPECT_EQ(error::UNIMPLEMENTED, s.code()) << s.ToString(); + EXPECT_EQ(error::UNIMPLEMENTED, s.code()) << s; EXPECT_TRUE(absl::StrContains( s.message(), "Recursive function calls are not supported. Node {{node out}} inside " "the body of {{function_node RecursiveF2}} calls function " "{{function_node RecursiveF1}} which is already present in the call " "stack")) - << s.ToString(); + << s; } TEST_F(PlacerTest, IdentityMatchesInputAndOutputPlacement) { diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.cc index 2c67fd687a74ba..c6c10b190f958c 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.cc @@ -35,11 +35,9 @@ void PluggableDeviceContext::CopyCPUTensorToDevice( cpu_tensor, this, device, device_tensor, done, sync_dst_compute); } -void PluggableDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, - Device* device, - Tensor* cpu_tensor, - StatusCallback done) { +void PluggableDeviceContext::CopyDeviceTensorToCPU( + const Tensor* device_tensor, absl::string_view tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) { PluggableDeviceUtil::CopyPluggableDeviceTensorToCPU( device, this, device_tensor, cpu_tensor, done); } diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h index 4c0eeb935b2aab..596341fdae9d20 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h @@ -60,7 +60,7 @@ class PluggableDeviceContext : public DeviceContext { bool sync_dst_compute) const override; void CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, Device* device, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.cc index d087e5df90a6ab..e1879f44c13566 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.cc @@ -113,9 +113,6 @@ Allocator* PluggableDeviceProcessState::GetPluggableDeviceAllocator( int bus_id = BusIdForPluggableDevice(tf_device_id); DCHECK_GE(bus_id, 0); - while (bus_id >= pluggable_device_visitors_.size()) { - pluggable_device_visitors_.push_back({}); - } bool use_unified_memory = options.per_process_gpu_memory_fraction() > 1.0 || options.experimental().use_unified_memory(); @@ -123,9 +120,7 @@ Allocator* PluggableDeviceProcessState::GetPluggableDeviceAllocator( platform->ExecutorForDevice(platform_device_id.value()).value(), platform_device_id, use_unified_memory ? stream_executor::MemoryType::kUnified - : stream_executor::MemoryType::kDevice, - pluggable_device_visitors_[bus_id], {}); - + : stream_executor::MemoryType::kDevice); Allocator* device_allocator = nullptr; auto cplatform = dynamic_cast(platform); if (cplatform == nullptr) { @@ -187,15 +182,8 @@ Allocator* PluggableDeviceProcessState::GetPluggableDeviceHostAllocator( while (static_cast(pluggable_device_host_allocators_.size()) <= numa_node) { - while (pluggable_device_host_alloc_visitors_.size() <= numa_node) { - pluggable_device_host_alloc_visitors_.push_back({}); - } - while (pluggable_device_host_free_visitors_.size() <= numa_node) { - pluggable_device_host_free_visitors_.push_back({}); - } SubAllocator* sub_allocator = new DeviceHostAllocator( - se, numa_node, pluggable_device_host_alloc_visitors_[numa_node], - pluggable_device_host_free_visitors_[numa_node]); + se, numa_node, /*alloc_visitors=*/{}, /*free_visitors=*/{}); int64_t pluggable_device_host_mem_limit_in_mb = -1; absl::Status status = ReadInt64FromEnvVar( "TF_GPU_HOST_MEM_LIMIT_IN_MB", 1LL << 17 /*128GB max by default*/, diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h index 0c3965886a088f..6e6b45fe887dca 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h @@ -117,10 +117,6 @@ class PluggableDeviceProcessState { std::vector pluggable_device_host_allocators_ TF_GUARDED_BY(mu_); - std::vector> - pluggable_device_host_alloc_visitors_ TF_GUARDED_BY(mu_); - std::vector> - pluggable_device_host_free_visitors_ TF_GUARDED_BY(mu_); }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/process_util.cc b/tensorflow/core/common_runtime/process_util.cc index e0fa771c4b8280..65733614bdc54c 100644 --- a/tensorflow/core/common_runtime/process_util.cc +++ b/tensorflow/core/common_runtime/process_util.cc @@ -93,13 +93,13 @@ thread::ThreadPool* ComputePool(const SessionOptions& options) { int32 NumInterOpThreadsFromEnvironment() { int32_t num; const char* val = std::getenv("TF_NUM_INTEROP_THREADS"); - return (val && strings::safe_strto32(val, &num)) ? num : 0; + return (val && absl::SimpleAtoi(val, &num)) ? num : 0; } int32 NumIntraOpThreadsFromEnvironment() { int32_t num; const char* val = std::getenv("TF_NUM_INTRAOP_THREADS"); - return (val && strings::safe_strto32(val, &num)) ? num : 0; + return (val && absl::SimpleAtoi(val, &num)) ? num : 0; } #if defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_MKL) int32 OMPThreadsFromEnvironment() { diff --git a/tensorflow/core/common_runtime/profile_handler.h b/tensorflow/core/common_runtime/profile_handler.h index 1d8856a344a452..71aac10bf6887a 100644 --- a/tensorflow/core/common_runtime/profile_handler.h +++ b/tensorflow/core/common_runtime/profile_handler.h @@ -41,8 +41,9 @@ class ProfileHandler { // - op_type: String name of the Op. // - details: Main content for timeline click text. virtual void RecordOneOp(const string& device, const NodeExecStats& stats, - bool is_copy, StringPiece label, StringPiece op_type, - StringPiece details) = 0; + bool is_copy, absl::string_view label, + absl::string_view op_type, + absl::string_view details) = 0; // Records that the current step finished. // diff --git a/tensorflow/core/common_runtime/quantize_training.cc b/tensorflow/core/common_runtime/quantize_training.cc index 6117cccaa0cf4c..c800552b5d3bca 100644 --- a/tensorflow/core/common_runtime/quantize_training.cc +++ b/tensorflow/core/common_runtime/quantize_training.cc @@ -151,7 +151,7 @@ absl::Status FindSaveOp(const Graph* graph, Node** save_op, return absl::OkStatus(); } -Node* FindRestoreAllOp(const Graph* graph, StringPiece save_prefix) { +Node* FindRestoreAllOp(const Graph* graph, absl::string_view save_prefix) { for (Node* node : graph->op_nodes()) { // The restore_all op should have the same prefix of the save_op. if (node->name() == strings::StrCat(save_prefix, "/restore_all")) { @@ -164,8 +164,8 @@ Node* FindRestoreAllOp(const Graph* graph, StringPiece save_prefix) { // Strips the last "/suffix" from a name. // We use this to construct the name of restore ops in the same way they are // constructed by the Saver. -StringPiece GetNodeNamePrefix(const Node* node) { - StringPiece name = node->name(); +absl::string_view GetNodeNamePrefix(const Node* node) { + absl::string_view name = node->name(); return name.substr(0, name.rfind('/')); } @@ -249,7 +249,7 @@ absl::Status AddRestoreVariableSubgraphs( Graph* graph, Node* save_op, const std::vector& in_edges, const std::vector& variables) { Node* prefix_op = in_edges[0]->src(); - StringPiece name_prefix = GetNodeNamePrefix(save_op); + absl::string_view name_prefix = GetNodeNamePrefix(save_op); Node* restore_all = FindRestoreAllOp(graph, name_prefix); if (restore_all == nullptr) { return errors::InvalidArgument("graph has SaveOp, but no restore_all NoOp"); diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 6a546835fd5f54..5b388ce68d68fd 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -74,7 +74,7 @@ absl::Status ShapeRefiner::InferShapesForFunctionSubNode( TF_RETURN_IF_ERROR(AddNodeInternal(node, outer_context)); InferenceContext* node_context = CHECK_NOTNULL(GetContext(node)); - if (StringPiece(node->type_string()) == kArgOp) { + if (absl::string_view(node->type_string()) == kArgOp) { // Handle special node: function input. // Shapes for these nodes are provided in the outer inference // context. @@ -102,7 +102,7 @@ absl::Status ShapeRefiner::InferShapesForFunctionSubNode( if (resource) { node_context->set_output_handle_shapes_and_types(0, *resource); } - } else if (StringPiece(node->type_string()) == kRetvalOp) { + } else if (absl::string_view(node->type_string()) == kRetvalOp) { // Handle special node: function output. // Shapes inferred for these nodes go into the outer inference // context. diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index 89105e1b636129..c54f26e7cc460c 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -65,7 +65,7 @@ class ShapeRefinerTest : public ::testing::Test { int end, int stride, const char* expected, int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int shrink_axis_mask = 0, - StringPiece test_op = "TensorAsShapeInt32") { + absl::string_view test_op = "TensorAsShapeInt32") { Scope root = Scope::DisabledShapeInferenceScope(); auto placeholder = ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape)); diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index 695b7d55217094..3aeb903e423b01 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -223,7 +223,7 @@ static int ExtractGpuWithStreamAll(string device_name) { string ordered_capture(capture); std::reverse(ordered_capture.begin(), ordered_capture.end()); int gpu_id; - CHECK(strings::safe_strto32(ordered_capture, &gpu_id)); + CHECK(absl::SimpleAtoi(ordered_capture, &gpu_id)); return gpu_id; } } @@ -252,7 +252,7 @@ static int ExtractGpuWithoutStream(string device_name) { string ordered_capture(capture); std::reverse(ordered_capture.begin(), ordered_capture.end()); int gpu_id; - CHECK(strings::safe_strto32(ordered_capture, &gpu_id)); + CHECK(absl::SimpleAtoi(ordered_capture, &gpu_id)); return gpu_id; } } diff --git a/tensorflow/core/config/flags.cc b/tensorflow/core/config/flags.cc index 26e74e063639c3..d2d1ea502dfe9e 100644 --- a/tensorflow/core/config/flags.cc +++ b/tensorflow/core/config/flags.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { namespace config { -Flag::Flag(StringPiece flag, bool default_value) { +Flag::Flag(absl::string_view flag, bool default_value) { bool val = default_value; if (ReadBoolFromEnvVar(absl::AsciiStrToUpper(flag), default_value, &val) .ok()) { diff --git a/tensorflow/core/config/flags.h b/tensorflow/core/config/flags.h index 3a01e65f12b294..c882cd3939f4af 100644 --- a/tensorflow/core/config/flags.h +++ b/tensorflow/core/config/flags.h @@ -25,7 +25,7 @@ namespace config { // Note: this class is not thread safe. class Flag { public: - explicit Flag(StringPiece flag_name, bool default_value); + explicit Flag(absl::string_view flag_name, bool default_value); bool value() { return value_; } void reset(bool value) { value_ = value; } diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD index de3de817211936..96f7f1e63fefe5 100644 --- a/tensorflow/core/data/BUILD +++ b/tensorflow/core/data/BUILD @@ -148,6 +148,8 @@ cc_library( "//tensorflow/core:testlib", "//tensorflow/core/framework:tensor_testutil", "//tensorflow/core/kernels:function_ops", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@eigen_archive//:eigen3", diff --git a/tensorflow/core/data/captured_function.cc b/tensorflow/core/data/captured_function.cc index 887e3b9b3bfa13..49c33c20911dde 100644 --- a/tensorflow/core/data/captured_function.cc +++ b/tensorflow/core/data/captured_function.cc @@ -402,8 +402,8 @@ class BorrowedArgsCallFrame : public CallFrameBase { absl::Status MakeIteratorFromInputElement( IteratorContext* ctx, const DatasetBaseIterator* parent, const std::vector& input_element, int64_t thread_index, - const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, - std::unique_ptr* out_iterator) { + const InstantiatedCapturedFunction& inst_captured_func, + absl::string_view prefix, std::unique_ptr* out_iterator) { return MakeIteratorFromInputElement(ctx, parent, input_element, thread_index, inst_captured_func, prefix, out_iterator, /*node=*/nullptr); @@ -412,8 +412,8 @@ absl::Status MakeIteratorFromInputElement( absl::Status MakeIteratorFromInputElement( IteratorContext* ctx, const DatasetBaseIterator* parent, const std::vector& input_element, int64_t thread_index, - const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, - std::unique_ptr* out_iterator, + const InstantiatedCapturedFunction& inst_captured_func, + absl::string_view prefix, std::unique_ptr* out_iterator, const std::shared_ptr& node) { std::vector return_values; diff --git a/tensorflow/core/data/captured_function.h b/tensorflow/core/data/captured_function.h index b72fcc8590c347..18fa698def2861 100644 --- a/tensorflow/core/data/captured_function.h +++ b/tensorflow/core/data/captured_function.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/dataset.h" @@ -51,8 +52,8 @@ class InstantiatedCapturedFunction; absl::Status MakeIteratorFromInputElement( IteratorContext* ctx, const DatasetBaseIterator* parent, const std::vector& input_element, int64_t thread_index, - const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, - std::unique_ptr* out_iterator); + const InstantiatedCapturedFunction& inst_captured_func, + absl::string_view prefix, std::unique_ptr* out_iterator); // Creates an iterator for a dataset which is created by applying the given // function to the given input element. Pass non-null `node` to record @@ -60,8 +61,8 @@ absl::Status MakeIteratorFromInputElement( absl::Status MakeIteratorFromInputElement( IteratorContext* ctx, const DatasetBaseIterator* parent, const std::vector& input_element, int64_t thread_index, - const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, - std::unique_ptr* out_iterator, + const InstantiatedCapturedFunction& inst_captured_func, + absl::string_view prefix, std::unique_ptr* out_iterator, const std::shared_ptr& node); struct ShortCircuitInfo { diff --git a/tensorflow/core/data/dataset_test_base.cc b/tensorflow/core/data/dataset_test_base.cc index 06fbbddaf713fd..a3702920b544c3 100644 --- a/tensorflow/core/data/dataset_test_base.cc +++ b/tensorflow/core/data/dataset_test_base.cc @@ -25,6 +25,8 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc index b584149627e02d..f5a13c5f59261b 100644 --- a/tensorflow/core/data/dataset_utils.cc +++ b/tensorflow/core/data/dataset_utils.cc @@ -483,7 +483,8 @@ std::string DeterminismPolicy::String() const { } } -bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match) { +bool MatchesAnyVersion(absl::string_view op_prefix, + absl::string_view op_to_match) { if (!absl::StartsWith(op_to_match, op_prefix)) { return false; } diff --git a/tensorflow/core/data/dataset_utils.h b/tensorflow/core/data/dataset_utils.h index be04ca67582116..929af873be19c3 100644 --- a/tensorflow/core/data/dataset_utils.h +++ b/tensorflow/core/data/dataset_utils.h @@ -251,7 +251,8 @@ class DummyResourceOp : public OpKernel { // MatchesAnyVersion("BatchDataset", "BatchDatasetV2") == true // MatchesAnyVersion("BatchDataset", "BatchDatasetV3") == true // MatchesAnyVersion("PaddedBatchDataset", "BatchDataset") == false -bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match); +bool MatchesAnyVersion(absl::string_view op_prefix, + absl::string_view op_to_match); // Returns the index-th slice of a given tensor. If the index-th slice of // the tensor is not aligned, returns a deep copy of the tensor. diff --git a/tensorflow/core/data/dataset_utils_test.cc b/tensorflow/core/data/dataset_utils_test.cc index da9e16cba39b49..8d572dd07894bd 100644 --- a/tensorflow/core/data/dataset_utils_test.cc +++ b/tensorflow/core/data/dataset_utils_test.cc @@ -523,7 +523,8 @@ TEST_P(GetExperimentsJobNameTest, DatasetUtils) { } } -// TODO(mpcallanan): Remove randomness from unit tests (see go/python-tips/048). +// Note: These tests use (deterministic) randomness. The behavior is correct but +// this approach is generally frowned upon (see go/python-tips/048). INSTANTIATE_TEST_SUITE_P( Test, GetExperimentsJobNameTest, ::testing::Values( diff --git a/tensorflow/core/data/serialization_utils.cc b/tensorflow/core/data/serialization_utils.cc index b1c7137a84a6fe..a37b16202b5219 100644 --- a/tensorflow/core/data/serialization_utils.cc +++ b/tensorflow/core/data/serialization_utils.cc @@ -107,8 +107,8 @@ absl::Status FindStatefulOps(const GraphDef& graph_def, } // namespace absl::Status ReadElementsFromCheckpoint( - IteratorContext* ctx, IteratorStateReader* reader, StringPiece key_prefix, - std::vector>* elements) { + IteratorContext* ctx, IteratorStateReader* reader, + absl::string_view key_prefix, std::vector>* elements) { int64_t num_elements; TF_RETURN_IF_ERROR( reader->ReadScalar(key_prefix, kNumElements, &num_elements)); @@ -132,7 +132,8 @@ absl::Status ReadElementsFromCheckpoint( return absl::OkStatus(); } -absl::Status WriteElement(IteratorStateWriter* writer, StringPiece key_prefix, +absl::Status WriteElement(IteratorStateWriter* writer, + absl::string_view key_prefix, const std::vector>& elements, int64_t index) { const std::vector& element = elements[index]; @@ -147,7 +148,7 @@ absl::Status WriteElement(IteratorStateWriter* writer, StringPiece key_prefix, } absl::Status WriteElementsToCheckpoint( - IteratorStateWriter* writer, StringPiece key_prefix, + IteratorStateWriter* writer, absl::string_view key_prefix, const std::vector>& elements) { TF_RETURN_IF_ERROR( writer->WriteScalar(key_prefix, kNumElements, elements.size())); @@ -158,7 +159,7 @@ absl::Status WriteElementsToCheckpoint( } absl::Status UpdateCheckpointElements( - IteratorStateWriter* writer, StringPiece key_prefix, + IteratorStateWriter* writer, absl::string_view key_prefix, const std::vector>& elements, const absl::flat_hash_set& checkpoint_indices) { TF_RETURN_IF_ERROR( @@ -184,33 +185,33 @@ VariantTensorDataReader::VariantTensorDataReader( } } -absl::Status VariantTensorDataReader::ReadScalar(StringPiece key, +absl::Status VariantTensorDataReader::ReadScalar(absl::string_view key, int64_t* val) const { string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return ReadScalar(prefix, key, val); } -absl::Status VariantTensorDataReader::ReadScalar(StringPiece name, - StringPiece key, +absl::Status VariantTensorDataReader::ReadScalar(absl::string_view name, + absl::string_view key, int64_t* val) const { return ReadScalarInternal(name, key, val); } -absl::Status VariantTensorDataReader::ReadScalar(StringPiece key, +absl::Status VariantTensorDataReader::ReadScalar(absl::string_view key, tstring* val) const { string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return ReadScalar(prefix, key, val); } -absl::Status VariantTensorDataReader::ReadScalar(StringPiece name, - StringPiece key, +absl::Status VariantTensorDataReader::ReadScalar(absl::string_view name, + absl::string_view key, tstring* val) const { return ReadScalarInternal(name, key, val); } -absl::Status VariantTensorDataReader::ReadTensor(StringPiece key, +absl::Status VariantTensorDataReader::ReadTensor(absl::string_view key, Tensor* val) const { string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); @@ -218,27 +219,27 @@ absl::Status VariantTensorDataReader::ReadTensor(StringPiece key, } absl::Status VariantTensorDataReader::ReadTensor(FunctionLibraryRuntime* flr, - StringPiece key, + absl::string_view key, Tensor* val) const { string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return ReadTensorInternal(flr, prefix, key, val); } -absl::Status VariantTensorDataReader::ReadTensor(StringPiece name, - StringPiece key, +absl::Status VariantTensorDataReader::ReadTensor(absl::string_view name, + absl::string_view key, Tensor* val) const { return ReadTensor(/*flr=*/nullptr, name, key, val); } absl::Status VariantTensorDataReader::ReadTensor(FunctionLibraryRuntime* flr, - StringPiece name, - StringPiece key, + absl::string_view name, + absl::string_view key, Tensor* val) const { return ReadTensorInternal(flr, name, key, val); } -bool VariantTensorDataReader::Contains(StringPiece key) const { +bool VariantTensorDataReader::Contains(absl::string_view key) const { string prefix; if (!ExtractIteratorPrefix(key, &prefix).ok()) { return false; @@ -246,7 +247,8 @@ bool VariantTensorDataReader::Contains(StringPiece key) const { return Contains(prefix, key); } -bool VariantTensorDataReader::Contains(StringPiece n, StringPiece key) const { +bool VariantTensorDataReader::Contains(absl::string_view n, + absl::string_view key) const { string name(n); auto it = map_.find(name); if (it == map_.end()) { @@ -257,8 +259,8 @@ bool VariantTensorDataReader::Contains(StringPiece n, StringPiece key) const { } template -absl::Status VariantTensorDataReader::ReadScalarInternal(StringPiece n, - StringPiece key, +absl::Status VariantTensorDataReader::ReadScalarInternal(absl::string_view n, + absl::string_view key, T* val) const { string name(n); auto it = map_.find(name); @@ -275,7 +277,7 @@ absl::Status VariantTensorDataReader::ReadScalarInternal(StringPiece n, } absl::Status VariantTensorDataReader::ReadTensorInternal( - FunctionLibraryRuntime* flr, StringPiece n, StringPiece key, + FunctionLibraryRuntime* flr, absl::string_view n, absl::string_view key, Tensor* val) const { if (Contains(n, strings::StrCat(key, kIsDataset))) { return ReadDatasetInternal(flr, n, key, val); @@ -295,7 +297,7 @@ absl::Status VariantTensorDataReader::ReadTensorInternal( } absl::Status VariantTensorDataReader::ReadDatasetInternal( - FunctionLibraryRuntime* flr, StringPiece n, StringPiece key, + FunctionLibraryRuntime* flr, absl::string_view n, absl::string_view key, Tensor* val) const { if (flr == nullptr) { return errors::Internal( @@ -326,41 +328,41 @@ std::map VariantTensorDataReader::ReadAllTensors() { return result; } -absl::Status VariantTensorDataWriter::WriteScalar(StringPiece key, +absl::Status VariantTensorDataWriter::WriteScalar(absl::string_view key, const int64_t val) { string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return WriteScalar(prefix, key, val); } -absl::Status VariantTensorDataWriter::WriteScalar(StringPiece name, - StringPiece key, +absl::Status VariantTensorDataWriter::WriteScalar(absl::string_view name, + absl::string_view key, const int64_t val) { return WriteScalarInternal(name, key, val); } -absl::Status VariantTensorDataWriter::WriteScalar(StringPiece key, +absl::Status VariantTensorDataWriter::WriteScalar(absl::string_view key, const tstring& val) { string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return WriteScalar(prefix, key, val); } -absl::Status VariantTensorDataWriter::WriteScalar(StringPiece name, - StringPiece key, +absl::Status VariantTensorDataWriter::WriteScalar(absl::string_view name, + absl::string_view key, const tstring& val) { return WriteScalarInternal(name, key, val); } -absl::Status VariantTensorDataWriter::WriteTensor(StringPiece key, +absl::Status VariantTensorDataWriter::WriteTensor(absl::string_view key, const Tensor& val) { string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return WriteTensor(prefix, key, val); } -absl::Status VariantTensorDataWriter::WriteTensor(StringPiece name, - StringPiece key, +absl::Status VariantTensorDataWriter::WriteTensor(absl::string_view name, + absl::string_view key, const Tensor& val) { return WriteTensorInternal(name, key, val); } @@ -402,9 +404,8 @@ void VariantTensorDataWriter::GetData( } template -absl::Status VariantTensorDataWriter::WriteScalarInternal(StringPiece name, - StringPiece key, - const T& val) { +absl::Status VariantTensorDataWriter::WriteScalarInternal( + absl::string_view name, absl::string_view key, const T& val) { if (is_flushed_) { return errors::FailedPrecondition( "Cannot call WriteScalar after GetData or ReleaseData is called"); @@ -414,8 +415,8 @@ absl::Status VariantTensorDataWriter::WriteScalarInternal(StringPiece name, return WriteTensorInternal(name, key, val_t); } -absl::Status VariantTensorDataWriter::WriteTensorInternal(StringPiece n, - StringPiece key, +absl::Status VariantTensorDataWriter::WriteTensorInternal(absl::string_view n, + absl::string_view key, const Tensor& val) { DatasetBase* dataset; if (GetDatasetFromVariantTensor(val, &dataset).ok()) { @@ -440,7 +441,7 @@ absl::Status VariantTensorDataWriter::WriteTensorInternal(StringPiece n, } absl::Status VariantTensorDataWriter::WriteDatasetInternal( - StringPiece n, StringPiece key, const DatasetBase* dataset) { + absl::string_view n, absl::string_view key, const DatasetBase* dataset) { GraphDef graph_def; SerializationContext ctx((SerializationContext::Params())); TF_RETURN_IF_ERROR(AsGraphDef(dataset, std::move(ctx), &graph_def)); diff --git a/tensorflow/core/data/serialization_utils.h b/tensorflow/core/data/serialization_utils.h index 10f39712d5e2f2..e59ac959432082 100644 --- a/tensorflow/core/data/serialization_utils.h +++ b/tensorflow/core/data/serialization_utils.h @@ -43,15 +43,15 @@ inline constexpr absl::string_view kRetvalOp = "_Retval"; // Reads dataset elements from the checkpoint reader using the given key prefix. absl::Status ReadElementsFromCheckpoint( - IteratorContext* ctx, IteratorStateReader* reader, StringPiece key_prefix, - std::vector>* elements); + IteratorContext* ctx, IteratorStateReader* reader, + absl::string_view key_prefix, std::vector>* elements); // Writes dataset elements to the checkpoint writer using the given key prefix. // The elements can be read back by passing the same key prefix to // ReadElementsFromCheckpoint. Only one list of elements can be written under // the same key_prefix. absl::Status WriteElementsToCheckpoint( - IteratorStateWriter* writer, StringPiece key_prefix, + IteratorStateWriter* writer, absl::string_view key_prefix, const std::vector>& elements); // Updates the dataset elements in the checkpoint for given `checkpoint_indices` @@ -59,7 +59,7 @@ absl::Status WriteElementsToCheckpoint( // checkpointed these before. The elements can be read back by passing the same // key prefix to ReadElementsFromCheckpoint. absl::Status UpdateCheckpointElements( - IteratorStateWriter* writer, StringPiece key_prefix, + IteratorStateWriter* writer, absl::string_view key_prefix, const std::vector>& elements, const absl::flat_hash_set& checkpoint_indices); @@ -69,32 +69,33 @@ class VariantTensorDataReader : public IteratorStateReader { explicit VariantTensorDataReader( const std::vector& data); - bool Contains(StringPiece key) const override; - bool Contains(StringPiece name, StringPiece key) const override; + bool Contains(absl::string_view key) const override; + bool Contains(absl::string_view name, absl::string_view key) const override; - absl::Status ReadScalar(StringPiece key, int64_t* val) const override; - absl::Status ReadScalar(StringPiece name, StringPiece key, + absl::Status ReadScalar(absl::string_view key, int64_t* val) const override; + absl::Status ReadScalar(absl::string_view name, absl::string_view key, int64_t* val) const override; - absl::Status ReadScalar(StringPiece key, tstring* val) const override; - absl::Status ReadScalar(StringPiece name, StringPiece key, + absl::Status ReadScalar(absl::string_view key, tstring* val) const override; + absl::Status ReadScalar(absl::string_view name, absl::string_view key, tstring* val) const override; - absl::Status ReadTensor(StringPiece key, Tensor* val) const override; - absl::Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece key, + absl::Status ReadTensor(absl::string_view key, Tensor* val) const override; + absl::Status ReadTensor(FunctionLibraryRuntime* flr, absl::string_view key, Tensor* val) const override; - absl::Status ReadTensor(StringPiece name, StringPiece key, + absl::Status ReadTensor(absl::string_view name, absl::string_view key, Tensor* val) const override; - absl::Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece name, - StringPiece key, Tensor* val) const override; + absl::Status ReadTensor(FunctionLibraryRuntime* flr, absl::string_view name, + absl::string_view key, Tensor* val) const override; private: template - absl::Status ReadScalarInternal(StringPiece name, StringPiece key, + absl::Status ReadScalarInternal(absl::string_view name, absl::string_view key, T* val) const; - absl::Status ReadTensorInternal(FunctionLibraryRuntime* flr, StringPiece name, - StringPiece key, Tensor* val) const; + absl::Status ReadTensorInternal(FunctionLibraryRuntime* flr, + absl::string_view name, absl::string_view key, + Tensor* val) const; absl::Status ReadDatasetInternal(FunctionLibraryRuntime* flr, - StringPiece name, StringPiece key, - Tensor* val) const; + absl::string_view name, + absl::string_view key, Tensor* val) const; // Produces all key/value pairs stored in this reader. Useful for debugging. std::map ReadAllTensors(); @@ -118,16 +119,16 @@ class VariantTensorDataReader : public IteratorStateReader { // Now the VariantTensorData objects can be used to serialize. class VariantTensorDataWriter : public IteratorStateWriter { public: - absl::Status WriteScalar(StringPiece key, int64_t val) override; - absl::Status WriteScalar(StringPiece name, StringPiece key, + absl::Status WriteScalar(absl::string_view key, int64_t val) override; + absl::Status WriteScalar(absl::string_view name, absl::string_view key, int64_t val) override; - absl::Status WriteScalar(StringPiece key, const tstring& val) override; - absl::Status WriteScalar(StringPiece name, StringPiece key, + absl::Status WriteScalar(absl::string_view key, const tstring& val) override; + absl::Status WriteScalar(absl::string_view name, absl::string_view key, const tstring& val) override; - absl::Status WriteTensor(StringPiece key, const Tensor& val) override; - absl::Status WriteTensor(StringPiece name, StringPiece key, + absl::Status WriteTensor(absl::string_view key, const Tensor& val) override; + absl::Status WriteTensor(absl::string_view name, absl::string_view key, const Tensor& val) override; // Releases the built VariantTensorData's to `variants`. Clears out all @@ -142,11 +143,12 @@ class VariantTensorDataWriter : public IteratorStateWriter { void Reset(); template - absl::Status WriteScalarInternal(StringPiece name, StringPiece key, - const T& val); - absl::Status WriteTensorInternal(StringPiece name, StringPiece key, - const Tensor& val); - absl::Status WriteDatasetInternal(StringPiece name, StringPiece key, + absl::Status WriteScalarInternal(absl::string_view name, + absl::string_view key, const T& val); + absl::Status WriteTensorInternal(absl::string_view name, + absl::string_view key, const Tensor& val); + absl::Status WriteDatasetInternal(absl::string_view name, + absl::string_view key, const DatasetBase* dataset); bool is_flushed_ = false; diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 8a76428a848dde..564ed6f59fefb5 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -844,6 +844,17 @@ cc_library( ], ) +cc_library( + name = "test_data_transfer", + testonly = True, + srcs = ["test_data_transfer.cc"], + deps = [ + ":data_transfer", + "@com_google_absl//absl/status", + ], + alwayslink = 1, +) + cc_library( name = "test_util", testonly = True, diff --git a/tensorflow/core/data/service/client/data_service_client.cc b/tensorflow/core/data/service/client/data_service_client.cc index 8270eae147d5fc..53e875ea2a6845 100644 --- a/tensorflow/core/data/service/client/data_service_client.cc +++ b/tensorflow/core/data/service/client/data_service_client.cc @@ -80,9 +80,9 @@ absl::StatusOr GetTransferServer( return transfer_server; } } - return errors::NotFound("protocol ", protocol, - " is not available for worker ", - task_info.worker_address()); + return absl::NotFoundError(absl::StrCat("Protocol '", protocol, + "' is not available for worker '", + task_info.worker_address(), "'.")); } } // namespace @@ -362,7 +362,7 @@ DataServiceClient::CreateGrpcWorkerClient(const TaskInfo& task_info) { } absl::StatusOr> -DataServiceClient::CreateAlternativeWorkerClientWithGrpcFallback( +DataServiceClient::CreateAlternativeWorkerClientMaybeWithGrpcFallback( const DataTransferServerInfo& transfer_server, const TaskInfo& task_info) { absl::StatusOr> worker = CreateDataServiceWorkerClient(params_.protocol, transfer_server, @@ -373,10 +373,17 @@ DataServiceClient::CreateAlternativeWorkerClientWithGrpcFallback( << task_info.worker_address() << "'."; return worker; } - LOG(INFO) << "Failed to start client for data transfer protocol '" - << transfer_server.protocol() << "' for worker '" - << task_info.worker_address() << "'; falling back to grpc. " - << "Original error: " << worker.status(); + std::string client_creation_error_message = + absl::StrCat("Failed to start client for data transfer protocol '", + transfer_server.protocol(), "' for worker '", + task_info.worker_address(), "'."); + if (!transfer_server.fall_back_to_grpc_at_client_creation_time()) { + return absl::InternalError( + absl::StrCat(client_creation_error_message, + " Original error: ", worker.status().message())); + } + LOG(INFO) << client_creation_error_message + << "; falling back to gRPC. Original error: " << worker.status(); metrics::RecordTFDataServiceDataTransferProtocolFallback( transfer_server.protocol(), static_cast(worker.status().raw_code()), @@ -398,16 +405,16 @@ DataServiceClient::CreateWorkerClient(const TaskInfo& task_info) { TF_ASSIGN_OR_RETURN( DataTransferServerInfo transfer_server, GetTransferServer(params_.data_transfer_protocol, task_info)); - return CreateAlternativeWorkerClientWithGrpcFallback(transfer_server, - task_info); + return CreateAlternativeWorkerClientMaybeWithGrpcFallback(transfer_server, + task_info); } if (std::string default_protocol = DefaultDataTransferProtocol(); default_protocol != kGrpcTransferProtocol) { absl::StatusOr transfer_server = GetTransferServer(default_protocol, task_info); if (transfer_server.ok()) { - return CreateAlternativeWorkerClientWithGrpcFallback(*transfer_server, - task_info); + return CreateAlternativeWorkerClientMaybeWithGrpcFallback( + *transfer_server, task_info); } VLOG(1) << "Failed to find transfer server for default data transfer " "protocol '" @@ -875,12 +882,25 @@ absl::Status DataServiceClient::GetElement(Task* task, int64_t deadline_micros, if (!IsPreemptedError(s)) { if (task->worker->GetDataTransferProtocol() == kGrpcTransferProtocol || task->worker->GetDataTransferProtocol() == kLocalTransferProtocol) { - return s; + return absl::Status( + s.code(), + absl::StrCat( + "Failed to get an element, with a nonretryable error: ", + s.message())); } - LOG(ERROR) << "Failed to use alternative data transfer protocol '" - << task->worker->GetDataTransferProtocol() << "' for worker '" - << task->info.worker_address() - << "'; falling back to grpc. Original error: " << s; + if (!task->worker->FallBackToGrpcAtGetElementTime()) { + return absl::Status( + s.code(), + absl::StrCat("Failed to get an element over data " + "transfer protocol '", + task->worker->GetDataTransferProtocol(), + "', with a nonretryable error: ", s.message())); + } + LOG(ERROR) << "Failed to get an element over data transfer protocol '" + << task->worker->GetDataTransferProtocol() + << "', with a nonretryable error; falling back to grpc. " + "Original error: " + << s; metrics::RecordTFDataServiceDataTransferProtocolError( task->worker->GetDataTransferProtocol(), static_cast(s.raw_code()), std::string(s.message())); diff --git a/tensorflow/core/data/service/client/data_service_client.h b/tensorflow/core/data/service/client/data_service_client.h index a5bb1fd634d83d..7c211d5551c46e 100644 --- a/tensorflow/core/data/service/client/data_service_client.h +++ b/tensorflow/core/data/service/client/data_service_client.h @@ -163,7 +163,7 @@ class DataServiceClient { absl::StatusOr> CreateGrpcWorkerClient(const TaskInfo& task_info); absl::StatusOr> - CreateAlternativeWorkerClientWithGrpcFallback( + CreateAlternativeWorkerClientMaybeWithGrpcFallback( const DataTransferServerInfo& transfer_server, const TaskInfo& task_info); void Heartbeat(); void UpdateTasks(const ClientHeartbeatResponse& resp); diff --git a/tensorflow/core/data/service/common.proto b/tensorflow/core/data/service/common.proto index 9d2825082efed1..c92fc06a7baa05 100644 --- a/tensorflow/core/data/service/common.proto +++ b/tensorflow/core/data/service/common.proto @@ -128,10 +128,20 @@ enum TargetWorkers { } // Information about one of a worker server's data transfer servers. +// Next tag: 6 message DataTransferServerInfo { string protocol = 1; string address = 2; + // If provided, properties of the server used to determine compatibility with // a client. bytes compatibility_info = 3; + + // If `true`, data service clients should fall back to gRPC for this server if + // they fail to create a data transfer client for it. + bool fall_back_to_grpc_at_client_creation_time = 4; + + // If `true`, data service clients should fall back to gRPC for this server if + // it nonretryably fails to transfer an element. + bool fall_back_to_grpc_at_get_element_time = 5; } diff --git a/tensorflow/core/data/service/data_transfer.h b/tensorflow/core/data/service/data_transfer.h index cb5125b573ce97..23c8247def05ef 100644 --- a/tensorflow/core/data/service/data_transfer.h +++ b/tensorflow/core/data/service/data_transfer.h @@ -136,6 +136,14 @@ class DataTransferServer { virtual absl::StatusOr GetCompatibilityInfo() const { return std::string(); } + + // If `true`, data service clients should fall back to gRPC for this server if + // they fail to create a data transfer client for it. + virtual bool FallBackToGrpcAtClientCreationTime() const { return true; } + + // If `true`, data service clients should fall back to gRPC for this server if + // it nonretryably fails to transfer an element. + virtual bool FallBackToGrpcAtGetElementTime() const { return true; } }; } // namespace data diff --git a/tensorflow/core/data/service/dispatcher_impl.h b/tensorflow/core/data/service/dispatcher_impl.h index b82b8cb0c89544..6fa299dc1e3435 100644 --- a/tensorflow/core/data/service/dispatcher_impl.h +++ b/tensorflow/core/data/service/dispatcher_impl.h @@ -385,10 +385,10 @@ class DataServiceDispatcherImpl { absl::flat_hash_map latest_worker_heartbeats_time_ TF_GUARDED_BY(mu_); - // TODO(mpcallanan): Don't recover completed snapshots. - // TODO(mpcallanan): Garbage collect completed snapshots. // A manager for each snapshot resumed or started during the lifetime of this - // dispatcher instance. + // dispatcher instance. Note that these are *not* garbage collected; managers + // for completed snapshots will remain here for the lifetime of the dispatcher + // instance. They will even be recovered if the dispatcher is restarted. absl::flat_hash_map> snapshots_ TF_GUARDED_BY(mu_); // A single stream assignment manager shared by all managers in `snapshots_`. diff --git a/tensorflow/core/data/service/graph_rewriters.h b/tensorflow/core/data/service/graph_rewriters.h index bdcf630bc88909..e1244fd57e54cb 100644 --- a/tensorflow/core/data/service/graph_rewriters.h +++ b/tensorflow/core/data/service/graph_rewriters.h @@ -31,8 +31,6 @@ limitations under the License. namespace tensorflow { namespace data { -// TODO(mpcallanan): Refactor rewriters into shared base class. - // Rewrites the dataset graph by removing the compression map. class RemoveCompressionMapRewriter { public: diff --git a/tensorflow/core/data/service/journal.cc b/tensorflow/core/data/service/journal.cc index 0462657e2363f7..78b8228a1c4ed3 100644 --- a/tensorflow/core/data/service/journal.cc +++ b/tensorflow/core/data/service/journal.cc @@ -34,7 +34,7 @@ namespace tensorflow { namespace data { namespace { -constexpr StringPiece kJournal = "journal"; +constexpr absl::string_view kJournal = "journal"; absl::Status ParseSequenceNumber(const std::string& journal_file, int64_t* sequence_number) { @@ -92,7 +92,7 @@ absl::Status FileJournalWriter::Write(const Update& update) { return absl::OkStatus(); } -FileJournalReader::FileJournalReader(Env* env, StringPiece journal_dir) +FileJournalReader::FileJournalReader(Env* env, absl::string_view journal_dir) : env_(env), journal_dir_(journal_dir) {} absl::Status FileJournalReader::EnsureInitialized() { diff --git a/tensorflow/core/data/service/journal.h b/tensorflow/core/data/service/journal.h index 7e909a268860d3..0c15856b574043 100644 --- a/tensorflow/core/data/service/journal.h +++ b/tensorflow/core/data/service/journal.h @@ -92,7 +92,7 @@ class JournalReader { // directory, in order of their sequence numbers. See FileJournalWriter above. class FileJournalReader : public JournalReader { public: - explicit FileJournalReader(Env* env, StringPiece journal_dir); + explicit FileJournalReader(Env* env, absl::string_view journal_dir); FileJournalReader(const FileJournalReader&) = delete; FileJournalReader& operator=(const FileJournalReader&) = delete; diff --git a/tensorflow/core/data/service/journal_test.cc b/tensorflow/core/data/service/journal_test.cc index bb9132d81725aa..7c79526c093dc8 100644 --- a/tensorflow/core/data/service/journal_test.cc +++ b/tensorflow/core/data/service/journal_test.cc @@ -67,7 +67,7 @@ Update MakeRegisterDatasetUpdate() { return update; } -absl::Status CheckJournalContent(StringPiece journal_dir, +absl::Status CheckJournalContent(absl::string_view journal_dir, const std::vector& expected) { FileJournalReader reader(Env::Default(), journal_dir); for (const auto& update : expected) { diff --git a/tensorflow/core/data/service/server_lib.cc b/tensorflow/core/data/service/server_lib.cc index b49fdbcd651f74..ddcc432dce92f7 100644 --- a/tensorflow/core/data/service/server_lib.cc +++ b/tensorflow/core/data/service/server_lib.cc @@ -229,6 +229,10 @@ void WorkerGrpcDataServer::MaybeStartAlternativeDataTransferServer( return; } alternative_transfer_server.set_compatibility_info(*compatibility_info); + alternative_transfer_server.set_fall_back_to_grpc_at_client_creation_time( + transfer_server_->FallBackToGrpcAtClientCreationTime()); + alternative_transfer_server.set_fall_back_to_grpc_at_get_element_time( + transfer_server_->FallBackToGrpcAtGetElementTime()); transfer_servers.push_back(alternative_transfer_server); } diff --git a/tensorflow/core/data/service/snapshot/BUILD b/tensorflow/core/data/service/snapshot/BUILD index ffc34db5936595..cff1e60f4a4972 100644 --- a/tensorflow/core/data/service/snapshot/BUILD +++ b/tensorflow/core/data/service/snapshot/BUILD @@ -26,6 +26,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/data:snapshot_utils", + "//tensorflow/core/data/service:common_proto_cc", "//tensorflow/core/data/service:dispatcher_client", "//tensorflow/core/data/service:test_cluster", "//tensorflow/core/data/service:test_util", @@ -74,8 +75,10 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/data:dataset_test_base", "//tensorflow/core/data:snapshot_utils", + "//tensorflow/core/data/service:common_proto_cc", "//tensorflow/core/data/service:test_util", "//tensorflow/core/framework:types_proto_cc", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", @@ -96,6 +99,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core/data:name_utils", "//tensorflow/core/data:split_utils", + "//tensorflow/core/framework:dataset_options_proto_cc", "//tensorflow/core/framework:op_requires", "//tensorflow/core/framework:types_proto_cc", "@com_google_absl//absl/status", @@ -144,6 +148,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/data:snapshot_utils", "//tensorflow/core/data/service:byte_size", + "//tensorflow/core/framework:types_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", @@ -311,6 +316,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -447,6 +453,9 @@ cc_library( "//tensorflow/core/data/service:task_runner", "//tensorflow/core/data/service:worker_proto_cc", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -503,6 +512,7 @@ tf_cc_test( "//tensorflow/core/data/service:common_proto_cc", "//tensorflow/core/data/service:task_runner", "//tensorflow/core/data/service:test_util", + "//tensorflow/core/framework:types_proto_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -529,10 +539,12 @@ cc_library( "//tensorflow/core/data/service:byte_size", "//tensorflow/core/data/service:common_proto_cc", "//tensorflow/core/data/service:task_runner", + "//tensorflow/core/framework:types_proto_cc", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", ], diff --git a/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc b/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc index f95fafb9343669..03fca7e012f7c9 100644 --- a/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc +++ b/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/time/time.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/lib/io/compression.h" +#include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/dispatcher_client.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" #include "tensorflow/core/data/service/snapshot/test_utils.h" diff --git a/tensorflow/core/data/service/snapshot/file_utils_test.cc b/tensorflow/core/data/service/snapshot/file_utils_test.cc index dc4efcc9497f22..1172c66c6f3406 100644 --- a/tensorflow/core/data/service/snapshot/file_utils_test.cc +++ b/tensorflow/core/data/service/snapshot/file_utils_test.cc @@ -18,10 +18,12 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/lib/io/compression.h" #include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/data/dataset_test_base.h" +#include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/test_util.h" #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/data/service/snapshot/list_snapshot_chunks_dataset_op.cc b/tensorflow/core/data/service/snapshot/list_snapshot_chunks_dataset_op.cc index 284c762354d260..de4804a80fdd07 100644 --- a/tensorflow/core/data/service/snapshot/list_snapshot_chunks_dataset_op.cc +++ b/tensorflow/core/data/service/snapshot/list_snapshot_chunks_dataset_op.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include #include -#include #include #include #include @@ -26,6 +25,7 @@ limitations under the License. #include "tensorflow/core/data/service/snapshot/snapshot_chunk_provider.h" #include "tensorflow/core/data/split_utils.h" #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/dataset_options.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc index 1623ac904c5484..fa32f3335ba18c 100644 --- a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc +++ b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/data/service/byte_size.h" #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" diff --git a/tensorflow/core/data/service/snapshot/prefetched_split_provider.h b/tensorflow/core/data/service/snapshot/prefetched_split_provider.h index 518f8a3712d099..2ec9472cc1a9be 100644 --- a/tensorflow/core/data/service/snapshot/prefetched_split_provider.h +++ b/tensorflow/core/data/service/snapshot/prefetched_split_provider.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/container/btree_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc b/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc index 49ec21ecf6e6b2..b134b962eeb806 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_chunk_dataset_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/data/utils.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.cc b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.cc index ff1e2caea35b00..bd5bb4a25600b1 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include #include #include #include diff --git a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc index e6fcd97ef6d5dd..300730cc75654c 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/core/data/service/snapshot/snapshot_chunk_provider.h" #include -#include #include #include #include @@ -25,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/protobuf/status.pb.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager.cc b/tensorflow/core/data/service/snapshot/snapshot_manager.cc index fffd36c09139a5..72f7c330147446 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_manager.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_manager.cc @@ -29,9 +29,12 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" +#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "xla/tsl/lib/io/compression.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager.h b/tensorflow/core/data/service/snapshot/snapshot_manager.h index 5db495f16c87ce..dd3a76d640e6ec 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_manager.h +++ b/tensorflow/core/data/service/snapshot/snapshot_manager.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc b/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc index cff201261b00c4..8f9fbf47ceb4c2 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/data/service/snapshot/snapshot_manager.h" +#include #include #include diff --git a/tensorflow/core/data/service/snapshot/snapshot_split_provider.cc b/tensorflow/core/data/service/snapshot/snapshot_split_provider.cc index 5a6f8200b589ab..0bb93b29818d3f 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_split_provider.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_split_provider.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -27,6 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/time/time.h" #include "tensorflow/core/data/service/dispatcher.pb.h" #include "tensorflow/core/data/service/dispatcher_client.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_stream_writer.cc b/tensorflow/core/data/service/snapshot/snapshot_stream_writer.cc index 01412806950427..db06b19b461949 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_stream_writer.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_stream_writer.cc @@ -23,10 +23,12 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/match.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/time/time.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_stream_writer.h b/tensorflow/core/data/service/snapshot/snapshot_stream_writer.h index 3179ab167a6620..09d72d86845583 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_stream_writer.h +++ b/tensorflow/core/data/service/snapshot/snapshot_stream_writer.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/substitute.h" +#include "absl/time/clock.h" #include "absl/time/time.h" #include "tensorflow/core/data/service/byte_size.h" #include "tensorflow/core/data/service/common.pb.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc index c557d1630194e7..a70fbcb276f330 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc @@ -14,17 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/data/service/snapshot/snapshot_stream_writer.h" -#include #include #include -#include #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/lib/io/compression.h" #include "xla/tsl/lib/monitoring/cell_reader.h" @@ -39,6 +36,7 @@ limitations under the License. #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/data/standalone.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/status_matchers.h" diff --git a/tensorflow/core/data/service/snapshot/test_utils.cc b/tensorflow/core/data/service/snapshot/test_utils.cc index 7b82dd7921a6a1..a93eeb696bcbd6 100644 --- a/tensorflow/core/data/service/snapshot/test_utils.cc +++ b/tensorflow/core/data/service/snapshot/test_utils.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" +#include "absl/time/time.h" #include "tensorflow/core/data/service/byte_size.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" diff --git a/tensorflow/core/data/service/snapshot/test_utils.h b/tensorflow/core/data/service/snapshot/test_utils.h index f8aee68541c587..efa31121a06ad8 100644 --- a/tensorflow/core/data/service/snapshot/test_utils.h +++ b/tensorflow/core/data/service/snapshot/test_utils.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/time/time.h" #include "tensorflow/core/data/service/byte_size.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/snapshot/file_utils.h" @@ -30,6 +31,7 @@ limitations under the License. #include "tensorflow/core/data/service/task_runner.h" #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" diff --git a/tensorflow/core/data/service/snapshot/utils.cc b/tensorflow/core/data/service/snapshot/utils.cc index cb0ada6e01cd99..54790b24da809e 100644 --- a/tensorflow/core/data/service/snapshot/utils.cc +++ b/tensorflow/core/data/service/snapshot/utils.cc @@ -16,8 +16,6 @@ limitations under the License. #include -#include "absl/strings/match.h" -#include "absl/strings/string_view.h" #include "tensorflow/core/data/service/byte_size.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" diff --git a/tensorflow/core/data/service/test_cluster.h b/tensorflow/core/data/service/test_cluster.h index a62669d7344d3d..b1d242fe8d3c08 100644 --- a/tensorflow/core/data/service/test_cluster.h +++ b/tensorflow/core/data/service/test_cluster.h @@ -173,7 +173,9 @@ DatasetClient::DatasetClient(const TestCluster& cluster) for (size_t i = 0; i < cluster.NumWorkers(); ++i) { worker_clients_[cluster_.WorkerAddress(i)] = std::make_unique( - cluster_.WorkerAddress(i), "grpc", "grpc", + cluster_.WorkerAddress(i), /*protocol=*/"grpc", + /*transfer_protocol=*/"grpc", + /*fall_back_to_grpc_at_get_element_time=*/true, /*accelerator_device_info=*/nullptr, /*allocator=*/nullptr); } } diff --git a/tensorflow/core/data/service/test_data_transfer.cc b/tensorflow/core/data/service/test_data_transfer.cc new file mode 100644 index 00000000000000..e30dd2b847981a --- /dev/null +++ b/tensorflow/core/data/service/test_data_transfer.cc @@ -0,0 +1,215 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/status/status.h" +#include "tensorflow/core/data/service/data_transfer.h" + +namespace tensorflow { +namespace data { + +// Fake alternative data transfer protocols: +// +// - good: No errors or fallback. +// +// - bad_with_primary_fallback: Fails at data transfer client creation time and +// falls back to gRPC. +// +// - bad_without_primary_fallback: Fails at data transfer client creation time +// and doesn't fall back, taking down the entire data service client. +// +// - bad_with_secondary_fallback: Fails at get element time and falls back to +// gRPC. +// +// - bad_without_secondary_fallback: Fails at get element time and doesn't fall +// back, taking down the entire data service client. +// +constexpr const char kGoodProtocol[] = "good"; +constexpr const char kBadProtocolWithPrimaryFallback[] = + "bad_with_primary_fallback"; +constexpr const char kBadProtocolWithoutPrimaryFallback[] = + "bad_without_primary_fallback"; +constexpr const char kBadProtocolWithSecondaryFallback[] = + "bad_with_secondary_fallback"; +constexpr const char kBadProtocolWithoutSecondaryFallback[] = + "bad_without_secondary_fallback"; + +// A server that works. +class GoodTestServer : public DataTransferServer { + public: + explicit GoodTestServer(DataTransferServer::GetElementT get_element, + bool fall_back_to_grpc_at_client_creation_time = true, + bool fall_back_to_grpc_at_get_element_time = true) + : get_element_(get_element), + fall_back_to_grpc_at_client_creation_time_( + fall_back_to_grpc_at_client_creation_time), + fall_back_to_grpc_at_get_element_time_( + fall_back_to_grpc_at_get_element_time) {} + + virtual absl::Status GetElement(const GetElementRequest& req, + GetElementResult& result) { + return get_element_(&req, &result); + } + + bool FallBackToGrpcAtClientCreationTime() const override { + return fall_back_to_grpc_at_client_creation_time_; + } + + bool FallBackToGrpcAtGetElementTime() const override { + return fall_back_to_grpc_at_get_element_time_; + } + + absl::Status Start(const experimental::WorkerConfig& config) override { + return absl::OkStatus(); + } + + int Port() const override { return -1; } + + private: + DataTransferServer::GetElementT get_element_; + bool fall_back_to_grpc_at_client_creation_time_; + bool fall_back_to_grpc_at_get_element_time_; +}; + +// A server that doesn't work (by failing at get element time). +class BadTestServer : public GoodTestServer { + public: + explicit BadTestServer(DataTransferServer::GetElementT get_element, + bool fall_back_to_grpc_at_client_creation_time = true, + bool fall_back_to_grpc_at_get_element_time = true) + : GoodTestServer(get_element, fall_back_to_grpc_at_client_creation_time, + fall_back_to_grpc_at_get_element_time) {} + + absl::Status GetElement(const GetElementRequest& req, + GetElementResult& result) override { + return absl::InternalError("Bad get element."); + } +}; + +// A working client for a server that may or may not work. +template +class TestClient : public DataTransferClient { + public: + explicit TestClient(std::shared_ptr server) : server_(server) {} + + absl::Status GetElement(const GetElementRequest& req, + GetElementResult& result) override { + return server_->GetElement(req, result); + } + + void TryCancel() override {} + + private: + std::shared_ptr server_; +}; + +class DataTransferRegistrar { + public: + DataTransferRegistrar() { + // "good". + RegisterServer(kGoodProtocol, good_); + RegisterClient(kGoodProtocol, good_); + + // "bad_with_primary_fallback". + RegisterUnusedServerForBadClient( + kBadProtocolWithPrimaryFallback, + /*fall_back_to_grpc_at_client_creation_time=*/true); + RegisterBadClient(kBadProtocolWithPrimaryFallback); + + // "bad_without_primary_fallback". + RegisterUnusedServerForBadClient( + kBadProtocolWithoutPrimaryFallback, + /*fall_back_to_grpc_at_client_creation_time=*/false); + RegisterBadClient(kBadProtocolWithoutPrimaryFallback); + + // "bad_with_secondary_fallback". + RegisterServer( + kBadProtocolWithSecondaryFallback, bad_with_secondary_fallback_, + /*fall_back_to_grpc_at_get_element_time=*/true); + RegisterClient(kBadProtocolWithSecondaryFallback, + bad_with_secondary_fallback_); + + // "bad_without_secondary_fallback". + RegisterServer( + kBadProtocolWithoutSecondaryFallback, bad_without_secondary_fallback_, + /*fall_back_to_grpc_at_get_element_time=*/false); + RegisterClient(kBadProtocolWithoutSecondaryFallback, + bad_without_secondary_fallback_); + } + + private: + // Registers a server that may or may not work. + template + void RegisterServer(const std::string& protocol, + std::shared_ptr& my_server, + bool fall_back_to_grpc_at_get_element_time = true) { + DataTransferServer::Register( + protocol, [&my_server, fall_back_to_grpc_at_get_element_time]( + DataTransferServer::GetElementT get_element, + std::shared_ptr* server) { + my_server = std::make_shared( + get_element, /*fall_back_to_grpc_at_client_creation_time=*/true, + fall_back_to_grpc_at_get_element_time); + *server = my_server; + return absl::OkStatus(); + }); + } + + // Registers a working client for a server that may or may not work. + template + void RegisterClient(const std::string& protocol, + std::shared_ptr& my_server) { + DataTransferClient::Register( + protocol, [&](DataTransferClient::Config config, + std::unique_ptr* client) { + *client = std::make_unique>(my_server); + return absl::OkStatus(); + }); + } + + // Registers a working server that shouldn't get used (because its client + // should fail first, which may or may not result in a fall back). + void RegisterUnusedServerForBadClient( + const std::string& protocol, + bool fall_back_to_grpc_at_client_creation_time) { + DataTransferServer::Register( + protocol, [fall_back_to_grpc_at_client_creation_time]( + DataTransferServer::GetElementT get_element, + std::shared_ptr* server) { + *server = std::make_shared( + get_element, fall_back_to_grpc_at_client_creation_time); + return absl::OkStatus(); + }); + } + + // Registers a nonworking client (via a client creation callback that fails). + void RegisterBadClient(const std::string& protocol) { + DataTransferClient::Register( + protocol, [](DataTransferClient::Config config, + std::unique_ptr* client) { + return absl::InternalError("Bad client."); + }); + } + + std::shared_ptr good_ = nullptr; + std::shared_ptr bad_with_secondary_fallback_ = nullptr; + std::shared_ptr bad_without_secondary_fallback_ = nullptr; +}; + +static DataTransferRegistrar data_transfer_registrar; + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/worker_client.cc b/tensorflow/core/data/service/worker_client.cc index 18510e5da36276..f38d93434ee35d 100644 --- a/tensorflow/core/data/service/worker_client.cc +++ b/tensorflow/core/data/service/worker_client.cc @@ -64,7 +64,8 @@ CreateDataServiceWorkerClient( Allocator* allocator) { auto client = std::make_unique( info.address(), dispatcher_protocol, info.protocol(), - accelerator_device_info, allocator); + info.fall_back_to_grpc_at_get_element_time(), accelerator_device_info, + allocator); TF_RETURN_IF_ERROR(client->Initialize()); TF_RETURN_WITH_CONTEXT_IF_ERROR( client->CheckCompatibility(info.compatibility_info()), diff --git a/tensorflow/core/data/service/worker_client.h b/tensorflow/core/data/service/worker_client.h index 2bb5328461f323..64ac446bd3064a 100644 --- a/tensorflow/core/data/service/worker_client.h +++ b/tensorflow/core/data/service/worker_client.h @@ -37,10 +37,13 @@ class DataServiceWorkerClient : public DataServiceClientBase { DataServiceWorkerClient( const std::string& address, const std::string& protocol, const std::string& transfer_protocol, + bool fall_back_to_grpc_at_get_element_time, const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info, Allocator* allocator) : DataServiceClientBase(address, protocol), transfer_protocol_(transfer_protocol), + fall_back_to_grpc_at_get_element_time_( + fall_back_to_grpc_at_get_element_time), accelerator_device_info_(accelerator_device_info), allocator_(allocator) {} @@ -51,12 +54,21 @@ class DataServiceWorkerClient : public DataServiceClientBase { // Makes a best effort to cancel all outstanding calls in progress for the // client, and causes further calls to return Cancelled status. void TryCancel(); + // Returns an error if the client is incompatible with a server which has the // properties described in `compatibility_info`. absl::Status CheckCompatibility( const std::string& server_compatibility_info) const { return client_->CheckCompatibility(server_compatibility_info); } + + // If `true`, data service clients should fall back to gRPC for this worker + // client if it nonretryably fails to transfer an element using an alternative + // data transfer protocol. + bool FallBackToGrpcAtGetElementTime() const { + return fall_back_to_grpc_at_get_element_time_; + } + // Returns the data transfer protocol, preferring to use the local transfer // protocol if a local tf.data worker exists. std::string GetDataTransferProtocol() const; @@ -66,6 +78,7 @@ class DataServiceWorkerClient : public DataServiceClientBase { private: std::string transfer_protocol_; + bool fall_back_to_grpc_at_get_element_time_; const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info_; Allocator* allocator_; diff --git a/tensorflow/core/data/snapshot_utils.cc b/tensorflow/core/data/snapshot_utils.cc index 50f3bc86cab92b..f5d90442b0639d 100644 --- a/tensorflow/core/data/snapshot_utils.cc +++ b/tensorflow/core/data/snapshot_utils.cc @@ -345,10 +345,10 @@ CustomWriter::~CustomWriter() { } } -absl::Status CustomWriter::WriteRecord(const StringPiece& data) { +absl::Status CustomWriter::WriteRecord(const absl::string_view& data) { char header[kHeaderSize]; core::EncodeFixed64(header, data.size()); - TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); + TF_RETURN_IF_ERROR(dest_->Append(absl::string_view(header, sizeof(header)))); return dest_->Append(data); } @@ -356,7 +356,7 @@ absl::Status CustomWriter::WriteRecord(const StringPiece& data) { absl::Status CustomWriter::WriteRecord(const absl::Cord& data) { char header[kHeaderSize]; core::EncodeFixed64(header, data.size()); - TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); + TF_RETURN_IF_ERROR(dest_->Append(absl::string_view(header, sizeof(header)))); return dest_->Append(data); } #endif // TF_CORD_SUPPORT @@ -917,7 +917,7 @@ absl::Status CustomReader::ReadTensorsV0(std::vector* read_tensors) { #if defined(PLATFORM_GOOGLE) absl::Cord c; TF_RETURN_IF_ERROR(ReadRecord(&c)); - record.ParseFromCord(c); + record.ParseFromString(c); #else // PLATFORM_GOOGLE tstring record_bytes; TF_RETURN_IF_ERROR(ReadRecord(&record_bytes)); diff --git a/tensorflow/core/data/snapshot_utils.h b/tensorflow/core/data/snapshot_utils.h index d543dcb3d29e40..f083cbe495fa72 100644 --- a/tensorflow/core/data/snapshot_utils.h +++ b/tensorflow/core/data/snapshot_utils.h @@ -154,7 +154,7 @@ class CustomWriter : public Writer { absl::Status Initialize(tensorflow::Env* env) override; private: - absl::Status WriteRecord(const StringPiece& data); + absl::Status WriteRecord(const absl::string_view& data); #if defined(TF_CORD_SUPPORT) absl::Status WriteRecord(const absl::Cord& data); diff --git a/tensorflow/core/debug/bfc_dump_reader.cc b/tensorflow/core/debug/bfc_dump_reader.cc index 5c780c7c9ae09b..aabdf146fc5e4a 100644 --- a/tensorflow/core/debug/bfc_dump_reader.cc +++ b/tensorflow/core/debug/bfc_dump_reader.cc @@ -38,7 +38,7 @@ MemoryDump ReadDumpFile(const string& fname) { } std::unique_ptr buffer(static_cast(malloc(file_size + 1))); DCHECK(buffer.get()); - StringPiece contents(buffer.get(), file_size); + absl::string_view contents(buffer.get(), file_size); status = file->Read(0, file_size, &contents, buffer.get()); if (!status.ok()) { LOG(ERROR) << "read from file " << fname << " failed " << status; diff --git a/tensorflow/core/debug/debug_graph_utils.cc b/tensorflow/core/debug/debug_graph_utils.cc index 1a40b13b227fd5..2b772e74c81153 100644 --- a/tensorflow/core/debug/debug_graph_utils.cc +++ b/tensorflow/core/debug/debug_graph_utils.cc @@ -382,7 +382,7 @@ absl::Status DebugNodeInserter::ParseDebugOpName( std::vector attribute_segs = str_util::Split(arguments, ";"); for (const string& attribute_seg : attribute_segs) { - StringPiece seg(attribute_seg); + absl::string_view seg(attribute_seg); str_util::RemoveWhitespaceContext(&seg); if (seg.empty()) { continue; @@ -429,7 +429,7 @@ absl::Status DebugNodeInserter::SetDebugNodeAttributes( debug_node->AddAttr(attr.name(), attr_value); } else if (attr.type() == "float") { float float_value = 0.0; - if (!::tensorflow::strings::safe_strtof(attr_value, &float_value)) { + if (!absl::SimpleAtof(attr_value, &float_value)) { return absl::InvalidArgumentError(absl::StrCat( "Invalid value string for float-type attribute ", attr.name(), "of debug node ", debug_node->name(), ": \"", attr_value, "\"")); @@ -437,7 +437,7 @@ absl::Status DebugNodeInserter::SetDebugNodeAttributes( debug_node->AddAttr(attr.name(), float_value); } else if (attr.type() == "int") { int64_t int_value = 0; - if (!::tensorflow::strings::safe_strto64(attr_value, &int_value)) { + if (!absl::SimpleAtoi(attr_value, &int_value)) { return absl::InvalidArgumentError(absl::StrCat( "Invalid value string for int-type attribute ", attr.name(), "of debug node ", debug_node->name(), ": \"", attr_value, "\"")); diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc index 9698076c36aba1..04317455a9450e 100644 --- a/tensorflow/core/debug/debug_io_utils.cc +++ b/tensorflow/core/debug/debug_io_utils.cc @@ -316,7 +316,7 @@ absl::Status ReadEventFromFile(const string& dump_file_path, Event* event) { return s; } - StringPiece result; + absl::string_view result; s = file->Read(0, file_size, &result, &(content)[0]); if (!s.ok()) { return s; diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 00515c71df7917..184740930d10db 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -109,6 +109,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@local_tsl//tsl/platform:stacktrace", ], ) @@ -172,6 +173,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/activity_watcher", "//tensorflow/core/protobuf:worker_proto_cc", + "@com_google_absl//absl/log", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_rpc_handler", @@ -512,6 +514,8 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/debug", "//tensorflow/core/protobuf:worker_proto_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@local_tsl//tsl/profiler/lib:connected_traceme", "@local_tsl//tsl/profiler/lib:context_types_hdrs", diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc index bdc2acbcd5b5a0..45e0327c7d37d0 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc @@ -107,8 +107,8 @@ BaseRemoteRendezvous::~BaseRemoteRendezvous() { // Returns true if "device_name" is a valid full name of local device // of the "worker". This helper is purely based on the worker name // and device name and does no lookups in the worker->device_mgr. -static bool IsLocalDevice(const StringPiece worker_name, - const StringPiece device_name) { +static bool IsLocalDevice(const absl::string_view worker_name, + const absl::string_view device_name) { return absl::StartsWith(device_name, worker_name); } diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc index 90f60c1d903d56..a09cec8ab6c778 100644 --- a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc +++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc @@ -123,6 +123,8 @@ class MockCoordinationServiceAgent : public CoordinationServiceAgent { (override)); MOCK_METHOD(void, CancelBarrierAsync, (std::string_view barrier_id, StatusCallback done), (override)); + MOCK_METHOD(absl::StatusOr>, GetAliveTasks, + (const std::vector& tasks), (override)); MOCK_METHOD(absl::StatusOr, GetEnv, (), (override)); MOCK_METHOD(void, SetError, (const absl::Status& error), (override)); MOCK_METHOD(absl::Status, ActivateWatch, diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 564bc065895cd5..3ee65b08338b61 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -15,10 +15,16 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/graph_mgr.h" -#include // NOLINT(build/c++11) +#include +#include +#include #include +#include +#include #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "tensorflow/core/common_runtime/build_graph_options.h" #include "tensorflow/core/common_runtime/debugger_state_interface.h" diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index 87ff4621ac7199..5c8c7ce0f20c95 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -16,9 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ +#include +#include +#include +#include #include #include +#include "absl/status/status.h" #include "tensorflow/core/common_runtime/costmodel_manager.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" diff --git a/tensorflow/core/distributed_runtime/integration_test/BUILD b/tensorflow/core/distributed_runtime/integration_test/BUILD index 7408bcbfdc9f71..b79b482be20dce 100644 --- a/tensorflow/core/distributed_runtime/integration_test/BUILD +++ b/tensorflow/core/distributed_runtime/integration_test/BUILD @@ -49,8 +49,8 @@ tf_cuda_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/distributed_runtime:server_lib", - "//tensorflow/core/platform:blocking_counter", "//tensorflow/core/platform:env", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@local_xla//xla/tsl/lib/core:status_test_util", ], @@ -137,8 +137,8 @@ tf_cc_test( "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/common_runtime/eager:kernel_and_device", "//tensorflow/core/distributed_runtime:server_lib", - "//tensorflow/core/platform:blocking_counter", "//tensorflow/core/platform:env", + "@com_google_absl//absl/synchronization", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent", ], diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc index 9f803991417dce..250521412ea9d5 100644 --- a/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc +++ b/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/synchronization/barrier.h" #include "absl/time/time.h" #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/eager/c_api.h" @@ -27,7 +28,6 @@ limitations under the License. #include "xla/tsl/protobuf/coordination_config.pb.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/framework/function.pb.h" -#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/cluster.pb.h" @@ -186,7 +186,7 @@ TEST(CAPI, MultiClientSetGetConfigInOp) { tensorflow::ServerDef server_def = GetMultiClientServerDef("worker", cluster_size); ConfigCoordinationService(&server_def); - BlockingCounter finish_counter(cluster_size); + absl::Barrier finish_counter(cluster_size); auto worker_thread_fn = [&](int worker_id) { tensorflow::ServerDef server_def_copy = server_def; // By default, server_def has task index set to 0. @@ -255,8 +255,7 @@ TEST(CAPI, MultiClientSetGetConfigInOp) { TFE_ExecutorWaitForAllPendingNodes(executor, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); - finish_counter.DecrementCount(); - finish_counter.Wait(); + finish_counter.Block(); TFE_DeleteExecutor(executor); TFE_DeleteContext(ctx); }; @@ -273,9 +272,9 @@ TEST(CAPI, MultiClientCoordinationSetGetConfigs) { tensorflow::ServerDef server_def = GetMultiClientServerDef("worker", cluster_size); ConfigCoordinationService(&server_def); - tensorflow::BlockingCounter counter1(cluster_size); - tensorflow::BlockingCounter counter2(cluster_size); - tensorflow::BlockingCounter counter3(cluster_size); + absl::Barrier counter1(cluster_size); + absl::Barrier counter2(cluster_size); + absl::Barrier counter3(cluster_size); auto worker_thread_fn = [&](int worker_id) { tensorflow::ServerDef server_def_copy = server_def; @@ -302,8 +301,7 @@ TEST(CAPI, MultiClientCoordinationSetGetConfigs) { ctx, key.c_str(), tensorflow::strings::StrCat("value", worker_id).c_str(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - counter1.DecrementCount(); - counter1.Wait(); + counter1.Block(); const int next_id = (worker_id + 1) % cluster_size; // Setting next_key errors out because it has been set by another worker @@ -319,14 +317,12 @@ TEST(CAPI, MultiClientCoordinationSetGetConfigs) { value_buf->length}; EXPECT_EQ(value_str, tensorflow::strings::StrCat("value", next_id)); TF_DeleteBuffer(value_buf); - counter2.DecrementCount(); - counter2.Wait(); + counter2.Block(); // Delete key TFE_DeleteConfigKeyValue(ctx, key.c_str(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - counter3.DecrementCount(); - counter3.Wait(); + counter3.Block(); TFE_DeleteContext(ctx); TF_DeleteStatus(status); @@ -345,9 +341,9 @@ TEST(CAPI, MultiClientPropagateError) { GetMultiClientServerDef("worker", cluster_size); ConfigCoordinationService(&server_def); // Barrier for initializing the cluster. - tensorflow::BlockingCounter counter1(cluster_size); + absl::Barrier counter1(cluster_size); // Barrier for finishing executing operations on all workers. - tensorflow::BlockingCounter counter2(cluster_size); + absl::Barrier counter2(cluster_size); auto worker_thread_fn = [&](int worker_id) { tensorflow::ServerDef server_def_copy = server_def; @@ -367,8 +363,7 @@ TEST(CAPI, MultiClientPropagateError) { TFE_EnableCollectiveOps(ctx, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - counter1.DecrementCount(); - counter1.Wait(); + counter1.Block(); // Set error from worker/1 if (worker_id == 1) { @@ -389,8 +384,7 @@ TEST(CAPI, MultiClientPropagateError) { TFE_DeleteTensorHandle(in); TFE_DeleteTensorHandle(retvals[0]); TFE_DeleteOp(allreduce); - counter2.DecrementCount(); - counter2.Wait(); + counter2.Block(); TFE_DeleteContext(ctx); TF_DeleteStatus(status); diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_function_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_function_test.cc index 25eb3da148a23f..7d767e9a8ce42a 100644 --- a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_function_test.cc +++ b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_function_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/synchronization/barrier.h" #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/framework/device_attributes.pb.h" -#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" @@ -192,17 +192,17 @@ TEST_P(MultiClientSendRecvTest, TestMultiClientSendRecv) { tensorflow::ServerDef server_def = GetMultiClientServerDef("worker", cluster_size); - // Enable coordination service for propagating remote device attributess + // Enable coordination service for propagating remote device attributes auto* coord_config = server_def.mutable_default_session_config() ->mutable_experimental() ->mutable_coordination_config(); coord_config->set_service_type("standalone"); coord_config->set_service_leader("/job:worker/replica:0/task:0"); - // The blocking counter makes sure that worker/0 thread (leader that starts + // The barrier makes sure that worker/0 thread (leader that starts // the coordination service) does not exit early while other workers are still // interacting with the coordination service. - tensorflow::BlockingCounter counter(cluster_size); + absl::Barrier barrier(cluster_size); auto worker_thread_fn = [&](int worker_id) { tensorflow::ServerDef server_def_copy = server_def; @@ -347,12 +347,11 @@ TEST_P(MultiClientSendRecvTest, TestMultiClientSendRecv) { // retrieves it, we need to do the following steps: // 1. Since we created async EagerContext, we need to force each worker to // wait until all pending operations finish before deleting the context. - // 2. In addition, use the blocking counter to notify the 2 workers when + // 2. In addition, use the barrier to notify the 2 workers when // it is safe to clean up all the data. TFE_ContextAsyncWait(ctx, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - counter.DecrementCount(); - counter.Wait(); + barrier.Block(); { tensorflow::mutex_lock l(mu); diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 761188d14f9b40..778b0c6c0198b1 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -332,7 +332,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { template absl::Status RunPartitionsHelper( - const std::unordered_map& feeds, + const std::unordered_map& + feeds, const FetchListType& fetches, const MasterEnv* env, int64_t step_id, int64_t execution_count, PerStepState* pss, CallOptions* call_opts, const ClientRequestType& req, ClientResponseType* resp, @@ -653,7 +654,8 @@ struct RunCallableResponseWrapper { template absl::Status MasterSession::ReffedClientGraph::RunPartitionsHelper( - const std::unordered_map& feeds, + const std::unordered_map& + feeds, const FetchListType& fetches, const MasterEnv* env, int64_t step_id, int64_t execution_count, PerStepState* pss, CallOptions* call_opts, const ClientRequestType& req, ClientResponseType* resp, @@ -825,7 +827,7 @@ absl::Status MasterSession::ReffedClientGraph::RunPartitions( VLOG(2) << "RunPartitions step_id " << step_id << " execution_count " << execution_count; // Maps the names of fed tensors to their index in `req`. - std::unordered_map feeds(3); + std::unordered_map feeds(3); for (size_t i = 0; i < req.num_feeds(); ++i) { if (!feeds.insert({req.feed_name(i), i}).second) { return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i)); @@ -849,7 +851,7 @@ absl::Status MasterSession::ReffedClientGraph::RunPartitions( VLOG(2) << "RunPartitions step_id " << step_id << " execution_count " << execution_count; // Maps the names of fed tensors to their index in `req`. - std::unordered_map feeds(3); + std::unordered_map feeds(3); for (size_t i = 0, end = callable_opts_.feed_size(); i < end; ++i) { if (!feeds.insert({callable_opts_.feed(i), i}).second) { // MakeCallable will fail if there are two feeds with the same name. diff --git a/tensorflow/core/distributed_runtime/remote_device.h b/tensorflow/core/distributed_runtime/remote_device.h index 766c9d8e167f8d..591531f94d567f 100644 --- a/tensorflow/core/distributed_runtime/remote_device.h +++ b/tensorflow/core/distributed_runtime/remote_device.h @@ -36,7 +36,7 @@ class WorkerCacheInterface; // This callback should have the same definition as DeviceMgr::LookupDevice // It assigns *device with pointer to Device of the given 'name', where 'name' // is either a full device name, or just the replica-local suffix. -typedef std::function +typedef std::function LookupLocalDevice; // Creates Remote Devices for the provided device attributes. Helpful when the diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 0b7b731e6f51fb..5f662e77795cf1 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -127,6 +127,9 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/protobuf:worker_proto_cc", "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", ] + tf_grpc_cc_dependencies(), ) @@ -208,6 +211,8 @@ tf_cuda_library( "//tensorflow/core/profiler/lib:scoped_memory_debug_annotation", "//tensorflow/core/protobuf:worker_proto_cc", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@local_xla//xla/tsl/distributed_runtime/rpc:async_service_interface", "@local_xla//xla/tsl/distributed_runtime/rpc:grpc_call", "@local_xla//xla/tsl/protobuf:rpc_options_proto_cc", @@ -508,6 +513,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/protobuf:worker_proto_cc", + "@com_google_absl//absl/status", ] + tf_grpc_cc_dependencies(), ) diff --git a/tensorflow/core/distributed_runtime/rpc/eager/BUILD b/tensorflow/core/distributed_runtime/rpc/eager/BUILD index 96945529341a09..e9effec0448504 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/eager/BUILD @@ -39,6 +39,8 @@ cc_library( "//tensorflow/core/platform:error_payloads", "//tensorflow/core/protobuf:eager_service_proto_cc", "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@local_xla//xla/tsl/distributed_runtime:call_options", ] + tf_grpc_cc_dependencies(), ) @@ -56,6 +58,8 @@ cc_library( "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", "//tensorflow/core/protobuf:eager_service_proto_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@local_xla//xla/tsl/distributed_runtime/rpc:async_service_interface", "@local_xla//xla/tsl/distributed_runtime/rpc:grpc_call", ] + tf_grpc_cc_dependencies(), @@ -78,5 +82,7 @@ tf_cc_test( "//tensorflow/core/platform:env", "//tensorflow/core/platform:status", "//tensorflow/core/platform:strcat", + "@com_google_absl//absl/status", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc index 00141a5dc89f30..fa8d608835d89e 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc @@ -15,10 +15,18 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" +#include #include +#include #include +#include +#include +#include +#include #include "grpcpp/generic/generic_stub.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "xla/tsl/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h" diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h index 8a926da488477b..2eb41b8a2103df 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_CLIENT_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_CLIENT_H_ +#include + #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client_test.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client_test.cc index 128da6b893add2..7b0cff37dc92d7 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client_test.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" +#include + +#include "absl/status/status.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/blocking_counter.h" diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc index 5d58d415c81470..b9bea2ea437a7a 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc @@ -17,11 +17,13 @@ limitations under the License. #include +#include "absl/status/status.h" #include "xla/tsl/distributed_runtime/rpc/grpc_call.h" #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" +#include "tensorflow/core/protobuf/eager_service.pb.h" namespace tensorflow { namespace eager { diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h index 7417b9a74a754d..7acc29556696bd 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h @@ -16,14 +16,19 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_ +#include + #include "grpcpp/alarm.h" #include "grpcpp/completion_queue.h" #include "grpcpp/server_builder.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" #include "xla/tsl/distributed_runtime/rpc/grpc_call.h" #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h" #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" +#include "tensorflow/core/protobuf/eager_service.pb.h" namespace tensorflow { namespace eager { @@ -45,7 +50,7 @@ class GrpcEagerServiceImpl : public tsl::AsyncServiceInterface { virtual ~GrpcEagerServiceImpl() {} // Create a master context in eager service. - absl::Status CreateMasterContext(const tensorflow::uint64 context_id, + absl::Status CreateMasterContext(tensorflow::uint64 context_id, EagerContext* context); void HandleRPCsLoop() override; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc index 1039acd85ef9c2..4cffa9e2ce40f7 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc @@ -301,12 +301,12 @@ class GrpcMasterService : public tsl::AsyncServiceInterface { // Start tracing, including the ID attached to the RPC. tsl::profiler::TraceMe* TraceRpc( - StringPiece name, + absl::string_view name, const std::multimap<::grpc::string_ref, ::grpc::string_ref>& metadata) { - StringPiece id; + absl::string_view id; auto it = metadata.find(GrpcIdKey()); if (it != metadata.end()) { - id = StringPiece(it->second.data(), it->second.size()); + id = absl::string_view(it->second.data(), it->second.size()); } return new tsl::profiler::TraceMe( [&] { return strings::StrCat(name, ":", id); }, diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc index 803d543aee63b7..6ccc00364c3962 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc @@ -115,7 +115,7 @@ class GrpcRemoteMaster : public MasterInterface { private: // Start tracing, attaching a unique ID to both the trace and the RPC. - tsl::profiler::TraceMe* NewTraceRpc(StringPiece name, + tsl::profiler::TraceMe* NewTraceRpc(absl::string_view name, ::grpc::ClientContext* ctx) { string trace_id = strings::StrCat(tsl::tracing::GetUniqueArg()); ctx->AddMetadata(GrpcIdKey(), trace_id); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 5b78748a909c06..8bd1193724f271 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -159,8 +159,7 @@ absl::Status GrpcServer::GetHostAndPort(const ServerDef& server_def, server_def.DebugString()); } auto colon_index = iter->second.find_last_of(':'); - if (!strings::safe_strto32(iter->second.substr(colon_index + 1), - port)) { + if (!absl::SimpleAtoi(iter->second.substr(colon_index + 1), port)) { return errors::InvalidArgument( "Could not parse port for local server from \"", iter->second, "\"."); @@ -419,8 +418,7 @@ absl::Status GrpcServer::WorkerCacheFactory( int requested_port; auto colon_index = host_port.find_last_of(':'); - if (!strings::safe_strto32(host_port.substr(colon_index + 1), - &requested_port)) { + if (!absl::SimpleAtoi(host_port.substr(colon_index + 1), &requested_port)) { return errors::Internal("Could not parse port for local server from \"", host_port, "\"."); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc index 1b5ae927544a14..9e293d70e0e3ea 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc @@ -826,7 +826,7 @@ TEST(GrpcSessionTest, LongErrorMessage) { auto a = test::graph::Constant(&g, Tensor()); a->set_assigned_device_name(dev_a); std::vector long_string_buffer(1024 * 1024, 'x'); - StringPiece long_string(long_string_buffer.data(), 1024 * 1024); + absl::string_view long_string(long_string_buffer.data(), 1024 * 1024); string name = strings::StrCat(long_string, "fantasia!"); auto a_err = test::graph::Error(&g, a, name); a_err->set_assigned_device_name(dev_a); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc index 33f40b9d39fa63..dd3848bc6e3ebf 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc @@ -15,12 +15,16 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h" +#include + +#include "grpcpp/impl/codegen/byte_buffer.h" #include "grpcpp/support/byte_buffer.h" #include "grpcpp/support/slice.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" -#include "tensorflow/core/framework/tensor_reference.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/io/proto_encode_helper.h" @@ -134,17 +138,17 @@ static void EncodeSkeleton(const Tensor& val, io::ProtoEncodeHelper* e) { #endif } -void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, bool require_ack, - ::grpc::ByteBuffer* result) { +absl::Status EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, + bool require_ack, + ::grpc::ByteBuffer* result) { const int kLargeTensorBytes = 1024; const int64_t kProtoBufLimitBytes = 1LL << 31; if (val.TotalBytes() > kProtoBufLimitBytes) { size_t exceeded_bytes = val.TotalBytes() - kProtoBufLimitBytes; - LOG(FATAL) << "Cannot encode a Tensor that exceeds the 2GB protobuf limit. " - "Exceeded bytes: " - << exceeded_bytes - << ", tensor shape: " << val.shape().AsProto().DebugString(); + return absl::InternalError(absl::StrCat( + "Cannot encode a Tensor that exceeds the 2GB protobuf limit. ", + "Exceeded bytes: ", exceeded_bytes)); } RecvTensorResponse response; @@ -169,7 +173,7 @@ void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, bool require_ack, io::ProtoEncodeHelper e_skeleton(skeleton.data(), skeleton.size()); EncodeSkeleton(val, &e_skeleton); - StringPiece tdata = val.tensor_data(); + absl::string_view tdata = val.tensor_data(); uint32 overall_tensor_proto_bytesize = (e_skeleton.size() + VarLengthEncodingSize(TensorProto::kTensorContentFieldNumber, @@ -206,7 +210,7 @@ void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, bool require_ack, e.WriteVarlengthBeginning(RecvTensorResponse::kTensorFieldNumber, overall_tensor_proto_bytesize); // (C) - e.WriteRawBytes(StringPiece(e_skeleton.data(), e_skeleton.size())); + e.WriteRawBytes(absl::string_view(e_skeleton.data(), e_skeleton.size())); // (D1) & (D2) e.WriteVarlengthBeginning(TensorProto::kTensorContentFieldNumber, tdata.size()); @@ -249,6 +253,7 @@ void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, bool require_ack, ::grpc::ByteBuffer tmp(&slices[0], num_slices); result->Swap(&tmp); } + return absl::OkStatus(); } } // namespace grpc diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h index ffcc4a2bbfa7f4..393ef2a70f96e5 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_TENSOR_CODING_H_ #include "grpcpp/impl/codegen/byte_buffer.h" +#include "absl/status/status.h" namespace tensorflow { class Tensor; @@ -46,8 +47,9 @@ void EncodeRecvTensorResponseToByteBuffer(const RecvTensorResponse& proto, // "val" holds the tensor value to be encoded. // // Discards original contents of *result. -void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, bool require_ack, - ::grpc::ByteBuffer* result); +absl::Status EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, + bool require_ack, + ::grpc::ByteBuffer* result); } // namespace grpc } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc index f4b36334237a09..1b6e71f048a57a 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc @@ -17,6 +17,8 @@ limitations under the License. #include "grpcpp/support/byte_buffer.h" #include "grpcpp/support/slice.h" +#include "absl/status/status.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -31,7 +33,8 @@ class GrpcTensorCodingTest : public ::testing::Test { void Validate(const Tensor& t, bool is_dead) { // Check by encoding to a ByteBuffer ::grpc::ByteBuffer buf; - grpc::EncodeTensorToByteBuffer(is_dead, t, false, &buf); + absl::Status s = grpc::EncodeTensorToByteBuffer(is_dead, t, false, &buf); + TF_EXPECT_OK(s); // Make a string std::vector<::grpc::Slice> slices; @@ -100,4 +103,12 @@ TEST_F(GrpcTensorCodingTest, Simple) { TEST_F(GrpcTensorCodingTest, StringTensor) { DoTestForStrings(DT_STRING); } +TEST_F(GrpcTensorCodingTest, LargeTensor) { + Tensor t(DT_INT8, TensorShape({1, 1 + (1LL << 31)})); + ::grpc::ByteBuffer buf; + absl::Status s = grpc::EncodeTensorToByteBuffer(/*is_dead=*/false, t, + /*require_ack=*/false, &buf); + EXPECT_EQ(s.code(), absl::StatusCode::kInternal); +} + } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc index 77f7d11283044f..45ee1c6df3ae1c 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc @@ -56,7 +56,7 @@ absl::Status FillServerDef(const string& cluster_spec, const string& job_name, const string& job_name = job_pieces[0]; job_def->set_name(job_name); // Does a bit more validation of the tasks_per_replica. - const StringPiece spec = job_pieces[1]; + const absl::string_view spec = job_pieces[1]; // job_str is of form |. const std::vector host_ports = str_util::Split(spec, ';'); for (size_t i = 0; i < host_ports.size(); ++i) { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc index f48ed0c11b73bc..d8a7a0b99dd9ab 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc @@ -62,7 +62,7 @@ absl::Status FillServerDef(const string& job_spec, const string& job_name, return errors::InvalidArgument("Invalid job string: ", job_str); } - const StringPiece spec = job_pieces[1]; + const absl::string_view spec = job_pieces[1]; // job_str is of form |. const std::vector host_ports = str_util::Split(spec, ';'); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index 68abc533c1fa67..d6abcf6d117063 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -18,11 +18,14 @@ limitations under the License. #include #include #include +#include #include #include "grpcpp/alarm.h" #include "grpcpp/server_builder.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" #include "xla/tsl/distributed_runtime/rpc/grpc_call.h" #include "xla/tsl/protobuf/rpc_options.pb.h" @@ -455,13 +458,26 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts, bool cache_enabled = (response_cache_ != nullptr && request_id != 0); - auto do_response = [response, done, cache_enabled]( + auto do_response = [request, response, done = std::move(done), cache_enabled]( const Tensor& tensor, bool is_dead, const absl::Status& status) { + absl::Status updated_status; if (status.ok()) { - grpc::EncodeTensorToByteBuffer(is_dead, tensor, cache_enabled, response); + updated_status = grpc::EncodeTensorToByteBuffer(is_dead, tensor, + cache_enabled, response); + if (!updated_status.ok()) { + updated_status = absl::InternalError(absl::StrCat( + "Failed to encode tensor to byte buffer: ", + updated_status.message(), " (request_id: ", request->request_id(), + " step_id: ", request->step_id(), + " rendezvous_key: ", request->rendezvous_key(), ")")); + LOG(ERROR) << "Failure to encode response during GrpcRecvTensorAsync: " + << updated_status; + } + } else { + updated_status = status; } - done(status); + done(updated_status); }; // If response cache is enabled and the response cache already contains the diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc index 51cbbbac941437..fffd799235ee70 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -58,7 +58,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { public: RpcRecvTensorCall() : wi_(nullptr), dst_device_(nullptr) {} - void Init(WorkerInterface* wi, int64_t step_id, StringPiece key, + void Init(WorkerInterface* wi, int64_t step_id, absl::string_view key, AllocatorAttributes alloc_attrs, Device* dst_device, const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) { wi_ = wi; diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc index aa6399f55c01a0..a881b2952fa5fa 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.cc +++ b/tensorflow/core/distributed_runtime/session_mgr.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "xla/tsl/protobuf/coordination_config.pb.h" @@ -207,7 +208,8 @@ absl::Status SessionMgr::CreateSession( } auto device_mgr = std::make_unique(std::move(renamed_devices)); - LookupLocalDevice cb = [&device_mgr](StringPiece name, Device** device) { + LookupLocalDevice cb = [&device_mgr](absl::string_view name, + Device** device) { return device_mgr->LookupDevice(name, device); }; AsRemoteDevices(worker_env_->env, cluster_device_attributes, cb, @@ -220,6 +222,7 @@ absl::Status SessionMgr::CreateSession( } auto graph_mgr = std::make_unique(worker_env_, device_mgr.get()); + VLOG(1) << "Creating WorkerSession with owned DeviceMgr."; worker_session.reset(new WorkerSession( session, worker_name, std::unique_ptr(worker_cache), @@ -244,6 +247,7 @@ absl::Status SessionMgr::CreateSession( // WorkerSession has been deleted. auto graph_mgr = std::make_unique(worker_env_, worker_env_->device_mgr); + VLOG(1) << "Creating WorkerSession with borrowed DeviceMgr."; worker_session = WorkerSession::CreateWithBorrowedDeviceMgr( session, worker_name, std::unique_ptr(worker_cache), diff --git a/tensorflow/core/distributed_runtime/tensor_coding.cc b/tensorflow/core/distributed_runtime/tensor_coding.cc index 4b4c7e4d8f5c32..1990f0c17c66a4 100644 --- a/tensorflow/core/distributed_runtime/tensor_coding.cc +++ b/tensorflow/core/distributed_runtime/tensor_coding.cc @@ -196,7 +196,7 @@ bool TensorResponse::ParseTensorSubmessage( seen_tensor_content = true; TensorShape shape(tensor_meta->tensor_shape()); Tensor t(allocator_, tensor_meta->dtype(), shape); - StringPiece buf = t.tensor_data(); + absl::string_view buf = t.tensor_data(); if (static_cast(num_bytes) != buf.size()) return false; // TODO(jeff,sanjay): Figure out a way to avoid this copy if // the underlying ZeroCopyInputStream data is properly aligned diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc index d40e409d22770c..d9286d0d148843 100644 --- a/tensorflow/core/distributed_runtime/worker_session.cc +++ b/tensorflow/core/distributed_runtime/worker_session.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/core/lib/monitoring/gauge.h" +#include "tsl/platform/stacktrace.h" namespace tensorflow { @@ -197,6 +198,8 @@ WorkerSession::WorkerSession( } WorkerSession::~WorkerSession() { + VLOG(1) << "WorkerSession::~WorkerSession @@stacktrace\n " + << tsl::CurrentStackTrace(); if (graph_mgr_) { absl::Status s = graph_mgr_->DeregisterAll(); if (!s.ok()) { diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h index 3f5c3c7f485634..092fabe6cbe427 100644 --- a/tensorflow/core/example/feature_util.h +++ b/tensorflow/core/example/feature_util.h @@ -201,7 +201,7 @@ template <> struct is_string : std::true_type {}; template <> -struct is_string<::tensorflow::StringPiece> : std::true_type {}; +struct is_string : std::true_type {}; template <> struct is_string : std::true_type {}; diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index fd8f9bcf72a2c9..11bb2ef4c5cb3c 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -840,6 +840,7 @@ tf_cuda_library( "variant_tensor_data.h", ], visibility = [ + "//learning/infra/runtime/experimental/mixed_engine:__subpackages__", "//tensorflow:__pkg__", "//tensorflow/core:__pkg__", "//tensorflow/core/runtime_fallback:__subpackages__", @@ -890,6 +891,7 @@ tf_cuda_library( "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/strings", "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:ml_dtypes", "@local_xla//xla/tsl/framework:device_type", "@local_xla//xla/tsl/util:byte_swap_array", ], diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc index 351ba293276456..f1ed3aca82dc1c 100644 --- a/tensorflow/core/framework/attr_value_util.cc +++ b/tensorflow/core/framework/attr_value_util.cc @@ -186,8 +186,8 @@ string SummarizeString(const string& str) { // If the string is long, replace the middle with ellipses. constexpr int kMaxStringSummarySize = 80; if (escaped.size() >= kMaxStringSummarySize) { - StringPiece prefix(escaped); - StringPiece suffix = prefix; + absl::string_view prefix(escaped); + absl::string_view suffix = prefix; prefix.remove_suffix(escaped.size() - 10); suffix.remove_prefix(escaped.size() - 10); return strings::StrCat("\"", prefix, "...", suffix, "\""); @@ -351,7 +351,8 @@ string SummarizeAttrValue(const AttrValue& attr_value) { return ""; // Prevent missing return warning } -Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) { +absl::Status AttrValueHasType(const AttrValue& attr_value, + absl::string_view type) { int num_set = 0; #define VALIDATE_FIELD(name, type_string, oneof_case) \ @@ -449,7 +450,8 @@ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) { return absl::OkStatus(); } -bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) { +bool ParseAttrValue(absl::string_view type, absl::string_view text, + AttrValue* out) { // Parse type. string field_name; bool is_list = absl::ConsumePrefix(&type, "list("); @@ -483,7 +485,7 @@ bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) { if (is_list) { // TextFormat parser considers "i: 7" to be the same as "i: [7]", // but we only want to allow list values with []. - StringPiece cleaned = text; + absl::string_view cleaned = text; str_util::RemoveLeadingWhitespace(&cleaned); str_util::RemoveTrailingWhitespace(&cleaned); if (cleaned.size() < 2 || cleaned[0] != '[' || @@ -552,11 +554,12 @@ void SetAttrValue(absl::Span value, AttrValue* out) { } } -void SetAttrValue(StringPiece value, AttrValue* out) { +void SetAttrValue(absl::string_view value, AttrValue* out) { out->set_s(value.data(), value.size()); } -void SetAttrValue(const absl::Span value, AttrValue* out) { +void SetAttrValue(const absl::Span value, + AttrValue* out) { out->mutable_list()->Clear(); // Create list() even if value empty. for (const auto& v : value) { out->mutable_list()->add_s(v.data(), v.size()); diff --git a/tensorflow/core/framework/attr_value_util.h b/tensorflow/core/framework/attr_value_util.h index fa6b6dda979b1e..b6f7c972c71624 100644 --- a/tensorflow/core/framework/attr_value_util.h +++ b/tensorflow/core/framework/attr_value_util.h @@ -45,7 +45,8 @@ class NameAttrList; std::string SummarizeAttrValue(const AttrValue& attr_value); // Generates an error if attr_value doesn't have the indicated attr type. -absl::Status AttrValueHasType(const AttrValue& attr_value, StringPiece type); +absl::Status AttrValueHasType(const AttrValue& attr_value, + absl::string_view type); // Converts a text proto value from "text" into the field of *out // indicated by "type" (e.g. from the type field of an AttrDef). @@ -54,13 +55,14 @@ absl::Status AttrValueHasType(const AttrValue& attr_value, StringPiece type); // * If type:"list(string)" and text:"['foo', 'bar']", // then *out is set to "list { s: ['foo', 'bar'] }" // Returns true on success. -bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out); +bool ParseAttrValue(absl::string_view type, absl::string_view text, + AttrValue* out); // Sets *out based on the type of value. void SetAttrValue(const std::string& value, AttrValue* out); void SetAttrValue(const tstring& value, AttrValue* out); void SetAttrValue(const char* value, AttrValue* out); -void SetAttrValue(StringPiece value, AttrValue* out); +void SetAttrValue(absl::string_view value, AttrValue* out); void SetAttrValue(int64_t value, AttrValue* out); void SetAttrValue(int32_t value, AttrValue* out); void SetAttrValue(float value, AttrValue* out); @@ -77,7 +79,7 @@ void SetAttrValue(const NameAttrList& value, AttrValue* out); void SetAttrValue(absl::Span value, AttrValue* out); void SetAttrValue(absl::Span value, AttrValue* out); void SetAttrValue(absl::Span value, AttrValue* out); -void SetAttrValue(absl::Span value, AttrValue* out); +void SetAttrValue(absl::Span value, AttrValue* out); void SetAttrValue(absl::Span value, AttrValue* out); void SetAttrValue(absl::Span value, AttrValue* out); void SetAttrValue(absl::Span value, AttrValue* out); diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc index 996acd12d78b3b..1e576146448d0a 100644 --- a/tensorflow/core/framework/collective.cc +++ b/tensorflow/core/framework/collective.cc @@ -176,14 +176,14 @@ CollectiveContext::CollectiveContext( int64_t CollectiveExecutor::kInvalidId = -1; /*static*/ -Status CollectiveRegistry::Lookup( +absl::Status CollectiveRegistry::Lookup( const string& collective_name, CollectiveImplementationInterface** implementation) { return LookupHelper(collective_name, implementation, false); } /*static*/ -Status CollectiveRegistry::LookupParamResolverInstance( +absl::Status CollectiveRegistry::LookupParamResolverInstance( const string& collective_name, CollectiveImplementationInterface** implementation) { return LookupHelper(collective_name, implementation, true); @@ -198,8 +198,8 @@ void CollectiveRegistry::GetAll( } /*static*/ -Status CollectiveRegistry::Register(const string& collective_name, - Factory factory) { +absl::Status CollectiveRegistry::Register(const string& collective_name, + Factory factory) { std::vector* registry = MutableCollectiveRegistry(); for (const RegistrationInfo& reg_info : *registry) { if (reg_info.name == collective_name) @@ -211,7 +211,7 @@ Status CollectiveRegistry::Register(const string& collective_name, } /*static*/ -Status CollectiveRegistry::LookupHelper( +absl::Status CollectiveRegistry::LookupHelper( const string& collective_name, CollectiveImplementationInterface** implementation, bool param_resolver) { std::vector* registry = MutableCollectiveRegistry(); diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index b400203013b0b2..53e64d698b1f28 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -40,7 +40,7 @@ namespace shape_inference { // The V2 version computes windowed output size with arbitrary dilation_rate and // explicit padding, while the original version only handles the cases where // dilation_rates equal to 1 and the padding is SAME or VALID. -Status GetWindowedOutputSizeFromDimsV2( +absl::Status GetWindowedOutputSizeFromDimsV2( shape_inference::InferenceContext* c, shape_inference::DimensionHandle input_size, shape_inference::DimensionOrConstant filter_size, int64_t dilation_rate, @@ -87,7 +87,7 @@ Status GetWindowedOutputSizeFromDimsV2( return absl::OkStatus(); } -Status GetWindowedOutputSizeFromDims( +absl::Status GetWindowedOutputSizeFromDims( shape_inference::InferenceContext* c, shape_inference::DimensionHandle input_size, shape_inference::DimensionOrConstant filter_size, int64_t stride, @@ -106,7 +106,7 @@ Status GetWindowedOutputSizeFromDims( -1, -1, output_size); } -Status UnchangedShape(shape_inference::InferenceContext* c) { +absl::Status UnchangedShape(shape_inference::InferenceContext* c) { c->set_output(0, c->input(0)); auto* handle_data = c->input_handle_shapes_and_types(0); if (handle_data != nullptr) { @@ -115,7 +115,7 @@ Status UnchangedShape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status MatMulShape(shape_inference::InferenceContext* c) { +absl::Status MatMulShape(shape_inference::InferenceContext* c) { ShapeHandle a; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a)); @@ -142,8 +142,8 @@ namespace { // Validate that an Einsum subscript contains exactly one or zero ellipsis; and // that periods (.) occur only within an ellipses (...). -Status ValidateEinsumEllipsis(absl::string_view subscript, - bool* found_ellipsis) { +absl::Status ValidateEinsumEllipsis(absl::string_view subscript, + bool* found_ellipsis) { const int num_periods = absl::c_count(subscript, '.'); if (num_periods != 0 && num_periods != 3) { return errors::InvalidArgument( @@ -160,7 +160,7 @@ Status ValidateEinsumEllipsis(absl::string_view subscript, } // namespace -Status EinsumShape(shape_inference::InferenceContext* c) { +absl::Status EinsumShape(shape_inference::InferenceContext* c) { // We assume that the equation has a valid format. Either (x),(y)->(z) // or (x)->(z), where each of (x), (y) and (z) are concatenation of zero or // more latin alphabets and contains at most one ellipsis ('...'). @@ -314,7 +314,7 @@ Status EinsumShape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) { +absl::Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) { ShapeHandle a_shape; ShapeHandle b_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape)); @@ -351,7 +351,7 @@ Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status BatchMatMulShape(shape_inference::InferenceContext* c) { +absl::Status BatchMatMulShape(shape_inference::InferenceContext* c) { ShapeHandle a_shape; ShapeHandle b_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape)); @@ -387,12 +387,12 @@ Status BatchMatMulShape(shape_inference::InferenceContext* c) { // -------------------------------------------------------------------------- -Status BiasAddShape(shape_inference::InferenceContext* c) { +absl::Status BiasAddShape(shape_inference::InferenceContext* c) { ShapeHandle input_shape; // Fetch the data_format attribute, which may not exist. string data_format; - Status s = c->GetAttr("data_format", &data_format); + absl::Status s = c->GetAttr("data_format", &data_format); if (s.ok() && data_format == "NCHW") { TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape)); @@ -446,11 +446,11 @@ Status BiasAddShape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status BiasAddGradShape(shape_inference::InferenceContext* c) { +absl::Status BiasAddGradShape(shape_inference::InferenceContext* c) { ShapeHandle input_shape; // Fetch the data_format attribute, which may not exist. string data_format; - Status s = c->GetAttr("data_format", &data_format); + absl::Status s = c->GetAttr("data_format", &data_format); if (s.ok() && data_format == "NCHW") { TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape)); @@ -463,10 +463,9 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format, - const ShapeHandle shape_handle, - const string& tensor_name, - shape_inference::InferenceContext* c) { +absl::Status CheckFormatConstraintsOnShape( + const TensorFormat tensor_format, const ShapeHandle shape_handle, + const string& tensor_name, shape_inference::InferenceContext* c) { if (tensor_format == FORMAT_NCHW_VECT_C) { // Check that the vect dim has size 4 or 32. const int num_dims = c->Rank(shape_handle); @@ -482,7 +481,7 @@ Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format, return absl::OkStatus(); } -Status DatasetIteratorShape(shape_inference::InferenceContext* c) { +absl::Status DatasetIteratorShape(shape_inference::InferenceContext* c) { shape_inference::ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); std::vector output_shapes; @@ -502,10 +501,10 @@ Status DatasetIteratorShape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, - const std::vector& spatial, - DimensionOrConstant C, ShapeHandle* out, - shape_inference::InferenceContext* context) { +absl::Status MakeShapeFromFormat( + TensorFormat format, DimensionOrConstant N, + const std::vector& spatial, DimensionOrConstant C, + ShapeHandle* out, shape_inference::InferenceContext* context) { const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format); std::vector dims_actual(num_dims); dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N); @@ -527,11 +526,11 @@ Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, return absl::OkStatus(); } -Status DimensionsFromShape(ShapeHandle shape, TensorFormat format, - DimensionHandle* batch_dim, - absl::Span spatial_dims, - DimensionHandle* filter_dim, - InferenceContext* context) { +absl::Status DimensionsFromShape(ShapeHandle shape, TensorFormat format, + DimensionHandle* batch_dim, + absl::Span spatial_dims, + DimensionHandle* filter_dim, + InferenceContext* context) { const int32_t rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format); // Batch. @@ -554,11 +553,13 @@ Status DimensionsFromShape(ShapeHandle shape, TensorFormat format, } // vect_size must be provided if format is NCHW_VECT_C. -Status ShapeFromDimensions(DimensionHandle batch_dim, - absl::Span spatial_dims, - DimensionHandle filter_dim, TensorFormat format, - absl::optional vect_size, - InferenceContext* context, ShapeHandle* shape) { +absl::Status ShapeFromDimensions(DimensionHandle batch_dim, + absl::Span spatial_dims, + DimensionHandle filter_dim, + TensorFormat format, + absl::optional vect_size, + InferenceContext* context, + ShapeHandle* shape) { const int32_t rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format); std::vector out_dims(rank); @@ -590,8 +591,8 @@ Status ShapeFromDimensions(DimensionHandle batch_dim, namespace { -Status Conv2DShapeImpl(shape_inference::InferenceContext* c, - bool supports_explicit_padding) { +absl::Status Conv2DShapeImpl(shape_inference::InferenceContext* c, + bool supports_explicit_padding) { string data_format_str, filter_format_str; if (!c->GetAttr("data_format", &data_format_str).ok()) { data_format_str = "NHWC"; @@ -706,7 +707,7 @@ Status Conv2DShapeImpl(shape_inference::InferenceContext* c, TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); std::vector explicit_paddings; if (supports_explicit_padding) { - Status s = c->GetAttr("explicit_paddings", &explicit_paddings); + absl::Status s = c->GetAttr("explicit_paddings", &explicit_paddings); // Use the default value, which is an empty list, if the attribute is not // found. Otherwise return the error to the caller. if (!s.ok() && !errors::IsNotFound(s)) { @@ -722,7 +723,7 @@ Status Conv2DShapeImpl(shape_inference::InferenceContext* c, std::vector p_list; // `padding_list` attribute is used by Fused int8 convolutions to support // explicit paddings. - Status s_p_list = c->GetAttr("padding_list", &p_list); + absl::Status s_p_list = c->GetAttr("padding_list", &p_list); if (!s_p_list.ok() && !errors::IsNotFound(s_p_list)) { return s_p_list; } @@ -766,7 +767,7 @@ Status Conv2DShapeImpl(shape_inference::InferenceContext* c, } // namespace // Shape function for general Convolution operation. -Status ConvShape(shape_inference::InferenceContext* c) { +absl::Status ConvShape(shape_inference::InferenceContext* c) { ShapeHandle input_shape = c->input(0); ShapeHandle filter_shape = c->input(1); @@ -933,7 +934,7 @@ Status ConvShape(shape_inference::InferenceContext* c) { "Explicit padding not supported for 3D Convolution"); } std::vector explicit_paddings; - Status s = c->GetAttr("explicit_paddings", &explicit_paddings); + absl::Status s = c->GetAttr("explicit_paddings", &explicit_paddings); // Use the default value, which is an empty list, if the attribute is not // found. Otherwise return the error to the caller. if (!s.ok() && !absl::IsNotFound(s)) { @@ -985,25 +986,26 @@ Status ConvShape(shape_inference::InferenceContext* c) { } // Shape function for Conv2D-like operations that support explicit padding. -Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c) { +absl::Status Conv2DShapeWithExplicitPadding( + shape_inference::InferenceContext* c) { return Conv2DShapeImpl(c, true); } // Shape function for Conv2D-like operations that do not support explicit // padding. -Status Conv2DShape(shape_inference::InferenceContext* c) { +absl::Status Conv2DShape(shape_inference::InferenceContext* c) { return Conv2DShapeImpl(c, false); } // TODO(mjanusz): Unify all conv/pooling shape functions. -Status Conv3DShape(shape_inference::InferenceContext* c) { +absl::Status Conv3DShape(shape_inference::InferenceContext* c) { ShapeHandle input_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); ShapeHandle filter_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape)); string data_format; - Status s = c->GetAttr("data_format", &data_format); + absl::Status s = c->GetAttr("data_format", &data_format); std::vector dilations; TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations)); @@ -1110,7 +1112,7 @@ Status Conv3DShape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) { +absl::Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) { string data_format_str; if (!c->GetAttr("data_format", &data_format_str).ok()) { data_format_str = "NHWC"; @@ -1182,11 +1184,12 @@ Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext* c) { +absl::Status Conv2DBackpropFilterWithBiasShape( + shape_inference::InferenceContext* c) { ShapeHandle input_shape; // Fetch the data_format attribute, which may not exist. string data_format; - Status s = c->GetAttr("data_format", &data_format); + absl::Status s = c->GetAttr("data_format", &data_format); TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); if (s.ok() && data_format == "NCHW") { @@ -1203,8 +1206,8 @@ Status Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext* c) { namespace { -Status DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext* c, - bool supports_explicit_padding) { +absl::Status DepthwiseConv2DNativeShapeImpl( + shape_inference::InferenceContext* c, bool supports_explicit_padding) { ShapeHandle input_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); ShapeHandle filter_shape; @@ -1233,7 +1236,7 @@ Status DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext* c, } string data_format_str; - Status s = c->GetAttr("data_format", &data_format_str); + absl::Status s = c->GetAttr("data_format", &data_format_str); TensorFormat data_format; if (!s.ok() || !FormatFromString(data_format_str, &data_format)) { data_format = FORMAT_NHWC; @@ -1280,7 +1283,7 @@ Status DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext* c, std::vector explicit_paddings; if (supports_explicit_padding) { - Status status = c->GetAttr("explicit_paddings", &explicit_paddings); + absl::Status status = c->GetAttr("explicit_paddings", &explicit_paddings); // Use the default value, which is an empty list, if the attribute is not // found. Otherwise return the error to the caller. if (!status.ok() && !errors::IsNotFound(status)) { @@ -1325,19 +1328,19 @@ Status DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext* c, }; // namespace -Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) { +absl::Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) { return DepthwiseConv2DNativeShapeImpl(c, false); } -Status DepthwiseConv2DNativeShapeWithExplicitPadding( +absl::Status DepthwiseConv2DNativeShapeWithExplicitPadding( shape_inference::InferenceContext* c) { return DepthwiseConv2DNativeShapeImpl(c, true); } -Status AvgPoolShape(shape_inference::InferenceContext* c) { +absl::Status AvgPoolShape(shape_inference::InferenceContext* c) { string data_format_str; TensorFormat data_format; - Status s = c->GetAttr("data_format", &data_format_str); + absl::Status s = c->GetAttr("data_format", &data_format_str); if (s.ok()) { FormatFromString(data_format_str, &data_format); } else { @@ -1403,7 +1406,7 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status AvgPoolGradShape(shape_inference::InferenceContext* c) { +absl::Status AvgPoolGradShape(shape_inference::InferenceContext* c) { ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); @@ -1411,7 +1414,7 @@ Status AvgPoolGradShape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status FusedBatchNormShape(shape_inference::InferenceContext* c) { +absl::Status FusedBatchNormShape(shape_inference::InferenceContext* c) { string data_format_str; TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); TensorFormat data_format; @@ -1453,13 +1456,13 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c) { +absl::Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(FusedBatchNormShape(c)); c->set_output(5, c->UnknownShape()); return absl::OkStatus(); } -Status FusedBatchNormExShape(shape_inference::InferenceContext* c) { +absl::Status FusedBatchNormExShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(FusedBatchNormV3Shape(c)); string data_format_str; @@ -1484,7 +1487,7 @@ Status FusedBatchNormExShape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { +absl::Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { string data_format_str; TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); TensorFormat data_format; @@ -1525,7 +1528,7 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c) { +absl::Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(FusedBatchNormGradShape(c)); int num_side_inputs; @@ -1561,8 +1564,8 @@ Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor, - int32* lower_diag_index, int32* upper_diag_index) { +absl::Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor, + int32* lower_diag_index, int32* upper_diag_index) { // This function assumes that the shape of diag_index_tensor is fully defined. if (diag_index_tensor->dims() == 0) { *lower_diag_index = diag_index_tensor->scalar()(); @@ -1584,7 +1587,7 @@ Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor, return absl::OkStatus(); } -Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) { +absl::Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) { ShapeHandle input_shape, diag_index_shape, unused_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape)); @@ -1637,7 +1640,7 @@ Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) { +absl::Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) { // Checks input ranks. ShapeHandle input_shape, diag_index_shape, unused_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape)); @@ -1738,7 +1741,7 @@ Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) { +absl::Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) { ShapeHandle input_shape, diag_shape, diag_index_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag_shape)); @@ -1810,11 +1813,11 @@ Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status MaxPoolShapeImpl(shape_inference::InferenceContext* c, - bool supports_explicit_padding) { +absl::Status MaxPoolShapeImpl(shape_inference::InferenceContext* c, + bool supports_explicit_padding) { string data_format_str; TensorFormat data_format; - Status s = c->GetAttr("data_format", &data_format_str); + absl::Status s = c->GetAttr("data_format", &data_format_str); if (s.ok()) { FormatFromString(data_format_str, &data_format); } else { @@ -1866,7 +1869,7 @@ Status MaxPoolShapeImpl(shape_inference::InferenceContext* c, std::vector explicit_paddings; if (supports_explicit_padding) { - Status status = c->GetAttr("explicit_paddings", &explicit_paddings); + absl::Status status = c->GetAttr("explicit_paddings", &explicit_paddings); // Use the default value, which is an empty list, if the attribute is not // found. Otherwise return the error to the caller. if (!status.ok() && !errors::IsNotFound(status)) { @@ -1906,22 +1909,24 @@ Status MaxPoolShapeImpl(shape_inference::InferenceContext* c, return absl::OkStatus(); } -Status MaxPoolShape(shape_inference::InferenceContext* c) { +absl::Status MaxPoolShape(shape_inference::InferenceContext* c) { return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/false); } -Status MaxPoolGradShape(shape_inference::InferenceContext* c) { +absl::Status MaxPoolGradShape(shape_inference::InferenceContext* c) { return UnchangedShapeWithRank(c, 4); } -Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c) { +absl::Status MaxPoolShapeWithExplicitPadding( + shape_inference::InferenceContext* c) { return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/true); } -Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { +absl::Status MaxPoolV2Shape(shape_inference::InferenceContext* c, + int num_inputs) { string data_format_str; TensorFormat data_format; - Status s = c->GetAttr("data_format", &data_format_str); + absl::Status s = c->GetAttr("data_format", &data_format_str); if (s.ok()) { FormatFromString(data_format_str, &data_format); } else { @@ -2020,12 +2025,12 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { return absl::OkStatus(); } -Status Pool3DShape(shape_inference::InferenceContext* c) { +absl::Status Pool3DShape(shape_inference::InferenceContext* c) { ShapeHandle input_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); string data_format; - Status s = c->GetAttr("data_format", &data_format); + absl::Status s = c->GetAttr("data_format", &data_format); std::vector strides; TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); @@ -2102,11 +2107,11 @@ Status Pool3DShape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status MaxPool3DGradShape(shape_inference::InferenceContext* c) { +absl::Status MaxPool3DGradShape(shape_inference::InferenceContext* c) { return UnchangedShapeWithRank(c, 5); } -Status AvgPool3DGradShape(shape_inference::InferenceContext* c) { +absl::Status AvgPool3DGradShape(shape_inference::InferenceContext* c) { ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); @@ -2114,7 +2119,7 @@ Status AvgPool3DGradShape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } -Status UnknownShape(shape_inference::InferenceContext* c) { +absl::Status UnknownShape(shape_inference::InferenceContext* c) { for (int i = 0; i < c->num_outputs(); ++i) { c->set_output(i, c->UnknownShape()); } @@ -2122,9 +2127,9 @@ Status UnknownShape(shape_inference::InferenceContext* c) { } template -Status ReductionShapeHelper(const Tensor* reduction_indices_t, - const int32_t input_rank, - std::set* true_indices) { +absl::Status ReductionShapeHelper(const Tensor* reduction_indices_t, + const int32_t input_rank, + std::set* true_indices) { auto reduction_indices = reduction_indices_t->flat(); for (int i = 0; i < reduction_indices_t->NumElements(); ++i) { const T reduction_index = reduction_indices(i); @@ -2144,7 +2149,7 @@ Status ReductionShapeHelper(const Tensor* reduction_indices_t, return absl::OkStatus(); } -Status ReductionShape(InferenceContext* c) { +absl::Status ReductionShape(InferenceContext* c) { ShapeHandle input = c->input(0); ShapeHandle indices; @@ -2201,8 +2206,8 @@ Status ReductionShape(InferenceContext* c) { return absl::OkStatus(); } -Status ConcatShapeHelper(InferenceContext* c, int start_value_index, - int end_value_index, int dim_index) { +absl::Status ConcatShapeHelper(InferenceContext* c, int start_value_index, + int end_value_index, int dim_index) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused)); const Tensor* concat_dim_t = c->input_tensor(dim_index); @@ -2289,29 +2294,30 @@ Status ConcatShapeHelper(InferenceContext* c, int start_value_index, return absl::OkStatus(); } -Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) { +absl::Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) { return ConcatShapeHelper(c, 1 /* start_value_index */, 1 + num_inputs_to_concat /* end_value_index */, 0 /* dim_index */); } -Status ConcatV2Shape(InferenceContext* c) { +absl::Status ConcatV2Shape(InferenceContext* c) { return ConcatShapeHelper(c, 0 /* start_value_index */, c->num_inputs() - 1 /* end_value_index */, c->num_inputs() - 1 /* dim_index */); } -Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) { +absl::Status QuantizedConcatV2Shape(InferenceContext* c, + int num_inputs_to_concat) { return ConcatShapeHelper(c, 0 /* start_value_index */, num_inputs_to_concat /* end_value_index */, num_inputs_to_concat /* dim_index */); } -Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, - ShapeHandle shape_x, - ShapeHandle shape_y, - bool incompatible_shape_error, - ShapeHandle* out) { +absl::Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, + ShapeHandle shape_x, + ShapeHandle shape_y, + bool incompatible_shape_error, + ShapeHandle* out) { CHECK_NOTNULL(out); if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) { *out = c->UnknownShape(); @@ -2382,7 +2388,7 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, } } else { DimensionHandle dim; - Status s = c->Merge(dim_x, dim_y, &dim); + absl::Status s = c->Merge(dim_x, dim_y, &dim); if (!s.ok()) { if (!incompatible_shape_error) { *out = c->MakeShape({}); @@ -2398,14 +2404,14 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, return absl::OkStatus(); } -Status RandomShape(shape_inference::InferenceContext* c) { +absl::Status RandomShape(shape_inference::InferenceContext* c) { shape_inference::ShapeHandle out; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); c->set_output(0, out); return absl::OkStatus(); } -Status SegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) { +absl::Status SegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) { ShapeHandle s_data = c->input(0); ShapeHandle s_segment_ids = c->input(1); ShapeHandle s_num_segments = c->input(2); @@ -2441,9 +2447,9 @@ namespace { // This SliceHelper processes the output shape of the `slice` // when the tensor of `sizes` is available. template -Status SliceHelper(InferenceContext* c, ShapeHandle begin_value, - const Tensor* sizes_value, - std::vector* dims) { +absl::Status SliceHelper(InferenceContext* c, ShapeHandle begin_value, + const Tensor* sizes_value, + std::vector* dims) { auto sizes_vec = sizes_value->vec(); for (int i = 0; i < sizes_value->NumElements(); ++i) { DimensionHandle dim = c->Dim(c->input(0), i); @@ -2467,7 +2473,7 @@ Status SliceHelper(InferenceContext* c, ShapeHandle begin_value, } } // namespace -Status SliceShape(InferenceContext* c) { +absl::Status SliceShape(InferenceContext* c) { ShapeHandle input = c->input(0); ShapeHandle begin_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape)); @@ -2543,8 +2549,10 @@ Status SliceShape(InferenceContext* c) { return absl::OkStatus(); } -Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, - ShapeHandle values_shape, ShapeHandle shape_shape) { +absl::Status ValidateSparseTensor(InferenceContext* c, + ShapeHandle indices_shape, + ShapeHandle values_shape, + ShapeHandle shape_shape) { // Validate ranks. ShapeHandle unused_shape; TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape)); @@ -2584,7 +2592,7 @@ Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, return absl::OkStatus(); } -Status ValidateVariableResourceHandle( +absl::Status ValidateVariableResourceHandle( InferenceContext* c, std::vector* shape_and_type) { auto* handle_data = c->input_handle_shapes_and_types(0); if (handle_data == nullptr || handle_data->empty()) { @@ -2604,7 +2612,7 @@ Status ValidateVariableResourceHandle( return absl::OkStatus(); } -Status GatherNdShape(InferenceContext* c) { +absl::Status GatherNdShape(InferenceContext* c) { ShapeHandle params; std::vector handle_shape_and_type; if (c->input_handle_shapes_and_types(0) != nullptr) { @@ -2640,9 +2648,10 @@ Status GatherNdShape(InferenceContext* c) { return absl::OkStatus(); } -Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape, - ShapeHandle updates_shape, - ShapeHandle input_shape) { +absl::Status ScatterNdShapeHelper(InferenceContext* c, + ShapeHandle indices_shape, + ShapeHandle updates_shape, + ShapeHandle input_shape) { if (c->Value(c->NumElements(input_shape)) == 0 && (c->Value(c->NumElements(indices_shape)) > 0 || c->Value(c->NumElements(updates_shape)) > 0)) { @@ -2667,7 +2676,7 @@ Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape, TF_RETURN_IF_ERROR( c->Subshape(updates_shape, 0, outer_dims, &prefix_updates)); - Status s = c->Merge(prefix_indices, prefix_updates, &unused); + absl::Status s = c->Merge(prefix_indices, prefix_updates, &unused); if (!s.ok()) { return errors::InvalidArgument( "Dimensions [0,", outer_dims, @@ -2703,7 +2712,7 @@ Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape, return absl::OkStatus(); } -Status ExplicitShape(InferenceContext* c) { +absl::Status ExplicitShape(InferenceContext* c) { PartialTensorShape shape; TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); ShapeHandle output_shape; @@ -2712,7 +2721,7 @@ Status ExplicitShape(InferenceContext* c) { return absl::OkStatus(); } -Status ExplicitShapes(InferenceContext* c) { +absl::Status ExplicitShapes(InferenceContext* c) { std::vector shapes; TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes)); if (shapes.empty()) { @@ -2727,7 +2736,7 @@ Status ExplicitShapes(InferenceContext* c) { return absl::OkStatus(); } -Status SparseReduceShapeFn(InferenceContext* c) { +absl::Status SparseReduceShapeFn(InferenceContext* c) { // Input 0: input_indices // Input 1: input_values // Input 2: input_shape @@ -2775,7 +2784,7 @@ Status SparseReduceShapeFn(InferenceContext* c) { return UnknownShape(c); } -Status QuantizedConv2DShape(InferenceContext* c) { +absl::Status QuantizedConv2DShape(InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); @@ -2787,7 +2796,7 @@ Status QuantizedConv2DShape(InferenceContext* c) { return absl::OkStatus(); } -Status FusedQuantizedConvShape(InferenceContext* c, int num_dims) { +absl::Status FusedQuantizedConvShape(InferenceContext* c, int num_dims) { std::vector fused_ops; TF_RETURN_IF_ERROR(c->GetAttr("fused_ops", &fused_ops)); ShapeHandle unused, channel; @@ -2834,19 +2843,19 @@ Status FusedQuantizedConvShape(InferenceContext* c, int num_dims) { return absl::OkStatus(); } -Status FusedQuantizedConv2DShape(InferenceContext* c) { +absl::Status FusedQuantizedConv2DShape(InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShapeImpl(c, true)); TF_RETURN_IF_ERROR(FusedQuantizedConvShape(c, 4)); return absl::OkStatus(); } -Status FusedQuantizedDepthwiseConv2D(InferenceContext* c) { +absl::Status FusedQuantizedDepthwiseConv2D(InferenceContext* c) { TF_RETURN_IF_ERROR(DepthwiseConv2DNativeShapeImpl(c, true)); TF_RETURN_IF_ERROR(FusedQuantizedConvShape(c, 4)); return absl::OkStatus(); } -Status QuantizedAvgPoolShape(InferenceContext* c) { +absl::Status QuantizedAvgPoolShape(InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::AvgPoolShape(c)); ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); @@ -2856,9 +2865,9 @@ Status QuantizedAvgPoolShape(InferenceContext* c) { return absl::OkStatus(); } -Status QuantizeV2Shape(InferenceContext* c) { +absl::Status QuantizeV2Shape(InferenceContext* c) { int axis = -1; - Status s = c->GetAttr("axis", &axis); + absl::Status s = c->GetAttr("axis", &axis); if (!s.ok() && s.code() != error::NOT_FOUND) { return s; } @@ -2882,7 +2891,7 @@ Status QuantizeV2Shape(InferenceContext* c) { return absl::OkStatus(); } -Status ReduceScatterShape(shape_inference::InferenceContext* c) { +absl::Status ReduceScatterShape(shape_inference::InferenceContext* c) { shape_inference::ShapeHandle in = c->input(0); if (!c->RankKnown(in)) { // Input shape unknown, so set unknown output shape. diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index f1d43d6c2abfd3..1be1633fb48ff5 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -27,28 +27,28 @@ namespace shape_inference { // Like GetWindowedOutputSize, but deals with DimensionHandles. Does not support // EXPLICIT padding. -Status GetWindowedOutputSizeFromDims(InferenceContext* c, - DimensionHandle input_size, - DimensionOrConstant filter_size, - int64_t stride, Padding padding_type, - DimensionHandle* output_size); +absl::Status GetWindowedOutputSizeFromDims(InferenceContext* c, + DimensionHandle input_size, + DimensionOrConstant filter_size, + int64_t stride, Padding padding_type, + DimensionHandle* output_size); // The V2 version computes the same outputs with arbitrary dilation_rate, and // supports EXPLICIT padding. For detailed equations, refer to the comments // for GetWindowedOutputSize(). The 'padding_before' and 'padding_after' // parameters are only used if padding_type == EXPLICIT. -Status GetWindowedOutputSizeFromDimsV2( +absl::Status GetWindowedOutputSizeFromDimsV2( InferenceContext* c, DimensionHandle input_size, DimensionOrConstant filter_size, int64_t dilation_rate, int64_t stride, Padding padding_type, int64_t padding_before, int64_t padding_after, DimensionHandle* output_size); // Transfers shape of input(0) to output(0). -Status UnchangedShape(shape_inference::InferenceContext* c); +absl::Status UnchangedShape(shape_inference::InferenceContext* c); // Transfers shape of input(0) to output(0), after asserting its rank is . -inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c, - int32_t rank) { +inline absl::Status UnchangedShapeWithRank(shape_inference::InferenceContext* c, + int32_t rank) { ShapeHandle out; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out)); c->set_output(0, out); @@ -56,7 +56,7 @@ inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c, } // Transfers shape of input(0) to output(0), after asserting its rank >= . -inline Status UnchangedShapeWithRankAtLeast( +inline absl::Status UnchangedShapeWithRankAtLeast( shape_inference::InferenceContext* c, int32_t rank) { ShapeHandle out; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out)); @@ -65,8 +65,8 @@ inline Status UnchangedShapeWithRankAtLeast( } // Transfers shape of input(0) to output(0), after asserting its rank <= . -inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c, - int32_t rank) { +inline absl::Status UnchangedShapeWithRankAtMost( + shape_inference::InferenceContext* c, int32_t rank) { ShapeHandle out; TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out)); c->set_output(0, out); @@ -74,18 +74,18 @@ inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c, } // Shape function for use with ops no outputs. -inline Status NoOutputs(shape_inference::InferenceContext* c) { +inline absl::Status NoOutputs(shape_inference::InferenceContext* c) { return absl::OkStatus(); } // Shape function for ops that output a single scalar value. -inline Status ScalarShape(shape_inference::InferenceContext* c) { +inline absl::Status ScalarShape(shape_inference::InferenceContext* c) { c->set_output(0, c->Scalar()); return absl::OkStatus(); } // Shape function for binary ops where both inputs and the output match. -inline Status MergeBothInputsShapeFn(InferenceContext* c) { +inline absl::Status MergeBothInputsShapeFn(InferenceContext* c) { ShapeHandle out; TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out)); c->set_output(0, out); @@ -93,149 +93,154 @@ inline Status MergeBothInputsShapeFn(InferenceContext* c) { } // Shape function for dataset iterators. -Status DatasetIteratorShape(shape_inference::InferenceContext* c); +absl::Status DatasetIteratorShape(shape_inference::InferenceContext* c); // Returns a new shape with the specified dims arranged in the specified // format. The returned value is owned by this context. // Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth. -Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, - const std::vector& spatial, - DimensionOrConstant C, ShapeHandle* out, - shape_inference::InferenceContext* context); +absl::Status MakeShapeFromFormat( + TensorFormat format, DimensionOrConstant N, + const std::vector& spatial, DimensionOrConstant C, + ShapeHandle* out, shape_inference::InferenceContext* context); // Shape function for MatMul-like operations. -Status MatMulShape(shape_inference::InferenceContext* c); +absl::Status MatMulShape(shape_inference::InferenceContext* c); // Shape function for Batched MatMul-like operations with broadcasting across // batch dimensions. -Status BatchMatMulV2Shape(shape_inference::InferenceContext* c); +absl::Status BatchMatMulV2Shape(shape_inference::InferenceContext* c); // Shape function for BatchMatMul-like operations -Status BatchMatMulShape(shape_inference::InferenceContext* c); +absl::Status BatchMatMulShape(shape_inference::InferenceContext* c); // Shape function for Einsum. -Status EinsumShape(shape_inference::InferenceContext* c); +absl::Status EinsumShape(shape_inference::InferenceContext* c); // Shape function for BiasAdd-like operations. -Status BiasAddShape(shape_inference::InferenceContext* c); +absl::Status BiasAddShape(shape_inference::InferenceContext* c); // Shape function for BiasAddGrad-like operations. -Status BiasAddGradShape(shape_inference::InferenceContext* c); +absl::Status BiasAddGradShape(shape_inference::InferenceContext* c); // Shape function for general Convolution operation -Status ConvShape(shape_inference::InferenceContext* c); +absl::Status ConvShape(shape_inference::InferenceContext* c); // Shape function for Conv2D-like operations that support explicit padding. -Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c); +absl::Status Conv2DShapeWithExplicitPadding( + shape_inference::InferenceContext* c); // Shape function for Conv2D-like operations that do not support explicit // padding. -Status Conv2DShape(shape_inference::InferenceContext* c); +absl::Status Conv2DShape(shape_inference::InferenceContext* c); // Shape function for Conv3D-like operations. -Status Conv3DShape(shape_inference::InferenceContext* c); +absl::Status Conv3DShape(shape_inference::InferenceContext* c); // Shape function for DepthwiseConv2D-like operations that support explicit // padding. -Status DepthwiseConv2DNativeShapeWithExplicitPadding( +absl::Status DepthwiseConv2DNativeShapeWithExplicitPadding( shape_inference::InferenceContext* c); // Shape function for DepthwiseConv2D-like operations that do not support // explicit padding. -Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c); +absl::Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c); // Shape function for Conv2DBackpropInput. -Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c); +absl::Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c); // Shape function for Conv2DBackpropFilterWithBias. -Status Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext* c); +absl::Status Conv2DBackpropFilterWithBiasShape( + shape_inference::InferenceContext* c); // Shape function for AvgPool-like operations. -Status AvgPoolShape(shape_inference::InferenceContext* c); +absl::Status AvgPoolShape(shape_inference::InferenceContext* c); // Shape function for AvgPoolGrad-like operations. -Status AvgPoolGradShape(shape_inference::InferenceContext* c); +absl::Status AvgPoolGradShape(shape_inference::InferenceContext* c); // Shape function for FusedBatchNorm and FusedBatchNormV2 operations. -Status FusedBatchNormShape(shape_inference::InferenceContext* c); +absl::Status FusedBatchNormShape(shape_inference::InferenceContext* c); // Shape function for FusedBatchNormV3 operations. -Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c); +absl::Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c); // Shape function for _FusedBatchNormEx operations. -Status FusedBatchNormExShape(shape_inference::InferenceContext* c); +absl::Status FusedBatchNormExShape(shape_inference::InferenceContext* c); // Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations. -Status FusedBatchNormGradShape(shape_inference::InferenceContext* c); +absl::Status FusedBatchNormGradShape(shape_inference::InferenceContext* c); // Shape function for _FusedBatchNormGradEx operations. -Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c); +absl::Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c); // Shape function for MatrixDiagPartV2 and MatrixDiagPartV3 operations. -Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c); +absl::Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c); // Shape function for MatrixDiagV2 and MatrixDiagV3 operations. -Status MatrixDiagV2Shape(shape_inference::InferenceContext* c); +absl::Status MatrixDiagV2Shape(shape_inference::InferenceContext* c); // Shape function for MatrixSetDiagV2 and MatrixSetDiagV3 operations. -Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c); +absl::Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c); // Shape function for MaxPool-like operations that support explicit padding. -Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c); +absl::Status MaxPoolShapeWithExplicitPadding( + shape_inference::InferenceContext* c); // Shape function for MaxPool-like operations that do not support explicit // padding. -Status MaxPoolShape(shape_inference::InferenceContext* c); +absl::Status MaxPoolShape(shape_inference::InferenceContext* c); // Shape function for MaxPoolV2-like operations. -Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs); +absl::Status MaxPoolV2Shape(shape_inference::InferenceContext* c, + int num_inputs); // Shape function for MaxPoolGrad-like operations. -Status MaxPoolGradShape(shape_inference::InferenceContext* c); +absl::Status MaxPoolGradShape(shape_inference::InferenceContext* c); // Shape function for 3D Pooling operations. -Status Pool3DShape(shape_inference::InferenceContext* c); +absl::Status Pool3DShape(shape_inference::InferenceContext* c); // Shape function for MaxPool3DGrad-like operations. -Status MaxPool3DGradShape(shape_inference::InferenceContext* c); +absl::Status MaxPool3DGradShape(shape_inference::InferenceContext* c); // Shape function for AvgPool3DGrad-like operations. -Status AvgPool3DGradShape(shape_inference::InferenceContext* c); +absl::Status AvgPool3DGradShape(shape_inference::InferenceContext* c); // Shape function for use with ops whose output shapes are unknown. -Status UnknownShape(shape_inference::InferenceContext* c); +absl::Status UnknownShape(shape_inference::InferenceContext* c); // Shape function for reduction operations. -Status ReductionShape(shape_inference::InferenceContext* c); +absl::Status ReductionShape(shape_inference::InferenceContext* c); // Shape function for unsorted segment operations. -Status SegmentReductionWithNumSegmentsShapeFn(InferenceContext* c); +absl::Status SegmentReductionWithNumSegmentsShapeFn(InferenceContext* c); // Shape function for concat operations. // is the number of inputs to concatenate and are taken // from inputs // [1,num_inputs_to_concat] of the op. Input 0 is the concat_dim input. -Status ConcatShape(shape_inference::InferenceContext* c, - int num_inputs_to_concat); +absl::Status ConcatShape(shape_inference::InferenceContext* c, + int num_inputs_to_concat); // Shape function for concat operations. -Status ConcatV2Shape(shape_inference::InferenceContext* c); +absl::Status ConcatV2Shape(shape_inference::InferenceContext* c); -Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat); +absl::Status QuantizedConcatV2Shape(InferenceContext* c, + int num_inputs_to_concat); // Shape function for binary operators that broadcast their inputs // and with output to output_index. // Note: out cannot be NULL. -Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, - ShapeHandle shape_x, - ShapeHandle shape_y, - bool incompatible_shape_error, - ShapeHandle* out); +absl::Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, + ShapeHandle shape_x, + ShapeHandle shape_y, + bool incompatible_shape_error, + ShapeHandle* out); // Shape function for binary operators that broadcast their inputs // and with output to output_index. -inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, - int output_index) { +inline absl::Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, + int output_index) { ShapeHandle out; TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( c, c->input(0), c->input(1), true, &out)); @@ -245,57 +250,61 @@ inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, // Shape function for binary operators that broadcast their inputs. // Tested by ops/math_ops_test.cc. -inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) { +inline absl::Status BroadcastBinaryOpShapeFn(InferenceContext* c) { return BroadcastBinaryOpOutputShapeFn(c, 0); } // Shape function for random operations. -Status RandomShape(shape_inference::InferenceContext* c); +absl::Status RandomShape(shape_inference::InferenceContext* c); // Shape function for Slice operations. -Status SliceShape(shape_inference::InferenceContext* c); +absl::Status SliceShape(shape_inference::InferenceContext* c); // Validates the 3 component tensors of a sparse tensor have the proper // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py. -Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, - ShapeHandle values_shape, ShapeHandle shape_shape); +absl::Status ValidateSparseTensor(InferenceContext* c, + ShapeHandle indices_shape, + ShapeHandle values_shape, + ShapeHandle shape_shape); -Status ValidateVariableResourceHandle( +absl::Status ValidateVariableResourceHandle( InferenceContext* c, std::vector* shape_and_type); // Shape function for GatherNd operations. -Status GatherNdShape(InferenceContext* c); +absl::Status GatherNdShape(InferenceContext* c); // Helper shape function for ScatterNd.../TensorScatter... operations. -Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape, - ShapeHandle updates_shape, ShapeHandle input_shape); +absl::Status ScatterNdShapeHelper(InferenceContext* c, + ShapeHandle indices_shape, + ShapeHandle updates_shape, + ShapeHandle input_shape); // Shape function for ops with an explicit "shape" attribute. -Status ExplicitShape(InferenceContext* c); +absl::Status ExplicitShape(InferenceContext* c); // Shape function for multiple-output ops with an explicit "shapes" attribute. -Status ExplicitShapes(InferenceContext* c); +absl::Status ExplicitShapes(InferenceContext* c); // Shape function for SparseReduceMax and SparseReduceSum. -Status SparseReduceShapeFn(InferenceContext* c); +absl::Status SparseReduceShapeFn(InferenceContext* c); // Shape function for QuantizedConv2D op. -Status QuantizedConv2DShape(InferenceContext* c); +absl::Status QuantizedConv2DShape(InferenceContext* c); // Shape function for _QuantizedConv2D op/fusion. -Status FusedQuantizedConv2DShape(InferenceContext* c); +absl::Status FusedQuantizedConv2DShape(InferenceContext* c); // Shape function for _QuantizedDepthwiseConv2D op/fusion. -Status FusedQuantizedDepthwiseConv2D(InferenceContext* c); +absl::Status FusedQuantizedDepthwiseConv2D(InferenceContext* c); // Shape function for QuantizedAvgPool op -Status QuantizedAvgPoolShape(InferenceContext* c); +absl::Status QuantizedAvgPoolShape(InferenceContext* c); // Shape function for QuantizeV2 op -Status QuantizeV2Shape(InferenceContext* c); +absl::Status QuantizeV2Shape(InferenceContext* c); // Shape function for ReduceScatter ops -Status ReduceScatterShape(shape_inference::InferenceContext* c); +absl::Status ReduceScatterShape(shape_inference::InferenceContext* c); } // namespace shape_inference diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index 564fe9a2eeedaa..bf4d30401f2233 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -207,7 +207,7 @@ REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant") .Device(DEVICE_GPU), UnwrapDatasetVariantOp); -static Status WrappedDatasetVariantDeviceCopy( +static absl::Status WrappedDatasetVariantDeviceCopy( const WrappedDatasetVariantWrapper& from, WrappedDatasetVariantWrapper* to, const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { *to = WrappedDatasetVariantWrapper(from); @@ -228,15 +228,15 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(WrappedDatasetVariantWrapper, } // namespace -Status GraphDefBuilderWrapper::AddDataset(const DatasetBase* dataset, - const std::vector& inputs, - Node** output) { +absl::Status GraphDefBuilderWrapper::AddDataset( + const DatasetBase* dataset, const std::vector& inputs, + Node** output) { return AddDataset(dataset, inputs, {}, output); } -Status GraphDefBuilderWrapper::AddDataset( +absl::Status GraphDefBuilderWrapper::AddDataset( const DatasetBase* dataset, const std::vector& inputs, - const std::vector>& attrs, + const std::vector>& attrs, Node** output) { std::vector> enumerated_inputs(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { @@ -245,21 +245,21 @@ Status GraphDefBuilderWrapper::AddDataset( return AddDataset(dataset, enumerated_inputs, {}, attrs, output); } -Status GraphDefBuilderWrapper::AddDataset( +absl::Status GraphDefBuilderWrapper::AddDataset( const DatasetBase* dataset, const std::vector>& inputs, const std::vector>>& list_inputs, - const std::vector>& attrs, + const std::vector>& attrs, Node** output) { return AddDataset(dataset, inputs, list_inputs, attrs, /*use_dataset_name=*/false, output); } -Status GraphDefBuilderWrapper::AddDataset( +absl::Status GraphDefBuilderWrapper::AddDataset( const DatasetBase* dataset, const std::vector>& inputs, const std::vector>>& list_inputs, - const std::vector>& attrs, + const std::vector>& attrs, bool use_dataset_name, Node** output) { auto& type_string = dataset->type_string(); auto opts = absl::make_unique(b_->opts()); @@ -323,7 +323,7 @@ Status GraphDefBuilderWrapper::AddDataset( return absl::OkStatus(); } -Status GraphDefBuilderWrapper::AddFunction( +absl::Status GraphDefBuilderWrapper::AddFunction( SerializationContext* ctx, const string& function_name, const FunctionLibraryDefinition& lib_def) { if (b_->HasFunction(function_name)) { @@ -383,7 +383,7 @@ void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val, bool GraphDefBuilderWrapper::HasAttr(const string& name, const string& attr_name) const { const OpDef* op_def = nullptr; - Status s = b_->opts().op_registry()->LookUpOpDef(name, &op_def); + absl::Status s = b_->opts().op_registry()->LookUpOpDef(name, &op_def); if (!s.ok() || op_def == nullptr) { return false; } @@ -516,7 +516,7 @@ void MemoryCheckpoint::Purge(const std::string& prefix) { } } -Status MemoryCheckpoint::Save(IteratorStateWriter* writer) const { +absl::Status MemoryCheckpoint::Save(IteratorStateWriter* writer) const { for (const auto& [id, value] : int_values_) { auto [prefix, key] = id_registry_->Get(id); TF_RETURN_IF_ERROR(writer->WriteScalar(prefix, key, value)); @@ -532,8 +532,8 @@ Status MemoryCheckpoint::Save(IteratorStateWriter* writer) const { return absl::OkStatus(); } -Status IteratorBase::InitializeBase(IteratorContext* ctx, - const IteratorBase* parent) { +absl::Status IteratorBase::InitializeBase(IteratorContext* ctx, + const IteratorBase* parent) { parent_ = parent; id_ = Hash64CombineUnordered(Hash64(prefix()), reinterpret_cast(this)); @@ -554,7 +554,7 @@ Status IteratorBase::InitializeBase(IteratorContext* ctx, return absl::OkStatus(); } -Status GetCompressedElementFromVariantTensor( +absl::Status GetCompressedElementFromVariantTensor( const Tensor& tensor, const CompressedElement** out_compressed_element) { if (!(tensor.dtype() == DT_VARIANT && TensorShapeUtils::IsScalar(tensor.shape()))) { @@ -626,7 +626,7 @@ std::string FullName(const std::string& prefix, const std::string& name) { return strings::StrCat(kFullNameRandomHex, kPipe, prefix, kColon, name); } -Status ExtractIteratorPrefix(StringPiece key, string* prefix) { +absl::Status ExtractIteratorPrefix(absl::string_view key, string* prefix) { if (!absl::StartsWith(key, data::kFullNameRandomHex)) { return errors::InvalidArgument("Key: ", key, " was not generated using full_name."); @@ -642,8 +642,8 @@ Status ExtractIteratorPrefix(StringPiece key, string* prefix) { return absl::OkStatus(); } -Status GetDatasetFromVariantTensor(const Tensor& tensor, - DatasetBase** out_dataset) { +absl::Status GetDatasetFromVariantTensor(const Tensor& tensor, + DatasetBase** out_dataset) { if (!(tensor.dtype() == DT_VARIANT && TensorShapeUtils::IsScalar(tensor.shape()))) { return errors::InvalidArgument( @@ -661,7 +661,7 @@ Status GetDatasetFromVariantTensor(const Tensor& tensor, return absl::OkStatus(); } -Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) { +absl::Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) { if (!(tensor->dtype() == DT_VARIANT && TensorShapeUtils::IsScalar(tensor->shape()))) { return errors::InvalidArgument( @@ -768,7 +768,7 @@ void MergeOptions(const protobuf::MessageLite& source, } // namespace internal void DatasetBase::Initialize(const Metadata& metadata) { - Status s = ComputeNumSources(); + absl::Status s = ComputeNumSources(); if (!s.ok()) { LOG_EVERY_N_SEC(ERROR, 10) << s; } @@ -784,9 +784,9 @@ void DatasetBase::Initialize(const Metadata& metadata) { } } -Status DatasetBase::ComputeNumSources() { +absl::Status DatasetBase::ComputeNumSources() { std::vector inputs; - Status s = InputDatasets(&inputs); + absl::Status s = InputDatasets(&inputs); if (errors::IsUnimplemented(s)) { return s; } @@ -811,7 +811,7 @@ Status DatasetBase::ComputeNumSources() { return absl::OkStatus(); } -Status DatasetBase::CheckRandomAccessCompatible(const int64 index) const { +absl::Status DatasetBase::CheckRandomAccessCompatible(const int64 index) const { CardinalityOptions options; options.set_compute_level(CardinalityOptions::CARDINALITY_COMPUTE_MODERATE); int64 cardinality = Cardinality(options); @@ -829,14 +829,14 @@ Status DatasetBase::CheckRandomAccessCompatible(const int64 index) const { return absl::OkStatus(); } -Status DatasetBase::Get(OpKernelContext* ctx, int64 index, - std::vector* out_tensors) const { +absl::Status DatasetBase::Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const { return errors::Unimplemented("Random access is not implemented for dataset ", DebugString()); } -Status DatasetBase::Get(AnyContext ctx, int64 index, - std::vector* out_tensors) const { +absl::Status DatasetBase::Get(AnyContext ctx, int64 index, + std::vector* out_tensors) const { return errors::Unimplemented("Random access is not implemented for dataset ", DebugString()); } @@ -852,9 +852,9 @@ absl::StatusOr DatasetBase::Finalize( return finalized_dataset_.get(); } -Status DatasetBase::MergeOptionsFromInputs() { +absl::Status DatasetBase::MergeOptionsFromInputs() { std::vector inputs; - Status s = InputDatasets(&inputs); + absl::Status s = InputDatasets(&inputs); if (errors::IsUnimplemented(s)) { return s; } @@ -874,13 +874,13 @@ Status DatasetBase::MergeOptionsFromInputs() { return absl::OkStatus(); } -Status DatasetBase::MakeIterator( +absl::Status DatasetBase::MakeIterator( IteratorContext* ctx, const IteratorBase* parent, const string& output_prefix, std::unique_ptr* iterator) const { if (type_string() == "OptionsDataset" || type_string() == "FinalizeDataset") { std::vector inputs; - Status s = InputDatasets(&inputs); + absl::Status s = InputDatasets(&inputs); return inputs[0]->MakeIterator(ctx, parent, output_prefix, iterator); } tsl::profiler::TraceMe traceme( @@ -890,7 +890,7 @@ Status DatasetBase::MakeIterator( }, tsl::profiler::TraceMeLevel::kInfo); *iterator = MakeIteratorInternal(output_prefix); - Status s = (*iterator)->InitializeBase(ctx, parent); + absl::Status s = (*iterator)->InitializeBase(ctx, parent); if (s.ok()) { s.Update((*iterator)->Initialize(ctx)); ctx->SaveCheckpoint(iterator->get()); @@ -902,10 +902,10 @@ Status DatasetBase::MakeIterator( return s; } -Status DatasetBase::MakeSplitProviders( +absl::Status DatasetBase::MakeSplitProviders( std::vector>* split_providers) const { std::vector inputs; - Status s = InputDatasets(&inputs); + absl::Status s = InputDatasets(&inputs); if (errors::IsUnimplemented(s)) { return errors::Unimplemented( "Cannot create split providers for dataset of type ", type_string(), @@ -963,7 +963,7 @@ int64_t DatasetBase::Cardinality(CardinalityOptions options) const { return cardinality_; } -Status DatasetBase::InputDatasets( +absl::Status DatasetBase::InputDatasets( std::vector* inputs) const { return errors::Unimplemented( "Cannot compute input sources for dataset of type ", type_string(), @@ -972,9 +972,9 @@ Status DatasetBase::InputDatasets( "source dataset, it should return empty inputs."); } -Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset( +absl::Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset( SerializationContext* ctx, const DatasetBase* dataset, Node** output) { - Status status = dataset->AsGraphDefInternal(ctx, this, output); + absl::Status status = dataset->AsGraphDefInternal(ctx, this, output); if (ctx->is_graph_rewrite()) { if (status.ok()) { // Record cardinality in an unregistered attributes so that rewrites have @@ -1001,7 +1001,7 @@ Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset( return status; } -Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensor( +absl::Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensor( SerializationContext* ctx, const Tensor& t, Node** output) { if (t.dtype() == DT_VARIANT) { // If the input tensor is a variant, it may represent a multi-dimensional @@ -1011,13 +1011,13 @@ Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensor( // // If this fails, we fallback to using its Variant::Encode() based // serialization. - Status s = AddDatasetOrTensorHelper(ctx, t, output); + absl::Status s = AddDatasetOrTensorHelper(ctx, t, output); if (s.ok()) { return s; } } if (t.dtype() == DT_RESOURCE && !ctx->is_graph_rewrite()) { - Status s = AddResourceHelper(ctx, t, output); + absl::Status s = AddResourceHelper(ctx, t, output); if (!errors::IsUnimplemented(s)) { // Fall through to AddTensor if AsGraphDef is not implemented for this // resource. @@ -1027,7 +1027,7 @@ Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensor( return AddTensor(t, output); } -Status DatasetBase::DatasetGraphDefBuilder::AddIdentity( +absl::Status DatasetBase::DatasetGraphDefBuilder::AddIdentity( SerializationContext* ctx, const std::string& name_prefix, Node** input, Node** output) { *output = @@ -1036,7 +1036,7 @@ Status DatasetBase::DatasetGraphDefBuilder::AddIdentity( return absl::OkStatus(); } -Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensorHelper( +absl::Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensorHelper( SerializationContext* ctx, const Tensor& t, Node** output) { if (t.dims() == 0) { DatasetBase* dataset; @@ -1058,7 +1058,7 @@ Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensorHelper( return absl::OkStatus(); } -Status DatasetBase::DatasetGraphDefBuilder::AddResourceHelper( +absl::Status DatasetBase::DatasetGraphDefBuilder::AddResourceHelper( SerializationContext* ctx, const Tensor& t, Node** output) { if (t.NumElements() == 0) { return errors::InvalidArgument("Empty resouce handle"); @@ -1128,9 +1128,9 @@ string DatasetBaseIterator::BuildTraceMeName() { return result; } -Status DatasetBaseIterator::GetNext(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) { +absl::Status DatasetBaseIterator::GetNext(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) { activity_watcher::ActivityScope activity_scope([&]() { activity_watcher::Activity::Attributes attributes; attributes["iterator_prefix"] = prefix(); @@ -1152,7 +1152,7 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx, node_->record_start(now_nanos); } out_tensors->clear(); - Status s = GetNextInternal(ctx, out_tensors, end_of_sequence); + absl::Status s = GetNextInternal(ctx, out_tensors, end_of_sequence); ctx->SaveCheckpoint(this); if (!SymbolicCheckpointCompatible()) { ctx->UpdateCheckpointStatus([this]() { @@ -1192,8 +1192,9 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx, return s; } -Status DatasetBaseIterator::Skip(IteratorContext* ctx, int num_to_skip, - bool* end_of_sequence, int* num_skipped) { +absl::Status DatasetBaseIterator::Skip(IteratorContext* ctx, int num_to_skip, + bool* end_of_sequence, + int* num_skipped) { tsl::profiler::TraceMe activity([&] { return BuildTraceMeName(); }, tsl::profiler::TraceMeLevel::kInfo); DVLOG(3) << prefix() << " Skip enter"; @@ -1208,7 +1209,7 @@ Status DatasetBaseIterator::Skip(IteratorContext* ctx, int num_to_skip, } node_->record_start(now_nanos); } - Status s = SkipInternal(ctx, num_to_skip, end_of_sequence, num_skipped); + absl::Status s = SkipInternal(ctx, num_to_skip, end_of_sequence, num_skipped); if (collect_resource_usage(ctx)) { int64_t now_nanos = EnvTime::NowNanos(); node_->record_stop(now_nanos); @@ -1229,9 +1230,10 @@ Status DatasetBaseIterator::Skip(IteratorContext* ctx, int num_to_skip, return s; } -Status DatasetBaseIterator::SkipInternal(IteratorContext* ctx, int num_to_skip, - bool* end_of_sequence, - int* num_skipped) { +absl::Status DatasetBaseIterator::SkipInternal(IteratorContext* ctx, + int num_to_skip, + bool* end_of_sequence, + int* num_skipped) { *num_skipped = 0; for (int i = 0; i < num_to_skip; ++i) { std::vector out_tensors; diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index d50a831826b9df..70ebc12a3f9f6c 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -87,7 +87,7 @@ void MergeOptions(const protobuf::MessageLite& source, protobuf::MessageLite* destination); } // namespace internal -using TraceMeMetadata = std::vector>; +using TraceMeMetadata = std::vector>; // Maps the index of dataset elements to a globally shuffled index. See the // comment for IteratorContext::Params::index_mapper for more details. @@ -135,28 +135,32 @@ inline bool IsTFDataFunction(const FunctionDef& func) { class IteratorStateReader { public: // Determines whether the iterator state contains the given key. - virtual bool Contains(StringPiece key) const = 0; - virtual bool Contains(StringPiece name, StringPiece key) const = 0; + virtual bool Contains(absl::string_view key) const = 0; + virtual bool Contains(absl::string_view name, + absl::string_view key) const = 0; // Reads an integer for the given key. - virtual Status ReadScalar(StringPiece key, int64_t* val) const = 0; - virtual Status ReadScalar(StringPiece name, StringPiece key, - int64_t* val) const = 0; + virtual absl::Status ReadScalar(absl::string_view key, + int64_t* val) const = 0; + virtual absl::Status ReadScalar(absl::string_view name, absl::string_view key, + int64_t* val) const = 0; // Reads a string for the given key. - virtual Status ReadScalar(StringPiece key, tstring* val) const = 0; - virtual Status ReadScalar(StringPiece name, StringPiece key, - tstring* val) const = 0; + virtual absl::Status ReadScalar(absl::string_view key, + tstring* val) const = 0; + virtual absl::Status ReadScalar(absl::string_view name, absl::string_view key, + tstring* val) const = 0; // Reads a tensor for the given key. // TODO(jsimsa): Remove non-FLR overrides once all callers are updated. - virtual Status ReadTensor(StringPiece key, Tensor* val) const = 0; - virtual Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece key, - Tensor* val) const = 0; - virtual Status ReadTensor(StringPiece name, StringPiece key, - Tensor* val) const = 0; - virtual Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece name, - StringPiece key, Tensor* val) const = 0; + virtual absl::Status ReadTensor(absl::string_view key, Tensor* val) const = 0; + virtual absl::Status ReadTensor(FunctionLibraryRuntime* flr, + absl::string_view key, Tensor* val) const = 0; + virtual absl::Status ReadTensor(absl::string_view name, absl::string_view key, + Tensor* val) const = 0; + virtual absl::Status ReadTensor(FunctionLibraryRuntime* flr, + absl::string_view name, absl::string_view key, + Tensor* val) const = 0; virtual ~IteratorStateReader() {} }; @@ -173,19 +177,25 @@ class IteratorStateReader { class IteratorStateWriter { public: // Writes an integer for the given key. - virtual Status WriteScalar(StringPiece key, const int64_t val) = 0; - virtual Status WriteScalar(StringPiece name, StringPiece key, - const int64_t val) = 0; + virtual absl::Status WriteScalar(absl::string_view key, + const int64_t val) = 0; + virtual absl::Status WriteScalar(absl::string_view name, + absl::string_view key, + const int64_t val) = 0; // Writes a string for the given key. - virtual Status WriteScalar(StringPiece key, const tstring& val) = 0; - virtual Status WriteScalar(StringPiece name, StringPiece key, - const tstring& val) = 0; + virtual absl::Status WriteScalar(absl::string_view key, + const tstring& val) = 0; + virtual absl::Status WriteScalar(absl::string_view name, + absl::string_view key, + const tstring& val) = 0; // Writes a tensor for the given key. - virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0; - virtual Status WriteTensor(StringPiece name, StringPiece key, - const Tensor& val) = 0; + virtual absl::Status WriteTensor(absl::string_view key, + const Tensor& val) = 0; + virtual absl::Status WriteTensor(absl::string_view name, + absl::string_view key, + const Tensor& val) = 0; virtual ~IteratorStateWriter() {} @@ -201,7 +211,7 @@ class IteratorStateWriter { std::string FullName(const std::string& prefix, const std::string& name); // Extracts iterator prefix from key generated by `FullName`. -Status ExtractIteratorPrefix(StringPiece key, string* prefix); +absl::Status ExtractIteratorPrefix(absl::string_view key, string* prefix); // Interface for objects that can be checkpointed. class Checkpointable { @@ -209,9 +219,10 @@ class Checkpointable { Checkpointable() = default; virtual ~Checkpointable() = default; - virtual Status Save(SerializationContext* ctx, - IteratorStateWriter* writer) = 0; - virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) = 0; + virtual absl::Status Save(SerializationContext* ctx, + IteratorStateWriter* writer) = 0; + virtual absl::Status Restore(IteratorContext* ctx, + IteratorStateReader* reader) = 0; }; // Wrapper around GraphDefBuilder. Used to serialize Dataset graph. @@ -224,7 +235,7 @@ class GraphDefBuilderWrapper { // non-null if the method returns with an OK status. // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. template - Status AddScalar(const T& val, Node** output) { + absl::Status AddScalar(const T& val, Node** output) { Tensor val_t = Tensor(DataTypeToEnum::v(), TensorShape({})); val_t.scalar()() = val; AddTensorInternal(val_t, output); @@ -240,7 +251,7 @@ class GraphDefBuilderWrapper { // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice? template - Status AddVector(const std::vector& val, Node** output) { + absl::Status AddVector(const std::vector& val, Node** output) { Tensor val_t = Tensor(DataTypeToEnum::v(), TensorShape({static_cast(val.size())})); for (size_t i = 0; i < val.size(); i++) { @@ -253,7 +264,7 @@ class GraphDefBuilderWrapper { return absl::OkStatus(); } - Status AddVector(const std::vector& val, Node** output) { + absl::Status AddVector(const std::vector& val, Node** output) { Tensor val_t = Tensor(DataTypeToEnum::v(), TensorShape({static_cast(val.size())})); for (size_t i = 0; i < val.size(); i++) { @@ -271,7 +282,7 @@ class GraphDefBuilderWrapper { // `*output` contains a pointer to the output `Node`. It is guaranteed to be // non-null if the method returns with an OK status. The returned `Node` // pointer is owned by the backing graph of `GraphDefBuilder`. - Status AddTensor(const Tensor& val, Node** output) { + absl::Status AddTensor(const Tensor& val, Node** output) { AddTensorInternal(val, output); if (*output == nullptr) { return errors::Internal("AddTensor: Failed to build Const op."); @@ -284,7 +295,7 @@ class GraphDefBuilderWrapper { // `*output` contains a pointer to the output `Node`. It is guaranteed to be // non-null if the method returns with an OK status. The returned `Node` // pointer is owned by the backing graph of `GraphDefBuilder`. - Status AddPlaceholder(const Tensor& val, Node** output) { + absl::Status AddPlaceholder(const Tensor& val, Node** output) { AddPlaceholderInternal(val, output); if (*output == nullptr) { return errors::Internal( @@ -310,25 +321,25 @@ class GraphDefBuilderWrapper { // `*output` contains a pointer to the output `Node`. It is guaranteed to be // non-null if the method returns with an OK status. The returned `Node` // pointer is owned by the backing `Graph` of `GraphDefBuilder`. - Status AddDataset(const DatasetBase* dataset, - const std::vector& inputs, Node** output); - Status AddDataset(const DatasetBase* dataset, - const std::vector& inputs, - const std::vector>& attrs, - Node** output); - Status AddDataset( + absl::Status AddDataset(const DatasetBase* dataset, + const std::vector& inputs, Node** output); + absl::Status AddDataset( + const DatasetBase* dataset, const std::vector& inputs, + const std::vector>& attrs, + Node** output); + absl::Status AddDataset( const DatasetBase* dataset, const std::vector>& inputs, const std::vector>>& list_inputs, - const std::vector>& attrs, + const std::vector>& attrs, Node** output); - Status AddDataset( + absl::Status AddDataset( const DatasetBase* dataset, const std::vector>& inputs, const std::vector>>& list_inputs, - const std::vector>& attrs, + const std::vector>& attrs, bool use_dataset_name, Node** output); // Adds a user-defined function with name `function_name` to the graph and @@ -338,8 +349,9 @@ class GraphDefBuilderWrapper { // returns an InvalidArgumentError. If the function with name `function_name` // or any of its dependent functions are stateful, and the context does not // explicitly permit stateful functions, returns an InvalidArgument error. - Status AddFunction(SerializationContext* ctx, const string& function_name, - const FunctionLibraryDefinition& lib_def); + absl::Status AddFunction(SerializationContext* ctx, + const string& function_name, + const FunctionLibraryDefinition& lib_def); template void BuildAttrValue(const T& value, AttrValue* attr) { @@ -370,9 +382,9 @@ class GraphDefBuilderWrapper { return false; } - Status AddAttrFunctions(SerializationContext* ctx, - const AttrValue& attr_value, - const FunctionLibraryDefinition& lib_def) { + absl::Status AddAttrFunctions(SerializationContext* ctx, + const AttrValue& attr_value, + const FunctionLibraryDefinition& lib_def) { if (attr_value.has_func()) { TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name(), lib_def)); } else if (attr_value.has_list()) { @@ -417,15 +429,16 @@ class SplitProvider { virtual ~SplitProvider() {} // Stores the next split in `*split`, setting `*end_of_splits` to indicate // whether there were any splits left. - virtual Status GetNext(Tensor* split, bool* end_of_splits) = 0; + virtual absl::Status GetNext(Tensor* split, bool* end_of_splits) = 0; // Resets the split provider to its beginning. - virtual Status Reset() = 0; + virtual absl::Status Reset() = 0; // Saves the state of this split provider. - virtual Status Save(std::function full_name, - IteratorStateWriter* writer) = 0; + virtual absl::Status Save(std::function full_name, + IteratorStateWriter* writer) = 0; // Restores the state of this split provider. - virtual Status Restore(std::function full_name, - IteratorStateReader* reader) = 0; + virtual absl::Status Restore( + std::function full_name, + IteratorStateReader* reader) = 0; // Returns the number of splits: // - If there are a finite number of splits, returns a non-negative count. // - If there are an infinite number of splits, returns kInfiniteCardinality. @@ -495,34 +508,35 @@ class MemoryCheckpoint final : public IteratorStateWriter { } // BEGIN implementation of `IteratorStateWriter` interface - Status WriteScalar(StringPiece key, int64_t val) override { + absl::Status WriteScalar(absl::string_view key, int64_t val) override { string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return WriteScalar(prefix, key, val); } - Status WriteScalar(StringPiece name, StringPiece key, int64_t val) override { + absl::Status WriteScalar(absl::string_view name, absl::string_view key, + int64_t val) override { auto id = id_registry_->Add(string(name), string(key)); int_values_[id] = val; return absl::OkStatus(); } - Status WriteScalar(StringPiece key, const tstring& val) override { + absl::Status WriteScalar(absl::string_view key, const tstring& val) override { string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return WriteScalar(prefix, key, val); } - Status WriteScalar(StringPiece name, StringPiece key, - const tstring& val) override { + absl::Status WriteScalar(absl::string_view name, absl::string_view key, + const tstring& val) override { auto id = id_registry_->Add(string(name), string(key)); str_values_[id] = val; return absl::OkStatus(); } - Status WriteTensor(StringPiece key, const Tensor& val) override { + absl::Status WriteTensor(absl::string_view key, const Tensor& val) override { string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return WriteTensor(prefix, key, val); } - Status WriteTensor(StringPiece name, StringPiece key, - const Tensor& val) override { + absl::Status WriteTensor(absl::string_view name, absl::string_view key, + const Tensor& val) override { auto id = id_registry_->Add(string(name), string(key)); tensor_values_[id] = val; return absl::OkStatus(); @@ -533,7 +547,7 @@ class MemoryCheckpoint final : public IteratorStateWriter { std::string DebugString() const; // Returns the status of the in-memory checkpoint. - Status GetStatus() const { return status_; } + absl::Status GetStatus() const { return status_; } // Merges state of another checkpoint into this checkpoint, overwriting // existing state (if applicable). @@ -546,17 +560,17 @@ class MemoryCheckpoint final : public IteratorStateWriter { void Purge(const std::string& prefix); // Stores the in-memory checkpoint to the given writer. - Status Save(IteratorStateWriter* writer) const; + absl::Status Save(IteratorStateWriter* writer) const; // Updates the status of the in-memory checkpoint with the given status. - void UpdateStatus(Status status) { status_.Update(status); } + void UpdateStatus(absl::Status status) { status_.Update(status); } private: explicit MemoryCheckpoint(std::shared_ptr registry, bool is_root) : is_root_(is_root), id_registry_(registry) {} void operator=(const MemoryCheckpoint&) = delete; - Status status_ = absl::OkStatus(); + absl::Status status_ = absl::OkStatus(); // Only set to true for the checkpoint in IteratorResource. // Root checkpoint does not track expired prefixes. const bool is_root_ = false; @@ -574,7 +588,7 @@ class MemoryCheckpoint final : public IteratorStateWriter { class SerializationContext { public: // Handles the external state according to the external state policy. - Status HandleCheckExternalStateStatus(Status s) { + absl::Status HandleCheckExternalStateStatus(absl::Status s) { if (s.ok()) { return s; } @@ -1001,7 +1015,7 @@ class IteratorContext { } // Updates the status of the checkpoint with the given status. - void UpdateCheckpointStatus(std::function status_fn) { + void UpdateCheckpointStatus(std::function status_fn) { if (symbolic_checkpoint()) { checkpoint_.UpdateStatus(status_fn()); } @@ -1070,11 +1084,12 @@ class IteratorBase : public Checkpointable { // // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and // potentially remove this method. - virtual Status GetNext(IteratorContext* ctx, std::vector* out_tensors, - bool* end_of_sequence) = 0; + virtual absl::Status GetNext(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) = 0; - Status GetNext(IteratorContext&& ctx, std::vector* out_tensors, - bool* end_of_sequence) { + absl::Status GetNext(IteratorContext&& ctx, std::vector* out_tensors, + bool* end_of_sequence) { return GetNext(&ctx, out_tensors, end_of_sequence); } @@ -1092,11 +1107,11 @@ class IteratorBase : public Checkpointable { // `*end_of_sequence = true` and return `OkStatus()`. `*num_skipped` will // store the number of outputs that are skipped. When `*end_of_sequence` is // `false`, `*num_skipped` should equal to `num_to_skip`. - virtual Status Skip(IteratorContext* ctx, int num_to_skip, - bool* end_of_sequence, int* num_skipped) = 0; + virtual absl::Status Skip(IteratorContext* ctx, int num_to_skip, + bool* end_of_sequence, int* num_skipped) = 0; - virtual Status Skip(IteratorContext&& ctx, int num_to_skip, - bool* end_of_sequence, int* num_skipped) { + virtual absl::Status Skip(IteratorContext&& ctx, int num_to_skip, + bool* end_of_sequence, int* num_skipped) { return Skip(&ctx, num_to_skip, end_of_sequence, num_skipped); } @@ -1119,13 +1134,16 @@ class IteratorBase : public Checkpointable { // Performs initialization that needs to happen outside of a constructor to // properly propagate errors. - virtual Status Initialize(IteratorContext* ctx) { return absl::OkStatus(); } + virtual absl::Status Initialize(IteratorContext* ctx) { + return absl::OkStatus(); + } // Performs initialization of the base iterator. - Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent); + absl::Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent); // Saves the state of this iterator. - Status Save(SerializationContext* ctx, IteratorStateWriter* writer) override { + absl::Status Save(SerializationContext* ctx, + IteratorStateWriter* writer) override { int64_t start_us = EnvTime::NowMicros(); TF_RETURN_IF_ERROR(SaveInternal(ctx, writer)); VLOG(1) << "Saved " << prefix() << " in " @@ -1134,7 +1152,8 @@ class IteratorBase : public Checkpointable { } // Restores the state of this iterator. - Status Restore(IteratorContext* ctx, IteratorStateReader* reader) override { + absl::Status Restore(IteratorContext* ctx, + IteratorStateReader* reader) override { int64_t start_us = EnvTime::NowMicros(); TF_RETURN_IF_ERROR(RestoreInternal(ctx, reader)); ctx->SaveCheckpoint(this); @@ -1157,8 +1176,8 @@ class IteratorBase : public Checkpointable { // This is needed so that sub-classes of IteratorBase can call // `SaveInternal` on their input iterators. - Status SaveInput(SerializationContext* ctx, IteratorStateWriter* writer, - const std::unique_ptr& input) { + absl::Status SaveInput(SerializationContext* ctx, IteratorStateWriter* writer, + const std::unique_ptr& input) { if (ctx->symbolic_checkpoint()) { return absl::OkStatus(); } @@ -1167,13 +1186,13 @@ class IteratorBase : public Checkpointable { // This is needed so that sub-classes of IteratorBase can call // `RestoreInternal` on their input iterators. - Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader, - const std::unique_ptr& input) { + absl::Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader, + const std::unique_ptr& input) { return input->Restore(ctx, reader); } - Status RestoreInput(IteratorContext&& ctx, IteratorStateReader* reader, - const std::unique_ptr& input) { + absl::Status RestoreInput(IteratorContext&& ctx, IteratorStateReader* reader, + const std::unique_ptr& input) { return RestoreInput(&ctx, reader, input); } @@ -1181,8 +1200,8 @@ class IteratorBase : public Checkpointable { // // This method is used to store the state of the iterator in a checkpoint. // implementations have an override. - virtual Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) = 0; + virtual absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) = 0; // Restores the state of this iterator. // @@ -1192,8 +1211,8 @@ class IteratorBase : public Checkpointable { // its `Initialize` method has been called, but its `GetNext` method has // never been called. // implementations have an override. - virtual Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) = 0; + virtual absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) = 0; // Returns a pointer to the node representing this iterator in the performance // model. It may be null, if performance modeling is not enabled for this @@ -1256,13 +1275,13 @@ int64_t GetTotalBytes(const std::vector& element); // by the tensor. The consumer must either acquire its own reference to the // dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not // destroyed or mutated while the retrieved pointer is in use. -Status GetDatasetFromVariantTensor(const Tensor& tensor, - DatasetBase** out_dataset); +absl::Status GetDatasetFromVariantTensor(const Tensor& tensor, + DatasetBase** out_dataset); // Stores a `DatasetBase` object in `tensor`. // // The ownership of `dataset` is transferred to `tensor`. -Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor); +absl::Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor); // Represents a (potentially infinite) range of outputs, where each // output is a tuple of tensors. @@ -1304,18 +1323,18 @@ class DatasetBase : public core::RefCounted { // // The prefix identifies the sequence of iterators leading up to the newly // created iterator. - Status MakeIterator(IteratorContext* ctx, const IteratorBase* parent, - const string& output_prefix, - std::unique_ptr* iterator) const; + absl::Status MakeIterator(IteratorContext* ctx, const IteratorBase* parent, + const string& output_prefix, + std::unique_ptr* iterator) const; - Status MakeIterator(IteratorContext&& ctx, const IteratorBase* parent, - const string& output_prefix, - std::unique_ptr* iterator) const { + absl::Status MakeIterator(IteratorContext&& ctx, const IteratorBase* parent, + const string& output_prefix, + std::unique_ptr* iterator) const { return MakeIterator(&ctx, parent, output_prefix, iterator); } // Returns a new iterator restored from the checkpoint data in `reader`. - Status MakeIteratorFromCheckpoint( + absl::Status MakeIteratorFromCheckpoint( IteratorContext* ctx, const string& output_prefix, IteratorStateReader* reader, std::unique_ptr* iterator) const { @@ -1331,7 +1350,7 @@ class DatasetBase : public core::RefCounted { return absl::OkStatus(); } - Status MakeIteratorFromCheckpoint( + absl::Status MakeIteratorFromCheckpoint( IteratorContext&& ctx, const string& output_prefix, IteratorStateReader* reader, std::unique_ptr* iterator) const { @@ -1341,7 +1360,7 @@ class DatasetBase : public core::RefCounted { // Returns a split provider which partitions the dataset's data into splits // and provides them in a sequence. The split provider is stored in // `*split_provider`. - virtual Status MakeSplitProviders( + virtual absl::Status MakeSplitProviders( std::vector>* split_providers) const; // Returns a vector of DataType values, representing the respective @@ -1388,26 +1407,27 @@ class DatasetBase : public core::RefCounted { // subclass. Implementing `InputDatasets` enables `DatasetBase` to provide a // default implementation of `MakeSplitProvider` when there is a single input // dataset. - virtual Status InputDatasets(std::vector* inputs) const; + virtual absl::Status InputDatasets( + std::vector* inputs) const; // Indicates whether the dataset depends on any external state which would // prevent it from being serializable. If so, the method returns // `errors::FailedPrecondition` with a message that identifies the external // state. Otherwise, the method returns `OkStatus()`. - virtual Status CheckExternalState() const = 0; + virtual absl::Status CheckExternalState() const = 0; // Indicates whether the dataset is compatible with random access. - Status CheckRandomAccessCompatible(const int64 index) const; + absl::Status CheckRandomAccessCompatible(const int64 index) const; // Return the element at a particular index for a randomly accessible dataset. - virtual Status Get(OpKernelContext* ctx, int64 index, - std::vector* out_tensors) const; + virtual absl::Status Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const; // Same as above, but with an `AnyContext`, which can be constructed from // either an `OpKernelContext` or `IteratorContext`. Used to support datasets // that provide random access through both the dataset and iterator APIs. - virtual Status Get(AnyContext ctx, int64 index, - std::vector* out_tensors) const; + virtual absl::Status Get(AnyContext ctx, int64 index, + std::vector* out_tensors) const; // Returns true if the dataset and its inputs support random access. virtual absl::Status RandomIndexingCompatible() const { @@ -1428,19 +1448,19 @@ class DatasetBase : public core::RefCounted { public: explicit DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {} - Status AddInputDataset(SerializationContext* ctx, - const DatasetBase* dataset, Node** output); - Status AddDatasetOrTensor(SerializationContext* ctx, const Tensor& val, - Node** output); - Status AddIdentity(SerializationContext* ctx, - const std::string& name_prefix, Node** input, - Node** output); - - private: - Status AddDatasetOrTensorHelper(SerializationContext* ctx, + absl::Status AddInputDataset(SerializationContext* ctx, + const DatasetBase* dataset, Node** output); + absl::Status AddDatasetOrTensor(SerializationContext* ctx, const Tensor& val, Node** output); - Status AddResourceHelper(SerializationContext* ctx, const Tensor& val, + absl::Status AddIdentity(SerializationContext* ctx, + const std::string& name_prefix, Node** input, Node** output); + + private: + absl::Status AddDatasetOrTensorHelper(SerializationContext* ctx, + const Tensor& val, Node** output); + absl::Status AddResourceHelper(SerializationContext* ctx, const Tensor& val, + Node** output); }; protected: @@ -1456,9 +1476,9 @@ class DatasetBase : public core::RefCounted { // 2) To save the dataset so that it can restore at a later point (possibly in // different environment). If a subclass of `DatasetBase` does not implement // this method, then this migration will not be possible. - virtual Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** node) const = 0; + virtual absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** node) const = 0; virtual std::unique_ptr MakeIteratorInternal( const string& prefix) const = 0; @@ -1467,17 +1487,17 @@ class DatasetBase : public core::RefCounted { private: // Computes and stores the cardinality of a given dataset. - Status ComputeCardinality(); + absl::Status ComputeCardinality(); // Computes the number of source datasets feeding into this dataset. A source // dataset is a leaf in the subtree of dataset inputs. - Status ComputeNumSources(); + absl::Status ComputeNumSources(); // Merges options from inputs to this dataset. If there is a conflict in a // field value, the options set on this dataset takes precedence over those in // the inputs. The order of precedence on the inputs is in the same order as // how they appear for this dataset. - Status MergeOptionsFromInputs(); + absl::Status MergeOptionsFromInputs(); const string type_string_; const string node_name_; @@ -1526,18 +1546,19 @@ class DatasetBaseIterator : public IteratorBase { // following format "name#arg_1=value_,...,arg_n=value_n". string BuildTraceMeName(); - Status GetNext(IteratorContext* ctx, std::vector* out_tensors, - bool* end_of_sequence) final; + absl::Status GetNext(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) final; - Status GetNext(IteratorContext&& ctx, std::vector* out_tensors, - bool* end_of_sequence) { + absl::Status GetNext(IteratorContext&& ctx, std::vector* out_tensors, + bool* end_of_sequence) { return GetNext(&ctx, out_tensors, end_of_sequence); } - Status Skip(IteratorContext* ctx, int num_to_skip, bool* end_of_sequence, - int* num_skipped) final; + absl::Status Skip(IteratorContext* ctx, int num_to_skip, + bool* end_of_sequence, int* num_skipped) final; - Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final { + absl::Status Save(SerializationContext* ctx, + IteratorStateWriter* writer) final { VLOG(2) << "Attempting to save checkpoints on iterator (prefix: " << prefix() << ") from " << dataset()->DebugString(); return IteratorBase::Save(ctx, writer); @@ -1545,16 +1566,18 @@ class DatasetBaseIterator : public IteratorBase { // Returns a copy of the `status` where the error message is prepended with // dataset name and the iterator prefix. - Status AddErrorContext(const Status& status) const { - return Status(status.code(), - strings::StrCat("Error in user-defined function passed to ", - dataset()->metadata().name(), - " transformation with iterator: ", prefix(), - ": ", status.message())); + absl::Status AddErrorContext(const absl::Status& status) const { + return absl::Status( + status.code(), + strings::StrCat("Error in user-defined function passed to ", + dataset()->metadata().name(), + " transformation with iterator: ", prefix(), ": ", + status.message())); } protected: - Status Restore(IteratorContext* ctx, IteratorStateReader* reader) final { + absl::Status Restore(IteratorContext* ctx, + IteratorStateReader* reader) final { VLOG(2) << "Attempting to restore checkpoints on iterator (prefix: " << prefix() << ") from " << dataset()->DebugString(); return IteratorBase::Restore(ctx, reader); @@ -1565,13 +1588,13 @@ class DatasetBaseIterator : public IteratorBase { // See the docstring of `GetNext` method regaring the contract for // `out_tensors` and `end_of_sequence`. Implementations may assume that // `*out_tensors` is empty. - virtual Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) = 0; + virtual absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) = 0; // Internal implementation of Skip that is wrapped in tracing logic - virtual Status SkipInternal(IteratorContext* ctx, int num_to_skip, - bool* end_of_sequence, int* num_skipped); + virtual absl::Status SkipInternal(IteratorContext* ctx, int num_to_skip, + bool* end_of_sequence, int* num_skipped); string full_name(const string& name) const { return FullName(params_.prefix, name); @@ -1693,8 +1716,9 @@ class DatasetIterator : public DatasetBaseIterator { }; template -Status ParseScalarArgument(OpKernelContext* ctx, - const StringPiece& argument_name, T* output) { +absl::Status ParseScalarArgument(OpKernelContext* ctx, + const absl::string_view& argument_name, + T* output) { const Tensor* argument_t; TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); if (!TensorShapeUtils::IsScalar(argument_t->shape())) { @@ -1705,9 +1729,9 @@ Status ParseScalarArgument(OpKernelContext* ctx, } template -Status ParseVectorArgument(OpKernelContext* ctx, - const StringPiece& argument_name, - std::vector* output) { +absl::Status ParseVectorArgument(OpKernelContext* ctx, + const absl::string_view& argument_name, + std::vector* output) { const Tensor* argument_t; TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); if (!TensorShapeUtils::IsVector(argument_t->shape())) { diff --git a/tensorflow/core/framework/dataset_stateful_op_allowlist.h b/tensorflow/core/framework/dataset_stateful_op_allowlist.h index b92acf5fb74972..cc25c801bf60b1 100644 --- a/tensorflow/core/framework/dataset_stateful_op_allowlist.h +++ b/tensorflow/core/framework/dataset_stateful_op_allowlist.h @@ -25,12 +25,12 @@ namespace data { // See below macro for usage details. class AllowlistedStatefulOpRegistry { public: - Status Add(string op_name) { + absl::Status Add(string op_name) { op_names_.insert(std::move(op_name)); return absl::OkStatus(); } - Status Remove(string op_name) { + absl::Status Remove(string op_name) { op_names_.erase(op_name); return absl::OkStatus(); } diff --git a/tensorflow/core/framework/device.h b/tensorflow/core/framework/device.h index 08231d55d3a160..7b5bfcb1042142 100644 --- a/tensorflow/core/framework/device.h +++ b/tensorflow/core/framework/device.h @@ -54,7 +54,7 @@ namespace tensorflow { class Device : public DeviceBase { public: // Callback type that takes a Status and returns void. - typedef std::function DoneCallback; + typedef std::function DoneCallback; Device(Env* env, const DeviceAttributes& device_attributes); ~Device() override; @@ -102,7 +102,7 @@ class Device : public DeviceBase { // Blocks until all operations queued on the device at the time of // the call have completed. Returns any error pending on the device // at completion. - virtual Status Sync() = 0; + virtual absl::Status Sync() = 0; // Calls the given callback when all operations queued on the device at the // time of the call have completed. The callback is passed any error pending @@ -128,7 +128,7 @@ class Device : public DeviceBase { // current status in a non-blocking way, without using blocking calls such as // Stream::BlockHostUntilDone or Device::Sync. When applicable, the device // status is also updated with the retrieved stream status. - virtual Status RefreshStatus() { + virtual absl::Status RefreshStatus() { return errors::Unimplemented( "RefreshStatus is not supported on this device."); } @@ -141,7 +141,7 @@ class Device : public DeviceBase { // // 'graph' supplies the partition of the graph assigned to this // device. - virtual Status MaybeRewriteGraph(std::unique_ptr* /*graph*/) { + virtual absl::Status MaybeRewriteGraph(std::unique_ptr* /*graph*/) { return absl::OkStatus(); } @@ -151,7 +151,7 @@ class Device : public DeviceBase { // // The caller takes ownership of one reference on the output DeviceContext*, // and should call Unref(). - virtual Status TryGetDeviceContext(DeviceContext** out_context) { + virtual absl::Status TryGetDeviceContext(DeviceContext** out_context) { *out_context = nullptr; return absl::OkStatus(); } diff --git a/tensorflow/core/framework/device_base.cc b/tensorflow/core/framework/device_base.cc index ac2d383f96ef5d..44db0a284f1f79 100644 --- a/tensorflow/core/framework/device_base.cc +++ b/tensorflow/core/framework/device_base.cc @@ -35,7 +35,7 @@ DeviceBase::~DeviceBase() { } absl::Status DeviceContext::CopyDeviceTensorToCPUSync( - const Tensor* device_tensor, StringPiece tensor_name, Device* device, + const Tensor* device_tensor, absl::string_view tensor_name, Device* device, Tensor* cpu_tensor) { absl::Notification n; absl::Status status; diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index 065707fde4b8c2..fe5099fa361429 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -86,8 +86,9 @@ class DeviceContext : public core::RefCounted { } // Same as CopyCPUTensorToDevice, but in a synchronous way. - Status CopyCPUTensorToDeviceSync(const Tensor* cpu_tensor, Device* device, - Tensor* device_tensor) const; + absl::Status CopyCPUTensorToDeviceSync(const Tensor* cpu_tensor, + Device* device, + Tensor* device_tensor) const; // Copies a tensor in this device. virtual void CopyTensorInSameDevice(const Tensor* input_tensor, @@ -100,22 +101,24 @@ class DeviceContext : public core::RefCounted { // device_tensor into "cpu_tensor". "cpu_tensor" must be allocated // to be of the same size as "device_tensor". virtual void CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, Device* device, - Tensor* cpu_tensor, StatusCallback done) { + absl::string_view tensor_name, + Device* device, Tensor* cpu_tensor, + StatusCallback done) { done(errors::Internal("Unrecognized device type in device-to-CPU Copy")); } // Same as `CopyDeviceTensorToCPU`, but blocks until the copy is done. - Status CopyDeviceTensorToCPUSync(const Tensor* device_tensor, - StringPiece tensor_name, Device* device, - Tensor* cpu_tensor); + absl::Status CopyDeviceTensorToCPUSync(const Tensor* device_tensor, + absl::string_view tensor_name, + Device* device, Tensor* cpu_tensor); // If possible, wait for all events on *stream to complete then execute func. // A non-OK Status is returned otherwise. The stream argument should be the // one provided by AcceleratorDeviceInfo. This function is not applicable to // devices that don't provide such a value. - virtual Status ThenExecute(Device* device, stream_executor::Stream* stream, - std::function func) { + virtual absl::Status ThenExecute(Device* device, + stream_executor::Stream* stream, + std::function func) { return errors::Internal("ThenExecute not supported by device"); } @@ -225,10 +228,10 @@ class DeviceBase { // This is overridden by GPU devices to reinitialize the derived // type returned by MakeGpuDevice. - virtual Status ReinitializeGpuDevice(OpKernelContext* /*context*/, - PerOpGpuDevice* /*device*/, - DeviceContext* /*dc*/, - Allocator* /*allocator*/) { + virtual absl::Status ReinitializeGpuDevice(OpKernelContext* /*context*/, + PerOpGpuDevice* /*device*/, + DeviceContext* /*dc*/, + Allocator* /*allocator*/) { return absl::OkStatus(); } @@ -253,9 +256,9 @@ class DeviceBase { // OpKernelContext and handle the copies from device memory via send // and receive nodes, instead of requiring that each device handle // the copies here as well as in copy ops. - virtual Status MakeTensorFromProto(const TensorProto& tensor_proto, - const AllocatorAttributes alloc_attrs, - Tensor* tensor) { + virtual absl::Status MakeTensorFromProto( + const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, + Tensor* tensor) { return errors::Internal("Device does not implement MakeTensorFromProto()"); } diff --git a/tensorflow/core/framework/device_factory.cc b/tensorflow/core/framework/device_factory.cc index e39d768a56c785..392b44f2eb177c 100644 --- a/tensorflow/core/framework/device_factory.cc +++ b/tensorflow/core/framework/device_factory.cc @@ -127,7 +127,8 @@ DeviceFactory* DeviceFactory::GetFactory(const string& device_type) { return it->second.factory.get(); } -Status DeviceFactory::ListAllPhysicalDevices(std::vector* devices) { +absl::Status DeviceFactory::ListAllPhysicalDevices( + std::vector* devices) { // CPU first. A CPU device is required. // TODO(b/183974121): Consider merge the logic into the loop below. auto cpu_factory = GetFactory("CPU"); @@ -154,7 +155,7 @@ Status DeviceFactory::ListAllPhysicalDevices(std::vector* devices) { return absl::OkStatus(); } -Status DeviceFactory::ListPluggablePhysicalDevices( +absl::Status DeviceFactory::ListPluggablePhysicalDevices( std::vector* devices) { tf_shared_lock l(*get_device_factory_lock()); for (auto& p : device_factories()) { @@ -166,7 +167,7 @@ Status DeviceFactory::ListPluggablePhysicalDevices( return absl::OkStatus(); } -Status DeviceFactory::GetAnyDeviceDetails( +absl::Status DeviceFactory::GetAnyDeviceDetails( int device_index, std::unordered_map* details) { if (device_index < 0) { return errors::InvalidArgument("Device index out of bounds: ", @@ -209,7 +210,7 @@ Status DeviceFactory::GetAnyDeviceDetails( orig_device_index); } -Status DeviceFactory::AddCpuDevices( +absl::Status DeviceFactory::AddCpuDevices( const SessionOptions& options, const string& name_prefix, std::vector>* devices) { auto cpu_factory = GetFactory("CPU"); @@ -226,7 +227,7 @@ Status DeviceFactory::AddCpuDevices( return absl::OkStatus(); } -Status DeviceFactory::AddDevices( +absl::Status DeviceFactory::AddDevices( const SessionOptions& options, const string& name_prefix, std::vector>* devices) { // CPU first. A CPU device is required. diff --git a/tensorflow/core/framework/device_factory.h b/tensorflow/core/framework/device_factory.h index 7957af3cbad869..8b07d15cfc0dac 100644 --- a/tensorflow/core/framework/device_factory.h +++ b/tensorflow/core/framework/device_factory.h @@ -43,17 +43,17 @@ class DeviceFactory { static DeviceFactory* GetFactory(const std::string& device_type); // Append to "*devices" CPU devices. - static Status AddCpuDevices(const SessionOptions& options, - const std::string& name_prefix, - std::vector>* devices); + static absl::Status AddCpuDevices( + const SessionOptions& options, const std::string& name_prefix, + std::vector>* devices); // Append to "*devices" all suitable devices, respecting // any device type specific properties/counts listed in "options". // // CPU devices are added first. - static Status AddDevices(const SessionOptions& options, - const std::string& name_prefix, - std::vector>* devices); + static absl::Status AddDevices(const SessionOptions& options, + const std::string& name_prefix, + std::vector>* devices); // Helper for tests. Create a single device of type "type". The // returned device is always numbered zero, so if creating multiple @@ -66,30 +66,31 @@ class DeviceFactory { // possible physical devices. // // CPU is are added first. - static Status ListAllPhysicalDevices(std::vector* devices); + static absl::Status ListAllPhysicalDevices(std::vector* devices); // Iterate through all device factories and build a list of all of the // possible pluggable physical devices. - static Status ListPluggablePhysicalDevices(std::vector* devices); + static absl::Status ListPluggablePhysicalDevices( + std::vector* devices); // Get details for a specific device among all device factories. // 'device_index' indexes into devices from ListAllPhysicalDevices. - static Status GetAnyDeviceDetails( + static absl::Status GetAnyDeviceDetails( int device_index, std::unordered_map* details); // For a specific device factory list all possible physical devices. - virtual Status ListPhysicalDevices(std::vector* devices) = 0; + virtual absl::Status ListPhysicalDevices(std::vector* devices) = 0; // Get details for a specific device for a specific factory. Subclasses // can store arbitrary device information in the map. 'device_index' indexes // into devices from ListPhysicalDevices. - virtual Status GetDeviceDetails(int device_index, - std::unordered_map* details) { + virtual absl::Status GetDeviceDetails( + int device_index, std::unordered_map* details) { return absl::OkStatus(); } // Most clients should call AddDevices() instead. - virtual Status CreateDevices( + virtual absl::Status CreateDevices( const SessionOptions& options, const std::string& name_prefix, std::vector>* devices) = 0; diff --git a/tensorflow/core/framework/fake_input.cc b/tensorflow/core/framework/fake_input.cc index bf7edef06ddae9..ec424f890883eb 100644 --- a/tensorflow/core/framework/fake_input.cc +++ b/tensorflow/core/framework/fake_input.cc @@ -33,12 +33,12 @@ class FakeInputImpl { void SetN(int n); void SetDataType(DataType dt); void SetTypeList(DataTypeSlice dts); - Status AddInputToBuilder(); + absl::Status AddInputToBuilder(); private: static string FakeNodeName(int in_index); - Status GetN(int* n) const; - Status GetDataType(DataType* dt) const; + absl::Status GetN(int* n) const; + absl::Status GetDataType(DataType* dt) const; void NSources(int n, DataType dt) const; void SourceList(DataTypeSlice dts) const; @@ -82,7 +82,7 @@ void FakeInputImpl::SetTypeList(DataTypeSlice dts) { dts_ = dts; } -Status FakeInputImpl::AddInputToBuilder() { +absl::Status FakeInputImpl::AddInputToBuilder() { if (dts_specified_) { SourceList(dts_); @@ -101,7 +101,8 @@ Status FakeInputImpl::AddInputToBuilder() { } else { if (!dt_specified_ && !arg_->type_list_attr().empty()) { DataTypeVector dts; - Status status = GetNodeAttr(*node_def_, arg_->type_list_attr(), &dts); + absl::Status status = + GetNodeAttr(*node_def_, arg_->type_list_attr(), &dts); if (!status.ok()) { return errors::InvalidArgument( "Could not infer list of types for input '", arg_->name(), @@ -124,11 +125,11 @@ string FakeInputImpl::FakeNodeName(int in_index) { return string(&c, 1); } -Status FakeInputImpl::GetN(int* n) const { +absl::Status FakeInputImpl::GetN(int* n) const { if (n_specified_) { *n = n_; } else { - Status status = GetNodeAttr(*node_def_, arg_->number_attr(), n); + absl::Status status = GetNodeAttr(*node_def_, arg_->number_attr(), n); if (!status.ok()) { return errors::InvalidArgument("Could not infer length of input '", arg_->name(), "': ", status.message()); @@ -137,14 +138,14 @@ Status FakeInputImpl::GetN(int* n) const { return absl::OkStatus(); } -Status FakeInputImpl::GetDataType(DataType* dt) const { +absl::Status FakeInputImpl::GetDataType(DataType* dt) const { if (dt_specified_) { *dt = dt_; return absl::OkStatus(); // Ignore is_ref field of arg_. } else if (arg_->type() != DT_INVALID) { *dt = arg_->type(); } else if (!arg_->type_attr().empty()) { - Status status = GetNodeAttr(*node_def_, arg_->type_attr(), dt); + absl::Status status = GetNodeAttr(*node_def_, arg_->type_attr(), dt); if (!status.ok()) { // Check if the type attr has a default const OpDef::AttrDef* attr = FindAttr(arg_->type_attr(), *op_def_); diff --git a/tensorflow/core/framework/full_type_util.cc b/tensorflow/core/framework/full_type_util.cc index b76b1d52274095..f494f2ef2bd766 100644 --- a/tensorflow/core/framework/full_type_util.cc +++ b/tensorflow/core/framework/full_type_util.cc @@ -139,21 +139,21 @@ OpTypeConstructor VariadicTensorContainer(FullTypeId t, namespace { -typedef absl::flat_hash_map AttrMap; +typedef absl::flat_hash_map AttrMap; -inline Status SubstituteFromAttrs(AttrMap& attrs, FullTypeDef& t); +inline absl::Status SubstituteFromAttrs(AttrMap& attrs, FullTypeDef& t); -Status SubstituteVar(AttrMap& attrs, FullTypeDef& t) { +absl::Status SubstituteVar(AttrMap& attrs, FullTypeDef& t) { if (t.args_size() != 0) { - return Status( + return absl::Status( absl::StatusCode::kInvalidArgument, absl::StrCat("Unexpected Var type, expected args_size 0, found ", t.args_size())); } - StringPiece var_name = t.s(); + absl::string_view var_name = t.s(); if (!attrs.contains(var_name)) { - return Status( + return absl::Status( absl::StatusCode::kInvalidArgument, absl::StrCat("could not find an attribute for key '", var_name, "'")); } @@ -165,34 +165,37 @@ Status SubstituteVar(AttrMap& attrs, FullTypeDef& t) { } else if (attr_type == AttrValue::kList) { const auto& attr_list = attr->list(); if (attr_list.type_size() != 1) { - return Status(absl::StatusCode::kUnimplemented, - absl::StrCat("lists or other than one type element\n", - attr_list.DebugString(), "\nkey=", var_name)); + return absl::Status( + absl::StatusCode::kUnimplemented, + absl::StrCat("lists or other than one type element\n", + attr_list.DebugString(), "\nkey=", var_name)); } map_dtype_to_tensor(attr_list.type(0), t); } else { - return Status(absl::StatusCode::kUnimplemented, - absl::StrCat("unsupported attribute type ", - attr->DebugString(), " for name ", var_name)); + return absl::Status( + absl::StatusCode::kUnimplemented, + absl::StrCat("unsupported attribute type ", attr->DebugString(), + " for name ", var_name)); } t.clear_s(); return absl::OkStatus(); } -Status SubstituteForEach(AttrMap& attrs, FullTypeDef& t) { +absl::Status SubstituteForEach(AttrMap& attrs, FullTypeDef& t) { if (t.args_size() != 3) { - return Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("illegal FOR_EACH type, expected 3 args, got ", - t.args_size())); + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrCat("illegal FOR_EACH type, expected 3 args, got ", + t.args_size())); } const auto& cont = t.args(0); const auto& tmpl = t.args(1); const auto& t_var = t.args(2); - StringPiece var_name = t_var.s(); + absl::string_view var_name = t_var.s(); if (!attrs.contains(var_name)) { - return Status( + return absl::Status( absl::StatusCode::kInvalidArgument, absl::StrCat("could not find an attribute for key '", var_name, "'")); } @@ -213,9 +216,10 @@ Status SubstituteForEach(AttrMap& attrs, FullTypeDef& t) { const auto& attr_list = attr->list(); int tsize = attr_list.type_size(); if (tsize == 0) { - return Status(absl::StatusCode::kUnimplemented, - absl::StrCat("unsupported list attribute type\n", - attr_list.DebugString(), "\nkey=", var_name)); + return absl::Status( + absl::StatusCode::kUnimplemented, + absl::StrCat("unsupported list attribute type\n", + attr_list.DebugString(), "\nkey=", var_name)); } AttrValue replacement; attrs[var_name] = &replacement; @@ -233,15 +237,16 @@ Status SubstituteForEach(AttrMap& attrs, FullTypeDef& t) { attrs[var_name] = attr; } else { - return Status(absl::StatusCode::kUnimplemented, - absl::StrCat("unsupported attribute type\n", - attr->DebugString(), "\nfor name ", var_name)); + return absl::Status( + absl::StatusCode::kUnimplemented, + absl::StrCat("unsupported attribute type\n", attr->DebugString(), + "\nfor name ", var_name)); } t = result; return absl::OkStatus(); } -Status SubstituteGeneric(AttrMap& attrs, FullTypeDef& t) { +absl::Status SubstituteGeneric(AttrMap& attrs, FullTypeDef& t) { int nargs = t.args_size(); for (int j = 0; j < nargs; j++) { FullTypeDef* arg_t = t.mutable_args(j); @@ -260,7 +265,7 @@ Status SubstituteGeneric(AttrMap& attrs, FullTypeDef& t) { return absl::OkStatus(); } -inline Status SubstituteFromAttrs(AttrMap& attrs, FullTypeDef& t) { +inline absl::Status SubstituteFromAttrs(AttrMap& attrs, FullTypeDef& t) { // Resolve dependent types. The convention for op registrations is to use // attributes as type variables. // See https://www.tensorflow.org/guide/create_op#type_polymorphism. @@ -286,8 +291,8 @@ inline Status SubstituteFromAttrs(AttrMap& attrs, FullTypeDef& t) { } // namespace -Status SpecializeType(const AttrSlice& attrs, const OpDef& op_def, - FullTypeDef& target) { +absl::Status SpecializeType(const AttrSlice& attrs, const OpDef& op_def, + FullTypeDef& target) { target.Clear(); target.set_type_id(TFT_PRODUCT); diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 61cfee4198de94..9e1b0a612a869a 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -77,8 +77,8 @@ namespace tensorflow { // Otherwise (arg_def is a simple type T), *is_type_list is set to // false, and *dtypes is set to a single element vector, whose only // element is T. -Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, - bool* is_type_list, DataTypeVector* dtypes) { +absl::Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, + bool* is_type_list, DataTypeVector* dtypes) { dtypes->clear(); if (!arg_def.type_list_attr().empty()) { const AttrValue* v = attrs.FindByString(arg_def.type_list_attr()); @@ -126,13 +126,14 @@ void AddAttr(const string& name, const T& val, NodeDef* ndef) { SetAttrValue(val, &((*ndef->mutable_attr())[name])); } -Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) { +absl::Status ValidateSignatureWithAttrs(const OpDef& sig, + AttrSlice attr_values) { // attr_values should specify all attrs defined in fdef, except for those // which have a default value for (const auto& attr : sig.attr()) { const AttrValue* attr_value = attr_values.FindByString(attr.name()); if (attr_value) { - Status status = AttrValueHasType(*attr_value, attr.type()); + absl::Status status = AttrValueHasType(*attr_value, attr.type()); if (!status.ok()) { errors::AppendToMessage(&status, "for attr '", attr.name(), "'"); return status; @@ -182,10 +183,11 @@ class FunctionInstantiationHelper { // Builds index for nodes that can be used as node's input arguments. // `resource_arg_unique_id`: if non-negative, will be populated to the // "_resource_arg_unique_id" attribute of the arg node. - Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values, - const FunctionDef::ArgAttrs* arg_attrs, - bool ints_on_device, - int64_t resource_arg_unique_id) { + absl::Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, + AttrSlice attr_values, + const FunctionDef::ArgAttrs* arg_attrs, + bool ints_on_device, + int64_t resource_arg_unique_id) { bool is_type_list; DataTypeVector dtypes; TF_RETURN_IF_ERROR( @@ -232,8 +234,8 @@ class FunctionInstantiationHelper { return absl::OkStatus(); } - Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs, - const int arg_index) { + absl::Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs, + const int arg_index) { const OpDef* node_sig = nullptr; TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig)); if (node_sig->output_arg_size() == 0) { @@ -262,7 +264,7 @@ class FunctionInstantiationHelper { return absl::OkStatus(); } - Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) { + absl::Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) { const OpDef* fnode_sig = nullptr; TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig)); NodeDef* gnode = AddNode(fnode.name()); @@ -366,7 +368,7 @@ class FunctionInstantiationHelper { return absl::OkStatus(); } - Status AddReturnNode( + absl::Status AddReturnNode( const OpDef::ArgDef& ret_def, AttrSlice attrs, const ::tensorflow::protobuf::Map& ret_map, bool ints_on_device, int* ret_index) { @@ -445,7 +447,7 @@ class FunctionInstantiationHelper { }; // Adds an item into the input name index. - Status AddItem(const string& name, const NameInfoItem& item) { + absl::Status AddItem(const string& name, const NameInfoItem& item) { if (!index_.insert({name, item}).second) { return errors::InvalidArgument( strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret", @@ -587,9 +589,9 @@ string Print(const NodeDef& n) { strings::StrAppend(&out, "[", absl::StrJoin(entries, ", "), "]"); } strings::StrAppend(&out, "("); - std::vector dat; + std::vector dat; std::vector dep; - for (StringPiece s : n.input()) { + for (absl::string_view s : n.input()) { if (absl::ConsumePrefix(&s, "^")) { dep.emplace_back(s); } else { @@ -725,9 +727,9 @@ string Print(absl::Span nodes) { return out; } -Status AddDefaultAttrs(const string& op, - const GetFunctionSignature& get_function, - AttrValueMap* attrs) { +absl::Status AddDefaultAttrs(const string& op, + const GetFunctionSignature& get_function, + AttrValueMap* attrs) { const OpDef* op_def = nullptr; TF_RETURN_IF_ERROR(get_function(op, &op_def)); AttrSlice attr_slice(attrs); @@ -743,9 +745,9 @@ Status AddDefaultAttrs(const string& op, } // end namespace -Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, - GetFunctionSignature get_function, - InstantiationResult* result) { +absl::Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, + GetFunctionSignature get_function, + InstantiationResult* result) { if (VLOG_IS_ON(5)) { const auto& signature = fdef.signature(); VLOG(5) << "Instantiate function definition: name=" << signature.name() @@ -769,7 +771,7 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, attr_values_ints_on_device->b()); FunctionInstantiationHelper helper(get_function, result); - Status s; + absl::Status s; for (int i = 0, e = sig.input_arg_size(); i < e; ++i) { const OpDef::ArgDef& arg_def = sig.input_arg(i); auto it = fdef.arg_attr().find(i); @@ -1147,7 +1149,7 @@ FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types, FunctionCallFrame::~FunctionCallFrame() {} -Status FunctionCallFrame::SetArgs(absl::Span args) { +absl::Status FunctionCallFrame::SetArgs(absl::Span args) { // Input type checks. if (args.size() != arg_types_.size()) { return errors::InvalidArgument("Expects ", arg_types_.size(), @@ -1165,7 +1167,7 @@ Status FunctionCallFrame::SetArgs(absl::Span args) { return absl::OkStatus(); } -Status FunctionCallFrame::GetRetvals(std::vector* rets) const { +absl::Status FunctionCallFrame::GetRetvals(std::vector* rets) const { rets->clear(); rets->reserve(rets_.size()); for (size_t i = 0; i < rets_.size(); ++i) { @@ -1179,8 +1181,8 @@ Status FunctionCallFrame::GetRetvals(std::vector* rets) const { return absl::OkStatus(); } -Status FunctionCallFrame::ConsumeRetvals(std::vector* rets, - bool allow_dead_tensors) { +absl::Status FunctionCallFrame::ConsumeRetvals(std::vector* rets, + bool allow_dead_tensors) { rets->clear(); rets->reserve(rets_.size()); for (size_t i = 0; i < rets_.size(); ++i) { @@ -1195,7 +1197,7 @@ Status FunctionCallFrame::ConsumeRetvals(std::vector* rets, return absl::OkStatus(); } -Status FunctionCallFrame::GetArg(int index, const Tensor** val) { +absl::Status FunctionCallFrame::GetArg(int index, const Tensor** val) { if (index < 0 || static_cast(index) >= args_.size()) { return errors::InvalidArgument("GetArg ", index, " is not within [0, ", args_.size(), ")"); @@ -1204,7 +1206,7 @@ Status FunctionCallFrame::GetArg(int index, const Tensor** val) { return absl::OkStatus(); } -Status FunctionCallFrame::SetRetval(int index, const Tensor& val) { +absl::Status FunctionCallFrame::SetRetval(int index, const Tensor& val) { if (index < 0 || static_cast(index) >= rets_.size()) { return errors::InvalidArgument("SetRetval ", index, " is not within [0, ", rets_.size(), ")"); @@ -1248,8 +1250,8 @@ void FunctionRecord::finalize() { absl::StatusOr FunctionRecord::mutable_fdef() { if (finalized_) { - return Status(absl::StatusCode::kPermissionDenied, - "Can not mutate FunctionDef after finalization."); + return absl::Status(absl::StatusCode::kPermissionDenied, + "Can not mutate FunctionDef after finalization."); } return &fdef_; @@ -1397,45 +1399,45 @@ core::RefCountPtr FunctionLibraryDefinition::FindHelper( } } -Status FunctionLibraryDefinition::AddFunctionDef( +absl::Status FunctionLibraryDefinition::AddFunctionDef( const FunctionDef& fdef, const StackTracesMap& stack_traces) { mutex_lock l(mu_); bool added; FunctionRecord* record = new FunctionRecord(fdef, stack_traces, true); core::ScopedUnref scoped_unref(record); - Status status = AddHelper(record, &added); + absl::Status status = AddHelper(record, &added); return status; } -Status FunctionLibraryDefinition::AddFunctionDef( +absl::Status FunctionLibraryDefinition::AddFunctionDef( FunctionDef&& fdef, StackTracesMap&& stack_traces) { mutex_lock l(mu_); bool added; FunctionRecord* record = new FunctionRecord(std::move(fdef), std::move(stack_traces), true); core::ScopedUnref scoped_unref(record); - Status status = AddHelper(record, &added); + absl::Status status = AddHelper(record, &added); return status; } -Status FunctionLibraryDefinition::AddFunctionDefHelper( +absl::Status FunctionLibraryDefinition::AddFunctionDefHelper( FunctionDef&& fdef, StackTracesMap&& stack_traces, bool* added) { FunctionRecord* record = new FunctionRecord(std::move(fdef), std::move(stack_traces), true); core::ScopedUnref scoped_unref(record); - Status status = AddHelper(record, added); + absl::Status status = AddHelper(record, added); return status; } -Status FunctionLibraryDefinition::AddFunctionRecord( +absl::Status FunctionLibraryDefinition::AddFunctionRecord( core::RefCountPtr record) TF_LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); bool added; return AddHelper(record.get(), &added); } -Status FunctionLibraryDefinition::AddHelper(FunctionRecord* registration, - bool* added) { +absl::Status FunctionLibraryDefinition::AddHelper(FunctionRecord* registration, + bool* added) { *added = false; auto iter = records_.find(registration->fdef().signature().name()); if (iter != records_.end()) { @@ -1463,7 +1465,7 @@ Status FunctionLibraryDefinition::AddHelper(FunctionRecord* registration, return absl::OkStatus(); } -Status FunctionLibraryDefinition::CopyFunctionDefFrom( +absl::Status FunctionLibraryDefinition::CopyFunctionDefFrom( const string& name, const FunctionLibraryDefinition& other) { if (default_registry() != other.default_registry()) { return errors::InvalidArgument( @@ -1496,14 +1498,15 @@ Status FunctionLibraryDefinition::CopyFunctionDefFrom( } } -Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) { +absl::Status FunctionLibraryDefinition::AddGradientDef( + const GradientDef& grad) { mutex_lock l(mu_); bool added; return AddGradientDefHelper(grad, &added); } -Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad, - bool* added) { +absl::Status FunctionLibraryDefinition::AddGradientDefHelper( + const GradientDef& grad, bool* added) { *added = false; string* entry = &func_grad_[grad.function_name()]; if (!entry->empty()) { @@ -1521,14 +1524,14 @@ Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad, return absl::OkStatus(); } -Status FunctionLibraryDefinition::AddLibrary( +absl::Status FunctionLibraryDefinition::AddLibrary( const FunctionLibraryDefinition& other) { // Clone `other` to ensure thread-safety (grabbing `other`'s lock for // the duration of the function could lead to deadlock). return AddLibrary(FunctionLibraryDefinition(other)); } -Status FunctionLibraryDefinition::AddLibrary( +absl::Status FunctionLibraryDefinition::AddLibrary( FunctionLibraryDefinition&& other) { mutex_lock l(mu_); mutex_lock l2(other.mu_); @@ -1536,12 +1539,12 @@ Status FunctionLibraryDefinition::AddLibrary( // we can roll them back on error. std::vector funcs; std::vector funcs_with_grads; - Status s; + absl::Status s; bool added; for (const auto& [name, record] : other.records_) { s = AddHelper(record, &added); if (!s.ok()) { - Status remove_status = Remove(funcs, funcs_with_grads); + absl::Status remove_status = Remove(funcs, funcs_with_grads); if (!remove_status.ok()) { return remove_status; } @@ -1557,7 +1560,7 @@ Status FunctionLibraryDefinition::AddLibrary( grad.set_gradient_func(iter.second); s = AddGradientDefHelper(grad, &added); if (!s.ok()) { - Status remove_status = Remove(funcs, funcs_with_grads); + absl::Status remove_status = Remove(funcs, funcs_with_grads); if (!remove_status.ok()) { return remove_status; } @@ -1570,22 +1573,23 @@ Status FunctionLibraryDefinition::AddLibrary( return absl::OkStatus(); } -Status FunctionLibraryDefinition::AddLibrary( +absl::Status FunctionLibraryDefinition::AddLibrary( const FunctionDefLibrary& lib_def) { return AddLibrary(FunctionDefLibrary(lib_def), /*stack_traces=*/{}); } -Status FunctionLibraryDefinition::AddLibrary(FunctionDefLibrary&& lib_def) { +absl::Status FunctionLibraryDefinition::AddLibrary( + FunctionDefLibrary&& lib_def) { return AddLibrary(std::move(lib_def), /*stack_traces=*/{}); } -Status FunctionLibraryDefinition::AddLibrary( +absl::Status FunctionLibraryDefinition::AddLibrary( const FunctionDefLibrary& lib_def, const FunctionDefLibraryStackTraces& library_traces) { return AddLibrary(FunctionDefLibrary(lib_def), library_traces); } -Status FunctionLibraryDefinition::AddLibrary( +absl::Status FunctionLibraryDefinition::AddLibrary( FunctionDefLibrary&& lib_def, const FunctionDefLibraryStackTraces& library_traces) { // Remember the funcs and grads that we added successfully so that @@ -1593,7 +1597,7 @@ Status FunctionLibraryDefinition::AddLibrary( mutex_lock l(mu_); std::vector funcs; std::vector funcs_with_grads; - Status s; + absl::Status s; bool added; for (FunctionDef& fdef : *lib_def.mutable_function()) { std::string name = fdef.signature().name(); @@ -1602,7 +1606,7 @@ Status FunctionLibraryDefinition::AddLibrary( : StackTracesMap(); s = AddFunctionDefHelper(std::move(fdef), std::move(stack_traces), &added); if (!s.ok()) { - Status remove_status = Remove(funcs, funcs_with_grads); + absl::Status remove_status = Remove(funcs, funcs_with_grads); if (!remove_status.ok()) { return remove_status; } @@ -1615,7 +1619,7 @@ Status FunctionLibraryDefinition::AddLibrary( for (const GradientDef& grad : lib_def.gradient()) { s = AddGradientDefHelper(grad, &added); if (!s.ok()) { - Status remove_status = Remove(funcs, funcs_with_grads); + absl::Status remove_status = Remove(funcs, funcs_with_grads); if (!remove_status.ok()) { return remove_status; } @@ -1628,7 +1632,7 @@ Status FunctionLibraryDefinition::AddLibrary( return absl::OkStatus(); } -Status FunctionLibraryDefinition::ReplaceFunction( +absl::Status FunctionLibraryDefinition::ReplaceFunction( const string& func, const FunctionDef& fdef, const StackTracesMap& stack_traces) { mutex_lock l(mu_); @@ -1639,7 +1643,8 @@ Status FunctionLibraryDefinition::ReplaceFunction( return absl::OkStatus(); } -Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) { +absl::Status FunctionLibraryDefinition::ReplaceGradient( + const GradientDef& grad) { mutex_lock l(mu_); bool added; TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name())); @@ -1647,13 +1652,14 @@ Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) { return absl::OkStatus(); } -Status FunctionLibraryDefinition::RemoveFunction(const string& func) { +absl::Status FunctionLibraryDefinition::RemoveFunction(const string& func) { mutex_lock l(mu_); TF_RETURN_IF_ERROR(RemoveFunctionHelper(func)); return absl::OkStatus(); } -Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) { +absl::Status FunctionLibraryDefinition::RemoveFunctionHelper( + const string& func) { auto iter = records_.find(func); if (iter == records_.end()) { return errors::InvalidArgument("Tried to remove non-existent function '", @@ -1674,7 +1680,7 @@ void FunctionLibraryDefinition::Clear() { func_grad_.clear(); } -Status FunctionLibraryDefinition::RemoveGradient(const string& func) { +absl::Status FunctionLibraryDefinition::RemoveGradient(const string& func) { const auto& i = func_grad_.find(func); if (i == func_grad_.end()) { return errors::InvalidArgument("Tried to remove non-existent gradient '", @@ -1684,10 +1690,10 @@ Status FunctionLibraryDefinition::RemoveGradient(const string& func) { return absl::OkStatus(); } -Status FunctionLibraryDefinition::Remove( +absl::Status FunctionLibraryDefinition::Remove( const std::vector& funcs, const std::vector& funcs_with_grads) { - Status s; + absl::Status s; for (const string& f : funcs) { s = RemoveFunctionHelper(f); if (!s.ok()) { @@ -1712,7 +1718,7 @@ string FunctionLibraryDefinition::FindGradientHelper(const string& func) const { return gtl::FindWithDefault(func_grad_, func, ""); } -Status FunctionLibraryDefinition::LookUp( +absl::Status FunctionLibraryDefinition::LookUp( const string& op, const OpRegistrationData** op_reg_data) const { tf_shared_lock l(mu_); auto iter = records_.find(op); @@ -1723,7 +1729,8 @@ Status FunctionLibraryDefinition::LookUp( return default_registry_->LookUp(op, op_reg_data); } -string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const { +string FunctionLibraryDefinition::UniqueFunctionName( + absl::string_view prefix) const { tf_shared_lock l(mu_); int index = 0; string name = strings::StrCat(prefix, index); @@ -1792,8 +1799,9 @@ FunctionDefLibrary FunctionLibraryDefinition::ToProto() const { } template -Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef, - const string& attr, T* value) const { +absl::Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef, + const string& attr, + T* value) const { const FunctionDef* fdef = GetAttrImpl(ndef); if (fdef && TryGetNodeAttr(AttrSlice(&fdef->attr()), attr, value)) { return absl::OkStatus(); @@ -1802,8 +1810,9 @@ Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef, } template -Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr, - T* value) const { +absl::Status FunctionLibraryDefinition::GetAttr(const Node& node, + const string& attr, + T* value) const { return GetAttr(node.def(), attr, value); } @@ -1941,7 +1950,7 @@ FunctionLibraryDefinition ReachableFunctionLibraryDefinition( for (const string& func_name : reachable_funcs) { // This should never fail, because we copy functions from a valid flib and // use the same default registry. - Status added = reachable_flib.CopyFunctionDefFrom(func_name, flib); + absl::Status added = reachable_flib.CopyFunctionDefFrom(func_name, flib); TF_DCHECK_OK(added); const string grad_func_name = flib.FindGradient(func_name); @@ -1950,7 +1959,7 @@ FunctionLibraryDefinition ReachableFunctionLibraryDefinition( grad.set_function_name(func_name); grad.set_gradient_func(grad_func_name); // It can only fail if function already has a gradient function. - const Status added_grad = reachable_flib.AddGradientDef(grad); + const absl::Status added_grad = reachable_flib.AddGradientDef(grad); TF_DCHECK_OK(added_grad); } } @@ -2033,7 +2042,8 @@ string FunctionLibraryRuntime::Options::DebugString() const { " rets_alloc_attrs=", AllocatorAttributesToString(rets_alloc_attrs), ")"); } -void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) { +void FunctionDefHelper::AttrValueWrapper::InitFromString( + absl::string_view val) { if (val.size() >= 2 && val[0] == '$') { proto.set_placeholder(val.data() + 1, val.size() - 1); } else { @@ -2231,7 +2241,7 @@ bool RegisterOp(const string& op, Creator func) { return true; } -Status GetOpGradientCreator(const string& op, Creator* creator) { +absl::Status GetOpGradientCreator(const string& op, Creator* creator) { auto fac = GetOpGradFactory(); auto iter = fac->find(op); if (iter == fac->end()) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 4c81d7b79ed457..8c77af3808d516 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -118,7 +118,7 @@ class FunctionDefHelper { } private: - void InitFromString(StringPiece val); + void InitFromString(absl::string_view val); }; // Constructs an AttrValue.func given the "name" and "attrs". @@ -237,7 +237,8 @@ inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper( } template <> -inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) { +inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper( + absl::string_view val) { InitFromString(val); } @@ -534,7 +535,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // Generates new function name with the specified prefix that is unique // across this library. - std::string UniqueFunctionName(StringPiece prefix) const + std::string UniqueFunctionName(absl::string_view prefix) const TF_LOCKS_EXCLUDED(mu_); // Given a node def 'ndef', inspects attributes of the callee diff --git a/tensorflow/core/framework/function_handle_cache.cc b/tensorflow/core/framework/function_handle_cache.cc index add92c44aff5bc..6b9119b681af88 100644 --- a/tensorflow/core/framework/function_handle_cache.cc +++ b/tensorflow/core/framework/function_handle_cache.cc @@ -26,13 +26,13 @@ FunctionHandleCache::FunctionHandleCache(FunctionLibraryRuntime* lib) strings::Printf("%lld", static_cast(random::New64()))) {} FunctionHandleCache::~FunctionHandleCache() { - Status s = Clear(); + absl::Status s = Clear(); if (!s.ok()) { LOG(ERROR) << "Failed to clear function handle cache: " << s.ToString(); } } -Status FunctionHandleCache::Instantiate( +absl::Status FunctionHandleCache::Instantiate( const string& function_name, AttrSlice attrs, FunctionLibraryRuntime::InstantiateOptions options, FunctionLibraryRuntime::Handle* handle) { @@ -54,7 +54,7 @@ Status FunctionHandleCache::Instantiate( return absl::OkStatus(); } -Status FunctionHandleCache::Clear() { +absl::Status FunctionHandleCache::Clear() { mutex_lock l(mu_); for (const auto& entry : handles_) { TF_RETURN_IF_ERROR(lib_->ReleaseHandle(entry.second)); diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc index ae06188b8bc83a..5e5c64d2a2a5ee 100644 --- a/tensorflow/core/framework/function_testlib.cc +++ b/tensorflow/core/framework/function_testlib.cc @@ -48,7 +48,8 @@ GraphDef GDef(absl::Span nodes, } // Helper to construct a NodeDef. -NodeDef NDef(StringPiece name, StringPiece op, absl::Span inputs, +NodeDef NDef(absl::string_view name, absl::string_view op, + absl::Span inputs, absl::Span> attrs, const string& device) { NodeDef n; diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h index 06e0c3a6d36ca9..93cae697e62d15 100644 --- a/tensorflow/core/framework/function_testlib.h +++ b/tensorflow/core/framework/function_testlib.h @@ -56,7 +56,8 @@ class Attrs { // Helper to construct a NodeDef. NodeDef NDef( - StringPiece name, StringPiece op, absl::Span inputs, + absl::string_view name, absl::string_view op, + absl::Span inputs, absl::Span> attrs = {}, const string& device = ""); diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc index 8b9a8615bc6113..73a8516bff5eb7 100644 --- a/tensorflow/core/framework/graph_def_util.cc +++ b/tensorflow/core/framework/graph_def_util.cc @@ -45,22 +45,22 @@ string SummarizeGraphDef(const GraphDef& graph_def) { return ret; } -Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) { +absl::Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) { for (const NodeDef& node : graph_def.node()) { TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node)); } return absl::OkStatus(); } -Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, - const OpRegistryInterface& op_registry, - int node_offset) { +absl::Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, + const OpRegistryInterface& op_registry, + int node_offset) { return AddDefaultAttrsToGraphDef(graph_def, op_registry, node_offset, false); } -Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, - const OpRegistryInterface& op_registry, - int node_offset, bool skip_unknown_ops) { +absl::Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, + const OpRegistryInterface& op_registry, + int node_offset, bool skip_unknown_ops) { if (node_offset > graph_def->node_size()) { return errors::InvalidArgument( "Tried to add default attrs to GraphDef " @@ -71,7 +71,7 @@ Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, for (int i = node_offset; i < graph_def->node_size(); ++i) { NodeDef* node_def = graph_def->mutable_node(i); const OpDef* op_def; - Status s = op_registry.LookUpOpDef(node_def->op(), &op_def); + absl::Status s = op_registry.LookUpOpDef(node_def->op(), &op_def); if (s.ok()) { AddDefaultsToNodeDef(*op_def, node_def); } else if (!skip_unknown_ops) { @@ -82,7 +82,7 @@ Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, return absl::OkStatus(); } -static Status RemoveNewDefaultAttrsFromNodeDef( +static absl::Status RemoveNewDefaultAttrsFromNodeDef( NodeDef* node_def, const OpRegistryInterface& consumer_op_registry, const OpRegistryInterface& producer_op_registry, std::set>* op_attr_removed) { @@ -134,7 +134,7 @@ static bool IsFunction(const GraphDef& graph_def, const string& op_name) { return false; } -Status RemoveNewDefaultAttrsFromGraphDef( +absl::Status RemoveNewDefaultAttrsFromGraphDef( GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry, const OpRegistryInterface& producer_op_registry, std::set>* op_attr_removed) { @@ -171,7 +171,7 @@ void StripDefaultAttributes(const OpRegistryInterface& op_registry, const OpDef* op_def; const OpRegistrationData* op_reg_data = nullptr; - Status s = op_registry.LookUp(node->op(), &op_reg_data); + absl::Status s = op_registry.LookUp(node->op(), &op_reg_data); if (!s.ok()) { VLOG(1) << "Ignoring encountered unknown operation " << SummarizeNodeDef(*node) @@ -246,9 +246,9 @@ void OpsUsedByGraph(const GraphDef& graph_def, } } -Status StrippedOpListForGraph(const GraphDef& graph_def, - const OpRegistryInterface& op_registry, - OpList* stripped_op_list) { +absl::Status StrippedOpListForGraph(const GraphDef& graph_def, + const OpRegistryInterface& op_registry, + OpList* stripped_op_list) { std::set used_ops; OpsUsedByGraph(graph_def, &used_ops); diff --git a/tensorflow/core/framework/graph_to_functiondef.cc b/tensorflow/core/framework/graph_to_functiondef.cc index fcd48e3fc5e047..b699037d7317a1 100644 --- a/tensorflow/core/framework/graph_to_functiondef.cc +++ b/tensorflow/core/framework/graph_to_functiondef.cc @@ -61,7 +61,7 @@ class NodeNameMapping { // Records name as a used name. If this name is already used, // returns an error status. - Status UseOutputName(const string& name); + absl::Status UseOutputName(const string& name); // Look up how a node name was previously normalized/uniquified. // Returns empty if name was never seen. @@ -137,7 +137,7 @@ string NodeNameMapping::Uniquify(const string& name) { return uniqued; } -Status NodeNameMapping::UseOutputName(const string& name) { +absl::Status NodeNameMapping::UseOutputName(const string& name) { const auto& iter = used_names_.find(name); if (iter != used_names_.end()) { return errors::InvalidArgument( @@ -154,7 +154,7 @@ string NodeNameMapping::Lookup(const string& name) const { return iter->second; } -Status FillFunctionBody( +absl::Status FillFunctionBody( const string& fn_name, const NodeNameMapping& node_names, const std::vector& body_nodes, const absl::flat_hash_map& tensor_renaming, @@ -321,7 +321,7 @@ Status FillFunctionBody( return absl::OkStatus(); } -Status GraphToFunctionDefHelper( +absl::Status GraphToFunctionDefHelper( const Graph& fn_body, const string& fn_name, bool append_hash_to_fn_name, bool set_stateful_from_nodes, bool copy_placeholder_attrs_from_nodes, const std::vector& body_nodes, @@ -439,7 +439,7 @@ Status GraphToFunctionDefHelper( TF_RETURN_IF_ERROR( NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges)); for (const auto& output : output_ranges) { - const StringPiece& output_name = output.first; + const absl::string_view& output_name = output.first; int index_start = output.second.first; int index_end = output.second.second; for (int i = index_start; i < index_end; ++i) { @@ -488,7 +488,7 @@ Status GraphToFunctionDefHelper( const uint64 hash = FunctionDefHash(*fdef); string encoded; TF_RETURN_IF_ERROR(Base64Encode( - StringPiece(reinterpret_cast(&hash), sizeof(hash)), + absl::string_view(reinterpret_cast(&hash), sizeof(hash)), &encoded)); // Besides letters and digits our Base64 encoding uses '_' and '-'. // Dash is invalid in operation names and multiple underscores in random @@ -539,7 +539,7 @@ Status GraphToFunctionDefHelper( return absl::OkStatus(); } -Status GraphToFunctionDefHelper( +absl::Status GraphToFunctionDefHelper( const Graph& graph, const string& name, const std::function(const Node*)>& control_ret, const std::vector& output_names, bool allow_destructive_reads, @@ -615,17 +615,17 @@ Status GraphToFunctionDefHelper( } // anonymous namespace -Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, - bool append_hash_to_fn_name, - bool set_stateful_from_nodes, - bool copy_placeholder_attrs_from_nodes, - const std::vector& body_nodes, - const std::vector& inputs, - const std::vector& outputs, - const std::vector& output_names, - const std::vector& control_outputs, - const std::vector& control_output_names, - const char* description, FunctionDef* fdef) { +absl::Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, + bool append_hash_to_fn_name, + bool set_stateful_from_nodes, + bool copy_placeholder_attrs_from_nodes, + const std::vector& body_nodes, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& output_names, + const std::vector& control_outputs, + const std::vector& control_output_names, + const char* description, FunctionDef* fdef) { return GraphToFunctionDefHelper( fn_body, fn_name, append_hash_to_fn_name, set_stateful_from_nodes, copy_placeholder_attrs_from_nodes, body_nodes, inputs, outputs, @@ -634,7 +634,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, return absl::OkStatus(); } -Status GraphToFunctionDef( +absl::Status GraphToFunctionDef( const Graph& graph, const string& name, const std::function(const Node*)>& control_ret, FunctionDef* fdef) { @@ -643,20 +643,20 @@ Status GraphToFunctionDef( /*allow_destructive_reads=*/false, fdef); } -Status GraphToFunctionDef(const Graph& graph, const string& name, - FunctionDef* fdef) { +absl::Status GraphToFunctionDef(const Graph& graph, const string& name, + FunctionDef* fdef) { return GraphToFunctionDef(graph, name, /*control_ret=*/nullptr, fdef); } -Status GraphToFunctionDef(const Graph& graph, const string& name, - const std::vector& output_names, - FunctionDef* fdef) { +absl::Status GraphToFunctionDef(const Graph& graph, const string& name, + const std::vector& output_names, + FunctionDef* fdef) { return GraphToFunctionDefHelper(graph, name, /*control_ret=*/nullptr, output_names, /*allow_destructive_reads=*/false, fdef); } -Status GraphToFunctionDef( +absl::Status GraphToFunctionDef( std::unique_ptr graph, const string& name, const std::function(const Node*)>& control_ret, FunctionDef* fdef) { diff --git a/tensorflow/core/framework/kernel_def_util.cc b/tensorflow/core/framework/kernel_def_util.cc index d1f556bdaa9288..f82faf9b0a50fa 100644 --- a/tensorflow/core/framework/kernel_def_util.cc +++ b/tensorflow/core/framework/kernel_def_util.cc @@ -33,8 +33,8 @@ bool InTypeList(DataType dt, const AttrValue& type_list) { } } // namespace -Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs, - bool* match) { +absl::Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs, + bool* match) { *match = false; for (const auto& constraint : kernel_def.constraint()) { auto constraint_value_case = AttrValue::VALUE_NOT_SET; diff --git a/tensorflow/core/framework/kernel_shape_util.cc b/tensorflow/core/framework/kernel_shape_util.cc index f06a366f435e5f..9a60b1bd762019 100644 --- a/tensorflow/core/framework/kernel_shape_util.cc +++ b/tensorflow/core/framework/kernel_shape_util.cc @@ -20,11 +20,10 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { -Status GetWindowedOutputSizeVerbose(int64_t input_size, int64_t filter_size, - int64_t dilation_rate, int64_t stride, - Padding padding_type, int64_t* output_size, - int64_t* padding_before, - int64_t* padding_after) { +absl::Status GetWindowedOutputSizeVerbose( + int64_t input_size, int64_t filter_size, int64_t dilation_rate, + int64_t stride, Padding padding_type, int64_t* output_size, + int64_t* padding_before, int64_t* padding_after) { if (stride <= 0) { return errors::InvalidArgument("Stride must be > 0, but got ", stride); } @@ -66,10 +65,10 @@ Status GetWindowedOutputSizeVerbose(int64_t input_size, int64_t filter_size, return absl::OkStatus(); } -Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size, - int dilation_rate, int64_t stride, - Padding padding_type, int64_t* output_size, - int64_t* padding_size) { +absl::Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size, + int dilation_rate, int64_t stride, + Padding padding_type, int64_t* output_size, + int64_t* padding_size) { if (padding_type == Padding::EXPLICIT) { return errors::Internal( "GetWindowedOutputSize does not handle EXPLICIT padding; call " @@ -81,13 +80,13 @@ Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size, padding_size, &padding_after_unused); } -Status Get3dOutputSizeV2(const std::array& input, - const std::array& window, - const std::array& dilations, - const std::array& strides, - Padding padding_type, - std::array* output_ptr, - std::array* padding_ptr) { +absl::Status Get3dOutputSizeV2(const std::array& input, + const std::array& window, + const std::array& dilations, + const std::array& strides, + Padding padding_type, + std::array* output_ptr, + std::array* padding_ptr) { for (size_t i = 0; i < input.size(); ++i) { TF_RETURN_IF_ERROR(GetWindowedOutputSize( input[i], window[i], dilations[i], strides[i], padding_type, diff --git a/tensorflow/core/framework/load_library.cc b/tensorflow/core/framework/load_library.cc index d428f6d463ea51..4c7de27c2afb32 100644 --- a/tensorflow/core/framework/load_library.cc +++ b/tensorflow/core/framework/load_library.cc @@ -43,8 +43,8 @@ struct Library { // and OpList. Ops and kernels are registered as globals when a library is // loaded for the first time. Without caching, every subsequent load would not // perform initialization again, so the OpList would be empty. -Status LoadDynamicLibrary(const char* library_filename, void** result, - const void** buf, size_t* len) { +absl::Status LoadDynamicLibrary(const char* library_filename, void** result, + const void** buf, size_t* len) { static mutex mu(LINKER_INITIALIZED); static std::unordered_map loaded_libs; Env* env = Env::Default(); @@ -55,13 +55,13 @@ Status LoadDynamicLibrary(const char* library_filename, void** result, if (loaded_libs.find(library_filename) != loaded_libs.end()) { library = loaded_libs[library_filename]; } else { - Status s = OpRegistry::Global()->ProcessRegistrations(); + absl::Status s = OpRegistry::Global()->ProcessRegistrations(); if (!s.ok()) { return s; } TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher( - [&library, &seen_op_names](const Status& s, - const OpDef& opdef) -> Status { + [&library, &seen_op_names](const absl::Status& s, + const OpDef& opdef) -> absl::Status { if (errors::IsAlreadyExists(s)) { if (seen_op_names.find(opdef.name()) == seen_op_names.end()) { // Over writing a registration of an op not in this custom op diff --git a/tensorflow/core/framework/local_rendezvous.cc b/tensorflow/core/framework/local_rendezvous.cc index 910c8a92a744fb..53e231bdc7fedd 100644 --- a/tensorflow/core/framework/local_rendezvous.cc +++ b/tensorflow/core/framework/local_rendezvous.cc @@ -141,12 +141,14 @@ LocalRendezvous::~LocalRendezvous() { } namespace { -uint64 KeyHash(const StringPiece& k) { return Hash64(k.data(), k.size()); } +uint64 KeyHash(const absl::string_view& k) { + return Hash64(k.data(), k.size()); +} } // namespace -Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key, - const Rendezvous::Args& send_args, - const Tensor& val, const bool is_dead) { +absl::Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key, + const Rendezvous::Args& send_args, + const Tensor& val, const bool is_dead) { uint64 key_hash = KeyHash(key.FullKey()); DVLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey(); @@ -330,7 +332,7 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key, queue->push_back(new Item( std::move(rc_owner), recv_args, [this, cm, token, done = std::move(done)]( - const Status& s, const Rendezvous::Args& send_args, + const absl::Status& s, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, const Tensor& v, bool dead) { // TryDeregisterCallback returns true when the cancellation callback // is successfully deregistered. If it fails because the CM already @@ -387,7 +389,7 @@ std::vector >& LocalRendezvous::aborted_rendezs_ = *new std::vector >(); -void LocalRendezvous::StartAbort(const Status& status) { +void LocalRendezvous::StartAbort(const absl::Status& status) { DoAbort(status); if (rc_owner_) { @@ -396,7 +398,7 @@ void LocalRendezvous::StartAbort(const Status& status) { } } -void LocalRendezvous::DoAbort(const Status& status) { +void LocalRendezvous::DoAbort(const absl::Status& status) { CHECK(!status.ok()); { mutex_lock l(mu_); @@ -436,7 +438,7 @@ void LocalRendezvous::DoAbort(const Status& status) { } } -Status LocalRendezvous::status() { +absl::Status LocalRendezvous::status() { tf_shared_lock ml(mu_); return status_; } diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc index 2dc224c3f5b6ea..eb8e0bc8eaff70 100644 --- a/tensorflow/core/framework/lookup_interface.cc +++ b/tensorflow/core/framework/lookup_interface.cc @@ -21,7 +21,7 @@ limitations under the License. namespace tensorflow { namespace lookup { -Status LookupInterface::CheckKeyShape(const TensorShape& shape) { +absl::Status LookupInterface::CheckKeyShape(const TensorShape& shape) { if (!TensorShapeUtils::EndsWith(shape, key_shape())) { return errors::InvalidArgument("Input key shape ", shape.DebugString(), " must end with the table's key shape ", @@ -30,8 +30,8 @@ Status LookupInterface::CheckKeyShape(const TensorShape& shape) { return absl::OkStatus(); } -Status LookupInterface::CheckKeyAndValueTypes(const Tensor& keys, - const Tensor& values) { +absl::Status LookupInterface::CheckKeyAndValueTypes(const Tensor& keys, + const Tensor& values) { if (keys.dtype() != key_dtype()) { return errors::InvalidArgument("Key must be type ", key_dtype(), " but got ", keys.dtype()); @@ -43,8 +43,8 @@ Status LookupInterface::CheckKeyAndValueTypes(const Tensor& keys, return absl::OkStatus(); } -Status LookupInterface::CheckKeyAndValueTensorsHelper(const Tensor& keys, - const Tensor& values) { +absl::Status LookupInterface::CheckKeyAndValueTensorsHelper( + const Tensor& keys, const Tensor& values) { TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(keys, values)); TF_RETURN_IF_ERROR(CheckKeyShape(keys.shape())); @@ -61,17 +61,17 @@ Status LookupInterface::CheckKeyAndValueTensorsHelper(const Tensor& keys, return absl::OkStatus(); } -Status LookupInterface::CheckKeyAndValueTensorsForInsert(const Tensor& keys, - const Tensor& values) { +absl::Status LookupInterface::CheckKeyAndValueTensorsForInsert( + const Tensor& keys, const Tensor& values) { return CheckKeyAndValueTensorsHelper(keys, values); } -Status LookupInterface::CheckKeyAndValueTensorsForImport(const Tensor& keys, - const Tensor& values) { +absl::Status LookupInterface::CheckKeyAndValueTensorsForImport( + const Tensor& keys, const Tensor& values) { return CheckKeyAndValueTensorsHelper(keys, values); } -Status LookupInterface::CheckKeyTensorForRemove(const Tensor& keys) { +absl::Status LookupInterface::CheckKeyTensorForRemove(const Tensor& keys) { if (keys.dtype() != key_dtype()) { return errors::InvalidArgument("Key must be type ", key_dtype(), " but got ", keys.dtype()); @@ -79,8 +79,8 @@ Status LookupInterface::CheckKeyTensorForRemove(const Tensor& keys) { return CheckKeyShape(keys.shape()); } -Status LookupInterface::CheckFindArguments(const Tensor& key, - const Tensor& default_value) { +absl::Status LookupInterface::CheckFindArguments(const Tensor& key, + const Tensor& default_value) { TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value)); TF_RETURN_IF_ERROR(CheckKeyShape(key.shape())); TensorShape fullsize_value_shape = key.shape(); diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc index b983cf95d8ca4a..8b187beb125740 100644 --- a/tensorflow/core/framework/memory_types.cc +++ b/tensorflow/core/framework/memory_types.cc @@ -80,17 +80,18 @@ MemoryType MTypeFromDTypeIntsOnDevice(const DataType dtype) { return DataTypeAlwaysOnHost(dtype) ? HOST_MEMORY : DEVICE_MEMORY; } -Status MemoryTypesForNode(const OpRegistryInterface* op_registry, - const DeviceType& device_type, const NodeDef& ndef, - MemoryTypeVector* inp_mtypes, - MemoryTypeVector* out_mtypes) { +absl::Status MemoryTypesForNode(const OpRegistryInterface* op_registry, + const DeviceType& device_type, + const NodeDef& ndef, + MemoryTypeVector* inp_mtypes, + MemoryTypeVector* out_mtypes) { // Look up the Op registered for this op name. const OpDef* op_def; TF_RETURN_IF_ERROR(op_registry->LookUpOpDef(ndef.op(), &op_def)); // Look up the Kernel registered for this node def. const KernelDef* kdef = nullptr; - Status status = + absl::Status status = FindKernelDef(device_type, ndef, &kdef, nullptr /* kernel_class_name */); DataTypeVector inp_dtypes; diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index 1fc6622bebe170..47c0dd1c6c2eab 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -363,7 +363,8 @@ inline void UpdateStateValues(Node::ModelParameters* parameters) { // Recursively produces protos for nodes in a subtree of `output` node and // appends them to nodes of the given model. -Status ModelToProtoHelper(std::shared_ptr output, ModelProto* model) { +absl::Status ModelToProtoHelper(std::shared_ptr output, + ModelProto* model) { model->set_output(output->id()); std::list> to_serialize = {output}; auto& nodes = *model->mutable_nodes(); @@ -379,7 +380,8 @@ Status ModelToProtoHelper(std::shared_ptr output, ModelProto* model) { } // Recursively produces node tree rooted in `output` from the given model proto. -Status ModelFromProtoHelper(ModelProto model, std::shared_ptr* output) { +absl::Status ModelFromProtoHelper(ModelProto model, + std::shared_ptr* output) { if (model.nodes().empty()) { return errors::Internal( "Cannot restore model from proto because it has no nodes."); @@ -552,7 +554,7 @@ class InterleaveMany : public Node { self_processing_time + inputs_processing_time; } - Status ToProto(ModelProto::Node* node_proto) const override { + absl::Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::INTERLEAVE_MANY); return absl::OkStatus(); @@ -775,7 +777,7 @@ class AsyncInterleaveMany : public Node { return (*parameter)->value * AverageBufferedElementSizeLocked(); } - Status ToProto(ModelProto::Node* node_proto) const override { + absl::Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::ASYNC_INTERLEAVE_MANY); return absl::OkStatus(); @@ -867,7 +869,7 @@ class KnownRatio : public Node { self_processing_time + inputs_processing_time; } - Status ToProto(ModelProto::Node* node_proto) const override { + absl::Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::KNOWN_RATIO); node_proto->set_ratio(ratio_); @@ -1247,7 +1249,7 @@ class UnknownRatio : public Node { self_processing_time + inputs_processing_time; } - Status ToProto(ModelProto::Node* node_proto) const override { + absl::Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::UNKNOWN_RATIO); return absl::OkStatus(); @@ -1301,7 +1303,7 @@ class Unknown : public Node { TotalProcessingTimeForInputs(*total_processing_times); } - Status ToProto(ModelProto::Node* node_proto) const override { + absl::Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::UNKNOWN); return absl::OkStatus(); @@ -1330,7 +1332,7 @@ class AsyncKnownRatio : public AsyncRatio { is_legacy_prefetch_autotuned_); } - Status ToProto(ModelProto::Node* node_proto) const override { + absl::Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::ASYNC_KNOWN_RATIO); node_proto->set_ratio(Ratio()); @@ -1387,7 +1389,7 @@ class AsyncUnknownRatio : public AsyncRatio { Args{id_, name_, std::move(output)}, parameters); } - Status ToProto(ModelProto::Node* node_proto) const override { + absl::Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::ASYNC_UNKNOWN_RATIO); return absl::OkStatus(); @@ -2138,7 +2140,7 @@ double Node::MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) { return 0; } -Status Node::ToProto(ModelProto::Node* node_proto) const { +absl::Status Node::ToProto(ModelProto::Node* node_proto) const { tf_shared_lock l(mu_); node_proto->set_id(id_); node_proto->set_name(name_); @@ -2171,8 +2173,8 @@ Status Node::ToProto(ModelProto::Node* node_proto) const { return absl::OkStatus(); } -Status Node::FromProtoHelper(ModelProto::Node node_proto, - std::shared_ptr node) { +absl::Status Node::FromProtoHelper(ModelProto::Node node_proto, + std::shared_ptr node) { { tf_shared_lock l(node->mu_); node->autotune_.store(node_proto.autotune()); @@ -2221,9 +2223,9 @@ Status Node::FromProtoHelper(ModelProto::Node node_proto, return absl::OkStatus(); } -Status Node::FromProto(ModelProto::Node node_proto, - std::shared_ptr output, - std::shared_ptr* node) { +absl::Status Node::FromProto(ModelProto::Node node_proto, + std::shared_ptr output, + std::shared_ptr* node) { // Note that parameters are restored in `FromProtoHelper`. Args args = {node_proto.id(), node_proto.name(), std::move(output)}; switch (node_proto.node_class()) { @@ -2274,7 +2276,7 @@ Model::Model(std::optional dataset_name) tf_shared_lock snapshot_lock(mu_); if (snapshot_ != nullptr) { ModelProto model_proto; - Status s = ModelToProtoHelper(snapshot_, &model_proto); + absl::Status s = ModelToProtoHelper(snapshot_, &model_proto); if (s.ok()) { *model_proto.mutable_optimization_params() = optimization_params_; tf_shared_lock l(gap_mu_); @@ -2538,12 +2540,12 @@ bool Model::ShouldStop(int64_t cpu_budget, int64_t ram_budget, } // TODO(jsimsa): Add support for tracking and using the model input time. -Status Model::OptimizeLoop(AutotuneAlgorithm algorithm, - std::function cpu_budget_func, - double ram_budget_share, - std::optional fixed_ram_budget, - RamBudgetManager& ram_budget_manager, - CancellationManager* cancellation_manager) { +absl::Status Model::OptimizeLoop(AutotuneAlgorithm algorithm, + std::function cpu_budget_func, + double ram_budget_share, + std::optional fixed_ram_budget, + RamBudgetManager& ram_budget_manager, + CancellationManager* cancellation_manager) { std::function unused; TF_RETURN_IF_ERROR(RegisterCancellationCallback( cancellation_manager, @@ -3182,7 +3184,7 @@ double Model::TotalProcessingTime(std::shared_ptr node) { return node->TotalProcessingTime(/*processing_times=*/nullptr); } -Status Model::ToProto(ModelProto* model_proto) { +absl::Status Model::ToProto(ModelProto* model_proto) { tf_shared_lock l(mu_); model_proto->set_id_counter(id_counter_); TF_RETURN_IF_ERROR(ModelToProtoHelper(output_, model_proto)); @@ -3197,7 +3199,8 @@ Status Model::ToProto(ModelProto* model_proto) { return absl::OkStatus(); } -Status Model::FromProto(ModelProto model_proto, std::unique_ptr* model) { +absl::Status Model::FromProto(ModelProto model_proto, + std::unique_ptr* model) { std::unique_ptr restored_model = std::make_unique(); mutex_lock l(restored_model->mu_); TF_RETURN_IF_ERROR( @@ -3207,8 +3210,8 @@ Status Model::FromProto(ModelProto model_proto, std::unique_ptr* model) { return absl::OkStatus(); } -Status Model::Save(const string& fname, std::shared_ptr snapshot, - const OptimizationParams& optimization_params) { +absl::Status Model::Save(const string& fname, std::shared_ptr snapshot, + const OptimizationParams& optimization_params) { ModelProto model_proto; std::unique_ptr model_snapshot = std::make_unique(); { @@ -3223,8 +3226,8 @@ Status Model::Save(const string& fname, std::shared_ptr snapshot, return WriteBinaryProto(Env::Default(), fname, model_proto); } -Status Model::Load(const string& fname, std::unique_ptr* model, - OptimizationParams* optimization_params) { +absl::Status Model::Load(const string& fname, std::unique_ptr* model, + OptimizationParams* optimization_params) { ModelProto model_proto; TF_RETURN_IF_ERROR( ReadTextOrBinaryProto(Env::Default(), fname, &model_proto)); @@ -3246,7 +3249,7 @@ std::string Model::DebugString() { } // TODO(jsimsa): Populate OptimizationParams. ModelProto model_proto; - Status s = ModelToProtoHelper(snapshot, &model_proto); + absl::Status s = ModelToProtoHelper(snapshot, &model_proto); if (s.ok()) { cached_debug_string_ = model_proto.DebugString(); } else { diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc index 86365b494217bd..727a66d45f2f41 100644 --- a/tensorflow/core/framework/node_def_builder.cc +++ b/tensorflow/core/framework/node_def_builder.cc @@ -24,24 +24,26 @@ limitations under the License. namespace tensorflow { -NodeDefBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType dt) +NodeDefBuilder::NodeOut::NodeOut(absl::string_view n, int i, DataType dt) : node(n), index(i), data_type(dt) {} NodeDefBuilder::NodeOut::NodeOut() { // uninitialized, call Reset() before use. } -void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) { +void NodeDefBuilder::NodeOut::Reset(absl::string_view n, int i, DataType dt) { node = string(n); index = i; data_type = dt; } -NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name, +NodeDefBuilder::NodeDefBuilder(absl::string_view name, + absl::string_view op_name, const OpRegistryInterface* op_registry, const NodeDebugInfo* debug) { node_def_.set_name(string(name)); - const Status status = op_registry->LookUpOpDef(string(op_name), &op_def_); + const absl::Status status = + op_registry->LookUpOpDef(string(op_name), &op_def_); if (status.ok()) { Initialize(); } else { @@ -51,13 +53,14 @@ NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name, if (debug != nullptr) MergeDebugInfo(*debug, &node_def_); } -NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name, +NodeDefBuilder::NodeDefBuilder(absl::string_view name, + absl::string_view op_name, const NodeDebugInfo& debug) : NodeDefBuilder(name, op_name) { MergeDebugInfo(debug, &node_def_); } -NodeDefBuilder::NodeDefBuilder(StringPiece name, const OpDef* op_def) +NodeDefBuilder::NodeDefBuilder(absl::string_view name, const OpDef* op_def) : op_def_(op_def) { node_def_.set_name(string(name)); Initialize(); @@ -87,13 +90,14 @@ bool NodeDefBuilder::NextArgAvailable() { NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) { if (NextArgAvailable()) { - Status status = fake_input(*op_def_, inputs_specified_, node_def_, this); + absl::Status status = + fake_input(*op_def_, inputs_specified_, node_def_, this); if (!status.ok()) errors_.push_back(std::string(status.message())); } return *this; } -NodeDefBuilder& NodeDefBuilder::Input(StringPiece src_node, int src_index, +NodeDefBuilder& NodeDefBuilder::Input(absl::string_view src_node, int src_index, DataType dt) { const OpDef::ArgDef* arg = NextArgDef(); if (arg != nullptr) SingleInput(arg, src_node, src_index, dt); @@ -113,7 +117,7 @@ NodeDefBuilder& NodeDefBuilder::Input(absl::Span src_list) { } void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg, - StringPiece src_node, int src_index, + absl::string_view src_node, int src_index, DataType dt) { AddInput(src_node, src_index); @@ -170,7 +174,7 @@ void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg, } } -void NodeDefBuilder::AddInput(StringPiece src_node, int src_index) { +void NodeDefBuilder::AddInput(absl::string_view src_node, int src_index) { if (src_node.empty()) { errors_.push_back("Empty input node name"); } else if (src_node[0] == '^') { @@ -201,17 +205,17 @@ void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg, } } -NodeDefBuilder& NodeDefBuilder::ControlInput(StringPiece src_node) { +NodeDefBuilder& NodeDefBuilder::ControlInput(absl::string_view src_node) { control_inputs_.emplace_back(src_node); return *this; } -NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) { +NodeDefBuilder& NodeDefBuilder::Device(absl::string_view device_spec) { node_def_.set_device(string(device_spec)); return *this; } -Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) { +absl::Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) { const std::vector* errors_ptr = &errors_; std::vector errors_storage; if (op_def_ != nullptr && inputs_specified_ < op_def_->input_arg_size()) { @@ -266,7 +270,7 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) { } } -bool NodeDefBuilder::AttrValueAlreadyPresent(StringPiece name, +bool NodeDefBuilder::AttrValueAlreadyPresent(absl::string_view name, const AttrValue& value) { if (const AttrValue* found = AttrSlice(node_def_).Find(name)) { if (!AreAttrValuesEqual(*found, value)) { @@ -279,14 +283,16 @@ bool NodeDefBuilder::AttrValueAlreadyPresent(StringPiece name, return false; } -NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) { +NodeDefBuilder& NodeDefBuilder::Attr(absl::string_view name, + const AttrValue& value) { if (!AttrValueAlreadyPresent(name, value)) { AddNodeAttr(name, value, &node_def_); } return *this; } -NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, AttrValue&& value) { +NodeDefBuilder& NodeDefBuilder::Attr(absl::string_view name, + AttrValue&& value) { if (!AttrValueAlreadyPresent(name, value)) { AddNodeAttr(name, std::move(value), &node_def_); } @@ -299,7 +305,7 @@ NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, AttrValue&& value) { SetAttrValue(value, &attr_value); \ return Attr(name, attr_value); \ } -ATTR(StringPiece) +ATTR(absl::string_view) ATTR(const char*) ATTR(int32_t) ATTR(int64_t) @@ -311,7 +317,7 @@ ATTR(const PartialTensorShape&) ATTR(const Tensor&) ATTR(const TensorProto&) ATTR(const NameAttrList&) -ATTR(absl::Span) +ATTR(absl::Span) ATTR(absl::Span) ATTR(absl::Span) ATTR(absl::Span) diff --git a/tensorflow/core/framework/node_def_builder.h b/tensorflow/core/framework/node_def_builder.h index 5a19f774c7a199..47b14f185800cf 100644 --- a/tensorflow/core/framework/node_def_builder.h +++ b/tensorflow/core/framework/node_def_builder.h @@ -53,9 +53,9 @@ class NodeDefBuilder { public: // To specify an output to be consumed by one of the Input() methods below. struct NodeOut { - NodeOut(StringPiece n, int i, DataType dt); + NodeOut(absl::string_view n, int i, DataType dt); NodeOut(); // uninitialized, call Reset() before use. - void Reset(StringPiece n, int i, DataType dt); + void Reset(absl::string_view n, int i, DataType dt); string node; int index; DataType data_type; @@ -65,19 +65,19 @@ class NodeDefBuilder { // the Op plus a registry) for the NodeDef. Other fields are // specified by calling the methods below. // REQUIRES: The OpDef must satisfy ValidateOpDef(). - NodeDefBuilder(StringPiece name, StringPiece op_name, + NodeDefBuilder(absl::string_view name, absl::string_view op_name, const OpRegistryInterface* op_registry = OpRegistry::Global(), const NodeDebugInfo* debug = nullptr); - NodeDefBuilder(StringPiece name, StringPiece op_name, + NodeDefBuilder(absl::string_view name, absl::string_view op_name, const NodeDebugInfo& debug); // REQUIRES: in addition, *op_def must outlive *this. - NodeDefBuilder(StringPiece name, const OpDef* op_def); + NodeDefBuilder(absl::string_view name, const OpDef* op_def); // You must call one Input() function per input_arg in the Op, // *and in the same order as the input_args appear in the OpDef.* // For inputs that take a single tensor. - NodeDefBuilder& Input(StringPiece src_node, int src_index, DataType dt); + NodeDefBuilder& Input(absl::string_view src_node, int src_index, DataType dt); NodeDefBuilder& Input(const NodeOut& src); // For inputs that take a list of tensors. @@ -87,47 +87,52 @@ class NodeDefBuilder { NodeDefBuilder& Input(FakeInputFunctor fake_input); // Specify that this node must only run after src_node. - NodeDefBuilder& ControlInput(StringPiece src_node); + NodeDefBuilder& ControlInput(absl::string_view src_node); // Constrains what devices this node may be scheduled on. - NodeDefBuilder& Device(StringPiece device_spec); + NodeDefBuilder& Device(absl::string_view device_spec); // Sets the attr, if not already set. If already set with a different // value, an error will be returned from Finalize(). - NodeDefBuilder& Attr(StringPiece name, const AttrValue& value); - NodeDefBuilder& Attr(StringPiece name, AttrValue&& value); - NodeDefBuilder& Attr(StringPiece name, StringPiece value); - NodeDefBuilder& Attr(StringPiece name, const char* value); - NodeDefBuilder& Attr(StringPiece name, int32_t value); - NodeDefBuilder& Attr(StringPiece name, int64_t value); - NodeDefBuilder& Attr(StringPiece name, float value); - NodeDefBuilder& Attr(StringPiece name, double value); - NodeDefBuilder& Attr(StringPiece name, bool value); - NodeDefBuilder& Attr(StringPiece name, DataType value); - NodeDefBuilder& Attr(StringPiece name, const PartialTensorShape& value); - NodeDefBuilder& Attr(StringPiece name, const Tensor& value); - NodeDefBuilder& Attr(StringPiece name, const TensorProto& value); - NodeDefBuilder& Attr(StringPiece name, const NameAttrList& value); - NodeDefBuilder& Attr(StringPiece name, absl::Span value); - NodeDefBuilder& Attr(StringPiece name, absl::Span value); - NodeDefBuilder& Attr(StringPiece name, absl::Span value); - NodeDefBuilder& Attr(StringPiece name, absl::Span value); - NodeDefBuilder& Attr(StringPiece name, absl::Span value); - NodeDefBuilder& Attr(StringPiece name, absl::Span value); - NodeDefBuilder& Attr(StringPiece name, absl::Span value); - NodeDefBuilder& Attr(StringPiece name, absl::Span value); - NodeDefBuilder& Attr(StringPiece name, const std::vector& value); - NodeDefBuilder& Attr(StringPiece name, absl::Span value); - NodeDefBuilder& Attr(StringPiece name, absl::Span value); - NodeDefBuilder& Attr(StringPiece name, + NodeDefBuilder& Attr(absl::string_view name, const AttrValue& value); + NodeDefBuilder& Attr(absl::string_view name, AttrValue&& value); + NodeDefBuilder& Attr(absl::string_view name, absl::string_view value); + NodeDefBuilder& Attr(absl::string_view name, const char* value); + NodeDefBuilder& Attr(absl::string_view name, int32_t value); + NodeDefBuilder& Attr(absl::string_view name, int64_t value); + NodeDefBuilder& Attr(absl::string_view name, float value); + NodeDefBuilder& Attr(absl::string_view name, double value); + NodeDefBuilder& Attr(absl::string_view name, bool value); + NodeDefBuilder& Attr(absl::string_view name, DataType value); + NodeDefBuilder& Attr(absl::string_view name, const PartialTensorShape& value); + NodeDefBuilder& Attr(absl::string_view name, const Tensor& value); + NodeDefBuilder& Attr(absl::string_view name, const TensorProto& value); + NodeDefBuilder& Attr(absl::string_view name, const NameAttrList& value); + NodeDefBuilder& Attr(absl::string_view name, + absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, + absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, const std::vector& value); + NodeDefBuilder& Attr(absl::string_view name, + absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, + absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, absl::Span value); - NodeDefBuilder& Attr(StringPiece name, + NodeDefBuilder& Attr(absl::string_view name, absl::Span value); - NodeDefBuilder& Attr(StringPiece name, absl::Span value); - NodeDefBuilder& Attr(StringPiece name, absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, + absl::Span value); template - NodeDefBuilder& Attr(StringPiece name, std::initializer_list value) { + NodeDefBuilder& Attr(absl::string_view name, std::initializer_list value) { return Attr(name, gtl::ArraySlice(value)); } @@ -156,13 +161,13 @@ class NodeDefBuilder { bool NextArgAvailable(); // These do the main work of the Input() methods. - void SingleInput(const OpDef::ArgDef* input_arg, StringPiece src_node, + void SingleInput(const OpDef::ArgDef* input_arg, absl::string_view src_node, int src_index, DataType dt); void ListInput(const OpDef::ArgDef* input_arg, absl::Span src_list); // Add "src_node:src_index" to the list of inputs in the node_def_. - void AddInput(StringPiece src_node, int src_index); + void AddInput(absl::string_view src_node, int src_index); // Generate an error if you can't pass dt when expected is expected. void VerifyInputType(const OpDef::ArgDef* input_arg, DataType expected, @@ -179,7 +184,7 @@ class NodeDefBuilder { // Returns true if an attr named `name` is already present in the node_def_. // If such an attr is already present and `value` is not equal to the present // value, an error is generated. - bool AttrValueAlreadyPresent(StringPiece name, const AttrValue& value); + bool AttrValueAlreadyPresent(absl::string_view name, const AttrValue& value); const OpDef* op_def_; NodeDef node_def_; diff --git a/tensorflow/core/framework/node_def_builder_test.cc b/tensorflow/core/framework/node_def_builder_test.cc index c89932b13ee518..af72436c32a34d 100644 --- a/tensorflow/core/framework/node_def_builder_test.cc +++ b/tensorflow/core/framework/node_def_builder_test.cc @@ -51,7 +51,8 @@ class NodeDefBuilderTest : public ::testing::Test { // expectations. void ExpectSuccess(NodeDefBuilder& builder, // NOLINT DataTypeSlice expected_in_types, - DataTypeSlice expected_out_types, StringPiece proto) { + DataTypeSlice expected_out_types, + absl::string_view proto) { NodeDef node_def; absl::Status status = builder.Finalize(&node_def); TF_EXPECT_OK(status); diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index 183a80ac18b1f5..c94e34bfa48be7 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -64,7 +64,7 @@ AttrSlice::AttrSlice(const NodeDef& node_def) AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {} -string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) { +string SummarizeAttrsHelper(AttrSlice attrs, absl::string_view device) { string ret; // We sort the attrs so the output is deterministic. @@ -92,9 +92,10 @@ string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) { } string AttrSlice::SummarizeNode() const { - return ndef_ ? SummarizeNodeDef(*ndef_) - : strings::StrCat( - "[", SummarizeAttrsHelper(*this, StringPiece()), "]"); + return ndef_ + ? SummarizeNodeDef(*ndef_) + : strings::StrCat( + "[", SummarizeAttrsHelper(*this, absl::string_view()), "]"); } string AttrSlice::DebugString() const { @@ -135,7 +136,7 @@ string SummarizeAttrs(const NodeDef& node_def) { } string FormatNodeDefForError( - StringPiece node_name, bool has_experimental_debug_info, + absl::string_view node_name, bool has_experimental_debug_info, const NodeDef_ExperimentalDebugInfo& experimental_debug_info) { return !has_experimental_debug_info || experimental_debug_info.original_node_names().empty() @@ -151,7 +152,7 @@ string FormatNodeDefForError(const NodeDef& node_def) { node_def.experimental_debug_info()); } -const AttrValue* AttrSlice::Find(StringPiece attr_name) const { +const AttrValue* AttrSlice::Find(absl::string_view attr_name) const { // Currently, the collection used for NodeDef::attr() (google::protobuf::Map) // requires that the keys used for lookups have type 'const string&'. Because // this method takes a StringPiece, it is necessary to allocate a temporary @@ -182,12 +183,13 @@ const AttrValue* AttrSlice::FindByString(const string& attr_name) const { } } -Status AttrSlice::CheckFind(StringPiece attr_name, - const AttrValue* attr_value) const { +absl::Status AttrSlice::CheckFind(absl::string_view attr_name, + const AttrValue* attr_value) const { if (attr_value != nullptr) { return absl::OkStatus(); } - Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:"); + absl::Status s = + errors::NotFound("No attr named '", attr_name, "' in NodeDef:"); // Skip AttachDef for internal attrs since it is a little bit // expensive and it is common for them to correctly not be included // in a NodeDef. @@ -197,14 +199,14 @@ Status AttrSlice::CheckFind(StringPiece attr_name, return s; } -Status AttrSlice::Find(StringPiece attr_name, - const AttrValue** attr_value) const { +absl::Status AttrSlice::Find(absl::string_view attr_name, + const AttrValue** attr_value) const { *attr_value = Find(attr_name); return CheckFind(attr_name, *attr_value); } -Status AttrSlice::FindByString(const string& attr_name, - const AttrValue** attr_value) const { +absl::Status AttrSlice::FindByString(const string& attr_name, + const AttrValue** attr_value) const { *attr_value = FindByString(attr_name); return CheckFind(attr_name, *attr_value); } @@ -342,31 +344,32 @@ DEFINE_GET_ATTR( DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;); #undef DEFINE_GET_ATTR -bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) { +bool HasNodeAttr(const NodeDef& node_def, absl::string_view attr_name) { return node_def.attr().find(string(attr_name)) != node_def.attr().end(); } static const string& kEmptyString = *new string(); -const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) { +const string& GetNodeAttrString(const AttrSlice& attrs, + absl::string_view attr_name) { const AttrValue* attr_value = attrs.Find(attr_name); if (attr_value == nullptr) { return kEmptyString; } - Status s = AttrValueHasType(*attr_value, "string"); + absl::Status s = AttrValueHasType(*attr_value, "string"); if (!s.ok()) { return kEmptyString; } return attr_value->s(); } -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value) { const AttrValue* attr_value = attrs.Find(attr_name); if (attr_value == nullptr) { return false; } - Status s = AttrValueHasType(*attr_value, "list(string)"); + absl::Status s = AttrValueHasType(*attr_value, "list(string)"); if (!s.ok()) { return false; } @@ -377,13 +380,13 @@ bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, return true; } -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value) { const AttrValue* attr_value = attrs.Find(attr_name); if (attr_value == nullptr) { return false; } - Status s = AttrValueHasType(*attr_value, "list(shape)"); + absl::Status s = AttrValueHasType(*attr_value, "list(shape)"); if (!s.ok()) { return false; } @@ -394,8 +397,8 @@ bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, return true; } -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - DataTypeVector* value) { +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + DataTypeVector* value) { const AttrValue* attr_value; TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)")); @@ -405,8 +408,8 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, return absl::OkStatus(); } -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - const TensorProto** value) { +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + const TensorProto** value) { const AttrValue* attr_value; TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor")); @@ -414,13 +417,13 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, return absl::OkStatus(); } -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, const TensorProto** value) { const AttrValue* attr_value = attrs.Find(attr_name); if (attr_value == nullptr) { return false; } - Status s = AttrValueHasType(*attr_value, "tensor"); + absl::Status s = AttrValueHasType(*attr_value, "tensor"); if (!s.ok()) { return false; } @@ -428,8 +431,8 @@ bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, return true; } -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - const NameAttrList** value) { +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + const NameAttrList** value) { const AttrValue* attr_value; TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func")); @@ -437,13 +440,13 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, return absl::OkStatus(); } -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, const NameAttrList** value) { const AttrValue* attr_value = attrs.Find(attr_name); if (attr_value == nullptr) { return false; } - Status s = AttrValueHasType(*attr_value, "func"); + absl::Status s = AttrValueHasType(*attr_value, "func"); if (!s.ok()) { return false; } @@ -451,8 +454,8 @@ bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, return true; } -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - Padding* value) { +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + Padding* value) { string str_value; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_name, &str_value)); return GetPaddingFromString(str_value, value); @@ -461,8 +464,8 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, namespace { // Helper for InOutTypesForNode(). template -Status AddArgToSig(const NodeDefOrAttrSlice& node_or_attrs, - const OpDef::ArgDef& arg_def, DataTypeVector* sig) { +absl::Status AddArgToSig(const NodeDefOrAttrSlice& node_or_attrs, + const OpDef::ArgDef& arg_def, DataTypeVector* sig) { const int original_size = sig->size(); if (!arg_def.number_attr().empty()) { // Same type repeated "repeats" times. @@ -528,8 +531,8 @@ Status AddArgToSig(const NodeDefOrAttrSlice& node_or_attrs, } // namespace -Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def, - int input_port, DataType* input_type) { +absl::Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def, + int input_port, DataType* input_type) { DataTypeVector input_types; for (const auto& arg : op_def.input_arg()) { TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types)); @@ -544,16 +547,16 @@ Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def, node_def.name()); } -Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def, - DataTypeVector* inputs) { +absl::Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs) { for (const auto& arg : op_def.input_arg()) { TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs)); } return absl::OkStatus(); } -Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, - int output_port, DataType* output_type) { +absl::Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, + int output_port, DataType* output_type) { DataTypeVector output_types; for (const auto& arg : op_def.output_arg()) { TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &output_types)); @@ -568,30 +571,31 @@ Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, node_def.name()); } -Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def, - DataTypeVector* outputs) { +absl::Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* outputs) { for (const auto& arg : op_def.output_arg()) { TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs)); } return absl::OkStatus(); } -Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def, - DataTypeVector* outputs) { +absl::Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def, + DataTypeVector* outputs) { for (const auto& arg : op_def.output_arg()) { TF_RETURN_IF_ERROR(AddArgToSig(attrs, arg, outputs)); } return absl::OkStatus(); } -Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, - DataTypeVector* inputs, DataTypeVector* outputs) { +absl::Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs, + DataTypeVector* outputs) { TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs)); return OutputTypesForNode(node_def, op_def, outputs); } -Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def, - int* num_outputs) { +absl::Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def, + int* num_outputs) { DataTypeVector outputs; TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs)); *num_outputs = outputs.size(); @@ -631,7 +635,7 @@ int OpPortIdToArgId(const NodeDef& node, return -1; } -Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { +absl::Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { if (node_def.op() != op_def.name()) { return errors::InvalidArgument( "NodeDef op '", node_def.op(), "' does not match ", @@ -723,8 +727,9 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { namespace { // Helpers for NameRangesForNode() -Status ComputeArgRange(const AttrSlice& attrs, const OpDef::ArgDef& arg_def, - const OpDef& op_def, int* num) { +absl::Status ComputeArgRange(const AttrSlice& attrs, + const OpDef::ArgDef& arg_def, const OpDef& op_def, + int* num) { if (!arg_def.number_attr().empty()) { // Same type repeated "num" times. return GetNodeAttr(attrs, arg_def.number_attr(), num); @@ -742,9 +747,10 @@ Status ComputeArgRange(const AttrSlice& attrs, const OpDef::ArgDef& arg_def, return absl::OkStatus(); } -Status NameRangesHelper(const AttrSlice& attrs, - const protobuf::RepeatedPtrField& args, - const OpDef& op_def, NameRangeMap* result) { +absl::Status NameRangesHelper( + const AttrSlice& attrs, + const protobuf::RepeatedPtrField& args, const OpDef& op_def, + NameRangeMap* result) { int start = 0; int num; for (const auto& arg : args) { @@ -757,8 +763,8 @@ Status NameRangesHelper(const AttrSlice& attrs, } // namespace -Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def, - NameRangeMap* inputs, NameRangeMap* outputs) { +absl::Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs) { if (inputs != nullptr) { TF_RETURN_IF_ERROR( NameRangesHelper(attrs, op_def.input_arg(), op_def, inputs)); @@ -794,7 +800,7 @@ namespace { using ::tensorflow::tstring; using ::tensorflow::strings::Scanner; -bool IsValidNodeName(StringPiece sp) { +bool IsValidNodeName(absl::string_view sp) { Scanner scanner(sp); scanner.One(Scanner::LETTER_DIGIT_DOT) .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); @@ -812,7 +818,7 @@ bool IsValidNodeName(StringPiece sp) { } } -bool IsValidDataInputName(StringPiece sp) { +bool IsValidDataInputName(absl::string_view sp) { // Data inputs are op_name, op_name:0, or op_name:12345. Scanner scan(sp); scan.One(Scanner::LETTER_DIGIT_DOT) @@ -840,7 +846,7 @@ bool IsValidDataInputName(StringPiece sp) { } } -bool IsValidControlInputName(StringPiece sp) { +bool IsValidControlInputName(absl::string_view sp) { Scanner scan(sp); scan.OneLiteral("^") .One(Scanner::LETTER_DIGIT_DOT) @@ -859,11 +865,12 @@ bool IsValidControlInputName(StringPiece sp) { } } -const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix); +const absl::string_view kColocationGroupPrefixStringPiece( + kColocationGroupPrefix); } // namespace -Status ValidateOpInput(const string& input_name, bool* is_control_input) { +absl::Status ValidateOpInput(const string& input_name, bool* is_control_input) { *is_control_input = false; if (IsValidDataInputName(input_name)) { return absl::OkStatus(); @@ -875,7 +882,7 @@ Status ValidateOpInput(const string& input_name, bool* is_control_input) { } } -Status ValidateNodeName(const string& node_name) { +absl::Status ValidateNodeName(const string& node_name) { if (IsValidNodeName(node_name)) { return absl::OkStatus(); } else { @@ -883,8 +890,8 @@ Status ValidateNodeName(const string& node_name) { } } -Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) { - Status s = ValidateNodeName(node_def.name()); +absl::Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) { + absl::Status s = ValidateNodeName(node_def.name()); if (!s.ok()) { return AttachDef(s, node_def); } @@ -906,8 +913,8 @@ Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) { return absl::OkStatus(); } -Status AttachDef(const Status& status, const NodeDef& node_def, - bool allow_multiple_formatted_node) { +absl::Status AttachDef(const absl::Status& status, const NodeDef& node_def, + bool allow_multiple_formatted_node) { string node_error; if (!allow_multiple_formatted_node && absl::StrContains(status.message(), "{{node ")) { @@ -920,12 +927,13 @@ Status AttachDef(const Status& status, const NodeDef& node_def, strings::StrCat(status.message(), "\n\t", " [[", node_error, "]]")); } -void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) { +void AddNodeAttr(absl::string_view name, const AttrValue& value, + NodeDef* node_def) { node_def->mutable_attr()->insert( AttrValueMap::value_type(string(name), value)); } -void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) { +void AddNodeAttr(absl::string_view name, AttrValue&& value, NodeDef* node_def) { (*node_def->mutable_attr())[string(name)] = std::move(value); } @@ -935,7 +943,7 @@ void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) { SetAttrValue(value, &attr_value); \ AddNodeAttr(name, attr_value, node_def); \ } -ADD_NODE_ATTR(StringPiece) +ADD_NODE_ATTR(absl::string_view) ADD_NODE_ATTR(const char*) ADD_NODE_ATTR(int32_t) ADD_NODE_ATTR(int64_t) @@ -947,7 +955,7 @@ ADD_NODE_ATTR(const PartialTensorShape&) ADD_NODE_ATTR(const Tensor&) ADD_NODE_ATTR(const TensorProto&) ADD_NODE_ATTR(const NameAttrList&) -ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) ADD_NODE_ATTR(absl::Span) ADD_NODE_ATTR(absl::Span) ADD_NODE_ATTR(absl::Span) @@ -963,7 +971,8 @@ ADD_NODE_ATTR(absl::Span) ADD_NODE_ATTR(absl::Span) #undef ADD_NODE_ATTR -void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) { +void AddAttr(absl::string_view name, const AttrValue& value, + AttrValueMap* map) { map->insert(AttrValueMap::value_type(string(name), value)); } @@ -976,8 +985,10 @@ void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) { ADD_ATTR(bool) #undef ADD_ATTR -Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix, - NodeDef* node_def, bool uniquify_frame_name) { +absl::Status AddPrefixAndSuffixToNode(absl::string_view prefix, + absl::string_view suffix, + NodeDef* node_def, + bool uniquify_frame_name) { node_def->set_name(strings::StrCat(prefix, node_def->name(), suffix)); // Update frame name to avoid multiple LoopCond nodes in one frame. @@ -993,8 +1004,8 @@ Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix, return absl::OkStatus(); } -Status MaybeAddPrefixToColocationConstraints( - const std::unordered_set& match, StringPiece prefix, +absl::Status MaybeAddPrefixToColocationConstraints( + const std::unordered_set& match, absl::string_view prefix, NodeDef* node_def) { auto attr = node_def->mutable_attr()->find(kColocationAttrName); if (attr == node_def->mutable_attr()->end()) { @@ -1003,7 +1014,7 @@ Status MaybeAddPrefixToColocationConstraints( auto constraints_list = attr->second.mutable_list(); auto constraints_size = constraints_list->s_size(); for (size_t i = 0; i < constraints_size; ++i) { - StringPiece original(constraints_list->s(i)); + absl::string_view original(constraints_list->s(i)); if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) { if (match.find(string(original)) != match.end()) { (*constraints_list->mutable_s(i)) = @@ -1014,7 +1025,7 @@ Status MaybeAddPrefixToColocationConstraints( return absl::OkStatus(); } -Status MaybeUpdateColocationConstraintsWithMap( +absl::Status MaybeUpdateColocationConstraintsWithMap( const std::map& node_name_map, NodeDef* node_def) { auto attr = node_def->mutable_attr()->find(kColocationAttrName); @@ -1024,7 +1035,7 @@ Status MaybeUpdateColocationConstraintsWithMap( auto constraints_list = attr->second.mutable_list(); auto constraints_size = constraints_list->s_size(); for (size_t i = 0; i < constraints_size; ++i) { - StringPiece original(constraints_list->s(i)); + absl::string_view original(constraints_list->s(i)); if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) { if (node_name_map.find(original) != node_name_map.end()) { (*constraints_list->mutable_s(i)) = diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index b5eb424a89bd58..2b82c596fee301 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -71,76 +71,80 @@ extern const char* const kTpuExecuteStagingNodeName; std::string SummarizeNodeDef(const NodeDef& node_def, int max_inputs_in_summary = -1); std::string SummarizeAttrs(const NodeDef& node_def); -std::string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device); +std::string SummarizeAttrsHelper(AttrSlice attrs, absl::string_view device); // Produces a formatted string pattern from the node which can uniquely identify // this node upstream to produce an informative error message. The pattern // followed is: {{node }} std::string FormatNodeDefForError(const NodeDef& node_def); std::string FormatNodeDefForError( - StringPiece node_name, bool has_experimental_debug_info, + absl::string_view node_name, bool has_experimental_debug_info, const NodeDef_ExperimentalDebugInfo& experimental_debug_info); typedef protobuf::Map AttrValueMap; // Adds an attr with name and value to *node_def. // The type of the attr is based on the type of value. -void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, int32_t value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, int64_t value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, float value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, double value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, bool value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, DataType value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, const PartialTensorShape& value, +void AddNodeAttr(absl::string_view name, const AttrValue& value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, const Tensor& value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, const TensorProto& value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, const NameAttrList& value, +void AddNodeAttr(absl::string_view name, AttrValue&& value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, absl::string_view value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, absl::Span value, +void AddNodeAttr(absl::string_view name, const char* value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, int32_t value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, int64_t value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, float value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, double value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, bool value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, DataType value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, const PartialTensorShape& value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, absl::Span value, +void AddNodeAttr(absl::string_view name, const Tensor& value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, absl::Span value, +void AddNodeAttr(absl::string_view name, const TensorProto& value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, absl::Span value, +void AddNodeAttr(absl::string_view name, const NameAttrList& value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, absl::Span value, +void AddNodeAttr(absl::string_view name, + absl::Span value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, absl::Span value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, absl::Span value, +void AddNodeAttr(absl::string_view name, absl::Span value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, absl::Span value, +void AddNodeAttr(absl::string_view name, absl::Span value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, const std::vector& value, +void AddNodeAttr(absl::string_view name, absl::Span value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, absl::Span value, +void AddNodeAttr(absl::string_view name, absl::Span value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, absl::Span value, +void AddNodeAttr(absl::string_view name, absl::Span value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, absl::Span value, +void AddNodeAttr(absl::string_view name, const std::vector& value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, absl::Span value, +void AddNodeAttr(absl::string_view name, absl::Span value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, absl::Span value, +void AddNodeAttr(absl::string_view name, absl::Span value, NodeDef* node_def); -void AddNodeAttr(StringPiece name, absl::Span value, +void AddNodeAttr(absl::string_view name, + absl::Span value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, + absl::Span value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, absl::Span value, + NodeDef* node_def); +void AddNodeAttr(absl::string_view name, absl::Span value, NodeDef* node_def); // Version to workaround C++'s "perfect" forwarding not being able to // forward {...} initialization. template -void AddNodeAttr(StringPiece name, std::initializer_list value, +void AddNodeAttr(absl::string_view name, std::initializer_list value, NodeDef* node_def) { AddNodeAttr(name, gtl::ArraySlice(value), node_def); } // Adds an attr to an attr value map. -void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map); -void AddAttr(StringPiece name, bool value, AttrValueMap* map); +void AddAttr(absl::string_view name, const AttrValue& value, AttrValueMap* map); +void AddAttr(absl::string_view name, bool value, AttrValueMap* map); class AttrSlice { public: @@ -153,12 +157,13 @@ class AttrSlice { // Returns the attr with attr_name if found. Otherwise, returns // nullptr. - const AttrValue* Find(StringPiece attr_name) const; + const AttrValue* Find(absl::string_view attr_name) const; const AttrValue* FindByString(const std::string& attr_name) const; // Returns the attr_value for attr_name if found. Otherwise, returns a // NotFound status. - absl::Status Find(StringPiece attr_name, const AttrValue** attr_value) const; + absl::Status Find(absl::string_view attr_name, + const AttrValue** attr_value) const; absl::Status FindByString(const std::string& attr_name, const AttrValue** attr_value) const; @@ -196,7 +201,7 @@ class AttrSlice { return ndef_ != nullptr ? &ndef_->attr() : attrs_; } - absl::Status CheckFind(StringPiece attr_name, + absl::Status CheckFind(absl::string_view attr_name, const AttrValue* attr_value) const; const NodeDef* ndef_; @@ -204,59 +209,59 @@ class AttrSlice { }; // Return true if the attr with the name attr_name is defined in node_def. -bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name); +bool HasNodeAttr(const NodeDef& node_def, absl::string_view attr_name); // Look up the attr with name attr_name and set *value to its value. If no // attr with attr_name is found in node_def, or the attr does not have // a matching type, a non-ok status will be returned. -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::string* value); // type: "string" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, tstring* value); // type: "tstring" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, int64_t* value); // type: "int" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, int32* value); // type: "int" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, float* value); // type: "float" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, bool* value); // type: "bool" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, DataType* value); // type: "type" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, TensorShapeProto* value); // type: "shape" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, TensorShape* value); // type: "shape" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, PartialTensorShape* value); // type: "shape" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, Tensor* value); // type: "tensor" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type "list(string)" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type "list(tstring)" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type "list(int)" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type "list(int)" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type "list(float)" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type "list(bool)" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type "list(type)" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, DataTypeVector* value); // type "list(type)" absl::Status GetNodeAttr( - const AttrSlice& attrs, StringPiece attr_name, + const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type "list(shape)" absl::Status GetNodeAttr( - const AttrSlice& attrs, StringPiece attr_name, + const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type "list(shape)" absl::Status GetNodeAttr( - const AttrSlice& attrs, StringPiece attr_name, + const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type "list(shape)" -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type: "list(tensor)" template @@ -268,66 +273,66 @@ StatusOr GetNodeAttr(const NodeDef& ndef, absl::string_view attr_name) { // This version avoids copying the TensorProto. // REQUIRES: Must not use *value beyond the lifetime of node_def. -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, const TensorProto** value); // type: "tensor" -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, const TensorProto** value); // type: "tensor" // This version avoids copying the NameAttrList. // REQUIRES: Must not use *value beyond the lifetime of node_def. -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, const NameAttrList** value); // type: "func" -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, const NameAttrList** value); // type: "func" // These versions copies the NameAttrList(s). -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, NameAttrList* value); // type: "func" absl::Status GetNodeAttr( - const AttrSlice& attrs, StringPiece attr_name, + const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type: "list(func)" // Look up the attr with name attr_name and set *value to its value. If no // attr with attr_name is found in node_def, or the attr does not have // a matching type, false is returned. -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::string* value); // type: "string" -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, int64_t* value); // type: "int" -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type: "int" -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, int32* value); // type: "int" -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, float* value); // type: "float" -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, bool* value); // type: "bool" -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, DataType* value); // type: "type" -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, TensorShape* value); // type: "shape" -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type: "list(string)" -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type: "list(tstring)" -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type: "list(int)" -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type: "list(float)" -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type: "list(bool)" -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type: "list(type)" -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector value); // type: "shape" // Overloads of TryGetNodeAttr() that avoid copying the non-POD attribute // values. -bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type: "list(string)" bool TryGetNodeAttr( - const AttrSlice& attrs, StringPiece attr_name, + const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type: "list(shape)" // Look up the attr with name attr_name and return a reference to its value. @@ -335,10 +340,10 @@ bool TryGetNodeAttr( // a matching type, a reference to an empty string is returned. // REQUIRES: Must not use the returned value beyond the lifetime of node_def. const std::string& GetNodeAttrString(const AttrSlice& attrs, - StringPiece attr_name); + absl::string_view attr_name); // Specialization to parse an attribute directly into a Padding enum. -absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, Padding* value); // Computes the input type for a specific node input. @@ -395,7 +400,8 @@ absl::Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def); // space, the returned `NameRangeMap` objects borrow the input/output // argument names from `op_def`. The `op_def` must outlive the // returned `NameRangeMap` objects. -typedef gtl::FlatMap, hash> +typedef gtl::FlatMap, + hash> NameRangeMap; absl::Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def, NameRangeMap* inputs, NameRangeMap* outputs); @@ -428,14 +434,15 @@ absl::Status AttachDef(const absl::Status& status, const NodeDef& node_def, // Appends the given prefix and suffix to the original node name in order to // make the name unique. If it's an "Enter" node and uniquify_frame_name is // true, use the same way to reset attribute "frame_name". -absl::Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix, +absl::Status AddPrefixAndSuffixToNode(absl::string_view prefix, + absl::string_view suffix, NodeDef* node_def, bool uniquify_frame_name = true); // Appends the given prefix to the colocation group name if the name exists // in `to_match`. absl::Status MaybeAddPrefixToColocationConstraints( - const std::unordered_set& match, StringPiece prefix, + const std::unordered_set& match, absl::string_view prefix, NodeDef* node_def); // Updates the colocation constraint name with the one provided in the map (if diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc index 67bde1fc71e228..52c6a48c2eaadd 100644 --- a/tensorflow/core/framework/node_def_util_test.cc +++ b/tensorflow/core/framework/node_def_util_test.cc @@ -57,7 +57,7 @@ void ExpectSuccess(const NodeDef& good, const OpDef& op_def) { void ExpectFailure(const NodeDef& bad, const OpDef& op_def, const string& message) { - Status status = ValidateNodeDef(bad, op_def); + absl::Status status = ValidateNodeDef(bad, op_def); EXPECT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad) << "; OpDef: " << SummarizeOpDef(op_def); @@ -323,14 +323,14 @@ void ExpectValidSyntax(const NodeDef& good) { } void ExpectInvalidSyntax(const NodeDef& bad, const string& message) { - Status status = ValidateExternalNodeDefSyntax(bad); + absl::Status status = ValidateExternalNodeDefSyntax(bad); ASSERT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad); EXPECT_TRUE(errors::IsInvalidArgument(status)) << status << "; NodeDef: " << SummarizeNodeDef(bad); - EXPECT_TRUE(absl::StrContains(StringPiece(status.ToString()), message)) + EXPECT_TRUE(absl::StrContains(absl::string_view(status.ToString()), message)) << "NodeDef: " << SummarizeNodeDef(bad) << ", " << status << ", " << message; } @@ -876,10 +876,10 @@ TEST(AttachDef, AllowMultipleFormattedNode) { a.set_name("a"); NodeDef b; b.set_name("b"); - Status s = Status(absl::StatusCode::kCancelled, "Error"); - Status s2 = AttachDef(s, a, true); + absl::Status s = absl::Status(absl::StatusCode::kCancelled, "Error"); + absl::Status s2 = AttachDef(s, a, true); EXPECT_EQ("Error\n\t [[{{node a}}]]", s2.message()); - Status s3 = AttachDef(s2, b, true); + absl::Status s3 = AttachDef(s2, b, true); EXPECT_EQ("Error\n\t [[{{node a}}]]\n\t [[{{node b}}]]", s3.message()); } @@ -888,10 +888,10 @@ TEST(AttachDef, DisallowMultipleFormattedNode) { a.set_name("a"); NodeDef b; b.set_name("b"); - Status s = Status(absl::StatusCode::kCancelled, "Error"); - Status s2 = AttachDef(s, a, false); + absl::Status s = absl::Status(absl::StatusCode::kCancelled, "Error"); + absl::Status s2 = AttachDef(s, a, false); EXPECT_EQ("Error\n\t [[{{node a}}]]", s2.message()); - Status s3 = AttachDef(s2, b, false); + absl::Status s3 = AttachDef(s2, b, false); EXPECT_EQ("Error\n\t [[{{node a}}]]\n\t [[b]]", s3.message()); } diff --git a/tensorflow/core/framework/node_properties.cc b/tensorflow/core/framework/node_properties.cc index 4af538b3b2c1c5..cfa4de99780fdb 100644 --- a/tensorflow/core/framework/node_properties.cc +++ b/tensorflow/core/framework/node_properties.cc @@ -21,7 +21,7 @@ limitations under the License. namespace tensorflow { // static -Status NodeProperties::CreateFromNodeDef( +absl::Status NodeProperties::CreateFromNodeDef( NodeDef node_def, const OpRegistryInterface* op_registry, std::shared_ptr* props) { const OpDef* op_def; diff --git a/tensorflow/core/framework/node_properties_test.cc b/tensorflow/core/framework/node_properties_test.cc index 5621137c7aba71..8e1dd344e91261 100644 --- a/tensorflow/core/framework/node_properties_test.cc +++ b/tensorflow/core/framework/node_properties_test.cc @@ -40,8 +40,8 @@ class MockOpRegistry : public OpRegistryInterface { // Returns an error status and sets *op_reg_data to nullptr if no OpDef is // registered under that name, otherwise returns the registered OpDef. // Caller must not delete the returned pointer. - Status LookUp(const string& op_type_name, - const OpRegistrationData** op_reg_data) const override { + absl::Status LookUp(const string& op_type_name, + const OpRegistrationData** op_reg_data) const override { if (op_type_name == "Foo") { *op_reg_data = &op_reg_; return absl::OkStatus(); diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index 3c3970506389f9..6b328989ab8725 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -32,15 +32,15 @@ limitations under the License. namespace tensorflow { -Status DefaultValidator(const OpRegistryInterface& op_registry) { +absl::Status DefaultValidator(const OpRegistryInterface& op_registry) { LOG(WARNING) << "No kernel validator registered with OpRegistry."; return absl::OkStatus(); } // OpRegistry ----------------------------------------------------------------- -Status OpRegistryInterface::LookUpOpDef(const string& op_type_name, - const OpDef** op_def) const { +absl::Status OpRegistryInterface::LookUpOpDef(const string& op_type_name, + const OpDef** op_def) const { *op_def = nullptr; const OpRegistrationData* op_reg_data = nullptr; TF_RETURN_IF_ERROR(LookUp(op_type_name, &op_reg_data)); @@ -62,8 +62,8 @@ void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) { namespace { // Helper function that returns Status message for failed LookUp. -Status OpNotFound(const string& op_type_name) { - Status status = errors::NotFound( +absl::Status OpNotFound(const string& op_type_name) { + absl::Status status = errors::NotFound( "Op type not registered '", op_type_name, "' in binary running on ", port::Hostname(), ". ", "Make sure the Op and Kernel are registered in the binary running in " @@ -76,8 +76,8 @@ Status OpNotFound(const string& op_type_name) { } } // namespace -Status OpRegistry::LookUp(const string& op_type_name, - const OpRegistrationData** op_reg_data) const { +absl::Status OpRegistry::LookUp(const string& op_type_name, + const OpRegistrationData** op_reg_data) const { if ((*op_reg_data = LookUp(op_type_name))) return absl::OkStatus(); return OpNotFound(op_type_name); } @@ -148,7 +148,7 @@ void OpRegistry::GetOpRegistrationData( } } -Status OpRegistry::SetWatcher(const Watcher& watcher) { +absl::Status OpRegistry::SetWatcher(const Watcher& watcher) { mutex_lock lock(mu_); if (watcher_ && watcher) { return errors::AlreadyExists( @@ -162,7 +162,7 @@ void OpRegistry::Export(bool include_internal, OpList* ops) const { mutex_lock lock(mu_); MustCallDeferred(); - std::vector> sorted; + std::vector> sorted; sorted.reserve(registry_.size()); for (const auto& item : registry_) { sorted.emplace_back(item.first, item.second.get()); @@ -190,7 +190,7 @@ void OpRegistry::ClearDeferredRegistrations() { deferred_.clear(); } -Status OpRegistry::ProcessRegistrations() const { +absl::Status OpRegistry::ProcessRegistrations() const { mutex_lock lock(mu_); return CallDeferred(); } @@ -216,12 +216,12 @@ bool OpRegistry::MustCallDeferred() const { return true; } -Status OpRegistry::CallDeferred() const { +absl::Status OpRegistry::CallDeferred() const { if (initialized_) return absl::OkStatus(); initialized_ = true; registry_.reserve(registry_.size() + deferred_.size()); for (const auto& op_data_factory : deferred_) { - Status s = RegisterAlreadyLocked(op_data_factory); + absl::Status s = RegisterAlreadyLocked(op_data_factory); if (!s.ok()) { return s; } @@ -230,11 +230,11 @@ Status OpRegistry::CallDeferred() const { return absl::OkStatus(); } -Status OpRegistry::RegisterAlreadyLocked( +absl::Status OpRegistry::RegisterAlreadyLocked( const OpRegistrationDataFactory& op_data_factory) const { auto op_reg_data = std::make_unique(); const auto* op_reg_data_raw = op_reg_data.get(); - Status s = op_data_factory(op_reg_data.get()); + absl::Status s = op_data_factory(op_reg_data.get()); if (s.ok()) { s = ValidateOpDef(op_reg_data->op_def); } @@ -243,7 +243,7 @@ Status OpRegistry::RegisterAlreadyLocked( .second) { s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name()); } - Status watcher_status = s; + absl::Status watcher_status = s; if (watcher_) { watcher_status = watcher_(s, op_reg_data_raw->op_def); } @@ -276,8 +276,8 @@ const OpRegistrationData* OpListOpRegistry::LookUp( return iter->second.get(); } -Status OpListOpRegistry::LookUp(const string& op_type_name, - const OpRegistrationData** op_reg_data) const { +absl::Status OpListOpRegistry::LookUp( + const string& op_type_name, const OpRegistrationData** op_reg_data) const { if ((*op_reg_data = LookUp(op_type_name))) return absl::OkStatus(); return OpNotFound(op_type_name); } @@ -286,10 +286,8 @@ namespace register_op { InitOnStartupMarker OpDefBuilderWrapper::operator()() { OpRegistry::Global()->Register( - [builder = - std::move(builder_)](OpRegistrationData* op_reg_data) -> Status { - return builder.Finalize(op_reg_data); - }); + [builder = std::move(builder_)](OpRegistrationData* op_reg_data) + -> absl::Status { return builder.Finalize(op_reg_data); }); return {}; } diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc index 83aa4d8e1974dd..466f6cbf3311f9 100644 --- a/tensorflow/core/framework/op_def_builder.cc +++ b/tensorflow/core/framework/op_def_builder.cc @@ -96,7 +96,7 @@ bool ConsumeAttrNumber(StringPiece* sp, int64_t* out) { return false; } int64_t value = 0; - if (!strings::safe_strto64(match, &value)) { + if (!absl::SimpleAtoi(match, &value)) { return false; } *out = value; @@ -664,7 +664,7 @@ OpDefBuilder& OpDefBuilder::AllowAttrTypeAny() { return *this; } -Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const { +absl::Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const { std::vector errors = errors_; *op_reg_data = op_reg_data_; diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc index 80d2d37545ebe2..74ef92c33366d9 100644 --- a/tensorflow/core/framework/op_def_builder_test.cc +++ b/tensorflow/core/framework/op_def_builder_test.cc @@ -40,7 +40,7 @@ class OpDefBuilderTest : public ::testing::Test { protected: OpDefBuilder b() { return OpDefBuilder("Test"); } - void ExpectSuccess(const OpDefBuilder& builder, StringPiece proto, + void ExpectSuccess(const OpDefBuilder& builder, absl::string_view proto, OpShapeInferenceFn* shape_fn_out = nullptr) { OpRegistrationData op_reg_data; absl::Status status = builder.Finalize(&op_reg_data); @@ -61,7 +61,7 @@ class OpDefBuilderTest : public ::testing::Test { } } - void ExpectOrdered(const OpDefBuilder& builder, StringPiece proto) { + void ExpectOrdered(const OpDefBuilder& builder, absl::string_view proto) { OpRegistrationData op_reg_data; absl::Status status = builder.Finalize(&op_reg_data); TF_EXPECT_OK(status); diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc index 1da0aa726d64ca..c1b180cd6b5caa 100644 --- a/tensorflow/core/framework/op_def_util.cc +++ b/tensorflow/core/framework/op_def_util.cc @@ -41,7 +41,7 @@ bool HasAttrStyleType(const OpDef::ArgDef& arg) { !arg.type_list_attr().empty(); } -Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) { +absl::Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) { const AttrValue& allowed_values(attr.allowed_values()); for (auto allowed : allowed_values.list().type()) { if (dt == allowed) { @@ -61,7 +61,7 @@ Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) { " is not in the list of allowed values: ", allowed_str); } -Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) { +absl::Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) { const AttrValue& allowed_values(attr.allowed_values()); for (const auto& allowed : allowed_values.list().s()) { if (str == allowed) { @@ -83,8 +83,8 @@ Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) { } // namespace // Requires: attr has already been validated. -Status ValidateAttrValue(const AttrValue& attr_value, - const OpDef::AttrDef& attr) { +absl::Status ValidateAttrValue(const AttrValue& attr_value, + const OpDef::AttrDef& attr) { // Is it a valid value? TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()), " for attr '", attr.name(), "'"); @@ -146,7 +146,7 @@ Status ValidateAttrValue(const AttrValue& attr_value, return absl::OkStatus(); } -const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) { +const OpDef::AttrDef* FindAttr(absl::string_view name, const OpDef& op_def) { for (int i = 0; i < op_def.attr_size(); ++i) { if (op_def.attr(i).name() == name) { return &op_def.attr(i); @@ -155,7 +155,7 @@ const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) { return nullptr; } -OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) { +OpDef::AttrDef* FindAttrMutable(absl::string_view name, OpDef* op_def) { for (int i = 0; i < op_def->attr_size(); ++i) { if (op_def->attr(i).name() == name) { return op_def->mutable_attr(i); @@ -164,7 +164,7 @@ OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) { return nullptr; } -const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) { +const OpDef::ArgDef* FindInputArg(absl::string_view name, const OpDef& op_def) { for (int i = 0; i < op_def.input_arg_size(); ++i) { if (op_def.input_arg(i).name() == name) { return &op_def.input_arg(i); @@ -173,7 +173,7 @@ const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) { return nullptr; } -const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) { +const ApiDef::Arg* FindInputArg(absl::string_view name, const ApiDef& api_def) { for (int i = 0; i < api_def.in_arg_size(); ++i) { if (api_def.in_arg(i).name() == name) { return &api_def.in_arg(i); @@ -190,9 +190,9 @@ const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) { } \ } while (false) -static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def, - bool output, - absl::flat_hash_set* names) { +static absl::Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def, + bool output, + absl::flat_hash_set* names) { const string suffix = strings::StrCat( output ? " for output '" : " for input '", arg.name(), "'"); VALIDATE(names->emplace(arg.name()).second, "Duplicate name: ", arg.name()); @@ -247,7 +247,7 @@ static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def, return absl::OkStatus(); } -bool IsValidOpName(StringPiece sp) { +bool IsValidOpName(absl::string_view sp) { using ::tensorflow::strings::Scanner; Scanner scanner(sp); @@ -266,13 +266,14 @@ bool IsValidOpName(StringPiece sp) { } } -Status ValidateOpDef(const OpDef& op_def) { +absl::Status ValidateOpDef(const OpDef& op_def) { if (!absl::StartsWith(op_def.name(), "_")) { VALIDATE(IsValidOpName(op_def.name()), "Invalid name: ", op_def.name(), " (Did you use CamelCase?)"); } - absl::flat_hash_set names; // for detecting duplicate names + absl::flat_hash_set + names; // for detecting duplicate names for (const auto& attr : op_def.attr()) { // Validate name VALIDATE(names.emplace(attr.name()).second, @@ -282,11 +283,11 @@ Status ValidateOpDef(const OpDef& op_def) { attr.name(), " that matches a data type"); // Validate type - StringPiece type(attr.type()); + absl::string_view type(attr.type()); bool is_list = absl::ConsumePrefix(&type, "list("); bool found = false; - for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape", - "tensor", "func"}) { + for (absl::string_view valid : {"string", "int", "float", "bool", "type", + "shape", "tensor", "func"}) { if (absl::ConsumePrefix(&type, valid)) { found = true; break; @@ -348,7 +349,7 @@ Status ValidateOpDef(const OpDef& op_def) { #undef VALIDATE -Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) { +absl::Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) { if (op_def.has_deprecation()) { const OpDeprecation& dep = op_def.deprecation(); if (graph_def_version >= dep.version()) { @@ -499,7 +500,7 @@ string MinStr(const OpDef::AttrDef& attr) { return strings::StrCat(attr.minimum()); } -typedef absl::flat_hash_map AttrMap; +typedef absl::flat_hash_map AttrMap; void FillAttrMap(const OpDef& op_def, AttrMap* attr_map) { for (const auto& attr : op_def.attr()) { (*attr_map)[attr.name()] = &attr; @@ -618,7 +619,7 @@ string ComputeArgSignature( } // namespace -Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) { +absl::Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) { #define VALIDATE(CONDITION, ...) \ if (!(CONDITION)) { \ return errors::InvalidArgument("Incompatible Op change: ", __VA_ARGS__, \ @@ -687,9 +688,9 @@ Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) { return absl::OkStatus(); } -Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, - const OpDef& penultimate_op, - const OpDef& new_op) { +absl::Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, + const OpDef& penultimate_op, + const OpDef& new_op) { AttrMap new_attrs, old_attrs; FillAttrMap(old_op, &old_attrs); FillAttrMap(new_op, &new_attrs); @@ -726,7 +727,8 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, return absl::OkStatus(); } -Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) { +absl::Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, + const OpDef& new_op) { AttrMap new_attrs, old_attrs; FillAttrMap(old_op, &old_attrs); FillAttrMap(new_op, &new_attrs); @@ -862,11 +864,11 @@ bool OpDefEqual(const OpDef& o1, const OpDef& o2) { if (!RepeatedAttrDefEqual(o1.attr(), o2.attr())) return false; // `control_output` order doesn't matter. - std::vector control_output1(o1.control_output().begin(), - o1.control_output().end()); + std::vector control_output1(o1.control_output().begin(), + o1.control_output().end()); std::sort(control_output1.begin(), control_output1.end()); - std::vector control_output2(o2.control_output().begin(), - o2.control_output().end()); + std::vector control_output2(o2.control_output().begin(), + o2.control_output().end()); std::sort(control_output2.begin(), control_output2.end()); if (control_output1 != control_output2) return false; diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h index e116f89229dc54..be1f08225c0e2e 100644 --- a/tensorflow/core/framework/op_def_util.h +++ b/tensorflow/core/framework/op_def_util.h @@ -43,16 +43,16 @@ absl::Status ValidateAttrValue(const AttrValue& attr_value, // The following search through op_def for an attr with the indicated name. // Returns nullptr if no such attr is found. -const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def); -OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def); +const OpDef::AttrDef* FindAttr(absl::string_view name, const OpDef& op_def); +OpDef::AttrDef* FindAttrMutable(absl::string_view name, OpDef* op_def); // Searches op_def for input argument with the indicated name. // Returns nullptr if no such attr is found. -const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def); +const OpDef::ArgDef* FindInputArg(absl::string_view name, const OpDef& op_def); // Searches api_def for input argument with the indicated name. // Returns nullptr if no such attr is found. -const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def); +const ApiDef::Arg* FindInputArg(absl::string_view name, const ApiDef& api_def); // Produce a human-readable version of an op_def that is more concise // than a text-format proto. Excludes descriptions. diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc index 9151e1b0448fb2..d5e6ab1a7cd227 100644 --- a/tensorflow/core/framework/op_gen_lib.cc +++ b/tensorflow/core/framework/op_gen_lib.cc @@ -30,7 +30,7 @@ limitations under the License. namespace tensorflow { -string WordWrap(StringPiece prefix, StringPiece str, int width) { +string WordWrap(absl::string_view prefix, absl::string_view str, int width) { const string indent_next_line = "\n" + Spaces(prefix.size()); width -= prefix.size(); string result; @@ -43,16 +43,16 @@ string WordWrap(StringPiece prefix, StringPiece str, int width) { break; } auto space = str.rfind(' ', width); - if (space == StringPiece::npos) { + if (space == absl::string_view::npos) { // Rather make a too-long line and break at a space. space = str.find(' '); - if (space == StringPiece::npos) { + if (space == absl::string_view::npos) { strings::StrAppend(&result, str); break; } } // Breaking at character at position . - StringPiece to_append = str.substr(0, space); + absl::string_view to_append = str.substr(0, space); str.remove_prefix(space + 1); // Remove spaces at break. while (absl::EndsWith(to_append, " ")) { @@ -69,7 +69,7 @@ string WordWrap(StringPiece prefix, StringPiece str, int width) { return result; } -bool ConsumeEquals(StringPiece* description) { +bool ConsumeEquals(absl::string_view* description) { if (absl::ConsumePrefix(description, "=")) { while (absl::ConsumePrefix(description, " ")) { // Also remove spaces after "=". @@ -84,12 +84,12 @@ bool ConsumeEquals(StringPiece* description) { // contains the maximum prefix of the input `*orig` that doesn't // contain `split_ch`, and `*orig` contains everything after the // first `split_ch`. -static bool SplitAt(char split_ch, StringPiece* orig, - StringPiece* before_split) { +static bool SplitAt(char split_ch, absl::string_view* orig, + absl::string_view* before_split) { auto pos = orig->find(split_ch); - if (pos == StringPiece::npos) { + if (pos == absl::string_view::npos) { *before_split = *orig; - *orig = StringPiece(); + *orig = absl::string_view(); return false; } else { *before_split = orig->substr(0, pos); @@ -100,9 +100,9 @@ static bool SplitAt(char split_ch, StringPiece* orig, // Does this line start with ":" where "" is // in multi_line_fields? Sets *colon_pos to the position of the colon. -static bool StartsWithFieldName(StringPiece line, +static bool StartsWithFieldName(absl::string_view line, const std::vector& multi_line_fields) { - StringPiece up_to_colon; + absl::string_view up_to_colon; if (!SplitAt(':', &line, &up_to_colon)) return false; while (absl::ConsumePrefix(&up_to_colon, " ")) ; // Remove leading spaces. @@ -114,7 +114,7 @@ static bool StartsWithFieldName(StringPiece line, return false; } -static bool ConvertLine(StringPiece line, +static bool ConvertLine(absl::string_view line, const std::vector& multi_line_fields, string* ml) { // Is this a field we should convert? @@ -122,8 +122,8 @@ static bool ConvertLine(StringPiece line, return false; } // Has a matching field name, so look for "..." after the colon. - StringPiece up_to_colon; - StringPiece after_colon = line; + absl::string_view up_to_colon; + absl::string_view after_colon = line; SplitAt(':', &after_colon, &up_to_colon); while (absl::ConsumePrefix(&after_colon, " ")) ; // Remove leading spaces. @@ -132,12 +132,12 @@ static bool ConvertLine(StringPiece line, return false; } auto last_quote = after_colon.rfind('\"'); - if (last_quote == StringPiece::npos) { + if (last_quote == absl::string_view::npos) { // Error: we don't see the expected matching quote, abort the conversion. return false; } - StringPiece escaped = after_colon.substr(0, last_quote); - StringPiece suffix = after_colon.substr(last_quote + 1); + absl::string_view escaped = after_colon.substr(0, last_quote); + absl::string_view suffix = after_colon.substr(last_quote + 1); // We've now parsed line into ': ""' string unescaped; @@ -163,13 +163,13 @@ static bool ConvertLine(StringPiece line, return true; } -string PBTxtToMultiline(StringPiece pbtxt, +string PBTxtToMultiline(absl::string_view pbtxt, const std::vector& multi_line_fields) { string ml; // Probably big enough, since the input and output are about the // same size, but just a guess. ml.reserve(pbtxt.size() * (17. / 16)); - StringPiece line; + absl::string_view line; while (!pbtxt.empty()) { // Split pbtxt into its first line and everything after. SplitAt('\n', &pbtxt, &line); @@ -184,8 +184,8 @@ string PBTxtToMultiline(StringPiece pbtxt, // Given a single line of text `line` with first : at `colon`, determine if // there is an "<set_visibility(new_api_def.visibility()); @@ -480,18 +480,19 @@ ApiDefMap::ApiDefMap(const OpList& op_list) { ApiDefMap::~ApiDefMap() {} -Status ApiDefMap::LoadFileList(Env* env, const std::vector& filenames) { +absl::Status ApiDefMap::LoadFileList(Env* env, + const std::vector& filenames) { for (const auto& filename : filenames) { TF_RETURN_IF_ERROR(LoadFile(env, filename)); } return absl::OkStatus(); } -Status ApiDefMap::LoadFile(Env* env, const string& filename) { +absl::Status ApiDefMap::LoadFile(Env* env, const string& filename) { if (filename.empty()) return absl::OkStatus(); string contents; TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents)); - Status status = LoadApiDef(contents); + absl::Status status = LoadApiDef(contents); if (!status.ok()) { // Return failed status annotated with filename to aid in debugging. return errors::CreateWithUpdatedMessage( @@ -501,7 +502,7 @@ Status ApiDefMap::LoadFile(Env* env, const string& filename) { return absl::OkStatus(); } -Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) { +absl::Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) { const string contents = PBTxtFromMultiline(api_def_file_contents); ApiDefs api_defs; TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/framework/op_gen_lib.h b/tensorflow/core/framework/op_gen_lib.h index 1db41eb401117f..27ffe522a6dd35 100644 --- a/tensorflow/core/framework/op_gen_lib.h +++ b/tensorflow/core/framework/op_gen_lib.h @@ -35,17 +35,17 @@ inline string Spaces(int n) { return string(n, ' '); } // after the first by prefix.size() spaces. Intended use case is something // like prefix = " Foo(" and str is a list of arguments (terminated by a ")"). // TODO(josh11b): Option to wrap on ", " instead of " " when possible. -string WordWrap(StringPiece prefix, StringPiece str, int width); +string WordWrap(absl::string_view prefix, absl::string_view str, int width); // Looks for an "=" at the beginning of *description. If found, strips it off // (and any following spaces) from *description and return true. Otherwise // returns false. -bool ConsumeEquals(StringPiece* description); +bool ConsumeEquals(absl::string_view* description); // Convert text-serialized protobufs to/from multiline format. -string PBTxtToMultiline(StringPiece pbtxt, +string PBTxtToMultiline(absl::string_view pbtxt, const std::vector& multi_line_fields); -string PBTxtFromMultiline(StringPiece multiline_pbtxt); +string PBTxtFromMultiline(absl::string_view multiline_pbtxt); // Takes a list of files with ApiDefs text protos, and allows you to // look up the specific ApiDef for any given op. diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index f15065f8628fb0..06dd8fddd128ec 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -66,10 +66,10 @@ const char* kDisableJitKernelsEnvVar = "TF_DISABLE_JIT_KERNELS"; namespace { -Status MatchSignatureHelper(const DataTypeSlice expected_inputs, - const DataTypeSlice expected_outputs, - const DataTypeSlice inputs, - const DataTypeSlice outputs) { +absl::Status MatchSignatureHelper(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs, + const DataTypeSlice inputs, + const DataTypeSlice outputs) { bool signature_mismatch = false; if (inputs.size() != expected_inputs.size()) signature_mismatch = true; @@ -188,8 +188,8 @@ OpKernel::OpKernel(OpKernelConstruction* context, NodeDef&& custom_def, OpKernel::~OpKernel() {} -Status OpKernel::InputRange(StringPiece input_name, int* start, - int* stop) const { +absl::Status OpKernel::InputRange(absl::string_view input_name, int* start, + int* stop) const { const auto result = input_name_map_.find(input_name); if (result == input_name_map_.end()) { return errors::InvalidArgument("Unknown input name: ", input_name); @@ -200,8 +200,8 @@ Status OpKernel::InputRange(StringPiece input_name, int* start, } } -Status OpKernel::OutputRange(StringPiece output_name, int* start, - int* stop) const { +absl::Status OpKernel::OutputRange(absl::string_view output_name, int* start, + int* stop) const { const auto result = output_name_map_.find(output_name); if (result == output_name_map_.end()) { return errors::InvalidArgument("Unknown output name: ", output_name); @@ -261,7 +261,7 @@ OpKernelConstruction::OpKernelConstruction( const std::shared_ptr& props, const MemoryTypeSlice& input_memory_types, const MemoryTypeSlice& output_memory_types, int graph_def_version, - Status* status) + absl::Status* status) : device_type_(std::move(device_type)), device_(device), allocator_(allocator), @@ -273,23 +273,23 @@ OpKernelConstruction::OpKernelConstruction( graph_def_version_(graph_def_version), status_(status) {} -bool OpKernelConstruction::HasAttr(StringPiece attr_name) const { +bool OpKernelConstruction::HasAttr(absl::string_view attr_name) const { return HasNodeAttr(def(), attr_name); } -void OpKernelConstruction::SetStatus(const Status& status) { +void OpKernelConstruction::SetStatus(const absl::Status& status) { status_->Update(status); } -Status OpKernelConstruction::MatchSignature( +absl::Status OpKernelConstruction::MatchSignature( const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) { return MatchSignatureHelper(expected_inputs, expected_outputs, props_->input_types, props_->output_types); } -Status OpKernelConstruction::allocate_temp(DataType type, - const TensorShape& shape, - Tensor* out_temp) { +absl::Status OpKernelConstruction::allocate_temp(DataType type, + const TensorShape& shape, + Tensor* out_temp) { AllocationAttributes attr; attr.allocation_will_be_logged = true; Tensor new_temp(allocator_, type, shape, attr); @@ -306,10 +306,9 @@ Status OpKernelConstruction::allocate_temp(DataType type, return absl::OkStatus(); } -Status OpKernelConstruction::allocate_temp(DataType type, - const TensorShape& shape, - Tensor* out_temp, - AllocatorAttributes allocator_attr) { +absl::Status OpKernelConstruction::allocate_temp( + DataType type, const TensorShape& shape, Tensor* out_temp, + AllocatorAttributes allocator_attr) { if (allocator_attr.scope_id != 0) { return errors::InvalidArgument( "ScopedAllocator cannot be used via OpKernelConstruction."); @@ -349,7 +348,7 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs) params_->ensure_eigen_gpu_device(); if (params_->eigen_gpu_device != nullptr) { Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes()); - Status s = params_->device->ReinitializeGpuDevice( + absl::Status s = params_->device->ReinitializeGpuDevice( this, params_->eigen_gpu_device, params_->op_device_context, eigen_gpu_allocator); if (!s.ok()) { @@ -400,11 +399,12 @@ Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) { } } -void OpKernelContext::SetStatus(const Status& status) { +void OpKernelContext::SetStatus(const absl::Status& status) { status_.Update(status); } -Status OpKernelContext::input(StringPiece name, const Tensor** tensor) { +absl::Status OpKernelContext::input(absl::string_view name, + const Tensor** tensor) { int index; TF_RETURN_IF_ERROR(get_input_index(name, &index)); if (input_is_ref(index)) { @@ -415,7 +415,8 @@ Status OpKernelContext::input(StringPiece name, const Tensor** tensor) { return absl::OkStatus(); } -Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const { +absl::Status OpKernelContext::input_dtype(absl::string_view name, + DataType* dtype) const { int index; TF_RETURN_IF_ERROR(get_input_index(name, &index)); const TensorValue& value(params_->inputs[index]); @@ -423,7 +424,8 @@ Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const { return absl::OkStatus(); } -Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) { +absl::Status OpKernelContext::input_ref_mutex(absl::string_view name, + mutex** out_mutex) { int index; TF_RETURN_IF_ERROR(get_input_index(name, &index)); *out_mutex = input_ref_mutex(index); @@ -506,8 +508,8 @@ bool OpKernelContext::forward_input_to_output_with_shape( } } -Status OpKernelContext::forward_input_to_output_with_shape( - StringPiece input_name, StringPiece output_name, +absl::Status OpKernelContext::forward_input_to_output_with_shape( + absl::string_view input_name, absl::string_view output_name, const TensorShape& output_shape, Tensor** output) { int input_index, output_index; TF_RETURN_IF_ERROR(get_input_index(input_name, &input_index)); @@ -588,7 +590,7 @@ std::unique_ptr OpKernelContext::forward_input( return output_tensor; } -Status OpKernelContext::forward_input_or_allocate_temp( +absl::Status OpKernelContext::forward_input_or_allocate_temp( absl::Span candidate_input_indices, DataType type, const TensorShape& shape, const AllocatorAttributes& allocator_attr, Tensor* out_temp) { @@ -604,7 +606,7 @@ Status OpKernelContext::forward_input_or_allocate_temp( return allocate_temp(type, shape, out_temp, allocator_attr); } -Status OpKernelContext::forward_input_or_allocate_output( +absl::Status OpKernelContext::forward_input_or_allocate_output( absl::Span candidate_input_indices, int output_index, const TensorShape& output_shape, Tensor** output, int* forwarded_input) { for (int input_index : candidate_input_indices) { @@ -622,10 +624,11 @@ Status OpKernelContext::forward_input_or_allocate_output( return allocate_output(output_index, output_shape, output); } -Status OpKernelContext::forward_input_or_allocate_output( - absl::Span candidate_input_names, - StringPiece output_name, const TensorShape& output_shape, Tensor** output) { - for (const StringPiece& input_name : candidate_input_names) { +absl::Status OpKernelContext::forward_input_or_allocate_output( + absl::Span candidate_input_names, + absl::string_view output_name, const TensorShape& output_shape, + Tensor** output) { + for (const absl::string_view& input_name : candidate_input_names) { if (forward_input_to_output_with_shape(input_name, output_name, output_shape, output) .ok()) { @@ -648,8 +651,8 @@ void OpKernelContext::delete_ref_input(int index, bool lock_held) { } } -Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor, - bool lock_held) { +absl::Status OpKernelContext::mutable_input(absl::string_view name, + Tensor* tensor, bool lock_held) { int index; TF_RETURN_IF_ERROR(get_input_index(name, &index)); if (!input_is_ref(index)) { @@ -666,9 +669,9 @@ Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor, return absl::OkStatus(); } -Status OpKernelContext::replace_ref_input(StringPiece name, - const Tensor& tensor, - bool lock_held) { +absl::Status OpKernelContext::replace_ref_input(absl::string_view name, + const Tensor& tensor, + bool lock_held) { int index; TF_RETURN_IF_ERROR(get_input_index(name, &index)); if (!input_is_ref(index)) { @@ -679,22 +682,24 @@ Status OpKernelContext::replace_ref_input(StringPiece name, return absl::OkStatus(); } -Status OpKernelContext::input_list(StringPiece name, OpInputList* list) { +absl::Status OpKernelContext::input_list(absl::string_view name, + OpInputList* list) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); *list = OpInputList(this, start, stop); return absl::OkStatus(); } -Status OpKernelContext::mutable_input_list(StringPiece name, - OpMutableInputList* list) { +absl::Status OpKernelContext::mutable_input_list(absl::string_view name, + OpMutableInputList* list) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); *list = OpMutableInputList(this, start, stop); return absl::OkStatus(); } -Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) { +absl::Status OpKernelContext::output_list(absl::string_view name, + OpOutputList* list) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); *list = OpOutputList(this, start, stop); @@ -707,8 +712,9 @@ void OpKernelContext::maybe_initialize_scope_id_set() { } } -Status OpKernelContext::allocate_output(int index, const TensorShape& shape, - Tensor** tensor) { +absl::Status OpKernelContext::allocate_output(int index, + const TensorShape& shape, + Tensor** tensor) { if (index < 0) { return errors::Internal("allocate_output with bad index=", index, " kernel=", params_->op_kernel->name()); @@ -730,9 +736,9 @@ Status OpKernelContext::allocate_output(int index, const TensorShape& shape, return allocate_output(index, shape, tensor, attr); } -Status OpKernelContext::allocate_output(StringPiece name, - const TensorShape& shape, - Tensor** tensor) { +absl::Status OpKernelContext::allocate_output(absl::string_view name, + const TensorShape& shape, + Tensor** tensor) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); if (stop != start + 1) { @@ -744,10 +750,10 @@ Status OpKernelContext::allocate_output(StringPiece name, return allocate_output(start, shape, tensor); } -Status OpKernelContext::allocate_output(StringPiece name, - const TensorShape& shape, - Tensor** tensor, - AllocatorAttributes attr) { +absl::Status OpKernelContext::allocate_output(absl::string_view name, + const TensorShape& shape, + Tensor** tensor, + AllocatorAttributes attr) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); if (stop != start + 1) { @@ -759,7 +765,7 @@ Status OpKernelContext::allocate_output(StringPiece name, return allocate_output(start, shape, tensor, attr); } -Status OpKernelContext::allocate_tensor( +absl::Status OpKernelContext::allocate_tensor( DataType type, const TensorShape& shape, Tensor* out_tensor, AllocatorAttributes attr, const AllocationAttributes& allocation_attr) { Allocator* a = get_allocator(attr); @@ -783,9 +789,10 @@ Status OpKernelContext::allocate_tensor( return absl::OkStatus(); } -Status OpKernelContext::allocate_output(int index, const TensorShape& shape, - Tensor** output, - AllocatorAttributes attr) { +absl::Status OpKernelContext::allocate_output(int index, + const TensorShape& shape, + Tensor** output, + AllocatorAttributes attr) { if (index < 0) { return errors::Internal("allocate_output with bad index=", index, " kernel=", params_->op_kernel->name()); @@ -821,7 +828,7 @@ Status OpKernelContext::allocate_output(int index, const TensorShape& shape, op_kernel().name_view().data(), step_id(), "output", type, [&shape]() { return shape.DebugString(); }); auto output_tensor = std::make_unique(); - Status s = allocate_tensor(type, shape, output_tensor.get(), attr); + absl::Status s = allocate_tensor(type, shape, output_tensor.get(), attr); if (s.ok()) { outputs_[index] = TensorValue(output_tensor.release()); *output = outputs_[index].tensor; @@ -829,7 +836,7 @@ Status OpKernelContext::allocate_output(int index, const TensorShape& shape, return s; } -Status OpKernelContext::allocate_temp( +absl::Status OpKernelContext::allocate_temp( DataType type, const TensorShape& shape, Tensor* out_temp, AllocatorAttributes allocator_attr, const AllocationAttributes& allocation_attr) { @@ -851,7 +858,7 @@ Status OpKernelContext::allocate_temp( tsl::profiler::ScopedMemoryDebugAnnotation op_annotation( op_kernel().name_view().data(), step_id(), "temp", type, [&shape]() { return shape.DebugString(); }); - Status s = + absl::Status s = allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr); if (track_allocations() && s.ok() && out_temp->TotalBytes() > 0) { Allocator* a = get_allocator(allocator_attr); @@ -867,20 +874,21 @@ Status OpKernelContext::allocate_temp( return s; } -Status OpKernelContext::allocate_temp(DataType type, const TensorShape& shape, - Tensor* out_temp, - AllocatorAttributes allocator_attr) { +absl::Status OpKernelContext::allocate_temp( + DataType type, const TensorShape& shape, Tensor* out_temp, + AllocatorAttributes allocator_attr) { return allocate_temp(type, shape, out_temp, allocator_attr, AllocationAttributes()); } -Status OpKernelContext::allocate_temp(DataType type, const TensorShape& shape, - Tensor* out_temp) { +absl::Status OpKernelContext::allocate_temp(DataType type, + const TensorShape& shape, + Tensor* out_temp) { return allocate_temp(type, shape, out_temp, AllocatorAttributes()); } -Status OpKernelContext::get_input_index(StringPiece name, - int* out_index) const { +absl::Status OpKernelContext::get_input_index(absl::string_view name, + int* out_index) const { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); if (stop != start + 1) { @@ -893,8 +901,8 @@ Status OpKernelContext::get_input_index(StringPiece name, return absl::OkStatus(); } -Status OpKernelContext::get_output_index(StringPiece name, - int* out_index) const { +absl::Status OpKernelContext::get_output_index(absl::string_view name, + int* out_index) const { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); if (stop != start + 1) { @@ -907,14 +915,16 @@ Status OpKernelContext::get_output_index(StringPiece name, return absl::OkStatus(); } -Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) { +absl::Status OpKernelContext::set_output(absl::string_view name, + const Tensor& tensor) { int index; TF_RETURN_IF_ERROR(get_output_index(name, &index)); set_output(index, tensor); return absl::OkStatus(); } -Status OpKernelContext::set_output(StringPiece name, Tensor&& tensor) { +absl::Status OpKernelContext::set_output(absl::string_view name, + Tensor&& tensor) { int index; TF_RETURN_IF_ERROR(get_output_index(name, &index)); set_output(index, std::move(tensor)); @@ -957,11 +967,13 @@ bool OpKernelContext::maybe_set_output_by_allocate_and_copy( op_kernel().name_view().data(), step_id(), "output", tensor.dtype(), [&tensor]() { return tensor.shape().DebugString(); }); auto new_tensor = std::make_unique(); - Status s = allocate_tensor(tensor.dtype(), tensor.shape(), new_tensor.get(), - output_alloc_attr(index)); + absl::Status s = + allocate_tensor(tensor.dtype(), tensor.shape(), new_tensor.get(), + output_alloc_attr(index)); TF_CHECK_OK(s); device()->CopyTensorInSameDevice(&tensor, new_tensor.get(), - op_device_context(), [](const Status&) {}); + op_device_context(), + [](const absl::Status&) {}); outputs_[index] = TensorValue(new_tensor.release()); } return allocate_and_copy; @@ -1021,15 +1033,16 @@ void OpKernelContext::set_output_ref(int index, mutex* mu, outputs_[index] = TensorValue(mu, tensor_for_ref); } -Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu, - Tensor* tensor_for_ref) { +absl::Status OpKernelContext::set_output_ref(absl::string_view name, mutex* mu, + Tensor* tensor_for_ref) { int index; TF_RETURN_IF_ERROR(get_output_index(name, &index)); set_output_ref(index, mu, tensor_for_ref); return absl::OkStatus(); } -Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) { +absl::Status OpKernelContext::mutable_output(absl::string_view name, + Tensor** tensor) { int index; TF_RETURN_IF_ERROR(get_output_index(name, &index)); *tensor = mutable_output(index); @@ -1051,8 +1064,8 @@ bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { return true; } -Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs, - const DataTypeSlice expected_outputs) { +absl::Status OpKernelContext::MatchSignature( + const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) { DataTypeVector inputs; for (const TensorValue& t : params_->inputs) { inputs.push_back(t.dtype()); @@ -1140,7 +1153,7 @@ const string& OpKernelContext::executor_type() const { // OpKernel registration ------------------------------------------------------ struct KernelRegistration { - KernelRegistration(const KernelDef& d, StringPiece c, + KernelRegistration(const KernelDef& d, absl::string_view c, std::unique_ptr f) : def(d), kernel_class_name(c), factory(std::move(f)) {} @@ -1171,7 +1184,7 @@ static const char kKernelLibPattern[] = "libtfkernel*.so"; // Returns Status::OK if the dynamic library at the given path is safe to // load with some level of confidence. -static Status IsProbablySafeToLoad(const string& path) { +static absl::Status IsProbablySafeToLoad(const string& path) { // A map of platform string to required CPU feature. using port::CPUFeature; static const auto* feature_map = @@ -1182,11 +1195,11 @@ static Status IsProbablySafeToLoad(const string& path) { std::vector platform_strings; int result = GetPlatformStrings(path, &platform_strings); if (result) { - return Status(absl::StatusCode::kUnknown, strerror(result)); + return absl::Status(absl::StatusCode::kUnknown, strerror(result)); } if (platform_strings.empty()) { - return Status(absl::StatusCode::kFailedPrecondition, - "Didn't find any platform strings"); + return absl::Status(absl::StatusCode::kFailedPrecondition, + "Didn't find any platform strings"); } std::vector missing_features; for (const auto& platform_string : platform_strings) { @@ -1218,13 +1231,13 @@ void LoadDynamicKernelsInternal() { string bazel_kernel_dir = io::JoinPath(env->GetRunfilesDir(), "tensorflow", "core", "kernels"); std::vector files; - Status s_kernel_dir = env->GetChildren(bazel_kernel_dir, &files); + absl::Status s_kernel_dir = env->GetChildren(bazel_kernel_dir, &files); if (s_kernel_dir.ok()) { string dll_spec = io::JoinPath(bazel_kernel_dir, kKernelLibPattern); for (const auto& file : files) { string fullpath = io::JoinPath(bazel_kernel_dir, file); if (env->MatchPath(fullpath, dll_spec)) { - Status s = IsProbablySafeToLoad(fullpath); + absl::Status s = IsProbablySafeToLoad(fullpath); if (!s.ok() && override_abi_check) { LOG(WARNING) << "Loading UNSAFE library " << fullpath << " because ABI check override is set: " << s.message(); @@ -1251,8 +1264,8 @@ void LoadDynamicKernels() { absl::call_once(dll_loader_flag, LoadDynamicKernelsInternal); } -static string Key(StringPiece op_type, const DeviceType& device_type, - StringPiece label) { +static string Key(absl::string_view op_type, const DeviceType& device_type, + absl::string_view label) { return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":", label); } @@ -1330,7 +1343,7 @@ static KernelRegistry* GlobalKernelRegistryTyped() { namespace kernel_factory { void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def, - StringPiece kernel_class_name, + absl::string_view kernel_class_name, std::unique_ptr factory) { const string key = Key(kernel_def->op(), DeviceType(kernel_def->device_type()), @@ -1378,12 +1391,12 @@ const string& GetKernelLabelAttr(const AttrSlice& node_attrs) { } // TODO(irving): Replace with const Node& version below. -Status FindKernelRegistration( - const DeviceType& device_type, StringPiece node_name, +absl::Status FindKernelRegistration( + const DeviceType& device_type, absl::string_view node_name, bool has_experimental_debug_info, const NodeDef_ExperimentalDebugInfo& experimental_debug_info, - StringPiece node_op, AttrSlice node_attrs, const KernelRegistration** reg, - bool* was_attr_mismatch) { + absl::string_view node_op, AttrSlice node_attrs, + const KernelRegistration** reg, bool* was_attr_mismatch) { *reg = nullptr; *was_attr_mismatch = false; @@ -1457,10 +1470,10 @@ Status FindKernelRegistration( return absl::OkStatus(); } -Status FindKernelRegistration(const DeviceType& device_type, - const NodeDef& node_def, - const KernelRegistration** reg, - bool* was_attr_mismatch) { +absl::Status FindKernelRegistration(const DeviceType& device_type, + const NodeDef& node_def, + const KernelRegistration** reg, + bool* was_attr_mismatch) { return FindKernelRegistration( device_type, node_def.name(), node_def.has_experimental_debug_info(), node_def.experimental_debug_info(), node_def.op(), @@ -1473,18 +1486,18 @@ bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def) { const KernelRegistration* reg = nullptr; bool was_attr_mismatch; - Status result = + absl::Status result = FindKernelRegistration(device_type, node_def, ®, &was_attr_mismatch); return result.ok() && reg != nullptr; } // TODO(irving): Change const NodeDef& to const Node& -Status FindKernelDef( - const DeviceType& device_type, StringPiece node_name, +absl::Status FindKernelDef( + const DeviceType& device_type, absl::string_view node_name, bool has_experimental_debug_info, const NodeDef_ExperimentalDebugInfo& experimental_debug_info, - StringPiece node_op, StringPiece node_device, AttrSlice node_attrs, - const KernelDef** def, string* kernel_class_name) { + absl::string_view node_op, absl::string_view node_device, + AttrSlice node_attrs, const KernelDef** def, string* kernel_class_name) { const KernelRegistration* reg = nullptr; bool was_attr_mismatch; TF_RETURN_IF_ERROR(FindKernelRegistration( @@ -1492,7 +1505,7 @@ Status FindKernelDef( experimental_debug_info, node_op, node_attrs, ®, &was_attr_mismatch)); if (reg == nullptr) { const std::string device_str = DeviceTypeString(device_type); - Status s = errors::NotFound( + absl::Status s = errors::NotFound( "No registered '", node_op, "' OpKernel for ", device_str, " devices compatible with node ", FormatNodeDefForError(node_name, has_experimental_debug_info, @@ -1521,15 +1534,16 @@ Status FindKernelDef( return absl::OkStatus(); } -Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, - const KernelDef** def, string* kernel_class_name) { +absl::Status FindKernelDef(const DeviceType& device_type, + const NodeDef& node_def, const KernelDef** def, + string* kernel_class_name) { return FindKernelDef( device_type, node_def.name(), node_def.has_experimental_debug_info(), node_def.experimental_debug_info(), node_def.op(), node_def.device(), AttrSlice(&node_def.attr()), def, kernel_class_name); } -Status SupportedDeviceTypesForNode( +absl::Status SupportedDeviceTypesForNode( const std::vector& prioritized_types, const NodeDef& def, PrioritizedDeviceTypeVector* prioritized_device_types, const DeviceNameUtils::ParsedName* local_address_spec) { @@ -1538,7 +1552,7 @@ Status SupportedDeviceTypesForNode( // a user-defined function and only calls this // SupportedDeviceTypesForNode for primitive ops. const OpRegistrationData* op_reg_data; - const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data); + const absl::Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data); if (s.ok()) { bool exists_attr_mismatch = false; for (const DeviceType& device_type : prioritized_types) { @@ -1626,12 +1640,12 @@ KernelList GetFilteredRegisteredKernels( return kernel_list; } -KernelList GetRegisteredKernelsForOp(StringPiece op_name) { +KernelList GetRegisteredKernelsForOp(absl::string_view op_name) { auto op_pred = [op_name](const KernelDef& k) { return k.op() == op_name; }; return GetFilteredRegisteredKernels(op_pred); } -string KernelsRegisteredForOp(StringPiece op_name) { +string KernelsRegisteredForOp(absl::string_view op_name) { KernelList kernel_list = GetRegisteredKernelsForOp(op_name); if (kernel_list.kernel_size() == 0) return " \n"; string ret; @@ -1654,7 +1668,7 @@ string KernelsRegisteredForOp(StringPiece op_name) { * copying the NodeDef. */ std::unique_ptr CreateOpKernel( DeviceType device_type, DeviceBase* device, Allocator* allocator, - const NodeDef& node_def, int graph_def_version, Status* status) { + const NodeDef& node_def, int graph_def_version, absl::Status* status) { // Look up the Op registered for this op name. std::shared_ptr props; status->Update(NodeProperties::CreateFromNodeDef( @@ -1671,31 +1685,31 @@ std::unique_ptr CreateOpKernel( std::unique_ptr CreateOpKernel( DeviceType device_type, DeviceBase* device, Allocator* allocator, const std::shared_ptr& props, int graph_def_version, - Status* status) { + absl::Status* status) { OpKernel* kernel = nullptr; *status = CreateOpKernel(std::move(device_type), device, allocator, /*flib=*/nullptr, props, graph_def_version, &kernel); return std::unique_ptr(kernel); } -Status CreateOpKernel(DeviceType device_type, DeviceBase* device, - Allocator* allocator, FunctionLibraryRuntime* flib, - const std::shared_ptr& props, - int graph_def_version, OpKernel** kernel) { +absl::Status CreateOpKernel(DeviceType device_type, DeviceBase* device, + Allocator* allocator, FunctionLibraryRuntime* flib, + const std::shared_ptr& props, + int graph_def_version, OpKernel** kernel) { return CreateOpKernel(std::move(device_type), device, allocator, flib, /* resource_mgr= */ nullptr, props, graph_def_version, kernel); } -Status CreateOpKernel(DeviceType device_type, DeviceBase* device, - Allocator* allocator, FunctionLibraryRuntime* flib, - ResourceMgr* resource_mgr, - const std::shared_ptr& props, - int graph_def_version, OpKernel** kernel) { +absl::Status CreateOpKernel(DeviceType device_type, DeviceBase* device, + Allocator* allocator, FunctionLibraryRuntime* flib, + ResourceMgr* resource_mgr, + const std::shared_ptr& props, + int graph_def_version, OpKernel** kernel) { const NodeDef& node_def = props->node_def; bool was_attr_mismatch; const KernelRegistration* registration = nullptr; - Status s; + absl::Status s; if (props != nullptr) { VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def); @@ -1748,7 +1762,7 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device, namespace { -bool FindArgInOp(StringPiece arg_name, +bool FindArgInOp(absl::string_view arg_name, const protobuf::RepeatedPtrField& args) { for (const auto& arg : args) { if (arg_name == arg.name()) { @@ -1760,13 +1774,15 @@ bool FindArgInOp(StringPiece arg_name, } // namespace -Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) { +absl::Status ValidateKernelRegistrations( + const OpRegistryInterface& op_registry) { auto typed_registry = GlobalKernelRegistryTyped(); tf_shared_lock lock(typed_registry->mu); for (const auto& key_registration : typed_registry->registry) { const KernelDef& kernel_def(key_registration.second.def); const OpRegistrationData* op_reg_data; - const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data); + const absl::Status status = + op_registry.LookUp(kernel_def.op(), &op_reg_data); if (!status.ok()) { LOG(WARNING) << "OpKernel ('" << kernel_def.ShortDebugString() << "') for unknown op: " << kernel_def.op(); @@ -1795,48 +1811,49 @@ const Eigen::GpuDevice& OpKernelContext::eigen_device() const { return eigen_gpu_device(); } -void OpKernelConstruction::CtxFailure(const Status& s) { +void OpKernelConstruction::CtxFailure(const absl::Status& s) { VLOG(1) << s; SetStatus(s); } -void OpKernelConstruction::CtxFailureWithWarning(const Status& s) { +void OpKernelConstruction::CtxFailureWithWarning(const absl::Status& s) { LOG(WARNING) << s; SetStatus(s); } void OpKernelConstruction::CtxFailure(const char* file, int line, - const Status& s) { + const absl::Status& s) { VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line << " : " << s; SetStatus(s); } void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line, - const Status& s) { + const absl::Status& s) { LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line << " : " << s; SetStatus(s); } -void OpKernelContext::CtxFailure(const Status& s) { +void OpKernelContext::CtxFailure(const absl::Status& s) { VLOG(1) << s; SetStatus(s); } -void OpKernelContext::CtxFailureWithWarning(const Status& s) { +void OpKernelContext::CtxFailureWithWarning(const absl::Status& s) { LOG(WARNING) << s; SetStatus(s); } -void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) { +void OpKernelContext::CtxFailure(const char* file, int line, + const absl::Status& s) { VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line << " : " << s; SetStatus(s); } void OpKernelContext::CtxFailureWithWarning(const char* file, int line, - const Status& s) { + const absl::Status& s) { LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line << " : " << s; SetStatus(s); diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 264b66471291ad..d925bc214b20bc 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -866,7 +866,7 @@ class OpKernelContext { Tensor** output) TF_MUST_USE_RESULT; absl::Status forward_input_to_output_with_shape( StringPiece input_name, StringPiece output_name, - const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT; + const TensorShape& output_shape, Tensor** output); // Returns a pointer to a Tensor aliasing the underlying buffer backing // input[input_index] iff @@ -910,11 +910,11 @@ class OpKernelContext { absl::Status forward_input_or_allocate_output( absl::Span candidate_input_indices, int output_index, const TensorShape& output_shape, Tensor** output, - int* forwarded_input = nullptr) TF_MUST_USE_RESULT; + int* forwarded_input = nullptr); absl::Status forward_input_or_allocate_output( absl::Span candidate_input_names, StringPiece output_name, const TensorShape& output_shape, - Tensor** output) TF_MUST_USE_RESULT; + Tensor** output); // Tries to reuse one of the inputs given in input_indices as a temporary. // If none of the given inputs can be forwarded, calls @@ -922,11 +922,11 @@ class OpKernelContext { absl::Status forward_input_or_allocate_temp( absl::Span candidate_input_indices, DataType type, const TensorShape& shape, const AllocatorAttributes& allocator_attr, - Tensor* out_temp) TF_MUST_USE_RESULT; + Tensor* out_temp); absl::Status forward_input_or_allocate_temp( absl::Span candidate_input_indices, DataType type, - const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT { + const TensorShape& shape, Tensor* out_temp) { return forward_input_or_allocate_temp(candidate_input_indices, type, shape, AllocatorAttributes(), out_temp); } @@ -996,20 +996,18 @@ class OpKernelContext { // // REQUIRES: !IsRefType(expected_output_dtype(index)) absl::Status allocate_output(int index, const TensorShape& shape, - Tensor** tensor) TF_MUST_USE_RESULT; + Tensor** tensor); absl::Status allocate_output(StringPiece name, const TensorShape& shape, - Tensor** tensor) TF_MUST_USE_RESULT; + Tensor** tensor); // The following methods use the supplied attributes instead of // those in output_attr_array. The caller is responsible for // ensuring that the attributes are "compatible" with the // output_attr_array, e.g. the tensor is allocated on the correct // device. See comment above. absl::Status allocate_output(int index, const TensorShape& shape, - Tensor** tensor, - AllocatorAttributes attr) TF_MUST_USE_RESULT; + Tensor** tensor, AllocatorAttributes attr); absl::Status allocate_output(StringPiece name, const TensorShape& shape, - Tensor** tensor, - AllocatorAttributes attr) TF_MUST_USE_RESULT; + Tensor** tensor, AllocatorAttributes attr); // Allocates a temporary Tensor of the specified type and // shape. Devices such as GPUs that enqueue Ops for lazy execution diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index bea1208053c5e2..be8341c3753303 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -65,8 +65,8 @@ class TestOp2 : public ::tensorflow::OpKernel { public: explicit TestOp2(::tensorflow::OpKernelConstruction* context) : OpKernel(context) { - ::tensorflow::Status status = context->MatchSignature( - {::tensorflow::DT_INT32}, {::tensorflow::DT_INT32}); + absl::Status status = context->MatchSignature({::tensorflow::DT_INT32}, + {::tensorflow::DT_INT32}); match_signature_ = status.ok(); context->SetStatus(status); } @@ -205,7 +205,7 @@ class OpKernelTest : public ::testing::Test { void ExpectSuccess(const string& op_type, DeviceType device_type, const DataTypeVector& inputs, const DataTypeVector& outputs) { - Status status; + absl::Status status; std::unique_ptr op(CreateOpKernel( std::move(device_type), &device_, cpu_allocator(), CreateNodeDef(op_type, inputs), TF_GRAPH_DEF_VERSION, &status)); @@ -221,7 +221,7 @@ class OpKernelTest : public ::testing::Test { error::Code code) { NodeDef node_def; protobuf::TextFormat::ParseFromString(ascii_node_def, &node_def); - Status status; + absl::Status status; std::unique_ptr op( CreateOpKernel(std::move(device_type), &device_, cpu_allocator(), node_def, TF_GRAPH_DEF_VERSION, &status)); @@ -412,7 +412,7 @@ TEST_F(OpKernelTest, InputDtype) { OpKernelContext::Params params; DummyDevice device(env); params.device = &device; - Status status; + absl::Status status; std::unique_ptr op( CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), @@ -440,7 +440,7 @@ TEST_F(OpKernelTest, InputOnly) { OpKernelContext::Params params; DummyDevice device(env); params.device = &device; - Status status; + absl::Status status; std::unique_ptr op( CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), @@ -465,7 +465,7 @@ TEST_F(OpKernelTest, RefInputs) { OpKernelContext::Params params; DummyDevice device(env); params.device = &device; - Status status; + absl::Status status; std::unique_ptr op( CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), CreateNodeDef("RefInputs", {DT_FLOAT_REF, DT_FLOAT_REF}), @@ -493,7 +493,7 @@ TEST_F(OpKernelTest, AllocateOutput) { OpKernelContext::Params params; DummyDevice device(env); params.device = &device; - Status status; + absl::Status status; std::unique_ptr op( CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), @@ -508,7 +508,7 @@ TEST_F(OpKernelTest, AllocateOutput) { Tensor* output = nullptr; // Allocating to index -1 should fail (Only 0 should work). - Status s = ctx->allocate_output(-1, TensorShape({}), &output); + absl::Status s = ctx->allocate_output(-1, TensorShape({}), &output); EXPECT_THAT(s, tensorflow::testing::StatusIs(error::INTERNAL)); EXPECT_THAT(s.message(), ::testing::ContainsRegex("bad index=-1")); @@ -595,7 +595,7 @@ TEST_F(OpKernelTest, ScopedAllocationTest) { OpKernelContext::Params params; auto sa_device = std::make_unique(env); params.device = sa_device.get(); - Status status; + absl::Status status; std::unique_ptr op(CreateOpKernel( DEVICE_CPU, params.device, cpu_allocator(), CreateNodeDef("Test4", {DT_FLOAT}), TF_GRAPH_DEF_VERSION, &status)); @@ -633,7 +633,7 @@ TEST_F(OpKernelTest, TraceString) { DummyDevice device(env); params.device = &device; - Status status; + absl::Status status; std::unique_ptr op(CreateOpKernel( DEVICE_CPU, params.device, cpu_allocator(), CreateNodeDef("Test4", {DT_FLOAT}), TF_GRAPH_DEF_VERSION, &status)); @@ -729,7 +729,7 @@ REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU), TEST_F(OpKernelBuilderTest, DuplicateKernel) { const NodeDef ndef = CreateNodeDef("DuplicateKernel", {}); PrioritizedDeviceTypeVector devs; - Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); + absl::Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); ASSERT_FALSE(status.ok()); EXPECT_TRUE(absl::StrContains( status.message(), "Multiple OpKernel registrations match NodeDef")); @@ -749,7 +749,7 @@ TEST_F(OpKernelBuilderTest, DuplicateKernelForT) { const NodeDef ndef = CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"}); PrioritizedDeviceTypeVector devs; - Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); + absl::Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); ASSERT_FALSE(status.ok()); EXPECT_TRUE(absl::StrContains( status.message(), "Multiple OpKernel registrations match NodeDef")); @@ -770,7 +770,7 @@ REGISTER_KERNEL_BUILDER(Name("BadConstraint") TEST_F(OpKernelBuilderTest, BadConstraint) { const NodeDef ndef = CreateNodeDef("BadConstraint", {}); PrioritizedDeviceTypeVector devs; - Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); + absl::Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); ASSERT_FALSE(status.ok()); EXPECT_TRUE( absl::StrContains(status.message(), @@ -790,7 +790,7 @@ TEST_F(OpKernelBuilderTest, OpOutputList) { OpKernelContext::Params params; DummyDevice device(env); params.device = &device; - Status status; + absl::Status status; std::unique_ptr op(CreateOpKernel( DEVICE_CPU, params.device, cpu_allocator(), CreateNodeDef("ListOut", {"T|list(type)|[DT_FLOAT, DT_INT32]"}), @@ -867,7 +867,7 @@ class GetAttrKernel : public ::tensorflow::OpKernel { std::vector shape_proto_list; TensorShape shape; std::vector shape_list; - std::vector> status; + std::vector> status; }; class GetAttrTest : public OpKernelBuilderTest {}; @@ -1074,7 +1074,7 @@ TEST_F(LabelTest, Filter) { void BM_InputRangeHelper(::testing::benchmark::State& state, const NodeDef& node_def, const char* input_name, int expected_start, int expected_stop) { - Status status; + absl::Status status; auto device = std::make_unique(Env::Default()); std::unique_ptr op(CreateOpKernel(DEVICE_CPU, device.get(), @@ -1150,7 +1150,7 @@ void BM_TraceString(::testing::benchmark::State& state) { } // Build OpKernel and OpKernelContext - Status status; + absl::Status status; auto device = std::make_unique(Env::Default()); std::unique_ptr op(CreateOpKernel(DEVICE_CPU, device.get(), cpu_allocator(), node_def, diff --git a/tensorflow/core/framework/op_registration_test.cc b/tensorflow/core/framework/op_registration_test.cc index 286a0db358702c..d11f819aa99134 100644 --- a/tensorflow/core/framework/op_registration_test.cc +++ b/tensorflow/core/framework/op_registration_test.cc @@ -25,10 +25,11 @@ namespace tensorflow { namespace { void Register(const string& op_name, OpRegistry* registry) { - registry->Register([op_name](OpRegistrationData* op_reg_data) -> Status { - op_reg_data->op_def.set_name(op_name); - return absl::OkStatus(); - }); + registry->Register( + [op_name](OpRegistrationData* op_reg_data) -> absl::Status { + op_reg_data->op_def.set_name(op_name); + return absl::OkStatus(); + }); } } // namespace @@ -45,11 +46,11 @@ TEST(OpRegistrationTest, TestBasic) { TEST(OpRegistrationTest, TestDuplicate) { std::unique_ptr registry(new OpRegistry); Register("Foo", registry.get()); - Status s = registry->ProcessRegistrations(); + absl::Status s = registry->ProcessRegistrations(); EXPECT_TRUE(s.ok()); - TF_EXPECT_OK( - registry->SetWatcher([](const Status& s, const OpDef& op_def) -> Status { + TF_EXPECT_OK(registry->SetWatcher( + [](const absl::Status& s, const OpDef& op_def) -> absl::Status { EXPECT_TRUE(errors::IsAlreadyExists(s)); return absl::OkStatus(); })); diff --git a/tensorflow/core/framework/op_requires.h b/tensorflow/core/framework/op_requires.h index 85e4f53bcf81f1..d9a7e35c539ee9 100644 --- a/tensorflow/core/framework/op_requires.h +++ b/tensorflow/core/framework/op_requires.h @@ -128,8 +128,10 @@ namespace tensorflow { namespace op_requires_internal { +// ctx is usually a plain pointer, but could be a smart pointer, so we accept it +// by const ref. template -bool OkImpl(Ctx&& ctx, const char* file, int line, const S& s) { +bool OkImpl(const Ctx& ctx, const char* file, int line, const S& s) { if (!TF_PREDICT_TRUE(s.ok())) { CheckNotInComputeAsync(ctx, "OP_REQUIRES_OK_ASYNC"); ctx->CtxFailureWithWarning(file, line, s); @@ -139,8 +141,10 @@ bool OkImpl(Ctx&& ctx, const char* file, int line, const S& s) { } } +// ctx is usually a plain pointer, but could be a smart pointer, so we accept it +// by const ref. template -bool OkAsyncImpl(Ctx&& ctx, const char* file, int line, const S& s) { +bool OkAsyncImpl(const Ctx& ctx, const char* file, int line, const S& s) { if (!TF_PREDICT_TRUE(s.ok())) { ctx->CtxFailureWithWarning(file, line, s); return false; diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc index 6af4d8973b3e1c..2f583903f43670 100644 --- a/tensorflow/core/framework/op_segment.cc +++ b/tensorflow/core/framework/op_segment.cc @@ -35,9 +35,9 @@ OpSegment::~OpSegment() { for (const auto& kv : sessions_) delete kv.second; } -Status OpSegment::FindOrCreate(const string& session_handle, - const string& node_name, OpKernel** kernel, - CreateKernelFn create_fn) { +absl::Status OpSegment::FindOrCreate(const string& session_handle, + const string& node_name, OpKernel** kernel, + CreateKernelFn create_fn) { { mutex_lock l(mu_); auto item = gtl::FindPtrOrNull(sessions_, session_handle); @@ -49,7 +49,7 @@ Status OpSegment::FindOrCreate(const string& session_handle, return absl::OkStatus(); } } - Status s = create_fn(kernel); + absl::Status s = create_fn(kernel); if (!s.ok()) { LOG(ERROR) << "Create kernel failed: " << s; return s; diff --git a/tensorflow/core/framework/ops_util.cc b/tensorflow/core/framework/ops_util.cc index abe57812774933..9a4de9240822bd 100644 --- a/tensorflow/core/framework/ops_util.cc +++ b/tensorflow/core/framework/ops_util.cc @@ -37,9 +37,9 @@ Eigen::PaddingType BrainPadding2EigenPadding(Padding padding) { return Eigen::PADDING_SAME; // Prevent compiler warning about missing return } -Status GetBroadcastSize(const int index, const int in_size, const int ksize, - const int stride, const int pad_size, int* bindex, - int* bsize) { +absl::Status GetBroadcastSize(const int index, const int in_size, + const int ksize, const int stride, + const int pad_size, int* bindex, int* bsize) { // Cannot have index beyond the input size. if (index * stride > in_size) { return errors::InvalidArgument( diff --git a/tensorflow/core/framework/partial_tensor_shape_test.cc b/tensorflow/core/framework/partial_tensor_shape_test.cc index 77f81cc5a8a549..0556d3f1ad386b 100644 --- a/tensorflow/core/framework/partial_tensor_shape_test.cc +++ b/tensorflow/core/framework/partial_tensor_shape_test.cc @@ -73,7 +73,7 @@ TEST(PartialTensorShapeTest, Concatenate) { TEST(PartialTensorShapeTest, ConcatenateWithStatus) { PartialTensorShape s({10, 5, 20}); PartialTensorShape s2; - Status status = s.ConcatenateWithStatus(400, &s2); + absl::Status status = s.ConcatenateWithStatus(400, &s2); EXPECT_TRUE(status.ok()); EXPECT_EQ(s2.num_elements(), 400000); EXPECT_EQ(s2.dims(), 4); diff --git a/tensorflow/core/framework/reader_base.cc b/tensorflow/core/framework/reader_base.cc index 2e433fb1359d5a..a71818430c0292 100644 --- a/tensorflow/core/framework/reader_base.cc +++ b/tensorflow/core/framework/reader_base.cc @@ -40,12 +40,12 @@ int64_t ReaderBase::NumWorkUnitsCompleted() { return work_finished_; } -Status ReaderBase::Reset() { +absl::Status ReaderBase::Reset() { mutex_lock lock(mu_); return ResetLocked(); } -Status ReaderBase::ResetLocked() { +absl::Status ReaderBase::ResetLocked() { work_started_ = 0; work_finished_ = 0; num_records_produced_ = 0; @@ -53,25 +53,25 @@ Status ReaderBase::ResetLocked() { return absl::OkStatus(); } -Status ReaderBase::SerializeState(tstring* state) { +absl::Status ReaderBase::SerializeState(tstring* state) { mutex_lock lock(mu_); return SerializeStateLocked(state); } -Status ReaderBase::SerializeStateLocked(tstring* state) { +absl::Status ReaderBase::SerializeStateLocked(tstring* state) { return errors::Unimplemented("Reader SerializeState"); } -Status ReaderBase::RestoreState(const tstring& state) { +absl::Status ReaderBase::RestoreState(const tstring& state) { mutex_lock lock(mu_); - Status status = RestoreStateLocked(state); + absl::Status status = RestoreStateLocked(state); if (!status.ok()) { ResetLocked().IgnoreError(); } return status; } -Status ReaderBase::RestoreStateLocked(const tstring& state) { +absl::Status ReaderBase::RestoreStateLocked(const tstring& state) { return errors::Unimplemented("Reader RestoreState"); } @@ -93,7 +93,7 @@ int64_t ReaderBase::ReadUpTo(const int64_t num_records, QueueInterface* queue, if (!context->status().ok()) { return records_produced_this_call; } - Status status = OnWorkStartedLocked(); + absl::Status status = OnWorkStartedLocked(); if (status.ok()) { work_started_++; } else { @@ -103,7 +103,7 @@ int64_t ReaderBase::ReadUpTo(const int64_t num_records, QueueInterface* queue, } bool at_end = false; - Status status = + absl::Status status = ReadUpToLocked(remaining, keys, values, &num_records_produced, &at_end); // This call so far. records_produced_this_call += num_records_produced; @@ -133,14 +133,14 @@ int64_t ReaderBase::ReadUpTo(const int64_t num_records, QueueInterface* queue, } // Default implementation just reads one record at a time. -Status ReaderBase::ReadUpToLocked(int64_t num_records, - std::vector* keys, - std::vector* values, - int64_t* num_read, bool* at_end) { +absl::Status ReaderBase::ReadUpToLocked(int64_t num_records, + std::vector* keys, + std::vector* values, + int64_t* num_read, bool* at_end) { bool produced = false; tstring key; tstring value; - Status status = ReadLocked(&key, &value, &produced, at_end); + absl::Status status = ReadLocked(&key, &value, &produced, at_end); if (produced) { keys->push_back(std::move(key)); values->push_back(std::move(value)); @@ -160,7 +160,7 @@ void ReaderBase::Read(QueueInterface* queue, tstring* key, tstring* value, if (!context->status().ok()) { return; } - Status status = OnWorkStartedLocked(); + absl::Status status = OnWorkStartedLocked(); if (status.ok()) { work_started_++; } else { @@ -171,7 +171,7 @@ void ReaderBase::Read(QueueInterface* queue, tstring* key, tstring* value, bool produced = false; bool at_end = false; - Status status = ReadLocked(key, value, &produced, &at_end); + absl::Status status = ReadLocked(key, value, &produced, &at_end); if (!at_end && status.ok() && !produced) { status = errors::Internal( @@ -236,7 +236,7 @@ tstring ReaderBase::KeyName(const tstring& key) const { return strings::StrCat(current_work(), ":", key); } -Status ReaderBase::RestoreBaseState(const ReaderBaseState& state) { +absl::Status ReaderBase::RestoreBaseState(const ReaderBaseState& state) { work_started_ = state.work_started(); work_finished_ = state.work_finished(); num_records_produced_ = state.num_records_produced(); diff --git a/tensorflow/core/framework/reader_base.h b/tensorflow/core/framework/reader_base.h index 644a5618f7564e..73842644d15992 100644 --- a/tensorflow/core/framework/reader_base.h +++ b/tensorflow/core/framework/reader_base.h @@ -52,28 +52,29 @@ class ReaderBase : public ReaderInterface { // d) If there was an error producing (e.g. an error reading the file, // data corruption), return a non-OK() status. ReadLocked may be // called again if the user reruns this part of the graph. - virtual Status ReadLocked(tstring* key, tstring* value, bool* produced, - bool* at_end) = 0; + virtual absl::Status ReadLocked(tstring* key, tstring* value, bool* produced, + bool* at_end) = 0; // Descendants may optionally implement these ------------------------------- // Produce up to num_records next key/value pairs from the current // work item, in the same manner of ReadLocked. - virtual Status ReadUpToLocked(int64_t num_records, std::vector* keys, - std::vector* values, int64_t* num_read, - bool* at_end); + virtual absl::Status ReadUpToLocked(int64_t num_records, + std::vector* keys, + std::vector* values, + int64_t* num_read, bool* at_end); // Called when work starts / finishes. - virtual Status OnWorkStartedLocked() { return absl::OkStatus(); } - virtual Status OnWorkFinishedLocked() { return absl::OkStatus(); } + virtual absl::Status OnWorkStartedLocked() { return absl::OkStatus(); } + virtual absl::Status OnWorkFinishedLocked() { return absl::OkStatus(); } // Called to reset the Reader to a newly constructed state. - virtual Status ResetLocked(); + virtual absl::Status ResetLocked(); // Default implementation generates an Unimplemented error. // See the protected helper methods below. - virtual Status SerializeStateLocked(tstring* state); - virtual Status RestoreStateLocked(const tstring& state); + virtual absl::Status SerializeStateLocked(tstring* state); + virtual absl::Status RestoreStateLocked(const tstring& state); // Accessors ---------------------------------------------------------------- @@ -99,7 +100,7 @@ class ReaderBase : public ReaderInterface { // Restores ReaderBase state from state. Assumes state was filled // using SaveBaseState() above. - Status RestoreBaseState(const ReaderBaseState& state); + absl::Status RestoreBaseState(const ReaderBaseState& state); private: // For descendants that wish to obtain the next work item in a different way. @@ -119,11 +120,11 @@ class ReaderBase : public ReaderInterface { std::vector* keys, std::vector* value, OpKernelContext* context) override; - Status Reset() override; + absl::Status Reset() override; int64_t NumRecordsProduced() override; int64_t NumWorkUnitsCompleted() override; - Status SerializeState(tstring* state) override; - Status RestoreState(const tstring& state) override; + absl::Status SerializeState(tstring* state) override; + absl::Status RestoreState(const tstring& state) override; mutable mutex mu_; const string name_; diff --git a/tensorflow/core/framework/reader_op_kernel.h b/tensorflow/core/framework/reader_op_kernel.h index 1433a54e5e7d12..bc1a7629ce55b1 100644 --- a/tensorflow/core/framework/reader_op_kernel.h +++ b/tensorflow/core/framework/reader_op_kernel.h @@ -68,7 +68,7 @@ class ReaderOpKernel : public ResourceOpKernel { virtual bool IsCancellable() const { return false; } virtual void Cancel() {} - Status CreateResource(ReaderInterface** reader) + absl::Status CreateResource(ReaderInterface** reader) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { *reader = factory_(); if (*reader == nullptr) { diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h index f0ba19ae865868..30b0f50fe3696f 100644 --- a/tensorflow/core/framework/register_types.h +++ b/tensorflow/core/framework/register_types.h @@ -89,6 +89,9 @@ limitations under the License. #define TF_CALL_float8_e5m2(m) m(::tensorflow::float8_e5m2) #define TF_CALL_float8_e4m3fn(m) m(::tensorflow::float8_e4m3fn) +#define TF_CALL_float8_e4m3fnuz(m) m(::tensorflow::float8_e4m3fnuz) +#define TF_CALL_float8_e4m3b11fnuz(m) m(::tensorflow::float8_e4m3b11fnuz) +#define TF_CALL_float8_e5m2fnuz(m) m(::tensorflow::float8_e5m2fnuz) #define TF_CALL_int4(m) m(::tensorflow::int4) #define TF_CALL_uint4(m) m(::tensorflow::uint4) @@ -127,6 +130,9 @@ limitations under the License. #define TF_CALL_float8_e5m2(m) #define TF_CALL_float8_e4m3fn(m) +#define TF_CALL_float8_e4m3fnuz(m) +#define TF_CALL_float8_e4m3b11fnuz(m) +#define TF_CALL_float8_e5m2fnuz(m) #define TF_CALL_int4(m) #define TF_CALL_uint4(m) @@ -164,6 +170,9 @@ limitations under the License. #define TF_CALL_float8_e5m2(m) #define TF_CALL_float8_e4m3fn(m) +#define TF_CALL_float8_e4m3fnuz(m) +#define TF_CALL_float8_e4m3b11fnuz(m) +#define TF_CALL_float8_e5m2fnuz(m) #define TF_CALL_int4(m) #define TF_CALL_uint4(m) diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc index 1792a1c1fed17d..e7a0f0b20061b9 100644 --- a/tensorflow/core/framework/rendezvous.cc +++ b/tensorflow/core/framework/rendezvous.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/core/framework/local_rendezvous.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" @@ -38,15 +39,15 @@ namespace tensorflow { Rendezvous::ParsedKey& Rendezvous::ParsedKey::operator=(const ParsedKey& b) { const char* b_base = b.buf_.data(); buf_ = b.buf_; - src_device = StringPiece(buf_.data() + (b.src_device.data() - b_base), - b.src_device.size()); + src_device = absl::string_view(buf_.data() + (b.src_device.data() - b_base), + b.src_device.size()); src = b.src; src_incarnation = b.src_incarnation; - dst_device = StringPiece(buf_.data() + (b.dst_device.data() - b_base), - b.dst_device.size()); + dst_device = absl::string_view(buf_.data() + (b.dst_device.data() - b_base), + b.dst_device.size()); dst = b.dst; - edge_name = StringPiece(buf_.data() + (b.edge_name.data() - b_base), - b.edge_name.size()); + edge_name = absl::string_view(buf_.data() + (b.edge_name.data() - b_base), + b.edge_name.size()); return *this; } @@ -61,31 +62,30 @@ string Rendezvous::CreateKey(const string& src_device, uint64 src_incarnation, // // "src_incarnation" is used to distinguish a worker when it // restarts. - char buf[strings::kFastToBufferSize]; - return strings::StrCat( - src_device, ";", strings::Uint64ToHexString(src_incarnation, buf), ";", + return absl::StrCat( + src_device, ";", absl::Hex(src_incarnation, absl::kZeroPad16), ";", dst_device, ";", name, ";", frame_iter.frame_id, ":", frame_iter.iter_id); } // Return the prefix of "*s" up to the next occurrence of "delim", or // the whole remaining string if "delim" is not found. "*s" is advanced // past the string returned plus the delimiter (if found). -static StringPiece ConsumeNextPart(StringPiece* s, char delim) { +static absl::string_view ConsumeNextPart(absl::string_view* s, char delim) { for (size_t offset = 0; offset < s->size(); offset++) { if ((*s)[offset] == delim) { - StringPiece result(s->data(), offset); + absl::string_view result(s->data(), offset); s->remove_prefix(offset + 1); // +1: remove delim, as well return result; } } // No delimiter found: return rest of string - StringPiece result(s->data(), s->size()); + absl::string_view result(s->data(), s->size()); s->remove_prefix(s->size()); return result; } /* static */ -Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) { +absl::Status Rendezvous::ParseKey(absl::string_view key, ParsedKey* out) { if (key.data() == out->buf_.data()) { // Caller used our buf_ string directly, so we don't need to copy. (The // SendOp and RecvOp implementations do this, for example). @@ -95,8 +95,8 @@ Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) { // for the lifetime of the ParsedKey object. out->buf_.assign(key.data(), key.size()); } - StringPiece s(out->buf_); - StringPiece parts[5]; + absl::string_view s(out->buf_); + absl::string_view parts[5]; for (int i = 0; i < 5; i++) { parts[i] = ConsumeNextPart(&s, ';'); } @@ -106,9 +106,9 @@ Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) { strings::HexStringToUint64(parts[1], &out->src_incarnation) && DeviceNameUtils::ParseFullName(parts[2], &out->dst) && !parts[3].empty()) { - out->src_device = StringPiece(parts[0].data(), parts[0].size()); - out->dst_device = StringPiece(parts[2].data(), parts[2].size()); - out->edge_name = StringPiece(parts[3].data(), parts[3].size()); + out->src_device = absl::string_view(parts[0].data(), parts[0].size()); + out->dst_device = absl::string_view(parts[2].data(), parts[2].size()); + out->edge_name = absl::string_view(parts[3].data(), parts[3].size()); return absl::OkStatus(); } return errors::InvalidArgument("Invalid rendezvous key: ", key); @@ -116,15 +116,15 @@ Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) { RendezvousInterface::~RendezvousInterface() {} -Status RendezvousInterface::Recv(const ParsedKey& key, const Args& recv_args, - Tensor* val, bool* is_dead, - int64_t timeout_ms) { - Status ret; +absl::Status RendezvousInterface::Recv(const ParsedKey& key, + const Args& recv_args, Tensor* val, + bool* is_dead, int64_t timeout_ms) { + absl::Status ret; Notification n; RecvAsync(key, recv_args, - [&ret, &n, val, is_dead](const Status& s, const Args& send_args, - const Args& recv_args, const Tensor& v, - const bool dead) { + [&ret, &n, val, is_dead]( + const absl::Status& s, const Args& send_args, + const Args& recv_args, const Tensor& v, const bool dead) { ret = s; *val = v; *is_dead = dead; @@ -134,8 +134,8 @@ Status RendezvousInterface::Recv(const ParsedKey& key, const Args& recv_args, int64_t timeout_us = timeout_ms * 1000; bool notified = WaitForNotificationWithTimeout(&n, timeout_us); if (!notified) { - return Status(absl::StatusCode::kDeadlineExceeded, - "Timed out waiting for notification"); + return absl::Status(absl::StatusCode::kDeadlineExceeded, + "Timed out waiting for notification"); } } else { n.WaitForNotification(); @@ -143,8 +143,8 @@ Status RendezvousInterface::Recv(const ParsedKey& key, const Args& recv_args, return ret; } -Status RendezvousInterface::Recv(const ParsedKey& key, const Args& args, - Tensor* val, bool* is_dead) { +absl::Status RendezvousInterface::Recv(const ParsedKey& key, const Args& args, + Tensor* val, bool* is_dead) { const int64_t no_timeout = 0; return Recv(key, args, val, is_dead, no_timeout); } @@ -154,8 +154,8 @@ class LocalRendezvousWrapper : public Rendezvous { public: LocalRendezvousWrapper(int num_shards) : impl_(this, num_shards) {} - Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val, - const bool is_dead) override { + absl::Status Send(const ParsedKey& key, const Args& send_args, + const Tensor& val, const bool is_dead) override { return impl_.Send(key, send_args, val, is_dead); } @@ -164,7 +164,9 @@ class LocalRendezvousWrapper : public Rendezvous { impl_.RecvAsync(key, recv_args, std::move(done)); } - void StartAbort(const Status& status) override { impl_.StartAbort(status); } + void StartAbort(const absl::Status& status) override { + impl_.StartAbort(status); + } private: LocalRendezvous impl_; diff --git a/tensorflow/core/framework/rendezvous.h b/tensorflow/core/framework/rendezvous.h index 87861994226707..97a5daffcae3ee 100644 --- a/tensorflow/core/framework/rendezvous.h +++ b/tensorflow/core/framework/rendezvous.h @@ -58,18 +58,18 @@ class RendezvousInterface { // Parses the key constructed by CreateKey and parse src/dst device // names into structures respectively. struct ParsedKey { - StringPiece src_device; + absl::string_view src_device; DeviceNameUtils::ParsedName src; uint64 src_incarnation = 0; - StringPiece dst_device; + absl::string_view dst_device; DeviceNameUtils::ParsedName dst; - StringPiece edge_name; + absl::string_view edge_name; ParsedKey() {} ParsedKey(const ParsedKey& b) { *this = b; } ParsedKey& operator=(const ParsedKey& b); - StringPiece FullKey() const { return buf_; } + absl::string_view FullKey() const { return buf_; } private: friend class Rendezvous; @@ -164,7 +164,7 @@ class Rendezvous : public RendezvousInterface, public core::WeakRefCounted { const std::string& name, const FrameAndIter& frame_iter); - static absl::Status ParseKey(StringPiece key, ParsedKey* out); + static absl::Status ParseKey(absl::string_view key, ParsedKey* out); }; // Returns a Rendezvous instance that is limited to use only by diff --git a/tensorflow/core/framework/rendezvous_test.cc b/tensorflow/core/framework/rendezvous_test.cc index 1c52e259ba55b1..96dcf0c8729aa4 100644 --- a/tensorflow/core/framework/rendezvous_test.cc +++ b/tensorflow/core/framework/rendezvous_test.cc @@ -227,14 +227,14 @@ TEST_F(LocalRendezvousTest, CancelMultiple) { Notification n1; Notification n2; Notification n3; - Status s0; - Status s1; - Status s2; - Status s3; + absl::Status s0; + absl::Status s1; + absl::Status s2; + absl::Status s3; rendez_->RecvAsync( KeyFoo(), args, - [&n0, &s0](const Status& s, const Rendezvous::Args& send_args, + [&n0, &s0](const absl::Status& s, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, const Tensor& v, const bool dead) { s0.Update(s); @@ -242,7 +242,7 @@ TEST_F(LocalRendezvousTest, CancelMultiple) { }); rendez_->RecvAsync( KeyFoo(), args_with_cancellation, - [&n1, &s1](const Status& s, const Rendezvous::Args& send_args, + [&n1, &s1](const absl::Status& s, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, const Tensor& v, const bool dead) { s1.Update(s); @@ -250,7 +250,7 @@ TEST_F(LocalRendezvousTest, CancelMultiple) { }); rendez_->RecvAsync( KeyFoo(), args, - [&n2, &s2](const Status& s, const Rendezvous::Args& send_args, + [&n2, &s2](const absl::Status& s, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, const Tensor& v, const bool dead) { s2.Update(s); @@ -258,7 +258,7 @@ TEST_F(LocalRendezvousTest, CancelMultiple) { }); rendez_->RecvAsync( KeyFoo(), args_with_cancellation, - [&n3, &s3](const Status& s, const Rendezvous::Args& send_args, + [&n3, &s3](const absl::Status& s, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, const Tensor& v, const bool dead) { s3.Update(s); @@ -304,7 +304,7 @@ TEST_F(LocalRendezvousTest, RandomSendRecv) { TF_ASSERT_OK(rendez_->Send(MakeKey(strings::StrCat(i)), args, V(strings::StrCat(i)), false)); }); - auto recv_done = [this, &state, i](const Status& status, + auto recv_done = [this, &state, i](const absl::Status& status, const Rendezvous::Args& sender_args, const Rendezvous::Args& recver_args, const Tensor& val, const bool val_dead) { @@ -365,7 +365,7 @@ TEST_F(LocalRendezvousTest, RecvAbort) { Tensor val(DT_STRING); bool val_dead = false; Rendezvous::Args args; - Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead); + absl::Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead); EXPECT_TRUE(absl::IsAborted(status)); } @@ -381,7 +381,7 @@ TEST_F(LocalRendezvousTest, RecvSleepAbort) { Tensor val(DT_STRING); bool val_dead = false; Rendezvous::Args args; - Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead); + absl::Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead); EXPECT_TRUE(absl::IsAborted(status)); } @@ -421,7 +421,7 @@ TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) { args1.device_context = new DummyDeviceContext(1); rendez_->RecvAsync( KeyFoo(), args1, - [&n](const Status& s, const Rendezvous::Args& send_args, + [&n](const absl::Status& s, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead) { CHECK_EQ(123, dynamic_cast( send_args.device_context) @@ -462,7 +462,7 @@ void BM_RecvSend(::testing::benchmark::State& state) { bool received = false; rendez->RecvAsync( KeyFoo(), args, - [&val, &received](const Status& /*s*/, + [&val, &received](const absl::Status& /*s*/, const Rendezvous::Args& /*send_args*/, const Rendezvous::Args& /*recv_args*/, const Tensor& tensor, bool /*is_dead*/) { diff --git a/tensorflow/core/framework/resource_handle.cc b/tensorflow/core/framework/resource_handle.cc index 0fe49206846a5f..93fc5360e68c9c 100644 --- a/tensorflow/core/framework/resource_handle.cc +++ b/tensorflow/core/framework/resource_handle.cc @@ -55,8 +55,8 @@ ResourceHandle::ResourceHandle(const ResourceHandleProto& proto) { TF_CHECK_OK(FromProto(proto)); } -Status ResourceHandle::BuildResourceHandle(const ResourceHandleProto& proto, - ResourceHandle* out) { +absl::Status ResourceHandle::BuildResourceHandle( + const ResourceHandleProto& proto, ResourceHandle* out) { if (out == nullptr) return errors::Internal( "BuildResourceHandle() was called with nullptr for the output"); @@ -78,7 +78,7 @@ void ResourceHandle::AsProto(ResourceHandleProto* proto) const { } } -Status ResourceHandle::FromProto(const ResourceHandleProto& proto) { +absl::Status ResourceHandle::FromProto(const ResourceHandleProto& proto) { set_device(proto.device()); set_container(proto.container()); set_name(proto.name()); @@ -88,7 +88,7 @@ Status ResourceHandle::FromProto(const ResourceHandleProto& proto) { for (const auto& dtype_and_shape : proto.dtypes_and_shapes()) { DataType dtype = dtype_and_shape.dtype(); PartialTensorShape shape; - Status s = PartialTensorShape::BuildPartialTensorShape( + absl::Status s = PartialTensorShape::BuildPartialTensorShape( dtype_and_shape.shape(), &shape); if (!s.ok()) { return s; @@ -147,7 +147,7 @@ ResourceHandle ResourceHandle::MakeRefCountingHandle( return result; } -Status ResourceHandle::ValidateType(const TypeIndex& type_index) const { +absl::Status ResourceHandle::ValidateType(const TypeIndex& type_index) const { if (type_index.hash_code() != hash_code()) { return errors::InvalidArgument( "Trying to access a handle's resource using the wrong type. ", diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index a738f8d735addd..4e2d26ff4a764e 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -53,9 +53,11 @@ ResourceHandle MakeResourceHandle( return result; } -Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, - const string& container, const string& name, - const TypeIndex& type_index) { +absl::Status MakeResourceHandleToOutput(OpKernelContext* context, + int output_index, + const string& container, + const string& name, + const TypeIndex& type_index) { Tensor* handle; TF_RETURN_IF_ERROR( context->allocate_output(output_index, TensorShape({}), &handle)); @@ -66,7 +68,7 @@ Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, namespace internal { -Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) { +absl::Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) { if (ctx->device()->attributes().name() != p.device()) { return errors::InvalidArgument( "Trying to access resource ", p.name(), " located in device ", @@ -77,8 +79,8 @@ Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) { } // end namespace internal -Status ResourceMgr::InsertDebugTypeName(uint64 hash_code, - const string& type_name) { +absl::Status ResourceMgr::InsertDebugTypeName(uint64 hash_code, + const string& type_name) { auto iter = debug_type_names_.emplace(hash_code, type_name); if (iter.first->second != type_name) { return errors::AlreadyExists("Duplicate hash code found for type ", @@ -182,9 +184,9 @@ string ResourceMgr::DebugString() const { return absl::StrJoin(text, "\n"); } -Status ResourceMgr::DoCreate(const string& container_name, TypeIndex type, - const string& name, ResourceBase* resource, - bool owns_resource) { +absl::Status ResourceMgr::DoCreate(const string& container_name, TypeIndex type, + const string& name, ResourceBase* resource, + bool owns_resource) { Container* container = [&]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { Container** ptr = &containers_[container_name]; if (*ptr == nullptr) { @@ -197,7 +199,7 @@ Status ResourceMgr::DoCreate(const string& container_name, TypeIndex type, // key can contain a StringPiece that borrows from the string in the value. ResourceAndName resource_and_name(name); - StringPiece borrowed_name(*resource_and_name.name); + absl::string_view borrowed_name(*resource_and_name.name); if (owns_resource) { resource_and_name.resource = core::RefCountPtr(resource); @@ -225,23 +227,24 @@ Status ResourceMgr::DoCreate(const string& container_name, TypeIndex type, type.name()); } -Status ResourceMgr::Lookup(const ResourceHandle& handle, - ResourceBase** resource) const { +absl::Status ResourceMgr::Lookup(const ResourceHandle& handle, + ResourceBase** resource) const { tf_shared_lock l(mu_); return DoLookup(handle.container(), handle.hash_code(), /*type_name=*/"ResourceBase", handle.name(), resource); } -Status ResourceMgr::DoLookup(const string& container, TypeIndex type, - const string& name, - ResourceBase** resource) const { +absl::Status ResourceMgr::DoLookup(const string& container, TypeIndex type, + const string& name, + ResourceBase** resource) const { return DoLookup(container, type.hash_code(), type.name(), name, resource); } -Status ResourceMgr::DoLookup(const string& container, uint64 type_hash_code, - const string& type_name, - const string& resource_name, - ResourceBase** resource) const { +absl::Status ResourceMgr::DoLookup(const string& container, + uint64 type_hash_code, + const string& type_name, + const string& resource_name, + ResourceBase** resource) const { const Container* b = gtl::FindPtrOrNull(containers_, container); if (b == nullptr) { return errors::NotFound("Container ", container, @@ -262,11 +265,9 @@ Status ResourceMgr::DoLookup(const string& container, uint64 type_hash_code, return absl::OkStatus(); } -Status ResourceMgr::PopResourceAndName(const string& container, - uint64 type_hash_code, - const string& resource_name, - const string& type_name, - ResourceAndName& resource_and_name) { +absl::Status ResourceMgr::PopResourceAndName( + const string& container, uint64 type_hash_code, const string& resource_name, + const string& type_name, ResourceAndName& resource_and_name) { mutex_lock l(mu_); Container* b = gtl::FindPtrOrNull(containers_, container); if (b == nullptr) { @@ -282,9 +283,10 @@ Status ResourceMgr::PopResourceAndName(const string& container, return absl::OkStatus(); } -Status ResourceMgr::DoDelete(const string& container, uint64 type_hash_code, - const string& resource_name, - const string& type_name) { +absl::Status ResourceMgr::DoDelete(const string& container, + uint64 type_hash_code, + const string& resource_name, + const string& type_name) { ResourceAndName resource_and_name; TF_RETURN_IF_ERROR(PopResourceAndName( container, type_hash_code, resource_name, type_name, resource_and_name)); @@ -300,17 +302,17 @@ Status ResourceMgr::DoDelete(const string& container, uint64 type_hash_code, return absl::OkStatus(); } -Status ResourceMgr::DoDelete(const string& container, TypeIndex type, - const string& resource_name) { +absl::Status ResourceMgr::DoDelete(const string& container, TypeIndex type, + const string& resource_name) { return DoDelete(container, type.hash_code(), resource_name, type.name()); } -Status ResourceMgr::Delete(const ResourceHandle& handle) { +absl::Status ResourceMgr::Delete(const ResourceHandle& handle) { return DoDelete(handle.container(), handle.hash_code(), handle.name(), ""); } -Status ResourceMgr::Cleanup(const string& container) { +absl::Status ResourceMgr::Cleanup(const string& container) { { tf_shared_lock l(mu_); if (!gtl::FindOrNull(containers_, container)) { @@ -334,7 +336,7 @@ Status ResourceMgr::Cleanup(const string& container) { return absl::OkStatus(); } -static bool IsValidContainerName(StringPiece s) { +static bool IsValidContainerName(absl::string_view s) { using ::tensorflow::strings::Scanner; return Scanner(s) .One(Scanner::LETTER_DIGIT_DOT) @@ -343,8 +345,8 @@ static bool IsValidContainerName(StringPiece s) { .GetResult(); } -Status ContainerInfo::Init(ResourceMgr* rmgr, const NodeDef& ndef, - bool use_node_name_as_default) { +absl::Status ContainerInfo::Init(ResourceMgr* rmgr, const NodeDef& ndef, + bool use_node_name_as_default) { CHECK(rmgr); rmgr_ = rmgr; string attr_container; @@ -387,8 +389,8 @@ const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input) { return ctx->input(input).flat()(0); } -Status HandleFromInput(OpKernelContext* ctx, int input, - ResourceHandle* handle) { +absl::Status HandleFromInput(OpKernelContext* ctx, int input, + ResourceHandle* handle) { TF_ASSIGN_OR_RETURN(const Tensor* tensor, ctx->get_input(input)); if (tensor->NumElements() == 0) { return absl::InvalidArgumentError("Empty resource handle"); @@ -397,8 +399,8 @@ Status HandleFromInput(OpKernelContext* ctx, int input, return absl::OkStatus(); } -Status HandleFromInput(OpKernelContext* ctx, StringPiece input, - ResourceHandle* handle) { +absl::Status HandleFromInput(OpKernelContext* ctx, absl::string_view input, + ResourceHandle* handle) { const Tensor* tensor; TF_RETURN_IF_ERROR(ctx->input(input, &tensor)); if (tensor->NumElements() == 0) { @@ -408,8 +410,8 @@ Status HandleFromInput(OpKernelContext* ctx, StringPiece input, return absl::OkStatus(); } -Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, - ResourceBase** value) { +absl::Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, + ResourceBase** value) { TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p)); if (p.IsRefCounting()) { TF_ASSIGN_OR_RETURN(*value, p.GetResource()); @@ -419,7 +421,7 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, return ctx->resource_manager()->Lookup(p, value); } -Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { +absl::Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p)); if (p.IsRefCounting()) { return absl::OkStatus(); diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index 658ed31ebfea9f..74e26b43588a56 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -123,19 +123,19 @@ class ScopedStepContainer { const std::string& name, const DeviceBase& device) TF_MUST_USE_RESULT; // Pass through to ResourceMgr::Create with the container name template - Status Create(ResourceMgr* rm, const std::string& name, - T* resource) TF_MUST_USE_RESULT; + absl::Status Create(ResourceMgr* rm, const std::string& name, T* resource); // Pass through to ResourceMgr::Delete with the container name template - Status Delete(ResourceMgr* rm, const std::string& name) TF_MUST_USE_RESULT; + absl::Status Delete(ResourceMgr* rm, const std::string& name); // Pass through to ResourceMgr::Lookup with the container name template - Status Lookup(ResourceMgr* rm, const std::string& name, - T** resource) const TF_MUST_USE_RESULT; + absl::Status Lookup(ResourceMgr* rm, const std::string& name, + T** resource) const; // Pass through to ResourceMgr::LookupOrCreate with the container name template - Status LookupOrCreate(ResourceMgr* rm, const std::string& name, T** resource, - std::function creator) TF_MUST_USE_RESULT; + absl::Status LookupOrCreate(ResourceMgr* rm, const std::string& name, + T** resource, + std::function creator); int64_t StepId() const { return step_id_; } private: @@ -162,8 +162,8 @@ class ResourceMgr { // REQUIRES: std::is_base_of // REQUIRES: resource != nullptr. template - Status Create(const std::string& container, const std::string& name, - T* resource) TF_MUST_USE_RESULT; + absl::Status Create(const std::string& container, const std::string& name, + T* resource); // Creates a unowned resource "name" in the "container". The caller does NOT // transfer the ownership of any ref on "resource" to *this, regardless of @@ -176,8 +176,8 @@ class ResourceMgr { // REQUIRES: std::is_base_of // REQUIRES: resource != nullptr. template - Status CreateUnowned(const std::string& container, const std::string& name, - T* resource) TF_MUST_USE_RESULT; + absl::Status CreateUnowned(const std::string& container, + const std::string& name, T* resource); // If "container" has a resource "name", returns it in "*resource" and // the caller takes the ownership of one ref on "*resource". @@ -185,24 +185,24 @@ class ResourceMgr { // REQUIRES: std::is_base_of // REQUIRES: resource != nullptr template - Status Lookup(const std::string& container, const std::string& name, - T** resource) const TF_MUST_USE_RESULT; + absl::Status Lookup(const std::string& container, const std::string& name, + T** resource) const; // If the resource manager has a resource matching "handle", returns it in // "*resource" and the caller takes the ownership of one ref on "*resource". // // REQUIRES: resource != nullptr - Status Lookup(const ResourceHandle& handle, - ResourceBase** resource) const TF_MUST_USE_RESULT; + absl::Status Lookup(const ResourceHandle& handle, + ResourceBase** resource) const; // Similar to Lookup, but looks up multiple resources at once, with only a // single lock acquisition. If containers_and_names[i] is uninitialized // then this function does not modify resources[i]. template - Status LookupMany(absl::Span const> - containers_and_names, - std::vector>* resources) const - TF_MUST_USE_RESULT; + absl::Status LookupMany( + absl::Span const> + containers_and_names, + std::vector>* resources) const; // If "container" has a resource "name", returns it in // "*resource". Otherwise, invokes creator() to create the resource. @@ -215,22 +215,21 @@ class ResourceMgr { // REQUIRES: std::is_base_of // REQUIRES: resource != nullptr template - Status LookupOrCreate(const std::string& container, const std::string& name, - T** resource, - std::function creator) TF_MUST_USE_RESULT; + absl::Status LookupOrCreate(const std::string& container, + const std::string& name, T** resource, + std::function creator); // Deletes the resource "name" from the "container". // // REQUIRES: std::is_base_of template - Status Delete(const std::string& container, - const std::string& name) TF_MUST_USE_RESULT; + absl::Status Delete(const std::string& container, const std::string& name); // Deletes the resource pointed by "handle". - Status Delete(const ResourceHandle& handle) TF_MUST_USE_RESULT; + absl::Status Delete(const ResourceHandle& handle); // Deletes all resources from the "container" and removes the container. - Status Cleanup(const std::string& container) TF_MUST_USE_RESULT; + absl::Status Cleanup(const std::string& container); // Deletes all resources in all containers. void Clear(); @@ -239,7 +238,7 @@ class ResourceMgr { std::string DebugString() const; private: - typedef std::pair Key; + typedef std::pair Key; struct KeyHash { std::size_t operator()(const Key& k) const { return Hash64(k.second.data(), k.second.size(), k.first); @@ -278,42 +277,44 @@ class ResourceMgr { absl::flat_hash_map containers_ TF_GUARDED_BY(mu_); template - Status LookupInternal(const std::string& container, const std::string& name, - T** resource) const - TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT; - Status LookupInternal(const std::string& container, uint64 type_hash_code, + absl::Status LookupInternal(const std::string& container, + const std::string& name, T** resource) const + TF_SHARED_LOCKS_REQUIRED(mu_); + absl::Status LookupInternal(const std::string& container, + uint64 type_hash_code, const std::string& name, + ResourceBase** resource) const + TF_SHARED_LOCKS_REQUIRED(mu_); + + absl::Status DoCreate(const std::string& container, TypeIndex type, + const std::string& name, ResourceBase* resource, + bool owns_resource) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + absl::Status DoLookup(const std::string& container, TypeIndex type, const std::string& name, ResourceBase** resource) const - TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT; - - Status DoCreate(const std::string& container, TypeIndex type, - const std::string& name, ResourceBase* resource, - bool owns_resource) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT; - - Status DoLookup(const std::string& container, TypeIndex type, - const std::string& name, ResourceBase** resource) const - TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT; - Status DoLookup(const std::string& container, uint64 type_hash_code, - const std::string& type_name, - const std::string& resource_name, - ResourceBase** resource) const - TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT; - - Status DoDelete(const std::string& container, uint64 type_hash_code, - const std::string& resource_name, - const std::string& type_name) TF_MUST_USE_RESULT; - Status DoDelete(const std::string& container, TypeIndex type, - const std::string& resource_name) TF_MUST_USE_RESULT; + TF_SHARED_LOCKS_REQUIRED(mu_); + absl::Status DoLookup(const std::string& container, uint64 type_hash_code, + const std::string& type_name, + const std::string& resource_name, + ResourceBase** resource) const + TF_SHARED_LOCKS_REQUIRED(mu_); + + absl::Status DoDelete(const std::string& container, uint64 type_hash_code, + const std::string& resource_name, + const std::string& type_name); + absl::Status DoDelete(const std::string& container, TypeIndex type, + const std::string& resource_name); // Pops the ResourceAndName entry. The entry is moved from the list to // the output argument `resource_and_name`. - Status PopResourceAndName( - const std::string& container, uint64 type_hash_code, - const std::string& resource_name, const std::string& type_name, - ResourceAndName& resource_and_name) TF_MUST_USE_RESULT; + absl::Status PopResourceAndName(const std::string& container, + uint64 type_hash_code, + const std::string& resource_name, + const std::string& type_name, + ResourceAndName& resource_and_name); // Inserts the type name for 'hash_code' into the hash_code to type name map. - Status InsertDebugTypeName(uint64 hash_code, const std::string& type_name) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT; + absl::Status InsertDebugTypeName(uint64 hash_code, + const std::string& type_name) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Returns the type name for the 'hash_code'. // Returns "" if a resource with such a type was never inserted into @@ -362,49 +363,54 @@ ResourceHandle MakeResourceHandle( dtypes_and_shapes, definition_stack_trace); } -Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, - const std::string& container, - const std::string& name, - const TypeIndex& type_index); +absl::Status MakeResourceHandleToOutput(OpKernelContext* context, + int output_index, + const std::string& container, + const std::string& name, + const TypeIndex& type_index); // Returns a resource handle from a numbered op input. const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input); // Safely returns a resource handle from a numbered op input. // Prevents segfault by checking for empty resource handle. -Status HandleFromInput(OpKernelContext* ctx, int input, ResourceHandle* handle); +absl::Status HandleFromInput(OpKernelContext* ctx, int input, + ResourceHandle* handle); // Returns a resource handle by name, as defined in the OpDef. // Also prevents segfault by checking for empty resource handle. -Status HandleFromInput(OpKernelContext* ctx, StringPiece input, - ResourceHandle* handle); +absl::Status HandleFromInput(OpKernelContext* ctx, absl::string_view input, + ResourceHandle* handle); // Create a resource pointed by a given resource handle. // // If successful, the caller transfers the ownership of one ref on `resource` to // `ctx->resource_mgr()`. template -Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value); +absl::Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, + T* value); // Looks up a resource pointed by a given resource handle. // // If the lookup is successful, the caller takes the ownership of one ref on // `*value`, and must call its `Unref()` method when it has finished using it. template -Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value); +absl::Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, + T** value); // Looks up a resource pointed by a given resource handle. // // Prefer usage of LookupResource taking `core::RefCountPtr` to avoid // requiring the caller to explicitly call `Unref()`. template -Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, - core::RefCountPtr* value); +absl::Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, + core::RefCountPtr* value); // Looks up multiple resources pointed by a sequence of resource handles. If // p[i] is uninitialized then values[i] is unmodified. template -Status LookupResources(OpKernelContext* ctx, absl::Span p, - std::vector>* values); +absl::Status LookupResources(OpKernelContext* ctx, + absl::Span p, + std::vector>* values); // Looks up or creates a resource. // @@ -416,23 +422,25 @@ Status LookupResources(OpKernelContext* ctx, absl::Span p, // Prefer usage of LookupOrCreateResource taking `core::RefCountPtr` to avoid // requiring the caller to explicitly call `Unref()`. template -Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, - T** value, std::function creator); +absl::Status LookupOrCreateResource(OpKernelContext* ctx, + const ResourceHandle& p, T** value, + std::function creator); // Looks up or creates a resource. template -Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, - core::RefCountPtr* value, - std::function creator); +absl::Status LookupOrCreateResource(OpKernelContext* ctx, + const ResourceHandle& p, + core::RefCountPtr* value, + std::function creator); // Destroys a resource pointed by a given resource handle. template -Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); +absl::Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); // Same as above, but uses the hash code of the type directly. // The type name information will be missing in the debug output when the // resource is not present in the container. -Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); +absl::Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); // Policy helper to decide which container/shared_name to use for a // stateful kernel that accesses shared resource. @@ -453,9 +461,9 @@ class ContainerInfo { // Otherwise, if "use_node_name_as_default" is true, the kernel's // node name is used as the resource name. Otherwise, a string // unique to this process is used. - Status Init(ResourceMgr* rmgr, const NodeDef& ndef, - bool use_node_name_as_default); - Status Init(ResourceMgr* rmgr, const NodeDef& ndef) { + absl::Status Init(ResourceMgr* rmgr, const NodeDef& ndef, + bool use_node_name_as_default); + absl::Status Init(ResourceMgr* rmgr, const NodeDef& ndef) { return Init(rmgr, ndef, false); } @@ -490,8 +498,9 @@ class ContainerInfo { // Returns OK if the resource is found and transfers one ref of // *resource to the caller. Otherwise, returns an error. template -Status GetResourceFromContext(OpKernelContext* ctx, - const std::string& input_name, T** resource); +absl::Status GetResourceFromContext(OpKernelContext* ctx, + const std::string& input_name, + T** resource); // Utility op kernel to check if a handle to resource type T is initialized. template @@ -637,8 +646,8 @@ void CheckDeriveFromResourceBase() { } template -Status ResourceMgr::Create(const std::string& container, - const std::string& name, T* resource) { +absl::Status ResourceMgr::Create(const std::string& container, + const std::string& name, T* resource) { CheckDeriveFromResourceBase(); CHECK(resource != nullptr); mutex_lock l(mu_); @@ -647,8 +656,8 @@ Status ResourceMgr::Create(const std::string& container, } template -Status ResourceMgr::CreateUnowned(const std::string& container, - const std::string& name, T* resource) { +absl::Status ResourceMgr::CreateUnowned(const std::string& container, + const std::string& name, T* resource) { CheckDeriveFromResourceBase(); mutex_lock l(mu_); return DoCreate(container, TypeIndex::Make(), name, resource, @@ -656,15 +665,15 @@ Status ResourceMgr::CreateUnowned(const std::string& container, } template -Status ResourceMgr::Lookup(const std::string& container, - const std::string& name, T** resource) const { +absl::Status ResourceMgr::Lookup(const std::string& container, + const std::string& name, T** resource) const { CheckDeriveFromResourceBase(); tf_shared_lock l(mu_); return LookupInternal(container, name, resource); } template -Status ResourceMgr::LookupMany( +absl::Status ResourceMgr::LookupMany( absl::Span const> containers_and_names, std::vector>* resources) const { @@ -673,7 +682,7 @@ Status ResourceMgr::LookupMany( resources->resize(containers_and_names.size()); for (size_t i = 0; i < containers_and_names.size(); ++i) { T* resource; - Status s = LookupInternal( + absl::Status s = LookupInternal( *containers_and_names[i].first, *containers_and_names[i].second, &resource); if (s.ok()) { @@ -695,11 +704,11 @@ struct TypeCastFunctor { }; template -Status ResourceMgr::LookupInternal(const std::string& container, - const std::string& name, - T** resource) const { +absl::Status ResourceMgr::LookupInternal(const std::string& container, + const std::string& name, + T** resource) const { ResourceBase* found = nullptr; - Status s = DoLookup(container, TypeIndex::Make(), name, &found); + absl::Status s = DoLookup(container, TypeIndex::Make(), name, &found); if (s.ok()) { // It's safe to down cast 'found' to T* since // typeid(T).hash_code() is part of the map key. @@ -709,12 +718,12 @@ Status ResourceMgr::LookupInternal(const std::string& container, } template -Status ResourceMgr::LookupOrCreate(const std::string& container, - const std::string& name, T** resource, - std::function creator) { +absl::Status ResourceMgr::LookupOrCreate( + const std::string& container, const std::string& name, T** resource, + std::function creator) { CheckDeriveFromResourceBase(); *resource = nullptr; - Status s; + absl::Status s; { tf_shared_lock l(mu_); s = LookupInternal(container, name, resource); @@ -734,15 +743,16 @@ Status ResourceMgr::LookupOrCreate(const std::string& container, } template -Status ResourceMgr::Delete(const std::string& container, - const std::string& name) { +absl::Status ResourceMgr::Delete(const std::string& container, + const std::string& name) { CheckDeriveFromResourceBase(); return DoDelete(container, TypeIndex::Make(), name); } template -Status GetResourceFromContext(OpKernelContext* ctx, - const std::string& input_name, T** resource) { +absl::Status GetResourceFromContext(OpKernelContext* ctx, + const std::string& input_name, + T** resource) { DataType dtype; TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype)); if (dtype == DT_RESOURCE) { @@ -771,10 +781,11 @@ Status GetResourceFromContext(OpKernelContext* ctx, namespace internal { -Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p); +absl::Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p); template -Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) { +absl::Status ValidateDeviceAndType(OpKernelContext* ctx, + const ResourceHandle& p) { TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p)); TF_RETURN_IF_ERROR(p.ValidateType()); return absl::OkStatus(); @@ -786,7 +797,8 @@ Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) { // one ref on "*value" to the resource manager in "ctx", regardless of whether // this operation succeeds or fails. template -Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) { +absl::Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, + T* value) { TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType(ctx, p)); return ctx->resource_manager()->Create(p.container(), p.name(), value); } @@ -797,8 +809,8 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) { // Always returns a new reference to the resource in "*value". The caller shall // call (*value)->Unref(). template -Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, - T** value) { +absl::Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, + T** value) { TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType(ctx, p)); if (p.IsRefCounting()) { TF_ASSIGN_OR_RETURN(*value, p.GetResource()); @@ -813,14 +825,14 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, // Finds the resource as "*value" from the handle. This is a type-erased // variant of LookupResource above. -Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, - ResourceBase** value); +absl::Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, + ResourceBase** value); // If the resource manager in "ctx" has a resource matching "p", returns it in // "*value". template -Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, - core::RefCountPtr* value) { +absl::Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, + core::RefCountPtr* value) { T* raw_ptr = nullptr; TF_RETURN_IF_ERROR(LookupResource(ctx, p, &raw_ptr)); value->reset(raw_ptr); @@ -831,9 +843,9 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, // Similar to Lookup, but looks up multiple resources at once, with only a // single lock acquisition. template -Status LookupResources(OpKernelContext* ctx, - absl::Span p, - std::vector>* values) { +absl::Status LookupResources(OpKernelContext* ctx, + absl::Span p, + std::vector>* values) { std::vector> containers_and_names( p.size()); for (size_t i = 0; i < p.size(); ++i) { @@ -851,8 +863,9 @@ Status LookupResources(OpKernelContext* ctx, // its execution, because a non-reentrant lock is held during the creator() call // in order to guarantee atomicity of LookupOrCreateResource(). template -Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, - T** value, std::function creator) { +absl::Status LookupOrCreateResource(OpKernelContext* ctx, + const ResourceHandle& p, T** value, + std::function creator) { TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType(ctx, p)); return ctx->resource_manager()->LookupOrCreate(p.container(), p.name(), value, creator); @@ -865,9 +878,10 @@ Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, // its execution, because a non-reentrant lock is held during the creator() call // in order to guarantee atomicity of LookupOrCreateResource(). template -Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, - core::RefCountPtr* value, - std::function creator) { +absl::Status LookupOrCreateResource(OpKernelContext* ctx, + const ResourceHandle& p, + core::RefCountPtr* value, + std::function creator) { T* raw_ptr = nullptr; TF_RETURN_IF_ERROR(LookupOrCreateResource(ctx, p, &raw_ptr, creator)); value->reset(raw_ptr); @@ -877,7 +891,7 @@ Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, // Deletes the resource pointed by "p", using the resource manager in "ctx". template -Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { +absl::Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType(ctx, p)); // This is a noop because ResourceMgr does not hold a reference. // NOTE(feyu): if we can convert all resources handle to ref-counting, then @@ -889,7 +903,7 @@ Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { } // Deletes the resource pointed by "p", using the resource manager in "ctx". -Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); +absl::Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); template void IsResourceInitialized::Compute(OpKernelContext* ctx) { @@ -994,31 +1008,32 @@ ResourceHandle ScopedStepContainer::MakeResourceHandle( } template -Status ScopedStepContainer::Lookup(ResourceMgr* rm, const std::string& name, - T** resource) const { +absl::Status ScopedStepContainer::Lookup(ResourceMgr* rm, + const std::string& name, + T** resource) const { return rm->Lookup(container_, name, resource); } template -Status ScopedStepContainer::LookupOrCreate(ResourceMgr* rm, - const std::string& name, - T** resource, - std::function creator) { +absl::Status ScopedStepContainer::LookupOrCreate( + ResourceMgr* rm, const std::string& name, T** resource, + std::function creator) { mutex_lock ml(mu_); dirty_ = true; return rm->LookupOrCreate(container_, name, resource, creator); } template -Status ScopedStepContainer::Create(ResourceMgr* rm, const std::string& name, - T* resource) { +absl::Status ScopedStepContainer::Create(ResourceMgr* rm, + const std::string& name, T* resource) { mutex_lock ml(mu_); dirty_ = true; return rm->Create(container_, name, resource); } template -Status ScopedStepContainer::Delete(ResourceMgr* rm, const std::string& name) { +absl::Status ScopedStepContainer::Delete(ResourceMgr* rm, + const std::string& name) { return rm->Delete(container_, name); } diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc index 6b12270ab97528..21d36dd16c04f8 100644 --- a/tensorflow/core/framework/resource_mgr_test.cc +++ b/tensorflow/core/framework/resource_mgr_test.cc @@ -80,7 +80,7 @@ string LookupOrCreate(ResourceMgr* rm, const string& container, return ret; } -static void HasError(const Status& s, const error::Code code, +static void HasError(const absl::Status& s, const error::Code code, const string& substr) { EXPECT_EQ(s.code(), code); EXPECT_TRUE(absl::StrContains(s.message(), substr)) @@ -88,10 +88,10 @@ static void HasError(const Status& s, const error::Code code, } template -Status FindErr(const ResourceMgr& rm, const string& container, - const string& name) { +absl::Status FindErr(const ResourceMgr& rm, const string& container, + const string& name) { T* r; - Status s = rm.Lookup(container, name, &r); + absl::Status s = rm.Lookup(container, name, &r); CHECK(!s.ok()); return s; } @@ -250,9 +250,9 @@ TEST(ResourceMgrTest, CreateOrLookupRaceCondition) { EXPECT_EQ(1, atomic_int); } -Status ComputePolicy(const string& attr_container, - const string& attr_shared_name, - bool use_node_name_as_default, string* result) { +absl::Status ComputePolicy(const string& attr_container, + const string& attr_shared_name, + bool use_node_name_as_default, string* result) { ContainerInfo cinfo; ResourceMgr rmgr; NodeDef ndef; @@ -292,8 +292,9 @@ TEST(ContainerInfo, Basic) { EXPECT_EQ(Policy(".cat", "bar", true), "[.cat,bar,public]"); } -Status WrongPolicy(const string& attr_container, const string& attr_shared_name, - bool use_node_name_as_default) { +absl::Status WrongPolicy(const string& attr_container, + const string& attr_shared_name, + bool use_node_name_as_default) { string dbg; auto s = ComputePolicy(attr_container, attr_shared_name, use_node_name_as_default, &dbg); diff --git a/tensorflow/core/framework/run_handler_util.cc b/tensorflow/core/framework/run_handler_util.cc index 8c0b32d352fe2d..47020ee1fcb670 100644 --- a/tensorflow/core/framework/run_handler_util.cc +++ b/tensorflow/core/framework/run_handler_util.cc @@ -26,7 +26,7 @@ namespace tensorflow { double ParamFromEnvWithDefault(const char* var_name, double default_value) { const char* val = std::getenv(var_name); double num; - return (val && strings::safe_strtod(val, &num)) ? num : default_value; + return (val && absl::SimpleAtod(val, &num)) ? num : default_value; } std::vector ParamFromEnvWithDefault(const char* var_name, diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 9a34865b3810b7..b63269f68c3368 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -119,7 +119,7 @@ absl::Status InferenceContext::Run( } absl::Status InferenceContext::set_output( - StringPiece output_name, const std::vector& shapes) { + absl::string_view output_name, const std::vector& shapes) { auto result = output_name_map_.find(output_name); if (result == output_name_map_.end()) { return errors::InvalidArgument("Unknown output name: ", output_name); @@ -137,7 +137,7 @@ absl::Status InferenceContext::set_output( return absl::OkStatus(); } -absl::Status InferenceContext::input(StringPiece input_name, +absl::Status InferenceContext::input(absl::string_view input_name, std::vector* output) const { const auto result = input_name_map_.find(input_name); if (result == input_name_map_.end()) { @@ -151,7 +151,7 @@ absl::Status InferenceContext::input(StringPiece input_name, return absl::OkStatus(); } -absl::Status InferenceContext::output(StringPiece output_name, +absl::Status InferenceContext::output(absl::string_view output_name, std::vector* output) const { const auto result = output_name_map_.find(output_name); if (result == output_name_map_.end()) { diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 4c02335ba82f82..8bfd301d860de1 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -340,7 +340,7 @@ class InferenceContext { void SetInput(int idx, ShapeHandle shape) { inputs_[idx] = shape; } ShapeHandle input(int64_t idx) const { return inputs_[idx]; } - absl::Status input(StringPiece input_name, + absl::Status input(absl::string_view input_name, std::vector* output) const; int num_inputs() const { return inputs_.size(); } @@ -394,20 +394,20 @@ class InferenceContext { ShapeHandle output(int64_t idx) const { return outputs_.at(idx); } void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; } - absl::Status set_output(StringPiece output_name, + absl::Status set_output(absl::string_view output_name, const std::vector& shapes); int num_outputs() const { return outputs_.size(); } ShapeHandle output(int idx) const { return outputs_.at(idx); } - absl::Status output(StringPiece output_name, + absl::Status output(absl::string_view output_name, std::vector* output) const; // Returns the value for attribute named `attr_name`. - absl::Status GetAttr(StringPiece attr_name, + absl::Status GetAttr(absl::string_view attr_name, const AttrValue** attr_value) const { return attrs_.Find(attr_name, attr_value); } - const AttrValue* GetAttr(StringPiece attr_name) const { + const AttrValue* GetAttr(absl::string_view attr_name) const { return attrs_.Find(attr_name); } @@ -467,69 +467,63 @@ class InferenceContext { // the shape with asserted rank in <*out>. Otherwise return an error. // // Note that <*out> may be set to . - absl::Status WithRank(ShapeHandle shape, int64_t rank, - ShapeHandle* out) TF_MUST_USE_RESULT; + absl::Status WithRank(ShapeHandle shape, int64_t rank, ShapeHandle* out); absl::Status WithRankAtLeast(ShapeHandle shape, int64_t rank, - ShapeHandle* out) TF_MUST_USE_RESULT; + ShapeHandle* out); absl::Status WithRankAtMost(ShapeHandle shape, int64_t rank, - ShapeHandle* out) TF_MUST_USE_RESULT; + ShapeHandle* out); // If has value , or its value is unknown, returns OK and returns // the dimension with asserted value in <*out>. Otherwise returns an error. // // Note that <*out> may be set to . absl::Status WithValue(DimensionHandle dim, int64_t value, - DimensionHandle* out) TF_MUST_USE_RESULT; + DimensionHandle* out); // Merges and and returns the merged shape in <*out>. See // 'MergeInput' function for full details and examples. - absl::Status Merge(ShapeHandle s0, ShapeHandle s1, - ShapeHandle* out) TF_MUST_USE_RESULT; + absl::Status Merge(ShapeHandle s0, ShapeHandle s1, ShapeHandle* out); // Asserts that 's rank >= 's rank, and the first // dimensions of are compatible with the dimensions of // . // Returns the merged results in <*s_out> and <*prefix_out>. absl::Status MergePrefix(ShapeHandle s, ShapeHandle prefix, - ShapeHandle* s_out, - ShapeHandle* prefix_out) TF_MUST_USE_RESULT; + ShapeHandle* s_out, ShapeHandle* prefix_out); // Merges and and returns the merged dimension in <*out>. If // and have incompatible values, returns an error. // // Note that <*out> may be set to or . absl::Status Merge(DimensionHandle d0, DimensionHandle d1, - DimensionHandle* out) TF_MUST_USE_RESULT; + DimensionHandle* out); // Returns in <*out> a sub-shape of with dimensions [start:]. // can be negative to index from the end of the shape. If > // rank of , then an empty subshape is returned. - absl::Status Subshape(ShapeHandle s, int64_t start, - ShapeHandle* out) TF_MUST_USE_RESULT; + absl::Status Subshape(ShapeHandle s, int64_t start, ShapeHandle* out); // Returns in <*out> a sub-shape of , with dimensions [start:end]. // and can be negative, to index from the end of the shape. // and are set to the rank of if > rank of . absl::Status Subshape(ShapeHandle s, int64_t start, int64_t end, - ShapeHandle* out) TF_MUST_USE_RESULT; + ShapeHandle* out); // Returns in <*out> a sub-shape of , with dimensions [start:end:stride]. // and can be negative, to index from the end of the shape. // and are set to the rank of if > rank of . // can be negative, to reverse the . absl::Status Subshape(ShapeHandle s, int64_t start, int64_t end, - int64_t stride, ShapeHandle* out) TF_MUST_USE_RESULT; + int64_t stride, ShapeHandle* out); // Returns in <*out> the result of appending the dimensions of to those // of . - absl::Status Concatenate(ShapeHandle s1, ShapeHandle s2, - ShapeHandle* out) TF_MUST_USE_RESULT; + absl::Status Concatenate(ShapeHandle s1, ShapeHandle s2, ShapeHandle* out); // Returns in the shape from replacing with // . absl::Status ReplaceDim(ShapeHandle s, int64_t dim_index, - DimensionHandle new_dim, - ShapeHandle* out) TF_MUST_USE_RESULT; + DimensionHandle new_dim, ShapeHandle* out); // Returns a new shape with the given dims. The returned value is owned by // this context. @@ -611,7 +605,7 @@ class InferenceContext { // value. If no attr with attr_name is found in def(), or the attr does not // have a matching type, a non-ok status will be returned. template - absl::Status GetAttr(StringPiece attr_name, T* value) const; + absl::Status GetAttr(absl::string_view attr_name, T* value) const; // Returns in the result of dividing by . // Returns an error if is not positive or if @@ -919,7 +913,8 @@ inline DimensionOrConstant::DimensionOrConstant(int64_t val) : val(val) { } template -absl::Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const { +absl::Status InferenceContext::GetAttr(absl::string_view attr_name, + T* value) const { return GetNodeAttr(attrs_, attr_name, value); } diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc index b4cd528a4470c6..98ed4a60833da3 100644 --- a/tensorflow/core/framework/shape_inference_testutil.cc +++ b/tensorflow/core/framework/shape_inference_testutil.cc @@ -203,7 +203,7 @@ absl::Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op, } else { // Parse it as a value. int64_t value = -1; - if (!strings::safe_strto64(expected_dim, &value)) { + if (!absl::SimpleAtoi(expected_dim, &value)) { return Unknown(err_prefix, ": the expected dimension value '", expected_dim, "' failed to parse as int64", err_suffix); diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h index d65965b43c2b51..c9b9bd74a8515f 100644 --- a/tensorflow/core/framework/shape_inference_testutil.h +++ b/tensorflow/core/framework/shape_inference_testutil.h @@ -33,7 +33,7 @@ class Tensor; struct ShapeInferenceTestOp { typedef std::pair ShapeAndType; - explicit ShapeInferenceTestOp(StringPiece name) : name(string(name)) {} + explicit ShapeInferenceTestOp(absl::string_view name) : name(string(name)) {} string name; NodeDef node_def; std::vector input_tensors; diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index f2cd323101c625..efb228ee6b47a6 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -66,6 +66,7 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tensor_coding.h" #include "tensorflow/core/platform/types.h" +#include "tsl/platform/ml_dtypes.h" namespace tensorflow { @@ -179,8 +180,8 @@ struct Helper { template static void Encode(TensorBuffer* in, int64_t n, Destination* out) { DCHECK_EQ(in->size(), sizeof(T) * n); - port::AssignRefCounted(StringPiece(in->base(), in->size()), in, - out); + port::AssignRefCounted( + absl::string_view(in->base(), in->size()), in, out); } // Decoder of simple type T. Copy the bytes from "in" into the @@ -563,6 +564,18 @@ struct ProtoHelper : public Float8ProtoHelper {}; template <> struct ProtoHelper : public Float8ProtoHelper {}; +template <> +struct ProtoHelper + : public Float8ProtoHelper {}; + +template <> +struct ProtoHelper + : public Float8ProtoHelper {}; + +template <> +struct ProtoHelper + : public Float8ProtoHelper {}; + template Buffer::Buffer(Allocator* a, int64_t n) : BufferBase(a, TypedAllocator::Allocate(a, n, AllocationAttributes())), @@ -950,6 +963,9 @@ int Tensor::RefCount() const { CASE(Variant, SINGLE_ARG(STMTS)) \ CASE(float8_e5m2, SINGLE_ARG(STMTS)) \ CASE(float8_e4m3fn, SINGLE_ARG(STMTS)) \ + CASE(float8_e4m3fnuz, SINGLE_ARG(STMTS)) \ + CASE(float8_e4m3b11fnuz, SINGLE_ARG(STMTS)) \ + CASE(float8_e5m2fnuz, SINGLE_ARG(STMTS)) \ CASE(int4, SINGLE_ARG(STMTS)) \ CASE(uint4, SINGLE_ARG(STMTS)) \ case DT_INVALID: \ @@ -1509,9 +1525,10 @@ string Tensor::SummarizeValue(int64_t max_entries, bool print_v2) const { } } -StringPiece Tensor::tensor_data() const { - if (buf_ == nullptr) return StringPiece(); // Don't die for empty tensors - return StringPiece(static_cast(buf_->data()), TotalBytes()); +absl::string_view Tensor::tensor_data() const { + if (buf_ == nullptr) + return absl::string_view(); // Don't die for empty tensors + return absl::string_view(static_cast(buf_->data()), TotalBytes()); } void* Tensor::data() const { diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 6ca65799276f0a..8f80ea7c805da9 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -635,7 +635,7 @@ class Tensor { /// not get destroyed while the `StringPiece` is still used. /// /// REQUIRES: `DataTypeCanUseMemcpy(dtype())`. - StringPiece tensor_data() const; + absl::string_view tensor_data() const; void* data() const; /// Copy the other tensor into this tensor, reshape it and reinterpret the diff --git a/tensorflow/core/framework/tensor_slice.cc b/tensorflow/core/framework/tensor_slice.cc index c64f4157c57561..adddf678f218e4 100644 --- a/tensorflow/core/framework/tensor_slice.cc +++ b/tensorflow/core/framework/tensor_slice.cc @@ -88,7 +88,7 @@ absl::Status TensorSlice::Parse(const string& str, TensorSlice* slice) { } else { std::vector sl = str_util::Split(x, ',', str_util::SkipEmpty()); if (sl.size() != 2 || !strings::safe_strto64(sl[0], &s) || - !strings::safe_strto64(sl[1], &l)) { + !absl::SimpleAtoi(sl[1], &l)) { return errors::InvalidArgument( "Expected a pair of numbers or '-' " "but got '", diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index 2d8d182da28da9..1b6da6bd858389 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -254,6 +254,40 @@ TEST(Tensor_Float8_E4m3fn, Simple) { TestCopies(t); } +TEST(Tensor_Float8_E4m3fnuz, Simple) { + Tensor t(DT_FLOAT8_E4M3FNUZ, TensorShape({5, 7})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({5, 7}))); + for (int64_t a = 0; a < t.shape().dim_size(0); a++) { + for (int64_t b = 0; b < t.shape().dim_size(1); b++) { + t.matrix()(a, b) = static_cast(a * b); + } + } + TestCopies(t); +} + +TEST(Tensor_Float8_E4m3b11fnuz, Simple) { + Tensor t(DT_FLOAT8_E4M3B11FNUZ, TensorShape({5, 7})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({5, 7}))); + for (int64_t a = 0; a < t.shape().dim_size(0); a++) { + for (int64_t b = 0; b < t.shape().dim_size(1); b++) { + t.matrix()(a, b) = + static_cast(a * b); + } + } + TestCopies(t); +} + +TEST(Tensor_Float8_E5m2fnuz, Simple) { + Tensor t(DT_FLOAT8_E5M2FNUZ, TensorShape({5, 7})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({5, 7}))); + for (int64_t a = 0; a < t.shape().dim_size(0); a++) { + for (int64_t b = 0; b < t.shape().dim_size(1); b++) { + t.matrix()(a, b) = static_cast(a * b); + } + } + TestCopies(t); +} + TEST(Tensor_Float, Simple) { Tensor t(DT_FLOAT, TensorShape({10, 20})); EXPECT_TRUE(t.shape().IsSameSize(TensorShape({10, 20}))); diff --git a/tensorflow/core/framework/tensor_testutil.cc b/tensorflow/core/framework/tensor_testutil.cc index 3c015b1828dbb3..3cae72d95ff79e 100644 --- a/tensorflow/core/framework/tensor_testutil.cc +++ b/tensorflow/core/framework/tensor_testutil.cc @@ -272,6 +272,12 @@ void ExpectEqual(const Tensor& x, const Tensor& y, Tolerance t) { return ExpectEqual(x, y, t); case DT_FLOAT8_E4M3FN: return ExpectEqual(x, y, t); + case DT_FLOAT8_E4M3FNUZ: + return ExpectEqual(x, y, t); + case DT_FLOAT8_E4M3B11FNUZ: + return ExpectEqual(x, y, t); + case DT_FLOAT8_E5M2FNUZ: + return ExpectEqual(x, y, t); case DT_INT4: return ExpectEqual(x, y, t); case DT_UINT4: diff --git a/tensorflow/core/framework/tensor_util.cc b/tensorflow/core/framework/tensor_util.cc index f9131de632827a..ed44732409e7fd 100644 --- a/tensorflow/core/framework/tensor_util.cc +++ b/tensorflow/core/framework/tensor_util.cc @@ -39,12 +39,12 @@ Tensor DeepCopy(const Tensor& other) { void DeepCopy(const Tensor& input, Tensor* output) { if (DataTypeCanUseMemcpy(input.dtype())) { if (input.NumElements() > 0) { - StringPiece input_data = input.tensor_data(); + absl::string_view input_data = input.tensor_data(); // We use StringPiece as a convenient map over the tensor buffer, // but we cast the type to get to the underlying buffer to do the // copy. - StringPiece output_data = output->tensor_data(); + absl::string_view output_data = output->tensor_data(); memcpy(const_cast(output_data.data()), input_data.data(), input_data.size()); } @@ -85,12 +85,12 @@ absl::Status Concat(const absl::Span tensors, Tensor* result) { // We use StringPiece as a convenient map over the tensor buffer, // but we cast the type to get to the underlying buffer to do the // copy. - StringPiece to_data = result->tensor_data(); + absl::string_view to_data = result->tensor_data(); if (DataTypeCanUseMemcpy(dtype)) { int64_t offset = 0; for (const Tensor& tensor : tensors) { - StringPiece from_data = tensor.tensor_data(); + absl::string_view from_data = tensor.tensor_data(); CHECK_LE(offset + from_data.size(), to_data.size()); memcpy(const_cast(to_data.data()) + offset, from_data.data(), from_data.size()); @@ -134,7 +134,7 @@ absl::Status Split(const Tensor& tensor, const absl::Span sizes, "'tensor'"); } - StringPiece from_data = tensor.tensor_data(); + absl::string_view from_data = tensor.tensor_data(); if (DataTypeCanUseMemcpy(tensor.dtype())) { int64_t offset = 0; @@ -147,7 +147,7 @@ absl::Status Split(const Tensor& tensor, const absl::Span sizes, // We use StringPiece as a convenient map over the tensor buffer, // but we cast the type to get to the underlying buffer to do the // copy. - StringPiece to_data = split->tensor_data(); + absl::string_view to_data = split->tensor_data(); CHECK_LE(offset + to_data.size(), from_data.size()); memcpy(const_cast(to_data.data()), from_data.data() + offset, to_data.size()); diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h index ee607ff5b8d5be..eec2bd3f018ddf 100644 --- a/tensorflow/core/framework/tensor_util.h +++ b/tensorflow/core/framework/tensor_util.h @@ -49,8 +49,7 @@ void DeepCopy(const Tensor& input, Tensor* output); // REQUIRES: Each member of 'tensors' must point to data stored in CPU memory. // REQUIRES: Each member of 'tensors' must be a Tensor of a copy-able type if it // is not appropriately memory-aligned. -absl::Status Concat(absl::Span tensors, - Tensor* result) TF_MUST_USE_RESULT; +absl::Status Concat(absl::Span tensors, Tensor* result); // Splits 'tensor' into 'sizes.size()' individual tensors, along the 0th // dimension. The ith output tensor has 0th-dimension size 'sizes[i]'. @@ -63,7 +62,7 @@ absl::Status Concat(absl::Span tensors, // // Split() and Concat() are inverse operations. absl::Status Split(const Tensor& tensor, absl::Span sizes, - std::vector* result) TF_MUST_USE_RESULT; + std::vector* result); namespace internal { void SetTensorProtoShape(absl::Span shape, diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc index d1e42814d75f92..1f2ba385744fc7 100644 --- a/tensorflow/core/framework/types.cc +++ b/tensorflow/core/framework/types.cc @@ -132,6 +132,12 @@ string DataTypeStringInternal(DataType dtype) { return "float8_e5m2"; case DT_FLOAT8_E4M3FN: return "float8_e4m3fn"; + case DT_FLOAT8_E4M3FNUZ: + return "float8_e4m3fnuz"; + case DT_FLOAT8_E4M3B11FNUZ: + return "float8_e4m3b11fnuz"; + case DT_FLOAT8_E5M2FNUZ: + return "float8_e5m2fnuz"; case DT_INT4: return "int4"; case DT_UINT4: @@ -155,7 +161,7 @@ string DataTypeString(DataType dtype) { return DataTypeStringInternal(dtype); } -bool DataTypeFromString(StringPiece sp, DataType* dt) { +bool DataTypeFromString(absl::string_view sp, DataType* dt) { if (absl::EndsWith(sp, "_ref")) { sp.remove_suffix(4); DataType non_ref; @@ -236,6 +242,15 @@ bool DataTypeFromString(StringPiece sp, DataType* dt) { } else if (sp == "float8_e4m3fn") { *dt = DT_FLOAT8_E4M3FN; return true; + } else if (sp == "float8_e4m3fnuz") { + *dt = DT_FLOAT8_E4M3FNUZ; + return true; + } else if (sp == "float8_e4m3b11fnuz") { + *dt = DT_FLOAT8_E4M3B11FNUZ; + return true; + } else if (sp == "float8_e5m2fnuz") { + *dt = DT_FLOAT8_E5M2FNUZ; + return true; } else if (sp == "int4") { *dt = DT_INT4; return true; @@ -291,6 +306,9 @@ int DataTypeSize(DataType dt) { TF_CALL_quint16(CASE); TF_CALL_float8_e5m2(CASE); TF_CALL_float8_e4m3fn(CASE); + TF_CALL_float8_e4m3fnuz(CASE); + TF_CALL_float8_e4m3b11fnuz(CASE); + TF_CALL_float8_e5m2fnuz(CASE); TF_CALL_int4(CASE); TF_CALL_uint4(CASE); @@ -327,6 +345,9 @@ DEFINE_DATATYPETOENUM_VALUE(bfloat16); DEFINE_DATATYPETOENUM_VALUE(Eigen::half); DEFINE_DATATYPETOENUM_VALUE(float8_e5m2); DEFINE_DATATYPETOENUM_VALUE(float8_e4m3fn); +DEFINE_DATATYPETOENUM_VALUE(float8_e4m3fnuz); +DEFINE_DATATYPETOENUM_VALUE(float8_e4m3b11fnuz); +DEFINE_DATATYPETOENUM_VALUE(float8_e5m2fnuz); DEFINE_DATATYPETOENUM_VALUE(int4); DEFINE_DATATYPETOENUM_VALUE(uint4); DEFINE_DATATYPETOENUM_VALUE(ResourceHandle); diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h index 85f1c519f8ae29..177de7e9fe0587 100644 --- a/tensorflow/core/framework/types.h +++ b/tensorflow/core/framework/types.h @@ -168,7 +168,7 @@ class DataTypeSet { // If "sp" names a valid type, store it in "*dt" and return true. Otherwise, // return false. -bool DataTypeFromString(StringPiece sp, DataType* dt); +bool DataTypeFromString(absl::string_view sp, DataType* dt); constexpr inline DataTypeSet ToSet(DataType dt) { return DataTypeSet(1u << static_cast(dt)); @@ -205,7 +205,8 @@ constexpr DataTypeSet kAllTypes = ToSet(DT_QUINT16) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_RESOURCE) | ToSet(DT_VARIANT) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | ToSet(DT_BFLOAT16) | ToSet(DT_FLOAT8_E5M2) | ToSet(DT_FLOAT8_E4M3FN) | - ToSet(DT_INT4) | ToSet(DT_UINT4); + ToSet(DT_FLOAT8_E4M3FNUZ) | ToSet(DT_FLOAT8_E4M3B11FNUZ) | + ToSet(DT_FLOAT8_E5M2FNUZ) | ToSet(DT_INT4) | ToSet(DT_UINT4); inline const DataTypeSet& AllTypes() { return kAllTypes; } @@ -342,6 +343,9 @@ MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16); MATCH_TYPE_AND_ENUM(Eigen::half, DT_HALF); MATCH_TYPE_AND_ENUM(float8_e5m2, DT_FLOAT8_E5M2); MATCH_TYPE_AND_ENUM(float8_e4m3fn, DT_FLOAT8_E4M3FN); +MATCH_TYPE_AND_ENUM(float8_e4m3fnuz, DT_FLOAT8_E4M3FNUZ); +MATCH_TYPE_AND_ENUM(float8_e4m3b11fnuz, DT_FLOAT8_E4M3B11FNUZ); +MATCH_TYPE_AND_ENUM(float8_e5m2fnuz, DT_FLOAT8_E5M2FNUZ); MATCH_TYPE_AND_ENUM(int4, DT_INT4); MATCH_TYPE_AND_ENUM(uint4, DT_UINT4); MATCH_TYPE_AND_ENUM(ResourceHandle, DT_RESOURCE); @@ -421,7 +425,9 @@ constexpr DataTypeSet kDataTypesCanUseMemcpy = ToSet(DT_UINT64) | ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) | ToSet(DT_BFLOAT16) | ToSet(DT_HALF) | ToSet(DT_FLOAT8_E5M2) | - ToSet(DT_FLOAT8_E4M3FN) | ToSet(DT_INT4) | ToSet(DT_UINT4); + ToSet(DT_FLOAT8_E4M3FN) | ToSet(DT_FLOAT8_E4M3FNUZ) | + ToSet(DT_FLOAT8_E4M3B11FNUZ) | ToSet(DT_FLOAT8_E5M2FNUZ) | ToSet(DT_INT4) | + ToSet(DT_UINT4); inline bool DataTypeCanUseMemcpy(DataType dt) { return kDataTypesCanUseMemcpy.Contains(dt); } @@ -429,7 +435,9 @@ inline bool DataTypeCanUseMemcpy(DataType dt) { // Returns true iff 'dt' is a real, non-quantized floating point type. constexpr DataTypeSet kDataTypeIsFloating = ToSet(DT_HALF) | ToSet(DT_BFLOAT16) | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | - ToSet(DT_FLOAT8_E4M3FN) | ToSet(DT_FLOAT8_E5M2); + ToSet(DT_FLOAT8_E4M3FN) | ToSet(DT_FLOAT8_E5M2) | + ToSet(DT_FLOAT8_E4M3FNUZ) | ToSet(DT_FLOAT8_E4M3B11FNUZ) | + ToSet(DT_FLOAT8_E5M2FNUZ); inline bool DataTypeIsFloating(DataType dt) { return kDataTypeIsFloating.Contains(dt); } diff --git a/tensorflow/core/framework/types.proto b/tensorflow/core/framework/types.proto index d0a973845f9fed..1f4858d11c4a18 100644 --- a/tensorflow/core/framework/types.proto +++ b/tensorflow/core/framework/types.proto @@ -43,10 +43,13 @@ enum DataType { DT_FLOAT8_E5M2 = 24; // 5 exponent bits, 2 mantissa bits. DT_FLOAT8_E4M3FN = 25; // 4 exponent bits, 3 mantissa bits, finite-only, with // 2 NaNs (0bS1111111). - // TODO - b/299182407: Leaving room for remaining float8 types. - // DT_FLOAT8_E4M3FNUZ = 26; - // DT_FLOAT8_E4M3B11FNUZ = 27; - // DT_FLOAT8_E5M2FNUZ = 28; + DT_FLOAT8_E4M3FNUZ = 26; // 4 exponent bits, 3 mantissa bits, finite-only, + // with NaN. + DT_FLOAT8_E4M3B11FNUZ = 27; // 4 exponent bits, 3 mantissa bits, 11 bits + // bias, finite-only, with NaNs. + DT_FLOAT8_E5M2FNUZ = 28; // 5 exponent bits, 2 mantissa bits, finite-only, + // with NaN. + DT_INT4 = 29; DT_UINT4 = 30; @@ -78,10 +81,10 @@ enum DataType { DT_UINT64_REF = 123; DT_FLOAT8_E5M2_REF = 124; DT_FLOAT8_E4M3FN_REF = 125; - // TODO - b/299182407: Leaving room for remaining float8 types. - // DT_FLOAT8_E4M3FNUZ_REF = 126; - // DT_FLOAT8_E4M3B11FNUZ_REF = 127; - // DT_FLOAT8_E5M2FNUZ_REF = 128; + + DT_FLOAT8_E4M3FNUZ_REF = 126; + DT_FLOAT8_E4M3B11FNUZ_REF = 127; + DT_FLOAT8_E5M2FNUZ_REF = 128; DT_INT4_REF = 129; DT_UINT4_REF = 130; } diff --git a/tensorflow/core/framework/types_test.cc b/tensorflow/core/framework/types_test.cc index 35fa1383a6cf48..031b4a4efe98e9 100644 --- a/tensorflow/core/framework/types_test.cc +++ b/tensorflow/core/framework/types_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" +#include #include "absl/strings/string_view.h" #include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/framework/types.pb.h" @@ -109,6 +110,12 @@ TEST(TypesTest, DataTypeFromString) { EXPECT_EQ(DT_FLOAT8_E5M2, dt); ASSERT_TRUE(DataTypeFromString("float8_e4m3fn", &dt)); EXPECT_EQ(DT_FLOAT8_E4M3FN, dt); + ASSERT_TRUE(DataTypeFromString("float8_e4m3fnuz", &dt)); + EXPECT_EQ(DT_FLOAT8_E4M3FNUZ, dt); + ASSERT_TRUE(DataTypeFromString("float8_e4m3b11fnuz", &dt)); + EXPECT_EQ(DT_FLOAT8_E4M3B11FNUZ, dt); + ASSERT_TRUE(DataTypeFromString("float8_e5m2fnuz", &dt)); + EXPECT_EQ(DT_FLOAT8_E5M2FNUZ, dt); ASSERT_TRUE(DataTypeFromString("int4", &dt)); EXPECT_EQ(DT_INT4, dt); ASSERT_TRUE(DataTypeFromString("uint4", &dt)); @@ -144,6 +151,9 @@ TEST(TypesTest, QuantizedTypes) { EXPECT_FALSE(DataTypeIsQuantized(DT_BFLOAT16)); EXPECT_FALSE(DataTypeIsQuantized(DT_FLOAT8_E5M2)); EXPECT_FALSE(DataTypeIsQuantized(DT_FLOAT8_E4M3FN)); + EXPECT_FALSE(DataTypeIsQuantized(DT_FLOAT8_E4M3FNUZ)); + EXPECT_FALSE(DataTypeIsQuantized(DT_FLOAT8_E4M3B11FNUZ)); + EXPECT_FALSE(DataTypeIsQuantized(DT_FLOAT8_E5M2FNUZ)); EXPECT_FALSE(DataTypeIsQuantized(DT_UINT4)); EXPECT_FALSE(DataTypeIsQuantized(DT_INT4)); } diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc index 306f6a6fec743d..225da86665613d 100644 --- a/tensorflow/core/framework/variant_op_registry.cc +++ b/tensorflow/core/framework/variant_op_registry.cc @@ -63,7 +63,7 @@ UnaryVariantOpRegistry* UnaryVariantOpRegistryGlobal() { } UnaryVariantOpRegistry::VariantDecodeFn* UnaryVariantOpRegistry::GetDecodeFn( - StringPiece type_name) { + absl::string_view type_name) { auto found = decode_fns.find(type_name); if (found == decode_fns.end()) return nullptr; return &found->second; @@ -76,7 +76,7 @@ void UnaryVariantOpRegistry::RegisterDecodeFn( CHECK_EQ(existing, nullptr) << "Unary VariantDecodeFn for type_name: " << type_name << " already registered"; - decode_fns.insert(std::pair( + decode_fns.insert(std::pair( GetPersistentStringPiece(type_name), decode_fn)); } diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h index f75177f712be74..c7d8680d31bfbe 100644 --- a/tensorflow/core/framework/variant_op_registry.h +++ b/tensorflow/core/framework/variant_op_registry.h @@ -105,7 +105,7 @@ class UnaryVariantOpRegistry { const VariantDecodeFn& decode_fn); // Returns nullptr if no decode function was found for the given TypeName. - VariantDecodeFn* GetDecodeFn(StringPiece type_name); + VariantDecodeFn* GetDecodeFn(absl::string_view type_name); // Add a copy-to-GPU function to the registry. void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction, @@ -146,7 +146,7 @@ class UnaryVariantOpRegistry { // Returns nullptr if no unary op function was found for the given // op, device, and TypeName. - VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device, + VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, absl::string_view device, const TypeIndex& type_index) { auto found = unary_op_fns.find({op, device, type_index}); if (found == unary_op_fns.end()) return nullptr; @@ -169,7 +169,7 @@ class UnaryVariantOpRegistry { // Returns nullptr if no binary op function was found for the given // op, device and TypeName. - VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device, + VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, absl::string_view device, const TypeIndex& type_index) { auto found = binary_op_fns.find({op, device, type_index}); if (found == binary_op_fns.end()) return nullptr; @@ -195,7 +195,8 @@ class UnaryVariantOpRegistry { std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); } }; - gtl::FlatMap decode_fns; + gtl::FlatMap + decode_fns; // Map std::pair to function. struct PairHash { @@ -219,10 +220,11 @@ class UnaryVariantOpRegistry { // and references therein template struct FuncTuple { - FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index) + FuncTuple(const Op& op, const absl::string_view& dev, + const TypeIndex& type_index) : op_type_(op), device_(dev), type_index_(type_index) {} Op op_type_; - StringPiece device_; + absl::string_view device_; TypeIndex type_index_; }; // friend declaration for operator== @@ -232,7 +234,7 @@ class UnaryVariantOpRegistry { struct TupleHash { template std::size_t operator()( - const std::tuple& x) const { + const std::tuple& x) const { // The hash of an enum is just its value as a std::size_t. std::size_t ret = static_cast(std::get<0>(x)); ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x))); @@ -258,14 +260,14 @@ class UnaryVariantOpRegistry { // Find or insert a string into a persistent string storage // container; return the StringPiece pointing to the permanent string // location. - static StringPiece GetPersistentStringPiece(const std::string& str) { + static absl::string_view GetPersistentStringPiece(const std::string& str) { const auto string_storage = PersistentStringStorage(); auto found = string_storage->find(str); if (found == string_storage->end()) { auto inserted = string_storage->insert(str); - return StringPiece(*inserted.first); + return absl::string_view(*inserted.first); } else { - return StringPiece(*found); + return absl::string_view(*found); } } }; diff --git a/tensorflow/core/function/runtime_client/runtime_client.cc b/tensorflow/core/function/runtime_client/runtime_client.cc index 4566d80ab3bc6b..b38e293026111b 100644 --- a/tensorflow/core/function/runtime_client/runtime_client.cc +++ b/tensorflow/core/function/runtime_client/runtime_client.cc @@ -93,7 +93,7 @@ EagerContext& GlobalPythonEagerContext() { return *ctx; } -absl::StatusOr Runtime::GetFunctionProto(StringPiece name) { +absl::StatusOr Runtime::GetFunctionProto(absl::string_view name) { EagerContext& ctx = this->eager_ctx_; const FunctionDef* f = ctx.FindFunctionDef(std::string(name)); @@ -134,8 +134,8 @@ absl::Status Runtime::CreateFunction(OpaqueTfFuncOp* fop) { return CreateFunction(fdef); } -absl::Status Runtime::TransformFunction(StringPiece name, - StringPiece pipeline_name, +absl::Status Runtime::TransformFunction(absl::string_view name, + absl::string_view pipeline_name, Dialect dialect) { // TODO(mdan): Use a longer-lived context. mlir::MLIRContext ctx; @@ -221,7 +221,7 @@ absl::Status Runtime::TransformFunction(StringPiece name, } absl::StatusOr Runtime::CallFunction( - StringPiece name, absl::Span args) { + absl::string_view name, absl::Span args) { EagerContext& ctx = this->eager_ctx_; ImmediateOpPtr op(ctx.CreateOperation()); diff --git a/tensorflow/core/function/runtime_client/runtime_client.h b/tensorflow/core/function/runtime_client/runtime_client.h index d26c09b3a9db3b..789788fbe37d09 100644 --- a/tensorflow/core/function/runtime_client/runtime_client.h +++ b/tensorflow/core/function/runtime_client/runtime_client.h @@ -70,7 +70,7 @@ class Runtime { TF, }; - absl::StatusOr GetFunctionProto(StringPiece name); + absl::StatusOr GetFunctionProto(absl::string_view name); // TODO(mdan): Enforce creation or rename to SetFunction. absl::Status CreateFunction(const FunctionDef& fdef); @@ -82,11 +82,12 @@ class Runtime { // The pipeline may rename the function. If it does so, the old function // remains unchanged. If the new name specifies an existing function, it will // be overwritten. - absl::Status TransformFunction(StringPiece name, StringPiece pipeline_name, + absl::Status TransformFunction(absl::string_view name, + absl::string_view pipeline_name, Dialect dialect = Dialect::TFG); absl::StatusOr CallFunction( - StringPiece name, absl::Span args); + absl::string_view name, absl::Span args); private: EagerContext& eager_ctx_; diff --git a/tensorflow/core/function/runtime_client/runtime_client_pybind.pyi b/tensorflow/core/function/runtime_client/runtime_client_pybind.pyi index 20809986cb47ed..77fb63ec719c51 100644 --- a/tensorflow/core/function/runtime_client/runtime_client_pybind.pyi +++ b/tensorflow/core/function/runtime_client/runtime_client_pybind.pyi @@ -26,12 +26,10 @@ class Runtime: __entries: ClassVar[dict] = ... def __init__(self, value: int) -> None: ... def __eq__(self, other: object) -> bool: ... - def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: int) -> None: ... @property def name(self) -> str: ... @property diff --git a/tensorflow/core/graph/costmodel.h b/tensorflow/core/graph/costmodel.h index 9f7aa35fbecc59..795d94720415b5 100644 --- a/tensorflow/core/graph/costmodel.h +++ b/tensorflow/core/graph/costmodel.h @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { -typedef std::unordered_map +typedef std::unordered_map NodeNameToCostIdMap; class StepStats; diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index a06187cdfeb8e5..cb9b7be66bbeae 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -400,7 +400,7 @@ NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef) : NodeDebugInfo(ndef.name(), ndef.has_experimental_debug_info(), ndef.experimental_debug_info()) {} NodeDebugInfo::NodeDebugInfo( - StringPiece node_name, bool has_experimental_debug_info, + absl::string_view node_name, bool has_experimental_debug_info, const NodeDef_ExperimentalDebugInfo& experimental_debug_info) : name(node_name) { if (has_experimental_debug_info) { @@ -750,7 +750,7 @@ absl::Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst, return absl::OkStatus(); } -void Graph::AddInput(NodeDef* dst, StringPiece src_name, int src_slot) { +void Graph::AddInput(NodeDef* dst, absl::string_view src_name, int src_slot) { if (src_slot == Graph::kControlSlot) { dst->add_input(strings::StrCat("^", src_name)); } else if (src_slot == 0) { @@ -911,7 +911,7 @@ void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id, } } -std::string Graph::NewName(StringPiece prefix) { +std::string Graph::NewName(absl::string_view prefix) { return strings::StrCat(prefix, "/_", name_counter_++); } @@ -1005,7 +1005,7 @@ int Graph::InternDeviceName(const std::string& device_name) { return index; } -absl::Status Graph::AddWhileContext(StringPiece frame_name, +absl::Status Graph::AddWhileContext(absl::string_view frame_name, std::vector enter_nodes, std::vector exit_nodes, OutputTensor cond_output, @@ -1034,7 +1034,7 @@ std::unordered_map Graph::BuildNodeNameIndex() const { return result; } -void Graph::SetNodeType(StringPiece name, const FullTypeDef& ft) { +void Graph::SetNodeType(absl::string_view name, const FullTypeDef& ft) { for (Node* n : op_nodes()) { if (n->name() == name) { NodeDef& node_def = n->props_->node_def; @@ -1045,7 +1045,7 @@ void Graph::SetNodeType(StringPiece name, const FullTypeDef& ft) { } } -void Graph::NodeType(StringPiece name, const FullTypeDef** result) { +void Graph::NodeType(absl::string_view name, const FullTypeDef** result) { *result = nullptr; for (Node* n : op_nodes()) { if (n->name() == name) { diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 68905818f403f9..6e70b0cdfa8322 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -388,7 +388,7 @@ struct NodeDebugInfo { NodeDebugInfo(const Node& n); NodeDebugInfo(const NodeDef& ndef); - NodeDebugInfo(StringPiece node_name, bool has_experimental_debug_info, + NodeDebugInfo(absl::string_view node_name, bool has_experimental_debug_info, const NodeDef_ExperimentalDebugInfo& experimental_debug_info); }; @@ -619,7 +619,7 @@ class Graph { // Add an input to dst that comes from the "src_slot" output of the // node named by "src_name". - static void AddInput(NodeDef* dst, StringPiece src_name, int src_slot); + static void AddInput(NodeDef* dst, absl::string_view src_name, int src_slot); // Like AddEdge but updates dst's NodeDef. Used to add an input edge to a // "While" op during gradient construction, see AddInputWhileHack in @@ -719,7 +719,7 @@ class Graph { // Generate new node name with the specified prefix that is unique // across this graph. - std::string NewName(StringPiece prefix); + std::string NewName(absl::string_view prefix); // Access to the list of all nodes. Example usage: // for (Node* node : graph.nodes()) { ... } @@ -794,7 +794,7 @@ class Graph { // Create and return a new WhileContext owned by this graph. This is called // when a new while loop is created. `frame_name` must be unique among // WhileContexts in this graph. - absl::Status AddWhileContext(StringPiece frame_name, + absl::Status AddWhileContext(absl::string_view frame_name, std::vector enter_nodes, std::vector exit_nodes, OutputTensor cond_output, @@ -828,7 +828,7 @@ class Graph { // future, an alternative method could be added that takes in a flat_hash_map // of name: type and simply iterates through the graph once and annotates all // nodes. - void SetNodeType(StringPiece name, const FullTypeDef& type); + void SetNodeType(absl::string_view name, const FullTypeDef& type); // Get full type information for a node given its name. // Note that if this is called in a loop iterating over all the nodes @@ -836,7 +836,7 @@ class Graph { // future, an alternative method could be added that takes in flat_hash_map of // name: type and simply iterates through the graph once and stores all the // information in the map. - void NodeType(StringPiece name, const FullTypeDef** result); + void NodeType(absl::string_view name, const FullTypeDef** result); // Builds a GraphDebugInfo from the functions and nodes in this graph. Stack // traces associated with function definitions will have a key of the form diff --git a/tensorflow/core/graph/graph_def_builder.cc b/tensorflow/core/graph/graph_def_builder.cc index b8734f662c5fe8..168fc1a0da3da7 100644 --- a/tensorflow/core/graph/graph_def_builder.cc +++ b/tensorflow/core/graph/graph_def_builder.cc @@ -27,11 +27,11 @@ GraphDefBuilder::Options::Options(Graph* graph, absl::Status* status) GraphDefBuilder::Options::~Options() {} GraphDefBuilder::Options GraphDefBuilder::Options::WithName( - StringPiece name) const { + absl::string_view name) const { return Options(*this).WithNameImpl(name); } GraphDefBuilder::Options GraphDefBuilder::Options::WithDevice( - StringPiece device) const { + absl::string_view device) const { return Options(*this).WithDeviceImpl(device); } GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInput( @@ -43,12 +43,12 @@ GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputs( return Options(*this).WithControlInputsImpl(control_inputs); } GraphDefBuilder::Options GraphDefBuilder::Options::WithNameImpl( - StringPiece name) { + absl::string_view name) { name_ = string(name); return *this; } GraphDefBuilder::Options GraphDefBuilder::Options::WithDeviceImpl( - StringPiece device) { + absl::string_view device) { device_ = string(device); return *this; } @@ -72,7 +72,7 @@ absl::Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const { return status_; } -string GraphDefBuilder::Options::GetNameForOp(StringPiece op) const { +string GraphDefBuilder::Options::GetNameForOp(absl::string_view op) const { if (name_.empty()) return graph_->NewName(op); return name_; } diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h index bc44649302172f..b635ece0eab707 100644 --- a/tensorflow/core/graph/graph_def_builder.h +++ b/tensorflow/core/graph/graph_def_builder.h @@ -79,19 +79,19 @@ class GraphDefBuilder { // Methods for setting options. These are const methods: they // return a copy of *this with the option set. - Options WithName(StringPiece name) const; - Options WithDevice(StringPiece device) const; + Options WithName(absl::string_view name) const; + Options WithDevice(absl::string_view device) const; Options WithControlInput(Node* control_input) const; Options WithControlInputs(absl::Span control_inputs) const; // Override the default value for an optional attr. template - Options WithAttr(StringPiece attr_name, T&& value) const { + Options WithAttr(absl::string_view attr_name, T&& value) const { return Options(*this).WithAttrImpl(attr_name, std::forward(value)); } // Note: overload needed to allow {...} expressions for value. template - Options WithAttr(StringPiece attr_name, + Options WithAttr(absl::string_view attr_name, std::initializer_list value) const { return WithAttr>(attr_name, std::move(value)); } @@ -111,7 +111,7 @@ class GraphDefBuilder { // Given the Op type name, return a name for a node of that type. // Uses the value set in WithName() if that has been called. Otherwise, // returns a name built out of the Op type name. - string GetNameForOp(StringPiece op) const; + string GetNameForOp(absl::string_view op) const; // Sets the device, adds control inputs, adds attrs, and calls Finalize(). // If Finalize returns an error, it is saved and this function returns @@ -127,12 +127,12 @@ class GraphDefBuilder { } private: - Options WithNameImpl(StringPiece name); - Options WithDeviceImpl(StringPiece device); + Options WithNameImpl(absl::string_view name); + Options WithDeviceImpl(absl::string_view device); Options WithControlInputImpl(Node* control_input); Options WithControlInputsImpl(absl::Span control_inputs); template - Options WithAttrImpl(StringPiece name, T&& value) { + Options WithAttrImpl(absl::string_view name, T&& value) { attrs_.emplace_back(string(name), AttrValue()); SetAttrValue(std::forward(value), &attrs_.back().second); return *this; diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 8e31106e70a58f..0b08a127cdbd13 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -936,7 +936,7 @@ absl::Status AddControlEdges(const PartitionOptions& opts, // If 'ndef' is a Send or Recv, fills its attr send_device_incarnation // if possible. void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) { - StringPiece op(ndef->op()); + absl::string_view op(ndef->op()); if (op != "_Send" && op != "_Recv") { // Not related to send/recv. return; diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index 96e5768941228d..d5fb09171fe9b9 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -36,18 +36,18 @@ NodeBuilder::NodeOut::NodeOut(Node* n, int32_t i) // NOLINT(runtime/explicit) NodeBuilder::NodeOut::NodeOut(OutputTensor t) : NodeOut(t.node, t.index) {} -NodeBuilder::NodeOut::NodeOut(StringPiece n, int32_t i, DataType t) +NodeBuilder::NodeOut::NodeOut(absl::string_view n, int32_t i, DataType t) : node(nullptr), error(false), name(n), index(i), dt(t) {} NodeBuilder::NodeOut::NodeOut() : node(nullptr), error(true), index(0), dt(DT_FLOAT) {} -NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name, +NodeBuilder::NodeBuilder(absl::string_view name, absl::string_view op_name, const OpRegistryInterface* op_registry, const NodeDebugInfo* debug) : def_builder_(name, op_name, op_registry, debug) {} -NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def) +NodeBuilder::NodeBuilder(absl::string_view name, const OpDef* op_def) : def_builder_(name, op_def) {} NodeBuilder::NodeBuilder(const NodeDefBuilder& def_builder) @@ -102,17 +102,17 @@ NodeBuilder& NodeBuilder::ControlInputs(absl::Span src_nodes) { return *this; } -NodeBuilder& NodeBuilder::Device(StringPiece device_spec) { +NodeBuilder& NodeBuilder::Device(absl::string_view device_spec) { def_builder_.Device(device_spec); return *this; } -NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) { +NodeBuilder& NodeBuilder::AssignedDevice(absl::string_view device) { assigned_device_ = string(device); return *this; } -NodeBuilder& NodeBuilder::XlaCluster(StringPiece xla_cluster) { +NodeBuilder& NodeBuilder::XlaCluster(absl::string_view xla_cluster) { def_builder_.Attr("_XlaCluster", xla_cluster); return *this; } diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h index 0d5bf9fb9a240c..6f249371606b3e 100644 --- a/tensorflow/core/graph/node_builder.h +++ b/tensorflow/core/graph/node_builder.h @@ -56,7 +56,7 @@ class NodeBuilder { // useful when preparing a graph for ExtendSession or creating a // back edge to a node that hasn't been added to the graph yet, // but will be. - NodeOut(StringPiece name, int32_t i, DataType t); + NodeOut(absl::string_view name, int32_t i, DataType t); // Default constructor for std::vector. NodeOut(); @@ -76,10 +76,10 @@ class NodeBuilder { // the Op plus a registry) for the Node. Other fields are // specified by calling the methods below. // REQUIRES: The OpDef must satisfy ValidateOpDef(). - NodeBuilder(StringPiece name, StringPiece op_name, + NodeBuilder(absl::string_view name, absl::string_view op_name, const OpRegistryInterface* op_registry = OpRegistry::Global(), const NodeDebugInfo* debug = nullptr); - NodeBuilder(StringPiece name, const OpDef* op_def); + NodeBuilder(absl::string_view name, const OpDef* op_def); // Create a NodeBuilder from an existing NodeDefBuilder. NodeBuilder(const NodeDefBuilder& def_builder); @@ -100,13 +100,13 @@ class NodeBuilder { // Sets the "requested device spec" in the NodeDef (not the // "assigned device" in the Node). - NodeBuilder& Device(StringPiece device_spec); + NodeBuilder& Device(absl::string_view device_spec); // Sets the device name in the "assigned device" field in tensorflow::Node. - NodeBuilder& AssignedDevice(StringPiece device); + NodeBuilder& AssignedDevice(absl::string_view device); // Sets the _XlaCluster attribute in created node to `xla_cluster`. - NodeBuilder& XlaCluster(StringPiece xla_cluster); + NodeBuilder& XlaCluster(absl::string_view xla_cluster); // Set the value of an attr. attr_name must match the name of one of // attrs defined by the Op, and value must have the corresponding type @@ -114,9 +114,10 @@ class NodeBuilder { // types for value). Note that attrs will be set automatically if // they can be determined by the inputs. template - NodeBuilder& Attr(StringPiece attr_name, T&& value); + NodeBuilder& Attr(absl::string_view attr_name, T&& value); template - NodeBuilder& Attr(StringPiece attr_name, std::initializer_list value); + NodeBuilder& Attr(absl::string_view attr_name, + std::initializer_list value); // Validates the described node and adds it to *graph, adding edges // for all (non-back) inputs. If created_node is not nullptr, @@ -163,13 +164,13 @@ class NodeBuilder { // IMPLEMENTATION ------------------------------------------------------------- template -NodeBuilder& NodeBuilder::Attr(StringPiece attr_name, T&& value) { +NodeBuilder& NodeBuilder::Attr(absl::string_view attr_name, T&& value) { def_builder_.Attr(attr_name, std::forward(value)); return *this; } template -NodeBuilder& NodeBuilder::Attr(StringPiece attr_name, +NodeBuilder& NodeBuilder::Attr(absl::string_view attr_name, std::initializer_list value) { def_builder_.Attr(attr_name, value); return *this; diff --git a/tensorflow/core/graph/regularization/util.cc b/tensorflow/core/graph/regularization/util.cc index 5df68d71cd4fd9..e81fbee6aa98b0 100644 --- a/tensorflow/core/graph/regularization/util.cc +++ b/tensorflow/core/graph/regularization/util.cc @@ -42,7 +42,7 @@ absl::StatusOr GetSuffixUID(absl::string_view function_name) { std::vector v = absl::StrSplit(function_name, '_'); int64_t uid; - if (!strings::safe_strto64(v.back(), &uid)) { + if (!absl::SimpleAtoi(v.back(), &uid)) { return errors::InvalidArgument(absl::StrCat( "Function name: `", function_name, "` does not end in an integer.")); } diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc index 8c73691fd6ba56..bb47f37ef7fbe3 100644 --- a/tensorflow/core/graph/subgraph.cc +++ b/tensorflow/core/graph/subgraph.cc @@ -43,7 +43,8 @@ namespace subgraph { namespace { -typedef std::unordered_map NameIndex; +typedef std::unordered_map + NameIndex; // Rewrite graph by replacing the output tensors specified in // "fed_outputs" with special feed nodes for each specified output diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc index 9d86672dd94f37..248daa9b5f6651 100644 --- a/tensorflow/core/graph/subgraph_test.cc +++ b/tensorflow/core/graph/subgraph_test.cc @@ -312,7 +312,7 @@ TEST_F(SubgraphTest, ChainOfFools) { EXPECT_TRUE(HasEdge("e", 0, "_send_e_0", 0)); } -static bool HasSubstr(StringPiece base, StringPiece substr) { +static bool HasSubstr(absl::string_view base, absl::string_view substr) { bool ok = absl::StrContains(base, substr); EXPECT_TRUE(ok) << base << ", expected substring " << substr; return ok; diff --git a/tensorflow/core/graph/tensor_id.cc b/tensorflow/core/graph/tensor_id.cc index fc04177363c441..7cdd046c48a806 100644 --- a/tensorflow/core/graph/tensor_id.cc +++ b/tensorflow/core/graph/tensor_id.cc @@ -28,10 +28,10 @@ SafeTensorId::SafeTensorId(const TensorId& id) : SafeTensorId(string(id.first), id.second) {} TensorId ParseTensorName(const string& name) { - return ParseTensorName(StringPiece(name.data(), name.size())); + return ParseTensorName(absl::string_view(name.data(), name.size())); } -TensorId ParseTensorName(StringPiece name) { +TensorId ParseTensorName(absl::string_view name) { // Parse either a name, ^name, or name:digits. To do so, we go backwards from // the end of the string, skipping over a run of digits. If we hit a ':' // character, then we know we are in the 'name:digits' regime. Otherwise, we @@ -49,11 +49,11 @@ TensorId ParseTensorName(StringPiece name) { } TensorId id; if (p > base && *p == ':' && mul > 1) { - id.first = StringPiece(base, p - base); + id.first = absl::string_view(base, p - base); id.second = index; } else if (absl::StartsWith(name, "^")) { // Control edge - id.first = StringPiece(base + 1); + id.first = absl::string_view(base + 1); id.second = Graph::kControlSlot; } else { id.first = name; diff --git a/tensorflow/core/graph/tensor_id.h b/tensorflow/core/graph/tensor_id.h index c593f96b0b329d..0cdfb7d9cec6ed 100644 --- a/tensorflow/core/graph/tensor_id.h +++ b/tensorflow/core/graph/tensor_id.h @@ -30,8 +30,8 @@ struct SafeTensorId; // Identifier for a tensor within a step. // first == operation_name, second == output_index // Note: does not own backing storage for name. -struct TensorId : public std::pair { - typedef std::pair Base; +struct TensorId : public std::pair { + typedef std::pair Base; // Inherit the set of constructors. using Base::pair; @@ -41,7 +41,7 @@ struct TensorId : public std::pair { TensorId() : Base() {} TensorId(const SafeTensorId& id); - const StringPiece node() const { return first; } + const absl::string_view node() const { return first; } int index() const { return second; } string ToString() const { @@ -58,7 +58,7 @@ struct TensorId : public std::pair { }; TensorId ParseTensorName(const string& name); -TensorId ParseTensorName(StringPiece name); +TensorId ParseTensorName(absl::string_view name); bool IsTensorIdControl(const TensorId& tensor_id); diff --git a/tensorflow/core/graph/while_context.h b/tensorflow/core/graph/while_context.h index 5405e62be2f3c5..e23e9df90afd2d 100644 --- a/tensorflow/core/graph/while_context.h +++ b/tensorflow/core/graph/while_context.h @@ -34,7 +34,7 @@ namespace tensorflow { // differentiable. Figure out backwards compatibility story. class WhileContext { public: - WhileContext(StringPiece frame_name, std::vector enter_nodes, + WhileContext(absl::string_view frame_name, std::vector enter_nodes, std::vector exit_nodes, OutputTensor cond_output, std::vector body_inputs, std::vector body_outputs); diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD index ae954ee3863c3d..caec0e11560e4f 100644 --- a/tensorflow/core/grappler/clusters/BUILD +++ b/tensorflow/core/grappler/clusters/BUILD @@ -91,6 +91,7 @@ cc_library( "//tensorflow/core/grappler/costs:analytical_cost_estimator", "//tensorflow/core/grappler/costs:op_level_cost_estimator", "//tensorflow/core/grappler/costs:virtual_scheduler", + "@com_google_absl//absl/status", ], ) @@ -106,6 +107,9 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -127,10 +131,16 @@ cc_library( "//tensorflow/core/common_runtime:core_cpu_lib", "//tensorflow/core/common_runtime:direct_session_internal", "//tensorflow/core/common_runtime/gpu:gpu_id", + "//tensorflow/core/framework:cost_graph_proto_cc", + "//tensorflow/core/framework:graph_proto_cc", "//tensorflow/core/grappler:utils", "//tensorflow/core/kernels:ops_util", + "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], alwayslink = 1, ) diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc index 3b1d7d8347549d..a630c1d3941aa7 100644 --- a/tensorflow/core/grappler/clusters/cluster.cc +++ b/tensorflow/core/grappler/clusters/cluster.cc @@ -14,6 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/clusters/cluster.h" + +#include +#include + +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { diff --git a/tensorflow/core/grappler/clusters/cluster.h b/tensorflow/core/grappler/clusters/cluster.h index a3a3708cd3e164..36aec54c42a245 100644 --- a/tensorflow/core/grappler/clusters/cluster.h +++ b/tensorflow/core/grappler/clusters/cluster.h @@ -23,10 +23,12 @@ limitations under the License. #include "absl/status/status.h" #include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/device_properties.pb.h" #include "tensorflow/core/public/session_options.h" diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc index 3fb2787f034e35..5113dc75d6cf47 100644 --- a/tensorflow/core/grappler/clusters/single_machine.cc +++ b/tensorflow/core/grappler/clusters/single_machine.cc @@ -16,15 +16,26 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/single_machine.h" #include +#include +#include #include +#include +#include +#include +#include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/types/optional.h" #include "tensorflow/cc/training/queue_runner.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/grappler/clusters/utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/kernels/ops_util.h" @@ -33,6 +44,8 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/device_properties.pb.h" #include "tensorflow/core/public/session.h" namespace tensorflow { diff --git a/tensorflow/core/grappler/clusters/single_machine.h b/tensorflow/core/grappler/clusters/single_machine.h index e049ca2fe09765..f3f36626767c52 100644 --- a/tensorflow/core/grappler/clusters/single_machine.h +++ b/tensorflow/core/grappler/clusters/single_machine.h @@ -16,11 +16,22 @@ limitations under the License. #ifndef TENSORFLOW_CORE_GRAPPLER_CLUSTERS_SINGLE_MACHINE_H_ #define TENSORFLOW_CORE_GRAPPLER_CLUSTERS_SINGLE_MACHINE_H_ +#include +#include +#include +#include +#include + +#include "absl/status/status.h" #include "tensorflow/cc/training/coordinator.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/queue_runner.pb.h" #include "tensorflow/core/public/session.h" namespace tensorflow { diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc index 0f2b6a6d2fdfff..e1775679e6ba54 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.cc +++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc @@ -15,11 +15,21 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include +#include +#include +#include +#include + +#include "absl/status/status.h" #include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/grappler/clusters/utils.h" #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/device_properties.pb.h" namespace tensorflow { namespace grappler { diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.h b/tensorflow/core/grappler/clusters/virtual_cluster.h index f42e1047ce2373..1204a34c7f3f8f 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.h +++ b/tensorflow/core/grappler/clusters/virtual_cluster.h @@ -16,13 +16,18 @@ limitations under the License. #ifndef TENSORFLOW_CORE_GRAPPLER_CLUSTERS_VIRTUAL_CLUSTER_H_ #define TENSORFLOW_CORE_GRAPPLER_CLUSTERS_VIRTUAL_CLUSTER_H_ +#include #include +#include +#include +#include "absl/status/status.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/costs/analytical_cost_estimator.h" #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" #include "tensorflow/core/grappler/costs/virtual_scheduler.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/device_properties.pb.h" namespace tensorflow { diff --git a/tensorflow/core/grappler/clusters/virtual_cluster_test.cc b/tensorflow/core/grappler/clusters/virtual_cluster_test.cc index a774b5e6ccc8af..251f02d407c093 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster_test.cc +++ b/tensorflow/core/grappler/clusters/virtual_cluster_test.cc @@ -16,15 +16,22 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include +#include +#include "absl/log/check.h" +#include "absl/status/status.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/device_properties.pb.h" #include "tensorflow/core/protobuf/error_codes.pb.h" namespace tensorflow { diff --git a/tensorflow/core/grappler/devices.h b/tensorflow/core/grappler/devices.h index a9bc76c3dbb87f..8a27bfacb07221 100644 --- a/tensorflow/core/grappler/devices.h +++ b/tensorflow/core/grappler/devices.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_GRAPPLER_DEVICES_H_ #define TENSORFLOW_CORE_GRAPPLER_DEVICES_H_ +#include #include #include diff --git a/tensorflow/core/grappler/inputs/BUILD b/tensorflow/core/grappler/inputs/BUILD index 3f2fddd7fef103..2bbd5885b07132 100644 --- a/tensorflow/core/grappler/inputs/BUILD +++ b/tensorflow/core/grappler/inputs/BUILD @@ -18,6 +18,7 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status", ], ) @@ -33,6 +34,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/framework:graph_proto_cc", "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/status", "@local_tsl//tsl/platform:status", ], ) diff --git a/tensorflow/core/grappler/inputs/file_input_yielder.cc b/tensorflow/core/grappler/inputs/file_input_yielder.cc index 2df0378441df9c..67eb881e5da0e3 100644 --- a/tensorflow/core/grappler/inputs/file_input_yielder.cc +++ b/tensorflow/core/grappler/inputs/file_input_yielder.cc @@ -15,9 +15,11 @@ limitations under the License. #include "tensorflow/core/grappler/inputs/file_input_yielder.h" +#include #include #include #include +#include #include "absl/log/check.h" #include "absl/log/log.h" diff --git a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc index 3c72721e5099a6..7f39582ba663f0 100644 --- a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc +++ b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc @@ -19,6 +19,9 @@ limitations under the License. #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include +#include + #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/data_flow_ops.h" diff --git a/tensorflow/core/grappler/inputs/utils.cc b/tensorflow/core/grappler/inputs/utils.cc index 6b2f380bd6a06d..294bb2cead1111 100644 --- a/tensorflow/core/grappler/inputs/utils.cc +++ b/tensorflow/core/grappler/inputs/utils.cc @@ -15,8 +15,10 @@ limitations under the License. #include "tensorflow/core/grappler/inputs/utils.h" +#include #include +#include "absl/status/status.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/status.h" diff --git a/tensorflow/core/grappler/inputs/utils.h b/tensorflow/core/grappler/inputs/utils.h index 589dbc00f4560c..9caefcd836c171 100644 --- a/tensorflow/core/grappler/inputs/utils.h +++ b/tensorflow/core/grappler/inputs/utils.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/core/grappler/inputs/utils_test.cc b/tensorflow/core/grappler/inputs/utils_test.cc index 51a1c48b6adf5c..b32229a051fa86 100644 --- a/tensorflow/core/grappler/inputs/utils_test.cc +++ b/tensorflow/core/grappler/inputs/utils_test.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/core/grappler/inputs/utils.h" +#include +#include + +#include "absl/status/status.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/path.h" diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index e441968d4c708d..df41e74b3390d6 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -484,12 +484,12 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage { return signature; } - void MarkWithTag(const StringPiece tag, NodeDef* node) { + void MarkWithTag(const absl::string_view tag, NodeDef* node) { AddNodeAttr(tag, true, node); } void MarkAllMembersWithTag(const OptimizedNodesGroup& group, - const StringPiece tag) const { + const absl::string_view tag) const { AddNodeAttr(tag, true, group.root_node); for (NodeDef* optimized_node : group.optimized_nodes) { AddNodeAttr(tag, true, optimized_node); @@ -506,12 +506,12 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage { ctx().nodes_to_preserve->end(); } - bool IsMarkedWithTag(const NodeDef& node, const StringPiece tag) const { + bool IsMarkedWithTag(const NodeDef& node, const absl::string_view tag) const { return HasNodeAttr(node, tag); } - bool IsMarkedWithAnyTag(const NodeDef& node, const StringPiece tag1, - const StringPiece tag2) const { + bool IsMarkedWithAnyTag(const NodeDef& node, const absl::string_view tag1, + const absl::string_view tag2) const { return IsMarkedWithTag(node, tag1) || IsMarkedWithTag(node, tag2); } }; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 414452698f1eb0..b14cff6c4a5982 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/tensor_bundle/byte_swap_tensor.h" namespace tensorflow { namespace grappler { @@ -94,6 +95,18 @@ void VerifyGraphsMatch(const GraphDef& original_graph, } } } + +void VerifyTensorContent(const TensorProto& proto, + const string& expected_content) { + if (port::kLittleEndian) { + EXPECT_EQ(proto.tensor_content(), expected_content); + } else { + TensorProto protoCopy; + protoCopy.CopyFrom(proto); + TF_EXPECT_OK(ByteSwapTensorProto(&protoCopy)); + EXPECT_EQ(protoCopy.tensor_content(), expected_content); + } +} } // namespace TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -716,8 +729,8 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) { ASSERT_NE(new_const, nullptr); ASSERT_EQ(new_const->input_size(), 1); EXPECT_EQ(new_const->input(0), "^x"); - EXPECT_EQ(new_const->attr().at("value").tensor().tensor_content(), - string("\0\0\0@", 4)); + VerifyTensorContent(new_const->attr().at("value").tensor(), + string("\0\0\0@", 4)); const NodeDef* new_mul = node_map.GetNode(optimized_mul_name); ASSERT_NE(new_mul, nullptr); @@ -763,8 +776,8 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) { ASSERT_NE(new_const, nullptr); ASSERT_EQ(new_const->input_size(), 1); EXPECT_EQ(new_const->input(0), "^x"); - EXPECT_EQ(new_const->attr().at("value").tensor().tensor_content(), - string("\0\0\0@", 4)); + VerifyTensorContent(new_const->attr().at("value").tensor(), + string("\0\0\0@", 4)); const NodeDef* new_mul = node_map.GetNode(optimized_mul_name); ASSERT_NE(new_mul, nullptr); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 0c9aca41dd98a7..87ffa9d1d7a0e5 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -368,12 +368,12 @@ static absl::Status ConvertShapeToConstant(const string& op, // TODO(rmlarsen): Perhaps we should move this to the GraphOptimizer base class. bool ConstantFolding::OptimizedNodeExists(const NodeDef& node, - StringPiece suffix) const { + absl::string_view suffix) const { return node_map_->NodeExists(OptimizedNodeName(node, suffix)); } string ConstantFolding::OptimizedNodeName(const NodeDef& node, - StringPiece suffix) const { + absl::string_view suffix) const { return AddPrefixToNodeName(strings::StrCat(node.name(), suffix), kConstantFoldingConst); } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 54490f8821e7ce..9c58f81e074d19 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -64,8 +64,8 @@ class ConstantFolding : public GraphOptimizer { private: bool ForwardInputs(NodeDef* node, absl::Span inputs_to_forward); - string OptimizedNodeName(const NodeDef& node, StringPiece suffix) const; - bool OptimizedNodeExists(const NodeDef& node, StringPiece suffix) const; + string OptimizedNodeName(const NodeDef& node, absl::string_view suffix) const; + bool OptimizedNodeExists(const NodeDef& node, absl::string_view suffix) const; bool IsReallyConstant(const NodeDef& node) const; diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.h b/tensorflow/core/grappler/optimizers/data/function_utils.h index 8941e58c55875b..0603463632d5ec 100644 --- a/tensorflow/core/grappler/optimizers/data/function_utils.h +++ b/tensorflow/core/grappler/optimizers/data/function_utils.h @@ -59,8 +59,8 @@ void ReplaceReferences(const string& from, const string& to, FunctionDef* func); // Adds a function output to the function def, ensuring that the output key // is unique, and maps to output_tensor_name in the ret dict. -void AddFunctionOutputWithUniqueName(StringPiece prefix, - StringPiece output_tensor_name, +void AddFunctionOutputWithUniqueName(absl::string_view prefix, + absl::string_view output_tensor_name, FunctionDef* fdef, DataType dtype); // Adds an input to a FunctionDef. @@ -68,41 +68,45 @@ OpDef_ArgDef* AddFunctionInput(const string& name, FunctionDef* fdef, DataType dtype); // Adds a node to a FunctionDef. -NodeDef* AddNode(StringPiece name, StringPiece op, +NodeDef* AddNode(absl::string_view name, absl::string_view op, const std::vector& inputs, const std::vector>& attributes, FunctionDef* fd); // Checks whether the function contains a node with the given name. -bool ContainsFunctionNodeWithName(StringPiece name, +bool ContainsFunctionNodeWithName(absl::string_view name, const FunctionDef& function); // Checks whether the function contains a node with the given op. -bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function); +bool ContainsFunctionNodeWithOp(absl::string_view op, + const FunctionDef& function); // Checks whether the function contains an output with the given name. -bool ContainsFunctionOutputWithName(StringPiece name, +bool ContainsFunctionOutputWithName(absl::string_view name, const FunctionDef& function); // Returns the index of the function input with the given name or -1 if the // function node does not exist. -int FindFunctionInputWithName(StringPiece name, const FunctionDef& function); +int FindFunctionInputWithName(absl::string_view name, + const FunctionDef& function); // Returns the index of the function output with the given name or -1 if the // function node does not exist. -int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function); +int FindFunctionOutputWithName(absl::string_view name, + const FunctionDef& function); // Returns the index of the function node with the given name or -1 if the // function node does not exist. -int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function); +int FindFunctionNodeWithName(absl::string_view name, + const FunctionDef& function); // Returns the index of the function node with the given op or -1 if the // function node does not exist. -int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function); +int FindFunctionNodeWithOp(absl::string_view op, const FunctionDef& function); // Sets the function node name using the `prefix` as a prefix while guaranteeing // the name is unique across the functions nodes. -void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function, +void SetUniqueFunctionNodeName(absl::string_view prefix, FunctionDef* function, NodeDef* node); // Checks if the function is stateful by checking the function graph for diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc index 20b5940f98102c..45b43e85814411 100644 --- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc @@ -502,11 +502,14 @@ void LazyConjunctionOutput(const protobuf::Map& first_ret, *fused_ret = first_ret; } -FunctionDef* FuseFunctions( - const FunctionDef& first_function, const FunctionDef& second_function, - StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature, - const SetInputFn& set_input, const SetOutputFn& set_output, - const SetNodesFn& set_nodes, FunctionDefLibrary* library) { +FunctionDef* FuseFunctions(const FunctionDef& first_function, + const FunctionDef& second_function, + absl::string_view fused_name_prefix, + const SetFunctionSignatureFn& set_signature, + const SetInputFn& set_input, + const SetOutputFn& set_output, + const SetNodesFn& set_nodes, + FunctionDefLibrary* library) { auto has_unknown_attrs = [](const FunctionDef& func) { int known_attribute_size = 0; diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.h b/tensorflow/core/grappler/optimizers/data/fusion_utils.h index f7da097d4b1b09..d0b7ed7cb4de67 100644 --- a/tensorflow/core/grappler/optimizers/data/fusion_utils.h +++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.h @@ -122,11 +122,14 @@ void LazyConjunctionNodes(const FunctionDef& first_function, // that are not conflicting with first function. This means that copied nodes // from second function can end up having different names. For explanation of // set up functions see the documentation of the functions types. -FunctionDef* FuseFunctions( - const FunctionDef& first_function, const FunctionDef& second_function, - StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature, - const SetInputFn& set_input, const SetOutputFn& set_output, - const SetNodesFn& set_nodes, FunctionDefLibrary* library); +FunctionDef* FuseFunctions(const FunctionDef& first_function, + const FunctionDef& second_function, + absl::string_view fused_name_prefix, + const SetFunctionSignatureFn& set_signature, + const SetInputFn& set_input, + const SetOutputFn& set_output, + const SetNodesFn& set_nodes, + FunctionDefLibrary* library); } // namespace fusion_utils } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc index a212e250510002..e99da1c407aa1a 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc @@ -25,9 +25,10 @@ namespace tensorflow { namespace grappler { namespace graph_tests_utils { -NodeDef MakeBatchV2Node(StringPiece name, StringPiece input_node_name, - StringPiece batch_size_node_name, - StringPiece drop_remainder_node_name, +NodeDef MakeBatchV2Node(absl::string_view name, + absl::string_view input_node_name, + absl::string_view batch_size_node_name, + absl::string_view drop_remainder_node_name, bool parallel_copy) { return test::function::NDef( name, "BatchDatasetV2", @@ -38,11 +39,12 @@ NodeDef MakeBatchV2Node(StringPiece name, StringPiece input_node_name, {"output_types", absl::Span{}}}); } -NodeDef MakeParallelBatchNode(StringPiece name, StringPiece input_node_name, - StringPiece batch_size_node_name, - StringPiece num_parallel_calls_node_name, - StringPiece drop_remainder_node_name, - StringPiece deterministic) { +NodeDef MakeParallelBatchNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view batch_size_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view drop_remainder_node_name, + absl::string_view deterministic) { return test::function::NDef( name, "ParallelBatchDataset", {string(input_node_name), string(batch_size_node_name), @@ -52,9 +54,10 @@ NodeDef MakeParallelBatchNode(StringPiece name, StringPiece input_node_name, {"deterministic", string(deterministic)}}); } -NodeDef MakeCacheV2Node(StringPiece name, StringPiece input_node_name, - StringPiece filename_node_name, - StringPiece cache_node_name) { +NodeDef MakeCacheV2Node(absl::string_view name, + absl::string_view input_node_name, + absl::string_view filename_node_name, + absl::string_view cache_node_name) { return test::function::NDef( name, "CacheDatasetV2", { @@ -68,8 +71,9 @@ NodeDef MakeCacheV2Node(StringPiece name, StringPiece input_node_name, }); } -NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name, - StringPiece function_name) { +NodeDef MakeFilterNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view function_name) { return test::function::NDef( name, "FilterDataset", {string(input_node_name)}, {{"predicate", FunctionDefHelper::FunctionRef(string(function_name))}, @@ -78,11 +82,12 @@ NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name, {"output_types", absl::Span{}}}); } -NodeDef MakeMapAndBatchNode(StringPiece name, StringPiece input_node_name, - StringPiece batch_size_node_name, - StringPiece num_parallel_calls_node_name, - StringPiece drop_remainder_node_name, - StringPiece function_name) { +NodeDef MakeMapAndBatchNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view batch_size_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view drop_remainder_node_name, + absl::string_view function_name) { return test::function::NDef( name, "MapAndBatchDataset", {string(input_node_name), string(batch_size_node_name), @@ -93,8 +98,8 @@ NodeDef MakeMapAndBatchNode(StringPiece name, StringPiece input_node_name, {"output_types", absl::Span{}}}); } -NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name, - StringPiece function_name) { +NodeDef MakeMapNode(absl::string_view name, absl::string_view input_node_name, + absl::string_view function_name) { return test::function::NDef( name, "MapDataset", {string(input_node_name)}, {{"f", FunctionDefHelper::FunctionRef(string(function_name))}, @@ -103,12 +108,12 @@ NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name, {"output_types", absl::Span{}}}); } -NodeDef MakeParallelInterleaveV2Node(StringPiece name, - StringPiece input_node_name, - StringPiece cycle_length_node_name, - StringPiece block_length_node_name, - StringPiece num_parallel_calls_node_name, - StringPiece function_name, bool sloppy) { +NodeDef MakeParallelInterleaveV2Node( + absl::string_view name, absl::string_view input_node_name, + absl::string_view cycle_length_node_name, + absl::string_view block_length_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view function_name, bool sloppy) { return test::function::NDef( name, "ParallelInterleaveDatasetV2", {string(input_node_name), string(cycle_length_node_name), @@ -122,13 +127,12 @@ NodeDef MakeParallelInterleaveV2Node(StringPiece name, }); } -NodeDef MakeParallelInterleaveV4Node(StringPiece name, - StringPiece input_node_name, - StringPiece cycle_length_node_name, - StringPiece block_length_node_name, - StringPiece num_parallel_calls_node_name, - StringPiece function_name, - StringPiece deterministic) { +NodeDef MakeParallelInterleaveV4Node( + absl::string_view name, absl::string_view input_node_name, + absl::string_view cycle_length_node_name, + absl::string_view block_length_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view function_name, absl::string_view deterministic) { return test::function::NDef( name, "ParallelInterleaveDatasetV4", {string(input_node_name), string(cycle_length_node_name), @@ -142,11 +146,12 @@ NodeDef MakeParallelInterleaveV4Node(StringPiece name, }); } -NodeDef MakeInterleaveNode(StringPiece name, StringPiece input_node_name, - StringPiece cycle_length_node_name, - StringPiece block_length_node_name, - StringPiece function_name, - StringPiece deterministic) { +NodeDef MakeInterleaveNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view cycle_length_node_name, + absl::string_view block_length_node_name, + absl::string_view function_name, + absl::string_view deterministic) { return test::function::NDef( name, "InterleaveDataset", {string(input_node_name), string(cycle_length_node_name), @@ -160,9 +165,10 @@ NodeDef MakeInterleaveNode(StringPiece name, StringPiece input_node_name, }); } -NodeDef MakeParallelMapNode(StringPiece name, StringPiece input_node_name, - StringPiece num_parallel_calls_node_name, - StringPiece function_name, bool sloppy) { +NodeDef MakeParallelMapNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view function_name, bool sloppy) { return test::function::NDef( name, "ParallelMapDataset", {string(input_node_name), string(num_parallel_calls_node_name)}, @@ -175,10 +181,11 @@ NodeDef MakeParallelMapNode(StringPiece name, StringPiece input_node_name, }); } -NodeDef MakeParallelMapV2Node(StringPiece name, StringPiece input_node_name, - StringPiece num_parallel_calls_node_name, - StringPiece function_name, - StringPiece deterministic, +NodeDef MakeParallelMapV2Node(absl::string_view name, + absl::string_view input_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view function_name, + absl::string_view deterministic, bool use_unbounded_threadpool) { return test::function::NDef( name, "ParallelMapDatasetV2", @@ -193,8 +200,9 @@ NodeDef MakeParallelMapV2Node(StringPiece name, StringPiece input_node_name, }); } -NodeDef MakeParseExampleNode(StringPiece name, StringPiece input_node_name, - StringPiece num_parallel_calls_node_name, +NodeDef MakeParseExampleNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view num_parallel_calls_node_name, bool sloppy) { return test::function::NDef( name, "ParseExampleDataset", @@ -206,9 +214,10 @@ NodeDef MakeParseExampleNode(StringPiece name, StringPiece input_node_name, }); } -NodeDef MakeShuffleV2Node(StringPiece name, StringPiece input_node_name, - StringPiece buffer_size_node_name, - StringPiece seed_generator_node_name) { +NodeDef MakeShuffleV2Node(absl::string_view name, + absl::string_view input_node_name, + absl::string_view buffer_size_node_name, + absl::string_view seed_generator_node_name) { return test::function::NDef( name, "ShuffleDatasetV2", { @@ -222,8 +231,8 @@ NodeDef MakeShuffleV2Node(StringPiece name, StringPiece input_node_name, }); } -NodeDef MakeTakeNode(StringPiece name, StringPiece input_node_name, - StringPiece count_node_name) { +NodeDef MakeTakeNode(absl::string_view name, absl::string_view input_node_name, + absl::string_view count_node_name) { return test::function::NDef( name, "TakeDataset", { @@ -236,7 +245,8 @@ NodeDef MakeTakeNode(StringPiece name, StringPiece input_node_name, }); } -NodeDef MakeTensorSliceNode(StringPiece name, StringPiece tensor_node_name, +NodeDef MakeTensorSliceNode(absl::string_view name, + absl::string_view tensor_node_name, bool replicate_on_split) { return test::function::NDef( name, "TensorSliceDataset", @@ -250,8 +260,8 @@ NodeDef MakeTensorSliceNode(StringPiece name, StringPiece tensor_node_name, }); } -NodeDef MakeSkipNode(StringPiece name, StringPiece input_node_name, - StringPiece count_node_name) { +NodeDef MakeSkipNode(absl::string_view name, absl::string_view input_node_name, + absl::string_view count_node_name) { return test::function::NDef( name, "SkipDataset", { @@ -264,9 +274,9 @@ NodeDef MakeSkipNode(StringPiece name, StringPiece input_node_name, }); } -NodeDef MakeShardNode(StringPiece name, StringPiece input_node_name, - StringPiece num_shards_node_name, - StringPiece index_node_name) { +NodeDef MakeShardNode(absl::string_view name, absl::string_view input_node_name, + absl::string_view num_shards_node_name, + absl::string_view index_node_name) { return test::function::NDef( name, "ShardDataset", { @@ -280,8 +290,9 @@ NodeDef MakeShardNode(StringPiece name, StringPiece input_node_name, }); } -NodeDef MakePrefetchNode(StringPiece name, StringPiece input_node_name, - StringPiece buffer_size) { +NodeDef MakePrefetchNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view buffer_size) { return test::function::NDef( name, "PrefetchDataset", {string(input_node_name), string(buffer_size)}, {{"output_shapes", absl::Span{}}, diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h index c5823d1a38607c..2b09eafc883705 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h @@ -24,104 +24,115 @@ namespace grappler { namespace graph_tests_utils { // Creates a test NodeDef for BatchDatasetV2. -NodeDef MakeBatchV2Node(StringPiece name, StringPiece input_node_name, - StringPiece batch_size_node_name, - StringPiece drop_remainder_node_name, +NodeDef MakeBatchV2Node(absl::string_view name, + absl::string_view input_node_name, + absl::string_view batch_size_node_name, + absl::string_view drop_remainder_node_name, bool parallel_copy); // Creates a test NodeDef for ParallelBatchDataset. -NodeDef MakeParallelBatchNode(StringPiece name, StringPiece input_node_name, - StringPiece batch_size_node_name, - StringPiece num_parallel_calls_node_name, - StringPiece drop_remainder_node_name, - StringPiece deterministic); +NodeDef MakeParallelBatchNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view batch_size_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view drop_remainder_node_name, + absl::string_view deterministic); // Creates a test NodeDef for ShuffleDatasetV2. -NodeDef MakeCacheV2Node(StringPiece name, StringPiece input_node_name, - StringPiece filename_node_name, - StringPiece cache_node_name); +NodeDef MakeCacheV2Node(absl::string_view name, + absl::string_view input_node_name, + absl::string_view filename_node_name, + absl::string_view cache_node_name); // Creates a test NodeDef for FilterDataset. -NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name, - StringPiece function_name = "IsZero"); +NodeDef MakeFilterNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view function_name = "IsZero"); // Creates a test NodeDef for MapDataset. -NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name, - StringPiece function_name = "XTimesTwo"); +NodeDef MakeMapNode(absl::string_view name, absl::string_view input_node_name, + absl::string_view function_name = "XTimesTwo"); // Creates a test NodeDef for MapAndBatchDataset. -NodeDef MakeMapAndBatchNode(StringPiece name, StringPiece input_node_name, - StringPiece batch_size_node_name, - StringPiece num_parallel_calls_node_name, - StringPiece drop_remainder_node_name, - StringPiece function_name = "XTimesTwo"); +NodeDef MakeMapAndBatchNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view batch_size_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view drop_remainder_node_name, + absl::string_view function_name = "XTimesTwo"); // Creates a test NodeDef for ParallelInterleaveDatasetV2. -NodeDef MakeParallelInterleaveV2Node(StringPiece name, - StringPiece input_node_name, - StringPiece cycle_length_node_name, - StringPiece block_length_node_name, - StringPiece num_parallel_calls_node_name, - StringPiece function_name, bool sloppy); +NodeDef MakeParallelInterleaveV2Node( + absl::string_view name, absl::string_view input_node_name, + absl::string_view cycle_length_node_name, + absl::string_view block_length_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view function_name, bool sloppy); // Creates a test NodeDef for ParallelInterleaveDatasetV4. -NodeDef MakeParallelInterleaveV4Node(StringPiece name, - StringPiece input_node_name, - StringPiece cycle_length_node_name, - StringPiece block_length_node_name, - StringPiece num_parallel_calls_node_name, - StringPiece function_name, - StringPiece deterministic); +NodeDef MakeParallelInterleaveV4Node( + absl::string_view name, absl::string_view input_node_name, + absl::string_view cycle_length_node_name, + absl::string_view block_length_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view function_name, absl::string_view deterministic); // Creates a test NodeDef for InterleaveDataset. -NodeDef MakeInterleaveNode(StringPiece name, StringPiece input_node_name, - StringPiece cycle_length_node_name, - StringPiece block_length_node_name, - StringPiece function_name, - StringPiece deterministic); +NodeDef MakeInterleaveNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view cycle_length_node_name, + absl::string_view block_length_node_name, + absl::string_view function_name, + absl::string_view deterministic); // Creates a test NodeDef for ParallelMapDataset. -NodeDef MakeParallelMapNode(StringPiece name, StringPiece input_node_name, - StringPiece num_parallel_calls_node_name, - StringPiece function_name, bool sloppy); +NodeDef MakeParallelMapNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view function_name, bool sloppy); // Creates a test NodeDef for ParallelMapDatasetV2. -NodeDef MakeParallelMapV2Node(StringPiece name, StringPiece input_node_name, - StringPiece num_parallel_calls_node_name, - StringPiece function_name, - StringPiece deterministic, +NodeDef MakeParallelMapV2Node(absl::string_view name, + absl::string_view input_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view function_name, + absl::string_view deterministic, bool use_unbounded_threadpool); // Creates a test NodeDef for ParseExampleDataset. -NodeDef MakeParseExampleNode(StringPiece name, StringPiece input_node_name, - StringPiece num_parallel_calls_node_name, +NodeDef MakeParseExampleNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view num_parallel_calls_node_name, bool sloppy); // Creates a test NodeDef for ShuffleDatasetV2. -NodeDef MakeShuffleV2Node(StringPiece name, StringPiece input_node_name, - StringPiece buffer_size_node_name, - StringPiece seed_generator_node_name); +NodeDef MakeShuffleV2Node(absl::string_view name, + absl::string_view input_node_name, + absl::string_view buffer_size_node_name, + absl::string_view seed_generator_node_name); // Creates a test NodeDef for TakeDataset. -NodeDef MakeTakeNode(StringPiece name, StringPiece input_node_name, - StringPiece count_node_name); +NodeDef MakeTakeNode(absl::string_view name, absl::string_view input_node_name, + absl::string_view count_node_name); // Creates a test NodeDef for TensorSliceDataset. -NodeDef MakeTensorSliceNode(StringPiece name, StringPiece tensor_node_name, +NodeDef MakeTensorSliceNode(absl::string_view name, + absl::string_view tensor_node_name, bool replicate_on_split); // Creates a test NodeDef for SkipDataset. -NodeDef MakeSkipNode(StringPiece name, StringPiece input_node_name, - StringPiece count_node_name); +NodeDef MakeSkipNode(absl::string_view name, absl::string_view input_node_name, + absl::string_view count_node_name); // Creates a test NodeDef for ShardDataset. -NodeDef MakeShardNode(StringPiece name, StringPiece input_node_name, - StringPiece num_shards_node_name, - StringPiece index_node_name); +NodeDef MakeShardNode(absl::string_view name, absl::string_view input_node_name, + absl::string_view num_shards_node_name, + absl::string_view index_node_name); // Creates a test NodeDef for PrefetchDataset. -NodeDef MakePrefetchNode(StringPiece name, StringPiece input_node_name, - StringPiece buffer_size); +NodeDef MakePrefetchNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view buffer_size); } // namespace graph_tests_utils } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index 7d72da88abff29..746b3ebb22bffd 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -108,7 +108,7 @@ NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) { return graph->AddNode(std::move(node)); } -NodeDef* AddNode(StringPiece name, StringPiece op, +NodeDef* AddNode(absl::string_view name, absl::string_view op, const std::vector& inputs, const std::vector>& attributes, MutableGraphView* graph) { @@ -159,7 +159,7 @@ NodeDef* AddScalarConstNode(int64_t v, MutableGraphView* graph) { } template <> -NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph) { +NodeDef* AddScalarConstNode(absl::string_view v, MutableGraphView* graph) { return AddScalarConstNodeHelper( DT_STRING, [v](TensorProto* proto) { proto->add_string_val(v.data(), v.size()); }, @@ -236,20 +236,20 @@ bool Compare(const GraphDef& g1, const GraphDef& g2) { return true; } -bool ContainsGraphFunctionWithName(StringPiece name, +bool ContainsGraphFunctionWithName(absl::string_view name, const FunctionDefLibrary& library) { return FindGraphFunctionWithName(name, library) != -1; } -bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) { +bool ContainsGraphNodeWithName(absl::string_view name, const GraphDef& graph) { return FindGraphNodeWithName(name, graph) != -1; } -bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) { +bool ContainsNodeWithOp(absl::string_view op, const GraphDef& graph) { return FindGraphNodeWithOp(op, graph) != -1; } -int FindGraphFunctionWithName(StringPiece name, +int FindGraphFunctionWithName(absl::string_view name, const FunctionDefLibrary& library) { return GetFirstElementIndexWithPredicate( [&name](const FunctionDef& function) { @@ -258,13 +258,13 @@ int FindGraphFunctionWithName(StringPiece name, library.function()); } -int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) { +int FindGraphNodeWithName(absl::string_view name, const GraphDef& graph) { return GetFirstElementIndexWithPredicate( [&name](const NodeDef& node) { return node.name() == name; }, graph.node()); } -int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) { +int FindGraphNodeWithOp(absl::string_view op, const GraphDef& graph) { return GetFirstElementIndexWithPredicate( [&op](const NodeDef& node) { return node.op() == op; }, graph.node()); } @@ -300,7 +300,7 @@ absl::Status GetDatasetOutputTypesAttr(const NodeDef& node, node.name(), " with op: ", node.op()); } -void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, +void SetUniqueGraphNodeName(absl::string_view prefix, GraphDef* graph, NodeDef* node) { string name = string(prefix); int id = graph->node_size(); @@ -316,7 +316,7 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, node->set_name(std::move(name)); } -void SetUniqueGraphFunctionName(StringPiece prefix, +void SetUniqueGraphFunctionName(absl::string_view prefix, const FunctionDefLibrary* library, FunctionDef* function) { string name = string(prefix); diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h index 0b3a8233921a3b..70d0c48085716a 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h @@ -49,7 +49,7 @@ int GetFirstElementIndexWithPredicate(const Predicate& predicate, } // Adds a node to the graph. -NodeDef* AddNode(StringPiece name, StringPiece op, +NodeDef* AddNode(absl::string_view name, absl::string_view op, const std::vector& inputs, const std::vector>& attributes, MutableGraphView* graph); @@ -78,7 +78,7 @@ NodeDef* AddScalarConstNode(int v, MutableGraphView* graph); template <> NodeDef* AddScalarConstNode(int64_t v, MutableGraphView* graph); template <> -NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph); +NodeDef* AddScalarConstNode(absl::string_view v, MutableGraphView* graph); // Retrieves the value of a const node. Returns an error // if the node is not const, or its value is of a different type. @@ -99,27 +99,27 @@ absl::Status GetScalarConstNodeValue(const NodeDef& node, bool* value); bool Compare(const GraphDef& g1, const GraphDef& g2); // Checks whether the graph contains a node with the given name. -bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph); +bool ContainsGraphNodeWithName(absl::string_view name, const GraphDef& graph); // Checks whether the library contains a function with the given name. -bool ContainsGraphFunctionWithName(StringPiece name, +bool ContainsGraphFunctionWithName(absl::string_view name, const FunctionDefLibrary& library); // Checks whether the graph contains a node with the given op. -bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph); +bool ContainsNodeWithOp(absl::string_view op, const GraphDef& graph); // Returns the index of the node with the given name or -1 if the node does // not exist. -int FindGraphNodeWithName(StringPiece name, const GraphDef& graph); +int FindGraphNodeWithName(absl::string_view name, const GraphDef& graph); // Returns the index of the function with the given name or -1 if the function // does not exist. -int FindGraphFunctionWithName(StringPiece name, +int FindGraphFunctionWithName(absl::string_view name, const FunctionDefLibrary& library); // Returns the index of the first node with the given op or -1 if no such node // exists. -int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph); +int FindGraphNodeWithOp(absl::string_view op, const GraphDef& graph); // Gets the 0th input to a node in the graph. NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph); @@ -139,11 +139,12 @@ std::vector FindAllGraphNodesWithOp(const string& op, // Sets the node name using `prefix` as a prefix while guaranteeing the name // is unique across the graph. -void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node); +void SetUniqueGraphNodeName(absl::string_view prefix, GraphDef* graph, + NodeDef* node); // Sets the function name using the `prefix` name as a prefix while guaranteeing // the name is unique across the function library. -void SetUniqueGraphFunctionName(StringPiece prefix, +void SetUniqueGraphFunctionName(absl::string_view prefix, const FunctionDefLibrary* library, FunctionDef* function); diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc index 599801dacc0336..31ca40af244757 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc @@ -87,7 +87,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeInt64) { TEST(GraphUtilsTest, AddScalarConstNodeString) { GraphDef graph_def; MutableGraphView graph(&graph_def); - NodeDef* string_node = AddScalarConstNode("hello", &graph); + NodeDef* string_node = AddScalarConstNode("hello", &graph); EXPECT_TRUE(ContainsGraphNodeWithName(string_node->name(), *graph.graph())); EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello"); } diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc index 077123ebf61184..0aaa95f77fbeb0 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc @@ -41,7 +41,7 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) { NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs, range_attrs, &graph); NodeDef *captured_input_node = - graph_utils::AddScalarConstNode("hello", &graph); + graph_utils::AddScalarConstNode("hello", &graph); NodeDef *map_node; { @@ -124,7 +124,7 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) { NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs, range_attrs, &graph); NodeDef *captured_input_node = - graph_utils::AddScalarConstNode("hello", &graph); + graph_utils::AddScalarConstNode("hello", &graph); NodeDef *map_node; { @@ -208,7 +208,7 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) { NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs, range_attrs, &graph); NodeDef *captured_input_node = - graph_utils::AddScalarConstNode("hello", &graph); + graph_utils::AddScalarConstNode("hello", &graph); NodeDef *num_parallel_calls_node = graph_utils::AddScalarConstNode(2, &graph); @@ -294,7 +294,7 @@ TEST(MapAndBatchFusionTest, FuseParallelMapV2AndBatchNodesIntoOne) { NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs, range_attrs, &graph); NodeDef *captured_input_node = - graph_utils::AddScalarConstNode("hello", &graph); + graph_utils::AddScalarConstNode("hello", &graph); NodeDef *num_parallel_calls_node = graph_utils::AddScalarConstNode(2, &graph); @@ -417,7 +417,7 @@ TEST(MapAndBatchFusionTest, NoChange_UnboundedThreadpoolParallelMap) { NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs, range_attrs, &graph); NodeDef *captured_input_node = - graph_utils::AddScalarConstNode("hello", &graph); + graph_utils::AddScalarConstNode("hello", &graph); NodeDef *num_parallel_calls_node = graph_utils::AddScalarConstNode(2, &graph); diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc index f173da58566920..2d34d97aaddbc1 100644 --- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc @@ -90,7 +90,6 @@ bool SameDeterministicAttr(const NodeDef& parallel_map_node, // optimizing each function in that graph and later aggregating any new // functions introduced during these individual optimizations into that single // graph's collective function library). -// TODO(mpcallanan): Look at deduping names in a more generic fashion upstream. string GetFusedName(const NodeDef& parent, const NodeDef& child) { return absl::StrCat("map_fusion_nodes/", parent.name(), "/", child.name()); } diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc index fe12e5dd1fe592..173f3e463fdf6d 100644 --- a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc +++ b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc @@ -35,7 +35,7 @@ std::vector> GetCommonAttributes() { return commonAttributes; } -NodeDef *MakeNode(StringPiece node_type, std::vector params, +NodeDef *MakeNode(absl::string_view node_type, std::vector params, string input_node, MutableGraphView *graph) { std::vector node_params; for (int param : params) { @@ -50,7 +50,7 @@ NodeDef *MakeNode(StringPiece node_type, std::vector params, graph); } -NodeDef *MakeNonConstNode(StringPiece node_type, +NodeDef *MakeNonConstNode(absl::string_view node_type, std::vector param_dtypes, string input_node, MutableGraphView *graph) { std::vector node_params; @@ -68,7 +68,7 @@ NodeDef *MakeNonConstNode(StringPiece node_type, NodeDef *MakeCacheNode(string input_node, MutableGraphView *graph) { NodeDef *node_filename = - graph_utils::AddScalarConstNode("", graph); + graph_utils::AddScalarConstNode("", graph); return graph_utils::AddNode("", "CacheDataset", {std::move(input_node), node_filename->name()}, GetCommonAttributes(), graph); diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc index 02b4800cf31317..5e392d231f5d83 100644 --- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc @@ -123,7 +123,8 @@ TEST(ShuffleAndRepeatFusionTest, FuseShuffleV2AndRepeat) { NodeDef *buffer_size_node = graph_utils::AddScalarConstNode(128, &graph); NodeDef *seed_generator_node = - graph_utils::AddScalarConstNode("dummy_resource", &graph); + graph_utils::AddScalarConstNode("dummy_resource", + &graph); std::vector shuffle_inputs(3); shuffle_inputs[0] = range_node->name(); shuffle_inputs[1] = buffer_size_node->name(); @@ -190,7 +191,8 @@ TEST(ShuffleAndRepeatFusionTest, FuseShuffleV3AndRepeat) { NodeDef *seed_node = graph_utils::AddScalarConstNode(-1, &graph); NodeDef *seed2_node = graph_utils::AddScalarConstNode(-1, &graph); NodeDef *seed_generator_node = - graph_utils::AddScalarConstNode("dummy_resource", &graph); + graph_utils::AddScalarConstNode("dummy_resource", + &graph); std::vector shuffle_inputs(5); shuffle_inputs[0] = range_node->name(); shuffle_inputs[1] = buffer_size_node->name(); diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc index 7350b0338ae115..d35f06daec1e9d 100644 --- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc @@ -351,7 +351,7 @@ void DumpGraphToVLOG(const GraphDef& graph, int log_level) { } // namespace -void ScopedAllocatorOptimizer::ExtendNodeAttr(StringPiece name, +void ScopedAllocatorOptimizer::ExtendNodeAttr(absl::string_view name, const std::vector& values, NodeDef* node_def) { if (HasNodeAttr(*node_def, name)) { diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h index f0f1e5c094eac9..1b50f148264bd7 100644 --- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h +++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h @@ -78,7 +78,8 @@ class ScopedAllocatorOptimizer : public GraphOptimizer { // Appends values to the attr value under name in node_def, if present. // If not present does an assignment. - static void ExtendNodeAttr(StringPiece name, const std::vector& values, + static void ExtendNodeAttr(absl::string_view name, + const std::vector& values, NodeDef* node_def); // Class that knows how to do graph rewriting for a particular kind of Op in diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 5a5a86536c9049..e437ebe0324fe6 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -69,8 +69,8 @@ inline int NodePositionIfSameNode(absl::string_view input_name, } // Returns the node name and position in a single call. -inline StringPiece ParseNodeNameAsStringPiece(absl::string_view name, - int* position) { +inline absl::string_view ParseNodeNameAsStringPiece(absl::string_view name, + int* position) { const bool is_control = absl::StartsWith(name, "^"); TensorId id = ParseTensorName(name); if (position) { @@ -89,7 +89,7 @@ inline string ParseNodeName(const string& name, int* position) { // Return the node name corresponding to 'name' if name is valid, or the empty // string otherwise. -inline StringPiece NodeNameAsStringPiece(const string& name) { +inline absl::string_view NodeNameAsStringPiece(const string& name) { return ParseNodeNameAsStringPiece(name, nullptr); } diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index 9bc94d5f7b083e..df74004f0d9419 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -497,7 +497,7 @@ void BM_NodeNameAsStringPiece(::testing::benchmark::State& state) { string input(size + 3, 'x'); input[size] = ':'; for (auto s : state) { - StringPiece node_name = NodeNameAsStringPiece(input); + absl::string_view node_name = NodeNameAsStringPiece(input); CHECK_GT(node_name.size(), 0); } } diff --git a/tensorflow/core/ir/dialect.h b/tensorflow/core/ir/dialect.h index cba40b384dad51..d74faa8976d958 100644 --- a/tensorflow/core/ir/dialect.h +++ b/tensorflow/core/ir/dialect.h @@ -26,47 +26,50 @@ limitations under the License. namespace mlir { namespace tfg { // Include the relevant TensorFlow attrs/types directly in the TFG namespace. -using mlir::tf_type::Bfloat16RefType; // NOLINT -using mlir::tf_type::BoolRefType; // NOLINT -using mlir::tf_type::Complex128RefType; // NOLINT -using mlir::tf_type::Complex64RefType; // NOLINT -using mlir::tf_type::ControlType; // NOLINT -using mlir::tf_type::DoubleRefType; // NOLINT -using mlir::tf_type::Float8E4M3FNRefType; // NOLINT -using mlir::tf_type::Float8E5M2RefType; // NOLINT -using mlir::tf_type::FloatRefType; // NOLINT -using mlir::tf_type::FuncAttr; // NOLINT -using mlir::tf_type::HalfRefType; // NOLINT -using mlir::tf_type::Int16RefType; // NOLINT -using mlir::tf_type::Int32RefType; // NOLINT -using mlir::tf_type::Int4RefType; // NOLINT -using mlir::tf_type::Int64RefType; // NOLINT -using mlir::tf_type::Int8RefType; // NOLINT -using mlir::tf_type::OpaqueTensorType; // NOLINT -using mlir::tf_type::PlaceholderAttr; // NOLINT -using mlir::tf_type::Qint16RefType; // NOLINT -using mlir::tf_type::Qint16Type; // NOLINT -using mlir::tf_type::Qint32RefType; // NOLINT -using mlir::tf_type::Qint32Type; // NOLINT -using mlir::tf_type::Qint8RefType; // NOLINT -using mlir::tf_type::Qint8Type; // NOLINT -using mlir::tf_type::Quint16RefType; // NOLINT -using mlir::tf_type::Quint16Type; // NOLINT -using mlir::tf_type::Quint8RefType; // NOLINT -using mlir::tf_type::Quint8Type; // NOLINT -using mlir::tf_type::ResourceRefType; // NOLINT -using mlir::tf_type::ResourceType; // NOLINT -using mlir::tf_type::ShapeAttr; // NOLINT -using mlir::tf_type::StringRefType; // NOLINT -using mlir::tf_type::StringType; // NOLINT -using mlir::tf_type::Uint16RefType; // NOLINT -using mlir::tf_type::Uint32RefType; // NOLINT -using mlir::tf_type::Uint4RefType; // NOLINT -using mlir::tf_type::Uint64RefType; // NOLINT -using mlir::tf_type::Uint8RefType; // NOLINT -using mlir::tf_type::VariantRefType; // NOLINT -using mlir::tf_type::VariantType; // NOLINT -using mlir::tf_type::VersionAttr; // NOLINT +using mlir::tf_type::Bfloat16RefType; // NOLINT +using mlir::tf_type::BoolRefType; // NOLINT +using mlir::tf_type::Complex128RefType; // NOLINT +using mlir::tf_type::Complex64RefType; // NOLINT +using mlir::tf_type::ControlType; // NOLINT +using mlir::tf_type::DoubleRefType; // NOLINT +using mlir::tf_type::Float8E4M3B11FNUZRefType; // NOLINT +using mlir::tf_type::Float8E4M3FNRefType; // NOLINT +using mlir::tf_type::Float8E4M3FNUZRefType; // NOLINT +using mlir::tf_type::Float8E5M2FNUZRefType; // NOLINT +using mlir::tf_type::Float8E5M2RefType; // NOLINT +using mlir::tf_type::FloatRefType; // NOLINT +using mlir::tf_type::FuncAttr; // NOLINT +using mlir::tf_type::HalfRefType; // NOLINT +using mlir::tf_type::Int16RefType; // NOLINT +using mlir::tf_type::Int32RefType; // NOLINT +using mlir::tf_type::Int4RefType; // NOLINT +using mlir::tf_type::Int64RefType; // NOLINT +using mlir::tf_type::Int8RefType; // NOLINT +using mlir::tf_type::OpaqueTensorType; // NOLINT +using mlir::tf_type::PlaceholderAttr; // NOLINT +using mlir::tf_type::Qint16RefType; // NOLINT +using mlir::tf_type::Qint16Type; // NOLINT +using mlir::tf_type::Qint32RefType; // NOLINT +using mlir::tf_type::Qint32Type; // NOLINT +using mlir::tf_type::Qint8RefType; // NOLINT +using mlir::tf_type::Qint8Type; // NOLINT +using mlir::tf_type::Quint16RefType; // NOLINT +using mlir::tf_type::Quint16Type; // NOLINT +using mlir::tf_type::Quint8RefType; // NOLINT +using mlir::tf_type::Quint8Type; // NOLINT +using mlir::tf_type::ResourceRefType; // NOLINT +using mlir::tf_type::ResourceType; // NOLINT +using mlir::tf_type::ShapeAttr; // NOLINT +using mlir::tf_type::StringRefType; // NOLINT +using mlir::tf_type::StringType; // NOLINT +using mlir::tf_type::Uint16RefType; // NOLINT +using mlir::tf_type::Uint32RefType; // NOLINT +using mlir::tf_type::Uint4RefType; // NOLINT +using mlir::tf_type::Uint64RefType; // NOLINT +using mlir::tf_type::Uint8RefType; // NOLINT +using mlir::tf_type::VariantRefType; // NOLINT +using mlir::tf_type::VariantType; // NOLINT +using mlir::tf_type::VersionAttr; // NOLINT class TFGraphOpAsmInterface; class TFOp; diff --git a/tensorflow/core/ir/importexport/convert_attributes.h b/tensorflow/core/ir/importexport/convert_attributes.h index e2df6a9ae42329..250a32e5319c4b 100644 --- a/tensorflow/core/ir/importexport/convert_attributes.h +++ b/tensorflow/core/ir/importexport/convert_attributes.h @@ -33,17 +33,16 @@ namespace tfg { // Convert the list of MLIR Attributes `attrs` to the `tensorflow::AttrValueMap` // `values`. -tensorflow::Status ConvertAttributes(ArrayRef attrs, - ArrayRef attrs_to_ignore, - bool remove_ref_type, - tensorflow::AttrValueMap* values); +absl::Status ConvertAttributes(ArrayRef attrs, + ArrayRef attrs_to_ignore, + bool remove_ref_type, + tensorflow::AttrValueMap* values); // Convert the MLIR attribute `attr` and return a `tensorflow::AttrValue`. absl::StatusOr ConvertAttribute(Attribute attr); -tensorflow::Status SetShapeAttribute(absl::string_view name, - ShapedType shaped_type, - tensorflow::AttrValueMap* values); +absl::Status SetShapeAttribute(absl::string_view name, ShapedType shaped_type, + tensorflow::AttrValueMap* values); // Converts an MLIR shaped type to a TensorFlow shape attribute. ShapeAttr ConvertTypeToTensorShapeAttr(const Type& type); @@ -78,8 +77,8 @@ absl::StatusOr ConvertHandleData( // Convert an array of handle data into the `handle_data` field of the provided // ArgDef. Each entry of the array is expected to be a TensorType. -tensorflow::Status ConvertHandleData(ArrayAttr handle_data_arr, - tensorflow::OpDef::ArgDef* arg); +absl::Status ConvertHandleData(ArrayAttr handle_data_arr, + tensorflow::OpDef::ArgDef* arg); } // namespace tfg } // namespace mlir diff --git a/tensorflow/core/ir/importexport/convert_tensor.h b/tensorflow/core/ir/importexport/convert_tensor.h index 0a20af3157e7af..15bbe282ac58f4 100644 --- a/tensorflow/core/ir/importexport/convert_tensor.h +++ b/tensorflow/core/ir/importexport/convert_tensor.h @@ -69,12 +69,12 @@ void SetTensorShapeProto(ShapeContainerT shape, } // Converts an MLIR elements attribute to a TensorFlow tensor proto. -tensorflow::Status ConvertToTensorProto(ElementsAttr attr, - tensorflow::TensorProto* output_tensor); +absl::Status ConvertToTensorProto(ElementsAttr attr, + tensorflow::TensorProto* output_tensor); // Converts an MLIR elements attribute to a TensorFlow tensor. -tensorflow::Status ConvertToTensor(ElementsAttr attr, - tensorflow::Tensor* output_tensor); +absl::Status ConvertToTensor(ElementsAttr attr, + tensorflow::Tensor* output_tensor); // Converts a TF shape to MLIR shape, i.e. -1 becomes kDynamicSize. llvm::SmallVector ConvertTFShapeToMlir(llvm::ArrayRef shape); diff --git a/tensorflow/core/ir/importexport/convert_types.h b/tensorflow/core/ir/importexport/convert_types.h index 3941e1d1a6bf9c..d3f1756caf0b50 100644 --- a/tensorflow/core/ir/importexport/convert_types.h +++ b/tensorflow/core/ir/importexport/convert_types.h @@ -26,25 +26,24 @@ limitations under the License. namespace mlir { namespace tfg { // Converts the TensorFlow DataType 'dtype' into an MLIR (scalar) type. -tensorflow::Status ConvertDataType(tensorflow::DataType dtype, Builder& builder, - Type* type); +absl::Status ConvertDataType(tensorflow::DataType dtype, Builder& builder, + Type* type); // Converts a scalar MLIR type to a TensorFlow Datatype. -tensorflow::Status ConvertScalarTypeToDataType(Type type, - tensorflow::DataType* dtype); +absl::Status ConvertScalarTypeToDataType(Type type, + tensorflow::DataType* dtype); // Converts an MLIR type to TensorFlow DataType. If 'type' is a scalar type, it // is converted directly. If it is a shaped type, the element type is converted. -tensorflow::Status ConvertToDataType(Type type, tensorflow::DataType* dtype); +absl::Status ConvertToDataType(Type type, tensorflow::DataType* dtype); // Converts an TensorFlow shape to the one used in MLIR. void ConvertToMlirShape(const tensorflow::TensorShape& input_shape, SmallVectorImpl* shape); // Converts an TensorFlow shape proto to the one used in MLIR. -tensorflow::Status ConvertToMlirShape( - const tensorflow::TensorShapeProto& input_shape, - SmallVectorImpl* shape); +absl::Status ConvertToMlirShape(const tensorflow::TensorShapeProto& input_shape, + SmallVectorImpl* shape); // Given a tensor shape and dtype, get the corresponding MLIR tensor type. absl::StatusOr ConvertToMlirTensorType( diff --git a/tensorflow/core/ir/importexport/functiondef_import.h b/tensorflow/core/ir/importexport/functiondef_import.h index 4bd76d1a50f5f4..7e9aba69b9e1e0 100644 --- a/tensorflow/core/ir/importexport/functiondef_import.h +++ b/tensorflow/core/ir/importexport/functiondef_import.h @@ -26,9 +26,9 @@ namespace tfg { // Import the FunctionDef `func` as a TFG generic function (see GraphFuncOp // documentation). The function will be inserted using the provided `builder`. -tensorflow::Status ConvertGenericFunction(GraphFuncOp func_op, - const tensorflow::FunctionDef& func, - OpBuilder& builder); +absl::Status ConvertGenericFunction(GraphFuncOp func_op, + const tensorflow::FunctionDef& func, + OpBuilder& builder); } // namespace tfg } // namespace mlir diff --git a/tensorflow/core/ir/importexport/graphdef_export.h b/tensorflow/core/ir/importexport/graphdef_export.h index 0f4a90d90733a5..74af12fbf6be8c 100644 --- a/tensorflow/core/ir/importexport/graphdef_export.h +++ b/tensorflow/core/ir/importexport/graphdef_export.h @@ -37,18 +37,17 @@ absl::StatusOr GetValueName(Value value, TFGraphDialect *dialect); // Convert a TFG graph directly to GraphDef. Graph functions in the module are // added to the GraphDef's function library. -tensorflow::Status ConvertToGraphDef(ModuleOp module, - tensorflow::GraphDef *graph); +absl::Status ConvertToGraphDef(ModuleOp module, tensorflow::GraphDef *graph); // Convert a single TFG op to NodeDef. This utliity function requires a callback // `get_value_name` that returns the edge name of the given operand. -tensorflow::Status ConvertToNodeDef( +absl::Status ConvertToNodeDef( Operation *op, tensorflow::NodeDef *node, TFGraphDialect *dialect, function_ref(Value)> get_value_name); // Convert a single TFG function to a FunctionDef and add it to the function // library. If a function with the same name already exists, replace it. -tensorflow::Status ConvertToFunctionDef( +absl::Status ConvertToFunctionDef( GraphFuncOp func, tensorflow::FunctionLibraryDefinition &library); } // namespace tfg diff --git a/tensorflow/core/ir/importexport/load_proto.cc b/tensorflow/core/ir/importexport/load_proto.cc index acaf2987b41e78..4adfd5effcfa47 100644 --- a/tensorflow/core/ir/importexport/load_proto.cc +++ b/tensorflow/core/ir/importexport/load_proto.cc @@ -30,7 +30,8 @@ inline llvm::StringRef StringViewToRef(absl::string_view view) { } } // namespace -Status LoadProtoFromBuffer(absl::string_view input, protobuf::Message* proto) { +absl::Status LoadProtoFromBuffer(absl::string_view input, + protobuf::Message* proto) { // Attempt to parse as text. if (mlir::tfg::ParseTextProto(input, "", proto).ok()) return absl::OkStatus(); @@ -38,8 +39,8 @@ Status LoadProtoFromBuffer(absl::string_view input, protobuf::Message* proto) { return LoadProtoFromBuffer(input, static_cast(proto)); } -Status LoadProtoFromBuffer(absl::string_view input, - protobuf::MessageLite* proto) { +absl::Status LoadProtoFromBuffer(absl::string_view input, + protobuf::MessageLite* proto) { // Attempt to parse as binary. protobuf::io::ArrayInputStream binary_stream(input.data(), input.size()); if (proto->ParseFromZeroCopyStream(&binary_stream)) return absl::OkStatus(); @@ -49,7 +50,7 @@ Status LoadProtoFromBuffer(absl::string_view input, } template -Status LoadProtoFromFileImpl(absl::string_view input_filename, T* proto) { +absl::Status LoadProtoFromFileImpl(absl::string_view input_filename, T* proto) { const auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(StringViewToRef(input_filename)); if (std::error_code error = file_or_err.getError()) { @@ -64,13 +65,13 @@ Status LoadProtoFromFileImpl(absl::string_view input_filename, T* proto) { return LoadProtoFromBuffer(content, proto); } -Status LoadProtoFromFile(absl::string_view input_filename, - protobuf::Message* proto) { +absl::Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::Message* proto) { return LoadProtoFromFileImpl(input_filename, proto); } -Status LoadProtoFromFile(absl::string_view input_filename, - protobuf::MessageLite* proto) { +absl::Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::MessageLite* proto) { return LoadProtoFromFileImpl(input_filename, proto); } diff --git a/tensorflow/core/ir/importexport/load_proto.h b/tensorflow/core/ir/importexport/load_proto.h index 2d6be1590ac26e..9644411c12d2e3 100644 --- a/tensorflow/core/ir/importexport/load_proto.h +++ b/tensorflow/core/ir/importexport/load_proto.h @@ -26,18 +26,19 @@ namespace tensorflow { // buffer. Returns error status of the file is not found or malformed proto. // Note that text protos can only be parsed when full protobuf::Message protos // are used, and will fail for protobuf::MessageLite protos. -Status LoadProtoFromBuffer(absl::string_view input, protobuf::Message* proto); -Status LoadProtoFromBuffer(absl::string_view input, - protobuf::MessageLite* proto); +absl::Status LoadProtoFromBuffer(absl::string_view input, + protobuf::Message* proto); +absl::Status LoadProtoFromBuffer(absl::string_view input, + protobuf::MessageLite* proto); // Reads text (.pbtext) or binary (.pb) format of a proto message from the given // file path. Returns error status of the file is not found or malformed proto. // Note that text protos can only be parsed when full protobuf::Message protos // are used, and will fail for protobuf::MessageLite protos. -Status LoadProtoFromFile(absl::string_view input_filename, - protobuf::Message* proto); -Status LoadProtoFromFile(absl::string_view input_filename, - protobuf::MessageLite* proto); +absl::Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::Message* proto); +absl::Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::MessageLite* proto); } // namespace tensorflow diff --git a/tensorflow/core/ir/importexport/mangling.h b/tensorflow/core/ir/importexport/mangling.h index 98bcddccc9df6c..a85be927bd31d9 100644 --- a/tensorflow/core/ir/importexport/mangling.h +++ b/tensorflow/core/ir/importexport/mangling.h @@ -54,20 +54,20 @@ MangledKind GetMangledKind(absl::string_view str); // Return a TensorShapeProto mangled as a string. std::string MangleShape(const tensorflow::TensorShapeProto& shape); // Demangle a string mangled with MangleShape. -tensorflow::Status DemangleShape(absl::string_view str, - tensorflow::TensorShapeProto* proto); +absl::Status DemangleShape(absl::string_view str, + tensorflow::TensorShapeProto* proto); // Return a TensorProto mangled as a string. std::string MangleTensor(const tensorflow::TensorProto& tensor); // Demangle a string mangled with MangleTensor. -tensorflow::Status DemangleTensor(absl::string_view str, - tensorflow::TensorProto* proto); +absl::Status DemangleTensor(absl::string_view str, + tensorflow::TensorProto* proto); // Return a DataType mangled as a string. std::string MangleDataType(const tensorflow::DataType& dtype); // Demangle a string mangled with MangleDataType. -tensorflow::Status DemangleDataType(absl::string_view str, - tensorflow::DataType* proto); +absl::Status DemangleDataType(absl::string_view str, + tensorflow::DataType* proto); } // namespace mangling_util } // namespace tfg diff --git a/tensorflow/core/ir/importexport/parse_text_proto.h b/tensorflow/core/ir/importexport/parse_text_proto.h index 913081de7eed44..00a7d83ebc2782 100644 --- a/tensorflow/core/ir/importexport/parse_text_proto.h +++ b/tensorflow/core/ir/importexport/parse_text_proto.h @@ -26,16 +26,15 @@ namespace tfg { // Sets output to the given input with `prefix` stripped, or returns an error if // the prefix doesn't exist. -tensorflow::Status ConsumePrefix(absl::string_view str, - absl::string_view prefix, - absl::string_view* output); +absl::Status ConsumePrefix(absl::string_view str, absl::string_view prefix, + absl::string_view* output); // Strips `prefix_to_strip` from `text_proto`, parses, and returns the parsed // proto. -tensorflow::Status ParseTextProto(absl::string_view text_proto, - absl::string_view prefix_to_strip, - tensorflow::protobuf::Message* parsed_proto); -inline tensorflow::Status ParseTextProto( +absl::Status ParseTextProto(absl::string_view text_proto, + absl::string_view prefix_to_strip, + tensorflow::protobuf::Message* parsed_proto); +inline absl::Status ParseTextProto( absl::string_view /* text_proto */, absl::string_view /* prefix_to_strip */, tensorflow::protobuf::MessageLite* /* parsed_proto */) { return tensorflow::errors::Unavailable("Cannot parse text protos on mobile."); diff --git a/tensorflow/core/ir/importexport/savedmodel_export.cc b/tensorflow/core/ir/importexport/savedmodel_export.cc index b4148dde56b965..b2a74aa678bff2 100644 --- a/tensorflow/core/ir/importexport/savedmodel_export.cc +++ b/tensorflow/core/ir/importexport/savedmodel_export.cc @@ -25,7 +25,7 @@ limitations under the License. namespace mlir { namespace tfg { -tensorflow::Status ExportMlirToSavedModel( +absl::Status ExportMlirToSavedModel( mlir::ModuleOp module, const tensorflow::SavedModel &original_saved_model, tensorflow::SavedModel *output_saved_model) { if (original_saved_model.meta_graphs_size() == 0) { diff --git a/tensorflow/core/ir/importexport/savedmodel_export.h b/tensorflow/core/ir/importexport/savedmodel_export.h index 0d9811fd6a8409..b270ce9ca764bc 100644 --- a/tensorflow/core/ir/importexport/savedmodel_export.h +++ b/tensorflow/core/ir/importexport/savedmodel_export.h @@ -29,7 +29,7 @@ namespace tfg { // The module must contain at most a single Graph operation and zero or more // TFFunc operations. `original_saved_model` is used as only a GraphDef portion // of a saved model represented in the MLIR module. -tensorflow::Status ExportMlirToSavedModel( +absl::Status ExportMlirToSavedModel( mlir::ModuleOp module, const tensorflow::SavedModel &original_saved_model, tensorflow::SavedModel *output_saved_model); diff --git a/tensorflow/core/ir/importexport/tests/saved_model/saved_model_roundtrip_test.cc b/tensorflow/core/ir/importexport/tests/saved_model/saved_model_roundtrip_test.cc index f585be6452ebc4..97bf1f09bc1769 100644 --- a/tensorflow/core/ir/importexport/tests/saved_model/saved_model_roundtrip_test.cc +++ b/tensorflow/core/ir/importexport/tests/saved_model/saved_model_roundtrip_test.cc @@ -28,8 +28,8 @@ limitations under the License. namespace { -tensorflow::Status ReadModelProto(const std::string& input_file, - tensorflow::SavedModel* out) { +absl::Status ReadModelProto(const std::string& input_file, + tensorflow::SavedModel* out) { return tensorflow::ReadBinaryProto(tensorflow::Env::Default(), input_file, out); } diff --git a/tensorflow/core/ir/tests/types.mlir b/tensorflow/core/ir/tests/types.mlir index 67dc7a5158e7d4..bb885415af8281 100644 --- a/tensorflow/core/ir/tests/types.mlir +++ b/tensorflow/core/ir/tests/types.mlir @@ -66,6 +66,12 @@ module attributes {tfg.type = !tf_type.halfref} {} module attributes {tfg.type = !tf_type.float8e4m3fnref} {} // CHECK: module attributes {tfg.type = !tf_type.float8e5m2ref module attributes {tfg.type = !tf_type.float8e5m2ref} {} +// CHECK: module attributes {tfg.type = !tf_type.float8e4m3fnuzref +module attributes {tfg.type = !tf_type.float8e4m3fnuzref} {} +// CHECK: module attributes {tfg.type = !tf_type.float8e4m3b11fnuzref +module attributes {tfg.type = !tf_type.float8e4m3b11fnuzref} {} +// CHECK: module attributes {tfg.type = !tf_type.float8e5m2fnuzref +module attributes {tfg.type = !tf_type.float8e5m2fnuzref} {} // CHECK: module attributes {tfg.type = !tf_type.control module attributes {tfg.type = !tf_type.control} {} // CHECK: module attributes {tfg.type = !tf_type.tensor diff --git a/tensorflow/core/ir/types/dialect.cc b/tensorflow/core/ir/types/dialect.cc index db175cfa089936..891ec4744b8477 100644 --- a/tensorflow/core/ir/types/dialect.cc +++ b/tensorflow/core/ir/types/dialect.cc @@ -546,6 +546,12 @@ TensorFlowType TensorFlowRefType::get(Type type) { return Float8E4M3FNRefType::get(ctx); } else if (type.isFloat8E5M2()) { return Float8E5M2RefType::get(ctx); + } else if (type.isFloat8E4M3FNUZ()) { + return Float8E4M3FNUZRefType::get(ctx); + } else if (type.isFloat8E4M3B11FNUZ()) { + return Float8E4M3B11FNUZRefType::get(ctx); + } else if (type.isFloat8E5M2FNUZ()) { + return Float8E5M2FNUZRefType::get(ctx); } else if (auto complex_type = mlir::dyn_cast(type)) { Type etype = complex_type.getElementType(); if (etype.isF32()) { @@ -596,6 +602,12 @@ Type TensorFlowRefType::RemoveRef() { if (mlir::isa(*this)) return FloatType::getFloat8E4M3FN(ctx); if (mlir::isa(*this)) return FloatType::getFloat8E5M2(ctx); + if (mlir::isa(*this)) + return FloatType::getFloat8E4M3FNUZ(ctx); + if (mlir::isa(*this)) + return FloatType::getFloat8E4M3B11FNUZ(ctx); + if (mlir::isa(*this)) + return FloatType::getFloat8E5M2FNUZ(ctx); if (mlir::isa(*this)) return IntegerType::get(ctx, 1); if (mlir::isa(*this)) return IntegerType::get(ctx, 4, IntegerType::Signed); diff --git a/tensorflow/core/ir/types/types.def b/tensorflow/core/ir/types/types.def index 64f73bfdf67e7a..ba3743ea9702fa 100644 --- a/tensorflow/core/ir/types/types.def +++ b/tensorflow/core/ir/types/types.def @@ -68,6 +68,9 @@ HANDLE_TF_REF_TYPE(HalfRef, HALF_REF, "halfref") HANDLE_TF_REF_TYPE(ResourceRef, RESOURCE_REF, "resourceref") HANDLE_TF_REF_TYPE(Float8E4M3FNRef, FLOAT8_E4M3FN_REF, "float8e4m3fnref") HANDLE_TF_REF_TYPE(Float8E5M2Ref, FLOAT8_E5M2_REF, "float8e5m2ref") +HANDLE_TF_REF_TYPE(Float8E4M3FNUZRef, FLOAT8_E4M3FNUZ_REF, "float8e4m3fnuzref") +HANDLE_TF_REF_TYPE(Float8E4M3B11FNUZRef, FLOAT8_E4M3B11FNUZ_REF, "float8e4m3b11fnuzref") +HANDLE_TF_REF_TYPE(Float8E5M2FNUZRef, FLOAT8_E5M2FNUZ_REF, "float8e5m2fnuzref") #ifndef HANDLE_LAST_TF_TYPE #define HANDLE_LAST_TF_TYPE(class, enumerant, name) \ diff --git a/tensorflow/core/ir/utils/shape_inference_utils.cc b/tensorflow/core/ir/utils/shape_inference_utils.cc index a78a29a2e2390b..753ad1450b8a9e 100644 --- a/tensorflow/core/ir/utils/shape_inference_utils.cc +++ b/tensorflow/core/ir/utils/shape_inference_utils.cc @@ -95,7 +95,7 @@ NamedAttrList GetAllAttributesFromOperation(Operation* op) { std::optional GetShapeFromMlirType(Type t) { if (auto ranked_type = t.dyn_cast()) { tensorflow::PartialTensorShape shape; - const tensorflow::Status status = + const absl::Status status = tensorflow::PartialTensorShape::BuildPartialTensorShape( ConvertMlirShapeToTF(ranked_type.getShape()), &shape); if (status.ok()) return shape; @@ -232,7 +232,7 @@ LogicalResult InferReturnTypeComponentsForTFOp( tensorflow::AttrValueMap attributes; if (get_attr_values_fn) { - tensorflow::Status status = + absl::Status status = get_attr_values_fn(op, op_name, op_reg_data, /*ignore_unregistered_attrs=*/true, &attributes); if (!status.ok()) { @@ -243,7 +243,7 @@ LogicalResult InferReturnTypeComponentsForTFOp( } else { auto* dialect = cast(op->getDialect()); tensorflow::NodeDef node_def; - tensorflow::Status status = ConvertToNodeDef( + absl::Status status = ConvertToNodeDef( op, &node_def, dialect, [&](Value value) { return GetValueName(value, dialect); }); if (!status.ok()) { diff --git a/tensorflow/core/ir/utils/shape_inference_utils.h b/tensorflow/core/ir/utils/shape_inference_utils.h index c2385095ecec92..273f4ceed01480 100644 --- a/tensorflow/core/ir/utils/shape_inference_utils.h +++ b/tensorflow/core/ir/utils/shape_inference_utils.h @@ -56,7 +56,7 @@ using ResultElementTypeFn = llvm::function_ref; // Extracts the attributes of a MLIR operation and populates the converted // attributes in a proto map. This is used by operation // defined in TF dialect which has different attributes format than TFG dialect. -using GetAttrValuesFn = llvm::function_ref; diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index a0cd508ceeec60..c50e7c7d1021fe 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4933,6 +4933,9 @@ cc_library( tf_kernel_library( name = "stateful_random_ops", + copts = [ + "-Wno-thread-safety-analysis", # TODO(b/384723765): Remove this once the bug is fixed. + ], features = if_cuda(["-layering_check"]), prefix = "stateful_random_ops", deps = [ @@ -8085,7 +8088,6 @@ tf_cc_shared_library( "@com_googlesource_code_re2//:__subpackages__", "@compute_library//:__subpackages__", "@curl//:__subpackages__", - "@double_conversion//:__subpackages__", "@eigen_archive//:__subpackages__", "@farmhash_archive//:__subpackages__", "@farmhash_gpu_archive//:__subpackages__", diff --git a/tensorflow/core/kernels/as_string_op.cc b/tensorflow/core/kernels/as_string_op.cc index 985b4059716ed1..12e1622963da07 100644 --- a/tensorflow/core/kernels/as_string_op.cc +++ b/tensorflow/core/kernels/as_string_op.cc @@ -197,8 +197,8 @@ class AsStringOp : public OpKernel { case (DT_STRING): { const auto& input_flat = input_tensor->flat(); for (int i = 0; i < input_flat.size(); ++i) { - output_flat(i) = strings::Printf(format_.c_str(), - StringPiece(input_flat(i)).data()); + output_flat(i) = strings::Printf( + format_.c_str(), absl::string_view(input_flat(i)).data()); } } break; case (DT_VARIANT): { diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index 250ce16b500c5f..a1b9f9778b3eb4 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -118,8 +118,7 @@ int32 NumBatchThreadsFromEnvironmentWithDefault(int default_num_batch_threads) { int32_t num; const char* val = std::getenv("TF_NUM_BATCH_THREADS"); - return (val && strings::safe_strto32(val, &num)) ? num - : default_num_batch_threads; + return (val && absl::SimpleAtoi(val, &num)) ? num : default_num_batch_threads; } static thread::ThreadPool* GetOrCreateBatchThreadsPool() { diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc index 4c1cfe162052c1..43e3e5ffa820f5 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc @@ -214,6 +214,17 @@ void RecordBatchDelayUsV2(int64_t batch_delay_us, const string& model_name, ->Add(static_cast(batch_delay_us)); } +void RecordBatchTaskSizeSum(int32_t batch_task_size, + int32_t unbatched_task_size, + const string& model_name, const string& op_name) { + static auto* cell = tensorflow::monitoring::Counter<3>::New( + "/tensorflow/serving/batching/batch_task_size_sum", + "Tracks the sum of the task sizes in a batch.", "model_name", "op_name", + "is_batched"); + cell->GetCell(model_name, op_name, "true")->IncrementBy(batch_task_size); + cell->GetCell(model_name, op_name, "false")->IncrementBy(unbatched_task_size); +} + void RecordBatchParamBatchTimeoutMicros(int64_t batch_timeout_micros, const string& model_name, const string& op_name) { @@ -694,6 +705,9 @@ absl::Status BatchResourceBase::ConcatInputTensors( {"padding_amount", padding_amount}, {"disable_padding", disable_padding}}); }); + RecordBatchTaskSizeSum(batch.size(), unbatched_tasks_size, + GetModelName(context), context->op_kernel().name()); + // TODO(b/316379576): Add metrics for the breakdown between the size of the // original batch size and the unbatched task size and update the batch size // to include the unbatched tasks. diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h index c50b29f3d1b3ed..e853fc482eeb57 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.h +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h @@ -67,7 +67,7 @@ struct BatchResourceOptions { class BatchResourceBase : public ResourceBase { public: // Given a BatchTask (from one op invocation) with 'num_outputs'== M and - // splitted into N sub tasks, TensorMatrix is a N X M matrix. + // split into N sub tasks, TensorMatrix is a N X M matrix. // Namely, TensorMatrix[i][j] indicates the i-th split tensor of j-th output; // concatenating tensors along the 2nd dimension gives a output tensor. typedef std::vector> TensorMatrix; diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc index a8a70a6aa6b944..f8e08f7b0021f7 100644 --- a/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc +++ b/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc @@ -243,7 +243,7 @@ TEST(TaskQueueTest, RemoveAllTasksWhenArgGreaterThanTaskSize) { EXPECT_EQ(3, task_queue.num_tasks()); EXPECT_EQ(6, task_queue.size()); - // All tasks upto the size 6 shoule be remove when the size 8 is specified. + // All tasks upto the size 6 should be remove when the size 8 is specified. EXPECT_THAT(task_queue.RemoveTask(8), ElementsAre(Pointee(Property(&FakeTask::size, Eq(1))), Pointee(Property(&FakeTask::size, Eq(2))), diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h index 5415182fe8e8f3..347f300836cd94 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h @@ -1167,7 +1167,7 @@ absl::Status Queue::ValidateLowPriorityTaskQueueCapacity( options_.low_priority_queue_options.max_execution_batch_size) { return absl::UnavailableError(absl::StrFormat( "The low priority task queue to which this task was submitted does not " - "have the capcity to handle this task; currently the low priority " + "have the capacity to handle this task; currently the low priority " "queue has %d tasks enqueued and the submitted task size is %d while " "max_enqueued_batches=%d and max_execution_batch_size=%d", low_priority_tasks_.size(), task.size(), diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc index 84f6e817e3c642..34b23da6099d26 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc @@ -74,7 +74,7 @@ class FakeTask : public BatchTask { void operator=(const FakeTask&) = delete; }; -// Fake task taht doesn't inherit BatchTask and doesn't define criticality. The +// Fake task that doesn't inherit BatchTask and doesn't define criticality. The // shared batch scheduler should still work with this task. class FakeTaskWithoutCriticality { public: @@ -1202,7 +1202,7 @@ TEST_P(SharedBatchSchedulerPriorityTest, testing::StatusIs( absl::StatusCode::kUnavailable, HasSubstr("The low priority task queue to which this task was " - "submitted does not have the capcity to handle this task; " + "submitted does not have the capacity to handle this task; " "currently the low priority queue has 20 tasks enqueued " "and the submitted task size is 1 while " "max_enqueued_batches=2 and max_execution_batch_size=10"))); diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc index f77d6cd010c1ae..eaae6db6725584 100644 --- a/tensorflow/core/kernels/collective_ops.cc +++ b/tensorflow/core/kernels/collective_ops.cc @@ -184,7 +184,7 @@ class CollectiveGatherOpKernel : public CollectiveOpV1Kernel { auto output_shape = c->input(0).shape(); OP_REQUIRES_ASYNC(c, output_shape.dims() > 0, errors::InvalidArgument("input should have rank > 0, ", - "recieved ", output_shape.dims()), + "received ", output_shape.dims()), done); output_shape.set_dim( 0, output_shape.dim_size(0) * col_params_->group.group_size); diff --git a/tensorflow/core/kernels/conv_grad_shape_utils.cc b/tensorflow/core/kernels/conv_grad_shape_utils.cc index aeafb0db6745c5..42e114ad33581d 100644 --- a/tensorflow/core/kernels/conv_grad_shape_utils.cc +++ b/tensorflow/core/kernels/conv_grad_shape_utils.cc @@ -51,7 +51,7 @@ int ConvBackpropDimensions::SpatialPadding(const Padding& padding, namespace { absl::Status ConvBackpropExtractAndVerifyDimension( - StringPiece label, const TensorShape& input_shape, + absl::string_view label, const TensorShape& input_shape, const TensorShape& filter_shape, const TensorShape& output_shape, const absl::Span dilations, const std::vector& strides, Padding padding, int64_t padding_before, int64_t padding_after, @@ -93,8 +93,9 @@ absl::Status ConvBackpropExtractAndVerifyDimension( } // namespace absl::Status ConvBackpropComputeDimensionsV2( - StringPiece label, int num_spatial_dims, const TensorShape& input_shape, - const TensorShape& filter_shape, const TensorShape& out_backprop_shape, + absl::string_view label, int num_spatial_dims, + const TensorShape& input_shape, const TensorShape& filter_shape, + const TensorShape& out_backprop_shape, const absl::Span dilations, const std::vector& strides, Padding padding, absl::Span explicit_paddings, TensorFormat data_format, ConvBackpropDimensions* dims) { @@ -158,10 +159,10 @@ absl::Status ConvBackpropComputeDimensionsV2( } absl::Status ConvBackpropComputeDimensions( - StringPiece label, int num_spatial_dims, const TensorShape& input_shape, - const TensorShape& filter_shape, const TensorShape& out_backprop_shape, - const std::vector& strides, Padding padding, - TensorFormat data_format, ConvBackpropDimensions* dims) { + absl::string_view label, int num_spatial_dims, + const TensorShape& input_shape, const TensorShape& filter_shape, + const TensorShape& out_backprop_shape, const std::vector& strides, + Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) { static constexpr std::array one_dilations = {{1, 1, 1, 1, 1}}; return ConvBackpropComputeDimensionsV2( label, num_spatial_dims, input_shape, filter_shape, out_backprop_shape, diff --git a/tensorflow/core/kernels/conv_grad_shape_utils.h b/tensorflow/core/kernels/conv_grad_shape_utils.h index 9fdc0ce9bcabdc..d83c1bb25ee02f 100644 --- a/tensorflow/core/kernels/conv_grad_shape_utils.h +++ b/tensorflow/core/kernels/conv_grad_shape_utils.h @@ -67,20 +67,21 @@ struct ConvBackpropDimensions { // Conv?DBackpropFilter. Verifies that the dimensions all match, and computes // sizes/padding for the spatial dimensions. Does not support explicit padding. absl::Status ConvBackpropComputeDimensions( - StringPiece label, int num_spatial_dims, const TensorShape& input_shape, - const TensorShape& filter_shape, const TensorShape& out_backprop_shape, - const std::vector& strides, Padding padding, - TensorFormat data_format, ConvBackpropDimensions* dims); + absl::string_view label, int num_spatial_dims, + const TensorShape& input_shape, const TensorShape& filter_shape, + const TensorShape& out_backprop_shape, const std::vector& strides, + Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims); // The V2 version computes the same outputs with arbitrary dilation rate and // supports explicit padding. // TODO(b/67112639): Merge V2 versions and the original versions eventually. absl::Status ConvBackpropComputeDimensionsV2( - StringPiece label, int num_spatial_dims, const TensorShape& input_shape, - const TensorShape& filter_shape, const TensorShape& out_backprop_shape, - absl::Span dilations, const std::vector& strides, - Padding padding, absl::Span explicit_paddings, - TensorFormat data_format, ConvBackpropDimensions* dims); + absl::string_view label, int num_spatial_dims, + const TensorShape& input_shape, const TensorShape& filter_shape, + const TensorShape& out_backprop_shape, absl::Span dilations, + const std::vector& strides, Padding padding, + absl::Span explicit_paddings, TensorFormat data_format, + ConvBackpropDimensions* dims); // Computes the shape of the in_backprop. absl::Status Conv2DBackpropComputeInputShape( diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index aa0a364988331f..eff4f1c145518f 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -507,7 +507,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { if (dataset()->env_->FileExists(lockfile_).ok()) { // Attempt to read the contents of the lockfile. char contents_scratch[151] = {0}; // Initialize all to 0. - StringPiece contents; + absl::string_view contents; std::unique_ptr file; if (dataset()->env_->NewRandomAccessFile(lockfile_, &file).ok()) { file->Read(0, 150, &contents, contents_scratch).IgnoreError(); @@ -621,7 +621,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { *end_of_sequence = true; return absl::OkStatus(); } - StringPiece key = reader_.key(); + absl::string_view key = reader_.key(); DCHECK_EQ(key, dataset()->FormatName(cur_index_, i)); TF_RETURN_IF_ERROR(reader_.ReadCurrent(&(*out_tensors)[i])); TF_RETURN_IF_ERROR(reader_.status()); diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 7c190403c44b89..32c7cd028b34ae 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -816,6 +816,11 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/framework:attr_value_proto_cc", + "//tensorflow/core/framework:dataset_options_proto_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", ], ) @@ -840,6 +845,9 @@ tf_kernel_library( "//tensorflow/core/platform:platform_port", "//tensorflow/core/profiler/lib:traceme", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", ], @@ -855,7 +863,10 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/framework:types_proto_cc", "//tensorflow/core/kernels/data/experimental/sql", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", ], ) @@ -868,6 +879,8 @@ tf_kernel_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:summary_interface", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", ], ) @@ -880,6 +893,7 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status", ], ) @@ -894,6 +908,10 @@ tf_kernel_library( "//tensorflow/core:lib_internal", "//tensorflow/core/data:captured_function", "//tensorflow/core/data:dataset_utils", + "//tensorflow/core/framework:attr_value_proto_cc", + "//tensorflow/core/framework:dataset_options_proto_cc", + "//tensorflow/core/framework:types_proto_cc", + "@com_google_absl//absl/status", ], ) @@ -907,6 +925,9 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/data:dataset_utils", + "//tensorflow/core/framework:dataset_options_proto_cc", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@eigen_archive//:eigen3", ], ) @@ -920,7 +941,10 @@ tf_kernel_library( "//tensorflow/core:lib_internal", "//tensorflow/core/data:dataset_utils", "//tensorflow/core/data:root_dataset", + "//tensorflow/core/framework:types_proto_cc", "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", ], ) @@ -934,6 +958,7 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/framework:dataset_options_proto_cc", + "//tensorflow/core/framework:types_proto_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", @@ -952,6 +977,9 @@ tf_kernel_library( "//tensorflow/core:experimental_dataset_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/framework:types_proto_cc", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@eigen_archive//:eigen3", ], ) @@ -966,6 +994,7 @@ tf_kernel_library( "//tensorflow/core:lib_internal", "//tensorflow/core/data:captured_function", "//tensorflow/core/data:name_utils", + "//tensorflow/core/framework:dataset_options_proto_cc", "//tensorflow/core/framework:types_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", @@ -991,7 +1020,10 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/data:dataset_test_base", + "//tensorflow/core/framework:types_proto_cc", "//tensorflow/core/kernels/data:tensor_slice_dataset_op", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", "@eigen_archive//:eigen3", ], ) diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index f296c4a0a96070..e1999ad1dbae41 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -265,7 +265,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { inputs.push_back(iteration_counter_handle); // Attributes - std::vector> attrs; + std::vector> attrs; AttrValue task_refresh_interval_hint_ms; b->BuildAttrValue(absl::ToInt64Milliseconds(task_refresh_interval_), &task_refresh_interval_hint_ms); diff --git a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc index 4c0184e1b4b36e..524621ab3b0ba9 100644 --- a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc @@ -151,7 +151,8 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { } else { // search a new pattern current_pattern_ = dataset()->patterns_[current_pattern_index_]; - StringPiece current_pattern_view = StringPiece(current_pattern_); + absl::string_view current_pattern_view = + absl::string_view(current_pattern_); // Windows paths contain backslashes and Windows APIs accept forward // and backslashes equivalently, so we convert the pattern to use @@ -168,7 +169,7 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { isWindows_ = false; } - StringPiece fixed_prefix = current_pattern_view.substr( + absl::string_view fixed_prefix = current_pattern_view.substr( 0, current_pattern_view.find_first_of("*?[\\")); string current_dir(io::Dirname(fixed_prefix)); @@ -277,8 +278,8 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { absl::Status UpdateIterator(IteratorContext* ctx, FileSystem* fs, const string& dir, const string& eval_pattern) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - StringPiece fixed_prefix = - StringPiece(eval_pattern) + absl::string_view fixed_prefix = + absl::string_view(eval_pattern) .substr(0, eval_pattern.find_first_of("*?[\\")); filepath_queue_.push(PathStatus(dir, true)); diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc index 662adc5295bfc4..c92e9a57bbbdd9 100644 --- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc @@ -203,7 +203,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { b->AddScalar(prefetch_input_elements_, &prefetch_input_elements_node)); inputs.emplace_back(input_index++, prefetch_input_elements_node); - std::vector> attrs; + std::vector> attrs; AttrValue f; b->BuildAttrValue(captured_func_->func(), &f); diff --git a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc index 5c7c6013ae8aad..3390c25af62491 100644 --- a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc @@ -303,7 +303,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { dense_defaults_nodes.emplace_back(node); } - std::vector> attrs; + std::vector> attrs; AttrValue sparse_keys_attr; b->BuildAttrValue(sparse_keys_, &sparse_keys_attr); diff --git a/tensorflow/core/kernels/data/experimental/random_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/random_dataset_op_test.cc index b099c7caea2365..a3e38ce4aeab90 100644 --- a/tensorflow/core/kernels/data/experimental/random_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/random_dataset_op_test.cc @@ -94,12 +94,11 @@ class RandomDatasetParams : public DatasetParams { ResourceHandle CreateDummyResourceHandle() { return ResourceHandle(); } - virtual std::vector GetInputTensors() const override { + std::vector GetInputTensors() const override { return {seed_, seed2_, seed_generator_resource_}; } - virtual absl::Status GetInputNames( - std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { *input_names = {RandomDatasetOp::kSeed, RandomDatasetOp::kSeed2}; if (op_version_ == 2) { input_names->emplace_back("seed_generator"); @@ -107,8 +106,7 @@ class RandomDatasetParams : public DatasetParams { return absl::OkStatus(); } - virtual absl::Status GetAttributes( - AttributeVector* attributes) const override { + absl::Status GetAttributes(AttributeVector* attributes) const override { *attributes = {{"output_types", output_dtypes_}, {"output_shapes", output_shapes_}, {"metadata", ""}}; @@ -119,9 +117,7 @@ class RandomDatasetParams : public DatasetParams { return absl::OkStatus(); } - virtual string dataset_type() const override { - return RandomDatasetOp::kDatasetType; - } + string dataset_type() const override { return RandomDatasetOp::kDatasetType; } private: Tensor seed_; diff --git a/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc index 1657cef0a092a9..69716a21df3c98 100644 --- a/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc @@ -12,10 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include +#include +#include #include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/dataset_options.pb.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index a8f3e1ed9a38fc..f1d7e58c141158 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/kernels/data/experimental/snapshot_dataset_op.h" #include +#include +#include #include #include #include @@ -22,15 +24,27 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/time/clock.h" +#include "absl/time/time.h" #include "tensorflow/core/data/hash_utils.h" #include "tensorflow/core/data/serialization_utils.h" #include "tensorflow/core/data/snapshot_utils.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/dataset_options.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" // NOLINT +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/raw_coding.h" @@ -1914,7 +1928,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { absl::StrSplit(split_filename.back(), '.'); std::string max_num_str = split_snapshot_filename[0]; uint64 max_num; - if (!strings::safe_strtou64(max_num_str, &max_num)) { + if (!absl::SimpleAtoi(max_num_str, &max_num)) { return errors::Internal("Could not parse: ", max_num, " as uint64"); } next_file_index_ = max_num + 1; diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.h b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.h index fb1fa875af264d..7faaa570ab846b 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.h +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_DATASET_OP_H_ #define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_DATASET_OP_H_ +#include #include +#include #include "absl/container/flat_hash_map.h" #include "tensorflow/core/data/captured_function.h" diff --git a/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc index 414773a48e9e1d..bca17788d33386 100644 --- a/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc @@ -12,11 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include +#include +#include "absl/log/log.h" +#include "absl/status/status.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/data/experimental/sql/driver_manager.h" #include "tensorflow/core/kernels/data/experimental/sql/query_connection.h" #include "tensorflow/core/lib/io/inputbuffer.h" diff --git a/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc index d338c42fda59de..b07f82314e6142 100644 --- a/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc +++ b/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc @@ -12,8 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include +#include +#include "absl/status/status.h" +#include "absl/types/span.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/resource_op_kernel.h" diff --git a/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc b/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc index 20c3ff46139a8e..a14d2b28b3a72c 100644 --- a/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc +++ b/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc @@ -12,7 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include +#include + +#include "absl/status/status.h" #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/dataset_options.pb.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc b/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc index 067e4ca32d3189..2570f680944042 100644 --- a/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc @@ -12,16 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include +#include #include +#include "absl/status/status.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" #include "tensorflow/core/data/captured_function.h" #include "tensorflow/core/data/dataset_utils.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/dataset_options.pb.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc index 819bf0f254a805..08c0fd13842a4e 100644 --- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc @@ -14,10 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/experimental/threadpool_dataset_op.h" +#include +#include #include +#include +#include +#include "absl/log/check.h" +#include "absl/status/status.h" #include "tensorflow/core/data/dataset_utils.h" #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/dataset_options.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/refcount.h" diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.h b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.h index 88d5ef7a4c341a..1255365d5fe525 100644 --- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.h +++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_THREADPOOL_DATASET_OP_H_ #define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_THREADPOOL_DATASET_OP_H_ +#include + #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/platform/platform.h" diff --git a/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc b/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc index 6d088a3c01daf3..c10a46c9fc6e5f 100644 --- a/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc +++ b/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc @@ -12,6 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "tensorflow/core/data/dataset_utils.h" #include "tensorflow/core/data/root_dataset.h" #include "tensorflow/core/framework/dataset.h" @@ -19,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/io/record_writer.h" @@ -38,7 +45,7 @@ class ToTFRecordOp : public AsyncOpKernel { template absl::Status ParseScalarArgument(OpKernelContext* ctx, - const StringPiece& argument_name, + const absl::string_view& argument_name, T* output) { const Tensor* argument_t; TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); diff --git a/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc index 5682d1966eba4a..f74e5a3d98620a 100644 --- a/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/status.h" diff --git a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc index 24e5aa6cd6e19b..750a86047c8be3 100644 --- a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc @@ -14,8 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/experimental/unique_dataset_op.h" +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/hash/hash.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/data/experimental/unique_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op_test.cc index 4f16c1b856eab8..b218f27516f14e 100644 --- a/tensorflow/core/kernels/data/experimental/unique_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op_test.cc @@ -11,7 +11,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/experimental/unique_dataset_op.h" +#include +#include +#include + +#include +#include "absl/status/status.h" #include "tensorflow/core/data/dataset_test_base.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" diff --git a/tensorflow/core/kernels/data/experimental/weighted_flat_map_dataset_op.cc b/tensorflow/core/kernels/data/experimental/weighted_flat_map_dataset_op.cc index 2560d2427fec0b..1d5715d34e44f0 100644 --- a/tensorflow/core/kernels/data/experimental/weighted_flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/weighted_flat_map_dataset_op.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/data/captured_function.h" #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/dataset_options.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc index 54a1aae03d00c9..e996fac56ae648 100644 --- a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc +++ b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc @@ -289,7 +289,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { if (s.ok()) { bytes_counter->IncrementBy(dataset()->record_bytes_); lookahead_cache_.append(record); - StringPiece lookahead_cache_view(lookahead_cache_); + absl::string_view lookahead_cache_view(lookahead_cache_); record = tstring( lookahead_cache_view.substr(0, dataset()->record_bytes_)); lookahead_cache_ = tstring( diff --git a/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc b/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc index 84de715927d369..63c0d465e431e9 100644 --- a/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc @@ -167,7 +167,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { Node* drop_remainder = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder)); - std::vector> attrs; + std::vector> attrs; // Attr: parallel_copy AttrValue parallel_copy_attr; b->BuildAttrValue(parallel_copy_, ¶llel_copy_attr); diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index acddbcd222496e..2f90731b0cf13b 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -331,7 +331,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { b->AddScalar(num_parallel_calls_, &num_parallel_calls_node)); inputs.emplace_back(input_index++, num_parallel_calls_node); - std::vector> attrs; + std::vector> attrs; AttrValue f; b->BuildAttrValue(captured_func_->func(), &f); attrs.emplace_back(kFunc, f); diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 5c25b52f48b71c..68680cc217c71a 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -213,7 +213,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( b->AddScalar(num_parallel_calls_, &num_parallel_calls)); } - std::vector> attrs; + std::vector> attrs; // Attr: f AttrValue f_attr; diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h index 92607656b52a00..f417caf2b4774c 100644 --- a/tensorflow/core/kernels/debug_ops.h +++ b/tensorflow/core/kernels/debug_ops.h @@ -140,7 +140,7 @@ class BaseDebugOp : public OpKernel { if (name_items.size() == 2) { node_name = name_items[0]; OP_REQUIRES( - context, strings::safe_strto32(name_items[1], &output_slot), + context, absl::SimpleAtoi(name_items[1], &output_slot), errors::InvalidArgument("Invalid string value for output_slot: \"", name_items[1], "\"")); } else if (name_items.size() == 1) { diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc index 1cd77c7218ae1a..9d9ff205096aba 100644 --- a/tensorflow/core/kernels/decode_csv_op.cc +++ b/tensorflow/core/kernels/decode_csv_op.cc @@ -106,7 +106,7 @@ class DecodeCSVOp : public OpKernel { output[f]->flat()(i) = record_defaults[f].flat()(0); } else { int32_t value; - OP_REQUIRES(ctx, strings::safe_strto32(fields[f], &value), + OP_REQUIRES(ctx, absl::SimpleAtoi(fields[f], &value), errors::InvalidArgument( "Field ", f, " in record ", i, " is not a valid int32: ", fields[f])); @@ -127,7 +127,7 @@ class DecodeCSVOp : public OpKernel { record_defaults[f].flat()(0); } else { int64_t value; - OP_REQUIRES(ctx, strings::safe_strto64(fields[f], &value), + OP_REQUIRES(ctx, absl::SimpleAtoi(fields[f], &value), errors::InvalidArgument( "Field ", f, " in record ", i, " is not a valid int64: ", fields[f])); @@ -146,7 +146,7 @@ class DecodeCSVOp : public OpKernel { output[f]->flat()(i) = record_defaults[f].flat()(0); } else { float value; - OP_REQUIRES(ctx, strings::safe_strtof(fields[f], &value), + OP_REQUIRES(ctx, absl::SimpleAtof(fields[f], &value), errors::InvalidArgument( "Field ", f, " in record ", i, " is not a valid float: ", fields[f])); @@ -166,7 +166,7 @@ class DecodeCSVOp : public OpKernel { record_defaults[f].flat()(0); } else { double value; - OP_REQUIRES(ctx, strings::safe_strtod(fields[f], &value), + OP_REQUIRES(ctx, absl::SimpleAtod(fields[f], &value), errors::InvalidArgument( "Field ", f, " in record ", i, " is not a valid double: ", fields[f])); diff --git a/tensorflow/core/kernels/deep_conv2d.cc b/tensorflow/core/kernels/deep_conv2d.cc index 68f5bd57256392..dcb50c3c2b88ab 100644 --- a/tensorflow/core/kernels/deep_conv2d.cc +++ b/tensorflow/core/kernels/deep_conv2d.cc @@ -82,7 +82,7 @@ static int64_t GetDirectConvCost(int filter_rows, int filter_cols, int in_depth, static bool ReadBoolFromEnvVar(const char* env_var_name, bool default_val) { const char* tf_env_var_val = getenv(env_var_name); if (tf_env_var_val != nullptr) { - StringPiece tf_env_var_val_str(tf_env_var_val); + absl::string_view tf_env_var_val_str(tf_env_var_val); if (tf_env_var_val_str == "0") { return false; } diff --git a/tensorflow/core/kernels/depthtospace_op.cc b/tensorflow/core/kernels/depthtospace_op.cc index 6f720190c9652b..fd5caa4dafc028 100644 --- a/tensorflow/core/kernels/depthtospace_op.cc +++ b/tensorflow/core/kernels/depthtospace_op.cc @@ -19,6 +19,8 @@ limitations under the License. #include "tensorflow/core/kernels/depthtospace_op.h" +#include +#include #include #include #include @@ -50,9 +52,13 @@ class DepthToSpaceOp : public OpKernel { errors::InvalidArgument("Invalid data format")); OP_REQUIRES_OK(context, context->GetAttr("block_size", &block_size_)); - OP_REQUIRES(context, block_size_ > 1, - errors::InvalidArgument("Block size should be > 1, but was: ", - block_size_)); + // This upper bound is needed to avoid an overflow when the block size value + // is squared in the output computation. + int block_size_limit = sqrt(std::numeric_limits::max()); + OP_REQUIRES(context, block_size_ > 1 && block_size_ <= block_size_limit, + errors::InvalidArgument( + "Block size should be > 1 and <= ", block_size_limit, + " but was: ", block_size_)); if (std::is_same::value) { OP_REQUIRES( diff --git a/tensorflow/core/kernels/encode_proto_op.cc b/tensorflow/core/kernels/encode_proto_op.cc index 5148e849b307bd..665abb9b4823fc 100644 --- a/tensorflow/core/kernels/encode_proto_op.cc +++ b/tensorflow/core/kernels/encode_proto_op.cc @@ -301,7 +301,7 @@ static void WriteStringAdapter(int field_number, const tstring& value, CodedOutputStream* output) { // Unfortunately, external proto does not accept string_view. #if defined(PLATFORM_GOOGLE) - WireFormatLite::WriteString(field_number, StringPiece(value), output); + WireFormatLite::WriteString(field_number, absl::string_view(value), output); #else WireFormatLite::WriteString(field_number, string(value), output); #endif @@ -311,7 +311,7 @@ static void WriteBytesAdapter(int field_number, const tstring& value, CodedOutputStream* output) { // Unfortunately, external proto does not accept string_view. #if defined(PLATFORM_GOOGLE) - WireFormatLite::WriteBytes(field_number, StringPiece(value), output); + WireFormatLite::WriteBytes(field_number, absl::string_view(value), output); #else WireFormatLite::WriteBytes(field_number, string(value), output); #endif diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc index 163e89bc0b4b0f..d7fb26a35722c4 100644 --- a/tensorflow/core/kernels/example_parsing_ops.cc +++ b/tensorflow/core/kernels/example_parsing_ops.cc @@ -56,9 +56,9 @@ class ParseExampleOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor* names; const Tensor* serialized; - std::vector dense_keys_t; - std::vector sparse_keys_t; - std::vector ragged_keys_t; + std::vector dense_keys_t; + std::vector sparse_keys_t; + std::vector ragged_keys_t; OpInputList dense_defaults; // Grab the inputs. @@ -102,8 +102,8 @@ class ParseExampleOp : public OpKernel { protected: // Copies keys from tensor to std::vector. - absl::Status GetTensorKeys(OpKernelContext* ctx, StringPiece input_name, - std::vector* keys) const { + absl::Status GetTensorKeys(OpKernelContext* ctx, absl::string_view input_name, + std::vector* keys) const { const Tensor* key_t; TF_RETURN_IF_ERROR(ctx->input(input_name, &key_t)); keys->reserve(key_t->NumElements()); @@ -115,8 +115,9 @@ class ParseExampleOp : public OpKernel { } // Copies keys from OpInputList of scalar to std::vector. - absl::Status GetInputListKeys(OpKernelContext* ctx, StringPiece input_name, - std::vector* keys) const { + absl::Status GetInputListKeys(OpKernelContext* ctx, + absl::string_view input_name, + std::vector* keys) const { OpInputList key_list; TF_RETURN_IF_ERROR(ctx->input_list(input_name, &key_list)); keys->reserve(key_list.size()); @@ -130,9 +131,9 @@ class ParseExampleOp : public OpKernel { absl::Status CheckInputShapes( const Tensor* serialized, const Tensor* names, const OpInputList& dense_defaults, - const std::vector& dense_keys_t, - const std::vector& sparse_keys_t, - const std::vector& ragged_keys_t) const { + const std::vector& dense_keys_t, + const std::vector& sparse_keys_t, + const std::vector& ragged_keys_t) const { if (op_version_ == 2) { if (TensorShapeUtils::IsMatrixOrHigher(serialized->shape())) { return errors::InvalidArgument( @@ -211,9 +212,9 @@ class ParseExampleOp : public OpKernel { // Populates the FastParseExampleConfig from keys & defaults. example::FastParseExampleConfig MakeConfig( - const std::vector& dense_keys_t, - const std::vector& sparse_keys_t, - const std::vector& ragged_keys_t, + const std::vector& dense_keys_t, + const std::vector& sparse_keys_t, + const std::vector& ragged_keys_t, const OpInputList& dense_defaults) const { example::FastParseExampleConfig config; config.dense.reserve(attrs_.num_dense); diff --git a/tensorflow/core/kernels/fill_functor.cc b/tensorflow/core/kernels/fill_functor.cc index a00cb9d2742c45..0ee9dbbfd29b6a 100644 --- a/tensorflow/core/kernels/fill_functor.cc +++ b/tensorflow/core/kernels/fill_functor.cc @@ -64,6 +64,9 @@ DEFINE_SETZERO_CPU(complex128); DEFINE_SETZERO_CPU(Variant); DEFINE_SETZERO_CPU(float8_e5m2); DEFINE_SETZERO_CPU(float8_e4m3fn); +DEFINE_SETZERO_CPU(float8_e4m3fnuz); +DEFINE_SETZERO_CPU(float8_e4m3b11fnuz); +DEFINE_SETZERO_CPU(float8_e5m2fnuz); DEFINE_SETZERO_CPU(int4); DEFINE_SETZERO_CPU(uint4); #undef DEFINE_SETZERO_CPU @@ -94,6 +97,9 @@ DEFINE_SETONE_CPU(complex64); DEFINE_SETONE_CPU(complex128); DEFINE_SETONE_CPU(float8_e5m2); DEFINE_SETONE_CPU(float8_e4m3fn); +DEFINE_SETONE_CPU(float8_e4m3fnuz); +DEFINE_SETONE_CPU(float8_e4m3b11fnuz); +DEFINE_SETONE_CPU(float8_e5m2fnuz); DEFINE_SETONE_CPU(int4); DEFINE_SETONE_CPU(uint4); #undef DEFINE_SETONE_CPU @@ -132,6 +138,9 @@ DEFINE_FILL_CPU(qint16); DEFINE_FILL_CPU(qint32); DEFINE_FILL_CPU(float8_e5m2); DEFINE_FILL_CPU(float8_e4m3fn); +DEFINE_FILL_CPU(float8_e4m3fnuz); +DEFINE_FILL_CPU(float8_e4m3b11fnuz); +DEFINE_FILL_CPU(float8_e5m2fnuz); TF_CALL_int4(DEFINE_FILL_CPU); TF_CALL_uint4(DEFINE_FILL_CPU); #undef DEFINE_FILL_CPU diff --git a/tensorflow/core/kernels/image/decode_image_op.cc b/tensorflow/core/kernels/image/decode_image_op.cc index d9c4def13bc044..e87d79fe67b5f2 100644 --- a/tensorflow/core/kernels/image/decode_image_op.cc +++ b/tensorflow/core/kernels/image/decode_image_op.cc @@ -67,7 +67,7 @@ enum FileFormat { }; // Classify the contents of a file based on starting bytes (the magic number). -FileFormat ClassifyFileFormat(StringPiece data) { +FileFormat ClassifyFileFormat(absl::string_view data) { if (absl::StartsWith(data, kJpegMagicBytes)) return kJpgFormat; if (absl::StartsWith(data, kPngMagicBytes)) return kPngFormat; if (absl::StartsWith(data, kGifMagicBytes)) return kGifFormat; @@ -197,7 +197,7 @@ class DecodeImageV2Op : public OpKernel { context, TensorShapeUtils::IsScalar(contents.shape()), errors::InvalidArgument("`contents` must be scalar but got shape", contents.shape().DebugString())); - const StringPiece input = contents.scalar()(); + const absl::string_view input = contents.scalar()(); OP_REQUIRES(context, !input.empty(), errors::InvalidArgument("Input is empty.")); OP_REQUIRES(context, input.size() <= std::numeric_limits::max(), @@ -226,7 +226,7 @@ class DecodeImageV2Op : public OpKernel { } } - void DecodeJpegV2(OpKernelContext* context, StringPiece input) { + void DecodeJpegV2(OpKernelContext* context, absl::string_view input) { OP_REQUIRES(context, channels_ == 0 || channels_ == 1 || channels_ == 3, errors::InvalidArgument("JPEG does not support 4 channels")); @@ -327,7 +327,7 @@ class DecodeImageV2Op : public OpKernel { } } - void DecodePngV2(OpKernelContext* context, StringPiece input) { + void DecodePngV2(OpKernelContext* context, absl::string_view input) { int channel_bits = (data_type_ == DataType::DT_UINT8) ? 8 : 16; png::DecodeContext decode; OP_REQUIRES( @@ -430,7 +430,7 @@ class DecodeImageV2Op : public OpKernel { } } - void DecodeGifV2(OpKernelContext* context, StringPiece input) { + void DecodeGifV2(OpKernelContext* context, absl::string_view input) { // GIF has 3 channels. OP_REQUIRES(context, channels_ == 0 || channels_ == 3, errors::InvalidArgument("channels must be 0 or 3 for GIF, got ", @@ -532,7 +532,7 @@ class DecodeImageV2Op : public OpKernel { } } - void DecodeBmpV2(OpKernelContext* context, StringPiece input) { + void DecodeBmpV2(OpKernelContext* context, absl::string_view input) { OP_REQUIRES( context, channels_ != 1, errors::InvalidArgument( diff --git a/tensorflow/core/kernels/image/extract_jpeg_shape_op.cc b/tensorflow/core/kernels/image/extract_jpeg_shape_op.cc index c74245dcf85ccc..38bcd35d4fd35b 100644 --- a/tensorflow/core/kernels/image/extract_jpeg_shape_op.cc +++ b/tensorflow/core/kernels/image/extract_jpeg_shape_op.cc @@ -41,7 +41,7 @@ class ExtractJpegShapeOp : public OpKernel { OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()), errors::InvalidArgument("contents must be scalar, got shape ", contents.shape().DebugString())); - const StringPiece input = contents.scalar()(); + const absl::string_view input = contents.scalar()(); OP_REQUIRES(context, input.size() <= std::numeric_limits::max(), errors::InvalidArgument("JPEG contents are too large for int: ", input.size())); diff --git a/tensorflow/core/kernels/image/sampling_kernels.cc b/tensorflow/core/kernels/image/sampling_kernels.cc index ae62a1b2e3dacd..d03247fc7487bf 100644 --- a/tensorflow/core/kernels/image/sampling_kernels.cc +++ b/tensorflow/core/kernels/image/sampling_kernels.cc @@ -23,7 +23,7 @@ limitations under the License. namespace tensorflow { namespace functor { -SamplingKernelType SamplingKernelTypeFromString(const StringPiece str) { +SamplingKernelType SamplingKernelTypeFromString(const absl::string_view str) { const string lower_case = absl::AsciiStrToLower(str); if (lower_case == "lanczos1") return Lanczos1Kernel; if (lower_case == "lanczos3") return Lanczos3Kernel; diff --git a/tensorflow/core/kernels/image/sampling_kernels.h b/tensorflow/core/kernels/image/sampling_kernels.h index 1903e675038b86..6f889adde3f5fe 100644 --- a/tensorflow/core/kernels/image/sampling_kernels.h +++ b/tensorflow/core/kernels/image/sampling_kernels.h @@ -62,7 +62,7 @@ enum SamplingKernelType { // Converts a string into the corresponding kernel type. // Returns SamplingKernelTypeEnd if the string couldn't be converted. -SamplingKernelType SamplingKernelTypeFromString(const StringPiece str); +SamplingKernelType SamplingKernelTypeFromString(const absl::string_view str); // A function object for a Lanczos kernel. struct LanczosKernelFunc { diff --git a/tensorflow/core/kernels/immutable_constant_op_test.cc b/tensorflow/core/kernels/immutable_constant_op_test.cc index e1edbadf9f210f..8dd06efebf1430 100644 --- a/tensorflow/core/kernels/immutable_constant_op_test.cc +++ b/tensorflow/core/kernels/immutable_constant_op_test.cc @@ -68,7 +68,7 @@ class TestFileSystem : public NullFileSystem { const string& fname, TransactionToken* token, std::unique_ptr* result) override { float val = 0; - StringPiece scheme, host, path; + absl::string_view scheme, host, path; io::ParseURI(fname, &scheme, &host, &path); // For the tests create in-memory regions with float values equal to the // region name. @@ -153,8 +153,8 @@ absl::Status CreateTempFileFloat(Env* env, float value, uint64 size, std::unique_ptr file; TF_RETURN_IF_ERROR(env->NewWritableFile(*filename, &file)); for (uint64 i = 0; i < size; ++i) { - StringPiece sp(static_cast(static_cast(&value)), - sizeof(value)); + absl::string_view sp(static_cast(static_cast(&value)), + sizeof(value)); TF_RETURN_IF_ERROR(file->Append(sp)); } TF_RETURN_IF_ERROR(file->Close()); diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc index a9640f553da2b8..f92d55919e7656 100644 --- a/tensorflow/core/kernels/logging_ops.cc +++ b/tensorflow/core/kernels/logging_ops.cc @@ -37,8 +37,8 @@ static mutex* file_mutex = new mutex(); // Appends the given data to the specified file. It will create the file if it // doesn't already exist. -absl::Status AppendStringToFile(const std::string& fname, StringPiece data, - Env* env) { +absl::Status AppendStringToFile(const std::string& fname, + absl::string_view data, Env* env) { // TODO(ckluk): If opening and closing on every log causes performance issues, // we can reimplement using reference counters. mutex_lock l(*file_mutex); diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc index 1a382261519992..838655b47688a2 100644 --- a/tensorflow/core/kernels/lookup_util.cc +++ b/tensorflow/core/kernels/lookup_util.cc @@ -217,7 +217,7 @@ class TextFileLineIterator switch (dtype) { case DT_INT32: { int32_t value; - if (!strings::safe_strto32(token.c_str(), &value)) { + if (!absl::SimpleAtoi(token.c_str(), &value)) { valid_ = false; return errors::InvalidArgument("Field ", token, " in line ", next_id_, " is not a valid int32."); @@ -226,7 +226,7 @@ class TextFileLineIterator } break; case DT_INT64: { int64_t value; - if (!strings::safe_strto64(token.c_str(), &value)) { + if (!absl::SimpleAtoi(token.c_str(), &value)) { valid_ = false; return errors::InvalidArgument("Field ", token, " in line ", next_id_, " is not a valid int64."); @@ -235,7 +235,7 @@ class TextFileLineIterator } break; case DT_FLOAT: { float value; - if (!strings::safe_strtof(token.c_str(), &value)) { + if (!absl::SimpleAtof(token.c_str(), &value)) { valid_ = false; return errors::InvalidArgument("Field ", token, " in line ", next_id_, " is not a valid float."); @@ -244,7 +244,7 @@ class TextFileLineIterator } break; case DT_DOUBLE: { double value; - if (!strings::safe_strtod(token.c_str(), &value)) { + if (!absl::SimpleAtod(token.c_str(), &value)) { valid_ = false; return errors::InvalidArgument("Field ", token, " in line ", next_id_, " is not a valid double."); diff --git a/tensorflow/core/kernels/lookup_util.h b/tensorflow/core/kernels/lookup_util.h index ca0e93833b04cb..677c6a5659fc23 100644 --- a/tensorflow/core/kernels/lookup_util.h +++ b/tensorflow/core/kernels/lookup_util.h @@ -33,19 +33,19 @@ namespace lookup { // passed by attribute with name input_name, returns null if the table // doesn't exist. Use GetResourceLookupTable() or GetReferenceLookupTable() if // the input dtype is known. -absl::Status GetLookupTable(StringPiece input_name, OpKernelContext* ctx, +absl::Status GetLookupTable(absl::string_view input_name, OpKernelContext* ctx, LookupInterface** table); -absl::Status GetResourceLookupTable(StringPiece input_name, +absl::Status GetResourceLookupTable(absl::string_view input_name, OpKernelContext* ctx, LookupInterface** table); -absl::Status GetReferenceLookupTable(StringPiece input_name, +absl::Status GetReferenceLookupTable(absl::string_view input_name, OpKernelContext* ctx, LookupInterface** table); // Gets the InitializableLookupTable stored in the // ctx->resource_manager() with key passed by attribute with name // input_name, returns null if the table doesn't exist. -absl::Status GetInitializableLookupTable(StringPiece input_name, +absl::Status GetInitializableLookupTable(absl::string_view input_name, OpKernelContext* ctx, InitializableLookupTable** table); diff --git a/tensorflow/core/kernels/matmul_op_fused.cc b/tensorflow/core/kernels/matmul_op_fused.cc index 677907cc4e7c70..02c8a6d48c92c7 100644 --- a/tensorflow/core/kernels/matmul_op_fused.cc +++ b/tensorflow/core/kernels/matmul_op_fused.cc @@ -39,7 +39,6 @@ limitations under the License. #include #include "Eigen/Core" // from @eigen_archive -#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -51,14 +50,13 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/util/matmul_autotune.h" #include "tensorflow/core/util/tensor_format.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) #include "xla/tsl/framework/contraction/eigen_contraction_kernel.h" #endif #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/stream_executor/gpu/redzone_allocator.h" -#include "xla/stream_executor/integrations/tf_allocator_adapter.h" #include "tensorflow/core/kernels/conv_ops_gpu.h" #include "tensorflow/core/kernels/gpu_utils.h" #include "tensorflow/core/kernels/matmul_op_impl.h" @@ -71,6 +69,8 @@ limitations under the License. #include "tensorflow/core/util/autotune_maps/conv_parameters.h" #include "tensorflow/core/util/proto/proto_utils.h" #include "tensorflow/core/util/use_cudnn.h" +#include "xla/stream_executor/gpu/redzone_allocator.h" +#include "xla/stream_executor/integrations/tf_allocator_adapter.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace tensorflow { @@ -202,7 +202,7 @@ namespace { /* hipBLASLt support Epilogue: https://rocm.docs.amd.com/projects/hipBLASLt/en/latest/datatypes.html#hipblasltepilogue-t -*/ +*/ StatusOr GetBlasLtEpilogOp( FusedComputationType fusion) { if (fusion == FusedComputationType::kBiasAdd) { @@ -484,12 +484,6 @@ struct LaunchFusedMatMulOp { #if !(GOOGLE_CUDA || TF_HIPBLASLT) use_cudnn = true; #endif - const auto& cc = stream->parent()->GetDeviceDescription(). - gpu_compute_capability(); - if (auto *procm = std::get_if< se::RocmComputeCapability >(&cc)) { - use_cudnn = !procm->gfx9_mi200_or_later(); - } - // use_cudnn is for hipblaslt doesn't support yet switch (fusion) { case FusedComputationType::kBiasAddWithGeluExact: @@ -525,7 +519,15 @@ struct LaunchFusedMatMulOp { default: use_cudnn = false; } +#if !(GOOGLE_CUDA || TF_HIPBLASLT) + use_cudnn = true; +#endif + const auto& cc = + stream->parent()->GetDeviceDescription().gpu_compute_capability(); + if (auto* procm = std::get_if(&cc)) { + use_cudnn = !procm->gfx9_mi200_or_later(); + } BlasScratchAllocator scratch_allocator(context); // The Gelu exact fusion is supported by the cuDNN. @@ -605,9 +607,9 @@ struct LaunchFusedMatMulOp { auto launch_func = [&](BlasScratchAllocator& scratch_allocator, size_t alg_idx, se::blas::ProfileResult* profile_result) { - return BlasLtMatmulPlanCache::ExecuteOnStream( - stream, entry, a_ptr, b_ptr, c_ptr, alg_idx, - scratch_allocator, bias_ptr, profile_result); + return BlasLtMatmulPlanCache::ExecuteOnStream( + stream, entry, a_ptr, b_ptr, c_ptr, alg_idx, scratch_allocator, + bias_ptr, profile_result); }; size_t alg_idx = 0; @@ -619,7 +621,7 @@ struct LaunchFusedMatMulOp { } OP_REQUIRES_OK(context, launch_func(scratch_allocator, alg_idx, nullptr)); -#endif // GOOGLE_CUDA || TF_HIPBLASLT +#endif // GOOGLE_CUDA || TF_HIPBLASLT } }; diff --git a/tensorflow/core/kernels/matmul_op_impl.h b/tensorflow/core/kernels/matmul_op_impl.h index db9fde9f0e5296..85037618a32c97 100644 --- a/tensorflow/core/kernels/matmul_op_impl.h +++ b/tensorflow/core/kernels/matmul_op_impl.h @@ -28,7 +28,6 @@ limitations under the License. #include #include "Eigen/Core" // from @eigen_archive -#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/bfloat16.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -45,17 +44,18 @@ limitations under the License. #include "tensorflow/core/util/matmul_autotune.h" #include "tensorflow/core/util/matmul_bcast.h" #include "tensorflow/core/util/work_sharder.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) #include "xla/tsl/framework/contraction/eigen_contraction_kernel.h" #endif #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/stream_executor/host_or_device_scalar.h" #include "tensorflow/core/kernels/gpu_utils.h" #include "tensorflow/core/kernels/matmul_util.h" #include "tensorflow/core/kernels/numeric_options_utils.h" #include "tensorflow/core/platform/stream_executor.h" +#include "xla/stream_executor/host_or_device_scalar.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" @@ -602,9 +602,9 @@ struct LaunchBatchMatMul { static const bool use_autotune = MatmulAutotuneEnable(); bool bCublasLtSupport = true; - const auto& cc = stream->parent()->GetDeviceDescription(). - gpu_compute_capability(); - if(auto *procm = std::get_if< se::RocmComputeCapability >(&cc)) { + const auto& cc = + stream->parent()->GetDeviceDescription().gpu_compute_capability(); + if (auto* procm = std::get_if(&cc)) { bCublasLtSupport = procm->gfx9_mi200_or_later(); } @@ -660,10 +660,10 @@ struct LaunchBatchMatMul { // scratch space is deallocated between runs. BlasScratchAllocator scratch_allocator(context, max_scratch_size); Status cublas_launch_status = - BlasLtMatmulPlanCache::ExecuteOnStream(stream, - *plan_and_algorithms, - *a_ptrs[0], *b_ptrs[0], *c_ptrs[0], i, scratch_allocator, - se::DeviceMemoryBase{}, &profile_result); + BlasLtMatmulPlanCache::ExecuteOnStream( + stream, *plan_and_algorithms, *a_ptrs[0], *b_ptrs[0], + *c_ptrs[0], i, scratch_allocator, se::DeviceMemoryBase{}, + &profile_result); VLOG(4) << " Autotune algorithm " << i << " result: " << profile_result.elapsed_time_in_ms() @@ -701,12 +701,10 @@ struct LaunchBatchMatMul { << "trans_x = " << trans_x << "trans_y = " << trans_y << "adj_x = " << adj_x << "adj_y = " << adj_y; - OP_REQUIRES_OK( - context, - BlasLtMatmulPlanCache::ExecuteOnStream(stream, - *plan_and_algorithms, - *a_ptrs[0], *b_ptrs[0], *c_ptrs[0], - algorithm_idx, scratch_allocator, se::DeviceMemoryBase{})); + OP_REQUIRES_OK(context, BlasLtMatmulPlanCache::ExecuteOnStream( + stream, *plan_and_algorithms, *a_ptrs[0], + *b_ptrs[0], *c_ptrs[0], algorithm_idx, + scratch_allocator, se::DeviceMemoryBase{})); } else { // requires mixed broadcasting const std::vector& a_batch_indices = bcast.x_batch_indices(); const std::vector& b_batch_indices = bcast.y_batch_indices(); diff --git a/tensorflow/core/kernels/matmul_op_test.cc b/tensorflow/core/kernels/matmul_op_test.cc index 897d8fd1772b07..1f953e40738a07 100644 --- a/tensorflow/core/kernels/matmul_op_test.cc +++ b/tensorflow/core/kernels/matmul_op_test.cc @@ -373,6 +373,11 @@ static auto GetActivations(DataType dtype) { TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x128x64WithActivation) { for (const string& activation : GetActivations(this->kTValueType)) { + if (this->kTValueType == DT_HALF) { + GTEST_SKIP() << "TODO(rocm): Weekly-sync 25-01-13: Skip investigate " + "issue with Eigen::half"; + } + this->VerifyConv2DWithBiasAndActivation(256, 128, 64, false, false, activation); this->VerifyConv2DWithBiasAndActivation(256, 128, 64, true, false, @@ -386,6 +391,10 @@ TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x128x64WithActivation) { TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x256WithActivation) { for (const string& activation : GetActivations(this->kTValueType)) { + if (this->kTValueType == DT_HALF) { + GTEST_SKIP() << "TODO(rocm): Weekly-sync 25-01-13: Skip investigate " + "issue with Eigen::half"; + } this->VerifyConv2DWithBiasAndActivation(1, 256, 256, false, false, activation); } @@ -393,6 +402,10 @@ TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x256WithActivation) { TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x1WithActivation) { for (const string& activation : GetActivations(this->kTValueType)) { + if (this->kTValueType == DT_HALF) { + GTEST_SKIP() << "TODO(rocm): Weekly-sync 25-01-13: Skip investigate " + "issue with Eigen::half"; + } this->VerifyConv2DWithBiasAndActivation(256, 256, 1, false, false, activation); } @@ -400,6 +413,10 @@ TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x1WithActivation) { TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x1WithActivation) { for (const string& activation : GetActivations(this->kTValueType)) { + if (this->kTValueType == DT_HALF) { + GTEST_SKIP() << "TODO(rocm): Weekly-sync 25-01-13: Skip investigate " + "issue with Eigen::half"; + } this->VerifyConv2DWithBiasAndActivation(1, 256, 1, false, false, activation); } diff --git a/tensorflow/core/kernels/matmul_util.cc b/tensorflow/core/kernels/matmul_util.cc index 8f95e9a9336fe2..6612513676fb3c 100644 --- a/tensorflow/core/kernels/matmul_util.cc +++ b/tensorflow/core/kernels/matmul_util.cc @@ -14,19 +14,19 @@ limitations under the License. #if GOOGLE_CUDA || TF_HIPBLASLT +#include #include #include -#include #include -#include "xla/status_macros.h" -#include "xla/xla_data.pb.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/tensor_float_32_utils.h" #include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/matmul_autotune.h" +#include "xla/status_macros.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/xla_data.pb.h" namespace tensorflow { @@ -93,10 +93,11 @@ StatusOr GetBlasComputationType( } // namespace -/* static */ BlasLtMatmulPlanCache& BlasLtMatmulPlanCache::i(se::Stream *stream) { +/* static */ BlasLtMatmulPlanCache& BlasLtMatmulPlanCache::i( + se::Stream* stream) { static absl::Mutex m(absl::kConstInit); // Each GPU gets different cache instance - static std::deque< BlasLtMatmulPlanCache > meta(8); + static std::deque meta(8); absl::MutexLock lock(&m); size_t dev_id = stream->parent()->device_ordinal(); if (dev_id >= meta.size()) meta.resize(dev_id + 1); @@ -105,17 +106,16 @@ StatusOr GetBlasComputationType( /* static */ auto BlasLtMatmulPlanCache::GetOrCreate( se::Stream* stream, const BlasLtMatmulPlanParams& params, - absl::Mutex** ppmu, std::optional max_algorithm_count) -> StatusOr{ + absl::Mutex** ppmu, std::optional max_algorithm_count) + -> StatusOr { static const int64_t max_scratch_size = GetWorkspaceLimit(1LL << 32); // 4GB by default static const int64_t max_autotune_algorithm_count = MatmulMaxAutotuneAlgorithmCount(); if (!max_algorithm_count) max_algorithm_count = max_autotune_algorithm_count; - auto& self = BlasLtMatmulPlanCache::i(stream); - absl::MutexLock lock(self.mutex_.get()); auto [ptr, inserted] = self.map_.emplace(params, Entry{}); auto& entry = ptr->second; if (inserted) { @@ -167,7 +167,7 @@ StatusOr GetBlasComputationType( }; TF_ASSIGN_OR_RETURN(entry.plan, se::gpu::BlasLt::GetMatmulPlan( - stream, cfg, params.epilogue)); + stream, cfg, params.epilogue)); TF_ASSIGN_OR_RETURN( entry.algorithms, @@ -177,31 +177,23 @@ StatusOr GetBlasComputationType( return &entry; } -/*static */ Status BlasLtMatmulPlanCache::ExecuteOnStream(se::Stream* stream, - const Entry& entry, - const se::DeviceMemoryBase& a, - const se::DeviceMemoryBase& b, - se::DeviceMemoryBase& c, - size_t algorithm_idx, - se::ScratchAllocator& scratch_allocator, - const se::DeviceMemoryBase& bias, - se::blas::ProfileResult* profile_result) { - - return entry.plan->ExecuteOnStream( - stream, a, b, c, c, - bias, // bias_buffer - se::DeviceMemoryBase{}, // aux_buffer - se::DeviceMemoryBase{}, // a_scale_buffer - se::DeviceMemoryBase{}, // b_scale_buffer - se::DeviceMemoryBase{}, // c_scale_buffer - se::DeviceMemoryBase{}, // d_scale_buffer - se::DeviceMemoryBase{}, // d_amax_buffer - entry.algorithms[algorithm_idx], - scratch_allocator, - profile_result); +/*static */ Status BlasLtMatmulPlanCache::ExecuteOnStream( + se::Stream* stream, const Entry& entry, const se::DeviceMemoryBase& a, + const se::DeviceMemoryBase& b, se::DeviceMemoryBase& c, + size_t algorithm_idx, se::ScratchAllocator& scratch_allocator, + const se::DeviceMemoryBase& bias, se::blas::ProfileResult* profile_result) { + return entry.plan->ExecuteOnStream(stream, a, b, c, c, + bias, // bias_buffer + se::DeviceMemoryBase{}, // aux_buffer + se::DeviceMemoryBase{}, // a_scale_buffer + se::DeviceMemoryBase{}, // b_scale_buffer + se::DeviceMemoryBase{}, // c_scale_buffer + se::DeviceMemoryBase{}, // d_scale_buffer + se::DeviceMemoryBase{}, // d_amax_buffer + + entry.algorithms[algorithm_idx], + scratch_allocator, profile_result); } - - } // namespace tensorflow -#endif \ No newline at end of file +#endif diff --git a/tensorflow/core/kernels/matmul_util.h b/tensorflow/core/kernels/matmul_util.h index dbf85eab41242c..fab1929a5ecbf7 100644 --- a/tensorflow/core/kernels/matmul_util.h +++ b/tensorflow/core/kernels/matmul_util.h @@ -22,10 +22,10 @@ limitations under the License. #if GOOGLE_CUDA || TF_HIPBLASLT #include "absl/container/node_hash_map.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "tensorflow/core/framework/types.h" #include "tsl/platform/types.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_blas_lt.h" namespace tensorflow { @@ -35,7 +35,6 @@ namespace tensorflow { int64_t GetWorkspaceLimit(int64_t default_value_in_bytes); struct BlasLtMatmulPlanParams { - std::string ToString() const { return "NOP"; } bool operator==(const BlasLtMatmulPlanParams& other) const; @@ -67,41 +66,36 @@ H AbslHashValue(H h, const BlasLtMatmulPlanParams& params) { } struct BlasLtMatmulPlanCache { - struct Entry { + struct Entry { se::gpu::BlasLt::MatmulPlanPtr plan; - std::vector< se::gpu::BlasLt::MatmulAlgorithm > algorithms; + std::vector algorithms; }; - static StatusOr GetOrCreate( - se::Stream* stream, const BlasLtMatmulPlanParams& params, absl::Mutex** pmu, - std::optional max_algorithm_count = std::nullopt - ); + static StatusOr GetOrCreate( + se::Stream* stream, const BlasLtMatmulPlanParams& params, + absl::Mutex** pmu, std::optional max_algorithm_count = std::nullopt); // helper function for plan execution - static Status ExecuteOnStream(se::Stream* stream, - const Entry& entry, - const se::DeviceMemoryBase& a, - const se::DeviceMemoryBase& b, - se::DeviceMemoryBase& c, - size_t algorithm_idx, - se::ScratchAllocator& scratch_allocator, - const se::DeviceMemoryBase& bias, - se::blas::ProfileResult* profile_result = nullptr); - - BlasLtMatmulPlanCache() : mutex_(new absl::Mutex) { - } - -private: - static BlasLtMatmulPlanCache& i(se::Stream *stream); + static Status ExecuteOnStream( + se::Stream* stream, const Entry& entry, const se::DeviceMemoryBase& a, + const se::DeviceMemoryBase& b, se::DeviceMemoryBase& c, + size_t algorithm_idx, se::ScratchAllocator& scratch_allocator, + const se::DeviceMemoryBase& bias, + se::blas::ProfileResult* profile_result = nullptr); + + BlasLtMatmulPlanCache() : mutex_(new absl::Mutex) {} + + private: + static BlasLtMatmulPlanCache& i(se::Stream* stream); std::unique_ptr mutex_; absl::node_hash_map map_ - ABSL_GUARDED_BY(mutex_); + ABSL_GUARDED_BY(mutex_); -}; // BlasLtMatmulPlanCache +}; // BlasLtMatmulPlanCache } // namespace tensorflow -#endif // GOOGLE_CUDA || TF_HIPBLASLT +#endif // GOOGLE_CUDA || TF_HIPBLASLT #endif // TENSORFLOW_CORE_KERNELS_MATMUL_UTIL_H_ diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index b4d9043ae45896..b405721f465ac9 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -275,7 +275,7 @@ tf_kernel_library( ":base_gpu_op", ":gpu_cast_kernels", "@eigen_archive//:eigen3", - ]), + ]) + ["//tensorflow/core/framework:types_proto_cc"], ) tf_kernel_library( @@ -414,6 +414,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:tensorflow", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", ], @@ -562,6 +563,7 @@ tf_cuda_cc_test( ":base_ops_test", "//tensorflow/core/common_runtime:device", "//tensorflow/core/common_runtime:device_factory", + "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/core/kernels/mlir_generated/base_ops_test.cc b/tensorflow/core/kernels/mlir_generated/base_ops_test.cc index a45cc9b9ec4098..693426bf4178b7 100644 --- a/tensorflow/core/kernels/mlir_generated/base_ops_test.cc +++ b/tensorflow/core/kernels/mlir_generated/base_ops_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/kernels/mlir_generated/base_ops_test.h" +#include + namespace tensorflow { namespace test { diff --git a/tensorflow/core/kernels/mlir_generated/base_ops_test.h b/tensorflow/core/kernels/mlir_generated/base_ops_test.h index d7a2a2d0e9886a..45568f88fc7498 100644 --- a/tensorflow/core/kernels/mlir_generated/base_ops_test.h +++ b/tensorflow/core/kernels/mlir_generated/base_ops_test.h @@ -17,10 +17,15 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_OPS_TEST_H_ #include +#include +#include +#include #include #include +#include #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" #include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_large_tensor_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_large_tensor_test.cc index c2bd533edca18a..18ddf0d5b358b9 100755 --- a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_large_tensor_test.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_large_tensor_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include "tensorflow/core/kernels/mlir_generated/base_binary_ops_test.h" #include "tensorflow/core/kernels/mlir_generated/base_ops_test.h" diff --git a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc index dc67845a22db73..a72c7939c24037 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc @@ -1015,6 +1015,12 @@ T baseline_mul(T lhs, T rhs) { return lhs * rhs; } +template +std::complex baseline_cmulf(std::complex lhs, std::complex rhs) { + return std::complex(lhs.real() * rhs.real() - lhs.imag() * rhs.imag(), + lhs.real() * rhs.imag() + lhs.imag() * rhs.real()); +} + GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Half, Eigen::half, Eigen::half, baseline_mul, test::OpsTestConfig().ExpectStrictlyEqual()) @@ -1056,7 +1062,7 @@ TEST_F(BinaryOpsTest, MulComplex64SpecialCases) { test::NearZeroInfAndNanInput>(), test::RepeatElements(test::NearZeroInfAndNanInput>(), 64), - baseline_mul, test::OpsTestConfig()); + baseline_cmulf, test::OpsTestConfig()); } TEST_F(BinaryOpsTest, MulComplex128SpecialCases) { @@ -1066,7 +1072,7 @@ TEST_F(BinaryOpsTest, MulComplex128SpecialCases) { test::NearZeroInfAndNanInput>(), test::RepeatElements(test::NearZeroInfAndNanInput>(), 64), - baseline_mul, test::OpsTestConfig()); + baseline_cmulf, test::OpsTestConfig()); } #endif diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_abs.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_abs.cc index 2509293f6b66ca..7249c9e790092c 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_abs.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_abs.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_acos.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_acos.cc index 09339cd15ded24..eb15e749016170 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_acos.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_acos.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_acosh.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_acosh.cc index 5e4040daddfa7b..289c60a7a63c7a 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_acosh.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_acosh.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_add.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_add.cc index 191648ee3c2402..8f45c8d8ee109a 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_add.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_add.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_angle.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_angle.cc index 7eb26c6d16a187..f7fe7b04f26254 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_angle.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_angle.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_asin.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_asin.cc index d163aa5cb3cc96..196fb0f25e9d0a 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_asin.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_asin.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_asinh.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_asinh.cc index 2875ba3d30f069..59f12c665da5f7 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_asinh.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_asinh.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_atan.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_atan.cc index b4199318728b74..b7d2a0f8409ee0 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_atan.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_atan.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_atan2.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_atan2.cc index 378f20cf3edee8..a6da29897b4364 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_atan2.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_atan2.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc index bef09f530ab34b..ae77d17fd8d25b 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_and.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_and.cc index 219c09ee582c0d..c25266d64f5bb8 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_and.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_and.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_or.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_or.cc index 884a084ddf151f..f8b11b7a5d36c6 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_or.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_or.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_xor.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_xor.cc index b357edf812a508..8913ab7e84c507 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_xor.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_bitwise_xor.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_cast.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_cast.cc index b29bd93570b2f9..ea86c49cb61510 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_cast.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_cast.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_ceil.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_ceil.cc index c0435cfe4bf8d4..aa6db38c260074 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_ceil.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_ceil.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_complex.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_complex.cc index df3cf9372398cd..481de25799ad82 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_complex.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_complex.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_complex_abs.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_complex_abs.cc index 6c80459cb9cbc5..c2575f1c298119 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_complex_abs.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_complex_abs.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_conj.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_conj.cc index 38e3a24327666a..49d4813ff39092 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_conj.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_conj.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_cos.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_cos.cc index af2c2b75a18f17..a44ae3eb93c8f4 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_cos.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_cos.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_cosh.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_cosh.cc index c861da60cd632d..3a92b49a7eff27 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_cosh.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_cosh.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_div.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_div.cc index 6253ffc05dd4cb..5c5a5ff8242edf 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_div.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_div.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_div_no_nan.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_div_no_nan.cc index f3280486b53b4b..1969c1f46619f4 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_div_no_nan.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_div_no_nan.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_elu.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_elu.cc index 15e718a8048c93..c6793f5b720d49 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_elu.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_elu.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_equal.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_equal.cc index 2b4d456b05cdd2..055eccd60cbae7 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_equal.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_equal.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_erf.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_erf.cc index ac815748ff94c9..ce3470fb588ae5 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_erf.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_erf.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_erfc.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_erfc.cc index fde8304258c0a6..eb449ff9a5a57a 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_erfc.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_erfc.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_exp.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_exp.cc index 22ca3fe0ff636b..ac1074cc5956f8 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_exp.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_exp.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_expm1.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_expm1.cc index a5f0d698916f64..84584a73b78f25 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_expm1.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_expm1.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_floor.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_floor.cc index 7acb10ed23dfdf..b3c41babafd8c6 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_floor.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_floor.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_floor_div.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_floor_div.cc index 4fcb8011ebfd71..4cd3aafe31b732 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_floor_div.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_floor_div.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_floor_mod.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_floor_mod.cc index a2d8a7352cd7a5..7f4cb9f6d4c44d 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_floor_mod.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_floor_mod.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_greater.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_greater.cc index 9d5e994df6f58f..1bf584adf5578c 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_greater.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_greater.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_greater_equal.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_greater_equal.cc index 81527b440cc46b..b6fd6a78414055 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_greater_equal.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_greater_equal.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_imag.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_imag.cc index 0e078e595191d6..8fd6bd76b48561 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_imag.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_imag.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_invert.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_invert.cc index 8fec3eb037a903..c553fdc0ed1375 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_invert.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_invert.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_is_finite.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_is_finite.cc index 906a311c7e1b78..b733b8b0893db2 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_is_finite.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_is_finite.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_is_inf.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_is_inf.cc index 580cef6ed8c7a3..92657c0e502ece 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_is_inf.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_is_inf.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_is_nan.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_is_nan.cc index 42a46113601c21..05ea0485fcb806 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_is_nan.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_is_nan.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_left_shift.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_left_shift.cc index d2aa3548e77788..def5ceb661fa7e 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_left_shift.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_left_shift.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_less.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_less.cc index c8e6166a3688e0..4ce0eeeb31b1eb 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_less.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_less.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_less_equal.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_less_equal.cc index 7aff9a28cb661b..b8cab6bde48e26 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_less_equal.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_less_equal.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_lgamma.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_lgamma.cc index d5b85626b62295..ceb8146593b46b 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_lgamma.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_lgamma.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_log.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_log.cc index 773da880b704d6..f163f91ddac148 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_log.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_log.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_log1p.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_log1p.cc index e026bc3e649ef0..f0cd0fe80e95e2 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_log1p.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_log1p.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_not_equal.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_not_equal.cc index 887ca37c2242dc..9914aabe559a3c 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_not_equal.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_not_equal.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_real.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_real.cc index 2c400aca27abb1..44cfdb083b5f27 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_real.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_real.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" diff --git a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc index 4a2e9744215fcd..0c375ddb1b2fa5 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc @@ -14,12 +14,16 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include +#include +#include #include #include #include #include +#include #include #include diff --git a/tensorflow/core/kernels/ragged_range_op.cc b/tensorflow/core/kernels/ragged_range_op.cc index 90c2060c33f386..65054ae5bd843a 100644 --- a/tensorflow/core/kernels/ragged_range_op.cc +++ b/tensorflow/core/kernels/ragged_range_op.cc @@ -87,16 +87,29 @@ class RaggedRangeOp : public OpKernel { size = 0; } else if constexpr (std::is_integral::value) { // The following is copied from tensorflow::RangeOp::Compute(). - size = Eigen::divup(Eigen::numext::abs(limit - start), - Eigen::numext::abs(delta)); + uint64_t range; + if ((limit > 0 && start < 0) || (limit < 0 && start > 0)) { + range = static_cast(Eigen::numext::abs(limit)) + + static_cast(Eigen::numext::abs(start)); + } else { + range = static_cast(Eigen::numext::abs(limit - start)); + } + + uint64_t size_unsigned = Eigen::divup( + range, static_cast(Eigen::numext::abs(delta))); + OP_REQUIRES(context, + size_unsigned <= std::numeric_limits::max(), + InvalidArgument("Requires ((limit - start) / delta) <= ", + std::numeric_limits::max())); + size = static_cast(size_unsigned); } else { // The following is copied from tensorflow::RangeOp::Compute(). auto size_auto = Eigen::numext::ceil(Eigen::numext::abs((limit - start) / delta)); OP_REQUIRES( - context, size_auto <= std::numeric_limits::max(), + context, size_auto <= std::numeric_limits::max(), errors::InvalidArgument("Requires ((limit - start) / delta) <= ", - std::numeric_limits::max())); + std::numeric_limits::max())); size = static_cast(size_auto); } OP_REQUIRES(context, size >= 0, InvalidArgument("Requires size >= 0")); @@ -122,7 +135,9 @@ class RaggedRangeOp : public OpKernel { T delta = broadcast_deltas ? deltas(0) : deltas(row); for (SPLITS_TYPE i = 0; i < row_size; ++i) { rt_dense_values(value_index++) = T(value); - value += delta; + if (i < row_size - 1) { + value += delta; + } } } } diff --git a/tensorflow/core/kernels/ragged_range_op_test.cc b/tensorflow/core/kernels/ragged_range_op_test.cc index 79514173547006..699531a8d3647c 100644 --- a/tensorflow/core/kernels/ragged_range_op_test.cc +++ b/tensorflow/core/kernels/ragged_range_op_test.cc @@ -89,6 +89,17 @@ TEST_F(RaggedRangeOpTest, RangeSizeOverflow) { RunOpKernel().message()); } +TEST_F(RaggedRangeOpTest, RangeSizeOverflow2) { + BuildRaggedRangeGraph(); + AddInputFromArray(TensorShape({}), {static_cast(5e18)}); + AddInputFromArray(TensorShape({}), {static_cast(-5e18)}); + AddInputFromArray(TensorShape({}), {-1}); + + EXPECT_EQ(absl::StrCat("Requires ((limit - start) / delta) <= ", + std::numeric_limits::max()), + RunOpKernel().message()); +} + TEST_F(RaggedRangeOpTest, BroadcastDeltas) { BuildRaggedRangeGraph(); AddInputFromArray(TensorShape({3}), {0, 5, 8}); // starts diff --git a/tensorflow/core/kernels/range_sampler.cc b/tensorflow/core/kernels/range_sampler.cc index 971b849ce71a59..409b5448243d90 100644 --- a/tensorflow/core/kernels/range_sampler.cc +++ b/tensorflow/core/kernels/range_sampler.cc @@ -297,7 +297,7 @@ absl::Status FixedUnigramSampler::LoadFromFile(Env* env, // Skip entries that do not belong to this shard. if (word_id % num_shards_ == shard_) { float w = 0.0; - if (!strings::safe_strtof(cols.at(cols.size() - 1), &w)) { + if (!absl::SimpleAtof(cols.at(cols.size() - 1), &w)) { return errors::InvalidArgument("Wrong vocabulary format at line: ", line); } diff --git a/tensorflow/core/kernels/reduce_join_op.cc b/tensorflow/core/kernels/reduce_join_op.cc index 72c41f8ab1420d..6ee2ef0139a427 100644 --- a/tensorflow/core/kernels/reduce_join_op.cc +++ b/tensorflow/core/kernels/reduce_join_op.cc @@ -161,7 +161,7 @@ class ReduceJoinOp : public OpKernel { const int64_t reduction_iter_size = GetReductionIterSize(reduced_indices, input_shape); - absl::InlinedVector curr_strings(reduction_iter_size); + absl::InlinedVector curr_strings(reduction_iter_size); for (int64_t output_index = 0; output_index < output_shape.num_elements(); ++output_index) { int64_t output_full_index = LinearSubIndexToFullIndex( diff --git a/tensorflow/core/kernels/restore_v2_op_test.cc b/tensorflow/core/kernels/restore_v2_op_test.cc index b9f289f01bb90f..c102cc42e2063f 100644 --- a/tensorflow/core/kernels/restore_v2_op_test.cc +++ b/tensorflow/core/kernels/restore_v2_op_test.cc @@ -60,7 +60,7 @@ class RestoreV2OpTest : public OpsTestBase { TF_ASSERT_OK(InitOp()); } - void RunTest(StringPiece save_op_to_use) { + void RunTest(absl::string_view save_op_to_use) { const string filename = io::JoinPath(testing::TmpDir(), "tensor_simple-", save_op_to_use); const std::vector tensor_names = { diff --git a/tensorflow/core/kernels/risc/experimental/risc_reshape_op.cc b/tensorflow/core/kernels/risc/experimental/risc_reshape_op.cc index bdcbdc0fe98f38..7d1dd915ee3383 100644 --- a/tensorflow/core/kernels/risc/experimental/risc_reshape_op.cc +++ b/tensorflow/core/kernels/risc/experimental/risc_reshape_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/kernels/risc/experimental/risc_shape_op.cc b/tensorflow/core/kernels/risc/experimental/risc_shape_op.cc index 510abf196c72ca..98273b64cf6a7d 100644 --- a/tensorflow/core/kernels/risc/experimental/risc_shape_op.cc +++ b/tensorflow/core/kernels/risc/experimental/risc_shape_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/kernels/sequence_ops.cc b/tensorflow/core/kernels/sequence_ops.cc index 5256db35a1f228..ca793a9c2a2520 100644 --- a/tensorflow/core/kernels/sequence_ops.cc +++ b/tensorflow/core/kernels/sequence_ops.cc @@ -93,8 +93,21 @@ class RangeOp : public OpKernel { } int64_t size; if constexpr (std::is_integral::value) { - size = Eigen::divup(Eigen::numext::abs(limit - start), - Eigen::numext::abs(delta)); + uint64_t range; + if ((limit > 0 && start < 0) || (limit < 0 && start > 0)) { + range = static_cast(Eigen::numext::abs(limit)) + + static_cast(Eigen::numext::abs(start)); + } else { + range = static_cast(Eigen::numext::abs(limit - start)); + } + + uint64_t size_unsigned = + Eigen::divup(range, static_cast(Eigen::numext::abs(delta))); + OP_REQUIRES( + context, size_unsigned <= std::numeric_limits::max(), + errors::InvalidArgument("Requires ((limit - start) / delta) <= ", + std::numeric_limits::max())); + size = static_cast(size_unsigned); } else { auto size_auto = Eigen::numext::ceil(Eigen::numext::abs((limit - start) / delta)); diff --git a/tensorflow/core/kernels/sequence_ops_test.cc b/tensorflow/core/kernels/sequence_ops_test.cc index d0a079f1827428..84943e2d5d2f46 100644 --- a/tensorflow/core/kernels/sequence_ops_test.cc +++ b/tensorflow/core/kernels/sequence_ops_test.cc @@ -115,6 +115,18 @@ TEST_F(RangeOpTest, Large_Double) { test::ExpectTensorEqual(expected, *GetOutput(0)); } +TEST_F(RangeOpTest, Range_Size_Overflow) { + MakeOp(DT_INT64); + + AddInputFromArray(TensorShape({}), {static_cast(5e18)}); + AddInputFromArray(TensorShape({}), {static_cast(-5e18)}); + AddInputFromArray(TensorShape({}), {-1}); + + EXPECT_EQ(absl::StrCat("Requires ((limit - start) / delta) <= ", + std::numeric_limits::max()), + RunOpKernel().message()); +} + TEST_F(LinSpaceOpTest, Simple_D32) { MakeOp(DT_FLOAT, DT_INT32); diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc index 1f10def306145d..2cf1388afa2a74 100644 --- a/tensorflow/core/kernels/sparse_cross_op.cc +++ b/tensorflow/core/kernels/sparse_cross_op.cc @@ -161,14 +161,14 @@ tstring KeyedSparseTensorColumn::Feature(int64_t batch, int64_t n, } template <> -StringPiece SparseTensorColumn::Feature(int64_t batch, int64_t n, - bool strong_hash) const { +absl::string_view SparseTensorColumn::Feature( + int64_t batch, int64_t n, bool strong_hash) const { const int64_t start = feature_start_indices_[batch]; return values_.vec().data()[start + n]; } template <> -StringPiece KeyedSparseTensorColumn::Feature( +absl::string_view KeyedSparseTensorColumn::Feature( int64_t batch, int64_t n, bool strong_hash) const { const int64_t start = feature_start_indices_[batch]; return values_.vec().data()[start + n]; @@ -259,13 +259,13 @@ tstring KeyedDenseTensorColumn::Feature(int64_t batch, int64_t n, } template <> -StringPiece DenseTensorColumn::Feature(int64_t batch, int64_t n, - bool strong_hash) const { +absl::string_view DenseTensorColumn::Feature( + int64_t batch, int64_t n, bool strong_hash) const { return tensor_.matrix()(batch, n); } template <> -StringPiece KeyedDenseTensorColumn::Feature( +absl::string_view KeyedDenseTensorColumn::Feature( int64_t batch, int64_t n, bool strong_hash) const { return tensor_.matrix()(batch, n); } @@ -961,7 +961,7 @@ REGISTER_KERNEL_BUILDER(Name("SparseCross") .Device(DEVICE_CPU) .TypeConstraint("out_type") .TypeConstraint("internal_type"), - SparseCrossOp); + SparseCrossOp); REGISTER_KERNEL_BUILDER(Name("SparseCross") .Device(DEVICE_CPU) diff --git a/tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc b/tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc index 27115f3153458f..92ef7528dab075 100644 --- a/tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc +++ b/tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc @@ -33,7 +33,7 @@ namespace tensorflow { namespace { -static void ExpectHasSubstr(StringPiece s, StringPiece expected) { +static void ExpectHasSubstr(absl::string_view s, absl::string_view expected) { EXPECT_TRUE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } diff --git a/tensorflow/core/kernels/sparse_reduce_op.cc b/tensorflow/core/kernels/sparse_reduce_op.cc index 97dd91523ebc7f..222280e5468969 100644 --- a/tensorflow/core/kernels/sparse_reduce_op.cc +++ b/tensorflow/core/kernels/sparse_reduce_op.cc @@ -143,7 +143,7 @@ struct SumOp { static void Run(OpKernelContext *ctx, typename TTypes::Scalar &s, const typename TTypes::UnalignedVec &v) { s.device(ctx->eigen_cpu_device()) = v.sum(); } - static StringPiece Name() { + static absl::string_view Name() { return "sum"; } }; @@ -153,7 +153,7 @@ struct MaxOp { static void Run(OpKernelContext *ctx, typename TTypes::Scalar &s, const typename TTypes::UnalignedVec &v) { s.device(ctx->eigen_cpu_device()) = v.maximum(); } - static StringPiece Name() { + static absl::string_view Name() { return "max"; } }; diff --git a/tensorflow/core/kernels/spectrogram_test_utils.cc b/tensorflow/core/kernels/spectrogram_test_utils.cc index 684cbc19e77a12..82c708b8dce918 100644 --- a/tensorflow/core/kernels/spectrogram_test_utils.cc +++ b/tensorflow/core/kernels/spectrogram_test_utils.cc @@ -166,7 +166,7 @@ void ReadCSVFileToArrayOrDie(const string& filename, std::vector split_line = str_util::Split(lines[l], ","); for (const string& token : split_line) { float tmp; - CHECK(strings::safe_strtof(token, &tmp)); + CHECK(absl::SimpleAtof(token, &tmp)); values.push_back(tmp); } array->push_back(values); @@ -181,8 +181,9 @@ bool WriteDoubleVectorToFile(const string& file_name, return false; } for (int i = 0; i < data.size(); ++i) { - if (!file->Append(StringPiece(reinterpret_cast(&(data[i])), - sizeof(data[i]))) + if (!file + ->Append(absl::string_view( + reinterpret_cast(&(data[i])), sizeof(data[i]))) .ok()) { LOG(ERROR) << "Failed to append to file " << file_name; return false; @@ -203,8 +204,9 @@ bool WriteFloatVectorToFile(const string& file_name, return false; } for (int i = 0; i < data.size(); ++i) { - if (!file->Append(StringPiece(reinterpret_cast(&(data[i])), - sizeof(data[i]))) + if (!file + ->Append(absl::string_view( + reinterpret_cast(&(data[i])), sizeof(data[i]))) .ok()) { LOG(ERROR) << "Failed to append to file " << file_name; return false; @@ -225,8 +227,9 @@ bool WriteDoubleArrayToFile(const string& file_name, int size, return false; } for (int i = 0; i < size; ++i) { - if (!file->Append(StringPiece(reinterpret_cast(&(data[i])), - sizeof(data[i]))) + if (!file + ->Append(absl::string_view( + reinterpret_cast(&(data[i])), sizeof(data[i]))) .ok()) { LOG(ERROR) << "Failed to append to file " << file_name; return false; @@ -247,8 +250,9 @@ bool WriteFloatArrayToFile(const string& file_name, int size, return false; } for (int i = 0; i < size; ++i) { - if (!file->Append(StringPiece(reinterpret_cast(&(data[i])), - sizeof(data[i]))) + if (!file + ->Append(absl::string_view( + reinterpret_cast(&(data[i])), sizeof(data[i]))) .ok()) { LOG(ERROR) << "Failed to append to file " << file_name; return false; @@ -272,16 +276,18 @@ bool WriteComplexVectorToRawFloatFile( for (int i = 0; i < data.size(); ++i) { for (int j = 0; j < data[i].size(); ++j) { const float real_part(real(data[i][j])); - if (!file->Append(StringPiece(reinterpret_cast(&real_part), - sizeof(real_part))) + if (!file->Append( + absl::string_view(reinterpret_cast(&real_part), + sizeof(real_part))) .ok()) { LOG(ERROR) << "Failed to append to file " << file_name; return false; } const float imag_part(imag(data[i][j])); - if (!file->Append(StringPiece(reinterpret_cast(&imag_part), - sizeof(imag_part))) + if (!file->Append( + absl::string_view(reinterpret_cast(&imag_part), + sizeof(imag_part))) .ok()) { LOG(ERROR) << "Failed to append to file " << file_name; return false; diff --git a/tensorflow/core/kernels/string_join_op.cc b/tensorflow/core/kernels/string_join_op.cc index 336be40b1927e0..4eebde744d1cbe 100644 --- a/tensorflow/core/kernels/string_join_op.cc +++ b/tensorflow/core/kernels/string_join_op.cc @@ -62,7 +62,7 @@ class StringJoinOp : public OpKernel { &output_tensor)); auto output_flat = output_tensor->flat(); - std::vector strings(input_list.size()); + std::vector strings(input_list.size()); for (size_t i = 0; i < input_shape.num_elements(); ++i) { for (int j = 0; j < input_list.size(); ++j) { strings[j] = (is_scalar[j]) ? inputs[j](0) : inputs[j](i); diff --git a/tensorflow/core/kernels/string_lower_op.cc b/tensorflow/core/kernels/string_lower_op.cc index 23b83b66d17a6f..51c614502c3ee7 100644 --- a/tensorflow/core/kernels/string_lower_op.cc +++ b/tensorflow/core/kernels/string_lower_op.cc @@ -50,7 +50,7 @@ class StringLowerOp : public OpKernel { if (encoding_.empty()) { for (int64_t i = 0; i < input.size(); ++i) { - StringPiece entry(input(i)); + absl::string_view entry(input(i)); output(i) = absl::AsciiStrToLower(entry); } } else { diff --git a/tensorflow/core/kernels/string_split_op.cc b/tensorflow/core/kernels/string_split_op.cc index dc8564fe74ee92..650c74a83238df 100644 --- a/tensorflow/core/kernels/string_split_op.cc +++ b/tensorflow/core/kernels/string_split_op.cc @@ -34,13 +34,13 @@ namespace { // a series of finds in the input string, making it much more efficient than // SplitOnCharSet. template -std::vector SplitOnChar(const tstring& str, const char delim, - Predicate p) { - std::vector result; - StringPiece text(str); +std::vector SplitOnChar(const tstring& str, const char delim, + Predicate p) { + std::vector result; + absl::string_view text(str); auto f = text.find(delim); - while (f != StringPiece::npos) { - StringPiece token = text.substr(0, f); + while (f != absl::string_view::npos) { + absl::string_view token = text.substr(0, f); if (p(token)) { result.emplace_back(token); } @@ -58,15 +58,17 @@ std::vector SplitOnChar(const tstring& str, const char delim, // is valid. // Based on str_util::Split. template -std::vector SplitOnCharSet(const tstring& str, - const tstring& delim_set, Predicate p) { - std::vector result; - StringPiece text(str); - StringPiece delims(delim_set); +std::vector SplitOnCharSet(const tstring& str, + const tstring& delim_set, + Predicate p) { + std::vector result; + absl::string_view text(str); + absl::string_view delims(delim_set); size_t token_start = 0; for (size_t i = 0; i < text.size() + 1; i++) { - if ((i == text.size()) || (delims.find(text[i]) != StringPiece::npos)) { - StringPiece token(text.data() + token_start, i - token_start); + if ((i == text.size()) || + (delims.find(text[i]) != absl::string_view::npos)) { + absl::string_view token(text.data() + token_start, i - token_start); if (p(token)) { result.emplace_back(token); } @@ -80,16 +82,17 @@ std::vector SplitOnCharSet(const tstring& str, // Returns a vector of StringPieces which are valid as long as input `str` // is valid. template -std::vector Split(const tstring& str, const tstring& delimiter, - Predicate predicate) { +std::vector Split(const tstring& str, + const tstring& delimiter, + Predicate predicate) { if (str.empty()) { - return std::vector(); + return std::vector(); } if (delimiter.empty()) { - std::vector result; + std::vector result; result.resize(str.size()); for (size_t i = 0; i < str.size(); ++i) { - result[i] = StringPiece(str.data() + i, 1); + result[i] = absl::string_view(str.data() + i, 1); } return result; } @@ -99,8 +102,8 @@ std::vector Split(const tstring& str, const tstring& delimiter, return SplitOnCharSet(str, delimiter, predicate); } -std::vector SplitV2(const tstring& str, StringPiece sep, - int maxsplit) { +std::vector SplitV2(const tstring& str, + absl::string_view sep, int maxsplit) { // This SplitV2 method matches the behavior of python's str.split: // If sep is given, consecutive delimiters are not grouped together // and are deemed to delimit empty strings (for example, '1,,2'.split(',') @@ -115,16 +118,16 @@ std::vector SplitV2(const tstring& str, StringPiece sep, // splitting an empty string or a string consisting of just whitespace // with a None separator returns []. - std::vector result; + std::vector result; - StringPiece text(str); + absl::string_view text(str); if (maxsplit == 0) { result.emplace_back(text); return result; } if (sep.empty()) { - StringPiece token; + absl::string_view token; // Remove leading whitespaces. str_util::RemoveLeadingWhitespace(&text); int split = 0; @@ -142,13 +145,13 @@ std::vector SplitV2(const tstring& str, StringPiece sep, auto p = std::search(text.begin(), text.end(), sep.begin(), sep.end()); int split = 0; while (p != text.end()) { - StringPiece token = text.substr(0, p - text.begin()); + absl::string_view token = text.substr(0, p - text.begin()); result.push_back(token); text.remove_prefix(token.size()); text.remove_prefix(sep.size()); ++split; if (maxsplit > 0 && split == maxsplit) { - result.push_back(StringPiece(text)); + result.push_back(absl::string_view(text)); return result; } p = std::search(text.begin(), text.end(), sep.begin(), sep.end()); @@ -190,7 +193,7 @@ class StringSplitOp : public OpKernel { const auto delimiter_vec = delimiter_tensor->flat(); const tstring& delimiter = delimiter_vec(0); // Empty delimiter means split the input character by character. - std::vector tokens; + std::vector tokens; // Guess that we'll be unpacking a handful of tokens per example. static constexpr int kReserveSize = 4; tokens.reserve(batch_size * kReserveSize); @@ -199,7 +202,7 @@ class StringSplitOp : public OpKernel { int64_t max_num_entries = 0; std::vector num_indices(batch_size); for (int64_t i = 0; i < batch_size; ++i) { - std::vector parts = + std::vector parts = skip_empty_ ? Split(input_vec(i), delimiter, str_util::SkipEmpty()) : Split(input_vec(i), delimiter, str_util::AllowEmpty()); int64_t n_entries = parts.size(); @@ -262,8 +265,8 @@ class StringSplitV2Op : public OpKernel { errors::InvalidArgument("sep must be a scalar, got shape: ", sep_tensor->shape().DebugString())); const auto sep_vec = sep_tensor->flat(); - StringPiece sep(sep_vec(0)); - std::vector tokens; + absl::string_view sep(sep_vec(0)); + std::vector tokens; // Guess that we'll be unpacking a handful of tokens per example. static constexpr int kReserveSize = 4; tokens.reserve(batch_size * kReserveSize); @@ -272,7 +275,8 @@ class StringSplitV2Op : public OpKernel { int64_t max_num_entries = 0; std::vector num_indices(batch_size); for (int64_t i = 0; i < batch_size; ++i) { - std::vector parts = SplitV2(input_vec(i), sep, maxsplit_); + std::vector parts = + SplitV2(input_vec(i), sep, maxsplit_); int64_t n_entries = parts.size(); num_indices[i] = n_entries; output_size += n_entries; diff --git a/tensorflow/core/kernels/string_strip_op.cc b/tensorflow/core/kernels/string_strip_op.cc index dbc8f9d02c48a4..6a0dabef7c0330 100644 --- a/tensorflow/core/kernels/string_strip_op.cc +++ b/tensorflow/core/kernels/string_strip_op.cc @@ -41,7 +41,7 @@ class StringStripOp : public OpKernel { auto output = output_tensor->flat(); for (int64_t i = 0; i < input.size(); ++i) { - StringPiece entry(input(i)); + absl::string_view entry(input(i)); str_util::RemoveWhitespaceContext(&entry); output(i) = string(entry); } diff --git a/tensorflow/core/kernels/string_to_hash_bucket_fast_op.h b/tensorflow/core/kernels/string_to_hash_bucket_fast_op.h index 73c55cf54dc8d5..f9119259f4d934 100644 --- a/tensorflow/core/kernels/string_to_hash_bucket_fast_op.h +++ b/tensorflow/core/kernels/string_to_hash_bucket_fast_op.h @@ -26,7 +26,7 @@ limitations under the License. namespace tensorflow { -template +template class StringToHashBucketOp : public OpKernel { public: explicit StringToHashBucketOp(OpKernelConstruction* ctx) : OpKernel(ctx) { diff --git a/tensorflow/core/kernels/string_upper_op.cc b/tensorflow/core/kernels/string_upper_op.cc index f948c2d5e30632..0a427dcc294c73 100644 --- a/tensorflow/core/kernels/string_upper_op.cc +++ b/tensorflow/core/kernels/string_upper_op.cc @@ -49,7 +49,7 @@ class StringUpperOp : public OpKernel { auto output = output_tensor->flat(); if (encoding_.empty()) { for (int64_t i = 0; i < input.size(); ++i) { - StringPiece entry(input(i)); + absl::string_view entry(input(i)); output(i) = absl::AsciiStrToUpper(entry); } } else { diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h index 9dda609a5b7d62..58230d3d3e3cf4 100644 --- a/tensorflow/core/kernels/string_util.h +++ b/tensorflow/core/kernels/string_util.h @@ -48,7 +48,7 @@ int32 UTF8StrLen(const string& str); // the end of the string is reached before the requested characters, then the // position will point to the end of string and this function will return false. template -bool ForwardNUTF8CharPositions(const StringPiece in, +bool ForwardNUTF8CharPositions(const absl::string_view in, const T num_utf8_chars_to_shift, T* pos) { const size_t size = in.size(); T utf8_chars_counted = 0; @@ -69,7 +69,7 @@ bool ForwardNUTF8CharPositions(const StringPiece in, // the string is reached before the requested character, then the position will // point to the beginning of the string and this function will return false. template -bool BackNUTF8CharPositions(const StringPiece in, +bool BackNUTF8CharPositions(const absl::string_view in, const T num_utf8_chars_to_shift, T* pos) { const size_t start = 0; T utf8_chars_counted = 0; diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc index a7880ccc681eff..3ea53cbe70f542 100644 --- a/tensorflow/core/kernels/substr_op.cc +++ b/tensorflow/core/kernels/substr_op.cc @@ -78,7 +78,7 @@ class SubstrOp : public OpKernel { const T len = tensorflow::internal::SubtleMustCopy(len_tensor.scalar()()); for (size_t i = 0; i < input_tensor.NumElements(); ++i) { - StringPiece in(input(i)); + absl::string_view in(input(i)); T byte_pos = pos; T byte_len = len; switch (unit_) { @@ -95,7 +95,7 @@ class SubstrOp : public OpKernel { errors::InvalidArgument("pos ", pos, " out of range for ", "string b'", in, "' at index ", i)); } - StringPiece sub_in = in.substr(byte_pos, byte_len); + absl::string_view sub_in = in.substr(byte_pos, byte_len); output(i).assign(sub_in.data(), sub_in.size()); } } else { @@ -103,7 +103,7 @@ class SubstrOp : public OpKernel { auto pos_flat = pos_tensor.flat(); auto len_flat = len_tensor.flat(); for (size_t i = 0; i < input_tensor.NumElements(); ++i) { - StringPiece in(input(i)); + absl::string_view in(input(i)); const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i)); const T len = tensorflow::internal::SubtleMustCopy(len_flat(i)); T byte_pos = pos; @@ -122,7 +122,7 @@ class SubstrOp : public OpKernel { errors::InvalidArgument("pos ", pos, " out of range for ", "string b'", in, "' at index ", i)); } - StringPiece sub_in = in.substr(byte_pos, byte_len); + absl::string_view sub_in = in.substr(byte_pos, byte_len); output(i).assign(sub_in.data(), sub_in.size()); } } @@ -174,7 +174,7 @@ class SubstrOp : public OpKernel { // Iterate through broadcasted tensors and perform substr for (int i = 0; i < output_shape.dim_size(0); ++i) { - StringPiece in(input(input.dimension(0) > 1 ? i : 0)); + absl::string_view in(input(input.dimension(0) > 1 ? i : 0)); const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i)); T byte_pos = pos; @@ -193,7 +193,7 @@ class SubstrOp : public OpKernel { errors::InvalidArgument("pos ", pos, " out of range for ", "string b'", in, "' at index ", i)); } - StringPiece sub_in = in.substr(byte_pos, byte_len); + absl::string_view sub_in = in.substr(byte_pos, byte_len); output(i).assign(sub_in.data(), sub_in.size()); } break; @@ -228,8 +228,8 @@ class SubstrOp : public OpKernel { // Iterate through broadcasted tensors and perform substr for (int i = 0; i < output_shape.dim_size(0); ++i) { for (int j = 0; j < output_shape.dim_size(1); ++j) { - StringPiece in(input(input.dimension(0) > 1 ? i : 0, - input.dimension(1) > 1 ? j : 0)); + absl::string_view in(input(input.dimension(0) > 1 ? i : 0, + input.dimension(1) > 1 ? j : 0)); const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i, j)); const T len = @@ -251,7 +251,7 @@ class SubstrOp : public OpKernel { "string b'", in, "' at index (", i, ", ", j, ")")); } - StringPiece sub_in = in.substr(byte_pos, byte_len); + absl::string_view sub_in = in.substr(byte_pos, byte_len); output(i, j).assign(sub_in.data(), sub_in.size()); } } @@ -268,7 +268,8 @@ class SubstrOp : public OpKernel { private: // This adjusts the requested position. Note it does not perform any bound // checks. - static inline T AdjustedPosIndex(const T pos_requested, const StringPiece s) { + static inline T AdjustedPosIndex(const T pos_requested, + const absl::string_view s) { if (pos_requested < 0) { return s.size() + pos_requested; } @@ -277,7 +278,7 @@ class SubstrOp : public OpKernel { // Return true if successful; otherwise, return false if the `pos` argument // is out of range in the string. - static inline bool UpdatePosAndLenForUtf8(const StringPiece in, T* pos, + static inline bool UpdatePosAndLenForUtf8(const absl::string_view in, T* pos, T* len) { if (*pos >= 0) { return UpdatePositivePosAndLenForUtf8(in, *pos, *len, pos, len); @@ -286,9 +287,9 @@ class SubstrOp : public OpKernel { } } - static bool UpdatePositivePosAndLenForUtf8(const StringPiece in, const T pos, - const T len, T* char_pos, - T* char_len) { + static bool UpdatePositivePosAndLenForUtf8(const absl::string_view in, + const T pos, const T len, + T* char_pos, T* char_len) { *char_pos = 0; // Determine byte position of the substring start. if (!ForwardNUTF8CharPositions(in, pos, char_pos)) { @@ -307,9 +308,9 @@ class SubstrOp : public OpKernel { // This function expects a negative position relative to the end of the // string, but will update the character position to a positive number // relative to the beginning of the string. - static bool UpdateNegativePosAndLenForUtf8(const StringPiece in, const T pos, - const T len, T* char_pos, - T* char_len) { + static bool UpdateNegativePosAndLenForUtf8(const absl::string_view in, + const T pos, const T len, + T* char_pos, T* char_len) { // Initially treat the length as position of the end of the substring. *char_len = in.size(); // This is the number of character to skip from the end of the string to diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index fe318a58803fb6..291818e6992270 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -326,13 +326,13 @@ class TensorArrayGradOp : public TensorArrayCreationOp { } else { container = "_tensor_arrays"; const auto& resource = ctx->input(0).flat()(0); - if (StringPiece(resource.name()).substr(0, container.size()) != + if (absl::string_view(resource.name()).substr(0, container.size()) != container) { return errors::InvalidArgument("Wrong input container. ", resource.name()); } tensor_array_name = - string(StringPiece(resource.name()).substr(container.size())); + string(absl::string_view(resource.name()).substr(container.size())); } auto output_handle = tensor_array_output_handle->flat(); diff --git a/tensorflow/core/kernels/tensor_list.cc b/tensorflow/core/kernels/tensor_list.cc index 2fbd871f688630..b65d4a96907d44 100644 --- a/tensorflow/core/kernels/tensor_list.cc +++ b/tensorflow/core/kernels/tensor_list.cc @@ -58,7 +58,7 @@ bool TensorList::Decode(const VariantTensorData& data) { string metadata; data.get_metadata(&metadata); uint64 scratch; - StringPiece iter(metadata); + absl::string_view iter(metadata); std::vector invalid_indices; core::GetVarint64(&iter, &scratch); size_t num_invalid_tensors = static_cast(scratch); diff --git a/tensorflow/core/kernels/word2vec_kernels.cc b/tensorflow/core/kernels/word2vec_kernels.cc index 7f1dddce884009..5ab33ae10b74b0 100644 --- a/tensorflow/core/kernels/word2vec_kernels.cc +++ b/tensorflow/core/kernels/word2vec_kernels.cc @@ -33,9 +33,9 @@ const int kSentenceSize = 1000; namespace { -bool ScanWord(StringPiece* input, string* word) { +bool ScanWord(absl::string_view* input, string* word) { str_util::RemoveLeadingWhitespace(input); - StringPiece tmp; + absl::string_view tmp; if (str_util::ConsumeNonWhitespace(input, &tmp)) { word->assign(tmp.data(), tmp.size()); return true; @@ -180,7 +180,7 @@ class SkipgramOp : public OpKernel { absl::Status Init(Env* env, const string& filename) { string data; TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &data)); - StringPiece input = data; + absl::string_view input = data; string w; corpus_size_ = 0; std::unordered_map word_freq; diff --git a/tensorflow/core/lib/db/sqlite.cc b/tensorflow/core/lib/db/sqlite.cc index 65f6492e50cd9d..30fbae40b6c5dd 100644 --- a/tensorflow/core/lib/db/sqlite.cc +++ b/tensorflow/core/lib/db/sqlite.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/lib/db/sqlite.h" +#include + #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -92,7 +94,7 @@ sqlite3_stmt* PrepareRawOrDie(sqlite3* db, const char* sql) { } absl::Status SetPragma(Sqlite* db, const char* pragma, - const StringPiece& value) { + const absl::string_view& value) { if (value.empty()) return absl::OkStatus(); for (auto p = value.begin(); p < value.end(); ++p) { if (!(('0' <= *p && *p <= '9') || ('A' <= *p && *p <= 'Z') || @@ -107,9 +109,9 @@ absl::Status SetPragma(Sqlite* db, const char* pragma, return stmt.Step(&unused_done); } -const StringPiece GetEnv(const char* var) { +const absl::string_view GetEnv(const char* var) { const char* val = std::getenv(var); - return (val == nullptr) ? StringPiece() : StringPiece(val); + return (val == nullptr) ? absl::string_view() : absl::string_view(val); } absl::Status EnvPragma(Sqlite* db, const char* pragma, const char* var) { @@ -171,7 +173,8 @@ Sqlite::~Sqlite() { CHECK_EQ(SQLITE_OK, sqlite3_close(db_)); } -absl::Status Sqlite::Prepare(const StringPiece& sql, SqliteStatement* stmt) { +absl::Status Sqlite::Prepare(const absl::string_view& sql, + SqliteStatement* stmt) { SqliteLock lock(*this); sqlite3_stmt* ps = nullptr; int rc = sqlite3_prepare_v2(db_, sql.data(), static_cast(sql.size()), diff --git a/tensorflow/core/lib/db/sqlite.h b/tensorflow/core/lib/db/sqlite.h index 35fc40d3e66ff2..992001e448e617 100644 --- a/tensorflow/core/lib/db/sqlite.h +++ b/tensorflow/core/lib/db/sqlite.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_DB_SQLITE_H_ #define TENSORFLOW_CORE_LIB_DB_SQLITE_H_ +#include +#include #include #include "absl/log/check.h" @@ -89,8 +91,8 @@ class TF_LOCKABLE Sqlite : public core::RefCounted { /// routine will retry automatically and then possibly fail. /// /// The returned statement holds a reference to this object. - absl::Status Prepare(const StringPiece& sql, SqliteStatement* stmt); - SqliteStatement PrepareOrDie(const StringPiece& sql); + absl::Status Prepare(const absl::string_view& sql, SqliteStatement* stmt); + SqliteStatement PrepareOrDie(const absl::string_view& sql); /// \brief Returns extended result code of last error. /// @@ -231,22 +233,22 @@ class SqliteStatement { /// /// When using the unsafe methods, the data must not be changed or /// freed until this statement is Reset() or finalized. - void BindText(int parameter, const StringPiece& text) { + void BindText(int parameter, const absl::string_view& text) { Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(), SQLITE_TRANSIENT, SQLITE_UTF8), parameter); size_ += text.size(); } - void BindText(const char* parameter, const StringPiece& text) { + void BindText(const char* parameter, const absl::string_view& text) { BindText(GetParameterIndex(parameter), text); } - void BindTextUnsafe(int parameter, const StringPiece& text) { + void BindTextUnsafe(int parameter, const absl::string_view& text) { Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(), SQLITE_STATIC, SQLITE_UTF8), parameter); size_ += text.size(); } - void BindTextUnsafe(const char* parameter, const StringPiece& text) { + void BindTextUnsafe(const char* parameter, const absl::string_view& text) { BindTextUnsafe(GetParameterIndex(parameter), text); } @@ -254,22 +256,22 @@ class SqliteStatement { /// /// When using the unsafe methods, the data must not be changed or /// freed until this statement is Reset() or finalized. - void BindBlob(int parameter, const StringPiece& blob) { + void BindBlob(int parameter, const absl::string_view& blob) { Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(), SQLITE_TRANSIENT), parameter); size_ += blob.size(); } - void BindBlob(const char* parameter, const StringPiece& blob) { + void BindBlob(const char* parameter, const absl::string_view& blob) { BindBlob(GetParameterIndex(parameter), blob); } - void BindBlobUnsafe(int parameter, const StringPiece& blob) { + void BindBlobUnsafe(int parameter, const absl::string_view& blob) { Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(), SQLITE_STATIC), parameter); size_ += blob.size(); } - void BindBlobUnsafe(const char* parameter, const StringPiece& text) { + void BindBlobUnsafe(const char* parameter, const absl::string_view& text) { BindBlobUnsafe(GetParameterIndex(parameter), text); } @@ -312,7 +314,7 @@ class SqliteStatement { /// Empty values are returned as NULL. The returned memory will no /// longer be valid the next time Step() or Reset() is called. No NUL /// terminator is added. - StringPiece ColumnStringUnsafe(int column) const TF_MUST_USE_RESULT { + absl::string_view ColumnStringUnsafe(int column) const TF_MUST_USE_RESULT { return {static_cast(sqlite3_column_blob(stmt_, column)), static_cast(ColumnSize(column))}; } @@ -444,7 +446,7 @@ class TF_SCOPED_LOCKABLE SqliteTransaction { TF_EXCLUSIVE_LOCKS_REQUIRED(__VA_ARGS__) #define SQLITE_TRANSACTIONS_EXCLUDED(...) TF_LOCKS_EXCLUDED(__VA_ARGS__) -inline SqliteStatement Sqlite::PrepareOrDie(const StringPiece& sql) { +inline SqliteStatement Sqlite::PrepareOrDie(const absl::string_view& sql) { SqliteStatement stmt; TF_CHECK_OK(Prepare(sql, &stmt)); return stmt; diff --git a/tensorflow/core/lib/db/sqlite_test.cc b/tensorflow/core/lib/db/sqlite_test.cc index ec394f262c65e7..a3551ca1aa5664 100644 --- a/tensorflow/core/lib/db/sqlite_test.cc +++ b/tensorflow/core/lib/db/sqlite_test.cc @@ -169,7 +169,7 @@ TEST_F(SqliteTest, UnsafeColumn) { TF_ASSERT_OK(stmt.StepAndReset()); stmt = db_->PrepareOrDie("SELECT b FROM T ORDER BY a"); TF_ASSERT_OK(stmt.Step(&is_done_)); - StringPiece p = stmt.ColumnStringUnsafe(0); + absl::string_view p = stmt.ColumnStringUnsafe(0); EXPECT_EQ('h', *p.data()); TF_ASSERT_OK(stmt.Step(&is_done_)); // This will actually happen, but it's not safe to test this behavior. diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.h b/tensorflow/core/lib/jpeg/jpeg_mem.h index 9f60c3618f5448..200e129be83c50 100644 --- a/tensorflow/core/lib/jpeg/jpeg_mem.h +++ b/tensorflow/core/lib/jpeg/jpeg_mem.h @@ -137,7 +137,7 @@ struct CompressFlags { int y_density = 300; // If not empty, embed this XMP metadata in the image header - StringPiece xmp_metadata; + absl::string_view xmp_metadata; // The distance in bytes from one scanline to the other. Should be at least // equal to width*components*sizeof(JSAMPLE). If 0 is passed, the stride diff --git a/tensorflow/core/lib/png/png_io.cc b/tensorflow/core/lib/png/png_io.cc index eedd12533513b6..8e9380998a4800 100644 --- a/tensorflow/core/lib/png/png_io.cc +++ b/tensorflow/core/lib/png/png_io.cc @@ -140,7 +140,7 @@ void CommonFreeDecode(DecodeContext* context) { } } -bool DecodeHeader(StringPiece png_string, int* width, int* height, +bool DecodeHeader(absl::string_view png_string, int* width, int* height, int* components, int* channel_bit_depth, std::vector >* metadata) { DecodeContext context; @@ -201,7 +201,7 @@ bool DecodeHeader(StringPiece png_string, int* width, int* height, return true; } -bool CommonInitDecode(StringPiece png_string, int desired_channels, +bool CommonInitDecode(absl::string_view png_string, int desired_channels, int desired_channel_bits, DecodeContext* context) { CHECK(desired_channel_bits == 8 || desired_channel_bits == 16) << "desired_channel_bits = " << desired_channel_bits; diff --git a/tensorflow/core/lib/png/png_io.h b/tensorflow/core/lib/png/png_io.h index f2d173ab3e82dd..a7fff84c1961ef 100644 --- a/tensorflow/core/lib/png/png_io.h +++ b/tensorflow/core/lib/png/png_io.h @@ -59,7 +59,7 @@ struct DecodeContext { DecodeContext() : png_ptr(nullptr), info_ptr(nullptr) {} }; -bool DecodeHeader(StringPiece png_string, int* width, int* height, +bool DecodeHeader(absl::string_view png_string, int* width, int* height, int* components, int* channel_bit_depth, std::vector >* metadata); @@ -74,7 +74,7 @@ bool DecodeHeader(StringPiece png_string, int* width, int* height, // // desired_channels may be 0 to detected it from the input. -bool CommonInitDecode(StringPiece png_string, int desired_channels, +bool CommonInitDecode(absl::string_view png_string, int desired_channels, int desired_channel_bits, DecodeContext* context); bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context); diff --git a/tensorflow/core/lib/strings/ordered_code.cc b/tensorflow/core/lib/strings/ordered_code.cc index 414bc520a010d4..5b971accbd71a6 100644 --- a/tensorflow/core/lib/strings/ordered_code.cc +++ b/tensorflow/core/lib/strings/ordered_code.cc @@ -161,7 +161,7 @@ const char* OrderedCode::TEST_SkipToNextSpecialByte(const char* start, // Helper routine to encode "s" and append to "*dest", escaping special // characters. -inline static void EncodeStringFragment(string* dest, StringPiece s) { +inline static void EncodeStringFragment(string* dest, absl::string_view s) { const char* p = s.data(); const char* limit = p + s.size(); const char* copy_start = p; @@ -188,7 +188,7 @@ inline static void EncodeStringFragment(string* dest, StringPiece s) { } } -void OrderedCode::WriteString(string* dest, StringPiece s) { +void OrderedCode::WriteString(string* dest, absl::string_view s) { EncodeStringFragment(dest, s); AppendBytes(dest, kEscape1_Separator, 2); } @@ -213,7 +213,7 @@ void OrderedCode::WriteNumIncreasing(string* dest, uint64 val) { // If parse succeeds, return true, consume encoding from // "*src", and if result != NULL append the decoded string to "*result". // Otherwise, return false and leave both undefined. -inline static bool ReadStringInternal(StringPiece* src, string* result) { +inline static bool ReadStringInternal(absl::string_view* src, string* result) { const char* start = src->data(); const char* string_limit = src->data() + src->size(); @@ -268,11 +268,11 @@ inline static bool ReadStringInternal(StringPiece* src, string* result) { return false; } -bool OrderedCode::ReadString(StringPiece* src, string* result) { +bool OrderedCode::ReadString(absl::string_view* src, string* result) { return ReadStringInternal(src, result); } -bool OrderedCode::ReadNumIncreasing(StringPiece* src, uint64* result) { +bool OrderedCode::ReadNumIncreasing(absl::string_view* src, uint64* result) { if (src->empty()) { return false; // Not enough bytes } @@ -452,7 +452,8 @@ void OrderedCode::WriteSignedNumIncreasing(string* dest, int64_t val) { dest->append(begin, len); } -bool OrderedCode::ReadSignedNumIncreasing(StringPiece* src, int64_t* result) { +bool OrderedCode::ReadSignedNumIncreasing(absl::string_view* src, + int64_t* result) { if (src->empty()) return false; const uint64 xor_mask = (!((*src)[0] & 0x80)) ? ~0ULL : 0ULL; const unsigned char first_byte = (*src)[0] ^ (xor_mask & 0xff); diff --git a/tensorflow/core/lib/strings/ordered_code.h b/tensorflow/core/lib/strings/ordered_code.h index bfccfc54938d7a..e7485bd57f7e15 100644 --- a/tensorflow/core/lib/strings/ordered_code.h +++ b/tensorflow/core/lib/strings/ordered_code.h @@ -54,7 +54,7 @@ class OrderedCode { // Encoding routines: each one of the following routines append // one item to "*dest" in an encoding where larger values are // ordered lexicographically after smaller values. - static void WriteString(string* dest, StringPiece str); + static void WriteString(string* dest, absl::string_view str); static void WriteNumIncreasing(string* dest, uint64 num); static void WriteSignedNumIncreasing(string* dest, int64_t num); @@ -66,9 +66,9 @@ class OrderedCode { // result. In case of string result, the decoded string is appended to // "*result". Returns true if the next item was read successfully, false // otherwise. - static bool ReadString(StringPiece* src, string* result); - static bool ReadNumIncreasing(StringPiece* src, uint64* result); - static bool ReadSignedNumIncreasing(StringPiece* src, int64_t* result); + static bool ReadString(absl::string_view* src, string* result); + static bool ReadNumIncreasing(absl::string_view* src, uint64* result); + static bool ReadSignedNumIncreasing(absl::string_view* src, int64_t* result); // Helper for testing: corrupt "*str" by changing the kth item separator // in the string. diff --git a/tensorflow/core/lib/strings/ordered_code_test.cc b/tensorflow/core/lib/strings/ordered_code_test.cc index ed18d12478e0be..4717007fc27fc2 100644 --- a/tensorflow/core/lib/strings/ordered_code_test.cc +++ b/tensorflow/core/lib/strings/ordered_code_test.cc @@ -47,7 +47,7 @@ string RandomString(random::SimplePhilox* rnd, size_t len) { template void OCWriteIncreasing(string* dest, const T& val); template -bool OCReadIncreasing(StringPiece* src, T* result); +bool OCReadIncreasing(absl::string_view* src, T* result); // Read/WriteIncreasing template <> @@ -55,7 +55,7 @@ void OCWriteIncreasing(string* dest, const string& val) { OrderedCode::WriteString(dest, val); } template <> -bool OCReadIncreasing(StringPiece* src, string* result) { +bool OCReadIncreasing(absl::string_view* src, string* result) { return OrderedCode::ReadString(src, result); } @@ -65,7 +65,7 @@ void OCWriteIncreasing(string* dest, const uint64& val) { OrderedCode::WriteNumIncreasing(dest, val); } template <> -bool OCReadIncreasing(StringPiece* src, uint64* result) { +bool OCReadIncreasing(absl::string_view* src, uint64* result) { return OrderedCode::ReadNumIncreasing(src, result); } @@ -75,7 +75,7 @@ void OCWriteIncreasing(string* dest, const int64_t& val) { OrderedCode::WriteSignedNumIncreasing(dest, val); } template <> -bool OCReadIncreasing(StringPiece* src, int64_t* result) { +bool OCReadIncreasing(absl::string_view* src, int64_t* result) { return OrderedCode::ReadSignedNumIncreasing(src, result); } @@ -92,7 +92,7 @@ void OCWriteToString(string* result, T val) { } template -bool OCRead(StringPiece* s, T* val) { +bool OCRead(absl::string_view* s, T* val) { return OCReadIncreasing(s, val); } @@ -103,12 +103,12 @@ template T TestRead(const string& a) { // gracefully reject any proper prefix of an encoding for (int i = 0; i < a.size() - 1; ++i) { - StringPiece s(a.data(), i); + absl::string_view s(a.data(), i); CHECK(!OCRead(&s, nullptr)); CHECK_EQ(s, a.substr(0, i)); } - StringPiece s(a); + absl::string_view s(a); T v; CHECK(OCRead(&s, &v)); CHECK(s.empty()); @@ -304,7 +304,7 @@ inline string StrNot(const string& s) { template void TestInvalidEncoding(const string& s) { - StringPiece p(s); + absl::string_view p(s); EXPECT_FALSE(OCRead(&p, nullptr)); EXPECT_EQ(s, p); } @@ -338,7 +338,7 @@ TEST(OrderedCodeInvalidEncodingsDeathTest, NonCanonical) { EXPECT_NE(OCWrite(0), non_minimal); #ifndef NDEBUG - StringPiece s(non_minimal); + absl::string_view s(non_minimal); EXPECT_DEATH(OrderedCode::ReadNumIncreasing(&s, nullptr), "invalid encoding"); #else @@ -357,7 +357,7 @@ TEST(OrderedCodeInvalidEncodingsDeathTest, NonCanonical) { EXPECT_NE(OCWrite(0), non_minimal); #ifndef NDEBUG - StringPiece s(non_minimal); + absl::string_view s(non_minimal); EXPECT_DEATH(OrderedCode::ReadSignedNumIncreasing(&s, nullptr), "invalid encoding") << n; @@ -408,7 +408,7 @@ void BM_ReadNum(::testing::benchmark::State& state, T multiplier) { uint32 index = 0; for (auto i : state) { T val; - StringPiece s = values[index++ % kValues]; + absl::string_view s = values[index++ % kValues]; OCRead(&s, &val); } } @@ -449,8 +449,8 @@ TEST(String, EncodeDecode) { OCWriteToString(&out, b); string a2, b2, dummy; - StringPiece s = out; - StringPiece s2 = out; + absl::string_view s = out; + absl::string_view s2 = out; CHECK(OCRead(&s, &a2)); CHECK(OCRead(&s2, nullptr)); CHECK_EQ(s, s2); @@ -472,7 +472,7 @@ TEST(String, EncodeDecode) { // 'str' is a string literal that may contain '\0'. #define STATIC_STR(str) StringPiece((str), sizeof(str) - 1) -string EncodeStringIncreasing(StringPiece value) { +string EncodeStringIncreasing(absl::string_view value) { string encoded; OrderedCode::WriteString(&encoded, value); return encoded; @@ -526,7 +526,7 @@ TEST(EncodingIsExpected, String) { OrderedCode::WriteString(&result, t.first); EXPECT_EQ(t.second, result); - StringPiece in = result; + absl::string_view in = result; string decoded; EXPECT_TRUE(OrderedCode::ReadString(&in, &decoded)); EXPECT_EQ(t.first, decoded); @@ -758,7 +758,7 @@ TEST(EncodingIsExpected, Unsigned) { OrderedCode::WriteNumIncreasing(&result, num); EXPECT_EQ(t.second, result) << std::hex << num; - StringPiece in = result; + absl::string_view in = result; uint64 decoded; EXPECT_TRUE(OrderedCode::ReadNumIncreasing(&in, &decoded)); EXPECT_EQ(num, decoded); @@ -1205,7 +1205,7 @@ TEST(EncodingIsExpected, Signed) { OrderedCode::WriteSignedNumIncreasing(&result, num); EXPECT_EQ(t.second, result) << std::hex << num; - StringPiece in = result; + absl::string_view in = result; int64_t decoded; EXPECT_TRUE(OrderedCode::ReadSignedNumIncreasing(&in, &decoded)); EXPECT_EQ(num, decoded); @@ -1242,7 +1242,7 @@ void BM_ReadString(::testing::benchmark::State& state, int len) { for (auto i : state) { result.clear(); - StringPiece s = data; + absl::string_view s = data; OCRead(&s, &result); } state.SetBytesProcessed(state.iterations() * len); diff --git a/tensorflow/core/lib/strings/proto_text_util.cc b/tensorflow/core/lib/strings/proto_text_util.cc index 38ea40b1cc45e2..a1b646448eff02 100644 --- a/tensorflow/core/lib/strings/proto_text_util.cc +++ b/tensorflow/core/lib/strings/proto_text_util.cc @@ -21,7 +21,7 @@ namespace tensorflow { namespace strings { bool ProtoParseBoolFromScanner(Scanner* scanner, bool* value) { - StringPiece bool_str; + absl::string_view bool_str; if (!scanner->RestartCapture() .Many(Scanner::LETTER_DIGIT) .GetResult(nullptr, &bool_str)) { @@ -43,7 +43,7 @@ bool ProtoParseStringLiteralFromScanner(Scanner* scanner, string* value) { const char quote = scanner->Peek(); if (quote != '\'' && quote != '"') return false; - StringPiece value_sp; + absl::string_view value_sp; if (!scanner->One(Scanner::ALL) .RestartCapture() .ScanEscapedUntil(quote) diff --git a/tensorflow/core/lib/strings/proto_text_util.h b/tensorflow/core/lib/strings/proto_text_util.h index af288e0738011f..ef73108b057557 100644 --- a/tensorflow/core/lib/strings/proto_text_util.h +++ b/tensorflow/core/lib/strings/proto_text_util.h @@ -100,7 +100,8 @@ class ProtoTextOutput { } private: - void AppendFieldAndValue(const char field_name[], StringPiece value_text) { + void AppendFieldAndValue(const char field_name[], + absl::string_view value_text) { absl::StrAppend(output_, level_empty_ ? "" : field_separator_, indent_, field_name, kColonSeparator, value_text); level_empty_ = false; @@ -132,7 +133,7 @@ inline void ProtoSpaceAndComments(Scanner* scanner) { // failed. template bool ProtoParseNumericFromScanner(Scanner* scanner, T* value) { - StringPiece numeric_str; + absl::string_view numeric_str; scanner->RestartCapture(); if (!scanner->Many(Scanner::LETTER_DIGIT_DOT_PLUS_MINUS) .GetResult(nullptr, &numeric_str)) { diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 1e263139ebc2c5..7814918c5dab69 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1640,8 +1640,21 @@ absl::Status RangeSize(const Tensor* start_t, const Tensor* limit_t, int64_t size; if (std::is_integral::value) { - size = Eigen::divup(static_cast(Eigen::numext::abs(limit - start)), - static_cast(Eigen::numext::abs(delta))); + uint64_t range; + if ((limit > 0 && start < 0) || (limit < 0 && start > 0)) { + range = static_cast(Eigen::numext::abs(limit)) + + static_cast(Eigen::numext::abs(start)); + } else { + range = static_cast(Eigen::numext::abs(limit - start)); + } + + uint64_t size_unsigned = + Eigen::divup(range, static_cast(Eigen::numext::abs(delta))); + if (size_unsigned > std::numeric_limits::max()) { + return errors::InvalidArgument("Requires ((limit - start) / delta) <= ", + std::numeric_limits::max()); + } + size = static_cast(size_unsigned); } else { auto size_auto = Eigen::numext::ceil(Eigen::numext::abs((limit - start) / delta)); diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index 2440915a66d0aa..b92cfc5a938901 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -810,6 +810,7 @@ cc_library( hdrs = ["stringpiece.h"], compatible_with = get_compatible_with_portable(), deps = [ + "@com_google_absl//absl/base:core_headers", "@local_tsl//tsl/platform:stringpiece", ], ) @@ -950,6 +951,7 @@ cc_library( ":bfloat16", ":platform", ":tstring", + "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:types", "@local_xla//xla/tsl/framework:device_type", ], @@ -1152,6 +1154,8 @@ cc_library( deps = [ "//tensorflow/core/lib/core:status", "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:cord", ], ) @@ -1219,7 +1223,7 @@ tf_cc_tests( "//tensorflow/core:lib_test_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@local_tsl//tsl/platform:logging", + "@local_xla//xla/tsl/platform:logging", "@zlib", ], ) @@ -1300,6 +1304,9 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/core/platform/build_config.default.bzl b/tensorflow/core/platform/build_config.default.bzl index c50a06ce635c2d..04f5bb79e08a69 100644 --- a/tensorflow/core/platform/build_config.default.bzl +++ b/tensorflow/core/platform/build_config.default.bzl @@ -32,7 +32,7 @@ def tf_additional_binary_deps(): Label("@local_xla//xla/stream_executor:cuda_platform"), ]) + if_rocm([ "@local_xla//xla/stream_executor:rocm_platform", - "@local_xla//xla/stream_executor/rocm:rocm_rpath", + "@local_config_rocm//rocm:rocm_rpath", ]) + if_mkl_ml([ Label("@local_xla//xla/tsl/mkl:intel_binary_blob"), ]) diff --git a/tensorflow/core/platform/error_payloads.cc b/tensorflow/core/platform/error_payloads.cc index 257f80b908f733..b78143ec50c8de 100644 --- a/tensorflow/core/platform/error_payloads.cc +++ b/tensorflow/core/platform/error_payloads.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/core/platform/error_payloads.h" +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "tensorflow/core/protobuf/core_platform_payloads.pb.h" + namespace tsl { using ::tensorflow::core::platform::ErrorSourceProto; diff --git a/tensorflow/core/platform/error_payloads.h b/tensorflow/core/platform/error_payloads.h index e976dfc0c470dc..7f1d8b61f8c3b3 100644 --- a/tensorflow/core/platform/error_payloads.h +++ b/tensorflow/core/platform/error_payloads.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_ERROR_PAYLOADS_H_ #define TENSORFLOW_CORE_PLATFORM_ERROR_PAYLOADS_H_ +#include "absl/status/status.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/protobuf/core_platform_payloads.pb.h" // This file contains macros and payload keys for the error counter in diff --git a/tensorflow/core/platform/fake_python_env_test.cc b/tensorflow/core/platform/fake_python_env_test.cc index b521db3c054bff..6547331fcb587c 100644 --- a/tensorflow/core/platform/fake_python_env_test.cc +++ b/tensorflow/core/platform/fake_python_env_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include -#include -#include + +#include #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/env.h" diff --git a/tensorflow/core/platform/file_system_test.cc b/tensorflow/core/platform/file_system_test.cc index b07e72b2b187c9..2c848292ed13cc 100644 --- a/tensorflow/core/platform/file_system_test.cc +++ b/tensorflow/core/platform/file_system_test.cc @@ -17,6 +17,16 @@ limitations under the License. #include +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_join.h" +#include "absl/strings/strip.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/null_file_system.h" #include "tensorflow/core/platform/path.h" @@ -127,7 +137,7 @@ class InterPlanetaryFileSystem : public NullFileSystem { } void ParsePath(const string& name, string* parsed_path) { - StringPiece scheme, host, path; + absl::string_view scheme, host, path; this->ParseURI(name, &scheme, &host, &path); ASSERT_EQ(scheme, "ipfs"); ASSERT_EQ(host, "solarsystem"); @@ -163,10 +173,10 @@ string Match(InterPlanetaryFileSystem* ipfs, const string& suffix_pattern) { if (!s.ok()) { return s.ToString(); } else { - std::vector trimmed_results; + std::vector trimmed_results; std::sort(results.begin(), results.end()); for (const string& result : results) { - StringPiece trimmed_result(result); + absl::string_view trimmed_result(result); EXPECT_TRUE( absl::ConsumePrefix(&trimmed_result, strings::StrCat(kPrefix, "/"))); trimmed_results.push_back(trimmed_result); diff --git a/tensorflow/core/platform/float8.h b/tensorflow/core/platform/float8.h index e2cad449d4aa13..dd80b37a4f4519 100644 --- a/tensorflow/core/platform/float8.h +++ b/tensorflow/core/platform/float8.h @@ -21,6 +21,9 @@ limitations under the License. namespace tensorflow { typedef tsl::float8_e4m3fn float8_e4m3fn; typedef tsl::float8_e5m2 float8_e5m2; +typedef tsl::float8_e4m3fnuz float8_e4m3fnuz; +typedef tsl::float8_e4m3b11fnuz float8_e4m3b11fnuz; +typedef tsl::float8_e5m2fnuz float8_e5m2fnuz; } // namespace tensorflow #endif // TENSORFLOW_CORE_PLATFORM_FLOAT8_H_ diff --git a/tensorflow/core/platform/numbers.h b/tensorflow/core/platform/numbers.h index 08732fcf6ca056..3164aab44ff76e 100644 --- a/tensorflow/core/platform/numbers.h +++ b/tensorflow/core/platform/numbers.h @@ -45,8 +45,6 @@ using tsl::strings::safe_strtof; using tsl::strings::safe_strtou32; using tsl::strings::safe_strtou64; using tsl::strings::SafeStringToNumeric; -using tsl::strings::StringToFp; -using tsl::strings::Uint64ToHexString; // NOLINTEND(misc-unused-using-decls) } // namespace strings } // namespace tensorflow diff --git a/tensorflow/core/platform/profile_utils/BUILD b/tensorflow/core/platform/profile_utils/BUILD index f9e43d033eed88..12bf382af647e2 100644 --- a/tensorflow/core/platform/profile_utils/BUILD +++ b/tensorflow/core/platform/profile_utils/BUILD @@ -53,8 +53,8 @@ cc_library( "//tensorflow/core/platform:macros", "//tensorflow/core/platform:types", "@com_google_absl//absl/base", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:types", + "@local_xla//xla/tsl/platform:macros", + "@local_xla//xla/tsl/platform:types", "@local_xla//xla/tsl/platform/profile_utils:profile_utils_cpu_utils", ], alwayslink = 1, diff --git a/tensorflow/core/platform/stringpiece.h b/tensorflow/core/platform/stringpiece.h index 66040fc997173c..43f3d4a9c38a78 100644 --- a/tensorflow/core/platform/stringpiece.h +++ b/tensorflow/core/platform/stringpiece.h @@ -26,11 +26,17 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_STRINGPIECE_H_ #define TENSORFLOW_CORE_PLATFORM_STRINGPIECE_H_ +#include "absl/base/macros.h" #include "tsl/platform/stringpiece.h" // IWYU pragma: export +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + namespace tensorflow { -using StringPiece = absl::string_view; +using StringPiece ABSL_DEPRECATE_AND_INLINE() = absl::string_view; } // namespace tensorflow diff --git a/tensorflow/core/platform/tensor_coding.cc b/tensorflow/core/platform/tensor_coding.cc index 38f1d26508722f..b5aa5ffe150c8e 100644 --- a/tensorflow/core/platform/tensor_coding.cc +++ b/tensorflow/core/platform/tensor_coding.cc @@ -32,7 +32,8 @@ limitations under the License. namespace tensorflow { namespace port { -void AssignRefCounted(StringPiece src, core::RefCounted* obj, string* out) { +void AssignRefCounted(absl::string_view src, core::RefCounted* obj, + string* out) { out->assign(src.data(), src.size()); } @@ -55,7 +56,7 @@ void EncodeStringList(const tstring* strings, int64_t n, string* out) { bool DecodeStringList(const string& src, tstring* strings, int64_t n) { std::vector sizes(n); - StringPiece reader(src); + absl::string_view reader(src); int64_t tot = 0; for (auto& v : sizes) { if (!core::GetVarint32(&reader, &v)) return false; @@ -130,7 +131,7 @@ class StringListDecoderImpl : public StringListDecoder { } private: - StringPiece reader_; + absl::string_view reader_; }; std::unique_ptr NewStringListEncoder(string* out) { @@ -142,7 +143,8 @@ std::unique_ptr NewStringListDecoder(const string& in) { } #if defined(TENSORFLOW_PROTOBUF_USES_CORD) -void AssignRefCounted(StringPiece src, core::RefCounted* obj, absl::Cord* out) { +void AssignRefCounted(absl::string_view src, core::RefCounted* obj, + absl::Cord* out) { obj->Ref(); *out = absl::MakeCordFromExternal(src, [obj] { obj->Unref(); }); } diff --git a/tensorflow/core/platform/tensor_coding.h b/tensorflow/core/platform/tensor_coding.h index fb10b14b757f94..b024e1432e9fd7 100644 --- a/tensorflow/core/platform/tensor_coding.h +++ b/tensorflow/core/platform/tensor_coding.h @@ -31,7 +31,8 @@ namespace port { // Store src contents in *out. If backing memory for src is shared with *out, // will ref obj during the call and will arrange to unref obj when no // longer needed. -void AssignRefCounted(StringPiece src, core::RefCounted* obj, std::string* out); +void AssignRefCounted(absl::string_view src, core::RefCounted* obj, + std::string* out); // Copy contents of src to dst[0,src.size()-1]. inline void CopyToArray(const std::string& src, char* dst) { @@ -100,7 +101,8 @@ std::unique_ptr NewStringListDecoder(const string& in); // Store src contents in *out. If backing memory for src is shared with *out, // will ref obj during the call and will arrange to unref obj when no // longer needed. -void AssignRefCounted(StringPiece src, core::RefCounted* obj, absl::Cord* out); +void AssignRefCounted(absl::string_view src, core::RefCounted* obj, + absl::Cord* out); // TODO(kmensah): Macro guard this with a check for Cord support. inline void CopyToArray(const absl::Cord& src, char* dst) { diff --git a/tensorflow/core/platform/testdata/test_echo_argv_1.cc b/tensorflow/core/platform/testdata/test_echo_argv_1.cc index e7563315ce6ea3..78034a6790b427 100644 --- a/tensorflow/core/platform/testdata/test_echo_argv_1.cc +++ b/tensorflow/core/platform/testdata/test_echo_argv_1.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include int main(int argc, char** argv) { std::cout << argv[1]; diff --git a/tensorflow/core/platform/types.h b/tensorflow/core/platform/types.h index a3159bfe8abea9..5e4498717ec096 100644 --- a/tensorflow/core/platform/types.h +++ b/tensorflow/core/platform/types.h @@ -38,8 +38,11 @@ using tsl::int4; using tsl::int64; using tsl::int8; +using tsl::float8_e4m3b11fnuz; using tsl::float8_e4m3fn; +using tsl::float8_e4m3fnuz; using tsl::float8_e5m2; +using tsl::float8_e5m2fnuz; static const uint8 kuint8max = tsl::kuint8max; static const uint16 kuint16max = tsl::kuint16max; diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index a8eee043f09545..c363133a7ec92d 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -1,6 +1,6 @@ load("//tensorflow:tensorflow.bzl", "if_oss", "tf_cc_test") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts") +load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_alias", "tf_profiler_copts") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -106,7 +106,6 @@ cc_library( "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", "//tensorflow/core/profiler/utils:math_utils", - "@com_google_absl//absl/strings", "@local_xla//xla/tsl/profiler/convert:xla_op_utils", "@local_xla//xla/tsl/profiler/utils:tf_op_utils", ], @@ -377,15 +376,16 @@ cc_library( "//tensorflow/core/profiler/utils:device_caps_utils", "//tensorflow/core/profiler/utils:event_span", "//tensorflow/core/profiler/utils:hardware_type_utils", + "//tensorflow/core/profiler/utils:hlo_module_map", "//tensorflow/core/profiler/utils:hlo_proto_map", "//tensorflow/core/profiler/utils:kernel_stats_utils", + "//tensorflow/core/profiler/utils:op_utils", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_utils", "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:math_utils", "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", @@ -422,7 +422,6 @@ tf_cc_test( ":repository", ":step_events_to_steps_db", ":xplane_to_op_stats", - ":xplane_to_step_events", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:test", @@ -433,13 +432,11 @@ tf_cc_test( "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", "//tensorflow/core/profiler/protobuf:tf_function_proto_cc", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", - "//tensorflow/core/profiler/utils:op_metrics_db_utils", "//tensorflow/core/profiler/utils:xplane_builder", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_test_utils", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:group_events", - "@local_xla//xla/tsl/profiler/utils:xplane_schema", ], ) @@ -708,6 +705,7 @@ cc_library( "//tensorflow/core/profiler/protobuf:op_profile_proto_cc", "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", "//tensorflow/core/profiler/protobuf:overview_page_proto_cc", + "//tensorflow/core/profiler/protobuf:roofline_model_proto_cc", "//tensorflow/core/profiler/protobuf:tf_data_stats_proto_cc", "//tensorflow/core/profiler/protobuf:tf_stats_proto_cc", "//tensorflow/core/profiler/utils:hardware_type_utils", @@ -1007,7 +1005,6 @@ cc_library( "//tensorflow/core/profiler/convert/trace_viewer:trace_events_util", "//tensorflow/core/profiler/protobuf:trace_events_proto_cc", "//tensorflow/core/profiler/protobuf:trace_events_raw_proto_cc", - "//tensorflow/core/profiler/utils:xplane_utils", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", @@ -1309,6 +1306,56 @@ cc_library( ], ) +cc_library( + name = "profile_time_breakdown", + srcs = ["profile_time_breakdown.cc"], + hdrs = ["profile_time_breakdown.h"], + visibility = ["@local_xla//xla/tsl/profiler:friends"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_xla//xla/tsl/profiler/convert:xla_op_utils", + "@local_xla//xla/tsl/profiler/utils:math_utils", + ], +) + +cc_library( + name = "tpu_input_pipeline_analysis_constants", + srcs = [tf_profiler_alias("//tensorflow/core/profiler/convert/", "tpu_input_pipeline_analysis_constants.cc")], + hdrs = ["tpu_input_pipeline_analysis_constants.h"], + visibility = ["@local_xla//xla/tsl/profiler:friends"], + deps = [ + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:macros", + ], +) + +cc_library( + name = "duty_cycle_tracker", + srcs = ["duty_cycle_tracker.cc"], + hdrs = ["duty_cycle_tracker.h"], + deps = [ + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/log:check", + "@local_xla//xla/tsl/profiler/utils:math_utils", + "@local_xla//xla/tsl/profiler/utils:timespan", + ], +) + +tf_cc_test( + name = "duty_cycle_tracker_test", + srcs = ["duty_cycle_tracker_test.cc"], + deps = [ + ":duty_cycle_tracker", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/log:check", + "@local_xla//xla/tsl/profiler/utils:timespan", + ], +) + tf_cc_test( name = "compute_inference_latency_test", srcs = ["compute_inference_latency_test.cc"], diff --git a/tensorflow/core/profiler/convert/duty_cycle_tracker.cc b/tensorflow/core/profiler/convert/duty_cycle_tracker.cc new file mode 100644 index 00000000000000..fa17ad7c98aa1e --- /dev/null +++ b/tensorflow/core/profiler/convert/duty_cycle_tracker.cc @@ -0,0 +1,97 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/profiler/convert/duty_cycle_tracker.h" + +#include + +#include +#include +#include + +#include "absl/container/btree_set.h" +#include "absl/log/check.h" +#include "xla/tsl/profiler/utils/timespan.h" + +namespace tensorflow { +namespace profiler { + +using tsl::profiler::Timespan; + +DutyCycleTracker::ActiveTimeSpans::const_iterator +DutyCycleTracker::MergeOrInsert(const Timespan& timespan, + ActiveTimeSpans::const_iterator hint) { + ActiveTimeSpans::const_iterator merge_begin = hint; + while (merge_begin != active_time_spans_.end() && + merge_begin->end_ps() < timespan.begin_ps()) { + ++merge_begin; + } + + // timespan is fully contained in an existing timespan. + if (merge_begin != active_time_spans_.end() && + merge_begin->Includes(timespan)) { + return merge_begin; + } + + ActiveTimeSpans::const_iterator merge_end = merge_begin; + while (merge_end != active_time_spans_.end() && + merge_end->begin_ps() <= timespan.end_ps()) { + ++merge_end; + } + if (merge_begin != merge_end) { + Timespan merged = Timespan::FromEndPoints( + std::min(timespan.begin_ps(), merge_begin->begin_ps()), + std::max(timespan.end_ps(), std::prev(merge_end)->end_ps())); + merge_end = active_time_spans_.erase(merge_begin, merge_end); + return active_time_spans_.insert(merge_end, merged); + } else { + // There is no overlap with the existing timespans. + return active_time_spans_.insert(merge_begin, timespan); + } +} + +void DutyCycleTracker::AddInterval(tsl::profiler::Timespan time_span, + bool is_active) { + total_time_span_.ExpandToInclude(time_span); + if (!is_active) { + return; + } + + MergeOrInsert(time_span, active_time_spans_.lower_bound(time_span)); +} + +void DutyCycleTracker::Union(const DutyCycleTracker& other) { + total_time_span_.ExpandToInclude(other.total_time_span_); + if (other.active_time_spans_.empty()) return; + ActiveTimeSpans::const_iterator hint_it = + active_time_spans_.lower_bound(*other.active_time_spans_.begin()); + for (const auto& interval : other.active_time_spans_) { + hint_it = MergeOrInsert(interval, hint_it); + } +} + +uint64_t DutyCycleTracker::GetActiveTimePs() const { + uint64_t active_time_ps = 0; + for (const auto& interval : active_time_spans_) { + DCHECK(!interval.Empty()); + active_time_ps += interval.duration_ps(); + } + return active_time_ps; +} + +uint64_t DutyCycleTracker::GetIdleTimePs() const { + return total_time_span_.duration_ps() - GetActiveTimePs(); +} +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/duty_cycle_tracker.h b/tensorflow/core/profiler/convert/duty_cycle_tracker.h new file mode 100644 index 00000000000000..fa89aeb3597ed3 --- /dev/null +++ b/tensorflow/core/profiler/convert/duty_cycle_tracker.h @@ -0,0 +1,72 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_DUTY_CYCLE_TRACKER_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_DUTY_CYCLE_TRACKER_H_ + +#include + +#include "absl/container/btree_set.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/timespan.h" + +namespace tensorflow { +namespace profiler { + +// Tracks the active time intervals for a given TPU core. +// Disjoint intervals of time in ps for which this core was active. +class DutyCycleTracker { + public: + explicit DutyCycleTracker() : active_time_spans_() {} + ~DutyCycleTracker() = default; + void AddInterval(tsl::profiler::Timespan time_span, bool is_active); + void Union(const DutyCycleTracker& other); + uint64_t GetActiveTimePs() const; + uint64_t GetIdleTimePs() const; + uint64_t GetDurationPs() const { return total_time_span_.duration_ps(); } + double DutyCycle() const { + return tsl::profiler::SafeDivide(GetActiveTimePs(), GetDurationPs()); + } + + private: + struct TimespanComparator { + // Order by increasing begin_ps, then decreasing duration_ps. + bool operator()(const tsl::profiler::Timespan& a, + const tsl::profiler::Timespan& b) const { + return a.begin_ps() < b.begin_ps() || (a.begin_ps() == b.begin_ps() && + a.duration_ps() > b.duration_ps()); + } + }; + using ActiveTimeSpans = + absl::btree_set; + + /** + * Merge or insert the given timespan into the set of active time spans. + * + * @param timespan The timespan to merge or insert. + * @param hint The iterator indicating where to begin the merge search. + * @return The iterator where the timespan was merged or inserted. + */ + ActiveTimeSpans::const_iterator MergeOrInsert( + const tsl::profiler::Timespan& timespan, + ActiveTimeSpans::const_iterator hint); + + ActiveTimeSpans active_time_spans_; + tsl::profiler::Timespan total_time_span_; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_DUTY_CYCLE_TRACKER_H_ diff --git a/tensorflow/core/profiler/convert/duty_cycle_tracker_test.cc b/tensorflow/core/profiler/convert/duty_cycle_tracker_test.cc new file mode 100644 index 00000000000000..e257f45f6335ae --- /dev/null +++ b/tensorflow/core/profiler/convert/duty_cycle_tracker_test.cc @@ -0,0 +1,124 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/profiler/convert/duty_cycle_tracker.h" + +#include + +#include +#include + +#include "absl/log/check.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace profiler { +namespace { + +using ::tsl::profiler::Timespan; + +TEST(DutyCycleTrackerTest, TimeIntervalsTest) { + DutyCycleTracker tracker; + tracker.AddInterval(Timespan::FromEndPoints(0, 10), true); + tracker.AddInterval(Timespan::FromEndPoints(20, 30), true); + EXPECT_EQ(tracker.GetActiveTimePs(), 20); + EXPECT_EQ(tracker.GetIdleTimePs(), 10); + EXPECT_EQ(tracker.GetDurationPs(), 30); +} + +TEST(DutyCycleTrackerTest, UnionTest) { + DutyCycleTracker tracker; + tracker.AddInterval(Timespan::FromEndPoints(0, 10), true); + tracker.AddInterval(Timespan::FromEndPoints(20, 30), true); + + DutyCycleTracker other_tracker; + other_tracker.AddInterval(Timespan::FromEndPoints(10, 20), true); + other_tracker.AddInterval(Timespan::FromEndPoints(30, 40), true); + + tracker.Union(other_tracker); + EXPECT_EQ(tracker.GetActiveTimePs(), 40); + EXPECT_EQ(tracker.GetIdleTimePs(), 0); + EXPECT_EQ(tracker.GetDurationPs(), 40); +} + +TEST(DutyCycleTrackerTest, ActiveTimeTest) { + DutyCycleTracker tracker; + EXPECT_EQ(tracker.GetActiveTimePs(), 0); + tracker.AddInterval(Timespan::FromEndPoints(0, 10), true); + EXPECT_EQ(tracker.GetActiveTimePs(), 10); +} + +void BM_DutyCycleTracker_AddInterval(::testing::benchmark::State& state) { + std::vector timespans; + timespans.reserve(state.range(0)); + for (uint64_t i = 0; i < state.range(0); ++i) { + timespans.push_back(Timespan::FromEndPoints(i * 2, i * 2 + 1)); + } + for (auto s : state) { + DutyCycleTracker tracker; + for (const auto& timespan : timespans) { + tracker.AddInterval(timespan, true); + } + } + state.SetItemsProcessed(state.iterations() * timespans.size()); +} + +BENCHMARK(BM_DutyCycleTracker_AddInterval)->Range(1 << 15, 1 << 21); + +void BM_DutyCycleTracker_AddInterval_Merge(::testing::benchmark::State& state) { + std::vector timespans; + timespans.reserve(state.range(0)); + for (uint64_t i = 0; i < state.range(0); ++i) { + timespans.push_back(Timespan::FromEndPoints(i, i + 1)); + } + for (auto s : state) { + DutyCycleTracker tracker; + for (const auto& timespan : timespans) { + tracker.AddInterval(timespan, true); + } + } + state.SetItemsProcessed(state.iterations() * timespans.size()); +} + +BENCHMARK(BM_DutyCycleTracker_AddInterval_Merge)->Range(1 << 15, 1 << 21); + +void BM_DutyCycleTracker_Union(::testing::benchmark::State& state) { + DCHECK_GT(state.range(1), 1); + DCHECK_LT(state.range(1), state.range(0)); + DutyCycleTracker tracker_a; + DutyCycleTracker tracker_b; + uint64_t merge_rate = state.range(1); + for (uint64_t i = 0; i < state.range(0); ++i) { + tracker_a.AddInterval(Timespan::FromEndPoints(i * 2, i * 2 + 1), true); + if (i % merge_rate == 0) { + tracker_b.AddInterval( + Timespan::FromEndPoints(i * 2, (i + merge_rate - 1) * 2), true); + } + } + for (auto s : state) { + DutyCycleTracker unioned_tracker; + unioned_tracker.Union(tracker_a); + unioned_tracker.Union(tracker_b); + } + state.SetItemsProcessed(state.iterations() * + (state.range(0) + state.range(0) / merge_rate)); +} + +BENCHMARK(BM_DutyCycleTracker_Union)->RangePair(1 << 10, 1 << 16, 2, 10); + +} // namespace +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc b/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc index d0cb6d46078eca..2f32b0ba45c9de 100644 --- a/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc +++ b/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc @@ -41,7 +41,7 @@ absl::Status ConvertMultiXSpacesToCombinedOpStats( TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, session_snapshot.GetXSpace(i)); PreprocessSingleHostXSpace(xspace.get(), /*step_grouping=*/true, - /*derived_timeline=*/false); + /*derived_timeline=*/true); all_op_stats.push_back(ConvertXSpaceToOpStats(*xspace, options)); } diff --git a/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc b/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc index 9e68980c72db16..978ce5c60e0e2e 100644 --- a/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc +++ b/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc @@ -125,6 +125,8 @@ void OpMetricsDbCombiner::Combine(const OpMetricsDb& src, dst->total_host_infeed_enq_start_timestamp_ps_diff()); dst->set_total_time_ps(src.total_time_ps() + dst->total_time_ps()); dst->set_total_op_time_ps(src.total_op_time_ps() + dst->total_op_time_ps()); + dst->set_idle_time_ps(src.idle_time_ps() + dst->idle_time_ps()); + dst->set_busy_time_ps(src.busy_time_ps() + dst->busy_time_ps()); CombinePrecisionStats(src.precision_stats(), dst->mutable_precision_stats()); for (const auto& src_metrics : src.metrics_db()) { diff --git a/tensorflow/core/profiler/convert/op_profile_builder.cc b/tensorflow/core/profiler/convert/op_profile_builder.cc index e94e09e2036957..8b318a9cf6e686 100644 --- a/tensorflow/core/profiler/convert/op_profile_builder.cc +++ b/tensorflow/core/profiler/convert/op_profile_builder.cc @@ -128,10 +128,13 @@ void FinalizeDeduplicatedNodes(bool by_program, Node* root) { for (Node& program_node : *root->mutable_children()) { for (Node& category_node : *program_node.mutable_children()) { for (Node& deduplicated_node : *category_node.mutable_children()) { - // Skip for non deduplicated nodes. Those nodes already have name set. - if (!deduplicated_node.name().empty() || - deduplicated_node.children().empty()) + // Node with 1 child doesn't have deduplication, the child is itself. + // Removing the dedup layer. + if (deduplicated_node.children_size() == 1) { + Node child = *deduplicated_node.mutable_children(0); + deduplicated_node = child; continue; + } CopySymbolDetailsToDeduplicatedNode( deduplicated_node.mutable_children(0), &deduplicated_node); } @@ -140,10 +143,13 @@ void FinalizeDeduplicatedNodes(bool by_program, Node* root) { } else { for (Node& category_node : *root->mutable_children()) { for (Node& deduplicated_node : *category_node.mutable_children()) { - // Skip for non deduplicated nodes. Those nodes already have name set. - if (!deduplicated_node.name().empty() || - deduplicated_node.children().empty()) + // Node with 1 child doesn't have deduplication, the child is itself. + // Removing the dedup layer. + if (deduplicated_node.children_size() == 1) { + Node child = *deduplicated_node.mutable_children(0); + deduplicated_node = child; continue; + } CopySymbolDetailsToDeduplicatedNode( deduplicated_node.mutable_children(0), &deduplicated_node); } @@ -281,12 +287,62 @@ Node* OpProfileBuilder::AddOpNode(const OpMetrics& op_metrics, return leaf; } +// Function to create deduplicated aggregation layer. +// 1. Empty deduplicated_name in op_metrics means either: +// (1) a grouping op of a deduplicated op list. (fusion.3 in the example below) +// (2) an op that does not have duplicates. (fusion.4 in the example below) +// We create dedup layer for both cases due to lack of clue which case it is. +// The op name is used directly as the hash key for the dedup group. The dedup +// layer will be removed in the 2nd pass for case (2). +// 2. Non-empty deduplicated_name means this op can be grouped to a +// deduplicated op list (fusion.1 in the example below). +// Example: +// op_metrics { +// name: "fusion.1" +// deduplicated_name: "fusion.3" +// category: "convolution" +// } +// op_metrics { +// name: "fusion.3" +// deduplicated_name: "" +// category: "convolution" +// } +// op_metrics { +// name: "fusion.4" +// deduplicated_name: "" +// category: "convolution" +// } +// The data above will create the following tree after calling the function +// repeatedly: +// root(by_program) +// - jit.xx +// - convolution +// - fusion.3 +// - fusion.1 +// - fusion.2 +// - fusion.3 +// - fusion.4 +// - fusion.4 +// After finalization, the tree will look like: +// root(by_program) +// - jit.xx +// - convolution +// - fusion.3 and its duplicate(s) +// - fusion.1 +// - fusion.2 +// - fusion.3 +// - fusion.4 Node* OpProfileBuilder::LookupOrAddDeduplicatedNode(const OpMetrics& op_metrics, Category* category) { - Node*& deduplicated_node = - category->deduplicated_nodes[op_metrics.deduplicated_name()]; + std::string deduplicated_name = op_metrics.deduplicated_name().empty() + ? op_metrics.name() + : op_metrics.deduplicated_name(); + Node*& deduplicated_node = category->deduplicated_nodes[deduplicated_name]; if (deduplicated_node == nullptr) { deduplicated_node = category->node->add_children(); + // Set deduplicated name which is the hash key for the dedup group. + // Symbol details will be added in finalization step. + deduplicated_node->set_name(deduplicated_name); } return deduplicated_node; } @@ -341,8 +397,7 @@ void OpProfileBuilder::AddOp(const OpMetrics& op_metrics) { nested_grouping_nodes.push_back(category->node); Node* deduplicated_node = nullptr; - if (options_.group_by_deduplicated_name && - !op_metrics.deduplicated_name().empty()) { + if (options_.group_by_deduplicated_name) { deduplicated_node = LookupOrAddDeduplicatedNode(op_metrics, category); nested_grouping_nodes.push_back(deduplicated_node); } diff --git a/tensorflow/core/profiler/convert/op_stats_combiner.cc b/tensorflow/core/profiler/convert/op_stats_combiner.cc index 34e102dca8aa7f..2bc15581f79b56 100644 --- a/tensorflow/core/profiler/convert/op_stats_combiner.cc +++ b/tensorflow/core/profiler/convert/op_stats_combiner.cc @@ -118,6 +118,11 @@ void CombineRunEnvironment(const RunEnvironment& src, RunEnvironment* dst) { } else if (dst->device_type().empty()) { dst->set_device_type(src.device_type()); } + if (src.hardware_type() != dst->hardware_type()) { + // Select the highest hardware type as TPU/GPU should override CPU_ONLY + // (e.g. coordinator). + dst->set_hardware_type(std::max(src.hardware_type(), dst->hardware_type())); + } dst->set_task_count(src.task_count() + dst->task_count()); // Only overwrite the dst if profile_duration_ms in dst is not defined or // is zero and profile_duration_ms in src is greater than zero. diff --git a/tensorflow/core/profiler/convert/op_stats_combiner_test.cc b/tensorflow/core/profiler/convert/op_stats_combiner_test.cc index 9268f72539703a..cd5e97fe3c7e18 100644 --- a/tensorflow/core/profiler/convert/op_stats_combiner_test.cc +++ b/tensorflow/core/profiler/convert/op_stats_combiner_test.cc @@ -107,6 +107,18 @@ TEST(CombineAllOpStatsTest, CombinePerfEnvOrderZero) { EXPECT_EQ(100, dst_op_stats2.perf_env().peak_tera_flops_per_second()); } +TEST(CombineAllOpStatsTest, CombineRunEnvironmentWithMismatchHardwareType) { + OpStats coordinator_op_stats, device_op_stats, dst_op_stats; + coordinator_op_stats.mutable_run_environment()->set_hardware_type( + HardwareType::CPU_ONLY); + device_op_stats.mutable_run_environment()->set_hardware_type( + HardwareType::TPU); + CombineAllOpStats({OpStatsInfo(&coordinator_op_stats, CPU_ONLY, 0), + OpStatsInfo(&device_op_stats, TPU, 1)}, + StepIntersection(1, {}), &dst_op_stats); + EXPECT_EQ(dst_op_stats.run_environment().hardware_type(), HardwareType::TPU); +} + } // namespace } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc index e13e0cb73a2ab5..bd21fae928c3de 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc @@ -114,27 +114,6 @@ double GetTimeInMs(const Collection& type_ps, EventType event_type) { return PicoToMilli(gtl::FindWithDefault(type_ps, event_type, /*value=*/0)); } -StepSummary GetStepSummaryForSampleStats( - const tsl::Stat& sample_stats) { - StepSummary step_time_summary; - double avg, sdv, min, max; - if (sample_stats.empty()) { - // If sample_stats is empty, sample_stats.avg() will return NaN. However, we - // prefer to show an 0 instead. - avg = sdv = min = max = 0.0; - } else { - avg = sample_stats.avg(); - sdv = sqrt(sample_stats.sample_variance()); - min = sample_stats.min(); - max = sample_stats.max(); - } - step_time_summary.set_average(avg); - step_time_summary.set_standard_deviation(sdv); - step_time_summary.set_minimum(min); - step_time_summary.set_maximum(max); - return step_time_summary; -} - GenericStepTimeBreakdown ComputeGenericStepTimeBreakdownInMs( const InputPipelineAnalysisResult& analysis) { tsl::Stat unknown_time_ms; @@ -484,6 +463,27 @@ std::string DatasetIntroDoc() { } // namespace +StepSummary GetStepSummaryForSampleStats( + const tsl::Stat& sample_stats) { + StepSummary step_time_summary; + double avg, sdv, min, max; + if (sample_stats.empty()) { + // If sample_stats is empty, sample_stats.avg() will return NaN. However, we + // prefer to show an 0 instead. + avg = sdv = min = max = 0.0; + } else { + avg = sample_stats.avg(); + sdv = sqrt(sample_stats.sample_variance()); + min = sample_stats.min(); + max = sample_stats.max(); + } + step_time_summary.set_average(avg); + step_time_summary.set_standard_deviation(sdv); + step_time_summary.set_minimum(min); + step_time_summary.set_maximum(max); + return step_time_summary; +} + void GenerateHostResult(const OpMetricsDb& host_tf_metrics_db, InputPipelineAnalysisResult* result) { InputOpMetrics input_op_metrics = SelectInputOpMetrics(host_tf_metrics_db); diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h index cc54a7ea684f43..c9de162eb8c058 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h +++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h @@ -20,6 +20,7 @@ limitations under the License. #include "google/protobuf/any.pb.h" #include "absl/strings/string_view.h" +#include "xla/tsl/util/stats_calculator.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" @@ -31,6 +32,8 @@ limitations under the License. namespace tensorflow { namespace profiler { +StepSummary GetStepSummaryForSampleStats(const tsl::Stat& sample_stats); + // If the percent of input-time spent on host-to-device transfer is greater than // kHostToDeviceTimePercentAsSignificant, we should advise the // user to optimize this transfer. diff --git a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc b/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc index c81d71a629aea6..fc827f55b24d9b 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_roofline_model.cc @@ -244,7 +244,9 @@ RooflineModelDatabase InitializeRooflineModelDatabaseFromOpStats( RooflineModelDatabase ConvertOpStatsToRooflineModel( const OpStats& op_stats, bool include_infeed_outfeed) { HardwareType hardware_type = op_stats.run_environment().hardware_type(); - DCHECK(hardware_type == GPU || hardware_type == TPU); + if (hardware_type != GPU && hardware_type != TPU) { + return RooflineModelDatabase(); + } RooflineModelDatabase roofline_model_db = InitializeRooflineModelDatabaseFromOpStats(op_stats, diff --git a/tensorflow/core/profiler/convert/oss/BUILD b/tensorflow/core/profiler/convert/oss/BUILD new file mode 100644 index 00000000000000..b2a4a71ee08bf7 --- /dev/null +++ b/tensorflow/core/profiler/convert/oss/BUILD @@ -0,0 +1,4 @@ +exports_files( + ["tpu_input_pipeline_analysis_constants.cc"], + visibility = ["//tensorflow/core/profiler/convert:__pkg__"], +) diff --git a/tensorflow/core/profiler/convert/oss/tpu_input_pipeline_analysis_constants.cc b/tensorflow/core/profiler/convert/oss/tpu_input_pipeline_analysis_constants.cc new file mode 100644 index 00000000000000..006f4c2cc0a421 --- /dev/null +++ b/tensorflow/core/profiler/convert/oss/tpu_input_pipeline_analysis_constants.cc @@ -0,0 +1,27 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/profiler/convert/tpu_input_pipeline_analysis_constants.h" + +#include "absl/strings/string_view.h" + +namespace tensorflow { +namespace profiler { + +constexpr absl::string_view kProfileAllHostsDoc = + "https://cloud.google.com/tpu/docs/troubleshooting/troubleshoot-multislice"; +constexpr absl::string_view kSparseCoreV0Name = "SparseCoreV0"; + +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/profile_time_breakdown.cc b/tensorflow/core/profiler/convert/profile_time_breakdown.cc new file mode 100644 index 00000000000000..e1826a7119f9a2 --- /dev/null +++ b/tensorflow/core/profiler/convert/profile_time_breakdown.cc @@ -0,0 +1,79 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/profiler/convert/profile_time_breakdown.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/tsl/profiler/convert/xla_op_utils.h" +#include "xla/tsl/profiler/utils/math_utils.h" + +namespace tensorflow { +namespace profiler { + +void ProfileTimeBreakdown::SetCategoryTimePs(absl::string_view category, + uint64_t time_ps) { + time_ps_by_category_.insert_or_assign(category, time_ps); +} + +uint64_t ProfileTimeBreakdown::PopCategoryTimePs(absl::string_view category) { + uint64_t time_ps = 0; + auto iter = time_ps_by_category_.find(category); + if (iter != time_ps_by_category_.end()) { + time_ps = iter->second; + time_ps_by_category_.erase(iter); + } + return time_ps; +} + +void ProfileTimeBreakdown::BreakdownSparseCoreV0Infeed() { + // Infeed from SparseCoreV0 and outfeed to SparseCoreV0 are mostly identical + // in compute since they do the same transformation. We can subtract out the + // outfeed time from the infeed time to know how much time the TensorCore + // actually spent waiting on SparseCoreV0. + uint64_t bc_infeed_ps = + PopCategoryTimePs(tsl::profiler::kHloSparseCoreV0Infeed); + if (bc_infeed_ps == 0) return; + uint64_t bc_outfeed_ps = + CategoryTimePs(tsl::profiler::kHloSparseCoreV0Outfeed); + + uint64_t bc_infeed_transform_ps = std::min(bc_infeed_ps, bc_outfeed_ps); + uint64_t bc_infeed_wait_ps = bc_infeed_ps - bc_infeed_transform_ps; + + SetCategoryTimePs(tsl::profiler::kHloSparseCoreV0InfeedWait, + bc_infeed_wait_ps); + SetCategoryTimePs(tsl::profiler::kHloSparseCoreV0InfeedTransform, + bc_infeed_transform_ps); +} + +std::string ProfileTimeBreakdown::DebugString() const { + std::string str; + for (const auto& [category, time_ps] : time_ps_by_category_) { + absl::StrAppend(&str, category, ": ", tsl::profiler::PicoToUni(time_ps), + "\n"); + } + absl::StrAppend( + &str, "total_time: ", tsl::profiler::PicoToUni(total_time_ps_), "\n"); + absl::StrAppend( + &str, "profile_time: ", tsl::profiler::PicoToUni(profile_time_ps_), "\n"); + return str; +} + +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/profile_time_breakdown.h b/tensorflow/core/profiler/convert/profile_time_breakdown.h new file mode 100644 index 00000000000000..1e3379beb4c457 --- /dev/null +++ b/tensorflow/core/profiler/convert/profile_time_breakdown.h @@ -0,0 +1,244 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_PROFILE_TIME_BREAKDOWN_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_PROFILE_TIME_BREAKDOWN_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/strings/string_view.h" +#include "xla/tsl/profiler/convert/xla_op_utils.h" +#include "xla/tsl/profiler/utils/math_utils.h" + +namespace tensorflow { +namespace profiler { + +// Allows accumulating time spent in different HLO instruction categories to +// breakdown the total profile time and compute metrics of interest. +class ProfileTimeBreakdown { + public: + // Category should be the operator category disambiguated by xprof instead of + // the original category from XLA. + // For a correct time breakdown, we need to use the self time of operators, + // instead of total time to avoid double counting. Note that for leaf ops, + // self time and total time are the same. + void IncrementCategoryTimePs(absl::string_view category, + uint64_t self_time_ps) { + time_ps_by_category_[category] += self_time_ps; + total_time_ps_ += self_time_ps; + } + + // Profile time cannot be smaller than the total time in all categories. + // If combining profiles across multiple cores, profile time should be the + // profiling duration multiplied by the number of cores that were profiled. + // go/autograppler_profile_time + void SetProfileTimePs(uint64_t profile_time_ps) { + DCHECK_LE(total_time_ps_, profile_time_ps); + profile_time_ps_ = profile_time_ps; + } + + // Breaks down "sparsecorev0 infeed" into two components: + // 1) "sparsecorev0 infeed wait": Time spent waiting on the SparseCoreV0. + // 2) "sparsecorev0 infeed transform": Time spent transforming activations in + // SparseCoreV0 layout into XLA layout. + // Even though 2) is part of the overall embedding computation, it is time + // spent doing work on the TensorCore. + void BreakdownSparseCoreV0Infeed(); + + // Duty cycle is the fraction of time an accelerator is being actively used. + // go/accelerator-metrics-definitions#common-accelerator-metrics + // go/ag-tpu-duty-cycle + double DutyCycle() const { return TimeFraction(OnDutyTimePs()); } + + double IdleFraction() const { return TimeFraction(IdleTimePs()); } + + double InfeedFraction() const { + return CategoryFraction(tsl::profiler::kHloInfeed); + } + + double OutfeedFraction() const { + return CategoryFraction(tsl::profiler::kHloOutfeed); + } + + double SparseCoreV0InfeedFraction() const { + return CategoriesFraction({tsl::profiler::kHloSparseCoreV0Infeed, + tsl::profiler::kHloSparseCoreV0InfeedWait, + tsl::profiler::kHloSparseCoreV0InfeedTransform}); + } + + double SparseCoreV0OutfeedFraction() const { + return CategoryFraction(tsl::profiler::kHloSparseCoreV0Outfeed); + } + + double AllReduceFraction() const { + return CategoryFraction(tsl::profiler::kHloAllReduce); + } + + double AllReduceFusionFraction() const { + return CategoryFraction(tsl::profiler::kHloAllReduceFusion); + } + + double SendRecvFraction() const { + return CategoriesFraction( + {tsl::profiler::kHloSend, tsl::profiler::kHloSendDone, + tsl::profiler::kHloRecv, tsl::profiler::kHloRecvDone}); + } + + double HostSendRecvFraction() const { + return CategoriesFraction( + {tsl::profiler::kHloHostSend, tsl::profiler::kHloHostSendDone, + tsl::profiler::kHloHostRecv, tsl::profiler::kHloHostRecvDone}); + } + + double CategoriesFraction( + const std::initializer_list& categories) const { + return TimeFraction(CategoriesTimePs(categories)); + } + + double CategoryFraction(absl::string_view category) const { + return TimeFraction(CategoryTimePs(category)); + } + + uint64_t ProfileTimePs() const { return profile_time_ps_; } + + uint64_t TotalTimePs() const { return total_time_ps_; } + + uint64_t IdleTimePs() const { return profile_time_ps_ - total_time_ps_; } + + uint64_t OnDutyTimePs() const { return profile_time_ps_ - OffDutyTimePs(); } + + uint64_t OffDutyTimePs() const { + return IdleTimePs() + + CategoriesTimePs( + {tsl::profiler::kHloInfeed, tsl::profiler::kHloOutfeed, + tsl::profiler::kHloHostSend, tsl::profiler::kHloHostSendDone, + tsl::profiler::kHloHostRecv, tsl::profiler::kHloHostRecvDone, + tsl::profiler::kHloMegacoreFusion}); + } + + uint64_t InfeedTimePs() const { + return CategoryTimePs(tsl::profiler::kHloInfeed); + } + + uint64_t OutfeedTimePs() const { + return CategoryTimePs(tsl::profiler::kHloOutfeed); + } + + uint64_t SparseCoreV0InfeedWaitTimePs() const { + return CategoryTimePs(tsl::profiler::kHloSparseCoreV0InfeedWait); + } + + uint64_t SparseCoreV0InfeedTransformTimePs() const { + return CategoryTimePs(tsl::profiler::kHloSparseCoreV0InfeedTransform); + } + + uint64_t SparseCoreV0OutfeedTimePs() const { + return CategoryTimePs(tsl::profiler::kHloSparseCoreV0Outfeed); + } + + uint64_t AllReduceOrAllToAllTimePs() const { + return CategoriesTimePs({tsl::profiler::kHloAllReduce, + tsl::profiler::kHloAllReduceFusion, + tsl::profiler::kHloAllToAll}); + } + + uint64_t SendTimePs() const { + return CategoriesTimePs( + {tsl::profiler::kHloSend, tsl::profiler::kHloSendDone}); + } + + uint64_t RecvTimePs() const { + return CategoriesTimePs( + {tsl::profiler::kHloRecv, tsl::profiler::kHloRecvDone}); + } + + uint64_t HostSendTimePs() const { + return CategoriesTimePs( + {tsl::profiler::kHloHostSend, tsl::profiler::kHloHostSendDone}); + } + + uint64_t HostRecvTimePs() const { + return CategoriesTimePs( + {tsl::profiler::kHloHostRecv, tsl::profiler::kHloHostRecvDone}); + } + + // Megacore fusion runs different operations on each core, e.g., a convolution + // on one core and an all-reduce on the other core. In a trace, megacore + // fusion is the parent operation, and its self time is the time that the core + // executing the faster operation waits for the core executing the slower + // operation to reach the synchronization point. + uint64_t MegacoreFusionTimePs() const { + return CategoryTimePs(tsl::profiler::kHloMegacoreFusion); + } + + uint64_t HighFlopsComputeTimePs() const { + return CategoriesTimePs({tsl::profiler::kHloConvolution, + tsl::profiler::kHloConvolutionBaseDilated, + tsl::profiler::kHloConvolutionWindowDilated, + tsl::profiler::kHloConvolutionFusion, + tsl::profiler::kHloOutputFusion}); + } + + // Calculated according to the "TC busy time" defined in go/tpu_kpis + uint64_t TensorCoreBusyTimePs() const { + return profile_time_ps_ - OffDutyTimePs() - SparseCoreV0InfeedWaitTimePs(); + } + + uint64_t CategoriesTimePs( + const std::initializer_list& categories) const { + uint64_t time_ps = 0; + for (auto category : categories) { + time_ps += CategoryTimePs(category); + } + return time_ps; + } + + uint64_t CategoryTimePs(absl::string_view category) const { + auto iter = time_ps_by_category_.find(category); + return (iter == time_ps_by_category_.end()) ? 0 : iter->second; + } + + template + void ComputeCategoryFractions(Map& category_fractions) { + for (const auto& [category, time_ps] : time_ps_by_category_) { + category_fractions[category] = TimeFraction(time_ps); + } + } + + std::string DebugString() const; + + private: + // Overwrites the time attributed to the given category. + void SetCategoryTimePs(absl::string_view category, uint64_t time_ps); + + // Removes and returns the time attributed to the given category. + uint64_t PopCategoryTimePs(absl::string_view category); + + double TimeFraction(uint64_t time_ps) const { + return tsl::profiler::SafeDivide(time_ps, profile_time_ps_); + } + + absl::flat_hash_map time_ps_by_category_; + uint64_t total_time_ps_ = 0; // Sum of values in time_ps_by_category_. + uint64_t profile_time_ps_ = 0; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_PROFILE_TIME_BREAKDOWN_H_ diff --git a/tensorflow/core/profiler/convert/tpu_input_pipeline_analysis_constants.h b/tensorflow/core/profiler/convert/tpu_input_pipeline_analysis_constants.h new file mode 100644 index 00000000000000..352a2b774fc2da --- /dev/null +++ b/tensorflow/core/profiler/convert/tpu_input_pipeline_analysis_constants.h @@ -0,0 +1,30 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_TPU_INPUT_PIPELINE_ANALYSIS_CONSTANTS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_TPU_INPUT_PIPELINE_ANALYSIS_CONSTANTS_H_ + +#include "absl/strings/string_view.h" +#include "tsl/platform/macros.h" + +namespace tensorflow { +namespace profiler { + +TF_CONST_INIT extern const absl::string_view kProfileAllHostsDoc; +TF_CONST_INIT extern const absl::string_view kSparseCoreV0Name; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_TPU_INPUT_PIPELINE_ANALYSIS_CONSTANTS_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc index 63399de65677c9..512809127e405b 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc @@ -241,6 +241,8 @@ TEST(ConvertXPlaneToOpMetricsDb, TpuDeviceOpMetricsDb) { hlo_module_id: 1 self_time_ps: 10000 flops: 68 + model_flops: 68 + num_cores: 1 occurrences: 2 name: "MatMul" time_ps: 10000 diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc index 1b33e5fbe7b949..85bcf086cc3a8f 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc @@ -42,8 +42,10 @@ limitations under the License. #include "tensorflow/core/profiler/utils/device_caps_utils.h" #include "tensorflow/core/profiler/utils/event_span.h" #include "tensorflow/core/profiler/utils/hardware_type_utils.h" +#include "tensorflow/core/profiler/utils/hlo_module_map.h" #include "tensorflow/core/profiler/utils/hlo_proto_map.h" #include "tensorflow/core/profiler/utils/kernel_stats_utils.h" +#include "tensorflow/core/profiler/utils/op_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" @@ -78,6 +80,22 @@ PerfEnv MakePerfEnv(double peak_tera_flops_per_second, return result; } +PerfEnv MakePerfEnvForTpu(double peak_tera_flops_per_second, + std::vector peak_bws, bool has_merged_vmem, + bool has_megacore) { + PerfEnv result = MakePerfEnv(peak_tera_flops_per_second, peak_bws); + result.set_has_cmem(peak_bws[MemBwType::MEM_BW_TYPE_CMEM_RD] > 0 || + peak_bws[MemBwType::MEM_BW_TYPE_CMEM_WR] > 0); + result.set_has_merged_vmem(has_merged_vmem); + result.set_has_megacore(has_megacore); + return result; +} + +PerfEnv MakePerfEnvForGpu(double peak_tera_flops_per_second, + std::vector peak_bws) { + return MakePerfEnv(peak_tera_flops_per_second, peak_bws); +} + PerfEnv GetPerfEnvFromXPlane(const XPlane& device_plane) { DeviceCapabilities cap = GetDeviceCaps(device_plane); if (!absl::StartsWith(device_plane.name(), kTpuPlanePrefix)) { @@ -91,10 +109,10 @@ PerfEnv GetPerfEnvFromXPlane(const XPlane& device_plane) { tsl::profiler::UniToGiga(GetSharedMemoryBandwidthPerSM(cap)); // Note that treat SRAM_RD and SRAM_WR as the same. So in future, we could // only use one for shared memory / L1 cache, one for another like L2. - return MakePerfEnv(peak_tera_flops_per_second, - {/*HBM_RW=*/hbm_bw_giga_bytes_per_second, - /*SRAM_RD=*/shm_giga_bytes_per_second, - /*SRAM_WR=*/shm_giga_bytes_per_second}); + return MakePerfEnvForGpu(peak_tera_flops_per_second, + {/*HBM_RW=*/hbm_bw_giga_bytes_per_second, + /*SRAM_RD=*/shm_giga_bytes_per_second, + /*SRAM_WR=*/shm_giga_bytes_per_second}); } else { XPlaneVisitor visitor = tsl::profiler::CreateTfXPlaneVisitor(&device_plane); std::optional peak_tera_flops_per_second = @@ -145,14 +163,24 @@ PerfEnv GetPerfEnvFromXPlane(const XPlane& device_plane) { vmem_wr_bw_giga_bytes_per_second.has_value() ? vmem_wr_bw_giga_bytes_per_second->DoubleValue() : 0.0; - return MakePerfEnv(peak_tera_flops_per_second_val, - {/*HBM_RW=*/peak_hbm_bw_giga_bytes_per_second_val, - /*SRAM_RD=*/peak_sram_rd_bw_giga_bytes_per_second_val, - /*SRAM_WR=*/peak_sram_wr_bw_giga_bytes_per_second_val, - /**CMEM_RD=*/cmem_rd_bw_giga_bytes_per_second_val, - /**CMEM_WR=*/cmem_wr_bw_giga_bytes_per_second_val, - /**VMEM_RD=*/vmem_rd_bw_giga_bytes_per_second_val, - /**VMEM_WR=*/vmem_wr_bw_giga_bytes_per_second_val}); + std::optional has_megacore = + visitor.GetStat(StatType::kDevHasMegacore); + bool has_megacore_val = + has_megacore.has_value() ? has_megacore->BoolValue() : false; + std::optional has_merged_vmem = + visitor.GetStat(StatType::kDevHasMergedVmem); + bool has_merged_vmem_val = + has_merged_vmem.has_value() ? has_merged_vmem->BoolValue() : false; + return MakePerfEnvForTpu( + peak_tera_flops_per_second_val, + {/*HBM_RW=*/peak_hbm_bw_giga_bytes_per_second_val, + /*SRAM_RD=*/peak_sram_rd_bw_giga_bytes_per_second_val, + /*SRAM_WR=*/peak_sram_wr_bw_giga_bytes_per_second_val, + /**CMEM_RD=*/cmem_rd_bw_giga_bytes_per_second_val, + /**CMEM_WR=*/cmem_wr_bw_giga_bytes_per_second_val, + /**VMEM_RD=*/vmem_rd_bw_giga_bytes_per_second_val, + /**VMEM_WR=*/vmem_wr_bw_giga_bytes_per_second_val}, + has_merged_vmem_val, has_megacore_val); } } @@ -217,6 +245,13 @@ void SetProgramIdToNameMap(const HloProtoMap& hlo_proto_map, } } +void UpdateOpMetricsDbFromHloModuleMap(OpMetricsDb& op_metrics_db, + const HloModuleMap& hlo_module_map) { + for (OpMetrics& op_metrics : *op_metrics_db.mutable_metrics_db()) { + EnterOpMetadataFromHloModuleMap(&op_metrics, hlo_module_map); + } +} + OpStats ConvertXSpaceToOpStats(const XSpace& space, const OpStatsOptions& options) { OpStats op_stats; @@ -245,6 +280,8 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space, if (!op_stats.has_perf_env()) { *op_stats.mutable_perf_env() = GetPerfEnvFromXPlane(*device_trace); } + HloModuleMap hlo_module_map; + ProcessHloModuleMapFromXSpace(hlo_module_map, &space); if (!is_tpu) { OpMetricsDb device_op_metrics_db = ConvertDeviceTraceXPlaneToOpMetricsDb(*device_trace); @@ -254,6 +291,7 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space, use_aggregated_xplane = true; OpMetricsDb device_op_metrics_db = ConvertTpuDeviceTraceXPlaneToOpMetricsDb(aggregated_xplane); + UpdateOpMetricsDbFromHloModuleMap(device_op_metrics_db, hlo_module_map); op_metrics_db_combiner.Combine(device_op_metrics_db); } } @@ -319,11 +357,27 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space, } } - // TODO(bvandermoon): Add the TPU equivalent for setting core details hostname if (!is_tpu) { CoreDetails& details = (*op_stats.mutable_core_id_to_details())[kDefaultGpuLocalCoreId]; details.set_hostname(Hostname(space)); + } else { + std::string hostname = Hostname(space); + auto& core_id_to_details = *op_stats.mutable_core_id_to_details(); + for (const XPlane* device_plane : device_planes) { + XPlaneVisitor visitor = + tsl::profiler::CreateTfXPlaneVisitor(device_plane); + auto stat = visitor.GetStat(StatType::kCoreDetails); + if (stat.has_value()) { + CoreDetails core_details; + absl::string_view core_details_bytes = stat->BytesValue(); + if (core_details.ParseFromArray(core_details_bytes.data(), + core_details_bytes.size())) { + core_details.set_hostname(hostname); + core_id_to_details[device_plane->id()] = core_details; + } + } + } } // Set program_id_to_name map in OpStats from Xspace diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc index 233d3574ee5e03..bfbcc9077beea0 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc @@ -216,7 +216,7 @@ TEST(ConvertXPlaneToOpStats, GpuStepDbTest) { options, &op_stats)); const StepDatabaseResult& step_db = op_stats.step_db(); - EXPECT_EQ(step_db.step_sequence_size(), 0); + EXPECT_EQ(step_db.step_sequence_size(), 1); PrecisionStats precision_stats = op_stats.device_op_metrics_db().precision_stats(); diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events.cc b/tensorflow/core/profiler/convert/xplane_to_step_events.cc index 47d1aa8c5f3588..bb75eb1a480b72 100644 --- a/tensorflow/core/profiler/convert/xplane_to_step_events.cc +++ b/tensorflow/core/profiler/convert/xplane_to_step_events.cc @@ -301,11 +301,11 @@ StepEvents ConvertDeviceTraceXPlaneToStepEvents(const XPlane& device_trace) { // one more step than the "Step" line. We need to intersect them to get // the common step numbers. stream_step_events = - ConvertTpuDeviceTraceXLineToStepEvents(*tpu_core_id, line); + ConvertTpuDeviceTraceXLineToStepEvents(plane.Id(), line); IntersectCombineStepEvents(stream_step_events, &device_step_events); } else if (sc_core_id.has_value()) { - stream_step_events = ConvertTpuDeviceTraceXLineToStepEvents( - kSparseCoreIndexStart + *sc_core_id, line); + stream_step_events = + ConvertTpuDeviceTraceXLineToStepEvents(plane.Id(), line); IntersectCombineStepEvents(stream_step_events, &device_step_events); } else { stream_step_events = diff --git a/tensorflow/core/profiler/convert/xplane_to_tool_names.cc b/tensorflow/core/profiler/convert/xplane_to_tool_names.cc index e01a932fdfaa91..77b13defbbb58a 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tool_names.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tool_names.cc @@ -46,6 +46,7 @@ absl::StatusOr GetAvailableToolNames( tools.push_back("op_profile"); tools.push_back("inference_profile"); tools.push_back("hlo_stats"); + tools.push_back("roofline_model"); TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, session_snapshot.GetXSpace(0)); diff --git a/tensorflow/core/profiler/convert/xplane_to_tool_names_test.cc b/tensorflow/core/profiler/convert/xplane_to_tool_names_test.cc index 73a79240343c78..414ace9b95c669 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tool_names_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tool_names_test.cc @@ -123,6 +123,7 @@ TEST_P(XPlaneToToolsTest, ToolsList) { "tf_data_bottleneck_analysis", "op_profile", "hlo_stats", + "roofline_model", "inference_profile", }; expected_tools.insert(expected_tools.end(), test_case.expected_tools.begin(), diff --git a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc index c12743a416b5b8..432fad90bb7474 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/profiler/convert/op_stats_to_op_profile.h" #include "tensorflow/core/profiler/convert/op_stats_to_overview_page.h" #include "tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h" +#include "tensorflow/core/profiler/convert/op_stats_to_roofline_model.h" #include "tensorflow/core/profiler/convert/op_stats_to_tf_stats.h" #include "tensorflow/core/profiler/convert/preprocess_single_host_xplane.h" #include "tensorflow/core/profiler/convert/process_megascale_dcn.h" @@ -58,6 +59,7 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/op_profile.pb.h" #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" #include "tensorflow/core/profiler/protobuf/overview_page.pb.h" +#include "tensorflow/core/profiler/protobuf/roofline_model.pb.h" #include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" #include "tensorflow/core/profiler/protobuf/tf_stats.pb.h" #include "tensorflow/core/profiler/utils/hardware_type_utils.h" @@ -277,6 +279,22 @@ absl::StatusOr ConvertMultiXSpacesToHloStats( return ConvertOpStatsToHloStats(combined_op_stats).SerializeAsString(); } +absl::StatusOr ConvertMultiXSpacesToRooflineModel( + const SessionSnapshot& session_snapshot) { + OpStatsOptions op_stats_options; + op_stats_options.generate_op_metrics_db = true; + OpStats combined_op_stats; + TF_RETURN_IF_ERROR(ConvertMultiXSpacesToCombinedOpStats( + session_snapshot, op_stats_options, &combined_op_stats)); + RooflineModelDatabase result = + ConvertOpStatsToRooflineModel(combined_op_stats, true); + RooflineModelDatabase result_without_infeed_outfeed = + ConvertOpStatsToRooflineModel(combined_op_stats, false); + result.mutable_roofline_model_record()->MergeFrom( + result_without_infeed_outfeed.roofline_model_record()); + return result.SerializeAsString(); +} + absl::StatusOr ConvertMultiXSpacesToOpProfileViewer( const SessionSnapshot& session_snapshot) { OpStatsOptions options; @@ -377,6 +395,8 @@ absl::StatusOr ConvertMultiXSpacesToToolData( return ConvertMultiXSpacesToOpProfileViewer(session_snapshot); } else if (tool_name == "hlo_stats") { return ConvertMultiXSpacesToHloStats(session_snapshot); + } else if (tool_name == "roofline_model") { + return ConvertMultiXSpacesToRooflineModel(session_snapshot); } else if (tool_name == "memory_viewer" || tool_name == "graph_viewer") { return ConvertHloProtoToToolData(session_snapshot, tool_name, options); } else if (tool_name == "dcn_collective_stats") { diff --git a/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc b/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc index 396b8d20b6da5c..2962d5cd46e75b 100644 --- a/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc +++ b/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/profiler/internal/advisor/tfprof_advisor.h" +#include #include #include #include diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD index 6b1ca8e6be8744..6ffe0258f4fce9 100644 --- a/tensorflow/core/profiler/lib/BUILD +++ b/tensorflow/core/profiler/lib/BUILD @@ -226,6 +226,7 @@ cc_library( "//tensorflow/core/platform", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/profiler/lib:scoped_annotation", ] + if_not_android([ "@local_xla//xla/tsl/profiler/backends/cpu:annotation_stack", diff --git a/tensorflow/core/profiler/lib/profiler_disabled_test.cc b/tensorflow/core/profiler/lib/profiler_disabled_test.cc index f55b50ad0375f8..42c3c16a432508 100644 --- a/tensorflow/core/profiler/lib/profiler_disabled_test.cc +++ b/tensorflow/core/profiler/lib/profiler_disabled_test.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include "absl/status/statusor.h" #include "tensorflow/core/platform/env.h" diff --git a/tensorflow/core/profiler/profiler.cc b/tensorflow/core/profiler/profiler.cc index dbcac0a2858c93..58d2bbc4a8fecc 100644 --- a/tensorflow/core/profiler/profiler.cc +++ b/tensorflow/core/profiler/profiler.cc @@ -16,8 +16,9 @@ limitations under the License. #include #include +#include +#include #include -#include #include #include #include diff --git a/tensorflow/core/profiler/protobuf/BUILD b/tensorflow/core/profiler/protobuf/BUILD index d88ed4d0835038..a29c64df4c674e 100644 --- a/tensorflow/core/profiler/protobuf/BUILD +++ b/tensorflow/core/profiler/protobuf/BUILD @@ -290,3 +290,9 @@ tf_proto_library( "//learning/serving/tools/servo_model_profiler:__subpackages__", ], ) + +tf_proto_library( + name = "tpu_input_pipeline_proto", + srcs = ["tpu_input_pipeline.proto"], + protodeps = [":input_pipeline_proto"], +) diff --git a/tensorflow/core/profiler/protobuf/op_metrics.proto b/tensorflow/core/profiler/protobuf/op_metrics.proto index 2d0ab71bbc0f48..c30557b6d96ed2 100644 --- a/tensorflow/core/profiler/protobuf/op_metrics.proto +++ b/tensorflow/core/profiler/protobuf/op_metrics.proto @@ -170,7 +170,7 @@ message PrecisionStats { } // A database for OpMetrics. -// Next ID: 14 +// Next ID: 16 message OpMetricsDb { // A bunch of OpMetrics. repeated OpMetrics metrics_db = 10; @@ -185,5 +185,11 @@ message OpMetricsDb { uint64 total_op_time_ps = 12; // Precision-related stats. PrecisionStats precision_stats = 13; + // The below two stats will be different from the total time ps and total op + // time ps because they are unioned all cores (and not summed). + // For duty cycle, a device is idle if all the cores are idle. + uint64 idle_time_ps = 14; + // For duty cycle, a device is busy if any of the cores is busy. + uint64 busy_time_ps = 15; reserved 1, 4, 5, 6, 7, 8, 9; } diff --git a/tensorflow/core/profiler/protobuf/steps_db.proto b/tensorflow/core/profiler/protobuf/steps_db.proto index c1077d6089cabd..5fb524b3c4d384 100644 --- a/tensorflow/core/profiler/protobuf/steps_db.proto +++ b/tensorflow/core/profiler/protobuf/steps_db.proto @@ -19,6 +19,95 @@ message GenericStepBreakdown { map category_ps = 2; } +// Breakdown of step-time on TPU. +// Next ID: 20 +message TpuStepBreakdown { + // The infeed duration (host to TensorCore) in picoseconds. + uint64 infeed_duration_ps = 1; + + // The outfeed duration (TensorCore to host) in picoseconds. + uint64 host_outfeed_ps = 2; + + // The TensorCore time that is waiting for SparseCoreV0 in picoseconds. + uint64 wait_for_scv0_duration_ps = 3; + + // The TensorCore time spent transforming activations in SparseCoreV0 layout + // into XLA layout. + uint64 scv0_infeed_transform_ps = 4; + + // The outfeed duration (TensorCore to SparseCoreV0) in picoseconds. + uint64 scv0_outfeed_ps = 5; + + // The time spent on all-reduce (used to be cross-replica-sum) in picoseconds. + uint64 crs_duration_ps = 6; + + // The percentage of the SparseCoreV0 time that spends on infeed from host + // (including both data and instruction). + double scv0_infeed_percent = 7; + + // The time spent on send operation. + uint64 send_duration_ps = 8; + + // The time spent on recv operation. + uint64 recv_duration_ps = 9; + + // The time spent on host send operation. + uint64 host_send_duration_ps = 15; + + // The time spent on host recv operation. + uint64 host_recv_duration_ps = 16; + + // Megacore fusion runs different operations on each core, e.g., a convolution + // on one core and an all-reduce on the other core. This is the time that the + // core executing the faster operation waits for the core executing the slower + // operation to reach the synchronization point. + uint64 wait_for_megacore_fusion_peer_duration_ps = 14; + + // The time waiting for overlay DMAs in picoseconds. + uint64 overlay_wait_duration_ps = 11; + + // The time spent running high flops ops, such as convolution and output + // fusion. + uint64 high_flops_compute_ps = 12; + + // The time that the Tensorcore is idle but not waiting for input or + // SparseCoreV0. + uint64 tc_idle_ps = 13; + + // The TensorCore time that is busy in picoseconds. + uint64 tc_busy_ps = 17; + + // The SparseCoreV0 time that is busy in picoseconds (equal to + // SparseCoreV0 time - HOST_INSTRUCTION_STALL - HOST_DATA_STALL - + // TENSOR_CORE_STALL). + uint64 scv0_busy_ps = 18; + + // SparseCoreV0 step time in picoseconds (equal to SparseCoreV0 time - + // TENSOR_CORE_STALL). + uint64 scv0_step_ps = 19; + + reserved 10; +} + +// Breakdown of step-time on SparseCore. +message SparseCoreStepBreakdown { + // SparseCore step time in picoseconds (equal to SparseCore time - sc_idle - + // sc_wait_time). + uint64 sc_compute_ps = 1; + + // Host to sparse core time in picoseconds. + uint64 sc_infeed_ps = 2; + + // SparseCore to host time in picoseconds. + uint64 sc_outfeed_ps = 3; + + // Idle time but not waiting for input in picoseconds. + uint64 sc_idle_ps = 4; + + // SparseCore busy time in picoseconds. + uint64 sc_busy_ps = 5; +} + // Information about memory transfer to/from device memory. message DeviceMemoryTransfer { uint64 occurrence = 1; diff --git a/tensorflow/core/profiler/protobuf/tpu_input_pipeline.proto b/tensorflow/core/profiler/protobuf/tpu_input_pipeline.proto new file mode 100644 index 00000000000000..b68a104b6cb26c --- /dev/null +++ b/tensorflow/core/profiler/protobuf/tpu_input_pipeline.proto @@ -0,0 +1,164 @@ +syntax = "proto3"; + +package tensorflow.profiler; + +import "tensorflow/core/profiler/protobuf/input_pipeline.proto"; + +// Per-step details on TPU. +// Next ID: 25 +message PerTpuStepDetails { + // The step number of a step. + int32 step_number = 1; + + // The TensorCore compute time in this step. + double tc_compute_time_ms = 13; + + // The maximum TensorCore idle time that is due to host overhead (but not + // input-related). + double tc_idle_time_ms = 14; + + // The part of a step (in ms) TC spends sending data to the host via outfeed. + double tc_outfeed_time_ms = 15; + + // The part of a step (in ms) on TC that is waiting for input data from the + // host. + double tc_infeed_time_ms = 3; + + // Average infeed-dequeue time across cores (as percentage of step time). + double infeed_percent_average = 4; + + // Minimum infeed-dequeue time across cores (as percentage of step time). + double infeed_percent_minimum = 5; + + // Maximum infeed-dequeue time across cores (as percentage of step time). + double infeed_percent_maximum = 6; + + // The core with the maximum infeed time in this step. + uint32 coreid_max_infeed_time = 7; + + // The part of a step (in ms) that is spent on the all-reduce compute. + double all_reduce_compute_time_ms = 11; + + // The part of a step (in ms) that is spent on the all-reduce synchronization. + double all_reduce_sync_time_ms = 12; + + // The part of a step (in ms) that is spent on SparseCoreV0 compute. + double scv0_compute_time_ms = 16; + + // The part of a step (in ms) that spent on infeed from host to SparseCoreV0. + double scv0_infeed_time_ms = 17; + + // The part of the step (in ms) that is spent waiting for device to host or + // host to device transfer. + double host_transfer_ms = 18; + + // The SparseCore compute time in this step. + double sc_compute_time_ms = 20; + + // The maximum SparseCore idle time that is due to host overhead (but not + // input-related). + double sc_idle_time_ms = 21; + + // The part of a step (in ms) SC spends sending data to the host via outfeed. + double sc_outfeed_time_ms = 22; + + // The part of a step (in ms) on SC that is waiting for input data from the + // host. + double sc_infeed_time_ms = 23; + + // Sparse core step time in ms. + double sc_step_time_ms = 24; + + reserved 2, 8, 9, 10; +} + +// Next Id: 9 +message TpuStepTimeBreakdown { + // Summary of all TensorCore compute op duration as a part of step in ms. + tensorflow.profiler.StepSummary tc_compute_ms_summary = 1; + + // Summary of all SparseCoreV0 compute op duration as a part of step in ms. + tensorflow.profiler.StepSummary scv0_compute_ms_summary = 2; + + // Summary of all TensorCore infeed op duration as a part of step in ms. + tensorflow.profiler.StepSummary tc_infeed_ms_summary = 3; + + // Summary of all TensorCore outfeed op duration as a part of step in ms. + tensorflow.profiler.StepSummary tc_outfeed_ms_summary = 6; + + // Summary of all SparseCoreV0 infeed op duration as a part of step in ms. + tensorflow.profiler.StepSummary scv0_infeed_ms_summary = 4; + + // Summary of all TensorCore idle (but not input-related) duration as a part + // of step in ms. + tensorflow.profiler.StepSummary tc_idle_ms_summary = 5; + + // Summary of all Host to Device and Device to Host transfer part of the step + // in ms. + tensorflow.profiler.StepSummary host_transfer_ms_summary = 7; + // Summary of all sparsecore step summary info. + SparseCoreStepSummary sparse_core_step_summary = 8; +} + +// Similar to TpuStepTimeBreakdown, this is for sparse core step time info. +message SparseCoreStepSummary { + // Summary of all SparseCore compute op duration as a part of step in ms. + tensorflow.profiler.StepSummary sc_compute_ms_summary = 1; + // Summary of all SparseCore infeed op duration as a part of step in ms. + tensorflow.profiler.StepSummary sc_infeed_ms_summary = 2; + // Summary of all SparseCore outfeed op duration as a part of step in ms. + tensorflow.profiler.StepSummary sc_outfeed_ms_summary = 3; + // Summary of all SparseCore idle (but not input-related) duration as a part + // of step in ms. + tensorflow.profiler.StepSummary sc_idle_ms_summary = 4; + // Summary of all SparseCore step time in ms. + tensorflow.profiler.StepSummary sc_step_time_ms_summary = 5; +} + +message TpuBottleneckAnalysis { + // Percentage of step time that is spent on input. + double input_percent = 11; + + // Indicates if input is a bottleneck. Possible values: "host", "device", + // "both", or "unknown" + string input_classification = 1; + + // A human-readable description of the input bottleneck. + string input_statement = 2; + + // Indicates if output is a bottleneck. Possible values: "host", "device", + // "both", or "unknown" + double output_percent = 12; + + // Percentage of step time that is spent on output. + string output_classification = 9; + + // A human-readable description of the output bottleneck. + string output_statement = 10; + + // Percentage of step time where the TC is idle (other than I/O). + double tc_idle_percent = 13; + + // Indicates if TensorCore being idle (other than input) is a bottleneck. + // Possible values: "no", "yes". + string tc_idle_classification = 3; + + // A human-readable description of the TC-idle bottleneck. + string tc_idle_statement = 4; + + // Indicates if SparseCoreV0 is a bottleneck. Possible values: "no", + // "moderate", "high". + string scv0_classification = 5; + + // A human-readable description of the SparseCoreV0 bottleneck. + string scv0_statement = 6; + + // Indicates if all-reduce is a bottleneck. Possible values: "no", "yes". + string all_reduce_classification = 7; + + // A human-readable description of the all-reduce bottleneck. + string all_reduce_statement = 8; + + // Percentage of step time that is spent on compute. + double compute_percent = 14; +} diff --git a/tensorflow/core/profiler/tfprof_options.cc b/tensorflow/core/profiler/tfprof_options.cc index 8e96deebc7512d..a31fddbcef3821 100644 --- a/tensorflow/core/profiler/tfprof_options.cc +++ b/tensorflow/core/profiler/tfprof_options.cc @@ -15,6 +15,11 @@ limitations under the License. #include "tensorflow/core/profiler/tfprof_options.h" +#include +#include +#include +#include + #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" diff --git a/tensorflow/core/profiler/tfprof_options.h b/tensorflow/core/profiler/tfprof_options.h index 7d24aaf4625b25..61143b49705138 100644 --- a/tensorflow/core/profiler/tfprof_options.h +++ b/tensorflow/core/profiler/tfprof_options.h @@ -16,11 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_TFPROF_OPTIONS_H_ #define TENSORFLOW_CORE_PROFILER_TFPROF_OPTIONS_H_ +#include +#include #include #include #include #include +#include "absl/status/status.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD index 4a9ece46f2c33d..cf0dd5728ce3c9 100644 --- a/tensorflow/core/profiler/utils/BUILD +++ b/tensorflow/core/profiler/utils/BUILD @@ -127,12 +127,14 @@ cc_library( hdrs = ["op_utils.h"], copts = tf_profiler_copts(), deps = [ + ":hlo_module_map", ":op_metrics_db_utils", "//tensorflow/core:lib", "//tensorflow/core/platform:protobuf", "//tensorflow/core/profiler/convert:op_metrics_db_combiner", "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", "@com_google_absl//absl/strings", + "@local_xla//xla/hlo/ir:hlo", "@local_xla//xla/tsl/profiler/utils:tf_op_utils", "@local_xla//xla/tsl/profiler/utils:timespan", ], @@ -452,6 +454,7 @@ tf_cuda_library( ], visibility = [":friends"], deps = [ + ":hlo_module_utils", ":hlo_proto_map", ":hlo_proto_to_module", "//tensorflow/core/platform:path", @@ -464,13 +467,37 @@ tf_cuda_library( "@local_xla//xla/hlo/ir:hlo", "@local_xla//xla/service:hlo_cost_analysis", "@local_xla//xla/service:hlo_proto_cc", + "@local_xla//xla/tsl/profiler/convert:xla_op_utils", ], ) cc_library( name = "hlo_module_utils", hdrs = ["hlo_module_utils.h"], - deps = ["@local_xla//xla/hlo/ir:hlo"], + visibility = [ + ":friends", + # copybara:uncomment "//tensorflow/compiler/mlir/lite/experimental/google/tooling/google:__subpackages__", + ], + deps = [ + "@com_google_absl//absl/strings", + "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/tsl/profiler/convert:xla_op_utils", + ], +) + +tf_cc_test( + name = "hlo_module_utils_test", + srcs = ["hlo_module_utils_test.cc"], + deps = [ + ":hlo_module_utils", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/tests:hlo_test_base", + ], ) cc_library( @@ -488,6 +515,20 @@ cc_library( ], ) +cc_library( + name = "tpu_step_breakdown_utils", + hdrs = ["tpu_step_breakdown_utils.h"], + visibility = [":friends"], + deps = ["//tensorflow/core/profiler/protobuf:steps_db_proto_cc"], +) + +cc_library( + name = "tpu_step_details_utils", + hdrs = ["tpu_step_details_utils.h"], + visibility = [":friends"], + deps = ["//tensorflow/core/profiler/protobuf:tpu_input_pipeline_proto_cc"], +) + tf_cc_test( name = "xprof_gpu_cost_analysis_test", srcs = ["xprof_gpu_cost_analysis_test.cc"], diff --git a/tensorflow/core/profiler/utils/event_span.cc b/tensorflow/core/profiler/utils/event_span.cc index 27ddddf1e4d195..bcae01bd4a49c0 100644 --- a/tensorflow/core/profiler/utils/event_span.cc +++ b/tensorflow/core/profiler/utils/event_span.cc @@ -283,7 +283,7 @@ void StepDetails::AddMarker(const StepMarker& m) { markers_.push_back(m); } void StepDetails::AddEvent(const EventTypeSpan& e) { events_.push_back(e); } void StepDetails::AggregateDeviceMemoryTransfers( - const std::vector device_memory_transfers) { + const std::vector& device_memory_transfers) { if (device_memory_transfers.size() != device_memory_transfers_.size()) { return; // Sanity check. } diff --git a/tensorflow/core/profiler/utils/event_span.h b/tensorflow/core/profiler/utils/event_span.h index 4100390b88959b..f1e3a5b7600151 100644 --- a/tensorflow/core/profiler/utils/event_span.h +++ b/tensorflow/core/profiler/utils/event_span.h @@ -203,7 +203,7 @@ class StepDetails { private: // Accumulates the device memory transfers from another step to this step. void AggregateDeviceMemoryTransfers( - const std::vector device_memory_transfers); + const std::vector& device_memory_transfers); // All step-markers found for marking this step in the traces. There could be // multiple step-markers for a single step for different reasons. One such diff --git a/tensorflow/core/profiler/utils/hlo_module_map.cc b/tensorflow/core/profiler/utils/hlo_module_map.cc index dda6a26b4d1157..245d0018cb297e 100644 --- a/tensorflow/core/profiler/utils/hlo_module_map.cc +++ b/tensorflow/core/profiler/utils/hlo_module_map.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/profiler/lib/traceme_encode.h" +#include "tensorflow/core/profiler/utils/hlo_module_utils.h" #include "tensorflow/core/profiler/utils/hlo_proto_map.h" #include "tensorflow/core/profiler/utils/hlo_proto_to_module.h" @@ -55,7 +56,9 @@ HloInstructionWrapper::HloInstructionWrapper( : instr_(instr), op_full_name_( tsl::profiler::TraceMeOp(Metadata().op_name(), Metadata().op_type())), - category_(instr_->ToCategory()) { + category_(instr_->ToCategory()), + expression_(tensorflow::profiler::UncachedExpression( + instr_, false, tensorflow::profiler::kMaxHlolNameSize)) { ProcessXlaCostAnalysis(cost_analysis); } diff --git a/tensorflow/core/profiler/utils/hlo_module_map.h b/tensorflow/core/profiler/utils/hlo_module_map.h index 92f99db42eb301..de37d5dff97619 100644 --- a/tensorflow/core/profiler/utils/hlo_module_map.h +++ b/tensorflow/core/profiler/utils/hlo_module_map.h @@ -45,6 +45,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_cost_analysis.h" +#include "xla/tsl/profiler/convert/xla_op_utils.h" +#include "tensorflow/core/profiler/utils/hlo_module_utils.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tensorflow { @@ -64,9 +66,12 @@ class HloInstructionInterface { virtual std::string source_info() const = 0; virtual bool isRoot() const = 0; virtual bool IsFusion() const = 0; + virtual const std::string& Expression() const = 0; virtual void ProcessXlaCostAnalysis( const xla::HloCostAnalysis* cost_analysis) = 0; + virtual std::string OpLocationStack(int32_t frame_id) const = 0; + virtual tsl::profiler::OpSourceInfo SourceInfo() const = 0; }; // This wrapper allows caching the results of HloInstruction methods. @@ -77,7 +82,7 @@ class HloInstructionWrapper : public HloInstructionInterface { const xla::HloInstruction* instr, const xla::HloCostAnalysis* cost_analysis = nullptr); - // Non copiable + // Non copyable HloInstructionWrapper(const HloInstructionWrapper&) = delete; HloInstructionWrapper& operator=(const HloInstructionWrapper&) = delete; // Movable. @@ -114,6 +119,8 @@ class HloInstructionWrapper : public HloInstructionInterface { bytes_accessed_ = cost_analysis->bytes_accessed(*instr_); } + const std::string& Expression() const override { return expression_; } + void AddFusedChild(const HloInstructionWrapper* child) { fused_children_.push_back(child); }; @@ -122,6 +129,14 @@ class HloInstructionWrapper : public HloInstructionInterface { return fused_children_; } + std::string OpLocationStack(int32_t frame_id) const override { + return GetOpLocationStack(frame_id, instr_); + } + + tsl::profiler::OpSourceInfo SourceInfo() const override { + return GetSourceInfo(instr_); + } + private: const xla::HloInstruction* instr_; std::vector fused_children_; @@ -129,6 +144,7 @@ class HloInstructionWrapper : public HloInstructionInterface { size_t flops_ = 0; size_t bytes_accessed_ = 0; std::string category_; + std::string expression_; }; // Helper class for accessing HloModule. diff --git a/tensorflow/core/profiler/utils/hlo_module_utils.h b/tensorflow/core/profiler/utils/hlo_module_utils.h index ab15a06c669e6a..f86161a6704f60 100644 --- a/tensorflow/core/profiler/utils/hlo_module_utils.h +++ b/tensorflow/core/profiler/utils/hlo_module_utils.h @@ -16,15 +16,22 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_UTILS_H_ +#include +#include #include +#include "absl/strings/str_cat.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/tsl/profiler/convert/xla_op_utils.h" namespace tensorflow { namespace profiler { +// Sometimes HLO produce a huge string (>100MB). Limit the name size to 1MB. +static constexpr size_t kMaxHlolNameSize = 1000000; + inline const xla::HloInstruction* FindInstruction(const xla::HloModule& module, std::string node_name) { if (absl::StartsWith(node_name, "%")) { @@ -54,6 +61,52 @@ inline const xla::HloComputation* FindComputation( } return nullptr; } + +inline std::string UncachedExpression(const xla::HloInstruction* instr, + bool skip_expression, size_t max_size) { + if (skip_expression) { + return ""; + } + static const auto* hlo_print_options = + new xla::HloPrintOptions(xla::HloPrintOptions() + .set_print_metadata(false) + .set_print_backend_config(false) + .set_print_infeed_outfeed_config(false)); + std::string expression = instr->ToString(*hlo_print_options); + if (expression.size() > max_size) { + expression.resize(max_size); + } + return expression; +} + +inline std::string GetOpLocationStack(int32_t frame_id, + const xla::HloInstruction* instr) { + std::string stack_lines; + xla::HloModule* hlo_module = instr->GetModule(); + while (frame_id != 0) { + xla::HloModule::StackFrame frame = hlo_module->get_stack_frame(frame_id); + if (frame.empty()) { + break; + } + stack_lines.insert(0, absl::StrCat(frame.file_name, ":", frame.line, ":", + frame.column, "\n")); + frame_id = frame.parent_frame_id; + } + + return stack_lines; +}; + +inline tsl::profiler::OpSourceInfo GetSourceInfo( + const xla::HloInstruction* instr) { + if (int32_t stack_frame_id = instr->metadata().stack_frame_id(); + stack_frame_id != 0) { + return {.source_file = instr->metadata().source_file(), + .source_line = instr->metadata().source_line(), + .stack_frame = GetOpLocationStack(stack_frame_id, instr)}; + } + return {.source_file = instr->metadata().source_file(), + .source_line = instr->metadata().source_line()}; +}; } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/hlo_module_utils_test.cc b/tensorflow/core/profiler/utils/hlo_module_utils_test.cc new file mode 100644 index 00000000000000..18eb2a2cdce7ce --- /dev/null +++ b/tensorflow/core/profiler/utils/hlo_module_utils_test.cc @@ -0,0 +1,104 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/profiler/utils/hlo_module_utils.h" + +#include + +#include +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace profiler { +namespace { + +class HloModuleUtilsTest : public xla::HloTestBase { + protected: + absl::StatusOr> GetModuleWithStackFrames() { + const char file_name[] = "main.py"; + const char function_name[] = "func1"; + const int line_number = 10; + const int column_number = 5; + const int frame_id = 1; + const char text[] = R"( + HloModule a_module + + ENTRY main { + %c = s32[] constant(1) + ROOT %result = s32[] parameter(0) + } + )"; + TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(text)); + + auto module_proto = module->ToProto(); + auto index = module_proto.mutable_stack_frame_index(); + index->add_file_names(file_name); + index->add_function_names(function_name); + auto location = index->add_file_locations(); + location->set_file_name_id(frame_id); + location->set_function_name_id(1); + location->set_line(line_number); + location->set_column(column_number); + + auto frame = index->add_stack_frames(); + frame->set_file_location_id(1); + + // Set the stack frame id of the root instruction. + for (auto& computation : *module_proto.mutable_computations()) { + if (computation.id() == module_proto.entry_computation_id()) { + for (auto& instruction : *computation.mutable_instructions()) { + if (instruction.id() == computation.root_id()) { + instruction.mutable_metadata()->set_stack_frame_id(frame_id); + instruction.mutable_metadata()->set_source_file(file_name); + instruction.mutable_metadata()->set_source_line(line_number); + } + } + } + } + + return xla::HloModule::CreateFromProto(module_proto, module->config()); + } +}; + +TEST_F(HloModuleUtilsTest, TestGetLocationStack) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module_with_stack_frames, + GetModuleWithStackFrames()); + auto root_instruction = + module_with_stack_frames->entry_computation()->root_instruction(); + EXPECT_EQ(GetOpLocationStack(1, root_instruction), "main.py:10:5\n"); +} + +TEST_F(HloModuleUtilsTest, TestGetSourceInfo) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module_with_stack_frames, + GetModuleWithStackFrames()); + auto root_instruction = + module_with_stack_frames->entry_computation()->root_instruction(); + auto source_info = GetSourceInfo(root_instruction); + EXPECT_EQ(source_info.source_file, "main.py"); + EXPECT_EQ(source_info.source_line, 10); + EXPECT_EQ(source_info.stack_frame, "main.py:10:5\n"); +} + +} // namespace +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/host_offload_utils.cc b/tensorflow/core/profiler/utils/host_offload_utils.cc index edf9ebe4c14088..7f135985d0b1c6 100644 --- a/tensorflow/core/profiler/utils/host_offload_utils.cc +++ b/tensorflow/core/profiler/utils/host_offload_utils.cc @@ -175,11 +175,12 @@ void HostOffloadEventProcessor::ProcessHostOffloadOpEvent( event_builder.AddStatValue(async_stat, 1); // Set metadata stats for the event. - const XStatMetadata& bytes_stat = *plane_builder_->GetOrCreateStatMetadata( - GetStatTypeStr(StatType::kBytesAccessed)); + const XStatMetadata& raw_bytes_stat = + *plane_builder_->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kRawBytesAccessed)); event.Metadata().ForEachStat([&](const XStatVisitor& stat) { - if (stat.Type() == StatType::kBytesAccessed) { - event_builder.AddStatValue(bytes_stat, stat.IntValue()); + if (stat.Type() == StatType::kRawBytesAccessed) { + event_builder.AddStatValue(raw_bytes_stat, stat.IntValue()); } }); const XStatMetadata& shape_with_layout_str = diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc index 5c8f13e58e8e0d..50feae968b1130 100644 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc +++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc @@ -127,6 +127,9 @@ void SetOpMetadataFromHloEventMetadata( case StatType::kFlops: op_metrics->set_flops(stat.IntOrUintValue()); break; + case StatType::kModelFlops: + op_metrics->set_model_flops(stat.IntOrUintValue()); + break; case StatType::kBytesAccessed: op_metrics->set_bytes_accessed(stat.IntOrUintValue()); break; @@ -184,6 +187,7 @@ void SetOpMetricsFromHloEvent(const tsl::profiler::XEventVisitor& hlo_event, op_metrics->set_min_time_ps(min_duration_ps); op_metrics->set_self_time_ps(self_duration_ps); op_metrics->set_dma_stall_ps(dma_stall_ps); + op_metrics->set_num_cores(1); } else { op_metrics->set_occurrences(op_metrics->occurrences() + hlo_event.NumOccurrences()); @@ -197,6 +201,12 @@ void SetOpMetricsFromHloEvent(const tsl::profiler::XEventVisitor& hlo_event, void AdjustFlopsAndBytesAccessed(OpMetrics& op_metrics) { op_metrics.set_flops(op_metrics.flops() * op_metrics.occurrences()); + if (op_metrics.model_flops() > 0) { + op_metrics.set_model_flops(op_metrics.model_flops() * + op_metrics.occurrences()); + } else { + op_metrics.set_model_flops(op_metrics.flops()); + } op_metrics.set_bytes_accessed(op_metrics.bytes_accessed() * op_metrics.occurrences()); for (auto& memory_access : *op_metrics.mutable_memory_accessed_breakdown()) { @@ -209,7 +219,7 @@ void AdjustFlopsAndBytesAccessed(OpMetrics& op_metrics) { OpMetricsDbBuilder::OpMetricsDbBuilder(OpMetricsDb* db) : db_(db) { DCHECK_NE(db_, nullptr); - DCHECK_EQ(db_->metrics_db_size(), 0); + DCHECK_EQ(db_->metrics_db_size(), db->metrics_db_size()); } OpMetrics* OpMetricsDbBuilder::LookupOrInsertNewOpMetrics( diff --git a/tensorflow/core/profiler/utils/op_utils.cc b/tensorflow/core/profiler/utils/op_utils.cc index 292abfb3edd177..cb126c1a3419d3 100644 --- a/tensorflow/core/profiler/utils/op_utils.cc +++ b/tensorflow/core/profiler/utils/op_utils.cc @@ -20,11 +20,13 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/tsl/profiler/utils/tf_op_utils.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" +#include "tensorflow/core/profiler/utils/hlo_module_map.h" namespace tensorflow { namespace profiler { @@ -41,6 +43,51 @@ double GetCappedPerf(double perf, uint64 time, double rate_limit) { } // namespace +// Annotate the op_metrics with the metadata from the instr_wrapper. +void EnterOpMetadata(OpMetrics* op_metrics, + const HloInstructionWrapper* instr_wrapper) { + if (op_metrics->name().empty() && op_metrics->category().empty() && + op_metrics->provenance().empty()) { + op_metrics->set_name(std::string(instr_wrapper->Name())); + op_metrics->set_category(std::string(instr_wrapper->Category())); + op_metrics->set_deduplicated_name( + instr_wrapper->Metadata().deduplicated_name()); + op_metrics->set_provenance(std::string(instr_wrapper->op_full_name())); + op_metrics->set_num_cores(1); + op_metrics->set_occurrences(op_metrics->occurrences() + 1); + op_metrics->set_flops(op_metrics->flops() + instr_wrapper->flops()); + op_metrics->set_bytes_accessed(op_metrics->bytes_accessed() + + instr_wrapper->bytes_accessed()); + op_metrics->set_long_name(instr_wrapper->Expression()); + } +} + +void AddFusionChildrenToOpMetricsFromHloInstruction( + OpMetrics* op_metrics, const HloInstructionWrapper* instr_wrapper) { + if (instr_wrapper->FusedChildren().empty()) return; + for (const HloInstructionWrapper* child : instr_wrapper->FusedChildren()) { + if (child->HloOpcode() == xla::HloOpcode::kParameter || + child->HloOpcode() == xla::HloOpcode::kTuple) + continue; + OpMetrics* child_op_metrics = + op_metrics->mutable_children()->add_metrics_db(); + // DeviceOpMetricsDbBuilder children_db_builder( + // op_metrics->mutable_children()); + EnterOpMetadata(child_op_metrics, child); + // children_db_builder.EnterOpMetadata(child_op_metrics, child); + AddFusionChildrenToOpMetricsFromHloInstruction(child_op_metrics, child); + } +} + +void EnterOpMetadataFromHloModuleMap(OpMetrics* op_metrics, + const HloModuleMap& hlo_module_map) { + const HloInstructionWrapper* instr_wrapper = GetHloInstruction( + hlo_module_map, op_metrics->hlo_module_id(), op_metrics->name()); + if (instr_wrapper != nullptr) { + AddFusionChildrenToOpMetricsFromHloInstruction(op_metrics, instr_wrapper); + } +} + void HostOpMetricsDbBuilder::EnterOp(absl::string_view name, absl::string_view category, bool is_eager, uint64 time_ps, uint64 children_time_ps) { @@ -75,6 +122,15 @@ void HostOpMetricsDbBuilder::EnterHostInfeedEnqueue( last_host_infeed_enqueue_ = host_infeed_enqueue; } +void DeviceOpMetricsDbBuilder::EnterOpMetadataFromHloModuleMap( + uint64 program_id, absl::string_view op_name, + const HloModuleMap& hlo_module_map) { + OpMetrics* op_metrics = + LookupOrInsertNewOpMetrics(program_id, op_name, /*fingerprint=*/0); + tensorflow::profiler::EnterOpMetadataFromHloModuleMap(op_metrics, + hlo_module_map); +} + void DeviceOpMetricsDbBuilder::EnterOpMetadata( uint64 program_id, absl::string_view program_name, absl::string_view category, absl::string_view provenance, diff --git a/tensorflow/core/profiler/utils/op_utils.h b/tensorflow/core/profiler/utils/op_utils.h index d83eb0c0942575..b3329b08e9e95f 100644 --- a/tensorflow/core/profiler/utils/op_utils.h +++ b/tensorflow/core/profiler/utils/op_utils.h @@ -21,11 +21,21 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" +#include "tensorflow/core/profiler/utils/hlo_module_map.h" #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" namespace tensorflow { namespace profiler { +// Annotate the op_metrics with the metadata from the instr_wrapper. +void EnterOpMetadata(OpMetrics* op_metrics, + const HloInstructionWrapper* instr_wrapper); +void EnterOpMetadataFromHloModuleMap(OpMetrics* op_metrics, + const HloModuleMap& hlo_module_map); + +void AddFusionChildrenToOpMetricsFromHloInstruction( + OpMetrics* op_metrics, const HloInstructionWrapper* instr_wrapper); + class HostOpMetricsDbBuilder : public OpMetricsDbBuilder { public: explicit HostOpMetricsDbBuilder(OpMetricsDb* db) : OpMetricsDbBuilder(db) {} @@ -84,6 +94,10 @@ class DeviceOpMetricsDbBuilder : public OpMetricsDbBuilder { absl::string_view category, absl::string_view provenance, absl::string_view deduplicated_name, bool is_eager, absl::string_view long_name = ""); + + void EnterOpMetadataFromHloModuleMap(uint64 program_id, + absl::string_view op_name, + const HloModuleMap& hlo_module_map); }; } // namespace profiler diff --git a/tensorflow/core/profiler/utils/tpu_step_breakdown_utils.h b/tensorflow/core/profiler/utils/tpu_step_breakdown_utils.h new file mode 100644 index 00000000000000..731481a4da8612 --- /dev/null +++ b/tensorflow/core/profiler/utils/tpu_step_breakdown_utils.h @@ -0,0 +1,75 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_BREAKDOWN_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_BREAKDOWN_UTILS_H_ + +#include + +#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" + +namespace tensorflow { +namespace profiler { + +// Total duration of infeed from host or SparseCoreV0 to TensorCore. +inline uint64_t InfeedDurationPs(const TpuStepBreakdown& tpu) { + return tpu.infeed_duration_ps() + tpu.wait_for_scv0_duration_ps() + + tpu.scv0_infeed_transform_ps(); +} + +// Total duration of outfeed from TensorCore to host or SparseCoreV0. +inline uint64_t OutfeedDurationPs(const TpuStepBreakdown& tpu) { + return tpu.host_outfeed_ps() + tpu.scv0_outfeed_ps(); +} + +// Total duration of infeed from host to SparseCoreV0. +inline uint64_t ScV0InfeedDurationPs(const TpuStepBreakdown& tpu) { + return tpu.wait_for_scv0_duration_ps() * tpu.scv0_infeed_percent() / 100.0; +} + +// Total duration of SparseCoreV0 compute. +inline uint64_t ScV0ComputeDurationPs(const TpuStepBreakdown& tpu) { + return tpu.wait_for_scv0_duration_ps() - ScV0InfeedDurationPs(tpu); +} + +// Total duration of infeed from host to TensorCore or SparseCoreV0. +inline uint64_t TcPlusScV0InfeedDurationPs(const TpuStepBreakdown& tpu) { + return tpu.infeed_duration_ps() + ScV0InfeedDurationPs(tpu); +} + +// Total duration of send and recv ops. +inline uint64_t SendRecvDurationPs(const TpuStepBreakdown& tpu) { + return tpu.send_duration_ps() + tpu.recv_duration_ps(); +} + +// Total duration of host send and host recv ops. +inline uint64_t HostSendRecvDurationPs(const TpuStepBreakdown& tpu) { + return tpu.host_send_duration_ps() + tpu.host_recv_duration_ps(); +} + +// Total duration TensorCore spends waiting for host. +inline uint64_t WaitForHostDurationPs(const TpuStepBreakdown& tpu) { + return tpu.infeed_duration_ps() + tpu.host_outfeed_ps() + + HostSendRecvDurationPs(tpu) + tpu.tc_idle_ps(); +} + +// Total duration TensorCore spends waiting for host or SparseCoreV0. +inline uint64_t WaitForHostOrScV0DurationPs(const TpuStepBreakdown& tpu) { + return WaitForHostDurationPs(tpu) + tpu.wait_for_scv0_duration_ps(); +} + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_BREAKDOWN_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/tpu_step_details_utils.h b/tensorflow/core/profiler/utils/tpu_step_details_utils.h new file mode 100644 index 00000000000000..d26e4973d757de --- /dev/null +++ b/tensorflow/core/profiler/utils/tpu_step_details_utils.h @@ -0,0 +1,51 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_DETAILS_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_DETAILS_UTILS_H_ + +#include + +#include "tensorflow/core/profiler/protobuf/tpu_input_pipeline.pb.h" + +namespace tensorflow { +namespace profiler { + +inline double ComputeTimeMs(const PerTpuStepDetails& details) { + return details.tc_compute_time_ms() + details.scv0_compute_time_ms(); +} + +inline double InfeedTimeMs(const PerTpuStepDetails& details) { + return details.tc_infeed_time_ms() + details.scv0_infeed_time_ms(); +} + +inline double AllReduceTimeMs(const PerTpuStepDetails& details) { + return details.all_reduce_compute_time_ms() + + details.all_reduce_sync_time_ms(); +} + +inline double NonIdleTimeMs(const PerTpuStepDetails& details) { + return ComputeTimeMs(details) + InfeedTimeMs(details) + + AllReduceTimeMs(details) + details.tc_outfeed_time_ms(); +} + +// Time spent by a training step on TPU. +inline double StepTimeMs(const PerTpuStepDetails& details) { + return NonIdleTimeMs(details) + details.tc_idle_time_ms(); +} + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_DETAILS_UTILS_H_ diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 0a8e80a249b09f..ac1c068a53fc46 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2072 // Updated: 2024/12/10 +#define TF_GRAPH_DEF_VERSION 2106 // Updated: 2025/1/13 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/runtime_fallback/kernel/BUILD b/tensorflow/core/runtime_fallback/kernel/BUILD index 3e4b3e12970b94..9ef8dba666689e 100644 --- a/tensorflow/core/runtime_fallback/kernel/BUILD +++ b/tensorflow/core/runtime_fallback/kernel/BUILD @@ -128,6 +128,7 @@ tf_cc_test( deps = [ ":attr_util", "//tensorflow/c:tf_tensor", + "@com_google_absl//absl/status", "@tf_runtime//:core_runtime", "@tf_runtime//:hostcontext", "@tf_runtime//:support", diff --git a/tensorflow/core/runtime_fallback/kernel/attr_util.cc b/tensorflow/core/runtime_fallback/kernel/attr_util.cc index 9f3040aea45835..82bb7ce1b89b57 100644 --- a/tensorflow/core/runtime_fallback/kernel/attr_util.cc +++ b/tensorflow/core/runtime_fallback/kernel/attr_util.cc @@ -39,7 +39,7 @@ namespace tensorflow { // TODO(annarev): merge this file with attr_util.cc // after reducing attr_util dependencies. -DataType ParseTFDataType(StringPiece dtype) { +DataType ParseTFDataType(absl::string_view dtype) { if (dtype == "DT_INT8") { return DataType::DT_INT8; } else if (dtype == "DT_INT32") { @@ -56,7 +56,7 @@ DataType ParseTFDataType(StringPiece dtype) { } } -bool ParseBoolAttrValue(StringPiece attr_value) { +bool ParseBoolAttrValue(absl::string_view attr_value) { if (attr_value == "false") { return false; } else if (attr_value == "true") { @@ -67,12 +67,12 @@ bool ParseBoolAttrValue(StringPiece attr_value) { } } -Status ParseValue(StringPiece input, bool* value) { +absl::Status ParseValue(absl::string_view input, bool* value) { *value = ParseBoolAttrValue(input); return absl::OkStatus(); } -Status ParseValue(StringPiece input, int32* value) { +absl::Status ParseValue(absl::string_view input, int32* value) { bool parse_result = absl::SimpleAtoi(input, value); if (!parse_result) { return errors::InvalidArgument("Could not parse int32 from ", input); @@ -80,17 +80,17 @@ Status ParseValue(StringPiece input, int32* value) { return absl::OkStatus(); } -Status ParseValue(StringPiece input, DataType* value) { +absl::Status ParseValue(absl::string_view input, DataType* value) { *value = ParseTFDataType(input); return absl::OkStatus(); } -Status ParseValue(StringPiece input, std::string* value) { +absl::Status ParseValue(absl::string_view input, std::string* value) { *value = std::string(input); return absl::OkStatus(); } -Status ParseValue(StringPiece input, std::vector* value) { +absl::Status ParseValue(absl::string_view input, std::vector* value) { std::vector parts = str_util::Split(input, ","); value->reserve(parts.size()); for (const auto& value_str : parts) { @@ -105,13 +105,13 @@ Status ParseValue(StringPiece input, std::vector* value) { return absl::OkStatus(); } -Status ParseValue(StringPiece input, Padding* value) { +absl::Status ParseValue(absl::string_view input, Padding* value) { return GetPaddingFromString(input, value); } -Status AddOpAttr(const std::string& name, const std::string& attr_value, - tfrt::OpAttrs* opattrs) { - Status s; +absl::Status AddOpAttr(const std::string& name, const std::string& attr_value, + tfrt::OpAttrs* opattrs) { + absl::Status s; // Splits attr_value into type and value std::vector value_split = tfd::AttrValueSplit(attr_value); auto& type = value_split[0]; @@ -140,14 +140,15 @@ Status AddOpAttr(const std::string& name, const std::string& attr_value, return s; } -Status FillOpAttrs(tfrt::RemainingAttributes attrs, tfrt::OpAttrs* opattrs) { +absl::Status FillOpAttrs(tfrt::RemainingAttributes attrs, + tfrt::OpAttrs* opattrs) { int num_tf_attrs = attrs.size() / 2; - Status status; + absl::Status status; for (int i = 0; i < num_tf_attrs; ++i) { // Each TF attribute is represented as a pair of name and value strings. std::string name = attrs.GetStringAttribute(i * 2).str(); std::string attr_value = attrs.GetStringAttribute(i * 2 + 1).str(); - Status s = AddOpAttr(name, attr_value, opattrs); + absl::Status s = AddOpAttr(name, attr_value, opattrs); status.Update(s); } return status; diff --git a/tensorflow/core/runtime_fallback/kernel/attr_util.h b/tensorflow/core/runtime_fallback/kernel/attr_util.h index 387f227f1c8cb4..4abbb4f8b31c58 100644 --- a/tensorflow/core/runtime_fallback/kernel/attr_util.h +++ b/tensorflow/core/runtime_fallback/kernel/attr_util.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "llvm/ADT/StringMap.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/status.h" @@ -36,17 +37,18 @@ namespace tensorflow { typedef llvm::StringMap AttrMap; // Parse value from the given string input. -Status ParseValue(StringPiece input, bool* value); -Status ParseValue(StringPiece input, int32* value); -Status ParseValue(StringPiece input, DataType* value); -Status ParseValue(StringPiece input, std::string* value); -Status ParseValue(StringPiece input, std::vector* value); -Status ParseValue(StringPiece input, Padding* value); - -Status AddOpAttr(const std::string& name, const std::string& attr_value, - tfrt::OpAttrs* opattrs); - -Status FillOpAttrs(tfrt::RemainingAttributes attrs, tfrt::OpAttrs* opattrs); +absl::Status ParseValue(absl::string_view input, bool* value); +absl::Status ParseValue(absl::string_view input, int32* value); +absl::Status ParseValue(absl::string_view input, DataType* value); +absl::Status ParseValue(absl::string_view input, std::string* value); +absl::Status ParseValue(absl::string_view input, std::vector* value); +absl::Status ParseValue(absl::string_view input, Padding* value); + +absl::Status AddOpAttr(const std::string& name, const std::string& attr_value, + tfrt::OpAttrs* opattrs); + +absl::Status FillOpAttrs(tfrt::RemainingAttributes attrs, + tfrt::OpAttrs* opattrs); } // namespace tensorflow #endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_ATTR_UTIL_H_ diff --git a/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc b/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc index 5b881676decffa..79d80b13ff501a 100644 --- a/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc +++ b/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include "absl/status/status.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" @@ -50,7 +51,7 @@ TEST(AttrUtilTest, TestGetIntAttr) { ASSERT_EQ(opattrs.GetAsserting("bar"), 0); ASSERT_EQ(opattrs.GetAsserting("baz"), 123); - Status s = AddOpAttr("invalid", "i32$4.5", &opattrs); + absl::Status s = AddOpAttr("invalid", "i32$4.5", &opattrs); ASSERT_FALSE(s.ok()); } diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc index d8ab1b18c5f483..f380f6003cb07f 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.cc @@ -148,7 +148,7 @@ static std::function)>* GetDefaultRunner() { return default_runner; } -Status SetUpKernelFallbackCompatRequestContext( +absl::Status SetUpKernelFallbackCompatRequestContext( tfrt::RequestContextBuilder* builder, const tensorflow::DeviceMgr* device_manager, const tensorflow::ProcessFunctionLibraryRuntime* pflr, diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h index 201eae2e1c6f5d..6cfbf88ca3f2cf 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h @@ -229,7 +229,7 @@ class KernelFallbackCompatRequestState { // function library runtime. They will be forwarded to tensorflow::OpKernel as // in tensorflow::Executor. If `runner` is nullptr, internally it will use a // default runner that executes tasks in the caller thread. -Status SetUpKernelFallbackCompatRequestContext( +absl::Status SetUpKernelFallbackCompatRequestContext( tfrt::RequestContextBuilder* builder, const tensorflow::DeviceMgr* device_manager, const tensorflow::ProcessFunctionLibraryRuntime* pflr, diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.cc b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.cc index aa48bcf6be10f0..0dd34564e39d81 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.cc +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.cc @@ -80,7 +80,7 @@ void KernelFallbackEmitError( const KernelFallbackCompatRequestState* fallback_request_state, tfrt::string_view op_name, tfrt::AsyncValueRef* op_chain, llvm::MutableArrayRef> results, - const tensorflow::Status& status) { + const absl::Status& status) { // Set all results to error, with the correct TFRT error code according to the // error propagated from runtime fallback execution. auto model_info = @@ -117,7 +117,7 @@ ConvertInputTensors(llvm::ArrayRef arguments) { return input_tf_tensors; } -static Status ValidateInputTypes( +static absl::Status ValidateInputTypes( tfrt::string_view op_name, const absl::InlinedVector& input_tf_tensors, const DataTypeVector& input_types) { @@ -261,7 +261,7 @@ tfrt::AsyncValueRef KernelFallbackExecuteCompatCoreRuntimeDispatch( const KernelFallbackCompatRequestState& fallback_request_state, const OpKernelRunner& op_kernel_runner) { auto op_chain = tfrt::GetReadyChain(); - tensorflow::Status status; + absl::Status status; auto expected_input_tf_tensors = ConvertInputTensors(arguments); if (!expected_input_tf_tensors) { diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat_eager.cc b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat_eager.cc index 745074cf568bd6..d70627e97c8c43 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat_eager.cc +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat_eager.cc @@ -37,7 +37,7 @@ constexpr char kFallbackResourceArray[] = "FallbackResourceArray"; } // namespace -Status SetUpKernelFallbackCompatRequestContext( +absl::Status SetUpKernelFallbackCompatRequestContext( tfrt::RequestContextBuilder* builder, tfrt_stub::OpKernelRunnerTable* runner_table, tensorflow::EagerContext* eager_context, diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat_eager.h b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat_eager.h index 05c302e9299b5c..cf3e00149a1c82 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat_eager.h +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat_eager.h @@ -26,7 +26,7 @@ namespace tfd { // Runner_table can be nullptr. In that case, kernel_fallback will use // the default runner_table. -Status SetUpKernelFallbackCompatRequestContext( +absl::Status SetUpKernelFallbackCompatRequestContext( tfrt::RequestContextBuilder* builder, tfrt_stub::OpKernelRunnerTable* runner_table, tensorflow::EagerContext* eager_context, diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_kernels.cc b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_kernels.cc index 6a03357a2592aa..da93625c5111c2 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_kernels.cc +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_kernels.cc @@ -55,7 +55,7 @@ static void TFDForwardKernel(tfrt::RemainingArguments arguments, } std::string op_name_str = op_name.str(); tfrt::OpAttrs opattrs; - Status s = FillOpAttrs(attributes, &opattrs); + absl::Status s = FillOpAttrs(attributes, &opattrs); if (!s.ok()) { frame->ReportError("TFDForwardKernel: Error while parsing attributes: ", s.message()); diff --git a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_op_handler.cc b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_op_handler.cc index dd5f8c5774ebae..b6641193aa72e2 100644 --- a/tensorflow/core/runtime_fallback/kernel/kernel_fallback_op_handler.cc +++ b/tensorflow/core/runtime_fallback/kernel/kernel_fallback_op_handler.cc @@ -155,7 +155,7 @@ Expected KernelFallbackOpHandler::MakeOp(string_view op_name) { op_name.consume_front("tf."); return CoreRuntimeOp( [op_name = op_name.str(), this](const OpInvocation& invocation) { - auto propagate_error = [&invocation](Status s) { + auto propagate_error = [&invocation](absl::Status s) { auto error = tfrt::EmitErrorAsync( invocation.exec_ctx, absl::Status( diff --git a/tensorflow/core/runtime_fallback/kernel/tensor_util.cc b/tensorflow/core/runtime_fallback/kernel/tensor_util.cc index 2bbc121d549e56..aa3cc3142353e6 100644 --- a/tensorflow/core/runtime_fallback/kernel/tensor_util.cc +++ b/tensorflow/core/runtime_fallback/kernel/tensor_util.cc @@ -53,7 +53,7 @@ llvm::Expected GetTfDevice(const tfrt::ExecutionContext& exec_ctx, return eager_context_expected.takeError(); } Device* tf_device; - Status s = eager_context_expected.get()->FindDeviceFromName( + absl::Status s = eager_context_expected.get()->FindDeviceFromName( device.name().data(), &tf_device); if (!s.ok()) { return tfrt::MakeStringError(s.message()); diff --git a/tensorflow/core/runtime_fallback/kernel/tensor_util.h b/tensorflow/core/runtime_fallback/kernel/tensor_util.h index 8e0ab312d35be5..6126f10457338e 100644 --- a/tensorflow/core/runtime_fallback/kernel/tensor_util.h +++ b/tensorflow/core/runtime_fallback/kernel/tensor_util.h @@ -90,19 +90,19 @@ tfrt::AsyncValueRef TransferTensorToDevice( // the GPU. With that setup, Sync()ing across all 3 streams should be // sufficient but more than necessary (since it waits for operations // that might have nothing to do with this tensor to complete). - Status s = src_device->Sync(); + absl::Status s = src_device->Sync(); if (!s.ok()) { result.SetError(absl::InternalError(s.message())); return; } tensorflow::Notification n; - tensorflow::Status status; + absl::Status status; tensorflow::CopyTensor::ViaDMA( "copy", src_device_context, dst_device_context, src_device, dst_device, tensorflow::AllocatorAttributes(), tensorflow::AllocatorAttributes(), &src, &dst, 0 /*dev_to_dev_stream_index*/, - [&status, &n](const tensorflow::Status& s) { + [&status, &n](const absl::Status& s) { status = s; n.Notify(); }); diff --git a/tensorflow/core/runtime_fallback/kernel/tfrt_op_kernel.cc b/tensorflow/core/runtime_fallback/kernel/tfrt_op_kernel.cc index 6272e343a0e2ea..41e7cfae0637e7 100644 --- a/tensorflow/core/runtime_fallback/kernel/tfrt_op_kernel.cc +++ b/tensorflow/core/runtime_fallback/kernel/tfrt_op_kernel.cc @@ -41,13 +41,13 @@ TFRTOpKernelConstruction::TFRTOpKernelConstruction( const tfrt::OpAttrsRef& attributes) : attributes_(std::move(attributes)) {} -Status MissingAttributeError(StringPiece attr_name) { +absl::Status MissingAttributeError(absl::string_view attr_name) { return errors::InvalidArgument("Missing attribute: ", attr_name); } template <> -Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name, - std::string* value) const { +absl::Status TFRTOpKernelConstruction::GetAttr(absl::string_view attr_name, + std::string* value) const { tfrt::string_view view; bool success = attributes_.GetString( llvm::StringRef(attr_name.data(), attr_name.size()), &view); @@ -59,8 +59,8 @@ Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name, } template <> -Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name, - DataType* value) const { +absl::Status TFRTOpKernelConstruction::GetAttr(absl::string_view attr_name, + DataType* value) const { tfrt::OpAttrType attrtype; bool success = attributes_.Get( llvm::StringRef(attr_name.data(), attr_name.size()), &attrtype); @@ -72,16 +72,16 @@ Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name, } template <> -Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name, - Padding* value) const { +absl::Status TFRTOpKernelConstruction::GetAttr(absl::string_view attr_name, + Padding* value) const { std::string padding_str; TF_RETURN_IF_ERROR(GetAttr(attr_name, &padding_str)); return GetPaddingFromString(padding_str, value); } template <> -Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name, - std::vector* value) const { +absl::Status TFRTOpKernelConstruction::GetAttr( + absl::string_view attr_name, std::vector* value) const { llvm::ArrayRef arrayref; bool success = attributes_.GetArray( llvm::StringRef(attr_name.data(), attr_name.size()), &arrayref); @@ -92,16 +92,17 @@ Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name, return absl::OkStatus(); } -void TFRTOpKernelConstruction::CtxFailure(const Status& s) { +void TFRTOpKernelConstruction::CtxFailure(const absl::Status& s) { error_ = tfrt::MakeStatusString(s); } -void TFRTOpKernelConstruction::CtxFailureWithWarning(const Status& s) { +void TFRTOpKernelConstruction::CtxFailureWithWarning(const absl::Status& s) { CtxFailure(s); } namespace { -std::string FillFailureMessage(const char* file, int line, const Status& s) { +std::string FillFailureMessage(const char* file, int line, + const absl::Status& s) { std::string error; llvm::raw_string_ostream sstr(error); sstr << "OP_REQUIRES failed at " << file << ":" << line << " : " @@ -112,12 +113,12 @@ std::string FillFailureMessage(const char* file, int line, const Status& s) { } // namespace void TFRTOpKernelConstruction::CtxFailure(const char* file, int line, - const Status& s) { + const absl::Status& s) { error_ = FillFailureMessage(file, line, s); } void TFRTOpKernelConstruction::CtxFailureWithWarning(const char* file, int line, - const Status& s) { + const absl::Status& s) { CtxFailure(file, line, s); } @@ -156,15 +157,16 @@ void TFRTOpKernelContext::set_output(int index, const Tensor& tensor) { outputs_[index] = tensor; } -Status TFRTOpKernelContext::allocate_temp(DataType type, - const TensorShape& shape, - Tensor* out_temp) { +absl::Status TFRTOpKernelContext::allocate_temp(DataType type, + const TensorShape& shape, + Tensor* out_temp) { *out_temp = Tensor(type, shape); return absl::OkStatus(); } -Status TFRTOpKernelContext::allocate_output(int index, const TensorShape& shape, - Tensor** tensor) { +absl::Status TFRTOpKernelContext::allocate_output(int index, + const TensorShape& shape, + Tensor** tensor) { // Fetch output DataType from the op's TFRTOpMeta. DataType output_type = op_meta_->output_type(index); outputs_[index] = Tensor(output_type, shape); @@ -176,16 +178,18 @@ DataType TFRTOpKernelContext::expected_output_dtype(int i) const { return op_meta_->output_type(i); } -void TFRTOpKernelContext::CtxFailure(const Status& s) { error_ = s.message(); } -void TFRTOpKernelContext::CtxFailureWithWarning(const Status& s) { +void TFRTOpKernelContext::CtxFailure(const absl::Status& s) { + error_ = s.message(); +} +void TFRTOpKernelContext::CtxFailureWithWarning(const absl::Status& s) { CtxFailure(s); } void TFRTOpKernelContext::CtxFailure(const char* file, int line, - const Status& s) { + const absl::Status& s) { error_ = FillFailureMessage(file, line, s); } void TFRTOpKernelContext::CtxFailureWithWarning(const char* file, int line, - const Status& s) { + const absl::Status& s) { CtxFailure(file, line, s); } @@ -204,11 +208,12 @@ DataType TFRTOpMeta::output_type(int index) const { return output_types_[index]; } -TFRTOpMetaBuilder::TFRTOpMetaBuilder(StringPiece op_name) : op_name_(op_name) {} +TFRTOpMetaBuilder::TFRTOpMetaBuilder(absl::string_view op_name) + : op_name_(op_name) {} namespace { -DataType ParseInputOutputSpec(StringPiece spec) { +DataType ParseInputOutputSpec(absl::string_view spec) { std::vector name_type = absl::StrSplit(spec, absl::MaxSplits(':', 2)); DataType data_type; @@ -221,16 +226,16 @@ DataType ParseInputOutputSpec(StringPiece spec) { } // anonymous namespace -TFRTOpMetaBuilder& TFRTOpMetaBuilder::Output(StringPiece output_spec) { +TFRTOpMetaBuilder& TFRTOpMetaBuilder::Output(absl::string_view output_spec) { output_types_.push_back(ParseInputOutputSpec(output_spec)); return *this; } -TFRTOpMetaBuilder& TFRTOpMetaBuilder::Input(StringPiece input_spec) { +TFRTOpMetaBuilder& TFRTOpMetaBuilder::Input(absl::string_view input_spec) { return *this; } -TFRTOpMetaBuilder& TFRTOpMetaBuilder::Attr(StringPiece attr_spec) { +TFRTOpMetaBuilder& TFRTOpMetaBuilder::Attr(absl::string_view attr_spec) { return *this; } @@ -249,7 +254,7 @@ void TFRTOpMetaMap::RegisterOpMeta(const TFRTOpMetaBuilder& op_builder) { (void)insert_result; } -const TFRTOpMeta* TFRTOpMetaMap::GetOpMeta(StringPiece op_name) const { +const TFRTOpMeta* TFRTOpMetaMap::GetOpMeta(absl::string_view op_name) const { auto it = op_metas_.find(llvm::StringRef(op_name.data(), op_name.size())); if (it == op_metas_.end()) return nullptr; @@ -270,19 +275,19 @@ llvm::ManagedStatic tfrt_forwarding_kernel_factories; TFRTOpKernelFactories::TFRTOpKernelFactories() = default; -void TFRTOpKernelFactories::RegisterFactory(StringPiece kernel_class_name, +void TFRTOpKernelFactories::RegisterFactory(absl::string_view kernel_class_name, TFRTOpKernelReg kernel_info) { factories_[std::string(kernel_class_name)].push_back(kernel_info); } // Returns true if kernel attributes match given type constraints. -Status ValidKernelAttr(StringPiece kernel_class_name, - TFRTOpKernelConstruction* construction, - const llvm::StringMap& constraints) { +absl::Status ValidKernelAttr(absl::string_view kernel_class_name, + TFRTOpKernelConstruction* construction, + const llvm::StringMap& constraints) { for (const auto& constraint : constraints) { auto attr_name = std::string(constraint.first()); DataType type; - Status s = construction->GetAttr(attr_name, &type); + absl::Status s = construction->GetAttr(attr_name, &type); if (!s.ok()) { return errors::InvalidArgument( "Kernel ", kernel_class_name, @@ -299,7 +304,7 @@ Status ValidKernelAttr(StringPiece kernel_class_name, } std::unique_ptr TFRTOpKernelFactories::CreateKernel( - StringPiece kernel_class_name, + absl::string_view kernel_class_name, TFRTOpKernelConstruction* op_kernel_construction) const { auto it = factories_.find(std::string(kernel_class_name)); if (it == factories_.end()) { @@ -308,10 +313,10 @@ std::unique_ptr TFRTOpKernelFactories::CreateKernel( "Could not find kernel ", kernel_class_name, " in the registry.")); return std::unique_ptr(nullptr); } - Status status; + absl::Status status; for (const auto& kernel_info : it->second) { - Status s = ValidKernelAttr(kernel_class_name, op_kernel_construction, - kernel_info.type_constraints); + absl::Status s = ValidKernelAttr(kernel_class_name, op_kernel_construction, + kernel_info.type_constraints); if (s.ok()) { return kernel_info.callback(op_kernel_construction); } diff --git a/tensorflow/core/runtime_fallback/kernel/tfrt_op_kernel.h b/tensorflow/core/runtime_fallback/kernel/tfrt_op_kernel.h index b1070a6375b67b..e370fde54e23db 100644 --- a/tensorflow/core/runtime_fallback/kernel/tfrt_op_kernel.h +++ b/tensorflow/core/runtime_fallback/kernel/tfrt_op_kernel.h @@ -65,15 +65,15 @@ class TFRTOpKernelConstruction { explicit TFRTOpKernelConstruction(const tfrt::OpAttrsRef& attributes); template - Status GetAttr(StringPiece attr_name, T* value) const; + absl::Status GetAttr(absl::string_view attr_name, T* value) const; - void CtxFailure(const Status& s); - void CtxFailureWithWarning(const Status& s); - void CtxFailure(const char* file, int line, const Status& s); - void CtxFailureWithWarning(const char* file, int line, const Status& s); + void CtxFailure(const absl::Status& s); + void CtxFailureWithWarning(const absl::Status& s); + void CtxFailure(const char* file, int line, const absl::Status& s); + void CtxFailureWithWarning(const char* file, int line, const absl::Status& s); - Status MatchSignature(const DataTypeSlice expected_inputs, - const DataTypeSlice expected_outputs) { + absl::Status MatchSignature(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs) { // TODO(annarev): Move MatchSignatureHelper out of op_kernel.h // and call it here. return absl::OkStatus(); @@ -88,26 +88,26 @@ class TFRTOpKernelConstruction { }; template <> -Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name, - std::string* value) const; +absl::Status TFRTOpKernelConstruction::GetAttr(absl::string_view attr_name, + std::string* value) const; template <> -Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name, - DataType* value) const; +absl::Status TFRTOpKernelConstruction::GetAttr(absl::string_view attr_name, + DataType* value) const; template <> -Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name, - Padding* value) const; +absl::Status TFRTOpKernelConstruction::GetAttr(absl::string_view attr_name, + Padding* value) const; template <> -Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name, - std::vector* value) const; +absl::Status TFRTOpKernelConstruction::GetAttr(absl::string_view attr_name, + std::vector* value) const; -Status MissingAttributeError(StringPiece attr_name); +absl::Status MissingAttributeError(absl::string_view attr_name); template -Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name, - T* value) const { +absl::Status TFRTOpKernelConstruction::GetAttr(absl::string_view attr_name, + T* value) const { bool success = attributes_.Get( llvm::StringRef(attr_name.data(), attr_name.size()), value); if (!success) { @@ -137,18 +137,19 @@ class TFRTOpKernelContext { Tensor** output) { return false; } - Status allocate_temp(DataType type, const TensorShape& shape, - Tensor* out_temp); - Status allocate_output(int index, const TensorShape& shape, Tensor** tensor); + absl::Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp); + absl::Status allocate_output(int index, const TensorShape& shape, + Tensor** tensor); DataType expected_output_dtype(int i) const; template const EigenDeviceType& eigen_device() const; - void CtxFailure(const Status& s); - void CtxFailureWithWarning(const Status& s); - void CtxFailure(const char* file, int line, const Status& s); - void CtxFailureWithWarning(const char* file, int line, const Status& s); + void CtxFailure(const absl::Status& s); + void CtxFailureWithWarning(const absl::Status& s); + void CtxFailure(const char* file, int line, const absl::Status& s); + void CtxFailureWithWarning(const char* file, int line, const absl::Status& s); private: llvm::ArrayRef> inputs_; @@ -201,10 +202,10 @@ class TFRTOpMeta { // AddN. class TFRTOpMetaBuilder { public: - explicit TFRTOpMetaBuilder(StringPiece op_name); - TFRTOpMetaBuilder& Output(StringPiece output_spec); - TFRTOpMetaBuilder& Input(StringPiece input_spec); - TFRTOpMetaBuilder& Attr(StringPiece attr_spec); + explicit TFRTOpMetaBuilder(absl::string_view op_name); + TFRTOpMetaBuilder& Output(absl::string_view output_spec); + TFRTOpMetaBuilder& Input(absl::string_view input_spec); + TFRTOpMetaBuilder& Attr(absl::string_view attr_spec); const string& op_name() const; TFRTOpMeta BuildMeta() const; @@ -221,7 +222,7 @@ class TFRTOpMetaMap { void RegisterOpMeta(const TFRTOpMetaBuilder& op_builder); // Returns nullptr if there is no metadata for op_name. - const TFRTOpMeta* GetOpMeta(StringPiece op_name) const; + const TFRTOpMeta* GetOpMeta(absl::string_view op_name) const; private: llvm::StringMap op_metas_; @@ -270,7 +271,7 @@ struct TFRTOpKernelReg { class TFRTOpKernelFactories { public: TFRTOpKernelFactories(); - void RegisterFactory(StringPiece kernel_class_name, + void RegisterFactory(absl::string_view kernel_class_name, TFRTOpKernelReg kernel_info); // Creates a kernel with the given name and passes op_kernel_construction @@ -284,7 +285,7 @@ class TFRTOpKernelFactories { // Note that we consider a constraint to be "not matched" if attribute // it applies to is not in op_kernel_construction. std::unique_ptr CreateKernel( - StringPiece kernel_class_name, + absl::string_view kernel_class_name, TFRTOpKernelConstruction* op_kernel_construction) const; private: diff --git a/tensorflow/core/runtime_fallback/runtime/BUILD b/tensorflow/core/runtime_fallback/runtime/BUILD index 45f433d2d732a9..0d78c029b6de76 100644 --- a/tensorflow/core/runtime_fallback/runtime/BUILD +++ b/tensorflow/core/runtime_fallback/runtime/BUILD @@ -67,6 +67,11 @@ cc_library( "//tensorflow/core/tfrt/utils:error_util", "//tensorflow/core/tfrt/utils:fallback_tensor", "//tensorflow/core/tfrt/utils:tensor_util", + "@com_google_absl//absl/base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", @@ -206,6 +211,8 @@ cc_library( "//tensorflow/core/tfrt/fallback:op_kernel_runner", "//tensorflow/core/tfrt/utils:error_util", "//tensorflow/core/tfrt/utils:fallback_tensor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", diff --git a/tensorflow/core/runtime_fallback/runtime/conversion_function.cc b/tensorflow/core/runtime_fallback/runtime/conversion_function.cc index cc9a5b3983b789..b525d6e222a819 100644 --- a/tensorflow/core/runtime_fallback/runtime/conversion_function.cc +++ b/tensorflow/core/runtime_fallback/runtime/conversion_function.cc @@ -18,9 +18,12 @@ limitations under the License. #include "tensorflow/core/runtime_fallback/runtime/conversion_function.h" +#include +#include #include #include "tensorflow/core/common_runtime/eager/execute.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h" #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.h" #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h" @@ -42,7 +45,7 @@ tfrt::Expected ConvertRuntimeFallbackTensorToDenseHostTensor( const RuntimeFallbackTensor &tensor, const tfrt::CpuDevice &src, const tfrt::CpuDevice &dst, const tfrt::ExecutionContext &exec_ctx) { - tensorflow::Status status; + absl::Status status; // Resolve ensures Tensor is on host CPU. OwnedAbstractTensorInterface tensor_interface{ tensor.GetTensorHandle()->Resolve(&status)}; @@ -68,7 +71,7 @@ ConvertRuntimeFallbackTensorToStringHostTensor( const RuntimeFallbackTensor &tensor, const tfrt::Device &src, const tfrt::CpuDevice &dst, const tfrt::ExecutionContext &exec_ctx) { auto *host_ctx = exec_ctx.host(); - tensorflow::Status status; + absl::Status status; // Resolve ensures Tensor is on host CPU. OwnedAbstractTensorInterface tensor_interface{ tensor.GetTensorHandle()->Resolve(&status)}; @@ -151,7 +154,8 @@ TransferRuntimeFallbackToAnotherDevice(const RuntimeFallbackTensor &tensor, auto *th = tensor.GetTensorHandle(); Device *tf_device; - Status s = eager_context->FindDeviceFromName(dst.name().data(), &tf_device); + absl::Status s = + eager_context->FindDeviceFromName(dst.name().data(), &tf_device); if (!s.ok()) return tfrt::MakeStringError(s.message()); auto *host = exec_ctx.host(); diff --git a/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.cc b/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.cc index 227b5b1a65650b..ba62080807f112 100644 --- a/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.cc +++ b/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/kernels/batching_util/bounded_executor.h" @@ -52,8 +54,7 @@ int32 BatchFunctionFallbackKernelBase:: int32_t num; const char* val = std::getenv("TF_NUM_BATCH_THREADS"); - return (val && strings::safe_strto32(val, &num)) ? num - : default_num_batch_threads; + return (val && absl::SimpleAtoi(val, &num)) ? num : default_num_batch_threads; } thread::ThreadPool* @@ -149,7 +150,8 @@ BatchFunctionFallbackKernelBase::BatchFunctionFallbackKernelBase( OP_REQUIRES_OK(c, ValidateAllowedBatchSizes()); } -Status BatchFunctionFallbackKernelBase::ValidateAllowedBatchSizes() const { +absl::Status BatchFunctionFallbackKernelBase::ValidateAllowedBatchSizes() + const { if (allowed_batch_sizes_.empty()) { return absl::OkStatus(); } diff --git a/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h b/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h index 86772a2a38d437..ef45282a9d7e1f 100644 --- a/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h +++ b/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_FALLBACK_BATCH_KERNEL_H_ #define TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_FALLBACK_BATCH_KERNEL_H_ +#include +#include #include #include #include @@ -51,7 +53,7 @@ class BatchFunctionFallbackKernelBase : public AsyncOpKernel { protected: // Validates 'allowed_batch_sizes_'. The entries must increase monotonically, // and the last one must equal 'max_batch_size_'. - Status ValidateAllowedBatchSizes() const; + absl::Status ValidateAllowedBatchSizes() const; // Initialize vars by reading from op-kernel-construction. // Vars @@ -265,7 +267,7 @@ void BatchFunctionFallbackKernel::ComputeAsync( auto create_batch_task_fn = [c]() { return BatchResourceType::CreateBatchTask(c); }; - Status status; + absl::Status status; if (serving::ShouldWarmupAllBatchSizes(c)) { status = (*br)->get()->RegisterWarmupInputs(guid, c, batcher_queue_, create_batch_task_fn, done); diff --git a/tensorflow/core/runtime_fallback/runtime/kernel_utils.cc b/tensorflow/core/runtime_fallback/runtime/kernel_utils.cc index 34beb55a7fbcff..655b23fff72048 100644 --- a/tensorflow/core/runtime_fallback/runtime/kernel_utils.cc +++ b/tensorflow/core/runtime_fallback/runtime/kernel_utils.cc @@ -35,14 +35,14 @@ tfrt::Expected InitEagerContext( bool is_async) { // Copied from TFE_NewContext. std::vector> devices; - tensorflow::Status status = tensorflow::DeviceFactory::AddDevices( + absl::Status status = tensorflow::DeviceFactory::AddDevices( session_opts, "/job:localhost/replica:0/task:0", &devices); if (!status.ok()) { return tfrt::MakeStringError(status.message()); } if (device_mgr != nullptr) { - Status s = device_mgr->AddDevices(std::move(devices)); + absl::Status s = device_mgr->AddDevices(std::move(devices)); DCHECK_OK(s) << "Failed to initialize device manager."; auto r = tsl::core::RefCountPtr( new tensorflow::IntraProcessRendezvous(device_mgr)); diff --git a/tensorflow/core/runtime_fallback/runtime/kernel_utils.h b/tensorflow/core/runtime_fallback/runtime/kernel_utils.h index fc201927f1f4c7..e4978b80475068 100644 --- a/tensorflow/core/runtime_fallback/runtime/kernel_utils.h +++ b/tensorflow/core/runtime_fallback/runtime/kernel_utils.h @@ -18,6 +18,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_KERNEL_UTILS_H_ #define TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_KERNEL_UTILS_H_ +#include +#include #include #include #include @@ -58,7 +60,7 @@ using OwnedAbstractTensorInterface = AutoReleasePtr; // Check if a TensorHandle physically resides on GPU. inline bool IsGpuTensorHandle(const tensorflow::TensorHandle& handle) { - tensorflow::Status dummy_status; + absl::Status dummy_status; // BackingDeviceName is where the tensor is physically located, not where the // op that produces the tensor is. // Note that dummy_status is never set in TensorHandle::BackingDeviceName. @@ -134,9 +136,9 @@ class EagerContextResource { llvm::Error AddDevices(std::vector> devices) { if (!ctx_) return ctx_.takeError(); - Status s = dynamic_cast( - ctx_.get()->local_device_mgr()) - ->AddDevices(std::move(devices)); + absl::Status s = dynamic_cast( + ctx_.get()->local_device_mgr()) + ->AddDevices(std::move(devices)); if (!s.ok()) return tfrt::MakeStringError(s.message()); ctx_.get()->InitPrioritizedDeviceTypeList(); ctx_.get()->pflr()->InitializeDeviceAndFlr(); diff --git a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc index ffba3837db5176..3fd21fcd49a187 100644 --- a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc +++ b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include #include #include #include @@ -20,7 +21,12 @@ limitations under the License. #include #include -#include "absl/strings/str_format.h" +#include "absl/base/casts.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -56,8 +62,8 @@ using ::tfrt::AsyncValue; using ::tfrt::HostContext; using ::tfrt::RCReference; -Status GetTfrtExecutionContext(OpKernelContext* c, - const tfrt::ExecutionContext** exec_ctx) { +absl::Status GetTfrtExecutionContext(OpKernelContext* c, + const tfrt::ExecutionContext** exec_ctx) { // ExecutionContext's address is passed in as an I64 input. exec_ctx is only // valid during the period of one bef execution. It should not be stored and // accessed after bef execution completes. @@ -109,11 +115,12 @@ class FallbackBatchResource : public tensorflow::serving::BatchResourceBase { return batch_function->name(); } - static Status Create(OpKernelContext* c, - const serving::BatchResourceOptions& options, - tsl::RCReference bef_func, - bool enable_large_batch_splitting, bool disable_padding, - std::unique_ptr* resource) { + static absl::Status Create(OpKernelContext* c, + const serving::BatchResourceOptions& options, + tsl::RCReference bef_func, + bool enable_large_batch_splitting, + bool disable_padding, + std::unique_ptr* resource) { const tfrt::ExecutionContext* exec_ctx = nullptr; TF_RETURN_IF_ERROR(GetTfrtExecutionContext(c, &exec_ctx)); @@ -147,7 +154,7 @@ class FallbackBatchResource : public tensorflow::serving::BatchResourceBase { return absl::OkStatus(); } - static Status Create( + static absl::Status Create( OpKernelContext* c, AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options, int32_t max_batch_size, int32_t batch_timeout_micros, @@ -232,7 +239,7 @@ class FallbackBatchResource : public tensorflow::serving::BatchResourceBase { void ProcessFuncBatchImpl( const BatchTask& last_task, absl::Span inputs, std::vector* combined_outputs, - std::function done) const override; + std::function done) const override; HostContext* const host_ctx_; tfrt::ResourceContext* const resource_context_; @@ -246,7 +253,7 @@ tfrt::AsyncValueRef TFTensorToFallbackTensor( return tfrt::MakeAvailableAsyncValueRef(tf_tensor); } -Status SetUpKernelFallbackCompatRequestContextForBatch( +absl::Status SetUpKernelFallbackCompatRequestContextForBatch( tfrt::RequestContextBuilder* builder, tfrt_stub::OpKernelRunnerTable* runner_table, tfd::FallbackResourceArray* resource_array, @@ -305,7 +312,7 @@ absl::StatusOr> SetUpRequestContext( void FallbackBatchResource::ProcessFuncBatchImpl( const BatchTask& last_task, absl::Span inputs, std::vector* combined_outputs, - std::function done) const { + std::function done) const { std::vector> arguments; arguments.reserve(inputs.size() + 1); // The first argument is a Chain. @@ -365,7 +372,7 @@ void FallbackBatchResource::ProcessFuncBatchImpl( result->get().tensor(); } // Aggregate errors. - Status final_status; + absl::Status final_status; if (!errors.empty()) { if (errors.size() > 1) { auto last = std::unique(errors.begin(), errors.end()); diff --git a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.cc b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.cc index 9abbdf411c2149..49af29d381ec36 100644 --- a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.cc +++ b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.cc @@ -19,12 +19,20 @@ limitations under the License. #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.h" #include +#include +#include +#include +#include +#include #include #include #include -#include "absl/strings/str_split.h" -#include "absl/synchronization/mutex.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -40,6 +48,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/profiler/lib/traceme.h" @@ -135,7 +144,7 @@ static AsyncValueRef CreateRuntimeFallbackTensor( TensorHandle* handle, HostContext* host) { OwnedTensorHandle th(handle); int rank; - tensorflow::Status status = th->NumDims(&rank); + absl::Status status = th->NumDims(&rank); if (!status.ok()) return tfrt::MakeErrorAsyncValueRef(tfrt::StrCat( "error getting rank from TF tensor handle: ", status.message())); @@ -244,7 +253,7 @@ OwnedTFTensor MoveDHTToTFTensor(DenseHostTensor&& dht, HostContext* host) { return tf_tensor; } -static tensorflow::Status DecodeDenseAttrToTensorInterface( +static absl::Status DecodeDenseAttrToTensorInterface( const DenseAttr& dense_attr, HostContext* host, tensorflow::TensorInterface* result) { Expected dht = @@ -268,11 +277,11 @@ static tensorflow::Status DecodeDenseAttrToTensorInterface( // Note we currently do not support the following attribute value types: // TFE_OpSetAttrFunction // TFE_OpSetAttrFunctionName -static tensorflow::Status PrepareAttributes(EagerOperation* eager_op, - const OpAttrsRef& attrs, - HostContext* host, - EagerContext* eager_ctx) { - tensorflow::Status status; +static absl::Status PrepareAttributes(EagerOperation* eager_op, + const OpAttrsRef& attrs, + HostContext* host, + EagerContext* eager_ctx) { + absl::Status status; attrs.IterateEntries([eager_op, eager_ctx, status_ptr = &status, host, &attrs](const OpAttrsRawEntry& entry) { // TFE does not expect a device attribute. @@ -450,13 +459,12 @@ static tensorflow::Status PrepareAttributes(EagerOperation* eager_op, return status; } -Status CallEagerExecute(const tfrt::ExecutionContext& exec_ctx, - EagerContext* eager_ctx, const char* op_name, - const char* device_name, - llvm::ArrayRef input_tensor_handles, - const OpAttrsRef& attrs, - llvm::MutableArrayRef - result_tensor_handles) { +absl::Status CallEagerExecute( + const tfrt::ExecutionContext& exec_ctx, EagerContext* eager_ctx, + const char* op_name, const char* device_name, + llvm::ArrayRef input_tensor_handles, const OpAttrsRef& attrs, + llvm::MutableArrayRef + result_tensor_handles) { assert(eager_ctx != nullptr && "EagerContext is NULL"); // Create TF EagerOperation. @@ -492,7 +500,7 @@ AsyncValueRef RuntimeFallbackExecute( const char* op_name, const char* device_name, llvm::ArrayRef arguments, const OpAttrsRef& attrs, llvm::MutableArrayRef> results) { - auto emit_error = [&exec_ctx, results](const tensorflow::Status& status) { + auto emit_error = [&exec_ctx, results](const absl::Status& status) { // Set the correct TFRT error code according to the error propagated from // runtime fallback execution. auto error = EmitErrorAsync(exec_ctx, status); @@ -511,7 +519,7 @@ AsyncValueRef RuntimeFallbackExecute( int num_retvals = results.size(); llvm::SmallVector result_tensor_handles( num_retvals); - Status status; + absl::Status status; if (!ShouldAddHostContextAttr(op_name)) { status = CallEagerExecute(exec_ctx, eager_ctx, op_name, device_name, @@ -682,7 +690,7 @@ static void RuntimeFallbackKernel( int num_retvals = output_tensors.size(); llvm::SmallVector retvals(num_retvals); - tensorflow::Status status = eager_op->Execute( + absl::Status status = eager_op->Execute( absl::MakeSpan(retvals.data(), num_retvals), &num_retvals); TFD_REPORT_AND_RETURN_IF_ERROR(handler, status); @@ -935,7 +943,8 @@ static void RuntimeFallbackExecuteOp( // Get device. Device* device = nullptr; - Status s = eager_ctx->local_device_mgr()->LookupDevice(device_name, &device); + absl::Status s = + eager_ctx->local_device_mgr()->LookupDevice(device_name, &device); if (!s.ok()) { // The device name can be invalid in certain cases. Use default CPU device. VLOG(1) << s.message() << " using default CPU device."; @@ -985,7 +994,7 @@ static void RuntimeFallbackExecuteOp( auto& runtime_fallback_tensor = tfrt_tensor_results[i]->get(); const tensorflow::Tensor* tf_tensor = nullptr; - tensorflow::Status s = + absl::Status s = runtime_fallback_tensor.GetTensorHandle()->Tensor(&tf_tensor); DCHECK(s.ok()) << s; results[i] = @@ -1039,7 +1048,7 @@ static OwnedTensorHandle ConvertTFRTTensorToTFTensorHandle( static llvm::Expected ConvertTFTensorHandleToTFRTTensor( OwnedTensorHandle tensor_handle, HostContext* host) { - tensorflow::Status status; + absl::Status status; // Resolve ensures Tensor is on host CPU. OwnedAbstractTensorInterface tensor_interface{ tensor_handle->Resolve(&status)}; diff --git a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.h b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.h index d0b8c0bfb242f3..833b92f7f24a8f 100644 --- a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.h +++ b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.h @@ -36,13 +36,13 @@ namespace tfd { // Create an EagerOperation to run the op, taking tensorflow::TensorHandle and // returning tensorflow::AbstractTensorHandle*. -Status CallEagerExecute(const tfrt::ExecutionContext& exec_ctx, - EagerContext* eager_ctx, const char* op_name, - const char* device_name, - llvm::ArrayRef input_tensor_handles, - const tfrt::OpAttrsRef& attrs, - llvm::MutableArrayRef - result_tensor_handles); +absl::Status CallEagerExecute( + const tfrt::ExecutionContext& exec_ctx, EagerContext* eager_ctx, + const char* op_name, const char* device_name, + llvm::ArrayRef input_tensor_handles, + const tfrt::OpAttrsRef& attrs, + llvm::MutableArrayRef + result_tensor_handles); // Take and return RuntimeFallbackTensors. tfrt::AsyncValueRef RuntimeFallbackExecute( diff --git a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.cc b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.cc index d383c78b0f3292..00a65db62024e2 100644 --- a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.cc +++ b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.h" +#include #include #include #include @@ -115,7 +116,7 @@ struct RuntimeFallbackOpEntry { static Expected> GetDeviceFromFallbackTensor( const RuntimeFallbackTensor& result_tensor, const ExecutionContext& exec_ctx) { - tensorflow::Status status; + absl::Status status; // Obtain the device. Please note that this device is probably not // the device that the TensorHandle is located on. E.g. for a TPU resource // its device is TPU but it is physicially located on CPU. diff --git a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.cc b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.cc index 15b652086e2c12..b876bc5d9b1ec8 100644 --- a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.cc +++ b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h" +#include +#include #include #include @@ -100,7 +102,7 @@ Expected CopyTfStringTensorToStringHostTensor( // TODO(jingdong): Format the tensor in more user-friendly format, especially // for large tensors. See tensorflow::Tensor::DebugString(). void RuntimeFallbackTensor::Print(tfrt::raw_ostream& os) const { - tensorflow::Status status; + absl::Status status; OwnedAbstractTensorInterface tensor_interface{ tensor_handle_->Resolve(&status)}; assert(status.ok()); @@ -149,7 +151,7 @@ tfrt::Expected CreateRuntimeFallbackTensorFromTfTensorHandle(OwnedTensorHandle owned_th, HostContext* host) { int rank; - tensorflow::Status status = owned_th->NumDims(&rank); + absl::Status status = owned_th->NumDims(&rank); if (!status.ok()) return tfrt::MakeStringError(tfrt::StrCat( "error getting rank from TF tensor handle: ", status.message())); diff --git a/tensorflow/core/runtime_fallback/util/attr_util.cc b/tensorflow/core/runtime_fallback/util/attr_util.cc index c7285ac6118687..3551b1ba0e6057 100644 --- a/tensorflow/core/runtime_fallback/util/attr_util.cc +++ b/tensorflow/core/runtime_fallback/util/attr_util.cc @@ -267,7 +267,7 @@ llvm::Error FillAttrValueMapUsingScalar(const OpAttrsRawEntry& entry, } // namespace -Status ParseTfDataType(absl::string_view dtype, DataType* data_type) { +absl::Status ParseTfDataType(absl::string_view dtype, DataType* data_type) { if (dtype == "DT_INT8") { *data_type = DataType::DT_INT8; return absl::OkStatus(); @@ -429,7 +429,7 @@ tfrt::DType ConvertTfDataTypeToBefAttrType(DataType data_type) { } } -Status ParseBoolAttrValue(absl::string_view attr_value, bool* bool_val) { +absl::Status ParseBoolAttrValue(absl::string_view attr_value, bool* bool_val) { if (attr_value == "false") { *bool_val = false; return absl::OkStatus(); @@ -451,8 +451,8 @@ absl::Status ParseIntAttrValue(absl::string_view attr_value, int64_t* int_val) { return absl::OkStatus(); } -Status ParseTensorAttrValue(absl::string_view attr_value, - tensorflow::Tensor* tensor) { +absl::Status ParseTensorAttrValue(absl::string_view attr_value, + tensorflow::Tensor* tensor) { if (std::is_base_of()) { tensorflow::TensorProto tensor_proto; @@ -476,8 +476,8 @@ Status ParseTensorAttrValue(absl::string_view attr_value, } } -Status ParseTensorShapeAttrValue(absl::string_view attr_value, - std::vector* shape_val) { +absl::Status ParseTensorShapeAttrValue(absl::string_view attr_value, + std::vector* shape_val) { if (attr_value.size() < 2 || attr_value[0] != '[' || attr_value[attr_value.size() - 1] != ']') { return errors::InvalidArgument( @@ -548,8 +548,8 @@ tensorflow::Tensor CreateTfTensorFromDenseAttr(tfrt::DenseAttr attr) { return tensor; } -Status SetUpScalarAttr(tfrt::TypedAttrBase bef_attr, - tensorflow::AttrValue* tf_attr) { +absl::Status SetUpScalarAttr(tfrt::TypedAttrBase bef_attr, + tensorflow::AttrValue* tf_attr) { if (auto shape_attr = bef_attr.dyn_cast()) { if (shape_attr.HasRank()) { tensorflow::PartialTensorShape tf_shape(shape_attr.GetShape()); @@ -579,8 +579,8 @@ Status SetUpScalarAttr(tfrt::TypedAttrBase bef_attr, return absl::OkStatus(); } -Status SetUpScalarFunctionAttr(tfrt::StringAttr func_attr, - tensorflow::AttrValue& tf_attr) { +absl::Status SetUpScalarFunctionAttr(tfrt::StringAttr func_attr, + tensorflow::AttrValue& tf_attr) { tfrt::string_view func_name = func_attr.GetValue(); tf_attr.mutable_func()->set_name(func_name.data(), func_name.size()); return absl::OkStatus(); @@ -603,8 +603,8 @@ void AddTensorToAttrList(tfrt::DenseAttr dense_attr, tf_tensor.AsProtoTensorContent(list->add_tensor()); } -Status SetUpListAttr(tfrt::AggregateAttr aggregate_attr, - tensorflow::AttrValue* tf_attr) { +absl::Status SetUpListAttr(tfrt::AggregateAttr aggregate_attr, + tensorflow::AttrValue* tf_attr) { auto* list = tf_attr->mutable_list(); for (int i = 0; i < aggregate_attr.GetNumElements(); ++i) { auto base = aggregate_attr.GetAttribute(i); @@ -621,8 +621,8 @@ Status SetUpListAttr(tfrt::AggregateAttr aggregate_attr, return absl::OkStatus(); } -Status SetUpListAttr(tfrt::ArrayAttr array_attr, - tensorflow::AttrValue* tf_attr) { +absl::Status SetUpListAttr(tfrt::ArrayAttr array_attr, + tensorflow::AttrValue* tf_attr) { auto* list = tf_attr->mutable_list(); // Handle an empty array case. @@ -669,9 +669,9 @@ Status SetUpListAttr(tfrt::ArrayAttr array_attr, } // namespace -Status SetUpAttrValueMap(tfrt::AggregateAttr op_attr_array, - tfrt::AggregateAttr op_func_attr_array, - tensorflow::AttrValueMap* attr_value_map) { +absl::Status SetUpAttrValueMap(tfrt::AggregateAttr op_attr_array, + tfrt::AggregateAttr op_func_attr_array, + tensorflow::AttrValueMap* attr_value_map) { auto obtain_name_attr_pair = [](tfrt::AggregateAttr attr_array, int i) -> std::pair { diff --git a/tensorflow/core/runtime_fallback/util/attr_util.h b/tensorflow/core/runtime_fallback/util/attr_util.h index 481c7663a7836b..2bb7f1379e1251 100644 --- a/tensorflow/core/runtime_fallback/util/attr_util.h +++ b/tensorflow/core/runtime_fallback/util/attr_util.h @@ -57,24 +57,22 @@ tfrt::DType ConvertTfDataTypeToBefAttrType(DataType data_type); // Parses the tensor valued `attr_value` and constructs the tensor with its // contents in `tensor`. Returns OK status on success, INVALID_ARGUMENT on // failure. -tensorflow::Status ParseTensorAttrValue(absl::string_view attr_value, - tensorflow::Tensor* tensor); +absl::Status ParseTensorAttrValue(absl::string_view attr_value, + tensorflow::Tensor* tensor); // Parses a string of the form "[1,2,3,...]" in `attr_value` and returns the // constituent dimension sizes (shape) in `int_list_val`. Returns // INVALID_ARGUMENT on invalid input. -tensorflow::Status ParseTensorShapeAttrValue(absl::string_view attr_value, - std::vector* shape_val); +absl::Status ParseTensorShapeAttrValue(absl::string_view attr_value, + std::vector* shape_val); // Parses a boolean from `attr_value` into `bool_val` and returns OK status on // success. Returns INVALID_ARGUMENT on invalid input. -tensorflow::Status ParseBoolAttrValue(absl::string_view attr_value, - bool* bool_val); +absl::Status ParseBoolAttrValue(absl::string_view attr_value, bool* bool_val); // Parses an int64_t from `attr_value` into `int_val` and returns OK status on // success. Returns INVLAID_ARGUMENT on invalid input. -tensorflow::Status ParseIntAttrValue(absl::string_view attr_value, - int64_t* int_val); +absl::Status ParseIntAttrValue(absl::string_view attr_value, int64_t* int_val); inline std::vector AttrValueSplit(absl::string_view str) { return absl::StrSplit(str, absl::MaxSplits('$', 1)); @@ -91,9 +89,9 @@ llvm::Error FillAttrValueMap(const tfrt::OpAttrsRef& attrs, AttrValueMap* attr_value_map); // Fills in the passed in AttrValueMap `attr_value_map`. -tensorflow::Status SetUpAttrValueMap(tfrt::AggregateAttr op_attr_array, - tfrt::AggregateAttr op_func_attr_array, - tensorflow::AttrValueMap* attr_value_map); +absl::Status SetUpAttrValueMap(tfrt::AggregateAttr op_attr_array, + tfrt::AggregateAttr op_func_attr_array, + tensorflow::AttrValueMap* attr_value_map); } // namespace tfd } // namespace tensorflow diff --git a/tensorflow/core/summary/BUILD b/tensorflow/core/summary/BUILD index 81b600f036716c..918007d927a5cd 100644 --- a/tensorflow/core/summary/BUILD +++ b/tensorflow/core/summary/BUILD @@ -23,6 +23,7 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core/lib/db:sqlite", + "@com_google_absl//absl/status", ], ) @@ -50,6 +51,11 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:summary_interface", "//tensorflow/core/lib/db:sqlite", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:histogram_proto_cc", ], ) @@ -65,6 +71,9 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/lib/db:sqlite", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@local_xla//xla/tsl/protobuf:histogram_proto_cc", ], ) @@ -80,7 +89,9 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:summary_interface", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -97,6 +108,8 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], ) @@ -113,6 +126,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/lib/png:png_io", + "@com_google_absl//absl/status", ], ) @@ -128,6 +142,9 @@ tf_cc_binary( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/lib/db:sqlite", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -139,5 +156,6 @@ tf_cc_binary( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core/lib/db:sqlite", + "@com_google_absl//absl/log", ], ) diff --git a/tensorflow/core/summary/loader.cc b/tensorflow/core/summary/loader.cc index 8d06f49a66e507..1144fed77165f3 100644 --- a/tensorflow/core/summary/loader.cc +++ b/tensorflow/core/summary/loader.cc @@ -13,13 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include +#include +#include #include -#include "tensorflow/core/summary/schema.h" -#include "tensorflow/core/summary/summary_db_writer.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/lib/db/sqlite.h" #include "tensorflow/core/lib/io/record_reader.h" #include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/summary/schema.h" +#include "tensorflow/core/summary/summary_db_writer.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/event.pb.h" diff --git a/tensorflow/core/summary/schema.cc b/tensorflow/core/summary/schema.cc index 3b6f3d6c5d3ce7..209d2fa9e341a7 100644 --- a/tensorflow/core/summary/schema.cc +++ b/tensorflow/core/summary/schema.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/summary/schema.h" +#include "absl/status/status.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { diff --git a/tensorflow/core/summary/schema.h b/tensorflow/core/summary/schema.h index 4361088c8be7a0..dc13bbfb0e8895 100644 --- a/tensorflow/core/summary/schema.h +++ b/tensorflow/core/summary/schema.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_SUMMARY_SCHEMA_H_ #define TENSORFLOW_CORE_SUMMARY_SCHEMA_H_ +#include "absl/status/status.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/db/sqlite.h" diff --git a/tensorflow/core/summary/schema_test.cc b/tensorflow/core/summary/schema_test.cc index fa21b45b62cca2..08fc3b60936172 100644 --- a/tensorflow/core/summary/schema_test.cc +++ b/tensorflow/core/summary/schema_test.cc @@ -14,8 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/summary/schema.h" -#include - #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/core/summary/summary_converter.cc b/tensorflow/core/summary/summary_converter.cc index 53ed1dfded5b55..458307697ffadf 100644 --- a/tensorflow/core/summary/summary_converter.cc +++ b/tensorflow/core/summary/summary_converter.cc @@ -14,6 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/summary/summary_converter.h" +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/core/summary/summary_converter.h b/tensorflow/core/summary/summary_converter.h index d77d4c670e8d8d..ab19669298ff4f 100644 --- a/tensorflow/core/summary/summary_converter.h +++ b/tensorflow/core/summary/summary_converter.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_SUMMARY_SUMMARY_CONVERTER_H_ #define TENSORFLOW_CORE_SUMMARY_SUMMARY_CONVERTER_H_ +#include "absl/status/status.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/core/summary/summary_db_writer.cc b/tensorflow/core/summary/summary_db_writer.cc index b2d12f5785f7af..eba02509eafffd 100644 --- a/tensorflow/core/summary/summary_db_writer.cc +++ b/tensorflow/core/summary/summary_db_writer.cc @@ -14,16 +14,30 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/summary/summary_db_writer.h" +#include +#include +#include #include - -#include "tensorflow/core/summary/summary_converter.h" +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "xla/tsl/protobuf/error_codes.pb.h" +#include "xla/tsl/protobuf/histogram.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/db/sqlite.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/summary/summary_converter.h" #include "tensorflow/core/util/event.pb.h" // TODO(jart): Break this up into multiple files with excellent unit tests. @@ -269,8 +283,8 @@ class GraphWriter { int64_t is_control = 0; size_t i = name.rfind(':'); if (i != StringPiece::npos) { - if (!strings::safe_strto64(name.substr(i + 1, name.size() - i - 1), - &input_node_idx)) { + if (!absl::SimpleAtoi(name.substr(i + 1, name.size() - i - 1), + &input_node_idx)) { return errors::DataLoss("Bad NodeDef.input: ", name); } name.remove_suffix(name.size() - i); diff --git a/tensorflow/core/summary/summary_db_writer.h b/tensorflow/core/summary/summary_db_writer.h index 9b4644b91bde24..545f849e0a1160 100644 --- a/tensorflow/core/summary/summary_db_writer.h +++ b/tensorflow/core/summary/summary_db_writer.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_SUMMARY_SUMMARY_DB_WRITER_H_ #define TENSORFLOW_CORE_SUMMARY_SUMMARY_DB_WRITER_H_ +#include "absl/status/status.h" #include "tensorflow/core/kernels/summary_interface.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/db/sqlite.h" diff --git a/tensorflow/core/summary/summary_db_writer_test.cc b/tensorflow/core/summary/summary_db_writer_test.cc index 8ddf4ebae66a48..da07ee81cd84b2 100644 --- a/tensorflow/core/summary/summary_db_writer_test.cc +++ b/tensorflow/core/summary/summary_db_writer_test.cc @@ -14,16 +14,25 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/summary/summary_db_writer.h" -#include "tensorflow/core/summary/schema.h" +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "xla/tsl/protobuf/histogram.pb.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/db/sqlite.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/summary/schema.h" #include "tensorflow/core/util/event.pb.h" namespace tensorflow { diff --git a/tensorflow/core/summary/summary_file_writer.cc b/tensorflow/core/summary/summary_file_writer.cc index 89d6c2fb76ef4f..2821edc777842c 100644 --- a/tensorflow/core/summary/summary_file_writer.cc +++ b/tensorflow/core/summary/summary_file_writer.cc @@ -14,17 +14,25 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/summary/summary_file_writer.h" +#include +#include #include +#include +#include +#include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/summary/summary_converter.h" +#include "tensorflow/core/util/event.pb.h" #include "tensorflow/core/util/events_writer.h" namespace tensorflow { diff --git a/tensorflow/core/summary/summary_file_writer.h b/tensorflow/core/summary/summary_file_writer.h index 6d58438de81b7a..847e7cb8d396b1 100644 --- a/tensorflow/core/summary/summary_file_writer.h +++ b/tensorflow/core/summary/summary_file_writer.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_SUMMARY_SUMMARY_FILE_WRITER_H_ #define TENSORFLOW_CORE_SUMMARY_SUMMARY_FILE_WRITER_H_ +#include "absl/status/status.h" #include "tensorflow/core/kernels/summary_interface.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/env.h" diff --git a/tensorflow/core/summary/summary_file_writer_test.cc b/tensorflow/core/summary/summary_file_writer_test.cc index 84f209f10256a8..4c8bf2eb407bb5 100644 --- a/tensorflow/core/summary/summary_file_writer_test.cc +++ b/tensorflow/core/summary/summary_file_writer_test.cc @@ -14,6 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/summary/summary_file_writer.h" +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_join.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/core/summary/vacuum.cc b/tensorflow/core/summary/vacuum.cc index 5febe63f061204..1268b93d040b17 100644 --- a/tensorflow/core/summary/vacuum.cc +++ b/tensorflow/core/summary/vacuum.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include "absl/log/log.h" #include "tensorflow/core/lib/db/sqlite.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" diff --git a/tensorflow/core/tfrt/common/pjrt_state.cc b/tensorflow/core/tfrt/common/pjrt_state.cc index 12a8937d389c9a..bf4290bfca990c 100644 --- a/tensorflow/core/tfrt/common/pjrt_state.cc +++ b/tensorflow/core/tfrt/common/pjrt_state.cc @@ -66,8 +66,8 @@ absl::StatusOr PjRtState::GetOrCreatePjRtClient( return clients_[device_type].get(); } -Status PjRtState::SetPjRtClient(const DeviceType& device_type, - std::unique_ptr client) { +absl::Status PjRtState::SetPjRtClient(const DeviceType& device_type, + std::unique_ptr client) { absl::MutexLock lock(&mu_); if (auto it = clients_.find(device_type); it != clients_.end()) { unused_.push_back(std::move(it->second)); @@ -76,7 +76,7 @@ Status PjRtState::SetPjRtClient(const DeviceType& device_type, return absl::OkStatus(); } -Status PjRtState::MovePjRtClientToUnused(const DeviceType& device_type) { +absl::Status PjRtState::MovePjRtClientToUnused(const DeviceType& device_type) { absl::MutexLock lock(&mu_); if (auto it = clients_.find(device_type); it != clients_.end()) { unused_.push_back(std::move(it->second)); @@ -87,7 +87,7 @@ Status PjRtState::MovePjRtClientToUnused(const DeviceType& device_type) { device_type); } -Status PjRtState::SetPjRtGpuClientCreationInfo( +absl::Status PjRtState::SetPjRtGpuClientCreationInfo( std::unique_ptr info) { absl::MutexLock lock(&mu_); pjrt_gpu_client_creation_info_ = std::move(info); diff --git a/tensorflow/core/tfrt/common/pjrt_state.h b/tensorflow/core/tfrt/common/pjrt_state.h index 84a669f4154394..c3df6806baa2dd 100644 --- a/tensorflow/core/tfrt/common/pjrt_state.h +++ b/tensorflow/core/tfrt/common/pjrt_state.h @@ -57,11 +57,11 @@ class PjRtState : public ResourceBase { absl::StatusOr GetPjRtClient(const DeviceType& device_type); absl::StatusOr GetOrCreatePjRtClient( const DeviceType& device_type); - Status SetPjRtClient(const DeviceType& device_type, - std::unique_ptr client); + absl::Status SetPjRtClient(const DeviceType& device_type, + std::unique_ptr client); // Moves PJRT client to `unused_`. The PJRT client moved to `unused_` will not // be returned by `GetPjRtClient`. - Status MovePjRtClientToUnused(const DeviceType& device_type); + absl::Status MovePjRtClientToUnused(const DeviceType& device_type); string DebugString() const override; // Saves information needed to create a PJRT client (to enable creating a diff --git a/tensorflow/core/tfrt/common/pjrt_util.cc b/tensorflow/core/tfrt/common/pjrt_util.cc index 54ed3060adbc08..dbd4787599c90b 100644 --- a/tensorflow/core/tfrt/common/pjrt_util.cc +++ b/tensorflow/core/tfrt/common/pjrt_util.cc @@ -31,7 +31,7 @@ limitations under the License. namespace tensorflow { -Status SetPjRtClientInTFGlobalResourceManager( +absl::Status SetPjRtClientInTFGlobalResourceManager( const DeviceType& device_type, std::unique_ptr client) { ResourceMgr* rmgr = tfrt_global::GetTFGlobalResourceMgr(); PjRtState* pjrt_state; diff --git a/tensorflow/core/tfrt/common/pjrt_util.h b/tensorflow/core/tfrt/common/pjrt_util.h index 2895f22bf4ea92..aaba7ad959e765 100644 --- a/tensorflow/core/tfrt/common/pjrt_util.h +++ b/tensorflow/core/tfrt/common/pjrt_util.h @@ -29,14 +29,14 @@ namespace tensorflow { // for this device_type already exists, the existing PJRT client will not be // destroyed, and will be kept alive in an "unused client" vector. PJRT API // semantics require the PJRT client to outlive PJRT buffers. -Status SetPjRtClientInTFGlobalResourceManager( +absl::Status SetPjRtClientInTFGlobalResourceManager( const DeviceType& device_type, std::unique_ptr client); // Gets (the most recent) PJRT client for device_type from // TFGlobalResourceManager. absl::StatusOr GetPjRtClient(const DeviceType& device_type); -Status SetPjRtGpuClientCreationInfoInTFGlobalResourceManager( +absl::Status SetPjRtGpuClientCreationInfoInTFGlobalResourceManager( std::unique_ptr info); absl::StatusOr GetPjRtGpuClientCreationInfo(); diff --git a/tensorflow/core/tfrt/fallback/BUILD b/tensorflow/core/tfrt/fallback/BUILD index a7eedfa43bbe89..3ca860b5cb4cd9 100644 --- a/tensorflow/core/tfrt/fallback/BUILD +++ b/tensorflow/core/tfrt/fallback/BUILD @@ -1,5 +1,6 @@ load( "//tensorflow:tensorflow.bzl", + "if_android", "if_mobile", "if_not_mobile", "tf_cc_test", @@ -47,6 +48,7 @@ cc_library( "//tensorflow/core/common_runtime:core_cpu_internal", "//tensorflow/core/common_runtime:device_set", "//tensorflow/core/framework:device_attributes_proto_cc", + "//tensorflow/core/framework:function_proto_cc", "//tensorflow/core/framework:graph_proto_cc", "//tensorflow/core/platform:strcat", "//tensorflow/core/tpu:virtual_device", @@ -73,9 +75,12 @@ tf_cc_test( "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/framework:function_proto_cc", "//tensorflow/core/platform:status_matchers", "//tensorflow/core/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/base:nullability", + "@com_google_googletest//:gtest", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -83,13 +88,21 @@ cc_library( name = "op_kernel_runner", srcs = ["op_kernel_runner.cc"], hdrs = ["op_kernel_runner.h"], - features = tf_features_nolayering_check_if_ios(), + features = tf_features_nolayering_check_if_ios() + if_android(["-layering_check"]), visibility = [ # copybara:uncomment "//tensorflow/core/runtime_fallback:internal", "//visibility:public", ], deps = [ + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", ] + if_mobile([ "//tensorflow/core:portable_tensorflow_lib_lite", ]) + if_not_mobile([ @@ -109,6 +122,11 @@ cc_library( deps = [ ":op_kernel_runner", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@tf_runtime//:hostcontext", ], @@ -124,6 +142,8 @@ cc_library( "//tensorflow/core/platform:status", "//tensorflow/core/util:env_var", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", ], ) diff --git a/tensorflow/core/tfrt/fallback/cost_recorder.cc b/tensorflow/core/tfrt/fallback/cost_recorder.cc index f3dad24ba56254..e0552b15c17806 100644 --- a/tensorflow/core/tfrt/fallback/cost_recorder.cc +++ b/tensorflow/core/tfrt/fallback/cost_recorder.cc @@ -14,10 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tfrt/fallback/cost_recorder.h" +#include +#include +#include #include #include #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/status.h" diff --git a/tensorflow/core/tfrt/fallback/cost_recorder.h b/tensorflow/core/tfrt/fallback/cost_recorder.h index f1abb352c8c493..e1d1b7f410f2c5 100644 --- a/tensorflow/core/tfrt/fallback/cost_recorder.h +++ b/tensorflow/core/tfrt/fallback/cost_recorder.h @@ -18,9 +18,12 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TFRT_FALLBACK_COST_RECORDER_H_ #define TENSORFLOW_CORE_TFRT_FALLBACK_COST_RECORDER_H_ +#include +#include #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/thread_annotations.h" diff --git a/tensorflow/core/tfrt/fallback/cost_recorder_test.cc b/tensorflow/core/tfrt/fallback/cost_recorder_test.cc index 3292957053de48..ee4b49befbbf06 100644 --- a/tensorflow/core/tfrt/fallback/cost_recorder_test.cc +++ b/tensorflow/core/tfrt/fallback/cost_recorder_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tfrt/fallback/cost_recorder.h" +#include #include #include diff --git a/tensorflow/core/tfrt/fallback/fallback_state.cc b/tensorflow/core/tfrt/fallback/fallback_state.cc index d44d7ccda523fb..7b4b26505467d9 100644 --- a/tensorflow/core/tfrt/fallback/fallback_state.cc +++ b/tensorflow/core/tfrt/fallback/fallback_state.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -32,6 +31,8 @@ limitations under the License. #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/device_factory.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/core/tfrt/fallback/fallback_state.h b/tensorflow/core/tfrt/fallback/fallback_state.h index 90ffb6bceb986d..ffbf0695bafbad 100644 --- a/tensorflow/core/tfrt/fallback/fallback_state.h +++ b/tensorflow/core/tfrt/fallback/fallback_state.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/public/session_options.h" diff --git a/tensorflow/core/tfrt/fallback/fallback_state_test.cc b/tensorflow/core/tfrt/fallback/fallback_state_test.cc index e76a37401f395c..3546992cfa7614 100644 --- a/tensorflow/core/tfrt/fallback/fallback_state_test.cc +++ b/tensorflow/core/tfrt/fallback/fallback_state_test.cc @@ -19,13 +19,16 @@ limitations under the License. #include #include +#include #include "absl/base/nullability.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/const_op.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/device_factory.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/platform/status_matchers.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/error_codes.pb.h" diff --git a/tensorflow/core/tfrt/fallback/op_kernel_runner.cc b/tensorflow/core/tfrt/fallback/op_kernel_runner.cc index 557ac7c3812054..9f21b2627186a7 100644 --- a/tensorflow/core/tfrt/fallback/op_kernel_runner.cc +++ b/tensorflow/core/tfrt/fallback/op_kernel_runner.cc @@ -14,11 +14,20 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" +#include #include #include #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/platform/errors.h" namespace tensorflow { diff --git a/tensorflow/core/tfrt/fallback/op_kernel_runner.h b/tensorflow/core/tfrt/fallback/op_kernel_runner.h index e969ba63225d67..317d0956b4a247 100644 --- a/tensorflow/core/tfrt/fallback/op_kernel_runner.h +++ b/tensorflow/core/tfrt/fallback/op_kernel_runner.h @@ -18,13 +18,21 @@ limitations under the License. #include #include +#include #include #include #include #include #include +#include "absl/base/attributes.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/device.h" diff --git a/tensorflow/core/tfrt/fallback/op_kernel_runner_cache.cc b/tensorflow/core/tfrt/fallback/op_kernel_runner_cache.cc index f2e1074c83f3ae..cd035cf21bad9f 100644 --- a/tensorflow/core/tfrt/fallback/op_kernel_runner_cache.cc +++ b/tensorflow/core/tfrt/fallback/op_kernel_runner_cache.cc @@ -14,12 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tfrt/fallback/op_kernel_runner_cache.h" +#include #include #include #include #include #include "absl/base/casts.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" namespace tensorflow { namespace tfrt_stub { diff --git a/tensorflow/core/tfrt/fallback/op_kernel_runner_cache.h b/tensorflow/core/tfrt/fallback/op_kernel_runner_cache.h index 22fe5f5c841253..64f1060e53d75b 100644 --- a/tensorflow/core/tfrt/fallback/op_kernel_runner_cache.h +++ b/tensorflow/core/tfrt/fallback/op_kernel_runner_cache.h @@ -18,6 +18,9 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" #include "tfrt/host_context/location.h" // from @tf_runtime diff --git a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h index 721f372cc1af01..f36356b8a4835a 100644 --- a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h +++ b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_CORE_TFRT_GPU_KERNEL_TFRT_GPU_INIT_H_ #define TENSORFLOW_CORE_TFRT_GPU_KERNEL_TFRT_GPU_INIT_H_ +#include "absl/status/status.h" #include "xla/tsl/framework/serving_device_selector_policies.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/tfrt/runtime/runtime.h" diff --git a/tensorflow/core/tfrt/ifrt/BUILD b/tensorflow/core/tfrt/ifrt/BUILD index b45b675f3b7bbd..46089fe7069c2c 100644 --- a/tensorflow/core/tfrt/ifrt/BUILD +++ b/tensorflow/core/tfrt/ifrt/BUILD @@ -21,6 +21,7 @@ tf_proto_library( srcs = ["ifrt_config.proto"], protodeps = [ "@local_xla//xla:xla_data_proto", + "//tensorflow/core/framework:tensor_proto", ], visibility = ["//visibility:public"], ) @@ -268,6 +269,7 @@ cc_library( srcs = ["ifrt_model_context.cc"], hdrs = ["ifrt_model_context.h"], deps = [ + ":ifrt_config_proto_cc", ":ifrt_executable_registry", ":ifrt_loaded_variable_registry", ":ifrt_persistent_compilation_cache", @@ -276,6 +278,7 @@ cc_library( "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:tf2hlo", "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/core:core_cpu_base", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:env", @@ -526,6 +529,7 @@ tf_cc_test( tags = ["no_oss"], deps = [ ":ifrt_tensor_utils", + ":pjrt_cpu_client_test_lib", ":sharding_utils", "//tensorflow/core:framework", "//tensorflow/core:test", @@ -544,7 +548,6 @@ tf_cc_test( "@local_xla//xla/hlo/ir:hlo", "@local_xla//xla/python/ifrt", "@local_xla//xla/python/ifrt:test_util", - "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib", "@local_xla//xla/python/pjrt_ifrt:xla_ifrt", "@local_xla//xla/tsl/concurrency:ref_count", ], @@ -679,3 +682,19 @@ cc_library( "@tf_runtime//:hostcontext", ], ) + +cc_library( + name = "pjrt_cpu_client_test_lib", + testonly = True, + srcs = ["pjrt_cpu_client_test_lib.cc"], + deps = [ + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", + "@local_xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", + "@local_xla//xla/python/ifrt", + "@local_xla//xla/python/ifrt:test_util", + "@local_xla//xla/python/pjrt_ifrt", + ], + alwayslink = True, +) diff --git a/tensorflow/core/tfrt/ifrt/ifrt_config.proto b/tensorflow/core/tfrt/ifrt/ifrt_config.proto index 784f3ba0625e9a..daa535f59a8cec 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_config.proto +++ b/tensorflow/core/tfrt/ifrt/ifrt_config.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package tensorflow.ifrt_serving; import "xla/xla_data.proto"; +import "tensorflow/core/framework/tensor.proto"; enum IfrtPjRtServingPlatformType { IFRT_PJRT_SERVING_PLATFORM_TYPE_UNSPECIFIED = 0; @@ -24,3 +25,11 @@ enum IfrtServingCoreSelectionPolicy { // Policy that round robin with local ordinal http://shortn/_7BtVe4dkp5. IFRT_SERVING_CORE_SELECTION_POLICY_LOCAL_ROUND_ROBIN = 1; } + +message DefaultSignatureInputConfig { + message Signature { + map default_inputs = 1; + } + + map signatures = 1; +} diff --git a/tensorflow/core/tfrt/ifrt/ifrt_model_context.h b/tensorflow/core/tfrt/ifrt/ifrt_model_context.h index e1eba8c0099abf..7c41a947751827 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_model_context.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_model_context.h @@ -18,9 +18,11 @@ limitations under the License. #include #include +#include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h" @@ -30,6 +32,7 @@ limitations under the License. #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/topology.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" #include "tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.h" @@ -128,6 +131,15 @@ class IfrtModelContext { checkpoint_loader_queue_ = work_queue; } + void set_default_signature_inputs( + const DefaultSignatureInputConfig& default_signature_inputs) { + default_signature_inputs_ = default_signature_inputs; + } + + const DefaultSignatureInputConfig& default_signature_inputs() const { + return default_signature_inputs_; + } + tsl::protobuf::Message* GetCompilationEnvironmentProto() const { return compilation_environment_proto_.get(); } @@ -164,6 +176,8 @@ class IfrtModelContext { std::vector handles_; + DefaultSignatureInputConfig default_signature_inputs_; + IfrtLoadedVariableRegistry loaded_variable_registry_; IfrtRestoreTensorRegistry restore_tensor_registry_; TfToHloCompiler* tf_to_hlo_compiler_ = nullptr; diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc index 2b8cd64c85a076..9f8490f255bb47 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc @@ -414,7 +414,8 @@ absl::StatusOr IfrtServingExecutable::CreateExecutableSynchronously( mlir::OwningOpRef module_copy, const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata, - absl::Span dtypes_and_shapes) { + absl::Span dtypes_and_shapes, + absl::Span variable_arg_indices) { TF_ASSIGN_OR_RETURN(auto host_callback_modules, GetHostCallbackModulesAndRemoveHostFuncs(*module_copy)); if (VLOG_IS_ON(1)) { @@ -422,7 +423,9 @@ IfrtServingExecutable::CreateExecutableSynchronously( } Tf2HloArg tf2hlo_arg{ .module = module_copy.get(), - .input_dtypes_and_shapes = dtypes_and_shapes, + .input_dtypes_and_shapes = std::vector( + dtypes_and_shapes.begin(), dtypes_and_shapes.end()), + .variable_arg_indices = variable_arg_indices, .entry_function_name = signature_name(), .compile_metadata = compile_metadata, .shape_representation_fn = shape_representation_fn_, @@ -533,7 +536,8 @@ IfrtServingExecutable::CreateExecutableSynchronously( xla::ifrt::Future IfrtServingExecutable::LookUpOrCreateExecutable( const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata, - absl::Span dtypes_and_shapes) { + absl::Span dtypes_and_shapes, + absl::Span variable_arg_indices) { std::vector input_shapes; for (const auto& dtype_and_shape : dtypes_and_shapes) { input_shapes.push_back(dtype_and_shape.shape); @@ -572,7 +576,7 @@ IfrtServingExecutable::LookUpOrCreateExecutable( LOG(INFO) << "Cache missed. Building executable"; absl::StatusOr executable_bundle = CreateExecutableSynchronously(std::move(module_copy), compile_metadata, - dtypes_and_shapes); + dtypes_and_shapes, variable_arg_indices); promise.Set(std::move(executable_bundle)); return future; } @@ -649,10 +653,11 @@ absl::StatusOr> IfrtServingExecutable::Execute( } else { device_list = assigned_device_list_; } - TF_ASSIGN_OR_RETURN(SharedCachedExecutableBundle executable_bundle, - LookUpOrCreateExecutable( - compile_metadata, absl::MakeSpan(dtypes_and_shapes)) - .Await()); + TF_ASSIGN_OR_RETURN( + SharedCachedExecutableBundle executable_bundle, + LookUpOrCreateExecutable(compile_metadata, dtypes_and_shapes, + variable_arg_indices) + .Await()); if (executable_bundle->compile_metadata.args().size() != dtypes_and_shapes.size()) { @@ -700,15 +705,28 @@ absl::StatusOr> IfrtServingExecutable::Execute( args.push_back(std::move(single_array)); variable_index++; } else { + // If the input shape is not the same as the shape after Tf2Hlo + // compilation, reshape the input tensor to the expected shape. Note that + // the tensor assignment here won't create a copy. + tensorflow::Tensor reshaped = inputs[i]; + TF_ASSIGN_OR_RETURN( + tensorflow::TensorShape reshaped_shape, + tensorflow::TensorShape::BuildTensorShape( + executable_bundle->compile_metadata.args()[i].shape())); + if (reshaped.shape() != reshaped_shape && + !reshaped.CopyFrom(inputs[i], reshaped_shape)) { + return absl::InternalError("Failed to reshape tensor"); + } + TF_ASSIGN_OR_RETURN( auto single_array, ConvertTensorToArray( - inputs[i], device_list, + reshaped, device_list, executable_bundle->compile_metadata.args()[i].sharding())); args.push_back(single_array); } } - DCHECK_EQ(args.size(), dtypes_and_shapes.size()); + DCHECK_EQ(args.size(), executable_bundle->compile_metadata.args().size()); VLOG(2) << "Start Execution"; diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h index 72e8c0b84df782..b9402d25c4e262 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h @@ -228,12 +228,14 @@ class IfrtServingExecutable { xla::ifrt::Future LookUpOrCreateExecutable( const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata, - absl::Span dtypes_and_shapes); + absl::Span dtypes_and_shapes, + absl::Span variable_arg_indices); absl::StatusOr CreateExecutableSynchronously( mlir::OwningOpRef module_copy, const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata, - absl::Span dtypes_and_shapes); + absl::Span dtypes_and_shapes, + absl::Span variable_arg_indices); absl::StatusOr> CreateSharding( int num_devices, const xla::ifrt::Shape& arg_xla_shape, diff --git a/tensorflow/core/tfrt/ifrt/pjrt_cpu_client_test_lib.cc b/tensorflow/core/tfrt/ifrt/pjrt_cpu_client_test_lib.cc new file mode 100644 index 00000000000000..35b2a1bba525fe --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/pjrt_cpu_client_test_lib.cc @@ -0,0 +1,45 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/status/statusor.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" +#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/test_util.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace { + +const bool kUnused = + (test_util::RegisterClientFactory( + []() -> absl::StatusOr> { + xla::CpuClientOptions options; + options.cpu_device_count = 8; + TF_ASSIGN_OR_RETURN(auto pjrt_client, + xla::GetXlaPjrtCpuClient(std::move(options))); + return std::shared_ptr( + PjRtClient::Create(std::move(pjrt_client))); + }), + true); + +} // namespace +} // namespace ifrt +} // namespace xla diff --git a/tensorflow/core/tfrt/ifrt/sharding_utils.cc b/tensorflow/core/tfrt/ifrt/sharding_utils.cc index 9040c5be7a0002..67ff56dea196eb 100644 --- a/tensorflow/core/tfrt/ifrt/sharding_utils.cc +++ b/tensorflow/core/tfrt/ifrt/sharding_utils.cc @@ -56,7 +56,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.h" #include "tensorflow/core/tpu/kernels/sharding_utils.h" #include "tsl/platform/errors.h" @@ -67,6 +66,15 @@ namespace tensorflow { namespace ifrt_serving { namespace { +struct IndexDomainLexicographicalComparator { + bool operator()(const xla::ifrt::IndexDomain& a, + const xla::ifrt::IndexDomain& b) const { + return std::lexicographical_compare( + a.origin().elements().begin(), a.origin().elements().end(), + b.origin().elements().begin(), b.origin().elements().end()); + } +}; + // Shard the given `input_tensor` into equal shapes of slices. // // `num_paritions_per_axis` specifies the number of partitions along @@ -112,7 +120,7 @@ SplitAndCreateArraysFromHostBuffer( // Fast path for output in the simple no split case. auto assign_or_copy_value_fn = - [&](const tensorflow::Tensor& input) -> Status { + [&](const tensorflow::Tensor& input) -> absl::Status { split_tensors[0] = input; return absl::OkStatus(); }; @@ -256,7 +264,7 @@ absl::StatusOr MakeTensorFromDisassembledTensors( } absl::StatusOr VerifyIndexDomainsAndGetReplicas( - absl::Span index_domains, + absl::Span index_domains, const tensorflow::TensorShape& tensor_shape) { if (index_domains.size() <= 1) { return absl::InvalidArgumentError(absl::StrCat( @@ -286,14 +294,7 @@ absl::StatusOr VerifyIndexDomainsAndGetReplicas( // Verify that each `IndexDomain` appear the same `num_replica` times. Since // shapes are the same for all `IndexDomain`, this also implies each `origin` // appear `num_replica` times. - struct IndexDomainLexicographicalComparator { - bool operator()(const xla::ifrt::IndexDomain& a, - const xla::ifrt::IndexDomain& b) const { - return std::lexicographical_compare( - a.origin().elements().begin(), a.origin().elements().end(), - b.origin().elements().begin(), b.origin().elements().end()); - } - }; + absl::btree_map index_domain_counts; @@ -312,35 +313,9 @@ absl::StatusOr VerifyIndexDomainsAndGetReplicas( } unique_index_domains.push_back(index_domain); } - - // Verify that distances of between origins of neighbouring `IndexDomain` - // bounded by shape. Note that unique_indexx_domains are already in sorted - // order. - auto prev_iter = unique_index_domains.begin(); - auto next_iter = unique_index_domains.begin() + 1; - const auto& bounded_box = first_index_domain->shape(); - while (prev_iter != unique_index_domains.end() && - next_iter != unique_index_domains.end()) { - xla::ifrt::Index offset = next_iter->origin() - prev_iter->origin(); - for (int dim = 0; dim < bounded_box.dims().size(); ++dim) { - if (std::abs(offset.elements()[dim]) != bounded_box.dims()[dim] && - offset.elements()[dim] != 0) { - return absl::FailedPreconditionError(absl::StrCat( - "IndexDomains should not have gap or overlap, but got ", - prev_iter->DebugString(), " and ", next_iter->DebugString(), - " that have offset of ", offset.DebugString())); - } - } - prev_iter = next_iter; - next_iter++; - } - // Verify the last `IndexDomain`'s upper end of the bound matches with the - // tensor shape. Together with the above check, this provides an approximation - // to the following two assumptions: - // 1. the union of all IndexDomain covers the entire global shape array with - // no gaps. - // 2. no two index_domain have any overlap. + // tensor shape. This provides an approximation to the assumptions that the + // union of all IndexDomain covers the entire global shape array with no gaps. std::vector bounded_shape; const auto& last_index_domain = unique_index_domains.back(); bounded_shape.reserve(last_index_domain.shape().dims().size()); @@ -569,17 +544,9 @@ absl::StatusOr> MakeTensorFromArrayHelper( TF_ASSIGN_OR_RETURN(auto index_domains, ifrt_sharding->IndexDomains(ToIfrtShape(tensor_shape))); - TF_ASSIGN_OR_RETURN(int index_domain_replicas, - VerifyIndexDomainsAndGetReplicas( - absl::MakeSpan(index_domains), tensor_shape)); - - if (index_domain_replicas != 1) { - return absl::UnimplementedError(absl::StrCat( - "Subgroup replication is not supported at output. Number " - "of unique index main ", - index_domain_replicas, " is not equal to number of index domains", - index_domains.size())); - } + TF_RETURN_IF_ERROR(VerifyIndexDomainsAndGetReplicas( + absl::MakeSpan(index_domains), tensor_shape) + .status()); TF_ASSIGN_OR_RETURN( std::vector> disassembled_array, @@ -612,11 +579,6 @@ absl::StatusOr> MakeTensorFromArrayHelper( num_slices *= dim_num_concats; num_concats.push_back(dim_num_concats); } - if (num_slices != index_domains.size()) { - return absl::FailedPreconditionError( - absl::StrCat("Expect number of slices is ", index_domains.size(), - " but got ", num_slices)); - } VLOG(2) << "Index domains: "; for (const auto& index_domain : index_domains) { @@ -628,23 +590,17 @@ absl::StatusOr> MakeTensorFromArrayHelper( xla::ifrt::IndexDomain index_domain; tsl::RCReference array; }; - std::vector index_domain_device_arrays; - index_domain_device_arrays.reserve(index_domains.size()); + // `index_domains` could have duplicate index when `replicate_on_last_tile_dim + // is enabled. So, we use the btreemap to remove duplicates and sort the index + // domains lexicographically. + absl::btree_map, + IndexDomainLexicographicalComparator> + index_domain_device_arrays; for (int i = 0; i < index_domains.size(); ++i) { - index_domain_device_arrays.push_back( + index_domain_device_arrays.insert( {index_domains[i], disassembled_array[i]}); } - std::sort( - index_domain_device_arrays.begin(), index_domain_device_arrays.end(), - [](const IndexDomainDeviceArray& a, const IndexDomainDeviceArray& b) { - return std::lexicographical_compare( - a.index_domain.origin().elements().begin(), - a.index_domain.origin().elements().end(), - b.index_domain.origin().elements().begin(), - b.index_domain.origin().elements().end()); - }); - std::vector> arrays_copy_status; std::vector input_tensors; input_tensors.reserve(index_domain_device_arrays.size()); diff --git a/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc b/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc index e7cf58437a6f37..a93e02983dbee9 100644 --- a/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc +++ b/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc @@ -143,7 +143,7 @@ TEST_P(ReshardToTensorTest, MakeHostTensorFromDeviceArrays) { device_list, thread_pool) .Await()); - EXPECT_THAT(GetParam().expected_out_tensor, TensorEq(output_tensor)); + EXPECT_THAT(output_tensor, TensorEq(GetParam().expected_out_tensor)); } INSTANTIATE_TEST_SUITE_P( @@ -323,6 +323,59 @@ INSTANTIATE_TEST_SUITE_P( .device_indices = {3, 2, 1, 0}, .sharding = Tile({2, 1, 2}), }, + // 2-d sharding with last tile replicated. + { + .split_tensors = + { + test::AsTensor({1, 2, 5, 6}, + TensorShape({2, 2})), + test::AsTensor({1, 2, 5, 6}, + TensorShape({2, 2})), + test::AsTensor({9, 10, 13, 14}, + TensorShape({2, 2})), + test::AsTensor({9, 10, 13, 14}, + TensorShape({2, 2})), + }, + .expected_out_tensor = test::AsTensor( + {1, 2, 5, 6, 9, 10, 13, 14}, TensorShape({4, 2})), + .device_indices = {0, 1, 2, 3}, + .sharding = PartialTile({2, 1, 2}), + }, + { + .split_tensors = + { + test::AsTensor({1, 2, 5, 6}, + TensorShape({2, 2})), + test::AsTensor({1, 2, 5, 6}, + TensorShape({2, 2})), + test::AsTensor({1, 2, 5, 6}, + TensorShape({2, 2})), + test::AsTensor({1, 2, 5, 6}, + TensorShape({2, 2})), + }, + .expected_out_tensor = + test::AsTensor({1, 2, 5, 6}, TensorShape({2, 2})), + .device_indices = {0, 1, 2, 3}, + .sharding = PartialTile({1, 1, 4}), + }, + { + .split_tensors = + { + test::AsTensor({1, 2, 5, 6}, + TensorShape({2, 2})), + test::AsTensor({3, 4, 7, 8}, + TensorShape({2, 2})), + test::AsTensor({9, 10, 13, 14}, + TensorShape({2, 2})), + test::AsTensor({11, 12, 15, 16}, + TensorShape({2, 2})), + }, + .expected_out_tensor = test::AsTensor( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + TensorShape({4, 4})), + .device_indices = {0, 1, 2, 3}, + .sharding = PartialTile({2, 2, 1}), + }, })); TEST_P(TensorToArrayTest, MakeArrayFromTensor) { @@ -529,6 +582,29 @@ INSTANTIATE_TEST_SUITE_P( .device_ids = {0, 1, 2, 3}, .sharding = Tile({2, 1, 2}), }, + { + .in_tensor = test::AsTensor( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, + TensorShape({4, 1, 6})), + .expected_out_tensors = + { + test::AsTensor({1, 2, 7, 8}, + TensorShape({2, 1, 2})), + test::AsTensor({3, 4, 9, 10}, + TensorShape({2, 1, 2})), + test::AsTensor({5, 6, 11, 12}, + TensorShape({2, 1, 2})), + test::AsTensor({13, 14, 19, 20}, + TensorShape({2, 1, 2})), + test::AsTensor({15, 16, 21, 22}, + TensorShape({2, 1, 2})), + test::AsTensor({17, 18, 23, 24}, + TensorShape({2, 1, 2})), + }, + .device_ids = {0, 1, 2, 3, 4, 5}, + .sharding = Tile({2, 1, 3}), + }, // Partial replication { .in_tensor = test::AsTensor({1, 2, 3, 4}, diff --git a/tensorflow/core/tfrt/mlrt/interpreter/async_handle.h b/tensorflow/core/tfrt/mlrt/interpreter/async_handle.h index 3f03349283cb36..064a43a6b052fc 100644 --- a/tensorflow/core/tfrt/mlrt/interpreter/async_handle.h +++ b/tensorflow/core/tfrt/mlrt/interpreter/async_handle.h @@ -15,7 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_ASYNC_HANDLE_H_ #define TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_ASYNC_HANDLE_H_ +#include #include +#include #include #include "absl/log/check.h" diff --git a/tensorflow/core/tfrt/mlrt/interpreter/builtin_kernels.cc b/tensorflow/core/tfrt/mlrt/interpreter/builtin_kernels.cc index 2635e21428d899..8ca71ba8e25b88 100644 --- a/tensorflow/core/tfrt/mlrt/interpreter/builtin_kernels.cc +++ b/tensorflow/core/tfrt/mlrt/interpreter/builtin_kernels.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tfrt/mlrt/interpreter/builtin_kernels.h" -#include +#include #include #include diff --git a/tensorflow/core/tfrt/mlrt/interpreter/context.h b/tensorflow/core/tfrt/mlrt/interpreter/context.h index 35329ced3b22ab..f4edd6b6bb2f12 100644 --- a/tensorflow/core/tfrt/mlrt/interpreter/context.h +++ b/tensorflow/core/tfrt/mlrt/interpreter/context.h @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include +#include #include #include #include diff --git a/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc b/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc index 97982e77e8c791..b1019d9041b598 100644 --- a/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc +++ b/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include #include #include #include diff --git a/tensorflow/core/tfrt/mlrt/interpreter/register_span_test.cc b/tensorflow/core/tfrt/mlrt/interpreter/register_span_test.cc index 301a09517b491f..90ec4a489e689d 100644 --- a/tensorflow/core/tfrt/mlrt/interpreter/register_span_test.cc +++ b/tensorflow/core/tfrt/mlrt/interpreter/register_span_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tfrt/mlrt/interpreter/register_span.h" +#include #include #include diff --git a/tensorflow/core/tfrt/run_handler_thread_pool/BUILD b/tensorflow/core/tfrt/run_handler_thread_pool/BUILD index dfcac86644c0b1..f6b963b1186a45 100644 --- a/tensorflow/core/tfrt/run_handler_thread_pool/BUILD +++ b/tensorflow/core/tfrt/run_handler_thread_pool/BUILD @@ -28,6 +28,7 @@ cc_library( hdrs = ["run_handler_util.h"], deps = [ "//tensorflow/core:lib", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/core/tfrt/run_handler_thread_pool/run_handler_util.cc b/tensorflow/core/tfrt/run_handler_thread_pool/run_handler_util.cc index 77d01d3fbdd056..1c5653125e1852 100644 --- a/tensorflow/core/tfrt/run_handler_thread_pool/run_handler_util.cc +++ b/tensorflow/core/tfrt/run_handler_thread_pool/run_handler_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/strings/ascii.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/core/tfrt/saved_model/BUILD b/tensorflow/core/tfrt/saved_model/BUILD index bbeed18fa5d637..08de6ec50c2baa 100644 --- a/tensorflow/core/tfrt/saved_model/BUILD +++ b/tensorflow/core/tfrt/saved_model/BUILD @@ -227,7 +227,7 @@ cc_library( ] + if_google([ "//learning/brain/tfrt/support:export_mlir", "//learning/brain/tfrt/tpu/compiler/mlir:tf_to_tfrt_tpu", - "//learning/brain/tfrt/mlrt/application/pathways:model_config_impl", + "//learning/brain/tfrt/saved_model:model_config_impl", ]), ) diff --git a/tensorflow/core/tfrt/saved_model/python/_pywrap_saved_model.pyi b/tensorflow/core/tfrt/saved_model/python/_pywrap_saved_model.pyi index f6b6465e9383fb..d2a82a9cc34931 100644 --- a/tensorflow/core/tfrt/saved_model/python/_pywrap_saved_model.pyi +++ b/tensorflow/core/tfrt/saved_model/python/_pywrap_saved_model.pyi @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== -from typing import Any - class GraphExecutionRunOptions: def __init__(self) -> None: ... @@ -26,4 +24,4 @@ class Tensor: def LoadSavedModel(saved_model_dir: str = ..., tags: set[str] = ...) -> SavedModel: ... def Run(saved_model: SavedModel = ..., run_options: GraphExecutionRunOptions = ..., name: str = ..., inputs: list[Tensor] = ..., outputs: list[Tensor] = ...) -> None: ... -def RunConvertor(*args, **kwargs) -> Any: ... +def RunConvertor(*args, **kwargs): ... diff --git a/tensorflow/core/tfrt/saved_model/python/saved_model_load_and_run.cc b/tensorflow/core/tfrt/saved_model/python/saved_model_load_and_run.cc index f5b67debacc777..448e05d411d165 100644 --- a/tensorflow/core/tfrt/saved_model/python/saved_model_load_and_run.cc +++ b/tensorflow/core/tfrt/saved_model/python/saved_model_load_and_run.cc @@ -79,7 +79,7 @@ std::string PyObject_ToString(PyObject* o, int length = -1) { if (length < 0 || str.size() <= length) { return str; } - tensorflow::StringPiece str_piece(str); + absl::string_view str_piece(str); return tensorflow::strings::StrCat(str_piece.substr(length), "..."); } diff --git a/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc b/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc index b3ddef175e092f..f5efa79c645ce1 100644 --- a/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc +++ b/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc @@ -165,16 +165,16 @@ class TfrtSession : public tensorflow::Session { backend_compiler_(backend_compiler), device_manager_(std::move(device_manager)) {} - Status Create(const GraphDef& graph) override { + absl::Status Create(const GraphDef& graph) override { return Create(GraphDef(graph)); } - Status Create(GraphDef&& graph) override { + absl::Status Create(GraphDef&& graph) override { absl::MutexLock lock(&session_state_lock_); return CreateLocked(std::move(graph)); } - Status CreateLocked(GraphDef graph) + absl::Status CreateLocked(GraphDef graph) TF_EXCLUSIVE_LOCKS_REQUIRED(session_state_lock_) { if (graph.node_size() == 0) { LOG(ERROR) << "Ignoring empty graph."; @@ -271,16 +271,16 @@ class TfrtSession : public tensorflow::Session { return absl::OkStatus(); } - Status Extend(const GraphDef& graph) override { + absl::Status Extend(const GraphDef& graph) override { return Extend(GraphDef(graph)); } - Status Extend(GraphDef&& graph) override { + absl::Status Extend(GraphDef&& graph) override { absl::MutexLock lock(&session_state_lock_); return ExtendLocked(std::move(graph)); } - Status ExtendLocked(GraphDef graph) + absl::Status ExtendLocked(GraphDef graph) TF_EXCLUSIVE_LOCKS_REQUIRED(session_state_lock_) { if (session_state_ == SessionState::kCreated) { return graph_executor_->Extend(graph); @@ -288,12 +288,13 @@ class TfrtSession : public tensorflow::Session { return CreateLocked(std::move(graph)); } - Status RunInternal(const RunOptions& run_options, - const std::vector>& inputs, - const std::vector& output_tensor_names, - const std::vector& target_node_names, - std::vector* outputs, - const thread::ThreadPoolOptions& thread_pool_options) { + absl::Status RunInternal( + const RunOptions& run_options, + const std::vector>& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs, + const thread::ThreadPoolOptions& thread_pool_options) { { absl::MutexLock lock(&session_state_lock_); if (session_state_ == SessionState::kInitialized) { @@ -354,10 +355,10 @@ class TfrtSession : public tensorflow::Session { return absl::OkStatus(); } - Status Run(const std::vector>& inputs, - const std::vector& output_tensor_names, - const std::vector& target_node_names, - std::vector* outputs) override { + absl::Status Run(const std::vector>& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs) override { return RunInternal(RunOptions{}, inputs, output_tensor_names, target_node_names, outputs, {}); } @@ -365,11 +366,12 @@ class TfrtSession : public tensorflow::Session { // TODO(jingdong): run_options and run_metadata are not fully supported for // now. Need to figure out the required features and how to handle them // properly. - Status Run(const RunOptions& run_options, - const std::vector>& inputs, - const std::vector& output_tensor_names, - const std::vector& target_node_names, - std::vector* outputs, RunMetadata* run_metadata) override { + absl::Status Run(const RunOptions& run_options, + const std::vector>& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs, + RunMetadata* run_metadata) override { return Run(run_options, inputs, output_tensor_names, target_node_names, outputs, run_metadata, /*thread_pool_options=*/{}); } @@ -380,12 +382,13 @@ class TfrtSession : public tensorflow::Session { // TODO(jingdong): run_options and run_metadata are not fully supported for // now. Need to figure out the required features and how to handle them // properly. - Status Run(const RunOptions& run_options, - const std::vector>& inputs, - const std::vector& output_tensor_names, - const std::vector& target_tensor_names, - std::vector* outputs, RunMetadata* run_metadata, - const thread::ThreadPoolOptions& thread_pool_options) override { + absl::Status Run( + const RunOptions& run_options, + const std::vector>& inputs, + const std::vector& output_tensor_names, + const std::vector& target_tensor_names, + std::vector* outputs, RunMetadata* run_metadata, + const thread::ThreadPoolOptions& thread_pool_options) override { return RunInternal(run_options, inputs, output_tensor_names, target_tensor_names, outputs, thread_pool_options); } @@ -393,8 +396,8 @@ class TfrtSession : public tensorflow::Session { /// \brief Creates a `handle` for invoking the subgraph defined by /// `callable_options`. // NOTE: This API is still experimental and may change. - Status MakeCallable(const CallableOptions& callable_options, - CallableHandle* out_handle) override { + absl::Status MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) override { absl::MutexLock lock(&callables_lock_); *out_handle = next_callable_handle_++; assert(callables_.find(*out_handle) == callables_.end()); @@ -409,10 +412,10 @@ class TfrtSession : public tensorflow::Session { /// match the order of names in `CallableOptions::feed()` and /// `CallableOptions::fetch()` when this subgraph was created. /// NOTE: This API is still experimental and may change. - Status RunCallable(CallableHandle handle, - const std::vector& feed_tensors, - std::vector* fetch_tensors, - RunMetadata* run_metadata) override { + absl::Status RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata) override { return RunCallable(handle, feed_tensors, fetch_tensors, run_metadata, {}); } @@ -424,7 +427,7 @@ class TfrtSession : public tensorflow::Session { /// match the order of names in `CallableOptions::feed()` and /// `CallableOptions::fetch()` when this subgraph was created. /// NOTE: This API is still experimental and may change. - Status RunCallable( + absl::Status RunCallable( CallableHandle handle, const std::vector& feed_tensors, std::vector* fetch_tensors, RunMetadata* run_metadata, const thread::ThreadPoolOptions& thread_pool_options) override { @@ -459,7 +462,7 @@ class TfrtSession : public tensorflow::Session { /// \brief Releases resources associated with the given `handle` in this /// session. /// NOTE: This API is still experimental and may change. - Status ReleaseCallable(CallableHandle handle) override { + absl::Status ReleaseCallable(CallableHandle handle) override { absl::MutexLock lock(&callables_lock_); auto it = callables_.find(handle); if (it == callables_.end()) @@ -468,20 +471,20 @@ class TfrtSession : public tensorflow::Session { return absl::OkStatus(); } - Status Close() override { + absl::Status Close() override { absl::MutexLock lock(&session_state_lock_); session_state_ = SessionState::kClosed; return absl::OkStatus(); } - Status ListDevices(std::vector* response) override { + absl::Status ListDevices(std::vector* response) override { return errors::Unimplemented("TfrtSession::ListDevices is Unimplemented."); } - Status LocalDeviceManager(const DeviceMgr** output) override { + absl::Status LocalDeviceManager(const DeviceMgr** output) override { *output = device_manager_.get(); return absl::OkStatus(); } - Status Finalize() override { return absl::OkStatus(); } + absl::Status Finalize() override { return absl::OkStatus(); } private: tfrt::HostContext* GetHostContext() { @@ -519,7 +522,7 @@ class TfrtSession : public tensorflow::Session { return options; } - Status CheckNotClosedLocked() const + absl::Status CheckNotClosedLocked() const TF_EXCLUSIVE_LOCKS_REQUIRED(session_state_lock_) { if (session_state_ == SessionState::kClosed) { return errors::Cancelled("Session has been closed."); @@ -773,7 +776,8 @@ void TfrtSessionFactory::RegisterInitializer(RuntimeInitializer initializer) { InitializerRegistry::Get().Register(std::move(initializer)); } -Status TfrtSessionFactory::InitializeLocked(const TfrtSessionOptions& options) { +absl::Status TfrtSessionFactory::InitializeLocked( + const TfrtSessionOptions& options) { mutex_.AssertHeld(); if (options.use_tpu) { DCHECK(!options.backend_compiler); @@ -808,8 +812,8 @@ bool TfrtSessionFactory::AcceptsOptions(const SessionOptions& options) { return false; } -Status TfrtSessionFactory::NewSession(const SessionOptions& options, - Session** out_session) +absl::Status TfrtSessionFactory::NewSession(const SessionOptions& options, + Session** out_session) TF_LOCKS_EXCLUDED(mutex_) { // TODO(b/206499043): `SessionOptions` should be passed to Saved Model to // create `FallbackState`. @@ -856,14 +860,14 @@ tfrt_stub::Runtime* TfrtSessionFactory::GetRuntime() { return session_factory->runtime_; } -Status InitializeTfrtSession(const TfrtSessionOptions& options) { +absl::Status InitializeTfrtSession(const TfrtSessionOptions& options) { DCHECK(session_factory != nullptr); absl::MutexLock lock(&session_factory->mutex_); DCHECK(!session_factory->IsInitialized()); return UpdateTfrtSessionOptionsLocked(options); } -Status UpdateTfrtSessionOptionsLocked(const TfrtSessionOptions& options) { +absl::Status UpdateTfrtSessionOptionsLocked(const TfrtSessionOptions& options) { DCHECK(session_factory != nullptr); session_factory->mutex_.AssertHeld(); return session_factory->InitializeLocked(options); diff --git a/tensorflow/core/tfrt/tfrt_session/tfrt_session.h b/tensorflow/core/tfrt/tfrt_session/tfrt_session.h index e2ed163d9a6ee6..84de49ebc62048 100644 --- a/tensorflow/core/tfrt/tfrt_session/tfrt_session.h +++ b/tensorflow/core/tfrt/tfrt_session/tfrt_session.h @@ -66,8 +66,9 @@ class TfrtSessionFactory : public tensorflow::SessionFactory { bool AcceptsOptions(const SessionOptions& options) override; - Status NewSession(const SessionOptions& options, - Session** out_session) override TF_LOCKS_EXCLUDED(mutex_); + absl::Status NewSession(const SessionOptions& options, + Session** out_session) override + TF_LOCKS_EXCLUDED(mutex_); // This should only be used for the sake initializing resources for // Python executables. It should only be called before main. @@ -82,10 +83,10 @@ class TfrtSessionFactory : public tensorflow::SessionFactory { private: class ThreadPoolManager; - friend Status InitializeTfrtSession(const TfrtSessionOptions& options); - friend Status UpdateTfrtSessionOptionsLocked( + friend absl::Status InitializeTfrtSession(const TfrtSessionOptions& options); + friend absl::Status UpdateTfrtSessionOptionsLocked( const TfrtSessionOptions& options); - Status InitializeLocked(const TfrtSessionOptions& options) + absl::Status InitializeLocked(const TfrtSessionOptions& options) TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_); bool IsInitialized() const TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { return runtime_ != nullptr; @@ -110,11 +111,11 @@ class TfrtSessionFactory : public tensorflow::SessionFactory { // Configures the TfrtSessionFactory according to `options`. Should not be // called within functions that are passed into // `TfrtSessionFactory::RegisterInitializer`, because it acquires `mutex_`. -Status InitializeTfrtSession(const TfrtSessionOptions& options); +absl::Status InitializeTfrtSession(const TfrtSessionOptions& options); // Version of `InitializeTfrtSession` that can be used within functions passed // into `TfrtSessionFactory::RegisterInitializer`. -Status UpdateTfrtSessionOptionsLocked(const TfrtSessionOptions& options); +absl::Status UpdateTfrtSessionOptionsLocked(const TfrtSessionOptions& options); } // namespace tensorflow #endif // TENSORFLOW_CORE_TFRT_TFRT_SESSION_TFRT_SESSION_H_ diff --git a/tensorflow/core/tfrt/tfrt_session/tfrt_session_init.h b/tensorflow/core/tfrt/tfrt_session/tfrt_session_init.h index e3fdacc6801cd9..7891a0a80c7148 100644 --- a/tensorflow/core/tfrt/tfrt_session/tfrt_session_init.h +++ b/tensorflow/core/tfrt/tfrt_session/tfrt_session_init.h @@ -24,7 +24,7 @@ namespace tensorflow { // // TODO(jingdong): Merge this function with the InitializeTfrtSession() in // tfrt_session.h after we decouple TPU logic from TfrtSession. -inline Status InitializeTfrtSession() { +inline absl::Status InitializeTfrtSession() { SetDefaultLocalSessionImpl(LocalSessionImpl::kTfrtSession); return absl::OkStatus(); } diff --git a/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter.cc b/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter.cc index 80017a11cd8a36..5b54f19a9b8671 100644 --- a/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter.cc +++ b/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter.cc @@ -49,8 +49,8 @@ absl::StatusOr GetDumpDir(absl::string_view dump_dir) { return errors::InvalidArgument("TF_DUMP_GRAPH_PREFIX not specified"); } -Status InsertDumpOpsForNode(Graph& graph, Node& node, - absl::string_view dump_dir) { +absl::Status InsertDumpOpsForNode(Graph& graph, Node& node, + absl::string_view dump_dir) { auto insert = [&](bool is_input, const std::vector edges) { for (const Edge* edge : edges) { if (edge->IsControlEdge()) continue; @@ -85,9 +85,9 @@ Status InsertDumpOpsForNode(Graph& graph, Node& node, } // namespace -Status InsertDumpOps(Graph& graph, - const absl::flat_hash_set& nodes_to_dump, - absl::string_view dump_dir) { +absl::Status InsertDumpOps( + Graph& graph, const absl::flat_hash_set& nodes_to_dump, + absl::string_view dump_dir) { TF_ASSIGN_OR_RETURN(auto dir, GetDumpDir(dump_dir)); auto insert = [&](Graph& graph) { for (Node* node : graph.op_nodes()) { @@ -115,9 +115,10 @@ Status InsertDumpOps(Graph& graph, return absl::OkStatus(); } -Status InsertDumpOps(MetaGraphDef& meta_graph_def, - const absl::flat_hash_set& nodes_to_dump, - absl::string_view dump_dir) { +absl::Status InsertDumpOps( + MetaGraphDef& meta_graph_def, + const absl::flat_hash_set& nodes_to_dump, + absl::string_view dump_dir) { Graph graph(OpRegistry::Global()); TF_RETURN_IF_ERROR( ConvertGraphDefToGraph({}, meta_graph_def.graph_def(), &graph)); diff --git a/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter.h b/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter.h index 759a6c8f4ed581..068c19ba46962e 100644 --- a/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter.h +++ b/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter.h @@ -29,13 +29,14 @@ namespace tfrt_stub { // Rewrites `graph` by inserting dump nodes for `nodes_to_dump`. During graph // execution, the inputs and outputs of `nodes_to_dump` will be dumped to the // folder specified by env var `TF_DUMP_GRAPH_PREFIX`. -Status InsertDumpOps(Graph& graph, - const absl::flat_hash_set& nodes_to_dump, - absl::string_view dump_dir = ""); +absl::Status InsertDumpOps( + Graph& graph, const absl::flat_hash_set& nodes_to_dump, + absl::string_view dump_dir = ""); // Similar to the above, but rewrites a `meta_graph_def`. -Status InsertDumpOps(MetaGraphDef& meta_graph_def, - const absl::flat_hash_set& nodes_to_dump, - absl::string_view dump_dir = ""); +absl::Status InsertDumpOps( + MetaGraphDef& meta_graph_def, + const absl::flat_hash_set& nodes_to_dump, + absl::string_view dump_dir = ""); } // namespace tfrt_stub } // namespace tensorflow diff --git a/tensorflow/core/tfrt/utils/error_util.cc b/tensorflow/core/tfrt/utils/error_util.cc index 2530b98f051041..e00a5be8bbe802 100644 --- a/tensorflow/core/tfrt/utils/error_util.cc +++ b/tensorflow/core/tfrt/utils/error_util.cc @@ -20,8 +20,7 @@ limitations under the License. namespace tfrt { -tfrt::ErrorCode ConvertTfErrorCodeToTfrtErrorCode( - const tensorflow::Status& status) { +tfrt::ErrorCode ConvertTfErrorCodeToTfrtErrorCode(const absl::Status& status) { auto tf_error_code = status.code(); switch (tf_error_code) { default: @@ -34,11 +33,11 @@ tfrt::ErrorCode ConvertTfErrorCodeToTfrtErrorCode( } } -tensorflow::Status CreateTfErrorStatus(const DecodedDiagnostic& error) { +absl::Status CreateTfErrorStatus(const DecodedDiagnostic& error) { return error.status; } -tensorflow::Status ToTfStatus(const tfrt::AsyncValue* av) { +absl::Status ToTfStatus(const tfrt::AsyncValue* av) { CHECK(av != nullptr && av->IsAvailable()) // Crash OK << "Expected a ready async value."; if (av->IsError()) { diff --git a/tensorflow/core/tfrt/utils/error_util.h b/tensorflow/core/tfrt/utils/error_util.h index ee7bcd81dd913f..229b854ae3c69c 100644 --- a/tensorflow/core/tfrt/utils/error_util.h +++ b/tensorflow/core/tfrt/utils/error_util.h @@ -24,14 +24,13 @@ limitations under the License. namespace tfrt { class DecodedDiagnostic; -tfrt::ErrorCode ConvertTfErrorCodeToTfrtErrorCode( - const tensorflow::Status& status); +tfrt::ErrorCode ConvertTfErrorCodeToTfrtErrorCode(const absl::Status& status); -tensorflow::Status CreateTfErrorStatus(const DecodedDiagnostic& error); +absl::Status CreateTfErrorStatus(const DecodedDiagnostic& error); -tensorflow::Status ToTfStatus(const AsyncValue* av); +absl::Status ToTfStatus(const AsyncValue* av); -inline std::string MakeStatusString(tensorflow::Status status) { +inline std::string MakeStatusString(absl::Status status) { switch (static_cast(status.code())) { case absl::StatusCode::kOk: return "OK"; @@ -72,7 +71,7 @@ inline std::string MakeStatusString(tensorflow::Status status) { } } -inline llvm::Error MakeStatusError(tensorflow::Status status) { +inline llvm::Error MakeStatusError(absl::Status status) { return MakeStringError(MakeStatusString(status)); } diff --git a/tensorflow/core/tfrt/utils/error_util_test.cc b/tensorflow/core/tfrt/utils/error_util_test.cc index 06edb63c897af4..126a6fcd7b24e5 100644 --- a/tensorflow/core/tfrt/utils/error_util_test.cc +++ b/tensorflow/core/tfrt/utils/error_util_test.cc @@ -35,7 +35,7 @@ TEST(ErrorUtilTest, AllSupportedErrorConversion){ } TEST(ErrorUtilTest, UnsupportedErrorConversion) { - tensorflow::Status status(absl::StatusCode::kUnauthenticated, "error_test"); + absl::Status status(absl::StatusCode::kUnauthenticated, "error_test"); EXPECT_EQ(ConvertTfErrorCodeToTfrtErrorCode(status), tfrt::ErrorCode::kUnknown); } diff --git a/tensorflow/core/tfrt/utils/graph_partition.cc b/tensorflow/core/tfrt/utils/graph_partition.cc index 3d4b8d6871a549..08f5dce6d5734d 100644 --- a/tensorflow/core/tfrt/utils/graph_partition.cc +++ b/tensorflow/core/tfrt/utils/graph_partition.cc @@ -72,7 +72,7 @@ struct OutputNodeInfo { // input/output info for the following processing. // TODO(b/217581711): Consider to use another GraphToFunctionDef() helper which // does not require _Arg and _Retval nodes. -Status PrepareSubgraphForFunctionConversion( +absl::Status PrepareSubgraphForFunctionConversion( const std::vector& inputs, const std::vector& outputs, const Device* host_device, const std::string& func_name, diff --git a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc index 877e2dd99f69b3..ce6cc28e141f66 100644 --- a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc +++ b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc @@ -324,7 +324,7 @@ TfrtGraphExecutionState::CreateOptimizedGraph( return result; } -Status TfrtGraphExecutionState::Extend(const GraphDef& graph) { +absl::Status TfrtGraphExecutionState::Extend(const GraphDef& graph) { std::unique_ptr new_state; absl::MutexLock lock(&graph_execution_state_mu_); TF_RETURN_IF_ERROR(graph_execution_state_->Extend(graph, &new_state)); @@ -383,8 +383,8 @@ absl::StatusOr FindLoopCondFromExitNode( } // namespace -Status PruneGraphDef(GraphDef& graph_def, - const CallableOptions& callable_options) { +absl::Status PruneGraphDef(GraphDef& graph_def, + const CallableOptions& callable_options) { // Gather node names and create a map from names to NodeDefs. absl::flat_hash_map name_to_node; // All exit nodes in order to track all while loops. @@ -515,7 +515,8 @@ Status PruneGraphDef(GraphDef& graph_def, return absl::OkStatus(); } -Status EliminateRefVariablesFromV1ControlFlow(tensorflow::GraphDef& graph_def) { +absl::Status EliminateRefVariablesFromV1ControlFlow( + tensorflow::GraphDef& graph_def) { auto* op_factory = OpRegistry::Global(); absl::flat_hash_set ref_nodes; @@ -605,7 +606,7 @@ namespace { // `functions_to_optimize`) using `flib` and `fallback_state`. Each // function is converted to a graph and optimized with Placer and Grappler, then // converted back to a function to replace the old one. -Status OptimizeFunctions( +absl::Status OptimizeFunctions( FunctionDefLibrary& flib_proto, const FunctionLibraryDefinition& flib, const FallbackState& fallback_state, const absl::flat_hash_set& functions_to_optimize) { diff --git a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h index 918425d1bda267..2912c2ca57c088 100644 --- a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h +++ b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h @@ -81,7 +81,7 @@ class TfrtGraphExecutionState { tensorflow::GraphImportConfig& graph_import_config); // Extends the current graph by `graph`. - Status Extend(const GraphDef& graph); + absl::Status Extend(const GraphDef& graph); // Return the preprocessed full graph. Note that it does not contain the // function library in the original graph. @@ -127,14 +127,14 @@ class TfrtGraphExecutionState { // pruning (e.g., prunes the input edges to the feed nodes) than // `ComputeTransitiveFanin()` so that the graph can be functionalized properly // later. -Status PruneGraphDef(GraphDef& graph_def, - const CallableOptions& callable_options); +absl::Status PruneGraphDef(GraphDef& graph_def, + const CallableOptions& callable_options); // Eliminates ref variables in V1 control flow, which is required for // functionalization. Current strategy is to insert an identity node between // each ref node and its ref input and in-place update the ref node to its // non-ref counterpart. -Status EliminateRefVariablesFromV1ControlFlow(GraphDef& graph_def); +absl::Status EliminateRefVariablesFromV1ControlFlow(GraphDef& graph_def); // Removes the "_input_shapes" attribute of functions in the graph. void RemoveInputShapesInFunctions(tensorflow::GraphDef& graph_def); diff --git a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc index f22c0982b569c1..026198ebd58ec7 100644 --- a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc +++ b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state_test.cc @@ -97,14 +97,15 @@ TEST_F(PruneGraphDefTest, ConstFeedWithInput) { CompareGraphs(expected, graphdef); } -Status LessThanTenCond(const Scope& scope, const std::vector& inputs, - Output* output) { +absl::Status LessThanTenCond(const Scope& scope, + const std::vector& inputs, + Output* output) { *output = ops::Less(scope, inputs[0], 10); return scope.status(); } -Status AddOneBody(const Scope& scope, const std::vector& inputs, - std::vector* outputs) { +absl::Status AddOneBody(const Scope& scope, const std::vector& inputs, + std::vector* outputs) { outputs->push_back(ops::AddN(scope, {inputs[0], 1})); return scope.status(); } diff --git a/tensorflow/core/tfrt/utils/utils.cc b/tensorflow/core/tfrt/utils/utils.cc index 3cc53af88cc692..e05f86bd1d0b37 100644 --- a/tensorflow/core/tfrt/utils/utils.cc +++ b/tensorflow/core/tfrt/utils/utils.cc @@ -51,9 +51,9 @@ DType ConvertTfDTypeToTfrtDType(tensorflow::DataType dtype) { } } -tensorflow::Status RunRuntimeInitializer(const tfrt::ExecutionContext& exec_ctx, - tfrt::BEFFile* bef_file, - absl::string_view fallback_init_func) { +absl::Status RunRuntimeInitializer(const tfrt::ExecutionContext& exec_ctx, + tfrt::BEFFile* bef_file, + absl::string_view fallback_init_func) { auto* host = exec_ctx.host(); auto* func = bef_file->GetFunction( diff --git a/tensorflow/core/tfrt/utils/utils.h b/tensorflow/core/tfrt/utils/utils.h index 3276101c1db970..970de920936393 100644 --- a/tensorflow/core/tfrt/utils/utils.h +++ b/tensorflow/core/tfrt/utils/utils.h @@ -52,9 +52,9 @@ DType ConvertTfDTypeToTfrtDType(tensorflow::DataType dtype); // // TODO(b/178714905): We should avoid special handling on initialization by // letting compiler to handle it. -tensorflow::Status RunRuntimeInitializer(const tfrt::ExecutionContext& exec_ctx, - tfrt::BEFFile* bef_file, - absl::string_view fallback_init_func); +absl::Status RunRuntimeInitializer(const tfrt::ExecutionContext& exec_ctx, + tfrt::BEFFile* bef_file, + absl::string_view fallback_init_func); // Creates dummy TF devices from the input device names. Currently this method // is used to create the TPU_SYSTEM device for worker server. diff --git a/tensorflow/core/tpu/ops/BUILD b/tensorflow/core/tpu/ops/BUILD index 3cfd5e82da8fa7..9fb3da59e46f62 100644 --- a/tensorflow/core/tpu/ops/BUILD +++ b/tensorflow/core/tpu/ops/BUILD @@ -174,6 +174,8 @@ cc_library( "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", "//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils", "//tensorflow/core/tpu:tpu_embedding_output_layout_utils", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/tensorflow/core/tpu/ops/tpu_embedding_ops.cc b/tensorflow/core/tpu/ops/tpu_embedding_ops.cc index 1e257f9a177325..dc604f83cfc88c 100644 --- a/tensorflow/core/tpu/ops/tpu_embedding_ops.cc +++ b/tensorflow/core/tpu/ops/tpu_embedding_ops.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" diff --git a/tensorflow/core/tpu/ops/tpu_embedding_shape_util.h b/tensorflow/core/tpu/ops/tpu_embedding_shape_util.h index c36d0c1495b514..1d1e91382d2fa9 100644 --- a/tensorflow/core/tpu/ops/tpu_embedding_shape_util.h +++ b/tensorflow/core/tpu/ops/tpu_embedding_shape_util.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "tensorflow/core/framework/tensor_shape.pb.h" diff --git a/tensorflow/core/tpu/tpu_defs.h b/tensorflow/core/tpu/tpu_defs.h index b5c3668067c3b5..70c8c952f16025 100644 --- a/tensorflow/core/tpu/tpu_defs.h +++ b/tensorflow/core/tpu/tpu_defs.h @@ -51,12 +51,31 @@ extern const char* const kTPUReplicateAttr; extern const char* const kOutsideCompilationAttr; // Supported types for TPUs. -inline constexpr std::array kTpuAllTypes = { - {DT_INT32, DT_UINT32, DT_FLOAT8_E4M3FN, DT_FLOAT8_E5M2, DT_HALF, - DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_BOOL, DT_COMPLEX64, - DT_INT64, DT_UINT64, DT_QINT8, DT_QUINT8, DT_QINT32, - DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT4, - DT_UINT4}}; +inline constexpr std::array kTpuAllTypes = { + {DT_INT32, + DT_UINT32, + DT_FLOAT8_E4M3FN, + DT_FLOAT8_E5M2, + DT_HALF, + DT_BFLOAT16, + DT_FLOAT, + DT_DOUBLE, + DT_BOOL, + DT_COMPLEX64, + DT_INT64, + DT_UINT64, + DT_QINT8, + DT_QUINT8, + DT_QINT32, + DT_INT8, + DT_UINT8, + DT_INT16, + DT_UINT16, + DT_INT4, + DT_UINT4, + DT_FLOAT8_E4M3FNUZ, + DT_FLOAT8_E4M3B11FNUZ, + DT_FLOAT8_E5M2FNUZ}}; } // namespace tensorflow diff --git a/tensorflow/core/tpu/tpu_node_device_util.cc b/tensorflow/core/tpu/tpu_node_device_util.cc index d63bebd2aad46d..2a0ca79fd4c982 100644 --- a/tensorflow/core/tpu/tpu_node_device_util.cc +++ b/tensorflow/core/tpu/tpu_node_device_util.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { bool TpuOpFilter(KernelDef* kdef) { - StringPiece op(kdef->op()); + absl::string_view op(kdef->op()); VLOG(2) << "TpuOpFilter " << op; // Enable const string operands to Assert op (b/69167214). if (op == "Const") { diff --git a/tensorflow/core/tpu/virtual_device.cc b/tensorflow/core/tpu/virtual_device.cc index 12ad3c67e9c0ba..3ee148c99c0dce 100644 --- a/tensorflow/core/tpu/virtual_device.cc +++ b/tensorflow/core/tpu/virtual_device.cc @@ -28,7 +28,7 @@ class VirtualDeviceContext : public DeviceContext { Tensor* device_tensor, StatusCallback done, bool sync_dst_compute) const override; void CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, Device* device, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, Tensor* output_tensor, @@ -45,7 +45,7 @@ void VirtualDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } void VirtualDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { diff --git a/tensorflow/core/transforms/cf_sink/pass.cc b/tensorflow/core/transforms/cf_sink/pass.cc index c7404925836435..063e7381b27294 100644 --- a/tensorflow/core/transforms/cf_sink/pass.cc +++ b/tensorflow/core/transforms/cf_sink/pass.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/core/transforms/cf_sink/pass.h" -#include #include #include "llvm/ADT/ScopeExit.h" diff --git a/tensorflow/core/transforms/consolidate_attrs/pass.cc b/tensorflow/core/transforms/consolidate_attrs/pass.cc index cce777dc4f6141..13b48acc8c1eb0 100644 --- a/tensorflow/core/transforms/consolidate_attrs/pass.cc +++ b/tensorflow/core/transforms/consolidate_attrs/pass.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/transforms/consolidate_attrs/pass.h" +#include +#include #include #include #include @@ -374,8 +376,7 @@ void ConsolidateAttributesPassImpl::runOnOperation() { patterns.add( RemoveAttributes( &getContext(), {"T"})); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { getOperation()->emitError(getArgument() + " pass failed"); signalPassFailure(); return; @@ -673,8 +674,7 @@ void PrepareAttributesForExportPassImpl::runOnOperation() { ForOp>(patterns, control_type); patterns.insert( &getContext()); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { getOperation()->emitError(getArgument() + " pass failed"); signalPassFailure(); return; diff --git a/tensorflow/core/transforms/const_dedupe_hoist/pass.cc b/tensorflow/core/transforms/const_dedupe_hoist/pass.cc index d25282631350ec..712f07371f675e 100644 --- a/tensorflow/core/transforms/const_dedupe_hoist/pass.cc +++ b/tensorflow/core/transforms/const_dedupe_hoist/pass.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/core/transforms/const_dedupe_hoist/pass.h" -#include #include #include diff --git a/tensorflow/core/transforms/constant_folding/pass.cc b/tensorflow/core/transforms/constant_folding/pass.cc index e4c8108772cdff..68f3a0f0a23a65 100644 --- a/tensorflow/core/transforms/constant_folding/pass.cc +++ b/tensorflow/core/transforms/constant_folding/pass.cc @@ -3705,7 +3705,7 @@ void ConstantFolding::runOnOperation() { GraphFuncOp func = getOperation(); // The max iteration is the same as the max default iteration in - // applyPatternsAndFoldGreedily. + // applyPatternsGreedily. constexpr int max_iterations = 10; int iteration = 0; diff --git a/tensorflow/core/transforms/func_to_graph/func_to_graph.cc b/tensorflow/core/transforms/func_to_graph/func_to_graph.cc index 0caddc9000c70a..a2ebcf51ed62ad 100644 --- a/tensorflow/core/transforms/func_to_graph/func_to_graph.cc +++ b/tensorflow/core/transforms/func_to_graph/func_to_graph.cc @@ -31,7 +31,7 @@ limitations under the License. namespace mlir { namespace tfg { -tensorflow::Status FuncToGraph(GraphFuncOp func) { +absl::Status FuncToGraph(GraphFuncOp func) { MLIRContext *context = func->getContext(); auto version = func->getAttrOfType("tfg.lifted_graph_version"); if (!version) { diff --git a/tensorflow/core/transforms/func_to_graph/func_to_graph.h b/tensorflow/core/transforms/func_to_graph/func_to_graph.h index abe97eee71490e..5cab621b5b0f43 100644 --- a/tensorflow/core/transforms/func_to_graph/func_to_graph.h +++ b/tensorflow/core/transforms/func_to_graph/func_to_graph.h @@ -25,7 +25,7 @@ namespace tfg { // Lowers a lifted graph func back to the graph. The uses of function arguments // will be replaced with the associated value according to // `tfg.lifted_value_attr` attribute. -tensorflow::Status FuncToGraph(GraphFuncOp func); +absl::Status FuncToGraph(GraphFuncOp func); } // namespace tfg } // namespace mlir diff --git a/tensorflow/core/transforms/functional_to_region/pass.cc b/tensorflow/core/transforms/functional_to_region/pass.cc index 6d21eb179bc6b4..87dbdd855a6f4a 100644 --- a/tensorflow/core/transforms/functional_to_region/pass.cc +++ b/tensorflow/core/transforms/functional_to_region/pass.cc @@ -50,8 +50,8 @@ struct FunctionalToRegionPass // cause the verifiers, which are implemented recursively, to stack // overflow. Set a relatively low iteration limit. config.maxIterations = 16; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) + if (failed( + applyPatternsGreedily(getOperation(), std::move(patterns), config))) signalPassFailure(); } }; diff --git a/tensorflow/core/transforms/graph_to_func/graph_to_func.cc b/tensorflow/core/transforms/graph_to_func/graph_to_func.cc index 1bbd9f24df30e5..d3769db8bcdf00 100644 --- a/tensorflow/core/transforms/graph_to_func/graph_to_func.cc +++ b/tensorflow/core/transforms/graph_to_func/graph_to_func.cc @@ -48,9 +48,9 @@ static ArrayAttr createLiftedValueAttr(OpBuilder &builder, OpResult value) { return builder.getArrayAttr(attrs); } -tensorflow::Status GraphToFunc(GraphOp graph, ArrayRef feeds, - ArrayRef fetches, - ArrayRef control_rets) { +absl::Status GraphToFunc(GraphOp graph, ArrayRef feeds, + ArrayRef fetches, + ArrayRef control_rets) { OpBuilder builder(graph); ControlType control_ty = ControlType::get(graph.getContext()); llvm::SmallVector arg_types; diff --git a/tensorflow/core/transforms/graph_to_func/graph_to_func.h b/tensorflow/core/transforms/graph_to_func/graph_to_func.h index 1283b6804e53f2..94723c96d38aa6 100644 --- a/tensorflow/core/transforms/graph_to_func/graph_to_func.h +++ b/tensorflow/core/transforms/graph_to_func/graph_to_func.h @@ -28,17 +28,16 @@ namespace tfg { // function arguments, `fetches` for function returned values, and // `control_rets` for returned control values. The Graph op is replaced in-place // by a GraphFuncOp with a name defined in the dialect. -tensorflow::Status GraphToFunc(GraphOp graph, ArrayRef feeds, - ArrayRef fetches, - ArrayRef control_rets); +absl::Status GraphToFunc(GraphOp graph, ArrayRef feeds, + ArrayRef fetches, ArrayRef control_rets); // Lifts a graph into a function, using the provided array of `feeds` for // function arguments, `fetches` for function returned values, and // `control_rets` for returned control values. The Graph op is replaced in-place // by a GraphFuncOp with a name defined in the dialect. -tensorflow::Status GraphToFunc(GraphOp graph, ArrayRef feeds_names, - ArrayRef fetches_names, - ArrayRef control_rets); +absl::Status GraphToFunc(GraphOp graph, ArrayRef feeds_names, + ArrayRef fetches_names, + ArrayRef control_rets); } // namespace tfg } // namespace mlir diff --git a/tensorflow/core/transforms/region_to_functional/pass.cc b/tensorflow/core/transforms/region_to_functional/pass.cc index 62d7d5061a68af..75e62d9b9cede5 100644 --- a/tensorflow/core/transforms/region_to_functional/pass.cc +++ b/tensorflow/core/transforms/region_to_functional/pass.cc @@ -53,8 +53,8 @@ struct RegionToFunctionalPass // Iterate until all regions have been outlined. This is guaranteed to // terminate because the IR can only hold a finite depth of regions. config.maxIterations = GreedyRewriteConfig::kNoLimit; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { getOperation()->emitError(getArgument() + " pass failed"); signalPassFailure(); } diff --git a/tensorflow/core/transforms/remapper/pass.cc b/tensorflow/core/transforms/remapper/pass.cc index 06025170061e42..189f2f3a666439 100644 --- a/tensorflow/core/transforms/remapper/pass.cc +++ b/tensorflow/core/transforms/remapper/pass.cc @@ -776,7 +776,7 @@ class Remapper : public impl::RemapperBase { }; void Remapper::runOnOperation() { - if (failed(applyPatternsAndFoldGreedily(getOperation(), final_patterns_))) { + if (failed(applyPatternsGreedily(getOperation(), final_patterns_))) { signalPassFailure(); } } diff --git a/tensorflow/core/transforms/utils/eval_utils.cc b/tensorflow/core/transforms/utils/eval_utils.cc index c70781fc6edade..9002a9b27fe615 100644 --- a/tensorflow/core/transforms/utils/eval_utils.cc +++ b/tensorflow/core/transforms/utils/eval_utils.cc @@ -63,7 +63,7 @@ tensorflow::Allocator *SimpleDevice::GetAllocator( return tensorflow::cpu_allocator(); } -tensorflow::Status SimpleDevice::MakeTensorFromProto( +absl::Status SimpleDevice::MakeTensorFromProto( const tensorflow::TensorProto &tensor_proto, const tensorflow::AllocatorAttributes alloc_attrs, tensorflow::Tensor *tensor) { @@ -111,7 +111,7 @@ LogicalResult EvaluateOperation(tensorflow::DeviceBase *cpu_device, input_tensor_value.tensor = &input_tensor; } - tensorflow::Status status; + absl::Status status; std::unique_ptr op_kernel = tensorflow::CreateOpKernel( tensorflow::DEVICE_CPU, cpu_device, cpu_device->GetAllocator({}), node_def, TF_GRAPH_DEF_VERSION, &status); diff --git a/tensorflow/core/transforms/utils/eval_utils.h b/tensorflow/core/transforms/utils/eval_utils.h index 972ce493b52e99..28128938d358a6 100644 --- a/tensorflow/core/transforms/utils/eval_utils.h +++ b/tensorflow/core/transforms/utils/eval_utils.h @@ -42,7 +42,7 @@ class SimpleDevice : public tensorflow::DeviceBase { SimpleDevice(); ~SimpleDevice() override; - tensorflow::Status MakeTensorFromProto( + absl::Status MakeTensorFromProto( const tensorflow::TensorProto& tensor_proto, const tensorflow::AllocatorAttributes alloc_attrs, tensorflow::Tensor* tensor) override; diff --git a/tensorflow/core/transforms/utils/op_cat_helper.cc b/tensorflow/core/transforms/utils/op_cat_helper.cc index 1347072cd87676..114a2d971da5c0 100644 --- a/tensorflow/core/transforms/utils/op_cat_helper.cc +++ b/tensorflow/core/transforms/utils/op_cat_helper.cc @@ -86,7 +86,7 @@ bool OpCatHelper::IsAggregate(TFOp op) { return !attr || !mlir::isa(attr.getValue()); } const tensorflow::OpDef *op_def = nullptr; - tensorflow::Status status = tensorflow::OpRegistry::Global()->LookUpOpDef( + absl::Status status = tensorflow::OpRegistry::Global()->LookUpOpDef( op->getName().stripDialect().data(), &op_def); return status.ok() && op_def->is_aggregate(); } @@ -97,7 +97,7 @@ bool OpCatHelper::IsCommutative(TFOp op) { return !attr || !mlir::isa(attr.getValue()); } const tensorflow::OpDef *op_def = nullptr; - tensorflow::Status status = tensorflow::OpRegistry::Global()->LookUpOpDef( + absl::Status status = tensorflow::OpRegistry::Global()->LookUpOpDef( op->getName().stripDialect().data(), &op_def); return status.ok() && op_def->is_commutative(); } diff --git a/tensorflow/core/util/autotune_maps/BUILD b/tensorflow/core/util/autotune_maps/BUILD index eba72070b7b4cb..9445c3145f9576 100644 --- a/tensorflow/core/util/autotune_maps/BUILD +++ b/tensorflow/core/util/autotune_maps/BUILD @@ -176,6 +176,7 @@ tf_cuda_library( "//tensorflow/core:framework", "//tensorflow/core/platform:status", "//tensorflow/core/platform:str_util", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@local_xla//xla:status_macros", "@local_xla//xla/stream_executor:dnn", diff --git a/tensorflow/core/util/autotune_maps/autotune_serialize.h b/tensorflow/core/util/autotune_maps/autotune_serialize.h index 8c8bdc2f7e13a7..745eb1ad61f3de 100644 --- a/tensorflow/core/util/autotune_maps/autotune_serialize.h +++ b/tensorflow/core/util/autotune_maps/autotune_serialize.h @@ -27,6 +27,7 @@ limitations under the License. #include +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "tensorflow/core/platform/status.h" diff --git a/tensorflow/core/util/debug_events_writer.cc b/tensorflow/core/util/debug_events_writer.cc index 9790422adc2701..7dfbbcf982fec5 100644 --- a/tensorflow/core/util/debug_events_writer.cc +++ b/tensorflow/core/util/debug_events_writer.cc @@ -69,7 +69,7 @@ absl::Status SingleDebugEventFileWriter::Init() { } void SingleDebugEventFileWriter::WriteSerializedDebugEvent( - StringPiece debug_event_str) { + absl::string_view debug_event_str) { if (record_writer_ == nullptr) { if (!Init().ok()) { LOG(ERROR) << "Write failed because file could not be opened."; diff --git a/tensorflow/core/util/debug_events_writer.h b/tensorflow/core/util/debug_events_writer.h index 1fa4718d45e30e..7b1042790d7913 100644 --- a/tensorflow/core/util/debug_events_writer.h +++ b/tensorflow/core/util/debug_events_writer.h @@ -53,7 +53,7 @@ class SingleDebugEventFileWriter { absl::Status Init(); - void WriteSerializedDebugEvent(tensorflow::StringPiece debug_event_str); + void WriteSerializedDebugEvent(absl::string_view debug_event_str); absl::Status Flush(); absl::Status Close(); diff --git a/tensorflow/core/util/dump_graph.cc b/tensorflow/core/util/dump_graph.cc index c8eb3d48060d71..adf49e492e5c33 100644 --- a/tensorflow/core/util/dump_graph.cc +++ b/tensorflow/core/util/dump_graph.cc @@ -121,7 +121,7 @@ class StderrWritableFile : public WritableFile { public: StderrWritableFile() = default; - absl::Status Append(StringPiece data) override { + absl::Status Append(absl::string_view data) override { fprintf(stderr, "%.*s", static_cast(data.size()), data.data()); return absl::OkStatus(); } @@ -133,7 +133,7 @@ class StderrWritableFile : public WritableFile { return absl::OkStatus(); } - absl::Status Name(StringPiece* result) const override { + absl::Status Name(absl::string_view* result) const override { *result = "stderr"; return absl::OkStatus(); } @@ -200,7 +200,7 @@ absl::Status WriteProtoToUniqueFile(const tensorflow::protobuf::Message& proto, absl ::StrCat("Unknown format: ", format)); } TF_RETURN_IF_ERROR(file->Append(s)); - StringPiece name; + absl::string_view name; TF_RETURN_IF_ERROR(file->Name(&name)); VLOG(5) << name; VLOG(5) << s; @@ -213,7 +213,7 @@ absl::Status WriteProtoToUniqueFile( if (!SerializeToStringDeterministic(proto, &s)) { return errors::Internal("Failed to serialize proto to string."); } - StringPiece name; + absl::string_view name; TF_RETURN_IF_ERROR(file->Name(&name)); VLOG(5) << name; VLOG(5) << s; diff --git a/tensorflow/core/util/events_writer.cc b/tensorflow/core/util/events_writer.cc index 6be31c499d33ae..80aadf73deafe2 100644 --- a/tensorflow/core/util/events_writer.cc +++ b/tensorflow/core/util/events_writer.cc @@ -106,7 +106,7 @@ string EventsWriter::FileName() { return filename_; } -void EventsWriter::WriteSerializedEvent(StringPiece event_str) { +void EventsWriter::WriteSerializedEvent(absl::string_view event_str) { if (recordio_writer_ == nullptr) { if (!InitIfNeeded().ok()) { LOG(ERROR) << "Write failed because file could not be opened."; diff --git a/tensorflow/core/util/events_writer.h b/tensorflow/core/util/events_writer.h index 06eaee845eb6a6..a06eac7db5d8ee 100644 --- a/tensorflow/core/util/events_writer.h +++ b/tensorflow/core/util/events_writer.h @@ -68,7 +68,7 @@ class EventsWriter { // Append "event_str", a serialized Event, to the file. // Note that this function does NOT check that de-serializing event_str // results in a valid Event proto. The tensorflow:: bit makes SWIG happy. - void WriteSerializedEvent(tensorflow::StringPiece event_str); + void WriteSerializedEvent(absl::string_view event_str); // EventWriter automatically flushes and closes on destruction, but // these two methods are provided for users who want to write to disk sooner diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index fafafa94ef0bda..b4fac84e7aa017 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -125,7 +125,7 @@ namespace parsed { class Feature { public: Feature() = default; - explicit Feature(StringPiece serialized) : serialized_(serialized) {} + explicit Feature(absl::string_view serialized) : serialized_(serialized) {} absl::Status ParseDataType(DataType* dtype) { DCHECK(dtype != nullptr); @@ -327,14 +327,14 @@ class Feature { return true; } - StringPiece GetSerialized() const { return serialized_; } + absl::string_view GetSerialized() const { return serialized_; } private: // TODO(lew): Pair of uint8* would be more natural. - StringPiece serialized_; + absl::string_view serialized_; }; -using FeatureMapEntry = std::pair; +using FeatureMapEntry = std::pair; using Example = std::vector; } // namespace parsed @@ -364,13 +364,14 @@ inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) { return false; // unrecognized tag type } -bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) { +bool ParseString(protobuf::io::CodedInputStream* stream, + absl::string_view* result) { DCHECK(stream != nullptr); DCHECK(result != nullptr); uint32 length; if (!stream->ReadVarint32(&length)) return false; if (length == 0) { - *result = StringPiece(nullptr, 0); + *result = absl::string_view(nullptr, 0); return true; } const void* stream_alias; @@ -379,7 +380,7 @@ bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) { return false; } if (static_cast(stream_size) < length) return false; - *result = StringPiece(static_cast(stream_alias), length); + *result = absl::string_view(static_cast(stream_alias), length); stream->Skip(length); return true; } @@ -401,7 +402,7 @@ bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream, break; case kDelimitedTag(2): { - StringPiece feature_string_piece; + absl::string_view feature_string_piece; if (!ParseString(stream, &feature_string_piece)) return false; feature_map_entry->second = parsed::Feature(feature_string_piece); break; @@ -451,7 +452,7 @@ bool ParseExample(protobuf::io::CodedInputStream* stream, return true; } -bool ParseExample(StringPiece serialized, parsed::Example* example) { +bool ParseExample(absl::string_view serialized, parsed::Example* example) { DCHECK(example != nullptr); protobuf::io::CodedInputStream stream( reinterpret_cast(serialized.data()), serialized.size()); @@ -561,13 +562,13 @@ struct SparseBuffer { }; struct SeededHasher { - uint64 operator()(StringPiece s) const { + uint64 operator()(absl::string_view s) const { return Hash64(s.data(), s.size(), seed); } uint64 seed{0xDECAFCAFFE}; }; -void LogDenseFeatureDataLoss(StringPiece feature_name) { +void LogDenseFeatureDataLoss(absl::string_view feature_name) { LOG(WARNING) << "Data loss! Feature '" << feature_name << "' is present in multiple concatenated " "tf.Examples. Ignoring all but last one."; @@ -578,7 +579,7 @@ void LogDenseFeatureDataLoss(StringPiece feature_name) { duplicated_dense_feature->GetCell()->IncrementBy(1); } -void LogSparseFeatureDataLoss(StringPiece feature_name) { +void LogSparseFeatureDataLoss(absl::string_view feature_name) { LOG(WARNING) << "Data loss! Feature '" << feature_name << "' is present in multiple concatenated " "tf.Examples. Ignoring all but last one."; @@ -626,7 +627,7 @@ absl::Status FastParseSerializedExample( parsed::FeatureMapEntry& name_and_feature = parsed_example[parsed_example_size - i - 1]; - const StringPiece feature_name = name_and_feature.first; + const absl::string_view feature_name = name_and_feature.first; parsed::Feature& feature = name_and_feature.second; std::pair d_and_type; @@ -647,7 +648,7 @@ absl::Status FastParseSerializedExample( if (feature_name != config_feature_name) continue; } - auto example_error = [&](StringPiece suffix) { + auto example_error = [&](absl::string_view suffix) { return errors::InvalidArgument("Name: ", example_name, ", Key: ", feature_name, ", Index: ", example_index, ". ", suffix); @@ -690,7 +691,7 @@ absl::Status FastParseSerializedExample( const std::size_t offset = example_index * num_elements; - auto shape_error = [&](size_t size, StringPiece type_str) { + auto shape_error = [&](size_t size, absl::string_view type_str) { return example_error(strings::StrCat( "Number of ", type_str, " values != expected. " @@ -742,7 +743,7 @@ absl::Status FastParseSerializedExample( "Expected type: ", DataTypeString(config.dense[d].dtype))); } - auto shape_error = [&](size_t size, StringPiece type_str) { + auto shape_error = [&](size_t size, absl::string_view type_str) { return example_error(strings::StrCat( "Number of ", type_str, " values is not a multiple of stride length. Saw ", size, @@ -1448,7 +1449,8 @@ absl::Status FastParseExample(const Config& config, } absl::Status FastParseSingleExample(const Config& config, - StringPiece serialized, Result* result) { + absl::string_view serialized, + Result* result) { DCHECK(result != nullptr); // Check config so we can safely CHECK(false) in switches on config.*.dtype TF_RETURN_IF_ERROR(CheckConfigDataTypes(config)); @@ -1555,7 +1557,7 @@ absl::Status FastParseSingleExample(const Config& config, parsed::FeatureMapEntry& name_and_feature = parsed_example[parsed_example_size - i - 1]; - const StringPiece feature_name = name_and_feature.first; + const absl::string_view feature_name = name_and_feature.first; parsed::Feature& feature = name_and_feature.second; std::pair d_and_type; @@ -1576,7 +1578,7 @@ absl::Status FastParseSingleExample(const Config& config, if (feature_name != config_feature_name) continue; } - auto example_error = [feature_name](StringPiece suffix) { + auto example_error = [feature_name](absl::string_view suffix) { return errors::InvalidArgument("Key: ", feature_name, ". ", suffix); }; @@ -1847,7 +1849,7 @@ struct FeatureProtos { // Proto substrings from each serialized SequenceExample that correspond // with this feature. `protos_present` records whether the proto had a // value defined (even if that value is empty). - std::vector protos; + std::vector protos; std::vector protos_present; // Information derived from protos: @@ -1860,7 +1862,7 @@ struct FeatureProtos { }; // Map from feature name to FeatureProtos for that feature. -using FeatureProtosMap = absl::flat_hash_map; +using FeatureProtosMap = absl::flat_hash_map; string ExampleName(const absl::Span example_names, int n) { return example_names.empty() ? "" : example_names[n]; @@ -2132,7 +2134,7 @@ absl::Status ExtractFeaturesFromSequenceExamples( } auto limit = stream.PushLimit(length); while (!stream.ExpectAtEnd()) { - StringPiece key, value; + absl::string_view key, value; uint32 length; if (!stream.ExpectTag(kDelimitedTag(1)) || !stream.ReadVarint32(&length)) { diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h index edc72f47e773ca..6ba6d89ab5aa01 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.h +++ b/tensorflow/core/util/example_proto_fast_parsing.h @@ -42,8 +42,8 @@ namespace example { // in Example. struct FastParseExampleConfig { struct Dense { - Dense(StringPiece feature_name, DataType dtype, PartialTensorShape shape, - Tensor default_value, bool variable_length, + Dense(absl::string_view feature_name, DataType dtype, + PartialTensorShape shape, Tensor default_value, bool variable_length, std::size_t elements_per_stride) : feature_name(feature_name), // TODO(mrry): Switch to preallocated // tstring when this is available. @@ -66,7 +66,7 @@ struct FastParseExampleConfig { }; struct Sparse { - Sparse(StringPiece feature_name, DataType dtype) + Sparse(absl::string_view feature_name, DataType dtype) : feature_name(feature_name), // TODO(mrry): Switch to preallocated // tstring when this is available. dtype(dtype) {} @@ -77,7 +77,8 @@ struct FastParseExampleConfig { }; struct Ragged { - Ragged(StringPiece feature_name, DataType dtype, DataType splits_dtype) + Ragged(absl::string_view feature_name, DataType dtype, + DataType splits_dtype) : feature_name(feature_name), // TODO(mrry): Switch to preallocated // tstring when this is available. dtype(dtype), @@ -143,7 +144,8 @@ absl::Status FastParseExample(const FastParseExampleConfig& config, typedef FastParseExampleConfig FastParseSingleExampleConfig; absl::Status FastParseSingleExample(const FastParseSingleExampleConfig& config, - StringPiece serialized, Result* result); + absl::string_view serialized, + Result* result); // Parses a batch of serialized SequenceExample protos and converts them into // result according to given config. diff --git a/tensorflow/core/util/gpu_solvers.h b/tensorflow/core/util/gpu_solvers.h index cf1689fff55fea..8a84f17c1980c4 100644 --- a/tensorflow/core/util/gpu_solvers.h +++ b/tensorflow/core/util/gpu_solvers.h @@ -365,7 +365,7 @@ class GpuSolver { template Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A, int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, - int ldvt, int* dev_lapack_info) TF_MUST_USE_RESULT; + int ldvt, int* dev_lapack_info); // QR factorization. // Computes QR factorization A = Q * R. @@ -428,15 +428,14 @@ class GpuSolver { const Scalar* alpha, /* host or device pointer */ const Scalar* A, int lda, const Scalar* beta, /* host or device pointer */ - const Scalar* B, int ldb, Scalar* C, - int ldc) const TF_MUST_USE_RESULT; + const Scalar* B, int ldb, Scalar* C, int ldc) const; // Computes the Cholesky factorization A = L * L^H for a single matrix. // Returns OkStatus() if the kernel was launched successfully. See: // http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf template Status Potrf(cublasFillMode_t uplo, int n, Scalar* dev_A, int lda, - int* dev_lapack_info) TF_MUST_USE_RESULT; + int* dev_lapack_info); // Computes the Cholesky factorization A = L * L^H for a batch of small // matrices. @@ -445,21 +444,20 @@ class GpuSolver { template Status PotrfBatched(cublasFillMode_t uplo, int n, const Scalar* const host_a_dev_ptrs[], int lda, - DeviceLapackInfo* dev_lapack_info, - int batch_size) TF_MUST_USE_RESULT; + DeviceLapackInfo* dev_lapack_info, int batch_size); // LU factorization. // Computes LU factorization with partial pivoting P * A = L * U. // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrf template Status Getrf(int m, int n, Scalar* dev_A, int lda, int* dev_pivots, - int* dev_lapack_info) TF_MUST_USE_RESULT; + int* dev_lapack_info); // Uses LU factorization to solve A * X = B. // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrs template Status Getrs(cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda, const int* pivots, Scalar* B, int ldb, - int* dev_lapack_info) const TF_MUST_USE_RESULT; + int* dev_lapack_info) const; // Computes partially pivoted LU factorizations for a batch of small matrices. // Returns OkStatus() if the kernel was launched successfully. See: @@ -467,7 +465,7 @@ class GpuSolver { template Status GetrfBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda, int* dev_pivots, DeviceLapackInfo* dev_lapack_info, - int batch_size) TF_MUST_USE_RESULT; + int batch_size); // Batched linear solver using LU factorization from getrfBatched. // Notice that lapack_info is returned on the host, as opposed to @@ -477,8 +475,7 @@ class GpuSolver { Status GetrsBatched(cublasOperation_t trans, int n, int nrhs, const Scalar* const dev_Aarray[], int lda, const int* devIpiv, const Scalar* const dev_Barray[], - int ldb, int* host_lapack_info, - int batch_size) TF_MUST_USE_RESULT; + int ldb, int* host_lapack_info, int batch_size); // Computes matrix inverses for a batch of small matrices. Uses the outputs // from GetrfBatched. Returns OkStatus() if the kernel was launched @@ -488,8 +485,7 @@ class GpuSolver { Status GetriBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, const Scalar* const host_a_inverse_dev_ptrs[], int ldainv, - DeviceLapackInfo* dev_lapack_info, - int batch_size) TF_MUST_USE_RESULT; + DeviceLapackInfo* dev_lapack_info, int batch_size); // Computes matrix inverses for a batch of small matrices with size n < 32. // Returns OkStatus() if the kernel was launched successfully. See: @@ -498,7 +494,7 @@ class GpuSolver { Status MatInvBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda, const Scalar* const host_a_inverse_dev_ptrs[], int ldainv, DeviceLapackInfo* dev_lapack_info, - int batch_size) TF_MUST_USE_RESULT; + int batch_size); // QR factorization. // Computes QR factorization A = Q * R. @@ -506,7 +502,7 @@ class GpuSolver { // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-geqrf template Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_tau, - int* dev_lapack_info) TF_MUST_USE_RESULT; + int* dev_lapack_info); // Overwrite matrix C by product of C and the unitary Householder matrix Q. // The Householder matrix Q is represented by the output from Geqrf in dev_a @@ -519,7 +515,7 @@ class GpuSolver { template Status Unmqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n, int k, const Scalar* dev_a, int lda, const Scalar* dev_tau, - Scalar* dev_c, int ldc, int* dev_lapack_info) TF_MUST_USE_RESULT; + Scalar* dev_c, int ldc, int* dev_lapack_info); // Overwrites QR factorization produced by Geqrf by the unitary Householder // matrix Q. On input, the Householder matrix Q is represented by the output @@ -529,7 +525,7 @@ class GpuSolver { // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-orgqr template Status Ungqr(int m, int n, int k, Scalar* dev_a, int lda, - const Scalar* dev_tau, int* dev_lapack_info) TF_MUST_USE_RESULT; + const Scalar* dev_tau, int* dev_lapack_info); // Hermitian (Symmetric) Eigen decomposition. // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-syevd @@ -537,7 +533,7 @@ class GpuSolver { Status Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, Scalar* dev_A, int lda, typename Eigen::NumTraits::Real* dev_W, - int* dev_lapack_info) TF_MUST_USE_RESULT; + int* dev_lapack_info); // Singular value decomposition. // Returns OkStatus() if the kernel was launched successfully. @@ -546,7 +542,7 @@ class GpuSolver { template Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A, int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, - int ldvt, int* dev_lapack_info) TF_MUST_USE_RESULT; + int ldvt, int* dev_lapack_info); template Status GesvdjBatched(cusolverEigMode_t jobz, int m, int n, Scalar* dev_A, int lda, Scalar* dev_S, Scalar* dev_U, int ldu, diff --git a/tensorflow/core/util/memmapped_file_system.cc b/tensorflow/core/util/memmapped_file_system.cc index 2dd82aeeff12b8..c3729774d5a07c 100644 --- a/tensorflow/core/util/memmapped_file_system.cc +++ b/tensorflow/core/util/memmapped_file_system.cc @@ -61,21 +61,21 @@ class RandomAccessFileFromMemmapped : public RandomAccessFile { ~RandomAccessFileFromMemmapped() override = default; - absl::Status Name(StringPiece* result) const override { + absl::Status Name(absl::string_view* result) const override { return errors::Unimplemented( "RandomAccessFileFromMemmapped does not support Name()"); } - absl::Status Read(uint64 offset, size_t to_read, StringPiece* result, + absl::Status Read(uint64 offset, size_t to_read, absl::string_view* result, char* scratch) const override { if (offset >= length_) { - *result = StringPiece(scratch, 0); + *result = absl::string_view(scratch, 0); return absl::Status(absl::StatusCode::kOutOfRange, "Read after file end"); } const uint64 region_left = std::min(length_ - offset, static_cast(to_read)); - *result = - StringPiece(reinterpret_cast(data_) + offset, region_left); + *result = absl::string_view(reinterpret_cast(data_) + offset, + region_left); return (region_left == to_read) ? absl::OkStatus() : absl::Status(absl::StatusCode::kOutOfRange, diff --git a/tensorflow/core/util/memmapped_file_system_test.cc b/tensorflow/core/util/memmapped_file_system_test.cc index 26e15450921e01..9e9bce6a883349 100644 --- a/tensorflow/core/util/memmapped_file_system_test.cc +++ b/tensorflow/core/util/memmapped_file_system_test.cc @@ -93,8 +93,8 @@ TEST(MemmappedFileSystemTest, SimpleTest) { // The memory region can be bigger but not less than Tensor size. ASSERT_GE(memory_region->length(), test_tensor.TotalBytes()); EXPECT_EQ(test_tensor.tensor_data(), - StringPiece(static_cast(memory_region->data()), - test_tensor.TotalBytes())); + absl::string_view(static_cast(memory_region->data()), + test_tensor.TotalBytes())); // Check that GetFileSize works. uint64 file_size = 0; TF_ASSERT_OK(memmapped_env.GetFileSize(kTensor2FileName, &file_size)); diff --git a/tensorflow/core/util/memmapped_file_system_writer.cc b/tensorflow/core/util/memmapped_file_system_writer.cc index 411dbc51733a48..ce5d435b8a7a3f 100644 --- a/tensorflow/core/util/memmapped_file_system_writer.cc +++ b/tensorflow/core/util/memmapped_file_system_writer.cc @@ -80,7 +80,7 @@ absl::Status MemmappedFileSystemWriter::SaveProtobuf( namespace { -StringPiece EncodeUint64LittleEndian(uint64 val, char* output_buffer) { +absl::string_view EncodeUint64LittleEndian(uint64 val, char* output_buffer) { for (unsigned int i = 0; i < sizeof(uint64); ++i) { output_buffer[i] = (val >> i * 8); } @@ -116,7 +116,7 @@ absl::Status MemmappedFileSystemWriter::AdjustAlignment(uint64 alignment) { static constexpr uint64 kFillerBufferSize = 16; const char kFillerBuffer[kFillerBufferSize] = {}; for (uint64 rest = to_write_for_alignment; rest > 0;) { - StringPiece sp(kFillerBuffer, std::min(rest, kFillerBufferSize)); + absl::string_view sp(kFillerBuffer, std::min(rest, kFillerBufferSize)); TF_RETURN_IF_ERROR(output_file_->Append(sp)); rest -= sp.size(); output_file_offset_ += sp.size(); diff --git a/tensorflow/core/util/mirror_pad_mode.cc b/tensorflow/core/util/mirror_pad_mode.cc index 067996c69d07ef..39364886219b29 100644 --- a/tensorflow/core/util/mirror_pad_mode.cc +++ b/tensorflow/core/util/mirror_pad_mode.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { -absl::Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name, +absl::Status GetNodeAttr(const NodeDef& node_def, absl::string_view attr_name, MirrorPadMode* value) { string str_value; TF_RETURN_IF_ERROR(GetNodeAttr(node_def, attr_name, &str_value)); diff --git a/tensorflow/core/util/mirror_pad_mode.h b/tensorflow/core/util/mirror_pad_mode.h index 5675a22739cc82..eea7c7415268a9 100644 --- a/tensorflow/core/util/mirror_pad_mode.h +++ b/tensorflow/core/util/mirror_pad_mode.h @@ -45,7 +45,7 @@ string GetMirrorPadModeAttrString(); class NodeDef; // Specialization to parse an attribute directly into a MirrorPadMode enum. -absl::Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name, +absl::Status GetNodeAttr(const NodeDef& node_def, absl::string_view attr_name, MirrorPadMode* value); } // end namespace tensorflow diff --git a/tensorflow/core/util/padding.cc b/tensorflow/core/util/padding.cc index e502d5eafae769..41989d277b55fc 100644 --- a/tensorflow/core/util/padding.cc +++ b/tensorflow/core/util/padding.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { -absl::Status GetPaddingFromString(StringPiece str_value, Padding* value) { +absl::Status GetPaddingFromString(absl::string_view str_value, Padding* value) { if (str_value == "SAME") { *value = SAME; } else if (str_value == "VALID") { diff --git a/tensorflow/core/util/padding.h b/tensorflow/core/util/padding.h index 9c0cf543a0dc4f..3c1351df96d929 100644 --- a/tensorflow/core/util/padding.h +++ b/tensorflow/core/util/padding.h @@ -61,7 +61,7 @@ std::string GetPaddingAttrStringWithExplicit(); std::string GetExplicitPaddingsAttrString(); // Sets padding value based on the given string padding value. -absl::Status GetPaddingFromString(StringPiece str_value, Padding* value); +absl::Status GetPaddingFromString(absl::string_view str_value, Padding* value); } // end namespace tensorflow diff --git a/tensorflow/core/util/reporter_test.cc b/tensorflow/core/util/reporter_test.cc index 6abcf9f25d6951..68690d94bee066 100644 --- a/tensorflow/core/util/reporter_test.cc +++ b/tensorflow/core/util/reporter_test.cc @@ -28,7 +28,7 @@ namespace tensorflow { namespace { // Tests of all the error paths in log_reader.cc follow: -static void ExpectHasSubstr(StringPiece s, StringPiece expected) { +static void ExpectHasSubstr(absl::string_view s, absl::string_view expected) { EXPECT_TRUE(absl::StrContains(s, expected)) << s << " does not contain " << expected; } diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc index 9948930c67c1a8..93c5a7e9818ae2 100644 --- a/tensorflow/core/util/strided_slice_op.cc +++ b/tensorflow/core/util/strided_slice_op.cc @@ -79,8 +79,8 @@ struct StridedSliceDenseSpec { } // namespace template -static absl::Status TF_MUST_USE_RESULT BuildDenseSpec( - const StridedSliceSparseSpec& sparse, StridedSliceDenseSpec* dense) { +static absl::Status BuildDenseSpec(const StridedSliceSparseSpec& sparse, + StridedSliceDenseSpec* dense) { if (dense->dims < 0) { return errors::InvalidArgument("Unexpected negative dense.dims: %d", dense->dims); diff --git a/tensorflow/core/util/tensor_bundle/byte_swap_tensor.cc b/tensorflow/core/util/tensor_bundle/byte_swap_tensor.cc index 6e04d4eec0893f..903e3592f7e38f 100644 --- a/tensorflow/core/util/tensor_bundle/byte_swap_tensor.cc +++ b/tensorflow/core/util/tensor_bundle/byte_swap_tensor.cc @@ -50,8 +50,8 @@ namespace { // If num_of_elem is -1, this function will calculate // the number of data based on size and dtype. // Returns: OkStatus() on success, -1 otherwise -Status ByteSwapBuffer(char* buff, size_t size, DataType dtype, - int num_of_elem) { +absl::Status ByteSwapBuffer(char* buff, size_t size, DataType dtype, + int num_of_elem) { int array_len = num_of_elem; size_t bytes_per_elem = 0; @@ -155,13 +155,21 @@ bool IsByteSwappable(DataType dtype) { } } -Status ByteSwapTensor(Tensor* t) { +absl::Status ByteSwapTensor(Tensor* t) { char* buff = const_cast((t->tensor_data().data())); return ByteSwapBuffer(buff, t->tensor_data().size(), t->dtype(), t->NumElements()); } -Status ByteSwapTensorContentInNode(NodeDef& node) { +absl::Status ByteSwapTensorProto(TensorProto* tp) { + std::string content_str = std::string(tp->tensor_content()); + char* buff = const_cast(content_str.data()); + TF_RETURN_IF_ERROR(ByteSwapBuffer(buff, content_str.size(), tp->dtype(), -1)); + tp->set_tensor_content(content_str); + return absl::OkStatus(); +} + +absl::Status ByteSwapTensorContentInNode(NodeDef& node) { if (node.op() == "Const") { auto node_iterator = node.mutable_attr()->find("value"); if (node_iterator != node.mutable_attr()->end()) { @@ -201,7 +209,7 @@ Status ByteSwapTensorContentInNode(NodeDef& node) { return absl::OkStatus(); } -Status ByteSwapTensorContentInMetaGraphDef(MetaGraphDef* meta_graph_def) { +absl::Status ByteSwapTensorContentInMetaGraphDef(MetaGraphDef* meta_graph_def) { for (auto& function : *meta_graph_def->mutable_graph_def() ->mutable_library() ->mutable_function()) @@ -210,7 +218,7 @@ Status ByteSwapTensorContentInMetaGraphDef(MetaGraphDef* meta_graph_def) { return absl::OkStatus(); } -Status ByteSwapTensorContentInGraphDef(GraphDef* graph_def) { +absl::Status ByteSwapTensorContentInGraphDef(GraphDef* graph_def) { for (auto& node : *graph_def->mutable_node()) TF_RETURN_IF_ERROR(ByteSwapTensorContentInNode(node)); return absl::OkStatus(); diff --git a/tensorflow/core/util/tensor_bundle/byte_swap_tensor.h b/tensorflow/core/util/tensor_bundle/byte_swap_tensor.h index dbfd63e355c18d..c86ffc26f1aa72 100644 --- a/tensorflow/core/util/tensor_bundle/byte_swap_tensor.h +++ b/tensorflow/core/util/tensor_bundle/byte_swap_tensor.h @@ -34,19 +34,26 @@ bool IsByteSwappable(DataType dtype); // buffer with this one will also end up byte-swapped. // Returns: OkStatus() on success, -1 otherwise // TODO(frreiss): Should this be a member of the Tensor class? -Status ByteSwapTensor(Tensor *t); +absl::Status ByteSwapTensor(Tensor* t); + +// Byte-swap a tensor proto's backing buffer in place. +// +// Args: +// t: TensorProto to be modified IN PLACE. +// Returns: OkStatus() on success, -1 otherwise +absl::Status ByteSwapTensorProto(TensorProto* tp); // Swap tensor_content field of Const Op Tensors in the named functions // in NodeDef -Status ByteSwapTensorContentInNode(NodeDef& node); +absl::Status ByteSwapTensorContentInNode(NodeDef& node); // Swap tensor_content field of Const Op Tensors in the named functions // in MetaGraphDef -Status ByteSwapTensorContentInMetaGraphDef(MetaGraphDef* meta_graph_def); +absl::Status ByteSwapTensorContentInMetaGraphDef(MetaGraphDef* meta_graph_def); // Swap tensor_content field of Const Op Tensors in the named functions // in GraphDef -Status ByteSwapTensorContentInGraphDef(GraphDef* graph_def); +absl::Status ByteSwapTensorContentInGraphDef(GraphDef* graph_def); } // namespace tensorflow diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index c97356202bcd93..7a34e5da1b895b 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -89,9 +89,10 @@ namespace { // // Checksums the string lengths (as restored uint32 or uint64, not varint64 // bytes) and string bytes, and stores it into "actual_crc32c". -Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements, - size_t offset, size_t size, tstring* destination, - uint32* actual_crc32c, bool need_to_swap_bytes) { +absl::Status ReadStringTensor(io::InputBuffer* buffered_file, + size_t num_elements, size_t offset, size_t size, + tstring* destination, uint32* actual_crc32c, + bool need_to_swap_bytes) { if (size == 0) return absl::OkStatus(); CHECK_GT(size, 0); @@ -160,8 +161,9 @@ Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements, return absl::OkStatus(); } -Status ReadVariantTensor(io::InputBuffer* buffered_file, Tensor* ret, - size_t offset, size_t size, uint32* actual_crc32c) { +absl::Status ReadVariantTensor(io::InputBuffer* buffered_file, Tensor* ret, + size_t offset, size_t size, + uint32* actual_crc32c) { // On-disk format: // [varint64 len1][bytes variant1][4 byte checksum] // .. @@ -233,8 +235,8 @@ tstring* GetStringBackingBuffer(const Tensor& val) { return const_cast(val.flat().data()); } -Status ParseEntryProto(StringPiece key, StringPiece value, - protobuf::MessageLite* out) { +absl::Status ParseEntryProto(StringPiece key, StringPiece value, + protobuf::MessageLite* out) { if (!out->ParseFromArray(value.data(), value.size())) { return errors::DataLoss("Entry for key ", key, " not parseable."); } @@ -245,8 +247,8 @@ Status ParseEntryProto(StringPiece key, StringPiece value, // original content of "bytes_written", and on OK updates it with number of // bytes written. // REQUIRES: val.dtype() != DT_STRING -Status WriteTensor(const Tensor& val, tsl::BufferedWritableFile* out, - size_t* bytes_written) { +absl::Status WriteTensor(const Tensor& val, tsl::BufferedWritableFile* out, + size_t* bytes_written) { DCHECK_NE(val.dtype(), DT_STRING); DCHECK_NE(val.dtype(), DT_VARIANT); *bytes_written = val.TotalBytes(); @@ -260,8 +262,9 @@ Status WriteTensor(const Tensor& val, tsl::BufferedWritableFile* out, // // Checksums all bytes written and stores it into "crc32c". // REQUIRES: val.dtype() == DT_STRING -Status WriteStringTensor(const Tensor& val, tsl::BufferedWritableFile* out, - size_t* bytes_written, uint32* crc32c) { +absl::Status WriteStringTensor(const Tensor& val, + tsl::BufferedWritableFile* out, + size_t* bytes_written, uint32* crc32c) { // On-disk format: // [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes] // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes), @@ -312,8 +315,9 @@ Status WriteStringTensor(const Tensor& val, tsl::BufferedWritableFile* out, return absl::OkStatus(); } -Status WriteVariantTensor(const Tensor& val, tsl::BufferedWritableFile* out, - size_t* bytes_written, uint32* crc32c) { +absl::Status WriteVariantTensor(const Tensor& val, + tsl::BufferedWritableFile* out, + size_t* bytes_written, uint32* crc32c) { // On-disk format: // [varint64 len1][bytes variant1][4 byte checksum] // .. @@ -380,8 +384,8 @@ bool IsFullSlice(const TensorSlice& slice_spec, } } -Status CorruptFileError(const Status& in_status, const string& filename, - const string& detail) { +absl::Status CorruptFileError(const absl::Status& in_status, + const string& filename, const string& detail) { if (in_status.ok()) { return errors::Internal("Unable to read file (", filename, "). Perhaps the file is corrupt or was produced by " @@ -389,7 +393,7 @@ Status CorruptFileError(const Status& in_status, const string& filename, "(", detail, ")"); } - return Status( + return absl::Status( in_status.code(), strings::StrCat("Unable to read file (", filename, "). Perhaps the file is corrupt or was produced by a " @@ -410,14 +414,14 @@ table::Options TableBuilderOptions() { // Writes zeros to output buffer to align the next write to the requested // alignment. "size" is the current size of the buffer and is updated to the // new size. -Status PadAlignment(tsl::BufferedWritableFile* out, int alignment, - int64_t* size) { +absl::Status PadAlignment(tsl::BufferedWritableFile* out, int alignment, + int64_t* size) { int bytes_over = *size % alignment; if (bytes_over == 0) { return absl::OkStatus(); } int bytes_to_write = alignment - bytes_over; - Status status = out->Append(string(bytes_to_write, '\0')); + absl::Status status = out->Append(string(bytes_to_write, '\0')); if (status.ok()) { *size += bytes_to_write; } @@ -453,7 +457,7 @@ BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options) VLOG(1) << "Writing to file " << data_path_; } -Status BundleWriter::Add(StringPiece key, const Tensor& val) { +absl::Status BundleWriter::Add(StringPiece key, const Tensor& val) { if (!status_.ok()) return status_; CHECK_NE(key, kHeaderEntryKey); const string key_string(key); @@ -490,10 +494,10 @@ Status BundleWriter::Add(StringPiece key, const Tensor& val) { return status_; } -Status BundleWriter::AddSlice(StringPiece full_tensor_key, - const TensorShape& full_tensor_shape, - const TensorSlice& slice_spec, - const Tensor& slice_tensor) { +absl::Status BundleWriter::AddSlice(StringPiece full_tensor_key, + const TensorShape& full_tensor_shape, + const TensorSlice& slice_spec, + const Tensor& slice_tensor) { if (!status_.ok()) return status_; CHECK_NE(full_tensor_key, kHeaderEntryKey); @@ -533,7 +537,7 @@ Status BundleWriter::AddSlice(StringPiece full_tensor_key, // TODO(zongheng): on metadata write failure or !status_.ok(), consider removing // the orphaned data file. -Status BundleWriter::Finish() { +absl::Status BundleWriter::Finish() { if (out_) { status_.Update(out_->Close()); out_ = nullptr; @@ -608,8 +612,8 @@ struct MergeState { // Merges entries of "prefix" into the accumulator state "merge". // Returns OK iff the merge succeeds. -static Status MergeOneBundle(Env* env, StringPiece prefix, - MergeState* merge_state) { +static absl::Status MergeOneBundle(Env* env, StringPiece prefix, + MergeState* merge_state) { VLOG(1) << "Merging bundle:" << prefix; const string filename = MetaFilename(prefix); uint64 file_size; @@ -632,7 +636,7 @@ static Status MergeOneBundle(Env* env, StringPiece prefix, "failed to seek to header entry"); } BundleHeaderProto header; - Status s = ParseEntryProto(iter->key(), iter->value(), &header); + absl::Status s = ParseEntryProto(iter->key(), iter->value(), &header); if (!s.ok()) return CorruptFileError(s, filename, "unable to parse header"); merge_state->num_shards += header.num_shards(); @@ -707,12 +711,12 @@ static Status MergeOneBundle(Env* env, StringPiece prefix, return absl::OkStatus(); } -Status MergeBundles(Env* env, absl::Span prefixes, - StringPiece merged_prefix, bool allow_missing_files) { +absl::Status MergeBundles(Env* env, absl::Span prefixes, + StringPiece merged_prefix, bool allow_missing_files) { // Merges all metadata tables. // TODO(zhifengc): KeyValue sorter if it becomes too big. MergeState merge; - Status status = env->CreateDir(string(io::Dirname(merged_prefix))); + absl::Status status = env->CreateDir(string(io::Dirname(merged_prefix))); if (!status.ok() && !errors::IsAlreadyExists(status)) return status; bool atleast_one_file_exists = false; for (auto& prefix : prefixes) { @@ -805,7 +809,7 @@ BundleReader::BundleReader(Env* env, StringPiece prefix, Options options) table::Options o; int64_t cache_size; - Status s = + absl::Status s = ReadInt64FromEnvVar("TF_TABLE_INDEX_CACHE_SIZE_IN_MB", 0, &cache_size); if (s.ok() && cache_size > 0) { index_cache_ = table::NewLRUCache(cache_size << 20); @@ -856,8 +860,8 @@ BundleReader::~BundleReader() { tensor_slices_.clear(); } -Status BundleReader::GetBundleEntryProto(StringPiece key, - BundleEntryProto* entry) { +absl::Status BundleReader::GetBundleEntryProto(StringPiece key, + BundleEntryProto* entry) { entry->Clear(); TF_CHECK_OK(status_); Seek(key); @@ -877,7 +881,8 @@ Status BundleReader::GetBundleEntryProto(StringPiece key, return absl::OkStatus(); } -Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) { +absl::Status BundleReader::GetValue(const BundleEntryProto& entry, + Tensor* val) { Tensor* ret = val; const TensorShape stored_shape(TensorShape(entry.shape())); if (val->NumElements() == 0) { @@ -943,7 +948,7 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) { (entry.size() + kMaxFileReadThreads - 1) / kMaxFileReadThreads; } - std::vector statuses(thread_pool_size); + std::vector statuses(thread_pool_size); auto reader_pool = std::make_unique( Env::Default(), "restore_large_tensor", thread_pool_size); @@ -1019,7 +1024,7 @@ Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) { return absl::OkStatus(); } -Status BundleReader::Lookup(StringPiece key, Tensor* val) { +absl::Status BundleReader::Lookup(StringPiece key, Tensor* val) { CHECK(val != nullptr); BundleEntryProto entry; TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry)); @@ -1033,7 +1038,7 @@ Status BundleReader::Lookup(StringPiece key, Tensor* val) { } } -Status BundleReader::ReadCurrent(Tensor* val) { +absl::Status BundleReader::ReadCurrent(Tensor* val) { CHECK(val != nullptr); BundleEntryProto entry; TF_RETURN_IF_ERROR(ParseEntryProto(iter_->key(), iter_->value(), &entry)); @@ -1051,8 +1056,8 @@ Status BundleReader::ReadCurrent(Tensor* val) { } } -Status BundleReader::LookupTensorSlices(StringPiece key, - std::vector* slices) { +absl::Status BundleReader::LookupTensorSlices( + StringPiece key, std::vector* slices) { slices->clear(); BundleEntryProto entry; TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry)); @@ -1063,17 +1068,18 @@ Status BundleReader::LookupTensorSlices(StringPiece key, return absl::OkStatus(); } -Status BundleReader::LookupSlice(StringPiece full_tensor_key, - const TensorSlice& slice_spec, Tensor* val) { +absl::Status BundleReader::LookupSlice(StringPiece full_tensor_key, + const TensorSlice& slice_spec, + Tensor* val) { CHECK(val != nullptr); BundleEntryProto entry; TF_RETURN_IF_ERROR(GetBundleEntryProto(full_tensor_key, &entry)); return GetSliceValue(full_tensor_key, entry, slice_spec, val); } -Status BundleReader::GetSliceValue(StringPiece full_tensor_key, - const BundleEntryProto& full_tensor_entry, - const TensorSlice& slice_spec, Tensor* val) { +absl::Status BundleReader::GetSliceValue( + StringPiece full_tensor_key, const BundleEntryProto& full_tensor_entry, + const TensorSlice& slice_spec, Tensor* val) { using checkpoint::RegisterTensorSlice; using checkpoint::TensorSliceSet; DCHECK_GE(full_tensor_entry.slices_size(), 0); @@ -1193,8 +1199,8 @@ bool BundleReader::Contains(StringPiece key) { return Valid() && (this->key() == key); } -Status BundleReader::LookupDtypeAndShape(StringPiece key, DataType* dtype, - TensorShape* shape) { +absl::Status BundleReader::LookupDtypeAndShape(StringPiece key, DataType* dtype, + TensorShape* shape) { BundleEntryProto entry; TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry)); *dtype = entry.dtype(); @@ -1202,7 +1208,8 @@ Status BundleReader::LookupDtypeAndShape(StringPiece key, DataType* dtype, return absl::OkStatus(); } -Status BundleReader::LookupTensorShape(StringPiece key, TensorShape* shape) { +absl::Status BundleReader::LookupTensorShape(StringPiece key, + TensorShape* shape) { DataType ignored; return LookupDtypeAndShape(key, &ignored, shape); } @@ -1246,7 +1253,8 @@ BundleCache::FileState* BundleCache::EnsureOpened(std::string name) { return f; } -Status BundleCache::GetFile(const std::string& fname, RandomAccessFile** file) { +absl::Status BundleCache::GetFile(const std::string& fname, + RandomAccessFile** file) { FileState* f = EnsureOpened(fname); *file = f->file.get(); return f->open_status; diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h index e3d8bb590ce411..a0fcb134fbce17 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h @@ -125,7 +125,7 @@ class BundleWriter { // Adds the tensor "val" under key "key". // Across calls "key" must be unique but can be added in any order. - Status Add(absl::string_view key, const Tensor& val); + absl::Status Add(absl::string_view key, const Tensor& val); // Partitioned variables support. // A slice of a full tensor is stored in two entries in the metadata table: @@ -143,14 +143,15 @@ class BundleWriter { // consistent entry for "full_tensor_key" is produced. // // Returns an error if the same slice is added the second time. - Status AddSlice(absl::string_view full_tensor_key, - const TensorShape& full_tensor_shape, - const TensorSlice& slice_spec, const Tensor& slice_tensor); + absl::Status AddSlice(absl::string_view full_tensor_key, + const TensorShape& full_tensor_shape, + const TensorSlice& slice_spec, + const Tensor& slice_tensor); // Finishes the writer and flushes. - Status Finish() TF_MUST_USE_RESULT; + absl::Status Finish(); - Status status() const { return status_; } + absl::Status status() const { return status_; } private: Env* const env_; // Not owned. @@ -162,7 +163,7 @@ class BundleWriter { std::unique_ptr out_; int64_t size_; // Number of bytes written into out_. std::map entries_; - Status status_; + absl::Status status_; BundleWriter(const BundleWriter&) = delete; void operator=(const BundleWriter&) = delete; @@ -190,9 +191,9 @@ class BundleWriter { // // Returns a NotFoundError when "allow_missing_files" is set to false and // any data file named in "prefixes" does not exist. -Status MergeBundles(Env* env, absl::Span prefixes, - absl::string_view merged_prefix, - bool allow_missing_files = false); +absl::Status MergeBundles(Env* env, absl::Span prefixes, + absl::string_view merged_prefix, + bool allow_missing_files = false); class BundleCache; @@ -202,7 +203,7 @@ class BundleCache; // All threads accessing the same BundleReader must synchronize. class BundleReader { public: - BundleReader(Env* const env, absl::string_view prefix, + BundleReader(Env* env, absl::string_view prefix, bool enable_multi_threading_for_testing = false); struct Options { @@ -219,7 +220,7 @@ class BundleReader { // Is ok() iff the reader construction is successful (completed the read of // the metadata). - Status status() const { return status_; } + absl::Status status() const { return status_; } // Queries whether the bundle contains an entry keyed by "key". Calls Seek() // internally, so this call invalidates the reader's current position. @@ -235,20 +236,19 @@ class BundleReader { // // REQUIRES: status().ok() template - Status SortForSequentialAccess( + absl::Status SortForSequentialAccess( std::vector& container, absl::FunctionRef get_key); // Looks up the dtype and the shape of the tensor keyed by "key". // REQUIRES: status().ok() - Status LookupDtypeAndShape(absl::string_view key, DataType* dtype, - TensorShape* shape) TF_MUST_USE_RESULT; + absl::Status LookupDtypeAndShape(absl::string_view key, DataType* dtype, + TensorShape* shape); // Looks up the shape of the tensor keyed by "key". // Clears "shape" if not found. // REQUIRES: status().ok() - Status LookupTensorShape(absl::string_view key, - TensorShape* shape) TF_MUST_USE_RESULT; + absl::Status LookupTensorShape(absl::string_view key, TensorShape* shape); // Looks up the tensor keyed by "key". If "key" refers to a partitioned // tensor, attempts to look up the full contents using all stored slices. @@ -262,7 +262,7 @@ class BundleReader { // // Validates the stored crc32c checksum against the restored bytes. // REQUIRES: status().ok() - Status Lookup(absl::string_view key, Tensor* val) TF_MUST_USE_RESULT; + absl::Status Lookup(absl::string_view key, Tensor* val); // Looks up the tensor pointed to by the internal iterator. // @@ -270,7 +270,7 @@ class BundleReader { // // Validates the stored crc32c checksum against the restored bytes. // REQUIRES: status().ok() && Valid() - Status ReadCurrent(Tensor* val) TF_MUST_USE_RESULT; + absl::Status ReadCurrent(Tensor* val); // Looks up the slices of the tensor keyed by "key". On OK, "slices" // is non-empty if and only if the tensor is a partitioned tensor. @@ -279,17 +279,15 @@ class BundleReader { // a slice with a larger start index in some dimension could come before // another slice with a smaller start index in the same dimension. // REQUIRES: status().ok() - Status LookupTensorSlices(absl::string_view key, - std::vector* slices) - TF_MUST_USE_RESULT; + absl::Status LookupTensorSlices(absl::string_view key, + std::vector* slices); // Looks up a specific slice of a partitioned tensor. // It is only required that the stored slices cover the requested slice, // namely "slice_spec" is a subset of the union of the stored slices. // REQUIRES: status().ok() - Status LookupSlice(absl::string_view full_tensor_key, - const TensorSlice& slice_spec, - Tensor* val) TF_MUST_USE_RESULT; + absl::Status LookupSlice(absl::string_view full_tensor_key, + const TensorSlice& slice_spec, Tensor* val); // Seeks to the first position in the bundle whose key is no less than "key". // REQUIRES: status().ok() @@ -314,28 +312,26 @@ class BundleReader { // Seeks for "key" and reads the metadata proto. // On non-OK return, clears "entry" for the caller. // REQUIRES: status().ok() - Status GetBundleEntryProto(absl::string_view key, - BundleEntryProto* entry) TF_MUST_USE_RESULT; + absl::Status GetBundleEntryProto(absl::string_view key, + BundleEntryProto* entry); // Reads the tensor value described by the metadata proto "entry". // Usage for "val" follows the comment of "Lookup()". - Status GetValue(const BundleEntryProto& entry, - Tensor* val) TF_MUST_USE_RESULT; + absl::Status GetValue(const BundleEntryProto& entry, Tensor* val); // Reads the slice described by "slice_spec". The corresponding full tensor // has key "ful_tensor_key" and metadata proto "full_tensor_entry". // REQUIRES: full_tensor_entry.slices_size() > 0 - Status GetSliceValue(absl::string_view full_tensor_key, - const BundleEntryProto& full_tensor_entry, - const TensorSlice& slice_spec, - Tensor* val) TF_MUST_USE_RESULT; + absl::Status GetSliceValue(absl::string_view full_tensor_key, + const BundleEntryProto& full_tensor_entry, + const TensorSlice& slice_spec, Tensor* val); Env* env_; // Not owned. const std::string prefix_; std::unique_ptr owned_cache_; // may be null BundleCache* cache_; // Not owned, or owned_cache_.get() - Status status_; + absl::Status status_; RandomAccessFile* metadata_; // Owned. table::Table* table_; table::Cache* index_cache_; @@ -365,7 +361,7 @@ class BundleReader { }; template -Status BundleReader::SortForSequentialAccess( +absl::Status BundleReader::SortForSequentialAccess( std::vector& container, absl::FunctionRef get_key) { struct FileOffset { @@ -399,7 +395,7 @@ class BundleCache { // Get the underlying file object for fname. The result will remain valid // while the BundleCache lives. - Status GetFile(const std::string& fname, RandomAccessFile** file); + absl::Status GetFile(const std::string& fname, RandomAccessFile** file); private: // State for each opened file (opened on first read). @@ -407,7 +403,7 @@ class BundleCache { absl::once_flag once; // Ensures file is opened exactly once. std::unique_ptr file; - Status open_status; // Records any error encountered on open + absl::Status open_status; // Records any error encountered on open }; FileState* EnsureOpened(std::string name); diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc index ac0b15644f106b..cd2b73c1afdfe7 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc @@ -149,7 +149,7 @@ std::vector AllTensorKeys(BundleReader* reader) { // Writes out the metadata file of a bundle again, with the endianness marker // bit flipped. -Status FlipEndiannessBit(const string& prefix) { +absl::Status FlipEndiannessBit(const string& prefix) { Env* env = Env::Default(); const string metadata_tmp_path = Prefix("some_tmp_path"); std::unique_ptr metadata_file; @@ -998,7 +998,7 @@ TEST(TensorBundleTest, Checksum) { auto ExpectLookupFails = [](const string& prefix, const string& key, const string& expected_msg, Tensor& val) { BundleReader reader(Env::Default(), Prefix(prefix)); - Status status = reader.Lookup(key, &val); + absl::Status status = reader.Lookup(key, &val); EXPECT_TRUE(errors::IsDataLoss(status)); EXPECT_TRUE(absl::StrContains(status.ToString(), expected_msg)); }; diff --git a/tensorflow/core/util/tensor_slice_reader.cc b/tensorflow/core/util/tensor_slice_reader.cc index 6911b58a563a2b..9902cd23d3aa61 100644 --- a/tensorflow/core/util/tensor_slice_reader.cc +++ b/tensorflow/core/util/tensor_slice_reader.cc @@ -56,7 +56,7 @@ class TensorSliceReaderTable : public TensorSliceReader::Table { std::unique_ptr iter(table_->NewIterator()); iter->Seek(key); if (iter->Valid() && iter->key() == key) { - StringPiece v = iter->value(); + absl::string_view v = iter->value(); value->assign(v.data(), v.size()); return true; } else { diff --git a/tensorflow/core/util/tensor_slice_writer.cc b/tensorflow/core/util/tensor_slice_writer.cc index 35fd86b5a86af9..884cd0a42d6ce1 100644 --- a/tensorflow/core/util/tensor_slice_writer.cc +++ b/tensorflow/core/util/tensor_slice_writer.cc @@ -41,7 +41,7 @@ class TableBuilder : public TensorSliceWriter::Builder { option.compression = table::kNoCompression; builder_ = std::make_unique(option, f); } - void Add(StringPiece key, StringPiece val) override { + void Add(absl::string_view key, absl::string_view val) override { builder_->Add(key, val); } absl::Status Finish(int64_t* file_size) override { diff --git a/tensorflow/core/util/tensor_slice_writer.h b/tensorflow/core/util/tensor_slice_writer.h index bd13b55d6de471..dbdfeea0e1392c 100644 --- a/tensorflow/core/util/tensor_slice_writer.h +++ b/tensorflow/core/util/tensor_slice_writer.h @@ -48,7 +48,7 @@ class TensorSliceWriter { class Builder { public: virtual ~Builder() = default; - virtual void Add(StringPiece key, StringPiece value) = 0; + virtual void Add(absl::string_view key, absl::string_view value) = 0; virtual absl::Status Finish(int64_t* file_size) = 0; }; typedef std::function diff --git a/tensorflow/core/util/util.cc b/tensorflow/core/util/util.cc index e197f0cf90c86c..0f7bf624c5dd84 100644 --- a/tensorflow/core/util/util.cc +++ b/tensorflow/core/util/util.cc @@ -23,23 +23,23 @@ limitations under the License. namespace tensorflow { -StringPiece NodeNamePrefix(const StringPiece& op_name) { - StringPiece sp(op_name); +absl::string_view NodeNamePrefix(const absl::string_view& op_name) { + absl::string_view sp(op_name); auto p = sp.find('/'); - if (p == StringPiece::npos || p == 0) { + if (p == absl::string_view::npos || p == 0) { return ""; } else { - return StringPiece(sp.data(), p); + return absl::string_view(sp.data(), p); } } -StringPiece NodeNameFullPrefix(const StringPiece& op_name) { - StringPiece sp(op_name); +absl::string_view NodeNameFullPrefix(const absl::string_view& op_name) { + absl::string_view sp(op_name); auto p = sp.rfind('/'); - if (p == StringPiece::npos || p == 0) { + if (p == absl::string_view::npos || p == 0) { return ""; } else { - return StringPiece(sp.data(), p); + return absl::string_view(sp.data(), p); } } diff --git a/tensorflow/core/util/util.h b/tensorflow/core/util/util.h index 701c423045da8f..d3dd88a43fd7d7 100644 --- a/tensorflow/core/util/util.h +++ b/tensorflow/core/util/util.h @@ -26,11 +26,11 @@ namespace tensorflow { // If op_name has '/' in it, then return everything before the first '/'. // Otherwise return empty string. -StringPiece NodeNamePrefix(const StringPiece& op_name); +absl::string_view NodeNamePrefix(const absl::string_view& op_name); // If op_name has '/' in it, then return everything before the last '/'. // Otherwise return empty string. -StringPiece NodeNameFullPrefix(const StringPiece& op_name); +absl::string_view NodeNameFullPrefix(const absl::string_view& op_name); class MovingAverage { public: diff --git a/tensorflow/dtensor/mlir/dtensor_layout_to_xla_sharding_op.cc b/tensorflow/dtensor/mlir/dtensor_layout_to_xla_sharding_op.cc index b9e9b0648a7220..e8084b9c33f75b 100644 --- a/tensorflow/dtensor/mlir/dtensor_layout_to_xla_sharding_op.cc +++ b/tensorflow/dtensor/mlir/dtensor_layout_to_xla_sharding_op.cc @@ -88,8 +88,8 @@ void DTensorLayoutToXlaShardingOpPass::runOnOperation() { // For BlockArgument, the sharding is already attached to function attribute // by DTensorSetHloShardingPass. No additional tf.XlaSharding is needed. patterns.add(&getContext()); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (mlir::failed( + mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } diff --git a/tensorflow/dtensor/mlir/sparse_expansions/dynamic_enqueue_sparse_expander.cc b/tensorflow/dtensor/mlir/sparse_expansions/dynamic_enqueue_sparse_expander.cc index f609c1576f72fd..f08908eff9395e 100644 --- a/tensorflow/dtensor/mlir/sparse_expansions/dynamic_enqueue_sparse_expander.cc +++ b/tensorflow/dtensor/mlir/sparse_expansions/dynamic_enqueue_sparse_expander.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/dtensor/mlir/sparse_expansions/dynamic_enqueue_sparse_expander.h" +#include + #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project diff --git a/tensorflow/dtensor/tests/dtensor_operation_test.cc b/tensorflow/dtensor/tests/dtensor_operation_test.cc index bf0d06050396cb..98ceadbc159ddc 100644 --- a/tensorflow/dtensor/tests/dtensor_operation_test.cc +++ b/tensorflow/dtensor/tests/dtensor_operation_test.cc @@ -15,8 +15,6 @@ limitations under the License. #include "tensorflow/dtensor/cc/dtensor_operation.h" -#include - #include #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op.h" diff --git a/tensorflow/dtensor/tests/layout_to_xla_sharding_test.cc b/tensorflow/dtensor/tests/layout_to_xla_sharding_test.cc index 475e08c28269f8..0c802ec643947f 100644 --- a/tensorflow/dtensor/tests/layout_to_xla_sharding_test.cc +++ b/tensorflow/dtensor/tests/layout_to_xla_sharding_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/dtensor/cc/xla_spmd/layout_to_xla_sharding.h" +#include #include #include diff --git a/tensorflow/dtensor/tests/slice_util_test.cc b/tensorflow/dtensor/tests/slice_util_test.cc index ddb034765627fd..9694bb1f6a58d1 100644 --- a/tensorflow/dtensor/tests/slice_util_test.cc +++ b/tensorflow/dtensor/tests/slice_util_test.cc @@ -15,9 +15,7 @@ limitations under the License. #include "tensorflow/dtensor/cc/slice_util.h" -#include -#include -#include +#include #include #include diff --git a/tensorflow/dtensor/tests/tensor_layout_test.cc b/tensorflow/dtensor/tests/tensor_layout_test.cc index 28bcf1c4c94739..3f4f8015944027 100644 --- a/tensorflow/dtensor/tests/tensor_layout_test.cc +++ b/tensorflow/dtensor/tests/tensor_layout_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/dtensor/cc/tensor_layout.h" #include -#include #include #include #include diff --git a/tensorflow/examples/label_image/BUILD b/tensorflow/examples/label_image/BUILD index 7545a2f49ef5cd..a31439134fd64c 100644 --- a/tensorflow/examples/label_image/BUILD +++ b/tensorflow/examples/label_image/BUILD @@ -47,6 +47,7 @@ tf_cc_binary( ], }) + [ "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "//tensorflow/cc:ops", "//tensorflow/cc:scope", "@local_xla//xla/tsl/util:command_line_flags", diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc index e6257220d6b6b6..371b54c25827a5 100644 --- a/tensorflow/examples/label_image/main.cc +++ b/tensorflow/examples/label_image/main.cc @@ -35,6 +35,8 @@ limitations under the License. // are supported. #include +#include +#include #include #include #include @@ -42,6 +44,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/array_ops.h" diff --git a/tensorflow/examples/multibox_detector/main.cc b/tensorflow/examples/multibox_detector/main.cc index 3ed053a7e58627..b4da66c7215b82 100644 --- a/tensorflow/examples/multibox_detector/main.cc +++ b/tensorflow/examples/multibox_detector/main.cc @@ -79,7 +79,7 @@ Status ReadLocationsFile(const string& file_name, std::vector* result, result->reserve(string_tokens.size()); for (const string& string_token : string_tokens) { float number; - CHECK(tensorflow::strings::safe_strtof(string_token, &number)); + CHECK(absl::SimpleAtof(string_token, &number)); result->push_back(number); } } diff --git a/tensorflow/examples/speech_commands/accuracy_utils.cc b/tensorflow/examples/speech_commands/accuracy_utils.cc index 9a896afd44ba0d..42736f2ca920bd 100644 --- a/tensorflow/examples/speech_commands/accuracy_utils.cc +++ b/tensorflow/examples/speech_commands/accuracy_utils.cc @@ -50,7 +50,7 @@ absl::Status ReadGroundTruthFile( continue; } float timestamp; - if (!tensorflow::strings::safe_strtof(pieces[1], ×tamp)) { + if (!absl::SimpleAtof(pieces[1], ×tamp)) { return tensorflow::errors::InvalidArgument( "Wrong number format at line: ", line); } diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index f6a339989b42df..dd1997b8e2d44e 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -43,32 +43,35 @@ type DataType C.TF_DataType // Types of scalar values in the TensorFlow type system. const ( - Float DataType = C.TF_FLOAT - Double DataType = C.TF_DOUBLE - Int32 DataType = C.TF_INT32 - Uint32 DataType = C.TF_UINT32 - Uint8 DataType = C.TF_UINT8 - Int16 DataType = C.TF_INT16 - Int8 DataType = C.TF_INT8 - String DataType = C.TF_STRING - Complex64 DataType = C.TF_COMPLEX64 - Complex DataType = C.TF_COMPLEX - Int64 DataType = C.TF_INT64 - Uint64 DataType = C.TF_UINT64 - Bool DataType = C.TF_BOOL - Qint8 DataType = C.TF_QINT8 - Quint8 DataType = C.TF_QUINT8 - Qint32 DataType = C.TF_QINT32 - Bfloat16 DataType = C.TF_BFLOAT16 - Qint16 DataType = C.TF_QINT16 - Quint16 DataType = C.TF_QUINT16 - Uint16 DataType = C.TF_UINT16 - Complex128 DataType = C.TF_COMPLEX128 - Half DataType = C.TF_HALF - Float8e5m2 DataType = C.TF_FLOAT8_E5M2 - Float8e4m3fn DataType = C.TF_FLOAT8_E4M3FN - Int4 DataType = C.TF_INT4 - Uint4 DataType = C.TF_UINT4 + Float DataType = C.TF_FLOAT + Double DataType = C.TF_DOUBLE + Int32 DataType = C.TF_INT32 + Uint32 DataType = C.TF_UINT32 + Uint8 DataType = C.TF_UINT8 + Int16 DataType = C.TF_INT16 + Int8 DataType = C.TF_INT8 + String DataType = C.TF_STRING + Complex64 DataType = C.TF_COMPLEX64 + Complex DataType = C.TF_COMPLEX + Int64 DataType = C.TF_INT64 + Uint64 DataType = C.TF_UINT64 + Bool DataType = C.TF_BOOL + Qint8 DataType = C.TF_QINT8 + Quint8 DataType = C.TF_QUINT8 + Qint32 DataType = C.TF_QINT32 + Bfloat16 DataType = C.TF_BFLOAT16 + Qint16 DataType = C.TF_QINT16 + Quint16 DataType = C.TF_QUINT16 + Uint16 DataType = C.TF_UINT16 + Complex128 DataType = C.TF_COMPLEX128 + Half DataType = C.TF_HALF + Float8e5m2 DataType = C.TF_FLOAT8_E5M2 + Float8e4m3fn DataType = C.TF_FLOAT8_E4M3FN + Float8e4m3fnuz DataType = C.TF_FLOAT8_E4M3FNUZ + Float8e4m3b11fnuz DataType = C.TF_FLOAT8_E4M3B11FNUZ + Float8e5m2fnuz DataType = C.TF_FLOAT8_E5M2FNUZ + Int4 DataType = C.TF_INT4 + Uint4 DataType = C.TF_UINT4 ) // Tensor holds a multi-dimensional array of elements of a single data type. @@ -558,7 +561,7 @@ func isTensorSerializable(dataType DataType) error { // serialization and deserialization of Tensors. Till then capitalize // on knowledge of the implementation for numeric types. switch dataType { - case Float, Double, Int32, Uint8, Int16, Int8, Complex, Int64, Bool, Quint8, Qint32, Bfloat16, Qint16, Quint16, Uint16, Complex128, Half, Float8e5m2, Float8e4m3fn, Int4, Uint4: + case Float, Double, Int32, Uint8, Int16, Int8, Complex, Int64, Bool, Quint8, Qint32, Bfloat16, Qint16, Quint16, Uint16, Complex128, Half, Float8e5m2, Float8e4m3fn, Float8e4m3fnuz, Float8e4m3b11fnuz, Float8e5m2fnuz, Int4, Uint4: return nil default: return fmt.Errorf("serialization of tensors with the DataType %d is not yet supported, see https://github.com/tensorflow/tensorflow/issues/6003", dataType) diff --git a/tensorflow/js/ops/ts_op_gen_test.cc b/tensorflow/js/ops/ts_op_gen_test.cc index c137d0606b31c0..45170ed846fd3d 100644 --- a/tensorflow/js/ops/ts_op_gen_test.cc +++ b/tensorflow/js/ops/ts_op_gen_test.cc @@ -26,12 +26,12 @@ limitations under the License. namespace tensorflow { namespace { -void ExpectContainsStr(StringPiece s, StringPiece expected) { +void ExpectContainsStr(absl::string_view s, absl::string_view expected) { EXPECT_TRUE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } -void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) { +void ExpectDoesNotContainStr(absl::string_view s, absl::string_view expected) { EXPECT_FALSE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 992c42dd9f6aba..5f918ed9955d45 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -50,20 +50,6 @@ config_setting( }, ) -config_setting( - name = "mips", - values = { - "cpu": "mips", - }, -) - -config_setting( - name = "mips64", - values = { - "cpu": "mips64", - }, -) - # Without "cpu":"k8", when building with --copt=-DTF_LITE_STATIC_MEMORY, we get # the following error: # Multiple matches are not allowed unless one is unambiguously more specialized. @@ -313,6 +299,8 @@ cc_test( ":simple_planner", "//tensorflow/core:tflite_portable_logging", "//tensorflow/lite/core/c:common", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index 732e0ececac24e..8b5fc32f4d3f00 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -101,7 +101,7 @@ else() set(FLATC_TARGET "flatbuffers-flatc") endif() -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(_TFLITE_ENABLE_RUY "${TFLITE_ENABLE_RUY}") if("${CMAKE_SYSTEM_NAME}" STREQUAL "Android") @@ -622,7 +622,7 @@ populate_tflite_source_vars("kernels/internal/reference/sparse_ops" ) populate_tflite_source_vars("kernels/internal/optimized/4bit" TFLITE_KERNEL_INTERNAL_OPT_4BIT_SRCS - FILTER "(.*neon.*|.*sse.*)\\.(cc|h)" + FILTER "(.*neon_.*|.*sse_.*)\\.(cc|h)" ) set(TFLITE_PROFILER_SRCS ${TFLITE_SOURCE_DIR}/profiling/platform_profiler.cc diff --git a/tensorflow/lite/c/CMakeLists.txt b/tensorflow/lite/c/CMakeLists.txt index 44876bc437bdfa..931f4372c5f2e2 100644 --- a/tensorflow/lite/c/CMakeLists.txt +++ b/tensorflow/lite/c/CMakeLists.txt @@ -36,7 +36,7 @@ add_subdirectory( EXCLUDE_FROM_ALL ) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) if(CMAKE_SYSTEM_NAME MATCHES "Windows" AND (MSVC AND (CMAKE_SIZEOF_VOID_P EQUAL 4))) @@ -79,6 +79,7 @@ add_library(tensorflowlite_c ${TFLITE_C_LIBTYPE} if (TFLITE_C_BUILD_SHARED_LIBS) if (WIN32) target_compile_definitions(tensorflowlite_c PRIVATE TFL_COMPILE_LIBRARY) + target_compile_definitions(tensorflow-lite PRIVATE TFL_COMPILE_LIBRARY) elseif (APPLE) target_link_options(tensorflowlite_c PRIVATE "-Wl,-exported_symbols_list,${TFLITE_SOURCE_DIR}/c/exported_symbols.lds") else () diff --git a/tensorflow/lite/c/c_api_signature_runner_test.cc b/tensorflow/lite/c/c_api_signature_runner_test.cc index 61af71ffd863a6..f31a7b543d698c 100644 --- a/tensorflow/lite/c/c_api_signature_runner_test.cc +++ b/tensorflow/lite/c/c_api_signature_runner_test.cc @@ -136,6 +136,10 @@ TEST(SignatureRunnerTest, TestMultiSignatures) { ASSERT_EQ(signature_defs[1], "sub"); ASSERT_EQ(TfLiteInterpreterGetSignatureRunner(interpreter, "foo"), nullptr); + // Test out-of-range values. + ASSERT_EQ(TfLiteInterpreterGetSignatureKey(interpreter, 2), nullptr); + ASSERT_EQ(TfLiteInterpreterGetSignatureKey(interpreter, -1), nullptr); + TfLiteSignatureRunner* add_runner = TfLiteInterpreterGetSignatureRunner( interpreter, signature_defs[0].c_str()); ASSERT_NE(add_runner, nullptr); @@ -170,6 +174,13 @@ TEST(SignatureRunnerTest, TestMultiSignatures) { ASSERT_EQ(TfLiteSignatureRunnerInvoke(add_runner), kTfLiteOk); ASSERT_EQ(add_output->data.f[0], 4); ASSERT_EQ(add_output->data.f[1], 6); + + // Test out-of-range values. + ASSERT_EQ(TfLiteSignatureRunnerGetInputName(add_runner, 1), nullptr); + ASSERT_EQ(TfLiteSignatureRunnerGetInputName(add_runner, -1), nullptr); + ASSERT_EQ(TfLiteSignatureRunnerGetOutputName(add_runner, 1), nullptr); + ASSERT_EQ(TfLiteSignatureRunnerGetOutputName(add_runner, -1), nullptr); + TfLiteSignatureRunnerDelete(add_runner); TfLiteSignatureRunner* sub_runner = diff --git a/tensorflow/lite/c/common_internal.cc b/tensorflow/lite/c/common_internal.cc index 2728fa91a0e66b..b4899a4dbd9f4b 100644 --- a/tensorflow/lite/c/common_internal.cc +++ b/tensorflow/lite/c/common_internal.cc @@ -45,7 +45,7 @@ TfLiteStatus TfLiteDelegateCopyFromBufferHandleInternal( // TfLiteOpaqueContext and TfLiteContext being equivalent, or on // TfLiteOpaqueDelegate and TfLiteDelegate being equivalent. if (TfLiteDelegateHasValidOpaqueDelegateBuilder(delegate) && - tensor->delegate->opaque_delegate_builder->CopyFromBufferHandle) { + delegate->opaque_delegate_builder->CopyFromBufferHandle) { return delegate->opaque_delegate_builder->CopyFromBufferHandle( reinterpret_cast(context), reinterpret_cast(delegate), diff --git a/tensorflow/lite/core/acceleration/configuration/c/nnapi_plugin.cc b/tensorflow/lite/core/acceleration/configuration/c/nnapi_plugin.cc index 150ca14e9952fb..e58fe4ad499a2a 100644 --- a/tensorflow/lite/core/acceleration/configuration/c/nnapi_plugin.cc +++ b/tensorflow/lite/core/acceleration/configuration/c/nnapi_plugin.cc @@ -17,8 +17,6 @@ limitations under the License. #include "tensorflow/lite/core/acceleration/configuration/c/nnapi_plugin.h" -#include - #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h" #include "tensorflow/lite/core/acceleration/configuration/nnapi_plugin.h" diff --git a/tensorflow/lite/core/acceleration/configuration/delegate_registry.cc b/tensorflow/lite/core/acceleration/configuration/delegate_registry.cc index 71ee43e2b5f935..541d681dbafcc2 100644 --- a/tensorflow/lite/core/acceleration/configuration/delegate_registry.cc +++ b/tensorflow/lite/core/acceleration/configuration/delegate_registry.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h" #include +#include #include #include "absl/synchronization/mutex.h" diff --git a/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.h b/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.h index 8b86801be3d28c..65c641293ac7c9 100644 --- a/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.h +++ b/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.h @@ -25,6 +25,7 @@ limitations under the License. // This file provides the NNApiPlugin class, which implements the // TFLite Delegate Plugin for the NNAPI Delegate. +#include #include #include diff --git a/tensorflow/lite/core/acceleration/configuration/nnapi_plugin_test.cc b/tensorflow/lite/core/acceleration/configuration/nnapi_plugin_test.cc index 57a3042737600a..2a8fc9de5429e8 100644 --- a/tensorflow/lite/core/acceleration/configuration/nnapi_plugin_test.cc +++ b/tensorflow/lite/core/acceleration/configuration/nnapi_plugin_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/core/acceleration/configuration/nnapi_plugin.h" #include +#include #include #include diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h index 376b0eeb6302bf..5bc70cc6deae99 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.h +++ b/tensorflow/lite/core/api/flatbuffer_conversions.h @@ -42,9 +42,8 @@ class BuiltinDataAllocator { // deallocation. template T* AllocatePOD() { - // TODO(b/154346074): Change this to is_trivially_destructible when all - // platform targets support that properly. - static_assert(std::is_pod::value, "Builtin data structure must be POD."); + static_assert(std::is_trivially_destructible::value, + "Builtin data structure must be POD."); void* allocated_memory = this->Allocate(sizeof(T), alignof(T)); return new (allocated_memory) T(); } diff --git a/tensorflow/lite/core/async/interop/attribute_map_internal_test.cc b/tensorflow/lite/core/async/interop/attribute_map_internal_test.cc index 3f5ee8ca36e965..e58590849f1ecc 100644 --- a/tensorflow/lite/core/async/interop/attribute_map_internal_test.cc +++ b/tensorflow/lite/core/async/interop/attribute_map_internal_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/core/async/interop/attribute_map_internal.h" -#include +#include #include #include "tensorflow/lite/core/async/interop/c/types.h" diff --git a/tensorflow/lite/core/async/interop/variant.cc b/tensorflow/lite/core/async/interop/variant.cc index 46965ebef37d91..954e81c8e4fe6f 100644 --- a/tensorflow/lite/core/async/interop/variant.cc +++ b/tensorflow/lite/core/async/interop/variant.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/lite/core/async/interop/variant.h" #include -#include namespace tflite { namespace interop { diff --git a/tensorflow/lite/core/async/interop/variant_test.cc b/tensorflow/lite/core/async/interop/variant_test.cc index 3ce5d39048283c..03b59cedd15bbb 100644 --- a/tensorflow/lite/core/async/interop/variant_test.cc +++ b/tensorflow/lite/core/async/interop/variant_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/lite/core/async/interop/variant.h" #include -#include #include #include diff --git a/tensorflow/lite/core/c/c_api_types.h b/tensorflow/lite/core/c/c_api_types.h index b8f3b47f4dda9f..1fe66a47c14743 100644 --- a/tensorflow/lite/core/c/c_api_types.h +++ b/tensorflow/lite/core/c/c_api_types.h @@ -61,7 +61,7 @@ extern "C" { #ifdef TFL_COMPILE_LIBRARY #define TFL_CAPI_EXPORT __declspec(dllexport) #else -#define TFL_CAPI_EXPORT __declspec(dllimport) +#define TFL_CAPI_EXPORT #endif // TFL_COMPILE_LIBRARY #else #define TFL_CAPI_EXPORT __attribute__((visibility("default"))) diff --git a/tensorflow/lite/core/interpreter.h b/tensorflow/lite/core/interpreter.h index b17e60fb0e33fd..363a2990d3bb22 100644 --- a/tensorflow/lite/core/interpreter.h +++ b/tensorflow/lite/core/interpreter.h @@ -126,6 +126,8 @@ class InterpreterBuilder; // Class for friend declarations. class Interpreter { public: + using Ptr = std::unique_ptr; + // Instantiate an interpreter. All errors associated with reading and // processing this model will be forwarded to the error_reporter object. // diff --git a/tensorflow/lite/core/interpreter_experimental.cc b/tensorflow/lite/core/interpreter_experimental.cc index ff052a1ee81b3f..b0e7b766d0c4aa 100644 --- a/tensorflow/lite/core/interpreter_experimental.cc +++ b/tensorflow/lite/core/interpreter_experimental.cc @@ -17,10 +17,8 @@ limitations under the License. #include #include -#include #include #include -#include #include "tensorflow/lite/core/api/profiler.h" #include "tensorflow/lite/core/async/async_signature_runner.h" diff --git a/tensorflow/lite/core/model_builder.h b/tensorflow/lite/core/model_builder.h index 6a9d33418f6d0f..c53765e3166bc7 100644 --- a/tensorflow/lite/core/model_builder.h +++ b/tensorflow/lite/core/model_builder.h @@ -28,6 +28,8 @@ limitations under the License. #include +#include + #include "tensorflow/compiler/mlir/lite/core/model_builder_base.h" // IWYU pragma: export #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/stderr_reporter.h" @@ -38,6 +40,8 @@ namespace impl { class FlatBufferModel : public FlatBufferModelBase { public: + using Ptr = std::unique_ptr; + // Use stderr_reporter as the default error reporter. static ErrorReporter* GetDefaultErrorReporter() { return DefaultErrorReporter(); diff --git a/tensorflow/lite/core/signature_runner.cc b/tensorflow/lite/core/signature_runner.cc index 5058588f688f29..ea66a9f5521c20 100644 --- a/tensorflow/lite/core/signature_runner.cc +++ b/tensorflow/lite/core/signature_runner.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/core/signature_runner.h" +#include #include #include "tensorflow/lite/c/common.h" diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index dbd250364a3d82..2e1e2575063b0b 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -497,9 +497,12 @@ const char* GetDelegateKernalName(const TfLiteRegistration& registration) { TfLiteStatus Subgraph::PartitionGraph(const TfLiteIntArray* nodes_to_replace, std::vector* node_subsets) { const InterpreterInfo info(this); - return PartitionGraphIntoIndependentNodeSubsets( + // Tensor preservation requires node fusion to be disabled. + const bool disable_node_fusion = ShouldPreserveAllTensors(); + return tflite::PartitionGraphIntoIndependentNodeSubsets( &info, nodes_to_replace, node_subsets, - /*greedily=*/!DisableDelegateClustering(), control_edges_); + /*greedily=*/!DisableDelegateClustering(), control_edges_, + disable_node_fusion); } TfLiteStatus Subgraph::ReplaceNodeSubsetsWithDelegateKernels( @@ -562,9 +565,10 @@ TfLiteStatus Subgraph::ReplaceNodeSubsetsWithDelegateKernels( TFLITE_LOG_PROD(tflite::TFLITE_LOG_VERBOSE, "Replacing %d out of %d node(s) with delegate (%s) node, " "yielding %zu partitions " - "for the whole graph.", + "for subgraph %d.", nodes_to_replace->size, execution_plan_.size(), - GetDelegateKernalName(registration), node_subsets.size()); + GetDelegateKernalName(registration), node_subsets.size(), + subgraph_index_); execution_plan_.clear(); diff --git a/tensorflow/lite/core/tools/verifier_internal_test.cc b/tensorflow/lite/core/tools/verifier_internal_test.cc index d725400d2c346f..3a2a6c34f5baf1 100644 --- a/tensorflow/lite/core/tools/verifier_internal_test.cc +++ b/tensorflow/lite/core/tools/verifier_internal_test.cc @@ -14,7 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/core/tools/verifier_internal.h" -#include +#include +#include #include #include diff --git a/tensorflow/lite/core/tools/verifier_test.cc b/tensorflow/lite/core/tools/verifier_test.cc index b7b8460e198d02..2d4e6a16a832fa 100644 --- a/tensorflow/lite/core/tools/verifier_test.cc +++ b/tensorflow/lite/core/tools/verifier_test.cc @@ -14,6 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/core/tools/verifier.h" +#include +#include +#include +#include #include #include #include diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index 2fe82d4df684db..0e5fda390b754c 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -72,11 +72,11 @@ config_setting( # copybara:uncomment_begin(google-only) # constraint_values = [ # "//third_party/bazel_platforms/os:linux", + # "//third_party/bazel_platforms/cpu:x86_64", # ], # copybara:uncomment_end values = { "copt": "-DTFLITE_GPU_EXTRA_GLES_DEPS", - "cpu": "k8", }, ) diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc index 5bd407de4c8b78..bddc2033547cd2 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include "absl/status/status.h" diff --git a/tensorflow/lite/delegates/gpu/common/model.cc b/tensorflow/lite/delegates/gpu/common/model.cc index a7a174f60f54d2..d2cc4f432ff136 100644 --- a/tensorflow/lite/delegates/gpu/common/model.cc +++ b/tensorflow/lite/delegates/gpu/common/model.cc @@ -18,10 +18,8 @@ limitations under the License. #include #include -#include #include #include -#include #include #include diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.h b/tensorflow/lite/delegates/gpu/common/model_builder.h index 6e72635db478a5..62c2310880cdd2 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.h +++ b/tensorflow/lite/delegates/gpu/common/model_builder.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_H_ #include +#include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc index b81379a909e079..42e7d1cf8058a4 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include #include #include diff --git a/tensorflow/lite/delegates/gpu/common/model_transformer.cc b/tensorflow/lite/delegates/gpu/common/model_transformer.cc index 361d48fd88f423..24cd4a976f5af5 100644 --- a/tensorflow/lite/delegates/gpu/common/model_transformer.cc +++ b/tensorflow/lite/delegates/gpu/common/model_transformer.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" -#include "absl/strings/str_join.h" #include "tensorflow/lite/delegates/gpu/common/model.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/common/object_reader.cc b/tensorflow/lite/delegates/gpu/common/object_reader.cc index d8e0c431c4f909..00a8dc715a721e 100644 --- a/tensorflow/lite/delegates/gpu/common/object_reader.cc +++ b/tensorflow/lite/delegates/gpu/common/object_reader.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" diff --git a/tensorflow/lite/delegates/gpu/common/object_reader.h b/tensorflow/lite/delegates/gpu/common/object_reader.h index 9f5337be7972a8..2dae9af7ecf5a3 100644 --- a/tensorflow/lite/delegates/gpu/common/object_reader.h +++ b/tensorflow/lite/delegates/gpu/common/object_reader.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OBJECT_READER_H_ #include +#include #include #include "fp16.h" // from @FP16 diff --git a/tensorflow/lite/delegates/gpu/common/operation_parser.h b/tensorflow/lite/delegates/gpu/common/operation_parser.h index bc0cb037d91bde..9f21b448b6e4d5 100644 --- a/tensorflow/lite/delegates/gpu/common/operation_parser.h +++ b/tensorflow/lite/delegates/gpu/common/operation_parser.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATION_PARSER_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATION_PARSER_H_ +#include + #include "absl/container/flat_hash_map.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/gpu/common/model.h" diff --git a/tensorflow/lite/delegates/gpu/common/operations.cc b/tensorflow/lite/delegates/gpu/common/operations.cc index 7167813f0fe3e3..78f9627b36c373 100644 --- a/tensorflow/lite/delegates/gpu/common/operations.cc +++ b/tensorflow/lite/delegates/gpu/common/operations.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/container/flat_hash_map.h" diff --git a/tensorflow/lite/delegates/gpu/common/quantization_util_test.cc b/tensorflow/lite/delegates/gpu/common/quantization_util_test.cc index c26d59402d05e6..be8780ec355448 100644 --- a/tensorflow/lite/delegates/gpu/common/quantization_util_test.cc +++ b/tensorflow/lite/delegates/gpu/common/quantization_util_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include -#include #include #include #include diff --git a/tensorflow/lite/delegates/gpu/common/shape.cc b/tensorflow/lite/delegates/gpu/common/shape.cc index be3c0a56b7aee8..fcdbd81c8b32b0 100644 --- a/tensorflow/lite/delegates/gpu/common/shape.cc +++ b/tensorflow/lite/delegates/gpu/common/shape.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" diff --git a/tensorflow/lite/delegates/gpu/common/shape.h b/tensorflow/lite/delegates/gpu/common/shape.h index 14b45537926f8d..d337c77a6e69bc 100644 --- a/tensorflow/lite/delegates/gpu/common/shape.h +++ b/tensorflow/lite/delegates/gpu/common/shape.h @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/add.cc b/tensorflow/lite/delegates/gpu/gl/kernels/add.cc index a14d7f24714f23..72ca42de7cd0f2 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/add.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/add.cc @@ -15,17 +15,13 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/gl/kernels/add.h" -#include #include -#include -#include #include #include #include #include #include -#include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/lite/delegates/gpu/common/convert.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/concat.cc b/tensorflow/lite/delegates/gpu/gl/kernels/concat.cc index 0513c8ec877b20..965d3ca36c7412 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/concat.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/concat.cc @@ -15,16 +15,12 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/gl/kernels/concat.h" -#include #include -#include -#include #include #include #include #include -#include "absl/memory/memory.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/gl/variable.h" diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/conv.cc b/tensorflow/lite/delegates/gpu/gl/kernels/conv.cc index 8522ea252ed4b3..12e222758d3b97 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/conv.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/conv.cc @@ -16,12 +16,12 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/gl/kernels/conv.h" #include +#include #include #include #include #include -#include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/lite/delegates/gpu/common/convert.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/converter.cc b/tensorflow/lite/delegates/gpu/gl/kernels/converter.cc index 2b36db572108e3..ac72e8e5e8d2b4 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/converter.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/converter.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/gl/kernels/converter.h" +#include #include #include #include diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/converter_test.cc b/tensorflow/lite/delegates/gpu/gl/kernels/converter_test.cc index 5f14f093c55eb1..fea5fad1183088 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/converter_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/converter_test.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/gl/kernels/converter.h" -#include +#include +#include #include -#include #include #include "absl/types/span.h" #include "tensorflow/lite/delegates/gpu/common/convert.h" diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc index 627aeeec9d2a7e..bca59ab5024cbb 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" #include "tensorflow/lite/delegates/gpu/common/convert.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.cc b/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.cc index 528d75d656d982..2bc07988d03bc9 100644 --- a/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.cc +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h" +#include +#include + #include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h" #include "tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h" diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h b/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h index e277e45fc2760d..9bf1c4cb921e38 100644 --- a/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_BEST_EFFORT_CALCULATOR_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_BEST_EFFORT_CALCULATOR_H_ +#include +#include + #include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h" diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc index 54252dc4fc8afb..5b5f0c1a05ae32 100644 --- a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h" +#include + #include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h" diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.h b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.h index 4c034b1604fa57..5087bfcaa68add 100644 --- a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.h +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_CALCULATOR_FROM_METADATA_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_CALCULATOR_FROM_METADATA_H_ +#include +#include + #include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h" diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h b/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h index 6053c9e62e2a11..0c23a962eb19c7 100644 --- a/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_DEFAULT_CALCULATOR_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_DEFAULT_CALCULATOR_H_ +#include + #include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h" diff --git a/tensorflow/lite/delegates/hexagon/hexagon_delegate.cc b/tensorflow/lite/delegates/hexagon/hexagon_delegate.cc index 0d257be7777aa7..e3116341d70863 100644 --- a/tensorflow/lite/delegates/hexagon/hexagon_delegate.cc +++ b/tensorflow/lite/delegates/hexagon/hexagon_delegate.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include #include -#include #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/hexagon/hexagon_delegate_kernel.h" diff --git a/tensorflow/lite/delegates/hexagon/hexagon_delegate_kernel.cc b/tensorflow/lite/delegates/hexagon/hexagon_delegate_kernel.cc index e7d11299bd36a5..ceac707b985650 100644 --- a/tensorflow/lite/delegates/hexagon/hexagon_delegate_kernel.cc +++ b/tensorflow/lite/delegates/hexagon/hexagon_delegate_kernel.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/hexagon/hexagon_delegate_kernel.h" +#include +#include +#include #include #include #include diff --git a/tensorflow/lite/delegates/hexagon/hexagon_implementation.cc b/tensorflow/lite/delegates/hexagon/hexagon_implementation.cc index 26433cee494f94..7cbddd27f93245 100644 --- a/tensorflow/lite/delegates/hexagon/hexagon_implementation.cc +++ b/tensorflow/lite/delegates/hexagon/hexagon_implementation.cc @@ -18,8 +18,6 @@ limitations under the License. #include #include -#include - #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/logger.h" #include "tensorflow/lite/minimal_logging.h" diff --git a/tensorflow/lite/delegates/hexagon/utils_test.cc b/tensorflow/lite/delegates/hexagon/utils_test.cc index 201a7d0fa9d1b0..83b3eaa02ea1f6 100644 --- a/tensorflow/lite/delegates/hexagon/utils_test.cc +++ b/tensorflow/lite/delegates/hexagon/utils_test.cc @@ -14,9 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/hexagon/utils.h" -#include -#include - #include #include "tensorflow/lite/core/c/common.h" diff --git a/tensorflow/lite/delegates/serialization.h b/tensorflow/lite/delegates/serialization.h index ab214265fa2780..5c3f3255a582aa 100644 --- a/tensorflow/lite/delegates/serialization.h +++ b/tensorflow/lite/delegates/serialization.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_SERIALIZATION_H_ #define TENSORFLOW_LITE_DELEGATES_SERIALIZATION_H_ +#include #include #include #include diff --git a/tensorflow/lite/delegates/serialization_test.cc b/tensorflow/lite/delegates/serialization_test.cc index 15835223356fc2..c18701a92b1210 100644 --- a/tensorflow/lite/delegates/serialization_test.cc +++ b/tensorflow/lite/delegates/serialization_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/serialization.h" -#include #include #include diff --git a/tensorflow/lite/delegates/telemetry.cc b/tensorflow/lite/delegates/telemetry.cc index 58e22f4db427f6..47cf32641734dc 100644 --- a/tensorflow/lite/delegates/telemetry.cc +++ b/tensorflow/lite/delegates/telemetry.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/delegates/telemetry.h" +#include + #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/core/api/profiler.h" #include "tensorflow/lite/core/c/common.h" diff --git a/tensorflow/lite/delegates/telemetry_test.cc b/tensorflow/lite/delegates/telemetry_test.cc index 192a053c4015d2..72478f6a74de9c 100644 --- a/tensorflow/lite/delegates/telemetry_test.cc +++ b/tensorflow/lite/delegates/telemetry_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include #include "flatbuffers/buffer.h" // from @flatbuffers diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index c41ec1766730c7..a0905e314a020b 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -1687,6 +1687,7 @@ cc_test( ":xnnpack_delegate_test_mode", "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/schema:schema_fbs", + "@XNNPACK", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.cc b/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.cc index 0e8c4a1c703dea..11351eb19c478c 100644 --- a/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.cc @@ -17,8 +17,11 @@ limitations under the License. #include #include +#include #include +#include #include +#include #include #include #include diff --git a/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.h b/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.h index 296dbd5d93f110..2f6edf5239d889 100644 --- a/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.h +++ b/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_XNNPACK_BINARY_ELEMENTWISE_TESTER_H_ #include +#include #include #include diff --git a/tensorflow/lite/delegates/xnnpack/concatenation_test.cc b/tensorflow/lite/delegates/xnnpack/concatenation_test.cc index 5a46c46a365946..dd4f8131587a3f 100644 --- a/tensorflow/lite/delegates/xnnpack/concatenation_test.cc +++ b/tensorflow/lite/delegates/xnnpack/concatenation_test.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include diff --git a/tensorflow/lite/delegates/xnnpack/concatenation_tester.cc b/tensorflow/lite/delegates/xnnpack/concatenation_tester.cc index d3bf54c145f975..a13a35de3032bb 100644 --- a/tensorflow/lite/delegates/xnnpack/concatenation_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/concatenation_tester.cc @@ -17,8 +17,10 @@ limitations under the License. #include #include +#include #include #include +#include #include #include #include diff --git a/tensorflow/lite/delegates/xnnpack/concatenation_tester.h b/tensorflow/lite/delegates/xnnpack/concatenation_tester.h index 2af4638fe52ac3..202dab11d0b2f5 100644 --- a/tensorflow/lite/delegates/xnnpack/concatenation_tester.h +++ b/tensorflow/lite/delegates/xnnpack/concatenation_tester.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_CONCATENATION_TESTER_H_ #define TENSORFLOW_LITE_DELEGATES_XNNPACK_CONCATENATION_TESTER_H_ +#include #include +#include #include #include diff --git a/tensorflow/lite/delegates/xnnpack/conv_2d_tester.cc b/tensorflow/lite/delegates/xnnpack/conv_2d_tester.cc index 34664909635a2d..d28cd403a6f90d 100644 --- a/tensorflow/lite/delegates/xnnpack/conv_2d_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/conv_2d_tester.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include diff --git a/tensorflow/lite/delegates/xnnpack/delegate_test.cc b/tensorflow/lite/delegates/xnnpack/delegate_test.cc index 4a00a77250db3c..fc31ca077c2d3c 100644 --- a/tensorflow/lite/delegates/xnnpack/delegate_test.cc +++ b/tensorflow/lite/delegates/xnnpack/delegate_test.cc @@ -13,10 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include #include -#include #include #include "pthreadpool.h" // from @pthreadpool diff --git a/tensorflow/lite/delegates/xnnpack/depth_to_space_test.cc b/tensorflow/lite/delegates/xnnpack/depth_to_space_test.cc index 213de422a15c48..5b28bed0b41c22 100644 --- a/tensorflow/lite/delegates/xnnpack/depth_to_space_test.cc +++ b/tensorflow/lite/delegates/xnnpack/depth_to_space_test.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include diff --git a/tensorflow/lite/delegates/xnnpack/depth_to_space_tester.cc b/tensorflow/lite/delegates/xnnpack/depth_to_space_tester.cc index d67d75182ccad1..33c246cbe7509f 100644 --- a/tensorflow/lite/delegates/xnnpack/depth_to_space_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/depth_to_space_tester.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include #include +#include #include -#include #include #include diff --git a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc index 383ba67570ffda..db7a2a4f7f80a2 100644 --- a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include diff --git a/tensorflow/lite/delegates/xnnpack/dequantize_tester.cc b/tensorflow/lite/delegates/xnnpack/dequantize_tester.cc index faf2fa2e0d0fa6..dc52c896654844 100644 --- a/tensorflow/lite/delegates/xnnpack/dequantize_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/dequantize_tester.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/lite/delegates/xnnpack/dequantize_tester.h" #include +#include #include #include +#include #include #include #include diff --git a/tensorflow/lite/delegates/xnnpack/dequantize_tester.h b/tensorflow/lite/delegates/xnnpack/dequantize_tester.h index b29df24d569248..8e7f80cb7c5498 100644 --- a/tensorflow/lite/delegates/xnnpack/dequantize_tester.h +++ b/tensorflow/lite/delegates/xnnpack/dequantize_tester.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_XNNPACK_DEQUANTIZE_TESTER_H_ #include +#include #include #include diff --git a/tensorflow/lite/delegates/xnnpack/dynamically_quantized_fully_connected_tester.h b/tensorflow/lite/delegates/xnnpack/dynamically_quantized_fully_connected_tester.h index e073fb79780f5d..1370d1013d601f 100644 --- a/tensorflow/lite/delegates/xnnpack/dynamically_quantized_fully_connected_tester.h +++ b/tensorflow/lite/delegates/xnnpack/dynamically_quantized_fully_connected_tester.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_XNNPACK_DYNAMICALLY_QUANTIZED_FULLY_CONNECTED_TESTER_H_ #include +#include #include #include diff --git a/tensorflow/lite/delegates/xnnpack/dynamically_quantized_transpose_conv_tester.h b/tensorflow/lite/delegates/xnnpack/dynamically_quantized_transpose_conv_tester.h index 6a0e8fbe1cf2dc..3c170523066843 100644 --- a/tensorflow/lite/delegates/xnnpack/dynamically_quantized_transpose_conv_tester.h +++ b/tensorflow/lite/delegates/xnnpack/dynamically_quantized_transpose_conv_tester.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_DYNAMICALLY_QUANTIZED_TRANSPOSE_CONV_TESTER_H_ #define TENSORFLOW_LITE_DELEGATES_XNNPACK_DYNAMICALLY_QUANTIZED_TRANSPOSE_CONV_TESTER_H_ +#include #include #include #include diff --git a/tensorflow/lite/delegates/xnnpack/fully_connected_tester.h b/tensorflow/lite/delegates/xnnpack/fully_connected_tester.h index da9a4aeea515b5..029fff3657e93f 100644 --- a/tensorflow/lite/delegates/xnnpack/fully_connected_tester.h +++ b/tensorflow/lite/delegates/xnnpack/fully_connected_tester.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_XNNPACK_FULLY_CONNECTED_TESTER_H_ #include +#include #include #include diff --git a/tensorflow/lite/delegates/xnnpack/reshape_test.cc b/tensorflow/lite/delegates/xnnpack/reshape_test.cc index 56c252f461eef6..e64dc217448fbd 100644 --- a/tensorflow/lite/delegates/xnnpack/reshape_test.cc +++ b/tensorflow/lite/delegates/xnnpack/reshape_test.cc @@ -16,11 +16,13 @@ limitations under the License. #include #include #include +#include #include #include #include #include +#include "xnnpack.h" // from @XNNPACK #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/reshape_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" @@ -224,5 +226,30 @@ TEST(Reshape, MultiThreading) { .Test(TensorType_FLOAT32, xnnpack_delegate.get()); } +TEST(Reshape, UnsupportedOutputRank) { + std::unique_ptr + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto shape_rng = + std::bind(std::uniform_int_distribution(2, 10), std::ref(rng)); + std::vector input_shape; + std::generate_n(std::back_inserter(input_shape), XNN_MAX_TENSOR_DIMS, + shape_rng); + + // Construct an output shape greater than XNN_MAX_TENSOR_DIMS. This will + // prevent this node from being delegated to XNNPACK. + std::vector output_shape = input_shape; + output_shape.push_back(1); + std::shuffle(output_shape.begin(), output_shape.end(), rng); + + ReshapeTester() + .InputShape(input_shape) + .OutputShape(output_shape) + .Test(TensorType_FLOAT32, xnnpack_delegate.get()); +} + } // namespace xnnpack } // namespace tflite diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 31a0ed0d863246..99e94b2e24ee6b 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -5423,6 +5423,37 @@ class Subgraph { /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], BuiltinOperator_RESHAPE, node_index)); + const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; + TF_LITE_ENSURE_STATUS( + CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, + node->outputs->data[0], node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorShape( + logging_context, output_tensor, /*min_num_dims=*/0, + /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->outputs->data[0], + BuiltinOperator_RESHAPE, node_index)); + + if (output_tensor.type == kTfLiteUInt8 || + output_tensor.type == kTfLiteInt8) { + if (input_tensor.params.zero_point != output_tensor.params.zero_point) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, + "Mismatching quantization zero point across the input " + "(%" PRId32 ") and the output (%" PRId32 + ") for RESHAPE operator #%d", + input_tensor.params.zero_point, output_tensor.params.zero_point, + node_index); + return kTfLiteError; + } + if (input_tensor.params.scale != output_tensor.params.scale) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, + "Mismatching quantization scale across the input (%f) " + "and the output (%f) for RESHAPE operator #%d", + input_tensor.params.scale, output_tensor.params.scale, node_index); + return kTfLiteError; + } + } + std::array new_shape; int num_new_dimensions; if (node->inputs->size == 2) { @@ -5455,36 +5486,6 @@ class Subgraph { } } - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, output_tensor, /*min_num_dims=*/0, - /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->outputs->data[0], - BuiltinOperator_RESHAPE, node_index)); - - if (output_tensor.type == kTfLiteUInt8 || - output_tensor.type == kTfLiteInt8) { - if (input_tensor.params.zero_point != output_tensor.params.zero_point) { - TF_LITE_MAYBE_KERNEL_LOG( - logging_context, - "Mismatching quantization zero point across the input " - "(%" PRId32 ") and the output (%" PRId32 - ") for RESHAPE operator #%d", - input_tensor.params.zero_point, output_tensor.params.zero_point, - node_index); - return kTfLiteError; - } - if (input_tensor.params.scale != output_tensor.params.scale) { - TF_LITE_MAYBE_KERNEL_LOG( - logging_context, - "Mismatching quantization scale across the input (%f) " - "and the output (%f) for RESHAPE operator #%d", - input_tensor.params.scale, output_tensor.params.scale, node_index); - return kTfLiteError; - } - } if (subgraph != nullptr) { const xnn_status status = xnn_define_static_reshape( subgraph, num_new_dimensions, new_shape.data(), diff --git a/tensorflow/lite/examples/label_image/CMakeLists.txt b/tensorflow/lite/examples/label_image/CMakeLists.txt index 2fcb09ce96e990..07ab2343ae513f 100644 --- a/tensorflow/lite/examples/label_image/CMakeLists.txt +++ b/tensorflow/lite/examples/label_image/CMakeLists.txt @@ -84,5 +84,5 @@ target_compile_options(label_image target_link_libraries(label_image tensorflow-lite profiling_info_proto - protobuf + libprotobuf ) diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD b/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD index be4d974a32df88..f6d986f7f6e2b5 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD @@ -697,6 +697,7 @@ cc_library( visibility = ["@org_tensorflow_lite_support//tensorflow_lite_support/cc:__subpackages__"] + minibenchmark_visibility_allowlist(), deps = [ "//tensorflow/lite/acceleration/configuration:configuration_fbs", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@flatbuffers", @@ -1118,6 +1119,7 @@ cc_test( deps = [ ":embedded_mobilenet_validation_model", ":embedded_nnapi_sl_fake_impl", + ":embedded_validator_runner_entrypoint", ":mini_benchmark_test_helper", ":nnapi_sl_fake_impl_client", ":status_codes", diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_decompress_buffered_struct_test.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_decompress_buffered_struct_test.cc index b9b6c272b177ce..f3d4906cf8be18 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_decompress_buffered_struct_test.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_decompress_buffered_struct_test.cc @@ -14,9 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_decompress_buffered_struct.h" -#include - -#include #include namespace tflite { diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_header_parser.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_header_parser.cc index 9af7407eeabe46..97927653ea88c8 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_header_parser.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_header_parser.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_header_parser.h" #include -#include #include #include "tensorflow/lite/core/c/c_api_types.h" diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_header_parser_test.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_header_parser_test.cc index f3600094ae7840..75db7be9e28d4a 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_header_parser_test.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_header_parser_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/experimental/acceleration/mini_benchmark/jpeg_header_parser.h" +#include +#include #include #include diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/libc_handle_test.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/libc_handle_test.cc index c7b9d871204671..f23f0aaeb686b8 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/libc_handle_test.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/libc_handle_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/experimental/acceleration/mini_benchmark/libc_handle.h" -#include #include namespace tflite { diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_decoder.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_decoder.cc index 9d431d2689c25f..42cd0b639d76a9 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_decoder.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_decoder.cc @@ -22,9 +22,9 @@ limitations under the License. #include #include #include -#include #include #include +#include #include "absl/strings/match.h" #include "absl/strings/string_view.h" diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_decoder_test.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_decoder_test.cc index 15dc12c4e87ca7..0f17f4769b2c5b 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_decoder_test.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_decoder_test.cc @@ -17,10 +17,12 @@ limitations under the License. #include #include +#include #include #include #include +#include #include #include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg_status.h" diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_handle_dynamic_link.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_handle_dynamic_link.cc index d8fddc2acbdc8d..ea1ae4b6ee4a9e 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_handle_dynamic_link.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_handle_dynamic_link.cc @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_handle.h" - #include #include #include #include +#include #include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg_status.h" +#include "tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_handle.h" namespace tflite { namespace acceleration { diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark.cc index d8c36cb6825531..9b0ccecf0a971a 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark.cc @@ -14,10 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark.h" +#include +#include #include #include +#include #include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "flatbuffers/flatbuffers.h" // from @flatbuffers namespace tflite { diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark.h b/tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark.h index a01b94d6397adb..2cc952e1bc0e87 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark.h +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark.h @@ -15,12 +15,14 @@ limitations under the License. #ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_MINI_BENCHMARK_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_MINI_BENCHMARK_H_ +#include #include #include #include #include #include +#include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" diff --git a/tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark_implementation.cc b/tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark_implementation.cc index fcfabdc9b0836f..60dde77f8a889f 100644 --- a/tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark_implementation.cc +++ b/tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark_implementation.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include #include #include #include diff --git a/tensorflow/lite/experimental/litert/build_common/BUILD b/tensorflow/lite/experimental/litert/build_common/BUILD index b6b545ed68e824..ff47bd3a762ac3 100644 --- a/tensorflow/lite/experimental/litert/build_common/BUILD +++ b/tensorflow/lite/experimental/litert/build_common/BUILD @@ -17,4 +17,7 @@ package( default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], ) -exports_files(srcs = ["export_litert_only.lds"]) +exports_files(srcs = [ + "export_litert_only_darwin.lds", + "export_litert_only_linux.lds", +]) diff --git a/tensorflow/lite/experimental/litert/build_common/export_litert_only_darwin.lds b/tensorflow/lite/experimental/litert/build_common/export_litert_only_darwin.lds new file mode 100644 index 00000000000000..a51afcee0a21f0 --- /dev/null +++ b/tensorflow/lite/experimental/litert/build_common/export_litert_only_darwin.lds @@ -0,0 +1,8 @@ +# Compiler Plugin +*LiteRt*CompilerPlugin* + +# Compiled Result +*LiteRt*CompiledResult* + +# Dispatch +*LiteRtDispatch* diff --git a/tensorflow/lite/experimental/litert/build_common/export_litert_only.lds b/tensorflow/lite/experimental/litert/build_common/export_litert_only_linux.lds similarity index 100% rename from tensorflow/lite/experimental/litert/build_common/export_litert_only.lds rename to tensorflow/lite/experimental/litert/build_common/export_litert_only_linux.lds diff --git a/tensorflow/lite/experimental/litert/build_common/litert_build_defs.bzl b/tensorflow/lite/experimental/litert/build_common/litert_build_defs.bzl index 227050f9e9acc3..c49fba756494c3 100644 --- a/tensorflow/lite/experimental/litert/build_common/litert_build_defs.bzl +++ b/tensorflow/lite/experimental/litert/build_common/litert_build_defs.bzl @@ -46,9 +46,6 @@ def _valid_so_name(name): def _make_target_ref(name): return ":{}".format(name) -def _make_script_linkopt(script): - return make_linkopt("--version-script=$(location {})".format(script)) - #################################################################################################### # Explicitly Link System Libraries ("ungrte") @@ -64,8 +61,28 @@ _SYS_ELF_INTERPRETER_LINKOPT_X86_64 = make_linkopt("--dynamic-linker={}".format( #################################################################################################### # Symbol Hiding -_EXPORT_LRT_ONLY_SCRIPT = "//tensorflow/lite/experimental/litert/build_common:export_litert_only.lds" -_EXPORT_LRT_ONLY_LINKOPT = _make_script_linkopt(_EXPORT_LRT_ONLY_SCRIPT) +_EXPORT_LRT_ONLY_SCRIPT_LINUX = "//tensorflow/lite/experimental/litert/build_common:export_litert_only_linux.lds" +_EXPORT_LRT_ONLY_SCRIPT_DARWIN = "//tensorflow/lite/experimental/litert/build_common:export_litert_only_darwin.lds" +_EXPORT_LRT_ONLY_LINKOPT_LINUX = make_linkopt("--version-script=$(location {})".format(_EXPORT_LRT_ONLY_SCRIPT_LINUX)) +_EXPORT_LRT_ONLY_LINKOPT_DARWIN = make_linkopt("-exported_symbols_list,$(location {})".format(_EXPORT_LRT_ONLY_SCRIPT_DARWIN)) + +def export_lrt_only_script(): + return select({ + "//tensorflow:linux_x86_64": [_EXPORT_LRT_ONLY_SCRIPT_LINUX], + "//tensorflow:android": [_EXPORT_LRT_ONLY_SCRIPT_LINUX], + "//tensorflow:macos": [_EXPORT_LRT_ONLY_SCRIPT_DARWIN], + "//tensorflow:ios": [_EXPORT_LRT_ONLY_SCRIPT_DARWIN], + "//conditions:default": [], + }) + +def export_lrt_only_linkopt(): + return select({ + "//tensorflow:linux_x86_64": [_EXPORT_LRT_ONLY_LINKOPT_LINUX], + "//tensorflow:android": [_EXPORT_LRT_ONLY_LINKOPT_LINUX], + "//tensorflow:macos": [_EXPORT_LRT_ONLY_LINKOPT_DARWIN], + "//tensorflow:ios": [_EXPORT_LRT_ONLY_LINKOPT_DARWIN], + "//conditions:default": [], + }) #################################################################################################### # Macros @@ -154,8 +171,8 @@ def litert_bin( if export_litert_only: append_rule_kwargs( cc_bin_kwargs, - linkopts = [_EXPORT_LRT_ONLY_LINKOPT], - deps = [_EXPORT_LRT_ONLY_SCRIPT], + linkopts = export_lrt_only_linkopt(), + deps = export_lrt_only_script(), ) _litert_base( @@ -205,8 +222,8 @@ def litert_dynamic_lib( user_link_flags = [] additional_linker_inputs = [] if export_litert_only: - user_link_flags.append(_EXPORT_LRT_ONLY_LINKOPT) - additional_linker_inputs.append(_EXPORT_LRT_ONLY_SCRIPT) + user_link_flags = export_lrt_only_linkopt() + additional_linker_inputs = export_lrt_only_script() native.cc_shared_library( name = shared_lib_name, diff --git a/tensorflow/lite/experimental/litert/c/BUILD b/tensorflow/lite/experimental/litert/c/BUILD index 9ccb0d1e1314ea..fcb8f2efd51bf5 100644 --- a/tensorflow/lite/experimental/litert/c/BUILD +++ b/tensorflow/lite/experimental/litert/c/BUILD @@ -14,7 +14,11 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + default_visibility = [ + # copybara:uncomment "//third_party/mediapipe/calculators/tensor:__subpackages__", + # copybara:uncomment "//third_party/odml/infra:__subpackages__", + "//tensorflow/lite/experimental/litert:__subpackages__", + ], ) cc_library( @@ -22,6 +26,23 @@ cc_library( hdrs = ["litert_common.h"], ) +cc_library( + name = "litert_any", + hdrs = ["litert_any.h"], +) + +cc_library( + name = "litert_environment", + srcs = ["litert_environment.cc"], + hdrs = ["litert_environment.h"], + deps = [ + ":litert_any", + ":litert_common", + "//tensorflow/lite/experimental/litert/core:environment", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "litert_logging", srcs = [ @@ -71,10 +92,8 @@ cc_library( ":litert_op_code", "//tensorflow/lite/core/c:c_api_types", "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_macros", "//tensorflow/lite/experimental/litert/core/model", "//tensorflow/lite/experimental/litert/core/model:model_load", - "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/strings:string_view", ], ) @@ -87,10 +106,9 @@ cc_test( ":litert_model", ":litert_op_code", "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/cc:litert_layout", "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/experimental/litert/test:common", - "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", + "//tensorflow/lite/experimental/litert/test:test_macros", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", @@ -126,28 +144,37 @@ cc_test( ], tags = ["no_oss"], deps = [ - ":litert_model", ":litert_options", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/experimental/litert/test:test_macros", "@com_google_googletest//:gtest_main", ], ) +cc_library( + name = "litert_event", + srcs = ["litert_event.cc"], + hdrs = ["litert_event.h"], + deps = [ + ":litert_common", + ":litert_logging", + "//tensorflow/lite/experimental/litert/runtime:event", + ], +) + cc_library( name = "litert_tensor_buffer", srcs = [ - "litert_event.cc", "litert_tensor_buffer.cc", "litert_tensor_buffer_requirements.cc", ], hdrs = [ - "litert_event.h", "litert_tensor_buffer.h", "litert_tensor_buffer_requirements.h", ], deps = [ ":litert_common", + ":litert_event", ":litert_logging", ":litert_model", "//tensorflow/lite/experimental/litert/runtime:tensor_buffer", @@ -211,6 +238,9 @@ cc_library( hdrs = [ "litert_compiled_model_options.h", ], + deps = [ + ":litert_common", + ], ) cc_library( @@ -245,7 +275,11 @@ cc_test( ":litert_compiled_model_options", ":litert_model", ":litert_tensor_buffer", + "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/experimental/litert/test:simple_model", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", ], ) @@ -269,8 +303,13 @@ cc_test( copts = ["--std=c11"], linkopts = ["-ldl"], deps = [ + ":litert_any", ":litert_common", ":litert_compiled_model", + ":litert_compiled_model_options", + ":litert_dispatch_delegate", + ":litert_event", + ":litert_layout", ":litert_logging", ":litert_model", ":litert_op_code", diff --git a/tensorflow/lite/experimental/litert/c/litert_any.h b/tensorflow/lite/experimental/litert/c/litert_any.h new file mode 100644 index 00000000000000..69a2a8d7acf20d --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_any.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ANY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ANY_H_ + +#include // NOLINT: To use bool type in C +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef enum { + kLiteRtAnyTypeNone = 0, + kLiteRtAnyTypeBool = 1, + kLiteRtAnyTypeInt = 2, + kLiteRtAnyTypeReal = 3, + kLiteRtAnyTypeString = 8, + kLiteRtAnyTypeVoidPtr = 9, +} LiteRtAnyType; + +typedef struct { + LiteRtAnyType type; + union { + bool bool_value; + int64_t int_value; + double real_value; + const char* str_value; + const void* ptr_value; + }; +} LiteRtAny; + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ANY_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_c_api_common_test.c b/tensorflow/lite/experimental/litert/c/litert_c_api_common_test.c index 4c877406ea87ef..59cef3f76a04d1 100644 --- a/tensorflow/lite/experimental/litert/c/litert_c_api_common_test.c +++ b/tensorflow/lite/experimental/litert/c/litert_c_api_common_test.c @@ -20,9 +20,14 @@ // Include all the header files in the litert/c directory. #include "tensorflow/lite/experimental/litert/c/litert_common.h" // NOLINT +#include "tensorflow/lite/experimental/litert/c/litert_any.h" // NOLINT #include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h" // NOLINT +#include "tensorflow/lite/experimental/litert/c/litert_compiled_model_options.h" // NOLINT +#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" // NOLINT #include "tensorflow/lite/experimental/litert/c/litert_event.h" // NOLINT +#include "tensorflow/lite/experimental/litert/c/litert_layout.h" // NOLINT #include "tensorflow/lite/experimental/litert/c/litert_logging.h" // NOLINT +#include "tensorflow/lite/experimental/litert/c/litert_options.h" // NOLINT #include "tensorflow/lite/experimental/litert/c/litert_model.h" // NOLINT #include "tensorflow/lite/experimental/litert/c/litert_op_code.h" // NOLINT #include "tensorflow/lite/experimental/litert/c/litert_options.h" // NOLINT diff --git a/tensorflow/lite/experimental/litert/c/litert_common.h b/tensorflow/lite/experimental/litert/c/litert_common.h index 0295fb10e86f13..72f089c2aa2af9 100644 --- a/tensorflow/lite/experimental/litert/c/litert_common.h +++ b/tensorflow/lite/experimental/litert/c/litert_common.h @@ -15,18 +15,12 @@ #ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMMON_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMMON_H_ -#include // NOLINT: To use bool type in C -#include - #ifdef __cplusplus extern "C" { #endif // __cplusplus // Declares canonical opaque type. #define LITERT_DEFINE_HANDLE(name) typedef struct name##T* name -// Declares an array of references to opaque type. `name` must be -// previously declared opaque type. -#define LITERT_DEFINE_HANDLE_ARRAY(name) typedef name* name##Array #if __ANDROID_API__ >= 26 #define LITERT_HAS_AHWB_SUPPORT 1 @@ -93,25 +87,12 @@ typedef enum { kLiteRtStatusErrorInvalidLegalization = 2001, } LiteRtStatus; -typedef enum { - kLiteRtAnyTypeNone = 0, - kLiteRtAnyTypeBool = 1, - kLiteRtAnyTypeInt = 2, - kLiteRtAnyTypeReal = 3, - kLiteRtAnyTypeString = 8, - kLiteRtAnyTypeVoidPtr = 9, -} LiteRtAnyType; - -typedef struct { - LiteRtAnyType type; - union { - bool bool_value; - int64_t int_value; - double real_value; - const char* str_value; - const void* ptr_value; - }; -} LiteRtAny; +typedef enum : int { + kLiteRtHwAccelatorNone = 0, + kLiteRtHwAccelatorCpu = 1 << 0, + kLiteRtHwAccelatorGpu = 1 << 1, + kLiteRtHwAccelatorNpu = 1 << 2, +} LiteRtHwAccelerators; #ifdef __cplusplus } diff --git a/tensorflow/lite/experimental/litert/c/litert_compiled_model.cc b/tensorflow/lite/experimental/litert/c/litert_compiled_model.cc index ff4bf50ea3cdcc..db675431a8fc76 100644 --- a/tensorflow/lite/experimental/litert/c/litert_compiled_model.cc +++ b/tensorflow/lite/experimental/litert/c/litert_compiled_model.cc @@ -25,10 +25,14 @@ #include "tensorflow/lite/experimental/litert/runtime/compiled_model.h" LiteRtStatus LiteRtCreateCompiledModel( - LiteRtModel model, LiteRtComplicationOptions complication_options, + LiteRtModel model, LiteRtCompilationOptions compilation_options, LiteRtCompiledModel* compiled_model) { + if (!model || !compiled_model) { + return kLiteRtStatusErrorInvalidArgument; + } + auto created_compiled_model = - LiteRtCompiledModelT::Create(model, complication_options); + LiteRtCompiledModelT::Create(model, compilation_options); if (!created_compiled_model) { LITERT_LOG(LITERT_ERROR, "%s", created_compiled_model.Error().Message().data()); @@ -42,6 +46,10 @@ LiteRtStatus LiteRtGetCompiledModelInputBufferRequirements( LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index, LiteRtParamIndex input_index, LiteRtTensorBufferRequirements* buffer_requirements) { + if (!compiled_model || !buffer_requirements) { + return kLiteRtStatusErrorInvalidArgument; + } + auto res = compiled_model->GetInputBufferRequirementsCApi(signature_index, input_index); if (!res) { @@ -56,6 +64,10 @@ LiteRtStatus LiteRtGetCompiledModelOutputBufferRequirements( LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index, LiteRtParamIndex output_index, LiteRtTensorBufferRequirements* buffer_requirements) { + if (!compiled_model || !buffer_requirements) { + return kLiteRtStatusErrorInvalidArgument; + } + auto res = compiled_model->GetOutputBufferRequirementsCApi(signature_index, output_index); if (!res) { @@ -72,6 +84,11 @@ LiteRtStatus LiteRtRunCompiledModel(LiteRtCompiledModel compiled_model, LiteRtTensorBuffer* input_buffers, size_t num_output_buffers, LiteRtTensorBuffer* output_buffers) { + if (!compiled_model || (num_input_buffers > 0 && !input_buffers) || + (num_output_buffers > 0 && !output_buffers)) { + return kLiteRtStatusErrorInvalidArgument; + } + auto res = compiled_model->RunCApi(signature_index, num_input_buffers, input_buffers, num_output_buffers, output_buffers); diff --git a/tensorflow/lite/experimental/litert/c/litert_compiled_model.h b/tensorflow/lite/experimental/litert/c/litert_compiled_model.h index 76fb2cfac2f78a..10a2d4c3d7eb04 100644 --- a/tensorflow/lite/experimental/litert/c/litert_compiled_model.h +++ b/tensorflow/lite/experimental/litert/c/litert_compiled_model.h @@ -49,7 +49,7 @@ LITERT_DEFINE_HANDLE(LiteRtCompiledModel); // The model is loaded into memory and the caller takes ownership of the // returned object. LiteRtStatus LiteRtCreateCompiledModel( - LiteRtModel model, LiteRtComplicationOptions complication_options, + LiteRtModel model, LiteRtCompilationOptions compilation_options, LiteRtCompiledModel* compiled_model); // Returns the buffer requirements for the given n-th input tensor. The returned diff --git a/tensorflow/lite/experimental/litert/c/litert_compiled_model_options.h b/tensorflow/lite/experimental/litert/c/litert_compiled_model_options.h index f95837440e41ce..151aa050616f60 100644 --- a/tensorflow/lite/experimental/litert/c/litert_compiled_model_options.h +++ b/tensorflow/lite/experimental/litert/c/litert_compiled_model_options.h @@ -15,18 +15,14 @@ #ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILED_MODEL_OPTIONS_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMPILED_MODEL_OPTIONS_H_ +#include "tensorflow/lite/experimental/litert/c/litert_common.h" + #ifdef __cplusplus extern "C" { #endif // __cplusplus // The compilation options for the LiteRtCompiledModel. -// WARNING: This is an experimental and subject to change. -// TODO: b/379317134 - Add GPU support. -typedef enum LiteRtComplicationOptions : int { - kHwAccelDefault = 0, - kHwAccelCpu = 1 << 0, - kHwAccelNpu = 1 << 1, -} LiteRtComplicationOptions; +typedef LiteRtHwAccelerators LiteRtCompilationOptions; #ifdef __cplusplus } diff --git a/tensorflow/lite/experimental/litert/c/litert_compiled_model_test.cc b/tensorflow/lite/experimental/litert/c/litert_compiled_model_test.cc index 68ad9512d4bc69..705be3d5ddb791 100644 --- a/tensorflow/lite/experimental/litert/c/litert_compiled_model_test.cc +++ b/tensorflow/lite/experimental/litert/c/litert_compiled_model_test.cc @@ -14,63 +14,150 @@ #include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h" +#include +#include #include +#include #include +#include "absl/log/absl_log.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_compiled_model_options.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" #include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" + +using testing::FloatNear; +using testing::Pointwise; namespace litert { namespace { -static constexpr absl::string_view kTfliteFile = - "third_party/tensorflow/lite/experimental/litert/test/testdata/" - "simple_model.tflite"; - TEST(CompiledModelTest, Basic) { + auto path = testing::GetTestFilePath(kModelFileName); + LiteRtModel model; - ASSERT_EQ(LiteRtCreateModelFromFile(kTfliteFile.data(), &model), - kLiteRtStatusOk); + ASSERT_EQ(LiteRtCreateModelFromFile(path.c_str(), &model), kLiteRtStatusOk); + LiteRtCompiledModel compiled_model; - ASSERT_EQ(LiteRtCreateCompiledModel(model, kHwAccelCpu, &compiled_model), - kLiteRtStatusOk); + ASSERT_EQ( + LiteRtCreateCompiledModel(model, kLiteRtHwAccelatorCpu, &compiled_model), + kLiteRtStatusOk); LiteRtSubgraph subgraph; ASSERT_EQ(LiteRtGetModelSubgraph(model, 0, &subgraph), kLiteRtStatusOk); + LiteRtParamIndex num_inputs; - LiteRtTensorArray input_tensors; - ASSERT_EQ(LiteRtGetSubgraphInputs(subgraph, &num_inputs, &input_tensors), - kLiteRtStatusOk); - std::vector input_buffer_requirements; - input_buffer_requirements.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - LiteRtTensorBufferRequirements buffer_requirements; - ASSERT_EQ( - LiteRtGetCompiledModelInputBufferRequirements( - compiled_model, /*signature_index=*/0, i, &buffer_requirements), + ASSERT_EQ(LiteRtGetNumSubgraphInputs(subgraph, &num_inputs), kLiteRtStatusOk); + + std::vector input_tensor_buffers; + input_tensor_buffers.reserve(num_inputs); + for (auto i = 0; i < num_inputs; ++i) { + LiteRtTensorBufferRequirements tensor_buffer_requirements; + ASSERT_EQ(LiteRtGetCompiledModelInputBufferRequirements( + compiled_model, /*signature_index=*/0, i, + &tensor_buffer_requirements), + kLiteRtStatusOk); + LiteRtTensorBufferType tensor_buffer_type; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + tensor_buffer_requirements, /*type_index=*/0, &tensor_buffer_type), kLiteRtStatusOk); - input_buffer_requirements.push_back(buffer_requirements); + size_t tensor_buffer_size; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( + tensor_buffer_requirements, &tensor_buffer_size), + kLiteRtStatusOk); + LiteRtTensorBuffer tensor_buffer; + EXPECT_EQ( + LiteRtCreateManagedTensorBuffer(tensor_buffer_type, &kInput0TensorType, + tensor_buffer_size, &tensor_buffer), + kLiteRtStatusOk); + input_tensor_buffers.push_back(tensor_buffer); } LiteRtParamIndex num_outputs; - LiteRtTensorArray output_tensors; - ASSERT_EQ(LiteRtGetSubgraphOutputs(subgraph, &num_outputs, &output_tensors), + ASSERT_EQ(LiteRtGetNumSubgraphOutputs(subgraph, &num_outputs), kLiteRtStatusOk); - std::vector output_buffer_requirements; - output_buffer_requirements.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - LiteRtTensorBufferRequirements buffer_requirements; - ASSERT_EQ( - LiteRtGetCompiledModelOutputBufferRequirements( - compiled_model, /*signature_index=*/0, i, &buffer_requirements), + + std::vector output_tensor_buffers; + output_tensor_buffers.reserve(num_outputs); + for (auto i = 0; i < num_outputs; ++i) { + LiteRtTensorBufferRequirements tensor_buffer_requirements; + ASSERT_EQ(LiteRtGetCompiledModelOutputBufferRequirements( + compiled_model, /*signature_index=*/0, i, + &tensor_buffer_requirements), + kLiteRtStatusOk); + LiteRtTensorBufferType tensor_buffer_type; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + tensor_buffer_requirements, /*type_index=*/0, &tensor_buffer_type), + kLiteRtStatusOk); + size_t tensor_buffer_size; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( + tensor_buffer_requirements, &tensor_buffer_size), + kLiteRtStatusOk); + LiteRtTensorBuffer tensor_buffer; + EXPECT_EQ( + LiteRtCreateManagedTensorBuffer(tensor_buffer_type, &kInput0TensorType, + tensor_buffer_size, &tensor_buffer), kLiteRtStatusOk); - output_buffer_requirements.push_back(buffer_requirements); + output_tensor_buffers.push_back(tensor_buffer); } + + { + ABSL_LOG(INFO) << "Filling inputs with data"; + void* host_mem_addr; + + ASSERT_EQ(LiteRtLockTensorBuffer(input_tensor_buffers[0], &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_tensor_buffers[0]), + kLiteRtStatusOk); + + ASSERT_EQ(LiteRtLockTensorBuffer(input_tensor_buffers[1], &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_tensor_buffers[1]), + kLiteRtStatusOk); + } + + ASSERT_EQ(LiteRtRunCompiledModel( + compiled_model, /*signature_index=*/0, + input_tensor_buffers.size(), input_tensor_buffers.data(), + output_tensor_buffers.size(), output_tensor_buffers.data()), + kLiteRtStatusOk); + + { + ABSL_LOG(INFO) << "Checking output..."; + void* host_mem_addr; + ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffers[0], &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + auto output = absl::MakeSpan(static_cast(host_mem_addr), + kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(FloatNear(1e-3), kTestOutputTensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffers[0]), + kLiteRtStatusOk); + } + LiteRtDestroyCompiledModel(compiled_model); LiteRtDestroyModel(model); + + for (auto tensor_buffer : input_tensor_buffers) { + LiteRtDestroyTensorBuffer(tensor_buffer); + } + for (auto tensor_buffer : output_tensor_buffers) { + LiteRtDestroyTensorBuffer(tensor_buffer); + } } } // namespace diff --git a/tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h b/tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h index e220c23dd4410d..48855e78b80b91 100644 --- a/tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h +++ b/tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h @@ -15,7 +15,7 @@ #ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_DISPATCH_DELEGATE_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_DISPATCH_DELEGATE_H_ -#include +#include #include "tensorflow/lite/c/c_api_opaque.h" #include "tensorflow/lite/c/c_api_types.h" diff --git a/tensorflow/lite/experimental/litert/c/litert_environment.cc b/tensorflow/lite/experimental/litert/c/litert_environment.cc new file mode 100644 index 00000000000000..c25e9a71e10e48 --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_environment.cc @@ -0,0 +1,31 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/c/litert_environment.h" + +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/core/environment.h" + +LiteRtStatus LiteRtEnvironmentCreate(int num_options, + const LiteRtEnvOption* options) { + if (auto status = litert::internal::Environment::CreateWithOptions( + absl::MakeSpan(options, num_options)); + !status) { + return status.Error().Status(); + } + return kLiteRtStatusOk; +} + +void LiteRtEnvironmentDestroy() { litert::internal::Environment::Destroy(); } diff --git a/tensorflow/lite/experimental/litert/c/litert_environment.h b/tensorflow/lite/experimental/litert/c/litert_environment.h new file mode 100644 index 00000000000000..fce03aee55e392 --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_environment.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_H_ + +#include "tensorflow/lite/experimental/litert/c/litert_any.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef enum { + kLiteRtEnvOptionTagCompilerPluginLibraryPath = 0, + kLiteRtEnvOptionTagDispatchLibraryPath = 1, +} LiteRtEnvOptionTag; + +typedef struct { + LiteRtEnvOptionTag tag; + LiteRtAny value; +} LiteRtEnvOption; + +// Create a singleton LiteRT environment with options. Returns an error if the +// instance already exists, in which case the specified options have no +// effect. If not created explicitly with options, the environment instance will +// be created (with no options) when needed. +LiteRtStatus LiteRtEnvironmentCreate(int num_options, + const LiteRtEnvOption* options); + +// Destroy the LiteRT environment instance. +void LiteRtEnvironmentDestroy(); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_ENVIRONMENT_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_event.cc b/tensorflow/lite/experimental/litert/c/litert_event.cc index b18f91fc229a78..7d58e7426ae98e 100644 --- a/tensorflow/lite/experimental/litert/c/litert_event.cc +++ b/tensorflow/lite/experimental/litert/c/litert_event.cc @@ -21,23 +21,34 @@ #include #include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" #include "tensorflow/lite/experimental/litert/runtime/event.h" -#if LITERT_HAS_SYNC_FENCE_SUPPORT LiteRtStatus LiteRtCreateEventFromSyncFenceFd(int sync_fence_fd, bool owns_fd, LiteRtEvent* event) { +#if LITERT_HAS_SYNC_FENCE_SUPPORT *event = new LiteRtEventT{.fd = sync_fence_fd, .owns_fd = owns_fd}; return kLiteRtStatusOk; +#else + return kLiteRtStatusErrorUnsupported; +#endif } LiteRtStatus LiteRtGetEventSyncFenceFd(LiteRtEvent event, int* sync_fence_fd) { +#if LITERT_HAS_SYNC_FENCE_SUPPORT *sync_fence_fd = event->fd; return kLiteRtStatusOk; -} +#else + return kLiteRtStatusErrorUnsupported; #endif +} LiteRtStatus LiteRtEventWait(LiteRtEvent event, int64_t timeout_in_ms) { - return event->Wait(timeout_in_ms); + if (auto status = event->Wait(timeout_in_ms); !status) { + LITERT_LOG(LITERT_ERROR, "%s", status.Error().Message().data()); + return status.Error().Status(); + } + return kLiteRtStatusOk; } void LiteRtDestroyEvent(LiteRtEvent event) { delete event; } diff --git a/tensorflow/lite/experimental/litert/c/litert_event.h b/tensorflow/lite/experimental/litert/c/litert_event.h index a3bca94436b81a..20a42738a822b5 100644 --- a/tensorflow/lite/experimental/litert/c/litert_event.h +++ b/tensorflow/lite/experimental/litert/c/litert_event.h @@ -15,6 +15,7 @@ #ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_H_ +#include // NOLINT: To use bool type in C #include #include "tensorflow/lite/experimental/litert/c/litert_common.h" @@ -25,12 +26,10 @@ extern "C" { LITERT_DEFINE_HANDLE(LiteRtEvent); -#if LITERT_HAS_SYNC_FENCE_SUPPORT LiteRtStatus LiteRtCreateEventFromSyncFenceFd(int sync_fence_fd, bool owns_fd, LiteRtEvent* event); LiteRtStatus LiteRtGetEventSyncFenceFd(LiteRtEvent event, int* sync_fence_fd); -#endif // LITERT_HAS_SYNC_FENCE_SUPPORT // Pass -1 for timeout_in_ms for indefinite wait. LiteRtStatus LiteRtEventWait(LiteRtEvent event, int64_t timeout_in_ms); diff --git a/tensorflow/lite/experimental/litert/c/litert_model.cc b/tensorflow/lite/experimental/litert/c/litert_model.cc index 4c48e657dff152..2cfa9264351c3b 100644 --- a/tensorflow/lite/experimental/litert/c/litert_model.cc +++ b/tensorflow/lite/experimental/litert/c/litert_model.cc @@ -22,12 +22,8 @@ #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_op_code.h" #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" #include "tensorflow/lite/experimental/litert/core/model/model.h" #include "tensorflow/lite/experimental/litert/core/model/model_load.h" -#include "tensorflow/lite/schema/schema_generated.h" - -static const char* LiteRtDefaultSignatureKey = LITERT_DEFAULT_SIGNATURE_KEY; // // Model @@ -65,22 +61,23 @@ LiteRtStatus LiteRtCreateModelFromBuffer(const void* buffer_addr, LiteRtStatus LiteRtGetNumModelSubgraphs(LiteRtModel model, LiteRtParamIndex* num_subgraphs) { - if (!model || !num_subgraphs) { + if (model == nullptr) { return kLiteRtStatusErrorInvalidArgument; } - *num_subgraphs = model->subgraphs.size(); + *num_subgraphs = model->Subgraphs().size(); return kLiteRtStatusOk; } LiteRtStatus LiteRtGetModelSubgraph(LiteRtModel model, LiteRtParamIndex subgraph_index, LiteRtSubgraph* subgraph) { - if (!model) { + if (model == nullptr) { return kLiteRtStatusErrorInvalidArgument; - } else if (subgraph_index >= model->subgraphs.size()) { + } + if (subgraph_index >= model->Subgraphs().size()) { return kLiteRtStatusErrorIndexOOB; } - *subgraph = model->subgraphs.data() + subgraph_index; + *subgraph = &model->Subgraph(subgraph_index); return kLiteRtStatusOk; } @@ -89,7 +86,7 @@ LiteRtStatus LiteRtGetMainModelSubgraphIndex( if (!model || !main_subgraph_index) { return kLiteRtStatusErrorInvalidArgument; } - *main_subgraph_index = model->MainSubgraphIndex(); + *main_subgraph_index = LiteRtModelT::kMainSubgraphIndex; return kLiteRtStatusOk; } @@ -113,7 +110,7 @@ LiteRtStatus LiteRtGetNumModelSignatures(LiteRtModel model, if (!model || !num_signatures) { return kLiteRtStatusErrorInvalidArgument; } - *num_signatures = model->signatures.size(); + *num_signatures = model->Signatures().size(); return kLiteRtStatusOk; } @@ -123,10 +120,11 @@ LiteRtStatus LiteRtGetModelSignature(LiteRtModel model, LiteRtSignature* signature) { if (!model || !signature) { return kLiteRtStatusErrorInvalidArgument; - } else if (signature_index >= model->signatures.size()) { + } + if (signature_index >= model->Signatures().size()) { return kLiteRtStatusErrorIndexOOB; } - *signature = model->signatures[signature_index].get(); + *signature = model->Signatures().at(signature_index); return kLiteRtStatusOk; } @@ -148,7 +146,7 @@ LiteRtStatus LiteRtGetDefaultSignatureKey(const char** signature_key) { if (!signature_key) { return kLiteRtStatusErrorInvalidArgument; } - *signature_key = LiteRtDefaultSignatureKey; + *signature_key = LiteRtSignatureT::kDefaultSignatureKey.data(); return kLiteRtStatusOk; } @@ -157,13 +155,16 @@ LiteRtStatus LiteRtGetSignatureKey(LiteRtSignature signature, if (!signature || !signature_key) { return kLiteRtStatusErrorInvalidArgument; } - *signature_key = signature->key.data(); + *signature_key = signature->Key().data(); return kLiteRtStatusOk; } -LiteRtStatus LiteRtGetSignatureSubgraphIndex(LiteRtSignature signature, - LiteRtParamIndex* subgraph_index) { - *subgraph_index = signature->subgraph_index; +LiteRtStatus LiteRtGetSignatureSubgraph(LiteRtSignature signature, + LiteRtSubgraph* subgraph) { + if (signature == nullptr) { + return kLiteRtStatusErrorInvalidArgument; + } + *subgraph = &signature->GetSubgraph(); return kLiteRtStatusOk; } @@ -172,7 +173,7 @@ LiteRtStatus LiteRtGetNumSignatureInputs(LiteRtSignature signature, if (!signature || !num_inputs) { return kLiteRtStatusErrorInvalidArgument; } - *num_inputs = signature->input_names.size(); + *num_inputs = signature->InputNames().size(); return kLiteRtStatusOk; } @@ -181,10 +182,11 @@ LiteRtStatus LiteRtGetSignatureInputName(LiteRtSignature signature, const char** input_name) { if (!signature || !input_name) { return kLiteRtStatusErrorInvalidArgument; - } else if (input_idx >= signature->input_names.size()) { + } + if (input_idx >= signature->InputNames().size()) { return kLiteRtStatusErrorIndexOOB; } - *input_name = signature->input_names[input_idx].data(); + *input_name = signature->InputNames().at(input_idx).data(); return kLiteRtStatusOk; } @@ -193,7 +195,7 @@ LiteRtStatus LiteRtGetNumSignatureOutputs(LiteRtSignature signature, if (!signature || !num_outputs) { return kLiteRtStatusErrorInvalidArgument; } - *num_outputs = signature->output_names.size(); + *num_outputs = signature->OutputNames().size(); return kLiteRtStatusOk; } @@ -202,10 +204,11 @@ LiteRtStatus LiteRtGetSignatureOutputName(LiteRtSignature signature, const char** output_name) { if (!signature || !output_name) { return kLiteRtStatusErrorInvalidArgument; - } else if (output_idx >= signature->output_names.size()) { + } + if (output_idx >= signature->OutputNames().size()) { return kLiteRtStatusErrorIndexOOB; } - *output_name = signature->output_names[output_idx].data(); + *output_name = signature->OutputNames().at(output_idx).data(); return kLiteRtStatusOk; } @@ -213,36 +216,65 @@ LiteRtStatus LiteRtGetSignatureOutputName(LiteRtSignature signature, // Subgraph // -LiteRtStatus LiteRtGetSubgraphInputs(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_inputs, - LiteRtTensorArray* inputs) { - if (!subgraph || !num_inputs || !inputs) { +LiteRtStatus LiteRtGetNumSubgraphInputs(LiteRtSubgraph subgraph, + LiteRtParamIndex* num_inputs) { + if (!subgraph || !num_inputs) { + return kLiteRtStatusErrorInvalidArgument; + } + *num_inputs = subgraph->Inputs().size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetSubgraphInput(LiteRtSubgraph subgraph, + LiteRtParamIndex input_index, + LiteRtTensor* input) { + if (!subgraph || !input) { + return kLiteRtStatusErrorInvalidArgument; + } else if (input_index < 0 || input_index >= subgraph->Inputs().size()) { + return kLiteRtStatusErrorIndexOOB; + } + *input = subgraph->Inputs()[input_index]; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetNumSubgraphOutputs(LiteRtSubgraph subgraph, + LiteRtParamIndex* num_outputs) { + if (!subgraph || !num_outputs) { return kLiteRtStatusErrorInvalidArgument; } - *num_inputs = subgraph->inputs.size(); - *inputs = subgraph->inputs.data(); + *num_outputs = subgraph->Outputs().size(); return kLiteRtStatusOk; } -LiteRtStatus LiteRtGetSubgraphOutputs(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_outputs, - LiteRtTensorArray* outputs) { - if (!subgraph || !num_outputs || !outputs) { +LiteRtStatus LiteRtGetSubgraphOutput(LiteRtSubgraph subgraph, + LiteRtParamIndex output_index, + LiteRtTensor* output) { + if (!subgraph || !output) { return kLiteRtStatusErrorInvalidArgument; + } else if (output_index < 0 || output_index >= subgraph->Outputs().size()) { + return kLiteRtStatusErrorIndexOOB; } - *num_outputs = subgraph->outputs.size(); - *outputs = subgraph->outputs.data(); + *output = subgraph->Outputs()[output_index]; return kLiteRtStatusOk; } -LiteRtStatus LiteRtGetSubgraphOps(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_ops, - LiteRtOpArray* ops) { - if (!subgraph || !num_ops || !ops) { +LiteRtStatus LiteRtGetNumSubgraphOps(LiteRtSubgraph subgraph, + LiteRtParamIndex* num_ops) { + if (!subgraph || !num_ops) { return kLiteRtStatusErrorInvalidArgument; } - *num_ops = subgraph->ops.size(); - *ops = subgraph->ops.data(); + *num_ops = subgraph->Ops().size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetSubgraphOp(LiteRtSubgraph subgraph, + LiteRtParamIndex op_index, LiteRtOp* op) { + if (!subgraph || !op) { + return kLiteRtStatusErrorInvalidArgument; + } else if (op_index < 0 || op_index >= subgraph->Ops().size()) { + return kLiteRtStatusErrorIndexOOB; + } + *op = subgraph->Ops()[op_index]; return kLiteRtStatusOk; } @@ -250,36 +282,54 @@ LiteRtStatus LiteRtGetSubgraphOps(LiteRtSubgraph subgraph, // Op // -LiteRtStatus LiteRtGetOpOutputs(LiteRtOp op, LiteRtParamIndex* num_outputs, - LiteRtTensorArray* outputs) { - if (!op || !num_outputs || !outputs) { +LiteRtStatus LiteRtGetOpCode(LiteRtOp op, LiteRtOpCode* code) { + if (!op || !code) { return kLiteRtStatusErrorInvalidArgument; } - *num_outputs = op->outputs.size(); - *outputs = op->outputs.data(); + *code = op->OpCode(); return kLiteRtStatusOk; } -LiteRtStatus LiteRtGetOpInputs(LiteRtOp op, LiteRtParamIndex* num_inputs, - LiteRtTensorArray* inputs) { - if (!op || !num_inputs || !inputs) { +LiteRtStatus LiteRtGetNumOpInputs(LiteRtOp op, LiteRtParamIndex* num_inputs) { + if (!op || !num_inputs) { return kLiteRtStatusErrorInvalidArgument; } - *num_inputs = op->inputs.size(); - *inputs = op->inputs.data(); + *num_inputs = op->Inputs().size(); return kLiteRtStatusOk; } -LiteRtStatus LiteRtGetOpCode(LiteRtOp op, LiteRtOpCode* code) { - if (!op || !code) { +LiteRtStatus LiteRtGetOpInput(LiteRtOp op, LiteRtParamIndex input_index, + LiteRtTensor* input) { + if (!op || !input) { + return kLiteRtStatusErrorInvalidArgument; + } else if (input_index < 0 || input_index >= op->Inputs().size()) { + return kLiteRtStatusErrorIndexOOB; + } + *input = op->Inputs()[input_index]; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetNumOpOutputs(LiteRtOp op, LiteRtParamIndex* num_outputs) { + if (!op || !num_outputs) { return kLiteRtStatusErrorInvalidArgument; } - *code = op->op_code; + *num_outputs = op->Outputs().size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetOpOutput(LiteRtOp op, LiteRtParamIndex output_index, + LiteRtTensor* output) { + if (!op || !output) { + return kLiteRtStatusErrorInvalidArgument; + } else if (output_index < 0 || output_index >= op->Outputs().size()) { + return kLiteRtStatusErrorIndexOOB; + } + *output = op->Outputs()[output_index]; return kLiteRtStatusOk; } // -// Tensor +// Weights // LiteRtStatus LiteRtGetWeightsBytes(LiteRtWeights weights, const void** addr, @@ -287,35 +337,43 @@ LiteRtStatus LiteRtGetWeightsBytes(LiteRtWeights weights, const void** addr, if (!weights || !addr || !size) { return kLiteRtStatusErrorInvalidArgument; } - if (weights->fb_buffer == nullptr) { - *addr = nullptr; - *size = 0; - } else { - *addr = weights->fb_buffer->data.data(); - *size = weights->fb_buffer->data.size(); - } + *addr = weights->Buf().Data(); + *size = weights->Buf().Size(); return kLiteRtStatusOk; } +// +// Tensor +// + LiteRtStatus LiteRtGetTensorWeights(LiteRtTensor tensor, LiteRtWeights* weights) { if (!tensor || !weights) { return kLiteRtStatusErrorInvalidArgument; } - *weights = &tensor->weights; + *weights = &tensor->Weights(); return kLiteRtStatusOk; } -LiteRtStatus LiteRtGetTensorUses(LiteRtTensor tensor, - LiteRtParamIndex* num_uses, - LiteRtOpArray* use_users, - LiteRtParamIndex** use_user_arg_inds) { - if (!tensor || !num_uses || !use_users || !use_user_arg_inds) { +LiteRtStatus LiteRtGetNumTensorUses(LiteRtTensor tensor, + LiteRtParamIndex* num_uses) { + if (!tensor || !num_uses) { return kLiteRtStatusErrorInvalidArgument; } - *num_uses = tensor->users.size(); - *use_users = tensor->users.data(); - *use_user_arg_inds = tensor->user_arg_inds.data(); + *num_uses = tensor->Users().size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorUse(LiteRtTensor tensor, LiteRtParamIndex use_index, + LiteRtOp* user, + LiteRtParamIndex* user_arg_index) { + if (!tensor || !user || !user_arg_index) { + return kLiteRtStatusErrorInvalidArgument; + } else if (use_index < 0 || use_index >= tensor->Users().size()) { + return kLiteRtStatusErrorIndexOOB; + } + *user = tensor->Users()[use_index]; + *user_arg_index = tensor->UserArgInds()[use_index]; return kLiteRtStatusOk; } @@ -326,10 +384,10 @@ LiteRtStatus LiteRtGetTensorDefiningOp(LiteRtTensor tensor, if (!tensor || !has_defining_op || !defining_op) { return kLiteRtStatusErrorInvalidArgument; } - if (tensor->defining_op != nullptr) { + if (tensor->DefiningOp() != nullptr) { *has_defining_op = true; - defining_op->op = tensor->defining_op; - defining_op->op_output_index = tensor->defining_op_out_ind; + defining_op->op = tensor->DefiningOp(); + defining_op->op_output_index = tensor->DefiningOpOutInd(); } else { *has_defining_op = false; } @@ -341,7 +399,7 @@ LiteRtStatus LiteRtGetTensorTypeId(LiteRtTensor tensor, if (!tensor || !type_id) { return kLiteRtStatusErrorInvalidArgument; } - *type_id = tensor->type_id; + *type_id = tensor->Type().first; return kLiteRtStatusOk; } @@ -349,10 +407,10 @@ LiteRtStatus LiteRtGetUnrankedTensorType( LiteRtTensor tensor, LiteRtUnrankedTensorType* unranked_tensor_type) { if (!tensor || !unranked_tensor_type) { return kLiteRtStatusErrorInvalidArgument; - } else if (tensor->type_id != kLiteRtUnrankedTensorType) { + } else if (tensor->Type().first != kLiteRtUnrankedTensorType) { return kLiteRtStatusErrorInvalidIrType; } - *unranked_tensor_type = tensor->type_detail.unranked_tensor_type; + *unranked_tensor_type = tensor->Type().second.unranked_tensor_type; return kLiteRtStatusOk; } @@ -360,10 +418,10 @@ LiteRtStatus LiteRtGetRankedTensorType( LiteRtTensor tensor, LiteRtRankedTensorType* ranked_tensor_type) { if (!tensor || !ranked_tensor_type) { return kLiteRtStatusErrorInvalidArgument; - } else if (tensor->type_id != kLiteRtRankedTensorType) { + } else if (tensor->Type().first != kLiteRtRankedTensorType) { return kLiteRtStatusErrorInvalidIrType; } - *ranked_tensor_type = tensor->type_detail.ranked_tensor_type; + *ranked_tensor_type = tensor->Type().second.ranked_tensor_type; return kLiteRtStatusOk; } @@ -371,7 +429,7 @@ LiteRtStatus LiteRtGetTensorName(LiteRtTensor tensor, const char** name) { if (!tensor || !name) { return kLiteRtStatusErrorInvalidArgument; } - *name = tensor->name.data(); + *name = tensor->Name().data(); return kLiteRtStatusOk; } @@ -380,7 +438,7 @@ LiteRtStatus LiteRtGetQuantizationTypeId(LiteRtTensor tensor, if (!tensor || !q_type_id) { return kLiteRtStatusErrorInvalidArgument; } - *q_type_id = tensor->q_type_id; + *q_type_id = tensor->Qparams().first; return kLiteRtStatusOk; } @@ -388,9 +446,28 @@ LiteRtStatus LiteRtGetPerTensorQuantization( LiteRtTensor tensor, LiteRtQuantizationPerTensor* per_tensor_quantization) { if (!tensor || !per_tensor_quantization) { return kLiteRtStatusErrorInvalidArgument; - } else if (tensor->q_type_id != kLiteRtQuantizationPerTensor) { + } else if (tensor->Qparams().first != kLiteRtQuantizationPerTensor) { + return kLiteRtStatusErrorInvalidIrType; + } + auto& per_tensor = tensor->Qparams().second.per_tensor; + per_tensor_quantization->scale = per_tensor.scale; + per_tensor_quantization->zero_point = per_tensor.zero_point; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetPerChannelQuantization( + LiteRtTensor tensor, + LiteRtQuantizationPerChannel* per_channel_quantization) { + if (!tensor || !per_channel_quantization) { + return kLiteRtStatusErrorInvalidArgument; + } else if (tensor->Qparams().first != kLiteRtQuantizationPerChannel) { return kLiteRtStatusErrorInvalidIrType; } - *per_tensor_quantization = tensor->q_type_detail.per_tensor; + auto& per_channel = tensor->Qparams().second.per_channel; + per_channel_quantization->scales = per_channel.scales; + per_channel_quantization->zero_points = per_channel.zero_points; + per_channel_quantization->num_channels = per_channel.num_channels; + per_channel_quantization->quantized_dimension = + per_channel.quantized_dimension; return kLiteRtStatusOk; } diff --git a/tensorflow/lite/experimental/litert/c/litert_model.h b/tensorflow/lite/experimental/litert/c/litert_model.h index 8431561158be96..0cae98e0d4e9bd 100644 --- a/tensorflow/lite/experimental/litert/c/litert_model.h +++ b/tensorflow/lite/experimental/litert/c/litert_model.h @@ -15,6 +15,7 @@ #ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_MODEL_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_MODEL_H_ +#include // NOLINT: To use bool type in C #include #include @@ -36,15 +37,12 @@ LITERT_DEFINE_HANDLE(LiteRtWeights); // Values/edges of the models graph. LITERT_DEFINE_HANDLE(LiteRtTensor); -LITERT_DEFINE_HANDLE_ARRAY(LiteRtTensor); // Operations/nodes of the models graph. LITERT_DEFINE_HANDLE(LiteRtOp); -LITERT_DEFINE_HANDLE_ARRAY(LiteRtOp); // Fundamental block of program, i.e. a function body. LITERT_DEFINE_HANDLE(LiteRtSubgraph); -LITERT_DEFINE_HANDLE_ARRAY(LiteRtSubgraph); // Signature of the model. LITERT_DEFINE_HANDLE(LiteRtSignature); @@ -56,7 +54,7 @@ LITERT_DEFINE_HANDLE(LiteRtModel); LITERT_DEFINE_HANDLE(LiteRtOpList); // For indexing into litert collections or counting litert things. -typedef uint64_t LiteRtParamIndex; +typedef size_t LiteRtParamIndex; // // LiteRtTensor + Types @@ -139,6 +137,14 @@ typedef struct { int64_t zero_point; } LiteRtQuantizationPerTensor; +// Schema for tensors quantized with one set of q-params per channel. +typedef struct { + int32_t quantized_dimension; + uint64_t num_channels; + float* scales; + int64_t* zero_points; +} LiteRtQuantizationPerChannel; + // The identifier for quantization scheme type union. typedef enum { // Tag for tensors without quantization. @@ -162,6 +168,11 @@ LiteRtStatus LiteRtGetQuantizationTypeId(LiteRtTensor tensor, LiteRtStatus LiteRtGetPerTensorQuantization( LiteRtTensor tensor, LiteRtQuantizationPerTensor* per_tensor_quantization); +// Get the per-channel quantization information for a given tensor if it has it. +LiteRtStatus LiteRtGetPerChannelQuantization( + LiteRtTensor tensor, + LiteRtQuantizationPerChannel* per_channel_quantization); + // EDGES // Information about the about that defines a tensor. @@ -183,10 +194,11 @@ typedef struct LiteRtTensorUserOp { } LiteRtTensorUserOp; // Get all the ops that reference given tensor, and at what operand index. -LiteRtStatus LiteRtGetTensorUses(LiteRtTensor tensor, - LiteRtParamIndex* num_uses, - LiteRtOpArray* users, - LiteRtParamIndex** user_arg_inds); +LiteRtStatus LiteRtGetNumTensorUses(LiteRtTensor tensor, + LiteRtParamIndex* num_uses); +LiteRtStatus LiteRtGetTensorUse(LiteRtTensor tensor, LiteRtParamIndex use_index, + LiteRtOp* user, + LiteRtParamIndex* user_arg_index); // Get the op that defines this tensor and the corresponding output index. If // tensor is a subgraph input, has_defining_op will be false. @@ -217,31 +229,38 @@ LiteRtStatus LiteRtGetWeightsBytes(LiteRtWeights weights, const void** addr, LiteRtStatus LiteRtGetOpCode(LiteRtOp op, LiteRtOpCode* code); // Get input tensors of given op. -LiteRtStatus LiteRtGetOpInputs(LiteRtOp op, LiteRtParamIndex* num_inputs, - LiteRtTensorArray* inputs); +LiteRtStatus LiteRtGetNumOpInputs(LiteRtOp op, LiteRtParamIndex* num_inputs); +LiteRtStatus LiteRtGetOpInput(LiteRtOp op, LiteRtParamIndex input_index, + LiteRtTensor* input); // Get output tensors of given op. -LiteRtStatus LiteRtGetOpOutputs(LiteRtOp op, LiteRtParamIndex* num_outputs, - LiteRtTensorArray* outputs); +LiteRtStatus LiteRtGetNumOpOutputs(LiteRtOp op, LiteRtParamIndex* num_outputs); +LiteRtStatus LiteRtGetOpOutput(LiteRtOp op, LiteRtParamIndex output_index, + LiteRtTensor* output); // // LiteRtSubgraph // // Get input tensors for given subgraph. -LiteRtStatus LiteRtGetSubgraphInputs(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_inputs, - LiteRtTensorArray* inputs); +LiteRtStatus LiteRtGetNumSubgraphInputs(LiteRtSubgraph subgraph, + LiteRtParamIndex* num_inputs); +LiteRtStatus LiteRtGetSubgraphInput(LiteRtSubgraph subgraph, + LiteRtParamIndex input_index, + LiteRtTensor* input); // Get output tensors for given subgraph. -LiteRtStatus LiteRtGetSubgraphOutputs(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_outputs, - LiteRtTensorArray* outputs); +LiteRtStatus LiteRtGetNumSubgraphOutputs(LiteRtSubgraph subgraph, + LiteRtParamIndex* num_outputs); +LiteRtStatus LiteRtGetSubgraphOutput(LiteRtSubgraph subgraph, + LiteRtParamIndex output_index, + LiteRtTensor* output); // Get all ops in given subgraph in a topological order. -LiteRtStatus LiteRtGetSubgraphOps(LiteRtSubgraph subgraph, - LiteRtParamIndex* num_ops, - LiteRtOpArray* ops); +LiteRtStatus LiteRtGetNumSubgraphOps(LiteRtSubgraph subgraph, + LiteRtParamIndex* num_ops); +LiteRtStatus LiteRtGetSubgraphOp(LiteRtSubgraph subgraph, + LiteRtParamIndex op_index, LiteRtOp* op); // // LiteRtSignature @@ -255,9 +274,9 @@ LiteRtStatus LiteRtGetDefaultSignatureKey(const char** signature_key); LiteRtStatus LiteRtGetSignatureKey(LiteRtSignature signature, const char** signature_key); -// Get the associated subgraph index for the given signature. -LiteRtStatus LiteRtGetSignatureSubgraphIndex(LiteRtSignature signature, - LiteRtParamIndex* subgraph_index); +// Get the associated subgraph for the given signature. +LiteRtStatus LiteRtGetSignatureSubgraph(LiteRtSignature signature, + LiteRtSubgraph* subgraph); // Get the number of inputs for the given signature. LiteRtStatus LiteRtGetNumSignatureInputs(LiteRtSignature signature, diff --git a/tensorflow/lite/experimental/litert/c/litert_model_test.cc b/tensorflow/lite/experimental/litert/c/litert_model_test.cc index baed7ac98db33c..e910786553de26 100644 --- a/tensorflow/lite/experimental/litert/c/litert_model_test.cc +++ b/tensorflow/lite/experimental/litert/c/litert_model_test.cc @@ -17,7 +17,8 @@ #include #include #include -#include +#include +#include #include #include @@ -26,31 +27,16 @@ #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_op_code.h" #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/cc/litert_layout.h" #include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/experimental/litert/test/common.h" -#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tensorflow/lite/experimental/litert/test/test_macros.h" namespace { using ::litert::BufferRef; +using ::litert::internal::MakeTflBuffer; using ::testing::ElementsAreArray; -template -LiteRtWeightsT MakeWeights(std::initializer_list data, size_t offset = 0) { - LiteRtWeightsT weights; - weights.fb_buffer = std::make_unique(); - weights.fb_buffer->data.resize(data.size() * sizeof(T)); - auto data_it = data.begin(); - for (int i = 0; i < data.size(); ++i) { - *(reinterpret_cast(weights.fb_buffer->data.data()) + i) = *data_it; - ++data_it; - } - weights.fb_buffer->size = weights.fb_buffer->data.size(); - weights.fb_buffer->offset = offset; - return weights; -} - TEST(LiteRtWeightsTest, GetNullWeights) { LiteRtWeightsT weights = {}; @@ -63,7 +49,8 @@ TEST(LiteRtWeightsTest, GetNullWeights) { } TEST(LiteRtWeightsTest, GetWeights) { - auto weights = MakeWeights({1, 2, 3}); + LiteRtWeightsT weights; + detail::SetTflBuffer(weights, MakeTflBuffer({1, 2, 3})); const void* addr; size_t size; @@ -77,34 +64,39 @@ TEST(LiteRtWeightsTest, GetWeights) { } TEST(LiteRtTensorTest, GetUnrankedType) { + static constexpr auto kElementType = kLiteRtElementTypeFloat32; + static constexpr auto kId = kLiteRtUnrankedTensorType; + + TensorType type; + type.first = kId; + type.second.unranked_tensor_type.element_type = kElementType; + LiteRtTensorT tensor; - tensor.type_id = kLiteRtUnrankedTensorType; - tensor.type_detail.unranked_tensor_type.element_type = - kLiteRtElementTypeFloat32; + tensor.SetType(std::move(type)); LiteRtTensorTypeId id; LITERT_ASSERT_STATUS_OK(LiteRtGetTensorTypeId(&tensor, &id)); - ASSERT_EQ(id, kLiteRtUnrankedTensorType); + ASSERT_EQ(id, kId); LiteRtUnrankedTensorType unranked; LITERT_ASSERT_STATUS_OK(LiteRtGetUnrankedTensorType(&tensor, &unranked)); - EXPECT_EQ(unranked.element_type, kLiteRtElementTypeFloat32); + EXPECT_EQ(unranked.element_type, kElementType); } TEST(LiteRtTensorTest, GetRankedTensorType) { + static constexpr auto kElementType = kLiteRtElementTypeFloat32; + static constexpr auto kId = kLiteRtRankedTensorType; + LiteRtTensorT tensor; - tensor.type_id = kLiteRtRankedTensorType; - tensor.type_detail.ranked_tensor_type.element_type = - kLiteRtElementTypeFloat32; - tensor.type_detail.ranked_tensor_type.layout = ::litert::BuildLayout({3, 3}); + tensor.SetType(MakeRankedTensorType(kElementType, {3, 3})); LiteRtTensorTypeId id; LITERT_ASSERT_STATUS_OK(LiteRtGetTensorTypeId(&tensor, &id)); - ASSERT_EQ(id, kLiteRtRankedTensorType); + ASSERT_EQ(id, kId); LiteRtRankedTensorType ranked; LITERT_ASSERT_STATUS_OK(LiteRtGetRankedTensorType(&tensor, &ranked)); - EXPECT_EQ(ranked.element_type, kLiteRtElementTypeFloat32); + EXPECT_EQ(ranked.element_type, kElementType); ASSERT_EQ(ranked.layout.rank, 2); EXPECT_THAT(absl::MakeConstSpan(ranked.layout.dimensions, 2), ElementsAreArray({3, 3})); @@ -114,31 +106,35 @@ TEST(LiteRtTensorTest, GetUses) { LiteRtTensorT tensor; LiteRtOpT user; - tensor.users.push_back(&user); - tensor.user_arg_inds.push_back(0); + tensor.Users().push_back(&user); + tensor.UserArgInds().push_back(0); LiteRtOpT other_user; - tensor.users.push_back(&other_user); - tensor.user_arg_inds.push_back(1); + tensor.Users().push_back(&other_user); + tensor.UserArgInds().push_back(1); LiteRtParamIndex num_uses; - LiteRtOpArray actual_users; - LiteRtParamIndex* user_arg_inds; - LITERT_ASSERT_STATUS_OK( - LiteRtGetTensorUses(&tensor, &num_uses, &actual_users, &user_arg_inds)); - + LITERT_ASSERT_STATUS_OK(LiteRtGetNumTensorUses(&tensor, &num_uses)); ASSERT_EQ(num_uses, 2); - EXPECT_THAT(absl::MakeConstSpan(actual_users, 2), - ElementsAreArray({&user, &other_user})); - EXPECT_THAT(absl::MakeConstSpan(user_arg_inds, 2), ElementsAreArray({0, 1})); + + LiteRtOp actual_user; + LiteRtParamIndex actual_user_arg_index; + LITERT_ASSERT_STATUS_OK(LiteRtGetTensorUse( + &tensor, /*use_index=*/0, &actual_user, &actual_user_arg_index)); + ASSERT_EQ(actual_user, &user); + ASSERT_EQ(actual_user_arg_index, 0); + + LITERT_ASSERT_STATUS_OK(LiteRtGetTensorUse( + &tensor, /*use_index=*/1, &actual_user, &actual_user_arg_index)); + ASSERT_EQ(actual_user, &other_user); + ASSERT_EQ(actual_user_arg_index, 1); } TEST(LiteRtTensorTest, GetDefiningOp) { LiteRtTensorT tensor; LiteRtOpT def_op; - tensor.defining_op = &def_op; - tensor.defining_op_out_ind = 0; + tensor.SetDefiningOp(def_op, 0); LiteRtTensorDefiningOp actual_def_op; bool has_defining_op; @@ -160,18 +156,18 @@ TEST(LiteRtTensorTest, NoDefiningOp) { } TEST(LiteRtTensorTest, Name) { - static constexpr absl::string_view kName = "foo"; + static constexpr const char kName[] = "foo"; + LiteRtTensorT tensor; - tensor.name = kName; + tensor.SetName(std::string(kName)); const char* name; LITERT_ASSERT_STATUS_OK(LiteRtGetTensorName(&tensor, &name)); - EXPECT_STREQ(name, kName.data()); + EXPECT_STREQ(name, kName); } TEST(LiteRtTensorTest, QuantizationNone) { LiteRtTensorT tensor; - tensor.q_type_id = kLiteRtQuantizationNone; LiteRtQuantizationTypeId q_type_id; LITERT_ASSERT_STATUS_OK(LiteRtGetQuantizationTypeId(&tensor, &q_type_id)); @@ -187,8 +183,7 @@ TEST(LiteRtTensorTest, QuantizationPerTensor) { static constexpr auto kZeroPoint = 1; LiteRtTensorT tensor; - tensor.q_type_id = kLiteRtQuantizationPerTensor; - tensor.q_type_detail.per_tensor = {kScale, kZeroPoint}; + tensor.SetQarams(MakePerTensorQuantization(kScale, kZeroPoint)); LiteRtQuantizationTypeId q_type_id; LITERT_ASSERT_STATUS_OK(LiteRtGetQuantizationTypeId(&tensor, &q_type_id)); @@ -202,13 +197,47 @@ TEST(LiteRtTensorTest, QuantizationPerTensor) { EXPECT_EQ(per_tensor_quantization.zero_point, kZeroPoint); } +TEST(LiteRtTensorTest, QuantizationPerChannel) { + static constexpr size_t kNumChannels = 2; + static constexpr size_t kQuantizedDimension = 0; + static constexpr float kScales[kNumChannels] = {1.0, 2.0}; + static constexpr int64_t kZps[kNumChannels] = {2, 3}; + + LiteRtTensorT tensor; + + { + auto per_channel = + MakePerChannelQuantization(kScales, kZps, kQuantizedDimension, tensor); + tensor.SetQarams(per_channel); + } + + LiteRtQuantizationTypeId q_type_id; + LITERT_ASSERT_STATUS_OK(LiteRtGetQuantizationTypeId(&tensor, &q_type_id)); + ASSERT_EQ(q_type_id, kLiteRtQuantizationPerChannel); + + LiteRtQuantizationPerChannel per_channel_quantization; + LITERT_ASSERT_STATUS_OK( + LiteRtGetPerChannelQuantization(&tensor, &per_channel_quantization)); + + EXPECT_THAT( + absl::MakeConstSpan(per_channel_quantization.scales, kNumChannels), + testing::ElementsAreArray(kScales)); + EXPECT_THAT( + absl::MakeConstSpan(per_channel_quantization.zero_points, kNumChannels), + testing::ElementsAreArray(kZps)); + ASSERT_EQ(per_channel_quantization.num_channels, kNumChannels); + ASSERT_EQ(per_channel_quantization.quantized_dimension, kQuantizedDimension); +} + TEST(LiteRtOpTest, GetOpCode) { + static constexpr auto kCode = kLiteRtOpCodeTflCustom; + LiteRtOpT op; - op.op_code = kLiteRtOpCodeTflCustom; + op.SetOpCode(kCode); LiteRtOpCode code; LITERT_ASSERT_STATUS_OK(LiteRtGetOpCode(&op, &code)); - EXPECT_EQ(code, kLiteRtOpCodeTflCustom); + EXPECT_EQ(code, kCode); } TEST(LiteRtOpTest, GetInputs) { @@ -216,15 +245,21 @@ TEST(LiteRtOpTest, GetInputs) { LiteRtTensorT input2; LiteRtOpT op; - op.inputs.push_back(&input1); - op.inputs.push_back(&input2); + op.Inputs().push_back(&input1); + op.Inputs().push_back(&input2); - LiteRtTensorArray inputs; LiteRtParamIndex num_inputs; - LITERT_ASSERT_STATUS_OK(LiteRtGetOpInputs(&op, &num_inputs, &inputs)); + LITERT_ASSERT_STATUS_OK(LiteRtGetNumOpInputs(&op, &num_inputs)); ASSERT_EQ(num_inputs, 2); - EXPECT_THAT(absl::MakeConstSpan(inputs, num_inputs), - ElementsAreArray({&input1, &input2})); + + LiteRtTensor actual_input; + LITERT_ASSERT_STATUS_OK( + LiteRtGetOpInput(&op, /*input_index=*/0, &actual_input)); + EXPECT_EQ(actual_input, &input1); + + LITERT_ASSERT_STATUS_OK( + LiteRtGetOpInput(&op, /*input_index=*/1, &actual_input)); + EXPECT_EQ(actual_input, &input2); } TEST(LiteRtOpTest, GetOutputs) { @@ -232,15 +267,21 @@ TEST(LiteRtOpTest, GetOutputs) { LiteRtTensorT output2; LiteRtOpT op; - op.outputs.push_back(&output1); - op.outputs.push_back(&output2); + op.Outputs().push_back(&output1); + op.Outputs().push_back(&output2); - LiteRtTensorArray outputs; LiteRtParamIndex num_outputs; - LITERT_ASSERT_STATUS_OK(LiteRtGetOpOutputs(&op, &num_outputs, &outputs)); + LITERT_ASSERT_STATUS_OK(LiteRtGetNumOpOutputs(&op, &num_outputs)); ASSERT_EQ(num_outputs, 2); - EXPECT_THAT(absl::MakeConstSpan(outputs, num_outputs), - ElementsAreArray({&output1, &output2})); + + LiteRtTensor actual_output; + LITERT_ASSERT_STATUS_OK( + LiteRtGetOpOutput(&op, /*output_index=*/0, &actual_output)); + EXPECT_EQ(actual_output, &output1); + + LITERT_ASSERT_STATUS_OK( + LiteRtGetOpOutput(&op, /*output_index=*/1, &actual_output)); + EXPECT_EQ(actual_output, &output2); } TEST(LiteRtSubgraphTest, GetInputs) { @@ -248,16 +289,20 @@ TEST(LiteRtSubgraphTest, GetInputs) { LiteRtTensorT input2; LiteRtSubgraphT subgraph; - subgraph.inputs.push_back(&input1); - subgraph.inputs.push_back(&input2); + subgraph.Inputs().push_back(&input1); + subgraph.Inputs().push_back(&input2); - LiteRtTensorArray inputs; LiteRtParamIndex num_inputs; + LITERT_ASSERT_STATUS_OK(LiteRtGetNumSubgraphInputs(&subgraph, &num_inputs)); + + LiteRtTensor actual_input; LITERT_ASSERT_STATUS_OK( - LiteRtGetSubgraphInputs(&subgraph, &num_inputs, &inputs)); - ASSERT_EQ(num_inputs, 2); - EXPECT_THAT(absl::MakeConstSpan(inputs, num_inputs), - ElementsAreArray({&input1, &input2})); + LiteRtGetSubgraphInput(&subgraph, /*input_index=*/0, &actual_input)); + EXPECT_EQ(actual_input, &input1); + + LITERT_ASSERT_STATUS_OK( + LiteRtGetSubgraphInput(&subgraph, /*input_index=*/1, &actual_input)); + EXPECT_EQ(actual_input, &input2); } TEST(LiteRtSubgraphTest, GetOutputs) { @@ -265,51 +310,58 @@ TEST(LiteRtSubgraphTest, GetOutputs) { LiteRtTensorT output2; LiteRtSubgraphT subgraph; - subgraph.outputs.push_back(&output1); - subgraph.outputs.push_back(&output2); + subgraph.Outputs().push_back(&output1); + subgraph.Outputs().push_back(&output2); - LiteRtTensorArray outputs; LiteRtParamIndex num_outputs; + LITERT_ASSERT_STATUS_OK(LiteRtGetNumSubgraphOutputs(&subgraph, &num_outputs)); + + LiteRtTensor actual_output; LITERT_ASSERT_STATUS_OK( - LiteRtGetSubgraphOutputs(&subgraph, &num_outputs, &outputs)); - ASSERT_EQ(num_outputs, 2); - EXPECT_THAT(absl::MakeConstSpan(outputs, num_outputs), - ElementsAreArray({&output1, &output2})); + LiteRtGetSubgraphOutput(&subgraph, /*output_index=*/0, &actual_output)); + EXPECT_EQ(actual_output, &output1); + + LITERT_ASSERT_STATUS_OK( + LiteRtGetSubgraphOutput(&subgraph, /*output_index=*/1, &actual_output)); + EXPECT_EQ(actual_output, &output2); } TEST(LiteRtSubgraphTest, GetOps) { - LiteRtOpT op1; - LiteRtOpT op2; - LiteRtSubgraphT subgraph; - subgraph.ops.push_back(&op1); - subgraph.ops.push_back(&op2); + auto& op1 = subgraph.EmplaceOp(); + auto& op2 = subgraph.EmplaceOp(); - LiteRtOpArray ops; LiteRtParamIndex num_ops; - LITERT_ASSERT_STATUS_OK(LiteRtGetSubgraphOps(&subgraph, &num_ops, &ops)); + LITERT_ASSERT_STATUS_OK(LiteRtGetNumSubgraphOps(&subgraph, &num_ops)); ASSERT_EQ(num_ops, 2); - EXPECT_THAT(absl::MakeConstSpan(ops, num_ops), - ElementsAreArray({&op1, &op2})); + + LiteRtOp actual_op; + LITERT_ASSERT_STATUS_OK( + LiteRtGetSubgraphOp(&subgraph, /*op_index=*/0, &actual_op)); + ASSERT_EQ(actual_op, &op1); + + LITERT_ASSERT_STATUS_OK( + LiteRtGetSubgraphOp(&subgraph, /*op_index=*/1, &actual_op)); + ASSERT_EQ(actual_op, &op2); } TEST(LiteRtModelTest, GetMetadata) { + static constexpr absl::string_view kKey = "KEY"; + static constexpr absl::string_view kData = "DATA"; + LiteRtModelT model; - model.flatbuffer_model = std::make_unique(); - litert::OwningBufferRef buf("Bar"); - model.PushMetadata("Foo", buf); + model.PushMetadata(kKey, kData); const void* metadata; size_t metadata_size; LITERT_ASSERT_STATUS_OK( - LiteRtGetModelMetadata(&model, "Foo", &metadata, &metadata_size)); - ASSERT_EQ(metadata_size, 3); - EXPECT_EQ(BufferRef(metadata, metadata_size).StrView(), "Bar"); + LiteRtGetModelMetadata(&model, kKey.data(), &metadata, &metadata_size)); + EXPECT_EQ(BufferRef(metadata, metadata_size).StrView(), kData); } TEST(LiteRtModelTest, GetSubgraph) { LiteRtModelT model; - auto& subgraph = model.subgraphs.emplace_back(); + auto& subgraph = model.EmplaceSubgraph(); LiteRtSubgraph actual_subgraph; LITERT_ASSERT_STATUS_OK(LiteRtGetModelSubgraph(&model, 0, &actual_subgraph)); diff --git a/tensorflow/lite/experimental/litert/c/litert_options.cc b/tensorflow/lite/experimental/litert/c/litert_options.cc index a0e64052239318..b34651b4e4eea7 100644 --- a/tensorflow/lite/experimental/litert/c/litert_options.cc +++ b/tensorflow/lite/experimental/litert/c/litert_options.cc @@ -26,212 +26,303 @@ LiteRtStatus LiteRtGetAddFusedActivationOption(LiteRtOp op, uint32_t* fused_activation) { - if (op->op_code != kLiteRtOpCodeTflAdd) { + if (op->OpCode() != kLiteRtOpCodeTflAdd) { return kLiteRtStatusErrorInvalidArgument; } - *fused_activation = op->option.AsAddOptions()->fused_activation_function; + const auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { + return kLiteRtStatusErrorNotFound; + } + *fused_activation = opts.AsAddOptions()->fused_activation_function; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetBatchMatmulAdjXOption(LiteRtOp op, bool* adj_x) { - if (op->op_code != kLiteRtOpCodeTflBatchMatmul) { + if (op->OpCode() != kLiteRtOpCodeTflBatchMatmul) { + return kLiteRtStatusErrorInvalidArgument; + } + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { return kLiteRtStatusErrorInvalidArgument; } - *adj_x = op->option.AsBatchMatMulOptions()->adj_x; + *adj_x = opts.AsBatchMatMulOptions()->adj_x; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetBatchMatmulAdjYOption(LiteRtOp op, bool* adj_y) { - if (op->op_code != kLiteRtOpCodeTflBatchMatmul) { + if (op->OpCode() != kLiteRtOpCodeTflBatchMatmul) { return kLiteRtStatusErrorInvalidArgument; } - *adj_y = op->option.AsBatchMatMulOptions()->adj_y; + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { + return kLiteRtStatusErrorInvalidArgument; + } + *adj_y = opts.AsBatchMatMulOptions()->adj_y; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetBatchMatmulAsymmetricQuantizeInputOption( LiteRtOp op, bool* asymmetric_quantize_input) { - if (op->op_code != kLiteRtOpCodeTflBatchMatmul) { + if (op->OpCode() != kLiteRtOpCodeTflBatchMatmul) { + return kLiteRtStatusErrorInvalidArgument; + } + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { return kLiteRtStatusErrorInvalidArgument; } *asymmetric_quantize_input = - op->option.AsBatchMatMulOptions()->asymmetric_quantize_inputs; + opts.AsBatchMatMulOptions()->asymmetric_quantize_inputs; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetConcatenationFusedActivationOption( LiteRtOp op, uint32_t* fused_activation) { - if (op->op_code != kLiteRtOpCodeTflConcatenation) { + if (op->OpCode() != kLiteRtOpCodeTflConcatenation) { + return kLiteRtStatusErrorInvalidArgument; + } + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { return kLiteRtStatusErrorInvalidArgument; } - *fused_activation = - op->option.AsConcatenationOptions()->fused_activation_function; + *fused_activation = opts.AsConcatenationOptions()->fused_activation_function; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetConcatenationAxisOption(LiteRtOp op, int32_t* axis) { - if (op->op_code != kLiteRtOpCodeTflConcatenation) { + if (op->OpCode() != kLiteRtOpCodeTflConcatenation) { return kLiteRtStatusErrorInvalidArgument; } - *axis = op->option.AsConcatenationOptions()->axis; + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { + return kLiteRtStatusErrorInvalidArgument; + } + *axis = opts.AsConcatenationOptions()->axis; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetDivFusedActivationOption(LiteRtOp op, uint32_t* fused_activation) { - if (op->op_code != kLiteRtOpCodeTflDiv) { + if (op->OpCode() != kLiteRtOpCodeTflDiv) { return kLiteRtStatusErrorInvalidArgument; } - *fused_activation = op->option.AsDivOptions()->fused_activation_function; + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { + return kLiteRtStatusErrorInvalidArgument; + } + *fused_activation = opts.AsDivOptions()->fused_activation_function; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetFullyConnectedFusedActivationOption( LiteRtOp op, uint32_t* fused_activation) { - if (op->op_code != kLiteRtOpCodeTflFullyConnected) { + if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { + return kLiteRtStatusErrorInvalidArgument; + } + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { return kLiteRtStatusErrorInvalidArgument; } - *fused_activation = - op->option.AsFullyConnectedOptions()->fused_activation_function; + *fused_activation = opts.AsFullyConnectedOptions()->fused_activation_function; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetFullyConnectedKeepNumDimsOption(LiteRtOp op, bool* keep_num_dims) { - if (op->op_code != kLiteRtOpCodeTflFullyConnected) { + if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { return kLiteRtStatusErrorInvalidArgument; } - *keep_num_dims = op->option.AsFullyConnectedOptions()->keep_num_dims; + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { + return kLiteRtStatusErrorInvalidArgument; + } + *keep_num_dims = opts.AsFullyConnectedOptions()->keep_num_dims; return kLiteRtStatusOk; } LiteRtStatus LiteRtFullyConnectedGetQuantizedBiasTypeOption( LiteRtOp op, uint32_t* quantized_bias_type) { - if (op->op_code != kLiteRtOpCodeTflFullyConnected) { + if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { return kLiteRtStatusErrorInvalidArgument; } - *quantized_bias_type = - op->option.AsFullyConnectedOptions()->quantized_bias_type; + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { + return kLiteRtStatusErrorInvalidArgument; + } + *quantized_bias_type = opts.AsFullyConnectedOptions()->quantized_bias_type; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetFullyConnectedAsymmetricQuantizeInputOption( LiteRtOp op, bool* asymmetric_quantize_input) { - if (op->op_code != kLiteRtOpCodeTflFullyConnected) { + if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { + return kLiteRtStatusErrorInvalidArgument; + } + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { return kLiteRtStatusErrorInvalidArgument; } *asymmetric_quantize_input = - op->option.AsFullyConnectedOptions()->asymmetric_quantize_inputs; + opts.AsFullyConnectedOptions()->asymmetric_quantize_inputs; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetFullyConnectedWeightsFormatOption( LiteRtOp op, uint32_t* weights_format) { - if (op->op_code != kLiteRtOpCodeTflFullyConnected) { + if (op->OpCode() != kLiteRtOpCodeTflFullyConnected) { + return kLiteRtStatusErrorInvalidArgument; + } + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { return kLiteRtStatusErrorInvalidArgument; } - *weights_format = op->option.AsFullyConnectedOptions()->weights_format; + *weights_format = opts.AsFullyConnectedOptions()->weights_format; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetMulFusedActivationOption(LiteRtOp op, uint32_t* fused_activation) { - if (op->op_code != kLiteRtOpCodeTflMul) { + if (op->OpCode() != kLiteRtOpCodeTflMul) { return kLiteRtStatusErrorInvalidArgument; } - *fused_activation = op->option.AsMulOptions()->fused_activation_function; + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { + return kLiteRtStatusErrorInvalidArgument; + } + *fused_activation = opts.AsMulOptions()->fused_activation_function; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetSoftmaxBetaOption(LiteRtOp op, float* beta) { - if (op->op_code != kLiteRtOpCodeTflSoftmax) { + if (op->OpCode() != kLiteRtOpCodeTflSoftmax) { + return kLiteRtStatusErrorInvalidArgument; + } + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { return kLiteRtStatusErrorInvalidArgument; } - *beta = op->option.AsSoftmaxOptions()->beta; + *beta = opts.AsSoftmaxOptions()->beta; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetStridedSliceBeginMaskOption(LiteRtOp op, int32_t* begin_mask) { - if (op->op_code != kLiteRtOpCodeTflStridedSlice) { + if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { + return kLiteRtStatusErrorInvalidArgument; + } + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { return kLiteRtStatusErrorInvalidArgument; } - *begin_mask = op->option.AsStridedSliceOptions()->begin_mask; + *begin_mask = opts.AsStridedSliceOptions()->begin_mask; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetStridedSliceEndMaskOption(LiteRtOp op, int32_t* end_mask) { - if (op->op_code != kLiteRtOpCodeTflStridedSlice) { + if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { return kLiteRtStatusErrorInvalidArgument; } - *end_mask = op->option.AsStridedSliceOptions()->end_mask; + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { + return kLiteRtStatusErrorInvalidArgument; + } + *end_mask = opts.AsStridedSliceOptions()->end_mask; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetStridedSliceEllipsisMaskOption(LiteRtOp op, int32_t* ellipsis_mask) { - if (op->op_code != kLiteRtOpCodeTflStridedSlice) { + if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { + return kLiteRtStatusErrorInvalidArgument; + } + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { return kLiteRtStatusErrorInvalidArgument; } - *ellipsis_mask = op->option.AsStridedSliceOptions()->ellipsis_mask; + *ellipsis_mask = opts.AsStridedSliceOptions()->ellipsis_mask; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetStridedSliceNewAxisMaskOption(LiteRtOp op, int32_t* new_axis_mask) { - if (op->op_code != kLiteRtOpCodeTflStridedSlice) { + if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { + return kLiteRtStatusErrorInvalidArgument; + } + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { return kLiteRtStatusErrorInvalidArgument; } - *new_axis_mask = op->option.AsStridedSliceOptions()->new_axis_mask; + *new_axis_mask = opts.AsStridedSliceOptions()->new_axis_mask; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetStridedSliceShrinkAxisMaskOption( LiteRtOp op, int32_t* shrink_axis_mask) { - if (op->op_code != kLiteRtOpCodeTflStridedSlice) { + if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { return kLiteRtStatusErrorInvalidArgument; } - *shrink_axis_mask = op->option.AsStridedSliceOptions()->shrink_axis_mask; + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { + return kLiteRtStatusErrorInvalidArgument; + } + *shrink_axis_mask = opts.AsStridedSliceOptions()->shrink_axis_mask; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetStridedSliceOffsetOption(LiteRtOp op, bool* offset) { - if (op->op_code != kLiteRtOpCodeTflStridedSlice) { + if (op->OpCode() != kLiteRtOpCodeTflStridedSlice) { + return kLiteRtStatusErrorInvalidArgument; + } + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { return kLiteRtStatusErrorInvalidArgument; } - *offset = op->option.AsStridedSliceOptions()->offset; + *offset = opts.AsStridedSliceOptions()->offset; return kLiteRtStatusOk; } LiteRtStatus LiteRtGetSubFusedActivationOption(LiteRtOp op, uint32_t* fused_activation) { - if (op->op_code != kLiteRtOpCodeTflSub) { + if (op->OpCode() != kLiteRtOpCodeTflSub) { + return kLiteRtStatusErrorInvalidArgument; + } + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { return kLiteRtStatusErrorInvalidArgument; } - *fused_activation = op->option.AsSubOptions()->fused_activation_function; + *fused_activation = opts.AsSubOptions()->fused_activation_function; return kLiteRtStatusOk; } -LiteRtStatus LiteRtGetReshapeNewShapeOption(LiteRtOp op, int32_t** new_shape, +LiteRtStatus LiteRtGetReshapeNewShapeOption(LiteRtOp op, + const int32_t** new_shape, int32_t* new_shape_size) { - if (op->op_code != kLiteRtOpCodeTflReshape) { + if (op->OpCode() != kLiteRtOpCodeTflReshape) { return kLiteRtStatusErrorInvalidArgument; } - if (op->option.AsReshapeOptions() == nullptr) { + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { + *new_shape_size = -1; + return kLiteRtStatusErrorInvalidArgument; + } + if (opts.AsReshapeOptions() == nullptr) { *new_shape_size = -1; return kLiteRtStatusOk; } else { - *new_shape = op->option.AsReshapeOptions()->new_shape.data(); - *new_shape_size = op->option.AsReshapeOptions()->new_shape.size(); + *new_shape = opts.AsReshapeOptions()->new_shape.data(); + *new_shape_size = opts.AsReshapeOptions()->new_shape.size(); } return kLiteRtStatusOk; } LiteRtStatus LiteRtGetSumKeepDimsOption(LiteRtOp op, bool* keepdims) { - if (op->op_code != kLiteRtOpCodeTflSum) { + if (op->OpCode() != kLiteRtOpCodeTflSum) { + return kLiteRtStatusErrorInvalidArgument; + } + auto& opts = detail::GetTflOptions(*op); + if (opts.value == nullptr) { return kLiteRtStatusErrorInvalidArgument; } // Sum OP options is stored as ReducerOptions. - *keepdims = op->option.AsReducerOptions()->keep_dims; + *keepdims = opts.AsReducerOptions()->keep_dims; return kLiteRtStatusOk; } diff --git a/tensorflow/lite/experimental/litert/c/litert_options.h b/tensorflow/lite/experimental/litert/c/litert_options.h index 5ac05cccf33b9e..4fd2da625f2430 100644 --- a/tensorflow/lite/experimental/litert/c/litert_options.h +++ b/tensorflow/lite/experimental/litert/c/litert_options.h @@ -15,6 +15,7 @@ #ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OPTIONS_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OPTIONS_H_ +#include // NOLINT: To use bool type in C #include #include "tensorflow/lite/experimental/litert/c/litert_common.h" @@ -153,7 +154,8 @@ LiteRtStatus LiteRtGetSubFusedActivationOption(LiteRtOp op, // - new_shape : int32_t[] // //============================================================================== -LiteRtStatus LiteRtGetReshapeNewShapeOption(LiteRtOp op, int32_t** new_shape, +LiteRtStatus LiteRtGetReshapeNewShapeOption(LiteRtOp op, + const int32_t** new_shape, int32_t* new_shape_size); //============================================================================== diff --git a/tensorflow/lite/experimental/litert/c/litert_options_test.cc b/tensorflow/lite/experimental/litert/c/litert_options_test.cc index cce5216bd161f7..a2ad861a565fdd 100644 --- a/tensorflow/lite/experimental/litert/c/litert_options_test.cc +++ b/tensorflow/lite/experimental/litert/c/litert_options_test.cc @@ -19,6 +19,7 @@ #include #include "tensorflow/lite/experimental/litert/c/litert_options.h" #include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/test/test_macros.h" namespace { TEST(GetOpOptionTest, TestGetAddOptions) { @@ -123,7 +124,7 @@ TEST(GetOpOptionTest, TestGetFullyConnectedOptions) { bool asymmetric_quantize_input; LITERT_ASSERT_STATUS_OK(LiteRtGetFullyConnectedAsymmetricQuantizeInputOption( op, &asymmetric_quantize_input)); - ASSERT_EQ(asymmetric_quantize_input, false); + ASSERT_EQ(asymmetric_quantize_input, true); } TEST(GetOpOptionTest, TestGetMulOptions) { @@ -205,7 +206,7 @@ TEST(GetOpOptionTest, TestGetSubOptions) { ASSERT_EQ(fused_activation, 0); } -TEST(GetOpOptionTest, TestGetReshapeOptions) { +TEST(GetOpOptionTest, TestGetNullReshapeOptions) { auto model = litert::testing::LoadTestFileModel("simple_reshape_op.tflite"); auto subgraph = model.MainSubgraph(); EXPECT_TRUE(subgraph); @@ -213,10 +214,11 @@ TEST(GetOpOptionTest, TestGetReshapeOptions) { auto ops = subgraph->Ops(); auto op = ops.front().Get(); - int32_t* new_shape = nullptr; + const int32_t* new_shape = nullptr; int32_t new_shape_size; - LITERT_ASSERT_STATUS_OK( - LiteRtGetReshapeNewShapeOption(op, &new_shape, &new_shape_size)); + + LITERT_ASSERT_STATUS_HAS_CODE( + LiteRtGetReshapeNewShapeOption(op, &new_shape, &new_shape_size), 1); ASSERT_EQ(new_shape_size, -1); } diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.cc b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.cc index cf5068575e225d..4e6cbd5d8132f3 100644 --- a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.cc +++ b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.cc @@ -277,6 +277,46 @@ LiteRtStatus LiteRtGetTensorBufferHostMemory(LiteRtTensorBuffer tensor_buffer, return kLiteRtStatusOk; } +LiteRtStatus LiteRtHasTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, + bool* has_event) { + if (!tensor_buffer || !has_event) { + return kLiteRtStatusErrorInvalidArgument; + } + *has_event = tensor_buffer->HasEvent(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, + LiteRtEvent* event) { + if (!tensor_buffer || !event) { + return kLiteRtStatusErrorInvalidArgument; + } + auto result = tensor_buffer->GetEvent(); + if (!result) { + LITERT_LOG(LITERT_ERROR, "%s", result.Error().Message().data()); + return result.Error().Status(); + } + *event = *result; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtSetTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, + LiteRtEvent event) { + if (!tensor_buffer || !event) { + return kLiteRtStatusErrorInvalidArgument; + } + tensor_buffer->SetEvent(event); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtClearTensorBufferEvent(LiteRtTensorBuffer tensor_buffer) { + if (!tensor_buffer) { + return kLiteRtStatusErrorInvalidArgument; + } + tensor_buffer->ClearEvent(); + return kLiteRtStatusOk; +} + LiteRtStatus LiteRtLockTensorBuffer(LiteRtTensorBuffer tensor_buffer, void** host_mem_addr, LiteRtEvent event) { if (!tensor_buffer || !host_mem_addr) { diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h index 54479ba23a3e98..2adbb49856d2b9 100644 --- a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h +++ b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h @@ -170,6 +170,17 @@ LiteRtStatus LiteRtGetTensorBufferSize(LiteRtTensorBuffer tensor_buffer, LiteRtStatus LiteRtGetTensorBufferOffset(LiteRtTensorBuffer tensor_buffer, size_t* offset); +LiteRtStatus LiteRtHasTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, + bool* has_event); + +LiteRtStatus LiteRtGetTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, + LiteRtEvent* event); + +LiteRtStatus LiteRtSetTensorBufferEvent(LiteRtTensorBuffer tensor_buffer, + LiteRtEvent event); + +LiteRtStatus LiteRtClearTensorBufferEvent(LiteRtTensorBuffer tensor_buffer); + // Lock a tensor buffer and map it to host memory, optionally synchronizing on a // given input event (parameter `event` can be NULL). LiteRtStatus LiteRtLockTensorBuffer(LiteRtTensorBuffer tensor_buffer, diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_test.cc b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_test.cc index c962abf2b5dce2..e4cb5aa1c8e0ec 100644 --- a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_test.cc +++ b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_test.cc @@ -19,10 +19,12 @@ #include // NOLINT: Need when ANDROID_API_LEVEL >= 26 #include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_event.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" #include "tensorflow/lite/experimental/litert/cc/litert_layout.h" #include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" // IWYU pragma: keep #include "tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.h" // IWYU pragma: keep +#include "tensorflow/lite/experimental/litert/runtime/event.h" #include "tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.h" // IWYU pragma: keep #include "tensorflow/lite/experimental/litert/runtime/ion_buffer.h" // IWYU pragma: keep @@ -294,3 +296,40 @@ TEST(TensorBuffer, FastRpc) { LiteRtDestroyTensorBuffer(tensor_buffer); } + +TEST(TensorBuffer, Event) { + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeHostMemory; + LiteRtTensorBuffer tensor_buffer; + ASSERT_EQ( + LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, + sizeof(kTensorData), &tensor_buffer), + kLiteRtStatusOk); + + bool has_event = true; + ASSERT_EQ(LiteRtHasTensorBufferEvent(tensor_buffer, &has_event), + kLiteRtStatusOk); + EXPECT_FALSE(has_event); + + LiteRtEventT event; + ASSERT_EQ(LiteRtSetTensorBufferEvent(tensor_buffer, &event), kLiteRtStatusOk); + + has_event = false; + ASSERT_EQ(LiteRtHasTensorBufferEvent(tensor_buffer, &has_event), + kLiteRtStatusOk); + EXPECT_TRUE(has_event); + + LiteRtEvent actual_event; + ASSERT_EQ(LiteRtGetTensorBufferEvent(tensor_buffer, &actual_event), + kLiteRtStatusOk); + ASSERT_EQ(actual_event, &event); + + ASSERT_EQ(LiteRtClearTensorBufferEvent(tensor_buffer), kLiteRtStatusOk); + ASSERT_EQ(actual_event, &event); + + has_event = true; + ASSERT_EQ(LiteRtHasTensorBufferEvent(tensor_buffer, &has_event), + kLiteRtStatusOk); + EXPECT_FALSE(has_event); + + LiteRtDestroyTensorBuffer(tensor_buffer); +} diff --git a/tensorflow/lite/experimental/litert/cc/BUILD b/tensorflow/lite/experimental/litert/cc/BUILD index d5284e03370b71..253792fc28f01b 100644 --- a/tensorflow/lite/experimental/litert/cc/BUILD +++ b/tensorflow/lite/experimental/litert/cc/BUILD @@ -14,13 +14,46 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + default_visibility = [ + # copybara:uncomment "//third_party/mediapipe/calculators/tensor:__subpackages__", + # copybara:uncomment "//third_party/odml/infra:__subpackages__", + "//tensorflow/lite/experimental/litert:__subpackages__", + ], +) + +cc_library( + name = "litert_environment", + hdrs = ["litert_environment.h"], + deps = [ + ":litert_any", + ":litert_expected", + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_environment", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "litert_event", + hdrs = ["litert_event.h"], + deps = [ + ":litert_expected", + ":litert_handle", + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_event", + "@com_google_absl//absl/types:span", + ], ) cc_library( name = "litert_any", hdrs = ["litert_any.h"], - deps = ["//tensorflow/lite/experimental/litert/c:litert_common"], + deps = [ + ":litert_expected", + "//tensorflow/lite/experimental/litert/c:litert_any", + "//tensorflow/lite/experimental/litert/c:litert_common", + "@com_google_absl//absl/strings:string_view", + ], ) cc_test( @@ -28,6 +61,10 @@ cc_test( srcs = [ "litert_any_test.cc", ], + linkopts = select({ + "//tensorflow:android": ["-llog"], + "//conditions:default": [], + }), deps = [ ":litert_any", "//tensorflow/lite/experimental/litert/c:litert_common", @@ -38,7 +75,10 @@ cc_test( cc_library( name = "litert_model", srcs = ["litert_model.cc"], - hdrs = ["litert_model.h"], + hdrs = [ + "litert_consts.h", + "litert_model.h", + ], deps = [ ":litert_buffer_ref", ":litert_detail", @@ -48,7 +88,7 @@ cc_library( ":litert_layout", "//tensorflow/lite/experimental/litert/c:litert_common", "//tensorflow/lite/experimental/litert/c:litert_model", - "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], @@ -90,10 +130,12 @@ cc_library( ], deps = [ ":litert_detail", + ":litert_event", ":litert_handle", ":litert_model", "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_event", "//tensorflow/lite/experimental/litert/c:litert_model", "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", "//tensorflow/lite/experimental/litert/cc:litert_expected", @@ -204,7 +246,6 @@ cc_library( hdrs = ["litert_detail.h"], deps = [ "//tensorflow/lite/experimental/litert/c:litert_common", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:absl_check", ], ) @@ -308,6 +349,7 @@ cc_library( "//tensorflow/lite/experimental/litert/c:litert_model", "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", "//tensorflow/lite/kernels:builtin_ops", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:string_view", ], ) @@ -321,14 +363,19 @@ cc_test( deps = [ ":litert_compiled_model", ":litert_model", + ":litert_tensor_buffer", "//tensorflow/lite:framework", "//tensorflow/lite/c:c_api_opaque", "//tensorflow/lite/c:common", "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/experimental/litert/test:simple_model", "//tensorflow/lite/kernels:builtin_ops", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", ], ) + +exports_files(srcs = glob(["litert_*.h"])) diff --git a/tensorflow/lite/experimental/litert/cc/litert_any.h b/tensorflow/lite/experimental/litert/cc/litert_any.h index 4f724f85f52935..7b95e65e809cad 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_any.h +++ b/tensorflow/lite/experimental/litert/cc/litert_any.h @@ -16,8 +16,12 @@ #define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ANY_H_ #include +#include +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/c/litert_any.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" namespace litert { @@ -45,6 +49,68 @@ inline std::any ToStdAny(LiteRtAny litert_any) { return res; } +inline Expected ToLiteRtAny(const std::any& any) { + LiteRtAny result; + if (!any.has_value()) { + result.type = kLiteRtAnyTypeNone; + return result; + + } else if (any.type() == typeid(LiteRtAny::bool_value)) { + result.type = kLiteRtAnyTypeBool; + result.bool_value = std::any_cast(any); + return result; + + } else if (any.type() == typeid(int8_t)) { + result.type = kLiteRtAnyTypeInt; + result.int_value = std::any_cast(any); + return result; + + } else if (any.type() == typeid(int16_t)) { + result.type = kLiteRtAnyTypeInt; + result.int_value = std::any_cast(any); + return result; + + } else if (any.type() == typeid(int32_t)) { + result.type = kLiteRtAnyTypeInt; + result.int_value = std::any_cast(any); + return result; + + } else if (any.type() == typeid(int64_t)) { + result.type = kLiteRtAnyTypeInt; + result.int_value = std::any_cast(any); + return result; + + } else if (any.type() == typeid(float)) { + result.type = kLiteRtAnyTypeReal; + result.real_value = std::any_cast(any); + return result; + + } else if (any.type() == typeid(double)) { + result.type = kLiteRtAnyTypeReal; + result.real_value = std::any_cast(any); + return result; + + } else if (any.type() == typeid(LiteRtAny::str_value)) { + result.type = kLiteRtAnyTypeString; + result.str_value = std::any_cast(any); + return result; + + } else if (any.type() == typeid(absl::string_view)) { + result.type = kLiteRtAnyTypeString; + result.str_value = std::any_cast(any).data(); + return result; + + } else if (any.type() == typeid(LiteRtAny::ptr_value)) { + result.type = kLiteRtAnyTypeVoidPtr; + result.ptr_value = std::any_cast(any); + return result; + + } else { + return Error(kLiteRtStatusErrorInvalidArgument, + "Invalid argument for ToLiteRtAny"); + } +} + } // namespace litert #endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ANY_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_any_test.cc b/tensorflow/lite/experimental/litert/cc/litert_any_test.cc index 0d3b4db29537c9..c6640ab8060c1c 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_any_test.cc +++ b/tensorflow/lite/experimental/litert/cc/litert_any_test.cc @@ -22,6 +22,8 @@ TEST(Any, ConversionNone) { EXPECT_FALSE( litert::ToStdAny(LiteRtAny{/*.type=*/kLiteRtAnyTypeNone}).has_value()); + + ASSERT_EQ(litert::ToLiteRtAny(std::any())->type, kLiteRtAnyTypeNone); } TEST(Any, ConversionBool) { @@ -31,6 +33,11 @@ TEST(Any, ConversionBool) { ASSERT_EQ(std::any_cast(litert::ToStdAny(LiteRtAny{ /*.type=*/kLiteRtAnyTypeBool, {/*.bool_value=*/false}})), false); + + ASSERT_EQ(litert::ToLiteRtAny(std::any(true))->type, kLiteRtAnyTypeBool); + ASSERT_EQ(litert::ToLiteRtAny(std::any(true))->bool_value, true); + ASSERT_EQ(litert::ToLiteRtAny(std::any(false))->type, kLiteRtAnyTypeBool); + ASSERT_EQ(litert::ToLiteRtAny(std::any(false))->bool_value, false); } TEST(Any, ConversionInt) { @@ -38,6 +45,26 @@ TEST(Any, ConversionInt) { litert_any.type = kLiteRtAnyTypeInt; litert_any.int_value = 1234; ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), 1234); + + ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(12)))->type, + kLiteRtAnyTypeInt); + ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(12)))->int_value, + 12); + ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1234)))->type, + kLiteRtAnyTypeInt); + ASSERT_EQ( + litert::ToLiteRtAny(std::any(static_cast(1234)))->int_value, + 1234); + ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1234)))->type, + kLiteRtAnyTypeInt); + ASSERT_EQ( + litert::ToLiteRtAny(std::any(static_cast(1234)))->int_value, + 1234); + ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1234)))->type, + kLiteRtAnyTypeInt); + ASSERT_EQ( + litert::ToLiteRtAny(std::any(static_cast(1234)))->int_value, + 1234); } TEST(Any, ConversionReal) { @@ -45,6 +72,17 @@ TEST(Any, ConversionReal) { litert_any.type = kLiteRtAnyTypeReal; litert_any.real_value = 123.4; ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), 123.4); + + ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1.2)))->type, + kLiteRtAnyTypeReal); + EXPECT_NEAR( + litert::ToLiteRtAny(std::any(static_cast(1.2)))->real_value, 1.2, + 1e-7); + ASSERT_EQ(litert::ToLiteRtAny(std::any(static_cast(1.2)))->type, + kLiteRtAnyTypeReal); + EXPECT_NEAR( + litert::ToLiteRtAny(std::any(static_cast(1.2)))->real_value, 1.2, + 1e-7); } TEST(Any, ConversionString) { @@ -54,6 +92,9 @@ TEST(Any, ConversionString) { litert_any.str_value = kTestString; ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), kTestString); + + ASSERT_EQ(litert::ToLiteRtAny(std::any("test"))->type, kLiteRtAnyTypeString); + EXPECT_STREQ(litert::ToLiteRtAny(std::any("test"))->str_value, "test"); } TEST(Any, ConversionPtr) { @@ -62,4 +103,8 @@ TEST(Any, ConversionPtr) { litert_any.type = kLiteRtAnyTypeVoidPtr; litert_any.ptr_value = kTestPtr; ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), kTestPtr); + + ASSERT_EQ(litert::ToLiteRtAny(std::any(kTestPtr))->type, + kLiteRtAnyTypeVoidPtr); + EXPECT_EQ(litert::ToLiteRtAny(std::any(kTestPtr))->ptr_value, kTestPtr); } diff --git a/tensorflow/lite/experimental/litert/cc/litert_compiled_model.cc b/tensorflow/lite/experimental/litert/cc/litert_compiled_model.cc index a41c315200459c..f8cb51097be5b7 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_compiled_model.cc +++ b/tensorflow/lite/experimental/litert/cc/litert_compiled_model.cc @@ -19,6 +19,8 @@ #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h" #include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" @@ -34,7 +36,7 @@ Expected> CompiledModel::CreateInputBuffers( if (!signature) { return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find signature"); } - auto subgraph = model_->Subgraph(signature->SubgraphIndex()); + auto subgraph = model_->Subgraph(signature->Key()); if (!subgraph) { return Unexpected(kLiteRtStatusErrorNotFound, "Failed to get subgraph"); } @@ -49,19 +51,35 @@ Expected> CompiledModel::CreateInputBuffers( return Unexpected(kLiteRtStatusErrorRuntimeFailure, input_buffer_requirements.Error().Message()); } + + auto supported_types = input_buffer_requirements->SupportedTypes(); + if (!supported_types) { + return supported_types.Error(); + } + if (supported_types->empty()) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Input doesn't support any tensor buffer types"); + } + // For simplicity we just pick the first supported tensor buffer type. + LiteRtTensorBufferType tensor_buffer_type = (*supported_types)[0]; + auto tensor_type = input_tensors[i].RankedTensorType(); - LiteRtTensorBufferType tensor_buffer_type = - (*(*input_buffer_requirements).SupportedTypes())[0]; + if (!tensor_type) { + return tensor_type.Error(); + } + auto input_buffer = TensorBuffer::CreateManaged( - tensor_buffer_type, tensor_type, + tensor_buffer_type, *tensor_type, (*input_buffer_requirements).BufferSize().Value()); if (!input_buffer) { return Unexpected(kLiteRtStatusErrorRuntimeFailure, input_buffer.Error().Message()); } + input_buffers.push_back(std::move(*input_buffer)); } - return std::move(input_buffers); + + return input_buffers; } Expected> CompiledModel::CreateOutputBuffers( @@ -70,13 +88,16 @@ Expected> CompiledModel::CreateOutputBuffers( if (!signature) { return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find signature"); } - auto subgraph = model_->Subgraph(signature->SubgraphIndex()); + auto subgraph = model_->Subgraph(signature->Key()); if (!subgraph) { return Unexpected(kLiteRtStatusErrorNotFound, "Failed to get subgraph"); } - std::vector output_buffers; + auto output_tensors = subgraph->Outputs(); + + std::vector output_buffers; output_buffers.reserve(output_tensors.size()); + for (int i = 0; i < output_tensors.size(); ++i) { auto output_buffer_requirements = GetOutputBufferRequirements(signature_index, i); @@ -84,11 +105,26 @@ Expected> CompiledModel::CreateOutputBuffers( return Unexpected(kLiteRtStatusErrorRuntimeFailure, output_buffer_requirements.Error().Message()); } + + auto supported_types = output_buffer_requirements->SupportedTypes(); + if (!supported_types) { + return supported_types.Error(); + } + if (supported_types->empty()) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Output doesn't support any tensor buffer types"); + } + + // For simplicity we just pick the first supported tensor buffer type. + LiteRtTensorBufferType tensor_buffer_type = (*supported_types)[0]; + auto tensor_type = output_tensors[i].RankedTensorType(); - LiteRtTensorBufferType tensor_buffer_type = - (*(*output_buffer_requirements).SupportedTypes())[0]; + if (!tensor_type) { + return tensor_type.Error(); + } + auto output_buffer = TensorBuffer::CreateManaged( - tensor_buffer_type, tensor_type, + tensor_buffer_type, *tensor_type, (*output_buffer_requirements).BufferSize().Value()); if (!output_buffer.HasValue()) { return Unexpected(kLiteRtStatusErrorRuntimeFailure, @@ -96,7 +132,8 @@ Expected> CompiledModel::CreateOutputBuffers( } output_buffers.push_back(std::move(*output_buffer)); } - return std::move(output_buffers); + + return output_buffers; } Expected CompiledModel::Run( @@ -121,4 +158,50 @@ Expected CompiledModel::Run( return {}; } +Expected CompiledModel::Run( + absl::string_view signature_key, + const absl::flat_hash_map& input_map, + const absl::flat_hash_map& output_map) { + auto signature_index = model_->GetSignatureIndex(signature_key); + if (!signature_index) { + return Unexpected(kLiteRtStatusErrorNotFound, + "Failed to get signature_index"); + } + auto subgraph = model_->Subgraph(signature_key); + if (!subgraph) { + return Unexpected(kLiteRtStatusErrorNotFound, "Failed to get subgraph"); + } + auto input_tensors = subgraph->Inputs(); + size_t num_inputs = input_tensors.size(); + auto input_buffers_ptr = std::make_unique(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + absl::string_view input_name = input_tensors[i].Name(); + auto it = input_map.find(input_name); + if (it == input_map.end()) { + return Unexpected(kLiteRtStatusErrorNotFound, + "The given map is missing some input TensorBuffers"); + } + input_buffers_ptr[i] = it->second.Get(); + } + auto output_tensors = subgraph->Outputs(); + size_t num_outputs = output_tensors.size(); + auto output_buffers_ptr = std::make_unique(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + absl::string_view output_name = output_tensors[i].Name(); + auto it = output_map.find(output_name); + if (it == output_map.end()) { + return Unexpected(kLiteRtStatusErrorNotFound, + "The given map is missing some output TensorBuffers"); + } + output_buffers_ptr[i] = it->second.Get(); + } + if (auto status = LiteRtRunCompiledModel(Get(), *signature_index, num_inputs, + input_buffers_ptr.get(), num_outputs, + output_buffers_ptr.get()); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to invoke the compiled model"); + } + return {}; +} + } // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_compiled_model.h b/tensorflow/lite/experimental/litert/cc/litert_compiled_model.h index b2215973b7c018..9b9499eef65022 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_compiled_model.h +++ b/tensorflow/lite/experimental/litert/cc/litert_compiled_model.h @@ -19,6 +19,7 @@ #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h" @@ -68,10 +69,10 @@ class CompiledModel // returned object. static Expected Create( litert::Model& model, - LiteRtComplicationOptions complication_options = kHwAccelDefault) { + LiteRtCompilationOptions compilation_options = kLiteRtHwAccelatorCpu) { LiteRtCompiledModel compiled_model; if (auto status = LiteRtCreateCompiledModel( - model.Get(), complication_options, &compiled_model); + model.Get(), compilation_options, &compiled_model); status != kLiteRtStatusOk) { return Unexpected(status, "Failed to create compiled model"); } @@ -118,12 +119,19 @@ class CompiledModel Expected> CreateOutputBuffers( size_t signature_index); - // Runs the model of the given signature with the provided input/output + // Runs the model of the given signature index with the provided input/output // TensorBuffers. Expected Run(size_t signature_index, const std::vector& input_buffers, const std::vector& output_buffers); + // Runs the model of the given signature key with the provided input/output + // TensorBuffer map. + Expected Run( + absl::string_view signature_key, + const absl::flat_hash_map& input_map, + const absl::flat_hash_map& output_map); + private: Model* model_; }; diff --git a/tensorflow/lite/experimental/litert/cc/litert_compiled_model_test.cc b/tensorflow/lite/experimental/litert/cc/litert_compiled_model_test.cc index 7c304b1dfbad22..874fe895c3a241 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_compiled_model_test.cc +++ b/tensorflow/lite/experimental/litert/cc/litert_compiled_model_test.cc @@ -15,41 +15,36 @@ #include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" #include +#include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/log/absl_log.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" #include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" -constexpr const float kTestInput0Tensor[] = {1, 2}; -constexpr const size_t kTestInput0Size = - sizeof(kTestInput0Tensor) / sizeof(kTestInput0Tensor[0]); -constexpr const float kTestInput1Tensor[] = {10, 20}; -constexpr const size_t kTestInput1Size = - sizeof(kTestInput1Tensor) / sizeof(kTestInput1Tensor[0]); -constexpr const float kTestOutputTensor[] = {11, 22}; -constexpr const size_t kTestOutputSize = - sizeof(kTestOutputTensor) / sizeof(kTestOutputTensor[0]); +using testing::FloatNear; +using testing::Pointwise; namespace litert { namespace { -using ::testing::FloatNear; -using ::testing::Pointwise; - -static constexpr absl::string_view kTfliteFile = "simple_model.tflite"; - TEST(CompiledModelTest, Basic) { - auto model = testing::LoadTestFileModel(kTfliteFile); + auto model = testing::LoadTestFileModel(kModelFileName); ASSERT_TRUE(model); + auto res_compiled_model = CompiledModel::Create(model); ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; + auto& compiled_model = *res_compiled_model; auto signatures = model.GetSignatures().Value(); EXPECT_EQ(signatures.size(), 1); + auto signature_key = signatures[0].Key(); EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); size_t signature_index = 0; @@ -79,14 +74,74 @@ TEST(CompiledModelTest, Basic) { auto output_names = signatures[0].OutputNames(); EXPECT_EQ(output_names.size(), 1); EXPECT_EQ(output_names.at(0), "tfl.add"); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - ASSERT_TRUE(output_buffers[0].Read(output_span)); - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; + { + auto lock_and_addr = + litert::TensorBufferScopedLock::Create(output_buffers[0]); + ASSERT_TRUE(lock_and_addr); + auto output = absl::MakeSpan(lock_and_addr->second, kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); + } +} + +TEST(CompiledModelTest, RunWithInputOutputMap) { + auto model = testing::LoadTestFileModel(kModelFileName); + ASSERT_TRUE(model); + + auto res_compiled_model = CompiledModel::Create(model); + ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; + + auto& compiled_model = *res_compiled_model; + auto signatures = model.GetSignatures().Value(); + EXPECT_EQ(signatures.size(), 1); + + auto signature_key = signatures[0].Key(); + EXPECT_EQ(signature_key, Model::DefaultSignatureKey()); + size_t signature_index = 0; + + auto input_buffers_res = compiled_model.CreateInputBuffers(signature_index); + EXPECT_TRUE(input_buffers_res); + auto& input_buffers = *input_buffers_res; + + auto output_buffers_res = compiled_model.CreateOutputBuffers(signature_index); + EXPECT_TRUE(output_buffers_res); + auto& output_buffers = *output_buffers_res; + + // Fill model inputs. + auto input_names = signatures[0].InputNames(); + EXPECT_EQ(input_names.size(), 2); + EXPECT_EQ(input_names.at(0), "arg0"); + EXPECT_EQ(input_names.at(1), "arg1"); + ASSERT_TRUE(input_buffers[0].Write( + absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); + ASSERT_TRUE(input_buffers[1].Write( + absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); + absl::flat_hash_map input_map; + input_map["arg0"] = std::move(input_buffers[0]); + input_map["arg1"] = std::move(input_buffers[1]); + + auto output_names = signatures[0].OutputNames(); + EXPECT_EQ(output_names.size(), 1); + EXPECT_EQ(output_names.at(0), "tfl.add"); + absl::flat_hash_map output_map; + output_map["tfl.add"] = std::move(output_buffers[0]); + + // Execute model. + compiled_model.Run(signature_key, input_map, output_map); + + // Check model output. + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create( + output_map["tfl.add"]); + ASSERT_TRUE(lock_and_addr); + auto output = absl::MakeSpan(lock_and_addr->second, kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); } } // namespace diff --git a/tensorflow/lite/experimental/litert/cc/litert_consts.h b/tensorflow/lite/experimental/litert/cc/litert_consts.h new file mode 100644 index 00000000000000..14ac9a0b00e832 --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/litert_consts.h @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_CONSTS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_CONSTS_H_ + +#include + +namespace litert { + +// The following constants are used to properly size absl::InlinedVector<> +// uses used in the LiteRT code. Their values don't need to be exact; they +// are just optimization hints. +static constexpr size_t kExpectedMaxTensorRank = 6; +static constexpr size_t kExpectedMaxNumOfTensorUses = 8; +static constexpr size_t kExpectedMaxNumOfOpInputs = 4; +static constexpr size_t kExpectedMaxNumOfOpOutputs = 8; +static constexpr size_t kExpectedMaxNumOfSubgraphInputs = 4; +static constexpr size_t kExpectedMaxNumOfSubgraphOutputs = 4; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_CONSTS_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_detail.h b/tensorflow/lite/experimental/litert/cc/litert_detail.h index 8153629bf7202f..566d8468fa8148 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_detail.h +++ b/tensorflow/lite/experimental/litert/cc/litert_detail.h @@ -17,31 +17,26 @@ #include #include +#include +#include -#include "absl/container/inlined_vector.h" #include "absl/log/absl_check.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" namespace litert { -// Expected size for inlined vectors for things like the input/outputs of ops or -// subgraphs. -static constexpr size_t kTensorVecSize = 8; -template -using SmallVec = absl::InlinedVector; - // See "std::construct_at" from C++20. template -inline T* ConstructAt(T* p, Args&&... args) { +T* ConstructAt(T* p, Args&&... args) { return ::new (static_cast(p)) T(std::forward(args)...); } // Reduce all over zipped iters of same size. template -inline bool AllZip(const LeftVals& lhs, const RightVals& rhs, - std::function - bin_pred) { +bool AllZip(const LeftVals& lhs, const RightVals& rhs, + std::function + bin_pred) { if (lhs.size() != rhs.size()) { return false; } @@ -55,14 +50,33 @@ inline bool AllZip(const LeftVals& lhs, const RightVals& rhs, // Reduce any over zipped iters of same size. template -inline bool AnyZip(const LeftVals& lhs, const RightVals& rhs, - std::function - bin_pred) { +bool AnyZip(const LeftVals& lhs, const RightVals& rhs, + std::function + bin_pred) { auto neg = [&](const auto& l, const auto& r) { return !bin_pred(l, r); }; return !(AllZip(lhs, rhs, neg)); } +// Does element exist in range. +template +bool Contains(It begin, It end, const T& val) { + return std::find(begin, end, val) != end; +} + +// Does element exist in range satisfying pred. +template +bool ContainsIf(It begin, It end, UPred u_pred) { + return std::find_if(begin, end, u_pred) != end; +} + +// Get the ind of the given element if it is present. +template +std::optional FindInd(It begin, It end, T val) { + auto it = std::find(begin, end, val); + return (it == end) ? std::nullopt : std::make_optional(it - begin); +} + namespace internal { // Call function "get" and assert it returns value equal to given expected diff --git a/tensorflow/lite/experimental/litert/cc/litert_element_type.h b/tensorflow/lite/experimental/litert/cc/litert_element_type.h index 3f2b49b9df8155..84b032b3820a7a 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_element_type.h +++ b/tensorflow/lite/experimental/litert/cc/litert_element_type.h @@ -87,10 +87,13 @@ inline constexpr size_t GetByteWidth() { return byte_width.value(); } +template +constexpr bool dependent_false = false; // workaround before CWG2518/P2593R1 + // Get the litert::ElementType associated with given C++ type. template inline constexpr ElementType GetElementType() { - static_assert(false, "Uknown C++ type"); + static_assert(dependent_false, "Uknown C++ type"); return ElementType::None; } diff --git a/tensorflow/lite/experimental/litert/cc/litert_environment.h b/tensorflow/lite/experimental/litert/cc/litert_environment.h new file mode 100644 index 00000000000000..4910abd89b27c7 --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/litert_environment.h @@ -0,0 +1,82 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ENVIRONMENT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ENVIRONMENT_H_ + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_environment.h" +#include "tensorflow/lite/experimental/litert/cc/litert_any.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" + +namespace litert { + +class Environment { + public: + enum class OptionTag { + CompilerPluginLibraryPath = kLiteRtEnvOptionTagCompilerPluginLibraryPath, + DispatchLibraryPath = kLiteRtEnvOptionTagDispatchLibraryPath, + }; + + struct Option { + OptionTag tag; + std::any value; + }; + + static Expected Create(absl::Span options) { + auto c_options = ConvertOptions(options); + if (!c_options) { + return c_options.Error(); + } + if (auto status = + LiteRtEnvironmentCreate(c_options->size(), c_options->data()); + status != kLiteRtStatusOk) { + return Error(status); + } else { + return {}; + } + } + + static void Destroy() { LiteRtEnvironmentDestroy(); } + + private: + static Expected> ConvertOptions( + absl::Span options) { + std::vector c_options; + c_options.reserve(options.size()); + + for (auto& option : options) { + auto litert_any = ToLiteRtAny(option.value); + if (!litert_any) { + return litert_any.Error(); + } + + LiteRtEnvOption c_option = { + /*.tag=*/static_cast(option.tag), + /*.value=*/*litert_any, + }; + c_options.push_back(c_option); + } + + return c_options; + } +}; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ENVIRONMENT_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_event.h b/tensorflow/lite/experimental/litert/cc/litert_event.h new file mode 100644 index 00000000000000..a618d3e8e4787c --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/litert_event.h @@ -0,0 +1,66 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EVENT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EVENT_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_event.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" + +namespace litert { + +class Event : public internal::Handle { + public: + // Parameter `owned` indicates if the created TensorBufferRequirements object + // should take ownership of the provided `requirements` handle. + explicit Event(LiteRtEvent event, bool owned = true) + : internal::Handle(event, owned) {} + + static Expected CreateFromSyncFenceFd(int sync_fence_fd, + bool owns_fd) { + LiteRtEvent event; + if (auto status = + LiteRtCreateEventFromSyncFenceFd(sync_fence_fd, owns_fd, &event); + status != kLiteRtStatusOk) { + return Error(status, "Failed to create event from sync fence fd"); + } + return Event(event); + } + + Expected GetSyncFenceFd(LiteRtEvent event) { + int fd; + if (auto status = LiteRtGetEventSyncFenceFd(Get(), &fd); + status != kLiteRtStatusOk) { + return Error(status, "Failed to get sync fence fd from event"); + } + return fd; + } + + // Pass -1 for timeout_in_ms for indefinite wait. + Expected Wait(int64_t timeout_in_ms) { + if (auto status = LiteRtEventWait(Get(), timeout_in_ms); + status != kLiteRtStatusOk) { + return Error(status, "Failed to wait on event"); + } + return {}; + } +}; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_EVENT_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_expected.h b/tensorflow/lite/experimental/litert/cc/litert_expected.h index 1526e7c0ec8092..01a481812b8e51 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_expected.h +++ b/tensorflow/lite/experimental/litert/cc/litert_expected.h @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -50,6 +51,10 @@ class Error { // Get the error message, empty string if none was attached. constexpr absl::string_view Message() const { return message_; } + friend std::ostream& operator<<(std::ostream& stream, const Error& error) { + return stream << error.Message(); + } + private: LiteRtStatus status_; absl::string_view message_; diff --git a/tensorflow/lite/experimental/litert/cc/litert_expected_test.cc b/tensorflow/lite/experimental/litert/cc/litert_expected_test.cc index 6dea4ecfb09b26..415bb389c54846 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_expected_test.cc +++ b/tensorflow/lite/experimental/litert/cc/litert_expected_test.cc @@ -16,10 +16,12 @@ #include #include +#include #include #include #include +#include #include #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" @@ -186,6 +188,13 @@ TEST(ExpectedWithNoValue, WithError) { EXPECT_EQ(expected.Error().Message(), "MESSAGE"); } +TEST(ExpectedWithNoValue, OStreamOutput) { + Expected expected(Unexpected(kErrorStatus, "MESSAGE")); + std::ostringstream oss; + oss << expected.Error(); + EXPECT_THAT(oss.str(), testing::HasSubstr("MESSAGE")); +} + } // namespace } // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_layout.h b/tensorflow/lite/experimental/litert/cc/litert_layout.h index e455a95924d985..a928e34c543a9f 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_layout.h +++ b/tensorflow/lite/experimental/litert/cc/litert_layout.h @@ -35,7 +35,7 @@ static constexpr size_t kTensorMaxRank = LITERT_TENSOR_MAX_RANK; template inline constexpr LiteRtLayout BuildLayout(Begin begin, End end, const uint32_t* strides = nullptr) { - LiteRtLayout res(end - begin, {}, strides); + LiteRtLayout res{static_cast(end - begin), {}, strides}; auto i = 0; for (auto* it = begin; it < end && i < kTensorMaxRank; ++it) { diff --git a/tensorflow/lite/experimental/litert/cc/litert_model.cc b/tensorflow/lite/experimental/litert/cc/litert_model.cc index 49f1f18d25f855..c5b943879d2c53 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_model.cc +++ b/tensorflow/lite/experimental/litert/cc/litert_model.cc @@ -14,6 +14,8 @@ #include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include + #include "tensorflow/lite/experimental/litert/c/litert_model.h" #include "tensorflow/lite/experimental/litert/cc/litert_detail.h" @@ -29,17 +31,84 @@ bool Tensor::IsConstant() const { return HasWeights() && !DefiningOp().has_value(); } -SmallVec Tensor::Uses() const { +Tensor::TensorUses Tensor::Uses() const { LiteRtParamIndex num_uses; - LiteRtOpArray users; - LiteRtParamIndex* user_arg_inds; - litert::internal::AssertOk(LiteRtGetTensorUses, Get(), &num_uses, &users, - &user_arg_inds); - SmallVec res; - for (int i = 0; i < num_uses; ++i) { - res.push_back(Tensor::TensorUse{Op(users[i]), user_arg_inds[i]}); // NOLINT + litert::internal::AssertOk(LiteRtGetNumTensorUses, Get(), &num_uses); + + TensorUses uses; + for (auto i = 0; i < num_uses; ++i) { + LiteRtOp user; + LiteRtParamIndex user_arg_index; + litert::internal::AssertOk(LiteRtGetTensorUse, Get(), i, &user, + &user_arg_index); + uses.emplace_back(TensorUse{Op(user), user_arg_index}); + } + return uses; +} + +OpInputs Op::Inputs() const { + LiteRtParamIndex num_inputs; + internal::AssertOk(LiteRtGetNumOpInputs, Get(), &num_inputs); + + OpInputs inputs; + for (auto i = 0; i < num_inputs; ++i) { + LiteRtTensor input; + internal::AssertOk(LiteRtGetOpInput, Get(), i, &input); + inputs.emplace_back(Tensor(input)); + } + return inputs; +} + +OpOutputs Op::Outputs() const { + LiteRtParamIndex num_outputs; + internal::AssertOk(LiteRtGetNumOpOutputs, Get(), &num_outputs); + + OpOutputs outputs; + for (auto i = 0; i < num_outputs; ++i) { + LiteRtTensor output; + internal::AssertOk(LiteRtGetOpOutput, Get(), i, &output); + outputs.emplace_back(Tensor(output)); + } + return outputs; +} + +SubgraphInputs Subgraph::Inputs() const { + LiteRtParamIndex num_inputs; + internal::AssertOk(LiteRtGetNumSubgraphInputs, Get(), &num_inputs); + + SubgraphInputs inputs; + for (auto i = 0; i < num_inputs; ++i) { + LiteRtTensor input; + internal::AssertOk(LiteRtGetSubgraphInput, Get(), i, &input); + inputs.emplace_back(Tensor(input)); + } + return inputs; +} + +SubgraphOutputs Subgraph::Outputs() const { + LiteRtParamIndex num_outputs; + internal::AssertOk(LiteRtGetNumSubgraphOutputs, Get(), &num_outputs); + + SubgraphOutputs outputs; + for (auto i = 0; i < num_outputs; ++i) { + LiteRtTensor output; + internal::AssertOk(LiteRtGetSubgraphOutput, Get(), i, &output); + outputs.emplace_back(Tensor(output)); + } + return outputs; +} + +std::vector Subgraph::Ops() const { + LiteRtParamIndex num_ops; + internal::AssertOk(LiteRtGetNumSubgraphOps, Get(), &num_ops); + + std::vector ops; + for (auto i = 0; i < num_ops; ++i) { + LiteRtOp op; + litert::internal::AssertOk(LiteRtGetSubgraphOp, Get(), i, &op); + ops.emplace_back(Op(op)); } - return res; + return ops; } } // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_model.h b/tensorflow/lite/experimental/litert/cc/litert_model.h index 4158b43fa3fa65..77c76afd8e06c5 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_model.h +++ b/tensorflow/lite/experimental/litert/cc/litert_model.h @@ -24,12 +24,13 @@ #include #include -#include "absl/strings/str_format.h" +#include "absl/container/inlined_vector.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" +#include "tensorflow/lite/experimental/litert/cc/litert_consts.h" #include "tensorflow/lite/experimental/litert/cc/litert_detail.h" #include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" @@ -38,11 +39,14 @@ namespace litert { +using Dimensions = absl::InlinedVector; +using Strides = absl::InlinedVector; + // Tensor layout. C++ equivalent to LiteRtLayout. class Layout { public: - explicit Layout(SmallVec&& dimensions, - SmallVec&& strides = SmallVec()) + explicit Layout(litert::Dimensions&& dimensions, + litert::Strides&& strides = litert::Strides()) : dimensions_(std::move(dimensions)), strides_(std::move(strides)) {} explicit Layout(const LiteRtLayout& layout) @@ -85,8 +89,8 @@ class Layout { } private: - SmallVec dimensions_; - SmallVec strides_; + litert::Dimensions dimensions_; + litert::Strides strides_; }; // Type for tensors with known dimensions. C++ equivalent to @@ -141,22 +145,36 @@ class Tensor : public internal::NonOwnedHandle { explicit Tensor(LiteRtTensor tensor) : internal::NonOwnedHandle(tensor) {} + enum ElementType ElementType() const { + if (TypeId() == kLiteRtUnrankedTensorType) { + return static_cast(UnrankedTensorType()->element_type); + } else { + return RankedTensorType()->ElementType(); + } + } + LiteRtTensorTypeId TypeId() const { LiteRtTensorTypeId type_id; internal::AssertOk(LiteRtGetTensorTypeId, Get(), &type_id); return type_id; } - LiteRtUnrankedTensorType UnrankedTensorType() const { - internal::AssertEq([&]() { return TypeId(); }, kLiteRtUnrankedTensorType); + Expected UnrankedTensorType() const { + if (TypeId() != kLiteRtUnrankedTensorType) { + return Error(kLiteRtStatusErrorInvalidArgument, + "Not an unranked invalid tensor"); + } LiteRtUnrankedTensorType unranked_tensor_type; internal::AssertOk(LiteRtGetUnrankedTensorType, Get(), &unranked_tensor_type); return unranked_tensor_type; } - class RankedTensorType RankedTensorType() const { - internal::AssertEq([&]() { return TypeId(); }, kLiteRtRankedTensorType); + Expected RankedTensorType() const { + if (TypeId() != kLiteRtRankedTensorType) { + return Error(kLiteRtStatusErrorInvalidArgument, + "Not a ranked tensor type"); + } LiteRtRankedTensorType ranked_tensor_type; internal::AssertOk(LiteRtGetRankedTensorType, Get(), &ranked_tensor_type); return litert::RankedTensorType(ranked_tensor_type); @@ -179,6 +197,15 @@ class Tensor : public internal::NonOwnedHandle { return per_tensor_quantization; } + LiteRtQuantizationPerChannel PerChannelQuantization() const { + internal::AssertEq([&]() { return QTypeId(); }, + kLiteRtQuantizationPerChannel); + LiteRtQuantizationPerChannel per_channel_quantization; + internal::AssertOk(LiteRtGetPerChannelQuantization, Get(), + &per_channel_quantization); + return per_channel_quantization; + } + bool HasWeights() const { auto weights = Weights(); return !weights.Bytes().empty(); @@ -197,11 +224,19 @@ class Tensor : public internal::NonOwnedHandle { } struct TensorUse; - SmallVec Uses() const; + using TensorUses = + absl::InlinedVector; + + TensorUses Uses() const; template Expected> WeightsData() const { - const ElementType ty = RankedTensorType().ElementType(); + auto ranked_tensor_type = RankedTensorType(); + if (!ranked_tensor_type) { + return ranked_tensor_type.Error(); + } + + const enum ElementType ty = ranked_tensor_type->ElementType(); if (ty != GetElementType()) { return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); } @@ -211,7 +246,7 @@ class Tensor : public internal::NonOwnedHandle { } const absl::Span weights = Weights().Bytes(); - auto num_elements = RankedTensorType().Layout().NumElements(); + auto num_elements = ranked_tensor_type->Layout().NumElements(); if (!num_elements.has_value()) { return litert::Unexpected(kLiteRtStatusErrorInvalidArgument); } @@ -245,6 +280,9 @@ class Tensor : public internal::NonOwnedHandle { bool IsConstant() const; }; +using OpInputs = absl::InlinedVector; +using OpOutputs = absl::InlinedVector; + // Operator. C++ equivalent of LiteRtOp. class Op : public internal::NonOwnedHandle { public: @@ -257,19 +295,8 @@ class Op : public internal::NonOwnedHandle { return opcode; } - SmallVec Inputs() const { - LiteRtParamIndex num_inputs; - LiteRtTensorArray inputs; - internal::AssertOk(LiteRtGetOpInputs, Get(), &num_inputs, &inputs); - return SmallVec(inputs, inputs + num_inputs); - } - - SmallVec Outputs() const { - LiteRtParamIndex num_outputs; - LiteRtTensorArray outputs; - internal::AssertOk(LiteRtGetOpOutputs, Get(), &num_outputs, &outputs); - return SmallVec(outputs, outputs + num_outputs); - } + OpInputs Inputs() const; + OpOutputs Outputs() const; }; struct Tensor::TensorUse { @@ -277,6 +304,11 @@ struct Tensor::TensorUse { LiteRtParamIndex user_arg_ind; }; +using SubgraphInputs = + absl::InlinedVector; +using SubgraphOutputs = + absl::InlinedVector; + // Model subgraph. C++ equivalent of LiteRtSubgraph. class Subgraph : public internal::NonOwnedHandle { public: @@ -284,26 +316,9 @@ class Subgraph : public internal::NonOwnedHandle { explicit Subgraph(LiteRtSubgraph subgraph) : internal::NonOwnedHandle(subgraph) {} - SmallVec Inputs() const { - LiteRtParamIndex num_inputs; - LiteRtTensorArray inputs; - internal::AssertOk(LiteRtGetSubgraphInputs, Get(), &num_inputs, &inputs); - return SmallVec(inputs, inputs + num_inputs); - } - - SmallVec Outputs() const { - LiteRtParamIndex num_outputs; - LiteRtTensorArray outputs; - internal::AssertOk(LiteRtGetSubgraphOutputs, Get(), &num_outputs, &outputs); - return SmallVec(outputs, outputs + num_outputs); - } - - std::vector Ops() const { - LiteRtParamIndex num_ops; - LiteRtOpArray ops; - internal::AssertOk(LiteRtGetSubgraphOps, Get(), &num_ops, &ops); - return std::vector(ops, ops + num_ops); - } + SubgraphInputs Inputs() const; + SubgraphOutputs Outputs() const; + std::vector Ops() const; }; // Model signature. C++ equivalent of LiteRtSignature. @@ -319,10 +334,10 @@ class Signature : public internal::NonOwnedHandle { return key; } - int SubgraphIndex() const { - LiteRtParamIndex subgraph_index; - internal::AssertOk(LiteRtGetSignatureSubgraphIndex, Get(), &subgraph_index); - return subgraph_index; + LiteRtSubgraph Subgraph() const { + LiteRtSubgraph subgraph; + internal::AssertOk(LiteRtGetSignatureSubgraph, Get(), &subgraph); + return subgraph; } std::vector InputNames() const { @@ -422,7 +437,13 @@ class Model : public internal::Handle { if (!signature) { return Unexpected(kLiteRtStatusErrorNotFound, "Signature not found"); } - return Subgraph(signature->SubgraphIndex()); + return litert::Subgraph(signature->Subgraph()); + } + + size_t GetNumSignatures() const { + LiteRtParamIndex num_signatures; + internal::AssertOk(LiteRtGetNumModelSignatures, Get(), &num_signatures); + return num_signatures; } // Returns the list of signatures defined in the model. @@ -448,6 +469,23 @@ class Model : public internal::Handle { return Signature(lite_rt_signature); } + // Returns the signature index for the given signature key. + Expected GetSignatureIndex(absl::string_view signature_key) const { + LiteRtParamIndex num_signatures; + internal::AssertOk(LiteRtGetNumModelSignatures, Get(), &num_signatures); + for (int i = 0; i < num_signatures; ++i) { + LiteRtSignature lite_rt_signature; + internal::AssertOk(LiteRtGetModelSignature, Get(), i, &lite_rt_signature); + const char* key_cstr; + internal::AssertOk(LiteRtGetSignatureKey, lite_rt_signature, &key_cstr); + if (absl::string_view(key_cstr) == signature_key) { + return i; + } + } + return Unexpected(kLiteRtStatusErrorNotFound, "Signature not found"); + } + + // Returns the Signature object for the given signature key. Expected FindSignature( absl::string_view signature_key) const { LiteRtParamIndex num_signatures; diff --git a/tensorflow/lite/experimental/litert/cc/litert_model_predicates.cc b/tensorflow/lite/experimental/litert/cc/litert_model_predicates.cc index b19545a029d016..18efea56f7ffa4 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_model_predicates.cc +++ b/tensorflow/lite/experimental/litert/cc/litert_model_predicates.cc @@ -80,7 +80,12 @@ bool MatchOpType( if (!expected.has_value()) { return true; } - return MatchRankedTensorType(actual.RankedTensorType(), expected.value()); + auto actual_ranked_tensor_type = actual.RankedTensorType(); + // Don't return a match if the tensor is unranked. + if (!actual_ranked_tensor_type) { + return false; + } + return MatchRankedTensorType(*actual_ranked_tensor_type, expected.value()); }; const bool inputs_match = AllZip(absl::MakeConstSpan(op.Inputs()), diff --git a/tensorflow/lite/experimental/litert/cc/litert_model_predicates_test.cc b/tensorflow/lite/experimental/litert/cc/litert_model_predicates_test.cc index 3e7f4ac7c72312..f16bc764e560c4 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_model_predicates_test.cc +++ b/tensorflow/lite/experimental/litert/cc/litert_model_predicates_test.cc @@ -38,8 +38,10 @@ TEST(MatchRankedTensorTypeTest, HasAll) { auto ops = subgraph->Ops(); const auto inputs = ops.front().Inputs(); const auto& input = inputs.front(); + auto input_tensor_type = input.RankedTensorType(); + EXPECT_TRUE(input_tensor_type); EXPECT_TRUE(MatchRankedTensorType( - input.RankedTensorType(), TensorTypeInfo(ElementType::Float32, {2, 2}))); + *input_tensor_type, TensorTypeInfo(ElementType::Float32, {2, 2}))); } TEST(MatchRankedTensorTypeTest, NoMatch) { @@ -49,8 +51,10 @@ TEST(MatchRankedTensorTypeTest, NoMatch) { auto ops = subgraph->Ops(); const auto inputs = ops.front().Inputs(); const auto& input = inputs.front(); + auto input_tensor_type = input.RankedTensorType(); + EXPECT_TRUE(input_tensor_type); EXPECT_FALSE(MatchRankedTensorType( - input.RankedTensorType(), TensorTypeInfo(ElementType::Float32, {3, 2}))); + *input_tensor_type, TensorTypeInfo(ElementType::Float32, {3, 2}))); } TEST(MatchRankedTensorTypeTest, AnyDims) { @@ -60,7 +64,9 @@ TEST(MatchRankedTensorTypeTest, AnyDims) { auto ops = subgraph->Ops(); const auto inputs = ops.front().Inputs(); const auto& input = inputs.front(); - EXPECT_TRUE(MatchRankedTensorType(input.RankedTensorType(), + auto input_tensor_type = input.RankedTensorType(); + EXPECT_TRUE(input_tensor_type); + EXPECT_TRUE(MatchRankedTensorType(*input_tensor_type, TensorTypeInfo(ElementType::Float32))); } @@ -71,8 +77,10 @@ TEST(MatchRankedTensorTypeTest, AnyElementType) { auto ops = subgraph->Ops(); const auto inputs = ops.front().Inputs(); const auto& input = inputs.front(); + auto input_tensor_type = input.RankedTensorType(); + EXPECT_TRUE(input_tensor_type); EXPECT_TRUE( - MatchRankedTensorType(input.RankedTensorType(), TensorTypeInfo({2, 2}))); + MatchRankedTensorType(*input_tensor_type, TensorTypeInfo({2, 2}))); } TEST(MatchOpTypeTest, HasAll) { diff --git a/tensorflow/lite/experimental/litert/cc/litert_model_test.cc b/tensorflow/lite/experimental/litert/cc/litert_model_test.cc index 5760cf41be546a..0084119013ce5a 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_model_test.cc +++ b/tensorflow/lite/experimental/litert/cc/litert_model_test.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -210,7 +211,8 @@ TEST(CcTensorTest, SimpleModel) { ASSERT_EQ(input_tensor.TypeId(), kLiteRtRankedTensorType); auto input_ranked_tensor_type = input_tensor.RankedTensorType(); - ASSERT_EQ(input_ranked_tensor_type.ElementType(), ElementType::Float32); + EXPECT_TRUE(input_ranked_tensor_type); + ASSERT_EQ(input_ranked_tensor_type->ElementType(), ElementType::Float32); EXPECT_FALSE(input_tensor.HasWeights()); @@ -249,7 +251,7 @@ TEST(CcTensorTest, WeightsData) { TEST(CcTensorTest, Name) { static constexpr absl::string_view kName = "foo"; LiteRtTensorT tensor; - tensor.name = kName; + tensor.SetName(std::string(kName)); Tensor cc_tensor(&tensor); EXPECT_EQ(cc_tensor.Name(), kName); @@ -257,7 +259,7 @@ TEST(CcTensorTest, Name) { TEST(CcTensorTest, QuantizationNone) { LiteRtTensorT litert_tensor; - litert_tensor.q_type_id = kLiteRtQuantizationNone; + litert_tensor.Qparams().first = kLiteRtQuantizationNone; Tensor tensor(&litert_tensor); EXPECT_EQ(tensor.QTypeId(), kLiteRtQuantizationNone); @@ -269,8 +271,7 @@ TEST(CcTensorTest, QuantizationPerTensor) { static constexpr auto kZeroPoint = 1; LiteRtTensorT litert_tensor; - litert_tensor.q_type_id = kLiteRtQuantizationPerTensor; - litert_tensor.q_type_detail.per_tensor = {kScale, kZeroPoint}; + litert_tensor.SetQarams(MakePerTensorQuantization(kScale, kZeroPoint)); Tensor tensor(&litert_tensor); ASSERT_EQ(tensor.QTypeId(), kLiteRtQuantizationPerTensor); @@ -281,6 +282,32 @@ TEST(CcTensorTest, QuantizationPerTensor) { EXPECT_EQ(per_tensor_quantization.zero_point, kZeroPoint); } +TEST(CcTensorTest, QuantizationPerChannel) { + static constexpr auto kNumChannels = 2; + static constexpr auto kQuantizedDimension = 0; + static constexpr float kScales[kNumChannels] = {1.0, 2.0}; + static constexpr int64_t kZeroPoints[kNumChannels] = {0, 0}; + + LiteRtTensorT litert_tensor; + auto per_channel = MakePerChannelQuantization( + kScales, kZeroPoints, kQuantizedDimension, litert_tensor); + litert_tensor.SetQarams(per_channel); + + Tensor tensor(&litert_tensor); + ASSERT_EQ(tensor.QTypeId(), kLiteRtQuantizationPerChannel); + ASSERT_TRUE(tensor.HasQuantization()); + + const auto per_channel_quantization = tensor.PerChannelQuantization(); + EXPECT_THAT( + absl::MakeConstSpan(per_channel_quantization.scales, kNumChannels), + ::testing::ElementsAreArray(kScales)); + EXPECT_THAT( + absl::MakeConstSpan(per_channel_quantization.zero_points, kNumChannels), + ::testing::ElementsAreArray(kZeroPoints)); + EXPECT_EQ(per_channel_quantization.num_channels, kNumChannels); + EXPECT_EQ(per_channel_quantization.quantized_dimension, kQuantizedDimension); +} + //===----------------------------------------------------------------------===// // CC Subgraph // //===----------------------------------------------------------------------===// diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h index ddf1e566eb07be..1c044dd7e8ce0f 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h +++ b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h @@ -24,6 +24,8 @@ #include "tensorflow/lite/experimental/litert/c/litert_event.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" #include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" +#include "tensorflow/lite/experimental/litert/cc/litert_event.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" #include "tensorflow/lite/experimental/litert/cc/litert_handle.h" #include "tensorflow/lite/experimental/litert/cc/litert_model.h" @@ -70,6 +72,53 @@ class TensorBuffer return TensorBuffer(tensor_buffer); } + // Creates a TensorBuffer object that wraps the provided host memory. + // The provided host memory is not owned by the TensorBuffer object and must + // outlive the TensorBuffer object. + static Expected CreateFromHostMemory( + const RankedTensorType& tensor_type, void* host_mem_addr, + size_t buffer_size) { + LiteRtTensorBuffer tensor_buffer; + auto litert_tensor_type = static_cast(tensor_type); + + if (auto status = LiteRtCreateTensorBufferFromHostMemory( + &litert_tensor_type, host_mem_addr, buffer_size, + /*deallocator=*/nullptr, &tensor_buffer); + status != kLiteRtStatusOk) { + return Unexpected(status, + "Failed to create tensor buffer from host memory"); + } + return TensorBuffer(tensor_buffer); + } + + // Creates a TensorBuffer object that wraps an Android Hardware Buffer. Note + // that the provided AHardwareBuffer is not owned by the TensorBuffer object + // and must outlive the TensorBuffer object. The `ahwb_offset` parameter + // specifies the offset in bytes from the start of the AHardwareBuffer where + // the tensor data starts. + static Expected CreateFromAhwb( + const RankedTensorType& tensor_type, AHardwareBuffer* ahwb, + size_t ahwb_offset) { +#if LITERT_HAS_AHWB_SUPPORT + LiteRtTensorBuffer tensor_buffer; + auto litert_tensor_type = static_cast(tensor_type); + + if (auto status = LiteRtCreateTensorBufferFromAhwb( + &litert_tensor_type, ahwb, ahwb_offset, + /*deallocator=*/nullptr, &tensor_buffer); + status != kLiteRtStatusOk) { + return Unexpected( + status, + "Failed to create tensor buffer from Android Hardware Buffer"); + } + return TensorBuffer(tensor_buffer); +#else + return litert::Unexpected( + kLiteRtStatusErrorRuntimeFailure, + "AHardwareBuffer is not supported on this platform"); +#endif + } + litert::Expected GetAhwb() const { #if LITERT_HAS_AHWB_SUPPORT AHardwareBuffer* ahwb; @@ -87,6 +136,27 @@ class TensorBuffer #endif } + struct DmaBuf { + void* addr; + int fd; + }; + + litert::Expected GetDmaBuf() const { +#if LITERT_HAS_DMABUF_SUPPORT + DmaBuf dma_buf; + if (LiteRtGetTensorBufferDmaBufBuffer(Get(), &dma_buf.addr, &dma_buf.fd) == + kLiteRtStatusOk) { + return dma_buf; + } else { + return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to get DMA-BUF from tensor buffer"); + } +#else + return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, + "DMA-BUF is not supported on this platform"); +#endif + } + Expected BufferType() const { LiteRtTensorBufferType tensor_buffer_type; if (auto status = LiteRtGetTensorBufferType(Get(), &tensor_buffer_type); @@ -123,6 +193,37 @@ class TensorBuffer return offset; } + bool HasEvent() const { + bool has_event; + internal::AssertOk(LiteRtHasTensorBufferEvent, Get(), &has_event); + return has_event; + } + + Expected GetEvent() const { + LiteRtEvent event; + if (auto status = LiteRtGetTensorBufferEvent(Get(), &event); + status != kLiteRtStatusOk) { + return Error(status, "Failed to get tensor buffer event"); + } + return Event(event, /*owned=*/false); + } + + Expected SetEvent(Event e) { + if (auto status = LiteRtSetTensorBufferEvent(Get(), e.Get()); + status != kLiteRtStatusOk) { + return Error(status, "Failed to set tensor buffer event"); + } + return {}; + } + + Expected ClearEvent() { + if (auto status = LiteRtClearTensorBufferEvent(Get()); + status != kLiteRtStatusOk) { + return Error(status, "Failed to clear tensor buffer event"); + } + return {}; + } + Expected Lock(LiteRtEvent event = nullptr) { void* host_mem_addr; if (auto status = LiteRtLockTensorBuffer(Get(), &host_mem_addr, event); @@ -194,21 +295,34 @@ class TensorBuffer class TensorBufferScopedLock { public: - ~TensorBufferScopedLock() { (void)tensor_buffer_.Unlock(); } + TensorBufferScopedLock(const TensorBufferScopedLock& arg) = delete; + TensorBufferScopedLock(TensorBufferScopedLock&& arg) = default; + ~TensorBufferScopedLock() { (void)LiteRtUnlockTensorBuffer(tensor_buffer_); } - static Expected> Create( + template + static Expected> Create( TensorBuffer& tensor_buffer, LiteRtEvent event = nullptr) { - auto addr = tensor_buffer.Lock(event); - if (!addr) { - return addr.Error(); + return Create(tensor_buffer.Get(), event); + } + + template + static Expected> Create( + LiteRtTensorBuffer tensor_buffer, LiteRtEvent event = nullptr) { + void* host_mem_addr; + if (auto status = + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, event); + status != kLiteRtStatusOk) { + return Unexpected(status, "Failed to lock the tensor buffer"); } - return std::make_pair(TensorBufferScopedLock(tensor_buffer), *addr); + return std::make_pair(TensorBufferScopedLock(tensor_buffer), + static_cast(host_mem_addr)); } private: - explicit TensorBufferScopedLock(TensorBuffer& tensor_buffer) + explicit TensorBufferScopedLock(LiteRtTensorBuffer& tensor_buffer) : tensor_buffer_(tensor_buffer) {} - TensorBuffer& tensor_buffer_; + + LiteRtTensorBuffer tensor_buffer_; }; } // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_test.cc b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_test.cc index ce6cefdce637ff..607e36fe1d4024 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_test.cc +++ b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_test.cc @@ -14,8 +14,11 @@ #include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include #include +#include #include +#include #include // NOLINT: Need when ANDROID_API_LEVEL >= 26 #include "absl/types/span.h" @@ -30,6 +33,10 @@ #include "tensorflow/lite/experimental/litert/runtime/ion_buffer.h" // IWYU pragma: keep #include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" +#if LITERT_HAS_AHWB_SUPPORT +#include +#endif // LITERT_HAS_AHWB_SUPPORT + namespace { constexpr const float kTensorData[] = {10, 20, 30, 40}; @@ -302,6 +309,84 @@ TEST(TensorBuffer, NotOwned) { LiteRtDestroyTensorBuffer(litert_tensor_buffer); } +TEST(TensorBuffer, ExternalHostMemory) { + // Allocate a tensor buffer with host memory. + const int kTensorBufferSize = + std::max(sizeof(kTensorData), LITERT_HOST_MEMORY_BUFFER_ALIGNMENT); + const litert::RankedTensorType kTensorType(::kTensorType); + void* host_memory_ptr; + ASSERT_EQ( + ::posix_memalign(&host_memory_ptr, LITERT_HOST_MEMORY_BUFFER_ALIGNMENT, + kTensorBufferSize), + 0); + + std::memcpy(host_memory_ptr, kTensorData, sizeof(kTensorData)); + + // Create a tensor buffer that wraps the host memory. + auto tensor_buffer_from_external_memory = + litert::TensorBuffer::CreateFromHostMemory(kTensorType, host_memory_ptr, + kTensorBufferSize); + + auto lock_and_addr_external_memory = litert::TensorBufferScopedLock::Create( + *tensor_buffer_from_external_memory); + ASSERT_TRUE(lock_and_addr_external_memory); + ASSERT_EQ(std::memcmp(lock_and_addr_external_memory->second, kTensorData, + sizeof(kTensorData)), + 0); + + free(host_memory_ptr); +} + +#if LITERT_HAS_AHWB_SUPPORT +TEST(TensorBuffer, FromAhwb) { + AHardwareBuffer* ahw_buffer = nullptr; + if (__builtin_available(android 26, *)) { + int error = 0; + AHardwareBuffer_Desc desc = { + .width = LITERT_HOST_MEMORY_BUFFER_ALIGNMENT, + .height = 1, + .layers = 1, + .format = AHARDWAREBUFFER_FORMAT_BLOB, + .usage = AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY | + AHARDWAREBUFFER_USAGE_CPU_READ_RARELY}; + error = AHardwareBuffer_allocate(&desc, &ahw_buffer); + ASSERT_EQ(error, 0); + + void* host_memory_ptr = nullptr; + error = + AHardwareBuffer_lock(ahw_buffer, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, + -1, nullptr, &host_memory_ptr); + ASSERT_EQ(error, 0); + + std::memcpy(host_memory_ptr, kTensorData, sizeof(kTensorData)); + + int fence_file_descriptor = -1; + error = AHardwareBuffer_unlock(ahw_buffer, &fence_file_descriptor); + ASSERT_EQ(error, 0); + } else { + GTEST_SKIP() << "AHardwareBuffers are not supported on this platform; " + "skipping the test"; + } + + // Create a tensor buffer that wraps the AHardwareBuffer. + const litert::RankedTensorType kTensorType(::kTensorType); + auto tensor_buffer_from_external_memory = + litert::TensorBuffer::CreateFromAhwb(kTensorType, ahw_buffer, + /*ahwb_offset=*/0); + + auto lock_and_addr_external_memory = litert::TensorBufferScopedLock::Create( + *tensor_buffer_from_external_memory); + ASSERT_TRUE(lock_and_addr_external_memory); + ASSERT_EQ(std::memcmp(lock_and_addr_external_memory->second, kTensorData, + sizeof(kTensorData)), + 0); + + if (__builtin_available(android 26, *)) { + AHardwareBuffer_release(ahw_buffer); + } +} +#endif // LITERT_HAS_AHWB_SUPPORT + TEST(TensorBuffer, Duplicate) { LiteRtTensorBuffer litert_tensor_buffer; ASSERT_EQ(LiteRtCreateManagedTensorBuffer(kLiteRtTensorBufferTypeHostMemory, diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/BUILD b/tensorflow/lite/experimental/litert/compiler/plugin/BUILD index eb0e2142156d09..87a913d838e3fa 100644 --- a/tensorflow/lite/experimental/litert/compiler/plugin/BUILD +++ b/tensorflow/lite/experimental/litert/compiler/plugin/BUILD @@ -23,6 +23,7 @@ cc_library( hdrs = ["compiler_plugin.h"], deps = [ ":algo", + "//tensorflow/lite/experimental/litert/c:litert_any", "//tensorflow/lite/experimental/litert/c:litert_common", "//tensorflow/lite/experimental/litert/c:litert_logging", "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", @@ -32,10 +33,16 @@ cc_library( "//tensorflow/lite/experimental/litert/cc:litert_model", "//tensorflow/lite/experimental/litert/core:byte_code_util", "//tensorflow/lite/experimental/litert/core:dynamic_loading", + "//tensorflow/lite/experimental/litert/core:environment", "//tensorflow/lite/experimental/litert/core:filesystem", "//tensorflow/lite/experimental/litert/core/model", + "//tensorflow/lite/experimental/litert/core/model:ir_allocator", + "//tensorflow/lite/experimental/litert/core/model:model_serialize", "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin", "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin_api", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], @@ -58,12 +65,14 @@ cc_library( # deps = [ # ":compiler_plugin", # "@com_google_googletest//:gtest_main", -# "//testing/base/public:unique-test-directory", # "@com_google_absl//absl/strings:string_view", # "//tensorflow/lite/experimental/litert/c:litert_common", # "//tensorflow/lite/experimental/litert/c:litert_op_code", +# "//tensorflow/lite/experimental/litert/cc:litert_environment", +# "//tensorflow/lite/experimental/litert/core:byte_code_util", # "//tensorflow/lite/experimental/litert/core:filesystem", # "//tensorflow/lite/experimental/litert/test:common", +# "//tensorflow/lite/experimental/litert/test:test_macros", # "//tensorflow/lite/experimental/litert/tools:dump", # ], # ) @@ -74,11 +83,11 @@ cc_library( srcs = ["algo.cc"], hdrs = ["algo.h"], deps = [ - "//tensorflow/lite/experimental/litert/c:litert_logging", "//tensorflow/lite/experimental/litert/c:litert_op_code", "//tensorflow/lite/experimental/litert/core/model", - "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/experimental/litert/core/model:model_graph", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", "@llvm-project//llvm:Support", ], ) @@ -91,11 +100,11 @@ cc_test( ], deps = [ ":algo", - "//tensorflow/lite/experimental/litert/c:litert_logging", "//tensorflow/lite/experimental/litert/c:litert_op_code", "//tensorflow/lite/experimental/litert/cc:litert_model", "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", "//tensorflow/lite/experimental/litert/core/model", + "//tensorflow/lite/experimental/litert/core/model:graph_validation", "//tensorflow/lite/experimental/litert/test:common", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/algo.cc b/tensorflow/lite/experimental/litert/compiler/plugin/algo.cc index e760dd2214f91a..dfa16d9c36e80b 100644 --- a/tensorflow/lite/experimental/litert/compiler/plugin/algo.cc +++ b/tensorflow/lite/experimental/litert/compiler/plugin/algo.cc @@ -14,25 +14,29 @@ #include "tensorflow/lite/experimental/litert/compiler/plugin/algo.h" -#include -#include -#include -#include #include #include #include #include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" #include "llvm/ADT/MapVector.h" -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" #include "tensorflow/lite/experimental/litert/c/litert_op_code.h" #include "tensorflow/lite/experimental/litert/core/model/model.h" -#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" namespace litert::internal { namespace { +void MakeDispatchOp(LiteRtOpT& op) { + ABSL_DCHECK(op.Inputs().empty()); + ABSL_DCHECK(op.Outputs().empty()); + op.SetOpCode(kLiteRtOpCodeTflCustom); + detail::SetTflOpCodeInd(op, detail::kDispatchOpCodeTflInd); + op.ClearCustomOptions(); +} + // // flatlist to partition(s) //===----------------------------------------------------------------------===// @@ -51,8 +55,7 @@ class DisjointSets { // NOLINTEND }; -inline std::vector> -DisjointSets::GetPartitionsFromFlatList( +std::vector> DisjointSets::GetPartitionsFromFlatList( const std::vector& flat_op_list) { DisjointSets disjoint_sets; for (auto* op : flat_op_list) { @@ -60,8 +63,8 @@ DisjointSets::GetPartitionsFromFlatList( } for (auto* op : flat_op_list) { - for (auto* output : op->outputs) { - for (auto* user : output->users) { + for (auto* output : op->Outputs()) { + for (auto* user : output->Users()) { if (disjoint_sets.map_.count(user) == 0) { continue; } @@ -73,7 +76,7 @@ DisjointSets::GetPartitionsFromFlatList( return disjoint_sets.GetBuckets(); } -inline void DisjointSets::Insert(LiteRtOp op, LiteRtOp parent) { +void DisjointSets::Insert(LiteRtOp op, LiteRtOp parent) { auto* parent_bucket = GetBucket(parent); auto* op_bucket = GetBucket(op); if (op_bucket == parent_bucket) { @@ -83,7 +86,7 @@ inline void DisjointSets::Insert(LiteRtOp op, LiteRtOp parent) { } // Get all disjoint sets. -inline std::vector> DisjointSets::GetBuckets() { +std::vector> DisjointSets::GetBuckets() { // NOLINTBEGIN std::unordered_map> invert_map; // NOLINTEND @@ -109,7 +112,7 @@ inline std::vector> DisjointSets::GetBuckets() { // Gets the pointer which serves as the key for given ops bucket. Collapses // paths to amortize. -inline LiteRtOp DisjointSets::GetBucket(LiteRtOp op) { +LiteRtOp DisjointSets::GetBucket(LiteRtOp op) { auto* parent = map_[op]; if (op != parent) { parent = GetBucket(parent); @@ -122,150 +125,6 @@ inline LiteRtOp DisjointSets::GetBucket(LiteRtOp op) { // slice partitions out of a subgraph (into new subgraphs) //===----------------------------------------------------------------------===// -// TODO: b/365339578 - Move helpers from algo.h to the internal model library. - -inline void CloneOpData(const LiteRtOpT& old_op, LiteRtOpT& new_op) { - new_op.op_code = old_op.op_code; - new_op.option = old_op.option; -} - -inline void CloneTensorData(const LiteRtTensorT& old_tensor, - LiteRtTensorT& new_tensor) { - new_tensor.type_id = old_tensor.type_id; - new_tensor.type_detail = old_tensor.type_detail; - // Copy weights buffer from old tensor to new tensor. - new_tensor.weights.fb_buffer = - std::make_unique(*old_tensor.weights.fb_buffer); -} - -inline std::optional FindUseInd(LiteRtTensor tensor, - LiteRtOp user) { - for (LiteRtParamIndex i = 0; i < tensor->users.size(); ++i) { - if (tensor->users[i] == user) { - return i; - } - } - return std::nullopt; -} - -inline void EraseUse(LiteRtTensor tensor, LiteRtParamIndex use_ind) { - if (use_ind < 0 || use_ind >= tensor->users.size()) { - return; - } - tensor->users[use_ind] = tensor->users.back(); - tensor->users.pop_back(); - tensor->user_arg_inds[use_ind] = tensor->user_arg_inds.back(); - tensor->user_arg_inds.pop_back(); -} - -inline void EraseUse(LiteRtTensor tensor, LiteRtOp user) { - auto use_ind = FindUseInd(tensor, user); - if (!use_ind.has_value()) { - LITERT_LOG(LITERT_WARNING, "Trying to erase from tensor that doesn't use."); - return; - } - EraseUse(tensor, use_ind.value()); -} - -// Push tensor to the end of ops arguments. -inline void AddUse(LiteRtTensorT& tensor, LiteRtOpT& op) { - op.inputs.push_back(&tensor); - tensor.users.push_back(&op); - tensor.user_arg_inds.push_back(op.inputs.size() - 1); -} - -inline void AddOutput(LiteRtOpT& op, LiteRtTensorT& tensor) { - op.outputs.push_back(&tensor); - tensor.defining_op = &op; - tensor.defining_op_out_ind = op.outputs.size() - 1; -} - -inline LiteRtTensor RequestNewTensor(LiteRtSubgraph subgraph, - const LiteRtTensorT& like) { - auto& new_tensor = subgraph->tensors_storage.emplace_back(); - CloneTensorData(like, new_tensor); - return &new_tensor; -} - -inline LiteRtTensor RequestNewInput(LiteRtSubgraph subgraph, - const LiteRtTensorT& like) { - auto new_tensor = RequestNewTensor(subgraph, like); - subgraph->inputs.push_back(new_tensor); - return new_tensor; -} - -inline LiteRtOp RequestNewOp(LiteRtSubgraph subgraph, const LiteRtOpT& like) { - auto& new_op = subgraph->ops_storage.emplace_back(); - CloneOpData(like, new_op); - return &new_op; -} - -inline void AddOutput(LiteRtSubgraph subgraph, LiteRtTensor tensor) { - subgraph->outputs.push_back(tensor); -} - -inline bool IsOutput(const LiteRtSubgraphT& subgraph, LiteRtTensor tensor) { - return std::count(subgraph.outputs.begin(), subgraph.outputs.end(), tensor) > - 0; -} - -inline void UpdateReferences(LiteRtSubgraphT& subgraph) { - subgraph.tensors.clear(); - subgraph.ops.clear(); - for (auto& tensor : subgraph.tensors_storage) { - subgraph.tensors.push_back(&tensor); - } - for (auto& op : subgraph.ops_storage) { - subgraph.ops.push_back(&op); - } -} - -inline void Drop(LiteRtOpT& op) { - for (auto tensor : op.inputs) { - EraseUse(tensor, &op); - } - op.inputs.clear(); - for (auto tensor : op.outputs) { - tensor->defining_op = nullptr; - } - op.outputs.clear(); -} - -// TODO expand dead code elimination to work recursively. This is a very simple. -inline void DCE(LiteRtSubgraphT& subgraph) { - auto& ops = subgraph.ops_storage; - for (auto it = ops.begin(); it != ops.end();) { - if (it->inputs.empty() && it->outputs.empty()) { - it = ops.erase(it); - } else { - ++it; - } - } - - // NOLINTBEGIN - std::set inputs(subgraph.inputs.begin(), subgraph.inputs.end()); - std::set outputs(subgraph.outputs.begin(), - subgraph.outputs.end()); - // NOLINTEND - - auto& tensors = subgraph.tensors_storage; - for (auto it = tensors.begin(); it != tensors.end();) { - auto* tensor = &*it; - - const bool not_in = inputs.find(tensor) == inputs.end(); - const bool not_out = outputs.find(tensor) == outputs.end(); - const bool dead = tensor->defining_op == nullptr && tensor->users.empty(); - - if (not_in && not_out && dead) { - it = tensors.erase(it); - } else { - ++it; - } - } - - UpdateReferences(subgraph); -} - class GraphSlicer { public: // Slices "partitions" from "root" into the empty subgraph "slice". Assumes @@ -287,10 +146,10 @@ class GraphSlicer { // NOLINTBEGIN llvm::MapVector tensor_map_; // NOLINTEND - LiteRtOp hal_cal_op_ = nullptr; + LiteRtOp dispatch_op_ = nullptr; }; -inline LiteRtOp GraphSlicer::SlicePartitionFromGraph( +LiteRtOp GraphSlicer::SlicePartitionFromGraph( LiteRtSubgraphT& root, LiteRtSubgraph slice, std::vector& partition) { GraphSlicer slicer(slice); @@ -300,13 +159,15 @@ inline LiteRtOp GraphSlicer::SlicePartitionFromGraph( // later outlined custom op is the same as the order of input tensors of the // GraphInputs. absl::flat_hash_set used_tensors; + // Get all tensors used in the partition. for (auto* op : partition) { - used_tensors.insert(op->inputs.begin(), op->inputs.end()); + used_tensors.insert(op->Inputs().cbegin(), op->Inputs().cend()); } - for (auto* old_input : root.inputs) { + for (auto* old_input : root.Inputs()) { if (used_tensors.contains(old_input)) { - LiteRtTensor new_input = RequestNewInput(slicer.slice_, *old_input); + auto* new_input = &MakeClone(*slicer.slice_, *old_input); + slicer.slice_->Inputs().push_back(new_input); slicer.tensor_map_.insert({old_input, new_input}); } } @@ -321,60 +182,62 @@ inline LiteRtOp GraphSlicer::SlicePartitionFromGraph( // Reuse the storage from the last op in partition to maintain // toplogical order. - slicer.hal_cal_op_ = partition.back(); - slicer.hal_cal_op_->op_code = kLiteRtOpCodeTflCustom; + slicer.dispatch_op_ = partition.back(); - UpdateReferences(*slicer.slice_); + MakeDispatchOp(*slicer.dispatch_op_); slicer.RerouteTensorsThroughCustomOp(root); + DCE(root); - return slicer.hal_cal_op_; + return slicer.dispatch_op_; } -inline void GraphSlicer::RerouteTensorsThroughCustomOp( - const LiteRtSubgraphT& root) { +void GraphSlicer::RerouteTensorsThroughCustomOp(const LiteRtSubgraphT& root) { for (auto& [old_tensor, new_tensor] : tensor_map_) { // Reroute tensors which need to be passed into the scope of the new // subgraph to inputs of the custom op. - if (new_tensor->defining_op == nullptr) { - AddUse(*old_tensor, *hal_cal_op_); + if (new_tensor->DefiningOp() == nullptr && !IsConstant(*new_tensor)) { + AttachInput(old_tensor, *dispatch_op_); continue; } // Reroute custom op as the definer of tensors within the removed partition // and referenced later in the root graph. - if (!old_tensor->users.empty() || IsOutput(root, old_tensor)) { - AddOutput(*hal_cal_op_, *old_tensor); - AddOutput(slice_, new_tensor); + if ((!old_tensor->Users().empty() && !IsConstant(*old_tensor)) || + FindOutput(root, *old_tensor)) { + AttachOutput(old_tensor, *dispatch_op_); + slice_->Outputs().push_back(new_tensor); } } } -inline void GraphSlicer::CloneInto(const LiteRtOpT& old_op) { - auto& new_op = *RequestNewOp(slice_, old_op); +void GraphSlicer::CloneInto(const LiteRtOpT& old_op) { + auto& new_op = MakeClone(*slice_, old_op); - for (int i = 0; i < old_op.inputs.size(); ++i) { - auto old_input = old_op.inputs[i]; + for (auto i = 0; i < old_op.NumInputs(); ++i) { + auto* old_input = old_op.Inputs().at(i); LiteRtTensor new_input; - if (tensor_map_.contains(old_input)) { // If old_input is already in the map then map[input] is its cloned // counterpart in the new graph. new_input = tensor_map_[old_input]; } else { - // Otherwise, it must be a new subgraph input. - new_input = RequestNewInput(slice_, *old_input); + // Otherwise, it must be a new subgraph input (or constant). + new_input = &MakeClone(*slice_, *old_input); + if (!IsConstant(*new_input)) { + slice_->Inputs().push_back(new_input); + } + tensor_map_.insert({old_input, new_input}); } - AddUse(*new_input, new_op); + AttachInput(new_input, new_op); } - for (int i = 0; i < old_op.outputs.size(); ++i) { - auto old_output = old_op.outputs[i]; - - auto new_output = RequestNewTensor(slice_, *old_output); - AddOutput(new_op, *new_output); + for (int i = 0; i < old_op.NumOutputs(); ++i) { + auto* old_output = old_op.Outputs().at(i); + auto* new_output = &MakeClone(*slice_, *old_output); + AttachOutput(new_output, new_op); // Update the values defined in scope of the new subgraph. tensor_map_.insert({old_output, new_output}); diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/algo_test.cc b/tensorflow/lite/experimental/litert/compiler/plugin/algo_test.cc index 967ec4a6a5366e..c93e268e00a34e 100644 --- a/tensorflow/lite/experimental/litert/compiler/plugin/algo_test.cc +++ b/tensorflow/lite/experimental/litert/compiler/plugin/algo_test.cc @@ -14,76 +14,20 @@ #include "tensorflow/lite/experimental/litert/compiler/plugin/algo.h" -#include -#include #include #include -#include "tensorflow/lite/experimental/litert/c/litert_logging.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" #include "tensorflow/lite/experimental/litert/c/litert_op_code.h" #include "tensorflow/lite/experimental/litert/cc/litert_model.h" #include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" +#include "tensorflow/lite/experimental/litert/core/model/graph_validation.h" #include "tensorflow/lite/experimental/litert/core/model/model.h" #include "tensorflow/lite/experimental/litert/test/common.h" namespace litert::internal { namespace { -// NOLINTBEGIN -bool HasValidGeneralTopology(LiteRtSubgraph subgraph) { - if (!testing::ValidateTopology(Subgraph(subgraph).Ops())) { - LITERT_LOG(LITERT_ERROR, "Invalid topology."); - return false; - } - - std::unordered_set implied_subgraph_outs; - for (auto tensor : subgraph->tensors) { - if (tensor->users.empty()) { - implied_subgraph_outs.insert(tensor); - } - } - - if (implied_subgraph_outs.size() != subgraph->outputs.size()) { - LITERT_LOG(LITERT_ERROR, - "Output size mismatch: %d (Actual) != %d (Expected).", - implied_subgraph_outs.size(), subgraph->outputs.size()); - return false; - } - - for (auto tensor : subgraph->outputs) { - if (implied_subgraph_outs.find(tensor) == implied_subgraph_outs.end()) { - LITERT_LOG(LITERT_ERROR, "Output not found."); - return false; - } - } - - std::unordered_set implied_subgraph_ins; - for (auto tensor : subgraph->tensors) { - if (tensor->defining_op == nullptr && - tensor->weights.fb_buffer->data.empty()) { - implied_subgraph_ins.insert(tensor); - } - } - - if (implied_subgraph_ins.size() != subgraph->inputs.size()) { - LITERT_LOG(LITERT_ERROR, - "Input size mismatch: %d (Actual) != %d (Expected).", - implied_subgraph_ins.size(), subgraph->inputs.size()); - return false; - } - - for (auto tensor : subgraph->inputs) { - if (implied_subgraph_ins.find(tensor) == implied_subgraph_ins.end()) { - LITERT_LOG(LITERT_ERROR, "Input not found."); - return false; - } - } - - return true; -} -// NOLINTEND - TEST(TestPartitionsFromFlatList, SimpleMultiOp) { auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); auto subgraph = model.MainSubgraph(); @@ -121,8 +65,8 @@ TEST(TestPartitionsFromFlatList, SimpleMultiOp) { ASSERT_EQ(partitions.front().size(), 1); ASSERT_EQ(partitions.back().size(), 1); - auto p1_op_code = partitions.front().front()->op_code; - auto p2_op_code = partitions.back().front()->op_code; + auto p1_op_code = partitions.front().front()->OpCode(); + auto p2_op_code = partitions.back().front()->OpCode(); ASSERT_TRUE((p1_op_code == kLiteRtOpCodeTflMul && p2_op_code == kLiteRtOpCodeTflAdd) || @@ -173,12 +117,14 @@ TEST(TestSliceSubgraphSimpleMultiOp, OnePartition) { partition.push_back(ops.at(1).Get()); partition.push_back(ops.at(2).Get()); - auto sliced_graph = litert::Subgraph(&model.Get()->subgraphs.emplace_back()); - auto* hal_cal_op = + auto sliced_graph = litert::Subgraph(&model.Get()->EmplaceSubgraph()); + auto* dispatch_op = OutlinePartition(*subgraph->Get(), sliced_graph.Get(), partition); - ASSERT_TRUE(HasValidGeneralTopology(sliced_graph.Get())); - ASSERT_TRUE(HasValidGeneralTopology(subgraph->Get())); + const auto& internal_sliced = *sliced_graph.Get(); + ASSERT_TRUE(ValidateSubgraphIO(internal_sliced)); + ASSERT_TRUE(ValidateLocalTopology(internal_sliced.Ops().cbegin(), + internal_sliced.Ops().cend())); auto edited_subgraph_ops = subgraph->Ops(); @@ -193,15 +139,15 @@ TEST(TestSliceSubgraphSimpleMultiOp, OnePartition) { ASSERT_EQ(sliced_subgraph_ops[0].Code(), kLiteRtOpCodeTflMul); ASSERT_EQ(sliced_subgraph_ops[1].Code(), kLiteRtOpCodeTflMul); - ASSERT_EQ(hal_cal_op, edited_subgraph_ops.at(1).Get()); - const Op hal_call(hal_cal_op); + ASSERT_EQ(dispatch_op, edited_subgraph_ops.at(1).Get()); + const Op hal_call(dispatch_op); { - const auto hal_cal_op_ins = hal_call.Inputs(); + const auto dispatch_op_ins = hal_call.Inputs(); - ASSERT_EQ(hal_cal_op_ins.size(), 1); + ASSERT_EQ(dispatch_op_ins.size(), 1); - auto hal_input_defining_op = hal_cal_op_ins.front().DefiningOp(); + auto hal_input_defining_op = dispatch_op_ins.front().DefiningOp(); ASSERT_EQ(hal_input_defining_op->op, edited_subgraph_ops.at(0).Get()); ASSERT_EQ(hal_input_defining_op->op_output_index, 0); @@ -253,23 +199,25 @@ TEST(TestSliceSubgraphSimpleMultiOp, TwoPartitions) { std::vector partition_1; partition_1.push_back(ops.at(0).Get()); - auto sliced_graph_1 = - litert::Subgraph(&model.Get()->subgraphs.emplace_back()); + auto sliced_graph_1 = litert::Subgraph(&model.Get()->EmplaceSubgraph()); OutlinePartition(*(subgraph->Get()), sliced_graph_1.Get(), partition_1); - ASSERT_TRUE(HasValidGeneralTopology(sliced_graph_1.Get())); - ASSERT_TRUE(HasValidGeneralTopology(subgraph->Get())); + const auto& internal_slice_1 = *sliced_graph_1.Get(); + ASSERT_TRUE(ValidateSubgraphIO(internal_slice_1)); + ASSERT_TRUE(ValidateLocalTopology(internal_slice_1.Ops().cbegin(), + internal_slice_1.Ops().cend())); std::vector partition_2; partition_2.push_back(ops.at(2).Get()); partition_2.push_back(ops.at(3).Get()); - auto sliced_graph_2 = - litert::Subgraph(&model.Get()->subgraphs.emplace_back()); + auto sliced_graph_2 = litert::Subgraph(&model.Get()->EmplaceSubgraph()); OutlinePartition(*(subgraph->Get()), sliced_graph_2.Get(), partition_2); - ASSERT_TRUE(HasValidGeneralTopology(sliced_graph_2.Get())); - ASSERT_TRUE(HasValidGeneralTopology(subgraph->Get())); + const auto& internal_slice_2 = *sliced_graph_2.Get(); + ASSERT_TRUE(ValidateSubgraphIO(internal_slice_2)); + ASSERT_TRUE(ValidateLocalTopology(internal_slice_2.Ops().cbegin(), + internal_slice_2.Ops().cend())); auto edited_subgraph_ops = subgraph->Ops(); diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.cc b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.cc index c802c9410dfe94..57b078ab8aadf5 100644 --- a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.cc +++ b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.cc @@ -14,18 +14,24 @@ #include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" +#include +#include #include #include #include -#include -#include #include #include #include +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_any.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_environment.h" #include "tensorflow/lite/experimental/litert/c/litert_logging.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" @@ -36,8 +42,11 @@ #include "tensorflow/lite/experimental/litert/compiler/plugin/algo.h" #include "tensorflow/lite/experimental/litert/core/byte_code_util.h" #include "tensorflow/lite/experimental/litert/core/dynamic_loading.h" +#include "tensorflow/lite/experimental/litert/core/environment.h" #include "tensorflow/lite/experimental/litert/core/filesystem.h" +#include "tensorflow/lite/experimental/litert/core/model/ir_allocator.h" #include "tensorflow/lite/experimental/litert/core/model/model.h" +#include "tensorflow/lite/experimental/litert/core/model/model_serialize.h" #include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" #include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h" @@ -50,29 +59,49 @@ namespace litert::internal { Expected> CompiledResult::ByteCode() const { const void* data; size_t size; - LITERT_EXPECT_OK(allocating_plugin_api_.get_compiled_result_byte_code( + LITERT_EXPECT_OK(parent_.get_compiled_result_byte_code( compiled_result_handle_, &data, &size)); return BufferRef(data, size); } Expected CompiledResult::NumCalls() const { LiteRtParamIndex call_idx; - LITERT_EXPECT_OK(allocating_plugin_api_.get_compiled_result_num_calls( + LITERT_EXPECT_OK(parent_.get_compiled_result_num_calls( compiled_result_handle_, &call_idx)); return call_idx; } -Expected CompiledResult::CallInfo( +Expected CompiledResult::CallInfo( LiteRtParamIndex call_idx) const { const void* data; size_t size; - LITERT_EXPECT_OK(allocating_plugin_api_.get_compiled_result_call_info( + LITERT_EXPECT_OK(parent_.get_compiled_result_call_info( compiled_result_handle_, call_idx, &data, &size)); - return std::string(reinterpret_cast(data), size); + return absl::string_view(reinterpret_cast(data), size); } CompiledResult::~CompiledResult() { - allocating_plugin_api_.destroy_compiled_result(compiled_result_handle_); + if (compiled_result_handle_ != nullptr) { + parent_.destroy_compiled_result(compiled_result_handle_); + } +} + +CompiledResult::CompiledResult(CompiledResult&& other) + : parent_(other.parent_), + compiled_result_handle_(other.compiled_result_handle_) { + other.parent_ = {}; + other.compiled_result_handle_ = nullptr; +} + +CompiledResult& CompiledResult::operator=(CompiledResult&& other) { + if (this != &other) { + parent_ = other.parent_; + other.parent_ = {}; + + compiled_result_handle_ = other.compiled_result_handle_; + other.compiled_result_handle_ = nullptr; + } + return *this; } // @@ -87,38 +116,41 @@ namespace { LiteRtStatus ResolvePluginApi(void* lib_handle, LiteRtCompilerPluginApi& result) { - RESOLVE_API_FUNC("LiteRtGetCompilerPluginVersion", + RESOLVE_API_FUNC(kLiteRtGetCompilerPluginVersion, result.get_compiler_plugin_version); - RESOLVE_API_FUNC("LiteRtGetCompilerPluginSocManufacturer", + RESOLVE_API_FUNC(kLiteRtGetCompilerPluginSupportedHardware, + result.get_compiler_plugin_supported_hardware); + RESOLVE_API_FUNC(kLiteRtGetCompilerPluginSocManufacturer, result.get_compiler_plugin_soc_manufacturer); - RESOLVE_API_FUNC("LiteRtGetNumCompilerPluginSupportedSocModels", + RESOLVE_API_FUNC(kLiteRtGetNumCompilerPluginSupportedSocModels, result.get_num_compiler_plugin_supported_models); - RESOLVE_API_FUNC("LiteRtGetCompilerPluginSupportedSocModel", + RESOLVE_API_FUNC(kLiteRtGetCompilerPluginSupportedSocModel, result.get_compiler_plugin_supported_soc_model); - RESOLVE_API_FUNC("LiteRtCreateCompilerPlugin", result.create_compiler_plugin); - RESOLVE_API_FUNC("LiteRtDestroyCompilerPlugin", + RESOLVE_API_FUNC(kLiteRtCreateCompilerPlugin, result.create_compiler_plugin); + RESOLVE_API_FUNC(kLiteRtDestroyCompilerPlugin, result.destroy_compiler_plugin); - RESOLVE_API_FUNC("LiteRtCompilerPluginPartitionModel", - result.compiler_plugin_partition_model); - RESOLVE_API_FUNC("LiteRtCompilerPluginCompile", + RESOLVE_API_FUNC(kLiteRtCompilerPluginPartition, + result.compiler_plugin_partition); + RESOLVE_API_FUNC(kLiteRtCompilerPluginCompile, result.compiler_plugin_compile); - RESOLVE_API_FUNC("LiteRtDestroyCompiledResult", + RESOLVE_API_FUNC(kLiteRtDestroyCompiledResult, result.destroy_compiled_result); - RESOLVE_API_FUNC("LiteRtGetCompiledResultByteCode", + RESOLVE_API_FUNC(kLiteRtGetCompiledResultByteCode, result.get_compiled_result_byte_code); - RESOLVE_API_FUNC("LiteRtGetCompiledResultCallInfo", + RESOLVE_API_FUNC(kLiteRtGetCompiledResultCallInfo, result.get_compiled_result_call_info); - RESOLVE_API_FUNC("LiteRtGetNumCompiledResultCalls", + RESOLVE_API_FUNC(kLiteRtGetNumCompiledResultCalls, result.get_compiled_result_num_calls); + return kLiteRtStatusOk; } -Expected> GetSocModels( +Expected> GetSocModels( const LiteRtCompilerPluginApi& api, LiteRtCompilerPlugin plugin_handle) { - SmallVec soc_models; + std::vector soc_models; LiteRtParamIndex num_models; LITERT_EXPECT_OK( @@ -136,6 +168,28 @@ Expected> GetSocModels( return soc_models; } +// Sort plugins so that we first apply those supporting NPU, then those +// supporting GPU, and finally those supporting CPU. +void SortPlugins(std::vector& compiler_plugins) { + std::sort(compiler_plugins.begin(), compiler_plugins.end(), + [](auto& x, auto& y) { + auto x_supported_hardware = x.SupportedHardware(); + auto y_supported_hardware = y.SupportedHardware(); + if (x_supported_hardware && y_supported_hardware) { + bool x_npu = (*x_supported_hardware & kLiteRtHwAccelatorNpu); + bool x_gpu = (*x_supported_hardware & kLiteRtHwAccelatorGpu); + bool x_cpu = (*x_supported_hardware & kLiteRtHwAccelatorCpu); + bool y_npu = (*y_supported_hardware & kLiteRtHwAccelatorNpu); + bool y_gpu = (*y_supported_hardware & kLiteRtHwAccelatorGpu); + bool y_cpu = (*y_supported_hardware & kLiteRtHwAccelatorCpu); + int x_score = 100 * x_npu + 10 * x_gpu + x_cpu; + int y_score = 100 * y_npu + 10 * y_gpu + y_cpu; + return x_score < y_score; + } + return true; + }); +} + } // namespace Expected CompilerPlugin::LoadPlugin( @@ -180,7 +234,7 @@ Expected CompilerPlugin::LoadPlugin( return plugin; } -Expected> CompilerPlugin::LoadPlugins( +Expected> CompilerPlugin::LoadPlugins( absl::Span lib_search_paths) { std::vector plugin_lib_paths; for (auto search_path : lib_search_paths) { @@ -190,7 +244,7 @@ Expected> CompilerPlugin::LoadPlugins( } } - SmallVec loaded_plugins; + std::vector loaded_plugins; loaded_plugins.reserve(lib_search_paths.size()); for (const auto& lib_path : plugin_lib_paths) { @@ -202,31 +256,17 @@ Expected> CompilerPlugin::LoadPlugins( loaded_plugins.push_back(std::move(plugin.Value())); } - return loaded_plugins; -} - -Expected CompilerPlugin::LoadPlugin( - absl::Span lib_search_paths, - absl::string_view soc_manufacturer) { - auto compiler_plugins = LoadPlugins(lib_search_paths); - if (!compiler_plugins) { - return compiler_plugins.Error(); - } - - for (auto& plugin : *compiler_plugins) { - if (plugin.SocManufacturer() == soc_manufacturer) { - return std::move(plugin); - } - } + // Sort plugins. + SortPlugins(loaded_plugins); - return Error(kLiteRtStatusErrorNotFound); + return loaded_plugins; } CompilerPlugin::CompilerPlugin(CompilerPlugin&& other) : soc_models_(std::move(other.soc_models_)), - lib_handle_(other.lib_handle_), + lib_handle_(std::move(other.lib_handle_)), plugin_api_(std::move(other.plugin_api_)), - plugin_handle_(other.plugin_handle_) { + plugin_handle_(std::move(other.plugin_handle_)) { other.soc_models_ = {}; other.plugin_api_ = {}; other.lib_handle_ = nullptr; @@ -235,17 +275,10 @@ CompilerPlugin::CompilerPlugin(CompilerPlugin&& other) CompilerPlugin& CompilerPlugin::operator=(CompilerPlugin&& other) { if (this != &other) { - soc_models_ = std::move(other.soc_models_); - other.soc_models_ = {}; - - lib_handle_ = other.lib_handle_; - other.lib_handle_ = nullptr; - - plugin_api_ = std::move(other.plugin_api_); - other.plugin_api_ = {}; - - plugin_handle_ = other.plugin_handle_; - other.plugin_handle_ = nullptr; + std::swap(soc_models_, other.soc_models_); + std::swap(lib_handle_, other.lib_handle_); + std::swap(plugin_api_, other.plugin_api_); + std::swap(plugin_handle_, other.plugin_handle_); } return *this; } @@ -261,144 +294,232 @@ CompilerPlugin::~CompilerPlugin() { } } +std::string CompilerPlugin::DebugString() const { + std::string version_str = "?"; + if (auto version = ApiVersion(); version) { + version_str = absl::StrFormat("%d.%d.%d", version->major, version->minor, + version->patch); + } + return absl::StrFormat("%s compiler plugin (ver %s)", SocManufacturer(), + version_str); +} + Expected CompilerPlugin::ApiVersion() const { LiteRtApiVersion api_version; LITERT_EXPECT_OK(plugin_api_.get_compiler_plugin_version(&api_version)); return api_version; } -Expected> CompilerPlugin::PartitionModel( - const Model& model) { +Expected CompilerPlugin::SupportedHardware() const { + LiteRtHwAccelerators supported_hardware; + LITERT_EXPECT_OK(plugin_api_.get_compiler_plugin_supported_hardware( + plugin_handle_, &supported_hardware)); + return supported_hardware; +} + +Expected> CompilerPlugin::Partition( + const Subgraph& subgraph) { LiteRtOpListT ops; - LiteRtModel model_handle = model.Get(); - LITERT_EXPECT_OK(plugin_api_.compiler_plugin_partition_model( - plugin_handle_, model_handle, &ops)); + LITERT_EXPECT_OK(plugin_api_.compiler_plugin_partition(plugin_handle_, + subgraph.Get(), &ops)); return ops.Vec(); } -LiteRtStatus CompilerPlugin::Compile( - std::optional soc_model, - const std::vector& partitions, std::ostream& byte_code_out, - std::vector& call_info_out) { +Expected CompilerPlugin::Compile( + absl::Span partitions, absl::string_view soc_model) { CompiledResult result = MakeResult(); + // If the user has passed an soc_model, then we use it; otherwise we let the + // backend pick the appropriate one by passing nullptr as soc_model. This is + // important for on-device compilation, where the backend must determine the + // SoC model based on the user device. + const char* soc_model_str = !soc_model.empty() ? soc_model.data() : nullptr; + LITERT_EXPECT_OK(plugin_api_.compiler_plugin_compile( + plugin_handle_, soc_model_str, partitions.data(), partitions.size(), + &result.compiled_result_handle_)); + return result; +} + +namespace { - const char* soc_model_str = soc_model ? soc_model->data() : nullptr; - - // Compile given partitions into result. - // TODO: Use const where appropriate in the C compiler plugin api. - LiteRtSubgraphArray partitions_arr = - const_cast(partitions.data()); - if (auto stat = plugin_api_.compiler_plugin_compile( - plugin_handle_, soc_model_str, partitions_arr, partitions.size(), - &result.compiled_result_handle_); - stat != kLiteRtStatusOk) { - return stat; +LiteRtStatus PartitionSubgraph(CompilerPlugin& compiler_plugin, + LiteRtSubgraphT& subgraph, + PartitionResult& result) { + // Get selected ops from plugin. + auto selected_ops = compiler_plugin.Partition(Subgraph(&subgraph)); + if (!selected_ops) { + LITERT_LOG(LITERT_ERROR, "Failed to get partitions from plugin"); + return selected_ops.Error().Status(); } - // Parse call info from the result. - { - auto num_call = result.NumCalls(); - if (!num_call) { - return num_call.Error().Status(); - } - if (num_call.Value() != partitions.size()) { - LITERT_LOG( - LITERT_ERROR, "%s", - "Plugin didn't return call info for each partition compiled.\n"); - return kLiteRtStatusErrorRuntimeFailure; - } - for (int i = 0; i < num_call.Value(); ++i) { - auto call_info = result.CallInfo(i); - if (!call_info) { - return call_info.Error().Status(); - } - call_info_out.emplace_back() = *call_info; - } + // Group selected ops into connected islands. + auto islands = GroupPartitions(*selected_ops); + if (islands.empty()) { + LITERT_LOG(LITERT_ERROR, "Failed to group partitions"); + return kLiteRtStatusErrorRuntimeFailure; } - // Parse byte code from result. - { - auto byte_code = result.ByteCode(); - if (!byte_code) { - return byte_code.Error().Status(); - } - LITERT_LOG(LITERT_INFO, "Compiled %d partitions in %lu bytes", - partitions.size(), byte_code->Size()); - byte_code->WriteStr(byte_code_out); + // For each connected island, slice into new subgraph and replace use with + // single dispatch op. + for (auto& island : islands) { + auto& new_subgraph = result.second.EmplaceBack(); + auto* dispatch_op = OutlinePartition(subgraph, &new_subgraph, island); + result.first.push_back(dispatch_op); } return kLiteRtStatusOk; } -Expected> ApplyPlugin( - CompilerPlugin& compiler_plugin, Model& model, - std::optional soc_model) { - // Get selected ops from plugin. - auto partition = compiler_plugin.PartitionModel(model); - if (!partition) { - LITERT_LOG(LITERT_ERROR, "Failed to get partitions from plugin"); - return Error(kLiteRtStatusErrorRuntimeFailure); +} // namespace + +Expected PartitionModel(CompilerPlugin& compiler_plugin, + LiteRtModelT& model) { + // Accumulate partition results for each subgraph in model. + PartitionResult result; + for (auto* subgraph : model.Subgraphs()) { + LITERT_EXPECT_OK(PartitionSubgraph(compiler_plugin, *subgraph, result)); } + ABSL_DCHECK_EQ(result.first.size(), result.second.Size()); + return result; +} - // Group selected ops into partitions. - auto grouped_partitions = GroupPartitions(*partition); - if (grouped_partitions.empty()) { - LITERT_LOG(LITERT_ERROR, "Failed to group partitions"); - return Error(kLiteRtStatusErrorRuntimeFailure); +Expected ApplyPlugin(CompilerPlugin& compiler_plugin, LiteRtModelT& model, + absl::string_view soc_model, + Serialization serialization) { + // Collect partitions to pass to compilation. + auto partitions = PartitionModel(compiler_plugin, model); + if (!partitions) { + return partitions.Error(); } - if (grouped_partitions.size() > 1) { - LITERT_LOG(LITERT_ERROR, "Apply on multiple partitions not supported yet."); - return Error(kLiteRtStatusErrorUnsupported); + auto& dispatch_ops = partitions->first; + auto& subgraphs = partitions->second; + + // Pass sliced subgraphs to plugin for compilation. + auto compiled_result = + compiler_plugin.Compile(subgraphs.Elements(), soc_model); + if (!compiled_result) { + return compiled_result.Error(); + } + + // Attach per-partition call info to the respective op. + // This data may be adjusted during serialization. Just passthrough for now. + for (auto i = 0; i < dispatch_ops.size(); ++i) { + auto call_info = compiled_result->CallInfo(i); + if (!call_info) { + return call_info.Error(); + } + auto exec_info = MakeExecInfo(*call_info, kByteCodeMetadataKey); + if (!exec_info) { + return exec_info.Error(); + } + dispatch_ops.at(i)->SetCustomOptions(std::move(*exec_info)); } - // Outline the partitions into new subgraphs. - std::vector custom_ops; - for (auto& partition : grouped_partitions) { - auto custom_op = - OutlinePartition(model.Get()->subgraphs.front(), - &model.Get()->subgraphs.emplace_back(), partition); - custom_ops.push_back(custom_op); + // Store the byte code in a metadata buffer. This data may be adjusted during + // serialization. Just passthrough for now. + auto byte_code = compiled_result->ByteCode(); + if (!byte_code) { + return byte_code.Error(); } + model.PushMetadata(kByteCodeMetadataKey, byte_code->StrView()); - // Pass new subgraphs to the plugin for compilation. - std::vector compilation_input; - for (auto it = model.Get()->subgraphs.begin() + 1; - it < model.Get()->subgraphs.end(); ++it) { - compilation_input.push_back(&*it); + // Tag the model with make/model from the plugin. + auto build_stamp = MakeBuildStamp(compiler_plugin.SocManufacturer(), + soc_model, serialization); + if (!build_stamp) { + return build_stamp.Error(); } - // Compile partitions with plugin. - std::stringstream byte_code; - std::vector exec_info; - if (auto status = compiler_plugin.Compile(soc_model, compilation_input, - byte_code, exec_info); + if (auto status = + model.PushMetadata(kLiteRtBuildStampKey, std::move(*build_stamp)); status != kLiteRtStatusOk) { - LITERT_LOG(LITERT_ERROR, "Failed to compile partitions."); return Error(status); } - if (exec_info.size() != custom_ops.size()) { - LITERT_LOG(LITERT_ERROR, - "Compilation did not return exec_info for every partition"); - return Error(kLiteRtStatusErrorRuntimeFailure); + return {}; +} + +Expected ApplyPlugins( + LiteRtModel model, LiteRtHwAccelerators selected_hw_accelerators) { + auto environment = litert::internal::Environment::Instance(); + if (!environment) { + return environment.Error(); + } + + std::string compiler_plugin_lib_path = "."; + auto option = + (*environment)->GetOption(kLiteRtEnvOptionTagCompilerPluginLibraryPath); + if (option.has_value() && option->type == kLiteRtAnyTypeString) { + compiler_plugin_lib_path = option->str_value; } - model.Get()->custom_op_code = kLiteRtDispatchOpCustomCode; + const std::array + compiler_plugin_lib_search_paths = {compiler_plugin_lib_path}; - // Attach entry point info to the custom ops. - auto custom_op_it = custom_ops.begin(); - auto exec_info_it = exec_info.begin(); - for (; custom_op_it < custom_ops.end(); custom_op_it++, exec_info_it++) { - LiteRtOp custom_op = *custom_op_it; - const auto& exec_info = *exec_info_it; - custom_op->custom_options = OwningBufferRef(exec_info.data()); + auto compiler_plugins = litert::internal::CompilerPlugin::LoadPlugins( + compiler_plugin_lib_search_paths); + if (!compiler_plugins) { + return compiler_plugins.Error(); + } + if (compiler_plugins->empty()) { + return litert::Error(kLiteRtStatusErrorRuntimeFailure, + "No compiler plugin found"); } - const auto byte_code_str = byte_code.str(); - return OwningBufferRef( - reinterpret_cast(byte_code_str.data()), - byte_code_str.size()); + OwningBufferRef new_flatbuffer; + std::vector success_messages; + std::vector error_messages; + + ApplyPluginsResult result; + result.num_applied_plugins = 0; + for (auto& compiler_plugin : *compiler_plugins) { + auto plugin_name = compiler_plugin.DebugString(); + + auto plugin_supported_hardware = compiler_plugin.SupportedHardware(); + if (!plugin_supported_hardware) { + error_messages.push_back(absl::StrCat( + plugin_name, " ", plugin_supported_hardware.Error().Message())); + continue; + } + + if (*plugin_supported_hardware & selected_hw_accelerators) { + // FIXME: the following code is quite inefficient and convoluted. We + // shouldn't be needing to serialize a model to then read it again from + // the serialized buffer when applying a compiler plugin. + if (auto status = ApplyPlugin(compiler_plugin, *model); !status) { + error_messages.push_back( + absl::StrCat(plugin_name, " ", status.Error().Message())); + continue; + } + + auto serialized_model = + litert::internal::SerializeModel(std::move(*model)); + if (!serialized_model) { + error_messages.push_back( + absl::StrCat(plugin_name, " ", serialized_model.Error().Message())); + continue; + } + + auto new_model = litert::Model::CreateFromBuffer(*serialized_model); + if (!new_model) { + error_messages.push_back( + absl::StrCat(plugin_name, " ", new_model.Error().Message())); + continue; + } + + new_flatbuffer = std::move(*serialized_model); + *model = std::move(*new_model->Get()); + + success_messages.push_back(absl::StrCat(plugin_name)); + result.num_applied_plugins++; + } + } + + result.new_flatbuffer = std::move(new_flatbuffer); + result.success_message = absl::StrJoin(success_messages, ", "); + result.error_message = absl::StrJoin(error_messages, ", "); + + return result; } } // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h index 259722dc9e3be8..f3b93293a60f76 100644 --- a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h +++ b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h @@ -15,9 +15,8 @@ #ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_PLUGIN_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_COMPILER_PLUGIN_COMPILER_PLUGIN_H_ -#include -#include #include +#include #include #include "absl/strings/string_view.h" @@ -28,46 +27,60 @@ #include "tensorflow/lite/experimental/litert/cc/litert_detail.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" #include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/core/byte_code_util.h" +#include "tensorflow/lite/experimental/litert/core/model/model.h" #include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" #include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h" +// C++ wrappers and high-level functions for managing compiler plugins +// and applying them to models. + namespace litert::internal { +// Wraps vendor compiled result. Must be outlived by the CompilerPlugin +// the generated it. class CompiledResult { + public: friend class CompilerPlugin; + // Get the single module of compiled byte code. This contains the // compilation result for all entry points. Expected> ByteCode() const; // Get information regarding the "ith" entry points in the compiled module. // There will be oe entry point for each subgraph compiled for. - Expected CallInfo(LiteRtParamIndex call_idx) const; + Expected CallInfo(LiteRtParamIndex call_idx) const; // Get the number of entry points in the compiled module. This will be equal // to the number of subgraphs passed to the compilation step. Expected NumCalls() const; - explicit CompiledResult(const LiteRtCompilerPluginApi& allocating_plugin_api) - : allocating_plugin_api_(allocating_plugin_api) {} + explicit CompiledResult(const LiteRtCompilerPluginApi& parent) + : parent_(parent) {} - CompiledResult(CompiledResult&& other) = default; - CompiledResult& operator=(CompiledResult&& other) = default; + CompiledResult(CompiledResult&& other); + CompiledResult& operator=(CompiledResult&& other); CompiledResult(const CompiledResult& other) = delete; CompiledResult& operator=(const CompiledResult& other) = delete; ~CompiledResult(); - LiteRtCompilerPluginApi allocating_plugin_api_; + private: + LiteRtCompilerPluginApi parent_; LiteRtCompiledResult compiled_result_handle_ = nullptr; }; -// Syntatic sugar around dynamically loaded LiteRtCompilerPlugin libraries. -// TODO turn this into a general C++ wraper for the whole compiler plugin api. +// Wraps vendor compiler plugin. class CompilerPlugin { public: + std::string DebugString() const; + // Get the compiler plugin's API version. Expected ApiVersion() const; + // Get the supported HW accelerators (e.g., GPU, NPU). + Expected SupportedHardware() const; + // Get the manufacturer associated with this plugin. NOTE: SocManufacturer // string returned by the underlying plugin are expected to have static // lifetime. @@ -76,37 +89,24 @@ class CompilerPlugin { } // Get list of unique soc models targetable by this plugin. - const SmallVec& SocModels() const { return soc_models_; } + const std::vector& SocModels() const { return soc_models_; } // Selects ops for the plugin to compile. - Expected> PartitionModel(const Model& model); - - // Compile given LiteRtSubgraphs. Write compiled byte code to the given - // stream. For each given subgraph, write opaque data about the corresponding - // entry point to the given "call_info_out". Parameter "soc_model" is optional - // and can be set to specify the target SoC; for on-device compilation it - // should be left unspecified so as to let the underlying logic pick the - // architecture that matches the SoC on the user device. - LiteRtStatus Compile(std::optional soc_model, - const std::vector& partitions, - std::ostream& byte_code_out, - std::vector& call_info_out); + Expected> Partition(const Subgraph& subgraph); + + // Compile given LiteRtSubgraphs. Result object must be outlived by + // this CompilerPlugin. + Expected Compile(absl::Span partitions, + absl::string_view soc_model = ""); // Search for shared library files with prefix "libLiteRtCompilerPlugin" in // the directories passed through "lib_search_paths". Populates // "loaded_plugins" with resolved plugin apis for each found library that can // be succesfully loaded. Additionally initializes the compiler plugin // instances and stores handle. - static Expected> LoadPlugins( + static Expected> LoadPlugins( absl::Span lib_search_paths); - // Search for shared library files with prefix "libLiteRtCompilerPlugin" in - // the directories passed through "lib_search_paths" and return a compiler - // plugin instance for a given manufactured, if one is found. - static Expected LoadPlugin( - absl::Span lib_search_paths, - absl::string_view soc_manufacturer); - CompilerPlugin(CompilerPlugin&& other); CompilerPlugin& operator=(CompilerPlugin&& other); CompilerPlugin(const CompilerPlugin& other) = delete; @@ -120,7 +120,7 @@ class CompilerPlugin { static Expected LoadPlugin(absl::string_view lib_path); CompilerPlugin() = default; - SmallVec soc_models_; + std::vector soc_models_; void* lib_handle_ = nullptr; LiteRtCompilerPluginApi plugin_api_ = {}; LiteRtCompilerPlugin plugin_handle_ = nullptr; @@ -130,15 +130,43 @@ class CompilerPlugin { CompiledResult MakeResult() const { return CompiledResult(plugin_api_); } }; -// Applies the plugin's "partition" and "compile" steps to the given model. -// Returns the serialized model with NPU code appended to the back. Parameter -// "soc_model" is optional and can be set to specify the target SoC; for -// on-device compilation it should be left unspecified so as to let the -// underlying logic pick the architecture that matches the SoC on the user -// device -Expected> ApplyPlugin( - CompilerPlugin& compiler_plugin, Model& model, - std::optional soc_model = std::nullopt); +// Higher level functions for applying plugin to graph. +//===--------------------------------------------------------------------------- + +// Dispatch op references and their subgraph to be compiled. +using PartitionResult = + std::pair, typename LiteRtSubgraphT::Alloc>; + +// Applies just the partition phase of the plugin on the model. Returns +// references newly allocated subgraphs removed from input and their +// corresponding dispatch ops in the input. +Expected PartitionModel(CompilerPlugin& compiler_plugin, + LiteRtModelT& model); + +// Applies both the partition and compile steps to the model. Generated +// byte_code will be internalized within the model for later serialization. +// The serialization parameter refers to the strategy used to pack the byte code +// during future serialization. +Expected ApplyPlugin( + CompilerPlugin& compiler_plugin, LiteRtModelT& model, + absl::string_view soc_model = "", + Serialization serialization = Serialization::kAppend); + +// Apply all available plugins providing the selected HW accelerators to the +// given model, modify the model accordingly, and return (1) the number of +// compiler plugins succesfully applied, (2) a new flatbuffer backing the +// modified model, (3) a string listing the compiler plugins that were +// succesfully applied, and (4) a string listing the compiler plugins that +// failed to apply with an associated error message. +struct ApplyPluginsResult { + size_t num_applied_plugins; + OwningBufferRef new_flatbuffer; + std::string success_message; + std::string error_message; +}; + +Expected ApplyPlugins( + LiteRtModel model, LiteRtHwAccelerators selected_hw_accelerators); } // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin_test.cc b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin_test.cc index ef51003be27403..40edd8ecf9c72d 100644 --- a/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin_test.cc +++ b/tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin_test.cc @@ -14,6 +14,7 @@ #include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" +#include #include #include #include @@ -21,17 +22,21 @@ #include #include -#include "testing/base/public/unique-test-directory.h" #include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" +#include "tensorflow/lite/experimental/litert/core/byte_code_util.h" #include "tensorflow/lite/experimental/litert/core/filesystem.h" #include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/test/test_macros.h" #include "tensorflow/lite/experimental/litert/tools/dump.h" namespace litert::internal { namespace { -using ::testing::UniqueTestDirectory; +using ::testing::HasSubstr; +using testing::UniqueTestDirectory; constexpr absl::string_view kTestPluginSearchPath = "third_party/tensorflow/lite/experimental/litert/vendors/examples"; @@ -49,8 +54,9 @@ TEST(CompilerPluginTest, LoadTestPlugin) { } TEST(CompilerPluginTest, LoadTestPluginWithMalformed) { - const auto dir = UniqueTestDirectory(); - Touch(Join({dir, "notLibLiteRt.so"})); + const auto dir = UniqueTestDirectory::Create(); + ASSERT_TRUE(dir); + Touch(Join({dir->Str(), "notLibLiteRt.so"})); auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); @@ -98,32 +104,40 @@ TEST(CompilerPluginTest, SocModels) { ::testing::ElementsAreArray({kTestModels})); } -TEST(CompilerPluginTest, PartitionModel) { +TEST(CompilerPluginTest, Partition) { auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); ASSERT_EQ(plugins->size(), 1); EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); auto model = testing::LoadTestFileModel("mul_simple.tflite"); auto subgraph = model.MainSubgraph(); + auto ops = plugins->front().Partition(*subgraph); + ASSERT_TRUE(ops); - EXPECT_EQ(subgraph->Ops().size(), 2); + EXPECT_EQ(ops->size(), 2); } -TEST(CompilerPluginTest, CompileModel) { +TEST(CompilerPluginTest, Compile) { auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); ASSERT_EQ(plugins->size(), 1); EXPECT_EQ(plugins->front().SocManufacturer(), kTestManufacturer); - auto model = testing::LoadTestFileModel("mul_simple.tflite"); - auto subgraph = model.MainSubgraph(); + auto model_wrap = testing::LoadTestFileModel("mul_simple.tflite"); + auto& model = *model_wrap.Get(); + + auto result = plugins->front().Compile(model.Subgraphs()); + ASSERT_TRUE(result); - std::ostringstream byte_code_out; - std::vector call_info_out; - LITERT_ASSERT_STATUS_OK(plugins->front().Compile( - kTestModels, {subgraph->Get()}, byte_code_out, call_info_out)); + auto byte_code = result->ByteCode(); + ASSERT_TRUE(byte_code && byte_code->Size() > 0); - EXPECT_GT(byte_code_out.str().size(), 0); - EXPECT_EQ(call_info_out.size(), 1); + auto num_calls = result->NumCalls(); + ASSERT_TRUE(num_calls); + ASSERT_EQ(*num_calls, 1); + + auto call_info = result->CallInfo(0); + ASSERT_TRUE(call_info); + ASSERT_FALSE(call_info->empty()); } TEST(CompilerPluginTest, Dump) { @@ -138,20 +152,131 @@ TEST(CompilerPluginTest, Dump) { "ExampleSocModel }\n"); } -TEST(ApplyPluginTest, ApplyPlugin) { +TEST(PartitionModelTest, Simple) { + auto model_wrap = testing::LoadTestFileModel("mul_simple.tflite"); + auto& model = *model_wrap.Get(); + auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); ASSERT_EQ(plugins->size(), 1); - auto model = testing::LoadTestFileModel("mul_simple.tflite"); - ASSERT_TRUE(model); + auto& plugin = plugins->front(); + + auto partition_result = PartitionModel(plugin, model); + ASSERT_TRUE(partition_result); + ASSERT_EQ(model.NumSubgraphs(), 1); + + const auto& [ops, subgraphs] = *partition_result; + + EXPECT_EQ(ops.size(), 1); + EXPECT_EQ(ops.front()->OpCode(), kLiteRtOpCodeTflCustom); + + EXPECT_EQ(subgraphs.Size(), 1); + EXPECT_EQ(subgraphs.Elements().front()->Ops().size(), 2); +} + +TEST(PartitionModelTest, MultiSubgraph) { + auto model_wrap = testing::LoadTestFileModel("multi_subgraph_mul.tflite"); + auto& model = *model_wrap.Get(); + + auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); + ASSERT_EQ(plugins->size(), 1); + auto& plugin = plugins->front(); + + auto partition_result = PartitionModel(plugin, model); + ASSERT_TRUE(partition_result); + ASSERT_EQ(model.NumSubgraphs(), 2); + + const auto& [ops, subgraphs] = *partition_result; + + EXPECT_EQ(ops.size(), 2); + EXPECT_EQ(ops.front()->OpCode(), kLiteRtOpCodeTflCustom); + EXPECT_EQ(ops.back()->OpCode(), kLiteRtOpCodeTflCustom); + + EXPECT_EQ(subgraphs.Size(), 2); + EXPECT_EQ(subgraphs.Elements().front()->Ops().size(), 1); + EXPECT_EQ(subgraphs.Elements().back()->Ops().size(), 1); +} + +TEST(ApplyTest, Simple) { + auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); + ASSERT_EQ(plugins->size(), 1); + auto model_wrap = testing::LoadTestFileModel("mul_simple.tflite"); + ASSERT_TRUE(model_wrap); + auto& model = *model_wrap.Get(); + + ASSERT_TRUE(ApplyPlugin(plugins->front(), model)); + ASSERT_EQ(model.NumSubgraphs(), 1); + + auto& subgraph = *model.MainSubgraph(); + ASSERT_EQ(subgraph.Ops().size(), 1); + + EXPECT_EQ(subgraph.Op(0).OpCode(), kLiteRtOpCodeTflCustom); + EXPECT_THAT(subgraph.Op(0).CustomOptions().StrView(), + HasSubstr(kByteCodeMetadataKey)); + + EXPECT_TRUE(model.FindMetadata(kByteCodeMetadataKey)); + EXPECT_TRUE(model.FindMetadata(kLiteRtBuildStampKey)); +} + +TEST(ApplyTest, MultiSubgraph) { + auto plugins = CompilerPlugin::LoadPlugins({kTestPluginSearchPath}); + ASSERT_EQ(plugins->size(), 1); + auto model_wrap = testing::LoadTestFileModel("multi_subgraph_mul.tflite"); + ASSERT_TRUE(model_wrap); + auto& model = *model_wrap.Get(); + + ASSERT_TRUE(ApplyPlugin(plugins->front(), model)); + ASSERT_EQ(model.NumSubgraphs(), 2); + + auto& subgraph = model.Subgraph(0); + ASSERT_EQ(subgraph.Ops().size(), 1); + EXPECT_EQ(subgraph.Op(0).OpCode(), kLiteRtOpCodeTflCustom); + EXPECT_THAT(subgraph.Op(0).CustomOptions().StrView(), + HasSubstr(kByteCodeMetadataKey)); + + auto& subgraph2 = model.Subgraph(1); + ASSERT_EQ(subgraph2.Ops().size(), 1); + EXPECT_EQ(subgraph2.Op(0).OpCode(), kLiteRtOpCodeTflCustom); + EXPECT_THAT(subgraph2.Op(0).CustomOptions().StrView(), + HasSubstr(kByteCodeMetadataKey)); + + EXPECT_TRUE(model.FindMetadata(kByteCodeMetadataKey)); + EXPECT_TRUE(model.FindMetadata(kLiteRtBuildStampKey)); +} + +TEST(ApplyTest, ApplyPlugins) { + litert::Environment::Destroy(); + + auto model_wrap = testing::LoadTestFileModel("mul_simple.tflite"); + ASSERT_TRUE(model_wrap); + auto& model = *model_wrap.Get(); + + const std::array environment_options = { + litert::Environment::Option{ + /*.tag=*/litert::Environment::OptionTag::CompilerPluginLibraryPath, + /*.value=*/kTestPluginSearchPath, + }, + }; + ASSERT_TRUE(litert::Environment::Create(environment_options)); + + LiteRtHwAccelerators compilation_options = static_cast( + kLiteRtHwAccelatorCpu | kLiteRtHwAccelatorGpu | kLiteRtHwAccelatorNpu); + auto new_flatbuffer = + litert::internal::ApplyPlugins(&model, compilation_options); + ASSERT_TRUE(new_flatbuffer); + + ASSERT_EQ(model.NumSubgraphs(), 1); + + auto& subgraph = *model.MainSubgraph(); + ASSERT_EQ(subgraph.Ops().size(), 1); + + EXPECT_EQ(subgraph.Op(0).OpCode(), kLiteRtOpCodeTflCustom); + EXPECT_THAT(subgraph.Op(0).CustomOptions().StrView(), + HasSubstr(kByteCodeMetadataKey)); - auto npu_code = ApplyPlugin(plugins->front(), model); - ASSERT_TRUE(npu_code); - EXPECT_GT(npu_code->Size(), 0); + EXPECT_TRUE(model.FindMetadata(kByteCodeMetadataKey)); + EXPECT_TRUE(model.FindMetadata(kLiteRtBuildStampKey)); - auto ops = model.MainSubgraph()->Ops(); - ASSERT_EQ(ops.size(), 1); - EXPECT_EQ(ops.front().Code(), kLiteRtOpCodeTflCustom); - EXPECT_EQ(ops.front().Get()->custom_options.StrView(), "Partition_0"); + litert::Environment::Destroy(); } } // namespace diff --git a/tensorflow/lite/experimental/litert/core/BUILD b/tensorflow/lite/experimental/litert/core/BUILD index c72934358a7c7e..e58e80fdd58404 100644 --- a/tensorflow/lite/experimental/litert/core/BUILD +++ b/tensorflow/lite/experimental/litert/core/BUILD @@ -58,13 +58,40 @@ cc_library( deps = [ "//tensorflow/lite/experimental/litert/c:litert_common", "//tensorflow/lite/experimental/litert/c:litert_logging", # buildcleaner: keep - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", ], ) +cc_library( + name = "environment", + srcs = ["environment.cc"], + hdrs = [ + "environment.h", + "//tensorflow/lite/experimental/litert/c:litert_environment.h", + ], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_any", + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/cc:litert_any", + "//tensorflow/lite/experimental/litert/cc:litert_expected", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "environment_test", + srcs = ["environment_test.cc"], + deps = [ + ":environment", + "//tensorflow/lite/experimental/litert/c:litert_any", + "//tensorflow/lite/experimental/litert/cc:litert_any", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "filesystem", srcs = ["filesystem.cc"], @@ -104,11 +131,10 @@ cc_test( # ":dynamic_loading", # ":filesystem", # "@com_google_googletest//:gtest_main", -# "//testing/base/public:unique-test-directory", -# "@com_google_absl//absl/strings:str_format", # "@com_google_absl//absl/strings:string_view", # "//tensorflow/lite/experimental/litert/c:litert_logging", # buildcleaner: keep # "//tensorflow/lite/experimental/litert/test:common", +# "//tensorflow/lite/experimental/litert/test:test_macros", # ], # ) # copybara:uncomment_end diff --git a/tensorflow/lite/experimental/litert/core/dynamic_loading.cc b/tensorflow/lite/experimental/litert/core/dynamic_loading.cc index dfc3fe4567144e..a5fc1053827a94 100644 --- a/tensorflow/lite/experimental/litert/core/dynamic_loading.cc +++ b/tensorflow/lite/experimental/litert/core/dynamic_loading.cc @@ -22,7 +22,6 @@ #endif #endif -#include #include // NOLINT #include #include @@ -31,11 +30,22 @@ #include "absl/strings/string_view.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_logging.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" namespace litert::internal { -LiteRtStatus OpenLib(absl::string_view so_path, void** lib_handle) { +LiteRtStatus OpenLib(const std::vector& so_paths, + void** lib_handle) { + for (const auto& so_path : so_paths) { + if (OpenLib(so_path, lib_handle, /*log_failure=*/false) == + kLiteRtStatusOk) { + return kLiteRtStatusOk; + } + } + return kLiteRtStatusErrorDynamicLoading; +} + +LiteRtStatus OpenLib(absl::string_view so_path, void** lib_handle, + bool log_failure) { #ifdef RTLD_DEEPBIND void* res = ::dlopen(so_path.data(), RTLD_NOW | RTLD_LOCAL | RTLD_DEEPBIND); #else @@ -43,10 +53,11 @@ LiteRtStatus OpenLib(absl::string_view so_path, void** lib_handle) { #endif if (res == nullptr) { - LITERT_LOG(LITERT_ERROR, "Failed to load .so at path: %s\n", - so_path.data()); - LogDlError(); - + if (log_failure) { + LITERT_LOG(LITERT_ERROR, "Failed to load .so at path: %s\n", + so_path.data()); + LogDlError(); + } return kLiteRtStatusErrorDynamicLoading; } *lib_handle = res; diff --git a/tensorflow/lite/experimental/litert/core/dynamic_loading.h b/tensorflow/lite/experimental/litert/core/dynamic_loading.h index 2b7c1aaf3a3b4c..d02756740ed250 100644 --- a/tensorflow/lite/experimental/litert/core/dynamic_loading.h +++ b/tensorflow/lite/experimental/litert/core/dynamic_loading.h @@ -38,8 +38,16 @@ inline void LogDlError() { LITERT_LOG(LITERT_WARNING, "::dlerror() : %s", err); } -// Loads shared library at given path. -LiteRtStatus OpenLib(absl::string_view so_path, void** lib_handle); +// Probes for a list of shared library at given paths and returns when the first +// one is found. Returns kLiteRtStatusErrorDynamicLoading if none of the shared +// libraries are found. +LiteRtStatus OpenLib(const std::vector& so_paths, + void** lib_handle); + +// Loads shared library at given path. Logging can be disabled to probe for +// shared libraries. +LiteRtStatus OpenLib(absl::string_view so_path, void** lib_handle, + bool log_failure = true); // Closes reference to loaded shared library held by lib_handle. LiteRtStatus CloseLib(void* lib_handle); diff --git a/tensorflow/lite/experimental/litert/core/dynamic_loading_test.cc b/tensorflow/lite/experimental/litert/core/dynamic_loading_test.cc index e0eb68e6971ab0..d0dbe40449b87b 100644 --- a/tensorflow/lite/experimental/litert/core/dynamic_loading_test.cc +++ b/tensorflow/lite/experimental/litert/core/dynamic_loading_test.cc @@ -19,50 +19,56 @@ #include #include -#include "testing/base/public/unique-test-directory.h" #include "absl/strings/string_view.h" #include "tensorflow/lite/experimental/litert/core/filesystem.h" #include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/test/test_macros.h" namespace litert::internal { namespace { +using litert::testing::UniqueTestDirectory; using ::testing::Contains; using ::testing::HasSubstr; -using ::testing::UniqueTestDirectory; constexpr absl::string_view kNotLiteRtSo = "notLibLiteRt.so"; constexpr absl::string_view kLiteRtSo1 = "libLiteRtCompilerPlugin_1.so"; constexpr absl::string_view kLiteRtSo2 = "libLiteRtCompilerPlugin_2.so"; TEST(TestDynamicLoading, GlobNoMatch) { - const auto dir = UniqueTestDirectory(); - Touch(Join({dir, kNotLiteRtSo})); + const auto dir = UniqueTestDirectory::Create(); + ASSERT_TRUE(dir); + Touch(Join({dir->Str(), kNotLiteRtSo})); std::vector results; - LITERT_ASSERT_STATUS_OK(litert::internal::FindLiteRtSharedLibs(dir, results)); + LITERT_ASSERT_STATUS_OK( + litert::internal::FindLiteRtSharedLibs(dir->Str(), results)); EXPECT_EQ(results.size(), 0); } TEST(TestDynamicLoading, GlobOneMatch) { - const auto dir = UniqueTestDirectory(); - Touch(Join({dir, kLiteRtSo1})); - Touch(Join({dir, kNotLiteRtSo})); + const auto dir = UniqueTestDirectory::Create(); + ASSERT_TRUE(dir); + Touch(Join({dir->Str(), kLiteRtSo1})); + Touch(Join({dir->Str(), kNotLiteRtSo})); std::vector results; - LITERT_ASSERT_STATUS_OK(litert::internal::FindLiteRtSharedLibs(dir, results)); + LITERT_ASSERT_STATUS_OK( + litert::internal::FindLiteRtSharedLibs(dir->Str(), results)); ASSERT_EQ(results.size(), 1); EXPECT_TRUE(absl::string_view(results.front()).ends_with(kLiteRtSo1)); } TEST(TestDynamicLoading, GlobMultiMatch) { - const auto dir = UniqueTestDirectory(); - Touch(Join({dir, kLiteRtSo1})); - Touch(Join({dir, kLiteRtSo2})); - Touch(Join({dir, kNotLiteRtSo})); + const auto dir = UniqueTestDirectory::Create(); + ASSERT_TRUE(dir); + Touch(Join({dir->Str(), kLiteRtSo1})); + Touch(Join({dir->Str(), kLiteRtSo2})); + Touch(Join({dir->Str(), kNotLiteRtSo})); std::vector results; - LITERT_ASSERT_STATUS_OK(litert::internal::FindLiteRtSharedLibs(dir, results)); + LITERT_ASSERT_STATUS_OK( + litert::internal::FindLiteRtSharedLibs(dir->Str(), results)); ASSERT_EQ(results.size(), 2); EXPECT_THAT(results, Contains(HasSubstr(kLiteRtSo1))); EXPECT_THAT(results, Contains(HasSubstr(kLiteRtSo2))); diff --git a/tensorflow/lite/experimental/litert/core/environment.cc b/tensorflow/lite/experimental/litert/core/environment.cc new file mode 100644 index 00000000000000..8cf6e20c918f9b --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/environment.cc @@ -0,0 +1,57 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/environment.h" + +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_environment.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" + +namespace litert::internal { + +Environment* Environment::the_instance_ = nullptr; + +Expected Environment::CreateWithOptions( + absl::Span options) { + LITERT_LOG(LITERT_INFO, "Environment::CreateWithOptions the_instance_=%p", + the_instance_); + if (the_instance_) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "LiteRT environment cannot be created with options, it has " + "already been created"); + } + LITERT_LOG(LITERT_INFO, "Creating LiteRT environment with options"); + the_instance_ = new Environment(); + for (auto& option : options) { + the_instance_->options_[option.tag] = option.value; + } + return {}; +} + +void Environment::Destroy() { + delete the_instance_; + the_instance_ = nullptr; +} + +Expected Environment::Instance() { + if (!the_instance_) { + LITERT_LOG(LITERT_INFO, "Creating LiteRT environment with no options"); + the_instance_ = new Environment(); + } + return the_instance_; +} + +} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/environment.h b/tensorflow/lite/experimental/litert/core/environment.h new file mode 100644 index 00000000000000..23fe16db009396 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/environment.h @@ -0,0 +1,62 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_H_ + +#include +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_environment.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" + +namespace litert::internal { + +// A singleton class that contains global LiteRT environment options. +class Environment { + public: + // Create the singleton environment instance with options. Returns an error if + // the instance already exists, in which case the specified options have no + // effect. + static Expected CreateWithOptions( + absl::Span options); + + // Return the envirnment instance and, if not yet created, creates one with no + // options. + static Expected Instance(); + + // Destroy the environment instance. + static void Destroy(); + + std::optional GetOption(LiteRtEnvOptionTag tag) const { + auto i = options_.find(tag); + if (i != options_.end()) { + return i->second; + } else { + return std::nullopt; + } + } + + private: + std::map options_; + + static Environment* the_instance_; +}; + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ENVIRONMENT_H_ diff --git a/tensorflow/lite/experimental/litert/core/environment_test.cc b/tensorflow/lite/experimental/litert/core/environment_test.cc new file mode 100644 index 00000000000000..ffba092420bf7a --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/environment_test.cc @@ -0,0 +1,70 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/environment.h" + +#include +#include + +#include +#include "tensorflow/lite/experimental/litert/c/litert_any.h" +#include "tensorflow/lite/experimental/litert/c/litert_environment.h" +#include "tensorflow/lite/experimental/litert/cc/litert_any.h" + +namespace litert::internal { +namespace { + +TEST(Environment, CreateWithNoOption) { + ASSERT_TRUE(Environment::Instance()); + Environment::Destroy(); +} + +TEST(Environment, CreateWithOptions) { + const std::array environment_options = { + LiteRtEnvOption{ + kLiteRtEnvOptionTagCompilerPluginLibraryPath, + *ToLiteRtAny(std::any("sample path")), + }, + }; + ASSERT_TRUE(Environment::CreateWithOptions(environment_options)); + + auto env = Environment::Instance(); + ASSERT_TRUE(env); + + auto option = (*env)->GetOption(kLiteRtEnvOptionTagCompilerPluginLibraryPath); + ASSERT_TRUE(option.has_value()); + ASSERT_EQ(option->type, kLiteRtAnyTypeString); + ASSERT_STREQ(option->str_value, "sample path"); + + Environment::Destroy(); +} + +TEST(Environment, CreateWithOptionsFailure) { + // This will create an environment without options. + auto env = Environment::Instance(); + ASSERT_TRUE(env); + + const std::array environment_options = { + LiteRtEnvOption{ + kLiteRtEnvOptionTagCompilerPluginLibraryPath, + *ToLiteRtAny(std::any("sample path")), + }, + }; + ASSERT_FALSE(Environment::CreateWithOptions(environment_options)); + + Environment::Destroy(); +} + +} // namespace +} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/filesystem.cc b/tensorflow/lite/experimental/litert/core/filesystem.cc index c3744239520254..50df1174723cd0 100644 --- a/tensorflow/lite/experimental/litert/core/filesystem.cc +++ b/tensorflow/lite/experimental/litert/core/filesystem.cc @@ -18,6 +18,8 @@ #include #include // NOLINT #include +#include +#include #include "absl/strings/string_view.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" @@ -63,7 +65,7 @@ LiteRtStatus StdIFRead(const StdPath& std_path, char* data, size_t size) { void Touch(absl::string_view path) { std::ofstream(MakeStdPath(path)); } -std::string Join(const SmallVec& paths) { +std::string Join(const std::vector& paths) { StdPath std_path; for (auto subpath : paths) { std_path /= MakeStdPath(subpath); @@ -76,7 +78,7 @@ bool Exists(absl::string_view path) { return StdExists(MakeStdPath(path)); } Expected Size(absl::string_view path) { auto std_path = MakeStdPath(path); if (!StdExists(std_path)) { - return Error(kLiteRtStatusErrorNotFound); + return Error(kLiteRtStatusErrorNotFound, "File not found"); } return StdSize(std_path); } @@ -85,7 +87,7 @@ Expected> LoadBinaryFile(absl::string_view path) { auto std_path = MakeStdPath(path); if (!StdExists(std_path)) { - return Error(kLiteRtStatusErrorFileIO); + return Error(kLiteRtStatusErrorFileIO, "File not found"); } OwningBufferRef buf(StdSize(std_path)); diff --git a/tensorflow/lite/experimental/litert/core/filesystem.h b/tensorflow/lite/experimental/litert/core/filesystem.h index 6dd3ae1f237664..87146d68029cbe 100644 --- a/tensorflow/lite/experimental/litert/core/filesystem.h +++ b/tensorflow/lite/experimental/litert/core/filesystem.h @@ -17,6 +17,8 @@ #include #include +#include +#include #include "absl/strings/string_view.h" #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" @@ -29,7 +31,7 @@ namespace litert::internal { // Append all given subpaths together (e.g. os.path.join). -std::string Join(const SmallVec& paths); +std::string Join(const std::vector& paths); // Make a new empty file at the given path. void Touch(absl::string_view path); diff --git a/tensorflow/lite/experimental/litert/core/model/BUILD b/tensorflow/lite/experimental/litert/core/model/BUILD index a81af12d7c897b..72f9e11e19e6dc 100644 --- a/tensorflow/lite/experimental/litert/core/model/BUILD +++ b/tensorflow/lite/experimental/litert/core/model/BUILD @@ -19,16 +19,13 @@ package( cc_library( name = "model", - srcs = [ - "model.cc", - "//tensorflow/lite/experimental/litert/c:litert_model_srcs", - ], + srcs = ["model.cc"], hdrs = [ "model.h", - "model_load.h", "//tensorflow/lite/experimental/litert/c:litert_model_hdrs", ], deps = [ + ":ir_allocator", "//tensorflow/compiler/mlir/lite/core:model_builder_base", "//tensorflow/lite/core/c:c_api_types", "//tensorflow/lite/experimental/litert/c:litert_common", @@ -36,10 +33,13 @@ cc_library( "//tensorflow/lite/experimental/litert/c:litert_op_code", "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", + "//tensorflow/lite/experimental/litert/core:byte_code_util", "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", ], ) @@ -51,12 +51,13 @@ cc_test( ], deps = [ ":model", - ":model_load", - "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_op_code", "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", - "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", + "//tensorflow/lite/experimental/litert/test:test_macros", "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", ], ) @@ -68,9 +69,9 @@ cc_library( deps = [ ":flatbuffer_to_litert", ":model", + ":model_graph", "//tensorflow/compiler/mlir/lite/core:model_builder_base", "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_layout", "//tensorflow/lite/experimental/litert/c:litert_logging", "//tensorflow/lite/experimental/litert/c:litert_op_code", "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", @@ -78,6 +79,7 @@ cc_library( "//tensorflow/lite/experimental/litert/cc:litert_macros", "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", ], ) @@ -90,21 +92,25 @@ cc_test( "//tensorflow/lite/experimental/litert/test:tflite_test_data", ], deps = [ + ":graph_validation", ":model", ":model_file_test_util", + ":model_load", ":model_serialize", "//tensorflow/lite/experimental/litert/c:litert_common", "//tensorflow/lite/experimental/litert/c:litert_op_code", "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", "//tensorflow/lite/experimental/litert/cc:litert_element_type", + "//tensorflow/lite/experimental/litert/cc:litert_expected", "//tensorflow/lite/experimental/litert/cc:litert_macros", "//tensorflow/lite/experimental/litert/cc:litert_model", "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", + "//tensorflow/lite/experimental/litert/core:byte_code_util", "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", "//tensorflow/lite/experimental/litert/test:common", "//tensorflow/lite/experimental/litert/test:test_macros", "//tensorflow/lite/experimental/litert/test:test_models", - "//tensorflow/lite/experimental/litert/tools:dump", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", ], @@ -118,17 +124,14 @@ cc_library( ":litert_to_flatbuffer", ":model", "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_op_code", + "//tensorflow/lite/experimental/litert/c:litert_logging", "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", "//tensorflow/lite/experimental/litert/cc:litert_expected", "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", + "//tensorflow/lite/experimental/litert/core:byte_code_util", "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/strings:string_view", - "@flatbuffers//:runtime_cc", ], ) @@ -140,7 +143,6 @@ cc_library( ":model", "//tensorflow/lite/experimental/litert/c:litert_common", "//tensorflow/lite/experimental/litert/c:litert_logging", - "//tensorflow/lite/experimental/litert/c:litert_op_code", "//tensorflow/lite/experimental/litert/cc:litert_expected", "//tensorflow/lite/experimental/litert/cc:litert_layout", "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", @@ -167,8 +169,6 @@ cc_library( deps = [ ":model", "//tensorflow/lite/experimental/litert/c:litert_common", - "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/c:litert_op_code", "//tensorflow/lite/experimental/litert/cc:litert_expected", "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", "//tensorflow/lite/schema:schema_fbs", @@ -197,17 +197,12 @@ cc_library( ":model", ":model_load", ":model_serialize", - "//tensorflow/lite/experimental/litert/c:litert_common", "//tensorflow/lite/experimental/litert/c:litert_op_code", "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", "//tensorflow/lite/experimental/litert/cc:litert_expected", "//tensorflow/lite/experimental/litert/cc:litert_macros", - "//tensorflow/lite/experimental/litert/cc:litert_model", "//tensorflow/lite/experimental/litert/core:byte_code_util", "//tensorflow/lite/experimental/litert/core:filesystem", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], ) @@ -222,12 +217,72 @@ cc_library( ":model", "//tensorflow/lite/experimental/litert/c:litert_logging", "//tensorflow/lite/experimental/litert/cc:litert_detail", - "//tensorflow/lite/experimental/litert/cc:litert_model", "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", "@com_google_absl//absl/types:span", ], ) +cc_library( + name = "ir_allocator", + hdrs = ["ir_allocator.h"], + deps = ["@com_google_absl//absl/types:span"], +) + +cc_test( + name = "ir_allocator_test", + srcs = ["ir_allocator_test.cc"], + deps = [ + ":ir_allocator", + ":model", + "//tensorflow/lite/experimental/litert/c:litert_op_code", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "model_graph", + srcs = ["model_graph.cc"], + hdrs = [ + "model_graph.h", + "//tensorflow/lite/experimental/litert/cc:litert_consts.h", + ], + deps = [ + ":model", + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_op_code", + "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", + "//tensorflow/lite/experimental/litert/cc:litert_detail", + "//tensorflow/lite/experimental/litert/cc:litert_expected", + "@com_google_absl//absl/log:absl_check", + ], +) + +cc_library( + name = "graph_validation", + srcs = ["graph_validation.cc"], + hdrs = ["graph_validation.h"], + deps = [ + ":model", + ":model_graph", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/cc:litert_detail", + ], +) + +cc_test( + name = "model_graph_test", + srcs = ["model_graph_test.cc"], + deps = [ + ":graph_validation", + ":model", + ":model_graph", + "//tensorflow/lite/experimental/litert/c:litert_op_code", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "model_buffer_test", srcs = ["model_buffer_test.cc"], diff --git a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.cc b/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.cc index ad4dadba531872..762ed10ec71b73 100644 --- a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.cc +++ b/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.cc @@ -118,33 +118,31 @@ Expected MapTensorType(const TflTensorType& tfl_tensor_type) { return Error(kLiteRtStatusErrorUnsupported); } - LiteRtTypeDetail detail; + TensorTypeDetail detail; detail.ranked_tensor_type.element_type = litert_element_type; detail.ranked_tensor_type.layout = BuildLayout(*ranked_shape); return std::make_pair(kLiteRtRankedTensorType, detail); } -Expected MapQuantization( - const TflQuantization* tfl_quantization) { +Expected MapQuantization(const TflQuantization* tfl_quantization, + BufferProvider buffer_provider) { if (!IsQuantized(tfl_quantization)) { - return std::make_pair(kLiteRtQuantizationNone, - LiteRtQuantizationTypeDetail()); + return MakeEmptyQuantization(); } - auto per_tensor_qparams = AsPerTensorQparams(tfl_quantization); - if (!per_tensor_qparams) { - LITERT_LOG(LITERT_ERROR, - "Only per tensor quantization currently supported"); - return Error(kLiteRtStatusErrorUnsupported); + if (auto tfl_qparams = AsPerTensorQparams(tfl_quantization)) { + return MakePerTensorQuantization(tfl_qparams->second, tfl_qparams->first); } - auto [zero_point, scale] = *per_tensor_qparams; - LiteRtQuantizationTypeDetail detail; - detail.per_tensor.scale = scale; - detail.per_tensor.zero_point = zero_point; + if (auto tfl_qparams = AsPerChannelQparams(tfl_quantization)) { + [[maybe_unused]] const auto& [quantized_dimension, num_channels, + zero_points, scales] = *tfl_qparams; + return MakePerChannelQuantization(scales, zero_points, quantized_dimension, + buffer_provider); + } - return std::make_pair(kLiteRtQuantizationPerTensor, detail); + LITERT_LOG(LITERT_ERROR, "Uknown tfl quantization type"); + return Error(kLiteRtStatusErrorUnsupported); } - } // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h b/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h index 9f9124777dc19e..033f6cddf19f81 100644 --- a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h +++ b/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h @@ -16,10 +16,10 @@ #define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_FLATBUFFER_TO_LITERT_H_ #include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" #include "tensorflow/lite/experimental/litert/core/model/model.h" #include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/schema/schema_generated.h" namespace litert::internal { @@ -35,7 +35,8 @@ LiteRtElementType MapElementType(TflElementType element_type); Expected MapTensorType(const TflTensorType& tfl_tensor_type); -Expected MapQuantization(const TflQuantization* tfl_quantization); +Expected MapQuantization(const TflQuantization* tfl_quantization, + BufferProvider buffer_provider); } // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert_test.cc b/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert_test.cc index 13aa9d05efc7b5..2ff1cb18ffa8a4 100644 --- a/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert_test.cc +++ b/tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert_test.cc @@ -18,6 +18,7 @@ #include #include +#include #include #include "absl/types/span.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" @@ -26,6 +27,8 @@ namespace litert::internal { namespace { +using ::testing::ElementsAreArray; + TEST(FlatbufferToLiteRtTest, MapStaticTensorType) { static constexpr int32_t kDims[] = {2, 2}; static constexpr auto kDimsSpan = absl::MakeConstSpan(kDims); @@ -57,7 +60,8 @@ TEST(FlatbufferToLiteRtTest, MapDynamicTensorType) { } TEST(FlatbufferToLiteRtTest, MapNoQuantization) { - auto q = MapQuantization(nullptr); + LiteRtTensorT tensor; + auto q = MapQuantization(nullptr, tensor); ASSERT_TRUE(q); ASSERT_EQ(q->first, kLiteRtQuantizationNone); } @@ -70,7 +74,8 @@ TEST(FlatbufferToLiteRtTest, MapPerTensorQuantization) { tfl_q.scale.assign({kScale}); tfl_q.zero_point.assign({kZp}); - auto q = MapQuantization(&tfl_q); + LiteRtTensorT tensor; + auto q = MapQuantization(&tfl_q, tensor); ASSERT_TRUE(q); ASSERT_EQ(q->first, kLiteRtQuantizationPerTensor); EXPECT_EQ(q->second.per_tensor.scale, kScale); @@ -88,8 +93,17 @@ TEST(FlatbufferToLiteRtTest, MapPerChannelQuantization) { tfl_q.zero_point.assign(kZps, kZps + kRank); tfl_q.quantized_dimension = kQDim; - auto q = MapQuantization(&tfl_q); - ASSERT_FALSE(q); + LiteRtTensorT tensor; + auto q = MapQuantization(&tfl_q, tensor); + ASSERT_TRUE(q); + ASSERT_EQ(q->first, kLiteRtQuantizationPerChannel); + EXPECT_THAT(absl::MakeConstSpan(q->second.per_channel.scales, kRank), + ElementsAreArray(kScales)); + + EXPECT_THAT(absl::MakeConstSpan(q->second.per_channel.zero_points, kRank), + ElementsAreArray(kZps)); + EXPECT_EQ(q->second.per_channel.quantized_dimension, kQDim); + EXPECT_EQ(q->second.per_channel.num_channels, kRank); } } // namespace diff --git a/tensorflow/lite/experimental/litert/core/model/graph_validation.cc b/tensorflow/lite/experimental/litert/core/model/graph_validation.cc new file mode 100644 index 00000000000000..a9a942c1bfaa14 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/model/graph_validation.cc @@ -0,0 +1,114 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/model/graph_validation.h" + +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" +#include "tensorflow/lite/experimental/litert/core/model/model.h" +#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" + +namespace litert::internal { + +bool ValidateLocalTopology(const LiteRtOpT& litert_op) { + // Check number of in edges equals number of inputs and each input index + // appears on an in edge. + for (auto i = 0; i < litert_op.Inputs().size(); ++i) { + const auto& litert_tensor = litert_op.Input(i); + + auto input_use = + GetTensorUses(litert_tensor, FindUseInds(litert_tensor, litert_op)); + + if (!ContainsIf(input_use.cbegin(), input_use.cend(), + [i](auto u) { return u.second == i; })) { + LITERT_LOG(LITERT_WARNING, + "Input tensor %d not connected to op on correct index.", i); + return false; + } + } + + // Similar to above for outputs. + for (auto i = 0; i < litert_op.Outputs().size(); ++i) { + const auto& litert_tensor = litert_op.Output(i); + + if (litert_tensor.DefiningOp() != &litert_op) { + LITERT_LOG(LITERT_WARNING, "Output back edge doesn't refer to this op."); + return false; + } + + if (litert_tensor.DefiningOpOutInd() != i) { + LITERT_LOG(LITERT_WARNING, "Output back edge ind is incorrect."); + return false; + } + } + + return true; +} + +bool ValidateSubgraphIO(const LiteRtSubgraphT& litert_subgraph) { + auto num_implied_inputs = 0; + auto num_implied_outputs = 0; + for (auto* tensor : litert_subgraph.Tensors()) { + const auto implied_out = tensor->NumUses() == 0; + const auto implied_in = + !IsConstant(*tensor) && tensor->DefiningOp() == nullptr; + + if (implied_out && implied_in) { + LITERT_LOG(LITERT_WARNING, "Graph contains a dead tensor"); + return false; + } + + const auto is_io = IsIO(litert_subgraph, *tensor); + + if (implied_in) { + if (!is_io) { + LITERT_LOG(LITERT_WARNING, + "Implied input not reflected in subgraph io %lu", + tensor - litert_subgraph.Tensors().at(0)); + return false; + } + ++num_implied_inputs; + } + + if (implied_out) { + if (!is_io) { + LITERT_LOG(LITERT_WARNING, + "Implied output not reflected in subgraph io"); + return false; + } + ++num_implied_outputs; + } + } + + if (num_implied_inputs != litert_subgraph.NumInputs()) { + LITERT_LOG( + LITERT_WARNING, + "Number of implied %lu inputs not equal to number of actual inputs %lu", + num_implied_inputs, litert_subgraph.NumInputs()); + return false; + } + + if (num_implied_outputs != litert_subgraph.NumOutputs()) { + LITERT_LOG(LITERT_WARNING, + "Number of implied %lu outputs not equal to number of actual " + "outputs %lu", + num_implied_outputs, litert_subgraph.NumOutputs()); + return false; + } + + return true; +} + +} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/graph_validation.h b/tensorflow/lite/experimental/litert/core/model/graph_validation.h new file mode 100644 index 00000000000000..c0a199294f8677 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/model/graph_validation.h @@ -0,0 +1,47 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_GRAPH_VALIDATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_GRAPH_VALIDATION_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/core/model/model.h" +#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" + +// Helper functions for validating the structure of IR graphs. + +namespace litert::internal { + +// Checks the double-linked edges to immediate neighbors are valid. +bool ValidateLocalTopology(const LiteRtOpT& litert_op); + +// Runs ValidateLocalTopology across given LiteRtOp iterator. +template +bool ValidateLocalTopology(OpIt start, OpIt end) { + return std::all_of(start, end, + [](const auto* op) { return ValidateLocalTopology(*op); }); +} + +// Checks the following are bijections: +// * non-const tensor with no defining op <-> subgraph input +// * tensor with no users <-> subgraph output (assuming no side effect ops) +// These are used to figure out the i/o signatures when building a subgraph +// from scratch. +bool ValidateSubgraphIO(const LiteRtSubgraphT& litert_subgraph); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_GRAPH_VALIDATION_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/ir_allocator.h b/tensorflow/lite/experimental/litert/core/model/ir_allocator.h new file mode 100644 index 00000000000000..4e0a575a105e88 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/model/ir_allocator.h @@ -0,0 +1,109 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_IR_ALLOCATOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_IR_ALLOCATOR_H_ + +#include +#include +#include +#include +#include + +#include "absl/types/span.h" + +namespace litert::internal { + +// A list of IR objects scoped to the same block (subgraph) that provides +// pointer stability. Facilitates management of memory and c-like access +// to elements. +template +class IrAllocator { + private: + using Storage = std::list; + using Refs = std::vector; + + public: + // Emplace a new element onto the list. + template + Ir& EmplaceBack(Args&&... args) { + auto& emp = storage_.emplace_back(std::forward(args)...); + refs_->push_back(&emp); + return emp; + } + + // Get the array of (stable) pointers to underlying elements. Suitable + // for passing through c-like interface. Consituent pointers are always + // guarateed to be stable (unless explicitly erased). The array of pointers + // itself is guaranteed to be stable so long as no length-changing operations + // occur, moving this class does not invalidate pointers or array. + absl::Span Elements() const { + return absl::MakeSpan(refs_->data(), refs_->size()); + } + + // Remove elements from the allocator if they match the predicate. + // Returns the number of elements removed. + size_t RemoveIf(std::function pred) { + auto ref_it = refs_->begin(); + for (auto it = storage_.begin(); it != storage_.end();) { + if (!pred(*it)) { + *ref_it = &*it; + ++ref_it; + ++it; + continue; + } + it = storage_.erase(it); + } + const size_t removed = refs_->end() - ref_it; + refs_->resize(refs_->size() - removed); + return removed; + } + + // Cuts all but the first `size` elements from storage. Does nothing if `size` + // is greater or equal to current size. + void ResizeDown(size_t size) { + if (size >= Size()) { + return; + } + storage_.resize(size); + refs_->resize(size); + } + + // Transfers the ownership of given allocator to this one. + void Transfer(IrAllocator&& other) { + storage_.splice(storage_.cend(), other.storage_); + refs_->insert(refs_->end(), other.refs_->cbegin(), other.refs_->cend()); + } + + // Number of elements stored by this allocator. + size_t Size() const { return storage_.size(); } + + IrAllocator() { refs_ = std::make_unique(); } + + // IR is generally semantically movable (without reference invalidation) + // but not copyable. IrAllocators reflect that, note moving lists + // does not invalidate references. + IrAllocator(const IrAllocator& other) = delete; + IrAllocator& operator=(const IrAllocator& other) = delete; + IrAllocator(IrAllocator&& other) = default; + IrAllocator& operator=(IrAllocator&& other) = default; + + private: + Storage storage_; + std::unique_ptr refs_; +}; + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_IR_ALLOCATOR_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/ir_allocator_test.cc b/tensorflow/lite/experimental/litert/core/model/ir_allocator_test.cc new file mode 100644 index 00000000000000..1923b70cc6ae3c --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/model/ir_allocator_test.cc @@ -0,0 +1,108 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/model/ir_allocator.h" + +#include + +#include +#include +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/core/model/model.h" + +namespace litert::internal { +namespace { + +using ::testing::ElementsAreArray; + +static constexpr auto kCustomOpCode = kLiteRtOpCodeTflCustom; +static constexpr auto kNonCustomOpCode = kLiteRtOpCodeTflSoftmax; + +TEST(IrAllocatorTest, EmplaceBack) { + IrAllocator ops; + + LiteRtOpT my_op; + my_op.SetOpCode(kCustomOpCode); + + ops.EmplaceBack(std::move(my_op)); + ASSERT_EQ(ops.Elements().size(), 1); + EXPECT_EQ(ops.Elements().at(0)->OpCode(), kCustomOpCode); +} + +TEST(IrAllocatorTest, RemoveIf) { + IrAllocator ops; + + LiteRtOpT my_op; + my_op.SetOpCode(kNonCustomOpCode); + ops.EmplaceBack(std::move(my_op)); + + LiteRtOpT my_op2; + my_op2.SetOpCode(kCustomOpCode); + ops.EmplaceBack(std::move(my_op2)); + + LiteRtOpT my_op3; + my_op3.SetOpCode(kCustomOpCode); + ops.EmplaceBack(std::move(my_op3)); + + LiteRtOpT my_op4; + my_op4.SetOpCode(kNonCustomOpCode); + ops.EmplaceBack(std::move(my_op4)); + + auto pred = [](const auto& op) { return op.OpCode() != kCustomOpCode; }; + ASSERT_EQ(ops.RemoveIf(pred), 2); + + ASSERT_EQ(ops.Elements().size(), 2); + ASSERT_EQ(ops.Elements().at(0)->OpCode(), kCustomOpCode); + ASSERT_EQ(ops.Elements().at(1)->OpCode(), kCustomOpCode); +} + +TEST(IrAllocatorTest, ResizeDown) { + IrAllocator ops; + + LiteRtOp op1 = nullptr; + { + LiteRtOpT my_op; + my_op.SetOpCode(kNonCustomOpCode); + op1 = &ops.EmplaceBack(std::move(my_op)); + } + + { + LiteRtOpT my_op2; + my_op2.SetOpCode(kCustomOpCode); + ops.EmplaceBack(std::move(my_op2)); + } + + ops.ResizeDown(1); + + ASSERT_EQ(ops.Size(), 1); + EXPECT_EQ(ops.Elements().at(0), op1); +} + +TEST(IrAllocatorTest, Transfer) { + IrAllocator ops; + auto& op1 = ops.EmplaceBack(); + auto& op2 = ops.EmplaceBack(); + + IrAllocator other_ops; + auto& other_op1 = other_ops.EmplaceBack(); + auto& other_op2 = other_ops.EmplaceBack(); + + ops.Transfer(std::move(other_ops)); + + EXPECT_THAT(ops.Elements(), + ElementsAreArray({&op1, &op2, &other_op1, &other_op2})); +} + +} // namespace +} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.cc b/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.cc index 3c4e3d29661a76..9bec2f4c1ce3fe 100644 --- a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.cc +++ b/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.cc @@ -83,6 +83,21 @@ Expected MapQuantizationDetail( return tfl_quantization; } +template <> +Expected +MapQuantizationDetail( + const LiteRtQuantizationPerChannel& litert_quantization) { + auto tfl_quantization = std::make_unique(); + + for (int i = 0; i < litert_quantization.num_channels; ++i) { + tfl_quantization->scale.push_back(litert_quantization.scales[i]); + tfl_quantization->zero_point.push_back(litert_quantization.zero_points[i]); + } + tfl_quantization->quantized_dimension = + litert_quantization.quantized_dimension; + return tfl_quantization; +} + } // namespace Expected MapTensorType(const TensorType& litert_tensor_type) { @@ -101,6 +116,8 @@ Expected MapQuantization( return TflQuantizationPtr(nullptr); case kLiteRtQuantizationPerTensor: return MapQuantizationDetail(litert_quantization.second.per_tensor); + case kLiteRtQuantizationPerChannel: + return MapQuantizationDetail(litert_quantization.second.per_channel); default: return Error(kLiteRtStatusErrorUnsupported); } diff --git a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h b/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h index 9b4d1cea239195..4fbe51bf9d3a0b 100644 --- a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h +++ b/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h @@ -16,11 +16,9 @@ #ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_LITERT_TO_FLATBUFFER_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_LITERT_TO_FLATBUFFER_H_ -#include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" #include "tensorflow/lite/experimental/litert/core/model/model.h" #include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/schema/schema_generated.h" namespace litert::internal { diff --git a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer_test.cc b/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer_test.cc index 8314c7c540eef7..3f5c8fdf101fa1 100644 --- a/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer_test.cc +++ b/tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer_test.cc @@ -15,6 +15,7 @@ #include "tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h" +#include #include #include @@ -52,6 +53,25 @@ TEST(LiteRtToFlatbufferTest, MapPerTensorQuantization) { EXPECT_THAT(tfl_q->get()->zero_point, ElementsAreArray({kZp})); } +TEST(LiteRtToFlatbufferTest, MapPerChannelQuantization) { + static constexpr size_t kRank = 2; + static constexpr size_t kQuantizedDimension = 1; + static constexpr float kScales[kRank] = {1.0, 2.0}; + static constexpr int64_t kZps[kRank] = {2, 3}; + + Quantization q; + q.first = kLiteRtQuantizationPerChannel; + q.second.per_channel.scales = const_cast(kScales); + q.second.per_channel.zero_points = const_cast(kZps); + q.second.per_channel.num_channels = kRank; + q.second.per_channel.quantized_dimension = kQuantizedDimension; + + auto tfl_q = MapQuantization(q); + ASSERT_TRUE(tfl_q); + EXPECT_THAT(tfl_q->get()->scale, ElementsAreArray(kScales)); + EXPECT_THAT(tfl_q->get()->zero_point, ElementsAreArray(kZps)); +} + TEST(LiteRtToFlatbufferTest, MapDynamicTensorType) { static constexpr int32_t kDims[] = {-1, 2}; diff --git a/tensorflow/lite/experimental/litert/core/model/model.cc b/tensorflow/lite/experimental/litert/core/model/model.cc index d5cc23d867d218..4549f008bc4fd8 100644 --- a/tensorflow/lite/experimental/litert/core/model/model.cc +++ b/tensorflow/lite/experimental/litert/core/model/model.cc @@ -14,50 +14,123 @@ #include "tensorflow/lite/experimental/litert/core/model/model.h" -#include +#include #include -#include +#include +#include +#include +#include "absl/log/absl_check.h" #include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_layout.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" #include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" -#include "tensorflow/lite/schema/schema_generated.h" using ::litert::BufferRef; -using ::litert::Expected; -using ::litert::Unexpected; +using ::litert::internal::TflBuffer; +using ::litert::internal::TflBufferPtr; +using ::litert::internal::TflOpCode; +using ::litert::internal::TflOpCodePtr; +using ::litert::internal::TflOptions; -Expected> LiteRtModelT::FindMetadata( - const absl::string_view key) const { - return ::litert::internal::GetMetadata(key, *flatbuffer_model); +TensorType MakeRankedTensorType(LiteRtElementType element_type, + absl::Span dims) { + TensorType tensor_type; + tensor_type.first = kLiteRtRankedTensorType; + auto& ranked = tensor_type.second.ranked_tensor_type; + ranked.element_type = element_type; + ABSL_DCHECK_LE(dims.size(), LITERT_TENSOR_MAX_RANK); + ranked.layout.rank = dims.size(); + std::copy(dims.begin(), dims.end(), ranked.layout.dimensions); + // Strides not yet supported. + ranked.layout.strides = nullptr; + return tensor_type; } -LiteRtStatus LiteRtModelT::PushMetadata(absl::string_view key, - BufferRef data) { - return ::litert::internal::PushMetadata(key, *flatbuffer_model, data); +Quantization MakePerTensorQuantization(float scale, int64_t zero_point) { + Quantization quantization; + quantization.first = kLiteRtQuantizationPerTensor; + quantization.second.per_tensor.scale = scale; + quantization.second.per_tensor.zero_point = zero_point; + return quantization; } -litert::Expected LiteRtModelT::FindSignature( - absl::string_view signature_key) const { - for (auto& signature : signatures) { - if (signature->key == signature_key) { - return signature.get(); - } - } - return Unexpected(kLiteRtStatusErrorNotFound, "Signature not found"); +LiteRtSignatureT MakeDefaultSignature(LiteRtSubgraph subgraph) { + auto tensor_name = [](auto* tensor) { return std::string(tensor->Name()); }; + + auto in_start = subgraph->Inputs().cbegin(); + auto in_end = subgraph->Inputs().cend(); + std::vector input_names(subgraph->NumInputs()); + std::transform(in_start, in_end, input_names.begin(), tensor_name); + + auto out_start = subgraph->Outputs().cbegin(); + auto out_end = subgraph->Outputs().cend(); + std::vector output_names(subgraph->NumOutputs()); + std::transform(out_start, out_end, output_names.begin(), tensor_name); + + std::string name(LiteRtSignatureT::kDefaultSignatureKey); + return LiteRtSignatureT(subgraph, std::move(input_names), + std::move(output_names), std::move(name)); } -litert::Expected LiteRtModelT::FindSubgraph( - absl::string_view signature_key) const { - for (auto& signature : signatures) { - if (signature->key == signature_key) { - return &(subgraphs[signature->subgraph_index]); - } +::litert::Expected LookupSubgraph( + const LiteRtModelT& model, absl::string_view signature_key) { + auto sig = model.FindSignature(signature_key); + if (!sig) { + return sig.Error(); } - return Unexpected(kLiteRtStatusErrorNotFound, "Signature not found"); + return &sig->get().GetSubgraph(); +} + +namespace detail { + +void SetTflOpCodeInd(LiteRtOpT& litert_op, int32_t tfl_op_code_ind) { + litert_op.tfl_op_code_ind_ = tfl_op_code_ind; +} + +int32_t GetTflOpCodeInd(const LiteRtOpT& litert_op) { + return litert_op.tfl_op_code_ind_; +} + +const TflOptions& GetTflOptions(const LiteRtOpT& litert_op) { + return litert_op.tfl_option_; +} + +TflOptions&& TakeTflOptions(LiteRtOpT& litert_op) { + return std::move(litert_op.tfl_option_); +} + +const TflBuffer& GetTflBuffer(const LiteRtWeightsT& litert_weights) { + return *litert_weights.tfl_buf_; +} + +TflBufferPtr TakeTflBuffer(LiteRtWeightsT& litert_weights) { + return std::move(litert_weights.tfl_buf_); +} + +void SetTflBuffer(LiteRtWeightsT& litert_weights, TflBufferPtr tfl_buffer) { + litert_weights.tfl_buf_ = std::move(tfl_buffer); +} + +const std::vector& GetTflOpCodes( + const LiteRtModelT& litert_model) { + return litert_model.tfl_operator_codes_; } + +std::vector&& TakeTflOpCodes(LiteRtModelT& litert_model) { + return std::move(litert_model.tfl_operator_codes_); +} + +void SetTflInitFlatbuffer(LiteRtModelT& litert_model, + BufferRef init_flatbuffer) { + litert_model.tfl_init_flatbuffer_ = init_flatbuffer; +} + +BufferRef GetTflInitFlatbuffer(const LiteRtModelT& litert_model) { + return litert_model.tfl_init_flatbuffer_; +} + +} // namespace detail diff --git a/tensorflow/lite/experimental/litert/core/model/model.h b/tensorflow/lite/experimental/litert/core/model/model.h index 0bb49a37e7a278..40a60af52fef0f 100644 --- a/tensorflow/lite/experimental/litert/core/model/model.h +++ b/tensorflow/lite/experimental/litert/core/model/model.h @@ -15,217 +15,781 @@ #ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_H_ +#include #include #include #include +#include #include #include +#include +#include #include +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" // IWYU pragma: export #include "tensorflow/lite/experimental/litert/c/litert_op_code.h" #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/core/byte_code_util.h" +#include "tensorflow/lite/experimental/litert/core/model/ir_allocator.h" +#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" #include "tensorflow/lite/schema/schema_generated.h" +//////////////////////////////////////////////////////////////////////////////// +// Internal LiteRtIR // -// Tensor +// These are the backing definitions for the opaque types in the c api +// (c/litert_model.h). +// +// < STORAGE DETAIL > +// +// Unless deleted as a result of calls c api client, the lifetime of all "IR +// Objects" (definitions of opaque types) are designed to be transitively owned +// by the LiteRtModelT which is generally the longset living object. See various +// "Emplace" methods. +// +// Since c api clients interface with pointers to IR Ojbects, a form of pointer +// stability is desirable. Classes in this file enforce that pointers to IR +// Objects are valid for their entire life time. Thus a c api client may store +// pointers and depend on referential equality of IR Objects thoughout different +// calls. This also facilitates storing edge/parent-references as pointers +// within IR Objects. +// +// Direct copying is generally not allowed for IR Objects since copying +// instances of mutually recursive types is not entirely well-defined. +// +// IR Objects are generally default constructible to facilitate stable storage +// and iterative construction. // +// < EXPOSING TFLITE SCHEMA > +// +// Direct access to tflite schema types is limited to the "detail" namespace. +// This indicates that encapsulating all the details of the flatbuffer is a WIP. +// Future implementations may use different data forms (new litert serialized +// format, tflite runtime types etc). +// +// < USAGE NOTE > +// +// The classes here contain only simple getters & setters. Care should be taken +// to leave the IR in a valid state when using setters since the graph is +// doubly-linked. Higher-level functionality for correct graph mutation can be +// found in "model_graph.h". +//////////////////////////////////////////////////////////////////////////////// -struct LiteRtWeightsT { - std::unique_ptr fb_buffer = nullptr; -}; +// All tflite schema type usage. +namespace detail { + +// OP + +// Placeholder for the ind of the dispatch op code added during serialization. +static constexpr auto kDispatchOpCodeTflInd = -1; + +void SetTflOpCodeInd(LiteRtOpT& litert_op, int32_t tfl_op_code_ind); + +int32_t GetTflOpCodeInd(const LiteRtOpT& litert_op); + +template +void SetTflOptions(LiteRtOpT& litert_op, Arg&& arg); + +const ::litert::internal::TflOptions& GetTflOptions(const LiteRtOpT& litert_op); + +::litert::internal::TflOptions&& TakeTflOptions(LiteRtOpT& litert_op); + +// WEIGHT + +const ::litert::internal::TflBuffer& GetTflBuffer( + const LiteRtWeightsT& litert_weights); + +litert::internal::TflBufferPtr TakeTflBuffer(LiteRtWeightsT& litert_weights); + +void SetTflBuffer(LiteRtWeightsT& litert_weights, + litert::internal::TflBufferPtr tfl_buffer); +// MODEL + +const std::vector<::litert::internal::TflOpCodePtr>& GetTflOpCodes( + const LiteRtModelT& litert_model); + +template +void SetTflOpCodes(LiteRtModelT& litert_model, Arg&& arg); + +std::vector<::litert::internal::TflOpCodePtr>&& TakeTflOpCodes( + LiteRtModelT& litert_model); + +void SetTflInitFlatbuffer(LiteRtModelT& litert_model, + ::litert::BufferRef init_flatbuffer); + +::litert::BufferRef GetTflInitFlatbuffer( + const LiteRtModelT& litert_model); + +} // namespace detail + +// +// Helpers for conceptual unions from C api. +// + +// // For requesting opaque data stored within IR. +using BufferProvider = std::function; + +// TENSOR TYPE + +// Detail convenience type for tensor type union. typedef union { LiteRtUnrankedTensorType unranked_tensor_type; LiteRtRankedTensorType ranked_tensor_type; -} LiteRtTypeDetail; +} TensorTypeDetail; + +// Union and identifier for tensor types. +using TensorType = std::pair; + +// Construct tensor type union as ranked tensor. NOTE: Copies data in `dims`. +TensorType MakeRankedTensorType(LiteRtElementType element_type, + absl::Span dims); -using TensorType = std::pair; +// QUANTIZATION TYPE +// Detail convenience type for quantization type union. typedef union { LiteRtQuantizationPerTensor per_tensor; -} LiteRtQuantizationTypeDetail; + LiteRtQuantizationPerChannel per_channel; +} QuantizationDetail; + +// Union and identifier for quantization types. +using Quantization = std::pair; + +// Make default type with quantization info. +inline Quantization MakeEmptyQuantization() { + return Quantization(kLiteRtQuantizationNone, QuantizationDetail()); +} + +// Construct quantization type as per tensor. +Quantization MakePerTensorQuantization(float scale, int64_t zero_point); + +// Construct quantization type as per channel, requires buffer callback to +// store data. +template +Quantization MakePerChannelQuantization(const Scales& scales, + const ZeroPoints& zero_points, + int32_t quantized_dim, + BufferProvider buffer_provider) { + const auto size = std::size(scales); + ABSL_DCHECK_EQ(size, std::size(zero_points)); + + Quantization res; + res.first = kLiteRtQuantizationPerChannel; + + res.second.per_channel.num_channels = size; + res.second.per_channel.quantized_dimension = quantized_dim; + + const size_t scales_buf_size = size * sizeof(float); + const size_t zeros_buf_size = size * sizeof(int64_t); + auto* scales_buf = reinterpret_cast(buffer_provider(scales_buf_size)); + auto* zeros_buf = reinterpret_cast(buffer_provider(zeros_buf_size)); + std::copy(std::cbegin(scales), std::cend(scales), scales_buf); + std::copy(std::cbegin(zero_points), std::cend(zero_points), zeros_buf); -using Quantization = - std::pair; + res.second.per_channel.scales = scales_buf; + res.second.per_channel.zero_points = zeros_buf; -struct LiteRtTensorT { + return res; +} + +// +// Tensor +// + +// Constant data associated with a tensor. +class LiteRtWeightsT { + private: + using OwnedBuffer = ::litert::OwningBufferRef; + + public: + // Underlying data. + ::litert::BufferRef Buf() const { + return ::litert::BufferRef(tfl_buf_->data.data(), + tfl_buf_->data.size()); + } + + // Set weights via copied data. + void SetFromBuf(::litert::BufferRef buf) { + tfl_buf_->data.assign(buf.Data(), buf.Data() + buf.Size()); + } + + // Set via copied vec. + void SetFromVec(const std::vector& vec) { tfl_buf_->data = vec; } + + // IR is generally, default constructible and movable but not copyable. + LiteRtWeightsT() + : tfl_buf_(std::make_unique<::litert::internal::TflBuffer>()) {} + LiteRtWeightsT(const LiteRtWeightsT&) = delete; + LiteRtWeightsT(LiteRtWeightsT&&) = default; + LiteRtWeightsT& operator=(const LiteRtWeightsT&) = delete; + LiteRtWeightsT& operator=(LiteRtWeightsT&&) = default; + + // Friendship for internal tflite details. + friend const ::litert::internal::TflBuffer& detail::GetTflBuffer( + const LiteRtWeightsT& litert_weights); + + friend litert::internal::TflBufferPtr detail::TakeTflBuffer( + LiteRtWeightsT& litert_weights); + + friend void detail::SetTflBuffer(LiteRtWeightsT& litert_weights, + litert::internal::TflBufferPtr tfl_buffer); + + private: + // TFLITE + ::litert::internal::TflBufferPtr tfl_buf_; +}; + +// Fundamental value in a litert program, "edges" in the graph. +class LiteRtTensorT { + private: + using UserData = std::unique_ptr; + + public: using Ref = std::reference_wrapper; + using Use = std::pair; + using UseVec = std::vector; + using Alloc = ::litert::internal::IrAllocator; + + // The ops that take this tensor as input. + const std::vector& Users() const { return users_; } + std::vector& Users() { return users_; } + + // Which operand index users take this tensor on, respects the ordering of + // users.. + const std::vector& UserArgInds() const { + return user_arg_inds_; + } + std::vector& UserArgInds() { return user_arg_inds_; } - // Empty if subgraph output. This is a reference. - std::vector users; + // Number of uses, same as number of user arg inds. + size_t NumUses() const { return users_.size(); } - // Which arg number for user i. - std::vector user_arg_inds; + // Get the ith use. + Use GetUse(size_t ind) const { + return {users_.at(ind), user_arg_inds_.at(ind)}; + } - // Null if subgraph input or constant. This is a reference. - LiteRtOp defining_op = nullptr; + // Remove the use at the given index. + void RemoveUse(size_t ind) { + users_.erase(users_.begin() + ind); + user_arg_inds_.erase(user_arg_inds_.begin() + ind); + } - // Which output ind from defining op made this tensor. - LiteRtParamIndex defining_op_out_ind; + // Get the op that outputs this tensor, null if constant or subgraph input. + LiteRtOp DefiningOp() const { return defining_op_; } - // Not a reference. - LiteRtWeightsT weights; + // Get the output index of the op that defines this tensor, only meaningful + // if it has a defining op. + LiteRtParamIndex DefiningOpOutInd() const { return defining_op_out_ind_; } - // Id for union tensor type. - LiteRtTensorTypeId type_id; + // Update the defining op of this tensor. The caller is required to update the + // given op's output if not already correct. + void SetDefiningOp(LiteRtOpT& defining_op, LiteRtParamIndex out_ind) { + defining_op_ = &defining_op; + defining_op_out_ind_ = out_ind; + } - // Union tensor type. - LiteRtTypeDetail type_detail; + // Set the defining op to none. + void ClearDefiningOp() { + defining_op_ = nullptr; + defining_op_out_ind_ = 0; + } + + // Any constant data associated with this tensor. + const LiteRtWeightsT& Weights() const { return weights_; } + LiteRtWeightsT& Weights() { return weights_; } + + // Authored name associated with this tensor. May be empty. + absl::string_view Name() const { return name_; } - // Id for union quantization type. - LiteRtQuantizationTypeId q_type_id = kLiteRtQuantizationNone; + // Update the name associated with this tensor. + void SetName(std::string name) { name_ = std::move(name); } - // Union quantization type. - LiteRtQuantizationTypeDetail q_type_detail; + // Get quantization information for this tensor. + const Quantization& Qparams() const { return quantization_; } + Quantization& Qparams() { return quantization_; } + + // Set quantization information. + template + void SetQarams(Arg&& arg) { + quantization_ = std::forward(arg); + } - // Authored name of tensor, may be empty. - std::string name; + // Get the tensor type of this tensor. + const TensorType& Type() const { return tensor_type_; } + TensorType& Type() { return tensor_type_; } + + // Set the tensor type. + template + void SetType(Arg&& arg) { + tensor_type_ = std::forward(arg); + } + + // Get a new buffer that will live as long as this tensor. Used for storing + // various buffers passed through c-api (dims, quantization etc). + uint8_t* RequestBuffer(size_t size) { + user_data_.push_back(std::make_unique(size)); + return user_data_.back().get(); + } + + // Allow for implicit conversion to bufer provider. + // NOLINTNEXTLINE + operator BufferProvider() & { + return [this](auto s) { return this->RequestBuffer(s); }; + } + + // IR is generally, default constructible and movable but not copyable. + LiteRtTensorT() = default; + LiteRtTensorT(const LiteRtTensorT&) = delete; + LiteRtTensorT(LiteRtTensorT&&) = default; + LiteRtTensorT& operator=(const LiteRtTensorT&) = delete; + LiteRtTensorT& operator=(LiteRtTensorT&&) = default; private: - // TODO Unify mangement of dims and clean this up. - litert::SmallVec dims; + std::vector users_; + std::vector user_arg_inds_; + + LiteRtOp defining_op_ = nullptr; + LiteRtParamIndex defining_op_out_ind_; + + LiteRtWeightsT weights_; + Quantization quantization_; + TensorType tensor_type_; + + std::string name_; + + std::vector user_data_; }; +// Helper to get multiple uses at once. +template +LiteRtTensorT::UseVec GetTensorUses(const LiteRtTensorT& tensor, + const Inds& inds) { + auto start = std::cbegin(inds); + auto end = std::cend(inds); + LiteRtTensorT::UseVec uses(end - start); + auto get = [&tensor = std::as_const(tensor)](auto i) { + return tensor.GetUse(i); + }; + std::transform(start, end, uses.begin(), get); + return uses; +} + // // Op // -struct LiteRtOpT { - // These are references. - std::vector inputs; +// Fundamental unit of compute of a litert program, or "nodes" in the graph. +class LiteRtOpT { + public: + using Ref = std::reference_wrapper; + using Alloc = ::litert::internal::IrAllocator; + + // Input tensors for this op. + const std::vector& Inputs() const { return inputs_; } + std::vector& Inputs() { return inputs_; } + + // Access input at given ind. + LiteRtTensorT& Input(size_t ind) { return *Inputs().at(ind); } + const LiteRtTensorT& Input(size_t ind) const { return *Inputs().at(ind); } + + // Number of input tensors. + size_t NumInputs() const { return inputs_.size(); } - // These are references. - std::vector outputs; + // Output tensors for this op. + const std::vector& Outputs() const { return outputs_; } + std::vector& Outputs() { return outputs_; } - LiteRtOpCode op_code; + // Number of output tensors. + size_t NumOutputs() const { return outputs_.size(); } - litert::OwningBufferRef custom_options; + // Access output at given ind. + LiteRtTensorT& Output(size_t ind) { return *Outputs().at(ind); } + const LiteRtTensorT& Output(size_t ind) const { return *Outputs().at(ind); } - tflite::BuiltinOptionsUnion option; + // Remove the ith entry of input list. + void RemoveInput(size_t ind) { inputs_.erase(inputs_.begin() + ind); } - // Add a new input to this op and updating given tensors users. - void AddInput(LiteRtTensorT& input_tensor) { - input_tensor.users.push_back(this); - input_tensor.user_arg_inds.push_back(inputs.size()); - inputs.push_back(&input_tensor); + // Remove the ith entry of output list. + void RemoveOutput(size_t ind) { outputs_.erase(outputs_.begin() + ind); } + + // Get any custom options attached to this op. Empty if there are none. + litert::BufferRef CustomOptions() const { return custom_options_; } + + // Attach custom opaque optins to this op. + template + void SetCustomOptions(Args&&... args) { + custom_options_ = + ::litert::OwningBufferRef(std::forward(args)...); } - // Add a new output to this op and update given tensors defining op. - void AddOutput(LiteRtTensorT& output_tensor) { - output_tensor.defining_op_out_ind = outputs.size(); - output_tensor.defining_op = this; - outputs.push_back(&output_tensor); + // Sets the custom options to zero length buffer. + void ClearCustomOptions() { custom_options_.Reset(); } + + // Get the op code. + LiteRtOpCode OpCode() const { return litert_op_code_; } + + // Set the op code. + void SetOpCode(LiteRtOpCode litert_op_code) { + litert_op_code_ = litert_op_code; } + + // IR is generally, default constructible and movable but not copyable. + LiteRtOpT() = default; + LiteRtOpT(const LiteRtOpT&) = delete; + LiteRtOpT(LiteRtOpT&&) = default; + LiteRtOpT& operator=(const LiteRtOpT&) = delete; + LiteRtOpT& operator=(LiteRtOpT&&) = default; + + // Friendship for internal tflite details. + friend void detail::SetTflOpCodeInd(LiteRtOpT& litert_op, + int32_t tfl_op_code_ind); + + friend int32_t detail::GetTflOpCodeInd(const LiteRtOpT& litert_op); + + template + friend void detail::SetTflOptions(LiteRtOpT& litert_op, Arg&& arg); + + friend const ::litert::internal::TflOptions& detail::GetTflOptions( + const LiteRtOpT& litert_op); + + friend ::litert::internal::TflOptions&& detail::TakeTflOptions( + LiteRtOpT& litert_op); + + private: + LiteRtOpCode litert_op_code_; + + ::litert::OwningBufferRef custom_options_; + + std::vector inputs_; + std::vector outputs_; + + // TFLITE + int32_t tfl_op_code_ind_ = detail::kDispatchOpCodeTflInd; + ::litert::internal::TflOptions tfl_option_; }; // // Subgraph // -struct LiteRtSubgraphT { - // Storage and views of tensors. Clients are only shown views. Facilitates - // efficient topological mutation. - std::list tensors_storage; - std::vector tensors; +// Fundamental block of a litert program. Manages the storage of all +// ops and tensor within. +class LiteRtSubgraphT { + public: + using Ref = std::reference_wrapper; + using Alloc = ::litert::internal::IrAllocator; + + // Get a stable pointer for all of the tensors in this subgraph. + absl::Span Tensors() { return tensors_.Elements(); } + absl::Span Tensors() const { return tensors_.Elements(); } + + // Access the tensor at given ind. + LiteRtTensorT& Tensor(size_t ind) { return *Tensors().at(ind); } + const LiteRtTensorT& Tensor(size_t ind) const { return *Tensors().at(ind); } + + // Get a stable pointer for all of the ops in this subgraph. Will + // be a valid toplological order. + absl::Span Ops() { return ops_.Elements(); } + absl::Span Ops() const { return ops_.Elements(); } + + // Access op at the given ind. + LiteRtOpT& Op(size_t ind) { return *Ops().at(ind); } + const LiteRtOpT& Op(size_t ind) const { return *Ops().at(ind); } + + // All the subgraph input tensors, these also exist in Tensors. + const std::vector& Inputs() const { return inputs_; } + std::vector& Inputs() { return inputs_; } - // Storage and vies of ops. - std::list ops_storage; - std::vector ops; + // Number of inputs tensors. + size_t NumInputs() const { return inputs_.size(); } - // Shared view of initial flatbuffer data. - std::shared_ptr flatbuffer_subgraph; + // Access the subgraph input at given ind. + LiteRtTensorT& Input(size_t ind) { return *Inputs().at(ind); } + const LiteRtTensorT& Input(size_t ind) const { return *Inputs().at(ind); } - // These are references and a subset of `tensors`. - std::vector inputs; + // All the subgraph output tensors, these also exist in Tensors. + const std::vector& Outputs() const { return outputs_; } + std::vector& Outputs() { return outputs_; } - // These are references and a subset of `tensors`. - std::vector outputs; + // Number of outputs tensors. + size_t NumOutputs() const { return outputs_.size(); } + + // Access the subgraph output at given ind. + LiteRtTensorT& Output(size_t ind) { return *Outputs().at(ind); } + const LiteRtTensorT& Output(size_t ind) const { return *Outputs().at(ind); } + + // Clear the entry for the ith input. + void ClearInput(size_t ind) { inputs_.erase(inputs_.begin() + ind); } + + // Clear the entry for the ith output. + void ClearOutput(size_t ind) { outputs_.erase(outputs_.begin() + ind); } + + // Construct a new tensor which will be owned by this subgraph and get a + // reference to it. + template + LiteRtTensorT& EmplaceTensor(Args&&... args) { + return tensors_.EmplaceBack(std::forward(args)...); + } - LiteRtTensorT& EmplaceTensor() { - auto& tensor = tensors_storage.emplace_back(); - tensors.push_back(&tensor); - return tensor; + // Construct a new op which will be owned by this subgraph and get a + // reference to it. + template + LiteRtOpT& EmplaceOp(Args&&... args) { + return ops_.EmplaceBack(std::forward(args)...); } - LiteRtOpT& EmplaceOp() { - auto& op = ops_storage.emplace_back(); - ops.push_back(&op); - return op; + // De-allocates ops that pass given predicate. Returns number of ops removed. + size_t RemoveOpIf(std::function pred) { + return ops_.RemoveIf(pred); } + + // De-allocates tensors that pass given predicate. Returns number of tensors + // removed. + size_t RemoveTensorIf(std::function pred) { + return tensors_.RemoveIf(pred); + } + + // IR is generally, default constructible and movable but not copyable. + LiteRtSubgraphT() = default; + LiteRtSubgraphT(const LiteRtSubgraphT&) = delete; + LiteRtSubgraphT(LiteRtSubgraphT&&) = default; + LiteRtSubgraphT& operator=(const LiteRtSubgraphT&) = delete; + LiteRtSubgraphT& operator=(LiteRtSubgraphT&&) = default; + + private: + LiteRtTensorT::Alloc tensors_; + + LiteRtOpT::Alloc ops_; + + std::vector inputs_; + std::vector outputs_; }; // // Signature // -#define LITERT_DEFAULT_SIGNATURE_KEY "" +class LiteRtSignatureT { + private: + using StrVec = std::vector; -struct LiteRtSignatureT { + public: using Ptr = std::unique_ptr; - absl::string_view key; - int subgraph_index; - std::vector input_names; - std::vector output_names; + using Ref = std::reference_wrapper; + using Alloc = ::litert::internal::IrAllocator; + + static constexpr absl::string_view kDefaultSignatureKey = + ""; + + LiteRtSignatureT(LiteRtSubgraph subgraph, StrVec input_names, + StrVec output_names, std::string key) + : key_(std::move(key)), + subgraph_(subgraph), + input_names_(std::move(input_names)), + output_names_(std::move(output_names)) {} + + // String named inputs for called subgraph. + const StrVec& InputNames() const { return input_names_; } + + // String named outputs for called subgraph. + const StrVec& OutputNames() const { return output_names_; } + + // Get the callable subgraph. + const LiteRtSubgraphT& GetSubgraph() const { return *subgraph_; } + LiteRtSubgraphT& GetSubgraph() { return *subgraph_; } + + // Name of the callable signature. + absl::string_view Key() const { return key_; } + + bool operator==(const LiteRtSignatureT& other) const { + const auto key_eq = key_ == other.key_; + const auto subgraph_eq = subgraph_ == other.subgraph_; + const auto input_names_eq = input_names_ == other.input_names_; + const auto output_names_eq = output_names_ == other.output_names_; + return key_eq && subgraph_eq && input_names_eq && output_names_eq; + } + + // IR is generally, default constructible and movable but not copyable. + LiteRtSignatureT() = default; + LiteRtSignatureT(const LiteRtSignatureT&) = delete; + LiteRtSignatureT(LiteRtSignatureT&&) = default; + LiteRtSignatureT& operator=(const LiteRtSignatureT&) = delete; + LiteRtSignatureT& operator=(LiteRtSignatureT&&) = default; + + private: + std::string key_; + + LiteRtSubgraph subgraph_; + + StrVec input_names_; + StrVec output_names_; }; +// Make a basic signature from information in the given subgraph. Used with the +// main subgraph when no explicit signatures have been authored. +LiteRtSignatureT MakeDefaultSignature(LiteRtSubgraph subgraph); + // // Model // -// A (partial) unpacking of the flatbuffer model into a list of subgraphs. -// Keeps a reference to the flatbuffer model. Lifetimes of all storage -// are linked to the containing model. -struct LiteRtModelT { +// Root-level graph object for litert programs. Manages the storage +// of all litert graph objects within. +class LiteRtModelT { + private: + using MetadataMap = + absl::flat_hash_map>; + + public: using Ref = std::reference_wrapper; + using Ptr = std::unique_ptr; + using TflOpCodes = std::vector; - // Subgraphs that have been unpacked into usable types. - std::vector subgraphs; + // TODO replace this with the index of the default signature. + static constexpr const size_t kMainSubgraphIndex = 0; - // Initial flatbuffer loaded in. "Subgraphs" field has been invalidated. - std::unique_ptr flatbuffer_model; + // OBSERVERS - // The buffer information when the model was loaded from a buffer. - const void* model_buffer = nullptr; - size_t model_buffer_size = 0; + // Get a stable pointer for all of the subgraphs within this model. + absl::Span Subgraphs() { return subgraphs_.Elements(); } + absl::Span Subgraphs() const { + return subgraphs_.Elements(); + } - // Custom code associated with all customs ops emitted during - // re-serialization. - std::string custom_op_code; + // Access subgraph at given ind. + LiteRtSubgraphT& Subgraph(size_t ind) { return *Subgraphs().at(ind); } + const LiteRtSubgraphT& Subgraph(size_t ind) const { + return *Subgraphs().at(ind); + } - // Signature definitions. - std::vector> signatures; + // Number of subraphs. + size_t NumSubgraphs() const { return subgraphs_.Elements().size(); } + + // Default entry point of this model. + const LiteRtSubgraphT* MainSubgraph() const { + return &Subgraph(kMainSubgraphIndex); + } + LiteRtSubgraph MainSubgraph() { return &Subgraph(kMainSubgraphIndex); } + + // Look up signature by key. + litert::Expected FindSignature( + absl::string_view signature_key) const { + for (LiteRtSignature sig : signatures_.Elements()) { + if (sig->Key() == signature_key) { + return std::ref(*sig); + } + } + return ::litert::Error(kLiteRtStatusErrorNotFound, "Signature not found"); + } + + // All signatures registered with this model. + absl::Span Signatures() const { + return signatures_.Elements(); + } // Look up metadata by key, getting a view of its buffer as a string // if it exists. litert::Expected> FindMetadata( - absl::string_view key) const; + absl::string_view key) const { + if (auto it = metadata_.find(key); it != metadata_.end()) { + return it->second; + } + return ::litert::Error(kLiteRtStatusErrorNotFound); + } - // Adds a new metadata buffer to the model. Fails if it already exists. - LiteRtStatus PushMetadata(absl::string_view key, - litert::BufferRef data); + // Metadata key-val pair iterator. + MetadataMap::iterator MetadataBegin() { return metadata_.begin(); } + MetadataMap::iterator MetadataEnd() { return metadata_.end(); } + + // Remvoe and take ownership of the metadata under given key if it exists. + litert::Expected> PopMetadata( + absl::string_view key) { + if (auto it = metadata_.find(key); it != metadata_.end()) { + return metadata_.extract(it).mapped(); + } + return ::litert::Error(kLiteRtStatusErrorNotFound); + } - // Look up signature by key. - litert::Expected FindSignature( - absl::string_view signature_key) const; + // BUILDERS + + // Build a new subgraph and get a stable reference to it. + template + LiteRtSubgraphT& EmplaceSubgraph(Args&&... args) { + return subgraphs_.EmplaceBack(std::forward(args)...); + } - // Look up subgraph by key. - litert::Expected FindSubgraph( - absl::string_view signature_key) const; + // Transfers given subgraphs into this model. + void TransferSubgraphs(LiteRtSubgraphT::Alloc&& subgraphs) { + subgraphs_.Transfer(std::move(subgraphs)); + } + + // Cut all by the first `size` subgraphs. Does nothing if given size is + // greater or equal to current. + void ResizeSubgraphsDown(size_t size) { subgraphs_.ResizeDown(size); } - size_t MainSubgraphIndex() const { - // TODO replace this with the index of the default signature. - return 0; + // Adds a new metadata buffer to the model. Fails if it already exists. + template + LiteRtStatus PushMetadata(absl::string_view key, Args&&... args) { + if (metadata_.contains(key)) { + return kLiteRtStatusErrorInvalidArgument; + } + metadata_.insert( + {std::string(key.begin(), key.end()), + ::litert::OwningBufferRef(std::forward(args)...)}); + return kLiteRtStatusOk; } - const LiteRtSubgraphT& MainSubgraph() const { - return subgraphs[MainSubgraphIndex()]; + // Construct a new signature for this model. + template + LiteRtSignatureT& EmplaceSignature(Args&&... args) { + return signatures_.EmplaceBack(std::forward(args)...); } + + // IR is generally, default constructible and movable but not copyable. + LiteRtModelT() = default; + LiteRtModelT(const LiteRtModelT&) = delete; + LiteRtModelT(LiteRtModelT&&) = default; + LiteRtModelT& operator=(const LiteRtModelT&) = delete; + LiteRtModelT& operator=(LiteRtModelT&&) = default; + + // Friendship for internal tflite details. + friend const TflOpCodes& detail::GetTflOpCodes( + const LiteRtModelT& litert_model); + + template + friend void detail::SetTflOpCodes(LiteRtModelT& litert_model, Arg&& arg); + + friend TflOpCodes&& detail::TakeTflOpCodes(LiteRtModelT& litert_model); + + friend void detail::SetTflInitFlatbuffer( + LiteRtModelT& litert_model, ::litert::BufferRef init_flatbuffer); + + friend ::litert::BufferRef detail::GetTflInitFlatbuffer( + const LiteRtModelT& litert_model); + + private: + LiteRtSubgraphT::Alloc subgraphs_; + LiteRtSignatureT::Alloc signatures_; + + MetadataMap metadata_; + + // TFLITE + TflOpCodes tfl_operator_codes_; + litert::BufferRef tfl_init_flatbuffer_; }; +// Lookup subgraph by signature name. +::litert::Expected LookupSubgraph( + const LiteRtModelT& model, absl::string_view signature_key); + // // Utils // @@ -243,11 +807,22 @@ class LiteRtOpListT { } private: - // NOTE: This was originally a vector. Was encountering really odd - // segfaults when freeing after code on another side of a compilation boundary - // was doing pushes that resized. A list+copy to vector is not optimimal, - // revisit if bottleneck. + // Investigate if this is possible with vector (hit some issues). std::list ops_; }; +namespace detail { + +template +void SetTflOptions(LiteRtOpT& litert_op, Arg&& arg) { + litert_op.tfl_option_ = std::forward(arg); +} + +template +void SetTflOpCodes(LiteRtModelT& litert_model, Arg&& arg) { + litert_model.tfl_operator_codes_ = std::forward(arg); +} + +} // namespace detail + #endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/model_buffer.cc b/tensorflow/lite/experimental/litert/core/model/model_buffer.cc index 37c53889a90431..1f739c5c8def44 100644 --- a/tensorflow/lite/experimental/litert/core/model/model_buffer.cc +++ b/tensorflow/lite/experimental/litert/core/model/model_buffer.cc @@ -15,18 +15,9 @@ #include "tensorflow/lite/experimental/litert/core/model/model_buffer.h" #include -#include -#include // NOLINT -#include -#include #include -#include -#include "absl/log/absl_check.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_op_code.h" #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" @@ -42,42 +33,27 @@ namespace internal { Expected> GetModelBufWithByteCode( LiteRtModelT&& model, BufferRef npu_byte_code) { - LITERT_EXPECT_OK( - model.PushMetadata(kByteCodeMetadataKey, MakeByteCodePlaceholder())); + LITERT_EXPECT_OK(model.PushMetadata( + kByteCodeMetadataKey, npu_byte_code.Data(), npu_byte_code.Size())); - for (auto& subgraph : model.subgraphs) { - for (auto& op : subgraph.ops) { - if (op->op_code != kLiteRtOpCodeTflCustom) { + for (auto* subgraph : model.Subgraphs()) { + for (auto* op : subgraph->Ops()) { + if (op->OpCode() != kLiteRtOpCodeTflCustom) { continue; } auto exec_info = - MakeExecInfo(op->custom_options.StrView(), kByteCodeMetadataKey); + MakeExecInfo(op->CustomOptions().StrView(), kByteCodeMetadataKey); if (!exec_info) { return exec_info.Error(); } - op->custom_options = std::move(*exec_info); + op->SetCustomOptions(std::move(*exec_info)); } } - model.custom_op_code = kLiteRtDispatchOpCustomCode; + auto build_stamp = MakeBuildStamp("", "", Serialization::kAppend); + LITERT_EXPECT_OK(model.PushMetadata(kLiteRtBuildStampKey, *build_stamp)); - auto serialized = SerializeModel(std::move(model)); - if (!serialized) { - return serialized; - } - - LITERT_EXPECT_OK( - FinishByteCodePlaceholders(*serialized, npu_byte_code.Size())); - - OwningBufferRef with_append(serialized->Size() + - npu_byte_code.Size()); - - uint8_t* write = with_append.Data(); - std::memcpy(write, serialized->Data(), serialized->Size()); - write += serialized->Size(); - std::memcpy(write, npu_byte_code.Data(), npu_byte_code.Size()); - - return with_append; + return SerializeModel(std::move(model)); } Expected> GetModelBufWithByteCode( diff --git a/tensorflow/lite/experimental/litert/core/model/model_file_test.cc b/tensorflow/lite/experimental/litert/core/model/model_file_test.cc index 9a127bd2fa140d..736468c254b29f 100644 --- a/tensorflow/lite/experimental/litert/core/model/model_file_test.cc +++ b/tensorflow/lite/experimental/litert/core/model/model_file_test.cc @@ -12,25 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include // NOLINT #include +#include #include -#include #include #include #include // IWYU pragma: keep #include +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" #include "tensorflow/lite/experimental/litert/c/litert_op_code.h" #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" #include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" #include "tensorflow/lite/experimental/litert/cc/litert_macros.h" #include "tensorflow/lite/experimental/litert/cc/litert_model.h" #include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" +#include "tensorflow/lite/experimental/litert/core/byte_code_util.h" +#include "tensorflow/lite/experimental/litert/core/model/graph_validation.h" #include "tensorflow/lite/experimental/litert/core/model/model.h" #include "tensorflow/lite/experimental/litert/core/model/model_file_test_util.h" #include "tensorflow/lite/experimental/litert/core/model/model_load.h" @@ -39,30 +44,59 @@ #include "tensorflow/lite/experimental/litert/test/common.h" #include "tensorflow/lite/experimental/litert/test/test_macros.h" #include "tensorflow/lite/experimental/litert/test/test_models.h" -#include "tensorflow/lite/experimental/litert/tools/dump.h" namespace litert::internal { namespace { -using ::litert::testing::ValidateTopology; - -Model LoadModelThroughRoundTrip(std::string_view path) { - auto model = litert::testing::LoadTestFileModel(path); +using ::litert::testing::GetTestFilePath; +using ::testing::Each; +using ::testing::ElementsAreArray; +using ::testing::FloatEq; +using ::testing::Values; + +using ModelFactory = std::function()>; + +static constexpr absl::string_view kAddSimple = "add_simple.tflite"; +static constexpr absl::string_view kAddCst = "add_cst.tflite"; +static constexpr absl::string_view kDynamicShapeModel = + "dynamic_shape_tensor.tflite"; +static constexpr absl::string_view kSimpleMultiOp = "simple_multi_op.tflite"; +static constexpr absl::string_view kOneMul = "one_mul.tflite"; +static constexpr absl::string_view kSimpleMultiSubgraph = + "multi_subgraph.tflite"; +static constexpr absl::string_view kCstMultiSubgraph = + "cst_multi_subgraph.tflite"; + +// Load a model, then serialize and re-load. Used to test serialization. +Expected LoadModelThroughRoundTrip(absl::string_view filename) { + auto model = Model::CreateFromFile(GetTestFilePath(filename)); + if (!model) { + return model.Error(); + } OwningBufferRef buf; auto [data, size, offset] = buf.GetWeak(); - LITERT_CHECK_STATUS_OK( - LiteRtSerializeModel(model.Release(), &data, &size, &offset)); + LITERT_EXPECT_OK( + LiteRtSerializeModel(model->Release(), &data, &size, &offset)); // Reload model. LiteRtModel result = nullptr; - LITERT_CHECK_STATUS_OK( + LITERT_EXPECT_OK( LiteRtCreateModelFromBuffer(buf.Data(), buf.Size(), &result)); return Model::CreateFromOwnedHandle(result); } +ModelFactory MakeRoundTripFactory(absl::string_view filename) { + return [=]() { return LoadModelThroughRoundTrip(filename); }; +} + +ModelFactory MakeLoadFactory(absl::string_view filename) { + return [=]() { return Model::CreateFromFile(GetTestFilePath(filename)); }; +} + +// Test fixture parameterized by a file path to test model. class TestWithModelPath : public ::testing::TestWithParam { protected: std::string GetTestModelPath() const { @@ -70,28 +104,22 @@ class TestWithModelPath : public ::testing::TestWithParam { } }; -class TopologyTest : public ::testing::TestWithParam { - public: - static std::vector MakeTestModels( - const std::vector& paths) { - std::vector result; - - for (auto p : paths) { - result.push_back(litert::testing::LoadTestFileModel(p).Release()); - result.push_back(LoadModelThroughRoundTrip(p).Release()); - } - - return result; - } +// Test fixture pareterized by a function that loads a model. +class TestWithModelFactory : public ::testing::TestWithParam { + protected: + Expected LoadModel() { return GetParam()(); } }; -TEST(LiteRtModelTest, TestLoadTestDataBadFilepath) { +// Simple tests +//===--------------------------------------------------------------------------- + +TEST(ModelLoadTest, BadFilepath) { LiteRtModel model = nullptr; LITERT_ASSERT_STATUS_HAS_CODE(LiteRtCreateModelFromFile("bad_path", &model), kLiteRtStatusErrorFileIO); } -TEST(LiteRtModelTest, TestLoadTestDataBadFileData) { +TEST(ModelLoadTest, BadFileData) { // NOLINTBEGIN #ifndef NDEBUG // In debug mode, flatbuffers will `assert` while verifying. This will @@ -113,16 +141,35 @@ TEST(LiteRtModelTest, TestLoadTestDataBadFileData) { // NOLINTEND } -TEST(TestSerializeModel, TestMetadata) { - auto model = litert::testing::LoadTestFileModel("add_simple.tflite"); +TEST(ModelLoadTest, WithMetadata) { + constexpr static absl::string_view kMetadataName = "an_soc_manufacturer"; + constexpr static absl::string_view kMetadataData = "My_Meta_Data"; + + auto flatbuffer = + FlatbufferWrapper::CreateFromTflFile(GetTestFilePath(kAddSimple)); + auto tfl_model = flatbuffer->get()->Unpack(); + PushMetadata(kMetadataName, *tfl_model, + BufferRef(kMetadataData.data(), kMetadataData.size())); + auto serialialized = SerializeFlatbuffer(*tfl_model); + + auto litert_model = LoadModelFromBuffer(serialialized); + ASSERT_TRUE(litert_model); + + auto metadata = litert_model->get()->FindMetadata(kMetadataName); + ASSERT_TRUE(metadata); + EXPECT_EQ(metadata->StrView(), kMetadataData); +} - constexpr static std::string_view kMetadataName = "an_soc_manufacturer"; - constexpr static std::string_view kMetadataData = "My_Meta_Data"; +TEST(ModelSerializeTest, WithMetadata) { + auto model = litert::testing::LoadTestFileModel(kAddSimple); + + constexpr static absl::string_view kMetadataName = "an_soc_manufacturer"; + constexpr static absl::string_view kMetadataData = "My_Meta_Data"; LITERT_ASSERT_STATUS_OK(model.Get()->PushMetadata( kMetadataName, OwningBufferRef(kMetadataData))); - auto serialized = SerializeModel(std::move(model)); + auto serialized = SerializeModel(std::move(*model.Get())); EXPECT_TRUE(VerifyFlatbuffer(serialized->Span())); auto re_loaded = LoadModelFromBuffer(*serialized); @@ -130,17 +177,138 @@ TEST(TestSerializeModel, TestMetadata) { EXPECT_EQ(metadata->StrView(), kMetadataData); } -using AddSimpleTest = TopologyTest; +TEST(ModelLoadTest, WithSignature) { + auto model = litert::testing::LoadTestFileModel(kAddSimple); + auto& litert_model = *model.Get(); + + auto signature = + litert_model.FindSignature(LiteRtSignatureT::kDefaultSignatureKey); + ASSERT_TRUE(signature); + + EXPECT_EQ(signature->get().InputNames().size(), 1); + EXPECT_EQ(signature->get().OutputNames().size(), 1); + EXPECT_EQ(&signature->get().GetSubgraph(), litert_model.MainSubgraph()); +} + +TEST(ModelSerializeTest, WithSignature) { + auto model = litert::testing::LoadTestFileModel(kAddSimple); + auto& litert_model = *model.Get(); -TEST_P(AddSimpleTest, TestBuildModelAddSimple) { - Model model = Model::CreateFromOwnedHandle(GetParam()); + static constexpr char kInput[] = "foo"; + static constexpr char kOutput[] = "bar"; + static constexpr char kKey[] = "newKey"; + + LiteRtSignatureT signature(litert_model.MainSubgraph(), {kInput}, {kOutput}, + kKey); + litert_model.EmplaceSignature(std::move(signature)); + + auto serialized = SerializeModel(std::move(*model.Get())); + EXPECT_TRUE(VerifyFlatbuffer(serialized->Span())); + + auto re_loaded = LoadModelFromBuffer(*serialized); + auto re_loaded_signature = re_loaded->get()->FindSignature(kKey); + ASSERT_TRUE(re_loaded_signature); + const auto& sig = re_loaded_signature->get(); + + const auto& inputs = sig.InputNames(); + const auto& outputs = sig.OutputNames(); + EXPECT_THAT(inputs, ElementsAreArray({kInput})); + EXPECT_THAT(outputs, ElementsAreArray({kOutput})); + EXPECT_EQ(&sig.GetSubgraph(), re_loaded->get()->MainSubgraph()); +} + +TEST(ModelSerializeTest, WithMetadataByteCode) { + auto model = litert::testing::LoadTestFileModel(kAddSimple); + auto& litert_model = *model.Get(); + + static constexpr absl::string_view kManufacturer = "Dodge"; + static constexpr absl::string_view kModel = "Dart"; + static constexpr absl::string_view kByteCode = "SOME_BYTE_CODE"; + static constexpr auto kSerialization = Serialization::kMetadata; + + // TODO(@lukeboyer) consider wrapping the tag & push metadata for npu + // in a helper function somewhere. + { + auto build_stamp = MakeBuildStamp(kManufacturer, kModel, kSerialization); + litert_model.PushMetadata(kLiteRtBuildStampKey, *build_stamp); + litert_model.PushMetadata(kByteCodeMetadataKey, kByteCode); + } + + auto serialized = SerializeModel(std::move(*model.Get())); + EXPECT_TRUE(VerifyFlatbuffer(serialized->Span())); + auto re_loaded = LoadModelFromBuffer(*serialized); + ASSERT_TRUE(re_loaded); + auto& re_loaded_model = **re_loaded; + + auto build_stamp = + ParseBuildStamp(*re_loaded_model.FindMetadata(kLiteRtBuildStampKey)); + ASSERT_TRUE(build_stamp); + + EXPECT_EQ(std::get<0>(*build_stamp), kManufacturer); + EXPECT_EQ(std::get<1>(*build_stamp), kModel); + EXPECT_EQ(std::get<2>(*build_stamp), kSerialization); + + auto byte_code = re_loaded_model.FindMetadata(kByteCodeMetadataKey); + ASSERT_TRUE(byte_code); + EXPECT_EQ(byte_code->StrView(), kByteCode); +} + +TEST(ModelSerializeTest, WithAppendByteCode) { + auto model = litert::testing::LoadTestFileModel(kAddSimple); + auto& litert_model = *model.Get(); + + static constexpr absl::string_view kManufacturer = "Honda"; + static constexpr absl::string_view kModel = "Civic"; + static constexpr absl::string_view kByteCode = "SOME_BYTE_CODE"; + static constexpr auto kSerialization = Serialization::kAppend; + + { + auto build_stamp = MakeBuildStamp(kManufacturer, kModel, kSerialization); + litert_model.PushMetadata(kLiteRtBuildStampKey, *build_stamp); + litert_model.PushMetadata(kByteCodeMetadataKey, kByteCode); + } + + auto serialized = SerializeModel(std::move(*model.Get())); + EXPECT_TRUE(VerifyFlatbuffer(serialized->Span())); + auto re_loaded = LoadModelFromBuffer(*serialized); + ASSERT_TRUE(re_loaded); + auto& re_loaded_model = **re_loaded; + + auto build_stamp = + ParseBuildStamp(*re_loaded_model.FindMetadata(kLiteRtBuildStampKey)); + ASSERT_TRUE(build_stamp); + + EXPECT_EQ(std::get<0>(*build_stamp), kManufacturer); + EXPECT_EQ(std::get<1>(*build_stamp), kModel); + EXPECT_EQ(std::get<2>(*build_stamp), kSerialization); + + auto byte_code_metadata = re_loaded_model.FindMetadata(kByteCodeMetadataKey); + ASSERT_TRUE(byte_code_metadata); + auto byte_code_offset = ParseByteCodePlaceholder(*byte_code_metadata); + ASSERT_TRUE(byte_code_offset); + + const auto offset = std::get<0>(*byte_code_offset); + const auto size = std::get<1>(*byte_code_offset); + + ASSERT_EQ(offset + size, serialized->Size()); + EXPECT_EQ(serialized->StrView().substr(offset, size), kByteCode); +} + +// Tests that explicitly check litert graph structure. +//===--------------------------------------------------------------------------- + +using AddSimpleTest = TestWithModelFactory; + +TEST_P(AddSimpleTest, CheckGraph) { + auto model = LoadModel(); + ASSERT_TRUE(model); // func(arg0) // output = tfl.add(arg0, arg0) // return(output) // - auto subgraph = model.MainSubgraph(); + auto subgraph = model->MainSubgraph(); const auto subgraph_inputs = subgraph->Inputs(); const auto subgraph_outputs = subgraph->Outputs(); const auto ops = subgraph->Ops(); @@ -148,7 +316,10 @@ TEST_P(AddSimpleTest, TestBuildModelAddSimple) { ASSERT_EQ(subgraph_inputs.size(), 1); ASSERT_EQ(subgraph_outputs.size(), 1); - ASSERT_TRUE(ValidateTopology(ops)); + const auto& internal_ops = subgraph->Get()->Ops(); + ASSERT_TRUE( + ValidateLocalTopology(internal_ops.cbegin(), internal_ops.cend())); + ASSERT_TRUE(ValidateSubgraphIO(*subgraph->Get())); ASSERT_EQ(ops.size(), 1); const auto& op = ops.front(); @@ -171,14 +342,17 @@ TEST_P(AddSimpleTest, TestBuildModelAddSimple) { ASSERT_FALSE(subgraph_inputs.front().IsConstant()); } -INSTANTIATE_TEST_SUITE_P( - AddSimpleTests, AddSimpleTest, - ::testing::ValuesIn(TopologyTest::MakeTestModels({"add_simple.tflite"}))); +INSTANTIATE_TEST_SUITE_P(ModelLoadTests, AddSimpleTest, + Values(MakeLoadFactory(kAddSimple))); -using AddCstTest = TopologyTest; +INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, AddSimpleTest, + Values(MakeRoundTripFactory(kAddSimple))); -TEST_P(AddCstTest, TestBuildModelAddCst) { - Model model = Model::CreateFromOwnedHandle(GetParam()); +using AddCstTest = TestWithModelFactory; + +TEST_P(AddCstTest, CheckGraph) { + auto model = LoadModel(); + ASSERT_TRUE(model); // func(arg0) // cst = ConstantTensor([1, 2, 3, 4]) @@ -186,7 +360,7 @@ TEST_P(AddCstTest, TestBuildModelAddCst) { // return(output) // - auto subgraph = model.MainSubgraph(); + auto subgraph = model->MainSubgraph(); const auto subgraph_inputs = subgraph->Inputs(); const auto subgraph_outputs = subgraph->Outputs(); const auto ops = subgraph->Ops(); @@ -194,7 +368,10 @@ TEST_P(AddCstTest, TestBuildModelAddCst) { ASSERT_EQ(subgraph_inputs.size(), 1); ASSERT_EQ(subgraph_outputs.size(), 1); - ASSERT_TRUE(ValidateTopology(ops)); + const auto& internal_ops = subgraph->Get()->Ops(); + ASSERT_TRUE( + ValidateLocalTopology(internal_ops.cbegin(), internal_ops.cend())); + ASSERT_TRUE(ValidateSubgraphIO(*subgraph->Get())); ASSERT_EQ(ops.size(), 1); const auto& op = ops.front(); @@ -218,14 +395,17 @@ TEST_P(AddCstTest, TestBuildModelAddCst) { ASSERT_FALSE(subgraph_inputs.front().IsConstant()); } -INSTANTIATE_TEST_SUITE_P( - AddCstTests, AddCstTest, - ::testing::ValuesIn(TopologyTest::MakeTestModels({"add_cst.tflite"}))); +INSTANTIATE_TEST_SUITE_P(ModelLoadTests, AddCstTest, + Values(MakeLoadFactory(kAddCst))); -using SimpleMultiOpTest = TopologyTest; +INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, AddCstTest, + Values(MakeRoundTripFactory(kAddCst))); -TEST_P(SimpleMultiOpTest, TestBuildModelSimpleMultiAdd) { - Model model = Model::CreateFromOwnedHandle(GetParam()); +using SimpleMultiOpTest = TestWithModelFactory; + +TEST_P(SimpleMultiOpTest, CheckGraph) { + auto model = LoadModel(); + ASSERT_TRUE(model); // func.func @main(arg0) // 0 = tfl.add arg0, arg0 @@ -234,7 +414,7 @@ TEST_P(SimpleMultiOpTest, TestBuildModelSimpleMultiAdd) { // 3 = tfl.add 2, 2 // return 3 - auto subgraph = model.MainSubgraph(); + auto subgraph = model->MainSubgraph(); const auto subgraph_inputs = subgraph->Inputs(); const auto subgraph_outputs = subgraph->Outputs(); const auto ops = subgraph->Ops(); @@ -242,7 +422,11 @@ TEST_P(SimpleMultiOpTest, TestBuildModelSimpleMultiAdd) { ASSERT_EQ(subgraph_inputs.size(), 1); ASSERT_EQ(subgraph_outputs.size(), 1); - ASSERT_TRUE(ValidateTopology(ops)); + const auto& internal_ops = subgraph->Get()->Ops(); + ASSERT_TRUE( + ValidateLocalTopology(internal_ops.cbegin(), internal_ops.cend())); + ASSERT_TRUE(ValidateSubgraphIO(*subgraph->Get())); + ASSERT_EQ(ops.size(), 4); for (const auto& op : ops) { @@ -258,26 +442,130 @@ TEST_P(SimpleMultiOpTest, TestBuildModelSimpleMultiAdd) { EXPECT_EQ(ops.at(2).Code(), kLiteRtOpCodeTflMul); } -INSTANTIATE_TEST_SUITE_P(SimpleMultiOpTests, SimpleMultiOpTest, - ::testing::ValuesIn(TopologyTest::MakeTestModels( - {"simple_multi_op.tflite"}))); +INSTANTIATE_TEST_SUITE_P(ModelLoadTests, SimpleMultiOpTest, + Values(MakeLoadFactory(kSimpleMultiOp))); + +INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, SimpleMultiOpTest, + Values(MakeRoundTripFactory(kSimpleMultiOp))); + +using SimpleMultiSubgraphTest = TestWithModelFactory; + +TEST_P(SimpleMultiSubgraphTest, CheckGraph) { + auto model_wrap = LoadModel(); + ASSERT_TRUE(model_wrap); + auto& model = *model_wrap->Get(); + + ASSERT_EQ(model.NumSubgraphs(), 3); + + { + auto& main = *model.MainSubgraph(); + EXPECT_EQ(main.NumInputs(), 1); + EXPECT_EQ(main.NumOutputs(), 1); + EXPECT_EQ(main.Ops().size(), 1); + EXPECT_EQ(main.Tensors().size(), 3); + auto& op = main.Op(0); + auto* cst = op.Inputs().back(); + auto data = Tensor(cst).WeightsData(); + ASSERT_TRUE(data); + EXPECT_THAT(*data, Each(FloatEq(-1.0))); + EXPECT_TRUE(ValidateLocalTopology(main.Ops().cbegin(), main.Ops().cend())); + EXPECT_TRUE(ValidateSubgraphIO(main)); + } + + { + auto& func1 = model.Subgraph(1); + EXPECT_EQ(func1.NumInputs(), 1); + EXPECT_EQ(func1.NumOutputs(), 1); + EXPECT_EQ(func1.Ops().size(), 1); + EXPECT_EQ(func1.Tensors().size(), 3); + auto& op = func1.Op(0); + auto* cst = op.Inputs().back(); + auto data = Tensor(cst).WeightsData(); + ASSERT_TRUE(data); + EXPECT_THAT(*data, Each(FloatEq(1.0))); + EXPECT_TRUE( + ValidateLocalTopology(func1.Ops().cbegin(), func1.Ops().cend())); + EXPECT_TRUE(ValidateSubgraphIO(func1)); + } + + { + auto& func2 = model.Subgraph(2); + EXPECT_EQ(func2.NumInputs(), 1); + EXPECT_EQ(func2.NumOutputs(), 1); + EXPECT_EQ(func2.Ops().size(), 1); + EXPECT_EQ(func2.Tensors().size(), 3); + auto& op = func2.Op(0); + auto* cst = op.Inputs().back(); + auto data = Tensor(cst).WeightsData(); + ASSERT_TRUE(data); + EXPECT_THAT(*data, Each(FloatEq(2.0))); + EXPECT_TRUE( + ValidateLocalTopology(func2.Ops().cbegin(), func2.Ops().cend())); + EXPECT_TRUE(ValidateSubgraphIO(func2)); + } +} + +INSTANTIATE_TEST_SUITE_P(ModelLoadTests, SimpleMultiSubgraphTest, + Values(MakeLoadFactory(kSimpleMultiSubgraph))); + +INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, SimpleMultiSubgraphTest, + Values(MakeRoundTripFactory(kSimpleMultiSubgraph))); + +// Test when flatbuffer export has optimized multiple tensors to share the +// same buffer. +using MultiSubgraphDupeConstTest = TestWithModelFactory; + +TEST_P(MultiSubgraphDupeConstTest, CheckGraph) { + static constexpr std::array kWeights = {1.0, 2.0, 3.0, 4.0}; + + auto model_wrap = LoadModel(); + ASSERT_TRUE(model_wrap); + auto& model = *model_wrap->Get(); + + ASSERT_EQ(model.NumSubgraphs(), 2); + + { + ASSERT_EQ(model.Subgraph(0).Ops().size(), 1); + ASSERT_EQ(model.Subgraph(0).Tensors().size(), 3); + auto& cst = model.Subgraph(0).Op(0).Input(1); + Tensor t(&cst); + EXPECT_THAT(*t.WeightsData(), ElementsAreArray(kWeights)); + } + + { + ASSERT_EQ(model.Subgraph(1).Ops().size(), 1); + ASSERT_EQ(model.Subgraph(1).Tensors().size(), 3); + auto& cst = model.Subgraph(1).Op(0).Input(1); + Tensor t(&cst); + EXPECT_THAT(*t.WeightsData(), ElementsAreArray(kWeights)); + } +} + +INSTANTIATE_TEST_SUITE_P(ModelLoadTests, MultiSubgraphDupeConstTest, + Values(MakeLoadFactory(kCstMultiSubgraph))); + +INSTANTIATE_TEST_SUITE_P(ModelSerializeTests, MultiSubgraphDupeConstTest, + Values(MakeRoundTripFactory(kCstMultiSubgraph))); + +// Tests that programatically check litert against tflite models. +//===--------------------------------------------------------------------------- using ModelLoadOpCheckTest = TestWithModelPath; TEST_P(ModelLoadOpCheckTest, CheckOps) { const auto model_path = GetTestModelPath(); - auto expected_fb = FlatbufferWrapper::CreateFromTflFile(model_path); - ASSERT_TRUE(expected_fb); + auto flatbuffer = FlatbufferWrapper::CreateFromTflFile(model_path); + ASSERT_TRUE(flatbuffer); + auto expected_fb = flatbuffer->get()->Unpack(); auto model = LoadModelFromFile(model_path); ASSERT_TRUE(model); - const auto& subgraph = model->get()->MainSubgraph(); - const auto& ops = subgraph.ops; + const auto* subgraph = model->get()->MainSubgraph(); + const auto& ops = subgraph->Ops(); - const auto& fb_subgraph = - *expected_fb->get()->UnpackedModel().subgraphs.front(); + const auto& fb_subgraph = *expected_fb->subgraphs.front(); const auto& fb_ops = fb_subgraph.operators; const auto& fb_tensors = fb_subgraph.tensors; @@ -288,7 +576,6 @@ TEST_P(ModelLoadOpCheckTest, CheckOps) { }; for (auto i = 0; i < ops.size(); ++i) { - Dump(*ops.at(i)); ASSERT_TRUE(EqualsFbOp(*ops.at(i), *fb_ops.at(i), get_tfl_tensor)); } } @@ -297,35 +584,32 @@ INSTANTIATE_TEST_SUITE_P(ModelLoadQuantizedOpCheckTest, ModelLoadOpCheckTest, ::testing::ValuesIn(kAllQModels)); INSTANTIATE_TEST_SUITE_P(ModelLoadDynamicOpCheckTest, ModelLoadOpCheckTest, - ::testing::ValuesIn({static_cast( - "dynamic_shape_tensor.tflite")})); - -INSTANTIATE_TEST_SUITE_P( - ModelLoadStaticOpCheckTest, ModelLoadOpCheckTest, - ::testing::ValuesIn({static_cast("one_mul.tflite")})); + ::testing::ValuesIn({kDynamicShapeModel})); using ModelSerializeOpCheckTest = TestWithModelPath; TEST_P(ModelSerializeOpCheckTest, CheckOps) { const auto model_path = GetTestModelPath(); - auto expected_fb = FlatbufferWrapper::CreateFromTflFile(model_path); - ASSERT_TRUE(expected_fb); + // Save the initial fb for comparison. + auto expected_fb_data = FlatbufferWrapper::CreateFromTflFile(model_path); + ASSERT_TRUE(expected_fb_data); + auto expected_fb = expected_fb_data->get()->Unpack(); + // Round trip the model. auto model = LoadModelFromFile(model_path); ASSERT_TRUE(model); - auto serialized = SerializeModel(std::move(**model)); - auto actual_fb = FlatbufferWrapper::CreateFromBuffer(*serialized); - ASSERT_TRUE(actual_fb); - const auto& expected_fb_subgraph = - *expected_fb->get()->UnpackedModel().subgraphs.front(); + auto actual_fb_data = FlatbufferWrapper::CreateFromBuffer(*serialized); + ASSERT_TRUE(actual_fb_data); + auto actual_fb = actual_fb_data->get()->Unpack(); + + const auto& expected_fb_subgraph = *expected_fb->subgraphs.front(); const auto& expected_fb_ops = expected_fb_subgraph.operators; const auto& expected_fb_tensors = expected_fb_subgraph.tensors; - const auto& actual_fb_subgraph = - *actual_fb->get()->UnpackedModel().subgraphs.front(); + const auto& actual_fb_subgraph = *actual_fb->subgraphs.front(); const auto& actual_fb_ops = actual_fb_subgraph.operators; const auto& actual_fb_tensors = actual_fb_subgraph.tensors; @@ -363,14 +647,8 @@ TEST_P(ModelSerializeOpCheckTest, CheckOps) { } } -INSTANTIATE_TEST_SUITE_P( - ModelSerializeStaticOpCheckTest, ModelSerializeOpCheckTest, - ::testing::ValuesIn({static_cast("one_mul.tflite")})); - -INSTANTIATE_TEST_SUITE_P(ModelSerializeDynamicOpCheckTest, - ModelSerializeOpCheckTest, - ::testing::ValuesIn({static_cast( - "dynamic_shape_tensor.tflite")})); +INSTANTIATE_TEST_SUITE_P(ModelSerializeOpCheckTest, ModelSerializeOpCheckTest, + ::testing::ValuesIn({kOneMul, kDynamicShapeModel})); INSTANTIATE_TEST_SUITE_P(ModelSerializeQuantizedOpCheckTest, ModelSerializeOpCheckTest, diff --git a/tensorflow/lite/experimental/litert/core/model/model_file_test_util.cc b/tensorflow/lite/experimental/litert/core/model/model_file_test_util.cc index 06e2f334ee9107..55bb72fa0c2961 100644 --- a/tensorflow/lite/experimental/litert/core/model/model_file_test_util.cc +++ b/tensorflow/lite/experimental/litert/core/model/model_file_test_util.cc @@ -14,11 +14,12 @@ #include "tensorflow/lite/experimental/litert/core/model/model_file_test_util.h" +#include + #include "absl/types/span.h" #include "tensorflow/lite/experimental/litert/c/litert_logging.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" #include "tensorflow/lite/experimental/litert/cc/litert_detail.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" #include "tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h" #include "tensorflow/lite/experimental/litert/core/model/model.h" #include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" @@ -43,6 +44,23 @@ bool EqualsFbQuantizationDetail( litert_quantization.scale == tfl_q_params->second; } +template <> +bool EqualsFbQuantizationDetail( + LiteRtQuantizationPerChannel litert_quantization, + const TflQuantization* tfl_quantization) { + auto tfl_q_params = AsPerChannelQparams(tfl_quantization); + if (!tfl_q_params) return false; + const auto& [quantized_dimension, num_channels, zero_points, scales] = + *tfl_q_params; + const auto qd_eq = + litert_quantization.quantized_dimension == quantized_dimension; + const auto num_chan_eq = litert_quantization.num_channels == num_channels; + const auto zeros_eq = std::equal(zero_points.begin(), zero_points.end(), + litert_quantization.zero_points); + const auto scales_eq = + std::equal(scales.begin(), scales.end(), litert_quantization.scales); + return qd_eq && num_chan_eq && zeros_eq && scales_eq; +} template bool EqualsFbTensorTypeDetail(LiteRtTenzorType litert_tensor_type, const TflTensorType& tfl_tensor) { @@ -92,6 +110,9 @@ bool EqualsFbQuantization(const Quantization& litert_quantization, case kLiteRtQuantizationPerTensor: return EqualsFbQuantizationDetail(litert_quantization.second.per_tensor, tfl_quantization); + case kLiteRtQuantizationPerChannel: + return EqualsFbQuantizationDetail(litert_quantization.second.per_channel, + tfl_quantization); case kLiteRtQuantizationNone: return !IsQuantized(tfl_quantization); default: @@ -115,14 +136,25 @@ bool EqualsFbTensorType(const TensorType& litert_tensor_type, } } -// Compare litert op to flatbuffer op along with their input/output tensors -// types and quantization. Takes a callback to lookup tfl tensors the indices -// within the tfl op. +bool EqualsFbTensor(const LiteRtTensorT& litert_tensor, + const TflTensor& tfl_tensor) { + if (!EqualsFbTensorType(litert_tensor.Type(), + {tfl_tensor.type, TflShapeInfo(tfl_tensor)})) { + LITERT_LOG(LITERT_ERROR, "Tensor not same type"); + return false; + } + + if (!EqualsFbQuantization(litert_tensor.Qparams(), + tfl_tensor.quantization.get())) { + LITERT_LOG(LITERT_ERROR, "Tensor not same quantization"); + return false; + } + + return true; +} + bool EqualsFbOp(const LiteRtOpT& litert_op, const TflOp& tfl_op, GetTflTensor get_tfl_tensor) { - const auto& litert_inputs = litert_op.inputs; - const auto& litert_outputs = litert_op.outputs; - auto check_tensors = [&](auto& litert_tensors, auto& tfl_tensors) { if (litert_tensors.size() != tfl_tensors.size()) { LITERT_LOG(LITERT_ERROR, "Tensors not same size"); @@ -133,17 +165,8 @@ bool EqualsFbOp(const LiteRtOpT& litert_op, const TflOp& tfl_op, const auto& fb_tensor = get_tfl_tensor(tfl_tensors.at(i)).get(); const auto& litert_tensor = *litert_tensors.at(i); - if (!EqualsFbTensorType( - {litert_tensor.type_id, litert_tensor.type_detail}, - {fb_tensor.type, TflShapeInfo(fb_tensor)})) { - LITERT_LOG(LITERT_ERROR, "Tensor %d not same type", i); - return false; - } - - if (!EqualsFbQuantization( - {litert_tensor.q_type_id, litert_tensor.q_type_detail}, - fb_tensor.quantization.get())) { - LITERT_LOG(LITERT_ERROR, "Tensor %d not same quantization", i); + if (!EqualsFbTensor(litert_tensor, fb_tensor)) { + LITERT_LOG(LITERT_ERROR, "Tensor %d not same", i); return false; } } @@ -151,8 +174,8 @@ bool EqualsFbOp(const LiteRtOpT& litert_op, const TflOp& tfl_op, return true; }; - return check_tensors(litert_inputs, tfl_op.inputs) && - check_tensors(litert_outputs, tfl_op.outputs); + return check_tensors(litert_op.Inputs(), tfl_op.inputs) && + check_tensors(litert_op.Outputs(), tfl_op.outputs); } } // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_file_test_util.h b/tensorflow/lite/experimental/litert/core/model/model_file_test_util.h index 33337e4d257b8f..4e958d5f301d30 100644 --- a/tensorflow/lite/experimental/litert/core/model/model_file_test_util.h +++ b/tensorflow/lite/experimental/litert/core/model/model_file_test_util.h @@ -17,7 +17,6 @@ #include -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" #include "tensorflow/lite/experimental/litert/core/model/model.h" #include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" @@ -41,6 +40,11 @@ bool EqualsFbTensorType(const TensorType& litert_tensor_type, bool EqualsFbOp(const LiteRtOpT& litert_op, const TflOp& tfl_op, GetTflTensor get_tfl_tensor); +// Compare litert tensor to flatbuffer tensor for having same types and +// quantization. +bool EqualsFbTensor(const LiteRtTensorT& litert_tensor, + const TflTensor& tfl_tensor); + } // namespace litert::internal #endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_FILE_TEST_UTIL_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/model_graph.cc b/tensorflow/lite/experimental/litert/core/model/model_graph.cc new file mode 100644 index 00000000000000..4f5a5ed0fae557 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/model/model_graph.cc @@ -0,0 +1,180 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" + +#include +#include + +#include "absl/log/absl_check.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" +#include "tensorflow/lite/experimental/litert/core/model/model.h" + +namespace litert::internal { + +namespace { + +bool IsOpDead(const LiteRtOpT& op) { + return op.Inputs().empty() && op.Outputs().empty(); +} + +bool IsTensorDead(const LiteRtTensorT& tensor) { + return tensor.DefiningOp() == nullptr && tensor.NumUses() == 0; +} + +} // namespace + +void CloneTo(const LiteRtTensorT& src, LiteRtTensorT& dest) { + dest.SetName({src.Name().cbegin(), src.Name().cend()}); + dest.SetQarams(src.Qparams()); + dest.SetType(src.Type()); + // TODO: b/383906683 Avoid copying for better performance. + dest.Weights().SetFromBuf(src.Weights().Buf()); +} + +void CloneTo(const LiteRtOpT& src, LiteRtOpT& dest) { + dest.SetCustomOptions(src.CustomOptions().Data(), src.CustomOptions().Size()); + detail::SetTflOptions(dest, detail::GetTflOptions(src)); + detail::SetTflOpCodeInd(dest, detail::GetTflOpCodeInd(src)); + dest.SetOpCode(src.OpCode()); +} + +LiteRtTensorT& MakeClone(LiteRtSubgraphT& parent, const LiteRtTensorT& src) { + auto& new_tensor = parent.EmplaceTensor(); + CloneTo(src, new_tensor); + return new_tensor; +} + +LiteRtOpT& MakeClone(LiteRtSubgraphT& parent, const LiteRtOpT& src) { + auto& new_op = parent.EmplaceOp(); + CloneTo(src, new_op); + return new_op; +} + +std::optional FindInput(const LiteRtOpT& op, + const LiteRtTensorT& tensor) { + return FindInd(op.Inputs().cbegin(), op.Inputs().cend(), &tensor); +} + +std::optional FindOutput(const LiteRtOpT& op, + const LiteRtTensorT& tensor) { + return FindInd(op.Outputs().cbegin(), op.Outputs().cend(), &tensor); +} + +std::optional FindInput(const LiteRtSubgraphT& subgraph, + const LiteRtTensorT& tensor) { + return FindInd(subgraph.Inputs().cbegin(), subgraph.Inputs().cend(), &tensor); +} + +std::optional FindOutput(const LiteRtSubgraphT& subgraph, + const LiteRtTensorT& tensor) { + return FindInd(subgraph.Outputs().cbegin(), subgraph.Outputs().cend(), + &tensor); +} + +UseIndices FindUseInds(const LiteRtTensorT& tensor, const LiteRtOpT& op) { + UseIndices res; + for (auto i = 0; i < tensor.NumUses(); ++i) { + if (tensor.Users().at(i) == &op) { + res.push_back(i); + } + } + return res; +} + +bool IsConstant(const LiteRtTensorT& tensor) { + const auto is_const = tensor.Weights().Buf().Size() > 0; + ABSL_DCHECK(!is_const || tensor.DefiningOp() == nullptr) + << "Constant tensors should not be defined by an op"; + return is_const; +} + +void AttachInput(LiteRtTensor tensor, LiteRtOpT& op) { + op.Inputs().push_back(tensor); + tensor->Users().push_back(&op); + tensor->UserArgInds().push_back(op.Inputs().size() - 1); +} + +void AttachOutput(LiteRtTensor tensor, LiteRtOpT& op) { + ABSL_DCHECK(tensor->DefiningOp() == nullptr) + << "Cannot add an already defined tensor as op output"; + op.Outputs().push_back(tensor); + tensor->SetDefiningOp(op, op.Outputs().size() - 1); +} + +LiteRtTensor DisconnectInput(LiteRtOpT& op, LiteRtParamIndex input_ind) { + ABSL_DCHECK(input_ind < op.Inputs().size()) << "Removing tensor index oob"; + auto& input = op.Input(input_ind); + + // Find the index of the use for the given in edge. + auto target_use_ind = -1; + for (auto i = 0; i < input.NumUses(); ++i) { + if (input.Users().at(i) == &op && input.UserArgInds().at(i) == input_ind) { + target_use_ind = i; + } + } + ABSL_DCHECK_GE(target_use_ind, 0) << "Malformed graph"; + + // Slide latter input use arg inds to the left. + for (auto i = input_ind + 1; i < op.Inputs().size(); ++i) { + auto& r_in = op.Input(i); + for (auto u = 0; u < r_in.NumUses(); ++u) { + auto& r_arg_ind = r_in.UserArgInds().at(u); + if (r_in.Users().at(u) == &op && r_arg_ind > input_ind) { + r_arg_ind -= 1; + } + } + } + + // Update the edges. + input.RemoveUse(target_use_ind); + op.RemoveInput(input_ind); + + return &input; +} + +bool IsIO(const LiteRtSubgraphT& subgraph, const LiteRtTensorT& tensor) { + return FindInput(subgraph, tensor) || FindOutput(subgraph, tensor); +} + +LiteRtTensor DisconnectOutput(LiteRtOpT& op, LiteRtParamIndex output_ind) { + ABSL_DCHECK(output_ind < op.Outputs().size()) << "Removing tensor index oob"; + auto& output = op.Output(output_ind); + output.ClearDefiningOp(); + op.RemoveOutput(output_ind); + return &output; +} + +void Drop(LiteRtOpT& litert_op) { + while (!litert_op.Inputs().empty()) { + DisconnectInput(litert_op, 0); + } + while (!litert_op.Outputs().empty()) { + DisconnectOutput(litert_op, 0); + } +} + +bool DCE(LiteRtSubgraphT& subgraph) { + const auto ops_removed = subgraph.RemoveOpIf(IsOpDead); + + auto rm_tensor = [&subgraph = std::as_const(subgraph)](const auto& t) { + return IsTensorDead(t) && !IsIO(subgraph, t); + }; + const auto tensors_removed = subgraph.RemoveTensorIf(rm_tensor); + + return (ops_removed + tensors_removed) > 0; +} + +} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_graph.h b/tensorflow/lite/experimental/litert/core/model/model_graph.h new file mode 100644 index 00000000000000..a6c5f27580ccd1 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/model/model_graph.h @@ -0,0 +1,108 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_GRAPH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_GRAPH_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" +#include "tensorflow/lite/experimental/litert/cc/litert_consts.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/core/model/model.h" + +namespace litert::internal { + +// using IrMapping = absl::flat_hash_map; + +// CLONING + +// Clones the basic data between tensors (like name and data) but not +// things related to incoming/outgoing edges (users, defining op) or weights. +void CloneTo(const LiteRtTensorT& src, LiteRtTensorT& dest); + +// Clones the basic data between ops (like op code and options) but +// things related to incoming/outgoing edges (input/output tensors). +void CloneTo(const LiteRtOpT& src, LiteRtOpT& dest); + +// Same as clone to, but allocates a the dest tensor into given subgraph. +LiteRtTensorT& MakeClone(LiteRtSubgraphT& parent, const LiteRtTensorT& src); + +// Same as clone to, but allocates a the dest op into given subgraph. +LiteRtOpT& MakeClone(LiteRtSubgraphT& parent, const LiteRtOpT& src); + +// OBSERVERS + +// Checks if tensor is input to given op, return its index if so. +std::optional FindInput(const LiteRtOpT& op, + const LiteRtTensorT& tensor); + +// Checks if tensor is output to given op, return its index if so. +std::optional FindOutput(const LiteRtOpT& op, + const LiteRtTensorT& tensor); + +// Checks if tensor is input to given subgraph, return its index if so. +std::optional FindInput(const LiteRtSubgraphT& subgraph, + const LiteRtTensorT& tensor); + +// Checks if tensor is output to given subgraph, return its index if so. +std::optional FindOutput(const LiteRtSubgraphT& subgraph, + const LiteRtTensorT& tensor); + +// Check if tensor is part of subgraph IO. +bool IsIO(const LiteRtSubgraphT& subgraph, const LiteRtTensorT& tensor); + +using UseIndices = + absl::InlinedVector; + +// Checks if tensor is used by op, return the use inds for each use of tensor by +// op (there may be multiple). These are the indexes to call +// LiteRtTensorT::GetUse with. +UseIndices FindUseInds(const LiteRtTensorT& tensor, const LiteRtOpT& op); + +// Is this tensor a constant tensor? +bool IsConstant(const LiteRtTensorT& tensor); + +// MUTATORS + +// Attaches the pre-allocated tensor to be an input of given op. +void AttachInput(LiteRtTensor tensor, LiteRtOpT& op); + +// Attaches the pre-allocated tensor to be an output of given op. +void AttachOutput(LiteRtTensor tensor, LiteRtOpT& op); + +// Remove the input edge from an op. Return the disconnected tensor. +LiteRtTensor DisconnectInput(LiteRtOpT& op, LiteRtParamIndex input_ind); + +// Remove an output edge from an op. Return the disconnected tensor. +LiteRtTensor DisconnectOutput(LiteRtOpT& op, LiteRtParamIndex output_ind); + +// Remove all incoming and outgoing edges from this op. This can prep nodes +// for removal in DCE. +void Drop(LiteRtOpT& litert_op); + +// Run very naive dead code elimination. Removes only ops/tensors that have no +// in/out edges. Ops are handled first. Ignores subgraph IO. Not recursive and +// does only one pass. Returns if the graph was modified. +// NOTE: This de-allocates removed objects, only use when references to these +// objects will not be used. +// TODO: Update this with complete work-list based approach. +bool DCE(LiteRtSubgraphT& subgraph); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_MODEL_GRAPH_H_ diff --git a/tensorflow/lite/experimental/litert/core/model/model_graph_test.cc b/tensorflow/lite/experimental/litert/core/model/model_graph_test.cc new file mode 100644 index 00000000000000..62abc3ecb97b0c --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/model/model_graph_test.cc @@ -0,0 +1,344 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" + +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/core/model/graph_validation.h" +#include "tensorflow/lite/experimental/litert/core/model/model.h" + +namespace litert::internal { +namespace { + +using ::testing::UnorderedElementsAreArray; + +// Custom matcher; example: +// ``` +// LiteRtTensor tensor ... +// EXPECT_THAT(tensor, HasRankedType(kLiteRtInt, absl::MakeSpan({2, 2}))); +// ``` +// TODO: Update to use dumping API directly and move to shared header. +MATCHER_P2(HasRankedType, element_type, shape, "") { + if (arg.Type().first != kLiteRtRankedTensorType) { + *result_listener << "Not ranked tensor type"; + return false; + } + const auto& ranked_tensor_type = arg.Type().second.ranked_tensor_type; + const auto& layout = ranked_tensor_type.layout; + + const auto element_type_eq = ranked_tensor_type.element_type == element_type; + const auto rank_eq = layout.rank == std::size(shape); + + auto actual_shape = absl::MakeConstSpan(layout.dimensions, layout.rank); + auto expected_shape = + absl::MakeConstSpan(std::cbegin(shape), std::cend(shape)); + const auto shape_eq = actual_shape == expected_shape; + + if (shape_eq && element_type_eq && rank_eq) { + return true; + } + + *result_listener << "\n"; + if (!shape_eq) { + *result_listener << "Not correct shape\n"; + } + if (!element_type_eq) { + *result_listener << "Not correct element type\n"; + } + if (!rank_eq) { + *result_listener << "Not correct rank\n"; + } + + *result_listener << absl::StreamFormat("Actual ElementType is: %d\n", + ranked_tensor_type.element_type); + *result_listener << absl::StreamFormat("Actual Rank is: %lu\n", layout.rank); + *result_listener << "Actual shape is: { "; + for (const auto d : actual_shape) { + *result_listener << absl::StreamFormat("%d, ", d); + } + *result_listener << "}\n"; + + return false; +} + +using ::testing::ElementsAreArray; + +static constexpr size_t kRank = 1; +static constexpr int32_t kDims[] = {2}; +static constexpr absl::Span kDimsSpan(kDims); +static constexpr auto kType = kLiteRtElementTypeInt32; +static constexpr absl::string_view kCustomOptions = "OPTIONS"; +static constexpr auto kOpCode = kLiteRtOpCodeTflMul; + +LiteRtTensorT TestTensor() { + LiteRtTensorT tensor; + tensor.Type().first = kLiteRtRankedTensorType; + tensor.Type().second.ranked_tensor_type.element_type = kType; + tensor.Type().second.ranked_tensor_type.layout.dimensions[0] = kDims[0]; + tensor.Type().second.ranked_tensor_type.layout.rank = kRank; + return tensor; +} + +LiteRtOpT TestOp() { + LiteRtOpT op; + op.SetOpCode(kOpCode); + op.SetCustomOptions(kCustomOptions); + return op; +} + +TEST(ModelGraphTest, CloneTensor) { + LiteRtTensorT dest; + CloneTo(TestTensor(), dest); + EXPECT_THAT(dest, HasRankedType(kType, kDimsSpan)); +} + +TEST(ModelGraphTest, MakeCloneTensor) { + LiteRtSubgraphT subgraph; + auto& dest = MakeClone(subgraph, TestTensor()); + EXPECT_THAT(dest, HasRankedType(kType, kDimsSpan)); +} + +TEST(ModelGraphTest, CloneOp) { + LiteRtOpT dest; + CloneTo(TestOp(), dest); + EXPECT_EQ(dest.OpCode(), kOpCode); + EXPECT_EQ(dest.CustomOptions().StrView(), kCustomOptions); +} + +TEST(ModelGraphTest, MakeCloneOp) { + LiteRtSubgraphT subgraph; + auto& dest = MakeClone(subgraph, TestOp()); + EXPECT_EQ(dest.OpCode(), kOpCode); + EXPECT_EQ(dest.CustomOptions().StrView(), kCustomOptions); +} + +TEST(ModelGraphTest, OpFindInput) { + auto op = TestOp(); + auto tensor = TestTensor(); + AttachInput(&tensor, op); + auto input = FindInput(op, tensor); + ASSERT_TRUE(input); + EXPECT_EQ(*input, 0); +} + +TEST(ModelGraphTest, OpFindOutput) { + auto op = TestOp(); + auto tensor = TestTensor(); + AttachOutput(&tensor, op); + auto output = FindOutput(op, tensor); + ASSERT_TRUE(output); + EXPECT_EQ(*output, 0); +} + +TEST(ModelGraphTest, SubgraphFindInput) { + LiteRtSubgraphT subgraph; + auto tensor = TestTensor(); + subgraph.Inputs().push_back(&tensor); + auto input = FindInput(subgraph, tensor); + ASSERT_TRUE(input); + EXPECT_EQ(*input, 0); +} + +TEST(ModelGraphTest, SubgraphFindOutput) { + LiteRtSubgraphT subgraph; + auto tensor = TestTensor(); + subgraph.Outputs().push_back(&tensor); + auto output = FindOutput(subgraph, tensor); + ASSERT_TRUE(output); + EXPECT_EQ(*output, 0); +} + +TEST(ModelGraphTest, TensorFindUseInds) { + auto op1 = TestOp(); + auto op2 = TestOp(); + auto tensor = TestTensor(); + + AttachInput(&tensor, op1); + AttachInput(&tensor, op2); + AttachInput(&tensor, op1); + + auto use_inds = FindUseInds(tensor, op1); + auto uses = GetTensorUses(tensor, use_inds); + ASSERT_EQ(uses.size(), 2); + + LiteRtTensorT::UseVec expected = {{&op1, 0}, {&op1, 1}}; + EXPECT_THAT(uses, UnorderedElementsAreArray(expected)); +} + +TEST(ModelGraphTest, OpAttachInput) { + auto op = TestOp(); + auto tensor = TestTensor(); + AttachInput(&tensor, op); + EXPECT_THAT(op.Inputs(), ElementsAreArray({&tensor})); + EXPECT_THAT(tensor.Users(), ElementsAreArray({&op})); + EXPECT_THAT(tensor.UserArgInds(), ElementsAreArray({0})); +} + +TEST(ModelGraphTest, OpAttachOutput) { + auto op = TestOp(); + auto tensor = TestTensor(); + AttachOutput(&tensor, op); + EXPECT_THAT(op.Outputs(), ElementsAreArray({&tensor})); + EXPECT_EQ(tensor.DefiningOp(), &op); + EXPECT_EQ(tensor.DefiningOpOutInd(), 0); +} + +TEST(ModelGraphTest, DisconnectInputOp) { + auto op = TestOp(); + auto tensor = TestTensor(); + AttachInput(&tensor, op); + auto disconnected = DisconnectInput(op, 0); + EXPECT_EQ(disconnected, &tensor); + EXPECT_TRUE(op.Inputs().empty()); + EXPECT_TRUE(tensor.Users().empty()); + EXPECT_TRUE(tensor.UserArgInds().empty()); +} + +TEST(ModelGraphTest, DisconnectMiddleInputOp) { + auto op = TestOp(); + + auto tensor1 = TestTensor(); + auto tensor2 = TestTensor(); + auto tensor3 = TestTensor(); + + AttachInput(&tensor1, op); + AttachInput(&tensor2, op); + AttachInput(&tensor3, op); + + auto disconnected = DisconnectInput(op, 1); + + EXPECT_EQ(disconnected, &tensor2); + ASSERT_EQ(op.Inputs().size(), 2); + EXPECT_EQ(op.Inputs().front(), &tensor1); + EXPECT_EQ(op.Inputs().back(), &tensor3); + ASSERT_TRUE(tensor2.Users().empty()); + ASSERT_TRUE(tensor2.UserArgInds().empty()); + + ASSERT_TRUE(ValidateLocalTopology(op)); +} + +TEST(ModelGraphTest, DisconnectOutputOp) { + auto op = TestOp(); + auto tensor = TestTensor(); + AttachOutput(&tensor, op); + auto disconnected = DisconnectOutput(op, 0); + EXPECT_EQ(disconnected, &tensor); + EXPECT_EQ(tensor.DefiningOp(), nullptr); + EXPECT_TRUE(op.Outputs().empty()); +} + +TEST(ModelGraphTest, DropOp) { + LiteRtOpT op; + + LiteRtTensorT input1; + LiteRtTensorT input2; + LiteRtTensorT output; + + AttachInput(&input1, op); + AttachInput(&input2, op); + AttachOutput(&output, op); + + Drop(op); + + EXPECT_TRUE(op.Inputs().empty()); + EXPECT_TRUE(op.Outputs().empty()); + EXPECT_TRUE(input1.Users().empty()); + EXPECT_TRUE(input2.Users().empty()); + EXPECT_EQ(output.DefiningOp(), nullptr); +} + +TEST(ModelGraphTestDCE, NoDeadCode) { + LiteRtSubgraphT subgraph; + + auto& input = subgraph.EmplaceTensor(); + auto& output = subgraph.EmplaceTensor(); + + auto& op = subgraph.EmplaceOp(); + + AttachInput(&input, op); + AttachOutput(&output, op); + + subgraph.Inputs().push_back(&input); + subgraph.Outputs().push_back(&output); + + ASSERT_FALSE(DCE(subgraph)); + EXPECT_EQ(subgraph.Ops().size(), 1); + EXPECT_EQ(subgraph.Tensors().size(), 2); + + ASSERT_TRUE( + ValidateLocalTopology(subgraph.Ops().cbegin(), subgraph.Ops().cend())); + ASSERT_TRUE(ValidateSubgraphIO(subgraph)); +} + +TEST(ModelGraphTestDCE, DeadTensor) { + LiteRtSubgraphT subgraph; + subgraph.EmplaceTensor(); + + ASSERT_TRUE(DCE(subgraph)); + EXPECT_TRUE(subgraph.Tensors().empty()); + + ASSERT_TRUE( + ValidateLocalTopology(subgraph.Ops().cbegin(), subgraph.Ops().cend())); + ASSERT_TRUE(ValidateSubgraphIO(subgraph)); +} + +TEST(ModelGraphTestDCE, DeadOp) { + LiteRtSubgraphT subgraph; + subgraph.EmplaceOp(); + + ASSERT_TRUE(DCE(subgraph)); + EXPECT_TRUE(subgraph.Ops().empty()); + + ASSERT_TRUE( + ValidateLocalTopology(subgraph.Ops().cbegin(), subgraph.Ops().cend())); + ASSERT_TRUE(ValidateSubgraphIO(subgraph)); +} + +TEST(ModelGraphTestDCE, SomeDead) { + LiteRtSubgraphT subgraph; + + auto& input = subgraph.EmplaceTensor(); + auto& output = subgraph.EmplaceTensor(); + + auto& op = subgraph.EmplaceOp(); + + AttachInput(&input, op); + AttachOutput(&output, op); + + // Dead + subgraph.EmplaceTensor(); + subgraph.EmplaceOp(); + + subgraph.Inputs().push_back(&input); + subgraph.Outputs().push_back(&output); + + ASSERT_TRUE(DCE(subgraph)); + EXPECT_EQ(subgraph.Ops().size(), 1); + EXPECT_EQ(subgraph.Tensors().size(), 2); + + ASSERT_TRUE( + ValidateLocalTopology(subgraph.Ops().cbegin(), subgraph.Ops().cend())); + ASSERT_TRUE(ValidateSubgraphIO(subgraph)); +} + +} // namespace +} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_load.cc b/tensorflow/lite/experimental/litert/core/model/model_load.cc index e74ef31208fe03..4d15ba291ea1ef 100644 --- a/tensorflow/lite/experimental/litert/core/model/model_load.cc +++ b/tensorflow/lite/experimental/litert/core/model/model_load.cc @@ -14,9 +14,8 @@ #include "tensorflow/lite/experimental/litert/core/model/model_load.h" -#include +#include #include -#include #include #include #include @@ -24,255 +23,305 @@ #include "absl/strings/string_view.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" #include "tensorflow/lite/experimental/litert/c/litert_op_code.h" #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" #include "tensorflow/lite/experimental/litert/cc/litert_macros.h" #include "tensorflow/lite/experimental/litert/core/model/flatbuffer_to_litert.h" #include "tensorflow/lite/experimental/litert/core/model/model.h" +#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" #include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" #include "tensorflow/lite/schema/schema_generated.h" namespace litert::internal { namespace { -using GetBuffer = std::function(uint32_t ind)>; -using GetOpCode = std::function(uint32_t ind)>; -using GetTensor = std::function(size_t ind)>; +// Provides a view of model-level resources when constructing litert graph. +class FlatbufferContext { + public: + explicit FlatbufferContext(TflModel& tfl_model) : tfl_model_(tfl_model) {} -LiteRtStatus ConvertTensor(const TflTensor& tfl_tensor, GetBuffer get_buffer, - LiteRtTensorT& target) { - LITERT_RETURN_STATUS_IF_NOT_OK(IsTensorSupported(tfl_tensor)); + void SetOpCode(LiteRtOpT& litert_op, uint32_t ind) { + auto tfl_op_code = GetTflOpCode(tfl_model_, ind); + litert_op.SetOpCode(static_cast(*tfl_op_code)); + detail::SetTflOpCodeInd(litert_op, ind); + } - const auto buffer_ind = tfl_tensor.buffer; - if (buffer_ind != 0) { - auto buffer = get_buffer(tfl_tensor.buffer); - if (!buffer) { - return buffer.Error().Status(); + // Take ownership of the tfl buffer under the given index if it exists. + Expected TakeTflBuffer(uint32_t ind) { + // TODO: Return (and store in litert model) these as shared pointers + // and remove copy. + auto tfl_buf = GetBuffer(tfl_model_, ind); + if (!tfl_buf) { + return tfl_buf.Error(); } - LITERT_RETURN_STATUS_IF_NOT_OK(IsBufferSupported(**buffer)); - target.weights.fb_buffer = std::move(*buffer); + return std::make_unique(**tfl_buf); } - TflTensorType tfl_tensor_type(tfl_tensor.type, TflShapeInfo(tfl_tensor)); - auto tensor_type = MapTensorType(tfl_tensor_type); - if (!tensor_type) { - return tensor_type.Error().Status(); - } + private: + TflModel& tfl_model_; +}; - target.type_id = tensor_type->first; - target.type_detail = tensor_type->second; +LiteRtStatus UnpackOp(FlatbufferContext& context, LiteRtSubgraphT& parent, + TflOpPtr tfl_op, LiteRtOpT& litert_op) { + // I/O TENSORS - auto quantization = MapQuantization(tfl_tensor.quantization.get()); - if (!quantization) { - return quantization.Error().Status(); + if (!tfl_op->intermediates.empty()) { + // TODO: b/365299994 - Support intermediates. + LITERT_LOG(LITERT_ERROR, "Intermediate tensors not yet supported."); + return kLiteRtStatusErrorUnsupported; } - target.q_type_id = quantization->first; - target.q_type_detail = quantization->second; - - target.name = tfl_tensor.name; - - return kLiteRtStatusOk; -} - -LiteRtStatus ConvertOp(const TflOp& op, GetTensor get_tensor, - GetOpCode get_op_code, LiteRtOpT& target) { - LITERT_RETURN_STATUS_IF_NOT_OK(IsOpSupported(op)); - - auto op_code = get_op_code(op.opcode_index); - if (!op_code) { - return op_code.Error().Status(); + for (auto m_input : tfl_op->mutating_variable_inputs) { + if (m_input) { + // TODO: b/365299994 - Support mutating variable inputs. + LITERT_LOG(LITERT_ERROR, "Mutating variable inputs not yet supported."); + return kLiteRtStatusErrorUnsupported; + } } - target.op_code = *op_code; - for (auto input_ind : op.inputs) { + for (auto input_ind : tfl_op->inputs) { // Skipping optional input tensor. if (input_ind == -1) { continue; } + AttachInput(&parent.Tensor(input_ind), litert_op); + } - auto input_tensor = get_tensor(input_ind); - if (!input_tensor) { - return input_tensor.Error().Status(); - } - - target.AddInput(input_tensor->get()); + for (auto output_ind : tfl_op->outputs) { + AttachOutput(&parent.Tensor(output_ind), litert_op); } - for (auto output_ind : op.outputs) { - auto output_tensor = get_tensor(output_ind); - if (!output_tensor) { - return output_tensor.Error().Status(); - } + // OPTIONS - target.AddOutput(output_tensor->get()); + if (tfl_op->large_custom_options_size != 0) { + // TODO: b/365299994 - Support large custom options. + LITERT_LOG(LITERT_ERROR, "Large custom options not yet supported."); + return kLiteRtStatusErrorUnsupported; } - target.option = op.builtin_options; - target.custom_options = OwningBufferRef(op.custom_options.data(), - op.custom_options.size()); + const auto& tfl_custom_opts = tfl_op->custom_options; + litert_op.SetCustomOptions(tfl_custom_opts.data(), tfl_custom_opts.size()); + detail::SetTflOptions(litert_op, std::move(tfl_op->builtin_options)); + + // OP CODE + + context.SetOpCode(litert_op, tfl_op->opcode_index); return kLiteRtStatusOk; } -class ModelUnpacker { - public: - static LiteRtStatus Unpack(LiteRtModel model); +LiteRtStatus UnpackTensor(FlatbufferContext& context, TflTensorPtr tfl_tensor, + LiteRtTensorT& litert_tensor) { + // WEIGHTS - private: - explicit ModelUnpacker(LiteRtModel model) : model_(model) {} - - LiteRtStatus UnpackSubgraph(LiteRtSubgraphT& target); + const auto buffer_ind = tfl_tensor->buffer; + if (buffer_ind != 0) { + auto buffer = context.TakeTflBuffer(buffer_ind); + if (!buffer) { + return buffer.Error().Status(); + } - GetBuffer GetBufferCallback() { - return [&](auto buffer_ind) { return TakeBuffer(Fb(), buffer_ind); }; + if (buffer->get()->offset != 0) { + // TODO: b/365299994 - Support buffer with offset. + LITERT_LOG(LITERT_ERROR, "Buffers with offset not yet supported."); + return kLiteRtStatusErrorUnsupported; + } + detail::SetTflBuffer(litert_tensor.Weights(), std::move(*buffer)); } - GetOpCode GetOpCodeCallback() { - return [&](auto opcode_ind) -> Expected { - auto tfl_op_code = GetTflOpCode(Fb(), opcode_ind); - if (!tfl_op_code) { - return tfl_op_code.Error(); - } - return static_cast(*tfl_op_code); - }; - } + // TENSOR TYPE - GetTensor GetTensorCallBack(const LiteRtSubgraphT& subgraph) { - return [&](auto tensor_ind) -> Expected { - if (tensor_ind >= subgraph.tensors.size()) { - return Error(kLiteRtStatusErrorIndexOOB); - } - return std::ref(*subgraph.tensors.at(tensor_ind)); - }; + TflTensorType tfl_tensor_type(tfl_tensor->type, TflShapeInfo(*tfl_tensor)); + auto tensor_type = MapTensorType(tfl_tensor_type); + if (!tensor_type) { + return tensor_type.Error().Status(); } - TflModel& Fb() { return *model_->flatbuffer_model; } - - LiteRtModel model_; -}; + litert_tensor.SetType(std::move(*tensor_type)); -LiteRtStatus ModelUnpacker::UnpackSubgraph(LiteRtSubgraphT& target) { - auto& flatbuffer_subgraph = target.flatbuffer_subgraph; + // QUANTIZATION - for (auto& flatbuffer_tensor : flatbuffer_subgraph->tensors) { - LITERT_RETURN_STATUS_IF_NOT_OK(IsTensorSupported(*flatbuffer_tensor)); - LITERT_RETURN_STATUS_IF_NOT_OK(ConvertTensor( - *flatbuffer_tensor, GetBufferCallback(), target.EmplaceTensor())); + auto quantization = + MapQuantization(tfl_tensor->quantization.get(), litert_tensor); + if (!quantization) { + return quantization.Error().Status(); } - for (auto& flatbuffer_op : flatbuffer_subgraph->operators) { - LITERT_RETURN_STATUS_IF_NOT_OK( - ConvertOp(*flatbuffer_op, GetTensorCallBack(target), - GetOpCodeCallback(), target.EmplaceOp())); + litert_tensor.SetQarams(std::move(*quantization)); + + // MISC + + litert_tensor.SetName(tfl_tensor->name); + + if (tfl_tensor->is_variable) { + // TODO: b/365299994 - Support variable tensors. + LITERT_LOG(LITERT_ERROR, "Variable tensors not yet supported."); + return kLiteRtStatusErrorUnsupported; } - for (auto input : flatbuffer_subgraph->inputs) { - target.inputs.push_back(target.tensors[input]); + if (!tfl_tensor->variant_tensors.empty()) { + // TODO: b/365299994 - Support variant tensors. + LITERT_LOG(LITERT_ERROR, "Variant tensors not yet supported."); + return kLiteRtStatusErrorUnsupported; } - for (auto output : flatbuffer_subgraph->outputs) { - target.outputs.push_back(target.tensors[output]); + if (tfl_tensor->sparsity) { + // TODO: b/365299994 - Support sparsity tensors. + LITERT_LOG(LITERT_ERROR, "Sparsity tensors not yet supported."); + return kLiteRtStatusErrorUnsupported; } return kLiteRtStatusOk; } -LiteRtStatus ModelUnpacker::Unpack(LiteRtModel model) { - ModelUnpacker unpacker(model); +LiteRtStatus UnpackSubgraph(FlatbufferContext& context, + TflSubgraphPtr tfl_subgraph, + LiteRtSubgraphT& litert_subgraph) { + // Unpack tensors. + for (auto& tfl_tensor : tfl_subgraph->tensors) { + LITERT_RETURN_STATUS_IF_NOT_OK(UnpackTensor( + context, std::move(tfl_tensor), litert_subgraph.EmplaceTensor())); + } - if (unpacker.Fb().subgraphs.size() != 1) { - // TODO: b/365299994 - Support multi subgraph. - LITERT_LOG(LITERT_ERROR, "%s", - "Only models with 1 subgraph current supported\n"); - return kLiteRtStatusErrorUnsupported; + // Unpack ops, pass litert_subgraph so they can look up the new litert + // tensors. + for (auto& tfl_op : tfl_subgraph->operators) { + LITERT_RETURN_STATUS_IF_NOT_OK(UnpackOp(context, litert_subgraph, + std::move(tfl_op), + litert_subgraph.EmplaceOp())); } - auto& subgraph = model->subgraphs.emplace_back(); - subgraph.flatbuffer_subgraph = std::move(unpacker.Fb().subgraphs[0]); - LITERT_RETURN_STATUS_IF_NOT_OK(unpacker.UnpackSubgraph(subgraph)); - - // Unpack signatures. If there are no signatures, create a default one with - // LiteRtDefaultSignatureKey. - if (unpacker.Fb().signature_defs.empty()) { - model->signatures.reserve(1); - auto signature = std::make_unique(); - signature->key = LITERT_DEFAULT_SIGNATURE_KEY; - signature->subgraph_index = 0; - signature->input_names.reserve(subgraph.inputs.size()); - for (auto& input : subgraph.inputs) { - signature->input_names.push_back(input->name); - } - signature->output_names.reserve(subgraph.outputs.size()); - for (auto& output : subgraph.outputs) { - signature->output_names.push_back(output->name); + // Update subgraph I/O. + for (auto tfl_input_ind : tfl_subgraph->inputs) { + litert_subgraph.Inputs().push_back(&litert_subgraph.Tensor(tfl_input_ind)); + } + for (auto tfl_output_ind : tfl_subgraph->outputs) { + litert_subgraph.Outputs().push_back( + &litert_subgraph.Tensor(tfl_output_ind)); + } + + return kLiteRtStatusOk; +} + +LiteRtStatus UnpackSignatures(std::vector& tfl_signatures, + LiteRtModelT& parent) { + for (auto& tfl_signature : tfl_signatures) { + auto* litert_subgraph = + parent.Subgraphs().at(tfl_signature->subgraph_index); + + auto& tfl_inputs = tfl_signature->inputs; + auto& tfl_outputs = tfl_signature->outputs; + +#ifndef NDEBUG + // Tflite signatures map a tensor index to a name. We just assume + // that the indexes are exactly those of the subgraph inputs. Check + // this in debug mode. + if (tfl_inputs.size() != litert_subgraph->Inputs().size() || + tfl_outputs.size() != litert_subgraph->Outputs().size()) { + LITERT_LOG(LITERT_ERROR, + "Signature has incorrect number of input/outputs"); } - model->signatures.push_back(std::move(signature)); - } else { - model->signatures.reserve(unpacker.Fb().signature_defs.size()); - for (auto& signature_def : unpacker.Fb().signature_defs) { - auto signature = std::make_unique(); - signature->key = signature_def->signature_key; - signature->subgraph_index = signature_def->subgraph_index; - signature->input_names.reserve(signature_def->inputs.size()); - for (auto& input : signature_def->inputs) { - signature->input_names.push_back(input->name); + + for (auto i = 0; i < tfl_inputs.size(); ++i) { + const auto& tfl_input = tfl_inputs.at(i); + const auto* litert_input = litert_subgraph->Inputs().at(i); + const auto* index_litert_input = + litert_subgraph->Tensors().at(tfl_input->tensor_index); + if (litert_input != index_litert_input) { + LITERT_LOG(LITERT_ERROR, + "Signature inputs reference tensors not in subgraph i/o"); } - signature->output_names.reserve(signature_def->outputs.size()); - for (auto& output : signature_def->outputs) { - signature->output_names.push_back(output->name); + } + + for (auto i = 0; i < tfl_outputs.size(); ++i) { + const auto& tfl_output = tfl_outputs.at(i); + const auto* litert_output = litert_subgraph->Outputs().at(i); + const auto* index_litert_output = + litert_subgraph->Tensors().at(tfl_output->tensor_index); + if (litert_output != index_litert_output) { + LITERT_LOG(LITERT_ERROR, + "Signature outputs reference tensors not in subgraph i/o"); } - model->signatures.push_back(std::move(signature)); } +#endif + + auto get_name = [](const auto& tfl_tensor) { return tfl_tensor->name; }; + + std::vector input_names(tfl_inputs.size()); + std::transform(tfl_inputs.cbegin(), tfl_inputs.cend(), input_names.begin(), + get_name); + + std::vector output_names(tfl_outputs.size()); + std::transform(tfl_outputs.cbegin(), tfl_outputs.cend(), + output_names.begin(), get_name); + + parent.EmplaceSignature(litert_subgraph, std::move(input_names), + std::move(output_names), + tfl_signature->signature_key); + } + + if (tfl_signatures.empty()) { + parent.EmplaceSignature(MakeDefaultSignature(parent.MainSubgraph())); } return kLiteRtStatusOk; } -Expected> LoadModelFromFlatbuffer( - std::unique_ptr flatbuffer) { +LiteRtStatus UnpackMetadata(FlatbufferContext& context, + std::vector& tfl_metadata, + LiteRtModelT& parent) { + for (auto& tfl_m_data : tfl_metadata) { + auto tfl_buffer = context.TakeTflBuffer(tfl_m_data->buffer); + if (!tfl_buffer) { + return tfl_buffer.Error().Status(); + } + + const auto& tfl_vec = tfl_buffer->get()->data; + parent.PushMetadata(tfl_m_data->name, tfl_vec.data(), tfl_vec.size()); + } + + return kLiteRtStatusOk; +} + +Expected UnpackModel(TflModelPtr tfl_model) { auto litert_model = std::make_unique(); - litert_model->flatbuffer_model = std::move(flatbuffer); - litert_model->subgraphs.reserve(100); + FlatbufferContext context(*tfl_model); - if (auto status = ModelUnpacker::Unpack(litert_model.get()); - status != kLiteRtStatusOk) { - return Unexpected(status); + for (auto& tfl_subgraph : tfl_model->subgraphs) { + LITERT_EXPECT_OK(UnpackSubgraph(context, std::move(tfl_subgraph), + litert_model->EmplaceSubgraph())); } - litert_model->flatbuffer_model->subgraphs.clear(); + LITERT_EXPECT_OK(UnpackSignatures(tfl_model->signature_defs, *litert_model)); + LITERT_EXPECT_OK(UnpackMetadata(context, tfl_model->metadata, *litert_model)); + detail::SetTflOpCodes(*litert_model, std::move(tfl_model->operator_codes)); return litert_model; } } // namespace -Expected> LoadModelFromBuffer( - BufferRef buffer) { +Expected LoadModelFromBuffer(BufferRef buffer) { auto flatbuffer = FlatbufferWrapper::CreateFromBuffer(buffer); if (!flatbuffer) { return flatbuffer.Error(); } - auto litert_model = LoadModelFromFlatbuffer( - std::make_unique(std::move((*flatbuffer)->UnpackedModel()))); + auto litert_model = UnpackModel(flatbuffer->get()->Unpack()); if (litert_model) { // Save the original FB pointer to use it later on CompiledModel. - (*litert_model)->model_buffer = buffer.Data(); - (*litert_model)->model_buffer_size = buffer.Size(); + detail::SetTflInitFlatbuffer(**litert_model, buffer); } return litert_model; } -Expected> LoadModelFromFile( - absl::string_view filename) { +Expected LoadModelFromFile(absl::string_view filename) { auto flatbuffer = FlatbufferWrapper::CreateFromTflFile(filename); if (!flatbuffer) { return flatbuffer.Error(); } - - return LoadModelFromFlatbuffer( - std::make_unique(std::move((*flatbuffer)->UnpackedModel()))); + return UnpackModel(flatbuffer->get()->Unpack()); } } // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_serialize.cc b/tensorflow/lite/experimental/litert/core/model/model_serialize.cc index 3fb75ebb18dbba..02ed871f6a9b78 100644 --- a/tensorflow/lite/experimental/litert/core/model/model_serialize.cc +++ b/tensorflow/lite/experimental/litert/core/model/model_serialize.cc @@ -14,25 +14,24 @@ #include "tensorflow/lite/experimental/litert/core/model/model_serialize.h" +#include #include #include -#include +#include #include +#include +#include #include #include #include #include "absl/container/flat_hash_map.h" -#include "absl/log/absl_check.h" -#include "absl/strings/string_view.h" -#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "tensorflow/lite/experimental/litert/c/litert_common.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" #include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/core/byte_code_util.h" #include "tensorflow/lite/experimental/litert/core/model/litert_to_flatbuffer.h" #include "tensorflow/lite/experimental/litert/core/model/model.h" #include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" @@ -41,207 +40,278 @@ namespace litert::internal { namespace { -using OpCodeMap = absl::flat_hash_map; using TensorMap = absl::flat_hash_map; -TflOpCodePtr MakeCustomOpCode(absl::string_view custom_code_name) { - auto custom_code = std::make_unique(); - custom_code->builtin_code = ::tflite::BuiltinOperator_CUSTOM; - custom_code->custom_code.assign(custom_code_name.begin(), - custom_code_name.end()); - custom_code->version = 1; - return custom_code; -} +// Pop npu related stuff from model if it exists and requires a post process +// step (i.e. appending byte code to tflite). +std::optional> PopByteCodeIfNeedsPostProcess( + LiteRtModelT& model) { + auto build_stamp_buf = model.FindMetadata(kLiteRtBuildStampKey); + if (!build_stamp_buf) { + return std::nullopt; + } -OpCodeMap BuildOpCodeMap(const std::vector& op_codes) { - OpCodeMap map; - for (auto i = 0; i < op_codes.size(); ++i) { - const auto tfl_code = op_codes[i]->builtin_code; - map.insert({static_cast(tfl_code), i}); + auto build_stamp = ParseBuildStamp(*build_stamp_buf); + if (!build_stamp) { + LITERT_LOG(LITERT_WARNING, + "Model has a build stamp but it couldn't be parsed"); + return std::nullopt; } - return map; -} -void SetOptions(const LiteRtOpT& litert_op, TflOp& tfl_op) { - tfl_op.builtin_options = litert_op.option; + // Only appending needs separate strategy. + if (std::get<2>(*build_stamp) != kAppend) { + return std::nullopt; + } - if (litert_op.custom_options.Size() != 0) { - tfl_op.custom_options = litert_op.custom_options.ToVec(); - tfl_op.custom_options_format = tflite::CustomOptionsFormat_FLEXBUFFERS; + // Pop the actual byte and and replace it with a placeholder value + // which will be + auto byte_code = model.PopMetadata(kByteCodeMetadataKey); + if (!byte_code) { + LITERT_LOG(LITERT_WARNING, "Model has npu build stamp but no byte code"); + return std::nullopt; } + model.PushMetadata(kByteCodeMetadataKey, MakeByteCodePlaceholder()); + + return *byte_code; } -class ModelRepacker { - public: - static LiteRtStatus Repack(LiteRtModelT& model); +Expected> AppendByteCode( + OwningBufferRef flatbuffer, + OwningBufferRef npu_byte_code) { + LITERT_EXPECT_OK( + FinishByteCodePlaceholders(flatbuffer, npu_byte_code.Size())); - private: - explicit ModelRepacker(LiteRtModelT::Ref model) : model_(model) { - if (!model_.get().custom_op_code.empty()) { - model_.get().flatbuffer_model->operator_codes.emplace_back( - MakeCustomOpCode(model_.get().custom_op_code)); - } - op_code_map_ = - BuildOpCodeMap(model_.get().flatbuffer_model->operator_codes); + const auto res_size = flatbuffer.Size() + npu_byte_code.Size(); + OwningBufferRef res(res_size); + + uint8_t* it = res.Data(); + std::memcpy(it, flatbuffer.Data(), flatbuffer.Size()); + it += flatbuffer.Size(); + std::memcpy(it, npu_byte_code.Data(), npu_byte_code.Size()); + + return res; +} + +// This is expected to be used to serialize the dispatch op custom code. +TflOpCodePtr MakeCustomOpCode(std::string custom_code_name) { + auto custom_code = std::make_unique(); + custom_code->builtin_code = ::tflite::BuiltinOperator_CUSTOM; + custom_code->custom_code = std::move(custom_code_name); + custom_code->version = 1; + return custom_code; +} + +// Utility for accessing flatbuffer state. +class FlatbufferBuilder { + public: + explicit FlatbufferBuilder(uint32_t dispatch_op_code_ind) + : tfl_model_(std::make_unique()), + dispatch_op_code_ind_(dispatch_op_code_ind) { + // Tfl expects empty buffer 0. + tfl_model_->buffers.push_back(std::make_unique()); } - LiteRtStatus SerializeTensor(LiteRtTensorT& tensor, TflTensor& target); + TflModel& Model() { return *tfl_model_.get(); } - LiteRtStatus SerializeOp(LiteRtOpT& op, TflOp& target, - const TensorMap& tensor_map); + TflModelPtr Release() && { return std::move(tfl_model_); } - LiteRtStatus SerializeSubgraph(LiteRtSubgraphT& subgraph, - TflSubgraph& target); + // Move given buffer into tfl model and get its index. + uint32_t SubmitBuffer(TflBufferPtr tfl_buffer) { + tfl_model_->buffers.push_back(std::move(tfl_buffer)); + return tfl_model_->buffers.size() - 1; + } - uint32_t SubmitBuffer(TflBufferPtr buffer) { - OldFb().buffers.push_back(std::move(buffer)); - return OldFb().buffers.size() - 1; + // Add to tfl model metadata. + void PushMetadata(std::string key, BufferRef data) { + auto tfl_buffer = std::make_unique(); + tfl_buffer->data.assign(data.Data(), data.Data() + data.Size()); + auto tfl_buffer_ind = SubmitBuffer(std::move(tfl_buffer)); + tfl_model_->metadata_buffer.push_back(tfl_buffer_ind); + auto tfl_metadata = std::make_unique(); + tfl_metadata->name = key; + tfl_metadata->buffer = tfl_buffer_ind; + tfl_model_->metadata.push_back(std::move(tfl_metadata)); } - TflModel& OldFb() { return *model_.get().flatbuffer_model; } + // Get the index in the tfl op codes for the dispatch custom code. + // This should be the only new custom code added after loading the initial + // tfl. + uint32_t DispatchOpCodeInd() const { return dispatch_op_code_ind_; } - LiteRtModelT::Ref model_; - OpCodeMap op_code_map_; + private: + TflModelPtr tfl_model_; + uint32_t dispatch_op_code_ind_; }; -LiteRtStatus ModelRepacker::SerializeTensor(LiteRtTensorT& tensor, - TflTensor& target) { - auto tfl_tensor_type = MapTensorType({tensor.type_id, tensor.type_detail}); - if (!tfl_tensor_type) { - return tfl_tensor_type.Error().Status(); +void SetOptions(const LiteRtOpT& litert_op, TflOp& tfl_op) { + tfl_op.builtin_options = detail::GetTflOptions(litert_op); + if (litert_op.CustomOptions().Size() != 0) { + tfl_op.custom_options = litert_op.CustomOptions().ToVec(); + tfl_op.custom_options_format = tflite::CustomOptionsFormat_FLEXBUFFERS; } - auto [tfl_elem_type, tfl_shape] = *tfl_tensor_type; +} - target.type = tfl_elem_type; - target.shape.assign(tfl_shape.shape.begin(), tfl_shape.shape.end()); - target.has_rank = tfl_shape.has_rank; - target.shape_signature.assign(tfl_shape.shape_signature.begin(), - tfl_shape.shape_signature.end()); +LiteRtStatus PackOp(FlatbufferBuilder& builder, LiteRtOpT& litert_op, + TflOp& tfl_op, const TensorMap& tensor_map) { + auto tfl_op_code_ind = detail::GetTflOpCodeInd(litert_op); + if (tfl_op_code_ind < 0) { + tfl_op_code_ind = builder.DispatchOpCodeInd(); + } + tfl_op.opcode_index = tfl_op_code_ind; - auto tfl_quantization = - MapQuantization(std::make_pair(tensor.q_type_id, tensor.q_type_detail)); - if (!tfl_quantization) { - return tfl_quantization.Error().Status(); + for (auto* in : litert_op.Inputs()) { + tfl_op.inputs.push_back(tensor_map.at(in)); } - target.quantization = std::move(*tfl_quantization); - ABSL_DCHECK(tensor.weights.fb_buffer != nullptr) - << "Submitting a null buffer"; - target.buffer = SubmitBuffer(std::move(tensor.weights.fb_buffer)); + for (auto* out : litert_op.Outputs()) { + tfl_op.outputs.push_back(tensor_map.at(out)); + } - target.name = tensor.name; + SetOptions(litert_op, tfl_op); return kLiteRtStatusOk; } -LiteRtStatus ModelRepacker::SerializeOp(LiteRtOpT& op, TflOp& target, - const TensorMap& tensor_map) { - target.opcode_index = op_code_map_.at(op.op_code); - - for (auto in : op.inputs) { - target.inputs.push_back(tensor_map.at(in)); +LiteRtStatus PackTensor(FlatbufferBuilder& builder, + LiteRtTensorT& litert_tensor, TflTensor& tfl_tensor) { + auto tfl_tensor_type = MapTensorType(litert_tensor.Type()); + if (!tfl_tensor_type) { + return tfl_tensor_type.Error().Status(); } + auto [tfl_elem_type, tfl_shape] = *tfl_tensor_type; - for (auto out : op.outputs) { - target.outputs.push_back(tensor_map.at(out)); - } + tfl_tensor.type = tfl_elem_type; + tfl_tensor.shape.assign(tfl_shape.shape.begin(), tfl_shape.shape.end()); + tfl_tensor.has_rank = tfl_shape.has_rank; + tfl_tensor.shape_signature.assign(tfl_shape.shape_signature.begin(), + tfl_shape.shape_signature.end()); - SetOptions(op, target); + auto tfl_quantization = MapQuantization(litert_tensor.Qparams()); + if (!tfl_quantization) { + return tfl_quantization.Error().Status(); + } + tfl_tensor.quantization = std::move(*tfl_quantization); - // TODO: b/365299994 - Support exotic op fields in serialize. + tfl_tensor.buffer = + builder.SubmitBuffer(detail::TakeTflBuffer(litert_tensor.Weights())); + tfl_tensor.name = std::string(litert_tensor.Name()); return kLiteRtStatusOk; } -LiteRtStatus ModelRepacker::SerializeSubgraph(LiteRtSubgraphT& subgraph, - TflSubgraph& target) { - TensorMap tensor_map; - - for (auto tensor : subgraph.tensors) { - tensor_map.insert({tensor, tensor_map.size()}); - target.tensors.push_back(std::make_unique()); +LiteRtStatus PackSubgraph(FlatbufferBuilder& builder, + LiteRtSubgraphT& litert_subgraph, + TflSubgraph& tfl_subgraph, TensorMap& tensor_map) { + for (auto* tensor : litert_subgraph.Tensors()) { + tfl_subgraph.tensors.push_back(std::make_unique()); + tensor_map.insert({tensor, tfl_subgraph.tensors.size() - 1}); LITERT_RETURN_STATUS_IF_NOT_OK( - SerializeTensor(*tensor, *target.tensors.back())); + PackTensor(builder, *tensor, *tfl_subgraph.tensors.back())); } - for (auto op : subgraph.ops) { - target.operators.push_back(std::make_unique()); + for (auto* op : litert_subgraph.Ops()) { + tfl_subgraph.operators.push_back(std::make_unique()); LITERT_RETURN_STATUS_IF_NOT_OK( - SerializeOp(*op, *target.operators.back(), tensor_map)); + PackOp(builder, *op, *tfl_subgraph.operators.back(), tensor_map)); } - for (auto in : subgraph.inputs) { - target.inputs.push_back(tensor_map.at(in)); + for (auto* in : litert_subgraph.Inputs()) { + tfl_subgraph.inputs.push_back(tensor_map.at(in)); } - for (auto out : subgraph.outputs) { - target.outputs.push_back(tensor_map.at(out)); + + for (auto* out : litert_subgraph.Outputs()) { + tfl_subgraph.outputs.push_back(tensor_map.at(out)); } return kLiteRtStatusOk; } -LiteRtStatus ModelRepacker::Repack(LiteRtModelT& model) { - ModelRepacker repacker(model); +Expected PackAsTflite(LiteRtModelT& litert_model) { + // Pass the op code list through that was saved during loading. Add one more + // op code for the dispatch ops. + auto tfl_op_codes = detail::TakeTflOpCodes(litert_model); + tfl_op_codes.push_back( + MakeCustomOpCode(std::string(kLiteRtDispatchOpCustomCode))); - auto& target = repacker.OldFb(); + FlatbufferBuilder builder(tfl_op_codes.size() - 1); + builder.Model().operator_codes = std::move(tfl_op_codes); - std::vector>> - metadata; - for (auto& flatbuffer_metadata : target.metadata) { - const auto metadata_buffer_ind = flatbuffer_metadata->buffer; - metadata.push_back({flatbuffer_metadata->name, - std::move(target.buffers[metadata_buffer_ind])}); + // Pack litert subgraphs into tfl subgraphs and save the mapping of tensors. + TensorMap tensor_map; + for (auto* litert_subgraph : litert_model.Subgraphs()) { + auto& tfl_subgraph = *builder.Model().subgraphs.emplace_back( + std::make_unique()); + LITERT_EXPECT_OK( + PackSubgraph(builder, *litert_subgraph, tfl_subgraph, tensor_map)); } - target.subgraphs.clear(); - target.buffers.clear(); - target.metadata.clear(); - target.metadata_buffer.clear(); - - target.buffers.push_back(std::make_unique()); + // Serialize the signatures using saved tensor mapping. + for (auto* litert_signature : litert_model.Signatures()) { + auto* litert_subgraph = &litert_signature->GetSubgraph(); + + auto& tfl_signature = *builder.Model().signature_defs.emplace_back( + std::make_unique()); + tfl_signature.signature_key = std::string(litert_signature->Key()); + + auto begin = litert_model.Subgraphs().cbegin(); + auto end = litert_model.Subgraphs().cend(); + const auto litert_subgraph_ind = + std::find(begin, end, litert_subgraph) - begin; + tfl_signature.subgraph_index = litert_subgraph_ind; + + auto input_ind = 0; + for (const auto& litert_name : litert_signature->InputNames()) { + auto& tfl_input = *tfl_signature.inputs.emplace_back( + std::make_unique<::tflite::TensorMapT>()); + tfl_input.name = litert_name; + tfl_input.tensor_index = + tensor_map.find(litert_subgraph->Inputs().at(input_ind))->second; + ++input_ind; + } - for (auto& subgraph : model.subgraphs) { - target.subgraphs.push_back(std::make_unique()); - LITERT_RETURN_STATUS_IF_NOT_OK( - repacker.SerializeSubgraph(subgraph, *target.subgraphs.back())); + auto output_ind = 0; + for (const auto& litert_name : litert_signature->OutputNames()) { + auto& tfl_output = *tfl_signature.outputs.emplace_back( + std::make_unique<::tflite::TensorMapT>()); + tfl_output.name = litert_name; + tfl_output.tensor_index = + tensor_map.find(litert_subgraph->Outputs().at(output_ind))->second; + ++output_ind; + } } - for (auto& [name, buf] : metadata) { - const auto new_ind = target.buffers.size(); - auto new_metadata = std::make_unique(); - new_metadata->name = name; - new_metadata->buffer = new_ind; - target.metadata.emplace_back(std::move(new_metadata)); - target.metadata_buffer.push_back(new_ind); - target.buffers.emplace_back(std::move(buf)); + // Serialize metadata. + for (auto it = litert_model.MetadataBegin(); it != litert_model.MetadataEnd(); + ++it) { + builder.PushMetadata(it->first, it->second); } - return kLiteRtStatusOk; + builder.Model().version = 3; + + return std::move(builder).Release(); } } // namespace Expected> SerializeModel(LiteRtModelT&& model) { - LITERT_EXPECT_OK(ModelRepacker::Repack(model)); - - flatbuffers::FlatBufferBuilder b; - auto model_offset = tflite::Model::Pack(b, model.flatbuffer_model.get()); - tflite::FinishModelBuffer(b, model_offset); + // Check if the model has fresh npu stuff. It it does, pop it off + // for post processing after packing to tflite model. + auto maybe_byte_code = PopByteCodeIfNeedsPostProcess(model); - OwningBufferRef buffer; - auto [new_buf, new_size, new_offset] = buffer.GetWeak(); - new_buf = b.ReleaseRaw(new_size, new_offset); - - if (!VerifyFlatbuffer(buffer.Span())) { - return Unexpected(kLiteRtStatusErrorInvalidFlatbuffer); + auto tfl_model = PackAsTflite(model); + if (!tfl_model) { + return tfl_model.Error(); } - return std::move(buffer); -} + auto serialized_tfl = SerializeFlatbuffer(**tfl_model); + if (!VerifyFlatbuffer(serialized_tfl.Span())) { + return Error(kLiteRtStatusErrorInvalidFlatbuffer); + } -Expected> SerializeModel(Model&& model) { - LiteRtModelT* m = model.Get(); - return SerializeModel(std::move(*m)); + if (!maybe_byte_code) { + return serialized_tfl; + } + return AppendByteCode(serialized_tfl, *maybe_byte_code); } } // namespace litert::internal @@ -249,10 +319,10 @@ Expected> SerializeModel(Model&& model) { LiteRtStatus LiteRtSerializeModel(LiteRtModel model, uint8_t** buf, size_t* size, size_t* offset, bool destroy_model) { - auto serialized = - (destroy_model) - ? SerializeModel(::litert::Model::CreateFromOwnedHandle(model)) - : SerializeModel(::litert::Model::CreateFromNonOwnedHandle(model)); + auto serialized = litert::internal::SerializeModel(std::move(*model)); + if (destroy_model) { + delete model; + } if (!serialized) { return serialized.Error().Status(); } diff --git a/tensorflow/lite/experimental/litert/core/model/model_serialize.h b/tensorflow/lite/experimental/litert/core/model/model_serialize.h index 4b6fe69cc6636a..61a4b51b40f1ac 100644 --- a/tensorflow/lite/experimental/litert/core/model/model_serialize.h +++ b/tensorflow/lite/experimental/litert/core/model/model_serialize.h @@ -34,12 +34,9 @@ LiteRtStatus LiteRtSerializeModel(LiteRtModel model, uint8_t** buf, #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model.h" namespace litert::internal { -Expected> SerializeModel(Model&& model); Expected> SerializeModel(LiteRtModelT&& model); } // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/model/model_test.cc b/tensorflow/lite/experimental/litert/core/model/model_test.cc index 764f9a655f878b..5c2327d901db02 100644 --- a/tensorflow/lite/experimental/litert/core/model/model_test.cc +++ b/tensorflow/lite/experimental/litert/core/model/model_test.cc @@ -14,17 +14,21 @@ #include "tensorflow/lite/experimental/litert/core/model/model.h" +#include #include -#include +#include +#include +#include #include #include #include "absl/strings/string_view.h" -#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "absl/types/span.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" -#include "tensorflow/lite/experimental/litert/core/model/model_load.h" -#include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tensorflow/lite/experimental/litert/test/test_macros.h" #include "tensorflow/lite/schema/schema_generated.h" namespace litert::internal { @@ -32,96 +36,244 @@ namespace { using ::testing::ElementsAreArray; -TEST(ModelTest, GetMetadata) { - LiteRtModelT model; - model.flatbuffer_model = std::make_unique(); +// +// Model +// +TEST(ModelTest, GetMetadata) { static constexpr absl::string_view kMetadata = "VALUE"; static constexpr absl::string_view kKey = "KEY"; - LITERT_ASSERT_STATUS_OK( - model.PushMetadata(kKey, OwningBufferRef(kMetadata))); + LiteRtModelT model; + LITERT_ASSERT_STATUS_OK(model.PushMetadata(kKey, kMetadata)); auto found_metadata = model.FindMetadata(kKey); - + ASSERT_TRUE(found_metadata); EXPECT_EQ(found_metadata->StrView(), kMetadata); } TEST(ModelTest, MetadataDNE) { LiteRtModelT model; - model.flatbuffer_model = std::make_unique(); - auto res = model.FindMetadata("FOO"); ASSERT_FALSE(res.HasValue()); } -TEST(ModelOpTest, AddInput) { - LiteRtOpT op; - LiteRtTensorT tensor; +TEST(ModelTest, PopMetadata) { + static constexpr absl::string_view kMetadata = "VALUE"; + static constexpr absl::string_view kKey = "KEY"; - op.AddInput(tensor); + LiteRtModelT model; + LITERT_ASSERT_STATUS_OK(model.PushMetadata(kKey, kMetadata)); - EXPECT_THAT(tensor.users, ElementsAreArray({&op})); - EXPECT_THAT(tensor.user_arg_inds, ElementsAreArray({0})); + auto popped_metadata = model.PopMetadata(kKey); + ASSERT_TRUE(popped_metadata); + EXPECT_EQ(popped_metadata->StrView(), kMetadata); - EXPECT_THAT(op.inputs, ElementsAreArray({&tensor})); + EXPECT_FALSE(model.FindMetadata(kKey)); } -TEST(ModelOpTest, AddOutput) { - LiteRtOpT op; - LiteRtTensorT tensor; +TEST(ModelTest, EmplaceSubgraph) { + LiteRtModelT model; + model.EmplaceSubgraph(); + EXPECT_EQ(model.Subgraphs().size(), 1); +} + +TEST(ModelTest, Signature) { + static constexpr absl::string_view kSignatureName = "MY_SIGNATURE"; - op.AddOutput(tensor); + const std::vector inputs = {"input_1", "input_2"}; + const std::vector outputs = {"output_1"}; - EXPECT_EQ(tensor.defining_op, &op); - EXPECT_EQ(tensor.defining_op_out_ind, 0); + LiteRtModelT model; + auto& subgraph = model.EmplaceSubgraph(); + + auto& signature = model.EmplaceSignature(&subgraph, inputs, outputs, + std::string(kSignatureName)); + + auto found_signature = model.FindSignature(kSignatureName); + ASSERT_TRUE(found_signature); + EXPECT_EQ(found_signature->get(), signature); +} - EXPECT_THAT(op.outputs, ElementsAreArray({&tensor})); +TEST(ModelTest, SignatureDNE) { + static constexpr absl::string_view kSignatureName = "MY_SIGNATURE"; + LiteRtModelT model; + auto found_signature = model.FindSignature(kSignatureName); + EXPECT_FALSE(found_signature); +} + +// +// Subgraph +// + +TEST(ModelSubgraphTest, Input) { + LiteRtTensorT tensor; + LiteRtSubgraphT subgraph; + subgraph.Inputs().push_back(&tensor); + EXPECT_EQ(&subgraph.Input(0), subgraph.Inputs().front()); +} + +TEST(ModelSubgraphTest, Output) { + LiteRtTensorT tensor; + LiteRtSubgraphT subgraph; + subgraph.Outputs().push_back(&tensor); + EXPECT_EQ(&subgraph.Output(0), subgraph.Outputs().front()); } TEST(ModelSubgraphTest, EmplaceTensor) { LiteRtSubgraphT subgraph; auto& tensor = subgraph.EmplaceTensor(); - ASSERT_EQ(subgraph.tensors_storage.size(), 1); - EXPECT_THAT(subgraph.tensors, ElementsAreArray({&tensor})); + ASSERT_EQ(subgraph.Tensors().size(), 1); + EXPECT_THAT(subgraph.Tensors(), ElementsAreArray({&tensor})); } TEST(ModelSubgraphTest, EmplaceOp) { LiteRtSubgraphT subgraph; - auto& tensor = subgraph.EmplaceOp(); - ASSERT_EQ(subgraph.ops_storage.size(), 1); - EXPECT_THAT(subgraph.ops, ElementsAreArray({&tensor})); -} - -TEST(ModelSignatureTest, Basic) { - constexpr absl::string_view kTfliteFile = - "third_party/tensorflow/lite/experimental/litert/test/testdata/" - "simple_model.tflite"; - LiteRtModel model; - auto status = LiteRtCreateModelFromFile(kTfliteFile.data(), &model); - ASSERT_EQ(status, kLiteRtStatusOk); - ASSERT_EQ(model->signatures.size(), 1); - EXPECT_EQ(model->signatures[0]->key, LITERT_DEFAULT_SIGNATURE_KEY); - EXPECT_THAT(model->signatures[0]->input_names, - ElementsAreArray({"arg0", "arg1"})); - EXPECT_THAT(model->signatures[0]->output_names, - ElementsAreArray({"tfl.add"})); - LiteRtDestroyModel(model); -} - -TEST(ModelSignatureTest, Lookup) { - constexpr absl::string_view kTfliteFile = - "third_party/tensorflow/lite/experimental/litert/test/testdata/" - "simple_model.tflite"; - LiteRtModel model; - auto status = LiteRtCreateModelFromFile(kTfliteFile.data(), &model); - ASSERT_EQ(status, kLiteRtStatusOk); - ASSERT_EQ(model->signatures.size(), 1); - auto signature = model->FindSignature(LITERT_DEFAULT_SIGNATURE_KEY); - ASSERT_TRUE(signature); - EXPECT_EQ((*signature)->key, LITERT_DEFAULT_SIGNATURE_KEY); - EXPECT_THAT((*signature)->input_names, ElementsAreArray({"arg0", "arg1"})); - EXPECT_THAT((*signature)->output_names, ElementsAreArray({"tfl.add"})); - LiteRtDestroyModel(model); + auto& op = subgraph.EmplaceOp(); + ASSERT_EQ(subgraph.Ops().size(), 1); + EXPECT_THAT(subgraph.Ops(), ElementsAreArray({&op})); +} + +// +// Op +// + +TEST(ModelOpTest, Input) { + LiteRtOpT op; + LiteRtTensorT tensor; + op.Inputs().push_back(&tensor); + EXPECT_EQ(&op.Input(0), op.Inputs().front()); +} + +TEST(ModelOpTest, Output) { + LiteRtOpT op; + LiteRtTensorT tensor; + op.Outputs().push_back(&tensor); + EXPECT_EQ(&op.Output(0), op.Outputs().front()); +} + +TEST(ModelOpTest, CustomOptions) { + static constexpr absl::string_view kOpts = "OPTIONS"; + + LiteRtOpT op; + op.SetCustomOptions(kOpts); + EXPECT_EQ(op.CustomOptions().StrView(), kOpts); +} + +TEST(ModelOpTest, Options) { + static constexpr auto kOptsType = ::tflite::BuiltinOptions_AddOptions; + + TflOptions options; + options.type = kOptsType; + options.Set(::tflite::AddOptionsT()); + + LiteRtOpT op; + detail::SetTflOptions(op, std::move(options)); + + ASSERT_EQ(detail::GetTflOptions(op).type, kOptsType); +} + +TEST(ModelOpTest, OpCode) { + constexpr static auto kOpCode = kLiteRtOpCodeTflMul; + + LiteRtOpT op; + op.SetOpCode(kOpCode); + EXPECT_EQ(op.OpCode(), kOpCode); +} + +// +// Tensor +// + +TEST(ModelTensorTypeTest, MakeRankedTensorType) { + static constexpr const int32_t kDims[] = {2, 2}; + static constexpr auto kDimsSpan = absl::MakeConstSpan(kDims); + static constexpr auto kElementType = kLiteRtElementTypeFloat32; + const auto tensor_type = MakeRankedTensorType(kElementType, kDimsSpan); + ASSERT_EQ(tensor_type.first, kLiteRtRankedTensorType); + EXPECT_EQ(tensor_type.second.ranked_tensor_type.element_type, kElementType); + const auto& layout = tensor_type.second.ranked_tensor_type.layout; + ASSERT_EQ(layout.rank, kDimsSpan.size()); + EXPECT_THAT(absl::MakeConstSpan(layout.dimensions, kDimsSpan.size()), + ElementsAreArray(kDimsSpan)); +} + +TEST(ModelQuantizationTypeTest, MakePerTensor) { + static constexpr auto kScale = 1.0f; + static constexpr auto kZero = 1L; + const auto quant = MakePerTensorQuantization(kScale, kZero); + ASSERT_EQ(quant.first, kLiteRtQuantizationPerTensor); + const auto& per_tensor = quant.second.per_tensor; + EXPECT_EQ(per_tensor.scale, kScale); + EXPECT_EQ(per_tensor.zero_point, kZero); +} + +TEST(ModelQuantizationTypeTest, MakePerChannel) { + static constexpr std::array kScale = {1.0f, 2.0f}; + static constexpr std::array kZero = {1L, 2L}; + static constexpr int32_t kQdim = 0; + + LiteRtTensorT tensor; + const auto quant = MakePerChannelQuantization( + kScale, kZero, kQdim, + [&tensor](auto s) { return tensor.RequestBuffer(s); }); + + ASSERT_EQ(quant.first, kLiteRtQuantizationPerChannel); + const auto& per_channel = quant.second.per_channel; + + const auto size = per_channel.num_channels; + ASSERT_EQ(size, 2); + EXPECT_EQ(per_channel.quantized_dimension, 0); + + auto scales = absl::MakeConstSpan(per_channel.scales, size); + auto zeros = absl::MakeConstSpan(per_channel.zero_points, size); + + EXPECT_THAT(scales, ElementsAreArray(kScale)); + EXPECT_THAT(zeros, ElementsAreArray(kZero)); +} + +TEST(ModelWeightsTest, WeightsFromBuf) { + static constexpr absl::string_view kData = "some_data"; + + LiteRtWeightsT weights; + weights.SetFromBuf(BufferRef(kData.data(), kData.size())); + EXPECT_EQ(weights.Buf().StrView(), kData); +} + +TEST(ModelTensorTest, Name) { + static constexpr absl::string_view kName = "TENSOR_NAME"; + + LiteRtTensorT tensor; + tensor.SetName(std::string(kName.begin(), kName.end())); + EXPECT_EQ(tensor.Name(), kName); +} + +TEST(ModelTensorTest, Use) { + LiteRtTensorT tensor; + tensor.Users().emplace_back(); + tensor.UserArgInds().push_back(0); + auto [user, ind] = tensor.GetUse(0); + EXPECT_EQ(user, tensor.Users().front()); + EXPECT_EQ(ind, 0); +} + +TEST(ModelTensorTest, DefiningOp) { + LiteRtTensorT tensor; + LiteRtOpT op; + tensor.SetDefiningOp(op, 0); + EXPECT_EQ(tensor.DefiningOp(), &op); + EXPECT_EQ(tensor.DefiningOpOutInd(), 0); +} + +// +// Util +// + +TEST(ModelOpListTest, Push) { + LiteRtOpListT op_list; + LiteRtOpT op; + op_list.Push(&op); + auto vec = op_list.Vec(); + EXPECT_EQ(vec.front(), &op); } } // namespace diff --git a/tensorflow/lite/experimental/litert/core/util/BUILD b/tensorflow/lite/experimental/litert/core/util/BUILD index 8521efd82d8a71..3b519ec91170ce 100644 --- a/tensorflow/lite/experimental/litert/core/util/BUILD +++ b/tensorflow/lite/experimental/litert/core/util/BUILD @@ -22,6 +22,7 @@ cc_library( srcs = ["flatbuffer_tools.cc"], hdrs = [ "flatbuffer_tools.h", + "//tensorflow/lite/experimental/litert/cc:litert_consts.h", ], deps = [ "//tensorflow/compiler/mlir/lite:allocation", @@ -29,9 +30,11 @@ cc_library( "//tensorflow/lite:stderr_reporter", "//tensorflow/lite/experimental/litert/c:litert_common", "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", + "//tensorflow/lite/experimental/litert/cc:litert_detail", "//tensorflow/lite/experimental/litert/cc:litert_expected", "//tensorflow/lite/experimental/litert/core:filesystem", "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@flatbuffers//:runtime_cc", @@ -49,6 +52,7 @@ cc_test( ":flatbuffer_tools", "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/experimental/litert/test:test_macros", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.cc b/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.cc index bc66dbb43530a1..598183614c2dba 100644 --- a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.cc +++ b/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.cc @@ -16,8 +16,11 @@ #include #include +#include #include +#include +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "tensorflow/compiler/mlir/lite/allocation.h" #include "tensorflow/lite/experimental/litert/core/filesystem.h" @@ -156,6 +159,14 @@ Expected> GetTflBuffer(const TflModel& tfl_model, return *buffer; } +Expected GetBuffer(const TflModel& tfl_model, + uint32_t buffer_ind) { + if (buffer_ind >= tfl_model.buffers.size()) { + return Error(kLiteRtStatusErrorIndexOOB); + } + return tfl_model.buffers.at(buffer_ind).get(); +} + Expected TakeBuffer(TflModel& tfl_model, uint32_t buffer_ind) { if (buffer_ind >= tfl_model.buffers.size()) { return Error(kLiteRtStatusErrorIndexOOB); @@ -235,7 +246,7 @@ bool IsCustomQuantized(const TflQuantization* tfl_quantization) { tflite::QuantizationDetails_CustomQuantization; } -Expected> AsPerTensorQparams( +Expected AsPerTensorQparams( const TflQuantization* tfl_quantization) { if (!IsPerTensorQuantized(tfl_quantization)) { return Error(kLiteRtStatusErrorInvalidArgument); @@ -244,6 +255,17 @@ Expected> AsPerTensorQparams( tfl_quantization->scale.front()); } +Expected AsPerChannelQparams( + const TflQuantization* tfl_quantization) { + if (!IsPerChannelQuantized(tfl_quantization)) { + return Error(kLiteRtStatusErrorInvalidArgument); + } + return TflPerChannelQParams(tfl_quantization->quantized_dimension, + tfl_quantization->zero_point.size(), + tfl_quantization->zero_point, + tfl_quantization->scale); +} + ::tflite::Allocation::Ptr MakeAllocation(BufferRef buf) { return std::make_unique<::tflite::MemoryAllocation>( buf.Data(), buf.Size(), ::tflite::DefaultErrorReporter()); @@ -286,4 +308,22 @@ Expected FlatbufferWrapper::CreateFromTflFile( return FlatbufferWrapper::CreateFromBuffer(std::move(*buf)); } +OwningBufferRef SerializeFlatbuffer(const TflModel& tfl_model) { + flatbuffers::FlatBufferBuilder b; + auto model_offset = tflite::Model::Pack(b, &tfl_model); + tflite::FinishModelBuffer(b, model_offset); + + OwningBufferRef buffer; + auto [new_buf, new_size, new_offset] = buffer.GetWeak(); + new_buf = b.ReleaseRaw(new_size, new_offset); + + return buffer; +} + +OwningBufferRef SerializeFlatbuffer( + const FlatbufferWrapper& flatbuffer) { + auto tfl_model = flatbuffer.Unpack(); + return SerializeFlatbuffer(*tfl_model); +} + } // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h b/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h index a9727bf53d51e1..bfeca1e77aa31e 100644 --- a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h +++ b/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h @@ -15,13 +15,21 @@ #ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_FLATBUFFER_TOOLS_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_FLATBUFFER_TOOLS_H_ +#include #include +#include +#include #include +#include +#include +#include #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/mlir/lite/allocation.h" #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" +#include "tensorflow/lite/experimental/litert/cc/litert_consts.h" +#include "tensorflow/lite/experimental/litert/cc/litert_detail.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" #include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -39,14 +47,30 @@ using TflOpCodeEnum = ::tflite::BuiltinOperator; using TflOpCode = ::tflite::OperatorCodeT; using TflQuantization = ::tflite::QuantizationParametersT; using TflElementType = ::tflite::TensorType; +using TflOptions = ::tflite::BuiltinOptionsUnion; +using TflSignature = ::tflite::SignatureDefT; +using TflMetadata = ::tflite::MetadataT; using TflBufferPtr = std::unique_ptr; using TflModelPtr = std::unique_ptr; using TflQuantizationPtr = std::unique_ptr; using TflOpCodePtr = std::unique_ptr; +using TflSubgraphPtr = std::unique_ptr; +using TflTensorPtr = std::unique_ptr; +using TflOpPtr = std::unique_ptr; +using TflSignaturePtr = std::unique_ptr; +using TflMetadataPtr = std::unique_ptr; +// Code and verion. +using TflOpCodeDetail = std::pair; + +// Zero-point, scale. using TflPerTensorQParams = std::pair; +// Quantized dim, num channels, zero-points, scales. +using TflPerChannelQParams = + std::tuple, std::vector>; + // Mirror of all the tensor type related fields in flatbuffer tensor definition. struct TflShapeInfo { // Fixed or dynamic rank. @@ -54,12 +78,12 @@ struct TflShapeInfo { // Basic shape, all elements are non-negative (even if this is a dynamic // shape). - SmallVec shape; + absl::InlinedVector shape; // Dynamic dyn info. If this is not empty, then its length is equal to shape. // If i is a dyn dim, then shape[i] == 1 and shape_signature[i] < 0. Otherwise // shape_signature[i] == shape[i]. - SmallVec shape_signature; + absl::InlinedVector shape_signature; // Convert from a single dims array. Will detect if array is static/dynamic // and populate fields accordingly. @@ -85,10 +109,9 @@ struct TflShapeInfo { // Convert from tensor. explicit TflShapeInfo(const TflTensor& tfl_tensor) : has_rank(tfl_tensor.has_rank), - shape(SmallVec(tfl_tensor.shape.begin(), - tfl_tensor.shape.end())), - shape_signature(SmallVec(tfl_tensor.shape_signature.begin(), - tfl_tensor.shape_signature.end())) {} + shape(tfl_tensor.shape.begin(), tfl_tensor.shape.end()), + shape_signature(tfl_tensor.shape_signature.begin(), + tfl_tensor.shape_signature.end()) {} }; using TflTensorType = std::pair; @@ -138,6 +161,10 @@ Expected> GetTflBuffer(const TflModel& tfl_model, Expected> GetMutableTflBuffer(TflModel& tfl_model, uint32_t buffer_ind); +// Get a non-owning view of tfl buffer if it exists. +Expected GetBuffer(const TflModel& tfl_model, + uint32_t buffer_ind); + // Move and take ownership of the buffer object at given index if it exists. Expected TakeBuffer(TflModel& tfl_model, uint32_t buffer_ind); @@ -145,6 +172,22 @@ Expected TakeBuffer(TflModel& tfl_model, uint32_t buffer_ind); Expected PushTflBuffer(TflModel& tfl_model, BufferRef buffer); +// Make a tflite buffer from data. +template +TflBufferPtr MakeTflBuffer(std::initializer_list data) { + auto res = std::make_unique(); + const auto byte_size = data.size() * sizeof(T); + res->data.resize(byte_size); + for (auto it = data.begin(); it != data.end(); ++it) { + auto* write_to = + reinterpret_cast(res->data.data()) + (it - data.begin()); + *write_to = *it; + } + res->size = res->data.size(); + res->offset = 0; + return res; +} + // Get the op code from the model at the given index if it exists. Expected GetTflOpCode(const TflModel& tfl_model, uint32_t op_code_ind); @@ -179,10 +222,14 @@ bool IsBlockWiseQuantized(const TflQuantization* tfl_quantization); // Does tensor have custom quantization. bool IsCustomQuantized(const TflQuantization* tfl_quantization); -// Get the per-tensor q-params if given tensor has them. +// Get the per-tensor tensor q-params if given tensor has them. Expected AsPerTensorQparams( const TflQuantization* tfl_quantization); +// Get the per-channel tensor q-params if given tensor has them. +Expected AsPerChannelQparams( + const TflQuantization* tfl_quantization); + // Flatbuffer management helpers. // Make a tfl allocation from buffer. @@ -212,9 +259,10 @@ class FlatbufferWrapper { return *fb_model_; } - // Unpacked version of underlying model object. - const TflModel& UnpackedModel() const { return *unpacked_; } - TflModel& UnpackedModel() { return *unpacked_; } + // Unpack the contained flatbuffer. + TflModelPtr Unpack() const { + return TflModelPtr(fb_model_->GetModel()->UnPack()); + } private: FlatbufferWrapper(::tflite::FlatBufferModel::Ptr fb_model, @@ -222,15 +270,18 @@ class FlatbufferWrapper { OwningBufferRef&& model_buf) : fb_model_(std::move(fb_model)), alloc_(std::move(alloc)), - model_buf_(std::forward>(model_buf)), - unpacked_(TflModelPtr(fb_model_->GetModel()->UnPack())) {} + model_buf_(std::forward>(model_buf)) {} ::tflite::FlatBufferModel::Ptr fb_model_; ::tflite::Allocation::Ptr alloc_; OwningBufferRef model_buf_; - TflModelPtr unpacked_; }; +// Re-serialize the unpacked model from flatbuffer wrapper. +OwningBufferRef SerializeFlatbuffer( + const FlatbufferWrapper& flatbuffer); +OwningBufferRef SerializeFlatbuffer(const TflModel& tfl_model); + } // namespace litert::internal #endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_FLATBUFFER_TOOLS_H_ diff --git a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools_test.cc b/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools_test.cc index f70874b482c10c..4d3badc471e587 100644 --- a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools_test.cc +++ b/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools_test.cc @@ -21,6 +21,7 @@ #include "absl/strings/string_view.h" #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" #include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/test/test_macros.h" namespace litert::internal { namespace { @@ -41,31 +42,33 @@ static const absl::string_view kData = "MyData"; TEST(FlatbufferToolsTest, Metadata) { auto flatbuffer = TestFlatbuffer(); ASSERT_NE(flatbuffer, nullptr); + auto tfl_model = flatbuffer->Unpack(); - LITERT_ASSERT_STATUS_OK( - PushMetadata(kKey, flatbuffer->UnpackedModel(), - BufferRef(kData.data(), kData.size()))); + LITERT_ASSERT_STATUS_OK(PushMetadata( + kKey, *tfl_model, BufferRef(kData.data(), kData.size()))); - auto metadata = GetMetadata(kKey, flatbuffer->UnpackedModel()); + auto metadata = GetMetadata(kKey, *tfl_model); ASSERT_TRUE(metadata); EXPECT_EQ(metadata->StrView(), kData); } TEST(FlatbufferToolsTest, GetMetadataNotFound) { auto flatbuffer = TestFlatbuffer(); + auto tfl_model = flatbuffer->Unpack(); ASSERT_NE(flatbuffer, nullptr); - EXPECT_FALSE(GetMetadata(kKey, flatbuffer->UnpackedModel())); + EXPECT_FALSE(GetMetadata(kKey, *tfl_model)); } TEST(FlatbufferToolsTest, TflBuffer) { auto flatbuffer = TestFlatbuffer(); ASSERT_NE(flatbuffer, nullptr); + auto tfl_model = flatbuffer->Unpack(); - auto ind = PushTflBuffer(flatbuffer->UnpackedModel(), + auto ind = PushTflBuffer((*tfl_model), BufferRef(kData.data(), kData.size())); ASSERT_TRUE(ind); - auto buf = GetTflBuffer(flatbuffer->UnpackedModel(), *ind); + auto buf = GetTflBuffer((*tfl_model), *ind); ASSERT_TRUE(buf); ASSERT_EQ(buf->StrView(), kData); } @@ -73,30 +76,34 @@ TEST(FlatbufferToolsTest, TflBuffer) { TEST(FlatbufferToolsTest, GetTflBufferNotFound) { auto flatbuffer = TestFlatbuffer(); ASSERT_NE(flatbuffer, nullptr); + auto tfl_model = flatbuffer->Unpack(); - auto buf = GetTflBuffer(flatbuffer->UnpackedModel(), 100); + auto buf = GetTflBuffer((*tfl_model), 100); ASSERT_FALSE(buf); } TEST(FlatbufferToolsTest, GetTflOpCode) { auto flatbuffer = TestFlatbuffer(); ASSERT_NE(flatbuffer, nullptr); + auto tfl_model = flatbuffer->Unpack(); - auto op_code = GetTflOpCode(flatbuffer->UnpackedModel(), 0); + auto op_code = GetTflOpCode((*tfl_model), 0); ASSERT_TRUE(op_code); } TEST(FlatbufferToolsTest, GetTflOpCodeNotFound) { auto flatbuffer = TestFlatbuffer(); ASSERT_NE(flatbuffer, nullptr); + auto tfl_model = flatbuffer->Unpack(); - auto op_code = GetTflOpCode(flatbuffer->UnpackedModel(), 100); + auto op_code = GetTflOpCode((*tfl_model), 100); ASSERT_FALSE(op_code); } TEST(FlatbufferToolsTest, StaticTensorTypeTest) { auto flatbuffer = TestFlatbuffer(); - auto& tensor = flatbuffer->UnpackedModel().subgraphs.front()->tensors.front(); + auto tfl_model = flatbuffer->Unpack(); + auto& tensor = tfl_model->subgraphs.front()->tensors.front(); TflShapeInfo shape(*tensor); @@ -111,7 +118,8 @@ TEST(FlatbufferToolsTest, StaticTensorTypeTest) { TEST(FlatbufferToolsTest, UnrankedTensorTypeTest) { auto flatbuffer = TestFlatbuffer("unranked_tensor.tflite"); - auto& tensor = flatbuffer->UnpackedModel().subgraphs.front()->tensors.front(); + auto tfl_model = flatbuffer->Unpack(); + auto& tensor = tfl_model->subgraphs.front()->tensors.front(); TflShapeInfo shape(*tensor); @@ -120,7 +128,8 @@ TEST(FlatbufferToolsTest, UnrankedTensorTypeTest) { TEST(FlatbufferToolsTest, RankedDynamicTensorTypeTest) { auto flatbuffer = TestFlatbuffer("dynamic_shape_tensor.tflite"); - auto& tensor = flatbuffer->UnpackedModel().subgraphs.front()->tensors.front(); + auto tfl_model = flatbuffer->Unpack(); + auto& tensor = tfl_model->subgraphs.front()->tensors.front(); TflShapeInfo shape(*tensor); @@ -136,7 +145,8 @@ TEST(FlatbufferToolsTest, RankedDynamicTensorTypeTest) { TEST(FlatbufferToolsTest, PerTensorQuantizedTest) { auto flatbuffer = TestFlatbuffer("single_add_default_a16w8_recipe_quantized.tflite"); - auto& tensor = flatbuffer->UnpackedModel().subgraphs.front()->tensors.front(); + auto tfl_model = flatbuffer->Unpack(); + auto& tensor = tfl_model->subgraphs.front()->tensors.front(); const auto* const q_parms = tensor->quantization.get(); @@ -147,5 +157,19 @@ TEST(FlatbufferToolsTest, PerTensorQuantizedTest) { ASSERT_TRUE(per_tensor); } +TEST(FlatbufferToolsTest, PerChannelQuantizedTest) { + auto flatbuffer = TestFlatbuffer("static_w8_a16_quantized_k_einsum.tflite"); + auto tfl_model = flatbuffer->Unpack(); + auto& tensor = tfl_model->subgraphs.front()->tensors[1]; + + const auto* const q_parms = tensor->quantization.get(); + + ASSERT_TRUE(IsQuantized(q_parms)); + EXPECT_TRUE(IsPerChannelQuantized(q_parms)); + + auto per_channel = AsPerChannelQparams(q_parms); + ASSERT_TRUE(per_channel); +} + } // namespace } // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/runtime/BUILD b/tensorflow/lite/experimental/litert/runtime/BUILD index e122f465855cd1..3f12ca299fff21 100644 --- a/tensorflow/lite/experimental/litert/runtime/BUILD +++ b/tensorflow/lite/experimental/litert/runtime/BUILD @@ -17,6 +17,21 @@ package( default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], ) +cc_library( + name = "event", + srcs = [ + "event.cc", + ], + hdrs = [ + "event.h", + ], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/cc:litert_expected", + ], +) + cc_library( name = "tensor_buffer", srcs = [ @@ -40,8 +55,10 @@ cc_library( ], deps = [ "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_event", "//tensorflow/lite/experimental/litert/c:litert_logging", "//tensorflow/lite/experimental/litert/c:litert_model", + "//tensorflow/lite/experimental/litert/cc:litert_event", "//tensorflow/lite/experimental/litert/cc:litert_expected", "//tensorflow/lite/experimental/litert/core/util:tensor_type_util", "@com_google_absl//absl/base:core_headers", @@ -106,20 +123,25 @@ cc_library( "//tensorflow/lite/c:c_api_opaque", "//tensorflow/lite/c:common", "//tensorflow/lite/core:cc_api_stable", + "//tensorflow/lite/delegates/utils:simple_opaque_delegate", "//tensorflow/lite/experimental/litert/c:litert_common", "//tensorflow/lite/experimental/litert/c:litert_compiled_model_options", "//tensorflow/lite/experimental/litert/c:litert_dispatch_delegate", + "//tensorflow/lite/experimental/litert/c:litert_logging", "//tensorflow/lite/experimental/litert/c:litert_model", "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", "//tensorflow/lite/experimental/litert/cc:litert_detail", + "//tensorflow/lite/experimental/litert/cc:litert_event", "//tensorflow/lite/experimental/litert/cc:litert_expected", "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer_requirements", + "//tensorflow/lite/experimental/litert/compiler/plugin:compiler_plugin", "//tensorflow/lite/experimental/litert/core/model", "//tensorflow/lite/experimental/litert/core/model:model_serialize", "//tensorflow/lite/kernels:builtin_ops", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", ], ) @@ -130,8 +152,13 @@ cc_test( data = [ "//tensorflow/lite/experimental/litert/test:testdata/simple_model.tflite", ], + linkopts = select({ + "//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), deps = [ ":compiled_model", + ":tensor_buffer", "//tensorflow/lite:framework", "//tensorflow/lite/c:c_api_opaque", "//tensorflow/lite/c:common", @@ -139,9 +166,12 @@ cc_test( "//tensorflow/lite/experimental/litert/c:litert_compiled_model_options", "//tensorflow/lite/experimental/litert/c:litert_model", "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", + "//tensorflow/lite/experimental/litert/cc:litert_environment", "//tensorflow/lite/experimental/litert/cc:litert_expected", "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", "//tensorflow/lite/experimental/litert/core/model", + "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/experimental/litert/test:simple_model", "//tensorflow/lite/kernels:builtin_ops", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings:string_view", diff --git a/tensorflow/lite/experimental/litert/runtime/compiled_model.cc b/tensorflow/lite/experimental/litert/runtime/compiled_model.cc index bfe61e6700d540..a3b24dffb00f39 100644 --- a/tensorflow/lite/experimental/litert/runtime/compiled_model.cc +++ b/tensorflow/lite/experimental/litert/runtime/compiled_model.cc @@ -17,10 +17,18 @@ #include #include #include +#include #include #include #include +#include "tensorflow/lite/experimental/litert/cc/litert_event.h" + +#if defined(__ANDROID__) +#include +#endif + +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/mlir/lite/allocation.h" #include "tensorflow/lite/c/common.h" @@ -28,6 +36,7 @@ #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_compiled_model_options.h" #include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" #include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" #include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" @@ -36,6 +45,7 @@ #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" #include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" #include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" +#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" #include "tensorflow/lite/experimental/litert/core/model/model.h" #include "tensorflow/lite/experimental/litert/core/model/model_serialize.h" #include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" @@ -45,9 +55,11 @@ #include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/stderr_reporter.h" +using litert::Error; using litert::Expected; -using litert::SmallVec; +using litert::OwningBufferRef; using litert::TensorBuffer; +using litert::TensorBufferScopedLock; using litert::Unexpected; using litert::internal::ExternalLiteRtBufferContext; @@ -62,8 +74,8 @@ Expected LiteRtCompiledModelT::Initialize() { signature_keys_ = interp_->signature_keys(); if (signature_keys_.empty()) { - static std::string* default_signature_key = - new std::string(LITERT_DEFAULT_SIGNATURE_KEY); + static auto* default_signature_key = + new std::string(LiteRtSignatureT::kDefaultSignatureKey); signature_keys_.push_back(default_signature_key); } // Register the ExternalLiteRtBufferContext for TensorBuffer handshaking. @@ -76,58 +88,87 @@ Expected LiteRtCompiledModelT::Initialize() { } Expected LiteRtCompiledModelT::Create( - LiteRtModel model, LiteRtComplicationOptions complication_options) { - auto runtime = std::make_unique(); + LiteRtModel model, LiteRtCompilationOptions compilation_options) { + auto compiled_model = std::make_unique(); + + std::optional> new_flatbuffer; + // TODO: b/379317134 - Support other delegates with compilation options. + if (compilation_options != kLiteRtHwAccelatorNone) { + LITERT_LOG(LITERT_INFO, "Applying compiler plugins..."); + if (auto result = + litert::internal::ApplyPlugins(model, compilation_options); + !result) { + LITERT_LOG(LITERT_WARNING, "Failed to apply compiler plugins: %s", + result.Error().Message().data()); + } else { + if (result->num_applied_plugins > 0) { + LITERT_LOG(LITERT_INFO, "Successfully applied %d compiler plugins: %s", + result->num_applied_plugins, + result->success_message.c_str()); + new_flatbuffer = std::move(result->new_flatbuffer); + } + if (!result->error_message.empty()) { + LITERT_LOG(LITERT_WARNING, "Some compiler plugins failed to apply: %s", + result->error_message.c_str()); + } + } + } const char* model_buffer = nullptr; size_t model_buffer_size = 0; // The following code gets the original FB pointer from LiteRtModel. // TODO b/383120429 - Use a better way of getting the FB pointer. - if (model->model_buffer) { + if (new_flatbuffer) { + model_buffer = reinterpret_cast(new_flatbuffer->Data()); + model_buffer_size = new_flatbuffer->Size(); + } else if (auto init_model_buffer = detail::GetTflInitFlatbuffer(*model); + init_model_buffer.Size() != 0) { // Use the saved the original FB pointer when the LiteRtModel was created // from a buffer. - model_buffer = reinterpret_cast(model->model_buffer); - model_buffer_size = model->model_buffer_size; + model_buffer = init_model_buffer.StrData(); + model_buffer_size = init_model_buffer.Size(); } else { // TODO b/383120429 - Once LiteRtModel provide tflite::Model object, switch // to use it to initialize Interpreter instead of serializing LiteRtModel. - auto [data, size, offset] = runtime->model_buf_.GetWeak(); + auto [data, size, offset] = compiled_model->model_buf_.GetWeak(); if (LiteRtSerializeModel(model, &data, &size, &offset, /*destroy_model=*/false) != kLiteRtStatusOk) { return Unexpected(kLiteRtStatusErrorRuntimeFailure); } - runtime->alloc_ = std::make_unique( - runtime->model_buf_.Data(), runtime->model_buf_.Size(), + compiled_model->alloc_ = std::make_unique( + compiled_model->model_buf_.Data(), compiled_model->model_buf_.Size(), tflite::DefaultErrorReporter()); - model_buffer = reinterpret_cast(runtime->alloc_->base()); - model_buffer_size = runtime->alloc_->bytes(); + model_buffer = + reinterpret_cast(compiled_model->alloc_->base()); + model_buffer_size = compiled_model->alloc_->bytes(); } - runtime->fb_model_ = + compiled_model->fb_model_ = tflite::FlatBufferModel::BuildFromBuffer(model_buffer, model_buffer_size); - if (runtime->fb_model_ == nullptr) { + if (compiled_model->fb_model_ == nullptr) { return Unexpected(kLiteRtStatusErrorFileIO); } - if (auto res = runtime->Initialize(); !res.HasValue()) { + if (auto res = compiled_model->Initialize(); !res.HasValue()) { return Unexpected(kLiteRtStatusErrorRuntimeFailure); } - // TODO: b/379317134 - Support other delegates with compilation options. - if (complication_options & kHwAccelNpu) { - auto dispatch_delegate_options = litert::CreateDispatchDelegateOptionsPtr(); - LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), - model_buffer); - auto dispatch_delegate = - litert::CreateDispatchDelegatePtr(std::move(dispatch_delegate_options)); - if (auto status = - runtime->interp_->ModifyGraphWithDelegate(dispatch_delegate.get()); - status != kTfLiteOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to modify graph with delegate"); - } + // Apply the dispatch delegate, unconditionally, since the loaded model may + // have been compiled for NPU at AOT. + auto dispatch_delegate_options = litert::CreateDispatchDelegateOptionsPtr(); + LiteRtDispatchDelegateAddAllocBaseOption(dispatch_delegate_options.get(), + model_buffer); + auto dispatch_delegate = + litert::CreateDispatchDelegatePtr(std::move(dispatch_delegate_options)); + if (auto status = compiled_model->interp_->ModifyGraphWithDelegate( + dispatch_delegate.get()); + status != kTfLiteOk) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to modify graph with delegate"); } - return runtime; + compiled_model->RegisterDelegate(std::move(dispatch_delegate)); + + return compiled_model; } litert::Expected @@ -200,76 +241,167 @@ tflite::SignatureRunner* LiteRtCompiledModelT::GetSignatureRunner( if (signature_runners_.contains(signature_key)) { return signature_runners_[signature_key]; } - auto runner = - interp_->GetSignatureRunner(signature_key == LITERT_DEFAULT_SIGNATURE_KEY - ? nullptr - : std::string(signature_key).c_str()); + auto runner = interp_->GetSignatureRunner( + signature_key == LiteRtSignatureT::kDefaultSignatureKey + ? nullptr + : std::string(signature_key).c_str()); signature_runners_[signature_key] = runner; return runner; } +Expected LiteRtCompiledModelT::RegisterBuffer( + tflite::SignatureRunner* runner, const TfLiteTensor* tensor, + const char* tensor_name, LiteRtTensorBuffer buffer, bool is_input, + std::vector& scoped_locks) { + bool backend_requires_cpu_buffer = false; + + auto requirements = buffer_context_->GetBufferRequirement(tensor); + if (requirements) { + auto supported_types = (*requirements)->SupportedTypes(); + if (!supported_types) { + return supported_types.Error(); + } + + for (auto& type : *supported_types) { + if (type == buffer->buffer_type()) { + // Register tensor buffer if it can be used by the backend. + buffer->Duplicate(); + TensorBuffer duplicated_buffer(buffer); + if (auto status = buffer_context_->RegisterTensorBuffer( + tensor, std::move(duplicated_buffer)); + status != kLiteRtStatusOk) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to register tensor buffer"); + } + return {}; + } + if (type == kLiteRtTensorBufferTypeHostMemory) { + backend_requires_cpu_buffer = true; + } + } + } else { + // If the BufferRequirement is not registered, assumes the backend requires + // CPU buffer. + backend_requires_cpu_buffer = true; + } + + if (backend_requires_cpu_buffer) { + // When backend requires CPU buffer. + bool buffer_is_cpu_compatible = + buffer->buffer_type() == kLiteRtTensorBufferTypeHostMemory; +#if defined(__ANDROID__) + if (buffer->buffer_type() == kLiteRtTensorBufferTypeAhwb) { + if (__builtin_available(android 26, *)) { + auto ahwb = buffer->GetAhwbBuffer(); + if (ahwb) { + // TODO: b/382330322 - Update logic to check if the AHWB (stride) is + // CPU compatible. + AHardwareBuffer_Desc desc; + AHardwareBuffer_describe(*ahwb, &desc); + buffer_is_cpu_compatible = true; + } + } + } +#endif + if (buffer_is_cpu_compatible) { + auto lock_and_addr = TensorBufferScopedLock::Create(buffer); + if (!lock_and_addr) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + absl::StrCat("Failed to lock input tensor buffer: ", + lock_and_addr.Error().Message())); + } + scoped_locks.push_back(std::move(lock_and_addr->first)); + TfLiteCustomAllocation custom_allocation{lock_and_addr->second, + tensor->bytes}; + if (is_input) { + runner->SetCustomAllocationForInputTensor(tensor_name, + custom_allocation, + /*flags=*/0); + } else { + runner->SetCustomAllocationForOutputTensor(tensor_name, + custom_allocation, + /*flags=*/0); + } + return {}; + } + } + // TODO: b/382330322 - Add buffer conversion logic instead of returning error. + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + "The given buffer type is not supported."); +} + Expected LiteRtCompiledModelT::Run( absl::string_view signature_key, - std::vector& input_buffers, - std::vector& output_buffers) { + const std::vector& input_buffers, + const std::vector& output_buffers) { auto runner = GetSignatureRunner(signature_key); if (runner == nullptr) { return Unexpected(kLiteRtStatusErrorNotFound, "Failed to get signature runner"); } - if (input_buffers.size() != runner->input_names().size()) { + size_t num_inputs = input_buffers.size(); + if (num_inputs != runner->input_names().size()) { return Unexpected(kLiteRtStatusErrorRuntimeFailure, "Input buffer size mismatch"); } - if (output_buffers.size() != runner->output_names().size()) { + size_t num_outputs = output_buffers.size(); + if (num_outputs != runner->output_names().size()) { return Unexpected(kLiteRtStatusErrorRuntimeFailure, "Output buffer size mismatch"); } - for (int i = 0; i < runner->input_names().size(); ++i) { + // In general output buffer events are assigned by the runtime and not the + // caller; here we check for any violation of that condition. + for (auto litert_output_buffer : output_buffers) { + if (litert_output_buffer->HasEvent()) { + return Error(kLiteRtStatusErrorInvalidArgument, + "Output buffers cannot have events attached"); + } + } + + // TODO: If input buffers have events, we wait on them before we launch the + // inference. This is inefficient when using HW acceleration, since in that + // case it would be best to make the HW accelerator wait for those events as + // opposed to blocking the CPU here. + for (auto input_buffer : input_buffers) { + if (input_buffer->HasEvent()) { + auto litert_event = input_buffer->GetEvent(); + if (!litert_event) { + return litert_event.Error(); + } + litert::Event event(*litert_event, /*owned=*/false); + if (auto status = event.Wait(/*timeout_in_ms=*/-1); !status) { + return status.Error(); + } + } + } + + std::vector scoped_locks; + scoped_locks.reserve(num_inputs + num_outputs); + for (int i = 0; i < num_inputs; ++i) { const auto& input_name = runner->input_names()[i]; auto* input_tensor = runner->input_tensor(input_name); - if (input_buffers[i]->buffer_type() == kLiteRtTensorBufferTypeHostMemory) { - // Assign CPU buffer via CustomAllocation. - TensorBuffer cpu_buffer(input_buffers[i], /*owned=*/false); - auto lock_and_addr = litert::TensorBufferScopedLock::Create(cpu_buffer); - TfLiteCustomAllocation custom_allocation{lock_and_addr->second, - input_tensor->bytes}; - runner->SetCustomAllocationForInputTensor(input_name, custom_allocation, - /*flags=*/0); - } else { - // Register tensor buffer for non CPU buffers. - input_buffers[i]->Duplicate(); - TensorBuffer duplicated_buffer(input_buffers[i]); - if (auto status = buffer_context_->RegisterTensorBuffer( - input_tensor, std::move(duplicated_buffer)); - status != kLiteRtStatusOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to register input tensor buffer"); - } + auto res = + RegisterBuffer(runner, input_tensor, input_name, input_buffers[i], + /*is_input=*/true, scoped_locks); + if (!res) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + absl::StrCat("Failed to register input tensor buffer: ", + res.Error().Message())); } } for (int i = 0; i < runner->output_names().size(); ++i) { const auto& output_name = runner->output_names()[i]; auto* output_tensor = runner->output_tensor(output_name); - if (output_buffers[i]->buffer_type() == kLiteRtTensorBufferTypeHostMemory) { - // Assign CPU buffer via CustomAllocation. - TensorBuffer cpu_buffer(output_buffers[i], /*owned=*/false); - auto lock_and_addr = litert::TensorBufferScopedLock::Create(cpu_buffer); - TfLiteCustomAllocation custom_allocation{lock_and_addr->second, - output_tensor->bytes}; - runner->SetCustomAllocationForOutputTensor(output_name, custom_allocation, - /*flags=*/0); - } else { - output_buffers[i]->Duplicate(); - TensorBuffer duplicated_buffer(output_buffers[i]); - if (auto status = buffer_context_->RegisterTensorBuffer( - output_tensor, std::move(duplicated_buffer)); - status != kLiteRtStatusOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to register output tensor buffer"); - } + auto res = + RegisterBuffer(runner, output_tensor, output_name, output_buffers[i], + /*is_input=*/false, scoped_locks); + if (!res) { + return Unexpected( + kLiteRtStatusErrorRuntimeFailure, + absl::StrCat("Failed to register output tensor buffer: ", + res.Error().Message())); } } diff --git a/tensorflow/lite/experimental/litert/runtime/compiled_model.h b/tensorflow/lite/experimental/litert/runtime/compiled_model.h index 821d4de2919649..9398c98aa946bc 100644 --- a/tensorflow/lite/experimental/litert/runtime/compiled_model.h +++ b/tensorflow/lite/experimental/litert/runtime/compiled_model.h @@ -24,6 +24,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/mlir/lite/allocation.h" +#include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_compiled_model_options.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" @@ -31,8 +32,10 @@ #include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" #include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" #include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" +#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/model_builder.h" @@ -48,7 +51,7 @@ class LiteRtCompiledModelT { // The model is loaded into memory and the caller takes ownership of the // returned object. static litert::Expected Create( - LiteRtModel model, LiteRtComplicationOptions complication_options); + LiteRtModel model, LiteRtCompilationOptions compilation_options); // Returns the buffer requirements for the n-th input tensor. The returned // LiteRtTensorBufferRequirements is used to create the input tensor @@ -88,9 +91,10 @@ class LiteRtCompiledModelT { // Runs the model of the given signature with the provided input/output // litert::TensorBuffers. - litert::Expected Run(absl::string_view signature_key, - std::vector& input_buffers, - std::vector& output_buffers); + litert::Expected Run( + absl::string_view signature_key, + const std::vector& input_buffers, + const std::vector& output_buffers); // The same as Run() for C API. litert::Expected RunCApi(size_t signature_index, @@ -112,6 +116,20 @@ class LiteRtCompiledModelT { // If the signature key is not found, returns nullptr. tflite::SignatureRunner* GetSignatureRunner(absl::string_view signature_key); + // Registers the TensorBuffer for the given tensor with the SignatureRunner. + // If the TensorBuffer can be directly consumed as CPU Tensors, they'll be + // locked and use it with CustomAllocation. The buffer is locked by + // LiteRtTensorBufferScopedLock and kept in the `scoped_locks`. It will be + // unlocked automatically when the `scoped_locks` are destroyed. + litert::Expected RegisterBuffer( + tflite::SignatureRunner* runner, const TfLiteTensor* tensor, + const char* tensor_name, LiteRtTensorBuffer buffer, bool is_input, + std::vector& scoped_locks); + + void RegisterDelegate(tflite::TfLiteOpaqueDelegateUniquePtr&& delegate) { + delegates_.push_back(std::move(delegate)); + } + // Map from signature key to SignatureRunner. This is used to lazy calling // GetSignatureRunner() which is expensive. absl::flat_hash_map @@ -137,6 +155,8 @@ class LiteRtCompiledModelT { // Interpreter. std::unique_ptr buffer_context_; + + std::vector delegates_; }; #endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_COMPILED_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/compiled_model_test.cc b/tensorflow/lite/experimental/litert/runtime/compiled_model_test.cc index 724c83444f262b..45730efb511c26 100644 --- a/tensorflow/lite/experimental/litert/runtime/compiled_model_test.cc +++ b/tensorflow/lite/experimental/litert/runtime/compiled_model_test.cc @@ -33,6 +33,9 @@ #include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" #include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" #include "tensorflow/lite/experimental/litert/core/model/model.h" +#include "tensorflow/lite/experimental/litert/runtime/tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" namespace litert { namespace { @@ -40,102 +43,109 @@ namespace { using ::testing::FloatNear; using ::testing::Pointwise; +// Creates input buffers for the given LiteRtTensorBufferType and size. Expected> CreateInputBuffers( - LiteRtModel& model, LiteRtCompiledModelT& compiled_model, - absl::string_view signature_key) { + LiteRtModel& model, absl::string_view signature_key, + LiteRtTensorBufferType buffer_type, size_t bytes) { std::vector input_buffers; - auto subgraph = model->FindSubgraph(signature_key); - auto& input_tensors = (*subgraph)->inputs; - size_t num_inputs = input_tensors.size(); + auto* subgraph = *LookupSubgraph(*model, signature_key); + auto& input_tensors = subgraph->Inputs(); + const size_t num_inputs = subgraph->NumInputs(); input_buffers.reserve(num_inputs); for (int i = 0; i < num_inputs; ++i) { - auto litert_input_buffer_requirements = - compiled_model.GetInputBufferRequirements(signature_key, i); - if (!litert_input_buffer_requirements.HasValue()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - litert_input_buffer_requirements.Error().Message()); - } - TensorBufferRequirements input_buffer_requirements = - TensorBufferRequirements(*litert_input_buffer_requirements, - /*owned=*/false); - auto ranked_tensor_type = input_tensors[i]->type_detail.ranked_tensor_type; - LiteRtTensorBufferType tensor_buffer_type = - input_buffer_requirements.SupportedTypes()->at(0); + const auto& ranked_tensor_type = + input_tensors[i]->Type().second.ranked_tensor_type; LiteRtTensorBuffer input_buffer; if (auto status = LiteRtCreateManagedTensorBuffer( - tensor_buffer_type, &ranked_tensor_type, - input_buffer_requirements.BufferSize().Value(), &input_buffer); + buffer_type, &ranked_tensor_type, bytes, &input_buffer); status != kLiteRtStatusOk) { return Unexpected(status, "Failed to create input tensor buffer"); } input_buffers.push_back(input_buffer); } - return std::move(input_buffers); } -Expected> CreateOutputBuffers( +// Creates input buffers for the given LiteRtCompiledModelT by leveraging +// TensorBufferRequirements. +Expected> CreateInputBuffers( LiteRtModel& model, LiteRtCompiledModelT& compiled_model, absl::string_view signature_key) { - std::vector output_buffers; + auto litert_input_buffer_requirements = + compiled_model.GetInputBufferRequirements(signature_key, 0); + if (!litert_input_buffer_requirements.HasValue()) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + litert_input_buffer_requirements.Error().Message()); + } + TensorBufferRequirements input_buffer_requirements = + TensorBufferRequirements(*litert_input_buffer_requirements, + /*owned=*/false); + LiteRtTensorBufferType tensor_buffer_type = + input_buffer_requirements.SupportedTypes()->at(0); + + return CreateInputBuffers(model, signature_key, tensor_buffer_type, + input_buffer_requirements.BufferSize().Value()); +} - auto subgraph = model->FindSubgraph(signature_key); - auto& output_tensors = (*subgraph)->outputs; - size_t num_outputs = output_tensors.size(); +// Creates output buffers for the given LiteRtTensorBufferType and size. +Expected> CreateOutputBuffers( + LiteRtModel& model, absl::string_view signature_key, + LiteRtTensorBufferType buffer_type, size_t bytes) { + std::vector output_buffers; + auto* subgraph = *LookupSubgraph(*model, signature_key); + auto& output_tensors = subgraph->Outputs(); + size_t num_outputs = subgraph->NumOutputs(); output_buffers.reserve(num_outputs); for (int i = 0; i < num_outputs; ++i) { - auto litert_output_buffer_requirements = - compiled_model.GetOutputBufferRequirements(signature_key, i); - if (!litert_output_buffer_requirements.HasValue()) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - litert_output_buffer_requirements.Error().Message()); - } - TensorBufferRequirements output_buffer_requirements = - TensorBufferRequirements(*litert_output_buffer_requirements, - /*owned=*/false); - auto ranked_tensor_type = output_tensors[i]->type_detail.ranked_tensor_type; - LiteRtTensorBufferType tensor_buffer_type = - output_buffer_requirements.SupportedTypes()->at(0); + auto ranked_tensor_type = + output_tensors[i]->Type().second.ranked_tensor_type; LiteRtTensorBuffer output_buffer; if (auto status = LiteRtCreateManagedTensorBuffer( - tensor_buffer_type, &ranked_tensor_type, - output_buffer_requirements.BufferSize().Value(), &output_buffer); + buffer_type, &ranked_tensor_type, bytes, &output_buffer); status != kLiteRtStatusOk) { return Unexpected(status, "Failed to create output tensor buffer"); } output_buffers.push_back(output_buffer); } - return std::move(output_buffers); } -constexpr const float kTestInput0Tensor[] = {1, 2}; -constexpr const size_t kTestInput0Size = - sizeof(kTestInput0Tensor) / sizeof(kTestInput0Tensor[0]); -constexpr const float kTestInput1Tensor[] = {10, 20}; -constexpr const size_t kTestInput1Size = - sizeof(kTestInput1Tensor) / sizeof(kTestInput1Tensor[0]); -constexpr const float kTestOutputTensor[] = {11, 22}; -constexpr const size_t kTestOutputSize = - sizeof(kTestOutputTensor) / sizeof(kTestOutputTensor[0]); - -static constexpr absl::string_view kTfliteFile = - "third_party/tensorflow/lite/experimental/litert/test/testdata/" - "simple_model.tflite"; +// Creates output buffers for the given LiteRtCompiledModelT by leveraging +// TensorBufferRequirements. +Expected> CreateOutputBuffers( + LiteRtModel& model, LiteRtCompiledModelT& compiled_model, + absl::string_view signature_key) { + auto litert_output_buffer_requirements = + compiled_model.GetOutputBufferRequirements(signature_key, 0); + if (!litert_output_buffer_requirements.HasValue()) { + return Unexpected(kLiteRtStatusErrorRuntimeFailure, + litert_output_buffer_requirements.Error().Message()); + } + TensorBufferRequirements output_buffer_requirements = + TensorBufferRequirements(*litert_output_buffer_requirements, + /*owned=*/false); + LiteRtTensorBufferType tensor_buffer_type = + output_buffer_requirements.SupportedTypes()->at(0); + return CreateOutputBuffers(model, signature_key, tensor_buffer_type, + output_buffer_requirements.BufferSize().Value()); +} TEST(CompiledModelTest, Basic) { + auto path = testing::GetTestFilePath(kModelFileName); + LiteRtModel model; - auto status = LiteRtCreateModelFromFile(kTfliteFile.data(), &model); - ASSERT_EQ(status, kLiteRtStatusOk); + ASSERT_EQ(LiteRtCreateModelFromFile(path.c_str(), &model), kLiteRtStatusOk); - auto res_compiled_model = LiteRtCompiledModelT::Create(model, kHwAccelCpu); - ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; + auto res_compiled_model = + LiteRtCompiledModelT::Create(model, kLiteRtHwAccelatorCpu); + ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel: " + << res_compiled_model.Error().Message(); auto& compiled_model = **res_compiled_model; - auto& signatures = model->signatures; + auto signatures = model->Signatures(); ASSERT_EQ(signatures.size(), 1); - auto signature_key = signatures[0]->key; - EXPECT_EQ(signature_key, LITERT_DEFAULT_SIGNATURE_KEY); + auto signature_key = signatures[0]->Key(); + EXPECT_EQ(signature_key, LiteRtSignatureT::kDefaultSignatureKey); auto input_buffers_res = CreateInputBuffers(model, compiled_model, signature_key); @@ -148,7 +158,7 @@ TEST(CompiledModelTest, Basic) { auto output_buffers = std::move(*output_buffers_res); // Fill model inputs. - auto input_names = signatures[0]->input_names; + auto& input_names = signatures[0]->InputNames(); EXPECT_EQ(input_names.size(), 2); EXPECT_EQ(input_names.at(0), "arg0"); EXPECT_EQ(input_names.at(1), "arg1"); @@ -169,21 +179,102 @@ TEST(CompiledModelTest, Basic) { compiled_model.Run(signature_key, input_buffers, output_buffers); // Check model output. - auto output_names = signatures[0]->output_names; + auto output_names = signatures[0]->OutputNames(); EXPECT_EQ(output_names.size(), 1); EXPECT_EQ(output_names.at(0), "tfl.add"); - auto& output_buffer = output_buffers[0]; { - TensorBuffer cpu_buffer(output_buffer, /*owned=*/false); - float output_buffer_data[kTestOutputSize]; - auto output_span = absl::MakeSpan(output_buffer_data, kTestOutputSize); - auto read_success = cpu_buffer.Read(output_span); + void* host_mem_addr; + ASSERT_EQ(LiteRtLockTensorBuffer(output_buffers[0], &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + auto output = absl::MakeSpan(static_cast(host_mem_addr), + kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(output_buffers[0]), kLiteRtStatusOk); + } + + // Since Buffers in LiteRtTensorBuffer, we need to destroy them explicitly. + for (auto& input_buffer : input_buffers) { + LiteRtDestroyTensorBuffer(input_buffer); + } + for (auto& output_buffer : output_buffers) { + LiteRtDestroyTensorBuffer(output_buffer); + } + + LiteRtDestroyModel(model); +} +TEST(CompiledModelTest, UseAhwbBuffer) { +#if !defined(__ANDROID__) + GTEST_SKIP() << "The rest of this test is specific to Android devices"; +#endif + auto path = testing::GetTestFilePath(kModelFileName); + LiteRtModel model; + ASSERT_EQ(LiteRtCreateModelFromFile(path.c_str(), &model), kLiteRtStatusOk); + + auto res_compiled_model = + LiteRtCompiledModelT::Create(model, kLiteRtHwAccelatorCpu); + ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; + auto& compiled_model = **res_compiled_model; + + auto signatures = model->Signatures(); + ASSERT_EQ(signatures.size(), 1); + auto signature_key = signatures[0]->Key(); + EXPECT_EQ(signature_key, LiteRtSignatureT::kDefaultSignatureKey); + + auto input_buffers_res = + CreateInputBuffers(model, signature_key, kLiteRtTensorBufferTypeAhwb, + sizeof(float) * kTestInput0Size); + EXPECT_TRUE(input_buffers_res); + auto input_buffers = std::move(*input_buffers_res); + + auto output_buffers_res = + CreateOutputBuffers(model, signature_key, kLiteRtTensorBufferTypeAhwb, + sizeof(float) * kTestOutputSize); + EXPECT_TRUE(output_buffers_res); + auto output_buffers = std::move(*output_buffers_res); + + // Fill model inputs. + auto input_names = signatures[0]->InputNames(); + EXPECT_EQ(input_names.size(), 2); + EXPECT_EQ(input_names.at(0), "arg0"); + EXPECT_EQ(input_names.at(1), "arg1"); + auto& input_0_buffer = input_buffers[0]; + EXPECT_EQ(input_0_buffer->buffer_type(), kLiteRtTensorBufferTypeAhwb); + { + TensorBuffer ahwb_buffer(input_0_buffer, /*owned=*/false); + ahwb_buffer.Write( + absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size)); + } + auto& input_1_buffer = input_buffers[1]; + { + TensorBuffer ahwb_buffer(input_1_buffer, /*owned=*/false); + ahwb_buffer.Write( + absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size)); + } + + // Execute model. + compiled_model.Run(signature_key, input_buffers, output_buffers); + + // Check model output. + auto output_names = signatures[0]->OutputNames(); + EXPECT_EQ(output_names.size(), 1); + EXPECT_EQ(output_names.at(0), "tfl.add"); + { + void* host_mem_addr; + ASSERT_EQ(LiteRtLockTensorBuffer(output_buffers[0], &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + auto output = absl::MakeSpan(static_cast(host_mem_addr), + kTestOutputSize); for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << "Result: " << output_span.at(i) << "\t" - << kTestOutputTensor[i]; + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; } - EXPECT_THAT(output_span, Pointwise(FloatNear(1e-5), kTestOutputTensor)); + EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(output_buffers[0]), kLiteRtStatusOk); } // Since Buffers in LiteRtTensorBuffer, we need to destroy them explicitly. diff --git a/tensorflow/lite/experimental/litert/runtime/compiler/BUILD b/tensorflow/lite/experimental/litert/runtime/compiler/BUILD index ff5c489dd70d18..dc6013689c391e 100644 --- a/tensorflow/lite/experimental/litert/runtime/compiler/BUILD +++ b/tensorflow/lite/experimental/litert/runtime/compiler/BUILD @@ -35,20 +35,56 @@ cc_test( "//tensorflow/lite/c:common", "//tensorflow/lite/experimental/litert/c:litert_common", "//tensorflow/lite/experimental/litert/c:litert_dispatch_delegate", - "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", + "//tensorflow/lite/experimental/litert/cc:litert_compiled_model", + "//tensorflow/lite/experimental/litert/cc:litert_environment", "//tensorflow/lite/experimental/litert/cc:litert_model", "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", - "//tensorflow/lite/experimental/litert/compiler/plugin:algo", "//tensorflow/lite/experimental/litert/compiler/plugin:compiler_plugin", - "//tensorflow/lite/experimental/litert/core/model:model_buffer", - "//tensorflow/lite/experimental/litert/core/model:model_serialize", "//tensorflow/lite/experimental/litert/runtime:external_litert_buffer_context", "//tensorflow/lite/experimental/litert/test:common", "//tensorflow/lite/experimental/litert/test:simple_model_npu", + "//tensorflow/lite/experimental/litert/test:test_macros", "//tensorflow/lite/kernels:builtin_ops", "@com_google_absl//absl/log", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "jit_compilation_mediatek_test", + srcs = ["jit_compilation_mediatek_test.cc"], + data = [ + "//tensorflow/lite/experimental/litert/test:simple_model", + "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler:compiler_plugin_so", + "//tensorflow/lite/experimental/litert/vendors/mediatek/dispatch:dispatch_api_so", + ], + linkopts = select({ + "//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_opaque", + "//tensorflow/lite/c:common", + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_dispatch_delegate", + "//tensorflow/lite/experimental/litert/cc:litert_compiled_model", + "//tensorflow/lite/experimental/litert/cc:litert_environment", + "//tensorflow/lite/experimental/litert/cc:litert_model", + "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", + "//tensorflow/lite/experimental/litert/compiler/plugin:compiler_plugin", + "//tensorflow/lite/experimental/litert/runtime:external_litert_buffer_context", + "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/experimental/litert/test:simple_model_npu", + "//tensorflow/lite/experimental/litert/test:test_macros", + "//tensorflow/lite/kernels:builtin_ops", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_mediatek_test.cc b/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_mediatek_test.cc new file mode 100644 index 00000000000000..b30d0ce8f1fa51 --- /dev/null +++ b/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_mediatek_test.cc @@ -0,0 +1,106 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include +#include +#include "absl/log/absl_log.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/lite/c/c_api_opaque.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" +#include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" +#include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/test/test_macros.h" +#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model_builder.h" +#include "tensorflow/lite/signature_runner.h" + +constexpr const char* kCompilerPluginLibSearchPath = "/data/local/tmp"; + +using testing::FloatNear; +using testing::Pointwise; + +TEST(JitCompilation, MediaTek) { + const std::array environment_options = { + litert::Environment::Option{ + /*.tag=*/litert::Environment::OptionTag::CompilerPluginLibraryPath, + /*.value=*/kCompilerPluginLibSearchPath, + }, + }; + ASSERT_TRUE(litert::Environment::Create(environment_options)); + + auto model_path = litert::testing::GetTestFilePath(kModelFileName); + auto model = litert::Model::CreateFromFile(model_path); + ASSERT_TRUE(model); + + auto num_signatures = model->GetNumSignatures(); + ASSERT_EQ(num_signatures, 1); + +#if !defined(__ANDROID__) + GTEST_SKIP() << "The rest of this test is specific to Android devices with a " + "MediaTek NPU"; +#endif + + auto compiled_model = + litert::CompiledModel::Create(*model, kLiteRtHwAccelatorNpu); + ASSERT_TRUE(compiled_model); + + auto input_buffers = + compiled_model->CreateInputBuffers(/*signature_index=*/0); + ASSERT_TRUE(input_buffers); + EXPECT_EQ(input_buffers->size(), 2); + + auto output_buffers = + compiled_model->CreateOutputBuffers(/*signature_index=*/0); + ASSERT_TRUE(output_buffers); + EXPECT_EQ(output_buffers->size(), 1); + + ASSERT_TRUE((*input_buffers)[0].Write( + absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); + ASSERT_TRUE((*input_buffers)[1].Write( + absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); + + // Execute model. + compiled_model->Run(/*signature_index=*/0, *input_buffers, *output_buffers); + + // Check model output. + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create( + (*output_buffers)[0]); + ASSERT_TRUE(lock_and_addr); + auto output = absl::MakeSpan(lock_and_addr->second, kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); + } + + litert::Environment::Destroy(); +} diff --git a/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_qualcomm_test.cc b/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_qualcomm_test.cc index 7b6373649aff9d..68d93f7df82c28 100644 --- a/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_qualcomm_test.cc +++ b/tensorflow/lite/experimental/litert/runtime/compiler/jit_compilation_qualcomm_test.cc @@ -18,19 +18,24 @@ #include #include +#include #include #include "absl/log/absl_log.h" #include "absl/log/log.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/lite/c/c_api_opaque.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_environment.h" #include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" #include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" -#include "tensorflow/lite/experimental/litert/core/model/model_buffer.h" #include "tensorflow/lite/experimental/litert/runtime/external_litert_buffer_context.h" #include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/test/test_macros.h" #include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" @@ -39,99 +44,63 @@ constexpr const char* kCompilerPluginLibSearchPath = "/data/local/tmp"; +using testing::FloatNear; +using testing::Pointwise; + TEST(JitCompilation, Qualcomm) { + const std::array environment_options = { + litert::Environment::Option{ + /*.tag=*/litert::Environment::OptionTag::CompilerPluginLibraryPath, + /*.value=*/kCompilerPluginLibSearchPath, + }, + }; + ASSERT_TRUE(litert::Environment::Create(environment_options)); + auto model_path = litert::testing::GetTestFilePath(kModelFileName); auto model = litert::Model::CreateFromFile(model_path); ASSERT_TRUE(model); + auto num_signatures = model->GetNumSignatures(); + ASSERT_EQ(num_signatures, 1); + #if !defined(__ANDROID__) GTEST_SKIP() << "The rest of this test is specific to Android devices with a " "Qualcomm HTP"; #endif - constexpr const std::array - compiler_plugin_lib_search_paths = {kCompilerPluginLibSearchPath}; - auto compiler_plugin = litert::internal::CompilerPlugin::LoadPlugin( - compiler_plugin_lib_search_paths, "Qualcomm"); - ASSERT_TRUE(compiler_plugin); - - auto api_version = compiler_plugin->ApiVersion(); - ASSERT_TRUE(api_version); - - ABSL_LOG(INFO) << "Found compiler plugin with version " << api_version->major - << "." << api_version->minor << "." << api_version->patch; - - auto npu_bytecode = ApplyPlugin(*compiler_plugin, *model); - EXPECT_TRUE(npu_bytecode); - EXPECT_GT(npu_bytecode->Size(), 0); - - auto serialized_model = litert::internal::GetModelBufWithByteCode( - std::move(*model->Get()), *npu_bytecode); - EXPECT_TRUE(serialized_model); - - model = litert::Model::CreateFromBuffer(*serialized_model); - - auto flatbuffer_model = tflite::FlatBufferModel::BuildFromBuffer( - reinterpret_cast(serialized_model->Data()), - serialized_model->Size()); - - EXPECT_TRUE(flatbuffer_model != nullptr); + auto compiled_model = + litert::CompiledModel::Create(*model, kLiteRtHwAccelatorNpu); + ASSERT_TRUE(compiled_model); - tflite::Interpreter::Ptr interpreter = nullptr; - tflite::ops::builtin::BuiltinOpResolver resolver; - tflite::InterpreterBuilder(*flatbuffer_model, resolver)(&interpreter); - EXPECT_TRUE(interpreter != nullptr); + auto input_buffers = + compiled_model->CreateInputBuffers(/*signature_index=*/0); + ASSERT_TRUE(input_buffers); + EXPECT_EQ(input_buffers->size(), 2); - EXPECT_EQ(interpreter->nodes_size(), 1); - EXPECT_EQ(interpreter->inputs().size(), 2); - EXPECT_EQ(interpreter->outputs().size(), 1); - ASSERT_EQ(interpreter->execution_plan().size(), 1); + auto output_buffers = + compiled_model->CreateOutputBuffers(/*signature_index=*/0); + ASSERT_TRUE(output_buffers); + EXPECT_EQ(output_buffers->size(), 1); - litert::internal::ExternalLiteRtBufferContext buffer_context; - interpreter->SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context); + ASSERT_TRUE((*input_buffers)[0].Write( + absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size))); + ASSERT_TRUE((*input_buffers)[1].Write( + absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size))); - auto dispatch_delegate_options = litert::CreateDispatchDelegateOptionsPtr(); - LiteRtDispatchDelegateAddAllocBaseOption( - dispatch_delegate_options.get(), flatbuffer_model->allocation()->base()); - auto dispatch_delegate = - litert::CreateDispatchDelegatePtr(std::move(dispatch_delegate_options)); - - ASSERT_EQ(interpreter->ModifyGraphWithDelegate(dispatch_delegate.get()), - kTfLiteOk); - - // Get the list of signatures and check it. - auto signature_defs = interpreter->signature_keys(); - ASSERT_EQ(signature_defs.size(), 0); - - tflite::impl::SignatureRunner* runner = - interpreter->GetSignatureRunner(/*signature_key=*/nullptr); - ASSERT_NE(runner, nullptr); - - EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); - - // Fill model inputs. - ASSERT_STREQ(runner->input_names()[0], "arg0"); - auto input_0_tensor = runner->input_tensor("arg0"); - ASSERT_NE(input_0_tensor, nullptr); - auto* input_0 = input_0_tensor->data.f; - std::memcpy(input_0, kTestInput0Tensor, sizeof(kTestInput0Tensor)); - - ASSERT_STREQ(runner->input_names()[1], "arg1"); - auto input_1_tensor = runner->input_tensor("arg1"); - ASSERT_NE(input_1_tensor, nullptr); - auto* input_1 = input_1_tensor->data.f; - std::memcpy(input_1, kTestInput1Tensor, sizeof(kTestInput1Tensor)); - - EXPECT_EQ(runner->Invoke(), kTfLiteOk); + // Execute model. + compiled_model->Run(/*signature_index=*/0, *input_buffers, *output_buffers); // Check model output. - auto output_tensor = runner->output_tensor(runner->output_names()[0]); - ASSERT_NE(output_tensor, nullptr); - auto* output = output_tensor->data.f; - for (auto i = 0; i < kTestOutputSize; ++i) { - ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; - } - for (auto i = 0; i < kTestOutputSize; ++i) { - EXPECT_NEAR(output[i], kTestOutputTensor[i], 1e-5); + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create( + (*output_buffers)[0]); + ASSERT_TRUE(lock_and_addr); + auto output = absl::MakeSpan(lock_and_addr->second, kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor)); } + + litert::Environment::Destroy(); } diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/BUILD b/tensorflow/lite/experimental/litert/runtime/dispatch/BUILD index 30f8050f7fabde..a016875563c339 100644 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/BUILD +++ b/tensorflow/lite/experimental/litert/runtime/dispatch/BUILD @@ -27,7 +27,9 @@ cc_library( "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_api.h", ], deps = [ + "//tensorflow/lite/experimental/litert/c:litert_any", "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_event", "//tensorflow/lite/experimental/litert/c:litert_logging", "//tensorflow/lite/experimental/litert/c:litert_model", "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", @@ -53,6 +55,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/core/c:c_api_opaque_without_op_resolver", "//tensorflow/lite/delegates/utils:simple_opaque_delegate", + "//tensorflow/lite/experimental/litert/c:litert_any", "//tensorflow/lite/experimental/litert/c:litert_common", "//tensorflow/lite/experimental/litert/c:litert_logging", "//tensorflow/lite/experimental/litert/c:litert_model", @@ -63,7 +66,7 @@ cc_library( "//tensorflow/lite/experimental/litert/cc:litert_model", "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", "//tensorflow/lite/experimental/litert/core:byte_code_util", - "//tensorflow/lite/experimental/litert/core/util:tensor_type_util", + "//tensorflow/lite/experimental/litert/core:environment", "//tensorflow/lite/experimental/litert/runtime:external_litert_buffer_context", "//tensorflow/lite/experimental/litert/runtime:tfl_utils", "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate.cc b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate.cc index ba8f7187a7ca45..f18c72ab8342a8 100644 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate.cc +++ b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate.cc @@ -136,6 +136,9 @@ void LiteRtDestroyDispatchDelegateOptions( TfLiteDelegate* LiteRtCreateDispatchDelegate( LiteRtDispatchDelegateOptions* options) { + if (!options) { + options = LiteRtCreateDefaultDispatchDelegateOptions(); + } return DispatchDelegate::Create(options); } diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_google_tensor_test.cc b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_google_tensor_test.cc index 82f4d61f97dea1..7701b908a49c1a 100644 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_google_tensor_test.cc +++ b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_google_tensor_test.cc @@ -80,7 +80,7 @@ TEST(DispatchDelegate, GoogleTensorCpuBuffer) { // Get the list of signatures and check it. auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 0); + ASSERT_EQ(signature_defs.size(), 1); tflite::impl::SignatureRunner* runner = interpreter.GetSignatureRunner(/*signature_key=*/nullptr); @@ -186,7 +186,7 @@ TEST(DispatchDelegate, GoogleTensorHwBuffer) { // Get the list of signatures and check it. auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 0); + ASSERT_EQ(signature_defs.size(), 1); tflite::impl::SignatureRunner* runner = interpreter.GetSignatureRunner(/*signature_key=*/nullptr); @@ -234,7 +234,7 @@ TEST(DispatchDelegate, CompiledModel) { "GoogleTensor eTPU"; #endif - auto res_compiled_model = CompiledModel::Create(*model, kHwAccelNpu); + auto res_compiled_model = CompiledModel::Create(*model); ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; auto& compiled_model = *res_compiled_model; diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.cc b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.cc index 21afed952952be..b59b7b5a461c6e 100644 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.cc +++ b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_kernel.cc @@ -576,7 +576,7 @@ TfLiteStatus DispatchDelegateKernel::RegisterLiteRtTensorBuffers( TfLiteStatus DispatchDelegateKernel::Eval(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) { if (auto status = RegisterLiteRtTensorBuffers(context, node); - status != kLiteRtStatusOk) { + status != kTfLiteOk) { LITERT_LOG(LITERT_ERROR, "Failed to register tensor buffers: %d", status); return kTfLiteError; } diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_mediatek_test.cc b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_mediatek_test.cc index 954c327ab85bd3..a7bb0c52ef6b70 100644 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_mediatek_test.cc +++ b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_mediatek_test.cc @@ -80,7 +80,7 @@ TEST(DispatchDelegate, MediaTekCpuBuffer) { // Get the list of signatures and check it. auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 0); + ASSERT_EQ(signature_defs.size(), 1); tflite::impl::SignatureRunner* runner = interpreter.GetSignatureRunner(/*signature_key=*/nullptr); @@ -186,7 +186,7 @@ TEST(DispatchDelegate, MediaTekHwBuffer) { // Get the list of signatures and check it. auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 0); + ASSERT_EQ(signature_defs.size(), 1); tflite::impl::SignatureRunner* runner = interpreter.GetSignatureRunner(/*signature_key=*/nullptr); @@ -234,7 +234,7 @@ TEST(DispatchDelegate, CompiledModel) { "MediaTek NPU"; #endif - auto res_compiled_model = CompiledModel::Create(*model, kHwAccelNpu); + auto res_compiled_model = CompiledModel::Create(*model); ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; auto& compiled_model = *res_compiled_model; diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h index af4e8d3046e44e..030c022db1fd4a 100644 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h +++ b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_options.h @@ -25,14 +25,45 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_any.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tensorflow/lite/experimental/litert/c/litert_environment.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" #include "tensorflow/lite/experimental/litert/cc/litert_any.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/core/environment.h" #include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" class LiteRtDispatchDelegateOptions { public: + LiteRtDispatchDelegateOptions() { + auto environment = litert::internal::Environment::Instance(); + if (!environment) { + LITERT_LOG(LITERT_WARNING, "LiteRT environment not found"); + return; + } + + auto option = + (*environment)->GetOption(kLiteRtEnvOptionTagDispatchLibraryPath); + if (!option.has_value()) { + return; + } + + if (option->type != kLiteRtAnyTypeString) { + LITERT_LOG(LITERT_WARNING, + "Ingoring option kLiteRtEnvOptionTagDispatchLibraryPath due " + "to invalid value"); + return; + } + + LiteRtDispatchOption dispatch_option = { + /*.name=*/kDispatchOptionSharedLibraryDir, + /*.value=*/*option, + }; + AddOption(dispatch_option); + } + // Push a new dispatch option. void AddOption(LiteRtDispatchOption option) { options_.push_back(option); } diff --git a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_qualcomm_test.cc b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_qualcomm_test.cc index 211180e322bf75..e97aaec3c646bf 100644 --- a/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_qualcomm_test.cc +++ b/tensorflow/lite/experimental/litert/runtime/dispatch/dispatch_delegate_qualcomm_test.cc @@ -79,7 +79,7 @@ TEST(DispatchDelegate, QualcommCpuBuffer) { // Get the list of signatures and check it. auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 0); + ASSERT_EQ(signature_defs.size(), 1); tflite::impl::SignatureRunner* runner = interpreter.GetSignatureRunner(/*signature_key=*/nullptr); @@ -185,7 +185,7 @@ TEST(DispatchDelegate, QualcommHwBuffer) { // Get the list of signatures and check it. auto signature_defs = interpreter.signature_keys(); - ASSERT_EQ(signature_defs.size(), 0); + ASSERT_EQ(signature_defs.size(), 1); tflite::impl::SignatureRunner* runner = interpreter.GetSignatureRunner(/*signature_key=*/nullptr); @@ -233,7 +233,7 @@ TEST(DispatchDelegate, CompiledModel) { "Qualcomm HTP"; #endif - auto res_compiled_model = CompiledModel::Create(*model, kHwAccelNpu); + auto res_compiled_model = CompiledModel::Create(*model); ASSERT_TRUE(res_compiled_model) << "Failed to initialize CompiledModel"; auto& compiled_model = *res_compiled_model; diff --git a/tensorflow/lite/experimental/litert/runtime/event.cc b/tensorflow/lite/experimental/litert/runtime/event.cc index 74b3ee72999c78..12b4458823df03 100644 --- a/tensorflow/lite/experimental/litert/runtime/event.cc +++ b/tensorflow/lite/experimental/litert/runtime/event.cc @@ -24,8 +24,12 @@ #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" -LiteRtStatus LiteRtEventT::Wait(int64_t timeout_in_ms) { +using litert::Error; +using litert::Expected; + +Expected LiteRtEventT::Wait(int64_t timeout_in_ms) { #if LITERT_HAS_SYNC_FENCE_SUPPORT struct pollfd fds = { .fd = fd, @@ -38,21 +42,19 @@ LiteRtStatus LiteRtEventT::Wait(int64_t timeout_in_ms) { if (ret == 1) { break; } else if (ret == 0) { - LITERT_LOG(LITERT_WARNING, "Timeout expired: %d", timeout_in_ms); - return kLiteRtStatusErrorTimeoutExpired; + return Error(kLiteRtStatusErrorTimeoutExpired, "Timeout expired"); } } while (ret == -1 && (errno == EINTR || errno == EAGAIN)); if (ret < 0) { - LITERT_LOG(LITERT_ERROR, "Error waiting for fence: %s", ::strerror(errno)); - return kLiteRtStatusErrorRuntimeFailure; + return Error(kLiteRtStatusErrorRuntimeFailure, "Error waiting for fence"); } - return kLiteRtStatusOk; + return {}; #else - LITERT_LOG(LITERT_ERROR, "LiteRtEventWait not implemented for this platform"); - return kLiteRtStatusErrorUnsupported; + return Error(kLiteRtStatusErrorUnsupported, + "LiteRtEventWait not implemented for this platform"); #endif } diff --git a/tensorflow/lite/experimental/litert/runtime/event.h b/tensorflow/lite/experimental/litert/runtime/event.h index e2ca93974cb3f0..8cc665e95f2ae1 100644 --- a/tensorflow/lite/experimental/litert/runtime/event.h +++ b/tensorflow/lite/experimental/litert/runtime/event.h @@ -18,14 +18,15 @@ #include #include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" struct LiteRtEventT { #if LITERT_HAS_SYNC_FENCE_SUPPORT - int fd; - bool owns_fd; + int fd = -1; + bool owns_fd = false; #endif ~LiteRtEventT(); - LiteRtStatus Wait(int64_t timeout_in_ms); + litert::Expected Wait(int64_t timeout_in_ms); }; #endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_EVENT_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/BUILD b/tensorflow/lite/experimental/litert/runtime/opencl/BUILD new file mode 100644 index 00000000000000..727f1e9faf84a1 --- /dev/null +++ b/tensorflow/lite/experimental/litert/runtime/opencl/BUILD @@ -0,0 +1,90 @@ +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "cl_command_queue", + srcs = [ + "cl_command_queue.cc", + ], + hdrs = [ + "cl_command_queue.h", + ], + deps = [ + ":cl_context", + ":cl_device", + ":opencl_wrapper", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@opencl_headers", + ], +) + +cc_library( + name = "cl_device", + srcs = [ + "cl_device.cc", + ], + hdrs = [ + "cl_device.h", + ], + deps = [ + ":opencl_wrapper", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@opencl_headers", + ], +) + +cc_library( + name = "cl_context", + srcs = [ + "cl_context.cc", + ], + hdrs = [ + "cl_context.h", + ], + deps = [ + ":cl_device", + ":opencl_wrapper", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@opencl_headers", + ], +) + +cc_library( + name = "opencl_wrapper", + srcs = [ + "opencl_wrapper.cc", + ], + hdrs = [ + "opencl_wrapper.h", + ], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@opencl_headers", + ], +) + +cc_library( + name = "buffer", + srcs = [ + "buffer.cc", + ], + hdrs = [ + "buffer.h", + ], + deps = [ + ":cl_command_queue", + ":cl_context", + ":opencl_wrapper", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@opencl_headers", + ], +) diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/buffer.cc b/tensorflow/lite/experimental/litert/runtime/opencl/buffer.cc new file mode 100644 index 00000000000000..57d831e030d7c1 --- /dev/null +++ b/tensorflow/lite/experimental/litert/runtime/opencl/buffer.cc @@ -0,0 +1,116 @@ +// Copyright 2024 The TensorFlow Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is a copy of third_party/ml_drift/cl/buffer.cc. +#include "tensorflow/lite/experimental/litert/runtime/opencl/buffer.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "third_party/opencl_headers/CL/cl.h" +#include "third_party/opencl_headers/CL/cl_platform.h" +#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" +#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" + +namespace litert { +namespace cl { +namespace { +absl::Status CreateClBuffer(cl_context context, int size_in_bytes, + bool read_only, void* data, cl_mem* result) { + cl_mem_flags flags = read_only ? CL_MEM_READ_ONLY : CL_MEM_READ_WRITE; + if (data) { + flags |= CL_MEM_COPY_HOST_PTR; + } + cl_int error_code; + *result = clCreateBuffer(context, flags, size_in_bytes, data, &error_code); + if (!*result) { + return absl::UnknownError( + absl::StrCat("Failed to allocate device memory (clCreateBuffer): ", + std::to_string(error_code))); + } + return absl::OkStatus(); +} +absl::Status CreateBuffer(size_t size_in_bytes, bool gpu_read_only, + const void* data, ClContext* context, + Buffer* result) { + cl_mem buffer; + auto status = CreateClBuffer(context->context(), size_in_bytes, gpu_read_only, + const_cast(data), &buffer); + if (!status.ok()) { + return status; + } + *result = Buffer(buffer, size_in_bytes); + + return absl::OkStatus(); +} +} // namespace + +Buffer::Buffer(cl_mem buffer, size_t size_in_bytes, bool is_sub_buffer) + : buffer_(buffer), size_(size_in_bytes), is_sub_buffer_(is_sub_buffer) {} + +Buffer::Buffer(cl_mem buffer) + : buffer_(buffer), size_(0), is_sub_buffer_(false), owner_(false) {} + +Buffer::Buffer(Buffer&& buffer) + : buffer_(buffer.buffer_), + size_(buffer.size_), + is_sub_buffer_(buffer.is_sub_buffer_), + owner_(buffer.owner_) { + buffer.buffer_ = nullptr; + buffer.size_ = 0; + buffer.is_sub_buffer_ = false; +} + +Buffer& Buffer::operator=(Buffer&& buffer) { + if (this != &buffer) { + Release(); + std::swap(size_, buffer.size_); + std::swap(buffer_, buffer.buffer_); + std::swap(is_sub_buffer_, buffer.is_sub_buffer_); + std::swap(owner_, buffer.owner_); + } + return *this; +} + +void Buffer::Release() { + if (owner_ && buffer_) { + clReleaseMemObject(buffer_); + buffer_ = nullptr; + size_ = 0; + is_sub_buffer_ = false; + } +} + +Buffer CreateBufferShared(cl_mem buffer) { return Buffer(buffer); } + +absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, ClContext* context, + Buffer* result) { + return CreateBuffer(size_in_bytes, true, nullptr, context, result); +} + +absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data, + ClContext* context, Buffer* result) { + return CreateBuffer(size_in_bytes, true, data, context, result); +} + +absl::Status CreateReadWriteBuffer(size_t size_in_bytes, ClContext* context, + Buffer* result) { + return CreateBuffer(size_in_bytes, false, nullptr, context, result); +} + +} // namespace cl +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/buffer.h b/tensorflow/lite/experimental/litert/runtime/opencl/buffer.h new file mode 100644 index 00000000000000..e9b8d877641f45 --- /dev/null +++ b/tensorflow/lite/experimental/litert/runtime/opencl/buffer.h @@ -0,0 +1,116 @@ +// Copyright 2024 The TensorFlow Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is a copy of third_party/ml_drift/cl/buffer.h. +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_BUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_BUFFER_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "third_party/opencl_headers/CL/cl.h" +#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h" +#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" + +namespace litert { +namespace cl { + +// Buffer represent linear GPU data storage with arbitrary data format. +// Buffer is moveable but not copyable. +class Buffer { + public: + Buffer() = default; // just for using Buffer as a class members + Buffer(cl_mem buffer, size_t size_in_bytes, bool is_sub_buffer = false); + explicit Buffer(cl_mem buffer); + + // Move only + Buffer(Buffer&& buffer); + Buffer& operator=(Buffer&& buffer); + Buffer(const Buffer&) = delete; + Buffer& operator=(const Buffer&) = delete; + + ~Buffer() { Release(); } + + // for profiling and memory statistics + uint64_t GetMemorySizeInBytes() const { return size_; } + + cl_mem GetMemoryPtr() const { return buffer_; } + + bool IsSubBuffer() const { return is_sub_buffer_; } + + // Writes data to a buffer. Data should point to a region that + // has exact size in bytes as size_in_bytes(constructor parameter). + template + absl::Status WriteData(ClCommandQueue* queue, absl::Span data); + + // Reads data from Buffer into CPU memory. + template + absl::Status ReadData(ClCommandQueue* queue, std::vector* result) const; + + private: + void Release(); + + cl_mem buffer_ = nullptr; + size_t size_ = 0; + bool is_sub_buffer_ = false; + bool owner_ = true; +}; + +Buffer CreateBufferShared(cl_mem buffer); + +absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, ClContext* context, + Buffer* result); + +absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data, + ClContext* context, Buffer* result); + +absl::Status CreateReadWriteBuffer(size_t size_in_bytes, ClContext* context, + Buffer* result); + +absl::Status CreateReadWriteSubBuffer(const Buffer& parent, + size_t origin_in_bytes, + size_t size_in_bytes, ClContext* context, + Buffer* result); + +template +absl::Status Buffer::WriteData(ClCommandQueue* queue, + const absl::Span data) { + if (sizeof(T) * data.size() > size_) { + return absl::InvalidArgumentError( + "absl::Span data size is greater from buffer allocated size."); + } + RETURN_IF_ERROR(queue->EnqueueWriteBuffer(buffer_, size_, data.data())); + return absl::OkStatus(); +} + +template +absl::Status Buffer::ReadData(ClCommandQueue* queue, + std::vector* result) const { + if (size_ % sizeof(T) != 0) { + return absl::UnknownError("Wrong element size(typename T is not correct?"); + } + + const int elements_count = size_ / sizeof(T); + result->resize(elements_count); + + return queue->EnqueueReadBuffer(buffer_, size_, result->data()); +} + +} // namespace cl +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.cc b/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.cc new file mode 100644 index 00000000000000..c194671848f344 --- /dev/null +++ b/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.cc @@ -0,0 +1,141 @@ +// Copyright 2024 The TensorFlow Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is a copy of third_party/ml_drift/cl/cl_command_queue.cc. +#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "third_party/opencl_headers/CL/cl.h" +#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" +#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" +#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" + +namespace litert { +namespace cl { +namespace { + +absl::StatusOr CreateClCommandQueueWithProperties( + const ClDevice& device, const ClContext& context, + cl_command_queue_properties queue_properties) { + int error_code; + cl_command_queue queue; + if (clCreateCommandQueueWithProperties) { + std::vector props; + if (queue_properties != 0) { + props.push_back(CL_QUEUE_PROPERTIES); + props.push_back(queue_properties); + } + props.push_back(0); + + queue = clCreateCommandQueueWithProperties(context.context(), device.id(), + props.data(), &error_code); + } else { + // Backwards compatibility for OpenCL versions before 2.0. + queue = clCreateCommandQueue(context.context(), device.id(), + queue_properties, &error_code); + } + if (!queue) { + return absl::UnknownError(absl::StrCat( + "Failed to create a command queue - ", std::to_string(error_code))); + } + return queue; +} + +} // namespace + +ClCommandQueue::ClCommandQueue() = default; + +ClCommandQueue::ClCommandQueue(cl_command_queue queue, bool has_ownership) + : queue_(queue), has_ownership_(has_ownership) {} + +ClCommandQueue::ClCommandQueue(ClCommandQueue&& queue) + : queue_(queue.queue_), has_ownership_(queue.has_ownership_) { + queue.queue_ = nullptr; +} + +ClCommandQueue& ClCommandQueue::operator=(ClCommandQueue&& queue) { + if (this != &queue) { + Release(); + std::swap(queue_, queue.queue_); + has_ownership_ = queue.has_ownership_; + } + return *this; +} + +ClCommandQueue::~ClCommandQueue() { Release(); } + +void ClCommandQueue::Release() { + if (has_ownership_ && queue_) { + clReleaseCommandQueue(queue_); + queue_ = nullptr; + } +} + +absl::Status ClCommandQueue::EnqueueWriteBuffer(cl_mem memory, + size_t size_in_bytes, + const void* data, bool async) { + const cl_bool blocking = async ? CL_FALSE : CL_TRUE; + auto error_code = clEnqueueWriteBuffer( + queue_, memory, blocking, 0, size_in_bytes, data, 0, nullptr, nullptr); + if (error_code != CL_SUCCESS) { + return absl::UnknownError( + absl::StrCat("Failed to upload data to GPU (clEnqueueWriteBuffer) - ", + std::to_string(error_code))); + } + return absl::OkStatus(); +} + +absl::Status ClCommandQueue::EnqueueReadBuffer(cl_mem memory, + size_t size_in_bytes, void* data, + bool async) { + const cl_bool blocking = async ? CL_FALSE : CL_TRUE; + auto error_code = clEnqueueReadBuffer( + queue_, memory, blocking, 0, size_in_bytes, data, 0, nullptr, nullptr); + if (error_code != CL_SUCCESS) { + return absl::UnknownError( + absl::StrCat("Failed to read data from GPU (clEnqueueReadBuffer) - ", + std::to_string(error_code))); + } + return absl::OkStatus(); +} + +absl::Status ClCommandQueue::WaitForCompletion() { + auto error_code = clFinish(queue_); + if (error_code != CL_SUCCESS) { + return absl::UnknownError( + absl::StrCat("Failed to clFinish - ", std::to_string(error_code))); + } + return absl::OkStatus(); +} + +absl::Status CreateClCommandQueue(const ClDevice& device, + const ClContext& context, + ClCommandQueue* result) { + auto queue = CreateClCommandQueueWithProperties(device, context, 0); + if (!queue.ok()) { + return queue.status(); + } + *result = ClCommandQueue(*queue, true); + return absl::OkStatus(); +} + +} // namespace cl +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h b/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h new file mode 100644 index 00000000000000..4149e5b0dbb33d --- /dev/null +++ b/tensorflow/lite/experimental/litert/runtime/opencl/cl_command_queue.h @@ -0,0 +1,82 @@ +// Copyright 2024 The TensorFlow Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is a copy of third_party/ml_drift/cl/cl_command_queue.h. +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_COMMAND_QUEUE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_COMMAND_QUEUE_H_ + +#include +#include + +#include "absl/status/status.h" +#include "third_party/opencl_headers/CL/cl.h" +#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" +#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" + +namespace litert { +namespace cl { + +// A wrapper around opencl command queue +class ClCommandQueue { + public: + ClCommandQueue(); + ClCommandQueue(cl_command_queue queue, bool has_ownership); + + // Move only + ClCommandQueue(ClCommandQueue&& queue); + ClCommandQueue& operator=(ClCommandQueue&& queue); + ClCommandQueue(const ClCommandQueue&) = delete; + ClCommandQueue& operator=(const ClCommandQueue&) = delete; + + virtual ~ClCommandQueue(); + + cl_command_queue queue() const { return queue_; } + + absl::Status EnqueueWriteBuffer(cl_mem memory, size_t size_in_bytes, + const void* data, bool async = false); + absl::Status EnqueueReadBuffer(cl_mem memory, size_t size_in_bytes, + void* data, bool async = false); + + absl::Status WaitForCompletion(); + + protected: + void Release(); + + cl_command_queue queue_ = nullptr; + bool has_ownership_ = false; +}; + +class ProfilingCommandQueue : public ClCommandQueue { + public: + ProfilingCommandQueue(); + explicit ProfilingCommandQueue(cl_command_queue queue); + + // Move only + ProfilingCommandQueue(ProfilingCommandQueue&& queue); + ProfilingCommandQueue& operator=(ProfilingCommandQueue&& queue); + ProfilingCommandQueue(const ProfilingCommandQueue&) = delete; + ProfilingCommandQueue& operator=(const ProfilingCommandQueue&) = delete; + + private: + std::string current_label_; +}; + +absl::Status CreateClCommandQueue(const ClDevice& device, + const ClContext& context, + ClCommandQueue* result); + +} // namespace cl +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_COMMAND_QUEUE_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.cc b/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.cc new file mode 100644 index 00000000000000..b7d6e074d2c239 --- /dev/null +++ b/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.cc @@ -0,0 +1,105 @@ +// Copyright 2024 The TensorFlow Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "third_party/opencl_headers/CL/cl.h" +#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" +#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" + +namespace litert { +namespace cl { +namespace { + +absl::Status CreateClContext(const ClDevice& device, + const std::vector& props, + ClContext* result) { + int error_code; + cl_device_id device_id = device.id(); + std::vector props_local = props; + if (!props_local.empty()) { + props_local.push_back(0); + } + cl_context_properties* properties_ptr = + props_local.empty() ? nullptr : props_local.data(); + cl_context context = clCreateContext(properties_ptr, 1, &device_id, nullptr, + nullptr, &error_code); + if (!context) { + return absl::UnknownError( + absl::StrCat("Failed to create a compute context - ", error_code)); + } + + *result = ClContext(context, true); + return absl::OkStatus(); +} + +} // namespace + +ClContext::ClContext() = default; + +ClContext::ClContext(cl_context context, bool has_ownership) + : context_(context), has_ownership_(has_ownership) {} + +ClContext::ClContext(cl_context context, bool has_ownership, ClDevice& device) + : context_(context), has_ownership_(has_ownership) {} + +ClContext::ClContext(ClContext&& context) + : context_(context.context_), has_ownership_(context.has_ownership_) { + context.context_ = nullptr; +} + +ClContext& ClContext::operator=(ClContext&& context) { + if (this != &context) { + Release(); + std::swap(context_, context.context_); + has_ownership_ = context.has_ownership_; + } + return *this; +} + +ClContext::~ClContext() { Release(); } + +void ClContext::Release() { + if (has_ownership_ && context_) { + clReleaseContext(context_); + context_ = nullptr; + } +} + +absl::Status CreateClContext(const ClDevice& device, ClContext* result) { + std::vector props; + return CreateClContext(device, props, result); +} + +absl::Status CreateClGlContext(const ClDevice& device, + cl_context_properties egl_context, + cl_context_properties egl_display, + ClContext* result) { + cl_context_properties platform = + reinterpret_cast(device.platform()); + + std::vector props = {CL_GL_CONTEXT_KHR, egl_context, + CL_EGL_DISPLAY_KHR, egl_display, + CL_CONTEXT_PLATFORM, platform}; + + return CreateClContext(device, props, result); +} + +} // namespace cl +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h b/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h new file mode 100644 index 00000000000000..8773059511dee3 --- /dev/null +++ b/tensorflow/lite/experimental/litert/runtime/opencl/cl_context.h @@ -0,0 +1,57 @@ +// Copyright 2024 The TensorFlow Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_CONTEXT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_CONTEXT_H_ + +#include "absl/status/status.h" +#include "third_party/opencl_headers/CL/cl.h" +#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" + +namespace litert { +namespace cl { + +// A RAII wrapper around opencl context +class ClContext { + public: + ClContext(); + ClContext(cl_context context, bool has_ownership); + ClContext(cl_context context, bool has_ownership, ClDevice& device); + // Move only + ClContext(ClContext&& context); + ClContext& operator=(ClContext&& context); + ClContext(const ClContext&) = delete; + ClContext& operator=(const ClContext&) = delete; + + ~ClContext(); + + cl_context context() const { return context_; } + + private: + void Release(); + + cl_context context_ = nullptr; + bool has_ownership_ = false; +}; + +absl::Status CreateClContext(const ClDevice& device, ClContext* result); +absl::Status CreateClGlContext(const ClDevice& device, + cl_context_properties egl_context, + cl_context_properties egl_display, + ClContext* result); + +} // namespace cl +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.cc b/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.cc new file mode 100644 index 00000000000000..72f90133c5ef2b --- /dev/null +++ b/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.cc @@ -0,0 +1,104 @@ +// Copyright 2024 The TensorFlow Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// this is a copy of ml_drift/cl/cl_device.cc +#include "tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "third_party/opencl_headers/CL/cl.h" +#include "third_party/opencl_headers/CL/cl_platform.h" +#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" + +namespace litert { +namespace cl { + +ClDevice::ClDevice(cl_device_id id, cl_platform_id platform_id) + : id_(id), platform_id_(platform_id) {} + +ClDevice::ClDevice(const ClDevice& device) = default; + +ClDevice& ClDevice::operator=(const ClDevice& device) { + if (this != &device) { + id_ = device.id_; + platform_id_ = device.platform_id_; + } + return *this; +} + +ClDevice::ClDevice(ClDevice&& device) + : id_(device.id_), platform_id_(device.platform_id_) { + device.id_ = nullptr; + device.platform_id_ = nullptr; +} + +ClDevice& ClDevice::operator=(ClDevice&& device) { + if (this != &device) { + id_ = nullptr; + platform_id_ = nullptr; + std::swap(id_, device.id_); + std::swap(platform_id_, device.platform_id_); + } + return *this; +} + +absl::Status CreateDefaultGPUDevice(ClDevice* result) { + cl_uint num_platforms; + cl_int status = clGetPlatformIDs(0, nullptr, &num_platforms); + if (status != CL_SUCCESS) { + return absl::UnknownError( + absl::StrFormat("clGetPlatformIDs returned %d", status)); + } + if (num_platforms == 0) { + return absl::UnknownError("No supported OpenCL platform."); + } + std::vector platforms(num_platforms); + status = clGetPlatformIDs(num_platforms, platforms.data(), nullptr); + if (status != CL_SUCCESS) { + return absl::UnknownError( + absl::StrFormat("clGetPlatformIDs returned %d", status)); + } + + cl_platform_id platform_id = platforms[0]; + cl_uint num_devices; + status = + clGetDeviceIDs(platform_id, CL_DEVICE_TYPE_GPU, 0, nullptr, &num_devices); + if (status != CL_SUCCESS) { + return absl::UnknownError( + absl::StrFormat("clGetDeviceIDs returned %d", status)); + } + if (num_devices == 0) { + return absl::UnknownError("No GPU on current platform."); + } + + std::vector devices(num_devices); + status = clGetDeviceIDs(platform_id, CL_DEVICE_TYPE_GPU, num_devices, + devices.data(), nullptr); + if (status != CL_SUCCESS) { + return absl::UnknownError( + absl::StrFormat("clGetDeviceIDs returned %d", status)); + } + + *result = ClDevice(devices[0], platform_id); + LoadOpenCLFunctionExtensions(platform_id); + return absl::OkStatus(); +} + +} // namespace cl +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h b/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h new file mode 100644 index 00000000000000..28a0226a7f274b --- /dev/null +++ b/tensorflow/lite/experimental/litert/runtime/opencl/cl_device.h @@ -0,0 +1,73 @@ +// Copyright 2024 The ML Drift Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_DEVICE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_DEVICE_H_ + +#include + +#include "absl/status/status.h" +#include "third_party/opencl_headers/CL/cl.h" +#include "third_party/opencl_headers/CL/cl_platform.h" + +namespace litert { +namespace cl { + +// A wrapper around opencl device id +class ClDevice { + public: + ClDevice() = default; + ClDevice(cl_device_id id, cl_platform_id platform_id); + + ClDevice(ClDevice&& device); + ClDevice& operator=(ClDevice&& device); + ClDevice(const ClDevice&); + ClDevice& operator=(const ClDevice&); + + ~ClDevice() = default; + + cl_device_id id() const { return id_; } + cl_platform_id platform() const { return platform_id_; } + std::string GetPlatformVersion() const; + + private: + cl_device_id id_ = nullptr; + cl_platform_id platform_id_ = nullptr; +}; + +absl::Status CreateDefaultGPUDevice(ClDevice* result); + +template +T GetDeviceInfo(cl_device_id id, cl_device_info info) { + T result; + cl_int error = clGetDeviceInfo(id, info, sizeof(T), &result, nullptr); + if (error != CL_SUCCESS) { + return {}; + } + return result; +} + +template +absl::Status GetDeviceInfo(cl_device_id id, cl_device_info info, T* result) { + cl_int error = clGetDeviceInfo(id, info, sizeof(T), result, nullptr); + if (error != CL_SUCCESS) { + return absl::InvalidArgumentError("cl error:" + std::to_string(error)); + } + return absl::OkStatus(); +} + +} // namespace cl +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_CL_DEVICE_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.cc b/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.cc new file mode 100644 index 00000000000000..79c4e33e2eb72f --- /dev/null +++ b/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.cc @@ -0,0 +1,470 @@ +// Copyright 2024 The Tensorflow Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is copied from third_party/ml_drift/cl/opencl_wrapper.cc. +#include "tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h" + +#if defined(_WIN32) +#define __WINDOWS__ +#endif + +#ifdef __WINDOWS__ +#include +#else +#include +#endif + +#include + +#include "absl/strings/str_cat.h" + +namespace litert { +namespace cl { + +#ifdef __ANDROID__ +#define LoadFunction(function) \ + if (use_wrapper) { \ + function = reinterpret_cast(loadOpenCLPointer(#function)); \ + } else { \ + function = reinterpret_cast(dlsym(libopencl, #function)); \ + } + +namespace { + +// Loads a library from Android SP-HAL namespace which includes libraries from +// the path /vendor/lib[64] directly and several sub-folders in it. +// First tries using dlopen(), which should work if the process is running with +// linker namespace "sphal" (so has permissions to sphal paths). +// If it fails, for example if process is running with linker default namespace +// because it's a sub-process of the app, then tries loading the library using +// a sphal helper loader function from Vendor NDK support library. +void* AndroidDlopenSphalLibrary(const char* filename, int dlopen_flags) { + void* lib = dlopen(filename, dlopen_flags); + if (lib != nullptr) { + return lib; + } + static void* (*android_load_sphal_library)(const char*, int) = nullptr; + if (android_load_sphal_library != nullptr) { + return android_load_sphal_library(filename, dlopen_flags); + } + android_load_sphal_library = + reinterpret_cast( + dlsym(RTLD_NEXT, "android_load_sphal_library")); + if (android_load_sphal_library == nullptr) { + void* vndk = dlopen("libvndksupport.so", RTLD_NOW); + if (vndk != nullptr) { + android_load_sphal_library = + reinterpret_cast( + dlsym(vndk, "android_load_sphal_library")); + } + if (android_load_sphal_library == nullptr) { + return nullptr; + } + } + return android_load_sphal_library(filename, dlopen_flags); +} + +} // namespace + +#elif defined(__WINDOWS__) +#define LoadFunction(function) \ + function = \ + reinterpret_cast(GetProcAddress(libopencl, #function)); +#else +#define LoadFunction(function) \ + function = reinterpret_cast(dlsym(libopencl, #function)); +#endif + +#define LoadFunctionExtension(plat_id, function) \ + function = reinterpret_cast( \ + clGetExtensionFunctionAddressForPlatform(plat_id, #function)); + +#ifdef __WINDOWS__ +void LoadOpenCLFunctions(HMODULE libopencl); +#else +void LoadOpenCLFunctions(void* libopencl, bool use_wrapper); +#endif + +absl::Status LoadOpenCL() { +#ifdef __WINDOWS__ + HMODULE libopencl = LoadLibraryA("OpenCL.dll"); + if (libopencl) { + LoadOpenCLFunctions(libopencl); + return absl::OkStatus(); + } else { + DWORD error_code = GetLastError(); + return absl::UnknownError(absl::StrCat( + "Can not open OpenCL library on this device, error code - ", + error_code)); + } +#else + void* libopencl = nullptr; +#ifdef __APPLE__ + static const char* kClLibName = + "/System/Library/Frameworks/OpenCL.framework/OpenCL"; +#else + static const char* kClLibName = "libOpenCL.so"; +#endif +#ifdef __ANDROID__ + libopencl = AndroidDlopenSphalLibrary(kClLibName, RTLD_NOW | RTLD_LOCAL); + if (!libopencl) { + // Legacy Pixel phone or auto path? + libopencl = + AndroidDlopenSphalLibrary("libOpenCL-pixel.so", RTLD_NOW | RTLD_LOCAL); + if (!libopencl) { + libopencl = + AndroidDlopenSphalLibrary("libOpenCL-car.so", RTLD_NOW | RTLD_LOCAL); + } + if (libopencl) { + typedef void (*enableOpenCL_t)(); + enableOpenCL_t enableOpenCL = + reinterpret_cast(dlsym(libopencl, "enableOpenCL")); + enableOpenCL(); + LoadOpenCLFunctions(libopencl, true); + return absl::OkStatus(); + } + } +#else + libopencl = dlopen(kClLibName, RTLD_NOW | RTLD_LOCAL); +#endif + if (libopencl) { + LoadOpenCLFunctions(libopencl, false); + return absl::OkStatus(); + } + // record error + std::string error(dlerror()); + + // Check if OpenCL functions are found via OpenCL ICD Loader. + LoadOpenCLFunctions(libopencl, /*use_wrapper=*/false); + if (clGetPlatformIDs != nullptr) { + cl_uint num_platforms; + cl_int status = clGetPlatformIDs(0, nullptr, &num_platforms); + if (status == CL_SUCCESS && num_platforms != 0) { + return absl::OkStatus(); + } + return absl::UnknownError("OpenCL is not supported."); + } + return absl::UnknownError( + absl::StrCat("Can not open OpenCL library on this device - ", error)); +#endif +} + +void LoadOpenCLFunctionExtensions(cl_platform_id platform_id) { + // cl_khr_command_buffer extension + LoadFunctionExtension(platform_id, clCreateCommandBufferKHR); + LoadFunctionExtension(platform_id, clRetainCommandBufferKHR); + LoadFunctionExtension(platform_id, clReleaseCommandBufferKHR); + LoadFunctionExtension(platform_id, clFinalizeCommandBufferKHR); + LoadFunctionExtension(platform_id, clEnqueueCommandBufferKHR); + LoadFunctionExtension(platform_id, clCommandNDRangeKernelKHR); + LoadFunctionExtension(platform_id, clGetCommandBufferInfoKHR); +} + +#ifdef __WINDOWS__ +void LoadOpenCLFunctions(HMODULE libopencl) { +#else +void LoadOpenCLFunctions(void* libopencl, bool use_wrapper) { +#ifdef __ANDROID__ + typedef void* (*loadOpenCLPointer_t)(const char* name); + loadOpenCLPointer_t loadOpenCLPointer; + if (use_wrapper) { + loadOpenCLPointer = reinterpret_cast( + dlsym(libopencl, "loadOpenCLPointer")); + } +#endif +#endif + + LoadFunction(clGetPlatformIDs); + LoadFunction(clGetPlatformInfo); + LoadFunction(clGetDeviceIDs); + LoadFunction(clGetDeviceInfo); + LoadFunction(clCreateSubDevices); + LoadFunction(clRetainDevice); + LoadFunction(clReleaseDevice); + LoadFunction(clCreateContext); + LoadFunction(clCreateContextFromType); + LoadFunction(clRetainContext); + LoadFunction(clReleaseContext); + LoadFunction(clGetContextInfo); + LoadFunction(clCreateCommandQueueWithProperties); + LoadFunction(clRetainCommandQueue); + LoadFunction(clReleaseCommandQueue); + LoadFunction(clGetCommandQueueInfo); + LoadFunction(clCreateBuffer); + LoadFunction(clCreateSubBuffer); + LoadFunction(clCreateImage); + LoadFunction(clCreatePipe); + LoadFunction(clRetainMemObject); + LoadFunction(clReleaseMemObject); + LoadFunction(clGetSupportedImageFormats); + LoadFunction(clGetMemObjectInfo); + LoadFunction(clGetImageInfo); + LoadFunction(clGetPipeInfo); + LoadFunction(clSetMemObjectDestructorCallback); + LoadFunction(clSVMAlloc); + LoadFunction(clSVMFree); + LoadFunction(clCreateSamplerWithProperties); + LoadFunction(clRetainSampler); + LoadFunction(clReleaseSampler); + LoadFunction(clGetSamplerInfo); + LoadFunction(clCreateProgramWithSource); + LoadFunction(clCreateProgramWithBinary); + LoadFunction(clCreateProgramWithBuiltInKernels); + LoadFunction(clRetainProgram); + LoadFunction(clReleaseProgram); + LoadFunction(clBuildProgram); + LoadFunction(clCompileProgram); + LoadFunction(clLinkProgram); + LoadFunction(clUnloadPlatformCompiler); + LoadFunction(clGetProgramInfo); + LoadFunction(clGetProgramBuildInfo); + LoadFunction(clCreateKernel); + LoadFunction(clCreateKernelsInProgram); + LoadFunction(clRetainKernel); + LoadFunction(clReleaseKernel); + LoadFunction(clSetKernelArg); + LoadFunction(clSetKernelArgSVMPointer); + LoadFunction(clSetKernelExecInfo); + LoadFunction(clGetKernelInfo); + LoadFunction(clGetKernelArgInfo); + LoadFunction(clGetKernelWorkGroupInfo); + LoadFunction(clWaitForEvents); + LoadFunction(clGetEventInfo); + LoadFunction(clCreateUserEvent); + LoadFunction(clRetainEvent); + LoadFunction(clReleaseEvent); + LoadFunction(clSetUserEventStatus); + LoadFunction(clSetEventCallback); + LoadFunction(clGetEventProfilingInfo); + LoadFunction(clFlush); + LoadFunction(clFinish); + LoadFunction(clEnqueueReadBuffer); + LoadFunction(clEnqueueReadBufferRect); + LoadFunction(clEnqueueWriteBuffer); + LoadFunction(clEnqueueWriteBufferRect); + LoadFunction(clEnqueueFillBuffer); + LoadFunction(clEnqueueCopyBuffer); + LoadFunction(clEnqueueCopyBufferRect); + LoadFunction(clEnqueueReadImage); + LoadFunction(clEnqueueWriteImage); + LoadFunction(clEnqueueFillImage); + LoadFunction(clEnqueueCopyImage); + LoadFunction(clEnqueueCopyImageToBuffer); + LoadFunction(clEnqueueCopyBufferToImage); + LoadFunction(clEnqueueMapBuffer); + LoadFunction(clEnqueueMapImage); + LoadFunction(clEnqueueUnmapMemObject); + LoadFunction(clEnqueueMigrateMemObjects); + LoadFunction(clEnqueueNDRangeKernel); + LoadFunction(clEnqueueNativeKernel); + LoadFunction(clEnqueueMarkerWithWaitList); + LoadFunction(clEnqueueBarrierWithWaitList); + LoadFunction(clEnqueueSVMFree); + LoadFunction(clEnqueueSVMMemcpy); + LoadFunction(clEnqueueSVMMemFill); + LoadFunction(clEnqueueSVMMap); + LoadFunction(clEnqueueSVMUnmap); + LoadFunction(clGetExtensionFunctionAddressForPlatform); + LoadFunction(clCreateImage2D); + LoadFunction(clCreateImage3D); + LoadFunction(clEnqueueMarker); + LoadFunction(clEnqueueWaitForEvents); + LoadFunction(clEnqueueBarrier); + LoadFunction(clUnloadCompiler); + LoadFunction(clGetExtensionFunctionAddress); + LoadFunction(clCreateCommandQueue); + LoadFunction(clCreateSampler); + LoadFunction(clEnqueueTask); + + // OpenGL sharing + LoadFunction(clCreateFromGLBuffer); + LoadFunction(clCreateFromGLTexture); + LoadFunction(clEnqueueAcquireGLObjects); + LoadFunction(clEnqueueReleaseGLObjects); + + // cl_khr_egl_event extension + LoadFunction(clCreateEventFromEGLSyncKHR); + + // EGL sharing + LoadFunction(clCreateFromEGLImageKHR); + LoadFunction(clEnqueueAcquireEGLObjectsKHR); + LoadFunction(clEnqueueReleaseEGLObjectsKHR); + + // OpenCL 3.0 + LoadFunction(clCreateBufferWithProperties); + LoadFunction(clCreateImageWithProperties); +} + +// No OpenCL support, do not set function addresses +PFN_clGetPlatformIDs clGetPlatformIDs; +PFN_clGetPlatformInfo clGetPlatformInfo; +PFN_clGetDeviceIDs clGetDeviceIDs; +PFN_clGetDeviceInfo clGetDeviceInfo; +PFN_clCreateSubDevices clCreateSubDevices; +PFN_clRetainDevice clRetainDevice; +PFN_clReleaseDevice clReleaseDevice; +PFN_clCreateContext clCreateContext; +PFN_clCreateContextFromType clCreateContextFromType; +PFN_clRetainContext clRetainContext; +PFN_clReleaseContext clReleaseContext; +PFN_clGetContextInfo clGetContextInfo; +PFN_clCreateCommandQueueWithProperties clCreateCommandQueueWithProperties; +PFN_clRetainCommandQueue clRetainCommandQueue; +PFN_clReleaseCommandQueue clReleaseCommandQueue; +PFN_clGetCommandQueueInfo clGetCommandQueueInfo; +PFN_clCreateBuffer clCreateBuffer; +PFN_clCreateSubBuffer clCreateSubBuffer; +PFN_clCreateImage clCreateImage; +PFN_clCreatePipe clCreatePipe; +PFN_clRetainMemObject clRetainMemObject; +PFN_clReleaseMemObject clReleaseMemObject; +PFN_clGetSupportedImageFormats clGetSupportedImageFormats; +PFN_clGetMemObjectInfo clGetMemObjectInfo; +PFN_clGetImageInfo clGetImageInfo; +PFN_clGetPipeInfo clGetPipeInfo; +PFN_clSetMemObjectDestructorCallback clSetMemObjectDestructorCallback; +PFN_clSVMAlloc clSVMAlloc; +PFN_clSVMFree clSVMFree; +PFN_clCreateSamplerWithProperties clCreateSamplerWithProperties; +PFN_clRetainSampler clRetainSampler; +PFN_clReleaseSampler clReleaseSampler; +PFN_clGetSamplerInfo clGetSamplerInfo; +PFN_clCreateProgramWithSource clCreateProgramWithSource; +PFN_clCreateProgramWithBinary clCreateProgramWithBinary; +PFN_clCreateProgramWithBuiltInKernels clCreateProgramWithBuiltInKernels; +PFN_clRetainProgram clRetainProgram; +PFN_clReleaseProgram clReleaseProgram; +PFN_clBuildProgram clBuildProgram; +PFN_clCompileProgram clCompileProgram; +PFN_clLinkProgram clLinkProgram; +PFN_clUnloadPlatformCompiler clUnloadPlatformCompiler; +PFN_clGetProgramInfo clGetProgramInfo; +PFN_clGetProgramBuildInfo clGetProgramBuildInfo; +PFN_clCreateKernel clCreateKernel; +PFN_clCreateKernelsInProgram clCreateKernelsInProgram; +PFN_clRetainKernel clRetainKernel; +PFN_clReleaseKernel clReleaseKernel; +PFN_clSetKernelArg clSetKernelArg; +PFN_clSetKernelArgSVMPointer clSetKernelArgSVMPointer; +PFN_clSetKernelExecInfo clSetKernelExecInfo; +PFN_clGetKernelInfo clGetKernelInfo; +PFN_clGetKernelArgInfo clGetKernelArgInfo; +PFN_clGetKernelWorkGroupInfo clGetKernelWorkGroupInfo; +PFN_clWaitForEvents clWaitForEvents; +PFN_clGetEventInfo clGetEventInfo; +PFN_clCreateUserEvent clCreateUserEvent; +PFN_clRetainEvent clRetainEvent; +PFN_clReleaseEvent clReleaseEvent; +PFN_clSetUserEventStatus clSetUserEventStatus; +PFN_clSetEventCallback clSetEventCallback; +PFN_clGetEventProfilingInfo clGetEventProfilingInfo; +PFN_clFlush clFlush; +PFN_clFinish clFinish; +PFN_clEnqueueReadBuffer clEnqueueReadBuffer; +PFN_clEnqueueReadBufferRect clEnqueueReadBufferRect; +PFN_clEnqueueWriteBuffer clEnqueueWriteBuffer; +PFN_clEnqueueWriteBufferRect clEnqueueWriteBufferRect; +PFN_clEnqueueFillBuffer clEnqueueFillBuffer; +PFN_clEnqueueCopyBuffer clEnqueueCopyBuffer; +PFN_clEnqueueCopyBufferRect clEnqueueCopyBufferRect; +PFN_clEnqueueReadImage clEnqueueReadImage; +PFN_clEnqueueWriteImage clEnqueueWriteImage; +PFN_clEnqueueFillImage clEnqueueFillImage; +PFN_clEnqueueCopyImage clEnqueueCopyImage; +PFN_clEnqueueCopyImageToBuffer clEnqueueCopyImageToBuffer; +PFN_clEnqueueCopyBufferToImage clEnqueueCopyBufferToImage; +PFN_clEnqueueMapBuffer clEnqueueMapBuffer; +PFN_clEnqueueMapImage clEnqueueMapImage; +PFN_clEnqueueUnmapMemObject clEnqueueUnmapMemObject; +PFN_clEnqueueMigrateMemObjects clEnqueueMigrateMemObjects; +PFN_clEnqueueNDRangeKernel clEnqueueNDRangeKernel; +PFN_clEnqueueNativeKernel clEnqueueNativeKernel; +PFN_clEnqueueMarkerWithWaitList clEnqueueMarkerWithWaitList; +PFN_clEnqueueBarrierWithWaitList clEnqueueBarrierWithWaitList; +PFN_clEnqueueSVMFree clEnqueueSVMFree; +PFN_clEnqueueSVMMemcpy clEnqueueSVMMemcpy; +PFN_clEnqueueSVMMemFill clEnqueueSVMMemFill; +PFN_clEnqueueSVMMap clEnqueueSVMMap; +PFN_clEnqueueSVMUnmap clEnqueueSVMUnmap; +PFN_clGetExtensionFunctionAddressForPlatform + clGetExtensionFunctionAddressForPlatform; +PFN_clCreateImage2D clCreateImage2D; +PFN_clCreateImage3D clCreateImage3D; +PFN_clEnqueueMarker clEnqueueMarker; +PFN_clEnqueueWaitForEvents clEnqueueWaitForEvents; +PFN_clEnqueueBarrier clEnqueueBarrier; +PFN_clUnloadCompiler clUnloadCompiler; +PFN_clGetExtensionFunctionAddress clGetExtensionFunctionAddress; +PFN_clCreateCommandQueue clCreateCommandQueue; +PFN_clCreateSampler clCreateSampler; +PFN_clEnqueueTask clEnqueueTask; + +// OpenGL sharing +PFN_clCreateFromGLBuffer clCreateFromGLBuffer; +PFN_clCreateFromGLTexture clCreateFromGLTexture; +PFN_clEnqueueAcquireGLObjects clEnqueueAcquireGLObjects; +PFN_clEnqueueReleaseGLObjects clEnqueueReleaseGLObjects; + +// cl_khr_egl_event extension +PFN_clCreateEventFromEGLSyncKHR clCreateEventFromEGLSyncKHR; + +// EGL sharing +PFN_clCreateFromEGLImageKHR clCreateFromEGLImageKHR; +PFN_clEnqueueAcquireEGLObjectsKHR clEnqueueAcquireEGLObjectsKHR; +PFN_clEnqueueReleaseEGLObjectsKHR clEnqueueReleaseEGLObjectsKHR; + +// cl_khr_command_buffer extension +PFN_clCreateCommandBufferKHR clCreateCommandBufferKHR; +PFN_clRetainCommandBufferKHR clRetainCommandBufferKHR; +PFN_clReleaseCommandBufferKHR clReleaseCommandBufferKHR; +PFN_clFinalizeCommandBufferKHR clFinalizeCommandBufferKHR; +PFN_clEnqueueCommandBufferKHR clEnqueueCommandBufferKHR; +PFN_clCommandNDRangeKernelKHR clCommandNDRangeKernelKHR; +PFN_clGetCommandBufferInfoKHR clGetCommandBufferInfoKHR; + +// OpenCL 3.0 +PFN_clCreateBufferWithProperties clCreateBufferWithProperties; +PFN_clCreateImageWithProperties clCreateImageWithProperties; + +cl_mem CreateImage2DLegacy(cl_context context, cl_mem_flags flags, + const cl_image_format* image_format, + const cl_image_desc* image_desc, void* host_ptr, + cl_int* errcode_ret) { + if (clCreateImage) { // clCreateImage available since OpenCL 1.2 + return clCreateImage(context, flags, image_format, image_desc, host_ptr, + errcode_ret); + } else { + return clCreateImage2D(context, flags, image_format, + image_desc->image_width, image_desc->image_height, + image_desc->image_row_pitch, host_ptr, errcode_ret); + } +} + +cl_mem CreateImage3DLegacy(cl_context context, cl_mem_flags flags, + const cl_image_format* image_format, + const cl_image_desc* image_desc, void* host_ptr, + cl_int* errcode_ret) { + if (clCreateImage) { // clCreateImage available since OpenCL 1.2 + return clCreateImage(context, flags, image_format, image_desc, host_ptr, + errcode_ret); + } else { + return clCreateImage3D(context, flags, image_format, + image_desc->image_width, image_desc->image_height, + image_desc->image_depth, image_desc->image_row_pitch, + image_desc->image_slice_pitch, host_ptr, + errcode_ret); + } +} +} // namespace cl +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h b/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h new file mode 100644 index 00000000000000..07d57212646ecb --- /dev/null +++ b/tensorflow/lite/experimental/litert/runtime/opencl/opencl_wrapper.h @@ -0,0 +1,737 @@ +// Copyright 2024 The TensorFlow Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is copied from third_party/ml_drift/cl/opencl_wrapper.h. +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_OPENCL_WRAPPER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_OPENCL_WRAPPER_H_ + +#include + +#include "absl/status/status.h" +#include "third_party/opencl_headers/CL/cl.h" // IWYU pragma: export +#include "third_party/opencl_headers/CL/cl_egl.h" // IWYU pragma: export +#include "third_party/opencl_headers/CL/cl_ext.h" // IWYU pragma: export +#include "third_party/opencl_headers/CL/cl_gl.h" // IWYU pragma: export +#include "third_party/opencl_headers/CL/cl_platform.h" // IWYU pragma: export + +namespace litert { +namespace cl { + +absl::Status LoadOpenCL(); +void LoadOpenCLFunctionExtensions(cl_platform_id platform_id); + +typedef cl_int(CL_API_CALL *PFN_clGetPlatformIDs)( + cl_uint /* num_entries */, cl_platform_id * /* platforms */, + cl_uint * /* num_platforms */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clGetPlatformInfo)( + cl_platform_id /* platform */, cl_platform_info /* param_name */, + size_t /* param_value_size */, void * /* param_value */, + size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clGetDeviceIDs)( + cl_platform_id /* platform */, cl_device_type /* device_type */, + cl_uint /* num_entries */, cl_device_id * /* devices */, + cl_uint * /* num_devices */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clGetDeviceInfo)( + cl_device_id /* device */, cl_device_info /* param_name */, + size_t /* param_value_size */, void * /* param_value */, + size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clCreateSubDevices)( + cl_device_id /* in_device */, + const cl_device_partition_property * /* properties */, + cl_uint /* num_devices */, cl_device_id * /* out_devices */, + cl_uint * /* num_devices_ret */) CL_API_SUFFIX__VERSION_1_2; +typedef cl_int(CL_API_CALL *PFN_clRetainDevice)(cl_device_id /* device */) + CL_API_SUFFIX__VERSION_1_2; +typedef cl_int(CL_API_CALL *PFN_clReleaseDevice)(cl_device_id /* device */) + CL_API_SUFFIX__VERSION_1_2; +typedef cl_context(CL_API_CALL *PFN_clCreateContext)( + const cl_context_properties * /* properties */, cl_uint /* num_devices */, + const cl_device_id * /* devices */, + void(CL_CALLBACK * /* pfn_notify */)(const char *, const void *, size_t, + void *), + void * /* user_data */, + cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_context(CL_API_CALL *PFN_clCreateContextFromType)( + const cl_context_properties * /* properties */, + cl_device_type /* device_type */, + void(CL_CALLBACK * /* pfn_notify*/)(const char *, const void *, size_t, + void *), + void * /* user_data */, + cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clRetainContext)(cl_context /* context */) + CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clReleaseContext)(cl_context /* context */) + CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clGetContextInfo)( + cl_context /* context */, cl_context_info /* param_name */, + size_t /* param_value_size */, void * /* param_value */, + size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_command_queue(CL_API_CALL *PFN_clCreateCommandQueueWithProperties)( + cl_context /* context */, cl_device_id /* device */, + const cl_queue_properties * /* properties */, + cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_2_0; +typedef cl_int(CL_API_CALL *PFN_clRetainCommandQueue)( + cl_command_queue /* command_queue */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clReleaseCommandQueue)( + cl_command_queue /* command_queue */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clGetCommandQueueInfo)( + cl_command_queue /* command_queue */, + cl_command_queue_info /* param_name */, size_t /* param_value_size */, + void * /* param_value */, + size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_mem(CL_API_CALL *PFN_clCreateBuffer)( + cl_context /* context */, cl_mem_flags /* flags */, size_t /* size */, + void * /* host_ptr */, + cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_mem(CL_API_CALL *PFN_clCreateSubBuffer)( + cl_mem /* buffer */, cl_mem_flags /* flags */, + cl_buffer_create_type /* buffer_create_type */, + const void * /* buffer_create_info */, + cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_1; +typedef cl_mem(CL_API_CALL *PFN_clCreateImage)( + cl_context /* context */, cl_mem_flags /* flags */, + const cl_image_format * /* image_format */, + const cl_image_desc * /* image_desc */, void * /* host_ptr */, + cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2; +typedef cl_mem(CL_API_CALL *PFN_clCreatePipe)( + cl_context /* context */, cl_mem_flags /* flags */, + cl_uint /* pipe_packet_size */, cl_uint /* pipe_max_packets */, + const cl_pipe_properties * /* properties */, + cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_2_0; +typedef cl_int(CL_API_CALL *PFN_clRetainMemObject)(cl_mem /* memobj */) + CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clReleaseMemObject)(cl_mem /* memobj */) + CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clGetSupportedImageFormats)( + cl_context /* context */, cl_mem_flags /* flags */, + cl_mem_object_type /* image_type */, cl_uint /* num_entries */, + cl_image_format * /* image_formats */, + cl_uint * /* num_image_formats */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clGetMemObjectInfo)( + cl_mem /* memobj */, cl_mem_info /* param_name */, + size_t /* param_value_size */, void * /* param_value */, + size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clGetImageInfo)( + cl_mem /* image */, cl_image_info /* param_name */, + size_t /* param_value_size */, void * /* param_value */, + size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clGetPipeInfo)( + cl_mem /* pipe */, cl_pipe_info /* param_name */, + size_t /* param_value_size */, void * /* param_value */, + size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_2_0; +typedef cl_int(CL_API_CALL *PFN_clSetMemObjectDestructorCallback)( + cl_mem /* memobj */, + void(CL_CALLBACK * /*pfn_notify*/)(cl_mem /* memobj */, + void * /*user_data*/), + void * /*user_data */) CL_API_SUFFIX__VERSION_1_1; +typedef void *(CL_API_CALL *PFN_clSVMAlloc)( + cl_context /* context */, cl_svm_mem_flags /* flags */, size_t /* size */, + cl_uint /* alignment */)CL_API_SUFFIX__VERSION_2_0; +typedef void(CL_API_CALL *PFN_clSVMFree)(cl_context /* context */, + void * /* svm_pointer */) + CL_API_SUFFIX__VERSION_2_0; +typedef cl_sampler(CL_API_CALL *PFN_clCreateSamplerWithProperties)( + cl_context /* context */, + const cl_sampler_properties * /* normalized_coords */, + cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_2_0; +typedef cl_int(CL_API_CALL *PFN_clRetainSampler)(cl_sampler /* sampler */) + CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clReleaseSampler)(cl_sampler /* sampler */) + CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clGetSamplerInfo)( + cl_sampler /* sampler */, cl_sampler_info /* param_name */, + size_t /* param_value_size */, void * /* param_value */, + size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_program(CL_API_CALL *PFN_clCreateProgramWithSource)( + cl_context /* context */, cl_uint /* count */, const char ** /* strings */, + const size_t * /* lengths */, + cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_program(CL_API_CALL *PFN_clCreateProgramWithBinary)( + cl_context /* context */, cl_uint /* num_devices */, + const cl_device_id * /* device_list */, const size_t * /* lengths */, + const unsigned char ** /* binaries */, cl_int * /* binary_status */, + cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_program(CL_API_CALL *PFN_clCreateProgramWithBuiltInKernels)( + cl_context /* context */, cl_uint /* num_devices */, + const cl_device_id * /* device_list */, const char * /* kernel_names */, + cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2; +typedef cl_int(CL_API_CALL *PFN_clRetainProgram)(cl_program /* program */) + CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clReleaseProgram)(cl_program /* program */) + CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clBuildProgram)( + cl_program /* program */, cl_uint /* num_devices */, + const cl_device_id * /* device_list */, const char * /* options */, + void(CL_CALLBACK * /* pfn_notify */)(cl_program /* program */, + void * /* user_data */), + void * /* user_data */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clCompileProgram)( + cl_program /* program */, cl_uint /* num_devices */, + const cl_device_id * /* device_list */, const char * /* options */, + cl_uint /* num_input_headers */, const cl_program * /* input_headers */, + const char ** /* header_include_names */, + void(CL_CALLBACK * /* pfn_notify */)(cl_program /* program */, + void * /* user_data */), + void * /* user_data */) CL_API_SUFFIX__VERSION_1_2; +typedef cl_program(CL_API_CALL *PFN_clLinkProgram)( + cl_context /* context */, cl_uint /* num_devices */, + const cl_device_id * /* device_list */, const char * /* options */, + cl_uint /* num_input_programs */, const cl_program * /* input_programs */, + void(CL_CALLBACK * /* pfn_notify */)(cl_program /* program */, + void * /* user_data */), + void * /* user_data */, + cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2; +typedef cl_int(CL_API_CALL *PFN_clUnloadPlatformCompiler)( + cl_platform_id /* platform */) CL_API_SUFFIX__VERSION_1_2; +typedef cl_int(CL_API_CALL *PFN_clGetProgramInfo)( + cl_program /* program */, cl_program_info /* param_name */, + size_t /* param_value_size */, void * /* param_value */, + size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clGetProgramBuildInfo)( + cl_program /* program */, cl_device_id /* device */, + cl_program_build_info /* param_name */, size_t /* param_value_size */, + void * /* param_value */, + size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_kernel(CL_API_CALL *PFN_clCreateKernel)( + cl_program /* program */, const char * /* kernel_name */, + cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clCreateKernelsInProgram)( + cl_program /* program */, cl_uint /* num_kernels */, + cl_kernel * /* kernels */, + cl_uint * /* num_kernels_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clRetainKernel)(cl_kernel /* kernel */) + CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clReleaseKernel)(cl_kernel /* kernel */) + CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clSetKernelArg)( + cl_kernel /* kernel */, cl_uint /* arg_index */, size_t /* arg_size */, + const void * /* arg_value */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clSetKernelArgSVMPointer)( + cl_kernel /* kernel */, cl_uint /* arg_index */, + const void * /* arg_value */) CL_API_SUFFIX__VERSION_2_0; +typedef cl_int(CL_API_CALL *PFN_clSetKernelExecInfo)( + cl_kernel /* kernel */, cl_kernel_exec_info /* param_name */, + size_t /* param_value_size */, + const void * /* param_value */) CL_API_SUFFIX__VERSION_2_0; +typedef cl_int(CL_API_CALL *PFN_clGetKernelInfo)( + cl_kernel /* kernel */, cl_kernel_info /* param_name */, + size_t /* param_value_size */, void * /* param_value */, + size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clGetKernelArgInfo)( + cl_kernel /* kernel */, cl_uint /* arg_indx */, + cl_kernel_arg_info /* param_name */, size_t /* param_value_size */, + void * /* param_value */, + size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_2; +typedef cl_int(CL_API_CALL *PFN_clGetKernelWorkGroupInfo)( + cl_kernel /* kernel */, cl_device_id /* device */, + cl_kernel_work_group_info /* param_name */, size_t /* param_value_size */, + void * /* param_value */, + size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clWaitForEvents)( + cl_uint /* num_events */, + const cl_event * /* event_list */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clGetEventInfo)( + cl_event /* event */, cl_event_info /* param_name */, + size_t /* param_value_size */, void * /* param_value */, + size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_event(CL_API_CALL *PFN_clCreateUserEvent)(cl_context /* context */, + cl_int * /* errcode_ret */) + CL_API_SUFFIX__VERSION_1_1; +typedef cl_int(CL_API_CALL *PFN_clRetainEvent)(cl_event /* event */) + CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clReleaseEvent)(cl_event /* event */) + CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clSetUserEventStatus)( + cl_event /* event */, + cl_int /* execution_status */) CL_API_SUFFIX__VERSION_1_1; +typedef cl_int(CL_API_CALL *PFN_clSetEventCallback)( + cl_event /* event */, cl_int /* command_exec_callback_type */, + void(CL_CALLBACK * /* pfn_notify */)(cl_event, cl_int, void *), + void * /* user_data */) CL_API_SUFFIX__VERSION_1_1; +typedef cl_int(CL_API_CALL *PFN_clGetEventProfilingInfo)( + cl_event /* event */, cl_profiling_info /* param_name */, + size_t /* param_value_size */, void * /* param_value */, + size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clFlush)(cl_command_queue /* command_queue */) + CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clFinish)(cl_command_queue /* command_queue */) + CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clEnqueueReadBuffer)( + cl_command_queue /* command_queue */, cl_mem /* buffer */, + cl_bool /* blocking_read */, size_t /* offset */, size_t /* size */, + void * /* ptr */, cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clEnqueueReadBufferRect)( + cl_command_queue /* command_queue */, cl_mem /* buffer */, + cl_bool /* blocking_read */, const size_t * /* buffer_offset */, + const size_t * /* host_offset */, const size_t * /* region */, + size_t /* buffer_row_pitch */, size_t /* buffer_slice_pitch */, + size_t /* host_row_pitch */, size_t /* host_slice_pitch */, + void * /* ptr */, cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_1; +typedef cl_int(CL_API_CALL *PFN_clEnqueueWriteBuffer)( + cl_command_queue /* command_queue */, cl_mem /* buffer */, + cl_bool /* blocking_write */, size_t /* offset */, size_t /* size */, + const void * /* ptr */, cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clEnqueueWriteBufferRect)( + cl_command_queue /* command_queue */, cl_mem /* buffer */, + cl_bool /* blocking_write */, const size_t * /* buffer_offset */, + const size_t * /* host_offset */, const size_t * /* region */, + size_t /* buffer_row_pitch */, size_t /* buffer_slice_pitch */, + size_t /* host_row_pitch */, size_t /* host_slice_pitch */, + const void * /* ptr */, cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_1; +typedef cl_int(CL_API_CALL *PFN_clEnqueueFillBuffer)( + cl_command_queue /* command_queue */, cl_mem /* buffer */, + const void * /* pattern */, size_t /* pattern_size */, size_t /* offset */, + size_t /* size */, cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; +typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyBuffer)( + cl_command_queue /* command_queue */, cl_mem /* src_buffer */, + cl_mem /* dst_buffer */, size_t /* src_offset */, size_t /* dst_offset */, + size_t /* size */, cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyBufferRect)( + cl_command_queue /* command_queue */, cl_mem /* src_buffer */, + cl_mem /* dst_buffer */, const size_t * /* src_origin */, + const size_t * /* dst_origin */, const size_t * /* region */, + size_t /* src_row_pitch */, size_t /* src_slice_pitch */, + size_t /* dst_row_pitch */, size_t /* dst_slice_pitch */, + cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_1; +typedef cl_int(CL_API_CALL *PFN_clEnqueueReadImage)( + cl_command_queue /* command_queue */, cl_mem /* image */, + cl_bool /* blocking_read */, const size_t * /* origin[3] */, + const size_t * /* region[3] */, size_t /* row_pitch */, + size_t /* slice_pitch */, void * /* ptr */, + cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clEnqueueWriteImage)( + cl_command_queue /* command_queue */, cl_mem /* image */, + cl_bool /* blocking_write */, const size_t * /* origin[3] */, + const size_t * /* region[3] */, size_t /* input_row_pitch */, + size_t /* input_slice_pitch */, const void * /* ptr */, + cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clEnqueueFillImage)( + cl_command_queue /* command_queue */, cl_mem /* image */, + const void * /* fill_color */, const size_t * /* origin[3] */, + const size_t * /* region[3] */, cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; +typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyImage)( + cl_command_queue /* command_queue */, cl_mem /* src_image */, + cl_mem /* dst_image */, const size_t * /* src_origin[3] */, + const size_t * /* dst_origin[3] */, const size_t * /* region[3] */, + cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyImageToBuffer)( + cl_command_queue /* command_queue */, cl_mem /* src_image */, + cl_mem /* dst_buffer */, const size_t * /* src_origin[3] */, + const size_t * /* region[3] */, size_t /* dst_offset */, + cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyBufferToImage)( + cl_command_queue /* command_queue */, cl_mem /* src_buffer */, + cl_mem /* dst_image */, size_t /* src_offset */, + const size_t * /* dst_origin[3] */, const size_t * /* region[3] */, + cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; +typedef void *(CL_API_CALL *PFN_clEnqueueMapBuffer)( + cl_command_queue /* command_queue */, cl_mem /* buffer */, + cl_bool /* blocking_map */, cl_map_flags /* map_flags */, + size_t /* offset */, size_t /* size */, + cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, cl_event * /* event */, + cl_int * /* errcode_ret */)CL_API_SUFFIX__VERSION_1_0; +typedef void *(CL_API_CALL *PFN_clEnqueueMapImage)( + cl_command_queue /* command_queue */, cl_mem /* image */, + cl_bool /* blocking_map */, cl_map_flags /* map_flags */, + const size_t * /* origin[3] */, const size_t * /* region[3] */, + size_t * /* image_row_pitch */, size_t * /* image_slice_pitch */, + cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, cl_event * /* event */, + cl_int * /* errcode_ret */)CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clEnqueueUnmapMemObject)( + cl_command_queue /* command_queue */, cl_mem /* memobj */, + void * /* mapped_ptr */, cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clEnqueueMigrateMemObjects)( + cl_command_queue /* command_queue */, cl_uint /* num_mem_objects */, + const cl_mem * /* mem_objects */, cl_mem_migration_flags /* flags */, + cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; +typedef cl_int(CL_API_CALL *PFN_clEnqueueNDRangeKernel)( + cl_command_queue /* command_queue */, cl_kernel /* kernel */, + cl_uint /* work_dim */, const size_t * /* global_work_offset */, + const size_t * /* global_work_size */, const size_t * /* local_work_size */, + cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clEnqueueNativeKernel)( + cl_command_queue /* command_queue */, + void(CL_CALLBACK * /*user_func*/)(void *), void * /* args */, + size_t /* cb_args */, cl_uint /* num_mem_objects */, + const cl_mem * /* mem_list */, const void ** /* args_mem_loc */, + cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; +typedef cl_int(CL_API_CALL *PFN_clEnqueueMarkerWithWaitList)( + cl_command_queue /* command_queue */, cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; +typedef cl_int(CL_API_CALL *PFN_clEnqueueBarrierWithWaitList)( + cl_command_queue /* command_queue */, cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2; +typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMFree)( + cl_command_queue /* command_queue */, cl_uint /* num_svm_pointers */, + void *[] /* svm_pointers[] */, + void(CL_CALLBACK * /*pfn_free_func*/)(cl_command_queue /* queue */, + cl_uint /* num_svm_pointers */, + void *[] /* svm_pointers[] */, + void * /* user_data */), + void * /* user_data */, cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; +typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMMemcpy)( + cl_command_queue /* command_queue */, cl_bool /* blocking_copy */, + void * /* dst_ptr */, const void * /* src_ptr */, size_t /* size */, + cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; +typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMMemFill)( + cl_command_queue /* command_queue */, void * /* svm_ptr */, + const void * /* pattern */, size_t /* pattern_size */, size_t /* size */, + cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; +typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMMap)( + cl_command_queue /* command_queue */, cl_bool /* blocking_map */, + cl_map_flags /* flags */, void * /* svm_ptr */, size_t /* size */, + cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; +typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMUnmap)( + cl_command_queue /* command_queue */, void * /* svm_ptr */, + cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0; +typedef void *(CL_API_CALL *PFN_clGetExtensionFunctionAddressForPlatform)( + cl_platform_id /* platform */, + const char * /* func_name */)CL_API_SUFFIX__VERSION_1_2; +typedef cl_mem(CL_API_CALL *PFN_clCreateImage2D)( + cl_context /* context */, cl_mem_flags /* flags */, + const cl_image_format * /* image_format */, size_t /* image_width */, + size_t /* image_height */, size_t /* image_row_pitch */, + void * /* host_ptr */, cl_int * /* errcode_ret */); +typedef cl_mem(CL_API_CALL *PFN_clCreateImage3D)( + cl_context /* context */, cl_mem_flags /* flags */, + const cl_image_format * /* image_format */, size_t /* image_width */, + size_t /* image_height */, size_t /* image_depth */, + size_t /* image_row_pitch */, size_t /* image_slice_pitch */, + void * /* host_ptr */, cl_int * /* errcode_ret */); +typedef cl_int(CL_API_CALL *PFN_clEnqueueMarker)( + cl_command_queue /* command_queue */, cl_event * /* event */); +typedef cl_int(CL_API_CALL *PFN_clEnqueueWaitForEvents)( + cl_command_queue /* command_queue */, cl_uint /* num_events */, + const cl_event * /* event_list */); +typedef cl_int(CL_API_CALL *PFN_clEnqueueBarrier)( + cl_command_queue /* command_queue */); +typedef cl_int(CL_API_CALL *PFN_clUnloadCompiler)(); +typedef void *(CL_API_CALL *PFN_clGetExtensionFunctionAddress)( + const char * /* func_name */); +typedef cl_command_queue(CL_API_CALL *PFN_clCreateCommandQueue)( + cl_context /* context */, cl_device_id /* device */, + cl_command_queue_properties /* properties */, cl_int * /* errcode_ret */); +typedef cl_sampler(CL_API_CALL *PFN_clCreateSampler)( + cl_context /* context */, cl_bool /* normalized_coords */, + cl_addressing_mode /* addressing_mode */, cl_filter_mode /* filter_mode */, + cl_int * /* errcode_ret */); +typedef cl_int(CL_API_CALL *PFN_clEnqueueTask)( + cl_command_queue /* command_queue */, cl_kernel /* kernel */, + cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, cl_event * /* event */); + +// OpenGL sharing +typedef cl_mem(CL_API_CALL *PFN_clCreateFromGLBuffer)(cl_context, cl_mem_flags, + cl_GLuint, int *); +typedef cl_mem(CL_API_CALL *PFN_clCreateFromGLTexture)( + cl_context /* context */, cl_mem_flags /* flags */, cl_GLenum /* target */, + cl_GLint /* miplevel */, cl_GLuint /* texture */, + cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2; +typedef cl_int(CL_API_CALL *PFN_clEnqueueAcquireGLObjects)( + cl_command_queue /* command_queue */, cl_uint /* num_objects */, + const cl_mem * /* mem_objects */, cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, cl_event * /* event */); +typedef cl_int(CL_API_CALL *PFN_clEnqueueReleaseGLObjects)( + cl_command_queue /* command_queue */, cl_uint /* num_objects */, + const cl_mem * /* mem_objects */, cl_uint /* num_events_in_wait_list */, + const cl_event * /* event_wait_list */, + cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0; + +// cl_khr_egl_event extension + +// CLeglDisplayKHR is an opaque handle to an EGLDisplay +typedef void *CLeglDisplayKHR; + +// CLeglSyncKHR is an opaque handle to an EGLSync object +typedef void *CLeglSyncKHR; + +typedef cl_event(CL_API_CALL *PFN_clCreateEventFromEGLSyncKHR)( + cl_context /* context */, CLeglSyncKHR /* sync */, + CLeglDisplayKHR /* display */, cl_int * /* errcode_ret */); + +// EGL sharing +typedef cl_mem(CL_API_CALL *PFN_clCreateFromEGLImageKHR)( + cl_context /*context*/, CLeglDisplayKHR /*display*/, + CLeglImageKHR /*image*/, cl_mem_flags /*flags*/, + const cl_egl_image_properties_khr * /*properties*/, + cl_int * /*errcode_ret*/); +typedef cl_int(CL_API_CALL *PFN_clEnqueueAcquireEGLObjectsKHR)( + cl_command_queue /*command_queue*/, cl_uint /*num_objects*/, + const cl_mem * /*mem_objects*/, cl_uint /*num_events_in_wait_list*/, + const cl_event * /*event_wait_list*/, cl_event * /*event*/); +typedef cl_int(CL_API_CALL *PFN_clEnqueueReleaseEGLObjectsKHR)( + cl_command_queue /*command_queue*/, cl_uint /*num_objects*/, + const cl_mem * /*mem_objects*/, cl_uint /*num_events_in_wait_list*/, + const cl_event * /*event_wait_list*/, cl_event * /*event*/); + +// cl_khr_command_buffer +typedef cl_command_buffer_khr(CL_API_CALL *PFN_clCreateCommandBufferKHR)( + cl_uint /*num_queues*/, const cl_command_queue * /*queues*/, + const cl_command_buffer_properties_khr * /*properties*/, + cl_int * /*errcode_ret*/); + +typedef cl_int(CL_API_CALL *PFN_clRetainCommandBufferKHR)( + cl_command_buffer_khr /*command_buffer*/); + +typedef cl_int(CL_API_CALL *PFN_clReleaseCommandBufferKHR)( + cl_command_buffer_khr /*command_buffer*/); + +typedef cl_int(CL_API_CALL *PFN_clFinalizeCommandBufferKHR)( + cl_command_buffer_khr /*command_buffer*/); + +typedef cl_int(CL_API_CALL *PFN_clEnqueueCommandBufferKHR)( + cl_uint /*num_queues*/, cl_command_queue * /*queues*/, + cl_command_buffer_khr /*command_buffer*/, + cl_uint /*num_events_in_wait_list*/, const cl_event * /*event_wait_list*/, + cl_event * /*event*/); + +#if CL_KHR_COMMAND_BUFFER_EXTENSION_VERSION >= CL_MAKE_VERSION(0, 9, 5) +typedef cl_int(CL_API_CALL *PFN_clCommandNDRangeKernelKHR)( + cl_command_buffer_khr /*command_buffer*/, + cl_command_queue /*command_queue*/, + const cl_command_properties_khr * /*properties*/, cl_kernel /*kernel*/, + cl_uint /*work_dim*/, const size_t * /*global_work_offset*/, + const size_t * /*global_work_size*/, const size_t * /*local_work_size*/, + cl_uint /*num_sync_points_in_wait_list*/, + const cl_sync_point_khr * /*sync_point_wait_list*/, + cl_sync_point_khr * /*sync_point*/, + cl_mutable_command_khr * /*mutable_handle*/); +#else +typedef cl_int(CL_API_CALL *PFN_clCommandNDRangeKernelKHR)( + cl_command_buffer_khr /*command_buffer*/, + cl_command_queue /*command_queue*/, + const cl_ndrange_kernel_command_properties_khr * /*properties*/, + cl_kernel /*kernel*/, cl_uint /*work_dim*/, + const size_t * /*global_work_offset*/, const size_t * /*global_work_size*/, + const size_t * /*local_work_size*/, + cl_uint /*num_sync_points_in_wait_list*/, + const cl_sync_point_khr * /*sync_point_wait_list*/, + cl_sync_point_khr * /*sync_point*/, + cl_mutable_command_khr * /*mutable_handle*/); +#endif + +typedef cl_int(CL_API_CALL *PFN_clGetCommandBufferInfoKHR)( + cl_command_buffer_khr /*command_buffer*/, + cl_command_buffer_info_khr /*param_name*/, size_t /*param_value_size*/, + void * /*param_value*/, size_t * /*param_value_size_ret*/); + +// OpenCL 3.0 +typedef cl_mem(CL_API_CALL *PFN_clCreateBufferWithProperties)( + cl_context /*context*/, const cl_mem_properties * /*properties*/, + cl_mem_flags /*flags*/, size_t /*size*/, void * /*host_ptr*/, + cl_int * /*errcode_ret*/); +typedef cl_mem(CL_API_CALL *PFN_clCreateImageWithProperties)( + cl_context /*context*/, const cl_mem_properties * /*properties*/, + cl_mem_flags /*flags*/, const cl_image_format * /*image_format*/, + const cl_image_desc * /*image_desc*/, void * /*host_ptr*/, + cl_int * /*errcode_ret*/); + +extern PFN_clGetPlatformIDs clGetPlatformIDs; +extern PFN_clGetPlatformInfo clGetPlatformInfo; +extern PFN_clGetDeviceIDs clGetDeviceIDs; +extern PFN_clGetDeviceInfo clGetDeviceInfo; +extern PFN_clCreateSubDevices clCreateSubDevices; +extern PFN_clRetainDevice clRetainDevice; +extern PFN_clReleaseDevice clReleaseDevice; +extern PFN_clCreateContext clCreateContext; +extern PFN_clCreateContextFromType clCreateContextFromType; +extern PFN_clRetainContext clRetainContext; +extern PFN_clReleaseContext clReleaseContext; +extern PFN_clGetContextInfo clGetContextInfo; +extern PFN_clCreateCommandQueueWithProperties + clCreateCommandQueueWithProperties; +extern PFN_clRetainCommandQueue clRetainCommandQueue; +extern PFN_clReleaseCommandQueue clReleaseCommandQueue; +extern PFN_clGetCommandQueueInfo clGetCommandQueueInfo; +extern PFN_clCreateBuffer clCreateBuffer; +extern PFN_clCreateSubBuffer clCreateSubBuffer; +extern PFN_clCreateImage clCreateImage; +extern PFN_clCreatePipe clCreatePipe; +extern PFN_clRetainMemObject clRetainMemObject; +extern PFN_clReleaseMemObject clReleaseMemObject; +extern PFN_clGetSupportedImageFormats clGetSupportedImageFormats; +extern PFN_clGetMemObjectInfo clGetMemObjectInfo; +extern PFN_clGetImageInfo clGetImageInfo; +extern PFN_clGetPipeInfo clGetPipeInfo; +extern PFN_clSetMemObjectDestructorCallback clSetMemObjectDestructorCallback; +extern PFN_clSVMAlloc clSVMAlloc; +extern PFN_clSVMFree clSVMFree; +extern PFN_clCreateSamplerWithProperties clCreateSamplerWithProperties; +extern PFN_clRetainSampler clRetainSampler; +extern PFN_clReleaseSampler clReleaseSampler; +extern PFN_clGetSamplerInfo clGetSamplerInfo; +extern PFN_clCreateProgramWithSource clCreateProgramWithSource; +extern PFN_clCreateProgramWithBinary clCreateProgramWithBinary; +extern PFN_clCreateProgramWithBuiltInKernels clCreateProgramWithBuiltInKernels; +extern PFN_clRetainProgram clRetainProgram; +extern PFN_clReleaseProgram clReleaseProgram; +extern PFN_clBuildProgram clBuildProgram; +extern PFN_clCompileProgram clCompileProgram; +extern PFN_clLinkProgram clLinkProgram; +extern PFN_clUnloadPlatformCompiler clUnloadPlatformCompiler; +extern PFN_clGetProgramInfo clGetProgramInfo; +extern PFN_clGetProgramBuildInfo clGetProgramBuildInfo; +extern PFN_clCreateKernel clCreateKernel; +extern PFN_clCreateKernelsInProgram clCreateKernelsInProgram; +extern PFN_clRetainKernel clRetainKernel; +extern PFN_clReleaseKernel clReleaseKernel; +extern PFN_clSetKernelArg clSetKernelArg; +extern PFN_clSetKernelArgSVMPointer clSetKernelArgSVMPointer; +extern PFN_clSetKernelExecInfo clSetKernelExecInfo; +extern PFN_clGetKernelInfo clGetKernelInfo; +extern PFN_clGetKernelArgInfo clGetKernelArgInfo; +extern PFN_clGetKernelWorkGroupInfo clGetKernelWorkGroupInfo; +extern PFN_clWaitForEvents clWaitForEvents; +extern PFN_clGetEventInfo clGetEventInfo; +extern PFN_clCreateUserEvent clCreateUserEvent; +extern PFN_clRetainEvent clRetainEvent; +extern PFN_clReleaseEvent clReleaseEvent; +extern PFN_clSetUserEventStatus clSetUserEventStatus; +extern PFN_clSetEventCallback clSetEventCallback; +extern PFN_clGetEventProfilingInfo clGetEventProfilingInfo; +extern PFN_clFlush clFlush; +extern PFN_clFinish clFinish; +extern PFN_clEnqueueReadBuffer clEnqueueReadBuffer; +extern PFN_clEnqueueReadBufferRect clEnqueueReadBufferRect; +extern PFN_clEnqueueWriteBuffer clEnqueueWriteBuffer; +extern PFN_clEnqueueWriteBufferRect clEnqueueWriteBufferRect; +extern PFN_clEnqueueFillBuffer clEnqueueFillBuffer; +extern PFN_clEnqueueCopyBuffer clEnqueueCopyBuffer; +extern PFN_clEnqueueCopyBufferRect clEnqueueCopyBufferRect; +extern PFN_clEnqueueReadImage clEnqueueReadImage; +extern PFN_clEnqueueWriteImage clEnqueueWriteImage; +extern PFN_clEnqueueFillImage clEnqueueFillImage; +extern PFN_clEnqueueCopyImage clEnqueueCopyImage; +extern PFN_clEnqueueCopyImageToBuffer clEnqueueCopyImageToBuffer; +extern PFN_clEnqueueCopyBufferToImage clEnqueueCopyBufferToImage; +extern PFN_clEnqueueMapBuffer clEnqueueMapBuffer; +extern PFN_clEnqueueMapImage clEnqueueMapImage; +extern PFN_clEnqueueUnmapMemObject clEnqueueUnmapMemObject; +extern PFN_clEnqueueMigrateMemObjects clEnqueueMigrateMemObjects; +extern PFN_clEnqueueNDRangeKernel clEnqueueNDRangeKernel; +extern PFN_clEnqueueNativeKernel clEnqueueNativeKernel; +extern PFN_clEnqueueMarkerWithWaitList clEnqueueMarkerWithWaitList; +extern PFN_clEnqueueBarrierWithWaitList clEnqueueBarrierWithWaitList; +extern PFN_clEnqueueSVMFree clEnqueueSVMFree; +extern PFN_clEnqueueSVMMemcpy clEnqueueSVMMemcpy; +extern PFN_clEnqueueSVMMemFill clEnqueueSVMMemFill; +extern PFN_clEnqueueSVMMap clEnqueueSVMMap; +extern PFN_clEnqueueSVMUnmap clEnqueueSVMUnmap; +extern PFN_clGetExtensionFunctionAddressForPlatform + clGetExtensionFunctionAddressForPlatform; +extern PFN_clCreateImage2D clCreateImage2D; +extern PFN_clCreateImage3D clCreateImage3D; +extern PFN_clEnqueueMarker clEnqueueMarker; +extern PFN_clEnqueueWaitForEvents clEnqueueWaitForEvents; +extern PFN_clEnqueueBarrier clEnqueueBarrier; +extern PFN_clUnloadCompiler clUnloadCompiler; +extern PFN_clGetExtensionFunctionAddress clGetExtensionFunctionAddress; +extern PFN_clCreateCommandQueue clCreateCommandQueue; +extern PFN_clCreateSampler clCreateSampler; +extern PFN_clEnqueueTask clEnqueueTask; + +// OpenGL sharing +extern PFN_clCreateFromGLBuffer clCreateFromGLBuffer; +extern PFN_clCreateFromGLTexture clCreateFromGLTexture; +extern PFN_clEnqueueAcquireGLObjects clEnqueueAcquireGLObjects; +extern PFN_clEnqueueReleaseGLObjects clEnqueueReleaseGLObjects; + +// cl_khr_egl_event extension +extern PFN_clCreateEventFromEGLSyncKHR clCreateEventFromEGLSyncKHR; + +// EGL sharing +extern PFN_clCreateFromEGLImageKHR clCreateFromEGLImageKHR; +extern PFN_clEnqueueAcquireEGLObjectsKHR clEnqueueAcquireEGLObjectsKHR; +extern PFN_clEnqueueReleaseEGLObjectsKHR clEnqueueReleaseEGLObjectsKHR; + +// cl_khr_command_buffer extension +extern PFN_clCreateCommandBufferKHR clCreateCommandBufferKHR; +extern PFN_clRetainCommandBufferKHR clRetainCommandBufferKHR; +extern PFN_clReleaseCommandBufferKHR clReleaseCommandBufferKHR; +extern PFN_clFinalizeCommandBufferKHR clFinalizeCommandBufferKHR; +extern PFN_clEnqueueCommandBufferKHR clEnqueueCommandBufferKHR; +extern PFN_clCommandNDRangeKernelKHR clCommandNDRangeKernelKHR; +extern PFN_clGetCommandBufferInfoKHR clGetCommandBufferInfoKHR; + +// OpenCL 3.0 +extern PFN_clCreateBufferWithProperties clCreateBufferWithProperties; +extern PFN_clCreateImageWithProperties clCreateImageWithProperties; + +// For convenient image creation +// It uses clCreateImage if it available (clCreateImage available since cl 1.2) +// otherwise it will use legacy clCreateImage2D +cl_mem CreateImage2DLegacy(cl_context context, cl_mem_flags flags, + const cl_image_format *image_format, + const cl_image_desc *image_desc, void *host_ptr, + cl_int *errcode_ret); + +// It uses clCreateImage if it available (clCreateImage available since cl 1.2) +// otherwise it will use legacy clCreateImage3D +cl_mem CreateImage3DLegacy(cl_context context, cl_mem_flags flags, + const cl_image_format *image_format, + const cl_image_desc *image_desc, void *host_ptr, + cl_int *errcode_ret); + +} // namespace cl +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_OPENCL_OPENCL_WRAPPER_H_ diff --git a/tensorflow/lite/experimental/litert/runtime/tensor_buffer.cc b/tensorflow/lite/experimental/litert/runtime/tensor_buffer.cc index 5ac2f023393e5f..dda81d5ab516bd 100644 --- a/tensorflow/lite/experimental/litert/runtime/tensor_buffer.cc +++ b/tensorflow/lite/experimental/litert/runtime/tensor_buffer.cc @@ -26,11 +26,11 @@ #include "tensorflow/lite/experimental/litert/c/litert_event.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" #include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/cc/litert_event.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" #include "tensorflow/lite/experimental/litert/core/util/tensor_type_util.h" #include "tensorflow/lite/experimental/litert/runtime/ahwb_buffer.h" #include "tensorflow/lite/experimental/litert/runtime/dmabuf_buffer.h" -#include "tensorflow/lite/experimental/litert/runtime/event.h" #include "tensorflow/lite/experimental/litert/runtime/fastrpc_buffer.h" #include "tensorflow/lite/experimental/litert/runtime/ion_buffer.h" @@ -402,10 +402,9 @@ Expected LiteRtTensorBufferT::Lock(LiteRtEvent event) { // Only AHWB supports waiting on an input sync fence when locking the // buffer. For all other buffer types we wait here. if (buffer_type() != kLiteRtTensorBufferTypeAhwb) { - if (auto status = event->Wait(/*timeout_in_ms*/ -1); - status != kLiteRtStatusOk) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to wait on input event"); + litert::Event e(event, /*owned=*/false); + if (auto status = e.Wait(/*timeout_in_ms=*/-1); !status) { + return status.Error(); } } } diff --git a/tensorflow/lite/experimental/litert/runtime/tensor_buffer.h b/tensorflow/lite/experimental/litert/runtime/tensor_buffer.h index 7b75d3d02ce50e..7997c9073bd85a 100644 --- a/tensorflow/lite/experimental/litert/runtime/tensor_buffer.h +++ b/tensorflow/lite/experimental/litert/runtime/tensor_buffer.h @@ -16,9 +16,12 @@ #define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_RUNTIME_TENSOR_BUFFER_H_ #include +#include #include +#include #include #include +#include #include "absl/types/span.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" @@ -73,6 +76,19 @@ class LiteRtTensorBufferT { size_t buffer_size() const { return buffer_size_; } size_t buffer_offset() const { return buffer_offset_; } + bool HasEvent() const { return event_.has_value(); } + + litert::Expected GetEvent() const { + if (!HasEvent()) { + return litert::Error(kLiteRtStatusErrorRuntimeFailure, + "TensorBuffer has no event"); + } + return *event_; + } + + void SetEvent(LiteRtEvent e) { event_ = e; } + void ClearEvent() { event_ = std::nullopt; } + litert::Expected GetHostBuffer(); litert::Expected GetAhwbBuffer(); litert::Expected> GetIonBuffer(); @@ -160,6 +176,7 @@ class LiteRtTensorBufferT { size_t buffer_offset_; std::variant buffer_; + std::optional event_; mutable std::atomic_int_fast32_t ref_; }; diff --git a/tensorflow/lite/experimental/litert/runtime/tfl_utils.cc b/tensorflow/lite/experimental/litert/runtime/tfl_utils.cc index 2104acdf12c7bf..d77bd0b58e4f9f 100644 --- a/tensorflow/lite/experimental/litert/runtime/tfl_utils.cc +++ b/tensorflow/lite/experimental/litert/runtime/tfl_utils.cc @@ -87,7 +87,7 @@ Expected ConvertTensorType( } size_t rank = TfLiteOpaqueTensorNumDims(tfl_opaque_tensor); - SmallVec dimensions(rank); + Dimensions dimensions(rank); for (size_t i = 0; i < rank; ++i) { dimensions[i] = TfLiteOpaqueTensorDim(tfl_opaque_tensor, i); } diff --git a/tensorflow/lite/experimental/litert/test/BUILD b/tensorflow/lite/experimental/litert/test/BUILD index 947d577a21cf12..d5864c0a68f519 100644 --- a/tensorflow/lite/experimental/litert/test/BUILD +++ b/tensorflow/lite/experimental/litert/test/BUILD @@ -60,6 +60,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/core:cc_api_stable", "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_logging", "//tensorflow/lite/experimental/litert/cc:litert_expected", "//tensorflow/lite/experimental/litert/cc:litert_model", "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", @@ -67,7 +68,10 @@ cc_library( "//tensorflow/lite/experimental/litert/core/model:model_buffer", "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", "//tensorflow/lite/kernels:builtin_ops", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform", ], ) diff --git a/tensorflow/lite/experimental/litert/test/common.cc b/tensorflow/lite/experimental/litert/test/common.cc index 163834e4c419c7..bb212e09382f36 100644 --- a/tensorflow/lite/experimental/litert/test/common.cc +++ b/tensorflow/lite/experimental/litert/test/common.cc @@ -14,13 +14,23 @@ #include "tensorflow/lite/experimental/litert/test/common.h" +#include +#include +#include // NOLINT +#include #include +#include +#include #include #include #include +#include "absl/base/attributes.h" +#include "absl/base/const_init.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" #include "tensorflow/lite/experimental/litert/cc/litert_model.h" #include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" @@ -33,6 +43,38 @@ namespace litert { namespace testing { +Expected UniqueTestDirectory::Create() { + constexpr size_t kMaxTries = 1000; + ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit); + + // We don't want multiple threads to create the same directory. + absl::MutexLock l(&mutex); + + auto tmp_dir = std::filesystem::temp_directory_path(); + std::random_device dev; + std::mt19937 prng(dev()); + std::uniform_int_distribution rand(0); + std::stringstream ss; + + for (auto i = 0; i < kMaxTries; ++i) { + ss.clear(); + ss << std::hex << rand(prng); + auto path = tmp_dir / ss.str(); + if (std::filesystem::create_directory(path)) { + LITERT_LOG(LITERT_INFO, "Created unique temporary directory %s", + path.c_str()); + return UniqueTestDirectory(path); + } + } + + return Error(kLiteRtStatusErrorRuntimeFailure, + "Could not create a unique temporary directory"); +} + +UniqueTestDirectory::~UniqueTestDirectory() { + std::filesystem::remove_all(tmpdir_); +} + std::string GetTestFilePath(absl::string_view filename) { static constexpr absl::string_view kTestDataDir = "tensorflow/lite/experimental/litert/" @@ -49,28 +91,6 @@ Model LoadTestFileModel(absl::string_view filename) { return *Model::CreateFromFile(GetTestFilePath(filename)); } -bool ValidateTopology(const std::vector& ops) { - for (const auto& op : ops) { - const auto inputs = op.Inputs(); - for (int i = 0; i < inputs.size(); ++i) { - if (!MatchUse(inputs.at(i), UseInfo{op.Code(), i})) { - return false; - } - } - const auto outputs = op.Outputs(); - for (int i = 0; i < outputs.size(); ++i) { - const auto defining_op = outputs.at(i).DefiningOp(); - if (!defining_op.has_value()) { - return false; - } - if (defining_op->op != op.Get() || defining_op->op_output_index != i) { - return false; - } - } - } - return true; -} - Expected TflRuntime::CreateFromFlatBuffer( internal::FlatbufferWrapper::Ptr flatbuffer) { ::tflite::Interpreter::Ptr interp; diff --git a/tensorflow/lite/experimental/litert/test/common.h b/tensorflow/lite/experimental/litert/test/common.h index 4a1a455a8365c0..6b6148c1802040 100644 --- a/tensorflow/lite/experimental/litert/test/common.h +++ b/tensorflow/lite/experimental/litert/test/common.h @@ -29,12 +29,29 @@ namespace litert { namespace testing { +// A x-platform compatible replacement for testing::UniqueTestDirectory. +class UniqueTestDirectory { + public: + static Expected Create(); + ~UniqueTestDirectory(); + + UniqueTestDirectory(const UniqueTestDirectory&) = delete; + UniqueTestDirectory(UniqueTestDirectory&&) = default; + UniqueTestDirectory& operator=(const UniqueTestDirectory&) = delete; + UniqueTestDirectory& operator=(UniqueTestDirectory&&) = default; + + absl::string_view Str() const { return tmpdir_; } + + private: + explicit UniqueTestDirectory(std::string&& tmpdir) + : tmpdir_(std::move(tmpdir)) {} + std::string tmpdir_; +}; + std::string GetTestFilePath(absl::string_view filename); Model LoadTestFileModel(absl::string_view filename); -bool ValidateTopology(const std::vector& ops); - class TflRuntime { public: using Ptr = std::unique_ptr; diff --git a/tensorflow/lite/experimental/litert/test/testdata/cst_multi_subgraph.mlir b/tensorflow/lite/experimental/litert/test/testdata/cst_multi_subgraph.mlir new file mode 100644 index 00000000000000..8a11bf4f58ba4f --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/cst_multi_subgraph.mlir @@ -0,0 +1,12 @@ +module { + func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %0 = "tfl.pseudo_const"() <{value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32>}> : () -> tensor<4xf32> + %1 = tfl.mul %arg0, %0 {fused_activation_function = "NONE"} : tensor<4xf32> + return %1 : tensor<4xf32> + } + func.func @other(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %0 = "tfl.pseudo_const"() <{value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32>}> : () -> tensor<4xf32> + %1 = tfl.mul %arg0, %0 {fused_activation_function = "NONE"} : tensor<4xf32> + return %1 : tensor<4xf32> + } +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph.mlir b/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph.mlir new file mode 100644 index 00000000000000..7c1f0fe4e0f5b0 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph.mlir @@ -0,0 +1,21 @@ +module { + +func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %cst = arith.constant dense<[-1.0, -1.0, -1.0, -1.0]> : tensor<4xf32> + %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<4xf32> + return %0 : tensor<4xf32> +} + +func.func @func1(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %cst = arith.constant dense<[1.0, 1.0, 1.0, 1.0]> : tensor<4xf32> + %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<4xf32> + return %0 : tensor<4xf32> +} + +func.func @func2(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %cst = arith.constant dense<[2.0, 2.0, 2.0, 2.0]> : tensor<4xf32> + %0 = tfl.add %arg0, %cst {fused_activation_function = "NONE"} : tensor<4xf32> + return %0 : tensor<4xf32> +} + +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph_mul.mlir b/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph_mul.mlir new file mode 100644 index 00000000000000..607100dbc389b6 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/multi_subgraph_mul.mlir @@ -0,0 +1,13 @@ +module { + +func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +func.func @func1(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<4x4xf32> + return %0 : tensor<4x4xf32> +} + +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_gelu_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_gelu_op.mlir new file mode 100644 index 00000000000000..39ebcf24e972d0 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_gelu_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<8x100x1xf32>) -> tensor<8x100x1xf32> { + %0 = "tfl.gelu"(%arg0) : (tensor<8x100x1xf32>) -> tensor<8x100x1xf32> + return %0 : tensor<8x100x1xf32> +} +} diff --git a/tensorflow/lite/experimental/litert/tools/BUILD b/tensorflow/lite/experimental/litert/tools/BUILD index dde0fd157399e6..78ea2cdcb53233 100644 --- a/tensorflow/lite/experimental/litert/tools/BUILD +++ b/tensorflow/lite/experimental/litert/tools/BUILD @@ -29,13 +29,13 @@ cc_library( ":outstream", ":tool_display", "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_logging", "//tensorflow/lite/experimental/litert/c:litert_model", "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", "//tensorflow/lite/experimental/litert/cc:litert_detail", "//tensorflow/lite/experimental/litert/cc:litert_expected", "//tensorflow/lite/experimental/litert/cc:litert_macros", "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/compiler/plugin:algo", "//tensorflow/lite/experimental/litert/compiler/plugin:compiler_plugin", "//tensorflow/lite/experimental/litert/core:byte_code_util", "//tensorflow/lite/experimental/litert/core/model:model_serialize", diff --git a/tensorflow/lite/experimental/litert/tools/apply_plugin.cc b/tensorflow/lite/experimental/litert/tools/apply_plugin.cc index af23240aabe9a2..f16108215daee1 100644 --- a/tensorflow/lite/experimental/litert/tools/apply_plugin.cc +++ b/tensorflow/lite/experimental/litert/tools/apply_plugin.cc @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -29,12 +28,12 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" #include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" #include "tensorflow/lite/experimental/litert/cc/litert_macros.h" #include "tensorflow/lite/experimental/litert/cc/litert_model.h" -#include "tensorflow/lite/experimental/litert/compiler/plugin/algo.h" #include "tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h" #include "tensorflow/lite/experimental/litert/core/byte_code_util.h" #include "tensorflow/lite/experimental/litert/core/model/model_serialize.h" @@ -45,19 +44,11 @@ namespace litert::tools { using ::litert::BufferRef; -using ::litert::OwningBufferRef; using ::litert::internal::CompilerPlugin; using ::litert::internal::Dump; -using ::litert::internal::FinishByteCodePlaceholders; -using ::litert::internal::GroupPartitions; -using ::litert::internal::kByteCodeMetadataKey; -using ::litert::internal::kLiteRtBuildStampKey; -using ::litert::internal::kLiteRtDispatchOpCustomCode; -using ::litert::internal::MakeBuildStamp; -using ::litert::internal::MakeByteCodePlaceholder; -using ::litert::internal::MakeExecInfo; -using ::litert::internal::OutlinePartition; +using ::litert::internal::PartitionResult; using ::litert::internal::Serialization; +using ::litert::internal::SerializeModel; using ::litert::internal::VerifyFlatbuffer; using ::litert::tools::ApplyPluginRun; @@ -119,6 +110,41 @@ class Context { ToolDisplay display_; }; +void DumpSubgraphs(ToolDisplay& display, absl::string_view label, + absl::Span subgraphs) { + for (auto* subgraph : subgraphs) { + display.Labeled(); + display.Indented() << absl::StreamFormat("(%s graph)", label); + Dump(*subgraph, display.Display()); + } +} + +void DumpCompilationRequest(ToolDisplay& display, absl::string_view soc_model, + size_t num_subgraphs) { + display.Labeled() << absl::StreamFormat( + "Requesting compilation for target `%s` on %lu partitions\n", soc_model, + num_subgraphs); +} + +void DumpCompilationResult(ToolDisplay& display, size_t byte_code_size, + size_t num_entry_points) { + display.Labeled() << absl::StreamFormat( + "Compiled %lu partitions into %lu bytes\n", num_entry_points, + byte_code_size); +} + +void DumpModelStats(ToolDisplay& display, BufferRef buf) { + display.Labeled() << absl::StreamFormat( + "Serialized a model of size %lu bytes\n", buf.Size()); +} + +void DumpPartitionResult(ToolDisplay& display, const PartitionResult& result) { + display.Labeled() << absl::StreamFormat( + "Partitioning yielded %lu new subgraphs\n", result.second.Size()); + + DumpSubgraphs(display, "new subgraphs", result.second.Elements()); +} + absl::string_view Context::CmdStr(ApplyPluginRun::Cmd cmd) { switch (cmd) { case ApplyPluginRun::Cmd::INFO: @@ -134,12 +160,7 @@ absl::string_view Context::CmdStr(ApplyPluginRun::Cmd cmd) { } } -void DumpModelStats(Context& ctx, BufferRef buf) { - ctx.Dump().Labeled() << absl::StreamFormat( - "Serialized a model of size %lu bytes\n", buf.Size()); -} - -Expected> LoadAllPlugins(Context& ctx) { +Expected> LoadAllPlugins(Context& ctx) { ctx.Dump().Start("Load Plugins"); ctx.Dump().Labeled() << "Loading plugins from: "; const auto paths = ctx.LibSearchPaths(); @@ -203,91 +224,6 @@ Expected LoadModel(Context& ctx) { return model_result; } -std::vector ApplyPartition(Context& ctx, const Model& model, - CompilerPlugin& plugin) { - ctx.Dump().Start("Partition Model"); - model.Get()->custom_op_code = kLiteRtDispatchOpCustomCode; - - ctx.Dump().Labeled() << "Input model: \n"; - for (auto it = model.Get()->subgraphs.begin(); - it < model.Get()->subgraphs.end(); ++it) { - ctx.Dump().Labeled(); - ctx.Dump().Indented() << "(input graph) "; - Dump(*it, ctx.Dump().Display()); - } - - auto partition = plugin.PartitionModel(model); - if (!partition.HasValue()) { - return {}; - } - auto grouped_partitions = GroupPartitions(partition.Value()); - if (grouped_partitions.empty()) { - return {}; - } - ctx.Dump().Labeled() << absl::StreamFormat( - "Plugin selected %lu ops, yielding %lu partitions\n", - partition.Value().size(), grouped_partitions.size()); - - std::vector res; - for (auto& partition : grouped_partitions) { - LiteRtOp custom_op = - OutlinePartition(model.Get()->subgraphs.front(), - &model.Get()->subgraphs.emplace_back(), partition); - res.push_back(custom_op); - } - - ctx.Dump().Labeled() << "Partitioned model: \n"; - ctx.Dump().Labeled(); - ctx.Dump().Indented() << "(initial graph) "; - Dump(model.Get()->subgraphs.front(), ctx.Dump().Display()); - for (auto it = model.Get()->subgraphs.begin() + 1; - it < model.Get()->subgraphs.end(); ++it) { - ctx.Dump().Labeled(); - ctx.Dump().Indented() << "(new graph) "; - Dump(*it, ctx.Dump().Display()); - } - - ctx.Dump().Done(); - return res; -} - -Expected PartitionModel(Context& ctx, Model&& model, - CompilerPlugin& plugin) { - auto custom_ops = ApplyPartition(ctx, model, plugin); - if (custom_ops.empty()) { - return Unexpected(kLiteRtStatusErrorGraphModification); - } - return std::move(model); -} - -Expected> CompilePartitions( - Context& ctx, std::vector& partitions, - CompilerPlugin& plugin) { - ctx.Dump().Start("Compile Model"); - ctx.Dump().Labeled() << absl::StreamFormat( - "Requesting compilation for target \"%s\" on %lu subgraphs\n", - ctx.SocModelTarget(), partitions.size()); - - std::vector call_info_out; - if (plugin.Compile(ctx.SocModelTarget(), partitions, ctx.Out(), - call_info_out) != kLiteRtStatusOk) { - ctx.Dump().Fail(); - return Unexpected(kLiteRtStatusErrorCompilation); - } - - ctx.Dump().Labeled() << "Entry point info: "; - for (auto it = call_info_out.begin(); it < call_info_out.end(); ++it) { - ctx.Dump().Display() << absl::StreamFormat("\"%s\"", *it); - if (it < call_info_out.end() - 1) { - ctx.Dump().Display() << ", "; - } - } - ctx.Dump().Display() << "\n"; - - ctx.Dump().Done(); - return std::move(call_info_out); -} - // // INFO Command // @@ -335,7 +271,7 @@ LiteRtStatus Noop(Context& ctx) { return model.Error().Status(); } - auto serialized = SerializeModel(std::move(*model)); + auto serialized = SerializeModel(std::move(*model->Get())); if (!serialized) { return serialized.Error().Status(); } @@ -364,19 +300,26 @@ LiteRtStatus Partition(Context& ctx) { return plugin.Error().Status(); } - auto model = LoadModel(ctx); - if (!model) { - return model.Error().Status(); + auto model_wrap = LoadModel(ctx); + if (!model_wrap) { + return model_wrap.Error().Status(); } + auto& model = *model_wrap->Get(); - auto partitioned_model = PartitionModel(ctx, std::move(*model), *plugin); - if (!partitioned_model) { - return partitioned_model.Error().Status(); + ctx.Dump().Start("Partitioning model"); + auto partition_result = PartitionModel(*plugin, model); + if (!partition_result) { + return partition_result.Error().Status(); } + ctx.Dump().Done(); + DumpPartitionResult(ctx.Dump(), *partition_result); + + auto& new_subgraphs = partition_result->second; + model.TransferSubgraphs(std::move(new_subgraphs)); ctx.Dump().Start("Serializing model"); - auto serialized = SerializeModel(std::move(*partitioned_model)); - DumpModelStats(ctx, *serialized); + auto serialized = SerializeModel(std::move(model)); + DumpModelStats(ctx.Dump(), *serialized); ctx.Dump().Done(); ctx.Dump().Start("Verifying flatbuffer"); @@ -408,137 +351,50 @@ LiteRtStatus ValidateCompileRun(const ApplyPluginRun& run) { } LiteRtStatus Compile(Context& ctx) { - auto model = LoadModel(ctx); - if (!model) { - return model.Error().Status(); + auto model_wrap = LoadModel(ctx); + if (!model_wrap) { + return model_wrap.Error().Status(); } + auto& model = *model_wrap->Get(); auto plugin = LoadPlugin(ctx); if (!plugin) { return plugin.Error().Status(); } - std::vector compilation_input; - compilation_input.reserve(model->Get()->subgraphs.size()); - for (auto& subgraph : model->Get()->subgraphs) { - compilation_input.push_back(&subgraph); - } - - auto entry_points = CompilePartitions(ctx, compilation_input, *plugin); - if (!entry_points) { - return entry_points.Error().Status(); - } - - return kLiteRtStatusOk; -} - -// -// APPLY Command -// - -LiteRtStatus StampModel(Context& ctx, LiteRtModel model) { - auto stamp = MakeBuildStamp(ctx.SocManufacturer(), ctx.SocModelTarget(), - ctx.Serialization()); - if (!stamp) { - return stamp.Error().Status(); + ctx.Dump().Start("Compiling"); + DumpCompilationRequest(ctx.Dump(), ctx.SocModelTarget(), + model.NumSubgraphs()); + auto compilation_result = + plugin->Compile(model.Subgraphs(), ctx.SocModelTarget()); + if (!compilation_result) { + ctx.Dump().Fail(); + return compilation_result.Error().Status(); } - ctx.Dump().Labeled() << absl::StreamFormat("Stamping model: %s\n", - stamp->StrView()); - return model->PushMetadata(kLiteRtBuildStampKey, *stamp); -} -Expected> DoMetadataSerialization( - Context& ctx, std::vector& custom_ops, - std::vector& call_info, BufferRef compilation_out, - Model&& model) { - ctx.Dump().Start("Serializing with bytecode in METADATA"); - - { - auto call_it = call_info.begin(); - auto custom_op_it = custom_ops.begin(); - for (; call_it < call_info.end() && custom_op_it < custom_ops.end();) { - (*custom_op_it)->custom_options = - OwningBufferRef((*call_it).c_str()); - ++call_it; - ++custom_op_it; - } + auto byte_code = compilation_result->ByteCode(); + if (!byte_code) { + ctx.Dump().Fail(); + return compilation_result.Error().Status(); } - { - ctx.Dump().Labeled() << absl::StreamFormat( - "Adding metadata byte code of size: %lu bytes\n", - compilation_out.Size()); - - LITERT_EXPECT_OK( - model.Get()->PushMetadata(kByteCodeMetadataKey, compilation_out)); + auto num_calls = compilation_result->NumCalls(); + if (!num_calls) { + ctx.Dump().Fail(); + return compilation_result.Error().Status(); } - auto serialized = SerializeModel(std::move(model)); - if (!serialized) { - return serialized.Error(); - } + DumpCompilationResult(ctx.Dump(), byte_code->Size(), *num_calls); - ctx.Dump().Labeled() << absl::StreamFormat( - "Serialized model of size: %lu bytes\n", serialized->Size()); - if (!VerifyFlatbuffer(serialized->Span())) { - ctx.Dump().Fail(); - return Unexpected(kLiteRtStatusErrorInvalidFlatbuffer); - } + byte_code->WriteStr(ctx.Out()); ctx.Dump().Done(); - return serialized; + return kLiteRtStatusOk; } -Expected> DoAppendSerialization( - Context& ctx, std::vector& custom_ops, - std::vector& call_info, BufferRef compilation_out, - Model&& model) { - ctx.Dump().Start("Serializing with bytecode APPEND"); - - // This need not be the same for all custom ops. - static constexpr absl::string_view kSharedByteCodePlaceholderName = - kByteCodeMetadataKey; - LITERT_EXPECT_OK(model.Get()->PushMetadata(kSharedByteCodePlaceholderName, - MakeByteCodePlaceholder())); - - { - auto call_it = call_info.begin(); - auto custom_op_it = custom_ops.begin(); - for (; call_it < call_info.end() && custom_op_it < custom_ops.end();) { - auto exec_info = MakeExecInfo(*call_it, kSharedByteCodePlaceholderName); - if (!exec_info) { - return exec_info; - } - (*custom_op_it)->custom_options = std::move(*exec_info); - ++call_it; - ++custom_op_it; - } - } - - auto serialized = SerializeModel(std::move(model)); - if (!serialized) { - return serialized; - } - - ctx.Dump().Labeled() << absl::StreamFormat( - "Serialized model of size: %lu bytes\n", serialized->Size()); - LITERT_EXPECT_OK( - FinishByteCodePlaceholders(*serialized, compilation_out.Size())); - - OwningBufferRef with_append(serialized->Size() + - compilation_out.Size()); - - uint8_t* write = with_append.Data(); - std::memcpy(write, serialized->Data(), serialized->Size()); - write += serialized->Size(); - std::memcpy(write, compilation_out.Data(), compilation_out.Size()); - - ctx.Dump().Labeled() << absl::StreamFormat("Appended byte code of size %lu\n", - compilation_out.Size()); - - ctx.Dump().Done(); - return with_append; -} +// +// APPLY Command +// LiteRtStatus ValidateApplyRun(const ApplyPluginRun& run) { LITERT_ENSURE_CONFIG(!run.lib_search_paths.empty()); @@ -554,82 +410,41 @@ LiteRtStatus ValidateApplyRun(const ApplyPluginRun& run) { } LiteRtStatus Apply(Context& ctx) { - auto model = LoadModel(ctx); - if (!model) { - return model.Error().Status(); + auto model_wrap = LoadModel(ctx); + if (!model_wrap) { + return model_wrap.Error().Status(); } + auto& model = *model_wrap->Get(); auto plugin = LoadPlugin(ctx); if (!plugin) { return plugin.Error().Status(); } - static constexpr size_t kNumInputSubgraphs = 1; - LITERT_ENSURE_SUPPORTED(model->Get()->subgraphs.size() == kNumInputSubgraphs, - "Only single subgraph models currently supported."); - - // Query plugin for compilable ops and slice partitions out of the graph, - // replacing use with single custom op.. - auto custom_ops = ApplyPartition(ctx, *model, *plugin); - LITERT_ENSURE(!custom_ops.empty(), kLiteRtStatusErrorGraphModification, - "Failed to partition graph."); - // All new subgraphs to be compiled are appended to the model's subgraphs. - std::vector compilation_input; - for (auto it = model->Get()->subgraphs.begin() + kNumInputSubgraphs; - it < model->Get()->subgraphs.end(); ++it) { - compilation_input.push_back(&*it); + ctx.Dump().Start("Applying plugin"); + if (auto status = litert::internal::ApplyPlugin( + *plugin, model, ctx.SocModelTarget(), ctx.Serialization()); + !status) { + LITERT_LOG(LITERT_ERROR, "%s", status.Error().Message().data()); + return status.Error().Status(); } + ctx.Dump().Done(); - // Call compilation method on the plugin. - std::stringstream compilation_out; - OutStream out = ctx.SwapOut(compilation_out); - - auto call_info = CompilePartitions(ctx, compilation_input, *plugin); - - // Update custom op info the it's respective entry point info from the plugin. - LITERT_ENSURE(call_info->size() == custom_ops.size(), - kLiteRtStatusErrorCompilation, - "Failed to verify entry point information."); - - model->Get()->subgraphs.resize(kNumInputSubgraphs); - LITERT_RETURN_STATUS_IF_NOT_OK(StampModel(ctx, model->Get())); - - BufferRef compiled_buffer(compilation_out.view().data(), - compilation_out.view().size()); - - // For each custom op, if the input tensor is a constant, it should be removed - // from the input list. - for (auto& custom_op : custom_ops) { - std::vector new_inputs; - for (auto& input : custom_op->inputs) { - litert::Tensor input_tensor = litert::Tensor(input); - if (!input_tensor.IsConstant()) { - new_inputs.push_back(input); - } - } - custom_op->inputs = new_inputs; - } + ctx.Dump().Start("Serializing model"); + auto serialized = SerializeModel(std::move(model)); + DumpModelStats(ctx.Dump(), *serialized); + ctx.Dump().Done(); - ctx.SwapOut(out); - if (ctx.Serialization() == Serialization::kMetadata) { - auto serialized = DoMetadataSerialization( - ctx, custom_ops, *call_info, compiled_buffer, std::move(*model)); - if (!serialized) { - return serialized.Error().Status(); - } - serialized->WriteStr(ctx.Out()); + ctx.Dump().Start("Verifying flatbuffer"); + LITERT_ENSURE(VerifyFlatbuffer(serialized->Span()), + kLiteRtStatusErrorInvalidFlatbuffer, + "Failed to invalidate flatbuffer"); + ctx.Dump().Done(); - } else if (ctx.Serialization() == Serialization::kAppend) { - auto serialized = DoAppendSerialization(ctx, custom_ops, *call_info, - compiled_buffer, std::move(*model)); - if (!serialized) { - return serialized.Error().Status(); - } - serialized->WriteStr(ctx.Out()); + ctx.Dump().Start("Writing to out"); + serialized->WriteStr(ctx.Out()); + ctx.Dump().Done(); - } else { - return kLiteRtStatusErrorUnsupported; - } return kLiteRtStatusOk; } diff --git a/tensorflow/lite/experimental/litert/tools/apply_plugin.h b/tensorflow/lite/experimental/litert/tools/apply_plugin.h index 414ebb677b750d..46caf8ac10456f 100644 --- a/tensorflow/lite/experimental/litert/tools/apply_plugin.h +++ b/tensorflow/lite/experimental/litert/tools/apply_plugin.h @@ -128,7 +128,7 @@ struct ApplyPluginRun { // select the first ".so" file found with prefix "libLiteRtPlugin" that has // the "soc_manufacturer" tag passed. Providing more than one plugin shared // library for the same manufacturer results in an error. - SmallVec lib_search_paths = {}; + std::vector lib_search_paths = {}; // Path to ".tflite" model the tool should operated on. std::optional model = {}; @@ -139,13 +139,13 @@ struct ApplyPluginRun { std::optional soc_manufacturer = {}; // Collection of soc models tags the tool should target for compilation. - SmallVec soc_models = {}; + std::vector soc_models = {}; // Where the tool should write its result file(s) to. If the command runs // compilation, an "out" stream should be passed for each "soc_model" target // requested for compilation. Output for the "ith" target will be written to // the "ith" outs stream. - SmallVec outs = {std::cout}; + std::vector outs = {std::cout}; // Where to direct logging for this run. Passing nullopt here indicates // "silent" behavior and should only be used when this tool is part of a diff --git a/tensorflow/lite/experimental/litert/tools/apply_plugin_test.cc b/tensorflow/lite/experimental/litert/tools/apply_plugin_test.cc index 08eb6a13f73d77..0671c45a1f2632 100644 --- a/tensorflow/lite/experimental/litert/tools/apply_plugin_test.cc +++ b/tensorflow/lite/experimental/litert/tools/apply_plugin_test.cc @@ -105,7 +105,7 @@ TEST(TestApplyPluginTool, TestNoop) { auto model = Model::CreateFromBuffer( BufferRef(out.view().data(), out.view().size())); - EXPECT_EQ(model->Get()->subgraphs.size(), 1); + EXPECT_EQ(model->Get()->NumSubgraphs(), 1); } TEST(TestApplyPluginTool, TestPartitionBadConfig) { @@ -154,7 +154,7 @@ TEST(TestApplyPluginTool, TestApply) { auto model = Model::CreateFromBuffer( BufferRef(out.str().data(), out.str().size())); - EXPECT_EQ(model->Get()->subgraphs.size(), 1); + EXPECT_EQ(model->Get()->NumSubgraphs(), 1); { auto stamp_buffer = model->Get()->FindMetadata(kLiteRtBuildStampKey); @@ -166,9 +166,9 @@ TEST(TestApplyPluginTool, TestApply) { } { - auto custom_op = model->Get()->subgraphs.front().ops.front(); - ASSERT_EQ(custom_op->op_code, kLiteRtOpCodeTflCustom); - EXPECT_EQ(custom_op->custom_options.StrView(), "Partition_0"); + const auto& custom_op = model->Get()->Subgraph(0).Op(0); + ASSERT_EQ(custom_op.OpCode(), kLiteRtOpCodeTflCustom); + EXPECT_THAT(custom_op.CustomOptions().StrView(), HasSubstr("Partition_0")); } { @@ -194,7 +194,7 @@ TEST(TestApplyPluginTool, TestApplyWithAppendSerialization) { BufferRef serialized(out.str().data(), out.str().size()); auto model = Model::CreateFromBuffer(serialized); - EXPECT_EQ(model->Get()->subgraphs.size(), 1); + EXPECT_EQ(model->Get()->NumSubgraphs(), 1); { auto stamp_buffer = model->Get()->FindMetadata(kLiteRtBuildStampKey); @@ -206,10 +206,10 @@ TEST(TestApplyPluginTool, TestApplyWithAppendSerialization) { } { - auto custom_op = model->Get()->subgraphs.front().ops.front(); - ASSERT_EQ(custom_op->op_code, kLiteRtOpCodeTflCustom); + const auto& custom_op = model->Get()->Subgraph(0).Op(0); + ASSERT_EQ(custom_op.OpCode(), kLiteRtOpCodeTflCustom); - auto options = ParseExecInfo(custom_op->custom_options); + auto options = ParseExecInfo(custom_op.CustomOptions()); auto [entry_point, metadata_key] = *options; EXPECT_EQ(entry_point, "Partition_0"); diff --git a/tensorflow/lite/experimental/litert/tools/dump.cc b/tensorflow/lite/experimental/litert/tools/dump.cc index d84eeb07e6fb9a..30917a13106619 100644 --- a/tensorflow/lite/experimental/litert/tools/dump.cc +++ b/tensorflow/lite/experimental/litert/tools/dump.cc @@ -17,13 +17,14 @@ #include #ifndef __ANDROID__ +#if __has_include() #include #endif +#endif #include #include #include -#include #include #include "absl/strings/str_format.h" @@ -37,21 +38,25 @@ namespace litert::internal { namespace { +static constexpr int kMaxDisplayCount = 16; + void DumpNode(const LiteRtTensorT& tensor, std::ostream& out) { - switch (tensor.type_id) { + switch (tensor.Type().first) { case kLiteRtRankedTensorType: - Dump(tensor.type_detail.ranked_tensor_type, out); + Dump(tensor.Type().second.ranked_tensor_type, out); break; case kLiteRtUnrankedTensorType: - Dump(tensor.type_detail.unranked_tensor_type.element_type, out); + Dump(tensor.Type().second.unranked_tensor_type.element_type, out); break; default: - out << "UKNOWN_TENSOR_TYPE" << tensor.type_id; + out << "UKNOWN_TENSOR_TYPE" << tensor.Type().first; } - Dump(std::make_pair(tensor.q_type_id, tensor.q_type_detail), out); + Dump(tensor.Qparams(), out); } -void DumpNode(const LiteRtOpT& op, std::ostream& out) { Dump(op.op_code, out); } +void DumpNode(const LiteRtOpT& op, std::ostream& out) { + Dump(op.OpCode(), out); +} void DumpSignature(const std::vector& ins, const std::vector& outs, std::ostream& out) { @@ -156,6 +161,9 @@ void Dump(LiteRtOpCode code, std::ostream& out) { case kLiteRtOpCodeTflGreater: out << "TFL_GREATER"; break; + case kLiteRtOpCodeTflGelu: + out << "TFL_GELU"; + break; default: out << "UKNOWN_OP_CODE: " << code; break; @@ -210,17 +218,17 @@ void Dump(const LiteRtTensorT& tensor, std::ostream& out) { out << "LiteRtTensor : "; DumpNode(tensor, out); out << " [ "; - if (tensor.defining_op == nullptr) { + if (tensor.DefiningOp() == nullptr) { out << "*"; } else { - DumpNode(*tensor.defining_op, out); + DumpNode(*tensor.DefiningOp(), out); } out << " ] "; out << "("; - for (auto it = tensor.users.begin(); it < tensor.users.end(); ++it) { + for (auto it = tensor.Users().begin(); it < tensor.Users().end(); ++it) { DumpNode(**it, out); - if (it != tensor.users.end() - 1) { + if (it != tensor.Users().end() - 1) { out << ", "; } } @@ -232,16 +240,16 @@ void Dump(const LiteRtOpT& op, std::ostream& out) { out << "LiteRtOp : [ "; DumpNode(op, out); out << " ] "; - DumpSignature(op.inputs, op.outputs, out); + DumpSignature(op.Inputs(), op.Outputs(), out); out << "\n"; } void Dump(const LiteRtSubgraphT& subgraph, std::ostream& out) { constexpr absl::string_view kSubgraphTpl = "LiteRtSubgraph : [ #ops=%d #tensors=%d ] "; - out << absl::StreamFormat(kSubgraphTpl, subgraph.ops.size(), - subgraph.tensors.size()); - DumpSignature(subgraph.inputs, subgraph.outputs, out); + out << absl::StreamFormat(kSubgraphTpl, subgraph.Ops().size(), + subgraph.Tensors().size()); + DumpSignature(subgraph.Inputs(), subgraph.Outputs(), out); out << "\n"; } @@ -262,8 +270,8 @@ void Dump(const CompilerPlugin& plugin, std::ostream& out) { out << "}\n"; } -void Dump(void* lib_handle, std::ostream& out) { -#ifndef __ANDROID__ +void DumpDLL(void* lib_handle, std::ostream& out) { +#if !defined(__ANDROID__) && !defined(__APPLE__) out << "\n--- Lib Info ---\n"; if (lib_handle == nullptr) { out << "Handle is nullptr\n"; @@ -312,95 +320,90 @@ void Dump(void* lib_handle, std::ostream& out) { void Dump(const LiteRtModelT& model, std::ostream& out) { out << absl::StreamFormat("LiteRtModel : [ #subgraphs=%d ]\n", - model.subgraphs.size()); + model.Subgraphs().size()); } void DumpOptions(const LiteRtOpT& op, std::ostream& out) { - if (op.option.value == nullptr) { + auto& opts = detail::GetTflOptions(op); + if (opts.value == nullptr) { out << "null options\n"; return; } - switch (op.op_code) { + switch (op.OpCode()) { case kLiteRtOpCodeTflAdd: out << "fused_activation_function: " - << op.option.AsAddOptions()->fused_activation_function << "\n"; + << opts.AsAddOptions()->fused_activation_function << "\n"; break; case kLiteRtOpCodeTflMul: out << "fused_activation_function: " - << op.option.AsMulOptions()->fused_activation_function << "\n"; + << opts.AsMulOptions()->fused_activation_function << "\n"; break; case kLiteRtOpCodeTflBatchMatmul: - out << "adj_x: " << op.option.AsBatchMatMulOptions()->adj_x << "\n"; - out << "adj_y: " << op.option.AsBatchMatMulOptions()->adj_y << "\n"; + out << "adj_x: " << opts.AsBatchMatMulOptions()->adj_x << "\n"; + out << "adj_y: " << opts.AsBatchMatMulOptions()->adj_y << "\n"; out << "asymmetric_quantize_input: " - << op.option.AsBatchMatMulOptions()->asymmetric_quantize_inputs - << "\n"; + << opts.AsBatchMatMulOptions()->asymmetric_quantize_inputs << "\n"; break; case kLiteRtOpCodeTflConcatenation: - out << "axis: " << op.option.AsConcatenationOptions()->axis << "\n"; + out << "axis: " << opts.AsConcatenationOptions()->axis << "\n"; out << "fused_activation_function: " - << op.option.AsConcatenationOptions()->fused_activation_function - << "\n"; + << opts.AsConcatenationOptions()->fused_activation_function << "\n"; break; case kLiteRtOpCodeTflDiv: out << "fused_activation_function: " - << op.option.AsDivOptions()->fused_activation_function << "\n"; + << opts.AsDivOptions()->fused_activation_function << "\n"; break; case kLiteRtOpCodeTflFullyConnected: out << "weights_format: " - << op.option.AsFullyConnectedOptions()->weights_format << "\n"; - out << "keep_num_dims: " - << op.option.AsFullyConnectedOptions()->keep_num_dims << "\n"; + << opts.AsFullyConnectedOptions()->weights_format << "\n"; + out << "keep_num_dims: " << opts.AsFullyConnectedOptions()->keep_num_dims + << "\n"; out << "quantized_bias_type: " - << op.option.AsFullyConnectedOptions()->quantized_bias_type << "\n"; + << opts.AsFullyConnectedOptions()->quantized_bias_type << "\n"; out << "asymmetric_quantize_input: " - << op.option.AsFullyConnectedOptions()->asymmetric_quantize_inputs - << "\n"; + << opts.AsFullyConnectedOptions()->asymmetric_quantize_inputs << "\n"; out << "fused_activation_function: " - << op.option.AsFullyConnectedOptions()->fused_activation_function - << "\n"; + << opts.AsFullyConnectedOptions()->fused_activation_function << "\n"; break; case kLiteRtOpCodeTflSoftmax: - out << "beta: " << op.option.AsSoftmaxOptions()->beta << "\n"; + out << "beta: " << opts.AsSoftmaxOptions()->beta << "\n"; break; case kLiteRtOpCodeTflStridedSlice: - out << "begin_mask: " << op.option.AsStridedSliceOptions()->begin_mask + out << "begin_mask: " << opts.AsStridedSliceOptions()->begin_mask << "\n"; + out << "end_mask: " << opts.AsStridedSliceOptions()->end_mask << "\n"; + out << "ellipsis_mask: " << opts.AsStridedSliceOptions()->ellipsis_mask << "\n"; - out << "end_mask: " << op.option.AsStridedSliceOptions()->end_mask + out << "new_axis_mask: " << opts.AsStridedSliceOptions()->new_axis_mask << "\n"; - out << "ellipsis_mask: " - << op.option.AsStridedSliceOptions()->ellipsis_mask << "\n"; - out << "new_axis_mask: " - << op.option.AsStridedSliceOptions()->new_axis_mask << "\n"; out << "shrink_axis_mask: " - << op.option.AsStridedSliceOptions()->shrink_axis_mask << "\n"; - out << "offset: " << op.option.AsStridedSliceOptions()->offset << "\n"; + << opts.AsStridedSliceOptions()->shrink_axis_mask << "\n"; + out << "offset: " << opts.AsStridedSliceOptions()->offset << "\n"; break; case kLiteRtOpCodeTflSub: out << "fused_activation_function: " - << op.option.AsSubOptions()->fused_activation_function << "\n"; + << opts.AsSubOptions()->fused_activation_function << "\n"; break; case kLiteRtOpCodeTflReshape: out << "new_shape: "; - if (op.option.AsReshapeOptions() != nullptr) { - const int32_t* new_shape = - op.option.AsReshapeOptions()->new_shape.data(); - int32_t new_shape_size = op.option.AsReshapeOptions()->new_shape.size(); + if (opts.AsReshapeOptions() != nullptr) { + const int32_t* new_shape = opts.AsReshapeOptions()->new_shape.data(); + int32_t new_shape_size = opts.AsReshapeOptions()->new_shape.size(); for (int i = 0; i < new_shape_size; ++i) { out << new_shape[i] << " "; } } break; case kLiteRtOpCodeTflSum: - out << "keepdims: " << op.option.AsReducerOptions()->keep_dims << "\n"; + out << "keepdims: " << opts.AsReducerOptions()->keep_dims << "\n"; break; default: - out << "No options for op code: " << op.op_code; + out << "No options for op code: " << op.OpCode(); break; } } void Dump(Quantization quantization, std::ostream& out) { + int max_display_count; switch (quantization.first) { case kLiteRtQuantizationNone: return; @@ -409,6 +412,25 @@ void Dump(Quantization quantization, std::ostream& out) { quantization.second.per_tensor.zero_point, quantization.second.per_tensor.scale); return; + case kLiteRtQuantizationPerChannel: + max_display_count = + kMaxDisplayCount < quantization.second.per_channel.num_channels + ? kMaxDisplayCount + : quantization.second.per_channel.num_channels; + out << absl::StreamFormat(" ", quantization.second.per_channel.quantized_dimension); + return; default: out << " "; return; diff --git a/tensorflow/lite/experimental/litert/tools/dump.h b/tensorflow/lite/experimental/litert/tools/dump.h index 68dbbe5e7c4929..4012bb3e9e7aa5 100644 --- a/tensorflow/lite/experimental/litert/tools/dump.h +++ b/tensorflow/lite/experimental/litert/tools/dump.h @@ -64,7 +64,7 @@ void DumpOptions(const LiteRtOpT& op, std::ostream& out = std::cerr); void Dump(const CompilerPlugin& plugin, std::ostream& out = std::cerr); // Dumps details about the dynamic library (see "dlinfo"). -void Dump(void* lib_handle, std::ostream& out = std::cerr); +void DumpDLL(void* lib_handle, std::ostream& out = std::cerr); } // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/tools/dump_test.cc b/tensorflow/lite/experimental/litert/tools/dump_test.cc index 3a133fd73009fa..ff89547c2350aa 100644 --- a/tensorflow/lite/experimental/litert/tools/dump_test.cc +++ b/tensorflow/lite/experimental/litert/tools/dump_test.cc @@ -14,6 +14,8 @@ #include "tensorflow/lite/experimental/litert/tools/dump.h" +#include +#include #include #include @@ -39,8 +41,7 @@ TEST(DumpTest, TestDump) { } { - const LiteRtTensorT& in_tensor = - *model.Get()->subgraphs.front().inputs.front(); + const LiteRtTensorT& in_tensor = model.Get()->Subgraph(0).Input(0); std::ostringstream in_tensor_dump; Dump(in_tensor, in_tensor_dump); EXPECT_EQ(in_tensor_dump.view(), @@ -48,8 +49,7 @@ TEST(DumpTest, TestDump) { } { - const LiteRtTensorT& out_tensor = - *model.Get()->subgraphs.front().outputs.front(); + const LiteRtTensorT& out_tensor = model.Get()->Subgraph(0).Output(0); std::ostringstream out_tensor_dump; Dump(out_tensor, out_tensor_dump); EXPECT_EQ(out_tensor_dump.view(), @@ -57,7 +57,7 @@ TEST(DumpTest, TestDump) { } { - const LiteRtOpT& op = *model.Get()->subgraphs.front().ops.front(); + const LiteRtOpT& op = model.Get()->Subgraph(0).Op(0); std::ostringstream op_dump; Dump(op, op_dump); EXPECT_EQ(op_dump.view(), @@ -65,7 +65,7 @@ TEST(DumpTest, TestDump) { } { - const LiteRtSubgraphT& subgraph = model.Get()->subgraphs.front(); + const LiteRtSubgraphT& subgraph = model.Get()->Subgraph(0); std::ostringstream subgraph_dump; Dump(subgraph, subgraph_dump); EXPECT_EQ( @@ -77,7 +77,7 @@ TEST(DumpTest, TestDump) { TEST(DumpTest, TestDumpOptions) { auto model = LoadTestFileModel("simple_strided_slice_op.tflite"); - const LiteRtOpT& op = *model.Get()->subgraphs.front().ops.front(); + const LiteRtOpT& op = model.Get()->Subgraph(0).Op(0); std::ostringstream op_dump; DumpOptions(op, op_dump); EXPECT_EQ(op_dump.view(), @@ -90,7 +90,7 @@ TEST(DumpTest, TestDumpOptions) { } TEST(DumpTest, TestDumpPerTensorQuantization) { - LiteRtQuantizationTypeDetail per_tensor_detail; + QuantizationDetail per_tensor_detail; per_tensor_detail.per_tensor.scale = 1.0; per_tensor_detail.per_tensor.zero_point = 2; std::ostringstream q_dump; @@ -98,15 +98,31 @@ TEST(DumpTest, TestDumpPerTensorQuantization) { EXPECT_EQ(q_dump.view(), " "); } +TEST(DumpTest, TestDumpPerChannelQuantization) { + static constexpr size_t kRank = 2; + static constexpr size_t kQuantizedDimension = 1; + static constexpr float kScales[kRank] = {1.0, 2.0}; + static constexpr int64_t kZps[kRank] = {2, 3}; + QuantizationDetail per_channel_detail; + per_channel_detail.per_channel.scales = const_cast(kScales); + per_channel_detail.per_channel.zero_points = const_cast(kZps); + per_channel_detail.per_channel.quantized_dimension = kQuantizedDimension; + per_channel_detail.per_channel.num_channels = kRank; + std::ostringstream q_dump; + Dump(std::make_pair(kLiteRtQuantizationPerChannel, per_channel_detail), + q_dump); + EXPECT_FALSE(q_dump.view().empty()); +} + TEST(DumpTest, TestDumpNoQuantization) { - LiteRtQuantizationTypeDetail none_detail; + QuantizationDetail none_detail; std::ostringstream q_dump; Dump(std::make_pair(kLiteRtQuantizationNone, none_detail), q_dump); EXPECT_TRUE(q_dump.view().empty()); } TEST(DumpTest, TestDumpUnknownQuantization) { - LiteRtQuantizationTypeDetail detail; + QuantizationDetail detail; std::ostringstream q_dump; Dump(std::make_pair(kLiteRtQuantizationBlockWise, detail), q_dump); EXPECT_EQ(q_dump.view(), " "); diff --git a/tensorflow/lite/experimental/litert/vendors/c/BUILD b/tensorflow/lite/experimental/litert/vendors/c/BUILD index 686bd0021d1d7e..8b8018451256a2 100644 --- a/tensorflow/lite/experimental/litert/vendors/c/BUILD +++ b/tensorflow/lite/experimental/litert/vendors/c/BUILD @@ -23,8 +23,6 @@ cc_library( deps = [ "//tensorflow/lite/experimental/litert/c:litert_common", "//tensorflow/lite/experimental/litert/c:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_expected", - "//tensorflow/lite/experimental/litert/cc:litert_macros", ], ) @@ -35,6 +33,7 @@ cc_library( ":litert_compiler_plugin", "//tensorflow/lite/experimental/litert/c:litert_common", "//tensorflow/lite/experimental/litert/c:litert_model", + "@com_google_absl//absl/strings:string_view", ], ) @@ -45,7 +44,9 @@ cc_library( "litert_dispatch_api.h", ], deps = [ + "//tensorflow/lite/experimental/litert/c:litert_any", "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_event", "//tensorflow/lite/experimental/litert/c:litert_model", "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", "//tensorflow/lite/experimental/litert/runtime/dispatch", diff --git a/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h b/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h index e32829396974d8..d0196e99a0d358 100644 --- a/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h +++ b/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h @@ -43,6 +43,11 @@ LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin); void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin); +// Return the HW supported by this plugin (e.g., GPU, NPU) +LiteRtStatus LiteRtGetCompilerPluginSupportedHardware( + LiteRtCompilerPlugin compiler_plugin, + LiteRtHwAccelerators* supported_hardware); + // Number of SoC models supported by this plugin. LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( LiteRtCompilerPlugin compiler_plugin, @@ -54,12 +59,11 @@ LiteRtStatus LiteRtGetCompilerPluginSupportedSocModel( LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, const char** soc_model_name); -// Select desired ops for compilation. This will be called only once -// during the plugin application flow, all ops should be selected during this -// call. -LiteRtStatus LiteRtCompilerPluginPartitionModel( - LiteRtCompilerPlugin compiler_plugin, LiteRtModel model, - LiteRtOpList selected_ops); +// Select desired ops for compilation. This will only be called once +// per subgraph, plugins should select all supportable ops. +LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, + LiteRtSubgraph subgraph, + LiteRtOpList selected_ops); // Prepare result to pass to the runtime for given partition and, optionally, // for a given SoC model (parameter `soc_model` can be NULL to specify a default @@ -67,7 +71,7 @@ LiteRtStatus LiteRtCompilerPluginPartitionModel( // partition step. LiteRtStatus LiteRtCompilerPluginCompile(LiteRtCompilerPlugin compiler_plugin, const char* soc_model, - LiteRtSubgraphArray partitions, + LiteRtSubgraph* partitions, LiteRtParamIndex num_partitions, LiteRtCompiledResult* compiled_result); diff --git a/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h b/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h index 5746e845b8e328..b376f5a91cb38a 100644 --- a/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h +++ b/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h @@ -40,6 +40,9 @@ typedef LiteRtStatus (*LiteRtCreateCompilerPluginT)(LiteRtCompilerPlugin*); typedef void (*LiteRtDestroyCompilerPluginT)(LiteRtCompilerPlugin); +typedef LiteRtStatus (*LiteRtGetCompilerPluginSupportedHardwareT)( + LiteRtCompilerPlugin, LiteRtHwAccelerators*); + typedef LiteRtStatus (*LiteRtGetNumCompilerPluginSupportedSocModelsT)( LiteRtCompilerPlugin, LiteRtParamIndex*); @@ -47,11 +50,11 @@ typedef LiteRtStatus (*LiteRtGetCompilerPluginSupportedSocModelT)( LiteRtCompilerPlugin, LiteRtParamIndex soc_model_idx, const char** soc_moel_idx); -typedef LiteRtStatus (*LiteRtCompilerPluginPartitionModelT)( - LiteRtCompilerPlugin, LiteRtModel model, LiteRtOpList selected_ops); +typedef LiteRtStatus (*LiteRtCompilerPluginPartitionT)( + LiteRtCompilerPlugin, LiteRtSubgraph subgraph, LiteRtOpList selected_ops); typedef LiteRtStatus (*LiteRtCompilerPluginCompileT)( - LiteRtCompilerPlugin, const char* soc_model, LiteRtSubgraphArray partitions, + LiteRtCompilerPlugin, const char* soc_model, LiteRtSubgraph* partitions, LiteRtParamIndex num_partitions, LiteRtCompiledResult* compiled_result); typedef void (*LiteRtDestroyCompiledResultT)(LiteRtCompiledResult); @@ -77,12 +80,14 @@ struct LiteRtCompilerPluginApi { LiteRtCreateCompilerPluginT create_compiler_plugin; LiteRtDestroyCompilerPluginT destroy_compiler_plugin; + LiteRtGetCompilerPluginSupportedHardwareT + get_compiler_plugin_supported_hardware; LiteRtGetNumCompilerPluginSupportedSocModelsT get_num_compiler_plugin_supported_models; LiteRtGetCompilerPluginSupportedSocModelT get_compiler_plugin_supported_soc_model; - LiteRtCompilerPluginPartitionModelT compiler_plugin_partition_model; + LiteRtCompilerPluginPartitionT compiler_plugin_partition; LiteRtCompilerPluginCompileT compiler_plugin_compile; LiteRtDestroyCompiledResultT destroy_compiled_result; @@ -93,6 +98,42 @@ struct LiteRtCompilerPluginApi { #ifdef __cplusplus } + +#include "absl/strings/string_view.h" + +static constexpr absl::string_view kLiteRtGetCompilerPluginVersion = + "LiteRtGetCompilerPluginVersion"; + +static constexpr absl::string_view kLiteRtGetCompilerPluginSupportedHardware = + "LiteRtGetCompilerPluginSupportedHardware"; + +static constexpr absl::string_view kLiteRtGetCompilerPluginSocManufacturer = + "LiteRtGetCompilerPluginSocManufacturer"; +static constexpr absl::string_view + kLiteRtGetNumCompilerPluginSupportedSocModels = + "LiteRtGetNumCompilerPluginSupportedSocModels"; +static constexpr absl::string_view kLiteRtGetCompilerPluginSupportedSocModel = + "LiteRtGetCompilerPluginSupportedSocModel"; + +static constexpr absl::string_view kLiteRtCreateCompilerPlugin = + "LiteRtCreateCompilerPlugin"; +static constexpr absl::string_view kLiteRtDestroyCompilerPlugin = + "LiteRtDestroyCompilerPlugin"; + +static constexpr absl::string_view kLiteRtCompilerPluginPartition = + "LiteRtCompilerPluginPartition"; +static constexpr absl::string_view kLiteRtCompilerPluginCompile = + "LiteRtCompilerPluginCompile"; + +static constexpr absl::string_view kLiteRtDestroyCompiledResult = + "LiteRtDestroyCompiledResult"; +static constexpr absl::string_view kLiteRtGetCompiledResultByteCode = + "LiteRtGetCompiledResultByteCode"; +static constexpr absl::string_view kLiteRtGetCompiledResultCallInfo = + "LiteRtGetCompiledResultCallInfo"; +static constexpr absl::string_view kLiteRtGetNumCompiledResultCalls = + "LiteRtGetNumCompiledResultCalls"; + #endif // __cplusplus #endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_API_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h b/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h index 9b70692cd83e83..fa735fed564f53 100644 --- a/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h +++ b/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h @@ -19,6 +19,7 @@ #include #include +#include "tensorflow/lite/experimental/litert/c/litert_any.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_event.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" diff --git a/tensorflow/lite/experimental/litert/vendors/cc/BUILD b/tensorflow/lite/experimental/litert/vendors/cc/BUILD index d02f4b67506a8b..e101607f1ba6d8 100644 --- a/tensorflow/lite/experimental/litert/vendors/cc/BUILD +++ b/tensorflow/lite/experimental/litert/vendors/cc/BUILD @@ -25,3 +25,99 @@ cc_library( "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin", ], ) + +cc_library( + name = "conversion", + hdrs = ["conversion.h"], + deps = [ + ":backend_ir", + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_model", + "//tensorflow/lite/experimental/litert/c:litert_op_code", + "//tensorflow/lite/experimental/litert/cc:litert_expected", + "//tensorflow/lite/experimental/litert/cc:litert_model", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + +cc_library( + name = "backend_ir", + hdrs = ["backend_ir.h"], + deps = ["//tensorflow/lite/experimental/litert/c:litert_common"], +) + +cc_library( + name = "partition_with_capabilities", + hdrs = ["partition_with_capabilities.h"], + deps = [ + ":conversion", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/c:litert_model", + "//tensorflow/lite/experimental/litert/cc:litert_expected", + "//tensorflow/lite/experimental/litert/cc:litert_model", + ], +) + +cc_library( + name = "convert_graph", + hdrs = ["convert_graph.h"], + deps = [ + ":conversion", + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/cc:litert_expected", + "//tensorflow/lite/experimental/litert/cc:litert_model", + ], +) + +cc_library( + name = "ir_types", + hdrs = ["ir_types.h"], + deps = [ + ":backend_ir", + ":conversion", + "//tensorflow/lite/experimental/litert/cc:litert_expected", + ], +) + +cc_test( + name = "partition_with_capabilities_test", + srcs = ["partition_with_capabilities_test.cc"], + deps = [ + ":partition_with_capabilities", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "//tensorflow/lite/experimental/litert/c:litert_model", + "//tensorflow/lite/experimental/litert/c:litert_op_code", + "//tensorflow/lite/experimental/litert/cc:litert_model", + "//tensorflow/lite/experimental/litert/core/model", + "//tensorflow/lite/experimental/litert/core/model:model_graph", + "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", + "//tensorflow/lite/experimental/litert/vendors/examples:example_conversion_impl", + "//tensorflow/lite/experimental/litert/vendors/examples:example_ir", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "convert_graph_test", + srcs = ["convert_graph_test.cc"], + deps = [ + ":backend_ir", + ":convert_graph", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_model", + "//tensorflow/lite/experimental/litert/c:litert_op_code", + "//tensorflow/lite/experimental/litert/cc:litert_buffer_ref", + "//tensorflow/lite/experimental/litert/cc:litert_model", + "//tensorflow/lite/experimental/litert/core/model", + "//tensorflow/lite/experimental/litert/core/model:model_graph", + "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", + "//tensorflow/lite/experimental/litert/test:test_macros", + "//tensorflow/lite/experimental/litert/vendors/examples:example_conversion_impl", + "//tensorflow/lite/experimental/litert/vendors/examples:example_ir", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h b/tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h new file mode 100644 index 00000000000000..34cf95bd3643e6 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h @@ -0,0 +1,79 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_BACKEND_IR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_BACKEND_IR_H_ + +#include +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" + +namespace litert { + +// Interfaces and types for managing backend IR to be targeted by LiteRt for +// compilation. + +// Memory Management +//===--------------------------------------------------------------------------- + +// Callable for allocating a new instance of a backend IR type. This facilitates +// external memory management for the backend IR implementented by the backend. +// It is encouraged for implementations provide pointer stability (consider +// std::list for storage). +template +using BackendIrAllocator = std::function; + +// Allocator for backend tensors. +template +using TensorAllocator = BackendIrAllocator; + +// Allocator for backend ops. +template +using OpAllocator = BackendIrAllocator; + +// Graph Construction +//===--------------------------------------------------------------------------- + +// Wrapper for an in memory graph for a particular backend. Implementations +// should contain an instance of a backend graph that can be iteratively +// constructed via calls to this interface. +template +class BackendGraphBuilder { + public: + // Hook called to initialize state for a new backend graph with a name. This + // will be called once per-instance before any other method. + virtual void InitGraph(std::string graph_name) = 0; + + // Hook called to register a backend tensor once it + // has been converted. This will be called once per tensor. + virtual LiteRtStatus RegisterTensor(BackendTensor& tensor) = 0; + + // Hook called to register a backend op once it has been converted. This will + // be called once per op (in a toplogogical order). All input/output tensors + // will have been registered before called. + virtual LiteRtStatus RegisterOp(BackendOp& op) = 0; + + // Hook called to register a graph when graph + // conversion is completed. Backend graph context should be stored as internal + // state. This will be called once per instance after all ops/tensors have + // been finalized. + virtual LiteRtStatus FinalizeGraph() = 0; + + virtual ~BackendGraphBuilder() = default; +}; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_BACKEND_IR_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/cc/conversion.h b/tensorflow/lite/experimental/litert/vendors/cc/conversion.h new file mode 100644 index 00000000000000..139ba594bb1e8a --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/cc/conversion.h @@ -0,0 +1,262 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Utility types for mapping LiteRt IR to arbitrary backend specific +// types. Implementations of these types define mapping for ops and tensors +// that may be used in a stndalone fashion. They also may be composed +// to create lowerings of entire graphs with topology. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERSION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERSION_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h" + +namespace litert { + +// Interfaces and types for implementing "conversions" that map LiteRt IR to +// backend IR. +// NOTE: Conversions depend on external memory management for the backend IR +// types. User defined conversions are usually expected to leverage callbacks +// to allocate backend IR types rather than constructing them directly. + +// Conversion Result Type +//===--------------------------------------------------------------------------- + +// Result of a one->many general mapping from LiteRt op to any number of +// backend specific ops. Does not own the memory of the backend ops or tensors. +template +struct GeneralConversionResult { + // Ops emitted from translation pattern. + std::vector ops; + + // Any backend tensors used within the results ops. Not relevant when + // size of backend ops == 1. This does not include input/output tensors of the + // op being converted. + std::vector intermediate_tensors; +}; + +// The result of a one->one specialized mapping from LiteRt op to backend op. +template +using SimpleConversionResult = BackendOp*; + +// A tag-type for a conversion result that is a non-error non-match. +struct NoMatch {}; + +// Type union for conversion results. +// TODO(lukeboyer): Update conversion result types to handle the case where +// backend ops add extra inputs. +template +using ConversionResult = + std::variant, + GeneralConversionResult, NoMatch>; + +// Short hand for holds_alternative. +template +bool ConversionIsA(const ConversionResult& result) { + return std::holds_alternative(result); +} + +// Short hand for holds_alternative. +template +bool ConversionMatched( + const ConversionResult& result) { + return !std::holds_alternative(result); +} + +// Short hand for holds_alternative. +template +bool IsSimpleResult(const ConversionResult& result) { + return ConversionIsA>(result); +} + +// Short hand for holds_alternative. +template +bool IsGeneralResult(const ConversionResult& result) { + return ConversionIsA>( + result); +} + +// Short hand for std::get. Also checks if match and wraps in expected. +template +Expected GetConversionResult( + const ConversionResult& result) { + if (ConversionMatched(result)) { + return Expected(std::get(result)); + } + return Error(kLiteRtStatusLegalizeNoMatch); +} + +// Get simple result if there was a match. +template +Expected> GetSimpleConversionResult( + const ConversionResult& result) { + if (!IsSimpleResult(result)) { + return Error(kLiteRtStatusErrorInvalidArgument); + } + return GetConversionResult>(result); +} + +// Get general result if there was a match. +template +Expected> +GetGeneralConversionResult( + const ConversionResult& result) { + if (!IsGeneralResult(result)) { + return Error(kLiteRtStatusErrorInvalidArgument); + } + return GetConversionResult>( + result); +} + +// Common IR Conversion +//===--------------------------------------------------------------------------- + +// User defined callback for converting a LiteRt tensor to a backend tensor. +// These are leveraged in various higher-level conversion routines. +// TensorConverters should not stack allocate memory for the backend tensor. In +// most situations, these will be bound to an external allocator. +template +using TensorConverter = + std::function(const Tensor& litert_tensor)>; + +// User defined callback for creating a TensorConverter. This facilitates +// TensoConverters that are bound to an external allocator. +template +using TensorConverterFactory = std::function( + TensorAllocator alloc)>; + +// Mapping from LiteRt tensor to backend tensor, used during iterative graph +// conversions to store current scope. +template +using TensorMap = absl::flat_hash_map; + +// User-defined hook that calls backend to determine if an op is supported. +template +using Capability = std::function; + +// Legalization +//===--------------------------------------------------------------------------- + +// A legalization is a particlar type of user-defined conversion that is +// scheduled for execution on a particular type of LiteRtOp. They may be +// one-to-one or one-to-many conversions. +template +class Legalization { + private: + using Self = Legalization; + + public: + using Result = ConversionResult; + using TensorConverter = TensorConverter; + using TensorConverterFactory = TensorConverterFactory; + using Ptr = std::unique_ptr; + using TensorAllocator = TensorAllocator; + using OpAllocator = OpAllocator; + using Tensors = std::vector; + + // The type of op to schedule on. + virtual LiteRtOpCode OpToMatch() const = 0; + + // Invoke this legalization on the given LiteRt op. All new backend IR will be + // allocated via given allocators. NOTE: In most cases, input and output + // converters will be the same. They are separated here for compatibility with + // graph-level conversions routines. + Expected Legalize(const Op& litert_op, + TensorConverterFactory input_converter, + TensorConverterFactory output_converter, + TensorAllocator tensor_allocator, + OpAllocator op_allocator) const { + const auto litert_inputs = litert_op.Inputs(); + Tensors inputs(litert_inputs.size()); + auto convert_input = input_converter(tensor_allocator); + + for (size_t i = 0; i < litert_inputs.size(); ++i) { + const auto& litert_input = litert_inputs[i]; + auto result = convert_input(litert_input); + if (!result) { + return result.Error(); + } + inputs[i] = *result; + } + + const auto litert_outputs = litert_op.Outputs(); + Tensors outputs(litert_outputs.size()); + auto convert_output = output_converter(tensor_allocator); + + for (size_t i = 0; i < litert_outputs.size(); ++i) { + const auto& litert_output = litert_outputs[i]; + auto result = convert_output(litert_output); + if (!result) { + return result.Error(); + } + outputs[i] = *result; + } + + return LegalizeImpl(litert_op, inputs, outputs, tensor_allocator, + op_allocator); + } + + virtual ~Legalization() = default; + + private: + // The user defined implementation of a legalization. Users must use the + // given allocators to allocate any new backend IR types (e.g. intermediate + // ops/tensors in the case of a one-to-many legalization). BackendTensors + // corresponding to LiteRt inputs and outputs have been pre-converted. + virtual Expected LegalizeImpl(const Op& litert_op, + const Tensors& inputs, + const Tensors& outputs, + TensorAllocator tensor_allocator, + OpAllocator op_allocator) const = 0; +}; + +// Collection of legalizations for a specific backend. +template +using Legalizations = + std::vector::Ptr>; + +// Map for instance lookup by op code. +template +using LegalizationMap = + absl::flat_hash_map*>; + +// Construct a LegalizationMap from a collection of legalizations. +// TODO: Consider wrapping the legalization map in a class to avoid +// re-constructing it & better syntax. +template +LegalizationMap MakeLegalizationMap( + const Legalizations& legalizations) { + LegalizationMap map; + for (const auto& l : legalizations) { + map.insert({l->OpToMatch(), l.get()}); + } + return map; +} + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERSION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/cc/convert_graph.h b/tensorflow/lite/experimental/litert/vendors/cc/convert_graph.h new file mode 100644 index 00000000000000..cd7221c7bba028 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/cc/convert_graph.h @@ -0,0 +1,177 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERT_GRAPH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERT_GRAPH_H_ + +#include +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" + +namespace litert { + +// Performs iterative graph conversion with user provided hooks. This function +// traverses the IR in toplogical order, converting ops and tensors with given +// tensor converter and legalizations. Registers converted ops and tensors with +// the backend graph builder after they have been converted. The following are +// true: +// * Each tensor and op will be converted & registered at most once. +// * An ops input and output tensors will be registered before the op is +// converted (and before its registered). +// * The graph builder will be initialized before any registration. +// * The graph builder will be finalized after all registration. +template +LiteRtStatus ConvertGraph( + const Subgraph& subgraph, std::string graph_name, + typename Ir::TensorConverterFactory tensor_converter_factory, + typename Ir::TensorAllocator tensor_alloc, + typename Ir::OpAllocator op_alloc, + const typename Ir::Legalizations& legalizations, + typename Ir::GraphBuilder& builder) { + // Store mapping between evaluated litert tensors and corresponding backend + // tensors. + typename Ir::TensorMap tensor_map; + + // Initialize backend graph builder. + builder.InitGraph(std::move(graph_name)); + + // Convert tensor, add to scope and register in backend graph builder. + auto handle_tensor = [&tensor_map, &builder]( + const auto& litert_tensor, + auto tensor_converter) -> Ir::TensorResult { + auto converted = tensor_converter(litert_tensor); + if (!converted) { + LITERT_LOG(LITERT_ERROR, "Failed to convert tensor %lu", + litert_tensor.Get()); + return converted.Error(); + } + + if (auto status = builder.RegisterTensor(**converted); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to register tensor %lu, with status %d", + litert_tensor.Get(), status); + return Error(status); + } + + tensor_map.insert({litert_tensor.Get(), *converted}); + return *converted; + }; + + // Wrap provided tensor conversion logic for converting subgraph or op input + // tensors. We want functionality that provides user-defined conversions with + // tensors to be aware of the tensor map and graph builder registration. + auto input_tensor_convert_factory = [tensor_converter_factory, &tensor_map, + handle_tensor](auto tensor_alloc) { + return [tensor_alloc, tensor_converter_factory, &tensor_map, + handle_tensor](const Tensor& litert_tensor) -> Ir::TensorResult { + auto tensor_converter = tensor_converter_factory(tensor_alloc); + + // Check if tensor has been converted already. + auto it = tensor_map.find(litert_tensor.Get()); + const auto in_scope = it != tensor_map.end(); + if (in_scope) { + LITERT_LOG(LITERT_VERBOSE, "Tensor %lu is in scope", + litert_tensor.Get()); + return it->second; + } + + // If its a subgraph input or constant, we can convert it and add to + // scope. + const auto is_cst = litert_tensor.IsConstant(); + const auto is_sg_input = litert_tensor.IsSubgraphInput(); + if (is_sg_input || is_cst) { + return handle_tensor(litert_tensor, tensor_converter); + } + + // Tensor must be added to scope before conversion, or not have a parent + // (e.g. subgraph input or constant) so error at this point. + LITERT_LOG(LITERT_ERROR, "Tensor %lu not handled", litert_tensor.Get()); + return Error(kLiteRtStatusErrorInvalidArgument); + }; + }; + + // Wrap provided tensor conversion logic for op output tensors. Adds to map + // and backend graph after conversion. + auto output_tensor_convert_factory = [tensor_converter_factory, + handle_tensor](auto tensor_alloc) { + return [tensor_alloc, tensor_converter_factory, + handle_tensor](const Tensor& litert_tensor) { + auto tensor_converter = tensor_converter_factory(tensor_alloc); + return handle_tensor(litert_tensor, tensor_converter); + }; + }; + + // Convert all ops in subgraph in toplogical order. + auto legalization_map = Ir::MakeLegalizationMap(legalizations); + for (const auto& op : subgraph.Ops()) { + auto it = legalization_map.find(op.Code()); + if (it == legalization_map.end()) { + LITERT_LOG(LITERT_ERROR, "No legalization found for op %d", op.Code()); + return kLiteRtStatusErrorUnsupported; + } + + auto result = it->second->Legalize(op, input_tensor_convert_factory, + output_tensor_convert_factory, + tensor_alloc, op_alloc); + if (!result) { + LITERT_LOG(LITERT_ERROR, "Failed to legalize op %d, with status %d", + op.Code(), result.Error().Status()); + return result.Error().Status(); + } + + auto simple_result = GetSimpleConversionResult(*result); + if (simple_result) { + if (auto stat = builder.RegisterOp(**simple_result); + stat != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to register op %d, with status %d", + op.Code(), stat); + return stat; + } + } + + auto general_result = GetGeneralConversionResult(*result); + if (general_result) { + for (auto* tensor : general_result->intermediate_tensors) { + if (auto stat = builder.RegisterTensor(*tensor); + stat != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, + "Failed to register tensor %d, with status %d", tensor->id, + stat); + return stat; + } + } + + for (auto* op : general_result->ops) { + if (auto stat = builder.RegisterOp(*op); stat != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to register op %d, with status %d", + op->op_code, stat); + return stat; + } + } + } + } + + builder.FinalizeGraph(); + + return kLiteRtStatusOk; +} + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_CONVERT_GRAPH_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/cc/convert_graph_test.cc b/tensorflow/lite/experimental/litert/vendors/cc/convert_graph_test.cc new file mode 100644 index 00000000000000..3314cfe8a78117 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/cc/convert_graph_test.cc @@ -0,0 +1,390 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/cc/convert_graph.h" + +#include +#include +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_buffer_ref.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/core/model/model.h" +#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" +#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tensorflow/lite/experimental/litert/test/test_macros.h" +#include "tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h" +#include "tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h" +#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" + +namespace litert { +namespace { + +using ::litert::example::ExampleOpAllocator; +using ::litert::example::ExampleOpType; +using ::litert::example::ExampleTensorAllocator; +using ::litert::example::ExampleTypes; +using ::litert::example::MakeAllLegalizations; +using ::litert::example::MakeTensorConverter; +using ::testing::AllOf; +using ::testing::ElementsAreArray; +using ::testing::Expectation; +using ::testing::ExpectationSet; +using ::testing::Field; +using ::testing::Return; + +static constexpr std::array kDims = {2, 2}; +static constexpr auto kElementType = kLiteRtElementTypeFloat32; +static constexpr absl::string_view kGraphName = "graph_name"; + +TensorType GetTestTensorType() { + return MakeRankedTensorType(kElementType, absl::MakeConstSpan(kDims)); +} + +class MockGraphBuilder + : public BackendGraphBuilder { + public: + MOCK_METHOD(void, InitGraph, (std::string name), (override)); + MOCK_METHOD(LiteRtStatus, RegisterTensor, (ExampleTypes::Tensor & tensor), + (override)); + MOCK_METHOD(LiteRtStatus, RegisterOp, (ExampleTypes::Op & op), (override)); + MOCK_METHOD(LiteRtStatus, FinalizeGraph, (), (override)); +}; + +TEST(ConvertGraphTest, ConvertSingleSimpleConversion) { + LiteRtSubgraphT subgraph; + + auto& op = subgraph.EmplaceOp(); + op.SetOpCode(kLiteRtOpCodeTflMul); + + auto& input1 = subgraph.EmplaceTensor(); + input1.SetType(GetTestTensorType()); + input1.SetName("input1"); + + auto& input2 = subgraph.EmplaceTensor(); + input2.SetType(GetTestTensorType()); + input2.SetName("input2"); + + auto& output = subgraph.EmplaceTensor(); + output.SetType(GetTestTensorType()); + output.SetName("output"); + + internal::AttachInput(&input1, op); + internal::AttachInput(&input2, op); + internal::AttachOutput(&output, op); + + subgraph.Inputs().push_back(&input1); + subgraph.Inputs().push_back(&input2); + subgraph.Outputs().push_back(&output); + + Subgraph litert_subgraph(&subgraph); + + ExampleOpAllocator op_alloc; + ExampleTensorAllocator tensor_alloc; + + MockGraphBuilder builder; + + Expectation init_graph = + EXPECT_CALL(builder, InitGraph(std::string(kGraphName))).Times(1); + + ExpectationSet reg_inputs; + reg_inputs += + EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, + input1.Name()))) + .Times(1) + .After(init_graph) + .WillOnce(Return(kLiteRtStatusOk)); + reg_inputs += + EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, + input2.Name()))) + .Times(1) + .After(init_graph) + .WillOnce(Return(kLiteRtStatusOk)); + + ExpectationSet reg_outputs; + reg_outputs += + EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, + output.Name()))) + .Times(1) + .After(init_graph) + .WillOnce(Return(kLiteRtStatusOk)); + + auto match_reg_op_args = + AllOf(Field(&ExampleTypes::Op::op_code, ExampleOpType::MUL), + Field(&ExampleTypes::Op::input_names, + ElementsAreArray({input1.Name(), input2.Name()})), + Field(&ExampleTypes::Op::output_names, + ElementsAreArray({output.Name()}))); + + Expectation reg_op = EXPECT_CALL(builder, RegisterOp(match_reg_op_args)) + .Times(1) + .After(reg_inputs, reg_outputs) + .WillOnce(Return(kLiteRtStatusOk)); + + Expectation finalize_graph = EXPECT_CALL(builder, FinalizeGraph()) + .Times(1) + .After(reg_op) + .WillOnce(Return(kLiteRtStatusOk)); + + auto stat = ConvertGraph( + litert_subgraph, std::string(kGraphName), MakeTensorConverter, + tensor_alloc, op_alloc, MakeAllLegalizations(), builder); + + LITERT_ASSERT_STATUS_OK(stat); +} + +TEST(ConvertGraphTest, ConvertSingleGeneralConversion) { + LiteRtSubgraphT subgraph; + + auto& op = subgraph.EmplaceOp(); + op.SetOpCode(kLiteRtOpCodeTflAdd); + + tflite::AddOptionsT add_opts; + add_opts.fused_activation_function = tflite::ActivationFunctionType_RELU; + internal::TflOptions tfl_opts; + tfl_opts.Set(std::move(add_opts)); + detail::SetTflOptions(op, std::move(tfl_opts)); + + auto& input1 = subgraph.EmplaceTensor(); + input1.SetType(GetTestTensorType()); + input1.SetName("input1"); + + auto& input2 = subgraph.EmplaceTensor(); + input2.SetType(GetTestTensorType()); + input2.SetName("input2"); + + auto& output = subgraph.EmplaceTensor(); + output.SetType(GetTestTensorType()); + output.SetName("output"); + + internal::AttachInput(&input1, op); + internal::AttachInput(&input2, op); + internal::AttachOutput(&output, op); + + subgraph.Inputs().push_back(&input1); + subgraph.Inputs().push_back(&input2); + subgraph.Outputs().push_back(&output); + + Subgraph litert_subgraph(&subgraph); + + ExampleOpAllocator op_alloc; + ExampleTensorAllocator tensor_alloc; + + MockGraphBuilder builder; + + Expectation init_graph = + EXPECT_CALL(builder, InitGraph(std::string(kGraphName))).Times(1); + + ExpectationSet reg_inputs; + reg_inputs += + EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, + input1.Name()))) + .Times(1) + .After(init_graph) + .WillOnce(Return(kLiteRtStatusOk)); + reg_inputs += + EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, + input2.Name()))) + .Times(1) + .After(init_graph) + .WillOnce(Return(kLiteRtStatusOk)); + + ExpectationSet reg_intermediates; + reg_intermediates += + EXPECT_CALL(builder, + RegisterTensor(Field(&ExampleTypes::Tensor::name, + example::kIntermediateTensorName))) + .Times(1) + .After(init_graph) + .WillOnce(Return(kLiteRtStatusOk)); + + ExpectationSet reg_outputs; + reg_outputs += + EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, + output.Name()))) + .Times(1) + .After(init_graph) + .WillOnce(Return(kLiteRtStatusOk)); + + auto match_reg_add_args = + AllOf(Field(&ExampleTypes::Op::op_code, ExampleOpType::ADD), + Field(&ExampleTypes::Op::input_names, + ElementsAreArray({input1.Name(), input2.Name()})), + Field(&ExampleTypes::Op::output_names, + ElementsAreArray({example::kIntermediateTensorName}))); + + Expectation reg_add = EXPECT_CALL(builder, RegisterOp(match_reg_add_args)) + .Times(1) + .After(reg_inputs, reg_intermediates) + .WillOnce(Return(kLiteRtStatusOk)); + + auto match_reg_relu_args = + AllOf(Field(&ExampleTypes::Op::op_code, ExampleOpType::RELU), + Field(&ExampleTypes::Op::input_names, + ElementsAreArray({example::kIntermediateTensorName})), + Field(&ExampleTypes::Op::output_names, + ElementsAreArray({output.Name()}))); + + Expectation reg_relu = EXPECT_CALL(builder, RegisterOp(match_reg_relu_args)) + .Times(1) + .After(reg_add, reg_intermediates, reg_outputs) + .WillOnce(Return(kLiteRtStatusOk)); + + Expectation finalize_graph = EXPECT_CALL(builder, FinalizeGraph()) + .Times(1) + .After(reg_relu) + .WillOnce(Return(kLiteRtStatusOk)); + + auto stat = ConvertGraph( + litert_subgraph, std::string(kGraphName), MakeTensorConverter, + tensor_alloc, op_alloc, MakeAllLegalizations(), builder); + + LITERT_ASSERT_STATUS_OK(stat); +} + +TEST(ConvertGraphTest, ConvertMultipleOps) { + LiteRtSubgraphT subgraph; + + auto& op = subgraph.EmplaceOp(); + op.SetOpCode(kLiteRtOpCodeTflMul); + + auto& input1 = subgraph.EmplaceTensor(); + input1.SetType(GetTestTensorType()); + input1.SetName("input1"); + + auto& input2 = subgraph.EmplaceTensor(); + input2.SetType(GetTestTensorType()); + input2.SetName("input2"); + + auto& output1 = subgraph.EmplaceTensor(); + output1.SetType(GetTestTensorType()); + output1.SetName("output1"); + + auto& cst = subgraph.EmplaceTensor(); + OwningBufferRef weights(8); + cst.Weights().SetFromBuf(weights); + cst.SetName("cst"); + cst.SetType(GetTestTensorType()); + + auto& op2 = subgraph.EmplaceOp(); + op2.SetOpCode(kLiteRtOpCodeTflAdd); + + auto& output2 = subgraph.EmplaceTensor(); + output2.SetType(GetTestTensorType()); + output2.SetName("output2"); + + internal::AttachInput(&input1, op); + internal::AttachInput(&input2, op); + internal::AttachOutput(&output1, op); + + internal::AttachInput(&output1, op2); + internal::AttachInput(&cst, op2); + internal::AttachOutput(&output2, op2); + + subgraph.Inputs().push_back(&input1); + subgraph.Inputs().push_back(&input2); + subgraph.Outputs().push_back(&output2); + + Subgraph litert_subgraph(&subgraph); + + ExampleOpAllocator op_alloc; + ExampleTensorAllocator tensor_alloc; + + MockGraphBuilder builder; + + Expectation init_graph = + EXPECT_CALL(builder, InitGraph(std::string(kGraphName))).Times(1); + + ExpectationSet reg_inputs; + reg_inputs += + EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, + input1.Name()))) + .Times(1) + .After(init_graph) + .WillOnce(Return(kLiteRtStatusOk)); + reg_inputs += + EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, + input2.Name()))) + .Times(1) + .After(init_graph) + .WillOnce(Return(kLiteRtStatusOk)); + + Expectation reg_output1 = + EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, + output1.Name()))) + .Times(1) + .After(init_graph) + .WillOnce(Return(kLiteRtStatusOk)); + + Expectation reg_cst = + EXPECT_CALL(builder, RegisterTensor( + Field(&ExampleTypes::Tensor::name, cst.Name()))) + .Times(1) + .After(init_graph) + .WillOnce(Return(kLiteRtStatusOk)); + + Expectation reg_output2 = + EXPECT_CALL(builder, RegisterTensor(Field(&ExampleTypes::Tensor::name, + output2.Name()))) + .Times(1) + .After(init_graph) + .WillOnce(Return(kLiteRtStatusOk)); + + auto match_reg_op1_args = + AllOf(Field(&ExampleTypes::Op::op_code, ExampleOpType::MUL), + Field(&ExampleTypes::Op::input_names, + ElementsAreArray({input1.Name(), input2.Name()})), + Field(&ExampleTypes::Op::output_names, + ElementsAreArray({output1.Name()}))); + + Expectation reg_op1 = EXPECT_CALL(builder, RegisterOp(match_reg_op1_args)) + .Times(1) + .After(reg_inputs, reg_output1) + .WillOnce(Return(kLiteRtStatusOk)); + + auto match_reg_op2_args = + AllOf(Field(&ExampleTypes::Op::op_code, ExampleOpType::ADD), + Field(&ExampleTypes::Op::input_names, + ElementsAreArray({output1.Name(), cst.Name()})), + Field(&ExampleTypes::Op::output_names, + ElementsAreArray({output2.Name()}))); + + Expectation reg_op2 = EXPECT_CALL(builder, RegisterOp(match_reg_op2_args)) + .Times(1) + .After(reg_op1, reg_cst, reg_output2, reg_output1) + .WillOnce(Return(kLiteRtStatusOk)); + + Expectation finalize_graph = EXPECT_CALL(builder, FinalizeGraph()) + .Times(1) + .After(reg_op2) + .WillOnce(Return(kLiteRtStatusOk)); + + auto stat = ConvertGraph( + litert_subgraph, std::string(kGraphName), MakeTensorConverter, + tensor_alloc, op_alloc, MakeAllLegalizations(), builder); + + LITERT_ASSERT_STATUS_OK(stat); +} + +} // namespace +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/cc/ir_types.h b/tensorflow/lite/experimental/litert/vendors/cc/ir_types.h new file mode 100644 index 00000000000000..a1da917de18a74 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/cc/ir_types.h @@ -0,0 +1,50 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_IR_TYPES_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_IR_TYPES_H_ + +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h" +#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" + +namespace litert { + +// Holds particular backends IR template aliases for convenience. +template +struct IrTypes { + using Op = BackendOp; + using Tensor = BackendTensor; + using OpAllocator = OpAllocator; + using TensorAllocator = TensorAllocator; + using GraphBuilder = BackendGraphBuilder; + using GeneralConversionResult = GeneralConversionResult; + using SimpleConversionResult = SimpleConversionResult; + using ConversionResult = Expected>; + using Legalization = Legalization; + using Legalizations = Legalizations; + using LegalizationMap = LegalizationMap; + using TensorConverter = TensorConverter; + using TensorResult = Expected; + using TensorConverterFactory = TensorConverterFactory; + using TensorMap = TensorMap; + using Capability = Capability; + // NOLINTNEXTLINE + inline static auto MakeLegalizationMap = + litert::MakeLegalizationMap; +}; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_IR_TYPES_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities.h b/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities.h new file mode 100644 index 00000000000000..a462d1744c3886 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities.h @@ -0,0 +1,97 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_PARTITION_WITH_CAPABILITIES_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_PARTITION_WITH_CAPABILITIES_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" + +namespace litert { + +// Higher-level functions for partitioning by leveraging user-defined +// conversions. This method selects ops for partitioning via a callback that +// checks if an op is supported by the backend. + +// Selects ops for partitioning from given subgraph based on given Capability +// check. Returns all ops in the given supbgraph that are supported by the +// backend. Suitable for use in implementing LiteRtCompilerPluginPartition. Any +// allocations of new backend ir types will be done through given external +// allocators. +// NOTE: A missing legalization or any legalization failure will result in +// an op not being supported, rather than a failure of this function. +template +Expected> PartitionWithCapabilities( + const typename Ir::Legalizations& legalizations, + typename Ir::Capability capability, + typename Ir::TensorConverterFactory convert_tensor_fact, + typename Ir::TensorAllocator tensor_allocator, + typename Ir::OpAllocator op_allocator, const Subgraph& litert_subgraph) { + std::vector results; + + // Build map for legalization lookup by op code. + auto map = Ir::MakeLegalizationMap(legalizations); + + // Convert all ops from the given subgraph and check backend support. + for (const auto& litert_op : litert_subgraph.Ops()) { + const auto code = litert_op.Code(); + LITERT_LOG(LITERT_INFO, "Checking support for LiteRtOp: %d", code); + + auto it = map.find(code); + if (it == map.end()) { + LITERT_LOG(LITERT_WARNING, "No legalization found for LiteRtOp: %d", + code); + continue; + } + + // Call user-defined conversion. + auto result = it->second->Legalize(litert_op, convert_tensor_fact, + convert_tensor_fact, tensor_allocator, + op_allocator); + if (!result) { + LITERT_LOG(LITERT_WARNING, "Failed to legalize LiteRtOp: %d", code); + continue; + } + + if (auto simple_result = GetSimpleConversionResult(*result)) { + if (capability(*simple_result)) { + LITERT_LOG(LITERT_INFO, "Selected LiteRtOp: %d", litert_op.Code()); + results.push_back(litert_op.Get()); + } + continue; + } + + // Check all ops emitted from a one-to-many conversion are supported. + if (auto gen_result = GetGeneralConversionResult(*result)) { + const auto b_ops_start = gen_result->ops.cbegin(); + const auto b_ops_end = gen_result->ops.cend(); + if (std::all_of(b_ops_start, b_ops_end, capability)) { + LITERT_LOG(LITERT_INFO, "Selected LiteRtOp: %d", litert_op.Code()); + results.push_back(litert_op.Get()); + } + continue; + } + } + + return results; +} + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_CC_PARTITION_WITH_CAPABILITIES_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities_test.cc b/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities_test.cc new file mode 100644 index 00000000000000..cfdb49ec5eec46 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities_test.cc @@ -0,0 +1,207 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Utility types for mapping LiteRt IR to arbitrary backend specific +// types. Implementations of these types define mapping for ops and tensors +// that may be used in a stndalone fashion. They also may be composed +// to create lowerings of entire graphs with topology. + +#include "tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities.h" + +#include +#include +#include + +#include +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/core/model/model.h" +#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" +#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h" +#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" + +namespace litert { +namespace { + +using ::litert::example::ExampleLegalizeAdd; +using ::litert::example::ExampleLegalizeMul; +using ::litert::example::ExampleOpAllocator; +using ::litert::example::ExampleOpType; +using ::litert::example::ExampleTensorAllocator; +using ::litert::example::ExampleTypes; +using ::litert::example::MakeTensorConverter; + +bool ExampleCapability(const ExampleTypes::Op* op) { + return op->op_code == ExampleOpType::ADD || + op->op_code == ExampleOpType::RELU; +} + +TEST(PartitionWithCapabilitiesTest, EmptyGraph) { + ExampleTypes::Legalizations legalizations; + legalizations.push_back(ExampleLegalizeAdd::Make()); + + LiteRtSubgraphT subgraph; + Subgraph litert_subgraph(&subgraph); + + ExampleTensorAllocator tensor_alloc; + ExampleOpAllocator op_alloc; + + auto ops = PartitionWithCapabilities( + legalizations, ExampleCapability, MakeTensorConverter, tensor_alloc, + op_alloc, litert_subgraph); + ASSERT_TRUE(ops); + EXPECT_TRUE(ops->empty()); +} + +TEST(PartitionWithCapabilitiesTest, SingleSelectedOp) { + static constexpr std::array kDims = {2, 2}; + + ExampleTypes::Legalizations legalizations; + legalizations.push_back(ExampleLegalizeAdd::Make()); + + LiteRtSubgraphT subgraph; + + const auto type = MakeRankedTensorType(kLiteRtElementTypeFloat32, kDims); + + auto& input1 = subgraph.EmplaceTensor(); + input1.SetType(type); + + auto& input2 = subgraph.EmplaceTensor(); + input2.SetType(type); + + auto& output = subgraph.EmplaceTensor(); + output.SetType(type); + + auto& op = subgraph.EmplaceOp(); + op.SetOpCode(kLiteRtOpCodeTflAdd); + + internal::AttachInput(&input1, op); + internal::AttachInput(&input2, op); + internal::AttachOutput(&output, op); + + Subgraph litert_subgraph(&subgraph); + + ExampleTensorAllocator tensor_alloc; + ExampleOpAllocator op_alloc; + + auto ops = PartitionWithCapabilities( + legalizations, ExampleCapability, MakeTensorConverter, tensor_alloc, + op_alloc, litert_subgraph); + + ASSERT_TRUE(ops); + EXPECT_EQ(ops->size(), 1); +} + +TEST(PartitionWithCapabilitiesTest, MultiSelectedOp) { + static constexpr std::array kDims = {2, 2}; + + ExampleTypes::Legalizations legalizations; + legalizations.push_back(ExampleLegalizeAdd::Make()); + + LiteRtSubgraphT subgraph; + + const auto type = MakeRankedTensorType(kLiteRtElementTypeFloat32, kDims); + + auto& add1_input = subgraph.EmplaceTensor(); + add1_input.SetType(type); + auto& add1_output = subgraph.EmplaceTensor(); + add1_output.SetType(type); + auto& add1 = subgraph.EmplaceOp(); + add1.SetOpCode(kLiteRtOpCodeTflAdd); + + internal::AttachInput(&add1_input, add1); + internal::AttachInput(&add1_input, add1); + internal::AttachOutput(&add1_output, add1); + + auto& mul_output = subgraph.EmplaceTensor(); + mul_output.SetType(type); + auto& mul = subgraph.EmplaceOp(); + mul.SetOpCode(kLiteRtOpCodeTflMul); + + internal::AttachInput(&add1_output, mul); + internal::AttachOutput(&mul_output, mul); + + auto& add2_output = subgraph.EmplaceTensor(); + add2_output.SetType(type); + auto& add2 = subgraph.EmplaceOp(); + add2.SetOpCode(kLiteRtOpCodeTflAdd); + + internal::AttachInput(&mul_output, add2); + internal::AttachInput(&mul_output, add2); + internal::AttachOutput(&add2_output, add2); + + Subgraph litert_subgraph(&subgraph); + + ExampleTensorAllocator tensor_alloc; + ExampleOpAllocator op_alloc; + + auto ops = PartitionWithCapabilities( + legalizations, ExampleCapability, MakeTensorConverter, tensor_alloc, + op_alloc, litert_subgraph); + + ASSERT_TRUE(ops); + + ASSERT_EQ(ops->size(), 2); + EXPECT_EQ(ops->front(), &add1); + EXPECT_EQ(ops->back(), &add2); +} + +TEST(PartitionWithCapabilitiesTest, WithGeneralResult) { + static constexpr std::array kDims = {2, 2}; + + ExampleTypes::Legalizations legalizations; + legalizations.push_back(ExampleLegalizeAdd::Make()); + + LiteRtSubgraphT subgraph; + + const auto type = MakeRankedTensorType(kLiteRtElementTypeFloat32, kDims); + + auto& add1_input = subgraph.EmplaceTensor(); + add1_input.SetType(type); + auto& add1_output = subgraph.EmplaceTensor(); + add1_output.SetType(type); + auto& add1 = subgraph.EmplaceOp(); + add1.SetOpCode(kLiteRtOpCodeTflAdd); + + internal::AttachInput(&add1_input, add1); + internal::AttachInput(&add1_input, add1); + internal::AttachOutput(&add1_output, add1); + + tflite::AddOptionsT add_opts; + add_opts.fused_activation_function = tflite::ActivationFunctionType_RELU; + internal::TflOptions tfl_opts; + tfl_opts.Set(std::move(add_opts)); + detail::SetTflOptions(add1, std::move(tfl_opts)); + + Subgraph litert_subgraph(&subgraph); + + ExampleTensorAllocator tensor_alloc; + ExampleOpAllocator op_alloc; + + auto ops = PartitionWithCapabilities( + legalizations, ExampleCapability, MakeTensorConverter, tensor_alloc, + op_alloc, litert_subgraph); + + ASSERT_TRUE(ops); + + ASSERT_EQ(ops->size(), 1); + EXPECT_EQ(ops->front(), &add1); +} + +} // namespace + +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/examples/BUILD b/tensorflow/lite/experimental/litert/vendors/examples/BUILD index 1213a9061ed80d..4c6fb69a4a0435 100644 --- a/tensorflow/lite/experimental/litert/vendors/examples/BUILD +++ b/tensorflow/lite/experimental/litert/vendors/examples/BUILD @@ -21,7 +21,11 @@ package( litert_dynamic_lib( name = "example_plugin", - srcs = ["example_plugin.cc"], + srcs = [ + "example_plugin.cc", + "example_plugin_common.cc", + "example_plugin_common.h", + ], hdrs = ["//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin.h"], export_litert_only = True, linkstatic = 1, @@ -49,14 +53,102 @@ cc_test( ":example_plugin", # buildcleaner: keep "//tensorflow/lite/experimental/litert/c:litert_model", "//tensorflow/lite/experimental/litert/c:litert_op_code", - "//tensorflow/lite/experimental/litert/cc:litert_macros", "//tensorflow/lite/experimental/litert/cc:litert_model", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", "//tensorflow/lite/experimental/litert/core/model", "//tensorflow/lite/experimental/litert/test:common", "//tensorflow/lite/experimental/litert/vendors/cc:litert_compiler_plugin", - "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "example_conversion_impl", + srcs = ["example_conversion_impl.cc"], + hdrs = ["example_conversion_impl.h"], + visibility = ["//tensorflow/lite/experimental/litert/vendors/cc:__pkg__"], + deps = [ + ":example_ir", + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_op_code", + "//tensorflow/lite/experimental/litert/c:litert_options", + "//tensorflow/lite/experimental/litert/cc:litert_element_type", + "//tensorflow/lite/experimental/litert/cc:litert_expected", + "//tensorflow/lite/experimental/litert/cc:litert_model", + "//tensorflow/lite/experimental/litert/vendors/cc:backend_ir", + "//tensorflow/lite/experimental/litert/vendors/cc:conversion", + "//tensorflow/lite/experimental/litert/vendors/cc:ir_types", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "example_conversion_impl_test", + srcs = ["example_conversion_impl_test.cc"], + deps = [ + ":example_conversion_impl", + ":example_ir", + "//tensorflow/lite/experimental/litert/c:litert_op_code", + "//tensorflow/lite/experimental/litert/cc:litert_model", + "//tensorflow/lite/experimental/litert/core/model", + "//tensorflow/lite/experimental/litert/core/model:model_graph", + "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", + "//tensorflow/lite/experimental/litert/test:test_macros", + "//tensorflow/lite/experimental/litert/vendors/cc:conversion", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "example_ir", + srcs = ["example_ir.cc"], + hdrs = ["example_ir.h"], + visibility = ["//tensorflow/lite/experimental/litert/vendors/cc:__pkg__"], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/vendors/cc:backend_ir", + "//tensorflow/lite/experimental/litert/vendors/cc:ir_types", + ], +) + +cc_library( + name = "example_plugin_with_conversions", + srcs = [ + "example_plugin_common.cc", + "example_plugin_common.h", + "example_plugin_with_conversions.cc", + ], + deps = [ + ":example_conversion_impl", + ":example_ir", + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_model", + "//tensorflow/lite/experimental/litert/cc:litert_macros", + "//tensorflow/lite/experimental/litert/cc:litert_model", + "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin", + "//tensorflow/lite/experimental/litert/vendors/cc:convert_graph", + "//tensorflow/lite/experimental/litert/vendors/cc:partition_with_capabilities", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_test( + name = "example_plugin_with_conversions_test", + srcs = ["example_plugin_with_conversions_test.cc"], + data = ["//tensorflow/lite/experimental/litert/test:mlir_test_data"], + deps = [ + ":example_plugin_with_conversions", # buildcleaner: keep + "//tensorflow/lite/experimental/litert/c:litert_model", + "//tensorflow/lite/experimental/litert/c:litert_op_code", + "//tensorflow/lite/experimental/litert/cc:litert_model", + "//tensorflow/lite/experimental/litert/core/model", + "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/experimental/litert/test:test_macros", + "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin", + "//tensorflow/lite/experimental/litert/vendors/cc:litert_compiler_plugin", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.cc new file mode 100644 index 00000000000000..fa6e163aee4b77 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.cc @@ -0,0 +1,64 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h" + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h" +#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" +#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" + +namespace litert::example { + +TensorConverter MakeTensorConverter( + TensorAllocator alloc) { + return [alloc](const Tensor& litert_tensor) -> Expected { + auto& tensor = *alloc(); + tensor.name = litert_tensor.Name(); + + auto litert_type = litert_tensor.RankedTensorType(); + if (!litert_type) { + return Error(litert_type.Error().Status()); + } + + const auto litert_dims = litert_type->Layout().Dimensions(); + + tensor.dims.assign(litert_dims.cbegin(), litert_dims.cend()); + + switch (litert_tensor.RankedTensorType()->ElementType()) { + case ElementType::Float32: + tensor.type = ExampleTensorType::FLOAT; + break; + case ElementType::Int32: + tensor.type = ExampleTensorType::INT; + break; + default: + return Error(kLiteRtStatusErrorInvalidArgument); + } + + return &tensor; + }; +} + +ExampleTypes::Legalizations MakeAllLegalizations() { + ExampleTypes::Legalizations legalizations; + legalizations.push_back(ExampleLegalizeMul::Make()); + legalizations.push_back(ExampleLegalizeAdd::Make()); + return legalizations; +} + +} // namespace litert::example diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h b/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h new file mode 100644 index 00000000000000..e7b932618bfcfb --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h @@ -0,0 +1,125 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_CONVERSION_IMPL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_CONVERSION_IMPL_H_ + +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/c/litert_options.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" +#include "tensorflow/lite/experimental/litert/vendors/cc/ir_types.h" +#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" + +namespace litert::example { + +// Conversion type implementations for the fictional "example" backend. + +ExampleTypes::TensorConverter MakeTensorConverter( + ExampleTypes::TensorAllocator alloc); + +static constexpr absl::string_view kIntermediateTensorName = + "intermediate_bin_output"; + +// Example legalization for simple binary ops. +template +class ExampleBinOpLegalization : public Legalization { + private: + using Self = ExampleBinOpLegalization; + + public: + using Ptr = std::unique_ptr; + + static Ptr Make() { return std::make_unique(); } + + // Return the litert op code to match on. + constexpr LiteRtOpCode OpToMatch() const override { return LiteRtOpType; } + + // Determines if the given litert op has a fused relu attribute. + bool HasFusedRelu(const Op& litert_op) const { + if constexpr (LiteRtOpType != kLiteRtOpCodeTflAdd) { + return false; + } + uint32_t faf; + if (LiteRtGetAddFusedActivationOption(litert_op.Get(), &faf) != + kLiteRtStatusOk) { + return false; + } + return faf == 1; + } + + // Transforms LiteRtAdd op into example op definition using the tensor + // converter to map tensors within. + ExampleTypes::ConversionResult LegalizeImpl( + const Op& litert_op, const Tensors& inputs, const Tensors& outputs, + ExampleTypes::TensorAllocator tensor_allocator, + ExampleTypes::OpAllocator op_allocator) const override { + ABSL_DCHECK_EQ(litert_op.Code(), LiteRtOpType); + + auto& bin_op = *op_allocator(); + bin_op.op_code = BackendOpType; + + if (inputs.size() != 2 || outputs.size() != 1) { + return Error(kLiteRtStatusErrorInvalidArgument); + } + + for (const auto* input : inputs) { + bin_op.inputs.push_back(input->id); + bin_op.input_names.push_back(input->name); + } + + auto& output_tensor = *outputs.front(); + if (!HasFusedRelu(litert_op)) { + bin_op.outputs.push_back(output_tensor.id); + bin_op.output_names.push_back(output_tensor.name); + return Expected(&bin_op); + } + + auto* bin_output = tensor_allocator(); + bin_output->dims = output_tensor.dims; + bin_output->type = output_tensor.type; + bin_output->name = std::string(kIntermediateTensorName); + bin_op.outputs.push_back(bin_output->id); + bin_op.output_names.push_back(bin_output->name); + + auto& relu = *op_allocator(); + relu.op_code = ExampleOpType::RELU; + relu.inputs.push_back(bin_output->id); + relu.input_names.push_back(bin_output->name); + relu.outputs.push_back(output_tensor.id); + relu.output_names.push_back(output_tensor.name); + + ExampleTypes::GeneralConversionResult result; + result.ops.push_back(&bin_op); + result.ops.push_back(&relu); + result.intermediate_tensors.push_back(bin_output); + + return ExampleTypes::ConversionResult(result); + } +}; + +using ExampleLegalizeAdd = + ExampleBinOpLegalization; +using ExampleLegalizeMul = + ExampleBinOpLegalization; + +ExampleTypes::Legalizations MakeAllLegalizations(); + +} // namespace litert::example + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_CONVERSION_IMPL_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl_test.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl_test.cc new file mode 100644 index 00000000000000..8cf105f70471ac --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl_test.cc @@ -0,0 +1,213 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h" + +#include +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/core/model/model.h" +#include "tensorflow/lite/experimental/litert/core/model/model_graph.h" +#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tensorflow/lite/experimental/litert/test/test_macros.h" +#include "tensorflow/lite/experimental/litert/vendors/cc/conversion.h" +#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace litert::example { +namespace { + +using ::testing::ElementsAreArray; +using ::testing::HasSubstr; + +TEST(ExampleConversionImplTest, ConvertTensor) { + static constexpr std::array kDims = {2, 2}; + static constexpr absl::string_view kName = "foo"; + + LiteRtTensorT litert_tensor; + litert_tensor.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, + absl::MakeConstSpan(kDims))); + litert_tensor.SetName(std::string(kName)); + + ExampleTensorAllocator tensor_alloc; + auto tensor_convert = MakeTensorConverter(tensor_alloc); + + auto& example_tensor = **tensor_convert(Tensor(&litert_tensor)); + EXPECT_EQ(example_tensor.type, ExampleTensorType::FLOAT); + EXPECT_THAT(example_tensor.dims, ElementsAreArray(kDims)); + EXPECT_EQ(example_tensor.name, kName); +} + +TEST(ExampleConversionImplTest, ExampleGraphBuilder) { + ExampleTensor input; + input.type = ExampleTensorType::FLOAT; + input.dims = {2, 2}; + input.id = 1; + + ExampleTensor output; + output.type = ExampleTensorType::INT; + output.dims = {3, 3}; + output.id = 2; + + ExampleOp op; + op.op_code = ExampleOpType::ADD; + op.inputs = {1}; + op.outputs = {2}; + + ExampleGraphBuilder builder; + static constexpr absl::string_view kName = "FOO_GRAPH"; + + builder.InitGraph(std::string(kName)); + LITERT_ASSERT_STATUS_OK(builder.RegisterTensor(input)); + LITERT_ASSERT_STATUS_OK(builder.RegisterOp(op)); + LITERT_ASSERT_STATUS_OK(builder.RegisterTensor(output)); + LITERT_ASSERT_STATUS_OK(builder.FinalizeGraph()); + + const auto serialized = builder.Serialize(); + EXPECT_THAT(serialized, HasSubstr("1FLOAT[2, 2]")); + EXPECT_THAT(serialized, HasSubstr("2INT[3, 3]")); + EXPECT_THAT(serialized, HasSubstr("ADD(1)->(2)")); + EXPECT_THAT(serialized, HasSubstr("FINALIZED")); + EXPECT_THAT(serialized, HasSubstr(kName)); +} + +TEST(ExampleConversionImplTest, LegalizeAddSimpleResult) { + static constexpr std::array kDims = {2, 2}; + + LiteRtTensorT input1; + input1.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, + absl::MakeConstSpan(kDims))); + input1.SetName("input1"); + + LiteRtTensorT input2; + input2.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, + absl::MakeConstSpan(kDims))); + input2.SetName("input2"); + + LiteRtTensorT output; + output.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, + absl::MakeConstSpan(kDims))); + output.SetName("output"); + + LiteRtOpT op; + op.SetOpCode(kLiteRtOpCodeTflAdd); + internal::AttachInput(&input1, op); + internal::AttachInput(&input2, op); + internal::AttachOutput(&output, op); + + tflite::AddOptionsT add_opts; + add_opts.fused_activation_function = tflite::ActivationFunctionType_NONE; + internal::TflOptions tfl_opts; + tfl_opts.Set(std::move(add_opts)); + detail::SetTflOptions(op, std::move(tfl_opts)); + + ExampleTensorAllocator tensor_alloc; + ExampleOpAllocator op_alloc; + + ExampleLegalizeAdd legalize_add; + EXPECT_EQ(legalize_add.OpToMatch(), kLiteRtOpCodeTflAdd); + + auto legalized = + legalize_add.Legalize(Op(&op), MakeTensorConverter, MakeTensorConverter, + tensor_alloc, op_alloc); + + ASSERT_TRUE(legalized); + + auto simple_result = GetSimpleConversionResult(*legalized); + ASSERT_TRUE(simple_result); + auto& example_op = **simple_result; + + EXPECT_EQ(example_op.op_code, ExampleOpType::ADD); + EXPECT_THAT(example_op.inputs, ElementsAreArray({0, 1})); + EXPECT_THAT(example_op.input_names, + ElementsAreArray({input1.Name(), input2.Name()})); + EXPECT_THAT(example_op.outputs, ElementsAreArray({2})); + EXPECT_THAT(example_op.output_names, ElementsAreArray({output.Name()})); +} + +TEST(ExampleConversionImplTest, LegalizeAddGeneralResult) { + static constexpr std::array kDims = {2, 2}; + LiteRtTensorT input1; + input1.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, + absl::MakeConstSpan(kDims))); + input1.SetName("input1"); + + LiteRtTensorT input2; + input2.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, + absl::MakeConstSpan(kDims))); + input2.SetName("input2"); + + LiteRtTensorT output; + output.SetType(MakeRankedTensorType(kLiteRtElementTypeFloat32, + absl::MakeConstSpan(kDims))); + output.SetName("output"); + + LiteRtOpT op; + op.SetOpCode(kLiteRtOpCodeTflAdd); + internal::AttachInput(&input1, op); + internal::AttachInput(&input2, op); + internal::AttachOutput(&output, op); + + tflite::AddOptionsT add_opts; + add_opts.fused_activation_function = tflite::ActivationFunctionType_RELU; + internal::TflOptions tfl_opts; + tfl_opts.Set(std::move(add_opts)); + detail::SetTflOptions(op, std::move(tfl_opts)); + + ExampleTensorAllocator tensor_alloc; + ExampleOpAllocator op_alloc; + + auto legalize_add = ExampleLegalizeAdd::Make(); + EXPECT_EQ(legalize_add->OpToMatch(), kLiteRtOpCodeTflAdd); + + auto legalized = + legalize_add->Legalize(Op(&op), MakeTensorConverter, MakeTensorConverter, + tensor_alloc, op_alloc); + ASSERT_TRUE(legalized); + + auto gen_result = GetGeneralConversionResult(*legalized); + ASSERT_TRUE(gen_result); + + ASSERT_EQ(gen_result->ops.size(), 2); + EXPECT_EQ(gen_result->ops[0]->op_code, ExampleOpType::ADD); + EXPECT_THAT(gen_result->ops[0]->inputs, ElementsAreArray({0, 1})); + EXPECT_THAT(gen_result->ops[0]->input_names, + ElementsAreArray({input1.Name(), input2.Name()})); + EXPECT_THAT(gen_result->ops[0]->outputs, ElementsAreArray({3})); + EXPECT_THAT(gen_result->ops[0]->output_names, + ElementsAreArray({kIntermediateTensorName})); + EXPECT_EQ(gen_result->ops[1]->op_code, ExampleOpType::RELU); + EXPECT_THAT(gen_result->ops[1]->inputs, ElementsAreArray({3})); + EXPECT_THAT(gen_result->ops[1]->input_names, + ElementsAreArray({kIntermediateTensorName})); + EXPECT_THAT(gen_result->ops[1]->outputs, ElementsAreArray({2})); + EXPECT_THAT(gen_result->ops[1]->output_names, + ElementsAreArray({output.Name()})); + EXPECT_EQ(gen_result->intermediate_tensors.size(), 1); + EXPECT_EQ(gen_result->intermediate_tensors.front()->id, 3); + EXPECT_EQ(gen_result->intermediate_tensors.front()->name, + kIntermediateTensorName); +} + +} // namespace + +} // namespace litert::example diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_ir.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_ir.cc new file mode 100644 index 00000000000000..da06b617d9f15b --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/examples/example_ir.cc @@ -0,0 +1,87 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" + +#include +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" + +namespace litert::example { + +namespace { + +template +void PrintWithCommas(It start, It end, std::ostream& out) { + for (auto it = start; it < end; ++it) { + out << std::to_string(*it); + if (it != end - 1) { + out << ", "; + } + } +} + +} // namespace + +LiteRtStatus ExampleGraphBuilder::RegisterOp(ExampleOp& op) { + switch (op.op_code) { + case ExampleOpType::ADD: + example_graph_ << "ADD"; + break; + case ExampleOpType::MUL: + example_graph_ << "MUL"; + break; + case ExampleOpType::RELU: + example_graph_ << "RELU"; + break; + } + example_graph_ << "("; + PrintWithCommas(op.inputs.cbegin(), op.inputs.cend(), example_graph_); + example_graph_ << ")->("; + PrintWithCommas(op.outputs.cbegin(), op.outputs.cend(), example_graph_); + example_graph_ << ")"; + return kLiteRtStatusOk; +} + +LiteRtStatus ExampleGraphBuilder::RegisterTensor(ExampleTensor& tensor) { + example_graph_ << std::to_string(tensor.id); + switch (tensor.type) { + case ExampleTensorType::FLOAT: + example_graph_ << "FLOAT"; + break; + case ExampleTensorType::INT: + example_graph_ << "INT"; + break; + } + example_graph_ << "["; + PrintWithCommas(tensor.dims.cbegin(), tensor.dims.cend(), example_graph_); + example_graph_ << "]"; + return kLiteRtStatusOk; +} + +LiteRtStatus ExampleGraphBuilder::FinalizeGraph() { + example_graph_ << "FINALIZED"; + return kLiteRtStatusOk; +} + +void ExampleGraphBuilder::InitGraph(std::string graph_name) { + example_graph_ << "name=" << graph_name << "\n"; +} + +std::string ExampleGraphBuilder::Serialize() const { + return example_graph_.str(); +} + +} // namespace litert::example diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_ir.h b/tensorflow/lite/experimental/litert/vendors/examples/example_ir.h new file mode 100644 index 00000000000000..e423a53f382b8d --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/examples/example_ir.h @@ -0,0 +1,153 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_IR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_IR_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/lite/experimental/litert/vendors/cc/backend_ir.h" +#include "tensorflow/lite/experimental/litert/vendors/cc/ir_types.h" + +namespace litert::example { + +// Example IR wrapper types for an imaginary backend. + +// Example backend knows only float and int 32. +enum class ExampleTensorType { + FLOAT, + INT, +}; + +// Example backend tensor wrapper that stores the type and shape and unique ID. +struct ExampleTensor { + using Id = int32_t; + ExampleTensorType type; + std::vector dims; + std::string name; + Id id = -1; +}; + +// Example backend knows only a few simple ops. +enum class ExampleOpType { + ADD, + MUL, + RELU, +}; + +// Example backend op that stores op type as well as input and output tensor +// IDs and names. +struct ExampleOp { + ExampleOpType op_code; + std::vector inputs; + std::vector input_names; + std::vector outputs; + std::vector output_names; +}; + +// Simple allocator(s) for example example IR types that provides pointer +// stability. +template +class ExampleIrAllocatorBase { + public: + ExampleIrAllocatorBase(const ExampleIrAllocatorBase&) = delete; + ExampleIrAllocatorBase& operator=(const ExampleIrAllocatorBase&) = delete; + ExampleIrAllocatorBase() = default; + + protected: + std::list ir_; +}; + +// Allocator for example tensors that provides pointer stability and unique IDs. +class ExampleTensorAllocator : public ExampleIrAllocatorBase { + private: + using Alloc = BackendIrAllocator; + + public: + ExampleTensor* operator()() { + auto& tensor = this->ir_.emplace_back(); + tensor.id = this->next_id_++; + return &tensor; + } + + // Return lambda instead of implicit copy construction when converting to + // function type. + // NOLINTNEXTLINE + operator Alloc() { + return [this]() { return this->operator()(); }; + } + + ExampleTensorAllocator(const ExampleTensorAllocator&) = delete; + ExampleTensorAllocator& operator=(const ExampleTensorAllocator&) = delete; + ExampleTensorAllocator() = default; + + private: + uint32_t next_id_ = 0; +}; + +// Allocator for example ops that provides pointer stability. +class ExampleOpAllocator : public ExampleIrAllocatorBase { + private: + using Alloc = BackendIrAllocator; + + public: + ExampleOp* operator()() { return &this->ir_.emplace_back(); } + + // Return lambda instead of implicit copy construction when converting to + // function type. + // NOLINTNEXTLINE + operator Alloc() { + return [this]() { return this->operator()(); }; + } + + ExampleOpAllocator(const ExampleOpAllocator&) = delete; + ExampleOpAllocator& operator=(const ExampleOpAllocator&) = delete; + ExampleOpAllocator() = default; +}; + +// Builder for graph conversion to example IR. The internal example IR graph is +// simply a string representation of the graph. +class ExampleGraphBuilder + : public BackendGraphBuilder { + public: + // Prefixes ir string. + void InitGraph(std::string graph_name) override; + + // Registers tensor into the currrent graph by simply appending its string + // representation. + LiteRtStatus RegisterTensor(ExampleTensor& tensor) override; + + // Registers op into the currrent graph by simply appending its string + // representation. + LiteRtStatus RegisterOp(ExampleOp& op) override; + + // Simply appends tag to IR string. + LiteRtStatus FinalizeGraph() override; + + // Gets the serialized IR representation. + std::string Serialize() const; + + private: + std::stringstream example_graph_; +}; + +using ExampleTypes = IrTypes; + +} // namespace litert::example + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_IR_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin.cc index 1b461c2968eed0..e994f7d9d70e7c 100644 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin.cc +++ b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin.cc @@ -12,12 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - -#include #include -#include -#include #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" @@ -25,97 +20,11 @@ #include "tensorflow/lite/experimental/litert/cc/litert_macros.h" #include "tensorflow/lite/experimental/litert/cc/litert_model.h" #include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" +#include "tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h" -// -// Configurations -// - -namespace { - -constexpr char kPluginManufacturer[] = "ExampleSocManufacturer"; -constexpr char kPluginSocModel[] = "ExampleSocModel"; - -} // namespace - -LiteRtStatus LiteRtGetCompilerPluginVersion(LiteRtApiVersion* api_version) { - if (!api_version) { - return kLiteRtStatusErrorInvalidArgument; - } - api_version->major = LITERT_API_VERSION_MAJOR; - api_version->minor = LITERT_API_VERSION_MINOR; - api_version->patch = LITERT_API_VERSION_PATCH; - return kLiteRtStatusOk; -} - -const char* LiteRtGetCompilerPluginSocManufacturer() { - return kPluginManufacturer; -} - -LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( - LiteRtCompilerPlugin compiler_plugin, - LiteRtParamIndex* num_supported_soc_models) { - if (!compiler_plugin || !num_supported_soc_models) { - return kLiteRtStatusErrorInvalidArgument; - } - *num_supported_soc_models = 1; - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompilerPluginSupportedSocModel( - LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, - const char** soc_model_name) { - if (!compiler_plugin || !soc_model_name) { - return kLiteRtStatusErrorInvalidArgument; - } else if (soc_model_idx != 0) { - return kLiteRtStatusErrorUnsupported; - } - *soc_model_name = kPluginSocModel; - return kLiteRtStatusOk; -} - -// -// Compiled Result Definition -// - -struct LiteRtCompiledResultT { - std::string byte_code; - std::vector per_op_data; -}; - -LiteRtStatus LiteRtGetCompiledResultByteCode( - LiteRtCompiledResult compiled_result, const void** byte_code, - size_t* byte_code_size) { - *byte_code = compiled_result->byte_code.data(); - *byte_code_size = compiled_result->byte_code.size(); - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetCompiledResultCallInfo( - LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, - const void** call_info, size_t* call_info_size) { - if (call_idx >= compiled_result->per_op_data.size()) { - return kLiteRtStatusErrorIndexOOB; - } - - *call_info = compiled_result->per_op_data.at(call_idx).data(); - *call_info_size = compiled_result->per_op_data.at(call_idx).size(); - - return kLiteRtStatusOk; -} - -LiteRtStatus LiteRtGetNumCompiledResultCalls( - LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls) { - *num_calls = compiled_result->per_op_data.size(); - return kLiteRtStatusOk; -} - -void LiteRtDestroyCompiledResult(LiteRtCompiledResult compiled_result) { - delete compiled_result; -} - -// -// Plugin Definition -// +// A simple compiler plugin example that implements everything directly. +// This plugin matches on mul ops, and emits "byte code" that is simply +// a string representative of the ops consumed. // Plugins can hold state. struct LiteRtCompilerPluginT {}; @@ -129,16 +38,11 @@ void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin) { delete compiler_plugin; } -LiteRtStatus LiteRtCompilerPluginPartitionModel( - LiteRtCompilerPlugin compiler_plugin, LiteRtModel model, - LiteRtOpList selected_ops) { - auto main_subgraph = - litert::Model::CreateFromNonOwnedHandle(model).MainSubgraph(); - if (!main_subgraph) { - return main_subgraph.Error().Status(); - } - - for (const auto& op : main_subgraph->Ops()) { +LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, + LiteRtSubgraph subgraph, + LiteRtOpList selected_ops) { + ::litert::Subgraph main_subgraph(subgraph); + for (const auto& op : main_subgraph.Ops()) { if (op.Code() != kLiteRtOpCodeTflMul) { continue; } @@ -184,7 +88,7 @@ LiteRtStatus CompileSinglePartition(LiteRtParamIndex partition_index, LiteRtStatus LiteRtCompilerPluginCompile( LiteRtCompilerPlugin compiler_plugin, const char* soc_model, - LiteRtSubgraphArray partitions, LiteRtParamIndex num_partitions, + LiteRtSubgraph* partitions, LiteRtParamIndex num_partitions, LiteRtCompiledResult* compiled_result) { LiteRtCompiledResult result = new LiteRtCompiledResultT; diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.cc new file mode 100644 index 00000000000000..11af31d1b14dd3 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.cc @@ -0,0 +1,123 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" + +// +// Configurations +// + +namespace litert::example { +namespace { + +constexpr char kPluginManufacturer[] = "ExampleSocManufacturer"; +constexpr char kPluginSocModel[] = "ExampleSocModel"; + +} // namespace +} // namespace litert::example + +LiteRtStatus LiteRtGetCompilerPluginVersion(LiteRtApiVersion* api_version) { + if (!api_version) { + return kLiteRtStatusErrorInvalidArgument; + } + api_version->major = LITERT_API_VERSION_MAJOR; + api_version->minor = LITERT_API_VERSION_MINOR; + api_version->patch = LITERT_API_VERSION_PATCH; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetCompilerPluginSupportedHardware( + LiteRtCompilerPlugin compiler_plugin, + LiteRtHwAccelerators* supported_hardware) { + if (!compiler_plugin || !supported_hardware) { + return kLiteRtStatusErrorInvalidArgument; + } + *supported_hardware = kLiteRtHwAccelatorCpu; + return kLiteRtStatusOk; +} + +const char* LiteRtGetCompilerPluginSocManufacturer() { + return litert::example::kPluginManufacturer; +} + +LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( + LiteRtCompilerPlugin compiler_plugin, + LiteRtParamIndex* num_supported_soc_models) { + if (!compiler_plugin || !num_supported_soc_models) { + return kLiteRtStatusErrorInvalidArgument; + } + *num_supported_soc_models = 1; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetCompilerPluginSupportedSocModel( + LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, + const char** soc_model_name) { + if (!compiler_plugin || !soc_model_name) { + return kLiteRtStatusErrorInvalidArgument; + } else if (soc_model_idx != 0) { + return kLiteRtStatusErrorUnsupported; + } + *soc_model_name = litert::example::kPluginSocModel; + return kLiteRtStatusOk; +} + +// +// Compiled Result Definition +// + +LiteRtStatus LiteRtGetCompiledResultByteCode( + LiteRtCompiledResult compiled_result, const void** byte_code, + size_t* byte_code_size) { + if (!compiled_result) { + return kLiteRtStatusErrorInvalidArgument; + } + *byte_code = compiled_result->byte_code.data(); + *byte_code_size = compiled_result->byte_code.size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetCompiledResultCallInfo( + LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, + const void** call_info, size_t* call_info_size) { + if (call_idx >= compiled_result->per_op_data.size()) { + return kLiteRtStatusErrorIndexOOB; + } + *call_info = compiled_result->per_op_data.at(call_idx).data(); + *call_info_size = compiled_result->per_op_data.at(call_idx).size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetNumCompiledResultCalls( + LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls) { + if (!compiled_result) { + return kLiteRtStatusErrorInvalidArgument; + } + *num_calls = compiled_result->per_op_data.size(); + return kLiteRtStatusOk; +} + +void LiteRtDestroyCompiledResult(LiteRtCompiledResult compiled_result) { + delete compiled_result; +} diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h new file mode 100644 index 00000000000000..e592dafcadb9eb --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h @@ -0,0 +1,29 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_PLUGIN_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_PLUGIN_COMMON_H_ + +#include +#include + +// Simple compiled result def holds byte code and per op data. +struct LiteRtCompiledResultT { + std::string byte_code; + std::vector per_op_data; +}; + +namespace litert::example {} // namespace litert::example + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_EXAMPLES_EXAMPLE_PLUGIN_COMMON_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_test.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_test.cc index 91713be1af8e1f..2d7d5c6eb0cbad 100644 --- a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_test.cc +++ b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_test.cc @@ -50,13 +50,13 @@ TEST(TestCallDummyPlugin, PartitionSimpleMultiAdd) { auto model = testing::LoadTestFileModel("simple_multi_op.tflite"); LiteRtOpListT selected_op_list; - LITERT_ASSERT_STATUS_OK(LiteRtCompilerPluginPartitionModel( - plugin.get(), model.Get(), &selected_op_list)); + LITERT_ASSERT_STATUS_OK(LiteRtCompilerPluginPartition( + plugin.get(), model.Subgraph(0)->Get(), &selected_op_list)); const auto selected_ops = selected_op_list.Vec(); ASSERT_EQ(selected_ops.size(), 2); - ASSERT_EQ(selected_ops[0]->op_code, kLiteRtOpCodeTflMul); - ASSERT_EQ(selected_ops[1]->op_code, kLiteRtOpCodeTflMul); + ASSERT_EQ(selected_ops[0]->OpCode(), kLiteRtOpCodeTflMul); + ASSERT_EQ(selected_ops[1]->OpCode(), kLiteRtOpCodeTflMul); } TEST(TestCallDummyPlugin, CompileMulSubgraph) { diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions.cc new file mode 100644 index 00000000000000..a2ad552f2a76fa --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions.cc @@ -0,0 +1,135 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/strings/str_format.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" +#include "tensorflow/lite/experimental/litert/vendors/cc/convert_graph.h" +#include "tensorflow/lite/experimental/litert/vendors/cc/partition_with_capabilities.h" +#include "tensorflow/lite/experimental/litert/vendors/examples/example_conversion_impl.h" +#include "tensorflow/lite/experimental/litert/vendors/examples/example_ir.h" +#include "tensorflow/lite/experimental/litert/vendors/examples/example_plugin_common.h" + +using ::litert::PartitionWithCapabilities; +using ::litert::example::ExampleGraphBuilder; +using ::litert::example::ExampleOpAllocator; +using ::litert::example::ExampleOpType; +using ::litert::example::ExampleTensorAllocator; +using ::litert::example::ExampleTypes; +using ::litert::example::MakeAllLegalizations; +using ::litert::example::MakeTensorConverter; + +// Example plugin implementations that leverage the pluggable conversion +// infrastructure. Implementations of common interfaces are provided in +// example_conversion_impl.h. These are passed to higher-level litert functions +// to perform the actual conversion. +// The primary benifit of this approach is the re-use of conversion logic +// between the partition and compile phases. + +// Plugins can hold state. +struct LiteRtCompilerPluginT { + ExampleTypes::Legalizations legalizations; +}; + +namespace { + +bool MulCapability(const ExampleTypes::Op* op) { + return op->op_code == ExampleOpType::MUL; +} + +} // namespace + +// Initialize example plugin and register legalizations. +LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin) { + auto* plugin = new LiteRtCompilerPluginT; + plugin->legalizations = MakeAllLegalizations(); + *compiler_plugin = plugin; + return kLiteRtStatusOk; +} + +void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin) { + delete compiler_plugin; +} + +// Leverage the convert_type PartitionViaCapabilties algorithm for partitioning +// implementation. +LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, + LiteRtSubgraph subgraph, + LiteRtOpList selected_ops) { + ExampleTensorAllocator tensor_alloc; + ExampleOpAllocator op_alloc; + + auto ops = PartitionWithCapabilities( + compiler_plugin->legalizations, MulCapability, MakeTensorConverter, + tensor_alloc, op_alloc, ::litert::Subgraph(subgraph)); + if (!ops) { + return ops.Error().Status(); + } + + for (auto* op : *ops) { + LITERT_RETURN_STATUS_IF_NOT_OK(LiteRtPushOp(selected_ops, op)); + } + + return kLiteRtStatusOk; +} + +namespace { + +LiteRtStatus CompileSinglePartition( + const ExampleTypes::Legalizations& legalizations, std::string name, + LiteRtSubgraph subgraph, LiteRtCompiledResultT& result) { + ::litert::Subgraph litert_subgraph(subgraph); + + ExampleTensorAllocator tensor_alloc; + ExampleOpAllocator op_alloc; + + ExampleGraphBuilder builder; + + LITERT_RETURN_STATUS_IF_NOT_OK(::litert::ConvertGraph( + litert_subgraph, name, MakeTensorConverter, tensor_alloc, op_alloc, + legalizations, builder)); + + result.byte_code.append(builder.Serialize()); + result.per_op_data.push_back(std::move(name)); + + return kLiteRtStatusOk; +} + +} // namespace + +// Plugin compiler implementation that leverages the pluggable convert_types +// infrastructure. +LiteRtStatus LiteRtCompilerPluginCompile( + LiteRtCompilerPlugin compiler_plugin, const char* soc_model, + LiteRtSubgraph* partitions, LiteRtParamIndex num_partitions, + LiteRtCompiledResult* compiled_result) { + auto* result = new LiteRtCompiledResultT; + + for (auto i = 0; i < num_partitions; ++i) { + auto name = absl::StrFormat("partition_%lu", i); + LITERT_RETURN_STATUS_IF_NOT_OK( + CompileSinglePartition(compiler_plugin->legalizations, std::move(name), + partitions[i], *result)); + } + + *compiled_result = result; + + return kLiteRtStatusOk; +} diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions_test.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions_test.cc new file mode 100644 index 00000000000000..76bb4a7f3baa6e --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions_test.cc @@ -0,0 +1,112 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/core/model/model.h" +#include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/test/test_macros.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" +#include "tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h" + +namespace litert { +namespace { + +using ::testing::HasSubstr; + +TEST(ExamplePluginWithConvertTypesTest, GetConfigInfo) { + ASSERT_STREQ(LiteRtGetCompilerPluginSocManufacturer(), + "ExampleSocManufacturer"); + + auto plugin = CreatePlugin(); + + LiteRtParamIndex num_supported_soc_models; + LITERT_ASSERT_STATUS_OK(LiteRtGetNumCompilerPluginSupportedSocModels( + plugin.get(), &num_supported_soc_models)); + ASSERT_EQ(num_supported_soc_models, 1); + + const char* soc_model_name; + LITERT_ASSERT_STATUS_OK(LiteRtGetCompilerPluginSupportedSocModel( + plugin.get(), 0, &soc_model_name)); + ASSERT_STREQ(soc_model_name, "ExampleSocModel"); +} + +TEST(ExamplePluginWithConvertTypesTest, PartitionSimpleMultiAdd) { + auto plugin = CreatePlugin(); + auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); + + LiteRtOpListT selected_op_list; + LITERT_ASSERT_STATUS_OK(LiteRtCompilerPluginPartition( + plugin.get(), model.Get()->MainSubgraph(), &selected_op_list)); + const auto selected_ops = selected_op_list.Vec(); + + ASSERT_EQ(selected_ops.size(), 2); + ASSERT_EQ(selected_ops[0]->OpCode(), kLiteRtOpCodeTflMul); + ASSERT_EQ(selected_ops[1]->OpCode(), kLiteRtOpCodeTflMul); +} + +TEST(ExamplePluginWithConvertTypesTest, CompileMulSubgraph) { + static constexpr absl::string_view kName = "partition_0"; + + auto plugin = CreatePlugin(); + auto model = litert::testing::LoadTestFileModel("mul_simple.tflite"); + + auto main_subgraph = model.MainSubgraph(); + LiteRtSubgraph litert_subgraph = main_subgraph->Get(); + + LiteRtCompiledResult compiled; + LITERT_ASSERT_STATUS_OK(LiteRtCompilerPluginCompile( + plugin.get(), /*soc_model=*/nullptr, &litert_subgraph, + /*num_partitions*/ 1, &compiled)); + + const void* byte_code; + size_t byte_code_size; + LITERT_ASSERT_STATUS_OK( + LiteRtGetCompiledResultByteCode(compiled, &byte_code, &byte_code_size)); + absl::string_view byte_code_str(reinterpret_cast(byte_code), + byte_code_size); + + EXPECT_THAT(byte_code_str, HasSubstr(kName)); + EXPECT_THAT(byte_code_str, HasSubstr("0FLOAT[2, 2]")); + EXPECT_THAT(byte_code_str, HasSubstr("1FLOAT[2, 2]")); + EXPECT_THAT(byte_code_str, HasSubstr("2FLOAT[2, 2]")); + EXPECT_THAT(byte_code_str, HasSubstr("MUL")); + EXPECT_THAT(byte_code_str, HasSubstr("FINALIZED")); + + LiteRtParamIndex num_call_infos; + LITERT_ASSERT_STATUS_OK( + LiteRtGetNumCompiledResultCalls(compiled, &num_call_infos)); + + ASSERT_EQ(num_call_infos, 1); + + const void* op_data; + size_t op_data_size; + LITERT_ASSERT_STATUS_OK( + LiteRtGetCompiledResultCallInfo(compiled, 0, &op_data, &op_data_size)); + absl::string_view op_data_str(reinterpret_cast(op_data), + op_data_size); + + EXPECT_EQ(op_data_str, kName); + + LiteRtDestroyCompiledResult(compiled); +} + +} // namespace +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/BUILD b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/BUILD index 673c2d868ed31e..e9f160dff77bfd 100644 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/BUILD +++ b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/BUILD @@ -46,6 +46,7 @@ litert_dynamic_lib( visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], deps = [ "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_event", "//tensorflow/lite/experimental/litert/c:litert_logging", "//tensorflow/lite/experimental/litert/c:litert_model", "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", @@ -71,7 +72,6 @@ cc_test( "//conditions:default": [], }), deps = [ - ":dispatch_api", "//tensorflow/lite/experimental/litert/c:litert_common", "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", "//tensorflow/lite/experimental/litert/core:filesystem", diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc index 5c40232b11bf6c..5ccc8af94b7b40 100644 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc +++ b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc @@ -61,9 +61,10 @@ TEST(DispatchApi, GoogleTensor) { kLiteRtStatusOk); ABSL_LOG(INFO) << "device_context: " << device_context; - auto model_file_name = kGoogleTensorModelFileName; + auto model_file_name = + litert::testing::GetTestFilePath(kGoogleTensorModelFileName); auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model); + EXPECT_TRUE(model) << model.Error(); ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() << " bytes"; diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc index 8a946bfbf4f398..ce8613526dcb81 100644 --- a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc +++ b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc @@ -31,7 +31,8 @@ namespace { constexpr const size_t kEdgeTpuPadding = 64; -inline constexpr auto Pad(auto x, auto align) { +template +inline constexpr auto Pad(X x, Align align) { return ((x + align - 1) / align) * align; } diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/BUILD b/tensorflow/lite/experimental/litert/vendors/mediatek/BUILD index b886beb1914d27..db1c877036c609 100644 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/BUILD +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/BUILD @@ -30,7 +30,7 @@ cc_library( "//tensorflow/lite/experimental/litert/c:litert_logging", "//tensorflow/lite/experimental/litert/cc:litert_expected", "//tensorflow/lite/experimental/litert/core:dynamic_loading", - "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", ], ) diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/BUILD b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/BUILD new file mode 100644 index 00000000000000..fe3de559b469e1 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/BUILD @@ -0,0 +1,135 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_dynamic_lib", "litert_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert/vendors/mediatek/compiler:__subpackages__"], +) + +litert_dynamic_lib( + name = "compiler_plugin", + srcs = ["compiler_plugin.cc"], + hdrs = ["//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin.h"], + export_litert_only = True, + shared_lib_name = "compiler_plugin_so", + so_name = "libLiteRtCompilerPlugin_MediaTek.so", + tags = [ + # Don't build/test in OS until MediaTek SDK is available. + "nobuilder", + ], + ungrte = True, + visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + deps = [ + ":compile_model", + ":create_model", + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/c:litert_model", + "//tensorflow/lite/experimental/litert/c:litert_op_code", + "//tensorflow/lite/experimental/litert/cc:litert_expected", + "//tensorflow/lite/experimental/litert/cc:litert_macros", + "//tensorflow/lite/experimental/litert/cc:litert_model", + "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", + "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "create_model", + srcs = ["create_model.cc"], + hdrs = ["create_model.h"], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/c:litert_op_code", + "//tensorflow/lite/experimental/litert/c:litert_options", + "//tensorflow/lite/experimental/litert/cc:litert_expected", + "//tensorflow/lite/experimental/litert/cc:litert_macros", + "//tensorflow/lite/experimental/litert/cc:litert_model", + "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", + "//tensorflow/lite/experimental/litert/core/model", + "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter", + "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:add_op_legalization", + "//tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations:operand_map", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "compile_model", + srcs = ["compile_model.cc"], + hdrs = ["compile_model.h"], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/c:litert_op_code", + "//tensorflow/lite/experimental/litert/c:litert_options", + "//tensorflow/lite/experimental/litert/cc:litert_expected", + "//tensorflow/lite/experimental/litert/cc:litert_macros", + "//tensorflow/lite/experimental/litert/cc:litert_model", + "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", + "//tensorflow/lite/experimental/litert/core/model", + "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter", + "@com_google_absl//absl/strings:string_view", + ], +) + +litert_test( + name = "compiler_plugin_test", + srcs = [ + "compiler_plugin_test.cc", + ], + data = [ + "//tensorflow/lite/experimental/litert/test:mlir_test_data", + "//tensorflow/lite/experimental/litert/test:tflite_test_data", + ], + linkstatic = True, + tags = [ + # Tests with ungrte deps do not currently work on forge. + "no-remote-exec", + "notap", + # Don't build/test in OS until qnn is available. + "nobuilder", + "no_oss", + # Sanitizer runtime doesn't work with anything that loads libQnnHtp.so. + "nosan", + ], + # Currently this test can only be run on Android because we don't have x86 shared libraries for + # MTK. + target_compatible_with = select({ + "@platforms//os:android": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + use_sys_malloc = True, + deps = [ + ":compiler_plugin", # buildcleaner: keep + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/c:litert_op_code", + "//tensorflow/lite/experimental/litert/cc:litert_expected", + "//tensorflow/lite/experimental/litert/cc:litert_macros", + "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", + "//tensorflow/lite/experimental/litert/core/model", + "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/experimental/litert/test:test_macros", + "//tensorflow/lite/experimental/litert/test:test_models", + "//tensorflow/lite/experimental/litert/vendors/cc:litert_compiler_plugin", + "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:string_view", + ], +) diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.cc new file mode 100644 index 00000000000000..705e91ff25516b --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.cc @@ -0,0 +1,105 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.h" + +#include +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.h" + +namespace litert::mediatek { + +Expected CompileModel( + const NeuronAdapter& neuron_adapter, NeuronModel* model, + std::optional soc_model) { +#if defined(__ANDROID__) + if (soc_model) { + return Error(kLiteRtStatusErrorInvalidArgument, + "JIT compilation for a specific SoC is not supported"); + } +#else + // TODO: Support offline compilation for a specific SoC by setting environment + // variables MTKNN_ADAPTER_DLA_PLATFORM and MTKNN_ADAPTER_DLA_DIR and fetching + // the content of the generated DLA file. + return Error(kLiteRtStatusErrorInvalidArgument, + "AOT compilation is not supported yet"); +#endif + + // Per MediaTek recommendation, Compilation_create, + // Compilation_createWithOptions, and Compilation_setOptimizationString + // should be used as follow: + // - AOT Compilation: Compilation_createWithOptions only + // - JIT Compilation: Compilation_create and Compilation_setOptimizationString + // The code below takes care of those conditions. + + const auto compile_options = +#if __ANDROID__ + std::string(neuron_adapter.JitCompileOptions()); +#else + std::string(neuron_adapter.AotCompileOptions()); +#endif + + auto compilation = +#if __ANDROID__ + neuron_adapter.CreateCompilation(model); +#else + neuron_adapter.CreateCompilation(model, compile_options); +#endif + if (!compilation) { + return compilation.Error(); + } + + if (neuron_adapter.api().compilation_set_priority( + compilation->get(), NEURON_PRIORITY_HIGH) != NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to set compilation priority"); + } + + if (neuron_adapter.api().compilation_set_preference( + compilation->get(), NEURON_PREFER_SUSTAINED_SPEED) != + NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to set compilation preference"); + } + +#if __ANDROID__ + if (!compile_options.empty()) { + if (auto status = neuron_adapter.api().compilation_set_optimization_string( + compilation->get(), compile_options.c_str()); + status != NEURON_NO_ERROR) { + LITERT_LOG(LITERT_INFO, + "NeuronCompilation_setOptimizationString failed with error %d", + status); + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to set optimization string"); + } + } +#endif + + if (auto status = neuron_adapter.api().compilation_finish(compilation->get()); + status != NEURON_NO_ERROR) { + LITERT_LOG(LITERT_INFO, "NeuronCompilation_finish failed with error %d", + status); + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to finish compilation"); + } + + return compilation; +} + +} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.h new file mode 100644 index 00000000000000..d7ac0a51130b24 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.h @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_COMPILE_MODEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_COMPILE_MODEL_H_ + +#include +#include + +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.h" + +namespace litert::mediatek { + +Expected CompileModel( + const NeuronAdapter& neuron_adapter, NeuronModel* model, + std::optional soc_model); + +} // namespace litert::mediatek + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_COMPILE_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin.cc new file mode 100644 index 00000000000000..17758498184201 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin.cc @@ -0,0 +1,320 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" +#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compile_model.h" +#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.h" +#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.h" + +// +// Configurations +// + +using litert::Error; +using litert::Expected; +using litert::mediatek::NEURON_NO_ERROR; +using litert::mediatek::NEURON_PREFER_SUSTAINED_SPEED; +using litert::mediatek::NEURON_PRIORITY_HIGH; +using litert::mediatek::NeuronAdapter; +using litert::mediatek::NeuronCompilation; +using litert::mediatek::NeuronCompilationPtr; +using litert::mediatek::NeuronModel; +using litert::mediatek::NeuronModelPtr; + +namespace { + +constexpr char kPluginManufacturer[] = "MediaTek"; + +// clang-format off +constexpr std::pair kPluginSocModels[] = { + {"mt6853", "mt6853"}, + {"mt6877", "mt6877"}, + {"mt6878", "mt6878"}, + {"mt6879", "mt6879"}, + {"mt6886", "mt6886"}, + {"mt6893", "mt6893"}, + {"mt6895", "mt6895"}, + {"mt6897", "mt6897"}, + {"mt6983", "mt6983"}, + {"mt6985", "mt6985"}, + {"mt6989", "mt6989"}, + {"mt6991", "mt6991"}, +}; + +constexpr LiteRtOpCode kSupportedOps[] = { + kLiteRtOpCodeTflAdd, +}; +// clang-format on + +constexpr auto kNumPluginSocModels = + sizeof(kPluginSocModels) / sizeof(kPluginSocModels[0]); + +std::optional FindSocModel(absl::string_view soc_model_name) { + std::optional soc_model; + for (auto i = 0; i < kNumPluginSocModels; ++i) { + if (soc_model_name == kPluginSocModels[i].first) { + soc_model = kPluginSocModels[i].second; + break; + } + } + return soc_model; +} + +} // namespace + +LiteRtStatus LiteRtGetCompilerPluginVersion(LiteRtApiVersion* api_version) { + if (api_version == nullptr) { + return kLiteRtStatusErrorInvalidArgument; + } + api_version->major = LITERT_API_VERSION_MAJOR; + api_version->minor = LITERT_API_VERSION_MINOR; + api_version->patch = LITERT_API_VERSION_PATCH; + return kLiteRtStatusOk; +} + +const char* LiteRtGetCompilerPluginSocManufacturer() { + return kPluginManufacturer; +} + +LiteRtStatus LiteRtGetCompilerPluginSupportedHardware( + LiteRtCompilerPlugin compiler_plugin, + LiteRtHwAccelerators* supported_hardware) { + if (!compiler_plugin || !supported_hardware) { + return kLiteRtStatusErrorInvalidArgument; + } + *supported_hardware = kLiteRtHwAccelatorNpu; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( + LiteRtCompilerPlugin compiler_plugin, + LiteRtParamIndex* num_supported_soc_models) { + if (!compiler_plugin || !num_supported_soc_models) { + return kLiteRtStatusErrorInvalidArgument; + } + *num_supported_soc_models = kNumPluginSocModels; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetCompilerPluginSupportedSocModel( + LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, + const char** soc_model_name) { + if (!compiler_plugin || !soc_model_name) { + return kLiteRtStatusErrorInvalidArgument; + } else if (soc_model_idx < 0 || soc_model_idx >= kNumPluginSocModels) { + return kLiteRtStatusErrorInvalidArgument; + } + *soc_model_name = kPluginSocModels[soc_model_idx].first; + return kLiteRtStatusOk; +} + +// +// Compiled Result Definition +// + +// TODO: Revisit this struct after we extend the compiler plugin API to return +// results with more than one single bytecode. +struct LiteRtCompiledResultT { + using Bytecode = std::vector; + std::vector bytecodes; + std::vector graph_names; +}; + +LiteRtStatus LiteRtGetCompiledResultByteCode( + LiteRtCompiledResult compiled_result, const void** byte_code, + size_t* byte_code_size) { + if (!compiled_result || !byte_code || !byte_code_size) { + return kLiteRtStatusErrorInvalidArgument; + } else if (compiled_result->bytecodes.size() > 1) { + // TODO: Revisit this struct after we extend the compiler plugin API to + // return results with more than one single bytecode. + LITERT_LOG(LITERT_ERROR, "CompilerPlugin API supports only 1 NPU bytecode"); + return kLiteRtStatusErrorIndexOOB; + } + *byte_code = compiled_result->bytecodes[0].data(); + *byte_code_size = compiled_result->bytecodes[0].size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetCompiledResultCallInfo( + LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, + const void** call_info, size_t* call_info_size) { + if (!compiled_result || !call_info || !call_info_size) { + return kLiteRtStatusErrorInvalidArgument; + } else if (call_idx >= compiled_result->graph_names.size()) { + return kLiteRtStatusErrorIndexOOB; + } + + auto& graph_name = compiled_result->graph_names[call_idx]; + *call_info = graph_name.data(); + *call_info_size = graph_name.size(); + + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetNumCompiledResultCalls( + LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls) { + if (!compiled_result || !num_calls) { + return kLiteRtStatusErrorInvalidArgument; + } + *num_calls = compiled_result->bytecodes.size(); + return kLiteRtStatusOk; +} + +void LiteRtDestroyCompiledResult(LiteRtCompiledResult compiled_result) { + delete compiled_result; +} + +// +// Plugin Definition +// + +// Plugins can hold state. +struct LiteRtCompilerPluginT {}; + +LiteRtStatus LiteRtCreateCompilerPlugin(LiteRtCompilerPlugin* compiler_plugin) { + auto* plugin = new LiteRtCompilerPluginT; + *compiler_plugin = plugin; + return kLiteRtStatusOk; +} + +void LiteRtDestroyCompilerPlugin(LiteRtCompilerPlugin compiler_plugin) { + delete compiler_plugin; +} + +namespace { + +// TODO update this function to match the new legalizations. +bool IsOpSupported(const litert::Op& op) { + // NOTE: Currently we are demoing by just mapping simple f32 mul ops. Use a + // very loose guard for now -- only checking if op code is supported. + for (auto supported_op : kSupportedOps) { + if (op.Code() == supported_op) { + return true; + } + } + return false; +} + +} // namespace + +LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, + LiteRtSubgraph subgraph, + LiteRtOpList selected_ops) { + litert::Subgraph graph(subgraph); + for (const auto& op : graph.Ops()) { + if (!IsOpSupported(op)) { + continue; + } + + LITERT_RETURN_STATUS_IF_NOT_OK(LiteRtPushOp(selected_ops, op.Get())); + } + + return kLiteRtStatusOk; +} + +namespace { + +Expected> CompilePartition( + NeuronAdapter& neuron_adapter, const litert::Subgraph& partition, + const std::string& graph_name, std::optional soc_model) { + auto model = CreateModel(neuron_adapter, partition, graph_name); + if (!model) { + return model.Error(); + } + + auto compilation = CompileModel(neuron_adapter, model->get(), soc_model); + if (!compilation) { + return compilation.Error(); + } + + size_t bytecode_size; + if (neuron_adapter.api().compilation_get_compiled_network_size( + compilation->get(), &bytecode_size) != NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to get compiled network size"); + } + + std::vector bytecode(bytecode_size); + if (neuron_adapter.api().compilation_store_compiled_network( + compilation->get(), bytecode.data(), bytecode.size()) != + NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to get compiled network"); + } + + return bytecode; +} + +} // namespace + +LiteRtStatus LiteRtCompilerPluginCompile( + LiteRtCompilerPlugin compiler_plugin, const char* soc_model, + LiteRtSubgraph* partitions, LiteRtParamIndex num_partitions, + LiteRtCompiledResult* compiled_result) { + LITERT_LOG(LITERT_INFO, + "Starting MediaTek Compilation for %d subgraphs, soc_model=%s", + num_partitions, soc_model); + + auto opt_soc_model = soc_model ? FindSocModel(soc_model) : std::nullopt; + if (opt_soc_model) { + LITERT_LOG(LITERT_ERROR, "Compiling for MediaTek architecture: %s", + *opt_soc_model); + } else if (soc_model) { + LITERT_LOG(LITERT_ERROR, "Unexpected SoC model: %s", soc_model); + return kLiteRtStatusErrorInvalidArgument; + } + + // Initialize SDK and load qnn shared libraries. + + auto neuron_adapter = + NeuronAdapter::Create(/*shared_library_dir=*/std::nullopt); + if (!neuron_adapter) { + return neuron_adapter.Error().Status(); + } + + auto result = std::make_unique(); + for (auto i = 0; i < num_partitions; ++i) { + auto partition = litert::Subgraph(partitions[i]); + auto graph_name = absl::StrFormat("Partition_%d", i); + auto bytecode = CompilePartition(**neuron_adapter, partition, graph_name, + opt_soc_model); + if (!bytecode) { + LITERT_LOG(LITERT_INFO, "%s", bytecode.Error().Message().data()); + return bytecode.Error().Status(); + } + + result->bytecodes.emplace_back(*bytecode); + result->graph_names.emplace_back(graph_name); + } + + *compiled_result = result.release(); + return kLiteRtStatusOk; +} diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin_test.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin_test.cc new file mode 100644 index 00000000000000..fc51ad177a08fb --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin_test.cc @@ -0,0 +1,117 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" +#include "tensorflow/lite/experimental/litert/core/model/model.h" +#include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/test/test_macros.h" +#include "tensorflow/lite/experimental/litert/test/test_models.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" +#include "tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h" + +namespace litert { +namespace { + +using ::testing::Values; + +// clang-format off +const auto kSupportedOps = Values( + "add_cst.tflite", + "add_simple.tflite", + "simple_add_op.tflite"); +// clang-format on + +TEST(TestQnnPlugin, GetConfigInfo) { + EXPECT_STREQ(LiteRtGetCompilerPluginSocManufacturer(), "MediaTek"); + + auto plugin = CreatePlugin(); + + LiteRtParamIndex num_supported_soc_models; + LITERT_ASSERT_STATUS_OK(LiteRtGetNumCompilerPluginSupportedSocModels( + plugin.get(), &num_supported_soc_models)); + ASSERT_EQ(num_supported_soc_models, 12); + + const char* config_id; + LITERT_CHECK_STATUS_OK( + LiteRtGetCompilerPluginSupportedSocModel(plugin.get(), 0, &config_id)); + EXPECT_STREQ(config_id, "mt6853"); +} + +TEST(TestQnnPlugin, PartitionAdd) { + auto plugin = CreatePlugin(); + auto model = testing::LoadTestFileModel("add_simple.tflite"); + + LiteRtOpListT selected_op_list; + LITERT_ASSERT_STATUS_OK(LiteRtCompilerPluginPartition( + plugin.get(), model.Subgraph(0)->Get(), &selected_op_list)); + const auto selected_ops = selected_op_list.Vec(); + + ASSERT_EQ(selected_ops.size(), 1); + EXPECT_EQ(selected_ops[0]->OpCode(), kLiteRtOpCodeTflAdd); +} + +// ///////////////////////////////////////////////////////////////////////////// + +class MtkPluginOpCompatibilityTest + : public ::testing::TestWithParam {}; + +TEST_P(MtkPluginOpCompatibilityTest, SupportedOpsTest) { + LITERT_LOG(LITERT_INFO, "Testing TFLite model: %s", GetParam().c_str()); + auto plugin = CreatePlugin(); + auto model = testing::LoadTestFileModel(GetParam()); + + const auto subgraph = model.MainSubgraph(); + LiteRtSubgraph litert_subgraph = subgraph->Get(); + + LiteRtCompiledResult compiled; + LITERT_ASSERT_STATUS_OK(LiteRtCompilerPluginCompile( + plugin.get(), /*soc_model=*/nullptr, &litert_subgraph, 1, &compiled)); + + const void* byte_code; + size_t byte_code_size; + + LITERT_ASSERT_STATUS_OK( + LiteRtGetCompiledResultByteCode(compiled, &byte_code, &byte_code_size)); + + absl::string_view byte_code_string(reinterpret_cast(byte_code), + byte_code_size); + ASSERT_FALSE(byte_code_string.empty()); + + const void* op_data; + size_t op_data_size; + + LITERT_ASSERT_STATUS_OK( + LiteRtGetCompiledResultCallInfo(compiled, 0, &op_data, &op_data_size)); + + absl::string_view op_data_string(reinterpret_cast(op_data), + op_data_size); + ASSERT_EQ("Partition_0", op_data_string); + + LiteRtDestroyCompiledResult(compiled); +} + +INSTANTIATE_TEST_SUITE_P(SupportedOpsTest, MtkPluginOpCompatibilityTest, + kSupportedOps); + +} // namespace +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.cc new file mode 100644 index 00000000000000..256c43c5e9b59d --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.cc @@ -0,0 +1,94 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.h" + +#include +#include +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.h" +#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" +#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.h" + +namespace litert::mediatek { + +Expected CreateModel(const NeuronAdapter& neuron_adapter, + const litert::Subgraph& partition, + const std::string& model_name) { + auto model = neuron_adapter.CreateModel(); + if (!model) { + return model.Error(); + } + + if (neuron_adapter.api().model_set_name(model->get(), model_name.c_str()) != + NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, "Failed to set model name"); + } + + OperandMap operand_map(neuron_adapter, model->get()); + + std::vector input_indices; + for (const auto& input : partition.Inputs()) { + auto operand_index = operand_map.GetOperandIndex(input); + if (!operand_index) { + return operand_index.Error(); + } + input_indices.push_back(*operand_index); + } + + std::vector output_indices; + for (const auto& output : partition.Outputs()) { + auto operand_index = operand_map.GetOperandIndex(output); + if (!operand_index) { + return operand_index.Error(); + } + output_indices.push_back(*operand_index); + } + + if (neuron_adapter.api().model_identify_inputs_and_outputs( + model->get(), input_indices.size(), input_indices.data(), + output_indices.size(), output_indices.data()) != NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to identify model I/Os"); + } + + for (const auto& op : partition.Ops()) { + Expected status; + switch (op.Code()) { + case kLiteRtOpCodeTflAdd: + status = LegalizeAddOp(neuron_adapter, model->get(), operand_map, op); + break; + + default: + return Error(kLiteRtStatusErrorRuntimeFailure, "Unsupported op"); + } + + if (!status) { + return status.Error(); + } + } + + if (neuron_adapter.api().model_finish(model->get()) != NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, "Failed to finish model"); + } + + return model; +} + +} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.h new file mode 100644 index 00000000000000..21af01d19f8b02 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/create_model.h @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_CREATE_MODEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_CREATE_MODEL_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.h" + +namespace litert::mediatek { + +// Create a new NeuronModel Graph from given LiteRt Graph. +Expected CreateModel(const NeuronAdapter& neuron_adapter, + const Subgraph& partition, + const std::string& model_name); + +} // namespace litert::mediatek + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_CREATE_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/BUILD b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/BUILD new file mode 100644 index 00000000000000..d15911fcc87838 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/BUILD @@ -0,0 +1,59 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert/vendors/mediatek/compiler:__subpackages__"], +) + +cc_library( + name = "operand_map", + srcs = ["operand_map.cc"], + hdrs = ["operand_map.h"], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/c:litert_op_code", + "//tensorflow/lite/experimental/litert/c:litert_options", + "//tensorflow/lite/experimental/litert/cc:litert_element_type", + "//tensorflow/lite/experimental/litert/cc:litert_expected", + "//tensorflow/lite/experimental/litert/cc:litert_macros", + "//tensorflow/lite/experimental/litert/cc:litert_model", + "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", + "//tensorflow/lite/experimental/litert/core/model", + "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "add_op_legalization", + srcs = ["add_op_legalization.cc"], + hdrs = ["add_op_legalization.h"], + deps = [ + "operand_map", + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/c:litert_op_code", + "//tensorflow/lite/experimental/litert/c:litert_options", + "//tensorflow/lite/experimental/litert/cc:litert_expected", + "//tensorflow/lite/experimental/litert/cc:litert_macros", + "//tensorflow/lite/experimental/litert/cc:litert_model", + "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", + "//tensorflow/lite/experimental/litert/core/model", + "//tensorflow/lite/experimental/litert/vendors/mediatek:neuron_adapter", + "@com_google_absl//absl/strings:string_view", + ], +) diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.cc new file mode 100644 index 00000000000000..c801bafb7dfd8a --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.cc @@ -0,0 +1,76 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.h" + +#include +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_options.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" +#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.h" + +namespace litert::mediatek { + +Expected LegalizeAddOp(const NeuronAdapter& neuron_adapter, + NeuronModel* model, OperandMap& operand_map, + const litert::Op& op) { + std::vector input_indices; + for (auto& input : op.Inputs()) { + auto id = operand_map.GetOperandIndex(input); + if (!id) { + return id.Error(); + } + input_indices.push_back(*id); + } + + // A NEURON_ADD operation takes a 3rd scalar operand, which is used to pass a + // TfLiteFusedActivation value. + uint32_t tfl_fused_activation; + if (auto status = + LiteRtGetAddFusedActivationOption(op.Get(), &tfl_fused_activation); + status != kLiteRtStatusOk) { + return Error(status, "Failed to get fused activation"); + } + auto fused_activation_operand_index = + operand_map.AddScalarInt32(tfl_fused_activation); + if (!fused_activation_operand_index) { + return fused_activation_operand_index.Error(); + } + input_indices.push_back(*fused_activation_operand_index); + + std::vector output_indices; + for (auto& output : op.Outputs()) { + auto id = operand_map.GetOperandIndex(output); + if (!id) { + return id.Error(); + } + output_indices.push_back(*id); + } + + if (neuron_adapter.api().model_add_operation( + model, /*type=*/NEURON_ADD, input_indices.size(), + input_indices.data(), output_indices.size(), + output_indices.data()) != NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to set value of NEURON_ADD fused activation"); + } + + return {}; +} + +} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.h new file mode 100644 index 00000000000000..fef6773e762bf0 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/add_op_legalization.h @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" +#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.h" + +namespace litert::mediatek { + +Expected LegalizeAddOp(const NeuronAdapter& neuron_adapter, + NeuronModel* model, OperandMap& operand_map, + const litert::Op& op); + +} // namespace litert::mediatek + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.cc new file mode 100644 index 00000000000000..94eda9dcfd9ae1 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.cc @@ -0,0 +1,126 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.h" + +namespace litert::mediatek { + +namespace { + +class OperandType : public NeuronOperandType { + public: + static Expected Create(const Tensor& t) { + auto ranked_tensor_type = t.RankedTensorType(); + if (!ranked_tensor_type) { + return ranked_tensor_type.Error(); + } + + auto tensor_dimensions = ranked_tensor_type->Layout().Dimensions(); + std::vector mtk_dimensions; + mtk_dimensions.reserve(tensor_dimensions.size()); + std::copy(tensor_dimensions.begin(), tensor_dimensions.end(), + std::back_inserter(mtk_dimensions)); + + int32_t mtk_type; + switch (ranked_tensor_type->ElementType()) { + case ElementType::Float32: + mtk_type = NEURON_TENSOR_FLOAT32; + break; + case ElementType::Int32: + mtk_type = NEURON_TENSOR_INT32; + break; + default: + return Error(kLiteRtStatusErrorRuntimeFailure, + "Unsupported element type"); + } + + return OperandType(mtk_type, std::move(mtk_dimensions)); + } + + OperandType(const OperandType&) = delete; + + OperandType(OperandType&& other) : dimensions_(std::move(other.dimensions_)) { + // Copy all the scalar fields from other. + *static_cast(this) = + *static_cast(&other); + // Reset the pointer fields by using own data. + dimensions = dimensions_.data(); + }; + + OperandType& operator=(const OperandType&) = delete; + OperandType& operator=(OperandType&& other) = delete; + + private: + explicit OperandType(int32_t mtk_type, std::vector&& mtk_dimensions) + : dimensions_(std::move(mtk_dimensions)) { + this->type = mtk_type; + this->dimensionCount = dimensions_.size(); + this->dimensions = dimensions_.data(); + }; + + std::vector dimensions_; +}; + +} // namespace + +// ///////////////////////////////////////////////////////////////////////////// + +Expected OperandMap::Register(const NeuronOperandType& operand_type) { + if (neuron_adapter_.api().model_add_operand(model_, &operand_type) != + NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to register model operand"); + } + return AllocateOperandIndex(); +} + +Expected OperandMap::Register(const Tensor& t) { + auto operand_type = OperandType::Create(t); + if (!operand_type) { + return operand_type.Error(); + } + + auto operand_index = + Register(static_cast(*operand_type)); + if (!operand_index) { + return operand_index.Error(); + } + + if (t.HasWeights()) { + auto weights = t.Weights().Bytes(); + if (neuron_adapter_.api().model_set_operand_value( + model_, *operand_index, weights.data(), weights.size()) != + NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to set value of tensor weights"); + } + } + + map_[t.Get()] = *operand_index; + return *operand_index; +} + +} // namespace litert::mediatek diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h new file mode 100644 index 00000000000000..ce3b5d8ca9b7d5 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/compiler/legalizations/operand_map.h @@ -0,0 +1,92 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_OPERAND_MAP_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_OPERAND_MAP_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.h" + +namespace litert::mediatek { + +// This class takes care of registering Tensors and scalars with a given +// NeuronModel and returing their "operand index", which is how the MTK SDK +// handles them. +class OperandMap { + public: + OperandMap(const NeuronAdapter& neuron_adapter, NeuronModel* model) + : neuron_adapter_(neuron_adapter), model_(model) {} + + // Add a scalar operand to the model. + Expected AddScalarBool(bool value) { + return AddScalar(NEURON_BOOL, value); + } + Expected AddScalarInt32(int32_t value) { + return AddScalar(NEURON_INT32, value); + } + Expected AddScalarFloat32(float value) { + return AddScalar(NEURON_FLOAT32, value); + } + + // Find the operand index for a given tensor and, if not done already, add the + // tensor as an operand in the model. + Expected GetOperandIndex(const Tensor& t) { + auto i = map_.find(t.Get()); + if (i != map_.end()) { + return i->second; + } else { + return Register(t); + } + } + + private: + Expected Register(const Tensor& t); + Expected Register(const NeuronOperandType& operand_type); + uint32_t AllocateOperandIndex() { return next_operand_index_++; } + + template + Expected AddScalar(int32_t mtk_type, T value) { + const NeuronOperandType scalar_type = { + .type = mtk_type, + .dimensionCount = 0, + .dimensions = nullptr, + }; + auto operand_index = Register(scalar_type); + if (!operand_index) { + return operand_index.Error(); + } + if (neuron_adapter_.api().model_set_operand_value( + model_, *operand_index, &value, sizeof(value)) != NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to set value of scalar operand"); + } + return operand_index; + } + + const NeuronAdapter& neuron_adapter_; + NeuronModel* model_; + int next_operand_index_ = 0; + absl::flat_hash_map map_; +}; + +} // namespace litert::mediatek + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_MEDIATEK_COMPILER_LEGALIZATIONS_OPERAND_MAP_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/BUILD b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/BUILD index ca281af5bc7a0e..4373c78d11c948 100644 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/BUILD +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/BUILD @@ -75,6 +75,7 @@ cc_test( "//tensorflow/lite/experimental/litert/c:litert_common", "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", "//tensorflow/lite/experimental/litert/core:filesystem", + "//tensorflow/lite/experimental/litert/test:common", "//tensorflow/lite/experimental/litert/test:simple_model_npu", "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", "@com_google_absl//absl/log", diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc index fe3997d5f9e103..f596a448fba534 100644 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/dispatch_api_mediatek_test.cc @@ -24,12 +24,13 @@ #include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" #include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" #include "tensorflow/lite/experimental/litert/core/filesystem.h" +#include "tensorflow/lite/experimental/litert/test/common.h" #include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" #include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" using ::testing::Pointwise; -TEST(DispatchApi, MediaTek) { +TEST(MediaTek, DispatchApiWithAhwb) { #if !defined(__ANDROID__) GTEST_SKIP() << "This test is specific to Android devices with a MediaTek NPU"; @@ -60,9 +61,10 @@ TEST(DispatchApi, MediaTek) { kLiteRtStatusOk); ABSL_LOG(INFO) << "device_context: " << device_context; - auto model_file_name = kMediaTekModelFileName; + auto model_file_name = + litert::testing::GetTestFilePath(kMediaTekModelFileName); auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model); + EXPECT_TRUE(model) << model.Error(); ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() << " bytes"; @@ -329,3 +331,305 @@ TEST(DispatchApi, MediaTek) { EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), kLiteRtStatusOk); } + +TEST(MediaTek, DispatchApiWithDmaBuf) { +#if !defined(__ANDROID__) + GTEST_SKIP() + << "This test is specific to Android devices with a MediaTek NPU"; +#endif + + EXPECT_EQ(LiteRtDispatchInitialize(/*options=*/nullptr, /*num_options=*/0), + kLiteRtStatusOk); + + const char* vendor_id; + EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); + ABSL_LOG(INFO) << "vendor_id: " << vendor_id; + + const char* build_id; + EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); + ABSL_LOG(INFO) << "build_id: " << build_id; + + LiteRtApiVersion api_version; + EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); + ABSL_LOG(INFO) << "api_version: " << api_version.major << "." + << api_version.minor << "." << api_version.patch; + + int capabilities; + EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); + ABSL_LOG(INFO) << "capabilities: " << capabilities; + + LiteRtDispatchDeviceContext device_context = nullptr; + EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), + kLiteRtStatusOk); + ABSL_LOG(INFO) << "device_context: " << device_context; + + auto model_file_name = + litert::testing::GetTestFilePath(kMediaTekModelFileName); + auto model = litert::internal::LoadBinaryFile(model_file_name); + EXPECT_TRUE(model) << model.Error(); + ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() + << " bytes"; + + // /////////////////////////////////////////////////////////////////////////// + // Set up an invocation context for a given model. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtDispatchInvocationContext invocation_context = nullptr; + EXPECT_EQ(LiteRtDispatchInvocationContextCreate( + device_context, kLiteRtDispatchExecutableTypeMlModel, + model->Data(), model->Size(), /*function_name=*/nullptr, + /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), + kLiteRtStatusOk); + ABSL_LOG(INFO) << "Invocation context: " << invocation_context; + + // /////////////////////////////////////////////////////////////////////////// + // Determine tensor buffer requirements. + // /////////////////////////////////////////////////////////////////////////// + + int num_tensor_buffer_types; + LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetInputRequirements( + invocation_context, /*input_index=*/0, &kInput0TensorType, + &input_0_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + input_0_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 2); + LiteRtTensorBufferType input_0_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + input_0_tensor_buffer_requirements, /*type_index=*/1, + &input_0_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); + size_t input_0_tensor_buffer_size; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsBufferSize( + input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); + + LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetInputRequirements( + invocation_context, /*input_index=*/1, &kInput1TensorType, + &input_1_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + input_1_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 2); + LiteRtTensorBufferType input_1_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + input_1_tensor_buffer_requirements, /*type_index=*/1, + &input_1_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); + size_t input_1_tensor_buffer_size; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsBufferSize( + input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); + + LiteRtTensorBufferRequirements output_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetOutputRequirements( + invocation_context, /*output_index=*/0, &kOutputTensorType, + &output_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetNumTensorBufferRequirementsSupportedBufferTypes( + output_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 2); + LiteRtTensorBufferType output_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + output_tensor_buffer_requirements, /*type_index=*/1, + &output_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); + size_t output_tensor_buffer_size; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( + output_tensor_buffer_requirements, &output_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); + + // /////////////////////////////////////////////////////////////////////////// + // Allocate tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtTensorBuffer input_0_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + input_0_tensor_buffer_type, &kInput0TensorType, + input_0_tensor_buffer_size, &input_0_tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBuffer input_1_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + input_1_tensor_buffer_type, &kInput1TensorType, + input_1_tensor_buffer_size, &input_1_tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBuffer output_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + output_tensor_buffer_type, &kOutputTensorType, + output_tensor_buffer_size, &output_tensor_buffer), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Register tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtTensorBufferHandle input_1_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, input_1_tensor_buffer, &input_1_handle), + kLiteRtStatusOk); + + LiteRtTensorBufferHandle input_0_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, input_0_tensor_buffer, &input_0_handle), + kLiteRtStatusOk); + + LiteRtTensorBufferHandle output_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, output_tensor_buffer, &output_handle), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Attach tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, + /*graph_input_index=*/0, input_0_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, + /*graph_input_index=*/1, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, + /*graph_output_index=*/0, output_handle), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Fill the input buffers with data. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Filling inputs with data"; + void* host_mem_addr; + + ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Execute model. + // /////////////////////////////////////////////////////////////////////////// + + ABSL_LOG(INFO) << "Invoking execution..."; + EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Check output for correctness. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Checking output..."; + void* host_mem_addr; + ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + auto output = absl::MakeSpan(static_cast(host_mem_addr), + kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Fill the input buffers with more data. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Filling inputs with data"; + void* host_mem_addr; + + ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput0Tensor_2, + sizeof(kTestInput0Tensor_2)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput1Tensor_2, + sizeof(kTestInput1Tensor_2)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Execute model once more. + // /////////////////////////////////////////////////////////////////////////// + + ABSL_LOG(INFO) << "Invoking second execution..."; + EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Check output for correctness. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Checking output..."; + void* host_mem_addr; + ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + auto output = absl::MakeSpan(static_cast(host_mem_addr), + kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor_2[i]; + } + EXPECT_THAT(output, + Pointwise(testing::FloatNear(1e-3), kTestOutputTensor_2)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Clean up resources. + // /////////////////////////////////////////////////////////////////////////// + + EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, + /*graph_input_index=*/0, input_0_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, + /*graph_input_index=*/1, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, + /*graph_output_index=*/0, output_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), + kLiteRtStatusOk); + EXPECT_EQ( + LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ( + LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), + kLiteRtStatusOk); + LiteRtDestroyTensorBuffer(output_tensor_buffer); + LiteRtDestroyTensorBuffer(input_1_tensor_buffer); + LiteRtDestroyTensorBuffer(input_0_tensor_buffer); + EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), + kLiteRtStatusOk); +} diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.cc index 3907ec9b5a7ef8..7c1ade0439a2b2 100644 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.cc +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.cc @@ -14,6 +14,8 @@ #include "tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h" +#include + #include #include @@ -46,12 +48,6 @@ LiteRtDispatchDeviceContextT::RegisterTensorBuffer( return tensor_buffer_type.Error(); } - if (*tensor_buffer_type != kLiteRtTensorBufferTypeAhwb) { - LITERT_LOG(LITERT_ERROR, "Unsupported buffer type: %d", - *tensor_buffer_type); - return litert::Unexpected(kLiteRtStatusErrorUnsupported); - } - auto tensor_buffer_size = tensor_buffer.Size(); if (!tensor_buffer_size) { return tensor_buffer_size.Error(); @@ -62,26 +58,52 @@ LiteRtDispatchDeviceContextT::RegisterTensorBuffer( return tensor_buffer_offset.Error(); } - auto ahwb = tensor_buffer.GetAhwb(); - if (!ahwb) { - return ahwb.Error(); - } - + switch (*tensor_buffer_type) { + case kLiteRtTensorBufferTypeAhwb: + if (auto ahwb = tensor_buffer.GetAhwb(); ahwb) { #ifdef __ANDROID__ - NeuronMemory* neuron_memory; - if (neuron_adapter_.api().memory_create_from_ahwb(*ahwb, &neuron_memory) != - NEURON_NO_ERROR) { - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create NeuronMemory from AHWB"); - } - return neuron_memory_registry_.Register(neuron_memory, *tensor_buffer_size, - *tensor_buffer_offset); + NeuronMemory* neuron_memory; + if (neuron_adapter_.api().memory_create_from_ahwb( + *ahwb, &neuron_memory) != NEURON_NO_ERROR) { + return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to create NeuronMemory from AHWB"); + } + return neuron_memory_registry_.Register( + neuron_memory, *tensor_buffer_size, *tensor_buffer_offset); #else - (void)neuron_adapter_; - return litert::Unexpected( - kLiteRtStatusErrorRuntimeFailure, - "AHardwareBuffer is not supported on this platform"); + (void)neuron_adapter_; + return litert::Unexpected( + kLiteRtStatusErrorRuntimeFailure, + "AHardwareBuffer is not supported on this platform"); #endif + } else { + return ahwb.Error(); + } + break; + + case kLiteRtTensorBufferTypeDmaBuf: + if (auto dma_buf = tensor_buffer.GetDmaBuf(); dma_buf) { + NeuronMemory* neuron_memory; + if (neuron_adapter_.api().memory_create_from_fd( + *tensor_buffer_size, /*protect*/ PROT_READ | PROT_WRITE, + dma_buf->fd, *tensor_buffer_offset, + &neuron_memory) != NEURON_NO_ERROR) { + return litert::Unexpected( + kLiteRtStatusErrorRuntimeFailure, + "Failed to create NeuronMemory from DMA-BUF"); + } + return neuron_memory_registry_.Register( + neuron_memory, *tensor_buffer_size, *tensor_buffer_offset); + } else { + return dma_buf.Error(); + } + break; + + default: + LITERT_LOG(LITERT_ERROR, "Unsupported buffer type: %d", + *tensor_buffer_type); + return litert::Unexpected(kLiteRtStatusErrorUnsupported); + } } LiteRtDispatchDeviceContextT::NeuronMemoryRegistry::~NeuronMemoryRegistry() { diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.cc index f8d8fe911dfe70..48885703b6a96d 100644 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.cc +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_invocation_context.cc @@ -18,6 +18,8 @@ #include #include #include +#include +#include #include #include "tensorflow/lite/experimental/litert/c/litert_common.h" @@ -30,26 +32,38 @@ #include "tensorflow/lite/experimental/litert/vendors/mediatek/dispatch/litert_dispatch_device_context.h" #include "tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.h" +using litert::Error; +using litert::Expected; using litert::mediatek::NEURON_NO_ERROR; using litert::mediatek::NEURON_PREFER_SUSTAINED_SPEED; using litert::mediatek::NEURON_PRIORITY_HIGH; using litert::mediatek::NEURON_TENSOR_FLOAT32; using litert::mediatek::NeuronCompilation; +using litert::mediatek::NeuronCompilationPtr; using litert::mediatek::NeuronExecution; +using litert::mediatek::NeuronExecutionPtr; using litert::mediatek::NeuronModel; +using litert::mediatek::NeuronModelPtr; using litert::mediatek::NeuronOperandType; using litert::mediatek::NeuronOperationType; using litert::mediatek::NeuronRuntimeVersion; namespace { -bool LoadFromCachedNetwork( - const litert::mediatek::NeuronAdapter& neuron_adapter, NeuronModel*& model, - NeuronCompilation*& compilation, const void* bytecode_addr, - size_t bytecode_size) { - return neuron_adapter.api().model_restore_from_compiled_network( - &model, &compilation, bytecode_addr, bytecode_size) == - NEURON_NO_ERROR; +Expected> LoadFromCachedNetwork( + const litert::mediatek::NeuronAdapter& neuron_adapter, + const void* bytecode_addr, size_t bytecode_size) { + NeuronModel* model; + NeuronCompilation* compilation; + if (neuron_adapter.api().model_restore_from_compiled_network( + &model, &compilation, bytecode_addr, bytecode_size) != + NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to restore model from compiled network"); + } + return std::make_pair( + NeuronModelPtr{model, neuron_adapter.api().model_free}, + NeuronCompilationPtr{compilation, neuron_adapter.api().compilation_free}); } uint16_t GetRestoreDlaExtensionOperandType( @@ -64,14 +78,13 @@ uint16_t GetRestoreDlaExtensionOperandType( } } -bool LoadFromDlaBytecode(const litert::mediatek::NeuronAdapter& neuron_adapter, - NeuronModel*& model, NeuronCompilation*& compilation, - const void* bytecode_addr, size_t bytecode_size, - int num_inputs, int num_outputs) { - LITERT_LOG(LITERT_INFO, "Creating model..."); - if (neuron_adapter.api().model_create(&model) != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, "Failed to create model"); - return false; +Expected> LoadFromDlaBytecode( + const litert::mediatek::NeuronAdapter& neuron_adapter, + const void* bytecode_addr, size_t bytecode_size, int num_inputs, + int num_outputs) { + Expected model = neuron_adapter.CreateModel(); + if (!model) { + return model.Error(); } // fake input, the real outputs are loaded by compiled network. @@ -85,10 +98,10 @@ bool LoadFromDlaBytecode(const litert::mediatek::NeuronAdapter& neuron_adapter, std::vector input_op_number; input_op_number.reserve(num_inputs); for (auto i = 0; i < num_inputs; i++) { - if (neuron_adapter.api().model_add_operand(model, &fake_io_operand_type) != - NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, "Failed to add input operand %d", i); - return false; + if (neuron_adapter.api().model_add_operand( + model->get(), &fake_io_operand_type) != NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to add input operand"); } input_op_number.emplace_back(i); } @@ -101,10 +114,10 @@ bool LoadFromDlaBytecode(const litert::mediatek::NeuronAdapter& neuron_adapter, int32_t operand_type; if (neuron_adapter.api().model_get_extension_operand_type( - model, kExtensionRestoreCompiledNetwork, kNetworkOperandRestoreData, - &operand_type) != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, "Failed to get extension operand"); - return false; + model->get(), kExtensionRestoreCompiledNetwork, + kNetworkOperandRestoreData, &operand_type) != NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to getextension operand"); } const NeuronOperandType extension_operand_type{ @@ -113,147 +126,141 @@ bool LoadFromDlaBytecode(const litert::mediatek::NeuronAdapter& neuron_adapter, .scale = 0.0f, .zeroPoint = 0, }; - if (neuron_adapter.api().model_add_operand(model, &extension_operand_type) != - NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, "Failed to add extension operand"); - return false; + if (neuron_adapter.api().model_add_operand( + model->get(), &extension_operand_type) != NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to add extension operand"); } input_op_number.emplace_back(input_op_number.size()); if (neuron_adapter.api().model_set_operand_value( - model, input_op_number.back(), bytecode_addr, bytecode_size) != + model->get(), input_op_number.back(), bytecode_addr, bytecode_size) != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, "Failed to set extension operand value"); - return false; + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to set extension operand value"); } std::vector output_op_number; for (auto i = 0; i < num_outputs; i++) { - if (neuron_adapter.api().model_add_operand(model, &fake_io_operand_type) != - NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, "Failed to add output operand %d", i); - return false; + if (neuron_adapter.api().model_add_operand( + model->get(), &fake_io_operand_type) != NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to add output operand"); } output_op_number.emplace_back(input_op_number.size() + i); } int32_t operation_type; if (neuron_adapter.api().model_get_extension_operation_type( - model, kExtensionRestoreCompiledNetwork, + model->get(), kExtensionRestoreCompiledNetwork, kRestoreDlaExtensionOperationType, &operation_type) != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, "Failed to get extension operation"); - return false; + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to get extension operation"); } // Add extension operation if (neuron_adapter.api().model_add_operation( - model, static_cast(operation_type), + model->get(), static_cast(operation_type), input_op_number.size(), input_op_number.data(), output_op_number.size(), output_op_number.data()) != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, "Failed to add extension operation"); - return false; + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to add extension operation"); } if (neuron_adapter.api().model_identify_inputs_and_outputs( - model, input_op_number.size() - 1, input_op_number.data(), + model->get(), input_op_number.size() - 1, input_op_number.data(), output_op_number.size(), output_op_number.data()) != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, "Failed to identify I/Os"); - return false; + return Error(kLiteRtStatusErrorRuntimeFailure, "Failed to identify I/Os"); } - if (neuron_adapter.api().model_finish(model) != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, "Failed to finish model"); - return false; + if (neuron_adapter.api().model_finish(model->get()) != NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, "Failed to finish model"); } - if (neuron_adapter.api().compilation_create(model, &compilation) != - NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, "Failed to create compilation"); - return false; + auto compilation = neuron_adapter.CreateCompilation(model->get()); + if (!compilation) { + return compilation.Error(); } if (neuron_adapter.api().compilation_set_priority( - compilation, NEURON_PRIORITY_HIGH) != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, "Failed to set compilation priority"); - return false; + compilation->get(), NEURON_PRIORITY_HIGH) != NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to set compilation priority"); } if (neuron_adapter.api().compilation_set_preference( - compilation, NEURON_PREFER_SUSTAINED_SPEED) != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, "Failed to set compilation preference"); - return false; + compilation->get(), NEURON_PREFER_SUSTAINED_SPEED) != + NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to set compilation preference"); } - if (neuron_adapter.api().compilation_finish(compilation) != NEURON_NO_ERROR) { - LITERT_LOG(LITERT_ERROR, "Failed to finish compilation"); - return false; + // We use AOT compile options since the DLA file was compiled ahead of time. + const auto compile_options = std::string(neuron_adapter.AotCompileOptions()); + if (!compile_options.empty()) { + if (neuron_adapter.api().compilation_set_optimization_string( + compilation->get(), compile_options.c_str()) != NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to set optimization string"); + } } - return true; + if (neuron_adapter.api().compilation_finish(compilation->get()) != + NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to finish compilation"); + } + + return std::make_pair(std::move(*model), std::move(*compilation)); } -bool LoadModelAndCompilation( - const litert::mediatek::NeuronAdapter& neuron_adapter, NeuronModel*& model, - NeuronCompilation*& compilation, const void* bytecode_addr, - size_t bytecode_size, int num_inputs, int num_outputs) { - if (!LoadFromDlaBytecode(neuron_adapter, model, compilation, bytecode_addr, - bytecode_size, num_inputs, num_outputs)) { - return LoadFromCachedNetwork(neuron_adapter, model, compilation, - bytecode_addr, bytecode_size); +Expected> +LoadModelAndCompilation(const litert::mediatek::NeuronAdapter& neuron_adapter, + const void* bytecode_addr, size_t bytecode_size, + int num_inputs, int num_outputs) { + if (auto result = LoadFromDlaBytecode(neuron_adapter, bytecode_addr, + bytecode_size, num_inputs, num_outputs); + !result) { + return LoadFromCachedNetwork(neuron_adapter, bytecode_addr, bytecode_size); + } else { + return result; } - return true; } } // namespace -litert::Expected +Expected LiteRtDispatchInvocationContextT::Create( litert::mediatek::NeuronAdapter& neuron_adapter, LiteRtDispatchDeviceContext device_context, LiteRtDispatchExecutableType exec_type, const void* bytecode_ptr, size_t bytecode_size, const char* function_name, int num_inputs, int num_outputs) { - NeuronModel* model; - NeuronCompilation* compilation; - if (!LoadModelAndCompilation(neuron_adapter, model, compilation, bytecode_ptr, - bytecode_size, num_inputs, num_outputs)) { - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to load compiled model"); + auto model_and_compilation = LoadModelAndCompilation( + neuron_adapter, bytecode_ptr, bytecode_size, num_inputs, num_outputs); + if (!model_and_compilation) { + return model_and_compilation.Error(); } - NeuronExecution* execution; - if (neuron_adapter.api().execution_create(compilation, &execution) != - NEURON_NO_ERROR) { - if (compilation) { - neuron_adapter.api().compilation_free(compilation); - } - if (model) { - neuron_adapter.api().model_free(model); - } - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create execution"); + auto& model = model_and_compilation->first; + auto& compilation = model_and_compilation->second; + + auto execution = neuron_adapter.CreateExecution(compilation.get()); + if (!execution) { + return execution.Error(); } - if (neuron_adapter.api().execution_set_boost_hint(execution, 100) != + if (neuron_adapter.api().execution_set_boost_hint(execution->get(), 100) != NEURON_NO_ERROR) { - if (execution) { - neuron_adapter.api().execution_free(execution); - } - if (compilation) { - neuron_adapter.api().compilation_free(compilation); - } - if (model) { - neuron_adapter.api().model_free(model); - } - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to set execution boost hint"); + return litert::Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to set execution boost hint"); } return Ptr(new LiteRtDispatchInvocationContextT( - neuron_adapter, device_context, model, compilation, execution, num_inputs, - num_outputs)); + neuron_adapter, device_context, model.release(), compilation.release(), + execution->release(), num_inputs, num_outputs)); } LiteRtDispatchInvocationContextT::~LiteRtDispatchInvocationContextT() { @@ -279,12 +286,12 @@ LiteRtDispatchInvocationContextT::IoRequirementsBuilder::IoRequirementsBuilder( } } -litert::Expected +Expected LiteRtDispatchInvocationContextT::IoRequirementsBuilder::Create() { - static constexpr std::array - kSupportedTensorBufferTypes = { - kLiteRtTensorBufferTypeAhwb, - }; + static constexpr std::array kSupportedTensorBufferTypes = { + kLiteRtTensorBufferTypeAhwb, + kLiteRtTensorBufferTypeDmaBuf, + }; LiteRtTensorBufferRequirements requirements; if (auto status = LiteRtCreateTensorBufferRequirements( @@ -292,30 +299,30 @@ LiteRtDispatchInvocationContextT::IoRequirementsBuilder::Create() { kSupportedTensorBufferTypes.data(), buffer_size_, strides_.size(), strides_.data(), &requirements); status != kLiteRtStatusOk) { - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to create tensor buffer requirements"); + return litert::Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to create tensor buffer requirements"); } return requirements; } -litert::Expected +Expected LiteRtDispatchInvocationContextT::GetInputRequirements( int input_index, const LiteRtRankedTensorType& tensor_type) { if (!input_requirements_builders_[input_index]) { size_t buffer_size; if (neuron_adapter_.api().compilation_get_input_padded_size( compilation_, input_index, &buffer_size) != NEURON_NO_ERROR) { - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get input padded size"); + return litert::Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to get input padded size"); } std::vector padded_dimensions(tensor_type.layout.rank); if (neuron_adapter_.api().compilation_get_input_padded_dimensions( compilation_, input_index, padded_dimensions.data()) != NEURON_NO_ERROR) { - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get input padded dimensions"); + return litert::Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to get input padded dimensions"); } input_requirements_builders_[input_index] = @@ -325,23 +332,23 @@ LiteRtDispatchInvocationContextT::GetInputRequirements( return input_requirements_builders_[input_index]->Create(); } -litert::Expected +Expected LiteRtDispatchInvocationContextT::GetOutputRequirements( int output_index, const LiteRtRankedTensorType& tensor_type) { if (!output_requirements_builders_[output_index]) { size_t buffer_size; if (neuron_adapter_.api().compilation_get_output_padded_size( compilation_, output_index, &buffer_size) != NEURON_NO_ERROR) { - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get output padded size"); + return litert::Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to get output padded size"); } std::vector padded_dimensions(tensor_type.layout.rank); if (neuron_adapter_.api().compilation_get_output_padded_dimensions( compilation_, output_index, padded_dimensions.data()) != NEURON_NO_ERROR) { - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to get output padded dimensions"); + return litert::Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to get output padded dimensions"); } output_requirements_builders_[output_index] = @@ -351,58 +358,58 @@ LiteRtDispatchInvocationContextT::GetOutputRequirements( return output_requirements_builders_[output_index]->Create(); } -litert::Expected LiteRtDispatchInvocationContextT::AttachInput( +Expected LiteRtDispatchInvocationContextT::AttachInput( int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle) { auto neuron_memory_info = device_context_->GetNeuronMemoryInfo(tensor_buffer_handle); if (!neuron_memory_info) { - return litert::Unexpected(neuron_memory_info.Error()); + return litert::Error(neuron_memory_info.Error()); } if (neuron_adapter_.api().execution_set_input_from_memory( execution_, graph_input_index, nullptr, neuron_memory_info->neuron_memory, neuron_memory_info->offset, neuron_memory_info->size) != NEURON_NO_ERROR) { - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to set execution input from memory"); + return litert::Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to set execution input from memory"); } return {}; } -litert::Expected LiteRtDispatchInvocationContextT::AttachOutput( +Expected LiteRtDispatchInvocationContextT::AttachOutput( int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle) { auto neuron_memory_info = device_context_->GetNeuronMemoryInfo(tensor_buffer_handle); if (!neuron_memory_info) { - return litert::Unexpected(neuron_memory_info.Error()); + return litert::Error(neuron_memory_info.Error()); } if (neuron_adapter_.api().execution_set_output_from_memory( execution_, graph_output_index, nullptr, neuron_memory_info->neuron_memory, neuron_memory_info->offset, neuron_memory_info->size) != NEURON_NO_ERROR) { - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to set execution output from memory"); + return litert::Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to set execution output from memory"); } return {}; } -litert::Expected LiteRtDispatchInvocationContextT::DetachInput( +Expected LiteRtDispatchInvocationContextT::DetachInput( int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle) { // Nothing to do. return {}; } -litert::Expected LiteRtDispatchInvocationContextT::DetachOutput( +Expected LiteRtDispatchInvocationContextT::DetachOutput( int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle) { // Nothing to do. return {}; } -litert::Expected LiteRtDispatchInvocationContextT::Invoke() { +Expected LiteRtDispatchInvocationContextT::Invoke() { if (neuron_adapter_.api().execution_compute(execution_) != NEURON_NO_ERROR) { - return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, - "Failed to execute network"); + return litert::Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to execute network"); } return {}; } diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.cc b/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.cc index 1b048f1ffc8e74..f8de79dc1ad842 100644 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.cc +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.cc @@ -19,7 +19,9 @@ #include #include #include +#include +#include "absl/strings/str_cat.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_logging.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" @@ -55,30 +57,28 @@ litert::Expected NeuronAdapter::Create( litert::Expected NeuronAdapter::LoadSymbols( std::optional shared_library_dir) { - // The following preinstalled library is for system partition applications. - if (litert::internal::OpenLib("libneuronusdk_adapter.mtk.so", - &dlib_handle_) != kLiteRtStatusOk) { - // The next preinstalled library is in the vendor partition. - if (litert::internal::OpenLib("libneuron_adapter_mgvi.so", &dlib_handle_) != - kLiteRtStatusOk) { + constexpr auto kLibNeuronAdapterLib = "libneuron_adapter.so"; + + const std::vector so_paths = { + // The following preinstalled library is for system partition + // applications. + "libneuronusdk_adapter.mtk.so", + // The next preinstalled library is in the vendor partition. + "libneuron_adapter_mgvi.so", // Finally, the app may want to provide their own version of the library. - constexpr auto kLibNeuronAdapterLib = "libneuron_adapter.so"; - std::string library_path = - shared_library_dir.has_value() - ? *shared_library_dir + kLibNeuronAdapterLib - : kLibNeuronAdapterLib; - if (litert::internal::OpenLib(library_path, &dlib_handle_) != - kLiteRtStatusOk) { - return litert::Unexpected( - kLiteRtStatusErrorRuntimeFailure, - "Failed to load NeuronAdapter shared library"); - } - } + shared_library_dir.has_value() + ? absl::StrCat(*shared_library_dir, "/", kLibNeuronAdapterLib) + : kLibNeuronAdapterLib}; + if (litert::internal::OpenLib(so_paths, &dlib_handle_) != kLiteRtStatusOk) { + return litert::Unexpected(kLiteRtStatusErrorRuntimeFailure, + "Failed to load NeuronAdapter shared library"); } // Binds all supported symbols from the shared library to the function // pointers. LOAD_SYMB(NeuronCompilation_create, api_->compilation_create); + LOAD_SYMB(NeuronCompilation_createWithOptions, + api_->compilation_create_with_options); LOAD_SYMB(NeuronCompilation_finish, api_->compilation_finish); LOAD_SYMB(NeuronCompilation_free, api_->compilation_free); LOAD_SYMB(NeuronCompilation_getInputPaddedDimensions, @@ -96,6 +96,10 @@ litert::Expected NeuronAdapter::LoadSymbols( LOAD_SYMB(NeuronExecution_compute, api_->execution_compute); LOAD_SYMB(NeuronExecution_create, api_->execution_create); LOAD_SYMB(NeuronExecution_free, api_->execution_free); + LOAD_SYMB(NeuronCompilation_getCompiledNetworkSize, + api_->compilation_get_compiled_network_size); + LOAD_SYMB(NeuronCompilation_storeCompiledNetwork, + api_->compilation_store_compiled_network); LOAD_SYMB(NeuronExecution_setBoostHint, api_->execution_set_boost_hint); LOAD_SYMB(NeuronExecution_setInputFromMemory, api_->execution_set_input_from_memory); @@ -103,6 +107,7 @@ litert::Expected NeuronAdapter::LoadSymbols( api_->execution_set_output_from_memory); LOAD_SYMB(NeuronMemory_createFromAHardwareBuffer, api_->memory_create_from_ahwb); + LOAD_SYMB(NeuronMemory_createFromFd, api_->memory_create_from_fd); LOAD_SYMB(NeuronMemory_free, api_->memory_free); LOAD_SYMB(NeuronModel_addOperand, api_->model_add_operand); LOAD_SYMB(NeuronModel_addOperation, api_->model_add_operation); @@ -117,6 +122,7 @@ litert::Expected NeuronAdapter::LoadSymbols( api_->model_identify_inputs_and_outputs); LOAD_SYMB(NeuronModel_restoreFromCompiledNetwork, api_->model_restore_from_compiled_network); + LOAD_SYMB(NeuronModel_setName, api_->model_set_name); LOAD_SYMB(NeuronModel_setOperandValue, api_->model_set_operand_value); LOAD_SYMB(Neuron_getVersion, api_->get_version); @@ -124,5 +130,45 @@ litert::Expected NeuronAdapter::LoadSymbols( return {}; } +Expected NeuronAdapter::CreateModel() const { + NeuronModel* model; + if (api().model_create(&model) != NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to create NeuroModel"); + } + return NeuronModelPtr{model, api().model_free}; +} + +Expected NeuronAdapter::CreateCompilation( + NeuronModel* model) const { + NeuronCompilation* compilation; + if (api().compilation_create(model, &compilation) != NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to create NeuronCompilation"); + } + return NeuronCompilationPtr{compilation, api().compilation_free}; +} + +Expected NeuronAdapter::CreateCompilation( + NeuronModel* model, const std::string& compile_options) const { + NeuronCompilation* compilation; + if (api().compilation_create_with_options( + model, &compilation, compile_options.c_str()) != NEURON_NO_ERROR) { + return Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to create NeuronCompilation"); + } + return NeuronCompilationPtr{compilation, api().compilation_free}; +} + +Expected NeuronAdapter::CreateExecution( + NeuronCompilation* compilation) const { + NeuronExecution* execution; + if (api().execution_create(compilation, &execution) != NEURON_NO_ERROR) { + return litert::Error(kLiteRtStatusErrorRuntimeFailure, + "Failed to create execution"); + } + return NeuronExecutionPtr{execution, api().execution_free}; +} + } // namespace mediatek } // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.h b/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.h index d29234eb469758..47809716eb0c59 100644 --- a/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.h +++ b/tensorflow/lite/experimental/litert/vendors/mediatek/neuron_adapter.h @@ -20,6 +20,7 @@ #include #include +#include "absl/strings/string_view.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" #if LITERT_HAS_AHWB_SUPPORT @@ -64,13 +65,21 @@ struct NeuronMemory; static constexpr int NEURON_NO_ERROR = 0; static constexpr int NEURON_FLOAT32 = 0; +static constexpr int NEURON_INT32 = 1; +static constexpr int NEURON_BOOL = 6; static constexpr int NEURON_TENSOR_FLOAT32 = 3; +static constexpr int NEURON_TENSOR_INT32 = 4; static constexpr int NEURON_PRIORITY_HIGH = 110; static constexpr int NEURON_PREFER_SUSTAINED_SPEED = 2; int NeuronCompilation_create(NeuronModel* model, NeuronCompilation** compilation); +int NeuronCompilation_createWithOptions(NeuronModel* model, + NeuronCompilation** compilation, + const char* options); int NeuronCompilation_finish(NeuronCompilation* compilation); +int NeuronCompilation_getCompiledNetworkSize(NeuronCompilation* compilation, + size_t* size); int NeuronCompilation_getInputPaddedDimensions(NeuronCompilation* compilation, int32_t index, uint32_t* dimensions); @@ -86,6 +95,8 @@ int NeuronCompilation_setOptimizationString(NeuronCompilation* compilation, int NeuronCompilation_setPreference(NeuronCompilation* compilation, int32_t preference); int NeuronCompilation_setPriority(NeuronCompilation* compilation, int priority); +int NeuronCompilation_storeCompiledNetwork(NeuronCompilation* compilation, + void* buffer, size_t size); int NeuronExecution_compute(NeuronExecution* execution); int NeuronExecution_create(NeuronCompilation* compilation, NeuronExecution** execution); @@ -103,6 +114,8 @@ int NeuronExecution_setOutputFromMemory(NeuronExecution* execution, size_t offset, size_t length); int NeuronMemory_createFromAHardwareBuffer(const AHardwareBuffer* ahwb, NeuronMemory** memory); +int NeuronMemory_createFromFd(size_t size, int protect, int fd, size_t offset, + NeuronMemory** memory); int NeuronModel_addOperand(NeuronModel* model, const NeuronOperandType* type); int NeuronModel_addOperation(NeuronModel* model, NeuronOperationType type, uint32_t inputCount, const uint32_t* inputs, @@ -125,6 +138,7 @@ int NeuronModel_identifyInputsAndOutputs(NeuronModel* model, int NeuronModel_restoreFromCompiledNetwork(NeuronModel** model, NeuronCompilation** compilation, const void* buffer, size_t size); +int NeuronModel_setName(NeuronModel* model, const char* name); int NeuronModel_setOperandValue(NeuronModel* model, int32_t index, const void* buffer, size_t length); int Neuron_getVersion(NeuronRuntimeVersion* version); @@ -135,6 +149,12 @@ void NeuronModel_free(NeuronModel* model); // ///////////////////////////////////////////////////////////////////////////// +using NeuronModelPtr = std::unique_ptr; +using NeuronCompilationPtr = + std::unique_ptr; +using NeuronExecutionPtr = + std::unique_ptr; + class NeuronAdapter { public: using Ptr = std::unique_ptr; @@ -147,11 +167,28 @@ class NeuronAdapter { ~NeuronAdapter(); - static litert::Expected Create( - std::optional shared_library_dir); + static Expected Create(std::optional shared_library_dir); const Api& api() const { return *api_; } + absl::string_view AotCompileOptions() const { + // Option `import_forever` has been recommended by MediaTek to reduce memory + // footprint when using the same I/O buffers across multiple invocations. + return "--apusys-config \"{ \\\"import_forever\\\": true }\""; + } + + absl::string_view JitCompileOptions() const { return ""; } + + Expected CreateModel() const; + + Expected CreateCompilation(NeuronModel* model) const; + + Expected CreateCompilation( + NeuronModel* model, const std::string& compile_options) const; + + Expected CreateExecution( + NeuronCompilation* compilation) const; + private: NeuronAdapter(); litert::Expected LoadSymbols( @@ -166,8 +203,12 @@ class NeuronAdapter { // device during runtime. struct NeuronAdapter::Api { decltype(&NeuronCompilation_create) compilation_create = nullptr; + decltype(&NeuronCompilation_createWithOptions) + compilation_create_with_options = nullptr; decltype(&NeuronCompilation_finish) compilation_finish = nullptr; decltype(&NeuronCompilation_free) compilation_free = nullptr; + decltype(&NeuronCompilation_getCompiledNetworkSize) + compilation_get_compiled_network_size = nullptr; decltype(&NeuronCompilation_getInputPaddedDimensions) compilation_get_input_padded_dimensions = nullptr; decltype(&NeuronCompilation_getInputPaddedSize) @@ -181,6 +222,8 @@ struct NeuronAdapter::Api { decltype(&NeuronCompilation_setPreference) compilation_set_preference = nullptr; decltype(&NeuronCompilation_setPriority) compilation_set_priority = nullptr; + decltype(&NeuronCompilation_storeCompiledNetwork) + compilation_store_compiled_network = nullptr; decltype(&NeuronExecution_compute) execution_compute = nullptr; decltype(&NeuronExecution_create) execution_create = nullptr; decltype(&NeuronExecution_free) execution_free = nullptr; @@ -191,6 +234,7 @@ struct NeuronAdapter::Api { execution_set_output_from_memory = nullptr; decltype(&NeuronMemory_createFromAHardwareBuffer) memory_create_from_ahwb = nullptr; + decltype(&NeuronMemory_createFromFd) memory_create_from_fd = nullptr; decltype(&NeuronMemory_free) memory_free = nullptr; decltype(&NeuronModel_addOperand) model_add_operand = nullptr; decltype(&NeuronModel_addOperation) model_add_operation = nullptr; @@ -205,6 +249,7 @@ struct NeuronAdapter::Api { model_identify_inputs_and_outputs = nullptr; decltype(&NeuronModel_restoreFromCompiledNetwork) model_restore_from_compiled_network = nullptr; + decltype(&NeuronModel_setName) model_set_name = nullptr; decltype(&NeuronModel_setOperandValue) model_set_operand_value = nullptr; decltype(&Neuron_getVersion) get_version = nullptr; }; diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/BUILD index 63b02cdba090a4..d4eb799ce5f3ad 100644 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/BUILD +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/BUILD @@ -99,6 +99,12 @@ litert_test( # Sanitizer runtime doesn't work with anything that loads libQnnHtp.so. "nosan", ], + # This test can be run only on Android and Linux. + target_compatible_with = select({ + "@platforms//os:android": [], + "@platforms//os:linux": [], + "//conditions:default": ["@platforms//:incompatible"], + }), deps = [ ":qnn_manager", "//tensorflow/lite/experimental/litert/test:common", @@ -113,7 +119,6 @@ cc_library( deps = [ ":qnn_manager", ":qnn_tensor", - "@com_google_absl//absl/log", "@com_google_absl//absl/strings:string_view", # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", "//tensorflow/lite/experimental/litert/c:litert_common", @@ -127,9 +132,7 @@ cc_library( srcs = ["qnn_tensor.cc"], hdrs = ["qnn_tensor.h"], deps = [ - "@com_google_absl//absl/log", "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/strings:string_view", # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", "//tensorflow/lite/experimental/litert/c:litert_common", diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/BUILD index f58fb0bca83397..0c6582262cbc18 100644 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/BUILD +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/BUILD @@ -67,6 +67,12 @@ litert_test( # Sanitizer runtime doesn't work with anything that loads libQnnHtp.so. "nosan", ], + # This test can be run only on Android and Linux. + target_compatible_with = select({ + "@platforms//os:android": [], + "@platforms//os:linux": [], + "//conditions:default": ["@platforms//:incompatible"], + }), use_sys_malloc = True, deps = [ ":qnn_compiler_plugin", # buildcleaner: keep @@ -119,6 +125,7 @@ litert_lib( "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:div_op_legalization", "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:embedding_lookup_op_legalization", "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:fully_connected_op_legalization", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:gelu_op_legalization", "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:greater_op_legalization", "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:legalization", "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:less_op_legalization", @@ -157,6 +164,7 @@ litert_lib( "//tensorflow/lite/experimental/litert/c:litert_common", "//tensorflow/lite/experimental/litert/c:litert_logging", "//tensorflow/lite/experimental/litert/c:litert_model", + "//tensorflow/lite/experimental/litert/cc:litert_element_type", "//tensorflow/lite/experimental/litert/cc:litert_expected", "//tensorflow/lite/experimental/litert/cc:litert_macros", "//tensorflow/lite/experimental/litert/cc:litert_model", diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD index 23dff704ca4799..6fcc85b43770d9 100644 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD @@ -42,7 +42,10 @@ cc_library( cc_test( name = "qnn_tensor_test", srcs = ["qnn_tensor_test.cc"], - data = ["//tensorflow/lite/experimental/litert/test:mlir_test_data"], + data = [ + "//tensorflow/lite/experimental/litert/test:mlir_test_data", + "//tensorflow/lite/experimental/litert/test:tflite_test_data", + ], tags = [ # Don't build/test in OS until qnn is available. "nobuilder", @@ -53,8 +56,11 @@ cc_test( "@com_google_googletest//:gtest_main", "@com_google_absl//absl/types:span", # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", - "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", + "//tensorflow/lite/experimental/litert/c:litert_model", + "//tensorflow/lite/experimental/litert/cc:litert_model", "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/experimental/litert/test:test_macros", + "//tensorflow/lite/experimental/litert/test:test_models", ], ) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc index 8b5221f268920c..17ac5cf553ddf8 100644 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc @@ -22,7 +22,6 @@ #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_logging.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" -#include "tensorflow/lite/experimental/litert/cc/litert_expected.h" #include "tensorflow/lite/experimental/litert/cc/litert_macros.h" #include "tensorflow/lite/experimental/litert/cc/litert_model.h" #include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" @@ -64,6 +63,15 @@ void FreeTensorDims(Qnn_Tensor_t& tensor) { } } +void FreePerChannelQuantization(Qnn_Tensor_t& tensor) { + if (tensor.v2.quantizeParams.quantizationEncoding == + QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + delete[] tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset; + tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset = nullptr; + tensor.v2.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets = 0; + } +} + } // namespace void SetInputTensorAttrs(Qnn_Tensor_t& tensor) { @@ -86,10 +94,12 @@ void SetResultTensorAttrs(Qnn_Tensor_t& tensor) { void ResetTensor(Qnn_Tensor_t& tensor) { FreeTensorDims(tensor); + FreePerChannelQuantization(tensor); tensor = QNN_TENSOR_INIT; tensor.version = QNN_TENSOR_VERSION_2; tensor.v2 = QNN_TENSOR_V2_INIT; tensor.v2.dataFormat = QNN_TENSOR_DATA_FORMAT_DENSE; + tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; } Qnn_Tensor_t BuildDefaultTensor(uint32_t id) { @@ -127,6 +137,57 @@ uint32_t MoveToId(Qnn_Tensor_t& tensor) { return id; } +void SetPerChannelQuantization( + Qnn_Tensor_t& tensor, + const LiteRtQuantizationPerChannel& lite_rt_quantization_per_channel) { + tensor.v2.quantizeParams.quantizationEncoding = + QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; + + tensor.v2.quantizeParams.axisScaleOffsetEncoding = QNN_AXIS_SCALE_OFFSET_INIT; + tensor.v2.quantizeParams.axisScaleOffsetEncoding.axis = + lite_rt_quantization_per_channel.quantized_dimension; + tensor.v2.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets = + lite_rt_quantization_per_channel.num_channels; + + // Allocates memory for scaleOffset array. + tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset = + new Qnn_ScaleOffset_t[lite_rt_quantization_per_channel.num_channels]; + + for (int i = 0; i < lite_rt_quantization_per_channel.num_channels; ++i) { + tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset[i].scale = + lite_rt_quantization_per_channel.scales[i]; + tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset[i].offset = + lite_rt_quantization_per_channel.zero_points[i]; + } +} + +void SetPerTensorQuantization( + Qnn_Tensor_t& tensor, + const LiteRtQuantizationPerTensor& lite_rt_quantization_per_tensor) { + tensor.v2.quantizeParams.quantizationEncoding = + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + tensor.v2.quantizeParams.scaleOffsetEncoding.scale = + lite_rt_quantization_per_tensor.scale; + tensor.v2.quantizeParams.scaleOffsetEncoding.offset = + lite_rt_quantization_per_tensor.zero_point; +} + +LiteRtStatus LegalizeQuntizationParameter(const litert::Tensor& src, + Qnn_Tensor_t& dest) { + LiteRtQuantizationTypeId lite_rt_quantization_type_id = src.QTypeId(); + switch (lite_rt_quantization_type_id) { + case kLiteRtQuantizationPerTensor: + SetPerTensorQuantization(dest, src.PerTensorQuantization()); + return kLiteRtStatusOk; + case kLiteRtQuantizationPerChannel: + SetPerChannelQuantization(dest, src.PerChannelQuantization()); + return kLiteRtStatusOk; + default: + LITERT_LOG(LITERT_ERROR, "Unsupported quantization type."); + return kLiteRtStatusErrorInvalidArgument; + } +} + LiteRtStatus LegalizeTensor(const litert::Tensor& src, Qnn_Tensor_t& dest) { if (src.TypeId() != kLiteRtRankedTensorType) { return kLiteRtStatusErrorInvalidArgument; @@ -134,12 +195,23 @@ LiteRtStatus LegalizeTensor(const litert::Tensor& src, Qnn_Tensor_t& dest) { ResetTensor(dest); + if (src.HasQuantization()) { + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeQuntizationParameter(src, dest)); + } + + auto src_ranked_tensor_type = src.RankedTensorType(); + if (!src_ranked_tensor_type) { + LITERT_LOG(LITERT_ERROR, "%s", + src_ranked_tensor_type.Error().Message().data()); + return src_ranked_tensor_type.Error().Status(); + } + Qnn_DataType_t* qnn_data_type = &dest.v2.dataType; - LITERT_RETURN_STATUS_IF_NOT_OK( - LegalizeElementType(src.RankedTensorType().ElementType(), qnn_data_type)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeElementType( + src_ranked_tensor_type->ElementType(), qnn_data_type)); LITERT_RETURN_STATUS_IF_NOT_OK( - LegalizeShapeInfo(src.RankedTensorType().Layout(), dest)); + LegalizeShapeInfo(src_ranked_tensor_type->Layout(), dest)); const bool is_subgraph_in = src.IsSubgraphInput(); const bool is_subgraph_out = src.IsSubgraphOutput(); diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc index ee0f0dc12b2c49..b03b32eab9379e 100644 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc @@ -18,11 +18,17 @@ #include #include "absl/types/span.h" #include "third_party/qairt/latest/include/QNN/QnnTypes.h" -#include "tensorflow/lite/experimental/litert/cc/litert_model_predicates.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" #include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/test/test_macros.h" +#include "tensorflow/lite/experimental/litert/test/test_models.h" namespace { +constexpr float kSimpleMulQuantModelOutputScale = 0.00028621565f; +constexpr float kSimpleMulQuantModelOutputOffset = 0; + TEST(TestInitQnnTensor, BuildDefaultTensor) { Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(); ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); @@ -130,4 +136,68 @@ TEST(TestLegalizeTensor, SimpleSupportedTensor) { litert::qnn::ResetTensor(qnn_tensor); } +TEST(TestLegalizeTensor, SimpleQuantizedTensor) { + auto model = litert::testing::LoadTestFileModel(kQSimpleMul16x16Model); + + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + auto ops = subgraph->Ops(); + auto op_outs = ops.at(0).Outputs(); + + auto qnn_tensor = litert::qnn::BuildDefaultTensor(); + const auto& op_out = op_outs.front(); + LITERT_ASSERT_STATUS_OK(litert::qnn::LegalizeTensor(op_out, qnn_tensor)); + + ASSERT_EQ(qnn_tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(qnn_tensor.v2.dataType, QNN_DATATYPE_INT_16); + EXPECT_EQ(qnn_tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); + + ASSERT_EQ(qnn_tensor.v2.quantizeParams.quantizationEncoding, + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET); + ASSERT_FLOAT_EQ(qnn_tensor.v2.quantizeParams.scaleOffsetEncoding.scale, + kSimpleMulQuantModelOutputScale); + + ASSERT_FLOAT_EQ(qnn_tensor.v2.quantizeParams.scaleOffsetEncoding.offset, + kSimpleMulQuantModelOutputOffset); + litert::qnn::ResetTensor(qnn_tensor); +} + +TEST(TestLegalizeTensor, PerChannelQuantizedTensor) { + auto model = litert::testing::LoadTestFileModel(kQKeyEinsum16x8Model); + + auto subgraph = model.MainSubgraph(); + EXPECT_TRUE(subgraph); + auto ops = subgraph->Ops(); + auto op_ins = ops.at(1).Inputs(); + + auto qnn_tensor = litert::qnn::BuildDefaultTensor(); + const auto& per_channel_quant_tensor = op_ins[1]; + LITERT_ASSERT_STATUS_OK( + litert::qnn::LegalizeTensor(per_channel_quant_tensor, qnn_tensor)); + + EXPECT_EQ(qnn_tensor.v2.dataType, QNN_DATATYPE_INT_8); + + LiteRtQuantizationPerChannel per_channel_quant_params = + per_channel_quant_tensor.PerChannelQuantization(); + + ASSERT_EQ(qnn_tensor.v2.quantizeParams.quantizationEncoding, + QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET); + EXPECT_EQ(qnn_tensor.v2.quantizeParams.axisScaleOffsetEncoding.axis, + per_channel_quant_params.quantized_dimension); + EXPECT_EQ( + qnn_tensor.v2.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets, + per_channel_quant_params.num_channels); + for (int i = 0; i < per_channel_quant_params.num_channels; ++i) { + ASSERT_FLOAT_EQ( + qnn_tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset[i] + .scale, + per_channel_quant_params.scales[i]); + ASSERT_EQ( + qnn_tensor.v2.quantizeParams.axisScaleOffsetEncoding.scaleOffset[i] + .offset, + per_channel_quant_params.zero_points[i]); + } + litert::qnn::ResetTensor(qnn_tensor); +} + } // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc index 109ec5720fa811..d0b628b9811076 100644 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include "absl/container/flat_hash_map.h" @@ -29,6 +30,7 @@ #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_logging.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_element_type.h" #include "tensorflow/lite/experimental/litert/cc/litert_macros.h" #include "tensorflow/lite/experimental/litert/cc/litert_model.h" #include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" @@ -38,7 +40,7 @@ namespace litert::qnn { // Get empty configurations for graph building. -inline absl::Span GetDefaultGraphConfigs() { +inline absl::Span GetFp32GraphConfigs() { static QnnHtpGraph_CustomConfig_t htp_graph_config = QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT; htp_graph_config.option = QNN_HTP_GRAPH_CONFIG_OPTION_PRECISION; @@ -52,6 +54,25 @@ inline absl::Span GetDefaultGraphConfigs() { return absl::MakeSpan(configs); } +inline absl::Span GetDefaultGraphConfigs() { + static const QnnGraph_Config_t* configs[] = {nullptr}; + return absl::MakeSpan(configs); +} + +absl::Span GraphMapper::PickGraphConfigHeuristic() { + for (const auto& input : subgraph_.Inputs()) { + if (input.ElementType() == ElementType::Float32) { + return GetFp32GraphConfigs(); + } + } + for (const auto& output : subgraph_.Outputs()) { + if (output.ElementType() == ElementType::Float32) { + return GetFp32GraphConfigs(); + } + } + return GetDefaultGraphConfigs(); +} + LiteRtStatus GraphMapper::AssignTensorName(Qnn_Tensor_t& qnn_tensor) { char* name = nullptr; const int written = asprintf(&name, "Tensor_%d", cur_tensor_num_++); @@ -129,7 +150,7 @@ LiteRtStatus GraphMapper::IsLiteRtSubgraphSupported() { LiteRtStatus GraphMapper::InitQnnGraph(absl::string_view qnn_graph_name) { LITERT_RETURN_STATUS_IF_QNN_NOT_OK( qnn_.Api()->graphCreate(context_handle_, qnn_graph_name.data(), - GetDefaultGraphConfigs().data(), &QnnGraph())); + PickGraphConfigHeuristic().data(), &QnnGraph())); return kLiteRtStatusOk; } diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h index 85414356218fad..0469fbdb4b5966 100644 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h @@ -20,7 +20,9 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnGraph.h" #include "third_party/qairt/latest/include/QNN/QnnTypes.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_model.h" @@ -84,6 +86,9 @@ class GraphMapper { // Finalize QNN Graph. Call this after all ops have been mapped. LiteRtStatus Finalize(); + // Pick graph config based on subgraph. + absl::Span PickGraphConfigHeuristic(); + inline void RegisterOutput(LiteRtTensor litert_tensor) { graph_outpus_.insert(litert_tensor); } diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD index 011b2a6ac0a0d9..dc0c61aaacbb9e 100644 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD @@ -329,6 +329,39 @@ litert_lib( ], ) +litert_lib( + name = "gelu_op_legalization", + srcs = ["gelu_op_legalization.cc"], + hdrs = ["gelu_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_common", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/c:litert_model", + "//tensorflow/lite/experimental/litert/c:litert_op_code", + "//tensorflow/lite/experimental/litert/c:litert_options", + "//tensorflow/lite/experimental/litert/cc:litert_expected", + "//tensorflow/lite/experimental/litert/cc:litert_macros", + "//tensorflow/lite/experimental/litert/cc:litert_model", + "//tensorflow/lite/experimental/litert/cc:litert_model_predicates", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + litert_lib( name = "greater_op_legalization", srcs = ["greater_op_legalization.cc"], diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.cc new file mode 100644 index 00000000000000..361e42187527a5 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.cc @@ -0,0 +1,50 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnGeluOpTypeName = "Gelu"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kGeluOpFmt = "gelu_%d"; + +LiteRtStatus GeluOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflGelu) { + return kLiteRtStatusLegalizeNoMatch; + } + const std::string op_name = absl::StrFormat(kGeluOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnGeluOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized gelu op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.h new file mode 100644 index 00000000000000..fdb31f5300d07c --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GELU_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GELU_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class GeluOpLegalization : public Legalization { + public: + GeluOpLegalization() = default; + ~GeluOpLegalization() = default; + using Ptr = std::unique_ptr; + static Ptr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_GELU_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc index 21cf7e80da5f97..2a961f86e319b0 100644 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc @@ -77,8 +77,14 @@ LiteRtStatus SliceOpLegalization::LegalizeOp(const Op& src, graph_mapper.PushToScope(op_outs.front().Get(), qnn_op_outs[0])); const auto& src_input_tensor = op_ins.front(); - auto src_input_tensor_rank = - src_input_tensor.RankedTensorType().Layout().Rank(); + auto src_input_tensor_type = src_input_tensor.RankedTensorType(); + if (!src_input_tensor_type) { + LITERT_LOG(LITERT_ERROR, "%s", + src_input_tensor_type.Error().Message().data()); + return src_input_tensor_type.Error().Status(); + } + + auto src_input_tensor_rank = src_input_tensor_type->Layout().Rank(); // Prepare qnn strided slice parameters. @@ -104,7 +110,8 @@ LiteRtStatus SliceOpLegalization::LegalizeOp(const Op& src, // Copy begin, end, and stride values from src_begin_indices and // src_size_indices to range_tensor_data. Stride is always 1. range_tensor_data[i * kRangesParamArgSize] = src_begin_indices->at(i); - range_tensor_data[i * kRangesParamArgSize + 1] = src_size_indices->at(i); + range_tensor_data[i * kRangesParamArgSize + 1] = + src_begin_indices->at(i) + src_size_indices->at(i); range_tensor_data[i * kRangesParamArgSize + 2] = 1; } diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.cc index d198feea7bc77c..034d0be6312db8 100644 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.cc +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sum_op_legalization.cc @@ -76,9 +76,18 @@ LiteRtStatus SumOpLegalization::LegalizeOp(const Op& src, Qnn_OpConfig_t& dest, // Check if src_axes are weights tensors. if (!src_axes.HasWeights()) { + LITERT_LOG(LITERT_ERROR, "Sum op axes are not weights tensors"); return kLiteRtStatusErrorInvalidLegalization; } - int32_t dest_axes_size = src_axes.RankedTensorType().Layout().Dimensions()[0]; + + auto src_axes_tensor_type = src_axes.RankedTensorType(); + if (!src_axes_tensor_type) { + LITERT_LOG(LITERT_ERROR, "%s", + src_axes_tensor_type.Error().Message().data()); + return src_axes_tensor_type.Error().Status(); + } + + int32_t dest_axes_size = src_axes_tensor_type->Layout().Dimensions()[0]; auto src_axes_data = src_axes.Weights().Bytes(); Qnn_ClientBuffer_t axes_tensor_client_buf = BuildDefaultClientBuffer(); axes_tensor_client_buf.data = (void*)src_axes_data.data(); diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc index 57a163c17a275d..d4f92af2f97747 100644 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc @@ -37,6 +37,7 @@ using ::litert::internal::DumpOptions; // Dump source Op details. void DumpLegalization(const LiteRtOpT& op) { std::ostringstream dump; + // TODO Make dump tools part of stable api. Dump(op, dump); DumpOptions(op, dump); std::string s = dump.str(); diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc index 08ea35cc089727..988aaa17f254bd 100644 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc @@ -76,6 +76,7 @@ constexpr LiteRtOpCode kSupportedOps[] = { kLiteRtOpCodeTflLogicalAnd, kLiteRtOpCodeTflLess, kLiteRtOpCodeTflGreater, + kLiteRtOpCodeTflGelu, }; // clang-format on @@ -110,6 +111,16 @@ const char* LiteRtGetCompilerPluginSocManufacturer() { return kPluginManufacturer; } +LiteRtStatus LiteRtGetCompilerPluginSupportedHardware( + LiteRtCompilerPlugin compiler_plugin, + LiteRtHwAccelerators* supported_hardware) { + if (!compiler_plugin || !supported_hardware) { + return kLiteRtStatusErrorInvalidArgument; + } + *supported_hardware = kLiteRtHwAccelatorNpu; + return kLiteRtStatusOk; +} + LiteRtStatus LiteRtGetNumCompilerPluginSupportedSocModels( LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex* num_supported_soc_models) { @@ -144,6 +155,9 @@ struct LiteRtCompiledResultT { LiteRtStatus LiteRtGetCompiledResultByteCode( LiteRtCompiledResult compiled_result, const void** byte_code, size_t* byte_code_size) { + if (!compiled_result || !byte_code || !byte_code_size) { + return kLiteRtStatusErrorInvalidArgument; + } *byte_code = compiled_result->context_bin.data(); *byte_code_size = compiled_result->context_bin.size(); return kLiteRtStatusOk; @@ -152,7 +166,9 @@ LiteRtStatus LiteRtGetCompiledResultByteCode( LiteRtStatus LiteRtGetCompiledResultCallInfo( LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, const void** call_info, size_t* call_info_size) { - if (call_idx >= compiled_result->graph_names.size()) { + if (!compiled_result || !call_info || !call_info_size) { + return kLiteRtStatusErrorInvalidArgument; + } else if (call_idx >= compiled_result->graph_names.size()) { return kLiteRtStatusErrorIndexOOB; } @@ -164,6 +180,9 @@ LiteRtStatus LiteRtGetCompiledResultCallInfo( LiteRtStatus LiteRtGetNumCompiledResultCalls( LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls) { + if (!compiled_result || !num_calls) { + return kLiteRtStatusErrorInvalidArgument; + } *num_calls = compiled_result->graph_names.size(); return kLiteRtStatusOk; } @@ -208,16 +227,11 @@ bool IsOpSupported(const litert::Op& op) { } // namespace -LiteRtStatus LiteRtCompilerPluginPartitionModel( - LiteRtCompilerPlugin compiler_plugin, LiteRtModel model, - LiteRtOpList selected_ops) { - auto m = litert::Model::CreateFromNonOwnedHandle(model); - auto subgraph = m.MainSubgraph(); - if (!subgraph) { - return subgraph.Error().Status(); - } - - for (const auto& op : subgraph->Ops()) { +LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin, + LiteRtSubgraph subgraph, + LiteRtOpList selected_ops) { + ::litert::Subgraph graph(subgraph); + for (const auto& op : graph.Ops()) { if (!IsOpSupported(op)) { continue; } @@ -230,7 +244,7 @@ LiteRtStatus LiteRtCompilerPluginPartitionModel( LiteRtStatus LiteRtCompilerPluginCompile( LiteRtCompilerPlugin compiler_plugin, const char* soc_model, - LiteRtSubgraphArray partitions, LiteRtParamIndex num_partitions, + LiteRtSubgraph* partitions, LiteRtParamIndex num_partitions, LiteRtCompiledResult* compiled_result) { LITERT_LOG(LITERT_INFO, "Starting QNN Compilation for %d subgraphs, soc_model=%s", diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc index 90ec9f2e9461bf..bf50a47d41d36e 100644 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc @@ -62,6 +62,7 @@ const auto kSupportedOps = "simple_logical_and_op.tflite", "simple_less_op.tflite", "simple_greater_op.tflite", + "simple_gelu_op.tflite", kFeedForwardModel, kKeyEinsumModel, kQueryEinsumModel, @@ -72,7 +73,13 @@ const auto kSupportedOps = kRMSNormModel, kSDPAModel, kAttentionModel, - kTransformerBlockModel + kTransformerBlockModel, + kQSimpleMul16x16Model, + kQMulAdd16x16Model, + kQQueryEinsum16x8Model, + kQKeyEinsum16x8Model, + kQVauleEinsum16x8Model, + kQAttnVecEinsum16x8Model ); // clang-format on @@ -97,12 +104,12 @@ TEST(TestQnnPlugin, PartitionMulOps) { auto model = testing::LoadTestFileModel("one_mul.tflite"); LiteRtOpListT selected_op_list; - LITERT_ASSERT_STATUS_OK(LiteRtCompilerPluginPartitionModel( - plugin.get(), model.Get(), &selected_op_list)); + LITERT_ASSERT_STATUS_OK(LiteRtCompilerPluginPartition( + plugin.get(), model.Subgraph(0)->Get(), &selected_op_list)); const auto selected_ops = selected_op_list.Vec(); ASSERT_EQ(selected_ops.size(), 1); - EXPECT_EQ(selected_ops[0]->op_code, kLiteRtOpCodeTflMul); + EXPECT_EQ(selected_ops[0]->OpCode(), kLiteRtOpCodeTflMul); } TEST(TestQnnPlugin, CompileMulSubgraph) { diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc index 7fe83fc2274205..ff1f7ca47e24ec 100644 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc @@ -40,6 +40,7 @@ #include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h" #include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/embedding_lookup_op_legalization.h" #include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/fully_connected_op_legalization.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/gelu_op_legalization.h" #include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/greater_op_legalization.h" #include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" #include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/less_op_legalization.h" @@ -85,6 +86,7 @@ LiteRtStatus RegisterAllLegalizations( legalizations.push_back(LogicalAndOpLegalization::Create()); legalizations.push_back(LessOpLegalization::Create()); legalizations.push_back(GreaterOpLegalization::Create()); + legalizations.push_back(GeluOpLegalization::Create()); LITERT_LOG(LITERT_INFO, "Scheduling %lu legalizations", legalizations.size()); return kLiteRtStatusOk; } diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/BUILD index f603a3a57ff836..2094db69436c30 100644 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/BUILD +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/BUILD @@ -80,7 +80,6 @@ cc_test( "notap", ], deps = [ - ":dispatch_api", "//tensorflow/lite/experimental/litert/c:litert_common", "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", "//tensorflow/lite/experimental/litert/core:filesystem", diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc index e6dcc0c4cb0fe7..e9fe08b3ca534f 100644 --- a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc @@ -24,6 +24,7 @@ #include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" #include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" #include "tensorflow/lite/experimental/litert/core/filesystem.h" +#include "tensorflow/lite/experimental/litert/test/common.h" #include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" #include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" @@ -60,9 +61,10 @@ TEST(Qualcomm, DispatchApiWithFastRpc) { kLiteRtStatusOk); ABSL_LOG(INFO) << "device_context: " << device_context; - auto model_file_name = kQualcommModelFileName; + auto model_file_name = + litert::testing::GetTestFilePath(kQualcommModelFileName); auto model = litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model); + EXPECT_TRUE(model) << model.Error(); ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() << " bytes"; @@ -311,9 +313,10 @@ TEST(Qualcomm, DispatchApiWithDmaBuf) { kLiteRtStatusOk); ABSL_LOG(INFO) << "device_context: " << device_context; - auto model_file_name = kQualcommModelFileName; - auto model = ::litert::internal::LoadBinaryFile(model_file_name); - EXPECT_TRUE(model); + auto model_file_name = + litert::testing::GetTestFilePath(kQualcommModelFileName); + auto model = litert::internal::LoadBinaryFile(model_file_name); + EXPECT_TRUE(model) << model.Error(); ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->Size() << " bytes"; diff --git a/tensorflow/lite/experimental/ml_adjacent/tflite/tfl_tensor_ref.h b/tensorflow/lite/experimental/ml_adjacent/tflite/tfl_tensor_ref.h index 3eca83a43b143c..2f37d71606ea11 100644 --- a/tensorflow/lite/experimental/ml_adjacent/tflite/tfl_tensor_ref.h +++ b/tensorflow/lite/experimental/ml_adjacent/tflite/tfl_tensor_ref.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_EXPERIMENTAL_ML_ADJACENT_TFLITE_TFL_TENSOR_REF_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_ML_ADJACENT_TFLITE_TFL_TENSOR_REF_H_ +#include + #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/experimental/ml_adjacent/lib.h" diff --git a/tensorflow/lite/experimental/shlo/legacy/test/BUILD b/tensorflow/lite/experimental/shlo/legacy/test/BUILD index 7adc819d12da6a..494e90fde90735 100644 --- a/tensorflow/lite/experimental/shlo/legacy/test/BUILD +++ b/tensorflow/lite/experimental/shlo/legacy/test/BUILD @@ -88,6 +88,7 @@ cc_test( ":util", "//tensorflow/lite/experimental/shlo/legacy:debug", "//tensorflow/lite/experimental/shlo/legacy:shlo", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/lite/experimental/shlo/legacy/test/concatenate_test.cc b/tensorflow/lite/experimental/shlo/legacy/test/concatenate_test.cc index 3494ad9940a58f..200490bef54f10 100644 --- a/tensorflow/lite/experimental/shlo/legacy/test/concatenate_test.cc +++ b/tensorflow/lite/experimental/shlo/legacy/test/concatenate_test.cc @@ -17,12 +17,12 @@ limitations under the License. #include #include #include -#include #include #include #include #include +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/lite/experimental/shlo/legacy/include/shlo.h" #include "tensorflow/lite/experimental/shlo/legacy/src/debug.h" // IWYU pragma: keep, b/321245930 @@ -39,7 +39,7 @@ struct TensorConst { }; template -std::string ToString(std::string_view name, +std::string ToString(absl::string_view name, const std::vector& tensors) { std::ostringstream result; for (size_t i = 0; i < tensors.size(); ++i) { diff --git a/tensorflow/lite/g3doc/examples/text_classification/overview.md b/tensorflow/lite/g3doc/examples/text_classification/overview.md index 5e836468c780ea..26c143269e5c0d 100644 --- a/tensorflow/lite/g3doc/examples/text_classification/overview.md +++ b/tensorflow/lite/g3doc/examples/text_classification/overview.md @@ -5,7 +5,7 @@ Use a TensorFlow Lite model to category a paragraph into predefined groups. Note: (1) To integrate an existing model, try [TensorFlow Lite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/nl_classifier). (2) To customize a model, try -[TensorFlow Lite Model Maker](https://www.tensorflow.org/lite/models/modify/model_maker/text_classification). +[TensorFlow Lite Model Maker](https://ai.google.dev/edge/litert/libraries/modify/text_classification). ## Get started @@ -13,10 +13,10 @@ Note: (1) To integrate an existing model, try If you are new to TensorFlow Lite and are working with Android, we recommend exploring the guide of -[TensorFLow Lite Task Library](../../inference_with_metadata/task_library/nl_classifier) +[TensorFLow Lite Task Library](../../inference_with_metadata/task_library/nl_classifier.md) to integrate text classification models within just a few lines of code. You can also integrate the model using the -[TensorFlow Lite Interpreter Java API](../../guide/inference#load_and_run_a_model_in_java). +[TensorFlow Lite Interpreter Java API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/guide/inference.md#load-and-run-a-model-in-java). The Android example below demonstrates the implementation for both methods as [lib_task_api](https://github.com/tensorflow/examples/tree/master/lite/examples/text_classification/android/lib_task_api) @@ -108,7 +108,7 @@ Performance benchmark numbers are generated with the tool ## Use your training dataset Follow this -[tutorial](https://www.tensorflow.org/lite/models/modify/model_maker/text_classification) +[tutorial](https://ai.google.dev/edge/litert/libraries/modify/text_classification) to apply the same technique used here to train a text classification model using your own datasets. With the right dataset, you can create a model for use cases such as document categorization or toxic comments detection. diff --git a/tensorflow/lite/g3doc/examples/video_classification/overview.md b/tensorflow/lite/g3doc/examples/video_classification/overview.md index a86ccf825fb6a7..2a15b884bfff06 100644 --- a/tensorflow/lite/g3doc/examples/video_classification/overview.md +++ b/tensorflow/lite/g3doc/examples/video_classification/overview.md @@ -26,7 +26,7 @@ already familiar with the [TensorFlow Lite APIs](https://www.tensorflow.org/api_docs/python/tf/lite), download the starter video classification model and the supporting files. You can also build your own custom inference pipeline using the -[TensorFlow Lite Support Library](../../inference_with_metadata/lite_support). +[TensorFlow Lite Support Library](../../inference_with_metadata/lite_support.md). Download starter model with metadata diff --git a/tensorflow/lite/g3doc/guide/ops_compatibility.md b/tensorflow/lite/g3doc/guide/ops_compatibility.md index 898481c74954c3..d5de333c7eaf58 100644 --- a/tensorflow/lite/g3doc/guide/ops_compatibility.md +++ b/tensorflow/lite/g3doc/guide/ops_compatibility.md @@ -1,21 +1,19 @@ # TensorFlow Lite and TensorFlow operator compatibility -The machine learning (ML) operators you use in your model can impact the -process of converting a -TensorFlow model to TensorFlow Lite format. The TensorFlow Lite converter -supports a limited number of TensorFlow operations used in common -inference models, which means that not every model is directly convertible. -The converter tool allows you to include additional operators, but converting -a model this way also requires you to modify the TensorFlow Lite runtime -environment you use to execute your model, which can limit your ability -use standard runtime deployment options, such as -[Google Play services](../android/play_services). - -The TensorFlow Lite Converter is designed to analyze model -structure and apply optimizations in order to make it compatible with the -directly supported operators. For example, depending on the ML operators in -your model, the converter may -[elide or fuse](../models/convert/operation_fusion) those +The machine learning (ML) operators you use in your model can impact the process +of converting a TensorFlow model to TensorFlow Lite format. The TensorFlow Lite +converter supports a limited number of TensorFlow operations used in common +inference models, which means that not every model is directly convertible. The +converter tool allows you to include additional operators, but converting a +model this way also requires you to modify the TensorFlow Lite runtime +environment you use to execute your model, which can limit your ability use +standard runtime deployment options, such as +[Google Play services](../android/play_services.md). + +The TensorFlow Lite Converter is designed to analyze model structure and apply +optimizations in order to make it compatible with the directly supported +operators. For example, depending on the ML operators in your model, the +converter may [elide or fuse](../models/convert/operation_fusion.md) those operators in order to map them to their TensorFlow Lite counterparts. Even for supported operations, specific usage patterns are sometimes expected, @@ -43,14 +41,13 @@ models supported by the conversion process: 1. Models with the built-in operators, TensorFlow core operators and/or custom operators. -If your model only contains operations that are natively supported by -TensorFlow Lite, you do not need any additional flags to convert it. This -is the recommended path because this type of model will convert smoothly -and is simpler to optimize and run using the default TensorFlow Lite runtime. -You also have more deployment options for your model such as -[Google Play services](../android/play_services). -You can get started with the -[TensorFlow Lite converter guide](../models/convert/convert_models). See +If your model only contains operations that are natively supported by TensorFlow +Lite, you do not need any additional flags to convert it. This is the +recommended path because this type of model will convert smoothly and is simpler +to optimize and run using the default TensorFlow Lite runtime. You also have +more deployment options for your model such as +[Google Play services](../android/play_services.md). You can get started with +the [TensorFlow Lite converter guide](../models/convert/convert_models.md). See the [TensorFlow Lite Ops page](https://www.tensorflow.org/mlir/tfl_ops) for a list of built-in operators. @@ -61,19 +58,17 @@ detailed steps. Whenever possible, avoid the last option of including custom operators in your converted model. [Custom operators](https://www.tensorflow.org/guide/create_op) -are either operators created by combining -multiple primitive TensorFlow core operators or defining a completely new one. -When custom operators are converted, they can increase the size of the overall -model by incurring dependencies outside of the built-in TensorFlow Lite library. -Custom ops, if not specifically created for mobile or device deployment, -can result in worse performance when deployed to -resource constrained devices compared to a server environment. +are either operators created by combining multiple primitive TensorFlow core +operators or defining a completely new one. When custom operators are converted, +they can increase the size of the overall model by incurring dependencies +outside of the built-in TensorFlow Lite library. Custom ops, if not specifically +created for mobile or device deployment, can result in worse performance when +deployed to resource constrained devices compared to a server environment. Finally, just like including select TensorFlow core operators, custom operators requires you to -[modify the model runtime environment](ops_custom#create_and_register_the_operator) -which limits you from taking advantage of standard runtime services such as -the [Google Play services](../android/play_services). - +[modify the model runtime environment](ops_custom.md#create-and-register-the-operator) +which limits you from taking advantage of standard runtime services such as the +[Google Play services](../android/play_services.md). ## Supported types diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/audio_classifier.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/audio_classifier.md index 5f62b56c0fde2c..0ca6603ce883ec 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/audio_classifier.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/audio_classifier.md @@ -29,7 +29,7 @@ The following models are guaranteed to be compatible with the `AudioClassifier` API. * Models created by - [TensorFlow Lite Model Maker for Audio Classification](https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/audio_classifier). + [TensorFlow Lite Model Maker for Audio Classification](https://ai.google.dev/edge/litert/libraries/modify/audio_classification). * The [pretrained audio event classification models on TensorFlow Hub](https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1). @@ -237,9 +237,10 @@ for more options to configure `TFLAudioClassifier`. pip install tflite-support ``` -Note: Task Library's Audio APIs rely on [PortAudio](http://www.portaudio.com/docs/v19-doxydocs/index.html) -to record audio from the device's microphone. If you intend to use Task -Library's [AudioRecord](/lite/api_docs/python/tflite_support/task/audio/AudioRecord) +Note: Task Library's Audio APIs rely on +[PortAudio](http://www.portaudio.com/docs/v19-doxydocs/index.html) to record +audio from the device's microphone. If you intend to use Task Library's +[AudioRecord](https://ai.google.dev/edge/api/tflite/python/tflite_support/task/audio/AudioRecord) for audio recording, you need to install PortAudio on your system. * Linux: Run `sudo apt-get update && apt-get install libportaudio2` diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md index f156880316ebb5..c1ce83285046c5 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md @@ -21,7 +21,7 @@ Sentencepiece tokenizations outside the TFLite model. The following models are compatible with the `BertNLClassifier` API. * Bert Models created by - [TensorFlow Lite Model Maker for text Classfication](https://www.tensorflow.org/lite/models/modify/model_maker/text_classification). + [TensorFlow Lite Model Maker for text Classfication](https://ai.google.dev/edge/litert/libraries/modify/text_classification). * Custom models that meet the [model compatibility requirements](#model-compatibility-requirements). @@ -148,7 +148,7 @@ for more options to configure `BertNLClassifier`. ## Example results Here is an example of the classification results of movie reviews using the -[MobileBert](https://www.tensorflow.org/lite/models/modify/model_maker/text_classification) +[MobileBert](https://ai.google.dev/edge/litert/libraries/modify/text_classification) model from Model Maker. Input: "it's a charming and often affecting journey" diff --git a/tensorflow/lite/g3doc/performance/best_practices.md b/tensorflow/lite/g3doc/performance/best_practices.md index 616583d353fe74..4d839b5c23d7be 100644 --- a/tensorflow/lite/g3doc/performance/best_practices.md +++ b/tensorflow/lite/g3doc/performance/best_practices.md @@ -38,7 +38,7 @@ help in understanding performance bottlenecks and which operators dominate the computation time. You can also use -[TensorFlow Lite tracing](measurement#trace_tensorflow_lite_internals_in_android) +[TensorFlow Lite tracing](measurement.md#trace-tensorflow-lite-internals-in-android) to profile the model in your Android application, using standard Android system tracing, and to visualize the operator invocations by time with GUI based profiling tools. @@ -51,7 +51,7 @@ look into optimizing that operator. This scenario should be rare as TensorFlow Lite has optimized versions for most operators. However, you may be able to write a faster version of a custom op if you know the constraints in which the operator is executed. Check out the -[custom operators guide](../guide/ops_custom). +[custom operators guide](../guide/ops_custom.md). ## Optimize your model @@ -59,7 +59,9 @@ Model optimization aims to create smaller models that are generally faster and more energy efficient, so that they can be deployed on mobile devices. TensorFlow Lite supports multiple optimization techniques, such as quantization. -Check out the [model optimization docs](model_optimization) for details. +Check out the +[model optimization docs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/model_optimization.md) +for details. ## Tweak the number of threads @@ -100,8 +102,10 @@ specific profiling tools and best practices for your platform. TensorFlow Lite has added new ways to accelerate models with faster hardware like GPUs, DSPs, and neural accelerators. Typically, these accelerators are -exposed through [delegate](delegates) submodules that take over parts of the -interpreter execution. TensorFlow Lite can use delegates by: +exposed through +[delegate](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/delegates.md) +submodules that take over parts of the interpreter execution. TensorFlow Lite +can use delegates by: * Using Android's [Neural Networks API](https://developer.android.com/ndk/guides/neuralnetworks/). @@ -110,23 +114,27 @@ interpreter execution. TensorFlow Lite can use delegates by: [NNAPI delegate](https://www.tensorflow.org/lite/android/delegates/nnapi) guide. * GPU delegate is available on Android and iOS, using OpenGL/OpenCL and Metal, - respectively. To try them out, see the [GPU delegate tutorial](gpu) and - [documentation](gpu_advanced). + respectively. To try them out, see the + [GPU delegate tutorial](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu.md) + and + [documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu.md#advanced-gpu-support). * Hexagon delegate is available on Android. It leverages the Qualcomm Hexagon DSP if it is available on the device. See the [Hexagon delegate tutorial](https://www.tensorflow.org/lite/android/delegates/hexagon) for more information. * It is possible to create your own delegate if you have access to - non-standard hardware. See [TensorFlow Lite delegates](delegates) for more - information. + non-standard hardware. See + [TensorFlow Lite delegates](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/delegates.md) + for more information. Be aware that some accelerators work better for different types of models. Some delegates only support float models or models optimized in a specific way. It is -important to [benchmark](measurement) each delegate to see if it is a good -choice for your application. For example, if you have a very small model, it may -not be worth delegating the model to either the NN API or the GPU. Conversely, -accelerators are a great choice for large models that have high arithmetic -intensity. +important to +[benchmark](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/measurement.md) +each delegate to see if it is a good choice for your application. For example, +if you have a very small model, it may not be worth delegating the model to +either the NN API or the GPU. Conversely, accelerators are a great choice for +large models that have high arithmetic intensity. ## Need more help diff --git a/tensorflow/lite/graph_info.cc b/tensorflow/lite/graph_info.cc index 5f7b466a7c10ca..59b750fc4d6581 100644 --- a/tensorflow/lite/graph_info.cc +++ b/tensorflow/lite/graph_info.cc @@ -45,13 +45,14 @@ class PartitionGraphIntoIndependentNodeSubsetsImpl { PartitionGraphIntoIndependentNodeSubsetsImpl( const GraphInfo* info, const TfLiteIntArray* nodes_to_partition, std::vector* node_subsets, bool greedily, - const ControlEdges& control_edges) + const ControlEdges& control_edges, bool disable_node_fusion) : info_(info), node_subsets_(node_subsets), node_type_(info_->num_total_nodes(), NodeSubset::kTfNonPartition), greedily_(greedily), control_edges_(control_edges), - num_incoming_control_edges_(info_->num_execution_nodes(), 0) { + num_incoming_control_edges_(info_->num_execution_nodes(), 0), + disable_node_fusion_(disable_node_fusion) { // Populate the node_type_ map. for (auto node_index : TfLiteIntArrayView(nodes_to_partition)) { node_type_[node_index] = NodeSubset::kTfPartition; @@ -134,6 +135,7 @@ class PartitionGraphIntoIndependentNodeSubsetsImpl { bool UpdateNode(int node_index) { const TfLiteNode& node = info_->node(node_index); NodeSubset& current_subset = node_subsets_->back(); + if (disable_node_fusion_ && !current_subset.nodes.empty()) return false; int current_epoch = node_subsets_->size() - 1; // Check if node is already done. if (node_epochs_[node_index] != kEpochNotReady) { @@ -257,6 +259,8 @@ class PartitionGraphIntoIndependentNodeSubsetsImpl { ControlEdges control_edges_; // Number of incoming control edges for each node. std::vector num_incoming_control_edges_; + // Whether to disable node fusion. + const bool disable_node_fusion_; }; // LINT.ThenChange(//tensorflow/lite/delegates/utils.h) @@ -265,7 +269,7 @@ class PartitionGraphIntoIndependentNodeSubsetsImpl { TfLiteStatus PartitionGraphIntoIndependentNodeSubsets( const GraphInfo* info, const TfLiteIntArray* nodes_to_partition, std::vector* node_subsets, bool greedily, - const ControlEdges* control_edges) { + const ControlEdges* control_edges, bool disable_node_fusion) { ControlEdges my_control_edges; if (control_edges == nullptr) { control_edges = &my_control_edges; @@ -284,7 +288,8 @@ TfLiteStatus PartitionGraphIntoIndependentNodeSubsets( } } PartitionGraphIntoIndependentNodeSubsetsImpl( - info, nodes_to_partition, node_subsets, greedily, *control_edges) + info, nodes_to_partition, node_subsets, greedily, *control_edges, + disable_node_fusion) .Partition(); return kTfLiteOk; } diff --git a/tensorflow/lite/graph_info.h b/tensorflow/lite/graph_info.h index c72c5c3efe620f..9b7a6acedfb01e 100644 --- a/tensorflow/lite/graph_info.h +++ b/tensorflow/lite/graph_info.h @@ -154,7 +154,8 @@ using ControlEdges = std::vector; TfLiteStatus PartitionGraphIntoIndependentNodeSubsets( const GraphInfo* info, const TfLiteIntArray* nodes_to_partition, std::vector* node_subsets, bool greedily, - const ControlEdges* control_edges = nullptr); + const ControlEdges* control_edges = nullptr, + bool disable_node_fusion = false); } // namespace tflite diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 6ae76e4cb1ce21..9b18bbafa8b5d3 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -40,77 +40,14 @@ config_setting( define_values = {"tflite_with_ruy": "false"}, ) -###### Beginning of config_setting's to match aarch64 ###### -# -# We need to identify the aarch64 instruction set to decide whether to enable -# TFLITE_WITH_RUY by default. This is surprisingly hard to do because select() -# can only consume config_setting's, these config_settings are not centralized, -# and the "cpu" value which they define are free-form strings and there is no -# standardization of the strings that we need to match for the aarch64 architecture. -# -# First, we have the case of --config=chromiumos_arm, which defines cpu=arm but is -# actually aarch64. For it, we name our config_setting chromiumos_arm64 to avoid -# adding to the confusion, at the cost of diverging from the --config name. -# This example shows that we can never hope to match aarch64 by looking only at -# "cpu", since the value "arm" would be used to mean the (32-bit) ARM instruction set -# in other configs. config_setting( - name = "chromiumos_arm64", - values = { - "crosstool_top": "//external:chromiumos/crosstool", - "cpu": "arm", - }, - visibility = ["//visibility:private"], -) - -# Next, several "cpu" values that unambiguously mean aarch64, that are observed in -# practice with --config's that we care to support: - -# This is defined by the tensorflow:linux_aarch64 config_setting. -config_setting( - name = "cpu_aarch64", - values = {"cpu": "aarch64"}, - visibility = ["//visibility:private"], -) - -# This is defined by some config_setting's in the wild and is a reasonable value to -# support anyway. -config_setting( - name = "cpu_arm64", - values = {"cpu": "arm64"}, - visibility = ["//visibility:private"], -) - -# This is the value defined by --config=ios_arm64. -config_setting( - name = "cpu_ios_arm64", - values = {"cpu": "ios_arm64"}, - visibility = ["//visibility:private"], -) - -# arm64e variants of the above two. See: -# https://stackoverflow.com/questions/52624308/xcode-arm64-vs-arm64e -config_setting( - name = "cpu_arm64e", - values = {"cpu": "arm64e"}, - visibility = ["//visibility:private"], -) - -config_setting( - name = "cpu_ios_arm64e", - values = {"cpu": "ios_arm64e"}, - visibility = ["//visibility:private"], -) - -# This is the value defined by --config=android_arm64 -config_setting( - name = "cpu_arm64_v8a", - values = {"cpu": "arm64-v8a"}, + name = "aarch64", + constraint_values = [ + "@platforms//cpu:aarch64", + ], visibility = ["//visibility:private"], ) -###### End of config_setting's to match aarch64 ###### - # Suppress warnings that are introduced by Eigen Tensor. EXTRA_EIGEN_COPTS = select({ "//tensorflow:ios": [ @@ -340,13 +277,7 @@ cc_library( compatible_with = get_compatible_with_portable(), visibility = ["//visibility:private"], deps = select({ - ":chromiumos_arm64": [":tflite_with_ruy_enabled"], - ":cpu_aarch64": [":tflite_with_ruy_enabled"], - ":cpu_arm64": [":tflite_with_ruy_enabled"], - ":cpu_arm64e": [":tflite_with_ruy_enabled"], - ":cpu_ios_arm64": [":tflite_with_ruy_enabled"], - ":cpu_ios_arm64e": [":tflite_with_ruy_enabled"], - ":cpu_arm64_v8a": [":tflite_with_ruy_enabled"], + ":aarch64": [":tflite_with_ruy_enabled"], "//tensorflow:android_arm": ["tflite_with_ruy_enabled"], "//conditions:default": [], }), diff --git a/tensorflow/lite/kernels/cpu_backend_gemm.h b/tensorflow/lite/kernels/cpu_backend_gemm.h index af91b0a6de7336..af100cb204b1df 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm.h @@ -37,18 +37,21 @@ namespace cpu_backend_gemm { // The main entry point for CpuBackendGemm::Gemm. // // If TFLITE_WITH_RUY is set, CpuBackendGemm::Gemm will always go to Ruy aka -// GemmImplUsingRuy. Other cases are as follows: +// GemmImplUsingRuy. The behavior is as follows: // // |Quantized (uint8)|Quantized (int8)| Float | // TFLITE_WITH_RUY | Ruy | Ruy | Ruy | // !TFLITE_WITH_RUY | gemmlowp | Ruy/gemmlowp* | eigen | // * - Ruy if NEON is not available. - -// On x86 platforms: +// +// On most ARM32/ARM64 platforms, the default is TFLITE_WITH_RUY: +// (default) | Ruy | Ruy | Ruy | +// +// On other platforms (including x86), the default is !TFLITE_WITH_RUY: // (default) | gemmlowp | Ruy | eigen | -// TFLITE_X86_RUY_\ | Ruy | Ruy | Ruy | -// ENABLED && (AVX -// or above available) +// +// Use --define=tflite_with_ruy=true or --define=tflite_with_ruy=false to +// override the default. #if !defined(TFLITE_WITH_RUY) && defined(TFLITE_X86_PLATFORM) /* GEMM dispatch implementation for x86. diff --git a/tensorflow/lite/kernels/detection_postprocess.cc b/tensorflow/lite/kernels/detection_postprocess.cc index d1ccdd7fad8c45..eb49f823558d1e 100644 --- a/tensorflow/lite/kernels/detection_postprocess.cc +++ b/tensorflow/lite/kernels/detection_postprocess.cc @@ -517,7 +517,7 @@ void InplaceMergeBoxInfo(std::vector& boxes, int mid_index, int end_index) { std::inplace_merge( boxes.begin(), boxes.begin() + mid_index, boxes.begin() + end_index, - [](const BoxInfo& a, const BoxInfo& b) { return a.score >= b.score; }); + [](const BoxInfo& a, const BoxInfo& b) { return a.score > b.score; }); } TfLiteStatus ComputeNMSResult(const NMSTaskParam& nms_task_param, int col_begin, diff --git a/tensorflow/lite/kernels/detection_postprocess_test.cc b/tensorflow/lite/kernels/detection_postprocess_test.cc index 856a577013f870..938d47eb3e20f0 100644 --- a/tensorflow/lite/kernels/detection_postprocess_test.cc +++ b/tensorflow/lite/kernels/detection_postprocess_test.cc @@ -379,7 +379,8 @@ class DetectionPostprocessOpModelwithRegularNMS : public SingleOpModel { const TensorData& input1, const TensorData& input2, const TensorData& input3, const TensorData& output1, const TensorData& output2, const TensorData& output3, - const TensorData& output4, bool use_regular_nms, int num_threads = 1) { + const TensorData& output4, bool use_regular_nms, int num_threads = 1, + int max_detections = 3, int detection_per_class = 1) { input1_ = AddInput(input1); input2_ = AddInput(input2); input3_ = AddInput(input3); @@ -390,9 +391,9 @@ class DetectionPostprocessOpModelwithRegularNMS : public SingleOpModel { flexbuffers::Builder fbb; fbb.Map([&]() { - fbb.Int("max_detections", 3); + fbb.Int("max_detections", max_detections); fbb.Int("max_classes_per_detection", 1); - fbb.Int("detections_per_class", 1); + fbb.Int("detections_per_class", detection_per_class); fbb.Bool("use_regular_nms", use_regular_nms); fbb.Float("nms_score_threshold", 0.0); fbb.Float("nms_iou_threshold", 0.5); @@ -702,6 +703,234 @@ TEST_P(DetectionPostprocessOpRegularTest, RegularNMS) { } } +TEST_P(DetectionPostprocessOpRegularTest, RegularNMSWithEqualScores) { + TensorData input1, input2, input3; + if (tensor_type_ == TensorType_UINT8) { + input1 = {tensor_type_, {1, 6, 4}, -1.0, 1.0}; + input2 = {tensor_type_, {1, 6, 3}, 0.0, 1.0}; + input3 = {tensor_type_, {6, 4}, 0.0, 100.5}; + } else { + input1 = {tensor_type_, {1, 6, 4}}; + input2 = {tensor_type_, {1, 6, 3}}; + input3 = {tensor_type_, {6, 4}}; + } + DetectionPostprocessOpModelwithRegularNMS m( + input1, input2, input3, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}, true, num_threads_, /*max_detections=*/4, + /*detection_per_class=*/2); + auto inputs1 = { + 0.0f, 0.0f, 0.0f, 0.0f, // box #1 (0, 0, 1, 1) + 0.0f, 0.0f, 0.0f, 0.0f, // box #2 (0, 1, 1, 2) + 0.0f, 0.0f, 0.0f, 0.0f, // box #3 (0, 5, 1, 6) + 0.0f, 0.0f, 0.0f, 0.0f, // box #4 (0, 10, 1, 11) + 0.0f, 0.0f, 0.0f, 0.0f, // box #5 (0, 20, 1, 21) + 0.0f, 0.0f, 0.0f, 0.0f // box #6 (0, 100, 1, 101) + }; + + if (tensor_type_ == TensorType_UINT8) { + m.QuantizeAndPopulate(m.input1(), std::vector{inputs1}); + } else { + m.SetInput1(inputs1); + } + // class scores - two classes with background + auto inputs2 = { + 0.f, .1f, 0.1f, // box #1 + 0.f, .1f, 0.96f, // box #2 + 0.f, .1f, 0.9f, // box #3 + 0.f, .95f, 0.1f, // box #4 + 0.f, .9f, 0.1f, // box #5 + 0.f, .1f, 0.1f // box #6 + }; + if (tensor_type_ == TensorType_UINT8) { + m.QuantizeAndPopulate(m.input2(), std::vector{inputs2}); + } else { + m.SetInput2(inputs2); + } + // six anchors in center-size encoding + auto inputs3 = { + 0.5f, 0.5f, 1.0f, 1.0f, // box #1 + 0.5f, 1.5f, 1.0f, 1.0f, // box #2 + 0.5f, 5.5f, 1.0f, 1.0f, // box #3 + 0.5f, 10.5f, 1.0f, 1.0f, // box #4 + 0.5f, 20.5f, 1.0f, 1.0f, // box #5 + 0.5f, 100.5f, 1.0f, 1.0f // box #6 + }; + if (tensor_type_ == TensorType_UINT8) { + m.QuantizeAndPopulate(m.input3(), std::vector{inputs3}); + } else { + m.SetInput3(inputs3); + } + ASSERT_EQ(m.Invoke(), kTfLiteOk); + // detection_boxes + // in center-size + std::vector output_shape1 = m.GetOutputShape1(); + EXPECT_THAT(output_shape1, ElementsAre(1, 4, 4)); + if (tensor_type_ == TensorType_UINT8) { + EXPECT_THAT(m.GetOutput1(), ElementsAreArray(ArrayFloatNear( + { + 0, 1, 1, 2, // box #2 + 0, 10, 1, 11, // box #4 + 0, 20, 1, 21, // box #5 + 0, 5, 1, 6 // box #3 + }, + 3e-1))); + } else { + EXPECT_THAT(m.GetOutput1(), ElementsAreArray(ArrayFloatNear( + { + 0, 1, 1, 2, // box #2 + 0, 10, 1, 11, // box #4 + 0, 20, 1, 21, // box #5 + 0, 5, 1, 6 // box #3 + }, + 3e-4))); + } + // detection_classes + std::vector output_shape2 = m.GetOutputShape2(); + EXPECT_THAT(output_shape2, ElementsAre(1, 4)); + if (tensor_type_ == TensorType_UINT8) { + EXPECT_THAT(m.GetOutput2(), + ElementsAreArray(ArrayFloatNear({1, 0, 0, 1}, 1e-1))); + } else { + EXPECT_THAT(m.GetOutput2(), + ElementsAreArray(ArrayFloatNear({1, 0, 0, 1}, 1e-4))); + } + // detection_scores + std::vector output_shape3 = m.GetOutputShape3(); + EXPECT_THAT(output_shape3, ElementsAre(1, 4)); + if (tensor_type_ == TensorType_UINT8) { + EXPECT_THAT(m.GetOutput3(), + ElementsAreArray(ArrayFloatNear({0.96, 0.95, 0.9, 0.9}, 1e-1))); + } else { + EXPECT_THAT(m.GetOutput3(), + ElementsAreArray(ArrayFloatNear({0.96, 0.95, 0.9, 0.9}, 1e-4))); + } + // num_detections + std::vector output_shape4 = m.GetOutputShape4(); + EXPECT_THAT(output_shape4, ElementsAre(1)); + if (tensor_type_ == TensorType_UINT8) { + EXPECT_THAT(m.GetOutput4(), + ElementsAreArray(ArrayFloatNear({4.0}, 1e-1))); + } else { + EXPECT_THAT(m.GetOutput4(), + ElementsAreArray(ArrayFloatNear({4.0}, 1e-4))); + } +} + +TEST_P(DetectionPostprocessOpRegularTest, FastNMSWithEqualScores) { + TensorData input1, input2, input3; + if (tensor_type_ == TensorType_UINT8) { + input1 = {tensor_type_, {1, 6, 4}, -1.0, 1.0}; + input2 = {tensor_type_, {1, 6, 3}, 0.0, 1.0}; + input3 = {tensor_type_, {6, 4}, 0.0, 100.5}; + } else { + input1 = {tensor_type_, {1, 6, 4}}; + input2 = {tensor_type_, {1, 6, 3}}; + input3 = {tensor_type_, {6, 4}}; + } + DetectionPostprocessOpModelwithRegularNMS m( + input1, input2, input3, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}, false, num_threads_, /*max_detections=*/4, + /*detection_per_class=*/2); + auto inputs1 = { + 0.0f, 0.0f, 0.0f, 0.0f, // box #1 (0, 0, 1, 1) + 0.0f, 0.0f, 0.0f, 0.0f, // box #2 (0, 1, 1, 2) + 0.0f, 0.0f, 0.0f, 0.0f, // box #3 (0, 5, 1, 6) + 0.0f, 0.0f, 0.0f, 0.0f, // box #4 (0, 10, 1, 11) + 0.0f, 0.0f, 0.0f, 0.0f, // box #5 (0, 20, 1, 21) + 0.0f, 0.0f, 0.0f, 0.0f // box #6 (0, 100, 1, 101) + }; + + if (tensor_type_ == TensorType_UINT8) { + m.QuantizeAndPopulate(m.input1(), std::vector{inputs1}); + } else { + m.SetInput1(inputs1); + } + // class scores - two classes with background + auto inputs2 = { + 0.f, .1f, 0.1f, // box #1 + 0.f, .1f, 0.96f, // box #2 + 0.f, .1f, 0.9f, // box #3 + 0.f, .95f, 0.1f, // box #4 + 0.f, .9f, 0.1f, // box #5 + 0.f, .1f, 0.1f // box #6 + }; + if (tensor_type_ == TensorType_UINT8) { + m.QuantizeAndPopulate(m.input2(), std::vector{inputs2}); + } else { + m.SetInput2(inputs2); + } + // six anchors in center-size encoding + auto inputs3 = { + 0.5f, 0.5f, 1.0f, 1.0f, // box #1 + 0.5f, 1.5f, 1.0f, 1.0f, // box #2 + 0.5f, 5.5f, 1.0f, 1.0f, // box #3 + 0.5f, 10.5f, 1.0f, 1.0f, // box #4 + 0.5f, 20.5f, 1.0f, 1.0f, // box #5 + 0.5f, 100.5f, 1.0f, 1.0f // box #6 + }; + if (tensor_type_ == TensorType_UINT8) { + m.QuantizeAndPopulate(m.input3(), std::vector{inputs3}); + } else { + m.SetInput3(inputs3); + } + ASSERT_EQ(m.Invoke(), kTfLiteOk); + // detection_boxes + // in center-size + std::vector output_shape1 = m.GetOutputShape1(); + EXPECT_THAT(output_shape1, ElementsAre(1, 4, 4)); + if (tensor_type_ == TensorType_UINT8) { + EXPECT_THAT(m.GetOutput1(), ElementsAreArray(ArrayFloatNear( + { + 0, 1, 1, 2, // box #2 + 0, 10, 1, 11, // box #4 + 0, 5, 1, 6, // box #3 + 0, 20, 1, 21 // box #5 + }, + 3e-1))); + } else { + EXPECT_THAT(m.GetOutput1(), ElementsAreArray(ArrayFloatNear( + { + 0, 1, 1, 2, // box #2 + 0, 10, 1, 11, // box #4 + 0, 5, 1, 6, // box #3 + 0, 20, 1, 21 // box #5 + }, + 3e-4))); + } + // detection_classes + std::vector output_shape2 = m.GetOutputShape2(); + EXPECT_THAT(output_shape2, ElementsAre(1, 4)); + if (tensor_type_ == TensorType_UINT8) { + EXPECT_THAT(m.GetOutput2(), + ElementsAreArray(ArrayFloatNear({1, 0, 1, 0}, 1e-1))); + } else { + EXPECT_THAT(m.GetOutput2(), + ElementsAreArray(ArrayFloatNear({1, 0, 1, 0}, 1e-4))); + } + // detection_scores + std::vector output_shape3 = m.GetOutputShape3(); + EXPECT_THAT(output_shape3, ElementsAre(1, 4)); + if (tensor_type_ == TensorType_UINT8) { + EXPECT_THAT(m.GetOutput3(), + ElementsAreArray(ArrayFloatNear({0.96, 0.95, 0.9, 0.9}, 1e-1))); + } else { + EXPECT_THAT(m.GetOutput3(), + ElementsAreArray(ArrayFloatNear({0.96, 0.95, 0.9, 0.9}, 1e-4))); + } + // num_detections + std::vector output_shape4 = m.GetOutputShape4(); + EXPECT_THAT(output_shape4, ElementsAre(1)); + if (tensor_type_ == TensorType_UINT8) { + EXPECT_THAT(m.GetOutput4(), + ElementsAreArray(ArrayFloatNear({4.0}, 1e-1))); + } else { + EXPECT_THAT(m.GetOutput4(), + ElementsAreArray(ArrayFloatNear({4.0}, 1e-4))); + } +} + TEST(DetectionPostprocessOpTest, FloatTestwithNoBackgroundClassAndNoKeypoints) { DetectionPostprocessOpModelwithRegularNMS m( {TensorType_FLOAT32, {1, 6, 4}}, {TensorType_FLOAT32, {1, 6, 2}}, diff --git a/tensorflow/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc index ea4a04b0482220..1239a3888677f8 100644 --- a/tensorflow/lite/kernels/fully_connected_test.cc +++ b/tensorflow/lite/kernels/fully_connected_test.cc @@ -787,6 +787,54 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestPerChannelQuantizedInt8) { EXPECT_THAT(m.GetOutput(), ElementsAre(23, 24, 25, 57, 58, 59)); } +TEST_P(QuantizedFullyConnectedOpTest, + SimpleTestPerChannelQuantizedOutputShape3DInt8) { + if (SingleOpModel::GetForceUseNnapi()) { + return; + } + + PerChannelQuantizedFullyConnectedOpModel m( + GetRegistration(), /*units=*/3, /*batches*/ 2, + /*input=*/{TensorType_INT8, {2, 2, 5}, -63.5, 64}, + /*per_channel_quantization_scales=*/{0.2, 0.25, 0.5}, + /*output=*/{TensorType_INT8, {}, -127, 128}, + /*bias_type=*/TensorType_INT32, + /*keep_num_dims=*/true, /*bias_tensor_optional=*/false, + /*activation_func=*/ActivationFunctionType_RELU, + /*weights_format=*/FullyConnectedOptionsWeightsFormat_DEFAULT, + /*input_size=*/5); + + // input_product_scale < output_scale was not true. + m.SetWeights({ + 1, 2, 3, 4, 5, // u = 0 + 1, 2, 3, 4, 5, // u = 1 + 1, 2, 3, 4, 5, // u = 2 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, -5, // b = 0, i = 0 + 1, 2, 3, -4, 5, // b = 0, i = 1 + 1, 2, -3, 4, 5, // b = 1, i = 0 + 1, -2, 3, 4, 5, // b = 1, i = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 6, 7, 8, // b = 0, i = 0 + 24, 25, 26, // b = 0, i = 1 + 38, 39, 40, // b = 1, i = 0 + 48, 49, 50 // b = 1, i = 1 + }))); + EXPECT_THAT(m.GetOutput(), ElementsAre(5, 6, 7, // b = 0, i = 0 + 23, 24, 25, // b = 0, i = 1 + 37, 38, 39, // b = 1, i = 0 + 47, 48, 49 // b = 1, i = 1 + )); +} + TEST_P(QuantizedFullyConnectedOpTest, SimpleTestPerChannelQuantizedInt4) { PerChannelQuantizedFullyConnectedOpModel m( GetRegistration(), /*units=*/3, /*batches*/ 2, diff --git a/tensorflow/lite/kernels/internal/common.h b/tensorflow/lite/kernels/internal/common.h index 9761a8cc07a8ec..4d990d70aa0c7c 100644 --- a/tensorflow/lite/kernels/internal/common.h +++ b/tensorflow/lite/kernels/internal/common.h @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/lite/kernels/internal/runtime_shape.h" #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK diff --git a/tensorflow/lite/kernels/internal/reference/comparisons.h b/tensorflow/lite/kernels/internal/reference/comparisons.h index a9f1e42c0a6c94..e40e4045cc7ff4 100644 --- a/tensorflow/lite/kernels/internal/reference/comparisons.h +++ b/tensorflow/lite/kernels/internal/reference/comparisons.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_ +#include + #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/macros.h" #include "tensorflow/lite/kernels/internal/common.h" diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h b/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h index 3a74402ed98a1c..c6d06077934839 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h @@ -42,12 +42,14 @@ void FullyConnectedPerChannel( const int32_t output_activation_min = params.quantized_activation_min; const int32_t output_activation_max = params.quantized_activation_max; TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); - TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2); + TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); const int filter_dim_count = filter_shape.DimensionsCount(); - const int batches = output_shape.Dims(0); - const int output_depth = output_shape.Dims(1); + + const int output_dim_count = output_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_depth = output_shape.Dims(output_dim_count - 1); TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2)); const int accum_depth = filter_shape.Dims(filter_dim_count - 1); for (int b = 0; b < batches; ++b) { diff --git a/tensorflow/lite/kernels/internal/types.h b/tensorflow/lite/kernels/internal/types.h index f2cc1603c652fe..510ffa30498319 100644 --- a/tensorflow/lite/kernels/internal/types.h +++ b/tensorflow/lite/kernels/internal/types.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/runtime_shape.h" diff --git a/tensorflow/lite/kernels/pad.cc b/tensorflow/lite/kernels/pad.cc index 0b3366e30c88e5..41fdf735dbb1ee 100644 --- a/tensorflow/lite/kernels/pad.cc +++ b/tensorflow/lite/kernels/pad.cc @@ -96,7 +96,9 @@ bool CheckPaddingOverflow(PadContext* op_context) { static_cast(std::numeric_limits::min()); int64_t int32_max = static_cast(std::numeric_limits::max()); - for (int idx = 0; idx < op_context->dims; ++idx) { + const int paddings_total = + GetTensorShape(op_context->paddings).FlatSize(); + for (int idx = 0; idx < paddings_total; ++idx) { int64_t padding = paddings_data[idx]; if (padding < int32_min || padding > int32_max) { return true; diff --git a/tensorflow/lite/kernels/pad_test.cc b/tensorflow/lite/kernels/pad_test.cc index 6fc7e79719a093..c3655897022444 100644 --- a/tensorflow/lite/kernels/pad_test.cc +++ b/tensorflow/lite/kernels/pad_test.cc @@ -242,6 +242,12 @@ TEST_F(PadOpTest, Int64PaddingOverflow) { {TensorType_FLOAT32}), "INT64 padding overflow. Only support value between INT32_MIN " "and INT32_MAX."); + EXPECT_DEATH(PadOpConstModel( + {TensorType_FLOAT32, {1, 1, 2, 1}}, {4, 2}, + {0, 0, 1, -1, 2, -1, std::numeric_limits::max(), 0}, + {TensorType_FLOAT32}), + "INT64 padding overflow. Only support value between INT32_MIN " + "and INT32_MAX."); } #endif diff --git a/tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.cc b/tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.cc index 69d9fc84a5836e..45336ea3b67e36 100644 --- a/tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.cc +++ b/tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.cc @@ -70,13 +70,14 @@ uint8 PeekTag(protobuf::io::CodedInputStream* stream) { return *static_cast(ptr); } -bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) { +bool ParseString(protobuf::io::CodedInputStream* stream, + absl::string_view* result) { DCHECK(stream != nullptr); DCHECK(result != nullptr); uint32 length; if (!stream->ReadVarint32(&length)) return false; if (length == 0) { - *result = StringPiece(nullptr, 0); + *result = absl::string_view(nullptr, 0); return true; } const void* stream_alias; @@ -85,7 +86,7 @@ bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) { return false; } if (static_cast(stream_size) < length) return false; - *result = StringPiece(static_cast(stream_alias), length); + *result = absl::string_view(static_cast(stream_alias), length); stream->Skip(length); return true; } @@ -100,7 +101,7 @@ bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream, if (!stream->ExpectTag(kDelimitedTag(1))) return false; if (!ParseString(stream, &feature_map_entry->first)) return false; if (!stream->ExpectTag(kDelimitedTag(2))) return false; - StringPiece feature_string_piece; + absl::string_view feature_string_piece; if (!ParseString(stream, &feature_string_piece)) return false; feature_map_entry->second = parsed::Feature(feature_string_piece); if (!stream->ExpectAtEnd()) return false; @@ -142,7 +143,7 @@ bool ParseExample(protobuf::io::CodedInputStream* stream, return true; } -bool ParseExample(StringPiece serialized, parsed::Example* example) { +bool ParseExample(absl::string_view serialized, parsed::Example* example) { DCHECK(example != nullptr); protobuf::io::CodedInputStream stream( reinterpret_cast(serialized.data()), serialized.size()); diff --git a/tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.h b/tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.h index da82f3c34199cf..018e813a498490 100644 --- a/tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.h +++ b/tensorflow/lite/kernels/parse_example/example_proto_fast_parsing.h @@ -113,9 +113,9 @@ namespace parsed { class Feature { public: Feature() {} - explicit Feature(StringPiece serialized) : serialized_(serialized) {} + explicit Feature(absl::string_view serialized) : serialized_(serialized) {} - Status ParseDataType(DataType* dtype) { + absl::Status ParseDataType(DataType* dtype) { DCHECK(dtype != nullptr); if (serialized_.empty()) { *dtype = DT_INVALID; @@ -315,13 +315,13 @@ class Feature { return true; } - StringPiece GetSerialized() const { return serialized_; } + absl::string_view GetSerialized() const { return serialized_; } private: - StringPiece serialized_; + absl::string_view serialized_; }; -using FeatureMapEntry = std::pair; +using FeatureMapEntry = std::pair; using Example = std::vector; } // namespace parsed @@ -351,7 +351,8 @@ inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) { return false; // unrecognized tag type } -bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result); +bool ParseString(protobuf::io::CodedInputStream* stream, + absl::string_view* result); bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream, parsed::FeatureMapEntry* feature_map_entry); @@ -362,7 +363,7 @@ bool ParseFeatures(protobuf::io::CodedInputStream* stream, bool ParseExample(protobuf::io::CodedInputStream* stream, parsed::Example* example); -bool ParseExample(StringPiece serialized, parsed::Example* example); +bool ParseExample(absl::string_view serialized, parsed::Example* example); using Config = FastParseExampleConfig; @@ -386,7 +387,7 @@ struct SparseBuffer { }; struct SeededHasher { - uint64 operator()(StringPiece s) const { + uint64 operator()(absl::string_view s) const { return Hash64(s.data(), s.size(), seed); } uint64 seed{0xDECAFCAFFE}; @@ -435,7 +436,7 @@ struct FeatureProtos { // Proto substrings from each serialized SequenceExample that correspond // with this feature. `protos_present` records whether the proto had a // value defined (even if that value is empty). - std::vector protos; + std::vector protos; std::vector protos_present; // Information derived from protos: @@ -448,7 +449,7 @@ struct FeatureProtos { }; // Map from feature name to FeatureProtos for that feature. -using FeatureProtosMap = absl::flat_hash_map; +using FeatureProtosMap = absl::flat_hash_map; string ExampleName(const absl::Span example_names, int n); diff --git a/tensorflow/lite/kernels/perception/max_pool_with_argmax.cc b/tensorflow/lite/kernels/perception/max_pool_with_argmax.cc index d1b924066b23e3..cb0eb842000821 100644 --- a/tensorflow/lite/kernels/perception/max_pool_with_argmax.cc +++ b/tensorflow/lite/kernels/perception/max_pool_with_argmax.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include +#include +#include #include #include "flatbuffers/flexbuffers.h" // from @flatbuffers diff --git a/tensorflow/lite/kernels/perception/max_pool_with_argmax_test.cc b/tensorflow/lite/kernels/perception/max_pool_with_argmax_test.cc index 082851d59c4488..b87bbd8be4a8f3 100644 --- a/tensorflow/lite/kernels/perception/max_pool_with_argmax_test.cc +++ b/tensorflow/lite/kernels/perception/max_pool_with_argmax_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include diff --git a/tensorflow/lite/kernels/perception/max_unpooling_2d.cc b/tensorflow/lite/kernels/perception/max_unpooling_2d.cc index 869a9457a9f49d..7c99c1c72a69b9 100644 --- a/tensorflow/lite/kernels/perception/max_unpooling_2d.cc +++ b/tensorflow/lite/kernels/perception/max_unpooling_2d.cc @@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/internal/runtime_shape.h" diff --git a/tensorflow/lite/kernels/perception/perception_ops_wrapper.cc b/tensorflow/lite/kernels/perception/perception_ops_wrapper.cc index cd5e96eceacfdd..ed36f12c3d676d 100644 --- a/tensorflow/lite/kernels/perception/perception_ops_wrapper.cc +++ b/tensorflow/lite/kernels/perception/perception_ops_wrapper.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/pytypes.h" // from @pybind11 #include "tensorflow/lite/kernels/perception/perception_ops.h" diff --git a/tensorflow/lite/kernels/shim/README.md b/tensorflow/lite/kernels/shim/README.md index 5e7f852dced309..a517f87de5c0b6 100644 --- a/tensorflow/lite/kernels/shim/README.md +++ b/tensorflow/lite/kernels/shim/README.md @@ -35,7 +35,7 @@ This folder contains two pieces: ### TensorView This class is a *view* over an already allocated tensor in TF or TFLite without -taking any ownership. In that sense it is similar to `std::string_view` but with +taking any ownership. In that sense it is similar to `absl::string_view` but with the difference that the underlying buffer can be mutable. Example Usage: diff --git a/tensorflow/lite/kernels/shim/test_op/simple_tflite_op_test.cc b/tensorflow/lite/kernels/shim/test_op/simple_tflite_op_test.cc index 786cb755eefd6b..61c0ccfbb375dc 100644 --- a/tensorflow/lite/kernels/shim/test_op/simple_tflite_op_test.cc +++ b/tensorflow/lite/kernels/shim/test_op/simple_tflite_op_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/kernels/shim/test_op/simple_tflite_op.h" -#include +#include #include #include diff --git a/tensorflow/lite/kernels/shim/test_op/tmpl_tflite_op.cc b/tensorflow/lite/kernels/shim/test_op/tmpl_tflite_op.cc index b53b0f86f48aea..8c656ddfd7739b 100644 --- a/tensorflow/lite/kernels/shim/test_op/tmpl_tflite_op.cc +++ b/tensorflow/lite/kernels/shim/test_op/tmpl_tflite_op.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/kernels/shim/test_op/tmpl_tflite_op.h" +#include + #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/shim/op_kernel.h" #include "tensorflow/lite/kernels/shim/test_op/tmpl_op.h" diff --git a/tensorflow/lite/optional_debug_tools_test.cc b/tensorflow/lite/optional_debug_tools_test.cc index c581a5029014ef..66030815a1e017 100644 --- a/tensorflow/lite/optional_debug_tools_test.cc +++ b/tensorflow/lite/optional_debug_tools_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/optional_debug_tools.h" #include +#include #include #include "tensorflow/lite/core/interpreter.h" diff --git a/tensorflow/lite/profiling/telemetry/profiler_test.cc b/tensorflow/lite/profiling/telemetry/profiler_test.cc index d9d20d9f08f4fc..6168a57d693c24 100644 --- a/tensorflow/lite/profiling/telemetry/profiler_test.cc +++ b/tensorflow/lite/profiling/telemetry/profiler_test.cc @@ -15,9 +15,7 @@ limitations under the License. #include "tensorflow/lite/profiling/telemetry/profiler.h" #include -#include #include -#include #include #include diff --git a/tensorflow/lite/profiling/telemetry/telemetry_test.cc b/tensorflow/lite/profiling/telemetry/telemetry_test.cc index 73bb6b7a28b719..39ac1c4822e6df 100644 --- a/tensorflow/lite/profiling/telemetry/telemetry_test.cc +++ b/tensorflow/lite/profiling/telemetry/telemetry_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/profiling/telemetry/telemetry.h" +#include #include #include diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index cc633399dc352a..f1a80f0b92fe06 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -10,6 +10,7 @@ package( "//tensorflow:__subpackages__", "//tensorflow:internal", "//third_party/odml/model_customization/quantization:__subpackages__", + "//third_party/py/ai_edge_torch:__subpackages__", "//third_party/py/tensorflow_federated:__subpackages__", "//third_party/tflite_micro:__subpackages__", ], diff --git a/tensorflow/lite/python/analyzer_wrapper/model_analyzer.cc b/tensorflow/lite/python/analyzer_wrapper/model_analyzer.cc index fb8baaccd71f74..606d2192af3839 100644 --- a/tensorflow/lite/python/analyzer_wrapper/model_analyzer.cc +++ b/tensorflow/lite/python/analyzer_wrapper/model_analyzer.cc @@ -13,9 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include +#include +#include +#include +#include #include #include #include +#include #include "absl/strings/str_join.h" #include "flatbuffers/vector.h" // from @flatbuffers diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 16995e47c83b71..2519835376fe4b 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -447,6 +447,7 @@ def build_conversion_flags( use_buffer_offset=False, reduce_type_precision=False, qdq_conversion_mode=None, + strict_qdq_mode=False, disable_per_channel_quantization_for_dense_layers=False, enable_composite_direct_lowering=False, model_origin_framework=lite_constants.UNSET, @@ -578,6 +579,9 @@ def build_conversion_flags( This could have side effects e.g. reduced flatbuffer size. qdq_conversion_mode: If set, assume input model is a quantized model represented with QDQ ops and convert to quantized kernels. + strict_qdq_mode: If set, adheres to the QDQ annotations added by the + framework when possible rather than quantizing any op that is possible to + quantize. disable_per_channel_quantization_for_dense_layers: If set, disables per channel end enables per tensor integer quantization for weights in Dense layers. The flag works only for integer quantized model. @@ -706,6 +710,7 @@ def build_conversion_flags( conversion_flags.reduce_type_precision = reduce_type_precision if qdq_conversion_mode is not None: conversion_flags.qdq_conversion_mode = qdq_conversion_mode + conversion_flags.strict_qdq_mode = strict_qdq_mode conversion_flags.disable_per_channel_quantization_for_dense_layers = ( disable_per_channel_quantization_for_dense_layers ) diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py index 391f36d15eb348..b66c0f1739004a 100644 --- a/tensorflow/lite/python/interpreter.py +++ b/tensorflow/lite/python/interpreter.py @@ -425,11 +425,11 @@ def __init__( in C++. experimental_preserve_all_tensors: If true, then intermediate tensors used during computation are preserved for inspection, and if the passed op - resolver type is AUTO or BUILTIN, the type will be changed to - BUILTIN_WITHOUT_DEFAULT_DELEGATES so that no Tensorflow Lite default - delegates are applied. If false, getting intermediate tensors could - result in undefined values or None, especially when the graph is - successfully modified by the Tensorflow Lite default delegate. + resolver type is AUTO or BUILTIN, the type will be changed to BUILTIN so + that Tensorflow Lite default delegates are applied. If false, getting + intermediate tensors could result in undefined values or None, + especially when the graph is successfully modified by the Tensorflow + Lite default delegate. experimental_disable_delegate_clustering: If true, don't perform delegate clustering during delegate graph partitioning phase. Disabling delegate clustering will make the execution order of ops respect the @@ -457,7 +457,13 @@ def __init__( if experimental_preserve_all_tensors and ( experimental_op_resolver_type == OpResolverType.AUTO or experimental_op_resolver_type == OpResolverType.BUILTIN): - actual_resolver_type = OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES + warnings.warn( + 'Warning: Enabling `experimental_preserve_all_tensors` with the' + ' BUILTIN or AUTO op resolver is intended for debugging purposes' + ' only. Be aware that this can significantly increase memory usage by' + ' storing all intermediate tensors. If you encounter memory problems' + ' or are not actively debugging, consider disabling this option.' + ) op_resolver_id = _get_op_resolver_id(actual_resolver_type) if op_resolver_id is None: raise ValueError('Unrecognized passed in op resolver type: {}'.format( diff --git a/tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyi b/tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyi index 5a2099d01c9e5f..c4a79168d6aa5b 100644 --- a/tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyi +++ b/tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyi @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== -from typing import Any - class InterpreterWrapper: def __init__(self, *args, **kwargs) -> None: ... def AllocateTensors(self, subgraph_index: int = ...) -> object: ... @@ -45,5 +43,5 @@ class InterpreterWrapper: def interpreter(self) -> int: ... def tensor(self, base_object: object, tensor_index: int, subgraph_index: int = ...) -> object: ... -def CreateWrapperFromBuffer(*args, **kwargs) -> Any: ... -def CreateWrapperFromFile(*args, **kwargs) -> Any: ... +def CreateWrapperFromBuffer(*args, **kwargs): ... +def CreateWrapperFromFile(*args, **kwargs): ... diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 806310fff8ac9b..ff44831068ad8a 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -680,6 +680,7 @@ def __init__(self): self._experimental_enable_composite_direct_lowering = False self.model_origin_framework = constants.UNSET self.canonicalizing_inf_as_min_max_float = True + self._experimental_strict_qdq = False # Debug parameters self.ir_dump_dir = None @@ -779,6 +780,7 @@ def _quantize( activations_type, bias_type, disable_per_channel=self._experimental_disable_per_channel, + disable_per_channel_quantization_for_dense_layers=self._experimental_disable_per_channel_quantization_for_dense_layers, ) def _is_unknown_shapes_allowed(self): @@ -836,6 +838,7 @@ def _get_base_converter_args(self): self.experimental_stablehlo_quantizer_config ), "qdq_conversion_mode": self._experimental_qdq_conversion_mode, + "strict_qdq_mode": self._experimental_strict_qdq, "disable_per_channel_quantization_for_dense_layers": ( self._experimental_disable_per_channel_quantization_for_dense_layers ), diff --git a/tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyi b/tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyi index b020337da48ed9..11c53fe433789e 100644 --- a/tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyi +++ b/tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyi @@ -13,12 +13,10 @@ # limitations under the License. # ============================================================================== -from typing import Callable - -from typing import overload +from typing import Callable, overload class CalibrationWrapper: - def __init__(self, arg0: object, arg1: list[str], arg2: list[Callable[[int],None]]) -> None: ... + def __init__(self, arg0: object, arg1: list[str], arg2: list[Callable[[int], None]]) -> None: ... def Calibrate(self) -> object: ... @overload def FeedTensor(self, arg0: object, arg1: str) -> object: ... @@ -33,7 +31,7 @@ class CalibrationWrapper: @overload def Prepare(self) -> object: ... @overload - def QuantizeModel(self, arg0: int, arg1: int, arg2: bool, arg3: int, arg4: int, arg5: bool) -> object: ... + def QuantizeModel(self, arg0: int, arg1: int, arg2: bool, arg3: int, arg4: int, arg5: bool, arg6: bool) -> object: ... @overload def QuantizeModel(self, arg0: int, arg1: int, arg2: bool, arg3: str) -> object: ... diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc index c6944fc9f9a757..6bce58ce3c4704 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc @@ -700,14 +700,17 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type, bool allow_float, int activations_py_type, int bias_py_type) { - return QuantizeModel(input_py_type, output_py_type, allow_float, - activations_py_type, bias_py_type, - /*disable_per_channel=*/false); + return QuantizeModel( + input_py_type, output_py_type, allow_float, activations_py_type, + bias_py_type, + /*disable_per_channel=*/false, + /*disable_per_channel_quantization_for_dense_layers=*/false); } PyObject* CalibrationWrapper::QuantizeModel( int input_py_type, int output_py_type, bool allow_float, - int activations_py_type, int bias_py_type, bool disable_per_channel) { + int activations_py_type, int bias_py_type, bool disable_per_channel, + bool disable_per_channel_quantization_for_dense_layers) { if (NoOpModel(*model_)) { return ConvertToPyString(model_str_->data(), model_str_->size()); } @@ -732,7 +735,7 @@ PyObject* CalibrationWrapper::QuantizeModel( TfLiteTypeToSchemaType(output_type), allow_float, TfLiteTypeToSchemaType(activations_type), TfLiteTypeToSchemaType(bias_type), disable_per_channel, - error_reporter_.get()); + disable_per_channel_quantization_for_dense_layers, error_reporter_.get()); if (status != kTfLiteOk) { error_reporter_->exception(); diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.h b/tensorflow/lite/python/optimize/calibration_wrapper.h index ec5c706eca2149..832fd7b6047007 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.h +++ b/tensorflow/lite/python/optimize/calibration_wrapper.h @@ -98,9 +98,10 @@ class CalibrationWrapper { // Disables per-channel quantization, can be used to produce smaller // models but may cause accuracy issues. - PyObject* QuantizeModel(int input_py_type, int output_py_type, - bool allow_float, int activations_py_type, - int bias_py_type, bool disable_per_channel); + PyObject* QuantizeModel( + int input_py_type, int output_py_type, bool allow_float, + int activations_py_type, int bias_py_type, bool disable_per_channel, + bool disable_per_channel_quantization_for_dense_layers); // Writes the in-memory calibration results to the model flatbuffer. The // produced model is as same as the original input model, but the min/max diff --git a/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc b/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc index f829867a63c7f4..067f57fd0b4947 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper_pybind11.cc @@ -79,10 +79,12 @@ PYBIND11_MODULE(_pywrap_tensorflow_lite_calibration_wrapper, m) { .def("QuantizeModel", [](CalibrationWrapper& self, int input_py_type, int output_py_type, bool allow_float, int activations_py_type, int bias_py_type, - bool disable_per_channel) { + bool disable_per_channel, + bool disable_per_channel_quantization_for_dense_layers) { return tensorflow::PyoOrThrow(self.QuantizeModel( input_py_type, output_py_type, allow_float, - activations_py_type, bias_py_type, disable_per_channel)); + activations_py_type, bias_py_type, disable_per_channel, + disable_per_channel_quantization_for_dense_layers)); }) .def("QuantizeModel", [](CalibrationWrapper& self, int input_py_type, int output_py_type, diff --git a/tensorflow/lite/python/optimize/calibrator.py b/tensorflow/lite/python/optimize/calibrator.py index 136890589a09fc..b5b494ebba69ff 100644 --- a/tensorflow/lite/python/optimize/calibrator.py +++ b/tensorflow/lite/python/optimize/calibrator.py @@ -165,6 +165,7 @@ def calibrate_and_quantize( bias_type=dtypes.int32, resize_input=True, disable_per_channel=False, + disable_per_channel_quantization_for_dense_layers=False, ): """Calibrates the model with specified generator and then quantizes it. @@ -189,6 +190,8 @@ def calibrate_and_quantize( from the input. disable_per_channel: A boolean. True if disabling per-channel quantization. + disable_per_channel_quantization_for_dense_layers: A boolean. True if + disabling per-channel quantization only in Dense layers. """ self._feed_tensors(dataset_gen, resize_input) return self._calibrator.QuantizeModel( @@ -198,6 +201,7 @@ def calibrate_and_quantize( np.dtype(activations_type.as_numpy_dtype()).num, np.dtype(bias_type.as_numpy_dtype()).num, disable_per_channel, + disable_per_channel_quantization_for_dense_layers, ) @convert_phase( diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py index c0692655c3f127..5f05881764a0fb 100644 --- a/tensorflow/lite/python/util.py +++ b/tensorflow/lite/python/util.py @@ -1000,7 +1000,7 @@ def get_sparsity_modes(model_object): # Block map is the list if indexes where the block size is larger than 1. # So empty block map means it is random sparsity. - if not tensor.sparsity.blockMap: + if tensor.sparsity.blockMap.size == 0 or not tensor.sparsity.blockMap: result.add( conversion_metadata_fb.ModelOptimizationMode.RANDOM_SPARSITY) else: diff --git a/tensorflow/lite/schema/builtin_ops_list/consistency_test.cc b/tensorflow/lite/schema/builtin_ops_list/consistency_test.cc index e2e74a7cd21a1f..575444f9eabef7 100644 --- a/tensorflow/lite/schema/builtin_ops_list/consistency_test.cc +++ b/tensorflow/lite/schema/builtin_ops_list/consistency_test.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include +#include +#include +#include #include #include diff --git a/tensorflow/lite/schema/builtin_ops_list/generator.cc b/tensorflow/lite/schema/builtin_ops_list/generator.cc index bfbefa1d06b4a3..215b9e0eb776f1 100644 --- a/tensorflow/lite/schema/builtin_ops_list/generator.cc +++ b/tensorflow/lite/schema/builtin_ops_list/generator.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/schema/builtin_ops_list/generator.h" +#include #include #include diff --git a/tensorflow/lite/schema/builtin_ops_list/generator_test.cc b/tensorflow/lite/schema/builtin_ops_list/generator_test.cc index 3cb4b0fee4a7d2..3cc0689da1cc55 100644 --- a/tensorflow/lite/schema/builtin_ops_list/generator_test.cc +++ b/tensorflow/lite/schema/builtin_ops_list/generator_test.cc @@ -16,8 +16,6 @@ limitations under the License. #include "tensorflow/lite/schema/builtin_ops_list/generator.h" -#include - #include namespace { diff --git a/tensorflow/lite/simple_memory_arena_debug_dump.cc b/tensorflow/lite/simple_memory_arena_debug_dump.cc index 0cf8005124dadc..52bbd3bbd7de97 100644 --- a/tensorflow/lite/simple_memory_arena_debug_dump.cc +++ b/tensorflow/lite/simple_memory_arena_debug_dump.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include #include #include #include diff --git a/tensorflow/lite/simple_planner.cc b/tensorflow/lite/simple_planner.cc index 9e24ad0660c7b8..f850e7ba7f4d05 100644 --- a/tensorflow/lite/simple_planner.cc +++ b/tensorflow/lite/simple_planner.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/simple_planner.h" +#include #include #include #include diff --git a/tensorflow/lite/simple_planner.h b/tensorflow/lite/simple_planner.h index db658839ac3672..32ee4584a2fd9d 100644 --- a/tensorflow/lite/simple_planner.h +++ b/tensorflow/lite/simple_planner.h @@ -16,7 +16,9 @@ limitations under the License. #define TENSORFLOW_LITE_SIMPLE_PLANNER_H_ #include +#include #include +#include #include #include diff --git a/tensorflow/lite/simple_planner_test.cc b/tensorflow/lite/simple_planner_test.cc index 08fd7debcee38a..08adf895ba0fd6 100644 --- a/tensorflow/lite/simple_planner_test.cc +++ b/tensorflow/lite/simple_planner_test.cc @@ -16,12 +16,16 @@ limitations under the License. #include #include +#include +#include #include #include #include #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/graph_info.h" diff --git a/tensorflow/lite/string_util_test.cc b/tensorflow/lite/string_util_test.cc index 746bf4ac8ee78e..b12241c0fa54ea 100644 --- a/tensorflow/lite/string_util_test.cc +++ b/tensorflow/lite/string_util_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include #include #include diff --git a/tensorflow/lite/tensorflow_profiler_logger.h b/tensorflow/lite/tensorflow_profiler_logger.h index 61ac0bff966bdd..3575107281ed55 100644 --- a/tensorflow/lite/tensorflow_profiler_logger.h +++ b/tensorflow/lite/tensorflow_profiler_logger.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TENSORFLOW_PROFILER_LOGGER_H_ #define TENSORFLOW_LITE_TENSORFLOW_PROFILER_LOGGER_H_ +#include #include #include diff --git a/tensorflow/lite/tensorflow_profiler_logger_shim.cc b/tensorflow/lite/tensorflow_profiler_logger_shim.cc index 72bf179f7e095a..489474ca8f4cb4 100644 --- a/tensorflow/lite/tensorflow_profiler_logger_shim.cc +++ b/tensorflow/lite/tensorflow_profiler_logger_shim.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/lite/core/macros.h" #include "tensorflow/lite/tensorflow_profiler_logger.h" diff --git a/tensorflow/lite/test_util_test.cc b/tensorflow/lite/test_util_test.cc index 36b45eed18d7ca..6d93a5817b97a0 100644 --- a/tensorflow/lite/test_util_test.cc +++ b/tensorflow/lite/test_util_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include #include diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD index 0477ae1ea8bc4f..ce7af49ef0b21c 100644 --- a/tensorflow/lite/testing/BUILD +++ b/tensorflow/lite/testing/BUILD @@ -229,6 +229,7 @@ cc_library( "@com_google_absl//absl/base", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", ], diff --git a/tensorflow/lite/testing/kernel_test/diff_analyzer.cc b/tensorflow/lite/testing/kernel_test/diff_analyzer.cc index 7ba14062fb9689..3bf562634bb7b1 100644 --- a/tensorflow/lite/testing/kernel_test/diff_analyzer.cc +++ b/tensorflow/lite/testing/kernel_test/diff_analyzer.cc @@ -16,8 +16,12 @@ limitations under the License. #include #include +#include #include #include +#include +#include +#include #include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/testing/split.h" diff --git a/tensorflow/lite/testing/kernel_test/diff_analyzer_test.cc b/tensorflow/lite/testing/kernel_test/diff_analyzer_test.cc index 3406cdf5c46b16..7d2ce72b38e535 100644 --- a/tensorflow/lite/testing/kernel_test/diff_analyzer_test.cc +++ b/tensorflow/lite/testing/kernel_test/diff_analyzer_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include "tensorflow/core/lib/io/path.h" diff --git a/tensorflow/lite/testing/kernel_test/input_generator.cc b/tensorflow/lite/testing/kernel_test/input_generator.cc index ec8fc239086975..bc365ed2317142 100644 --- a/tensorflow/lite/testing/kernel_test/input_generator.cc +++ b/tensorflow/lite/testing/kernel_test/input_generator.cc @@ -14,13 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/testing/kernel_test/input_generator.h" +#include #include +#include #include +#include #include #include #include -#include #include +#include #include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/core/c/common.h" diff --git a/tensorflow/lite/testing/kernel_test/input_generator_test.cc b/tensorflow/lite/testing/kernel_test/input_generator_test.cc index f6f1248d8e5195..650d39690e6817 100644 --- a/tensorflow/lite/testing/kernel_test/input_generator_test.cc +++ b/tensorflow/lite/testing/kernel_test/input_generator_test.cc @@ -15,11 +15,10 @@ limitations under the License. #include "tensorflow/lite/testing/kernel_test/input_generator.h" #include -#include #include -#include +#include +#include -#include #include namespace tflite { diff --git a/tensorflow/lite/testing/kernel_test/util_test.cc b/tensorflow/lite/testing/kernel_test/util_test.cc index 59d75931079600..3149350f9a5c08 100644 --- a/tensorflow/lite/testing/kernel_test/util_test.cc +++ b/tensorflow/lite/testing/kernel_test/util_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include "tensorflow/lite/testing/tflite_driver.h" diff --git a/tensorflow/lite/testing/matchers.h b/tensorflow/lite/testing/matchers.h index 17646ffb811eb4..3293519d871946 100644 --- a/tensorflow/lite/testing/matchers.h +++ b/tensorflow/lite/testing/matchers.h @@ -25,7 +25,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -33,6 +32,7 @@ limitations under the License. #include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -132,7 +132,7 @@ class TensorMatcher { return false; } - void Describe(std::ostream* os, std::string_view prefix) const { + void Describe(std::ostream* os, absl::string_view prefix) const { *os << prefix; if (comp_.float_comp == FloatComparison::kApproximate) { *os << "approximately "; diff --git a/tensorflow/lite/testing/op_tests/fully_connected_4bit_hybrid.py b/tensorflow/lite/testing/op_tests/fully_connected_4bit_hybrid.py index a5611e2d5af604..ea3d4cda8bdd4d 100644 --- a/tensorflow/lite/testing/op_tests/fully_connected_4bit_hybrid.py +++ b/tensorflow/lite/testing/op_tests/fully_connected_4bit_hybrid.py @@ -37,11 +37,11 @@ def make_fully_connected_4bit_hybrid_tests(options): "dynamic_range_quantize": [True], }, # No optimization. - { - "shape1": [[1, 40]], - "shape2": [[40, 3]], - "dynamic_range_quantize": [True], - }, + # { + # "shape1": [[1, 40]], + # "shape2": [[40, 3]], + # "dynamic_range_quantize": [True], + # }, ] def build_graph(parameters): diff --git a/tensorflow/lite/toco/BUILD b/tensorflow/lite/toco/BUILD index 2c2e5e41081a9c..1daaa368f0db0f 100644 --- a/tensorflow/lite/toco/BUILD +++ b/tensorflow/lite/toco/BUILD @@ -125,6 +125,8 @@ cc_library( ":types_proto_cc", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], @@ -342,7 +344,10 @@ cc_library( "//tensorflow/lite/toco/tensorflow_graph_matching:resolve_cluster", "//tensorflow/lite/toco/tflite:export", "//tensorflow/lite/toco/tflite:import", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf_headers", ], @@ -387,7 +392,10 @@ cc_library( ":types_proto_cc", "//tensorflow/core:lib", "//tensorflow/lite/kernels/internal:types", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_googlesource_code_re2//:re2", ], @@ -402,6 +410,7 @@ tf_cc_test( ":tooling_util", "//tensorflow/core:lib", "//tensorflow/lite/testing:util", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest", ], ) @@ -423,6 +432,8 @@ cc_library( ":toco_port", ":toco_tooling", ":types_proto_cc", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "//tensorflow/core:lib", # We cannot embed the core:ops dependency directly into :toco_tooling as @@ -445,6 +456,7 @@ tf_cc_binary( ":toco_port", ":toco_tooling", ":types_proto_cc", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "//tensorflow/core:lib", # We cannot embed the core:ops dependency directly into :toco_tooling as diff --git a/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc index d6932b73138c94..3c1666f068674d 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc @@ -28,9 +28,8 @@ limitations under the License. namespace toco { -::tensorflow::Status ConvertExpandDimsToReshape::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; auto expand_it = model->operators.begin() + op_index; if (expand_it->get()->type != OperatorType::kExpandDims) { diff --git a/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc b/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc index 6d2b5ca4c4a582..b582641ec4618d 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc @@ -23,9 +23,9 @@ namespace toco { // V3 is only different from V2 because it has an extra attribute (align). // This attribute doesn't affect V1 so we don't have to keep track of it here. -::tensorflow::Status ConvertMatrixDiagV2OrV3ToV1::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status ConvertMatrixDiagV2OrV3ToV1::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; auto it = model->operators.begin() + op_index; const auto* op = it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc b/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc index 84e84aabce74d3..d4dafaa7ed678d 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc @@ -28,9 +28,9 @@ namespace toco { // V3 is only different from V2 because it has an extra attribute (align). // This attribute doesn't affect V1 so we don't have to keep track of it here. -::tensorflow::Status ConvertMatrixSetDiagV2OrV3ToV1::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status ConvertMatrixSetDiagV2OrV3ToV1::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; auto it = model->operators.begin() + op_index; const auto* op = it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc index b7763e1ff98fe3..f8c7e0130e7272 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc @@ -25,9 +25,8 @@ limitations under the License. namespace toco { -::tensorflow::Status ConvertPureConvToDepthwise::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; auto conv_it = model->operators.begin() + op_index; if (conv_it->get()->type != OperatorType::kConv) { diff --git a/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc b/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc index 60dcf00f8d5693..cd5684bfbaf583 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc @@ -88,8 +88,8 @@ TransposeOperator* CreateTransposeFromReorderAxes( // Converts ReorderAxes into Transpose and Reshape which are compatible with the // TFLite interpreter. -::tensorflow::Status ConvertReorderAxes::Run(Model* model, std::size_t op_index, - bool* modified) { +absl::Status ConvertReorderAxes::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; auto reorder_it = model->operators.begin() + op_index; if (reorder_it->get()->type != OperatorType::kReorderAxes) diff --git a/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc index c98d64d389aacb..7d64a30b5d1483 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc @@ -31,9 +31,8 @@ namespace toco { // means that the data layout will never change with this op, just the shape. // By converting these to reshapes once we have run shape propagation we allow // standard reshape optimization transforms to do their magic. -::tensorflow::Status ConvertSqueezeToReshape::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; auto squeeze_it = model->operators.begin() + op_index; if (squeeze_it->get()->type != OperatorType::kSqueeze) { diff --git a/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc index c60ddff8a9284f..bc8d88999acd27 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc @@ -23,9 +23,8 @@ namespace toco { // This pass will convert an AddN operator with only 2 inputs into a regular Add // operator, to which more optimizations may apply. -::tensorflow::Status ConvertTrivialAddNToAdd::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status ConvertTrivialAddNToAdd::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; auto addn_it = model->operators.begin() + op_index; if (addn_it->get()->type != OperatorType::kAddN) { diff --git a/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc index c945615c1fb319..7aa694395fc18c 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc @@ -26,9 +26,9 @@ limitations under the License. namespace toco { -::tensorflow::Status ConvertTrivialPackToReshape::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status ConvertTrivialPackToReshape::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; auto pack_it = model->operators.begin() + op_index; if (pack_it->get()->type != OperatorType::kPack) { diff --git a/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc index 71a7d92d2e2b0e..bfd97311c587a5 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc @@ -23,9 +23,8 @@ limitations under the License. namespace toco { -::tensorflow::Status ConvertTrivialTileToConcat::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; auto tile_it = model->operators.begin() + op_index; if (tile_it->get()->type != OperatorType::kTile) { diff --git a/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc index 8a33ad575bcf12..4871439f925812 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc @@ -51,9 +51,9 @@ bool TransposeAffectsMemoryOrder(std::vector perm, } // namespace -::tensorflow::Status ConvertTrivialTransposeToReshape::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status ConvertTrivialTransposeToReshape::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; auto transpose_it = model->operators.begin() + op_index; if (transpose_it->get()->type != OperatorType::kTranspose) { diff --git a/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc index 380cdf216efb70..bb3ac3a5c94bd2 100644 --- a/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc +++ b/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc @@ -74,8 +74,8 @@ bool ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { return true; } -::tensorflow::Status CreateIm2colArrays::Run(Model* model, std::size_t op_index, - bool* modified) { +absl::Status CreateIm2colArrays::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; auto it = model->operators.begin() + op_index; auto* op = it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/dequantize.cc b/tensorflow/lite/toco/graph_transformations/dequantize.cc index 5dd4d2e8750377..4dad4679e5f1a2 100644 --- a/tensorflow/lite/toco/graph_transformations/dequantize.cc +++ b/tensorflow/lite/toco/graph_transformations/dequantize.cc @@ -188,8 +188,8 @@ bool DequantizeArray(const std::string& array_name, } // namespace -::tensorflow::Status Dequantize::Run(Model* model, std::size_t op_index, - bool* modified) { +absl::Status Dequantize::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; const auto op_it = model->operators.begin() + op_index; auto* op = op_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc b/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc index cdd748ac371075..62968789dfb241 100644 --- a/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc +++ b/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc @@ -26,8 +26,8 @@ limitations under the License. namespace toco { -::tensorflow::Status DropFakeQuant::Run(Model* model, std::size_t op_index, - bool* modified) { +absl::Status DropFakeQuant::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; const auto fakequant_it = model->operators.begin() + op_index; auto* fakequant_base_op = fakequant_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc b/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc index d3cfae07faebbd..3c5340544ce819 100644 --- a/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc +++ b/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc @@ -21,8 +21,8 @@ limitations under the License. namespace toco { -::tensorflow::Status DropIm2colArrays::Run(Model* model, std::size_t op_index, - bool* modified) { +absl::Status DropIm2colArrays::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; auto conv_it = model->operators.begin() + op_index; if (conv_it->get()->type != OperatorType::kConv) { diff --git a/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc index f8d639cc396e25..a1dda5c93f8bc6 100644 --- a/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc +++ b/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc @@ -76,8 +76,8 @@ bool ProcessLinearOperator(Model* model, Operator* op) { } } // namespace -::tensorflow::Status EnsureBiasVectors::Run(Model* model, std::size_t op_index, - bool* modified) { +absl::Status EnsureBiasVectors::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; auto* op = model->operators[op_index].get(); if (op->type == OperatorType::kConv || diff --git a/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc b/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc index ed3a89a70123ad..3d84bfa0bbbe0c 100644 --- a/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc +++ b/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc @@ -111,8 +111,9 @@ namespace toco { // we can foresee these 'fast int8 kernels' to remain important to have into // the 2020s. // -::tensorflow::Status EnsureUint8WeightsSafeForFastInt8Kernels::Run( - Model* model, std::size_t op_index, bool* modified) { +absl::Status EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; const auto& op = *model->operators[op_index]; int weights_index = 0; diff --git a/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc index 64b91ccf62878a..3c9a6b968d6e41 100644 --- a/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc +++ b/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc @@ -27,9 +27,8 @@ limitations under the License. namespace toco { -::tensorflow::Status FuseActivationFunctions::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status FuseActivationFunctions::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; const auto ac_it = model->operators.begin() + op_index; const auto* ac_op = ac_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc index 3afa9c44a59e5c..c6b4b6fa228b9f 100644 --- a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc +++ b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc @@ -152,9 +152,9 @@ void FuseMulOrDivParamsIntoFollowingAffine(Model* model, Operator* following_op, } // namespace -::tensorflow::Status FuseBinaryIntoFollowingAffine::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status FuseBinaryIntoFollowingAffine::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; const auto binary_it = model->operators.begin() + op_index; auto* binary_op = binary_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc index fa0baf97dbd9c5..b9c3b7e7c2d33c 100644 --- a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc +++ b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc @@ -205,9 +205,9 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, } } // namespace -::tensorflow::Status FuseBinaryIntoPrecedingAffine::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status FuseBinaryIntoPrecedingAffine::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; const auto binary_it = model->operators.begin() + op_index; const auto* binary_op = binary_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc b/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc index ba57090e2eff6a..66fa1a8ffe9147 100644 --- a/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc +++ b/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc @@ -52,9 +52,9 @@ bool IsBroadcastingOp(const Model& model, Operator* op) { // Finds an operation that looks like a broadcast (concat of the same sources // along the last dimension) and drops it by relying on the ability of certain // binary ops to perform an implicit broadcast. -::tensorflow::Status FuseBroadcastIntoFollowingBinary::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status FuseBroadcastIntoFollowingBinary::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; const auto binary_it = model->operators.begin() + op_index; auto* binary_op = binary_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/lite/toco/graph_transformations/graph_transformations.cc index 125e5597a49f35..3a31f69982f633 100644 --- a/tensorflow/lite/toco/graph_transformations/graph_transformations.cc +++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.cc @@ -132,7 +132,7 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) { bool GraphTransformationsPass(int increment, Model* model, const GraphTransformationsSet& transformations, - tensorflow::Status* status) { + absl::Status* status) { CHECK(increment == 1 || increment == -1); bool changed = false; if (model->operators.empty()) { @@ -193,12 +193,12 @@ bool GraphTransformationsPass(int increment, Model* model, } // namespace -tensorflow::Status RunGraphTransformationsWithStatus( +absl::Status RunGraphTransformationsWithStatus( Model* model, const std::string& msg, const GraphTransformationsSet& transformations) { PrintModelStats(toco::port::StringF("Before %s", msg), *model); int pass_index = 0; - tensorflow::Status status; + absl::Status status; while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model, transformations, &status)) { pass_index++; diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/lite/toco/graph_transformations/graph_transformations.h index c7e2c9de186f97..7e0b57c8dd5d60 100644 --- a/tensorflow/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.h @@ -30,8 +30,8 @@ namespace toco { class GraphTransformation { public: - virtual ::tensorflow::Status Run(Model* model, std::size_t op_index, - bool* modified) = 0; + virtual absl::Status Run(Model* model, std::size_t op_index, + bool* modified) = 0; virtual const char* Name() const = 0; virtual ~GraphTransformation() {} // Returns the list of messages that this graph transformation @@ -105,7 +105,7 @@ class GraphTransformationsSet { // construct GraphTransformation objects by using 'new', pass us // the resulting raw pointers, and this RunGraphTransformations // takes care of delete'ing these pointers. -tensorflow::Status RunGraphTransformationsWithStatus( +absl::Status RunGraphTransformationsWithStatus( Model* model, const std::string& msg, const GraphTransformationsSet& transformations); @@ -222,8 +222,7 @@ DECLARE_GRAPH_TRANSFORMATION(IdentifyNearestUpsample) class PropagateDefaultMinMax : public GraphTransformation { public: - ::tensorflow::Status Run(Model* model, std::size_t op_index, - bool* modified) override; + absl::Status Run(Model* model, std::size_t op_index, bool* modified) override; const char* Name() const override { return "PropagateDefaultMinMax"; } bool has_any_ranges_defined() const { return !type_ranges_.empty(); } @@ -241,8 +240,7 @@ class PropagateDefaultMinMax : public GraphTransformation { class RemoveTrivialReshape : public GraphTransformation { public: - ::tensorflow::Status Run(Model* model, std::size_t op_index, - bool* modified) override; + absl::Status Run(Model* model, std::size_t op_index, bool* modified) override; const char* Name() const override { return "RemoveTrivialReshape"; } bool treat_expand_dims_as_trivial() const { return treat_expand_dims_as_trivial_; @@ -257,8 +255,7 @@ class RemoveTrivialReshape : public GraphTransformation { class ResolveConstantFakeQuant : public GraphTransformation { public: - ::tensorflow::Status Run(Model* model, std::size_t op_index, - bool* modified) override; + absl::Status Run(Model* model, std::size_t op_index, bool* modified) override; const char* Name() const override { return "ResolveConstantFakeQuant"; } // True if the num_bits should adjust the final data type. @@ -275,8 +272,7 @@ class ResolveConstantFakeQuant : public GraphTransformation { class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation { public: - ::tensorflow::Status Run(Model* model, std::size_t op_index, - bool* modified) override; + absl::Status Run(Model* model, std::size_t op_index, bool* modified) override; const char* Name() const override { return "EnsureUint8WeightsSafeForFastInt8Kernels"; } @@ -293,8 +289,7 @@ class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation { class IdentifyDilatedConv : public GraphTransformation { public: - ::tensorflow::Status Run(Model* model, std::size_t op_index, - bool* modified) override; + absl::Status Run(Model* model, std::size_t op_index, bool* modified) override; const char* Name() const override { return "IdentifyDilatedConv"; } bool identify_depthwise_conv() const { return identify_depthwise_conv_; } void set_identify_depthwise_conv(bool val) { identify_depthwise_conv_ = val; } diff --git a/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc b/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc index 2da6fbe6cfe76f..1765ce7e184560 100644 --- a/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc +++ b/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc @@ -403,9 +403,9 @@ void RemoveUnidirectionalSequenceOps(std::stack uni_sequence_ops, } template -::tensorflow::Status GroupDynamicSequenceOps(Model* model, std::size_t op_index, - OperatorType operator_type, - bool* modified) { +absl::Status GroupDynamicSequenceOps(Model* model, std::size_t op_index, + OperatorType operator_type, + bool* modified) { *modified = false; // We assume there's a concatenation right after the bidirectional sequence @@ -477,9 +477,9 @@ ::tensorflow::Status GroupDynamicSequenceOps(Model* model, std::size_t op_index, } // namespace -::tensorflow::Status GroupBidirectionalSequenceLstm::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status GroupBidirectionalSequenceLstm::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; // Bidirectional sequence lstm will generate two separate unidirectional // sequence lstm ops, for static bidirectional sequence lstm, there will be @@ -554,9 +554,9 @@ ::tensorflow::Status GroupBidirectionalSequenceLstm::Run(Model* model, return absl::OkStatus(); } -::tensorflow::Status GroupBidirectionalSequenceRnn::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status GroupBidirectionalSequenceRnn::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; // Bidirectional sequence rnn will generate two separate unidirectional // sequence rnn ops, for static bidirectional sequence rnn, there will be @@ -629,14 +629,16 @@ ::tensorflow::Status GroupBidirectionalSequenceRnn::Run(Model* model, return absl::OkStatus(); } -::tensorflow::Status GroupDynamicBidirectionalSequenceRnn::Run( - Model* model, std::size_t op_index, bool* modified) { +absl::Status GroupDynamicBidirectionalSequenceRnn::Run(Model* model, + std::size_t op_index, + bool* modified) { return GroupDynamicSequenceOps( model, op_index, OperatorType::kBidirectionalSequenceRnn, modified); } -::tensorflow::Status GroupDynamicBidirectionalSequenceLstm::Run( - Model* model, std::size_t op_index, bool* modified) { +absl::Status GroupDynamicBidirectionalSequenceLstm::Run(Model* model, + std::size_t op_index, + bool* modified) { return GroupDynamicSequenceOps( model, op_index, OperatorType::kBidirectionalSequenceLstm, modified); } diff --git a/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc index 6f142a447f60d8..a6681d8da76aae 100644 --- a/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc @@ -425,8 +425,8 @@ bool HardcodeMinMaxForPack(Model* model, Operator* op) { } // namespace -::tensorflow::Status HardcodeMinMax::Run(Model* model, std::size_t op_index, - bool* modified) { +absl::Status HardcodeMinMax::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; auto it = model->operators.begin() + op_index; auto* op = it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc index 985e588072136e..1686ee9c1eb8ea 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc @@ -168,9 +168,8 @@ bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op, return true; } -::tensorflow::Status IdentifyDilatedConv::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status IdentifyDilatedConv::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; const auto it = model->operators.begin() + op_index; auto* stb_op = it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/identify_hardswish.cc b/tensorflow/lite/toco/graph_transformations/identify_hardswish.cc index 437147f8b55d81..10b548db2d373e 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_hardswish.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_hardswish.cc @@ -37,8 +37,8 @@ namespace toco { using util::IsBinaryOp; -::tensorflow::Status IdentifyHardSwish::Run(Model* model, std::size_t op_index, - bool* modified) { +absl::Status IdentifyHardSwish::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; const auto add_with_relu6_op_it = (model->operators.begin() + op_index); const auto add_with_relu6_op = add_with_relu6_op_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc index e8a5d209d64a6f..a410d90294f8ff 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc @@ -27,9 +27,8 @@ limitations under the License. namespace toco { -::tensorflow::Status IdentifyL2Normalization::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status IdentifyL2Normalization::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; const auto div_it = model->operators.begin() + op_index; const auto* div_or_mul_op = div_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc index a980995a870280..48511419cb87e1 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc @@ -26,8 +26,8 @@ limitations under the License. namespace toco { -::tensorflow::Status IdentifyL2Pool::Run(Model* model, std::size_t op_index, - bool* modified) { +absl::Status IdentifyL2Pool::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; const auto sqrt_it = model->operators.begin() + op_index; const auto* sqrt_op = sqrt_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/lite/toco/graph_transformations/identify_lstm.cc index df0aa9ff3ddba7..38b63469f49486 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_lstm.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_lstm.cc @@ -136,8 +136,8 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, } // namespace -::tensorflow::Status IdentifyLstmCell::Run(Model* model, std::size_t op_index, - bool* modified) { +absl::Status IdentifyLstmCell::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; // This LSTM cell identification method is not invariant to commutation of // commutative operator inputs. For example, if input[0] and input[1] of the diff --git a/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc index 24299d557551c8..2fea3f4d357512 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc @@ -28,9 +28,8 @@ limitations under the License. namespace toco { -::tensorflow::Status MergeLstmCellInputs::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status MergeLstmCellInputs::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; // Find lstm cell. auto op_it = model->operators.begin() + op_index; diff --git a/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc index aea6d93d00a04a..bc79bd5602a63c 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc @@ -27,9 +27,8 @@ limitations under the License. namespace toco { -::tensorflow::Status SplitLstmCellInputs::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status SplitLstmCellInputs::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; // Find lstm cell. auto op_it = model->operators.begin() + op_index; diff --git a/tensorflow/lite/toco/graph_transformations/identify_nearest_upsample.cc b/tensorflow/lite/toco/graph_transformations/identify_nearest_upsample.cc index 1d1d67bd253a75..76d45982d32dd4 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_nearest_upsample.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_nearest_upsample.cc @@ -80,9 +80,8 @@ std::vector>::iterator FindOperator( // It's possible the model uses mul-broadcast to implement nearest neighbor // upsample which may involve 5-d, 6-d tensors. We can actually change this // pattern to be pack-based which is easier for us to handle. -::tensorflow::Status IdentifyNearestUpsample::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status IdentifyNearestUpsample::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; auto op_it = model->operators.begin() + op_index; auto* op = op_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/lite/toco/graph_transformations/identify_prelu.cc index 0f28cb1cd26ef6..dbf33a1fb58223 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_prelu.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_prelu.cc @@ -45,8 +45,8 @@ limitations under the License. namespace toco { -::tensorflow::Status IdentifyPRelu::Run(Model* model, std::size_t op_index, - bool* modified) { +absl::Status IdentifyPRelu::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; const auto add_op_it = model->operators.begin() + op_index; const auto* add_op = add_op_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/lite/toco/graph_transformations/identify_relu1.cc index 6f2e22439f7e44..a25ad134e62b97 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_relu1.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_relu1.cc @@ -28,8 +28,8 @@ namespace toco { using util::GetSingleScalarInputIndexOfBinaryOp; -::tensorflow::Status IdentifyRelu1::Run(Model* model, std::size_t op_index, - bool* modified) { +absl::Status IdentifyRelu1::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; // Follow sequences of min+max and max+min. First get the leading op. const auto op_it = model->operators.begin() + op_index; diff --git a/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc index 0726b32632668f..84e6d877eab225 100644 --- a/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc +++ b/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc @@ -99,9 +99,9 @@ bool AddDequantizeOperatorToInput(const std::string& input_name, return true; } -::tensorflow::Status MakeInitialDequantizeOperator::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status MakeInitialDequantizeOperator::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; // This is effectively a transformation applied to edges. We iterate over the // specified node (op) and proceed for input edges. diff --git a/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc index a292b97f002010..860c0094434eb7 100644 --- a/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc +++ b/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc @@ -104,8 +104,9 @@ std::vector ReshapeToTranspose(const Model& model, // to be merged if the reshape does not affect memory ordering and does not // affects the number of dimensions. This only occurs when only unary dimensions // are shifting position. -::tensorflow::Status MergeReshapeIntoPrecedingTranspose::Run( - Model* model, std::size_t op_index, bool* modified) { +absl::Status MergeReshapeIntoPrecedingTranspose::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; auto it = model->operators.begin() + op_index; auto* reshape_op = ConvertOperator( diff --git a/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc b/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc index 588a03445d4df8..47bd4268800898 100644 --- a/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc @@ -58,9 +58,9 @@ bool IsTailOfShape(const Shape& tail, const Shape& shape) { // // Note we are testing for one particular case of a broader set of possible // binary-reshape op transformations. This transformation could be generalized. -::tensorflow::Status MoveBinaryOperatorBeforeReshape::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status MoveBinaryOperatorBeforeReshape::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; const auto binary_it = model->operators.begin() + op_index; Operator* binary_op = binary_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc b/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc index fffdde0a571cf9..240d0ae90232cf 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc @@ -27,8 +27,9 @@ limitations under the License. namespace toco { -::tensorflow::Status PropagateActivationFunctionIntoConstants::Run( - Model* model, std::size_t op_index, bool* modified) { +absl::Status PropagateActivationFunctionIntoConstants::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; const auto ac_it = model->operators.begin() + op_index; const auto* ac_op = ac_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc index ef0a5205bd867a..f0bd980fbdc35b 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -34,9 +34,8 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op, } } // namespace -::tensorflow::Status PropagateArrayDataTypes::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status PropagateArrayDataTypes::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; auto it = model->operators.begin() + op_index; auto* op = it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc b/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc index 54b76fb89bbbda..e577194cb46940 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc @@ -41,9 +41,8 @@ bool SupportsMinMax(const Array& array) { // When provided a set of min/max values for uint8 arrays this will rescale // the values for other data types as required and preserving the floating point // range within the new type. -::tensorflow::Status PropagateDefaultMinMax::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status PropagateDefaultMinMax::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; const auto it = model->operators.begin() + op_index; const auto* op = it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc index 62d8715b808491..a80c96bf1a5a5a 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc @@ -279,9 +279,8 @@ bool RecursivelyForwardPropagateDataType(GraphTransformation* transformation, // nice logging and integration with the graphviz video dumping mode. // In general you should not copy this style of transformation and stick to // local-only changes as seen in the other transformations. -::tensorflow::Status PropagateFakeQuantNumBits::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; auto it = model->operators.begin() + op_index; auto* op = it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 5136bc0012a8af..0ecc475a12149d 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -2147,9 +2147,8 @@ void ProcessScatterNdOperator(Model* model, ScatterNdOperator* op) { } // namespace -::tensorflow::Status PropagateFixedSizes::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status PropagateFixedSizes::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; auto it = model->operators.begin() + op_index; auto* op = it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/quantize.cc b/tensorflow/lite/toco/graph_transformations/quantize.cc index 9e5e58017afd00..6c619e78b65143 100644 --- a/tensorflow/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/lite/toco/graph_transformations/quantize.cc @@ -500,8 +500,7 @@ void FixMinMaxPostQuantization(GraphTransformation* transformation, } // namespace -::tensorflow::Status Quantize::Run(Model* model, std::size_t op_index, - bool* modified) { +absl::Status Quantize::Run(Model* model, std::size_t op_index, bool* modified) { *modified = false; // Our general "quantization" graph transformation consists in replacing // QuantizedInputArrays[] -> diff --git a/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc b/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc index bf9334f2a86793..b61189eba627f2 100644 --- a/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc +++ b/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc @@ -52,7 +52,7 @@ bool ApplyAttrsToArray(GraphTransformation* transformation, Model* model, } // end namespace -::tensorflow::Status ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run( +absl::Status ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run( Model* model, std::size_t op_index, bool* modified) { *modified = false; const auto fakequant_it = model->operators.begin() + op_index; diff --git a/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc b/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc index fc15e8ed7cd406..3600ead2489250 100644 --- a/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc @@ -26,9 +26,8 @@ limitations under the License. namespace toco { -::tensorflow::Status RemoveFinalDequantizeOp::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; const auto dequantize_it = model->operators.begin() + op_index; const auto* dequantize_op = dequantize_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc b/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc index 79e6b68c99978a..d13006b14f4bfd 100644 --- a/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc @@ -58,9 +58,8 @@ void ReplaceOpInputsWith(Model* model, const std::string& lookfor, } // namespace -::tensorflow::Status RemoveSuccessiveTranspose::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status RemoveSuccessiveTranspose::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; auto op = model->operators.begin() + op_index; if (op->get()->type != OperatorType::kTranspose) { diff --git a/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc index 45de603fdc20a7..627abba6ad199f 100644 --- a/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc @@ -25,9 +25,8 @@ limitations under the License. namespace toco { -::tensorflow::Status RemoveTensorFlowAssert::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; const auto assert_it = model->operators.begin() + op_index; const auto* assert_op = assert_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/remove_tensorflow_identity.cc b/tensorflow/lite/toco/graph_transformations/remove_tensorflow_identity.cc index 0ce8628899e750..1fd133e2bd4d23 100644 --- a/tensorflow/lite/toco/graph_transformations/remove_tensorflow_identity.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_tensorflow_identity.cc @@ -25,9 +25,8 @@ limitations under the License. namespace toco { -::tensorflow::Status RemoveTensorFlowIdentity::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status RemoveTensorFlowIdentity::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; const auto passthru_it = model->operators.begin() + op_index; const auto* passthru_op = passthru_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/lite/toco/graph_transformations/remove_trivial_binary.cc index eff06cb4a2791b..77e0b54073c7c4 100644 --- a/tensorflow/lite/toco/graph_transformations/remove_trivial_binary.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_binary.cc @@ -49,9 +49,9 @@ bool AreAllBufferElementsEqualTo(const std::vector& buffer_data, // For example, an Add operator is trivial if // one of its operands is constant 0, a Mul operator is trivial // if one of its operands is constant 1, etc. -::tensorflow::Status RemoveTrivialBinaryOperator::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status RemoveTrivialBinaryOperator::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; const auto binary_it = model->operators.begin() + op_index; auto* binary_op = binary_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/remove_trivial_concatenation.cc b/tensorflow/lite/toco/graph_transformations/remove_trivial_concatenation.cc index 99f369e16300bc..900bc09af91917 100644 --- a/tensorflow/lite/toco/graph_transformations/remove_trivial_concatenation.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_concatenation.cc @@ -25,9 +25,8 @@ limitations under the License. namespace toco { -::tensorflow::Status RemoveTrivialConcatenation::Run(Model* model, - std::size_t op_index, - bool* modified) { +absl::Status RemoveTrivialConcatenation::Run(Model* model, std::size_t op_index, + bool* modified) { *modified = false; const auto concat_it = model->operators.begin() + op_index; auto* concat_op = concat_it->get(); diff --git a/tensorflow/lite/toco/graph_transformations/tests/identify_l2_normalization_test.cc b/tensorflow/lite/toco/graph_transformations/tests/identify_l2_normalization_test.cc index c21118f4df7e2e..8c8dbd601c9b9a 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/identify_l2_normalization_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/identify_l2_normalization_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include #include diff --git a/tensorflow/lite/toco/graph_transformations/tests/lstm_utils_test.cc b/tensorflow/lite/toco/graph_transformations/tests/lstm_utils_test.cc index ae9006af978237..a743b26414bf31 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/lstm_utils_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/lstm_utils_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/toco/graph_transformations/lstm_utils.h" +#include #include -#include #include #include diff --git a/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc index 405c79b8d52c40..159e24743f6147 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include #include #include diff --git a/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc b/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc index af26eef7ff6922..3dfa9244c09bc8 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include #include -#include #include #include diff --git a/tensorflow/lite/toco/graph_transformations/tests/unpack_quantize_test.cc b/tensorflow/lite/toco/graph_transformations/tests/unpack_quantize_test.cc index 3a22849b949955..0b12905cdb16ad 100755 --- a/tensorflow/lite/toco/graph_transformations/tests/unpack_quantize_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/unpack_quantize_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include #include #include diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc index 8e3159c8646f34..56957a51c9e7f7 100644 --- a/tensorflow/lite/toco/import_tensorflow.cc +++ b/tensorflow/lite/toco/import_tensorflow.cc @@ -142,9 +142,9 @@ const AttrValue::ListValue& GetListAttr(const NodeDef& node, return attr.list(); } -tensorflow::Status CheckOptionalAttr(const NodeDef& node, - const std::string& attr_name, - const std::string& expected_value) { +absl::Status CheckOptionalAttr(const NodeDef& node, + const std::string& attr_name, + const std::string& expected_value) { if (HasAttr(node, attr_name)) { const std::string& value = GetStringAttr(node, attr_name); if (value != expected_value) { @@ -156,9 +156,9 @@ tensorflow::Status CheckOptionalAttr(const NodeDef& node, return absl::OkStatus(); } -tensorflow::Status CheckOptionalAttr( - const NodeDef& node, const std::string& attr_name, - const tensorflow::DataType& expected_value) { +absl::Status CheckOptionalAttr(const NodeDef& node, + const std::string& attr_name, + const tensorflow::DataType& expected_value) { if (HasAttr(node, attr_name)) { const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name); if (value != expected_value) { @@ -171,8 +171,8 @@ tensorflow::Status CheckOptionalAttr( } template -tensorflow::Status ExpectValue(const T1& v1, const T2& v2, - const std::string& description) { +absl::Status ExpectValue(const T1& v1, const T2& v2, + const std::string& description) { if (v1 == v2) return absl::OkStatus(); return tensorflow::errors::InvalidArgument(absl::StrCat( "Unexpected ", description, ": got ", v1, ", expected ", v2)); @@ -204,10 +204,9 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) { return ArrayDataType::kNone; } -tensorflow::Status ImportShape( - const TFLITE_PROTO_NS::RepeatedPtrField& - input_dims, - int* input_flat_size, Shape* shape) { +absl::Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField< + tensorflow::TensorShapeProto_Dim>& input_dims, + int* input_flat_size, Shape* shape) { std::vector input_dims_only_sizes; bool zero_sized_shape = false; for (auto& d : input_dims) { @@ -344,9 +343,9 @@ struct TensorTraits { }; template -tensorflow::Status ImportTensorData(const TensorProto& input_tensor, - int input_flat_size, - std::vector* output_data) { +absl::Status ImportTensorData(const TensorProto& input_tensor, + int input_flat_size, + std::vector* output_data) { CHECK_GE(output_data->size(), input_flat_size); int num_elements_in_tensor = TensorTraits::size(input_tensor); if (num_elements_in_tensor == input_flat_size) { @@ -384,8 +383,8 @@ tensorflow::Status ImportTensorData(const TensorProto& input_tensor, return absl::OkStatus(); } -tensorflow::Status ImportFloatArray(const TensorProto& input_tensor, - Array* output_array) { +absl::Status ImportFloatArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_FLOAT); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 6); @@ -402,8 +401,8 @@ tensorflow::Status ImportFloatArray(const TensorProto& input_tensor, &output_float_data); } -tensorflow::Status ImportComplex64Array(const TensorProto& input_tensor, - Array* output_array) { +absl::Status ImportComplex64Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_COMPLEX64); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -420,8 +419,8 @@ tensorflow::Status ImportComplex64Array(const TensorProto& input_tensor, &output_complex_data); } -tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor, - Array* output_array) { +absl::Status ImportQuint8Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_QUINT8); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 6); @@ -437,8 +436,8 @@ tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor, &output_int_data); } -tensorflow::Status ImportInt32Array(const TensorProto& input_tensor, - Array* output_array) { +absl::Status ImportInt32Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT32); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 6); @@ -454,8 +453,8 @@ tensorflow::Status ImportInt32Array(const TensorProto& input_tensor, &output_int_data); } -tensorflow::Status ImportUint32Array(const TensorProto& input_tensor, - Array* output_array) { +absl::Status ImportUint32Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_UINT32); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 6); @@ -471,8 +470,8 @@ tensorflow::Status ImportUint32Array(const TensorProto& input_tensor, &output_int_data); } -tensorflow::Status ImportInt64Array(const TensorProto& input_tensor, - Array* output_array) { +absl::Status ImportInt64Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT64); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 6); @@ -488,8 +487,8 @@ tensorflow::Status ImportInt64Array(const TensorProto& input_tensor, &output_int_data); } -tensorflow::Status ImportBoolArray(const TensorProto& input_tensor, - Array* output_array) { +absl::Status ImportBoolArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_BOOL); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 6); @@ -515,8 +514,8 @@ tensorflow::Status ImportBoolArray(const TensorProto& input_tensor, return status; } -tensorflow::Status ImportStringArray(const TensorProto& input_tensor, - Array* output_array) { +absl::Status ImportStringArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_STRING); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 6); @@ -556,9 +555,9 @@ int GetInputsCount(const NodeDef& node, return node.input_size(); } -tensorflow::Status CheckInputsCount( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - int expected_input_count) { +absl::Status CheckInputsCount(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + int expected_input_count) { if (GetInputsCount(node, tf_import_flags) != expected_input_count) { return tensorflow::errors::FailedPrecondition( node.op(), " node expects ", expected_input_count, @@ -689,7 +688,7 @@ void GetOutputTypesFromNodeDef(const NodeDef& node, } } -tensorflow::Status ConvertUnsupportedOperator( +absl::Status ConvertUnsupportedOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { // Names of special attributes in TF graph that are used by Toco. @@ -777,14 +776,14 @@ tensorflow::Status ConvertUnsupportedOperator( return absl::OkStatus(); } -tensorflow::Status ConvertConstOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertConstOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Const"); const auto& tensor = GetTensorAttr(node, "value"); const auto dtype = GetDataTypeAttr(node, "dtype"); - tensorflow::Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); auto& array = model->GetOrCreateArray(node.name()); switch (dtype) { @@ -833,9 +832,9 @@ tensorflow::Status ConvertConstOperator( return absl::OkStatus(); } -tensorflow::Status ConvertConvOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertConvOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Conv2D"); TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2)); @@ -914,7 +913,7 @@ tensorflow::Status ConvertConvOperator( return absl::OkStatus(); } -tensorflow::Status ConvertDepthwiseConvOperator( +absl::Status ConvertDepthwiseConvOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "DepthwiseConv2dNative"); @@ -992,7 +991,7 @@ tensorflow::Status ConvertDepthwiseConvOperator( return absl::OkStatus(); } -tensorflow::Status ConvertDepthToSpaceOperator( +absl::Status ConvertDepthToSpaceOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "DepthToSpace"); @@ -1015,7 +1014,7 @@ tensorflow::Status ConvertDepthToSpaceOperator( return absl::OkStatus(); } -tensorflow::Status ConvertSpaceToDepthOperator( +absl::Status ConvertSpaceToDepthOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "SpaceToDepth"); @@ -1038,7 +1037,7 @@ tensorflow::Status ConvertSpaceToDepthOperator( return absl::OkStatus(); } -tensorflow::Status ConvertBiasAddOperator( +absl::Status ConvertBiasAddOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "BiasAdd"); @@ -1055,9 +1054,9 @@ tensorflow::Status ConvertBiasAddOperator( return absl::OkStatus(); } -tensorflow::Status ConvertRandomUniform( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertRandomUniform(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "RandomUniform"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); @@ -1073,7 +1072,7 @@ tensorflow::Status ConvertRandomUniform( return absl::OkStatus(); } -tensorflow::Status ConvertIdentityOperator( +absl::Status ConvertIdentityOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" || @@ -1096,7 +1095,7 @@ tensorflow::Status ConvertIdentityOperator( return absl::OkStatus(); } -tensorflow::Status ConvertIdentityNOperator( +absl::Status ConvertIdentityNOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "IdentityN"); @@ -1114,7 +1113,7 @@ tensorflow::Status ConvertIdentityNOperator( return absl::OkStatus(); } -tensorflow::Status ConvertFakeQuantWithMinMaxArgs( +absl::Status ConvertFakeQuantWithMinMaxArgs( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs"); @@ -1135,7 +1134,7 @@ tensorflow::Status ConvertFakeQuantWithMinMaxArgs( return absl::OkStatus(); } -tensorflow::Status ConvertFakeQuantWithMinMaxVars( +absl::Status ConvertFakeQuantWithMinMaxVars( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars"); @@ -1157,7 +1156,7 @@ tensorflow::Status ConvertFakeQuantWithMinMaxVars( return absl::OkStatus(); } -tensorflow::Status ConvertSqueezeOperator( +absl::Status ConvertSqueezeOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Squeeze"); @@ -1178,9 +1177,9 @@ tensorflow::Status ConvertSqueezeOperator( return absl::OkStatus(); } -tensorflow::Status ConvertSplitOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertSplitOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Split"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowSplitOperator; @@ -1196,9 +1195,10 @@ tensorflow::Status ConvertSplitOperator( return absl::OkStatus(); } -tensorflow::Status ConvertSplitVOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertSplitVOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { CHECK_EQ(node.op(), "SplitV"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); auto* op = new TensorFlowSplitVOperator; @@ -1215,9 +1215,10 @@ tensorflow::Status ConvertSplitVOperator( return absl::OkStatus(); } -tensorflow::Status ConvertSwitchOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertSwitchOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { CHECK_EQ(node.op(), "Switch"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowSwitchOperator; @@ -1230,7 +1231,7 @@ tensorflow::Status ConvertSwitchOperator( return absl::OkStatus(); } -tensorflow::Status ConvertSoftmaxOperator( +absl::Status ConvertSoftmaxOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Softmax"); @@ -1250,9 +1251,9 @@ tensorflow::Status ConvertSoftmaxOperator( return absl::OkStatus(); } -tensorflow::Status ConvertLRNOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertLRNOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "LRN"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); @@ -1267,7 +1268,7 @@ tensorflow::Status ConvertLRNOperator( return absl::OkStatus(); } -tensorflow::Status ConvertMaxPoolOperator( +absl::Status ConvertMaxPoolOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "MaxPool"); @@ -1310,7 +1311,7 @@ tensorflow::Status ConvertMaxPoolOperator( return absl::OkStatus(); } -tensorflow::Status ConvertAvgPoolOperator( +absl::Status ConvertAvgPoolOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "AvgPool"); @@ -1349,7 +1350,7 @@ tensorflow::Status ConvertAvgPoolOperator( return absl::OkStatus(); } -tensorflow::Status ConvertBatchMatMulOperator( +absl::Status ConvertBatchMatMulOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); @@ -1372,9 +1373,10 @@ tensorflow::Status ConvertBatchMatMulOperator( return absl::OkStatus(); } -tensorflow::Status ConvertMatMulOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertMatMulOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); CHECK(!HasAttr(node, "adjoint_a") || @@ -1396,9 +1398,10 @@ tensorflow::Status ConvertMatMulOperator( return absl::OkStatus(); } -tensorflow::Status ConvertConcatOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertConcatOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { Operator* op = nullptr; if (node.op() == "Concat") { op = new TensorFlowConcatOperator; @@ -1421,7 +1424,7 @@ tensorflow::Status ConvertConcatOperator( return absl::OkStatus(); } -tensorflow::Status ConvertMirrorPadOperator( +absl::Status ConvertMirrorPadOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { if (node.op() != "MirrorPad") { @@ -1456,7 +1459,7 @@ enum FlexSupport { kFlexOk, kFlexNotOk }; // kAnyNumInputs is passed in. If kFlexOk is passed in the resulting operator // will be eligible for being exported as a flex op. template -tensorflow::Status ConvertSimpleOperatorGeneric( +absl::Status ConvertSimpleOperatorGeneric( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { if (NumInputs != kAnyNumInputs) { @@ -1484,16 +1487,17 @@ tensorflow::Status ConvertSimpleOperatorGeneric( // Convert a simple operator which is not valid as a flex op. template -tensorflow::Status ConvertSimpleOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertSimpleOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { return ConvertSimpleOperatorGeneric( node, tf_import_flags, model_flags, model); } // Convert a simple operator which is valid as a flex op. template -tensorflow::Status ConvertSimpleOperatorFlexOk( +absl::Status ConvertSimpleOperatorFlexOk( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { return ConvertSimpleOperatorGeneric( @@ -1503,7 +1507,7 @@ tensorflow::Status ConvertSimpleOperatorFlexOk( // Same as ConvertConstOperator, but revert to ConvertUnsupportedOperator if // the types are not supported. Converting Const operators here avoids // expensive copies of the protocol buffers downstream in the flex delegate. -tensorflow::Status ConditionallyConvertConstOperator( +absl::Status ConditionallyConvertConstOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { // We avoid incomplete and zero shapes because the resulting arrays @@ -1531,7 +1535,7 @@ tensorflow::Status ConditionallyConvertConstOperator( } } -tensorflow::Status ConvertStridedSliceOperator( +absl::Status ConvertStridedSliceOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "StridedSlice"); @@ -1560,7 +1564,7 @@ tensorflow::Status ConvertStridedSliceOperator( return absl::OkStatus(); } -tensorflow::Status ConvertPlaceholderOperator( +absl::Status ConvertPlaceholderOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput"); @@ -1600,15 +1604,15 @@ tensorflow::Status ConvertPlaceholderOperator( return absl::OkStatus(); } -tensorflow::Status ConvertNoOpOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertNoOpOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { return absl::OkStatus(); } -tensorflow::Status ConvertCastOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertCastOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Cast"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT"); @@ -1622,9 +1626,9 @@ tensorflow::Status ConvertCastOperator( return absl::OkStatus(); } -tensorflow::Status ConvertFloorOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertFloorOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Floor"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto data_type = GetDataTypeAttr(node, "T"); @@ -1636,9 +1640,9 @@ tensorflow::Status ConvertFloorOperator( return absl::OkStatus(); } -tensorflow::Status ConvertCeilOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertCeilOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Ceil"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto data_type = GetDataTypeAttr(node, "T"); @@ -1650,9 +1654,9 @@ tensorflow::Status ConvertCeilOperator( return absl::OkStatus(); } -tensorflow::Status ConvertRoundOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertRoundOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Round"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto data_type = GetDataTypeAttr(node, "T"); @@ -1664,9 +1668,10 @@ tensorflow::Status ConvertRoundOperator( return absl::OkStatus(); } -tensorflow::Status ConvertGatherOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertGatherOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { CHECK(node.op() == "Gather" || node.op() == "GatherV2"); if (node.op() == "Gather") TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); @@ -1693,7 +1698,7 @@ tensorflow::Status ConvertGatherOperator( return absl::OkStatus(); } -tensorflow::Status ConvertGatherNdOperator( +absl::Status ConvertGatherNdOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "GatherNd"); @@ -1709,7 +1714,7 @@ tensorflow::Status ConvertGatherNdOperator( } template -tensorflow::Status ConvertArgMinMaxOperator( +absl::Status ConvertArgMinMaxOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); @@ -1729,23 +1734,25 @@ tensorflow::Status ConvertArgMinMaxOperator( return absl::OkStatus(); } -tensorflow::Status ConvertArgMaxOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertArgMaxOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { CHECK_EQ(node.op(), "ArgMax"); return ConvertArgMinMaxOperator(node, tf_import_flags, model_flags, model); } -tensorflow::Status ConvertArgMinOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertArgMinOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { CHECK_EQ(node.op(), "ArgMin"); return ConvertArgMinMaxOperator(node, tf_import_flags, model_flags, model); } -tensorflow::Status ConvertResizeBilinearOperator( +absl::Status ConvertResizeBilinearOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "ResizeBilinear"); @@ -1768,7 +1775,7 @@ tensorflow::Status ConvertResizeBilinearOperator( return absl::OkStatus(); } -tensorflow::Status ConvertResizeNearestNeighborOperator( +absl::Status ConvertResizeNearestNeighborOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "ResizeNearestNeighbor"); @@ -1791,7 +1798,7 @@ tensorflow::Status ConvertResizeNearestNeighborOperator( return absl::OkStatus(); } -tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator( +absl::Status ConvertBatchNormWithGlobalNormalizationOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization"); @@ -1841,7 +1848,7 @@ tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator( return absl::OkStatus(); } -tensorflow::Status ConvertFusedBatchNormOperator( +absl::Status ConvertFusedBatchNormOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK((node.op() == "FusedBatchNorm") || (node.op() == "FusedBatchNormV3")); @@ -1896,7 +1903,7 @@ tensorflow::Status ConvertFusedBatchNormOperator( return absl::OkStatus(); } -tensorflow::Status ConvertSpaceToBatchNDOperator( +absl::Status ConvertSpaceToBatchNDOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "SpaceToBatchND"); @@ -1912,7 +1919,7 @@ tensorflow::Status ConvertSpaceToBatchNDOperator( return absl::OkStatus(); } -tensorflow::Status ConvertBatchToSpaceNDOperator( +absl::Status ConvertBatchToSpaceNDOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "BatchToSpaceND"); @@ -1929,9 +1936,10 @@ tensorflow::Status ConvertBatchToSpaceNDOperator( } template -tensorflow::Status ConvertReduceOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertReduceOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new T; op->inputs.push_back(node.input(0)); @@ -1947,9 +1955,9 @@ tensorflow::Status ConvertReduceOperator( } // TODO(b/139320642): Add test when fused op is supported. -tensorflow::Status ConvertSvdfOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertSvdfOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Svdf"); const int input_size = GetInputsCount(node, tf_import_flags); QCHECK(input_size == 4 || input_size == 5) @@ -1977,7 +1985,7 @@ tensorflow::Status ConvertSvdfOperator( } // This is just bare bones support to get the shapes to propagate. -tensorflow::Status ConvertTransposeConvOperator( +absl::Status ConvertTransposeConvOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Conv2DBackpropInput"); @@ -2048,9 +2056,9 @@ tensorflow::Status ConvertTransposeConvOperator( return absl::OkStatus(); } -tensorflow::Status ConvertRangeOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertRangeOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Range"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); auto* op = new RangeOperator; @@ -2073,9 +2081,9 @@ tensorflow::Status ConvertRangeOperator( // they aren't the same thing. tf.stack results in a "Pack" operator. "Stack" // operators also exist, but involve manipulating the TF runtime stack, and are // not directly related to tf.stack() usage. -tensorflow::Status ConvertPackOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertPackOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Pack"); auto op = std::make_unique(); const int num_inputs = GetInputsCount(node, tf_import_flags); @@ -2095,9 +2103,10 @@ tensorflow::Status ConvertPackOperator( return absl::OkStatus(); } -tensorflow::Status ConvertUnpackOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertUnpackOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { CHECK_EQ(node.op(), "Unpack"); auto op = std::make_unique(); const int num_inputs = GetInputsCount(node, tf_import_flags); @@ -2125,7 +2134,7 @@ tensorflow::Status ConvertUnpackOperator( // such ops as RNN back-edges, which is technically incorrect (does not // allow representing the op's semantics) but good enough to get a // graph visualization. -tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge( +absl::Status ConvertOperatorSpecialCasedAsRNNBackEdge( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { // At the moment, the only type of operator special-cased in this way is @@ -2144,9 +2153,9 @@ tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge( return absl::OkStatus(); } -tensorflow::Status ConvertShapeOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertShapeOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "Shape"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto out_type = @@ -2160,7 +2169,7 @@ tensorflow::Status ConvertShapeOperator( return absl::OkStatus(); } -tensorflow::Status ConvertReverseSequenceOperator( +absl::Status ConvertReverseSequenceOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "ReverseSequence"); @@ -2327,9 +2336,10 @@ bool InlineAllFunctions(GraphDef* graphdef) { return graph_modified; } -tensorflow::Status ConvertTopKV2Operator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertTopKV2Operator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { CHECK((node.op() == "TopK") || (node.op() == "TopKV2")); auto op = std::make_unique(); op->inputs.push_back(node.input(0)); @@ -2349,7 +2359,7 @@ tensorflow::Status ConvertTopKV2Operator( return absl::OkStatus(); } -tensorflow::Status ConvertDynamicPartitionOperator( +absl::Status ConvertDynamicPartitionOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { auto op = std::make_unique(); @@ -2367,7 +2377,7 @@ tensorflow::Status ConvertDynamicPartitionOperator( return absl::OkStatus(); } -tensorflow::Status ConvertDynamicStitchOperator( +absl::Status ConvertDynamicStitchOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { // The parallel and non-parallel variants are the same besides whether they @@ -2386,7 +2396,7 @@ tensorflow::Status ConvertDynamicStitchOperator( return absl::OkStatus(); } -tensorflow::Status ConvertSparseToDenseOperator( +absl::Status ConvertSparseToDenseOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "SparseToDense"); @@ -2405,9 +2415,10 @@ tensorflow::Status ConvertSparseToDenseOperator( return absl::OkStatus(); } -tensorflow::Status ConvertOneHotOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - const ModelFlags& model_flags, Model* model) { +absl::Status ConvertOneHotOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, + Model* model) { CHECK_EQ(node.op(), "OneHot"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4)); @@ -2426,7 +2437,7 @@ tensorflow::Status ConvertOneHotOperator( return absl::OkStatus(); } -tensorflow::Status ConvertCTCBeamSearchDecoderOperator( +absl::Status ConvertCTCBeamSearchDecoderOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "CTCBeamSearchDecoder"); @@ -2456,7 +2467,7 @@ tensorflow::Status ConvertCTCBeamSearchDecoderOperator( // This isn't a TensorFlow builtin op. Currently this node can only be generated // with TfLite OpHint API. -tensorflow::Status ConvertUnidirectionalSequenceLstm( +absl::Status ConvertUnidirectionalSequenceLstm( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm"); @@ -2512,7 +2523,7 @@ tensorflow::Status ConvertUnidirectionalSequenceLstm( return absl::OkStatus(); } -tensorflow::Status ConvertLeakyReluOperator( +absl::Status ConvertLeakyReluOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { CHECK_EQ(node.op(), "LeakyRelu"); @@ -2527,7 +2538,7 @@ tensorflow::Status ConvertLeakyReluOperator( return absl::OkStatus(); } -tensorflow::Status ConvertUnidirectionalSequenceRnn( +absl::Status ConvertUnidirectionalSequenceRnn( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model) { DCHECK_EQ(node.op(), "UnidirectionalSequenceRnn"); @@ -2552,7 +2563,7 @@ tensorflow::Status ConvertUnidirectionalSequenceRnn( namespace internal { -using ConverterType = tensorflow::Status (*)( +using ConverterType = absl::Status (*)( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model); using ConverterMapType = std::unordered_map; @@ -2721,10 +2732,10 @@ ConverterMapType GetTensorFlowNodeConverterMap() { }); } -tensorflow::Status ImportTensorFlowNode( - const tensorflow::NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, - Model* model, const ConverterMapType& converter_map) { +absl::Status ImportTensorFlowNode(const tensorflow::NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + const ModelFlags& model_flags, Model* model, + const ConverterMapType& converter_map) { auto converter = converter_map.find(node.op()); if (converter == converter_map.end()) { return ConvertUnsupportedOperator(node, tf_import_flags, model_flags, diff --git a/tensorflow/lite/toco/import_tensorflow_test.cc b/tensorflow/lite/toco/import_tensorflow_test.cc index e39ae062f8dfcc..a9943e3323121b 100644 --- a/tensorflow/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/lite/toco/import_tensorflow_test.cc @@ -47,7 +47,7 @@ using tensorflow::Status; using ::testing::ElementsAre; namespace internal { -using ConverterType = tensorflow::Status (*)( +using ConverterType = absl::Status (*)( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags, Model* model); using ConverterMapType = std::unordered_map; diff --git a/tensorflow/lite/toco/model_cmdline_flags.cc b/tensorflow/lite/toco/model_cmdline_flags.cc index 7aaa742e183086..b916d80c43baa6 100644 --- a/tensorflow/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/lite/toco/model_cmdline_flags.cc @@ -14,19 +14,24 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/toco/model_cmdline_flags.h" +#include +#include +#include +#include +#include #include #include #include "absl/strings/numbers.h" -#include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" -#include "absl/strings/strip.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/toco/args.h" +#include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_graphviz_dump_options.h" #include "tensorflow/lite/toco/toco_port.h" +#include "tensorflow/lite/toco/types.pb.h" // "batch" flag only exists internally #ifdef PLATFORM_GOOGLE diff --git a/tensorflow/lite/toco/model_cmdline_flags_test.cc b/tensorflow/lite/toco/model_cmdline_flags_test.cc index b87e200095c49a..5bdb7e95d18e72 100644 --- a/tensorflow/lite/toco/model_cmdline_flags_test.cc +++ b/tensorflow/lite/toco/model_cmdline_flags_test.cc @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/toco/model_cmdline_flags.h" + #include #include +#include -#include #include #include "tensorflow/lite/testing/util.h" #include "tensorflow/lite/toco/args.h" -#include "tensorflow/lite/toco/model_cmdline_flags.h" namespace toco { namespace { diff --git a/tensorflow/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc index 64199cdbf5778e..2dc1032f7213ef 100644 --- a/tensorflow/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -152,7 +152,7 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, int64_t arithmetic_ops_count; // Convert model. - tensorflow::Status status = + absl::Status status = Convert(input_contents_txt, toco_flags, model_flags, &output_file_contents_txt, &arithmetic_ops_count); @@ -257,8 +257,7 @@ PyObject* RegisterCustomOpdefs(PyObject* list) { // Register extra opdefs to TensorFlow global op registry. tensorflow::OpRegistry::Global()->Register( - [opdef]( - tensorflow::OpRegistrationData* op_reg_data) -> tensorflow::Status { + [opdef](tensorflow::OpRegistrationData* op_reg_data) -> absl::Status { *op_reg_data = tensorflow::OpRegistrationData(opdef); return absl::OkStatus(); }); diff --git a/tensorflow/lite/toco/tflite/export_test.cc b/tensorflow/lite/toco/tflite/export_test.cc index e43b1bfe71fd39..5fe3901c83195e 100644 --- a/tensorflow/lite/toco/tflite/export_test.cc +++ b/tensorflow/lite/toco/tflite/export_test.cc @@ -15,9 +15,12 @@ limitations under the License. #include "tensorflow/lite/toco/tflite/export.h" #include +#include #include +#include #include #include +#include #include #include diff --git a/tensorflow/lite/toco/tflite/import.cc b/tensorflow/lite/toco/tflite/import.cc index 4659fbfb89ee7e..7285635d02ba4d 100644 --- a/tensorflow/lite/toco/tflite/import.cc +++ b/tensorflow/lite/toco/tflite/import.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/toco/tflite/import.h" +#include +#include +#include #include #include diff --git a/tensorflow/lite/toco/tflite/import.h b/tensorflow/lite/toco/tflite/import.h index 30930fdc1e33a6..21a003a977d3fc 100644 --- a/tensorflow/lite/toco/tflite/import.h +++ b/tensorflow/lite/toco/tflite/import.h @@ -15,7 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOCO_TFLITE_IMPORT_H_ #define TENSORFLOW_LITE_TOCO_TFLITE_IMPORT_H_ +#include #include +#include #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/toco/model.h" diff --git a/tensorflow/lite/toco/tflite/import_test.cc b/tensorflow/lite/toco/tflite/import_test.cc index b73c673c9199d3..0eb5a8329113f8 100644 --- a/tensorflow/lite/toco/tflite/import_test.cc +++ b/tensorflow/lite/toco/tflite/import_test.cc @@ -14,8 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/toco/tflite/import.h" +#include #include #include +#include #include #include diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index c73e30781faf09..06cd8728549d9b 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -14,10 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/toco/tflite/operator.h" +#include #include #include #include #include +#include #include "absl/log/check.h" #include "absl/log/log.h" diff --git a/tensorflow/lite/toco/tflite/operator.h b/tensorflow/lite/toco/tflite/operator.h index 836c287674e084..7dd941adc860ce 100644 --- a/tensorflow/lite/toco/tflite/operator.h +++ b/tensorflow/lite/toco/tflite/operator.h @@ -15,7 +15,11 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOCO_TFLITE_OPERATOR_H_ #define TENSORFLOW_LITE_TOCO_TFLITE_OPERATOR_H_ +#include +#include +#include #include +#include #include "flatbuffers/flatbuffers.h" #include "flatbuffers/flexbuffers.h" diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc index 8f1d42ad8fb9d8..6e021dd3538809 100644 --- a/tensorflow/lite/toco/tflite/operator_test.cc +++ b/tensorflow/lite/toco/tflite/operator_test.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/toco/tflite/operator.h" +#include +#include +#include #include #include diff --git a/tensorflow/lite/toco/tflite/types.cc b/tensorflow/lite/toco/tflite/types.cc index f67aad1f7f7b0d..b84312ddcf0eec 100644 --- a/tensorflow/lite/toco/tflite/types.cc +++ b/tensorflow/lite/toco/tflite/types.cc @@ -14,7 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/toco/tflite/types.h" +#include +#include +#include #include +#include #include "absl/log/log.h" #include "flatbuffers/buffer.h" // from @flatbuffers diff --git a/tensorflow/lite/toco/tflite/types.h b/tensorflow/lite/toco/tflite/types.h index cccba6a45db5c5..ef655b60b1dc20 100644 --- a/tensorflow/lite/toco/tflite/types.h +++ b/tensorflow/lite/toco/tflite/types.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOCO_TFLITE_TYPES_H_ #define TENSORFLOW_LITE_TOCO_TFLITE_TYPES_H_ +#include + #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers diff --git a/tensorflow/lite/toco/tflite/types_test.cc b/tensorflow/lite/toco/tflite/types_test.cc index 5ed493c2ac066f..505cb2284214fb 100644 --- a/tensorflow/lite/toco/tflite/types_test.cc +++ b/tensorflow/lite/toco/tflite/types_test.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/lite/toco/tflite/types.h" #include +#include +#include +#include +#include #include #include diff --git a/tensorflow/lite/toco/toco.cc b/tensorflow/lite/toco/toco.cc index bd3cedb947867c..5c93f737f0b612 100644 --- a/tensorflow/lite/toco/toco.cc +++ b/tensorflow/lite/toco/toco.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include #include +#include "absl/status/status.h" #include "tensorflow/lite/toco/model_cmdline_flags.h" #include "tensorflow/lite/toco/toco_cmdline_flags.h" #include "tensorflow/lite/toco/toco_convert.h" diff --git a/tensorflow/lite/toco/toco_cmdline_flags.cc b/tensorflow/lite/toco/toco_cmdline_flags.cc index 505b9ec6301ba0..55030247d2efae 100644 --- a/tensorflow/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/lite/toco/toco_cmdline_flags.cc @@ -15,18 +15,21 @@ limitations under the License. #include "tensorflow/lite/toco/toco_cmdline_flags.h" +#include +#include #include #include #include -#include "absl/strings/numbers.h" -#include "absl/strings/str_join.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/str_split.h" -#include "absl/strings/strip.h" #include "absl/types/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/toco_port.h" +#include "tensorflow/lite/toco/types.pb.h" namespace toco { diff --git a/tensorflow/lite/toco/toco_convert.cc b/tensorflow/lite/toco/toco_convert.cc index f3c0e46e5786db..9cfdc9cb34e814 100644 --- a/tensorflow/lite/toco/toco_convert.cc +++ b/tensorflow/lite/toco/toco_convert.cc @@ -12,11 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include -#include "absl/strings/string_view.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/model_cmdline_flags.h" #include "tensorflow/lite/toco/model_flags.pb.h" @@ -25,8 +27,6 @@ limitations under the License. #include "tensorflow/lite/toco/toco_port.h" #include "tensorflow/lite/toco/toco_tooling.h" #include "tensorflow/lite/toco/toco_types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/logging.h" namespace toco { namespace { diff --git a/tensorflow/lite/toco/toco_convert_test.cc b/tensorflow/lite/toco/toco_convert_test.cc index 8206ca15c9924a..cc7ec096ff4900 100644 --- a/tensorflow/lite/toco/toco_convert_test.cc +++ b/tensorflow/lite/toco/toco_convert_test.cc @@ -16,9 +16,10 @@ limitations under the License. #include -#include #include #include "tensorflow/lite/testing/util.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/toco_port.h" namespace toco { diff --git a/tensorflow/lite/toco/toco_tooling.cc b/tensorflow/lite/toco/toco_tooling.cc index 5b38d535c7e8ac..6e2ab030c3e49d 100644 --- a/tensorflow/lite/toco/toco_tooling.cc +++ b/tensorflow/lite/toco/toco_tooling.cc @@ -14,12 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/toco/toco_tooling.h" -#include #include #include #include -#include "absl/memory/memory.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/str_join.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/toco/allocate_transient_arrays.h" @@ -33,6 +34,7 @@ limitations under the License. #include "tensorflow/lite/toco/tflite/import.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/types.pb.h" namespace toco { namespace { diff --git a/tensorflow/lite/toco/toco_tooling.h b/tensorflow/lite/toco/toco_tooling.h index 6fe4fb064af1d4..64d78f5bbe09a0 100644 --- a/tensorflow/lite/toco/toco_tooling.h +++ b/tensorflow/lite/toco/toco_tooling.h @@ -18,6 +18,9 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc index 4e548f92e46c4b..51c2732058c4ef 100644 --- a/tensorflow/lite/toco/tooling_util.cc +++ b/tensorflow/lite/toco/tooling_util.cc @@ -15,25 +15,31 @@ limitations under the License. #include "tensorflow/lite/toco/tooling_util.h" #include +#include #include #include +#include #include #include #include #include #include +#include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" -#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "re2/re2.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/toco/dump_graphviz.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_graphviz_dump_options.h" +#include "tensorflow/lite/toco/types.pb.h" namespace toco { diff --git a/tensorflow/lite/toco/tooling_util.h b/tensorflow/lite/toco/tooling_util.h index b9419f19dbf649..f87982e40dd44e 100644 --- a/tensorflow/lite/toco/tooling_util.h +++ b/tensorflow/lite/toco/tooling_util.h @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include +#include #include #include #include @@ -24,6 +26,8 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/lite/toco/tooling_util_test.cc b/tensorflow/lite/toco/tooling_util_test.cc index f0da510c69540a..ef2364fecfc6cd 100644 --- a/tensorflow/lite/toco/tooling_util_test.cc +++ b/tensorflow/lite/toco/tooling_util_test.cc @@ -12,15 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include "tensorflow/lite/toco/tooling_util.h" + +#include #include #include +#include "absl/status/status.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/lite/testing/util.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/toco_port.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index 24473fa296142a..9347b7bb127afb 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -125,6 +125,7 @@ cc_test( "//tensorflow/lite/tools:logging", "//tensorflow/lite/tools/delegates:delegate_provider_hdr", "@com_google_absl//absl/algorithm", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", @@ -183,6 +184,7 @@ cc_library( "//tensorflow/lite/tools/delegates:tflite_execution_providers", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@ruy//ruy/profiler", ], diff --git a/tensorflow/lite/tools/benchmark/benchmark_main.cc b/tensorflow/lite/tools/benchmark/benchmark_main.cc index 76ae68fe98e13d..43b249080a2a07 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_main.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_main.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h" #include "tensorflow/lite/tools/logging.h" diff --git a/tensorflow/lite/tools/benchmark/benchmark_model.h b/tensorflow/lite/tools/benchmark/benchmark_model.h index 3192c741cc71d7..93072ffcdddf34 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_model.h +++ b/tensorflow/lite/tools/benchmark/benchmark_model.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include diff --git a/tensorflow/lite/tools/benchmark/benchmark_test.cc b/tensorflow/lite/tools/benchmark/benchmark_test.cc index 5f97b32deb1e37..053035aa752702 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_test.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_test.cc @@ -12,6 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include + +#include "absl/log/log.h" #ifndef _WIN32 #include #endif // !defined(_WIN32) @@ -26,7 +32,6 @@ limitations under the License. #include #include #include "absl/algorithm/algorithm.h" -#include "absl/memory/memory.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "tensorflow/lite/core/c/c_api_types.h" diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h index 468d9037b52c35..489fe75da793ed 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_TFLITE_MODEL_H_ #include +#include #include #include #include diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model_test.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model_test.cc index b5cf1f425d67d5..c2ddcb76bd3781 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model_test.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include #include #include "absl/strings/str_cat.h" #include "tensorflow/lite/core/c/c_api_types.h" diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_performance_options_main.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_performance_options_main.cc index 7c8c6b39f78093..fe46a9a603fb8e 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_performance_options_main.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_performance_options_main.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/lite/tools/benchmark/benchmark_performance_options.h" #include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h" #include "tensorflow/lite/tools/logging.h" diff --git a/tensorflow/lite/tools/benchmark/benchmark_utils_test.cc b/tensorflow/lite/tools/benchmark/benchmark_utils_test.cc index cb1517293f7507..adaa239e8b0b35 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_utils_test.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_utils_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include "tensorflow/lite/profiling/time.h" diff --git a/tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.cc b/tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.cc index 52162ca0d96cb3..cc2fa6d886d3f1 100644 --- a/tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.cc +++ b/tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.h" +#include +#include #include #include "xla/tsl/util/stats_calculator.h" diff --git a/tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.h b/tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.h index 30328aa5c7e383..a84e17931d56ac 100644 --- a/tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.h +++ b/tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_EXPERIMENTAL_C_BENCHMARK_C_API_H_ #define TENSORFLOW_LITE_TOOLS_BENCHMARK_EXPERIMENTAL_C_BENCHMARK_C_API_H_ +#include + #include "tensorflow/lite/core/c/c_api_types.h" // ----------------------------------------------------------------------------- diff --git a/tensorflow/lite/tools/benchmark/profiling_listener.cc b/tensorflow/lite/tools/benchmark/profiling_listener.cc index 3faa54cc9a3cf1..9ffc1c0fa98246 100644 --- a/tensorflow/lite/tools/benchmark/profiling_listener.cc +++ b/tensorflow/lite/tools/benchmark/profiling_listener.cc @@ -15,7 +15,8 @@ limitations under the License. #include "tensorflow/lite/tools/benchmark/profiling_listener.h" -#include +#include +#include #include #include "tensorflow/lite/interpreter.h" diff --git a/tensorflow/lite/tools/benchmark/profiling_listener.h b/tensorflow/lite/tools/benchmark/profiling_listener.h index a09667ccbcc4d3..cc1fd3d774e6bb 100644 --- a/tensorflow/lite/tools/benchmark/profiling_listener.h +++ b/tensorflow/lite/tools/benchmark/profiling_listener.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_PROFILING_LISTENER_H_ #define TENSORFLOW_LITE_TOOLS_BENCHMARK_PROFILING_LISTENER_H_ +#include #include #include diff --git a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake index 983e68ca6da3a9..677ae1f59a6035 100644 --- a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake +++ b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake @@ -23,7 +23,7 @@ OverridableFetchContent_Declare( xnnpack GIT_REPOSITORY https://github.com/google/XNNPACK # Sync with tensorflow/workspace2.bzl - GIT_TAG 983d013300f19fd3f4e33220b6401408e97a8d12 + GIT_TAG 02764b305b430aec42c3df85ba32b9a3f8d6e3d4 GIT_PROGRESS TRUE PREFIX "${CMAKE_BINARY_DIR}" SOURCE_DIR "${CMAKE_BINARY_DIR}/xnnpack" diff --git a/tensorflow/lite/tools/command_line_flags.h b/tensorflow/lite/tools/command_line_flags.h index a853f552f9fd89..2d729f59b6639e 100644 --- a/tensorflow/lite/tools/command_line_flags.h +++ b/tensorflow/lite/tools/command_line_flags.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOOLS_COMMAND_LINE_FLAGS_H_ #define TENSORFLOW_LITE_TOOLS_COMMAND_LINE_FLAGS_H_ +#include + #include #include #include diff --git a/tensorflow/lite/tools/delegates/compatibility/common/delegate_compatibility_checker_base.cc b/tensorflow/lite/tools/delegates/compatibility/common/delegate_compatibility_checker_base.cc index 88a76f67d5fdb9..2cf62c1090e349 100644 --- a/tensorflow/lite/tools/delegates/compatibility/common/delegate_compatibility_checker_base.cc +++ b/tensorflow/lite/tools/delegates/compatibility/common/delegate_compatibility_checker_base.cc @@ -15,11 +15,7 @@ limitations under the License. #include "tensorflow/lite/tools/delegates/compatibility/common/delegate_compatibility_checker_base.h" -#include -#include -#include -#include -#include +#include #include "absl/status/status.h" #include "tensorflow/lite/model_builder.h" diff --git a/tensorflow/lite/tools/delegates/compatibility/gpu/gpu_delegate_compatibility_checker.cc b/tensorflow/lite/tools/delegates/compatibility/gpu/gpu_delegate_compatibility_checker.cc index 2f328a4b5f5394..38a7a7be1a97d9 100644 --- a/tensorflow/lite/tools/delegates/compatibility/gpu/gpu_delegate_compatibility_checker.cc +++ b/tensorflow/lite/tools/delegates/compatibility/gpu/gpu_delegate_compatibility_checker.cc @@ -15,8 +15,6 @@ limitations under the License. #include "tensorflow/lite/tools/delegates/compatibility/gpu/gpu_delegate_compatibility_checker.h" -#include -#include #include #include diff --git a/tensorflow/lite/tools/delegates/default_execution_provider.cc b/tensorflow/lite/tools/delegates/default_execution_provider.cc index 113e20d3b47ffc..22373c9483eb29 100644 --- a/tensorflow/lite/tools/delegates/default_execution_provider.cc +++ b/tensorflow/lite/tools/delegates/default_execution_provider.cc @@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include +#include #include "tensorflow/lite/tools/delegates/delegate_provider.h" diff --git a/tensorflow/lite/tools/delegates/external_delegate_provider.cc b/tensorflow/lite/tools/delegates/external_delegate_provider.cc index 2a8ba20fffe9b0..74cd7d2f10f9ce 100644 --- a/tensorflow/lite/tools/delegates/external_delegate_provider.cc +++ b/tensorflow/lite/tools/delegates/external_delegate_provider.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include diff --git a/tensorflow/lite/tools/delegates/nnapi_delegate_provider.cc b/tensorflow/lite/tools/delegates/nnapi_delegate_provider.cc index c1d14c91f59f57..a281fbd8166288 100644 --- a/tensorflow/lite/tools/delegates/nnapi_delegate_provider.cc +++ b/tensorflow/lite/tools/delegates/nnapi_delegate_provider.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include #include +#include #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #include "tensorflow/lite/nnapi/nnapi_implementation.h" diff --git a/tensorflow/lite/tools/delegates/xnnpack_delegate_provider.cc b/tensorflow/lite/tools/delegates/xnnpack_delegate_provider.cc index ee7d30d48833c2..c6cbcf8e7aab6a 100644 --- a/tensorflow/lite/tools/delegates/xnnpack_delegate_provider.cc +++ b/tensorflow/lite/tools/delegates/xnnpack_delegate_provider.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #include "tensorflow/lite/tools/delegates/delegate_provider.h" diff --git a/tensorflow/lite/tools/evaluation/stages/BUILD b/tensorflow/lite/tools/evaluation/stages/BUILD index 07bf204f113a1f..4f72d01366960e 100644 --- a/tensorflow/lite/tools/evaluation/stages/BUILD +++ b/tensorflow/lite/tools/evaluation/stages/BUILD @@ -50,6 +50,7 @@ cc_library( "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", "//tensorflow/lite/tools/evaluation/proto:preprocessing_steps_cc_proto", "@com_google_absl//absl/base", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@libjpeg_turbo//:jpeg", "@local_xla//xla/tsl/util:stats_calculator_portable", @@ -90,6 +91,7 @@ cc_library( "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + "@com_google_absl//absl/log", ], ) @@ -126,6 +128,7 @@ cc_library( "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", "@local_xla//xla/tsl/util:stats_calculator_portable", ], ) @@ -164,6 +167,7 @@ cc_library( "//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + "@com_google_absl//absl/log", ], ) @@ -181,6 +185,7 @@ cc_library( "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", "@FP16", + "@com_google_absl//absl/log", "@local_xla//xla/tsl/util:stats_calculator_portable", ], ) @@ -214,6 +219,7 @@ cc_library( "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", "//tensorflow/lite/tools/evaluation/stages/utils:image_metrics", + "@com_google_absl//absl/log", ], ) @@ -249,5 +255,6 @@ cc_library( "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", ], ) diff --git a/tensorflow/lite/tools/evaluation/stages/image_classification_stage.cc b/tensorflow/lite/tools/evaluation/stages/image_classification_stage.cc index bc5158c8e4d9b3..7dc62f0811b531 100644 --- a/tensorflow/lite/tools/evaluation/stages/image_classification_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/image_classification_stage.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" diff --git a/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.cc b/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.cc index 7b03ec2b139790..068a98247e4f0b 100644 --- a/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/base/casts.h" +#include "absl/log/log.h" #include "absl/strings/ascii.h" #include "jpeglib.h" // from @libjpeg_turbo #include "tensorflow/core/lib/jpeg/jpeg_mem.h" diff --git a/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h b/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h index f16fda5b9a027a..289eda627943a8 100644 --- a/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "xla/tsl/util/stats_calculator.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/c/c_api_types.h" diff --git a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc index f79089129285cc..ae8d06cae88fd8 100644 --- a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "fp16.h" // from @FP16 +#include "absl/log/log.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" diff --git a/tensorflow/lite/tools/evaluation/stages/object_detection_average_precision_stage.cc b/tensorflow/lite/tools/evaluation/stages/object_detection_average_precision_stage.cc index 6e4bfe595ef2d6..65827849be67e2 100644 --- a/tensorflow/lite/tools/evaluation/stages/object_detection_average_precision_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/object_detection_average_precision_stage.cc @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/log/log.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" diff --git a/tensorflow/lite/tools/evaluation/stages/object_detection_stage.cc b/tensorflow/lite/tools/evaluation/stages/object_detection_stage.cc index cd7b04931765bb..7e50efa7f84807 100644 --- a/tensorflow/lite/tools/evaluation/stages/object_detection_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/object_detection_stage.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" diff --git a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc index fdf80a5f2c03cf..9f11a45fb8ce4e 100644 --- a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/base/attributes.h" +#include "absl/log/log.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/interpreter_builder.h" diff --git a/tensorflow/lite/tools/evaluation/stages/topk_accuracy_eval_stage.cc b/tensorflow/lite/tools/evaluation/stages/topk_accuracy_eval_stage.cc index e25d2aa95b9ceb..7b6b9f5ff9b322 100644 --- a/tensorflow/lite/tools/evaluation/stages/topk_accuracy_eval_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/topk_accuracy_eval_stage.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" diff --git a/tensorflow/lite/tools/evaluation/stages/utils/BUILD b/tensorflow/lite/tools/evaluation/stages/utils/BUILD index 2548d88a3d849f..f2443ad678b2db 100644 --- a/tensorflow/lite/tools/evaluation/stages/utils/BUILD +++ b/tensorflow/lite/tools/evaluation/stages/utils/BUILD @@ -29,6 +29,7 @@ cc_library( deps = [ "//tensorflow/core:tflite_portable_logging", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", ], ) diff --git a/tensorflow/lite/tools/evaluation/stages/utils/image_metrics.cc b/tensorflow/lite/tools/evaluation/stages/utils/image_metrics.cc index ae12fcad58ca85..df9918db718611 100644 --- a/tensorflow/lite/tools/evaluation/stages/utils/image_metrics.cc +++ b/tensorflow/lite/tools/evaluation/stages/utils/image_metrics.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" #include "tensorflow/core/platform/logging.h" namespace tflite { diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc index 1dbb26a0176d91..ab1a88c413ad08 100644 --- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc @@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include +#include +#include #include #include #include diff --git a/tensorflow/lite/tools/optimize/quantize_model.cc b/tensorflow/lite/tools/optimize/quantize_model.cc index e962742425bdea..ef7a532d5f93dd 100644 --- a/tensorflow/lite/tools/optimize/quantize_model.cc +++ b/tensorflow/lite/tools/optimize/quantize_model.cc @@ -30,6 +30,7 @@ limitations under the License. #include "flatbuffers/flexbuffers.h" #include "absl/strings/str_cat.h" +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "tensorflow/compiler/mlir/lite/tools/optimize/operator_property.h" #include "tensorflow/lite/context.h" #include "tensorflow/lite/core/api/error_reporter.h" @@ -150,7 +151,8 @@ bool IsFloatTensor(const SubGraphT* subgraph, int32_t tensor_idx) { operator_property::OperatorProperty GetOperatorProperty( const std::unordered_set& operator_names, const ModelT* model, int subgraph_index, int op_idx, const string& operator_name, - const TensorType& activations_type, bool disable_per_channel = false) { + const TensorType& activations_type, bool disable_per_channel = false, + bool disable_per_channel_quantization_for_dense_layers = false) { operator_property::OperatorProperty property = operator_property::GetOperatorProperty(model, subgraph_index, op_idx); const SubGraphT* subgraph = model->subgraphs[subgraph_index].get(); @@ -175,6 +177,14 @@ operator_property::OperatorProperty GetOperatorProperty( } } } + if (disable_per_channel_quantization_for_dense_layers && + op_code == BuiltinOperator_FULLY_CONNECTED) { + for (auto& input : property.inputs) { + if (input.second.per_axis) { + input.second.per_axis = false; + } + } + } return property; } @@ -1513,6 +1523,7 @@ TfLiteStatus QuantizeWeightsInputOutput( const std::unordered_set& operator_names, const std::unordered_set& real_value_op_set, const TensorType& activations_type, bool disable_per_channel, + bool disable_per_channel_quantization_for_dense_layers, ErrorReporter* error_reporter) { // Flag to track unsupported ops. bool quantization_not_supported = false; @@ -1533,7 +1544,8 @@ TfLiteStatus QuantizeWeightsInputOutput( : subgraph->tensors[op->inputs[0]]->name; operator_property::OperatorProperty property = GetOperatorProperty( operator_names, model, subgraph_idx, op_idx, operator_name, - activations_type, disable_per_channel); + activations_type, disable_per_channel, + disable_per_channel_quantization_for_dense_layers); if (!IsRealValueOp(real_value_op_set, operator_name)) { continue; } @@ -1583,13 +1595,13 @@ TfLiteStatus QuantizeWeightsInputOutput( } // Quantize bias. -TfLiteStatus QuantizeBiases(ModelT* model, - const std::unordered_set& operator_names, - const std::unordered_set& real_value_op_set, - const TensorType& activations_type, - const TensorType& bias_type, - bool disable_per_channel, - ErrorReporter* error_reporter) { +TfLiteStatus QuantizeBiases( + ModelT* model, const std::unordered_set& operator_names, + const std::unordered_set& real_value_op_set, + const TensorType& activations_type, const TensorType& bias_type, + bool disable_per_channel, + bool disable_per_channel_quantization_for_dense_layers, + ErrorReporter* error_reporter) { for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); subgraph_idx++) { SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get(); @@ -1603,7 +1615,8 @@ TfLiteStatus QuantizeBiases(ModelT* model, const string operator_name = subgraph->tensors[op->outputs[0]]->name; operator_property::OperatorProperty property = GetOperatorProperty( operator_names, model, subgraph_idx, op_idx, operator_name, - activations_type, disable_per_channel); + activations_type, disable_per_channel, + disable_per_channel_quantization_for_dense_layers); if (!property.quantizable || !IsRealValueOp(real_value_op_set, operator_name)) { continue; @@ -1684,6 +1697,7 @@ TfLiteStatus FillQuantizationParams( ModelT* model, const std::unordered_set& operator_names, const std::unordered_set& real_value_op_set, const TensorType& activations_type, bool disable_per_channel, + bool disable_per_channel_quantization_for_dense_layers, ErrorReporter* error_reporter) { for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); subgraph_idx++) { @@ -1697,9 +1711,10 @@ TfLiteStatus FillQuantizationParams( } if (!op->outputs.empty()) { const string operator_name = subgraph->tensors[op->outputs[0]]->name; - property = GetOperatorProperty(operator_names, model, subgraph_idx, - op_idx, operator_name, activations_type, - disable_per_channel); + property = GetOperatorProperty( + operator_names, model, subgraph_idx, op_idx, operator_name, + activations_type, disable_per_channel, + disable_per_channel_quantization_for_dense_layers); if (!IsRealValueOp(real_value_op_set, operator_name)) { continue; } @@ -1783,8 +1798,8 @@ TfLiteStatus FillQuantizationParams( return kTfLiteError; } } // loop over op inputs - } // loop over ops - } // loop over subgraphs + } // loop over ops + } // loop over subgraphs return kTfLiteOk; } @@ -1793,6 +1808,7 @@ TfLiteStatus EnsureBiasScaleCompatibility( ModelT* model, const std::unordered_set& operator_names, const std::unordered_set& real_value_op_set, const TensorType& activations_type, bool disable_per_channel, + bool disable_per_channel_quantization_for_dense_layers, ErrorReporter* error_reporter) { for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs.size(); subgraph_idx++) { @@ -1805,7 +1821,8 @@ TfLiteStatus EnsureBiasScaleCompatibility( const string operator_name = subgraph->tensors[op->outputs[0]]->name; operator_property::OperatorProperty property = GetOperatorProperty( operator_names, model, subgraph_idx, op_idx, operator_name, - activations_type, disable_per_channel); + activations_type, disable_per_channel, + disable_per_channel_quantization_for_dense_layers); if (!IsRealValueOp(real_value_op_set, operator_name)) { continue; } @@ -1939,24 +1956,25 @@ TfLiteStatus EnsureBiasScaleCompatibility( } // namespace // Assumes that the operators in the model have been topologically sorted. -TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, - ModelT* model, const TensorType& input_type, - const TensorType& output_type, bool allow_float, - const std::unordered_set& operator_names, - const TensorType& activations_type, - const TensorType& bias_type, - bool disable_per_channel, - ErrorReporter* error_reporter, - bool handle_external_state = false) { +TfLiteStatus QuantizeModel( + flatbuffers::FlatBufferBuilder* builder, ModelT* model, + const TensorType& input_type, const TensorType& output_type, + bool allow_float, const std::unordered_set& operator_names, + const TensorType& activations_type, const TensorType& bias_type, + bool disable_per_channel, + bool disable_per_channel_quantization_for_dense_layers, + ErrorReporter* error_reporter, bool handle_external_state = false) { auto real_value_op_set = PopulateRealValueOpSet(model, operator_names, activations_type); TF_LITE_ENSURE_STATUS(DuplicateBiasesWithMultipleUses(model, error_reporter)); TF_LITE_ENSURE_STATUS(FillQuantizationParams( model, operator_names, real_value_op_set, activations_type, - disable_per_channel, error_reporter)); + disable_per_channel, disable_per_channel_quantization_for_dense_layers, + error_reporter)); TF_LITE_ENSURE_STATUS(EnsureBiasScaleCompatibility( model, operator_names, real_value_op_set, activations_type, - disable_per_channel, error_reporter)); + disable_per_channel, disable_per_channel_quantization_for_dense_layers, + error_reporter)); TF_LITE_ENSURE_STATUS( QuantizeIntermediateTensors(model, activations_type, error_reporter)); TF_LITE_ENSURE_STATUS(QuantizeSharedRange(model, error_reporter)); @@ -1964,14 +1982,16 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, QuantizeResources(model, activations_type, error_reporter)); TF_LITE_ENSURE_STATUS(QuantizeWeightsInputOutput( model, allow_float, operator_names, real_value_op_set, activations_type, - disable_per_channel, error_reporter)); + disable_per_channel, disable_per_channel_quantization_for_dense_layers, + error_reporter)); TF_LITE_ENSURE_STATUS(ApplyConstraints(model, operator_names, real_value_op_set, activations_type, error_reporter)); SetOperatorPropertyBiasType(model, bias_type); - TF_LITE_ENSURE_STATUS(QuantizeBiases(model, operator_names, real_value_op_set, - activations_type, bias_type, - disable_per_channel, error_reporter)); + TF_LITE_ENSURE_STATUS(QuantizeBiases( + model, operator_names, real_value_op_set, activations_type, bias_type, + disable_per_channel, disable_per_channel_quantization_for_dense_layers, + error_reporter)); utils::SetOperatorCodeVersion(model); TF_LITE_ENSURE_STATUS( SetInputAndOutputTypes(model, input_type, output_type, activations_type, @@ -1992,10 +2012,13 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, const TensorType& activations_type, const TensorType& bias_type, ErrorReporter* error_reporter) { - return QuantizeModel(builder, model, input_type, output_type, allow_float, - operator_names, activations_type, - /*bias_type=*/bias_type, - /*disable_per_channel=*/false, error_reporter); + return QuantizeModel( + builder, model, input_type, output_type, allow_float, operator_names, + activations_type, + /*bias_type=*/bias_type, + /*disable_per_channel=*/false, + /*disable_per_channel_quantization_for_dense_layers=*/false, + error_reporter); } TfLiteStatus QuantizeModelAllOperators( @@ -2003,10 +2026,12 @@ TfLiteStatus QuantizeModelAllOperators( const TensorType& input_type, const TensorType& output_type, bool allow_float, const TensorType& activations_type, const TensorType& bias_type, ErrorReporter* error_reporter) { - return QuantizeModel(builder, model, input_type, output_type, allow_float, - GetAllOperatorOutputs(model), activations_type, - bias_type, - /*disable_per_channel=*/false, error_reporter); + return QuantizeModel( + builder, model, input_type, output_type, allow_float, + GetAllOperatorOutputs(model), activations_type, bias_type, + /*disable_per_channel=*/false, + /*disable_per_channel_quantization_for_dense_layers=*/false, + error_reporter); } TfLiteStatus QuantizeModelAllOperators( @@ -2014,10 +2039,13 @@ TfLiteStatus QuantizeModelAllOperators( const TensorType& input_type, const TensorType& output_type, bool allow_float, const TensorType& activations_type, const TensorType& bias_type, bool disable_per_channel, + bool disable_per_channel_quantization_for_dense_layers, ErrorReporter* error_reporter) { return QuantizeModel(builder, model, input_type, output_type, allow_float, GetAllOperatorOutputs(model), activations_type, - bias_type, disable_per_channel, error_reporter); + bias_type, disable_per_channel, + disable_per_channel_quantization_for_dense_layers, + error_reporter); } TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, @@ -2029,30 +2057,35 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, /*activations_type=*/TensorType_INT8, /*bias_type=*/TensorType_INT32, error_reporter); } -TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, - ModelT* model, const TensorType& input_type, - const TensorType& output_type, bool allow_float, - bool disable_per_channel, - ErrorReporter* error_reporter) { +TfLiteStatus QuantizeModel( + flatbuffers::FlatBufferBuilder* builder, ModelT* model, + const TensorType& input_type, const TensorType& output_type, + bool allow_float, bool disable_per_channel, + bool disable_per_channel_quantization_for_dense_layers, + ErrorReporter* error_reporter) { return QuantizeModel(builder, model, input_type, output_type, allow_float, GetAllOperatorOutputs(model), /*activations_type=*/TensorType_INT8, /*bias_type=*/TensorType_INT32, /*disable_per_channel=*/disable_per_channel, + /*disable_per_channel_quantization_for_dense_layers=*/ + disable_per_channel_quantization_for_dense_layers, error_reporter); } -TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, - ModelT* model, const TensorType& input_type, - const TensorType& output_type, bool allow_float, - bool disable_per_channel, - ErrorReporter* error_reporter, - bool handle_external_state) { +TfLiteStatus QuantizeModel( + flatbuffers::FlatBufferBuilder* builder, ModelT* model, + const TensorType& input_type, const TensorType& output_type, + bool allow_float, bool disable_per_channel, + bool disable_per_channel_quantization_for_dense_layers, + ErrorReporter* error_reporter, bool handle_external_state) { return QuantizeModel(builder, model, input_type, output_type, allow_float, GetAllOperatorOutputs(model), /*activations_type=*/TensorType_INT8, /*bias_type=*/TensorType_INT32, /*disable_per_channel=*/disable_per_channel, + /*disable_per_channel_quantization_for_dense_layers=*/ + disable_per_channel_quantization_for_dense_layers, error_reporter, handle_external_state); } diff --git a/tensorflow/lite/tools/optimize/quantize_model.h b/tensorflow/lite/tools/optimize/quantize_model.h index e117cdf4ded409..77c94f430c003b 100644 --- a/tensorflow/lite/tools/optimize/quantize_model.h +++ b/tensorflow/lite/tools/optimize/quantize_model.h @@ -58,23 +58,24 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, // Same as above but with added option of disabling per channel quantization // // Note: This is a private API, subject to change. -TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, - ModelT* input_model, const TensorType& input_type, - const TensorType& output_type, bool allow_float, - bool disable_per_channel, - ErrorReporter* error_reporter); +TfLiteStatus QuantizeModel( + flatbuffers::FlatBufferBuilder* builder, ModelT* input_model, + const TensorType& input_type, const TensorType& output_type, + bool allow_float, bool disable_per_channel, + bool disable_per_channel_quantization_for_dense_layers, + ErrorReporter* error_reporter); // Same as above but with added option of handling quantization of external // state tensors. This assumes first input and output tensors are ouputs and // rest are state tensors which are quantized later with type as // activation type (hence no fake quant ops). // Note: This is a private API, subject to change. -TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, - ModelT* input_model, const TensorType& input_type, - const TensorType& output_type, bool allow_float, - bool disable_per_channel, - ErrorReporter* error_reporter, - bool handle_external_state); +TfLiteStatus QuantizeModel( + flatbuffers::FlatBufferBuilder* builder, ModelT* input_model, + const TensorType& input_type, const TensorType& output_type, + bool allow_float, bool disable_per_channel, + bool disable_per_channel_quantization_for_dense_layers, + ErrorReporter* error_reporter, bool handle_external_state); // Same as above, but enables only quantizing an allowlist of operations, // specified by their operator output name. @@ -115,6 +116,7 @@ TfLiteStatus QuantizeModelAllOperators( const TensorType& input_type, const TensorType& output_type, bool allow_float, const TensorType& activations_type, const TensorType& bias_type, bool disable_per_channel, + bool disable_per_channel_quantization_for_dense_layers, ErrorReporter* error_reporter); // Quantizes input_model and populates the provided builder with the new model @@ -122,15 +124,14 @@ TfLiteStatus QuantizeModelAllOperators( // quantization. // // All functions above call this function underneath. -TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, - ModelT* model, const TensorType& input_type, - const TensorType& output_type, bool allow_float, - const std::unordered_set& operator_names, - const TensorType& activations_type, - const TensorType& bias_type, - bool disable_per_channel, - ErrorReporter* error_reporter, - bool handle_external_state); +TfLiteStatus QuantizeModel( + flatbuffers::FlatBufferBuilder* builder, ModelT* model, + const TensorType& input_type, const TensorType& output_type, + bool allow_float, const std::unordered_set& operator_names, + const TensorType& activations_type, const TensorType& bias_type, + bool disable_per_channel, + bool disable_per_channel_quantization_for_dense_layers, + ErrorReporter* error_reporter, bool handle_external_state); } // namespace optimize } // namespace tflite diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc index 467dd009b9cee2..8a0013b09e6851 100644 --- a/tensorflow/lite/tools/optimize/quantize_model_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -165,6 +165,7 @@ TEST_P(QuantizeConvModelTest, AvoidQuantOpForExternalStates) { auto status = QuantizeModel(&builder_, &model_, TensorType_FLOAT32, TensorType_FLOAT32, /*allow_float=*/true, /*disable_per_channel=*/true, + /*disable_per_channel_quantization_for_dense_layers=*/true, &error_reporter_, /*handle_external_state=*/true); EXPECT_EQ(status, kTfLiteOk); for (const auto& subgraph : model_.subgraphs) { @@ -846,10 +847,12 @@ TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) { } TEST_P(QuantizeConvModel2Test, VerifyConvDisablePerChannelQuantization) { - auto status = - QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_, - false, tensor_type_, bias_type_, - /*disable_per_channel=*/true, &error_reporter_); + auto status = QuantizeModelAllOperators( + &builder_, &model_, tensor_type_, tensor_type_, false, tensor_type_, + bias_type_, + /*disable_per_channel=*/true, + /*disable_per_channel_quantization_for_dense_layers=*/true, + &error_reporter_); ASSERT_EQ(kTfLiteOk, status); const auto& subgraph = model_.subgraphs[0]; auto conv_op = subgraph->operators[0].get(); diff --git a/tensorflow/lite/tools/optimize/sparsity/format_converter_wrapper_pybind11.pyi b/tensorflow/lite/tools/optimize/sparsity/format_converter_wrapper_pybind11.pyi index 8010487155cd69..f8fcba460ed808 100644 --- a/tensorflow/lite/tools/optimize/sparsity/format_converter_wrapper_pybind11.pyi +++ b/tensorflow/lite/tools/optimize/sparsity/format_converter_wrapper_pybind11.pyi @@ -13,9 +13,8 @@ # limitations under the License. # ============================================================================== -from typing import ClassVar +from typing import ClassVar, overload -from typing import overload TF_LITE_DIM_DENSE: TfLiteDimensionType TF_LITE_DIM_SPARSE_CSR: TfLiteDimensionType TF_LITE_ERROR: TfLiteStatus @@ -38,12 +37,10 @@ class TfLiteDimensionType: __entries: ClassVar[dict] = ... def __init__(self, value: int) -> None: ... def __eq__(self, other: object) -> bool: ... - def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: int) -> None: ... @property def name(self) -> str: ... @property @@ -59,12 +56,10 @@ class TfLiteStatus: __entries: ClassVar[dict] = ... def __init__(self, value: int) -> None: ... def __eq__(self, other: object) -> bool: ... - def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: int) -> None: ... @property def name(self) -> str: ... @property diff --git a/tensorflow/lite/tools/serialization/writer.cc b/tensorflow/lite/tools/serialization/writer.cc index 2997736aee049e..5caf6577d88ce3 100644 --- a/tensorflow/lite/tools/serialization/writer.cc +++ b/tensorflow/lite/tools/serialization/writer.cc @@ -18,7 +18,8 @@ limitations under the License. // Usage: // writer -#include +#include +#include #include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" diff --git a/tensorflow/lite/tools/serialization/writer_lib.h b/tensorflow/lite/tools/serialization/writer_lib.h index baa31872aa8692..a9648265192919 100644 --- a/tensorflow/lite/tools/serialization/writer_lib.h +++ b/tensorflow/lite/tools/serialization/writer_lib.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_ #define TENSORFLOW_LITE_TOOLS_SERIALIZATION_WRITER_LIB_H_ +#include +#include #include #include #include diff --git a/tensorflow/lite/tools/serialization/writer_lib_test.cc b/tensorflow/lite/tools/serialization/writer_lib_test.cc index ecacd90f0a10d1..7744544d50bbc2 100644 --- a/tensorflow/lite/tools/serialization/writer_lib_test.cc +++ b/tensorflow/lite/tools/serialization/writer_lib_test.cc @@ -15,13 +15,15 @@ limitations under the License. #include "tensorflow/lite/tools/serialization/writer_lib.h" +#include #include +#include #include +#include #include #include #include #include -#include #include #include diff --git a/tensorflow/lite/tools/serialization/writer_test.cc b/tensorflow/lite/tools/serialization/writer_test.cc index 50326074bcc1f2..46787d560fea2b 100644 --- a/tensorflow/lite/tools/serialization/writer_test.cc +++ b/tensorflow/lite/tools/serialization/writer_test.cc @@ -19,7 +19,10 @@ limitations under the License. // Usage: // writer_test -#include +#include +#include +#include +#include #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/interpreter_builder.h" diff --git a/tensorflow/lite/tools/signature/_pywrap_signature_def_util_wrapper.pyi b/tensorflow/lite/tools/signature/_pywrap_signature_def_util_wrapper.pyi index 53ebca78d7cf1f..3cf76e338b14b7 100644 --- a/tensorflow/lite/tools/signature/_pywrap_signature_def_util_wrapper.pyi +++ b/tensorflow/lite/tools/signature/_pywrap_signature_def_util_wrapper.pyi @@ -14,5 +14,5 @@ # ============================================================================== def ClearSignatureDefs(arg0: list[int]) -> bytes: ... -def GetSignatureDefMap(arg0: list[int]) -> dict[str,bytes]: ... -def SetSignatureDefMap(arg0: list[int], arg1: dict[str,str]) -> bytes: ... +def GetSignatureDefMap(arg0: list[int]) -> dict[str, bytes]: ... +def SetSignatureDefMap(arg0: list[int], arg1: dict[str, str]) -> bytes: ... diff --git a/tensorflow/lite/tools/signature/signature_def_util.cc b/tensorflow/lite/tools/signature/signature_def_util.cc index c2d971e67d7151..5cd7ef8ffd15d5 100644 --- a/tensorflow/lite/tools/signature/signature_def_util.cc +++ b/tensorflow/lite/tools/signature/signature_def_util.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/tools/signature/signature_def_util.h" +#include +#include #include #include #include diff --git a/tensorflow/lite/tools/signature/signature_def_util.h b/tensorflow/lite/tools/signature/signature_def_util.h index bc8e8d1b65e3c6..c55600ccad47ef 100644 --- a/tensorflow/lite/tools/signature/signature_def_util.h +++ b/tensorflow/lite/tools/signature/signature_def_util.h @@ -15,8 +15,10 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOOLS_SIGNATURE_SIGNATURE_DEF_UTIL_H_ #define TENSORFLOW_LITE_TOOLS_SIGNATURE_SIGNATURE_DEF_UTIL_H_ +#include #include +#include "absl/status/status.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" @@ -42,7 +44,7 @@ constexpr char kSignatureDefsMetadataName[] = "signature_defs_metadata"; // // On success, returns tensorflow::OkStatus() or error otherwise. // On error, `model_data_with_signature_defs` is unchanged. -tensorflow::Status SetSignatureDefMap( +absl::Status SetSignatureDefMap( const Model* model, const std::map& signature_def_map, std::string* model_data_with_signature_defs); @@ -65,8 +67,7 @@ absl::Status GetSignatureDefMap( // The function `ClearSignatureDefs` results in `model_data` // containing a serialized Model identical to `model` omitting any // SignatureDef-related metadata or buffers. -tensorflow::Status ClearSignatureDefMap(const Model* model, - std::string* model_data); +absl::Status ClearSignatureDefMap(const Model* model, std::string* model_data); } // namespace tflite diff --git a/tensorflow/lite/tools/signature/signature_def_util_wrapper_pybind11.cc b/tensorflow/lite/tools/signature/signature_def_util_wrapper_pybind11.cc index d1c3ed6beb62a2..61a4e0c945ab08 100644 --- a/tensorflow/lite/tools/signature/signature_def_util_wrapper_pybind11.cc +++ b/tensorflow/lite/tools/signature/signature_def_util_wrapper_pybind11.cc @@ -12,7 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include #include +#include #include "absl/status/status.h" #include "pybind11/pybind11.h" // from @pybind11 diff --git a/tensorflow/lite/tools/versioning/gpu_compatibility.cc b/tensorflow/lite/tools/versioning/gpu_compatibility.cc index 3070ab342ca3fd..fc4c2b48a777b0 100644 --- a/tensorflow/lite/tools/versioning/gpu_compatibility.cc +++ b/tensorflow/lite/tools/versioning/gpu_compatibility.cc @@ -743,15 +743,6 @@ absl::Status CheckGpuDelegateCompatibility(const OpSignature& op_sig, OpSignatureTensorSpec operand = op_sig.inputs[0]; OpSignatureTensorSpec update_slice = op_sig.inputs[1]; OpSignatureTensorSpec start_indices = op_sig.inputs[2]; - if (operand.dims.size() == 4 && operand.dims[0] != 1) { - return absl::UnimplementedError( - "DynamicUpdateSlice only support 4D operand with batch size 1."); - } - - if (start_indices.dims.size() > 1) { - return absl::UnimplementedError( - "DynamicUpdateSlice only support 1D start_indices."); - } if (operand.type != update_slice.type) { return absl::InternalError( @@ -761,9 +752,8 @@ absl::Status CheckGpuDelegateCompatibility(const OpSignature& op_sig, } if (start_indices.dims.size() != 1) { - return absl::InternalError( - absl::StrCat("Start indices must have be 1D, but got: ", - start_indices.dims.size())); + return absl::InternalError(absl::StrCat( + "Start indices must be 1D, but got: ", start_indices.dims.size())); } if (start_indices.type != kTfLiteInt32) { diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index b239b23e71a4b0..f36f467b977a10 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -189,6 +189,8 @@ tf_staging/tensorflow/tools/toolchains/win/20240424/BUILD: tf_staging/tensorflow/tools/toolchains/win/BUILD: tf_staging/tensorflow/tools/toolchains/win/bazel_211/BUILD: tf_staging/tensorflow/tools/toolchains/win/tf_win_05022023/BUILD: +tf_staging/tensorflow/tools/toolchains/win2022/20241118/BUILD: +tf_staging/tensorflow/tools/toolchains/win2022/BUILD: tf_staging/tensorflow/tools/toolchains/win_1803/py38/BUILD: tf_staging/tensorflow/tools/toolchains/win_1803/py39/BUILD: tf_staging/tensorflow/virtual_root_template_v1.__init__:.py @@ -360,7 +362,6 @@ tf_staging/third_party/systemlibs/boringssl.BUILD: tf_staging/third_party/systemlibs/build_defs.bzl.tpl: tf_staging/third_party/systemlibs/curl.BUILD: tf_staging/third_party/systemlibs/cython.BUILD: -tf_staging/third_party/systemlibs/double_conversion.BUILD: tf_staging/third_party/systemlibs/gif.BUILD: tf_staging/third_party/systemlibs/google_cloud_cpp.BUILD: tf_staging/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD: diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 5d131359fb4966..8a781badc6897a 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -686,7 +686,6 @@ pywrap_tensorflow_macro( "//:__subpackages__", "@com_google_absl//:__subpackages__", "@com_google_protobuf//:__subpackages__", - "@double_conversion//:__subpackages__", "@eigen_archive//:__subpackages__", "@local_tsl//tsl:__subpackages__", "@local_xla//xla:__subpackages__", @@ -710,7 +709,6 @@ pywrap_tensorflow_macro( "@cpuinfo//:__subpackages__", "@curl//:__subpackages__", "@dlpack//:__subpackages__", - "@double_conversion//:__subpackages__", "@eigen_archive//:__subpackages__", "@farmhash_archive//:__subpackages__", "@farmhash_gpu_archive//:__subpackages__", @@ -876,6 +874,7 @@ pywrap_tensorflow_macro( # be brought in via other dependencies. "@local_xla//xla/tsl/cuda:cudnn", "@local_xla//xla/tsl/cuda:cufft", + "@local_xla//xla/tsl/cuda:cupti", "@local_xla//xla/tsl/cuda:nccl_rpath", ])) + if_xla_available([ "//tensorflow/compiler/aot:tfcompile_lib", @@ -1446,72 +1445,123 @@ pytype_strict_library( ], ) +pybind_extension( + name = "_pywrap_tensorflow_internal", + srcs = ["pywrap_tensorflow_internal.cc"], + pywrap_only = True, + deps = [], +) + +pybind_extension( + name = "_pywrap_tensorflow_cc_only", + srcs = [], + deps = [ + ":_protobuf_inline_symbols_enforcer", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_session", + "//tensorflow/core/kernels:data_service_ops", + "//tensorflow/core/kernels:reader_ops", + "//tensorflow/distribute/experimental/rpc/kernels:rpc_ops", + "//tensorflow/dtensor/cc:tensor_layout", + "@local_xla//xla/backends/profiler/cpu:python_tracer", + ], +) + +cc_library( + name = "_protobuf_inline_symbols_enforcer", + srcs = ["protobuf_inline_symbols_enforcer.cc"], + deps = [ + "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", + "//tensorflow/core/framework:attr_value_proto_cc", + "//tensorflow/core/framework:function_proto_cc", + "//tensorflow/core/framework:graph_proto_cc", + "//tensorflow/core/protobuf:for_core_protos_cc", + "//tensorflow/dtensor/proto:layout_proto_cc", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", + ], +) + +cc_library( + name = "_pywrap_lib_filter", + deps = if_pywrap( + if_true = [ + "@pybind11_abseil//pybind11_abseil:absl_casters", + "@pybind11_abseil//pybind11_abseil:import_status_module", + "@pybind11_abseil//pybind11_abseil:status_casters", + "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", + ], + ), +) + +cc_library( + name = "_pywrap_lib_exclusion_filter", + deps = if_pywrap( + if_true = [ + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:protobuf_lite", + "@zlib//:zlib", + ], + ), +) + pywrap_library( name = "_pywrap_tensorflow", - cc_deps_filter = [ - "@com_google_protobuf//:protobuf", - "@com_google_protobuf//:protobuf_lite", - "@zlib//:zlib", - ], - linkopts = select({ - "//tensorflow:windows": [ - "-DEFAULTLIB:ws2_32.lib", - "-DEFAULTLIB:advapi32.lib", - "-DEFAULTLIB:crypt32.lib", - "-DEFAULTLIB:Normaliz.lib", - "-DEFAULTLIB:ntdll.lib", + # buildifier: disable=unsorted-dict-items + # @unsorted-dict-items + common_lib_filters = { + "tensorflow/libtensorflow_framework.so.2": "//tensorflow:tensorflow_framework_pywrap_filter", + "tensorflow/libtensorflow_cc.so.2": "//tensorflow:tensorflow_cc_pywrap_filter", + }, + # buildifier: disable=unsorted-dict-items + # @unsorted-dict-items + common_lib_linkopts = { + "tensorflow/libtensorflow_framework.so.2": [ + "-z defs", + "-lpthread", + "-ldl", + "-lm", ], - "//conditions:default": [], - }), - py_cc_deps_filter = select({ - "//tensorflow:windows": [], - "//conditions:default": [ - "@local_xla//xla/tsl/python/lib/core:ml_dtypes_lib", - "@local_xla//xla/tsl/python/lib/core:numpy", - "@local_xla//xla/backends/profiler/cpu:python_tracer_impl", - "@local_xla//xla/backends/profiler/cpu:python_tracer", - "@local_xla//xla/python/profiler/internal:python_hooks", - "//tensorflow/lite/python/interpreter_wrapper:python_error_reporter", - "//tensorflow/lite/python/interpreter_wrapper:python_utils", - "//tensorflow/lite/toco/python:toco_python_api", - "//tensorflow/python/client:tf_session_helper", - "//tensorflow/python/eager:pywrap_tfe_lib", - "//tensorflow/python/framework:op_def_util_cc", - "//tensorflow/python/framework:py_context_manager", - "//tensorflow/python/framework:python_api_info", - "//tensorflow/python/framework:python_api_parameter_converter", - "//tensorflow/python/framework:python_tensor_converter", - "//tensorflow/python/framework:python_api_dispatcher", - "//tensorflow/python/lib/core:ndarray_tensor_bridge", - "//tensorflow/python/lib/core:ndarray_tensor", - "//tensorflow/python/lib/core:py_seq_tensor", - "//tensorflow/python/lib/core:py_util", - "//tensorflow/python/lib/core:py_exception_registry", - "//tensorflow/python/lib/core:py_func_lib", - "//tensorflow/python/util:cpp_python_util", - "//tensorflow/python/util:function_parameter_canonicalizer", - "//tensorflow/python/util:stack_trace", - "//tensorflow/python/util:cpp_nest", - "//tensorflow/compiler/mlir/lite/python:converter_python_api", - "//tensorflow/lite/python/metrics:metrics_wrapper_lib", - "//tensorflow/lite/python/interpreter_wrapper:interpreter_wrapper_lib", - "//tensorflow/lite/python/interpreter_wrapper:numpy", - "//tensorflow/lite/python/optimize:calibration_wrapper_lib", + "tensorflow/libtensorflow_cc.so.2": [ + "-z defs", + "-lpthread", + "-ldl", + "-lm", ], - }), + }, + # buildifier: disable=unsorted-dict-items + # @unsorted-dict-items + common_lib_version_scripts = { + "tensorflow/libtensorflow_cc.so.2": "//tensorflow:tf_version_script.lds", + }, + pywrap_lib_exclusion_filter = ":_pywrap_lib_exclusion_filter", + pywrap_lib_filter = ":_pywrap_lib_filter", + starlark_only_deps = [ + "//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper", + "//tensorflow/compiler/mlir/quantization/stablehlo/python:pywrap_quantization", + "//tensorflow/compiler/mlir/tfr:tfr_wrapper", + "//tensorflow/python/framework:_errors_test_helper", + "//tensorflow/python/framework/experimental:_math_ops", + "//tensorflow/python/framework/experimental:_nn_ops", + "//tensorflow/python/framework/experimental:_tape", + "//tensorflow/python/framework/experimental:_unified_api", + "//tensorflow/python/framework:_op_def_util", + "//tensorflow/python/framework:_py_context_manager", + "//tensorflow/python/framework:_pywrap_python_api_info", + "//tensorflow/python/framework:_pywrap_python_api_parameter_converter", + "//tensorflow/python/framework:_pywrap_python_tensor_converter", + "//tensorflow/python/grappler:_pywrap_cost_analyzer", + "//tensorflow/python/grappler:_pywrap_model_analyzer", + "//tensorflow/python/util:_function_parameter_canonicalizer_binding_for_test", + ], visibility = ["//visibility:public"], win_def_file = "_pywrap_tensorflow.def", - # win_def_file = "_pywrap_tensorflow.def", deps = [ - ":_pywrap_quantize_training", - ":_pywrap_tensorflow_cc_only", "//tensorflow/compiler/mlir/lite/python:_pywrap_converter_api", - "//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper", - "//tensorflow/compiler/mlir/quantization/stablehlo/python:pywrap_quantization", "//tensorflow/compiler/mlir/quantization/tensorflow/python:pywrap_function_lib", "//tensorflow/compiler/mlir/quantization/tensorflow/python:pywrap_quantize_model", - "//tensorflow/compiler/mlir/stablehlo:stablehlo_extension", - "//tensorflow/compiler/mlir/tfr:tfr_wrapper", + "//tensorflow/compiler/mlir/tensorflow_to_stablehlo/python:pywrap_tensorflow_to_stablehlo", "//tensorflow/compiler/tf2tensorrt:_pywrap_py_utils", "//tensorflow/lite/python/analyzer_wrapper:_pywrap_analyzer_wrapper", "//tensorflow/lite/python/interpreter_wrapper:_pywrap_tensorflow_interpreter_wrapper", @@ -1521,7 +1571,10 @@ pywrap_library( "//tensorflow/python:_pywrap_mlir", "//tensorflow/python:_pywrap_parallel_device", "//tensorflow/python:_pywrap_py_exception_registry", + "//tensorflow/python:_pywrap_quantize_training", "//tensorflow/python:_pywrap_sanitizers", + "//tensorflow/python:_pywrap_tensorflow_cc_only", + "//tensorflow/python:_pywrap_tensorflow_internal", "//tensorflow/python:_pywrap_tfcompile", "//tensorflow/python:_pywrap_tfe", "//tensorflow/python:_pywrap_toco_api", @@ -1535,25 +1588,13 @@ pywrap_library( "//tensorflow/python/data/experimental/service:_pywrap_snapshot_utils", "//tensorflow/python/data/experimental/service:_pywrap_utils_exp", "//tensorflow/python/framework:_dtypes", - "//tensorflow/python/framework:_errors_test_helper", "//tensorflow/python/framework:_op_def_library_pybind", "//tensorflow/python/framework:_op_def_registry", - "//tensorflow/python/framework:_op_def_util", "//tensorflow/python/framework:_proto_comparators", - "//tensorflow/python/framework:_py_context_manager", "//tensorflow/python/framework:_python_memory_checker_helper", "//tensorflow/python/framework:_pywrap_python_api_dispatcher", - "//tensorflow/python/framework:_pywrap_python_api_info", - "//tensorflow/python/framework:_pywrap_python_api_parameter_converter", "//tensorflow/python/framework:_pywrap_python_op_gen", - "//tensorflow/python/framework:_pywrap_python_tensor_converter", "//tensorflow/python/framework:_test_metrics_util", - "//tensorflow/python/framework/experimental:_math_ops", - "//tensorflow/python/framework/experimental:_nn_ops", - "//tensorflow/python/framework/experimental:_tape", - "//tensorflow/python/framework/experimental:_unified_api", - "//tensorflow/python/grappler:_pywrap_cost_analyzer", - "//tensorflow/python/grappler:_pywrap_model_analyzer", "//tensorflow/python/grappler:_pywrap_tf_cluster", "//tensorflow/python/grappler:_pywrap_tf_item", "//tensorflow/python/grappler:_pywrap_tf_optimizer", @@ -1568,11 +1609,11 @@ pywrap_library( "//tensorflow/python/saved_model:pywrap_saved_model", "//tensorflow/python/tpu:_pywrap_sparse_core_layout", "//tensorflow/python/tpu:_pywrap_tpu_embedding", - "//tensorflow/python/util:_function_parameter_canonicalizer_binding_for_test", "//tensorflow/python/util:_pywrap_checkpoint_reader", "//tensorflow/python/util:_pywrap_determinism", "//tensorflow/python/util:_pywrap_kernel_registry", "//tensorflow/python/util:_pywrap_nest", + "//tensorflow/python/util:_pywrap_stat_summarizer", "//tensorflow/python/util:_pywrap_tensor_float_32_execution", "//tensorflow/python/util:_pywrap_tfprof", "//tensorflow/python/util:_pywrap_transform_graph", @@ -1584,43 +1625,13 @@ pywrap_library( ], ) -pybind_extension( - name = "_pywrap_tensorflow_cc_only", - srcs = [], - deps = [ - ":_protobuf_inline_symbols_enforcer", - "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", - "//tensorflow/core/distributed_runtime:server_lib", - "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", - "//tensorflow/core/distributed_runtime/rpc:grpc_session", - "//tensorflow/core/kernels:data_service_ops", - "//tensorflow/core/kernels:reader_ops", - "//tensorflow/distribute/experimental/rpc/kernels:rpc_ops", - "//tensorflow/dtensor/cc:tensor_layout", - "@local_xla//xla/backends/profiler/cpu:python_tracer", - ], -) - -cc_library( - name = "_protobuf_inline_symbols_enforcer", - srcs = ["protobuf_inline_symbols_enforcer.cc"], - deps = [ - "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", - "//tensorflow/core/framework:attr_value_proto_cc", - "//tensorflow/core/framework:function_proto_cc", - "//tensorflow/core/framework:graph_proto_cc", - "//tensorflow/core/protobuf:for_core_protos_cc", - "//tensorflow/dtensor/proto:layout_proto_cc", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - ], -) - pywrap_common_library( - name = "_pywrap_tensorflow_common", + name = "tensorflow_common_framework", dep = ":_pywrap_tensorflow", + filter_name = "libtensorflow_framework.so.2", ) pywrap_binaries( - name = "_pywrap_tensorflow_binaries", + name = "pywrap_tensorflow_binaries", dep = ":_pywrap_tensorflow", ) diff --git a/tensorflow/python/_pywrap_dtensor_device.pyi b/tensorflow/python/_pywrap_dtensor_device.pyi index 0362a8c0f59a99..7657ceb332eae5 100644 --- a/tensorflow/python/_pywrap_dtensor_device.pyi +++ b/tensorflow/python/_pywrap_dtensor_device.pyi @@ -13,12 +13,9 @@ # limitations under the License. # ============================================================================== -from typing import Any, ClassVar - -from typing import overload +from typing import ClassVar, overload class Layout: - __hash__: ClassVar[None] = ... @overload def __init__(self, layout: Layout) -> None: ... @overload @@ -33,7 +30,7 @@ class Layout: def __init__(self, mesh: Mesh, rank: int, batch_dim: str, axis: int) -> None: ... @overload def __init__(self, mesh: Mesh) -> None: ... - def as_proto(self, *args, **kwargs) -> Any: ... + def as_proto(self, *args, **kwargs): ... def global_shape_from_local_shape(self, local_shape: list[int]) -> tuple: ... def is_batch_parallel(self) -> bool: ... def is_fully_replicated(self) -> bool: ... @@ -60,19 +57,16 @@ class LayoutType: __entries: ClassVar[dict] = ... def __init__(self, value: int) -> None: ... def __eq__(self, other: object) -> bool: ... - def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: int) -> None: ... @property def name(self) -> str: ... @property def value(self) -> int: ... class Mesh: - __hash__: ClassVar[None] = ... @overload def __init__(self, mesh: Mesh) -> None: ... @overload @@ -83,7 +77,7 @@ class Mesh: def __init__(self, mesh_proto) -> None: ... @overload def __init__(self, mesh_str: str) -> None: ... - def as_proto(self, *args, **kwargs) -> Any: ... + def as_proto(self, *args, **kwargs): ... def contains_dim(self, dim_name: str) -> bool: ... def device_location(self, arg0: int) -> list[int]: ... def device_type(self) -> str: ... @@ -119,7 +113,7 @@ def ExperimentalClearDefaultMesh(arg0) -> None: ... def ExperimentalSetDefaultLayout(arg0, arg1: str) -> None: ... def ExperimentalSetDefaultMesh(arg0, arg1: str) -> None: ... def FetchLayout(arg0: object, arg1: object, arg2) -> object: ... -def GetStats(arg0: object, arg1) -> dict[str,int]: ... +def GetStats(arg0: object, arg1) -> dict[str, int]: ... def IsDTensor(arg0: object, arg1: object, arg2) -> bool: ... def IsSparseDTensor(arg0: object, arg1: object, arg2) -> bool: ... def Pack(arg0: object, arg1: object, arg2: str, arg3, arg4: bool) -> object: ... diff --git a/tensorflow/python/_pywrap_mlir.pyi b/tensorflow/python/_pywrap_mlir.pyi index d1375e15159c31..86411b1ef9407e 100644 --- a/tensorflow/python/_pywrap_mlir.pyi +++ b/tensorflow/python/_pywrap_mlir.pyi @@ -19,7 +19,6 @@ def ExperimentalConvertSavedModelToMlir(arg0: str, arg1: str, arg2: bool) -> str def ExperimentalConvertSavedModelV1ToMlir(arg0: str, arg1: str, arg2: str, arg3: bool, arg4: bool, arg5: bool, arg6: bool) -> str: ... def ExperimentalConvertSavedModelV1ToMlirLite(arg0: str, arg1: str, arg2: str, arg3: bool, arg4: bool) -> str: ... def ExperimentalRunPassPipeline(arg0: str, arg1: str, arg2: bool) -> str: ... -def ExperimentalTFLiteToTosaBytecode(arg0: str, arg1: str, arg2: bool, arg3: list[str], arg4: list[str]) -> None: ... def ExperimentalWriteBytecode(arg0: str, arg1: str) -> None: ... def ImportFunction(arg0: object, arg1: str, arg2: str, arg3: bool) -> str: ... @overload diff --git a/tensorflow/python/_pywrap_py_exception_registry.pyi b/tensorflow/python/_pywrap_py_exception_registry.pyi index 2fe8027309ee3b..502fcb249dbbd1 100644 --- a/tensorflow/python/_pywrap_py_exception_registry.pyi +++ b/tensorflow/python/_pywrap_py_exception_registry.pyi @@ -49,12 +49,10 @@ class TF_Code: __entries: ClassVar[dict] = ... def __init__(self, value: int) -> None: ... def __eq__(self, other: object) -> bool: ... - def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: int) -> None: ... @property def name(self) -> str: ... @property diff --git a/tensorflow/python/_pywrap_tfe.pyi b/tensorflow/python/_pywrap_tfe.pyi index 1385ae69244d58..0c272a999f869c 100644 --- a/tensorflow/python/_pywrap_tfe.pyi +++ b/tensorflow/python/_pywrap_tfe.pyi @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from typing import Any, ClassVar +from typing import ClassVar TFE_DEVICE_PLACEMENT_EXPLICIT: TFE_ContextDevicePlacementPolicy TFE_DEVICE_PLACEMENT_SILENT: TFE_ContextDevicePlacementPolicy @@ -52,12 +52,10 @@ class TFE_ContextDevicePlacementPolicy: __entries: ClassVar[dict] = ... def __init__(self, value: int) -> None: ... def __eq__(self, other: object) -> bool: ... - def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: int) -> None: ... @property def name(self) -> str: ... @property @@ -152,12 +150,10 @@ class TF_AttrType: __entries: ClassVar[dict] = ... def __init__(self, value: int) -> None: ... def __eq__(self, other: object) -> bool: ... - def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: int) -> None: ... @property def name(self) -> str: ... @property @@ -192,7 +188,7 @@ def TFE_ContextGetDevicePlacementPolicy(arg0: object) -> TFE_ContextDevicePlacem def TFE_ContextGetExecutorForThread(arg0: object) -> TFE_Executor: ... def TFE_ContextGetFunction(arg0: object, arg1: str) -> TF_Function: ... def TFE_ContextGetFunctionDef(arg0: object, arg1: str, arg2: TF_Buffer) -> None: ... -def TFE_ContextGetFunctionDefNoSerialization(*args, **kwargs) -> Any: ... +def TFE_ContextGetFunctionDefNoSerialization(*args, **kwargs): ... def TFE_ContextGetGraphDebugInfo(arg0: object, arg1: str, arg2: TF_Buffer) -> None: ... def TFE_ContextHasFunction(arg0: object, arg1: str) -> int: ... def TFE_ContextListDevices(arg0: object) -> TF_DeviceList: ... @@ -225,7 +221,7 @@ def TFE_ExecutorWaitForAllPendingNodes(arg0: TFE_Executor) -> None: ... def TFE_FromDlpackCapsule(arg0, arg1: object) -> object: ... def TFE_GetConfigKeyValue(arg0: object, arg1: str, arg2: int, arg3: TF_Buffer) -> None: ... def TFE_GetContextId(arg0: object) -> int: ... -def TFE_GetMemoryInfo(arg0: object, arg1: str) -> dict[str,int]: ... +def TFE_GetMemoryInfo(arg0: object, arg1: str) -> dict[str, int]: ... def TFE_GetTaskStates(arg0: object, arg1: list[str], arg2: list[int]) -> object: ... def TFE_HostAddressSpace(arg0: object, arg1: TF_Buffer) -> None: ... def TFE_InsertConfigKeyValue(arg0: object, arg1: str, arg2: str) -> None: ... @@ -359,7 +355,7 @@ def TF_DeviceListType(arg0: TF_DeviceList, arg1: int) -> str: ... def TF_EnableMlirBridge(arg0: bool) -> None: ... def TF_EnableXlaDevices() -> None: ... def TF_GetCompilerIr(arg0: object, arg1: str, arg2: str, arg3: str, arg4: object, arg5: object, arg6: str) -> bytes: ... -def TF_GetDeviceDetails(arg0: int) -> dict[str,str]: ... +def TF_GetDeviceDetails(arg0: int) -> dict[str, str]: ... def TF_GetXlaConstantFoldingDisabled() -> int: ... def TF_IsMlirBridgeEnabled() -> int: ... def TF_ListPhysicalDevices() -> object: ... diff --git a/tensorflow/python/autograph/STYLE_GUIDE.md b/tensorflow/python/autograph/STYLE_GUIDE.md index 1c23eacd8fd89c..12bceffdaec74c 100644 --- a/tensorflow/python/autograph/STYLE_GUIDE.md +++ b/tensorflow/python/autograph/STYLE_GUIDE.md @@ -17,8 +17,8 @@ Naming conventions: ## AutoGraph Style -Below are AutoGraph-specific conventions. In the event of conflict, -it supercedes all previous conventions. +Below are AutoGraph-specific conventions. In the event of conflict, it +supersedes all previous conventions. 1. __Types in docstrings.__ Use [PEP 484][https://www.python.org/dev/peps/pep-0484/] notation to describe the type for args, return values and attributes. diff --git a/tensorflow/python/client/_pywrap_events_writer.pyi b/tensorflow/python/client/_pywrap_events_writer.pyi index 92da35bcfe093b..04d73399a3234c 100644 --- a/tensorflow/python/client/_pywrap_events_writer.pyi +++ b/tensorflow/python/client/_pywrap_events_writer.pyi @@ -20,7 +20,6 @@ class EventsWriter: def Flush(self) -> Status: ... def InitWithSuffix(self, arg0: str) -> Status: ... def WriteEvent(self, arg0: object) -> None: ... - def _WriteSerializedEvent(self, arg0: str) -> None: ... class Status: def __init__(self, *args, **kwargs) -> None: ... diff --git a/tensorflow/python/client/_pywrap_tf_session.pyi b/tensorflow/python/client/_pywrap_tf_session.pyi index 14645b34c5f5be..2be74af74875d8 100644 --- a/tensorflow/python/client/_pywrap_tf_session.pyi +++ b/tensorflow/python/client/_pywrap_tf_session.pyi @@ -13,9 +13,8 @@ # limitations under the License. # ============================================================================== -from typing import Any, ClassVar, Iterator, Optional +from typing import ClassVar, Iterator, overload -from typing import overload TF_ABORTED: TF_Code TF_BFLOAT16: TF_DataType TF_BOOL: TF_DataType @@ -102,25 +101,13 @@ class PyGraph: @classmethod def __init__(cls, *args, **kwargs) -> None: ... @classmethod - def Dismantle(cls, *args, **kwargs) -> Any: ... + def Dismantle(cls, *args, **kwargs): ... @classmethod - def _add_op(cls, *args, **kwargs) -> Any: ... + def get_operations(cls, *args, **kwargs): ... @classmethod - def _get_operation_by_name(cls, *args, **kwargs) -> Any: ... + def new_operations(cls, *args, **kwargs): ... @classmethod - def _op_def_for_type(cls, *args, **kwargs) -> Any: ... - @classmethod - def get_operations(cls, *args, **kwargs) -> Any: ... - @classmethod - def new_operations(cls, *args, **kwargs) -> Any: ... - @classmethod - def num_operations(cls, *args, **kwargs) -> Any: ... - @property - def _nodes_by_id(self) -> OpsById: ... - @property - def _nodes_by_name(self) -> OpsByName: ... - @property - def _version_def(self) -> bytes: ... + def num_operations(cls, *args, **kwargs): ... @property def operations(self) -> list: ... @property @@ -130,32 +117,6 @@ class PyOperation: graph: object @classmethod def __init__(cls, *args, **kwargs) -> None: ... - @classmethod - def _add_control_input(cls, *args, **kwargs) -> Any: ... - @classmethod - def _add_control_inputs(cls, *args, **kwargs) -> Any: ... - @classmethod - def _add_outputs(cls, *args, **kwargs) -> Any: ... - @classmethod - def _init_outputs(cls, *args, **kwargs) -> Any: ... - @classmethod - def _remove_all_control_inputs(cls, *args, **kwargs) -> Any: ... - @classmethod - def _set_device_from_string(cls, *args, **kwargs) -> Any: ... - @classmethod - def _tf_input(cls, *args, **kwargs) -> Any: ... - @classmethod - def _tf_output(cls, *args, **kwargs) -> Any: ... - @property - def _c_op(self) -> TF_Operation: ... - @property - def _control_outputs(self) -> list: ... - @property - def _is_stateful(self) -> bool: ... - @property - def _node_def(self) -> bytes: ... - @property - def _op_def(self) -> bytes: ... @property def control_inputs(self) -> list: ... @property @@ -168,25 +129,10 @@ class PyOperation: def type(self) -> str: ... class PyTensor: - _id: object - _name: object - _shape_val: object @classmethod def __init__(cls, *args, **kwargs) -> None: ... @classmethod - def _as_tf_output(cls, *args, **kwargs) -> Any: ... - @classmethod - def _rank(cls, *args, **kwargs) -> Any: ... - @classmethod - def _set_shape(cls, *args, **kwargs) -> Any: ... - @classmethod - def consumers(cls, *args, **kwargs) -> Any: ... - @property - def _dtype(self) -> object: ... - @property - def _op(self) -> object: ... - @property - def _shape(self) -> object: ... + def consumers(cls, *args, **kwargs): ... @property def device(self) -> str: ... @property @@ -223,12 +169,10 @@ class TF_Code: __entries: ClassVar[dict] = ... def __init__(self, value: int) -> None: ... def __eq__(self, other: object) -> bool: ... - def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: int) -> None: ... @property def name(self) -> str: ... @property @@ -263,12 +207,10 @@ class TF_DataType: __entries: ClassVar[dict] = ... def __init__(self, value: int) -> None: ... def __eq__(self, other: object) -> bool: ... - def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: int) -> None: ... @property def name(self) -> str: ... @property @@ -375,10 +317,10 @@ def TF_GraphImportGraphDefWithResults(arg0: PyGraph, arg1: TF_Buffer, arg2: TF_I def TF_GraphImportGraphDefWithResultsNoSerialization(arg0: PyGraph, arg1, arg2: TF_ImportGraphDefOptions) -> TF_ImportGraphDefResults: ... def TF_GraphNextOperation(arg0: PyGraph, arg1: int) -> tuple: ... def TF_GraphRemoveFunction(arg0: PyGraph, arg1: str) -> None: ... -def TF_GraphSetOutputHandleShapesAndTypes_wrapper(arg0: PyGraph, arg1: TF_Output, arg2: list[Optional[list[int]]], arg3: list[int], arg4: object) -> None: ... -def TF_GraphToFunction_wrapper(arg0: PyGraph, arg1: str, arg2: bool, arg3: Optional[list[TF_Operation]], arg4: list[TF_Output], arg5: list[TF_Output], arg6: list[bytes], arg7: list[TF_Operation], arg8: list[bytes], arg9: None, arg10: str) -> TF_Function: ... +def TF_GraphSetOutputHandleShapesAndTypes_wrapper(arg0: PyGraph, arg1: TF_Output, arg2: list[list[int] | None], arg3: list[int], arg4: object) -> None: ... +def TF_GraphToFunction_wrapper(arg0: PyGraph, arg1: str, arg2: bool, arg3: list[TF_Operation] | None, arg4: list[TF_Output], arg5: list[TF_Output], arg6: list[bytes], arg7: list[TF_Operation], arg8: list[bytes], arg9: None, arg10: str) -> TF_Function: ... def TF_GraphToGraphDef(arg0: PyGraph, arg1: TF_Buffer) -> None: ... -def TF_GraphToGraphDefPybind(*args, **kwargs) -> Any: ... +def TF_GraphToGraphDefPybind(*args, **kwargs): ... def TF_ImportGraphDefOptionsAddInputMapping(arg0: TF_ImportGraphDefOptions, arg1: str, arg2: int, arg3: TF_Output) -> None: ... def TF_ImportGraphDefOptionsAddReturnOperation(arg0: TF_ImportGraphDefOptions, arg1: str) -> None: ... def TF_ImportGraphDefOptionsAddReturnOutput(arg0: TF_ImportGraphDefOptions, arg1: str, arg2: int) -> None: ... @@ -439,9 +381,6 @@ def TF_SetXlaEnableLazyCompilation(arg0: int) -> int: ... def TF_SetXlaMinClusterSize(arg0: int) -> None: ... def TF_TryEvaluateConstant_wrapper(arg0: PyGraph, arg1: TF_Output) -> object: ... def UpdateEdge(arg0: PyGraph, arg1: TF_Output, arg2: TF_Input) -> None: ... -def _TF_NewSessionOptions() -> TF_SessionOptions: ... -def _TF_SetConfig(arg0: TF_SessionOptions, arg1: bytes) -> None: ... -def _TF_SetTarget(arg0: TF_SessionOptions, arg1: str) -> None: ... def get_compiler_version() -> str: ... def get_cxx11_abi_flag() -> int: ... def get_cxx_version() -> int: ... diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index b475c74a7ec3d0..c04d99fcc2fb91 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 12, 10) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 1, 12) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/compiler/mlir/mlir_test.py b/tensorflow/python/compiler/mlir/mlir_test.py index 9b4a54729b1d4f..9c7f75950f4190 100644 --- a/tensorflow/python/compiler/mlir/mlir_test.py +++ b/tensorflow/python/compiler/mlir/mlir_test.py @@ -14,7 +14,6 @@ # ============================================================================= """Tests for python.compiler.mlir.""" -import os from tensorflow.python.compiler.mlir import mlir from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes @@ -23,9 +22,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops -from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test -from tensorflow.python.pywrap_mlir import experimental_tflite_to_tosa_bytecode from tensorflow.python.pywrap_mlir import import_graphdef @@ -161,19 +158,5 @@ def logging(): self.assertRegex(mlir_module, r'tf_executor.fetch.*: !tf_executor.control') -class MLIRFlatbufferImportTest(test.TestCase): - - def testImport(self): - """Tests the basic flow of `experimental_tflite_to_tosa_bytecode`.""" - filename = os.path.join(self.get_temp_dir(), "multi_add_tosa.mlirbc") - experimental_tflite_to_tosa_bytecode( - resource_loader.get_path_to_datafile("multi_add.tflite"), filename - ) - with open(filename, mode="rb") as f: - chunk = f.read(4) - # Just verify output is bytecode. - self.assertEqual(b"ML\xefR", chunk) - - if __name__ == '__main__': test.main() diff --git a/tensorflow/python/data/experimental/ops/from_list.py b/tensorflow/python/data/experimental/ops/from_list.py index 5f10c3fb252508..008442c76cc372 100644 --- a/tensorflow/python/data/experimental/ops/from_list.py +++ b/tensorflow/python/data/experimental/ops/from_list.py @@ -28,7 +28,10 @@ class _ListDataset(dataset_ops.DatasetSource): def __init__(self, elements, name=None): if not elements: - raise ValueError("Invalid `elements`. `elements` should not be empty.") + raise ValueError( + "Invalid `elements`. `elements` should not be empty. If you want an" + " empty dataset, use `tf.data.Dataset.range(0)`." + ) if not isinstance(elements, list): raise ValueError("Invalid `elements`. `elements` must be a list.") diff --git a/tensorflow/python/data/experimental/service/_pywrap_server_lib.pyi b/tensorflow/python/data/experimental/service/_pywrap_server_lib.pyi index d39c6ac8225da8..d3443e95d52376 100644 --- a/tensorflow/python/data/experimental/service/_pywrap_server_lib.pyi +++ b/tensorflow/python/data/experimental/service/_pywrap_server_lib.pyi @@ -13,14 +13,12 @@ # limitations under the License. # ============================================================================== -from typing import Any - class DispatchGrpcDataServer: def __init__(self, *args, **kwargs) -> None: ... def bound_port(self) -> int: ... def join(self) -> None: ... def num_workers(self) -> int: ... - def snapshot_streams(self, *args, **kwargs) -> Any: ... + def snapshot_streams(self, *args, **kwargs): ... def start(self) -> Status: ... def stop(self) -> None: ... @@ -45,10 +43,10 @@ class WorkerGrpcDataServer: def bound_port(self) -> int: ... def join(self) -> None: ... def num_tasks(self) -> int: ... - def snapshot_task_progresses(self, *args, **kwargs) -> Any: ... + def snapshot_task_progresses(self, *args, **kwargs): ... def start(self) -> Status: ... def stop(self) -> None: ... -def TF_DATA_GetDataServiceMetadataByID(*args, **kwargs) -> Any: ... +def TF_DATA_GetDataServiceMetadataByID(*args, **kwargs): ... def TF_DATA_NewDispatchServer(arg0: str) -> DispatchGrpcDataServer: ... def TF_DATA_NewWorkerServer(arg0: str) -> WorkerGrpcDataServer: ... diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 38801be5158b22..b91437f60240c6 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -174,7 +174,7 @@ def deserialize(self, serialized): def _calculate_replicas_with_values(strategy, input_workers, optional_list): - """Calculates the number of replicas that have values. + """Computes the number of replicas that have values. Args: strategy: the `tf.distribute.Strategy`. diff --git a/tensorflow/python/distribute/integration_test/saved_model_test.py b/tensorflow/python/distribute/integration_test/saved_model_test.py index 216efa88aa4bef..838cc88670a3da 100644 --- a/tensorflow/python/distribute/integration_test/saved_model_test.py +++ b/tensorflow/python/distribute/integration_test/saved_model_test.py @@ -122,7 +122,7 @@ class SaveAndLoadForServingTest(test.TestCase, parameterized.TestCase): # # Note that distributed variables have different behavior in the replica # context and the cross-replica context. Saving happens in the cross replica - # context or the default startegy's replica context. + # context or the default strategy's replica context. def test_read_sync_on_read_variable(self, strategy): # synchronizaiton=ON_READ variables are typically used in Keras metrics and diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 8c49758d560dcd..1fab92f551e33b 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -105,7 +105,12 @@ def testNumpyValueWithCast(self): ctx = context.context() # Bad dtype value. with self.assertRaisesRegex(TypeError, "Invalid dtype argument value"): - ops.EagerTensor(values, device=ctx.device_name, dtype=12345) + # The chosen `dtype` value here needs to be both not defined as a value of + # TF_DataType but also representable in the same number of bits as the max + # value of TF_DataType. At 12/20/24, where the max value of TF_DataType is + # 30, so using e.g. 63 would fail ASAN due to 63 not being representable + # in 5 bits. + ops.EagerTensor(values, device=ctx.device_name, dtype=31) def testNumpyOrderHandling(self): n = np.array([[1, 2], [3, 4]], order="F") diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index 730a3377feff4d..9f380cfb82cf6d 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -66,7 +66,7 @@ tf_cc_shared_object( ], if_true = [ ":test_file_system_stripped", - "//tensorflow/python:_pywrap_tensorflow_common", + "//tensorflow/python:tensorflow_common_framework", ], ) + ["@com_google_protobuf//:protobuf_headers"], ) @@ -787,6 +787,7 @@ tf_python_pybind_extension( pytype_srcs = [ "_py_context_manager.pyi", ], + starlark_only = True, deps = [ ":py_context_manager", "//third_party/python_runtime:headers", @@ -836,6 +837,7 @@ tf_python_pybind_extension( pytype_srcs = [ "_op_def_util.pyi", ], + starlark_only = True, deps = if_pywrap( if_false = [":op_def_util_headers"], if_true = [":op_def_util_cc"], @@ -918,6 +920,7 @@ tf_python_pybind_extension( pytype_srcs = [ "_pywrap_python_api_parameter_converter.pyi", ], + starlark_only = True, deps = [ "//tensorflow/c:pywrap_required_hdrs", "//tensorflow/core:framework", @@ -1028,6 +1031,7 @@ tf_python_pybind_extension( pytype_srcs = [ "_pywrap_python_api_info.pyi", ], + starlark_only = True, deps = [ "//tensorflow/c:pywrap_required_hdrs", "//tensorflow/core:framework", @@ -1185,6 +1189,7 @@ tf_python_pybind_extension( pytype_srcs = [ "_pywrap_python_tensor_converter.pyi", ], + starlark_only = True, deps = [ "//tensorflow/c:pywrap_required_hdrs", "//tensorflow/core:framework", @@ -2174,14 +2179,20 @@ pytype_strict_library( name = "test_lib", srcs = ["test_util.py"], srcs_version = "PY3", - visibility = visibility + [ - "//tensorflow:internal", - "//tensorflow_model_optimization:__subpackages__", - "//third_party/cloud_tpu/convergence_tools:__subpackages__", - "//third_party/py/neural_structured_learning:__subpackages__", - "//third_party/py/tf_agents:__subpackages__", - "//third_party/py/tf_keras:__subpackages__", + # copybara:uncomment_begin(google-only) + # visibility = visibility + [ + # "//third_party/cloud_tpu/convergence_tools:__subpackages__", + # "//third_party/py/neural_structured_learning:__subpackages__", + # "//third_party/py/tf_agents:__subpackages__", + # "//third_party/py/tf_keras:__subpackages__", + # "//tensorflow:internal", + # "//tensorflow_model_optimization:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + visibility = [ + "//visibility:public", ], + # copybara:comment_end deps = [ ":_test_metrics_util", ":config", @@ -3188,6 +3199,7 @@ tf_python_pybind_extension( pytype_srcs = [ "_errors_test_helper.pyi", ], + starlark_only = True, deps = [ "//tensorflow/c:tf_status_headers", "//tensorflow/core/platform:status", diff --git a/tensorflow/python/framework/_dtypes.pyi b/tensorflow/python/framework/_dtypes.pyi index b3514b9ea55bb6..7d49508ef446af 100644 --- a/tensorflow/python/framework/_dtypes.pyi +++ b/tensorflow/python/framework/_dtypes.pyi @@ -18,8 +18,6 @@ class DType: def __hash__(self) -> int: ... def __int__(self) -> int: ... @property - def _type_enum(self) -> int: ... - @property def as_datatype_enum(self) -> int: ... @property def is_bool(self) -> bool: ... diff --git a/tensorflow/python/framework/_pywrap_python_api_dispatcher.pyi b/tensorflow/python/framework/_pywrap_python_api_dispatcher.pyi index b4451d9e8c926b..fd9416cccbba9f 100644 --- a/tensorflow/python/framework/_pywrap_python_api_dispatcher.pyi +++ b/tensorflow/python/framework/_pywrap_python_api_dispatcher.pyi @@ -27,19 +27,17 @@ class MatchType: __entries: ClassVar[dict] = ... def __init__(self, value: int) -> None: ... def __eq__(self, other: object) -> bool: ... - def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: int) -> None: ... @property def name(self) -> str: ... @property def value(self) -> int: ... class PySignatureChecker: - def __init__(self, arg0: list[tuple[int,PyTypeChecker]]) -> None: ... + def __init__(self, arg0: list[tuple[int, PyTypeChecker]]) -> None: ... def CheckCanonicalizedArgs(self, arg0: tuple) -> bool: ... class PyTypeChecker: diff --git a/tensorflow/python/framework/_pywrap_python_api_info.pyi b/tensorflow/python/framework/_pywrap_python_api_info.pyi index e64b364a5d970c..759ee3c64b848e 100644 --- a/tensorflow/python/framework/_pywrap_python_api_info.pyi +++ b/tensorflow/python/framework/_pywrap_python_api_info.pyi @@ -13,16 +13,14 @@ # limitations under the License. # ============================================================================== -from typing import Any - class InferredAttributes: def __init__(self, *args, **kwargs) -> None: ... @property def lengths(self) -> list[int]: ... @property - def type_lists(self) -> Any: ... + def type_lists(self): ... @property - def types(self) -> Any: ... + def types(self): ... class PythonAPIInfo: def __init__(self, arg0: str) -> None: ... @@ -30,5 +28,5 @@ class PythonAPIInfo: def InferredLengthAttrs(self) -> list[str]: ... def InferredTypeAttrs(self) -> list[str]: ... def InferredTypeListAttrs(self) -> list[str]: ... - def InitializeFromParamSpecs(self, arg0: dict[str,str], arg1: dict[str,str], arg2: list[str], arg3: object) -> None: ... + def InitializeFromParamSpecs(self, arg0: dict[str, str], arg1: dict[str, str], arg2: list[str], arg3: object) -> None: ... def InitializeFromRegisteredOp(self, arg0: str) -> None: ... diff --git a/tensorflow/python/framework/_pywrap_python_api_parameter_converter.pyi b/tensorflow/python/framework/_pywrap_python_api_parameter_converter.pyi index 7f5cf048c43669..e1eafd69c90a63 100644 --- a/tensorflow/python/framework/_pywrap_python_api_parameter_converter.pyi +++ b/tensorflow/python/framework/_pywrap_python_api_parameter_converter.pyi @@ -13,6 +13,4 @@ # limitations under the License. # ============================================================================== -from typing import Any - -def Convert(*args, **kwargs) -> Any: ... +def Convert(*args, **kwargs): ... diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index c4ae5fd573ef28..4c359c511a1bcf 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -296,175 +296,188 @@ def __reduce__(self): np.uint16: (0, 65535), np.int8: (-128, 127), np.int16: (-32768, 32767), - np.int64: (-2**63, 2**63 - 1), + np.int64: (-(2**63), 2**63 - 1), np.uint64: (0, 2**64 - 1), - np.int32: (-2**31, 2**31 - 1), + np.int32: (-(2**31), 2**31 - 1), np.uint32: (0, 2**32 - 1), np.float32: (-1, 1), - np.float64: (-1, 1) + np.float64: (-1, 1), } # Define standard wrappers for the types_pb2.DataType enum. resource = DType(types_pb2.DT_RESOURCE) doc_typealias.document( - obj=resource, - doc="Handle to a mutable, dynamically allocated resource.") + obj=resource, doc="Handle to a mutable, dynamically allocated resource." +) tf_export("dtypes.resource", "resource").export_constant(__name__, "resource") variant = DType(types_pb2.DT_VARIANT) doc_typealias.document( - obj=variant, - doc="Data of arbitrary type (known at runtime).") + obj=variant, doc="Data of arbitrary type (known at runtime)." +) tf_export("dtypes.variant", "variant").export_constant(__name__, "variant") uint8 = DType(types_pb2.DT_UINT8) -doc_typealias.document( - obj=uint8, - doc="Unsigned 8-bit (byte) integer.") +doc_typealias.document(obj=uint8, doc="Unsigned 8-bit (byte) integer.") tf_export("dtypes.uint8", "uint8").export_constant(__name__, "uint8") uint16 = DType(types_pb2.DT_UINT16) -doc_typealias.document( - obj=uint16, - doc="Unsigned 16-bit (word) integer.") +doc_typealias.document(obj=uint16, doc="Unsigned 16-bit (word) integer.") tf_export("dtypes.uint16", "uint16").export_constant(__name__, "uint16") uint32 = DType(types_pb2.DT_UINT32) -doc_typealias.document( - obj=uint32, - doc="Unsigned 32-bit (dword) integer.") +doc_typealias.document(obj=uint32, doc="Unsigned 32-bit (dword) integer.") tf_export("dtypes.uint32", "uint32").export_constant(__name__, "uint32") uint64 = DType(types_pb2.DT_UINT64) -doc_typealias.document( - obj=uint64, - doc="Unsigned 64-bit (qword) integer.") +doc_typealias.document(obj=uint64, doc="Unsigned 64-bit (qword) integer.") tf_export("dtypes.uint64", "uint64").export_constant(__name__, "uint64") int8 = DType(types_pb2.DT_INT8) -doc_typealias.document( - obj=int8, - doc="Signed 8-bit integer.") +doc_typealias.document(obj=int8, doc="Signed 8-bit integer.") tf_export("dtypes.int8", "int8").export_constant(__name__, "int8") int16 = DType(types_pb2.DT_INT16) -doc_typealias.document( - obj=int16, - doc="Signed 16-bit integer.") +doc_typealias.document(obj=int16, doc="Signed 16-bit integer.") tf_export("dtypes.int16", "int16").export_constant(__name__, "int16") int32 = DType(types_pb2.DT_INT32) -doc_typealias.document( - obj=int32, - doc="Signed 32-bit integer.") +doc_typealias.document(obj=int32, doc="Signed 32-bit integer.") tf_export("dtypes.int32", "int32").export_constant(__name__, "int32") int64 = DType(types_pb2.DT_INT64) -doc_typealias.document( - obj=int64, - doc="Signed 64-bit integer.") +doc_typealias.document(obj=int64, doc="Signed 64-bit integer.") tf_export("dtypes.int64", "int64").export_constant(__name__, "int64") float16 = DType(types_pb2.DT_HALF) half = float16 doc_typealias.document( - obj=float16, - doc="16-bit (half precision) floating-point.") + obj=float16, doc="16-bit (half precision) floating-point." +) tf_export("dtypes.float16", "float16").export_constant(__name__, "float16") tf_export("dtypes.half", "half").export_constant(__name__, "half") float32 = DType(types_pb2.DT_FLOAT) doc_typealias.document( - obj=float32, - doc="32-bit (single precision) floating-point.") + obj=float32, doc="32-bit (single precision) floating-point." +) tf_export("dtypes.float32", "float32").export_constant(__name__, "float32") float64 = DType(types_pb2.DT_DOUBLE) doc_typealias.document( - obj=float64, - doc="64-bit (double precision) floating-point.") + obj=float64, doc="64-bit (double precision) floating-point." +) tf_export("dtypes.float64", "float64").export_constant(__name__, "float64") double = float64 tf_export("dtypes.double", "double").export_constant(__name__, "double") complex64 = DType(types_pb2.DT_COMPLEX64) -doc_typealias.document( - obj=complex64, - doc="64-bit complex.") -tf_export("dtypes.complex64", - "complex64").export_constant(__name__, "complex64") +doc_typealias.document(obj=complex64, doc="64-bit complex.") +tf_export("dtypes.complex64", "complex64").export_constant( + __name__, "complex64" +) complex128 = DType(types_pb2.DT_COMPLEX128) -doc_typealias.document( - obj=complex128, - doc="128-bit complex.") -tf_export("dtypes.complex128", - "complex128").export_constant(__name__, "complex128") +doc_typealias.document(obj=complex128, doc="128-bit complex.") +tf_export("dtypes.complex128", "complex128").export_constant( + __name__, "complex128" +) string = DType(types_pb2.DT_STRING) doc_typealias.document( - obj=string, - doc="Variable-length string, represented as byte array.") + obj=string, doc="Variable-length string, represented as byte array." +) tf_export("dtypes.string", "string").export_constant(__name__, "string") bool = DType(types_pb2.DT_BOOL) # pylint: disable=redefined-builtin -doc_typealias.document( - obj=bool, - doc="Boolean.") +doc_typealias.document(obj=bool, doc="Boolean.") tf_export("dtypes.bool", "bool").export_constant(__name__, "bool") qint8 = DType(types_pb2.DT_QINT8) -doc_typealias.document( - obj=qint8, - doc="Signed quantized 8-bit integer.") +doc_typealias.document(obj=qint8, doc="Signed quantized 8-bit integer.") tf_export("dtypes.qint8", "qint8").export_constant(__name__, "qint8") qint16 = DType(types_pb2.DT_QINT16) -doc_typealias.document( - obj=qint16, - doc="Signed quantized 16-bit integer.") +doc_typealias.document(obj=qint16, doc="Signed quantized 16-bit integer.") tf_export("dtypes.qint16", "qint16").export_constant(__name__, "qint16") qint32 = DType(types_pb2.DT_QINT32) -doc_typealias.document( - obj=qint32, - doc="signed quantized 32-bit integer.") +doc_typealias.document(obj=qint32, doc="signed quantized 32-bit integer.") tf_export("dtypes.qint32", "qint32").export_constant(__name__, "qint32") quint8 = DType(types_pb2.DT_QUINT8) -doc_typealias.document( - obj=quint8, - doc="Unsigned quantized 8-bit integer.") +doc_typealias.document(obj=quint8, doc="Unsigned quantized 8-bit integer.") tf_export("dtypes.quint8", "quint8").export_constant(__name__, "quint8") quint16 = DType(types_pb2.DT_QUINT16) -doc_typealias.document( - obj=quint16, - doc="Unsigned quantized 16-bit integer.") +doc_typealias.document(obj=quint16, doc="Unsigned quantized 16-bit integer.") tf_export("dtypes.quint16", "quint16").export_constant(__name__, "quint16") bfloat16 = DType(types_pb2.DT_BFLOAT16) doc_typealias.document( - obj=bfloat16, - doc="16-bit bfloat (brain floating point).") + obj=bfloat16, doc="16-bit bfloat (brain floating point)." +) tf_export("dtypes.bfloat16", "bfloat16").export_constant(__name__, "bfloat16") float8_e5m2 = DType(types_pb2.DT_FLOAT8_E5M2) doc_typealias.document( - obj=float8_e5m2, - doc="8-bit float with 5 exponent bits and 2 mantissa bits.") -tf_export("dtypes.experimental.float8_e5m2", - "experimental.float8_e5m2").export_constant(__name__, "float8_e5m2") + obj=float8_e5m2, doc="8-bit float with 5 exponent bits and 2 mantissa bits." +) +tf_export( + "dtypes.experimental.float8_e5m2", "experimental.float8_e5m2" +).export_constant(__name__, "float8_e5m2") float8_e4m3fn = DType(types_pb2.DT_FLOAT8_E4M3FN) doc_typealias.document( obj=float8_e4m3fn, - doc="8-bit float with 4 exponent bits and 3 mantissa bits, with extended " - "finite range. This type has no representation for inf, and only two NaN " - "values: 0xFF for negative NaN, and 0x7F for positive NaN.") -tf_export("dtypes.experimental.float8_e4m3fn", - "experimental.float8_e4m3fn").export_constant(__name__, - "float8_e4m3fn") + doc=( + "8-bit float with 4 exponent bits and 3 mantissa bits, with extended" + " finite range. This type has no representation for inf, and only two" + " NaN values: 0xFF for negative NaN, and 0x7F for positive NaN." + ), +) +tf_export( + "dtypes.experimental.float8_e4m3fn", "experimental.float8_e4m3fn" +).export_constant(__name__, "float8_e4m3fn") + +float8_e4m3fnuz = DType(types_pb2.DT_FLOAT8_E4M3FNUZ) +doc_typealias.document( + obj=float8_e4m3fnuz, + doc=( + "8-bit float with 4 exponent bits and 3 mantissa bits, with extended" + " finite range. This type has no representation for inf, and only one" + " NaN value: 0x80." + ), +) +tf_export( + "dtypes.experimental.float8_e4m3fnuz", "experimental.float8_e4m3fnuz" +).export_constant(__name__, "float8_e4m3fnuz") + +float8_e4m3b11fnuz = DType(types_pb2.DT_FLOAT8_E4M3B11FNUZ) +doc_typealias.document( + obj=float8_e4m3b11fnuz, + doc=( + "8-bit float with 4 exponent bits and 3 mantissa bits, with extended " + "finite range and 11 bits of bias. This type has no representation " + "for inf, and only one NaN value: 0x80." + ), +) +tf_export( + "dtypes.experimental.float8_e4m3b11fnuz", "experimental.float8_e4m3b11fnuz" +).export_constant(__name__, "float8_e4m3b11fnuz") + +float8_e5m2fnuz = DType(types_pb2.DT_FLOAT8_E5M2FNUZ) +doc_typealias.document( + obj=float8_e5m2fnuz, + doc=( + "8-bit float with 5 exponent bits and 2 mantissa bits, with extended " + "finite range. This type has no representation for inf, and only one " + "NaN value: 0x80." + ), +) +tf_export( + "dtypes.experimental.float8_e5m2fnuz", "experimental.float8_e5m2fnuz" +).export_constant(__name__, "float8_e5m2fnuz") int4 = DType(types_pb2.DT_INT4) doc_typealias.document(obj=int4, doc="Signed 4-bit integer.") @@ -505,6 +518,9 @@ def __reduce__(self): bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF) float8_e5m2_ref = DType(types_pb2.DT_FLOAT8_E5M2_REF) float8_e4m3fn_ref = DType(types_pb2.DT_FLOAT8_E4M3FN_REF) +float8_e4m3fnuz_ref = DType(types_pb2.DT_FLOAT8_E4M3FNUZ_REF) +float8_e4m3b11fnuz_ref = DType(types_pb2.DT_FLOAT8_E4M3B11FNUZ_REF) +float8_e5m2fnuz_ref = DType(types_pb2.DT_FLOAT8_E5M2FNUZ_REF) int4_ref = DType(types_pb2.DT_INT4_REF) uint4_ref = DType(types_pb2.DT_UINT4_REF) @@ -534,6 +550,9 @@ def __reduce__(self): types_pb2.DT_BFLOAT16: bfloat16, types_pb2.DT_FLOAT8_E5M2: float8_e5m2, types_pb2.DT_FLOAT8_E4M3FN: float8_e4m3fn, + types_pb2.DT_FLOAT8_E4M3FNUZ: float8_e4m3fnuz, + types_pb2.DT_FLOAT8_E4M3B11FNUZ: float8_e4m3b11fnuz, + types_pb2.DT_FLOAT8_E5M2FNUZ: float8_e5m2fnuz, types_pb2.DT_INT4: int4, types_pb2.DT_UINT4: uint4, types_pb2.DT_RESOURCE: resource, @@ -561,6 +580,9 @@ def __reduce__(self): types_pb2.DT_BFLOAT16_REF: bfloat16_ref, types_pb2.DT_FLOAT8_E5M2_REF: float8_e5m2_ref, types_pb2.DT_FLOAT8_E4M3FN_REF: float8_e4m3fn_ref, + types_pb2.DT_FLOAT8_E4M3FNUZ_REF: float8_e4m3fnuz_ref, + types_pb2.DT_FLOAT8_E4M3B11FNUZ_REF: float8_e4m3b11fnuz_ref, + types_pb2.DT_FLOAT8_E5M2FNUZ_REF: float8_e5m2fnuz_ref, types_pb2.DT_INT4_REF: int4_ref, types_pb2.DT_UINT4_REF: uint4_ref, types_pb2.DT_RESOURCE_REF: resource_ref, @@ -592,6 +614,9 @@ def __reduce__(self): types_pb2.DT_BFLOAT16: "bfloat16", types_pb2.DT_FLOAT8_E5M2: "float8_e5m2", types_pb2.DT_FLOAT8_E4M3FN: "float8_e4m3fn", + types_pb2.DT_FLOAT8_E4M3FNUZ: "float8_e4m3fnuz", + types_pb2.DT_FLOAT8_E4M3B11FNUZ: "float8_e4m3b11fnuz", + types_pb2.DT_FLOAT8_E5M2FNUZ: "float8_e5m2fnuz", types_pb2.DT_INT4: "int4", types_pb2.DT_UINT4: "uint4", types_pb2.DT_RESOURCE: "resource", @@ -619,6 +644,9 @@ def __reduce__(self): types_pb2.DT_BFLOAT16_REF: "bfloat16_ref", types_pb2.DT_FLOAT8_E5M2_REF: "float8_e5m2_ref", types_pb2.DT_FLOAT8_E4M3FN_REF: "float8_e4m3fn_ref", + types_pb2.DT_FLOAT8_E4M3FNUZ_REF: "float8_e4m3fnuz_ref", + types_pb2.DT_FLOAT8_E4M3B11FNUZ_REF: "float8_e4m3b11fnuz_ref", + types_pb2.DT_FLOAT8_E5M2FNUZ_REF: "float8_e5m2fnuz_ref", types_pb2.DT_INT4_REF: "int4_ref", types_pb2.DT_UINT4_REF: "uint4_ref", types_pb2.DT_RESOURCE_REF: "resource_ref", @@ -687,6 +715,9 @@ def __reduce__(self): _np_bfloat16: bfloat16, _np_float8_e5m2: float8_e5m2, _np_float8_e4m3fn: float8_e4m3fn, + _np_float8_e4m3fnuz: float8_e4m3fnuz, + _np_float8_e4m3b11fnuz: float8_e4m3b11fnuz, + _np_float8_e5m2fnuz: float8_e5m2fnuz, _np_int4: int4, _np_uint4: uint4, } @@ -734,6 +765,9 @@ def __reduce__(self): types_pb2.DT_BFLOAT16: _np_bfloat16, types_pb2.DT_FLOAT8_E5M2: _np_float8_e5m2, types_pb2.DT_FLOAT8_E4M3FN: _np_float8_e4m3fn, + types_pb2.DT_FLOAT8_E4M3FNUZ: _np_float8_e4m3fnuz, + types_pb2.DT_FLOAT8_E4M3B11FNUZ: _np_float8_e4m3b11fnuz, + types_pb2.DT_FLOAT8_E5M2FNUZ: _np_float8_e5m2fnuz, types_pb2.DT_INT4: _np_int4, types_pb2.DT_UINT4: _np_uint4, # Ref types @@ -760,6 +794,9 @@ def __reduce__(self): types_pb2.DT_BFLOAT16_REF: _np_bfloat16, types_pb2.DT_FLOAT8_E5M2_REF: _np_float8_e5m2, types_pb2.DT_FLOAT8_E4M3FN_REF: _np_float8_e4m3fn, + types_pb2.DT_FLOAT8_E4M3FNUZ_REF: _np_float8_e4m3fnuz, + types_pb2.DT_FLOAT8_E4M3B11FNUZ_REF: _np_float8_e4m3b11fnuz, + types_pb2.DT_FLOAT8_E5M2FNUZ_REF: _np_float8_e5m2fnuz, types_pb2.DT_INT4_REF: _np_int4, types_pb2.DT_UINT4_REF: _np_uint4, } diff --git a/tensorflow/python/framework/dtypes_test.py b/tensorflow/python/framework/dtypes_test.py index 541acb85d61f58..047f81408471a8 100644 --- a/tensorflow/python/framework/dtypes_test.py +++ b/tensorflow/python/framework/dtypes_test.py @@ -96,6 +96,18 @@ def testNumpyConversion(self): self.assertIs(dtypes.float8_e5m2, dtypes.as_dtype(dtypes._np_float8_e5m2)) self.assertIs(dtypes.float8_e4m3fn, dtypes.as_dtype(dtypes._np_float8_e4m3fn)) + self.assertIs( + dtypes.float8_e4m3fnuz, dtypes.as_dtype(dtypes._np_float8_e4m3fnuz) + ) + self.assertIs( + dtypes.float8_e4m3b11fnuz, + dtypes.as_dtype(dtypes._np_float8_e4m3b11fnuz), + ) + self.assertIs( + dtypes.float8_e5m2fnuz, dtypes.as_dtype(dtypes._np_float8_e5m2fnuz) + ) + self.assertIs(dtypes.int4, dtypes.as_dtype(dtypes._np_int4)) + self.assertIs(dtypes.uint4, dtypes.as_dtype(dtypes._np_uint4)) with self.assertRaises(TypeError): dtypes.as_dtype(np.dtype([("f1", np.uint), ("f2", np.int32)])) @@ -121,6 +133,9 @@ def testRealDtype(self): dtypes.int64, dtypes.float8_e5m2, dtypes.float8_e4m3fn, + dtypes.float8_e4m3fnuz, + dtypes.float8_e4m3b11fnuz, + dtypes.float8_e5m2fnuz, dtypes.int4, dtypes.uint4, ]: @@ -147,6 +162,11 @@ def testStringConversion(self): self.assertIs(dtypes.bfloat16, dtypes.as_dtype("bfloat16")) self.assertIs(dtypes.float8_e5m2, dtypes.as_dtype("float8_e5m2")) self.assertIs(dtypes.float8_e4m3fn, dtypes.as_dtype("float8_e4m3fn")) + self.assertIs(dtypes.float8_e4m3fnuz, dtypes.as_dtype("float8_e4m3fnuz")) + self.assertIs( + dtypes.float8_e4m3b11fnuz, dtypes.as_dtype("float8_e4m3b11fnuz") + ) + self.assertIs(dtypes.float8_e5m2fnuz, dtypes.as_dtype("float8_e5m2fnuz")) self.assertIs(dtypes.int4, dtypes.as_dtype("int4")) self.assertIs(dtypes.uint4, dtypes.as_dtype("uint4")) self.assertIs(dtypes.float32_ref, dtypes.as_dtype("float32_ref")) @@ -199,6 +219,9 @@ def testIsInteger(self): self.assertEqual(dtypes.as_dtype("bfloat16").is_integer, False) self.assertEqual(dtypes.as_dtype("float8_e5m2").is_integer, False) self.assertEqual(dtypes.as_dtype("float8_e4m3fn").is_integer, False) + self.assertEqual(dtypes.as_dtype("float8_e4m3fnuz").is_integer, False) + self.assertEqual(dtypes.as_dtype("float8_e4m3b11fnuz").is_integer, False) + self.assertEqual(dtypes.as_dtype("float8_e5m2fnuz").is_integer, False) self.assertEqual(dtypes.as_dtype("int4").is_integer, True) self.assertEqual(dtypes.as_dtype("uint4").is_integer, True) self.assertEqual(dtypes.as_dtype("qint8").is_integer, False) @@ -223,6 +246,9 @@ def testIsFloating(self): self.assertEqual(dtypes.as_dtype("bfloat16").is_floating, True) self.assertEqual(dtypes.as_dtype("float8_e5m2").is_floating, True) self.assertEqual(dtypes.as_dtype("float8_e4m3fn").is_floating, True) + self.assertEqual(dtypes.as_dtype("float8_e4m3fnuz").is_floating, True) + self.assertEqual(dtypes.as_dtype("float8_e4m3b11fnuz").is_floating, True) + self.assertEqual(dtypes.as_dtype("float8_e5m2fnuz").is_floating, True) self.assertEqual(dtypes.as_dtype("int4").is_floating, False) self.assertEqual(dtypes.as_dtype("uint4").is_floating, False) self.assertEqual(dtypes.as_dtype("qint8").is_floating, False) @@ -247,6 +273,9 @@ def testIsComplex(self): self.assertEqual(dtypes.as_dtype("bfloat16").is_complex, False) self.assertEqual(dtypes.as_dtype("float8_e5m2").is_complex, False) self.assertEqual(dtypes.as_dtype("float8_e4m3fn").is_complex, False) + self.assertEqual(dtypes.as_dtype("float8_e4m3fnuz").is_complex, False) + self.assertEqual(dtypes.as_dtype("float8_e4m3b11fnuz").is_complex, False) + self.assertEqual(dtypes.as_dtype("float8_e5m2fnuz").is_complex, False) self.assertEqual(dtypes.as_dtype("int4").is_complex, False) self.assertEqual(dtypes.as_dtype("uint4").is_complex, False) self.assertEqual(dtypes.as_dtype("qint8").is_complex, False) @@ -271,6 +300,9 @@ def testIsUnsigned(self): self.assertEqual(dtypes.as_dtype("bfloat16").is_unsigned, False) self.assertEqual(dtypes.as_dtype("float8_e5m2").is_unsigned, False) self.assertEqual(dtypes.as_dtype("float8_e4m3fn").is_unsigned, False) + self.assertEqual(dtypes.as_dtype("float8_e4m3fnuz").is_unsigned, False) + self.assertEqual(dtypes.as_dtype("float8_e4m3b11fnuz").is_unsigned, False) + self.assertEqual(dtypes.as_dtype("float8_e5m2fnuz").is_unsigned, False) self.assertEqual(dtypes.as_dtype("int4").is_unsigned, False) self.assertEqual(dtypes.as_dtype("uint4").is_unsigned, True) self.assertEqual(dtypes.as_dtype("qint8").is_unsigned, False) @@ -341,6 +373,15 @@ def testMinMax(self): if numpy_dtype == dtypes.float8_e4m3fn.as_numpy_dtype: self.assertEqual(dtype.min, -448.0) self.assertEqual(dtype.max, 448.0) + if numpy_dtype == dtypes.float8_e4m3fnuz.as_numpy_dtype: + self.assertEqual(dtype.min, -240.0) + self.assertEqual(dtype.max, 240.0) + if numpy_dtype == dtypes.float8_e4m3b11fnuz.as_numpy_dtype: + self.assertEqual(dtype.min, -30.0) + self.assertEqual(dtype.max, 30.0) + if numpy_dtype == dtypes.float8_e5m2fnuz.as_numpy_dtype: + self.assertEqual(dtype.min, -57344.0) + self.assertEqual(dtype.max, 57344.0) if numpy_dtype == dtypes.int4.as_numpy_dtype: self.assertEqual(dtype.min, -8) self.assertEqual(dtype.max, 7) diff --git a/tensorflow/python/framework/experimental/BUILD b/tensorflow/python/framework/experimental/BUILD index 273cf42c4e132c..dfa124d299ab93 100644 --- a/tensorflow/python/framework/experimental/BUILD +++ b/tensorflow/python/framework/experimental/BUILD @@ -21,16 +21,20 @@ tf_python_pybind_extension( pytype_srcs = [ "_unified_api.pyi", ], + starlark_only = True, visibility = [ "//tensorflow/python:__pkg__", ], deps = [ "//tensorflow/c/eager:tfe_tensorhandle_internal", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/platform:refcount", "//tensorflow/python:unified_api_pywrap_required_headers", "//tensorflow/python/lib/core:pybind11_lib", + "@com_google_absl//absl/types:span", "@pybind11", ], ) @@ -39,6 +43,7 @@ tf_python_pybind_extension( name = "_tape", srcs = ["tape.cc"], features = ["-layering_check"], + starlark_only = True, visibility = [ "//tensorflow/python:__pkg__", ], @@ -49,6 +54,8 @@ tf_python_pybind_extension( "//tensorflow/core/lib/llvm_rtti", "//tensorflow/python:unified_api_pywrap_required_headers", "//tensorflow/python/lib/core:pybind11_lib", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", "@pybind11", ] + if_pywrap( if_true = [ @@ -66,6 +73,7 @@ tf_python_pybind_extension( pytype_srcs = [ "_math_ops.pyi", ], + starlark_only = True, visibility = [ "//tensorflow/python:__pkg__", ], @@ -93,6 +101,7 @@ tf_python_pybind_extension( pytype_srcs = [ "_nn_ops.pyi", ], + starlark_only = True, visibility = [ "//tensorflow/python:__pkg__", ], diff --git a/tensorflow/python/framework/experimental/_math_ops.pyi b/tensorflow/python/framework/experimental/_math_ops.pyi index 96ab1b898df088..0867853a4a59e6 100644 --- a/tensorflow/python/framework/experimental/_math_ops.pyi +++ b/tensorflow/python/framework/experimental/_math_ops.pyi @@ -13,12 +13,10 @@ # limitations under the License. # ============================================================================== -from typing import Any - -def add(*args, **kwargs) -> Any: ... -def div_no_nan(*args, **kwargs) -> Any: ... -def log1p(*args, **kwargs) -> Any: ... -def mat_mul(*args, **kwargs) -> Any: ... -def mul(*args, **kwargs) -> Any: ... -def neg(*args, **kwargs) -> Any: ... -def sub(*args, **kwargs) -> Any: ... +def add(*args, **kwargs): ... +def div_no_nan(*args, **kwargs): ... +def log1p(*args, **kwargs): ... +def mat_mul(*args, **kwargs): ... +def mul(*args, **kwargs): ... +def neg(*args, **kwargs): ... +def sub(*args, **kwargs): ... diff --git a/tensorflow/python/framework/experimental/_nn_ops.pyi b/tensorflow/python/framework/experimental/_nn_ops.pyi index 64d57f21fae113..919504720779cb 100644 --- a/tensorflow/python/framework/experimental/_nn_ops.pyi +++ b/tensorflow/python/framework/experimental/_nn_ops.pyi @@ -13,7 +13,5 @@ # limitations under the License. # ============================================================================== -from typing import Any - -def relu(*args, **kwargs) -> Any: ... -def sparse_softmax_cross_entropy_with_logits(*args, **kwargs) -> Any: ... +def relu(*args, **kwargs): ... +def sparse_softmax_cross_entropy_with_logits(*args, **kwargs): ... diff --git a/tensorflow/python/framework/experimental/_unified_api.pyi b/tensorflow/python/framework/experimental/_unified_api.pyi index eed51c09ae560a..5d4a3b33aac531 100644 --- a/tensorflow/python/framework/experimental/_unified_api.pyi +++ b/tensorflow/python/framework/experimental/_unified_api.pyi @@ -13,11 +13,9 @@ # limitations under the License. # ============================================================================== -from typing import Any - class AbstractContext: def __init__(self, *args, **kwargs) -> None: ... - def CreateOperation(self, *args, **kwargs) -> Any: ... + def CreateOperation(self, *args, **kwargs): ... def RegisterFunction(self, arg0) -> None: ... def RemoveFunction(self, arg0: str) -> None: ... @@ -28,7 +26,7 @@ class AbstractOperation: def __init__(self, *args, **kwargs) -> None: ... def AddInput(self, arg0) -> None: ... def DeviceName(self) -> str: ... - def Execute(self, *args, **kwargs) -> Any: ... + def Execute(self, *args, **kwargs): ... def Name(self) -> str: ... def Reset(self, arg0: str, arg1: str) -> None: ... def SetAttrType(self, arg0: str, arg1) -> None: ... @@ -37,7 +35,7 @@ class AbstractOperation: class AbstractTensorHandle: def __init__(self, *args, **kwargs) -> None: ... - def DataType(self, *args, **kwargs) -> Any: ... + def DataType(self, *args, **kwargs): ... def numpy(self) -> object: ... class ImmediateExecutionContext(AbstractContext): @@ -45,10 +43,10 @@ class ImmediateExecutionContext(AbstractContext): class TracingContext(AbstractContext): def __init__(self, *args, **kwargs) -> None: ... - def AddParameter(self, *args, **kwargs) -> Any: ... - def Finalize(self, *args, **kwargs) -> Any: ... + def AddParameter(self, *args, **kwargs): ... + def Finalize(self, *args, **kwargs): ... -def EagerContextToImmediateExecutionContext(*args, **kwargs) -> Any: ... +def EagerContextToImmediateExecutionContext(*args, **kwargs): ... def EagerTensorToImmediateExecutionTensorHandle(arg0: object) -> AbstractTensorHandle: ... -def NewTracingContext(*args, **kwargs) -> Any: ... +def NewTracingContext(*args, **kwargs): ... def SetTracingImplementation(arg0: str) -> None: ... diff --git a/tensorflow/python/framework/experimental/math_ops.cc b/tensorflow/python/framework/experimental/math_ops.cc index 8508bb58afd0da..7c9954eb18e326 100644 --- a/tensorflow/python/framework/experimental/math_ops.cc +++ b/tensorflow/python/framework/experimental/math_ops.cc @@ -17,9 +17,6 @@ limitations under the License. #include -#include - -#include "absl/types/span.h" #include "pybind11/pybind11.h" // from @pybind11 #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" diff --git a/tensorflow/python/framework/experimental/nn_ops.cc b/tensorflow/python/framework/experimental/nn_ops.cc index 7d2228532273ae..983bdb2b24b974 100644 --- a/tensorflow/python/framework/experimental/nn_ops.cc +++ b/tensorflow/python/framework/experimental/nn_ops.cc @@ -17,9 +17,6 @@ limitations under the License. #include -#include - -#include "absl/types/span.h" #include "pybind11/pybind11.h" // from @pybind11 #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" diff --git a/tensorflow/python/framework/experimental/tape.cc b/tensorflow/python/framework/experimental/tape.cc index 951e649df3b473..2b161a0f7d94ef 100644 --- a/tensorflow/python/framework/experimental/tape.cc +++ b/tensorflow/python/framework/experimental/tape.cc @@ -14,6 +14,10 @@ limitations under the License. ==============================================================================*/ #include +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" #include "pybind11/pybind11.h" // from @pybind11 #include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/experimental/gradients/math_grad.h" diff --git a/tensorflow/python/framework/experimental/unified_api.cc b/tensorflow/python/framework/experimental/unified_api.cc index dddc322610e823..ea1047ff8d9032 100644 --- a/tensorflow/python/framework/experimental/unified_api.cc +++ b/tensorflow/python/framework/experimental/unified_api.cc @@ -16,7 +16,10 @@ limitations under the License. #include #include +#include +#include +#include "absl/types/span.h" #include "pybind11/pybind11.h" // from @pybind11 #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_function.h" diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 848a4c8f23599f..7cc817d45ee32b 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -1385,6 +1385,9 @@ def _type_list_to_str(types): dtypes.bfloat16: "b16", dtypes.float8_e5m2: "f8e5m2", dtypes.float8_e4m3fn: "f8e4m3fn", + dtypes.float8_e4m3fnuz: "f8e4m3fnuz", + dtypes.float8_e4m3b11fnuz: "f8e4m3b11fnuz", + dtypes.float8_e5m2fnuz: "f8e5m2fnuz", dtypes.int4: "i4", dtypes.uint4: "u4", } diff --git a/tensorflow/python/framework/python_api_info.cc b/tensorflow/python/framework/python_api_info.cc index 7df48e4d1be528..cacee6c4591d5a 100644 --- a/tensorflow/python/framework/python_api_info.cc +++ b/tensorflow/python/framework/python_api_info.cc @@ -118,9 +118,9 @@ void GetOpDefNamesAndDefaults(const tensorflow::OpDef& op_def, PythonAPIInfo::PythonAPIInfo(const std::string& api_name) : api_name_(InternPyString(api_name)) {} -Status PythonAPIInfo::Initialize(const OpDef& op_def, - const std::vector param_names, - PyObject* defaults_tuple) { +absl::Status PythonAPIInfo::Initialize(const OpDef& op_def, + const std::vector param_names, + PyObject* defaults_tuple) { // Intern the parameter names. param_names_.reserve(param_names.size()); for (const auto& param_name : param_names) { @@ -170,7 +170,7 @@ Status PythonAPIInfo::Initialize(const OpDef& op_def, return absl::OkStatus(); } -Status PythonAPIInfo::CheckParamNames() const { +absl::Status PythonAPIInfo::CheckParamNames() const { std::vector param_found(param_names_.size()); for (const auto& attr : attributes_) { if (attr.index != -1) { @@ -193,7 +193,8 @@ Status PythonAPIInfo::CheckParamNames() const { return absl::OkStatus(); } -Status PythonAPIInfo::InitializeFromRegisteredOp(const std::string& op_name) { +absl::Status PythonAPIInfo::InitializeFromRegisteredOp( + const std::string& op_name) { const tensorflow::OpDef* op_def = nullptr; TF_RETURN_IF_ERROR( tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def)); @@ -204,7 +205,7 @@ Status PythonAPIInfo::InitializeFromRegisteredOp(const std::string& op_name) { return absl::OkStatus(); } -Status PythonAPIInfo::InitializeFromParamSpecs( +absl::Status PythonAPIInfo::InitializeFromParamSpecs( const std::map& input_specs, const std::map& attr_specs, const std::vector param_names, PyObject* defaults_tuple) { @@ -226,7 +227,7 @@ Status PythonAPIInfo::InitializeFromParamSpecs( return absl::OkStatus(); } -Status PythonAPIInfo::InitializeAttribute( +absl::Status PythonAPIInfo::InitializeAttribute( const OpDef::AttrDef& attr_def, const std::map& param_name_to_index) { if (attr_def.name() == "name") { @@ -296,7 +297,7 @@ Status PythonAPIInfo::InitializeAttribute( return absl::OkStatus(); } -Status PythonAPIInfo::InitializeInput( +absl::Status PythonAPIInfo::InitializeInput( const OpDef::ArgDef& arg_def, const std::map& param_name_to_index) { if (arg_def.name() == "name") { diff --git a/tensorflow/python/framework/python_api_info.h b/tensorflow/python/framework/python_api_info.h index 0484531a8f9c6d..6372a9e2345c12 100644 --- a/tensorflow/python/framework/python_api_info.h +++ b/tensorflow/python/framework/python_api_info.h @@ -143,15 +143,16 @@ class PythonAPIInfo { // defaults_tuple: Tuple containing default values for the parameters, // right-aligned with `param_names` -- i.e., `defaults[-i]` is the default // for `param_names[-i]`. - Status Initialize(const OpDef& op_def, const std::vector param_names, - PyObject* defaults_tuple); + absl::Status Initialize(const OpDef& op_def, + const std::vector param_names, + PyObject* defaults_tuple); // Initialize this PythonAPIInfo based on the registered OpDef for the given // operation. // // Args: // op_name: The registered name of the operation (e.g. "AddV2"). - Status InitializeFromRegisteredOp(const std::string& op_name); + absl::Status InitializeFromRegisteredOp(const std::string& op_name); // Initializes this PythonAPIInfo based on a set of parameter specifications. // @@ -167,7 +168,7 @@ class PythonAPIInfo { // // Note: the `name` parameter should not be included in `input_specs` or // `attr_specs`. - Status InitializeFromParamSpecs( + absl::Status InitializeFromParamSpecs( const std::map& input_specs, const std::map& attr_specs, const std::vector param_names, PyObject* defaults_tuple); @@ -226,7 +227,7 @@ class PythonAPIInfo { // If `attr_def` describes an int attribute, then adds a value to // inputs_with_number_attrs_ (to record any tensor inputs that use this // value as a list length). - Status InitializeAttribute( + absl::Status InitializeAttribute( const OpDef::AttrDef& attr_def, const std::map& param_name_to_index); @@ -241,12 +242,13 @@ class PythonAPIInfo { // If `arg_def`'s dtype is described by a `list(type)` attr, then updates the // appropriate value in `inputs_with_type_list_attrs_` with information about // the `arg_def`. - Status InitializeInput(const OpDef::ArgDef& arg_def, - const std::map& param_name_to_index); + absl::Status InitializeInput( + const OpDef::ArgDef& arg_def, + const std::map& param_name_to_index); // Checks that the OpDef used to initialize this PythonAPIInfo // had an AttrDef or ArgDef specification for each parameter. - Status CheckParamNames() const; + absl::Status CheckParamNames() const; // Searches inputs_with_type_attrs_ for an input with the given name. InputsWithTypeAttr* FindInputsWithTypeAttr(const string& name); diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index e064900002be16..6ca94896d51bfe 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -93,6 +93,9 @@ const std::unordered_map dtype_type{ {"_dtypes.variant", "_atypes.Variant"}, {"_dtypes.float8_e4m3fn", "_atypes.Float8e4m3fn"}, {"_dtypes.float8_e5m2", "_atypes.Float8e5m2"}, + {"_dtypes.float8_e4m3fnuz", "_atypes.Float8e4m3fnuz"}, + {"_dtypes.float8_e4m3b11fnuz", "_atypes.Float8e4m3b11fnuz"}, + {"_dtypes.float8_e5m2fnuz", "_atypes.Float8e5m2fnuz"}, {"_dtypes.int4", "_atypes.Int4"}, {"_dtypes.uint4", "_atypes.UInt4"}, }; @@ -326,7 +329,7 @@ class GenPythonOp { // with defaults, except "name" std::vector param_names_; - StringPiece op_name_; + absl::string_view op_name_; typedef std::unordered_map> AttrToArgMap; AttrToArgMap attr_to_args_; std::unordered_map attr_expressions_; @@ -411,7 +414,7 @@ string AvoidPythonReserved(const string& s) { // Indent the first line by "initial" spaces and all following lines // by "rest" spaces. -string Indent(int initial, int rest, StringPiece in) { +string Indent(int initial, int rest, absl::string_view in) { // TODO(josh11b): Also word-wrapping? string copy(in.data(), in.size()); absl::StripTrailingAsciiWhitespace(©); @@ -436,7 +439,7 @@ string Indent(int initial, int rest, StringPiece in) { // Adds append to *dest, with a space if the first line will be <= width, // or a newline otherwise. -void AppendWithinWidth(string* dest, StringPiece append, int width) { +void AppendWithinWidth(string* dest, absl::string_view append, int width) { auto first_line = append.find('\n'); if (first_line == string::npos) first_line = append.size(); if (dest->size() + first_line + 1 /* space */ > static_cast(width)) { @@ -585,7 +588,7 @@ string GetReturns(const OpDef& op_def, strings::StrAppend(&result, " The created Operation.\n"); } else { if (num_outs == 1) { - StringPiece description = op_def.output_arg(0).description(); + absl::string_view description = op_def.output_arg(0).description(); if (ConsumeEquals(&description)) { // Skip the generated type info. strings::StrAppend(&result, Indent(4, 4, description)); } else { @@ -621,7 +624,7 @@ string GetReturns(const OpDef& op_def, absl::StrJoin(out_names, ", "), ").\n\n"); for (int i = 0; i < num_outs; ++i) { string desc = strings::StrCat(out_names[i], ": "); - StringPiece description = op_def.output_arg(i).description(); + absl::string_view description = op_def.output_arg(i).description(); if (ConsumeEquals(&description)) { // Skip the generated type info. strings::StrAppend(&desc, description); } else { @@ -798,7 +801,7 @@ static void AddDelimiter(string* append_to, const string& delim) { if (!append_to->empty()) strings::StrAppend(append_to, delim); } -const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) { +const ApiDef::Attr* FindAttr(absl::string_view name, const ApiDef& api_def) { for (int i = 0; i < api_def.attr_size(); ++i) { if (api_def.attr(i).name() == name) { return &api_def.attr(i); @@ -889,7 +892,7 @@ void GenPythonOp::AddDocStringInputs() { for (int i = 0; i < api_def_.arg_order_size(); ++i) { const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_); const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_); - StringPiece description = api_def_arg.description(); + absl::string_view description = api_def_arg.description(); string desc; if (ConsumeEquals(&description)) { // Skip the generated type info. desc = strings::StrCat(param_names_[i].GetRenameTo(), ": "); @@ -1321,7 +1324,9 @@ void GenPythonOp::GenerateTypeVars( it != allowed_types.end(); ++it) { if (!typevar_dtypes.empty()) strings::StrAppend(&typevar_dtypes, ", "); + strings::StrAppend(&typevar_dtypes, "\""); strings::StrAppend(&typevar_dtypes, *it); + strings::StrAppend(&typevar_dtypes, "\""); } } @@ -1512,7 +1517,7 @@ bool GenPythonOp::GetEagerFunctionSetup(const string& indentation, const auto& param = param_names_[i + op_def_.input_arg_size()]; const auto& attr = *FindAttr(attr_name, op_def_); const string& attr_api_name = param.GetRenameTo(); - StringPiece attr_type = attr.type(); + absl::string_view attr_type = attr.type(); attr_expressions_[attr_name] = attr_api_name; const int default_index = i - (attrs_.size() - params_with_default_.size()); if (default_index >= 0) { diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc index 940c6a349b1c1f..948b320b3c3581 100644 --- a/tensorflow/python/framework/python_op_gen_main.cc +++ b/tensorflow/python/framework/python_op_gen_main.cc @@ -47,20 +47,20 @@ namespace { constexpr char kUsage[] = "This tool generates python wrapper for tensorflow ops."; -Status ReadOpListFromFile(const string& filename, - std::vector* op_list) { +absl::Status ReadOpListFromFile(const string& filename, + std::vector* op_list) { std::unique_ptr file; TF_RETURN_IF_ERROR(Env::Default()->NewRandomAccessFile(filename, &file)); std::unique_ptr input_buffer( new io::InputBuffer(file.get(), 256 << 10)); string line_contents; - Status s = input_buffer->ReadLine(&line_contents); + absl::Status s = input_buffer->ReadLine(&line_contents); while (s.ok()) { // The parser assumes that the op name is the first string on each // line with no preceding whitespace, and ignores lines that do // not start with an op name as a comment. - strings::Scanner scanner{StringPiece(line_contents)}; - StringPiece op_name; + strings::Scanner scanner{absl::string_view(line_contents)}; + absl::string_view op_name; if (scanner.One(strings::Scanner::LETTER_DIGIT_DOT) .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) .GetResult(nullptr, &op_name)) { @@ -72,8 +72,8 @@ Status ReadOpListFromFile(const string& filename, return absl::OkStatus(); } -Status ReadOpRegOffsetsFromFile(absl::string_view filename, - OpRegOffsets* op_reg_offsets) { +absl::Status ReadOpRegOffsetsFromFile(absl::string_view filename, + OpRegOffsets* op_reg_offsets) { std::unique_ptr file; TF_RETURN_IF_ERROR( Env::Default()->NewRandomAccessFile(std::string(filename), &file)); @@ -103,12 +103,12 @@ std::vector GetSourceFileListFromOpRegOffsets( // // If `source_file_name` is not empty, a comment block will be generated // to show the source file name that the generated file is generated from. -Status PrintAllPythonOps(absl::Span api_def_dirs, - absl::Span source_file_list, - const string& out_path, - const OpRegOffsets& op_reg_offsets, - absl::Span op_allowlist = {}, - absl::Span hidden_op_list = {}) { +absl::Status PrintAllPythonOps(absl::Span api_def_dirs, + absl::Span source_file_list, + const string& out_path, + const OpRegOffsets& op_reg_offsets, + absl::Span op_allowlist = {}, + absl::Span hidden_op_list = {}) { OpList ops; OpRegistry::Global()->Export(false, &ops); diff --git a/tensorflow/python/framework/python_op_gen_test.cc b/tensorflow/python/framework/python_op_gen_test.cc index bab9e087bb5400..d02861d2e12978 100644 --- a/tensorflow/python/framework/python_op_gen_test.cc +++ b/tensorflow/python/framework/python_op_gen_test.cc @@ -60,14 +60,22 @@ TEST(PythonOpGen, TypeAnnotateAllOps) { /* source_file_list= */ {}); const string all_types = - ", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, " - "_atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, " - "_atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, " - "_atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, " - "_atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, " - "_atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, " - "_atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, " - "_atypes.Variant)"; + ", \"_atypes.BFloat16\", \"_atypes.Bool\", \"_atypes.Complex128\", " + "\"_atypes.Complex64\", \"_atypes.Float16\", \"_atypes.Float32\", " + "\"_atypes.Float64\", " + "\"_atypes.Float8e4m3b11fnuz\", \"_atypes.Float8e4m3fn\", " + "\"_atypes.Float8e4m3fnuz\", \"_atypes.Float8e5m2\", " + "\"_atypes.Float8e5m2fnuz\", " + "\"_atypes.Half\", \"_atypes.Int16\", " + "\"_atypes.Int32\", \"_atypes.Int4\", \"_atypes.Int64\", " + "\"_atypes.Int8\", " + "\"_atypes.QInt16\", \"_atypes.QInt32\", \"_atypes.QInt8\", " + "\"_atypes.QUInt16\", " + "\"_atypes.QUInt8\", \"_atypes.Resource\", \"_atypes.String\", " + "\"_atypes.UInt16\", " + "\"_atypes.UInt32\", \"_atypes.UInt4\", \"_atypes.UInt64\", " + "\"_atypes.UInt8\", " + "\"_atypes.Variant\")"; const string fake_param_typevar = "TV_FakeParam_dtype = TypeVar(\"TV_FakeParam_dtype\"" + all_types; @@ -248,8 +256,8 @@ TEST(PythonOpGen, GenerateCorrectTypeVars) { /* source_file_list= */ {}); const string typevars_foo = R"( -TV_Foo_T = TypeVar("TV_Foo_T", _atypes.Int8, _atypes.UInt8) -TV_Foo_T2 = TypeVar("TV_Foo_T2", _atypes.Float32, _atypes.Float64, _atypes.String) +TV_Foo_T = TypeVar("TV_Foo_T", "_atypes.Int8", "_atypes.UInt8") +TV_Foo_T2 = TypeVar("TV_Foo_T2", "_atypes.Float32", "_atypes.Float64", "_atypes.String") )"; ExpectHasSubstr(code, typevars_foo); diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 9097784711fb42..35bf60a4d7bf6a 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -110,6 +110,63 @@ def FastAppendFloat8e4m3fnArrayToTensorProto(tensor_proto, proto_values): np.uint8)) +def SlowAppendFloat8e4m3fnuzArrayToTensorProto(tensor_proto, proto_values): + tensor_proto.float8_val += ( + numpy_compat.np_asarray( + proto_values, dtype=dtypes.float8_e4m3fnuz.as_numpy_dtype + ) + .view(np.uint8) + .tobytes() + ) + + +def FastAppendFloat8e4m3fnuzArrayToTensorProto(tensor_proto, proto_values): + fast_tensor_util.AppendFloat8ArrayToTensorProto( + tensor_proto, + numpy_compat.np_asarray( + proto_values, dtype=dtypes.float8_e4m3fnuz.as_numpy_dtype + ).view(np.uint8), + ) + + +def SlowAppendFloat8e4m3b11fnuzArrayToTensorProto(tensor_proto, proto_values): + tensor_proto.float8_val += ( + numpy_compat.np_asarray( + proto_values, dtype=dtypes.float8_e4m3b11fnuz.as_numpy_dtype + ) + .view(np.uint8) + .tobytes() + ) + + +def FastAppendFloat8e4m3b11fnuzArrayToTensorProto(tensor_proto, proto_values): + fast_tensor_util.AppendFloat8ArrayToTensorProto( + tensor_proto, + numpy_compat.np_asarray( + proto_values, dtype=dtypes.float8_e4m3b11fnuz.as_numpy_dtype + ).view(np.uint8), + ) + + +def SlowAppendFloat8e5m2fnuzArrayToTensorProto(tensor_proto, proto_values): + tensor_proto.float8_val += ( + numpy_compat.np_asarray( + proto_values, dtype=dtypes.float8_e5m2fnuz.as_numpy_dtype + ) + .view(np.uint8) + .tobytes() + ) + + +def FastAppendFloat8e5m2fnuzArrayToTensorProto(tensor_proto, proto_values): + fast_tensor_util.AppendFloat8ArrayToTensorProto( + tensor_proto, + numpy_compat.np_asarray( + proto_values, dtype=dtypes.float8_e5m2fnuz.as_numpy_dtype + ).view(np.uint8), + ) + + def SlowAppendInt4ArrayToTensorProto(tensor_proto, proto_values): # The actual bit representation of int4 as a bit-field is # implementation-defined, so we need to explicitly cast each @@ -165,6 +222,15 @@ def SlowAppendUInt4ArrayToTensorProto(tensor_proto, proto_values): dtypes.float8_e4m3fn.as_numpy_dtype: ( FastAppendFloat8e4m3fnArrayToTensorProto ), + dtypes.float8_e4m3fnuz.as_numpy_dtype: ( + FastAppendFloat8e4m3fnuzArrayToTensorProto + ), + dtypes.float8_e4m3b11fnuz.as_numpy_dtype: ( + FastAppendFloat8e4m3b11fnuzArrayToTensorProto + ), + dtypes.float8_e5m2fnuz.as_numpy_dtype: ( + FastAppendFloat8e5m2fnuzArrayToTensorProto + ), dtypes.int4.as_numpy_dtype: SlowAppendInt4ArrayToTensorProto, dtypes.uint4.as_numpy_dtype: SlowAppendUInt4ArrayToTensorProto, } @@ -288,30 +354,31 @@ def _FlattenToStrings(nested_strings): yield nested_strings -_TENSOR_CONTENT_TYPES = frozenset( - [ - dtypes.float16, - dtypes.float32, - dtypes.float64, - dtypes.int32, - dtypes.uint8, - dtypes.int16, - dtypes.int8, - dtypes.int64, - dtypes.qint8, - dtypes.quint8, - dtypes.qint16, - dtypes.quint16, - dtypes.qint32, - dtypes.uint32, - dtypes.uint64, - dtypes.float8_e5m2, - dtypes.float8_e4m3fn, - dtypes.bfloat16 - # int4/uint4 intentionally not listed, since their binary representation - # is implementation-dependent. - ] -) +_TENSOR_CONTENT_TYPES = frozenset([ + dtypes.float16, + dtypes.float32, + dtypes.float64, + dtypes.int32, + dtypes.uint8, + dtypes.int16, + dtypes.int8, + dtypes.int64, + dtypes.qint8, + dtypes.quint8, + dtypes.qint16, + dtypes.quint16, + dtypes.qint32, + dtypes.uint32, + dtypes.uint64, + dtypes.float8_e5m2, + dtypes.float8_e4m3fn, + dtypes.float8_e4m3fnuz, + dtypes.float8_e4m3b11fnuz, + dtypes.float8_e5m2fnuz, + dtypes.bfloat16, + # int4/uint4 intentionally not listed, since their binary representation + # is implementation-dependent. +]) # pylint: disable=invalid-name diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index f43ba2deb2663c..9949b706405fb1 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -321,6 +321,60 @@ def testFloat8e4m3fn(self): tensor_content: "RZ" """, t) + def testFloat8e4m3fnuz(self): + test_type = dtypes.float8_e4m3fnuz.as_numpy_dtype + t = tensor_util.make_tensor_proto(np.array([10.0, 20.0], dtype=test_type)) + # 10.0: "Z" = 90 = 1010 010: 2^(10 - 7) * (1 + 1/4) + 8 + # 20.0: "b" = 98 = 1011 010: 2^(11 - 7) * (1 + 1/4) + 8 + self.assertProtoEquals( + """ + dtype: DT_FLOAT8_E4M3FNUZ + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "Zb" + """, + t, + ) + + def testFloat8e4m3b11fnuz(self): + test_type = dtypes.float8_e4m3b11fnuz.as_numpy_dtype + t = tensor_util.make_tensor_proto(np.array([10.0, 20.0], dtype=test_type)) + # 10.0: "r" = 114 = 1010 010: 2^(10 - 7) * (1 + 1/4) + 36 + # 20.0: "z" = 126 = 1011 010: 2^(11 - 7) * (1 + 1/4) + 36 + self.assertProtoEquals( + """ + dtype: DT_FLOAT8_E4M3B11FNUZ + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "rz" + """, + t, + ) + + def testFloat8e5m2fnuz(self): + test_type = dtypes.float8_e5m2fnuz.as_numpy_dtype + t = tensor_util.make_tensor_proto(np.array([10.0, 20.0], dtype=test_type)) + # 10.0: "M" = 77 = 1010 010: 2^(10 - 7) * (1 + 1/4) - 3 + # 20.0: "Q" = 87 = 1011 010: 2^(11 - 7) * (1 + 1/4) - 3 + self.assertProtoEquals( + """ + dtype: DT_FLOAT8_E5M2FNUZ + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "MQ" + """, + t, + ) + def testInt(self): t = tensor_util.make_tensor_proto(10) self.assertProtoEquals(""" diff --git a/tensorflow/python/framework/test_file_system.cc b/tensorflow/python/framework/test_file_system.cc index 1bb3bff3520b10..ab68834712aed2 100644 --- a/tensorflow/python/framework/test_file_system.cc +++ b/tensorflow/python/framework/test_file_system.cc @@ -20,9 +20,9 @@ namespace tensorflow { class TestRandomAccessFile : public RandomAccessFile { // The file contents is 10 bytes of all A's - Status Read(uint64 offset, size_t n, StringPiece* result, - char* scratch) const override { - Status s; + absl::Status Read(uint64 offset, size_t n, absl::string_view* result, + char* scratch) const override { + absl::Status s; for (int i = 0; i < n; ++i) { if (offset + i >= 10) { n = i; @@ -31,22 +31,22 @@ class TestRandomAccessFile : public RandomAccessFile { } scratch[i] = 'A'; } - *result = StringPiece(scratch, n); + *result = absl::string_view(scratch, n); return s; } }; class TestFileSystem : public NullFileSystem { public: - Status NewRandomAccessFile( + absl::Status NewRandomAccessFile( const string& fname, TransactionToken* token, std::unique_ptr* result) override { result->reset(new TestRandomAccessFile); return absl::OkStatus(); } // Always return size of 10 - Status GetFileSize(const string& fname, TransactionToken* token, - uint64* file_size) override { + absl::Status GetFileSize(const string& fname, TransactionToken* token, + uint64* file_size) override { *file_size = 10; return absl::OkStatus(); } diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 46f981df64b6c6..ff0f5810563641 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -3233,7 +3233,10 @@ def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): a_dtype = a.dtype custom_dtypes = (dtypes.bfloat16.as_numpy_dtype, dtypes.float8_e5m2.as_numpy_dtype, - dtypes.float8_e4m3fn.as_numpy_dtype) + dtypes.float8_e4m3fn.as_numpy_dtype, + dtypes.float8_e4m3fnuz.as_numpy_dtype, + dtypes.float8_e4m3b11fnuz.as_numpy_dtype, + dtypes.float8_e5m2fnuz.as_numpy_dtype) a = a.astype(np.float32) if a.dtype in custom_dtypes else a b = b.astype(np.float32) if b.dtype in custom_dtypes else b if not np.allclose(a, b, rtol=rtol, atol=atol): diff --git a/tensorflow/python/grappler/BUILD b/tensorflow/python/grappler/BUILD index 1e1d643602b5ba..baf0641b6fbc42 100644 --- a/tensorflow/python/grappler/BUILD +++ b/tensorflow/python/grappler/BUILD @@ -55,6 +55,7 @@ tf_python_pybind_extension( pytype_srcs = [ "_pywrap_cost_analyzer.pyi", ], + starlark_only = True, deps = [ ":cost_analyzer_headers", "//tensorflow/core:framework_headers_lib", @@ -91,6 +92,7 @@ tf_python_pybind_extension( pytype_srcs = [ "_pywrap_model_analyzer.pyi", ], + starlark_only = True, deps = [ "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib_headers_for_pybind", diff --git a/tensorflow/python/grappler/_pywrap_tf_cluster.pyi b/tensorflow/python/grappler/_pywrap_tf_cluster.pyi index fa2a1086cac252..1f717165e1e040 100644 --- a/tensorflow/python/grappler/_pywrap_tf_cluster.pyi +++ b/tensorflow/python/grappler/_pywrap_tf_cluster.pyi @@ -16,12 +16,12 @@ class Cluster: def __init__(self, *args, **kwargs) -> None: ... -def TF_DeterminePeakMemoryUsage(arg0, arg1: Cluster) -> dict[str,tuple[int,list[tuple[str,int,int,int,int]]]]: ... +def TF_DeterminePeakMemoryUsage(arg0, arg1: Cluster) -> dict[str, tuple[int, list[tuple[str, int, int, int, int]]]]: ... def TF_EstimatePerformance(arg0: bytes) -> float: ... -def TF_GetSupportedDevices(arg0: Cluster, arg1) -> dict[str,list[str]]: ... +def TF_GetSupportedDevices(arg0: Cluster, arg1) -> dict[str, list[str]]: ... def TF_ListAvailableOps() -> list[str]: ... def TF_ListDevices(arg0: Cluster) -> list[bytes]: ... -def TF_MeasureCosts(arg0, arg1: Cluster, arg2: bool) -> tuple[list[bytes],float,bytes]: ... +def TF_MeasureCosts(arg0, arg1: Cluster, arg2: bool) -> tuple[list[bytes], float, bytes]: ... def TF_NewCluster(arg0: bool, arg1: bool) -> Cluster: ... def TF_NewVirtualCluster(arg0: list[bytes]) -> Cluster: ... def TF_ShutdownCluster(arg0: Cluster) -> None: ... diff --git a/tensorflow/python/grappler/_pywrap_tf_item.pyi b/tensorflow/python/grappler/_pywrap_tf_item.pyi index a087325eb642c0..259ffceeba7e9c 100644 --- a/tensorflow/python/grappler/_pywrap_tf_item.pyi +++ b/tensorflow/python/grappler/_pywrap_tf_item.pyi @@ -17,6 +17,6 @@ class GrapplerItem: def __init__(self, *args, **kwargs) -> None: ... def TF_GetColocationGroups(arg0: GrapplerItem) -> list[list[str]]: ... -def TF_GetOpProperties(arg0: GrapplerItem) -> dict[str,list[bytes]]: ... +def TF_GetOpProperties(arg0: GrapplerItem) -> dict[str, list[bytes]]: ... def TF_IdentifyImportantOps(arg0: GrapplerItem, arg1: bool) -> list[str]: ... def TF_NewItem(arg0: bytes, arg1: bool, arg2: bool) -> GrapplerItem: ... diff --git a/tensorflow/python/grappler/_pywrap_tf_optimizer.pyi b/tensorflow/python/grappler/_pywrap_tf_optimizer.pyi index 9eb2d0e7393c6f..7bacdcd18d6beb 100644 --- a/tensorflow/python/grappler/_pywrap_tf_optimizer.pyi +++ b/tensorflow/python/grappler/_pywrap_tf_optimizer.pyi @@ -13,7 +13,5 @@ # limitations under the License. # ============================================================================== -from typing import Any - -def TF_OptimizeGraph(*args, **kwargs) -> Any: ... +def TF_OptimizeGraph(*args, **kwargs): ... def TF_OptimizeGraphSerialized(arg0, arg1: str, arg2: str, arg3: bool, arg4: str, arg5: bool) -> bytes: ... diff --git a/tensorflow/python/kernel_tests/array_ops/BUILD b/tensorflow/python/kernel_tests/array_ops/BUILD index 73f07f5b95b81b..af0e4464396a3e 100644 --- a/tensorflow/python/kernel_tests/array_ops/BUILD +++ b/tensorflow/python/kernel_tests/array_ops/BUILD @@ -259,6 +259,7 @@ cuda_py_strict_test( ], deps = [ "//tensorflow/python/client:device_lib", + "//tensorflow/python/eager:context", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:errors", "//tensorflow/python/framework:for_generated_wrappers", diff --git a/tensorflow/python/kernel_tests/array_ops/depthtospace_op_test.py b/tensorflow/python/kernel_tests/array_ops/depthtospace_op_test.py index d2a166a60136b1..2fac119599ad16 100644 --- a/tensorflow/python/kernel_tests/array_ops/depthtospace_op_test.py +++ b/tensorflow/python/kernel_tests/array_ops/depthtospace_op_test.py @@ -19,6 +19,7 @@ import numpy as np from tensorflow.python.client import device_lib +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -87,6 +88,14 @@ def testBlockSize2(self): [[11], [12], [15], [16]]]] self._testOne(x_np, block_size, x_out) + @test_util.run_deprecated_v1 + def testBlockSizeOverflow(self): + with context.eager_mode(): + x_np = [[[[1, 2, 3, 4]]]] + block_size = 100000 + with self.assertRaises(errors_impl.InvalidArgumentError): + self.evaluate(array_ops.depth_to_space(x_np, block_size)) + @test_util.run_deprecated_v1 def testBlockSize2Batch10(self): block_size = 2 diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 88cc811f08fa7d..e82c2fa44e73c5 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -752,7 +752,6 @@ cuda_py_strict_test( size = "medium", srcs = ["normalize_op_test.py"], shard_count = 20, - # TODO(b/117236102): Re-enable in msan build. tags = ["no_windows_gpu"], # TODO(b/208263392): Re-enable. tf.Squeeze op after tf.Where op doesn't reshape. xla_enable_strict_auto_jit = False, @@ -769,7 +768,6 @@ cuda_py_strict_test( size = "medium", srcs = ["norm_op_test.py"], shard_count = 20, - # TODO(b/117236102): Re-enable in msan build. tags = ["no_windows_gpu"], # TODO(b/208263392): Re-enable. tf.Squeeze op after tf.Where op doesn't reshape. xla_enable_strict_auto_jit = False, diff --git a/tensorflow/python/kernel_tests/proto/BUILD b/tensorflow/python/kernel_tests/proto/BUILD index 129797173211ed..f66e92ddaecc22 100644 --- a/tensorflow/python/kernel_tests/proto/BUILD +++ b/tensorflow/python/kernel_tests/proto/BUILD @@ -117,7 +117,7 @@ tf_cc_shared_object( ":test_example_proto_cc", ], if_true = [ - "//tensorflow/python:_pywrap_tensorflow_common", + "//tensorflow/python:tensorflow_common_framework", ":test_example_proto_cc_stripped", ], ), diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc index 435387f0eb0fa4..b3a8c84adf21a2 100644 --- a/tensorflow/python/lib/core/ndarray_tensor.cc +++ b/tensorflow/python/lib/core/ndarray_tensor.cc @@ -216,6 +216,15 @@ absl::Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array, } else if (pyarray_type == custom_dtypes.float8_e4m3fn) { *out_tf_datatype = TF_FLOAT8_E4M3FN; break; + } else if (pyarray_type == custom_dtypes.float8_e4m3fnuz) { + *out_tf_datatype = TF_FLOAT8_E4M3FNUZ; + break; + } else if (pyarray_type == custom_dtypes.float8_e4m3b11fnuz) { + *out_tf_datatype = TF_FLOAT8_E4M3B11FNUZ; + break; + } else if (pyarray_type == custom_dtypes.float8_e5m2fnuz) { + *out_tf_datatype = TF_FLOAT8_E5M2FNUZ; + break; } else if (pyarray_type == custom_dtypes.int4) { *out_tf_datatype = TF_INT4; break; diff --git a/tensorflow/python/lib/core/ndarray_tensor_bridge.cc b/tensorflow/python/lib/core/ndarray_tensor_bridge.cc index fc64f0ee8e05f3..92b176db9c7952 100644 --- a/tensorflow/python/lib/core/ndarray_tensor_bridge.cc +++ b/tensorflow/python/lib/core/ndarray_tensor_bridge.cc @@ -198,6 +198,15 @@ absl::Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype, case TF_FLOAT8_E4M3FN: *out_pyarray_type = custom_dtypes.float8_e4m3fn; break; + case TF_FLOAT8_E4M3FNUZ: + *out_pyarray_type = custom_dtypes.float8_e4m3fnuz; + break; + case TF_FLOAT8_E4M3B11FNUZ: + *out_pyarray_type = custom_dtypes.float8_e4m3b11fnuz; + break; + case TF_FLOAT8_E5M2FNUZ: + *out_pyarray_type = custom_dtypes.float8_e5m2fnuz; + break; case TF_INT4: *out_pyarray_type = custom_dtypes.int4; break; diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index ee8d79c107ff78..926625cb625658 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -915,7 +915,7 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj, } if (!status.ok()) { - PyErr_SetString(PyExc_ValueError, tsl::NullTerminatedMessage(status)); + PyErr_SetString(PyExc_ValueError, absl::StatusMessageAsCStr(status)); return nullptr; } diff --git a/tensorflow/python/lib/core/pybind11_status.h b/tensorflow/python/lib/core/pybind11_status.h index b175837ffb001a..b00f38580fa1fe 100644 --- a/tensorflow/python/lib/core/pybind11_status.h +++ b/tensorflow/python/lib/core/pybind11_status.h @@ -44,7 +44,7 @@ inline PyObject* CodeToPyExc(const int code) { } } -inline PyObject* StatusToPyExc(const Status& status) { +inline PyObject* StatusToPyExc(const absl::Status& status) { return CodeToPyExc(status.raw_code()); } @@ -76,7 +76,7 @@ inline void MaybeRaiseFromStatus(const absl::Status& status) { } } -inline void SetRegisteredErrFromStatus(const tensorflow::Status& status) { +inline void SetRegisteredErrFromStatus(const absl::Status& status) { PyErr_SetObject( tensorflow::PyExceptionRegistry::Lookup(status.raw_code()), pybind11::make_tuple(pybind11::none(), pybind11::none(), status.message(), @@ -92,15 +92,14 @@ inline void SetRegisteredErrFromTFStatus(TF_Status* status) { .ptr()); } -inline void MaybeRaiseRegisteredFromStatus(const tensorflow::Status& status) { +inline void MaybeRaiseRegisteredFromStatus(const absl::Status& status) { if (!status.ok()) { SetRegisteredErrFromStatus(status); throw pybind11::error_already_set(); } } -inline void MaybeRaiseRegisteredFromStatusWithGIL( - const tensorflow::Status& status) { +inline void MaybeRaiseRegisteredFromStatusWithGIL(const absl::Status& status) { if (!status.ok()) { // Acquire GIL for throwing exception. pybind11::gil_scoped_acquire acquire; @@ -160,10 +159,10 @@ namespace detail { // by PyExceptionRegistry. Note that the registry should be initialized // in order to be used, see PyExceptionRegistry::Init. template <> -struct type_caster { +struct type_caster { public: - PYBIND11_TYPE_CASTER(tensorflow::Status, _("Status")); - static handle cast(tensorflow::Status status, return_value_policy, handle) { + PYBIND11_TYPE_CASTER(absl::Status, _("Status")); + static handle cast(absl::Status status, return_value_policy, handle) { tensorflow::MaybeRaiseFromStatus(status); return none().inc_ref(); } @@ -177,7 +176,7 @@ template struct type_caster> { public: using PayloadCaster = make_caster; - using StatusCaster = make_caster; + using StatusCaster = make_caster; static constexpr auto name = PayloadCaster::name; static handle cast(const tensorflow::StatusOr* src, diff --git a/tensorflow/python/lib/io/_pywrap_record_io.pyi b/tensorflow/python/lib/io/_pywrap_record_io.pyi index 9939b15e7f01ed..cc06ddb6300a10 100644 --- a/tensorflow/python/lib/io/_pywrap_record_io.pyi +++ b/tensorflow/python/lib/io/_pywrap_record_io.pyi @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== -from typing import Any - class RandomRecordReader: def __init__(self, arg0: str) -> None: ... def close(self) -> None: ... @@ -38,7 +36,7 @@ class RecordWriter: class RecordWriterOptions: def __init__(self, arg0: str) -> None: ... @property - def compression_type(self) -> Any: ... + def compression_type(self): ... @property def zlib_options(self) -> ZlibCompressionOptions: ... diff --git a/tensorflow/python/mlir_wrapper.cc b/tensorflow/python/mlir_wrapper.cc index 158e345cf34709..662f70ba3e112b 100644 --- a/tensorflow/python/mlir_wrapper.cc +++ b/tensorflow/python/mlir_wrapper.cc @@ -126,17 +126,4 @@ PYBIND11_MODULE(_pywrap_mlir, m) { tensorflow::ExperimentalWriteBytecode(filename, mlir_txt, status.get()); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); }); - - m.def("ExperimentalTFLiteToTosaBytecode", - [](const std::string &flatbuffer_file, - const std::string &tosa_bytecode_file, bool use_external_constant, - const std::vector &ordered_input_arrays, - const std::vector &ordered_output_arrays) { - tensorflow::Safe_TF_StatusPtr status = - tensorflow::make_safe(TF_NewStatus()); - tensorflow::ExperimentalTFLiteToTosaBytecode( - flatbuffer_file, tosa_bytecode_file, use_external_constant, - ordered_input_arrays, ordered_output_arrays, status.get()); - tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); - }); }; diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index 51dd3717ed86fc..6ff5a0459c80de 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -3139,6 +3139,7 @@ py_strict_library( ":state_ops_gen", "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", "//tensorflow/core:protos_all_py", + "//tensorflow/core/config:flags_py", "//tensorflow/core/function/trace_type", "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python/checkpoint:tensor_callable", @@ -4832,3 +4833,14 @@ py_strict_library( "//third_party/py/numpy", ], ) + +py_strict_test( + name = "tensor_math_operator_overrides_test", + srcs = ["tensor_math_operator_overrides_test.py"], + python_version = "PY3", + deps = [ + ":math_ops", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/platform:client_testlib", + ], +) diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index c9f67c5e59ffdd..7ee6d645915cf9 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -662,8 +662,7 @@ def _GatherV2Grad(op: ops.Operation, grad): # so it's fine to convert it back to int32 regardless of truncation. params = op.inputs[0] with ops.colocate_with(params): - params_shape = array_ops.shape(params, out_type=ops.dtypes.int64) - params_shape = math_ops.cast(params_shape, dtypes.int32) + params_shape = array_ops.shape(params) indices = op.inputs[1] indices_size = array_ops.expand_dims(array_ops.size(indices), 0) diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index ef9206f1d646c1..b70fdaec4d7692 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -326,6 +326,17 @@ def enqueue(self, vals, name=None): `tf.Session.close`, `tf.errors.CancelledError` will be raised. + >>> q = tf.queue.FIFOQueue(capacity=3, dtypes=tf.int32) + >>> q.enqueue(1) + >>> q.enqueue(2) + >>> q.size() + + + >>> q = tf.queue.FIFOQueue(2, tf.int32, shapes=tf.TensorShape(4)) + >>> q.enqueue(tf.constant([1, 2, 3, 4], dtype=tf.int32)) + >>> q.size() + + Args: vals: A tensor, a list or tuple of tensors, or a dictionary containing the values to enqueue. @@ -369,6 +380,11 @@ def enqueue_many(self, vals, name=None): `tf.Session.close`, `tf.errors.CancelledError` will be raised. + >>> q = tf.queue.FIFOQueue(capacity=10, dtypes=tf.int32) + >>> q.enqueue_many(tf.constant([1, 2, 3, 4, 5], dtype=tf.int32)) + >>> q.size() + + Args: vals: A tensor, a list or tuple of tensors, or a dictionary from which the queue elements are taken. @@ -435,6 +451,14 @@ def dequeue(self, name=None): `tf.Session.close`, `tf.errors.CancelledError` will be raised. + >>> q = tf.queue.FIFOQueue(capacity=2, dtypes=tf.int32) + >>> q.enqueue(1) + >>> q.enqueue(2) + >>> q.dequeue() + + >>> q.dequeue() + + Args: name: A name for the operation (optional). @@ -477,6 +501,17 @@ def dequeue_many(self, n, name=None): session is `tf.Session.close`, `tf.errors.CancelledError` will be raised. + >>> q = tf.queue.FIFOQueue(10, tf.int32, shapes=tf.TensorShape(2)) + >>> q.enqueue(tf.constant([1, 2], dtype=tf.int32, shape=(2))) + >>> q.enqueue(tf.constant([3, 4], dtype=tf.int32, shape=(2))) + >>> q.enqueue(tf.constant([5, 6], dtype=tf.int32, shape=(2))) + >>> q.enqueue(tf.constant([7, 8], dtype=tf.int32, shape=(2))) + >>> q.dequeue_many(3) + + Args: n: A scalar `Tensor` containing the number of elements to dequeue. name: A name for the operation (optional). @@ -521,6 +556,15 @@ def dequeue_up_to(self, n, name=None): `tf.errors.OutOfRangeError` is raised just like in `dequeue_many`. Otherwise the behavior is identical to `dequeue_many`. + >>> q = tf.queue.FIFOQueue(10, tf.int32, shapes=tf.TensorShape(2)) + >>> q.enqueue(tf.constant([1, 2], dtype=tf.int32, shape=(2))) + >>> q.enqueue(tf.constant([3, 4], dtype=tf.int32, shape=(2))) + >>> q.close() + >>> q.dequeue_up_to(5) + + Args: n: A scalar `Tensor` containing the number of elements to dequeue. name: A name for the operation (optional). @@ -557,6 +601,13 @@ def close(self, cancel_pending_enqueues=False, name=None): If `cancel_pending_enqueues` is `True`, all pending requests will also be canceled. + >>> q = tf.queue.FIFOQueue(capacity=3, dtypes=tf.int32) + >>> q.is_closed() + + >>> q.close() + >>> q.is_closed() + + Args: cancel_pending_enqueues: (Optional.) A boolean, defaulting to `False` (described above). @@ -584,6 +635,10 @@ def is_closed(self, name=None): This operation returns true if the queue is closed and false if the queue is open. + >>> q = tf.queue.FIFOQueue(capacity=3, dtypes=tf.int32) + >>> q.is_closed() + + Args: name: A name for the operation (optional). @@ -600,6 +655,11 @@ def is_closed(self, name=None): def size(self, name=None): """Compute the number of elements in this queue. + >>> q = tf.queue.FIFOQueue(capacity=10, dtypes=tf.int32) + >>> q.enqueue_many(tf.constant([1, 2, 3, 4], dtype=tf.int32)) + >>> q.size() + + Args: name: A name for the operation (optional). @@ -753,6 +813,10 @@ def __init__(self, shared_name: (Optional.) If non-empty, this queue will be shared under the given name across multiple sessions. name: Optional name for the queue operation. + + >>> q = tf.queue.FIFOQueue(capacity=10, dtypes=tf.int32) + >>> q.size() + """ dtypes = _as_type_list(dtypes) shapes = _as_shape_list(shapes, dtypes) diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 51e6fc3c2988a9..9ddf38bda943bc 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -276,7 +276,8 @@ def _MeanGrad(op: ops.Operation, grad): else: input_shape = array_ops.shape(op.inputs[0]) input_rank = array_ops.size(input_shape) - axes = (op.inputs[1] + input_rank) % input_rank + axes = math_ops.cast(op.inputs[1], input_rank.dtype) + axes = (axes + input_rank) % input_rank factor = math_ops.reduce_prod(array_ops.gather(input_shape, axes)) return math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), None @@ -306,10 +307,10 @@ def _ProdGrad(op: ops.Operation, grad): # copying back and forth, and since listdiff is CPU only. with ops.device("/cpu:0"): rank = array_ops.rank(op.inputs[0]) - reduction_indices = (reduction_indices + rank) % rank - reduced = math_ops.cast(reduction_indices, dtypes.int32) + reduction_indices = math_ops.cast(reduction_indices, rank.dtype) + reduced = (reduction_indices + rank) % rank idx = math_ops.range(0, rank) - other, _ = gen_array_ops.list_diff(idx, reduced, dtypes.int32) + other, _ = gen_array_ops.list_diff(idx, reduced, reduced.dtype) perm = array_ops.concat([reduced, other], 0) reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced)) other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other)) @@ -339,12 +340,12 @@ def _SegmentSumGrad(op: ops.Operation, grad): @ops.RegisterGradient("SegmentMean") def _SegmentMeanGrad(op: ops.Operation, grad): """Gradient for SegmentMean.""" - input_rank = array_ops.rank(op.inputs[0]) - ones_shape = array_ops.concat([ - array_ops.shape(op.inputs[1]), - array_ops.ones( - array_ops.expand_dims(input_rank - 1, 0), dtype=dtypes.int32) - ], 0) + data_rank = array_ops.rank(op.inputs[0]) + segment_ids_shape = array_ops.shape(op.inputs[1]) + remaining_shape = array_ops.ones( + array_ops.expand_dims(data_rank - 1, 0), dtype=segment_ids_shape.dtype + ) + ones_shape = array_ops.concat([segment_ids_shape, remaining_shape], 0) ones = array_ops.ones(ones_shape, dtype=grad.dtype) scaled_grad = math_ops.divide(grad, math_ops.segment_sum(ones, op.inputs[1])) return array_ops.gather(scaled_grad, op.inputs[1]), None @@ -353,18 +354,16 @@ def _SegmentMeanGrad(op: ops.Operation, grad): def _SparseSegmentReduceGradV2(op, grad, norm=None): """Sparse gradient for SparseSegment(Sum|Mean|SqrtN)[WithNumSegments].""" assert norm is None or norm == "mean" or norm == "sqrtn" - data = op.inputs[0] indices = op.inputs[1] segment_ids = op.inputs[2] data_shape = array_ops.shape(op.inputs[0]) dense_output_dim0 = data_shape[0] - grad_fn = ( - math_ops.sparse_segment_mean_grad_v2 - if norm == "mean" - else math_ops.sparse_segment_sqrt_n_grad_v2 - if norm == "sqrtn" - else math_ops.sparse_segment_sum_grad_v2 - ) + if norm == "mean": + grad_fn = math_ops.sparse_segment_mean_grad_v2 + elif norm == "sqrtn": + grad_fn = math_ops.sparse_segment_sqrt_n_grad_v2 + else: + grad_fn = math_ops.sparse_segment_sum_grad_v2 grad_values, sorted_unique_indices = grad_fn( grad, indices, segment_ids, dense_output_dim0 ) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 577b98e98b71e6..d3cce16cda681f 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -4341,15 +4341,15 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None): @tf_export("math.cumulative_logsumexp", v1=["math.cumulative_logsumexp"]) @dispatch.add_dispatch_support def cumulative_logsumexp(x, axis=0, exclusive=False, reverse=False, name=None): - """Compute the cumulative log-sum-exp of the tensor `x` along `axis`. + """Compute the cumulative log-sum-exp of the tensor `x` along the `axis`. - By default, this op performs an inclusive cumulative log-sum-exp, which means - that the first element of the input is identical to the first element of + By default, this operation performs an inclusive cumulative log-sum-exp, which + means that the first element of the input is identical to the first element of the output. This operation is significantly more numerically stable than the equivalent - tensorflow operation `tf.math.log(tf.math.cumsum(tf.math.exp(x)))`, although - computes the same result given infinite numerical precision. However, note + Tensorflow operation `tf.math.log(tf.math.cumsum(tf.math.exp(x)))`, although + it computes the same result given infinite numerical precision. However, note that in some cases, it may be less stable than `tf.math.reduce_logsumexp` for a given element, as it applies the "log-sum-exp trick" in a different way. @@ -4476,30 +4476,35 @@ def reduced_shape(input_shape, axes): constant_input_shape[constant_axes] = 1 return constant_input_shape - # Example: - # cast needed for SparseTensor reductions - input_shape = cast(input_shape, dtypes.int32) # [2, 3, 5, 7] - axes = cast(axes, dtypes.int32) # [1, 2] - - input_rank = array_ops.size(input_shape) # 4 + axes = ops.convert_to_tensor(axes) + input_rank = array_ops.size(input_shape, out_type=axes.dtype) # 4 axes = (axes + input_rank) % input_rank axes_shape = array_ops.shape(axes) # [2] return gen_data_flow_ops.dynamic_stitch( # [2, 1, 1, 7] - [ - range(input_rank), # [0, 1, 2, 3] - axes - ], # [1, 2] + [range(input_rank), axes], # [0, 1, 2, 3] # [1, 2] [ input_shape, # [2, 3, 5, 7] - array_ops.ones(axes_shape, dtype=dtypes.int32) - ]) # [1, 1] + array_ops.ones(axes_shape, dtype=input_shape.dtype), + ], + ) # [1, 1] def _unsorted_segment_N(data, segment_ids, num_segments): - """ Helper function for unsorted_segment_mean/_sqrtN. + """Helper function for unsorted_segment_mean/_sqrtN. + + Computes the number of segment entries with 0-entries set to 1 to allow + division by N. + + Args: + data: A `Tensor` with data that will be assembled in the output. + segment_ids: An integer tensor whose shape is a prefix of `data.shape`. The + values must be in the range `[0, num_segments)`. The values are always + validated to be in range on CPU, never validated on TPU/GPU. + num_segments: An integer scalar `Tensor`. The number of distinct segment + IDs. - Computes the number - of segment entries with 0-entries set to 1 to allow division by N. + Returns: + A `Tensor` with the number of segment entries with 0-entries set to 1. """ num_segments = ops.convert_to_tensor(num_segments) # bincount doesn't support negative indices so we use unsorted_segment_sum @@ -4839,7 +4844,7 @@ def sampled_addmm( dense_shape: `tf.Tensor` defining the dense shape of the output. mat1: `tf.Tensor` to be multiplied. Must have rank > 1. mat2: `tf.Tensor` to be multiplied. Must have rank > 1. - beta: Number to be multipled with `values`. Defaults to 1.0. + beta: Number to be multiplied with `values`. Defaults to 1.0. alpha: Number to be multiplied with the sampled dot product of `mat1` and `mat2`. Defaults to 1.0. output_type: The output datatype if needed. Defaults to float32. diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index c2665986d18ab7..c3685de0c896d7 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -1360,6 +1360,14 @@ def testInputsNearInt64Max(self): self.assertAllEqual( (0,), self.evaluate(x)) # smallest input with potential overflow + def testInt32Overflow(self): + start = 1136033460 + end = -2110457150 + step = -1849827689 + expected = np.arange(start, end, step) + actual = math_ops.range(start, end, step) + self.assertAllEqual(expected, self.evaluate(actual)) + @test_util.run_all_in_graph_and_eager_modes class ErfcinvTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 9eb695109fda3c..bb2c171e30df57 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -661,7 +661,7 @@ def _LRNGrad(op: ops.Operation, grad): @ops.RegisterGradient("AvgPool") def _AvgPoolGrad(op: ops.Operation, grad): return gen_nn_ops.avg_pool_grad( - array_ops.shape(op.inputs[0]), + array_ops.shape(op.inputs[0], out_type=dtypes.int32), grad, op.get_attr("ksize"), op.get_attr("strides"), diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 88c9483edd7019..50ffe600480a6e 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -1368,8 +1368,12 @@ def __init__(self, self._all_indices_partitioned = all_indices_partitioned if all_indices_partitioned: assert all_indices is not None - self.all_indices = ( - math_ops.range(loop_len) if all_indices is None else all_indices) + if all_indices is None: + self.all_indices = math_ops.range( + loop_len, dtype=dtypes.int32, name="all_indices" + ) + else: + self.all_indices = all_indices self._conversion_map = object_identity.ObjectIdentityDictionary() self._conversion_map[loop_var] = wrap(self.all_indices, True) diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD index fca58e2b13699f..b2700f88c32331 100644 --- a/tensorflow/python/ops/ragged/BUILD +++ b/tensorflow/python/ops/ragged/BUILD @@ -904,6 +904,7 @@ py_strict_test( "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:ragged_math_ops_gen", "//tensorflow/python/platform:test", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/python/ops/ragged/ragged_range_op_test.py b/tensorflow/python/ops/ragged/ragged_range_op_test.py index c759b8254ac167..61fbc48047e575 100644 --- a/tensorflow/python/ops/ragged/ragged_range_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_range_op_test.py @@ -14,6 +14,8 @@ # ============================================================================== """Tests for ragged_range op.""" +import numpy as np + from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import test_util @@ -129,6 +131,14 @@ def testShape(self): self.assertAllEqual( ragged_math_ops.range([1, 2, 3], [4, 5, 6]).shape.as_list(), [3, None]) + def testInt32Overflow(self): + start = 1136033460 + end = -2110457150 + step = -1849827689 + expected = [np.arange(start, end, step)] + actual = ragged_math_ops.range(start, end, step) + self.assertAllEqual(expected, self.evaluate(actual)) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 8db75f93970a73..566b4094e2c650 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -23,6 +23,7 @@ from absl import logging from tensorflow.compiler.tf2xla.ops import gen_xla_ops +from tensorflow.core.config import flags from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import variable_pb2 from tensorflow.core.function import trace_type @@ -447,7 +448,7 @@ def __init__( # pylint: disable=super-init-not-called deduplicate copying through `Switch` and other conditional statements. in_graph_mode: whether we are executing in TF1 graph mode. If None, will detect within the function. This is to avoid repeated init_scope() - conetxt entrances which can add up. + context entrances which can add up. validate_shape: If `False`, allows the variable to be initialized with a value of unknown shape. If `True`, the default, the shape of `initial_value` must be known. @@ -1675,8 +1676,8 @@ def get_gradient_components(self, value): For a ResourceVariable, its gradient component is its handle tensor. For now, we return the ResourceVariable because the gradient infrastructure - has special logics to handle ResourceVariables. We should remove those - special logics and return the handle tensor. + has special logic to handle ResourceVariables. We should remove the special + logic and return the handle tensor. Args: value: A `ResourceVariable`. @@ -2521,7 +2522,24 @@ def _ReadGrad(_, grad): return grad -def variable_shape(handle, out_type=dtypes.int32): +def variable_shape(handle, out_type=None): + """Returns the shape of the variable from the handle. + + If the output shape dtype is not specified, it will be set to int64 if + tf_shape_default_int64 is enabled, otherwise it will be set to int32. + + Args: + handle: The handle of the variable. + out_type: The dtype of the output shape. + + Returns: + The shape of the variable. + """ + if out_type is None: + if flags.config().tf_shape_default_int64.value(): + out_type = dtypes.int64 + else: + out_type = dtypes.int32 handle_data = get_eager_safe_handle_data(handle) if handle_data is None or not handle_data.is_set: return gen_resource_variable_ops.variable_shape(handle, out_type=out_type) diff --git a/tensorflow/python/ops/tensor_math_operator_overrides.py b/tensorflow/python/ops/tensor_math_operator_overrides.py index f94d2a14da8faa..06947894840fd7 100644 --- a/tensorflow/python/ops/tensor_math_operator_overrides.py +++ b/tensorflow/python/ops/tensor_math_operator_overrides.py @@ -60,7 +60,19 @@ def _mod_factory(x, y, name=None): def _mul_dispatch_factory(x, y, name=None): from tensorflow.python.ops import math_ops - + from tensorflow.python.framework import dtypes + + if (isinstance(x, tensor_lib.Tensor) and x.dtype == dtypes.bool) or ( + isinstance(y, tensor_lib.Tensor) and y.dtype == dtypes.bool + ): + return gen_math_ops.cast( + math_ops._mul_dispatch( + gen_math_ops.cast(x, dtypes.int32), + gen_math_ops.cast(y, dtypes.int32), + name=name, + ), + dtypes.bool, + ) # pylint: disable=protected-access return math_ops._mul_dispatch(x, y, name=name) # pylint: disable=protected-access diff --git a/tensorflow/python/ops/tensor_math_operator_overrides_test.py b/tensorflow/python/ops/tensor_math_operator_overrides_test.py new file mode 100644 index 00000000000000..5fa27aeaeb6ab4 --- /dev/null +++ b/tensorflow/python/ops/tensor_math_operator_overrides_test.py @@ -0,0 +1,55 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the math operator overrides.""" + + +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import tensor_math_operator_overrides as tmoo +from tensorflow.python.platform import test + + +class SortTest(test.TestCase): + + def _test_mul_dispatch_factory(self, x, y, expected, name=None): + self.assertAllEqual(expected, tmoo._mul_dispatch_factory(x, y, name=name)) + + def testNonBooleanTensor(self): + x = constant_op.constant([1, 2, 3]) + y = constant_op.constant([4, 5, 6]) + expected = constant_op.constant([4, 10, 18]) + self._test_mul_dispatch_factory(x, y, expected) + + def testBooleanTensor(self): + x = constant_op.constant([True, False, True]) + y = constant_op.constant([False, True, True]) + expected = constant_op.constant([False, False, True]) + self._test_mul_dispatch_factory(x, y, expected) + + def testBooleanMix(self): + # Non-boolean tensor is first. + x = constant_op.constant([1, 2, 3]) + y = constant_op.constant([False, True, True]) + expected = constant_op.constant([False, True, True]) + self._test_mul_dispatch_factory(x, y, expected) + + # Boolean tensor is first. + x = constant_op.constant([False, True, True]) + y = constant_op.constant([1, 2, 3]) + expected = constant_op.constant([False, True, True]) + self._test_mul_dispatch_factory(x, y, expected) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/platform/BUILD b/tensorflow/python/platform/BUILD index ba85b2bb656e5a..c33177795daa08 100644 --- a/tensorflow/python/platform/BUILD +++ b/tensorflow/python/platform/BUILD @@ -220,13 +220,19 @@ py_strict_library( name = "client_testlib", srcs = ["test.py"], srcs_version = "PY3", - visibility = visibility + [ - "//tensorflow:internal", - "//tensorflow_models:__subpackages__", - "//third_party/cloud_tpu/convergence_tools:__subpackages__", - "//third_party/mlperf:__subpackages__", - "//third_party/py/tf_slim:__subpackages__", - ], + # copybara:uncomment_begin(google-only) + # visibility = visibility + [ + # "//third_party/cloud_tpu/convergence_tools:__subpackages__", + # "//third_party/mlperf:__subpackages__", + # "//third_party/py/tf_slim:__subpackages__", + # "//tensorflow:internal", + # "//tensorflow_models:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + visibility = [ + "//visibility:public", + ], + # copybara:comment_end deps = [ ":test", "//tensorflow/python/framework:test_lib", @@ -286,7 +292,13 @@ py_strict_library( py_strict_library( name = "gfile", srcs = ["gfile.py"], - visibility = visibility, + # copybara:uncomment_begin(google-only) + # visibility = visibility, + # copybara:uncomment_end_and_comment_begin + visibility = [ + "//visibility:public", + ], + # copybara:comment_end deps = [ "//tensorflow/python/lib/io:file_io", "//tensorflow/python/util:deprecation", diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD index df084ae8795f56..bc9212bf102186 100644 --- a/tensorflow/python/profiler/internal/BUILD +++ b/tensorflow/python/profiler/internal/BUILD @@ -141,6 +141,7 @@ tf_python_pybind_extension( "//tensorflow/core/profiler/convert:xplane_to_tools_data", "//tensorflow/core/profiler/rpc:profiler_server_for_pybind", "//tensorflow/python/lib/core:pybind11_status", + "@com_google_absl//absl/status", "@pybind11", ], ) @@ -180,8 +181,10 @@ cc_library( "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "//tensorflow/core/profiler/rpc:profiler_server_for_pybind", "//tensorflow/core/profiler/rpc/client:save_profile", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:variant", @@ -209,6 +212,10 @@ tsl_pybind_extension( "//tensorflow/core/profiler/convert:xplane_to_tools_data", "//tensorflow/python/lib/core:py_exception_registry", "//tensorflow/python/lib/core:pybind11_status", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/profiler/protobuf:profiler_analysis_proto_cc_impl", diff --git a/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc b/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc index 86401ce602ec39..39898c91ddbdbb 100644 --- a/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc +++ b/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc @@ -19,14 +19,8 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "absl/memory/memory.h" -#include "absl/strings/match.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_split.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" -#include "absl/strings/strip.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" #include "absl/types/variant.h" #include "xla/tsl/profiler/convert/xplane_to_trace_events.h" #include "xla/tsl/profiler/rpc/client/capture_profile.h" @@ -39,6 +33,7 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/rpc/client/save_profile.h" #include "tensorflow/core/profiler/rpc/profiler_server.h" +#include "tsl/profiler/protobuf/xplane.pb.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/python/profiler/internal/profiler_pywrap_impl.h b/tensorflow/python/profiler/internal/profiler_pywrap_impl.h index 700565cc3d51f0..d99e36333432c4 100644 --- a/tensorflow/python/profiler/internal/profiler_pywrap_impl.h +++ b/tensorflow/python/profiler/internal/profiler_pywrap_impl.h @@ -15,10 +15,12 @@ limitations under the License. #ifndef TENSORFLOW_PYTHON_PROFILER_INTERNAL_PROFILER_PYWRAP_IMPL_H_ #define TENSORFLOW_PYTHON_PROFILER_INTERNAL_PROFILER_PYWRAP_IMPL_H_ +#include #include #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/types/variant.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/profiler/lib/profiler_session.h" diff --git a/tensorflow/python/profiler/internal/profiler_wrapper.cc b/tensorflow/python/profiler/internal/profiler_wrapper.cc index 5be852a715fe11..8ec97b32799856 100644 --- a/tensorflow/python/profiler/internal/profiler_wrapper.cc +++ b/tensorflow/python/profiler/internal/profiler_wrapper.cc @@ -13,14 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include -#include #include #include #include -#include +#include "absl/status/status.h" #include "pybind11/pybind11.h" // from @pybind11 #include "tensorflow/core/profiler/convert/repository.h" #include "tensorflow/core/profiler/convert/tool_options.h" diff --git a/tensorflow/python/profiler/internal/pywrap_profiler_plugin.cc b/tensorflow/python/profiler/internal/pywrap_profiler_plugin.cc index dbf5208084bc60..acb6896e5f1e57 100644 --- a/tensorflow/python/profiler/internal/pywrap_profiler_plugin.cc +++ b/tensorflow/python/profiler/internal/pywrap_profiler_plugin.cc @@ -21,6 +21,10 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "pybind11/pybind11.h" // from @pybind11 #include "xla/pjrt/status_casters.h" #include "xla/tsl/profiler/rpc/client/capture_profile.h" @@ -29,6 +33,7 @@ limitations under the License. #include "tensorflow/core/profiler/convert/tool_options.h" #include "tensorflow/core/profiler/convert/xplane_to_tools_data.h" #include "tensorflow/python/lib/core/pybind11_status.h" +#include "tsl/profiler/protobuf/xplane.pb.h" namespace py = ::pybind11; @@ -87,7 +92,7 @@ PYBIND11_MODULE(_pywrap_profiler_plugin, m) { "trace", [](const char* service_addr, const char* logdir, const char* worker_list, bool include_dataset_ops, int duration_ms, int num_tracing_attempts, py::dict options) { - tensorflow::Status status; + absl::Status status; ToolOptions tool_options = ToolOptionsFromPythonDict(options); { py::gil_scoped_release release; diff --git a/tensorflow/python/pywrap_dtensor_device.cc b/tensorflow/python/pywrap_dtensor_device.cc index 8cd5fe8b5014aa..a055f784d382a3 100644 --- a/tensorflow/python/pywrap_dtensor_device.cc +++ b/tensorflow/python/pywrap_dtensor_device.cc @@ -414,7 +414,7 @@ PYBIND11_MODULE(_pywrap_dtensor_device, m) { return *mesh; }), py::arg("mesh_proto"), "Returns a Mesh from a MeshProto.") - .def(py::init([](std::string_view mesh_str) { + .def(py::init([](absl::string_view mesh_str) { auto mesh = Mesh::FromString(mesh_str); if (!mesh.ok()) { throw py::value_error(std::string(mesh.status().message())); @@ -436,7 +436,7 @@ PYBIND11_MODULE(_pywrap_dtensor_device, m) { "Returns True if a Mesh contains the given dimension name.") .def( "dim_size", - [](const Mesh& mesh, std::string_view name) { + [](const Mesh& mesh, absl::string_view name) { auto dim_size = mesh.dim_size(name); if (!dim_size.ok()) { throw py::value_error(std::string(dim_size.status().message())); @@ -512,7 +512,7 @@ PYBIND11_MODULE(_pywrap_dtensor_device, m) { return *layout; }), py::arg("layout_proto"), "Returns a Layout from a LayoutProto.") - .def(py::init([](std::string_view layout_str) { + .def(py::init([](absl::string_view layout_str) { auto layout = Layout::FromString(layout_str); if (!layout.ok()) { throw py::value_error(std::string(layout.status().message())); diff --git a/tensorflow/python/saved_model/load_v1_in_v2_test.py b/tensorflow/python/saved_model/load_v1_in_v2_test.py index bdec3926f4312f..76ef0e18b5e117 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2_test.py +++ b/tensorflow/python/saved_model/load_v1_in_v2_test.py @@ -95,6 +95,40 @@ def _v1_single_metagraph_saved_model(self, use_resource): ) return path + @test_util.run_in_graph_and_eager_modes + def test_pretty_printed_signature(self): + imported = load.load( + self._v1_single_metagraph_saved_model(use_resource=True) + ) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(variables.local_variables_initializer()) + concrete_fn = imported.signatures["serving_default"] + + summary = ( + "(*, start: TensorSpec(shape=, dtype=tf.float32," + " name='start')) -> Dict[['output', TensorSpec(shape=," + " dtype=tf.float32, name=None)]]" + ) + details = ( + r"Input Parameters:\n" + r" start \(KEYWORD_ONLY\): TensorSpec\(shape=," + r" dtype=tf\.float32, name='start'\)\n" + r"Output Type:\n" + r" Dict\[\['output', TensorSpec\(shape=," + r" dtype=tf\.float32, name=None\)\]\]\n" + r"Captures:\n" + r" \d+: TensorSpec\(shape=\(\), dtype=tf\.resource, name=None\)\n" + r" \d+: TensorSpec\(shape=\(\), dtype=tf\.resource, name=None\)" + ) + self.assertEqual( + concrete_fn.pretty_printed_signature(verbose=False), summary + ) + self.assertRegex( + concrete_fn.pretty_printed_signature(verbose=True), details + ) + self.assertRegex(repr(concrete_fn), r" None: ... diff --git a/tensorflow/python/saved_model/pywrap_saved_model/merger.pyi b/tensorflow/python/saved_model/pywrap_saved_model/merger.pyi index 4023ce61ee5cb1..6905f3befb0f2b 100644 --- a/tensorflow/python/saved_model/pywrap_saved_model/merger.pyi +++ b/tensorflow/python/saved_model/pywrap_saved_model/merger.pyi @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== -from typing import Any - class MergerException(Exception): ... -def MergerRead(*args, **kwargs) -> Any: ... +def MergerRead(*args, **kwargs): ... diff --git a/tensorflow/python/saved_model/pywrap_saved_model/metrics.pyi b/tensorflow/python/saved_model/pywrap_saved_model/metrics.pyi index 6228fca0cebb97..460b0bbc73d69c 100644 --- a/tensorflow/python/saved_model/pywrap_saved_model/metrics.pyi +++ b/tensorflow/python/saved_model/pywrap_saved_model/metrics.pyi @@ -13,50 +13,48 @@ # limitations under the License. # ============================================================================== -from typing import Any - kFingerprintError: str kFingerprintFound: str kFingerprintNotFound: str class MetricException(Exception): ... -def AddAsyncCheckpointWriteDuration(*args, **kwargs) -> Any: ... -def AddCheckpointReadDuration(*args, **kwargs) -> Any: ... -def AddCheckpointWriteDuration(*args, **kwargs) -> Any: ... -def AddNumCheckpointShardsWritten(*args, **kwargs) -> Any: ... -def AddShardingCallbackDuration(*args, **kwargs) -> Any: ... -def AddTrainingTimeSaved(*args, **kwargs) -> Any: ... +def AddAsyncCheckpointWriteDuration(*args, **kwargs): ... +def AddCheckpointReadDuration(*args, **kwargs): ... +def AddCheckpointWriteDuration(*args, **kwargs): ... +def AddNumCheckpointShardsWritten(*args, **kwargs): ... +def AddShardingCallbackDuration(*args, **kwargs): ... +def AddTrainingTimeSaved(*args, **kwargs): ... def CalculateFileSize(arg0: str) -> int: ... -def GetAsyncCheckpointWriteDurations(*args, **kwargs) -> Any: ... -def GetCheckpointReadDurations(*args, **kwargs) -> Any: ... -def GetCheckpointSize(*args, **kwargs) -> Any: ... -def GetCheckpointWriteDurations(*args, **kwargs) -> Any: ... +def GetAsyncCheckpointWriteDurations(*args, **kwargs): ... +def GetCheckpointReadDurations(*args, **kwargs): ... +def GetCheckpointSize(*args, **kwargs): ... +def GetCheckpointWriteDurations(*args, **kwargs): ... def GetFoundFingerprintOnLoad() -> str: ... def GetNumCheckpointShardsWritten() -> int: ... -def GetRead(*args, **kwargs) -> Any: ... +def GetRead(*args, **kwargs): ... def GetReadApi(arg0: str) -> int: ... def GetReadFingerprint() -> str: ... def GetReadPath() -> str: ... -def GetReadPathAndSingleprint() -> tuple[str,str]: ... +def GetReadPathAndSingleprint() -> tuple[str, str]: ... def GetShardingCallbackDescription() -> str: ... def GetShardingCallbackDuration() -> int: ... -def GetTrainingTimeSaved(*args, **kwargs) -> Any: ... -def GetWrite(*args, **kwargs) -> Any: ... +def GetTrainingTimeSaved(*args, **kwargs): ... +def GetWrite(*args, **kwargs): ... def GetWriteApi(arg0: str) -> int: ... def GetWriteFingerprint() -> str: ... def GetWritePath() -> str: ... -def GetWritePathAndSingleprint() -> tuple[str,str]: ... -def IncrementRead(*args, **kwargs) -> Any: ... +def GetWritePathAndSingleprint() -> tuple[str, str]: ... +def IncrementRead(*args, **kwargs): ... def IncrementReadApi(arg0: str) -> None: ... -def IncrementWrite(*args, **kwargs) -> Any: ... +def IncrementWrite(*args, **kwargs): ... def IncrementWriteApi(arg0: str) -> None: ... -def RecordCheckpointSize(*args, **kwargs) -> Any: ... -def SetFoundFingerprintOnLoad(*args, **kwargs) -> Any: ... -def SetReadFingerprint(*args, **kwargs) -> Any: ... -def SetReadPath(*args, **kwargs) -> Any: ... -def SetReadPathAndSingleprint(*args, **kwargs) -> Any: ... -def SetShardingCallbackDescription(*args, **kwargs) -> Any: ... -def SetWriteFingerprint(*args, **kwargs) -> Any: ... -def SetWritePath(*args, **kwargs) -> Any: ... -def SetWritePathAndSingleprint(*args, **kwargs) -> Any: ... +def RecordCheckpointSize(*args, **kwargs): ... +def SetFoundFingerprintOnLoad(*args, **kwargs): ... +def SetReadFingerprint(*args, **kwargs): ... +def SetReadPath(*args, **kwargs): ... +def SetReadPathAndSingleprint(*args, **kwargs): ... +def SetShardingCallbackDescription(*args, **kwargs): ... +def SetWriteFingerprint(*args, **kwargs): ... +def SetWritePath(*args, **kwargs): ... +def SetWritePathAndSingleprint(*args, **kwargs): ... diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index 4a0dced235e542..dc5b1491e9aa7e 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -502,7 +502,6 @@ genrule( name = "create_models_for_aot_compile", outs = EMITTED_AOT_SAVE_MODEL_OBJECTS, cmd = ( - "PYWRAP_TARGET='//tensorflow/python:_pywrap_tensorflow' " + "$(location :make_aot_compile_models) --out_dir $(@D)" ), tags = ["cuda-only"], diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py index 90f5752eebad00..c0b973c2d513cf 100644 --- a/tensorflow/python/tools/inspect_checkpoint.py +++ b/tensorflow/python/tools/inspect_checkpoint.py @@ -66,7 +66,7 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors, tensor_name: Name of the tensor in the checkpoint file to print. all_tensors: Boolean indicating whether to print all tensors. all_tensor_names: Boolean indicating whether to print all tensor names. - count_exclude_pattern: Regex string, pattern to exclude tensors when count. + count_exclude_pattern: Regex string, pattern to exclude tensors from count. """ try: reader = py_checkpoint_reader.NewCheckpointReader(file_name) @@ -123,7 +123,7 @@ def parse_numpy_printoption(kv_str): Raises: argparse.ArgumentTypeError: If the string couldn't be used to set any - nump printoption. + numpy printoption. """ k_v_str = kv_str.split("=", 1) if len(k_v_str) != 2 or not k_v_str[0]: @@ -147,11 +147,14 @@ def parse_numpy_printoption(kv_str): def main(unused_argv): if not FLAGS.file_name: - print("Usage: inspect_checkpoint --file_name=checkpoint_file_name " - "[--tensor_name=tensor_to_print] " - "[--all_tensors] " - "[--all_tensor_names] " - "[--printoptions]") + print( + "Usage: inspect_checkpoint --file_name=checkpoint_file_name " + "[--tensor_name=tensor_to_print] " + "[--all_tensors] " + "[--all_tensor_names] " + "[--count_exclude_pattern] " + "[--printoptions]" + ) sys.exit(1) else: print_tensors_in_checkpoint_file( diff --git a/tensorflow/python/tools/tools.bzl b/tensorflow/python/tools/tools.bzl index 2e787be73af973..0255876c0fe322 100644 --- a/tensorflow/python/tools/tools.bzl +++ b/tensorflow/python/tools/tools.bzl @@ -132,7 +132,6 @@ def saved_model_compile_aot( "{}_makefile.inc".format(name), ], cmd = ( - "PYWRAP_TARGET='//tensorflow/python:_pywrap_tensorflow' " + "$(location {}) aot_compile_cpu ".format( clean_dep("//tensorflow/python/tools:saved_model_cli"), ) + diff --git a/tensorflow/python/tpu/_pywrap_sparse_core_layout.pyi b/tensorflow/python/tpu/_pywrap_sparse_core_layout.pyi index cf6aae1857f4f0..7a8fba85e9e6f5 100644 --- a/tensorflow/python/tpu/_pywrap_sparse_core_layout.pyi +++ b/tensorflow/python/tpu/_pywrap_sparse_core_layout.pyi @@ -13,12 +13,10 @@ # limitations under the License. # ============================================================================== -from typing import Any - class SparseCoreLayoutStacker: def __init__(self, num_partitions: int, disable_table_stacking: bool, sparse_cores_per_partition: int) -> None: ... def AddTable(self, table_name: str, table_height: int, table_width: int, group: str, output_samples: int) -> None: ... - def GetLayouts(self, *args, **kwargs) -> Any: ... + def GetLayouts(self, *args, **kwargs): ... def SetActivationMemoryBytesLimit(self, arg0: int) -> None: ... def SetStackingEnabled(self, arg0: bool) -> None: ... def SetVariableShardBytesLimit(self, arg0: int) -> None: ... diff --git a/tensorflow/python/tpu/tpu_embedding_v3.py b/tensorflow/python/tpu/tpu_embedding_v3.py index c822ee9ddae177..f0bad56f2042f3 100644 --- a/tensorflow/python/tpu/tpu_embedding_v3.py +++ b/tensorflow/python/tpu/tpu_embedding_v3.py @@ -1883,7 +1883,7 @@ def _get_csr_wrapped_coo_from_sorted_coo_tensor( table_vocab_size=total_vocab_size, feature_width=feature_width, table_name=table_name, - allow_id_dropping=True, # TODO(pineapplejuice233): make this configurable. + allow_id_dropping=self._sparse_core_embedding_config.allow_id_dropping, ) table_to_csr_format_tensor[table_name] = ( PartitionedCsrFormatTensor( diff --git a/tensorflow/python/util/BUILD b/tensorflow/python/util/BUILD index c6b69285ed9ba7..a007579903efab 100644 --- a/tensorflow/python/util/BUILD +++ b/tensorflow/python/util/BUILD @@ -148,6 +148,7 @@ tf_python_pybind_extension( deps = [ "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib_headers_for_pybind", + "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/python/lib/core:pybind11_lib", "//third_party/python_runtime:headers", @@ -233,6 +234,7 @@ tf_python_pybind_extension( "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", "//tensorflow/python/lib/core:pybind11_status", + "@com_google_absl//absl/status", "@pybind11", ] + if_pywrap(["//tensorflow/tools/graph_transforms:transform_graph_lib"]), ) @@ -555,6 +557,7 @@ tf_python_pybind_extension( pytype_srcs = [ "_function_parameter_canonicalizer_binding_for_test.pyi", ], + starlark_only = True, deps = [ "//tensorflow/core:lib", "//third_party/python_runtime:headers", # buildcleaner: keep @@ -837,7 +840,13 @@ py_strict_library( py_strict_library( name = "lazy_loader", srcs = ["lazy_loader.py"], - visibility = util_subpackage_visibility, + # copybara:uncomment_begin(google-only) + # visibility = util_subpackage_visibility, + # copybara:uncomment_end_and_comment_begin + visibility = [ + "//visibility:public", + ], + # copybara:comment_end deps = [ "//tensorflow/python/platform:tf_logging", # global_test_configuration is added here because all major tests depend on this @@ -1063,7 +1072,13 @@ py_strict_library( py_strict_library( name = "dispatch", srcs = ["dispatch.py"], - visibility = util_subpackage_visibility, + # copybara:uncomment_begin(google-only) + # visibility = util_subpackage_visibility, + # copybara:uncomment_end_and_comment_begin + visibility = [ + "//visibility:public", + ], + # copybara:comment_end deps = [ ":tf_decorator_py", ":tf_inspect", @@ -1287,9 +1302,12 @@ tf_py_strict_test( tf_python_pybind_extension( name = "pywrap_xla_ops", srcs = ["tf2xla_opset_wrapper.cc"], - hdrs = [ - "//tensorflow/compiler/tf2xla:tf2xla_opset_hdrs", - ], + hdrs = if_pywrap( + if_false = [ + "//tensorflow/compiler/tf2xla:tf2xla_opset_hdrs", + ], + if_true = [], + ), enable_stub_generation = True, pytype_srcs = [ "pywrap_xla_ops.pyi", @@ -1299,7 +1317,7 @@ tf_python_pybind_extension( "@pybind11", "@pybind11_abseil//pybind11_abseil:absl_casters", "@pybind11_abseil//pybind11_abseil:status_casters", - ], + ] + if_pywrap(["//tensorflow/compiler/tf2xla:tf2xla_opset"]), ) py_strict_library( diff --git a/tensorflow/python/util/_pywrap_checkpoint_reader.pyi b/tensorflow/python/util/_pywrap_checkpoint_reader.pyi index 1402d60148afeb..2a6f5e05a54777 100644 --- a/tensorflow/python/util/_pywrap_checkpoint_reader.pyi +++ b/tensorflow/python/util/_pywrap_checkpoint_reader.pyi @@ -13,13 +13,9 @@ # limitations under the License. # ============================================================================== -from typing import Any - class CheckpointReader: def __init__(self, arg0: str) -> None: ... - @classmethod - def CheckpointReader_GetTensor(cls, arg0: CheckpointReader, arg1: str) -> object: ... - def _GetVariableToDataTypeMap(self, *args, **kwargs) -> Any: ... - def _HasTensor(self, arg0: str) -> bool: ... + @staticmethod + def CheckpointReader_GetTensor(arg0: CheckpointReader, arg1: str) -> object: ... def debug_string(self) -> bytes: ... - def get_variable_to_shape_map(self, *args, **kwargs) -> Any: ... + def get_variable_to_shape_map(self, *args, **kwargs): ... diff --git a/tensorflow/python/util/_tf_stack.pyi b/tensorflow/python/util/_tf_stack.pyi index cc906680cbc705..be7f4969f0725a 100644 --- a/tensorflow/python/util/_tf_stack.pyi +++ b/tensorflow/python/util/_tf_stack.pyi @@ -13,9 +13,8 @@ # limitations under the License. # ============================================================================== -from typing import Iterator - -from typing import overload +import typing +from typing import Iterator, overload class GraphDebugInfoBuilder: def __init__(self) -> None: ... @@ -57,8 +56,9 @@ class StackTrace: def __getitem__(self, arg0: int) -> StackFrame: ... @overload def __getitem__(self, arg0: slice) -> StackTrace: ... + def __iter__(self) -> typing.Iterator[StackFrame]: ... def __hash__(self) -> int: ... def __len__(self) -> int: ... -def LoadTracesFromDebugInfo(debug_info_proto: bytes) -> dict[str,StackTrace]: ... +def LoadTracesFromDebugInfo(debug_info_proto: bytes) -> dict[str, StackTrace]: ... def extract_stack(source_map: PyBindSourceMap, file_set: PyBindFileSet, stacklevel: int = ...) -> StackTrace: ... diff --git a/tensorflow/python/util/function_parameter_canonicalizer.h b/tensorflow/python/util/function_parameter_canonicalizer.h index 512267595202e6..5a841f652ed2bf 100644 --- a/tensorflow/python/util/function_parameter_canonicalizer.h +++ b/tensorflow/python/util/function_parameter_canonicalizer.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include #include #include "absl/base/attributes.h" diff --git a/tensorflow/python/util/function_parameter_canonicalizer_binding_for_test.cc b/tensorflow/python/util/function_parameter_canonicalizer_binding_for_test.cc index 121c61dbf48bbf..0e8d95a815c7cb 100644 --- a/tensorflow/python/util/function_parameter_canonicalizer_binding_for_test.cc +++ b/tensorflow/python/util/function_parameter_canonicalizer_binding_for_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include #include #include "absl/types/span.h" diff --git a/tensorflow/python/util/kernel_registry_wrapper.cc b/tensorflow/python/util/kernel_registry_wrapper.cc index d3d303416961b5..8fa360e124c5a1 100644 --- a/tensorflow/python/util/kernel_registry_wrapper.cc +++ b/tensorflow/python/util/kernel_registry_wrapper.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "pybind11/pybind11.h" // from @pybind11 #include "tensorflow/python/util/kernel_registry.h" diff --git a/tensorflow/python/util/nest.cc b/tensorflow/python/util/nest.cc index 4ee9497cb455f2..72a88697aee9dd 100644 --- a/tensorflow/python/util/nest.cc +++ b/tensorflow/python/util/nest.cc @@ -14,7 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/python/util/nest.h" -#include +#include +#include #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/stringpiece.h" @@ -40,7 +41,7 @@ std::string PyObject_ToString(PyObject* o, int length = -1) { if (length < 0 || str.size() <= length) { return str; } - tensorflow::StringPiece str_piece(str); + absl::string_view str_piece(str); return tensorflow::strings::StrCat(str_piece.substr(length), "..."); } diff --git a/tensorflow/python/util/stack_trace.h b/tensorflow/python/util/stack_trace.h index df55a206e022e0..4296c34979e418 100644 --- a/tensorflow/python/util/stack_trace.h +++ b/tensorflow/python/util/stack_trace.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include #include #include diff --git a/tensorflow/python/util/stat_summarizer_wrapper.cc b/tensorflow/python/util/stat_summarizer_wrapper.cc index 13f6d2330d4130..8224e52a0d932f 100644 --- a/tensorflow/python/util/stat_summarizer_wrapper.cc +++ b/tensorflow/python/util/stat_summarizer_wrapper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/pytypes.h" // from @pybind11 diff --git a/tensorflow/python/util/tf2xla_opset_wrapper.cc b/tensorflow/python/util/tf2xla_opset_wrapper.cc index aa1f8f52e06863..53d9eb25b969fb 100644 --- a/tensorflow/python/util/tf2xla_opset_wrapper.cc +++ b/tensorflow/python/util/tf2xla_opset_wrapper.cc @@ -15,7 +15,6 @@ limitations under the License. #include -#include #include #include diff --git a/tensorflow/python/util/tf_stack.cc b/tensorflow/python/util/tf_stack.cc index 5cfaf5145155b3..9d211ade47fcbb 100644 --- a/tensorflow/python/util/tf_stack.cc +++ b/tensorflow/python/util/tf_stack.cc @@ -34,7 +34,10 @@ limitations under the License. // clang-format on #include -#include +#include +#include +#include +#include #include #include "absl/algorithm/container.h" diff --git a/tensorflow/python/util/transform_graph_wrapper.cc b/tensorflow/python/util/transform_graph_wrapper.cc index ec0ca2d78237ed..dc6c5cb18e3e13 100644 --- a/tensorflow/python/util/transform_graph_wrapper.cc +++ b/tensorflow/python/util/transform_graph_wrapper.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "pybind11/pybind11.h" // from @pybind11 #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/strings/str_util.h" diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index 661ba0aed648d4..22136b7840bf28 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -120,7 +120,7 @@ bool IsString(PyObject* o) { // Note that '__class__' attribute is set only in new-style classes. // A lot of tensorflow code uses __class__ without checks, so it seems like // we only support new-style classes. -StringPiece GetClassName(PyObject* o) { +absl::string_view GetClassName(PyObject* o) { // __class__ is equivalent to type() for new style classes. // type() is equivalent to PyObject_Type() // (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type) @@ -130,9 +130,9 @@ StringPiece GetClassName(PyObject* o) { // __name__ is the value of `tp_name` after the last '.' // (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name) - StringPiece name(type->tp_name); + absl::string_view name(type->tp_name); size_t pos = name.rfind('.'); - if (pos != StringPiece::npos) { + if (pos != absl::string_view::npos) { name.remove_prefix(pos + 1); } return name; diff --git a/tensorflow/security/fuzzing/cc/consume_leading_digits_fuzz.cc b/tensorflow/security/fuzzing/cc/consume_leading_digits_fuzz.cc index 060535600bc1ae..32f56250bccecf 100644 --- a/tensorflow/security/fuzzing/cc/consume_leading_digits_fuzz.cc +++ b/tensorflow/security/fuzzing/cc/consume_leading_digits_fuzz.cc @@ -25,7 +25,7 @@ limitations under the License. namespace { void FuzzTest(std::string data) { - tensorflow::StringPiece sp(data); + absl::string_view sp(data); tensorflow::uint64 val; const bool leading_digits = diff --git a/tensorflow/security/fuzzing/cc/parseURI_fuzz.cc b/tensorflow/security/fuzzing/cc/parseURI_fuzz.cc index fc538a9017559b..9dff089f22aa43 100644 --- a/tensorflow/security/fuzzing/cc/parseURI_fuzz.cc +++ b/tensorflow/security/fuzzing/cc/parseURI_fuzz.cc @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include #include #include "fuzztest/fuzztest.h" @@ -27,7 +26,7 @@ limitations under the License. namespace { void FuzzTest(std::string_view uri) { - tensorflow::StringPiece scheme, host, path; + absl::string_view scheme, host, path; tensorflow::io::ParseURI(uri, &scheme, &host, &path); // If a path is invalid. diff --git a/tensorflow/security/fuzzing/cc/status_fuzz.cc b/tensorflow/security/fuzzing/cc/status_fuzz.cc index 9e259fd4e8d4c9..7fdc96c94e41b8 100644 --- a/tensorflow/security/fuzzing/cc/status_fuzz.cc +++ b/tensorflow/security/fuzzing/cc/status_fuzz.cc @@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include #include #include diff --git a/tensorflow/security/fuzzing/cc/status_group_fuzz.cc b/tensorflow/security/fuzzing/cc/status_group_fuzz.cc index a0273717367262..dd2169cb117ca3 100644 --- a/tensorflow/security/fuzzing/cc/status_group_fuzz.cc +++ b/tensorflow/security/fuzzing/cc/status_group_fuzz.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include "fuzztest/fuzztest.h" diff --git a/tensorflow/security/fuzzing/cc/string_replace_fuzz.cc b/tensorflow/security/fuzzing/cc/string_replace_fuzz.cc index ca280a057366f9..e41334529b52a2 100644 --- a/tensorflow/security/fuzzing/cc/string_replace_fuzz.cc +++ b/tensorflow/security/fuzzing/cc/string_replace_fuzz.cc @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include #include #include "fuzztest/fuzztest.h" @@ -26,9 +24,9 @@ namespace { void FuzzTest(bool all_flag, std::string s, std::string oldsub, std::string newsub) { - tensorflow::StringPiece sp(s); - tensorflow::StringPiece oldsubp(oldsub); - tensorflow::StringPiece newsubp(newsub); + absl::string_view sp(s); + absl::string_view oldsubp(oldsub); + absl::string_view newsubp(newsub); std::string subbed = tensorflow::str_util::StringReplace(sp, oldsubp, newsubp, all_flag); diff --git a/tensorflow/security/fuzzing/cc/stringprintf_fuzz.cc b/tensorflow/security/fuzzing/cc/stringprintf_fuzz.cc index a37c82a2490700..76a8ffe5f9bef7 100644 --- a/tensorflow/security/fuzzing/cc/stringprintf_fuzz.cc +++ b/tensorflow/security/fuzzing/cc/stringprintf_fuzz.cc @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include #include #include diff --git a/tensorflow/security/fuzzing/cc/tstring_fuzz.cc b/tensorflow/security/fuzzing/cc/tstring_fuzz.cc index e69aa09b4588ed..788191b7e8e952 100644 --- a/tensorflow/security/fuzzing/cc/tstring_fuzz.cc +++ b/tensorflow/security/fuzzing/cc/tstring_fuzz.cc @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include #include #include diff --git a/tensorflow/security/fuzzing/py/annotation_types.py b/tensorflow/security/fuzzing/py/annotation_types.py index 4ce6fa3cf85fb3..b03f66e5e29ca2 100644 --- a/tensorflow/security/fuzzing/py/annotation_types.py +++ b/tensorflow/security/fuzzing/py/annotation_types.py @@ -30,6 +30,15 @@ def _create_dtype_wrapper(name, underlying_dtype: _dtypes.DType): Complex64 = _create_dtype_wrapper("Complex64", _dtypes.complex64) Float8e4m3fn = _create_dtype_wrapper("Float8e4m3fn", _dtypes.float8_e4m3fn) Float8e5m2 = _create_dtype_wrapper("Float8e5m2", _dtypes.float8_e5m2) +Float8e4m3fnuz = _create_dtype_wrapper( + "Float8e4m3fnuz", _dtypes.float8_e4m3fnuz +) +Float8e4m3b11fnuz = _create_dtype_wrapper( + "Float8e4m3b11fnuz", _dtypes.float8_e4m3b11fnuz +) +Float8e5m2fnuz = _create_dtype_wrapper( + "Float8e5m2fnuz", _dtypes.float8_e5m2fnuz +) Float16 = _create_dtype_wrapper("Float16", _dtypes.float16) Float32 = _create_dtype_wrapper("Float32", _dtypes.float32) Float64 = _create_dtype_wrapper("Float64", _dtypes.float64) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 3225f6f6ba4163..6a77797b545098 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -918,6 +918,13 @@ def tf_cc_shared_library_opensource( """Configures the shared object file for TensorFlow.""" if use_pywrap_rules(): + # TODO(b/356020232): move to a simple top-level target once this macro is removed. + # This target is used solely for filtering purposes and not put directly into + # any final binary artifacts. + cc_library( + name = "%s_pywrap_filter" % name, + deps = roots, + ) return names = _get_shared_library_name_os_version_matrix( @@ -2313,7 +2320,7 @@ def tf_custom_op_library( gpu_deps = [] if use_pywrap_rules(): - deps = [clean_dep("//tensorflow/python:_pywrap_tensorflow_common")] + deps + deps = [clean_dep("//tensorflow/python:tensorflow_common_framework")] + deps else: deps = list(deps) @@ -2620,20 +2627,17 @@ def py_test( exec_properties = None, test_rule = _plain_py_test, env = {}, + extra_pywrap_deps = [clean_dep("//tensorflow/python:_pywrap_tensorflow")], **kwargs): if not exec_properties: exec_properties = tf_exec_properties(kwargs) if use_pywrap_rules(): - test_env = { - "PYWRAP_TARGET": clean_dep(Label("//tensorflow/python:_pywrap_tensorflow")), - } - test_env.update(env) actual_deps = deps.to_list() if hasattr(deps, "to_list") else deps test_rule( - deps = actual_deps + [test_env["PYWRAP_TARGET"]], + deps = actual_deps + extra_pywrap_deps, exec_properties = exec_properties, - env = test_env, + env = env, data = data, **kwargs ) @@ -3141,9 +3145,10 @@ def pybind_extension_opensource( srcs_version = "PY3", testonly = None, visibility = None, - win_def_file = None): + win_def_file = None, + starlark_only = False): """Builds a generic Python extension module.""" - _ignore = [enable_stub_generation, additional_stubgen_deps, module_name] # buildifier: disable=unused-variable + _ignore = [enable_stub_generation, additional_stubgen_deps, module_name, starlark_only] # buildifier: disable=unused-variable p = name.rfind("/") if p == -1: sname = name @@ -3326,7 +3331,23 @@ def pybind_extension_opensource( ) # Export open source version of pybind_extension under base name as well. -pybind_extension = _pybind_extension if use_pywrap_rules() else pybind_extension_opensource +def pybind_extension( + name, + common_lib_packages = [], + pywrap_only = False, + **kwargs): + if use_pywrap_rules(): + _pybind_extension( + name = name, + common_lib_packages = common_lib_packages + ["tensorflow", "tensorflow/python"], + **kwargs + ) + elif not pywrap_only: + pybind_extension_opensource( + name = name, + **kwargs + ) + stripped_cc_info = _stripped_cc_info # Note: we cannot add //third_party/tf_runtime:__subpackages__ here, @@ -3372,7 +3393,6 @@ def tf_python_pybind_static_deps(testonly = False): "@cpuinfo//:__subpackages__", "@curl//:__subpackages__", "@dlpack//:__subpackages__", - "@double_conversion//:__subpackages__", "@eigen_archive//:__subpackages__", "@farmhash_archive//:__subpackages__", "@farmhash_gpu_archive//:__subpackages__", @@ -3442,7 +3462,8 @@ def tf_python_pybind_extension_opensource( visibility = None, win_def_file = None, additional_exported_symbols = None, - linkopts = []): + linkopts = [], + starlark_only = False): """A wrapper macro for pybind_extension_opensource that is used in tensorflow/python/BUILD. Please do not use it anywhere else as it may behave unexpectedly. b/146445820 @@ -3472,10 +3493,11 @@ def tf_python_pybind_extension_opensource( visibility = visibility, win_def_file = win_def_file, linkopts = linkopts, + starlark_only = starlark_only, ) # Export open source version of tf_python_pybind_extension under base name as well. -tf_python_pybind_extension = _pybind_extension if use_pywrap_rules() else tf_python_pybind_extension_opensource +tf_python_pybind_extension = pybind_extension if use_pywrap_rules() else tf_python_pybind_extension_opensource def tf_pybind_cc_library_wrapper_opensource(name, deps, visibility = None, **kwargs): """Wrapper for cc_library and proto dependencies used by tf_python_pybind_extension_opensource. diff --git a/tensorflow/tf_version_script.lds b/tensorflow/tf_version_script.lds index 968683fa698631..a1447e68a13e52 100644 --- a/tensorflow/tf_version_script.lds +++ b/tensorflow/tf_version_script.lds @@ -16,7 +16,14 @@ tensorflow { *tsl*; *lite*; *TFL*; + *TfLite*; *quantization*; + *mlir*detail*; + *mlir*func*; + *mlir*TF*; + *mlir*shape*; + *mlir*scf*; + *mlir*quant*; local: *; }; diff --git a/tensorflow/tools/api/golden/v1/tensorflow.dtypes.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.dtypes.experimental.pbtxt index 8b5291efaf7d60..54f9dbc4a6781d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.dtypes.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.dtypes.experimental.pbtxt @@ -1,13 +1,25 @@ path: "tensorflow.dtypes.experimental" tf_module { + member { + name: "float8_e4m3b11fnuz" + mtype: "" + } member { name: "float8_e4m3fn" mtype: "" } + member { + name: "float8_e4m3fnuz" + mtype: "" + } member { name: "float8_e5m2" mtype: "" } + member { + name: "float8_e5m2fnuz" + mtype: "" + } member { name: "int4" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt index d0805a722bfc21..649a60a67494f9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt @@ -36,14 +36,26 @@ tf_module { name: "extension_type" mtype: "" } + member { + name: "float8_e4m3b11fnuz" + mtype: "" + } member { name: "float8_e4m3fn" mtype: "" } + member { + name: "float8_e4m3fnuz" + mtype: "" + } member { name: "float8_e5m2" mtype: "" } + member { + name: "float8_e5m2fnuz" + mtype: "" + } member { name: "int4" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.dtypes.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.dtypes.experimental.pbtxt index 8b5291efaf7d60..54f9dbc4a6781d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.dtypes.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.dtypes.experimental.pbtxt @@ -1,13 +1,25 @@ path: "tensorflow.dtypes.experimental" tf_module { + member { + name: "float8_e4m3b11fnuz" + mtype: "" + } member { name: "float8_e4m3fn" mtype: "" } + member { + name: "float8_e4m3fnuz" + mtype: "" + } member { name: "float8_e5m2" mtype: "" } + member { + name: "float8_e5m2fnuz" + mtype: "" + } member { name: "int4" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.experimental.pbtxt index 61d39f73849443..4f7f48b27ef3a3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.experimental.pbtxt @@ -44,14 +44,26 @@ tf_module { name: "extension_type" mtype: "" } + member { + name: "float8_e4m3b11fnuz" + mtype: "" + } member { name: "float8_e4m3fn" mtype: "" } + member { + name: "float8_e4m3fnuz" + mtype: "" + } member { name: "float8_e5m2" mtype: "" } + member { + name: "float8_e5m2fnuz" + mtype: "" + } member { name: "int4" mtype: "" diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index 546eb464b5adf7..1d161320ac5356 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -198,6 +198,8 @@ def _FilterGoldenProtoDict(golden_proto_dict, omit_golden_symbols_map): elif api_object.HasField('tf_class'): module_or_class = api_object.tf_class if module_or_class is not None: + if 'is_instance' in symbol_list: + del module_or_class.is_instance[:] for members in (module_or_class.member, module_or_class.member_method): filtered_members = [m for m in members if m.name not in symbol_list] # Two steps because protobuf repeated fields disallow slice assignment. @@ -404,6 +406,7 @@ def _ReadFileToProto(filename): } golden_proto_dict = _FilterGoldenProtoDict(golden_proto_dict, omit_golden_symbols_map) + proto_dict = _FilterGoldenProtoDict(proto_dict, omit_golden_symbols_map) # Diff them. Do not fail if called with update. # If the test is run to update goldens, only report diffs but do not fail. @@ -429,6 +432,9 @@ def testAPIBackwardsCompatibility(self): omit_golden_symbols_map['tensorflow.summary'] = [ 'audio', 'histogram', 'image', 'scalar', 'text' ] + omit_golden_symbols_map.update( + self._ignored_is_instance_types(['tensorflow.__internal__.FuncGraph']) + ) self._checkBackwardsCompatibility( tf, @@ -447,6 +453,10 @@ def testAPIBackwardsCompatibilityV1(self): golden_file_patterns = os.path.join( resource_loader.get_root_dir_with_all_resources(), _KeyToFilePath('*', api_version)) + omit_golden_symbols_map = {'tensorflow': ['pywrap_tensorflow']} + omit_golden_symbols_map.update( + self._ignored_is_instance_types(['tensorflow.python_io.TFRecordWriter']) + ) self._checkBackwardsCompatibility( tf.compat.v1, golden_file_patterns, @@ -455,7 +465,7 @@ def testAPIBackwardsCompatibilityV1(self): 'tf': ['pywrap_tensorflow'], 'tf.compat': ['v1', 'v2'], }, - omit_golden_symbols_map={'tensorflow': ['pywrap_tensorflow']}) + omit_golden_symbols_map=omit_golden_symbols_map) def testAPIBackwardsCompatibilityV2(self): api_version = 2 @@ -469,6 +479,10 @@ def testAPIBackwardsCompatibilityV2(self): omit_golden_symbols_map['tensorflow.summary'] = [ 'audio', 'histogram', 'image', 'scalar', 'text' ] + omit_golden_symbols_map.update( + self._ignored_is_instance_types(['tensorflow.__internal__.FuncGraph']) + ) + self._checkBackwardsCompatibility( tf.compat.v2, golden_file_patterns, @@ -476,6 +490,33 @@ def testAPIBackwardsCompatibilityV2(self): additional_private_map={'tf.compat': ['v1', 'v2']}, omit_golden_symbols_map=omit_golden_symbols_map) + def _ignored_is_instance_types(self, extra_types=None): + # In case a new type is defined within a pywrap_.so library, + # it will end up having proper type and location in distributed OSS wheel + # package eventually, but that conversion happens after this test is ran. + # + # Making this test depend on wheel itself also breaks because wheels use + # _upb as underlying protobuf implementation while internal TF uses cpp + # implementation (resulting in different is_instance values for protobuf + # metadata types in golden pbtxt depending on which protobuf implementation + # is being used during test execution). The cpp implementation is not even + # included anymore in protobuf oss wheels. + # + # We end up in a situation when we cannot make this test pass internally and + # externally on the same set of golden expected .pbtxt inputs. It is rare + # and minor discrepancy, so just ignore the is_instance checks for the few + # problematic types, they are guaraneed to have proper types in final wheel + # anyway. + ignored_is_instance_types = [ + 'tensorflow.DType', + 'tensorflow.dtypes.DType', + 'tensorflow.__internal__.SymbolicTensor', + 'tensorflow.Graph', + 'tensorflow.Operation', + 'tensorflow.io.TFRecordWriter' + ] + extra_types if extra_types else [] + return {k: 'is_instance' for k in ignored_is_instance_types} + if __name__ == '__main__': parser = argparse.ArgumentParser() diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc index b135554bfaabba..fc4a4d05d996bd 100644 --- a/tensorflow/tools/benchmark/benchmark_model.cc +++ b/tensorflow/tools/benchmark/benchmark_model.cc @@ -555,7 +555,7 @@ int Main(int argc, char** argv) { str_util::Split(input_layer_shapes[n], ','); for (const string& layer_shape : split_layer_shapes) { int32_t tmp; - CHECK(strings::safe_strto32(layer_shape, &tmp)) + CHECK(absl::SimpleAtoi(layer_shape, &tmp)) << "Incorrect size string specified: " << input_layer_shapes[n]; if (tmp == -1) { LOG(ERROR) << "Any unknown sizes in the shapes (-1's) must be replaced" @@ -573,7 +573,7 @@ int Main(int argc, char** argv) { input.initialization_values.reserve(string_tokens.size()); for (const string& str_val : string_tokens) { float val; - CHECK(strings::safe_strtof(str_val, &val)) + CHECK(absl::SimpleAtof(str_val, &val)) << "Incorrect initialization values string specified: " << input_layer_values[n]; input.initialization_values.push_back(val); diff --git a/tensorflow/tools/ci_build/rel/windows/cpu_libtensorflow.bat b/tensorflow/tools/ci_build/rel/windows/cpu_libtensorflow.bat index b28c53a90bd078..ed3638379187f8 100644 --- a/tensorflow/tools/ci_build/rel/windows/cpu_libtensorflow.bat +++ b/tensorflow/tools/ci_build/rel/windows/cpu_libtensorflow.bat @@ -16,7 +16,7 @@ SET TF_DIR=%cd% SET TF_DOCKER_DIR=C:\src\tensorflow REM TODO(belitskiy): Switch to Artifact Registry -set TF_DOCKER_IMAGE="gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc" +set TF_DOCKER_IMAGE="gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:d3577d20dea75966faf7fd03479c71462441937df5694259109c2ee1d002a3dd" docker pull %TF_DOCKER_IMAGE% || exit /b 1 @echo *****Finished docker image pull: %date% %time% diff --git a/tensorflow/tools/gcs_test/Dockerfile b/tensorflow/tools/gcs_test/Dockerfile index 11bef82159f19f..549326b6e9e3f9 100644 --- a/tensorflow/tools/gcs_test/Dockerfile +++ b/tensorflow/tools/gcs_test/Dockerfile @@ -1,4 +1,4 @@ -FROM ubuntu:24.04@sha256:278628f08d4979fb9af9ead44277dbc9c92c2465922310916ad0c46ec9999295 +FROM ubuntu:24.04@sha256:80dd3c3b9c6cecb9f1667e9290b3bc61b78c2678c02cbdae5f0fea92cc6734ab LABEL maintainer="Shanqing Cai " diff --git a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc index 5e92435e482a14..9901d565adfc2e 100644 --- a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc +++ b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc @@ -29,8 +29,8 @@ struct MinMaxRecord { // Try to parse a log file containing loosely-structured lines, some of which // are the min/max logs we want. -Status ExtractMinMaxRecords(const string& log_file_name, - std::vector* records) { +absl::Status ExtractMinMaxRecords(const string& log_file_name, + std::vector* records) { string file_data; TF_RETURN_IF_ERROR( ReadFileToString(Env::Default(), log_file_name, &file_data)); @@ -88,21 +88,21 @@ Status ExtractMinMaxRecords(const string& log_file_name, continue; } StringPiece name_string = line_parts[min_max_index - 1]; - if (!str_util::EndsWith(name_string, print_suffix)) { + if (!absl::EndsWith(name_string, print_suffix)) { continue; } string name( name_string.substr(0, name_string.size() - print_suffix.size())); records->push_back({name, min, max}); } - return OkStatus(); + return absl::OkStatus(); } // Uses the observed min/max values for requantization captured in a log file to // replace costly RequantizationRange ops with simple Consts. -Status FreezeRequantizationRanges(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def) { +absl::Status FreezeRequantizationRanges(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { string min_max_log_file; TF_RETURN_IF_ERROR( context.GetOneStringParameter("min_max_log_file", "", &min_max_log_file)); diff --git a/tensorflow/tools/graph_transforms/insert_logging.cc b/tensorflow/tools/graph_transforms/insert_logging.cc index ccb96efdbd51bb..c138f346b9587d 100644 --- a/tensorflow/tools/graph_transforms/insert_logging.cc +++ b/tensorflow/tools/graph_transforms/insert_logging.cc @@ -79,7 +79,7 @@ absl::Status InsertLogging(const GraphDef& input_graph_def, NodeNamePartsFromInput(canonical_input, &prefix, &name, &suffix); const string output_index_string = suffix.substr(1, suffix.size() - 1); int32_t output_index; - if (!strings::safe_strto32(output_index_string, &output_index)) { + if (!absl::SimpleAtoi(output_index_string, &output_index)) { return errors::InvalidArgument("Couldn't understand output number in ", input); } diff --git a/tensorflow/tools/graph_transforms/transform_graph_test.cc b/tensorflow/tools/graph_transforms/transform_graph_test.cc index dde497436fc0a6..86cbb34e2ac406 100644 --- a/tensorflow/tools/graph_transforms/transform_graph_test.cc +++ b/tensorflow/tools/graph_transforms/transform_graph_test.cc @@ -32,15 +32,15 @@ namespace tensorflow { namespace graph_transforms { // Declared here so we don't have to expose it in the public header. -Status ShouldIgnoreErrors(const TransformFuncParameters& transform_params, - bool* ignore_errors); +absl::Status ShouldIgnoreErrors(const TransformFuncParameters& transform_params, + bool* ignore_errors); namespace { -Status test_empty_graph_transform(const GraphDef& graph_def, - const TransformFuncContext& context, - GraphDef* result) { +absl::Status test_empty_graph_transform(const GraphDef& graph_def, + const TransformFuncContext& context, + GraphDef* result) { result->Clear(); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -114,10 +114,10 @@ class TransformGraphTest : public ::testing::Test { for (const NodeDef& node : out_graph_def.node()) { const int occurrence_count = out_node_map.count(node.name()); - if (str_util::EndsWith(node.name(), "expect_removed")) { + if (absl::EndsWith(node.name(), "expect_removed")) { EXPECT_EQ(0, occurrence_count) << "node.name()=" << node.name(); } - if (str_util::EndsWith(node.name(), "expect_remains")) { + if (absl::EndsWith(node.name(), "expect_remains")) { EXPECT_EQ(1, occurrence_count) << "node.name()=" << node.name(); } } @@ -136,7 +136,7 @@ class TransformGraphTest : public ::testing::Test { EXPECT_EQ(0, graph_def.node().size()); TF_ASSERT_OK(root.ToGraphDef(&graph_def)); - Status no_such_status = + absl::Status no_such_status = TransformGraph({}, {}, {{"test_no_such_transform", {}}}, &graph_def); EXPECT_TRUE(absl::StrContains(no_such_status.ToString(), "not recognized")); } diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc index d3cc3c85db2cfe..eb2760dfb548d8 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.cc +++ b/tensorflow/tools/graph_transforms/transform_utils.cc @@ -181,8 +181,8 @@ void RemoveAttributes(const GraphDef& input_graph_def, } } -Status SortByExecutionOrder(const GraphDef& input_graph_def, - GraphDef* output_graph_def) { +absl::Status SortByExecutionOrder(const GraphDef& input_graph_def, + GraphDef* output_graph_def) { const int num_nodes = input_graph_def.node_size(); std::vector ready; std::vector pending_count; @@ -260,7 +260,7 @@ Status SortByExecutionOrder(const GraphDef& input_graph_def, } return errors::InvalidArgument(num_nodes - processed, " nodes in a cycle"); } - return OkStatus(); + return absl::OkStatus(); } string OpTypePattern::DebugString() const { @@ -288,8 +288,8 @@ GraphMatcher::GraphMatcher(const GraphDef& graph_def) { MapNamesToNodes(graph_def_, &node_map_); } -Status GraphMatcher::GetOpTypeMatches(const OpTypePattern& pattern, - std::vector* matches) { +absl::Status GraphMatcher::GetOpTypeMatches(const OpTypePattern& pattern, + std::vector* matches) { std::set matched_nodes; for (const NodeDef& node : graph_def_.node()) { // Skip any nodes that are already part of a match. @@ -302,7 +302,7 @@ Status GraphMatcher::GetOpTypeMatches(const OpTypePattern& pattern, matches->push_back(match); } } - return OkStatus(); + return absl::OkStatus(); } bool GraphMatcher::DoesOpTypeMatch( @@ -360,11 +360,11 @@ bool GraphMatcher::DoesOpTypeMatch( return true; } -Status ReplaceMatchingOpTypes( +absl::Status ReplaceMatchingOpTypes( const GraphDef& input_graph_def, const OpTypePattern& pattern, - const std::function&, - const std::set&, std::vector*)>& - node_generator, + const std::function&, + const std::set&, + std::vector*)>& node_generator, const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def) { // Start off by retrieving all the matching subgraphs. GraphMatcher matcher(input_graph_def); @@ -471,13 +471,13 @@ Status ReplaceMatchingOpTypes( } } - return OkStatus(); + return absl::OkStatus(); } -Status RenameNodeInputs(const GraphDef& input_graph_def, - const std::map& inputs_to_rename, - const std::unordered_set& nodes_to_ignore, - GraphDef* output_graph_def) { +absl::Status RenameNodeInputs(const GraphDef& input_graph_def, + const std::map& inputs_to_rename, + const std::unordered_set& nodes_to_ignore, + GraphDef* output_graph_def) { std::map>> canonical_inputs_to_rename; for (const auto& input_to_rename : inputs_to_rename) { @@ -512,7 +512,7 @@ Status RenameNodeInputs(const GraphDef& input_graph_def, const string& dest_name = input_to_rename.second; bool is_match; string match_name; - if (str_util::EndsWith(source_name, ":*")) { + if (absl::EndsWith(source_name, ":*")) { is_match = true; string prefix; string unused_node_name; @@ -537,7 +537,7 @@ Status RenameNodeInputs(const GraphDef& input_graph_def, *(new_node->mutable_input()->Add()) = new_input_name; } } - return OkStatus(); + return absl::OkStatus(); } void CopyOriginalMatch(const NodeMatch& match, @@ -569,7 +569,7 @@ void FindInvalidInputs(const GraphDef& graph_def, } } -Status IsGraphValid(const GraphDef& graph_def) { +absl::Status IsGraphValid(const GraphDef& graph_def) { std::vector> invalid_inputs; FindInvalidInputs(graph_def, &invalid_inputs); if (!invalid_inputs.empty()) { @@ -583,18 +583,19 @@ Status IsGraphValid(const GraphDef& graph_def) { return errors::Internal( "Invalid graph with inputs referring to nonexistent nodes"); } - return OkStatus(); + return absl::OkStatus(); } -Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, - DataTypeVector* outputs) { +absl::Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, + DataTypeVector* outputs) { const OpDef* op_def; TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def)); TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, *op_def, inputs, outputs)); - return OkStatus(); + return absl::OkStatus(); } -Status TensorShapeFromString(const string& shape_string, TensorShape* result) { +absl::Status TensorShapeFromString(const string& shape_string, + TensorShape* result) { if (shape_string.empty()) { return errors::InvalidArgument("Specified shape is empty."); } @@ -610,7 +611,7 @@ Status TensorShapeFromString(const string& shape_string, TensorShape* result) { } } *result = TensorShape(dims); - return OkStatus(); + return absl::OkStatus(); } int TransformFuncContext::CountParameters(const string& name) const { @@ -621,16 +622,15 @@ int TransformFuncContext::CountParameters(const string& name) const { } } -Status TransformFuncContext::GetOneStringParameter(const string& name, - const string& default_value, - string* result) const { +absl::Status TransformFuncContext::GetOneStringParameter( + const string& name, const string& default_value, string* result) const { const int params_count = CountParameters(name); if (params_count == 0) { *result = default_value; - return OkStatus(); + return absl::OkStatus(); } else if (params_count == 1) { *result = params.at(name).at(0); - return OkStatus(); + return absl::OkStatus(); } else { return errors::InvalidArgument("Expected a single '", name, "' parameter, but found ", params_count, @@ -638,13 +638,13 @@ Status TransformFuncContext::GetOneStringParameter(const string& name, } } -Status TransformFuncContext::GetOneInt32Parameter(const string& name, - int32_t default_value, - int32* result) const { +absl::Status TransformFuncContext::GetOneInt32Parameter(const string& name, + int32_t default_value, + int32* result) const { const int params_count = CountParameters(name); if (params_count == 0) { *result = default_value; - return OkStatus(); + return absl::OkStatus(); } string string_value; TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value)); @@ -652,16 +652,16 @@ Status TransformFuncContext::GetOneInt32Parameter(const string& name, return errors::InvalidArgument("Couldn't interpret the ", name, " argument as a number:", string_value); } - return OkStatus(); + return absl::OkStatus(); } -Status TransformFuncContext::GetOneInt64Parameter(const string& name, - int64_t default_value, - int64_t* result) const { +absl::Status TransformFuncContext::GetOneInt64Parameter(const string& name, + int64_t default_value, + int64_t* result) const { const int params_count = CountParameters(name); if (params_count == 0) { *result = default_value; - return OkStatus(); + return absl::OkStatus(); } string string_value; TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value)); @@ -669,16 +669,16 @@ Status TransformFuncContext::GetOneInt64Parameter(const string& name, return errors::InvalidArgument("Couldn't interpret the ", name, " argument as a number:", string_value); } - return OkStatus(); + return absl::OkStatus(); } -Status TransformFuncContext::GetOneFloatParameter(const string& name, - float default_value, - float* result) const { +absl::Status TransformFuncContext::GetOneFloatParameter(const string& name, + float default_value, + float* result) const { const int params_count = CountParameters(name); if (params_count == 0) { *result = default_value; - return OkStatus(); + return absl::OkStatus(); } string string_value; TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value)); @@ -687,16 +687,16 @@ Status TransformFuncContext::GetOneFloatParameter(const string& name, "Couldn't interpret the ", name, " argument as a float number:", string_value); } - return OkStatus(); + return absl::OkStatus(); } -Status TransformFuncContext::GetOneBoolParameter(const string& name, - bool default_value, - bool* result) const { +absl::Status TransformFuncContext::GetOneBoolParameter(const string& name, + bool default_value, + bool* result) const { const int params_count = CountParameters(name); if (params_count == 0) { *result = default_value; - return OkStatus(); + return absl::OkStatus(); } string string_value; TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value)); @@ -709,7 +709,7 @@ Status TransformFuncContext::GetOneBoolParameter(const string& name, " argument as a boolean:", string_value, " (expected true, false, 0 or 1)"); } - return OkStatus(); + return absl::OkStatus(); } } // namespace graph_transforms diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index a122a44de73be7..c77a7ad11c2153 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -155,6 +155,7 @@ genrule( "@tf_runtime//:LICENSE", "@local_tsl//:LICENSE", "@local_xla//:LICENSE", + "@XNNPACK//:LICENSE", ] + select({ "//tensorflow:android": [], "//tensorflow:ios": [], @@ -198,6 +199,7 @@ genrule( "@tf_runtime//:LICENSE", "@local_tsl//:LICENSE", "@local_xla//:LICENSE", + "@XNNPACK//:LICENSE", ] + select({ "//tensorflow:android": [], "//tensorflow:ios": [], diff --git a/tensorflow/tools/optimization/BUILD b/tensorflow/tools/optimization/BUILD index 928adec880d5cf..c43a8f34f93509 100644 --- a/tensorflow/tools/optimization/BUILD +++ b/tensorflow/tools/optimization/BUILD @@ -48,6 +48,7 @@ tf_cc_binary( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", + "@com_google_absl//absl/status", "@local_tsl//tsl/platform:status", ], ) diff --git a/tensorflow/tools/optimization/gpu_optimization_pass_runner_main.cc b/tensorflow/tools/optimization/gpu_optimization_pass_runner_main.cc index 300552914c230a..5801deb1b6f6f7 100644 --- a/tensorflow/tools/optimization/gpu_optimization_pass_runner_main.cc +++ b/tensorflow/tools/optimization/gpu_optimization_pass_runner_main.cc @@ -17,6 +17,10 @@ limitations under the License. // --output_file_path=/tmp/output.pbtxt // --optimization_pass=NameOfGraphOptimizationPass +#include +#include + +#include "absl/status/status.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" diff --git a/tensorflow/tools/optimization/optimization_pass_runner.cc b/tensorflow/tools/optimization/optimization_pass_runner.cc index 008cf9a6f50a58..c14ccb68db4b61 100644 --- a/tensorflow/tools/optimization/optimization_pass_runner.cc +++ b/tensorflow/tools/optimization/optimization_pass_runner.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/tools/optimization/optimization_pass_runner.h" #include -#include +#include #include #include "absl/status/status.h" diff --git a/tensorflow/tools/optimization/optimization_pass_runner.h b/tensorflow/tools/optimization/optimization_pass_runner.h index 5c81f2a13a7396..cd4dcaa3eb42c4 100644 --- a/tensorflow/tools/optimization/optimization_pass_runner.h +++ b/tensorflow/tools/optimization/optimization_pass_runner.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/optimization_registry.h" diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 03ebe530a3b4fd..829fa4a0e0e15e 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -7,10 +7,19 @@ load( "@local_tsl//third_party/py:py_import.bzl", "py_import", ) -load("@local_xla//xla/tsl:tsl.bzl", "if_cuda_libs") +load( + "@local_tsl//third_party/py:py_manylinux_compliance_test.bzl", + "verify_manylinux_compliance_test", +) load("@local_xla//xla/tsl/mkl:build_defs.bzl", "if_enable_mkl", "if_mkl", "if_mkl_ml") load("//tensorflow:tensorflow.bzl", "if_wheel_dependency", "if_with_tpu_support", "transitive_hdrs") -load("//tensorflow/core/platform:build_config_root.bzl", "if_pywrap", "tf_additional_license_deps") +load( + "//tensorflow/core/platform:build_config_root.bzl", + "if_pywrap", + "tf_additional_license_deps", + "tf_cuda_tests_tags", + "tf_exec_properties", +) load("//tensorflow/tools/pip_package/utils:data_deps.bzl", "collect_data_files") load("//tensorflow/tools/pip_package/utils:py_deps.bzl", "transitive_py_deps") load("//tensorflow/tools/pip_package/utils:tf_wheel.bzl", "tf_wheel", "tf_wheel_dep") @@ -273,12 +282,6 @@ tf_wheel( ":licenses", "//tensorflow/core:protos_all_proto_srcs", ], - manylinux_compliance_tag = select({ - "@platforms//cpu:aarch64": "manylinux_2_17_aarch64", - "@platforms//cpu:arm64": "manylinux_2_17_aarch64", - "@platforms//cpu:x86_64": "manylinux_2_17_x86_64", - "//conditions:default": "", - }), platform_name = select({ "@platforms//os:osx": "macosx", "@platforms//os:macos": "macosx", @@ -311,7 +314,7 @@ tf_wheel( "//tensorflow:tensorflow_framework", ], if_true = [ - "//tensorflow/python:_pywrap_tensorflow_binaries", + "//tensorflow/python:pywrap_tensorflow_binaries", ], ), "//tensorflow:windows": if_pywrap( @@ -341,7 +344,7 @@ sh_binary( ) py_test( - name = "prebuilt_wheel_import_api_packages_test", + name = "prebuilt_wheel_import_api_packages_test_cpu", srcs = if_wheel_dependency( ["import_api_packages_test.py"], [":empty_test"], @@ -349,20 +352,57 @@ py_test( main = if_wheel_dependency("import_api_packages_test.py", "empty_test.py"), tags = [ "cpu", + "manual", + "windows_excluded", + ], + deps = if_wheel_dependency(tf_wheel_dep()), +) + +py_test( + name = "prebuilt_wheel_import_api_packages_test_gpu", + srcs = if_wheel_dependency( + ["import_api_packages_test.py"], + [":empty_test"], + ), + exec_properties = if_cuda( + tf_exec_properties({"tags": tf_cuda_tests_tags()}), + {}, + ), + main = if_wheel_dependency("import_api_packages_test.py", "empty_test.py"), + tags = [ "gpu", + "manual", "windows_excluded", ], deps = if_wheel_dependency(tf_wheel_dep()), ) py_test( - name = "import_api_packages_test", + name = "import_api_packages_test_cpu", srcs = ["import_api_packages_test.py"], main = "import_api_packages_test.py", tags = [ "cuda-only", #TODO(rocm): weekly-sync 24-12-10 "cpu", + "manual", + "windows_excluded", + ], + deps = [ + ":tf_py_import", + ], +) + +py_test( + name = "import_api_packages_test_gpu", + srcs = ["import_api_packages_test.py"], + exec_properties = if_cuda( + tf_exec_properties({"tags": tf_cuda_tests_tags()}), + {}, + ), + main = "import_api_packages_test.py", + tags = [ "gpu", + "manual", "windows_excluded", ], deps = [ @@ -370,23 +410,33 @@ py_test( ], ) +verify_manylinux_compliance_test( + name = "manylinux_compliance_test", + aarch64_compliance_tag = "manylinux_2_17_aarch64", + test_tags = [ + "mac_excluded", + "windows_excluded", + ], + wheel = ":wheel", + x86_64_compliance_tag = "manylinux_2_17_x86_64", +) + py_import( name = "tf_py_import", - cc_deps = if_cuda_libs([ - "@cuda_cublas//:cublas", - "@cuda_cublas//:cublasLt", - "@cuda_cudart//:cudart", - "@cuda_cudnn//:cudnn", - "@cuda_cufft//:cufft", - "@cuda_cupti//:cupti", - "@cuda_curand//:curand", - "@cuda_cusolver//:cusolver", - "@cuda_cusparse//:cusparse", - "@cuda_nccl//:nccl", - "@cuda_nvjitlink//:nvjitlink", - "@cuda_nvrtc//:nvrtc", - ]), wheel = ":wheel", + wheel_deps = if_cuda([ + "@pypi_nvidia_cublas_cu12//:whl", + "@pypi_nvidia_cuda_cupti_cu12//:whl", + "@pypi_nvidia_cuda_nvrtc_cu12//:whl", + "@pypi_nvidia_cuda_runtime_cu12//:whl", + "@pypi_nvidia_cudnn_cu12//:whl", + "@pypi_nvidia_cufft_cu12//:whl", + "@pypi_nvidia_curand_cu12//:whl", + "@pypi_nvidia_cusolver_cu12//:whl", + "@pypi_nvidia_cusparse_cu12//:whl", + "@pypi_nvidia_nccl_cu12//:whl", + "@pypi_nvidia_nvjitlink_cu12//:whl", + ]), deps = [ "@pypi_absl_py//:pkg", "@pypi_astunparse//:pkg", diff --git a/tensorflow/tools/pip_package/build_pip_package.py b/tensorflow/tools/pip_package/build_pip_package.py index e61204d8865c2f..4809d5ec7a7c50 100644 --- a/tensorflow/tools/pip_package/build_pip_package.py +++ b/tensorflow/tools/pip_package/build_pip_package.py @@ -120,6 +120,14 @@ def prepare_headers(headers: list[str], srcs_dir: str) -> None: "python_x86_64", "python_aarch64", "llvm-project/llvm/", + "external/cpuinfo", + "external/FXdiv", + "external/net_zstd", + "external/org_brotli/c", + "external/org_brotli/_virtual_includes", + "external/pthreadpool", + "external/riegeli/riegeli", + "external/XNNPACK/src/", ] path_to_replace = { diff --git a/tensorflow/tools/pip_package/utils/tf_wheel.bzl b/tensorflow/tools/pip_package/utils/tf_wheel.bzl index c8f31d38c6dd67..62bde9c5c02464 100644 --- a/tensorflow/tools/pip_package/utils/tf_wheel.bzl +++ b/tensorflow/tools/pip_package/utils/tf_wheel.bzl @@ -74,7 +74,6 @@ def _tf_wheel_impl(ctx): " `--@local_config_cuda//cuda:override_include_cuda_libs=true`.") executable = ctx.executable.wheel_binary - verify_manylinux = ctx.attr.verify_manylinux[BuildSettingInfo].value full_wheel_name = _get_full_wheel_name( platform_name = ctx.attr.platform_name, platform_tag = ctx.attr.platform_tag, @@ -120,23 +119,7 @@ def _tf_wheel_impl(ctx): outputs = [output_file], executable = executable, ) - auditwheel_show_log = None - if ctx.attr.platform_name == "linux": - auditwheel_show_log = ctx.actions.declare_file("auditwheel_show.log") - args = ctx.actions.args() - args.add("--wheel_path", output_file.path) - if verify_manylinux: - args.add("--compliance-tag", ctx.attr.manylinux_compliance_tag) - args.add("--auditwheel-show-log-path", auditwheel_show_log.path) - ctx.actions.run( - arguments = [args], - inputs = [output_file], - outputs = [auditwheel_show_log], - executable = ctx.executable.verify_manylinux_compliance_binary, - ) - - auditwheel_show_output = [auditwheel_show_log] if auditwheel_show_log else [] - return [DefaultInfo(files = depset(direct = [output_file] + auditwheel_show_output))] + return [DefaultInfo(files = depset(direct = [output_file]))] tf_wheel = rule( attrs = { @@ -153,13 +136,6 @@ tf_wheel = rule( "override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")), "platform_tag": attr.string(mandatory = True), "platform_name": attr.string(mandatory = True), - "verify_manylinux_compliance_binary": attr.label( - default = Label("@local_tsl//third_party/py:verify_manylinux_compliance"), - executable = True, - cfg = "exec", - ), - "verify_manylinux": attr.label(default = Label("@local_tsl//third_party/py:verify_manylinux")), - "manylinux_compliance_tag": attr.string(mandatory = True), }, implementation = _tf_wheel_impl, ) diff --git a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc index 3c2853a86b1b08..7bec60085a9ff5 100644 --- a/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc +++ b/tensorflow/tools/proto_splitter/cc/composable_splitter_base.cc @@ -225,7 +225,7 @@ ComposableSplitterBase::WriteToCord() { absl::Cord output; if (chunked_message->chunked_fields().empty()) { // Export regular pb. - if (!message_->SerializeToCord(&output)) + if (!message_->SerializeToString(&output)) return absl::InvalidArgumentError("Serialization to absl::Cord failed"); LOG(INFO) << "Splitter output written to absl::Cord"; return std::make_tuple(output, false); diff --git a/tensorflow/tools/tfg_graph_transforms/utils.cc b/tensorflow/tools/tfg_graph_transforms/utils.cc index 2fe2e9476e7659..4d7d191cc58508 100644 --- a/tensorflow/tools/tfg_graph_transforms/utils.cc +++ b/tensorflow/tools/tfg_graph_transforms/utils.cc @@ -39,7 +39,7 @@ absl::string_view GetNameWithoutExtension(absl::string_view filename) { } // namespace bool IsTextProto(const std::string& input_file) { - tensorflow::StringPiece extension = tensorflow::io::Extension(input_file); + absl::string_view extension = tensorflow::io::Extension(input_file); return !extension.compare("pbtxt"); } diff --git a/tensorflow/tools/tfg_graph_transforms/utils.h b/tensorflow/tools/tfg_graph_transforms/utils.h index 9ea59a385ad6ee..84b9f87ec84e91 100644 --- a/tensorflow/tools/tfg_graph_transforms/utils.h +++ b/tensorflow/tools/tfg_graph_transforms/utils.h @@ -38,7 +38,7 @@ namespace graph_transforms { template absl::Status ReadModelProto(const std::string& input_file, T& model_proto) { // Proto might be either in binary or text format. - tensorflow::StringPiece extension = tensorflow::io::Extension(input_file); + absl::string_view extension = tensorflow::io::Extension(input_file); bool binary_extenstion = !extension.compare("pb"); bool text_extension = !extension.compare("pbtxt"); diff --git a/tensorflow/tools/toolchains/remote_config/configs.bzl b/tensorflow/tools/toolchains/remote_config/configs.bzl index 1182e52997fce0..3fdf6704e2ff53 100644 --- a/tensorflow/tools/toolchains/remote_config/configs.bzl +++ b/tensorflow/tools/toolchains/remote_config/configs.bzl @@ -1,6 +1,6 @@ """Configurations of RBE builds used with remote config.""" -load("//tensorflow/tools/toolchains/remote_config:rbe_config.bzl", "sigbuild_tf_configs", "tensorflow_local_config", "tensorflow_rbe_config", "tensorflow_rbe_win_config") +load("//tensorflow/tools/toolchains/remote_config:rbe_config.bzl", "ml_build_rbe_config", "sigbuild_tf_configs", "tensorflow_local_config", "tensorflow_rbe_config", "tensorflow_rbe_win_config") def initialize_rbe_configs(): tensorflow_local_config( @@ -47,6 +47,11 @@ def initialize_rbe_configs(): python_bin_path = "C:/Python37/python.exe", ) + # The `ml-build-rbe` image is identical to the `ml-build` image except for the base image. + # The `ml-build`'s base image is a standard `ubuntu22.04` image. + # The `ml-build-rbe`'s base image is `nvidia/cuda:12.3.2-base-ubuntu22.04` which has nvidia driver installed. + ml_build_rbe_config("docker://us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-rbe@sha256:aaeb29799463729092c05f5ac8393113b3bb5d1ecf085f9f1f2016e3a1ece11c") + # TF-Version-Specific SIG Build RBE Configs. The crosstool generated from these # configs are python-version-independent because they only care about the # tooling paths; the container mapping is useful only so that TF RBE users diff --git a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl index 8a6120efbbd69d..ddd87ae0cf9786 100644 --- a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl +++ b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl @@ -92,10 +92,24 @@ def _tensorflow_local_config(name): platform_constraint = "@%s_config_platform//:platform_constraint" % name, ) +def _ml_build_rbe_config(container_image): + exec_properties = { + "container-image": container_image, + "Pool": "default", + } + + remote_platform_configure( + name = "ml_build_config_platform", + platform = "linux", + platform_exec_properties = exec_properties, + ) + tensorflow_rbe_config = _tensorflow_rbe_config tensorflow_rbe_win_config = _tensorflow_rbe_win_config tensorflow_local_config = _tensorflow_local_config +ml_build_rbe_config = _ml_build_rbe_config +# TODO(b/369382309): Remove this once ml_build_rbe_config is used everywhere. # Streamlined platform configuration for the SIG Build containers. # See //tensorflow/tools/tf_sig_build_dockerfiles # These containers do not support ROCm and all have CUDA. diff --git a/tensorflow/tools/toolchains/win/20240424/BUILD b/tensorflow/tools/toolchains/win/20240424/BUILD index 93b3c90aff81d9..db4cf0eac92066 100644 --- a/tensorflow/tools/toolchains/win/20240424/BUILD +++ b/tensorflow/tools/toolchains/win/20240424/BUILD @@ -20,24 +20,6 @@ load(":windows_cc_toolchain_config.bzl", "cc_toolchain_config") package(default_visibility = ["//visibility:public"]) -cc_library(name = "empty_lib") - -# Label flag for extra libraries to be linked into every binary. -# TODO(bazel-team): Support passing flag multiple times to build a list. -label_flag( - name = "link_extra_libs", - build_setting_default = ":empty_lib", -) - -# The final extra library to be linked into every binary target. This collects -# the above flag, but may also include more libraries depending on config. -cc_library( - name = "link_extra_lib", - deps = [ - ":link_extra_libs", - ], -) - cc_library( name = "malloc", ) @@ -228,7 +210,8 @@ cc_toolchain_config( compiler = "msvc-cl", cpu = "x64_windows", cxx_builtin_include_directories = [ - "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include", "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", @@ -240,24 +223,24 @@ cc_toolchain_config( default_link_flags = ["/MACHINE:X64"], fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", host_system_name = "local", - msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/cl.exe", - msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", - msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64", - msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\lib\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", msvc_env_tmp = "C:\\Users\\ContainerAdministrator\\AppData\\Local\\Temp", - msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/lib.exe", - msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/link.exe", - msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/ml64.exe", + msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/lib.exe", + msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/link.exe", + msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/ml64.exe", supports_parse_showincludes = True, target_libc = "msvcrt", target_system_name = "local", tool_paths = { - "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/lib.exe", - "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/ml64.exe", - "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/cl.exe", - "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/cl.exe", + "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/lib.exe", + "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/ml64.exe", + "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", + "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", "gcov": "wrapper/bin/msvc_nop.bat", - "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/link.exe", + "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/link.exe", "nm": "wrapper/bin/msvc_nop.bat", "objcopy": "wrapper/bin/msvc_nop.bat", "objdump": "wrapper/bin/msvc_nop.bat", @@ -303,7 +286,8 @@ cc_toolchain_config( compiler = "msvc-cl", cpu = "x64_windows", cxx_builtin_include_directories = [ - "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include", "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", @@ -315,24 +299,24 @@ cc_toolchain_config( default_link_flags = ["/MACHINE:X86"], fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", host_system_name = "local", - msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/cl.exe", - msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", - msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\lib\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x86", - msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\bin\\HostX64\\x86;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\lib\\x86;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x86", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x86;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", msvc_env_tmp = "C:\\Users\\ContainerAdministrator\\AppData\\Local\\Temp", - msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/lib.exe", - msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/link.exe", - msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/ml.exe", + msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/lib.exe", + msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/link.exe", + msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/ml.exe", supports_parse_showincludes = True, target_libc = "msvcrt", target_system_name = "local", tool_paths = { - "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/lib.exe", - "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/ml.exe", - "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/cl.exe", - "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/cl.exe", + "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/lib.exe", + "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/ml.exe", + "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", + "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", "gcov": "wrapper/bin/msvc_nop.bat", - "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/link.exe", + "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/link.exe", "nm": "wrapper/bin/msvc_nop.bat", "objcopy": "wrapper/bin/msvc_nop.bat", "objdump": "wrapper/bin/msvc_nop.bat", @@ -511,7 +495,8 @@ cc_toolchain_config( compiler = "clang-cl", cpu = "x64_windows", cxx_builtin_include_directories = [ - "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include", "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", @@ -521,13 +506,16 @@ cc_toolchain_config( "C:\\tools\\LLVM\\lib\\clang\\18\\include", ], dbg_mode_debug_flag = "/DEBUG", - default_link_flags = ["/MACHINE:X64"], + default_link_flags = [ + "/MACHINE:X64", + "/DEFAULTLIB:clang_rt.builtins-x86_64.lib", + ], fastbuild_mode_debug_flag = "/DEBUG", host_system_name = "local", msvc_cl_path = "C:/tools/LLVM/bin/clang-cl.exe", - msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt;C:\\tools\\LLVM\\lib\\clang\\18\\include", - msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64;C:\\tools\\LLVM\\lib\\clang\\18\\lib\\windows", - msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt;C:\\tools\\LLVM\\lib\\clang\\18\\include", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\lib\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64;C:\\tools\\LLVM\\lib\\clang\\18\\lib\\windows", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", msvc_env_tmp = "C:\\Users\\ContainerAdministrator\\AppData\\Local\\Temp", msvc_lib_path = "C:/tools/LLVM/bin/llvm-lib.exe", msvc_link_path = "C:/tools/LLVM/bin/lld-link.exe", diff --git a/tensorflow/tools/toolchains/win/20240424/builtin_include_directory_paths_clangcl b/tensorflow/tools/toolchains/win/20240424/builtin_include_directory_paths_clangcl index 0a1fb6e0df84ce..f440b6083d71fb 100644 --- a/tensorflow/tools/toolchains/win/20240424/builtin_include_directory_paths_clangcl +++ b/tensorflow/tools/toolchains/win/20240424/builtin_include_directory_paths_clangcl @@ -3,3 +3,5 @@ that clang-cl reported. This file is a dependency of every compilation action an changes to it will be reflected in the action cache key. When some of these paths change, Bazel will make sure to rerun the action, even though none of declared action inputs or the action commandline changes. + + diff --git a/tensorflow/tools/toolchains/win/20240424/builtin_include_directory_paths_msvc b/tensorflow/tools/toolchains/win/20240424/builtin_include_directory_paths_msvc index 55ba44f761e2c1..1380bc62e15b60 100644 --- a/tensorflow/tools/toolchains/win/20240424/builtin_include_directory_paths_msvc +++ b/tensorflow/tools/toolchains/win/20240424/builtin_include_directory_paths_msvc @@ -4,3 +4,4 @@ changes to it will be reflected in the action cache key. When some of these paths change, Bazel will make sure to rerun the action, even though none of declared action inputs or the action commandline changes. + diff --git a/tensorflow/tools/toolchains/win/20240424/toolchain_image_info b/tensorflow/tools/toolchains/win/20240424/toolchain_image_info index 807a14bebbdb44..ffa6a8e33c7933 100644 --- a/tensorflow/tools/toolchains/win/20240424/toolchain_image_info +++ b/tensorflow/tools/toolchains/win/20240424/toolchain_image_info @@ -1,2 +1,2 @@ REPOSITORY TAG DIGEST IMAGE ID CREATED SIZE -gcr.io/tensorflow-testing/tf-win2019-docker-staging latest sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc b601adb43430 8 minutes ago 20.4GB \ No newline at end of file +gcr.io/tensorflow-testing/tf-win2019-rbe latest sha256:d3577d20dea75966faf7fd03479c71462441937df5694259109c2ee1d002a3dd b601adb43430 8 minutes ago 20.4GB \ No newline at end of file diff --git a/tensorflow/tools/toolchains/win/20240424/windows_cc_toolchain_config.bzl b/tensorflow/tools/toolchains/win/20240424/windows_cc_toolchain_config.bzl index 6d8e8af6d50e4a..03ff9b6b30078d 100644 --- a/tensorflow/tools/toolchains/win/20240424/windows_cc_toolchain_config.bzl +++ b/tensorflow/tools/toolchains/win/20240424/windows_cc_toolchain_config.bzl @@ -375,7 +375,6 @@ def _impl(ctx): compiler_param_file_feature = feature( name = "compiler_param_file", - enabled = True, ) copy_dynamic_libraries_to_binary_feature = feature( diff --git a/tensorflow/tools/toolchains/win/BUILD b/tensorflow/tools/toolchains/win/BUILD index 55ae6fb22b81f6..258ca032ecd1ea 100644 --- a/tensorflow/tools/toolchains/win/BUILD +++ b/tensorflow/tools/toolchains/win/BUILD @@ -17,7 +17,7 @@ platform( remote_execution_properties = """ properties:{ name: "container-image" - value: "docker://gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc" + value: "docker://gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:d3577d20dea75966faf7fd03479c71462441937df5694259109c2ee1d002a3dd" } properties:{ name: "OSFamily" @@ -43,7 +43,7 @@ platform( remote_execution_properties = """ properties:{ name: "container-image" - value: "docker://gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc" + value: "docker://gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:d3577d20dea75966faf7fd03479c71462441937df5694259109c2ee1d002a3dd" } properties:{ name: "OSFamily" diff --git a/tensorflow/tools/toolchains/win2022/20241118/BUILD b/tensorflow/tools/toolchains/win2022/20241118/BUILD new file mode 100644 index 00000000000000..7d1ac7d0dfa1f2 --- /dev/null +++ b/tensorflow/tools/toolchains/win2022/20241118/BUILD @@ -0,0 +1,647 @@ +# Copyright 2018 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This becomes the BUILD file for @local_config_cc// under Windows. + +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_toolchain", "cc_toolchain_suite") +load(":armeabi_cc_toolchain_config.bzl", "armeabi_cc_toolchain_config") +load(":windows_cc_toolchain_config.bzl", "cc_toolchain_config") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "malloc", +) + +filegroup( + name = "empty", + srcs = [], +) + +filegroup( + name = "mingw_compiler_files", + srcs = [":builtin_include_directory_paths_mingw"], +) + +filegroup( + name = "clangcl_compiler_files", + srcs = [":builtin_include_directory_paths_clangcl"], +) + +filegroup( + name = "msvc_compiler_files", + srcs = [":builtin_include_directory_paths_msvc"], +) + +# Hardcoded toolchain, legacy behaviour. +cc_toolchain_suite( + name = "toolchain", + toolchains = { + "armeabi-v7a|compiler": ":cc-compiler-armeabi-v7a", + "x64_windows|msvc-cl": ":cc-compiler-x64_windows", + "x64_x86_windows|msvc-cl": ":cc-compiler-x64_x86_windows", + "x64_arm_windows|msvc-cl": ":cc-compiler-x64_arm_windows", + "x64_arm64_windows|msvc-cl": ":cc-compiler-arm64_windows", + "arm64_windows|msvc-cl": ":cc-compiler-arm64_windows", + "x64_windows|msys-gcc": ":cc-compiler-x64_windows_msys", + "x64_windows|mingw-gcc": ":cc-compiler-x64_windows_mingw", + "x64_windows|clang-cl": ":cc-compiler-x64_windows-clang-cl", + "x64_windows_msys": ":cc-compiler-x64_windows_msys", + "x64_windows": ":cc-compiler-x64_windows", + "x64_x86_windows": ":cc-compiler-x64_x86_windows", + "x64_arm_windows": ":cc-compiler-x64_arm_windows", + "x64_arm64_windows": ":cc-compiler-arm64_windows", + "arm64_windows": ":cc-compiler-arm64_windows", + "x64_arm64_windows|clang-cl": ":cc-compiler-arm64_windows-clang-cl", + "arm64_windows|clang-cl": ":cc-compiler-arm64_windows-clang-cl", + "armeabi-v7a": ":cc-compiler-armeabi-v7a", + }, +) + +cc_toolchain( + name = "cc-compiler-x64_windows_msys", + all_files = ":empty", + ar_files = ":empty", + as_files = ":mingw_compiler_files", + compiler_files = ":mingw_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":msys_x64", + toolchain_identifier = "msys_x64", +) + +cc_toolchain_config( + name = "msys_x64", + abi_libc_version = "local", + abi_version = "local", + compiler = "msys-gcc", + cpu = "x64_windows", + cxx_builtin_include_directories = [ + "c:/tools/msys64/usr/", + ], + dbg_mode_debug_flag = "/DEBUG:FULL", + fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", + host_system_name = "local", + target_libc = "msys", + target_system_name = "local", + tool_bin_path = "c:/tools/msys64/usr/bin", + tool_paths = { + "ar": "c:/tools/msys64/usr/bin/ar", + "cpp": "c:/tools/msys64/usr/bin/cpp", + "dwp": "c:/tools/msys64/usr/bin/dwp", + "gcc": "c:/tools/msys64/usr/bin/gcc", + "gcov": "c:/tools/msys64/usr/bin/gcov", + "ld": "c:/tools/msys64/usr/bin/ld", + "nm": "c:/tools/msys64/usr/bin/nm", + "objcopy": "c:/tools/msys64/usr/bin/objcopy", + "objdump": "c:/tools/msys64/usr/bin/objdump", + "strip": "c:/tools/msys64/usr/bin/strip", + }, +) + +toolchain( + name = "cc-toolchain-x64_windows_msys", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + "@bazel_tools//tools/cpp:msys", + ], + target_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_windows_msys", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-x64_windows_mingw", + all_files = ":empty", + ar_files = ":empty", + as_files = ":mingw_compiler_files", + compiler_files = ":mingw_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 0, + toolchain_config = ":msys_x64_mingw", + toolchain_identifier = "msys_x64_mingw", +) + +cc_toolchain_config( + name = "msys_x64_mingw", + abi_libc_version = "local", + abi_version = "local", + compiler = "mingw-gcc", + cpu = "x64_windows", + cxx_builtin_include_directories = [ + "c:/tools/msys64/mingw64/", + ], + dbg_mode_debug_flag = "/DEBUG:FULL", + fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", + host_system_name = "local", + target_libc = "mingw", + target_system_name = "local", + tool_bin_path = "c:/tools/msys64/mingw64/bin", + tool_paths = { + "ar": "c:/tools/msys64/mingw64/bin/ar", + "cpp": "c:/tools/msys64/mingw64/bin/cpp", + "dwp": "c:/tools/msys64/mingw64/bin/dwp", + "gcc": "c:/tools/msys64/mingw64/bin/gcc", + "gcov": "c:/tools/msys64/mingw64/bin/gcov", + "ld": "c:/tools/msys64/mingw64/bin/ld", + "nm": "c:/tools/msys64/mingw64/bin/nm", + "objcopy": "c:/tools/msys64/mingw64/bin/objcopy", + "objdump": "c:/tools/msys64/mingw64/bin/objdump", + "strip": "c:/tools/msys64/mingw64/bin/strip", + }, +) + +toolchain( + name = "cc-toolchain-x64_windows_mingw", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + "@bazel_tools//tools/cpp:mingw", + ], + target_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_windows_mingw", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-x64_windows", + all_files = ":empty", + ar_files = ":empty", + as_files = ":msvc_compiler_files", + compiler_files = ":msvc_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":msvc_x64", + toolchain_identifier = "msvc_x64", +) + +cc_toolchain_config( + name = "msvc_x64", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:X64"], + compiler = "msvc-cl", + cpu = "x64_windows", + cxx_builtin_include_directories = [ + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", + "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + ], + dbg_mode_debug_flag = "/DEBUG:FULL", + default_link_flags = ["/MACHINE:X64"], + fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", + host_system_name = "local", + msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_env_tmp = "C:\\TMP", + msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/lib.exe", + msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/link.exe", + msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/ml64.exe", + supports_parse_showincludes = True, + target_libc = "msvcrt", + target_system_name = "local", + tool_paths = { + "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/lib.exe", + "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/ml64.exe", + "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", + "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/link.exe", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "msvc_x64", +) + +toolchain( + name = "cc-toolchain-x64_windows", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + target_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_windows", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-x64_x86_windows", + all_files = ":empty", + ar_files = ":empty", + as_files = ":msvc_compiler_files", + compiler_files = ":msvc_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":msvc_x64_x86", + toolchain_identifier = "msvc_x64_x86", +) + +cc_toolchain_config( + name = "msvc_x64_x86", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:X86"], + compiler = "msvc-cl", + cpu = "x64_windows", + cxx_builtin_include_directories = [ + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", + "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + ], + dbg_mode_debug_flag = "/DEBUG:FULL", + default_link_flags = ["/MACHINE:X86"], + fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", + host_system_name = "local", + msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x86", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x86;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_env_tmp = "C:\\TMP", + msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/lib.exe", + msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/link.exe", + msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/ml.exe", + supports_parse_showincludes = True, + target_libc = "msvcrt", + target_system_name = "local", + tool_paths = { + "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/lib.exe", + "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/ml.exe", + "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", + "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/link.exe", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "msvc_x64_x86", +) + +toolchain( + name = "cc-toolchain-x64_x86_windows", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + target_compatible_with = [ + "@platforms//cpu:x86_32", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_x86_windows", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-x64_arm_windows", + all_files = ":empty", + ar_files = ":empty", + as_files = ":msvc_compiler_files", + compiler_files = ":msvc_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":msvc_x64_arm", + toolchain_identifier = "msvc_x64_arm", +) + +cc_toolchain_config( + name = "msvc_x64_arm", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:ARM"], + compiler = "msvc-cl", + cpu = "x64_windows", + cxx_builtin_include_directories = [], + dbg_mode_debug_flag = "/DEBUG", + default_link_flags = ["/MACHINE:ARM"], + fastbuild_mode_debug_flag = "/DEBUG", + host_system_name = "local", + msvc_cl_path = "vc_installation_error_arm.bat", + msvc_env_include = "msvc_not_found", + msvc_env_lib = "msvc_not_found", + msvc_env_path = "msvc_not_found", + msvc_env_tmp = "msvc_not_found", + msvc_lib_path = "vc_installation_error_arm.bat", + msvc_link_path = "vc_installation_error_arm.bat", + msvc_ml_path = "vc_installation_error_arm.bat", + supports_parse_showincludes = False, + target_libc = "msvcrt", + target_system_name = "local", + tool_paths = { + "ar": "vc_installation_error_arm.bat", + "ml": "vc_installation_error_arm.bat", + "cpp": "vc_installation_error_arm.bat", + "gcc": "vc_installation_error_arm.bat", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "vc_installation_error_arm.bat", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "msvc_x64_arm", +) + +toolchain( + name = "cc-toolchain-x64_arm_windows", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + target_compatible_with = [ + "@platforms//cpu:arm", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_arm_windows", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-arm64_windows", + all_files = ":empty", + ar_files = ":empty", + as_files = ":msvc_compiler_files", + compiler_files = ":msvc_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":msvc_arm64", + toolchain_identifier = "msvc_arm64", +) + +cc_toolchain_config( + name = "msvc_arm64", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:ARM64"], + compiler = "msvc-cl", + cpu = "x64_windows", + cxx_builtin_include_directories = [], + dbg_mode_debug_flag = "/DEBUG", + default_link_flags = ["/MACHINE:ARM64"], + fastbuild_mode_debug_flag = "/DEBUG", + host_system_name = "local", + msvc_cl_path = "vc_installation_error_arm64.bat", + msvc_env_include = "msvc_not_found", + msvc_env_lib = "msvc_not_found", + msvc_env_path = "msvc_not_found", + msvc_env_tmp = "msvc_not_found", + msvc_lib_path = "vc_installation_error_arm64.bat", + msvc_link_path = "vc_installation_error_arm64.bat", + msvc_ml_path = "vc_installation_error_arm64.bat", + supports_parse_showincludes = False, + target_libc = "msvcrt", + target_system_name = "local", + tool_paths = { + "ar": "vc_installation_error_arm64.bat", + "ml": "vc_installation_error_arm64.bat", + "cpp": "vc_installation_error_arm64.bat", + "gcc": "vc_installation_error_arm64.bat", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "vc_installation_error_arm64.bat", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "msvc_arm64", +) + +toolchain( + name = "cc-toolchain-arm64_windows", + exec_compatible_with = [ + "@platforms//os:windows", + ], + target_compatible_with = [ + "@platforms//cpu:arm64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-arm64_windows", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-x64_windows-clang-cl", + all_files = ":empty", + ar_files = ":empty", + as_files = ":clangcl_compiler_files", + compiler_files = ":clangcl_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":clang_cl_x64", + toolchain_identifier = "clang_cl_x64", +) + +cc_toolchain_config( + name = "clang_cl_x64", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:X64"], + compiler = "clang-cl", + cpu = "x64_windows", + cxx_builtin_include_directories = [ + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", + "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + "C:\\tools\\LLVM\\lib\\clang\\18\\include", + ], + dbg_mode_debug_flag = "/DEBUG", + default_link_flags = [ + "/MACHINE:X64", + "/DEFAULTLIB:clang_rt.builtins-x86_64.lib", + ], + fastbuild_mode_debug_flag = "/DEBUG", + host_system_name = "local", + msvc_cl_path = "C:/tools/LLVM/bin/clang-cl.exe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt;C:\\tools\\LLVM\\lib\\clang\\18\\include", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64;C:\\tools\\LLVM\\lib\\clang\\18\\lib\\windows", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_env_tmp = "C:\\TMP", + msvc_lib_path = "C:/tools/LLVM/bin/llvm-lib.exe", + msvc_link_path = "C:/tools/LLVM/bin/lld-link.exe", + msvc_ml_path = "C:/tools/LLVM/bin/clang-cl.exe", + supports_parse_showincludes = True, + target_libc = "msvcrt", + target_system_name = "local", + tool_paths = { + "ar": "C:/tools/LLVM/bin/llvm-lib.exe", + "ml": "C:/tools/LLVM/bin/clang-cl.exe", + "cpp": "C:/tools/LLVM/bin/clang-cl.exe", + "gcc": "C:/tools/LLVM/bin/clang-cl.exe", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "C:/tools/LLVM/bin/lld-link.exe", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "clang_cl_x64", +) + +toolchain( + name = "cc-toolchain-x64_windows-clang-cl", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + "@bazel_tools//tools/cpp:clang-cl", + ], + target_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_windows-clang-cl", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-arm64_windows-clang-cl", + all_files = ":empty", + ar_files = ":empty", + as_files = ":clangcl_compiler_files", + compiler_files = ":clangcl_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":clang_cl_arm64", + toolchain_identifier = "clang_cl_arm64", +) + +cc_toolchain_config( + name = "clang_cl_arm64", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:ARM64"], + compiler = "clang-cl", + cpu = "arm64_windows", + cxx_builtin_include_directories = [], + dbg_mode_debug_flag = "/DEBUG", + default_link_flags = ["/MACHINE:ARM64"], + fastbuild_mode_debug_flag = "/DEBUG", + host_system_name = "local", + msvc_cl_path = "vc_installation_error_arm64.bat", + msvc_env_include = "clang_cl_not_found", + msvc_env_lib = "clang_cl_not_found", + msvc_env_path = "clang_cl_not_found", + msvc_env_tmp = "clang_cl_not_found", + msvc_lib_path = "vc_installation_error_arm64.bat", + msvc_link_path = "vc_installation_error_arm64.bat", + msvc_ml_path = "vc_installation_error_arm64.bat", + supports_parse_showincludes = False, + target_libc = "msvcrt", + target_system_name = "aarch64-pc-windows-msvc", + tool_paths = { + "ar": "vc_installation_error_arm64.bat", + "ml": "vc_installation_error_arm64.bat", + "cpp": "vc_installation_error_arm64.bat", + "gcc": "vc_installation_error_arm64.bat", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "vc_installation_error_arm64.bat", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "clang_cl_arm64", +) + +toolchain( + name = "cc-toolchain-arm64_windows-clang-cl", + exec_compatible_with = [ + "@platforms//os:windows", + "@bazel_tools//tools/cpp:clang-cl", + ], + target_compatible_with = [ + "@platforms//cpu:arm64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-arm64_windows-clang-cl", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-armeabi-v7a", + all_files = ":empty", + ar_files = ":empty", + as_files = ":empty", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":stub_armeabi-v7a", + toolchain_identifier = "stub_armeabi-v7a", +) + +armeabi_cc_toolchain_config(name = "stub_armeabi-v7a") + +toolchain( + name = "cc-toolchain-armeabi-v7a", + exec_compatible_with = [ + ], + target_compatible_with = [ + "@platforms//cpu:armv7", + "@platforms//os:android", + ], + toolchain = ":cc-compiler-armeabi-v7a", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) diff --git a/tensorflow/tools/toolchains/win2022/20241118/armeabi_cc_toolchain_config.bzl b/tensorflow/tools/toolchains/win2022/20241118/armeabi_cc_toolchain_config.bzl new file mode 100644 index 00000000000000..72ef48ae6d6dfc --- /dev/null +++ b/tensorflow/tools/toolchains/win2022/20241118/armeabi_cc_toolchain_config.bzl @@ -0,0 +1,82 @@ +# Copyright 2019 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Starlark cc_toolchain configuration rule""" + +load( + "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl", + "feature", + "tool_path", +) + +def _impl(ctx): + toolchain_identifier = "stub_armeabi-v7a" + host_system_name = "armeabi-v7a" + target_system_name = "armeabi-v7a" + target_cpu = "armeabi-v7a" + target_libc = "armeabi-v7a" + compiler = "compiler" + abi_version = "armeabi-v7a" + abi_libc_version = "armeabi-v7a" + cc_target_os = None + builtin_sysroot = None + action_configs = [] + + supports_pic_feature = feature(name = "supports_pic", enabled = True) + supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True) + features = [supports_dynamic_linker_feature, supports_pic_feature] + + cxx_builtin_include_directories = [] + artifact_name_patterns = [] + make_variables = [] + + tool_paths = [ + tool_path(name = "ar", path = "/bin/false"), + tool_path(name = "cpp", path = "/bin/false"), + tool_path(name = "dwp", path = "/bin/false"), + tool_path(name = "gcc", path = "/bin/false"), + tool_path(name = "gcov", path = "/bin/false"), + tool_path(name = "ld", path = "/bin/false"), + tool_path(name = "llvm-profdata", path = "/bin/false"), + tool_path(name = "nm", path = "/bin/false"), + tool_path(name = "objcopy", path = "/bin/false"), + tool_path(name = "objdump", path = "/bin/false"), + tool_path(name = "strip", path = "/bin/false"), + ] + + return cc_common.create_cc_toolchain_config_info( + ctx = ctx, + features = features, + action_configs = action_configs, + artifact_name_patterns = artifact_name_patterns, + cxx_builtin_include_directories = cxx_builtin_include_directories, + toolchain_identifier = toolchain_identifier, + host_system_name = host_system_name, + target_system_name = target_system_name, + target_cpu = target_cpu, + target_libc = target_libc, + compiler = compiler, + abi_version = abi_version, + abi_libc_version = abi_libc_version, + tool_paths = tool_paths, + make_variables = make_variables, + builtin_sysroot = builtin_sysroot, + cc_target_os = cc_target_os, + ) + +armeabi_cc_toolchain_config = rule( + implementation = _impl, + attrs = {}, + provides = [CcToolchainConfigInfo], +) diff --git a/tensorflow/tools/toolchains/win2022/20241118/builtin_include_directory_paths_clangcl b/tensorflow/tools/toolchains/win2022/20241118/builtin_include_directory_paths_clangcl new file mode 100644 index 00000000000000..f440b6083d71fb --- /dev/null +++ b/tensorflow/tools/toolchains/win2022/20241118/builtin_include_directory_paths_clangcl @@ -0,0 +1,7 @@ +This file is generated by cc_configure and contains builtin include directories +that clang-cl reported. This file is a dependency of every compilation action and +changes to it will be reflected in the action cache key. When some of these +paths change, Bazel will make sure to rerun the action, even though none of +declared action inputs or the action commandline changes. + + diff --git a/tensorflow/tools/toolchains/win2022/20241118/builtin_include_directory_paths_msvc b/tensorflow/tools/toolchains/win2022/20241118/builtin_include_directory_paths_msvc new file mode 100644 index 00000000000000..1380bc62e15b60 --- /dev/null +++ b/tensorflow/tools/toolchains/win2022/20241118/builtin_include_directory_paths_msvc @@ -0,0 +1,7 @@ +This file is generated by cc_configure and contains builtin include directories +that msvc reported. This file is a dependency of every compilation action and +changes to it will be reflected in the action cache key. When some of these +paths change, Bazel will make sure to rerun the action, even though none of +declared action inputs or the action commandline changes. + + diff --git a/tensorflow/tools/toolchains/win2022/20241118/windows_cc_toolchain_config.bzl b/tensorflow/tools/toolchains/win2022/20241118/windows_cc_toolchain_config.bzl new file mode 100644 index 00000000000000..03ff9b6b30078d --- /dev/null +++ b/tensorflow/tools/toolchains/win2022/20241118/windows_cc_toolchain_config.bzl @@ -0,0 +1,1442 @@ +# Copyright 2019 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Starlark cc_toolchain configuration rule for Windows""" + +load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES") +load( + "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl", + "action_config", + "artifact_name_pattern", + "env_entry", + "env_set", + "feature", + "flag_group", + "flag_set", + "tool", + "tool_path", + "variable_with_value", + "with_feature_set", +) + +all_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.clif_match, + ACTION_NAMES.lto_backend, +] + +all_cpp_compile_actions = [ + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.clif_match, +] + +preprocessor_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.clif_match, +] + +codegen_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, +] + +all_link_actions = [ + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, +] + +def _use_msvc_toolchain(ctx): + return ctx.attr.cpu in ["x64_windows", "arm64_windows"] and (ctx.attr.compiler == "msvc-cl" or ctx.attr.compiler == "clang-cl") + +def _impl(ctx): + if _use_msvc_toolchain(ctx): + artifact_name_patterns = [ + artifact_name_pattern( + category_name = "object_file", + prefix = "", + extension = ".obj", + ), + artifact_name_pattern( + category_name = "static_library", + prefix = "", + extension = ".lib", + ), + artifact_name_pattern( + category_name = "alwayslink_static_library", + prefix = "", + extension = ".lo.lib", + ), + artifact_name_pattern( + category_name = "executable", + prefix = "", + extension = ".exe", + ), + artifact_name_pattern( + category_name = "dynamic_library", + prefix = "", + extension = ".dll", + ), + artifact_name_pattern( + category_name = "interface_library", + prefix = "", + extension = ".if.lib", + ), + ] + else: + artifact_name_patterns = [ + artifact_name_pattern( + category_name = "executable", + prefix = "", + extension = ".exe", + ), + ] + + if _use_msvc_toolchain(ctx): + cpp_link_nodeps_dynamic_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_nodeps_dynamic_library, + implies = [ + "nologo", + "shared_flag", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + "has_configured_linker_path", + "def_file", + ], + tools = [tool(path = ctx.attr.msvc_link_path)], + ) + + cpp_link_static_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_static_library, + implies = [ + "nologo", + "archiver_flags", + "input_param_flags", + "linker_param_file", + "msvc_env", + ], + tools = [tool(path = ctx.attr.msvc_lib_path)], + ) + + assemble_action = action_config( + action_name = ACTION_NAMES.assemble, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "sysroot", + ], + tools = [tool(path = ctx.attr.msvc_ml_path)], + ) + + preprocess_assemble_action = action_config( + action_name = ACTION_NAMES.preprocess_assemble, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "sysroot", + ], + tools = [tool(path = ctx.attr.msvc_ml_path)], + ) + + c_compile_action = action_config( + action_name = ACTION_NAMES.c_compile, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "user_compile_flags", + "sysroot", + ], + tools = [tool(path = ctx.attr.msvc_cl_path)], + ) + + linkstamp_compile_action = action_config( + action_name = ACTION_NAMES.linkstamp_compile, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "default_compile_flags", + "nologo", + "msvc_env", + "user_compile_flags", + "sysroot", + "unfiltered_compile_flags", + ], + tools = [tool(path = ctx.attr.msvc_cl_path)], + ) + + cpp_compile_action = action_config( + action_name = ACTION_NAMES.cpp_compile, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "user_compile_flags", + "sysroot", + ], + tools = [tool(path = ctx.attr.msvc_cl_path)], + ) + + cpp_link_executable_action = action_config( + action_name = ACTION_NAMES.cpp_link_executable, + implies = [ + "nologo", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + ], + tools = [tool(path = ctx.attr.msvc_link_path)], + ) + + cpp_link_dynamic_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_dynamic_library, + implies = [ + "nologo", + "shared_flag", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + "has_configured_linker_path", + "def_file", + ], + tools = [tool(path = ctx.attr.msvc_link_path)], + ) + + action_configs = [ + assemble_action, + preprocess_assemble_action, + c_compile_action, + linkstamp_compile_action, + cpp_compile_action, + cpp_link_executable_action, + cpp_link_dynamic_library_action, + cpp_link_nodeps_dynamic_library_action, + cpp_link_static_library_action, + ] + else: + action_configs = [] + + if _use_msvc_toolchain(ctx): + msvc_link_env_feature = feature( + name = "msvc_link_env", + env_sets = [ + env_set( + actions = all_link_actions + + [ACTION_NAMES.cpp_link_static_library], + env_entries = [env_entry(key = "LIB", value = ctx.attr.msvc_env_lib)], + ), + ], + ) + + shared_flag_feature = feature( + name = "shared_flag", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [flag_group(flags = ["/DLL"])], + ), + ], + ) + + determinism_feature = feature( + name = "determinism", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "/wd4117", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + ] + (["-Wno-builtin-macro-redefined"] if ctx.attr.compiler == "clang-cl" else []), + ), + ], + ), + ], + ) + + sysroot_feature = feature( + name = "sysroot", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = ["--sysroot=%{sysroot}"], + iterate_over = "sysroot", + expand_if_available = "sysroot", + ), + ], + ), + ], + ) + + unfiltered_compile_flags_feature = feature( + name = "unfiltered_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["%{unfiltered_compile_flags}"], + iterate_over = "unfiltered_compile_flags", + expand_if_available = "unfiltered_compile_flags", + ), + ], + ), + ], + ) + + archive_param_file_feature = feature( + name = "archive_param_file", + enabled = True, + ) + + compiler_param_file_feature = feature( + name = "compiler_param_file", + ) + + copy_dynamic_libraries_to_binary_feature = feature( + name = "copy_dynamic_libraries_to_binary", + ) + + input_param_flags_feature = feature( + name = "input_param_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = ["/IMPLIB:%{interface_library_output_path}"], + expand_if_available = "interface_library_output_path", + ), + ], + ), + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["%{libopts}"], + iterate_over = "libopts", + expand_if_available = "libopts", + ), + ], + ), + flag_set( + actions = all_link_actions + + [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + iterate_over = "libraries_to_link", + flag_groups = [ + flag_group( + iterate_over = "libraries_to_link.object_files", + flag_groups = [flag_group(flags = ["%{libraries_to_link.object_files}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file_group", + ), + ), + flag_group( + flag_groups = [flag_group(flags = ["%{libraries_to_link.name}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file", + ), + ), + flag_group( + flag_groups = [flag_group(flags = ["%{libraries_to_link.name}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "interface_library", + ), + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["%{libraries_to_link.name}"], + expand_if_false = "libraries_to_link.is_whole_archive", + ), + flag_group( + flags = ["/WHOLEARCHIVE:%{libraries_to_link.name}"], + expand_if_true = "libraries_to_link.is_whole_archive", + ), + ], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "static_library", + ), + ), + ], + expand_if_available = "libraries_to_link", + ), + ], + ), + ], + ) + + fastbuild_feature = feature( + name = "fastbuild", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Od", "/Z7"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = [ctx.attr.fastbuild_mode_debug_flag, "/INCREMENTAL:NO"], + ), + ], + ), + ], + implies = ["generate_pdb_file"], + ) + + user_compile_flags_feature = feature( + name = "user_compile_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["%{user_compile_flags}"], + iterate_over = "user_compile_flags", + expand_if_available = "user_compile_flags", + ), + ], + ), + ], + ) + + archiver_flags_feature = feature( + name = "archiver_flags", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + flags = ["/OUT:%{output_execpath}"], + expand_if_available = "output_execpath", + ), + flag_group( + flags = ctx.attr.archiver_flags, + ), + ], + ), + ], + ) + + default_link_flags_feature = feature( + name = "default_link_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ctx.attr.default_link_flags)], + ), + ], + ) + + static_link_msvcrt_feature = feature( + name = "static_link_msvcrt", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MT"])], + with_features = [with_feature_set(not_features = ["dbg"])], + ), + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MTd"])], + with_features = [with_feature_set(features = ["dbg"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmt.lib"])], + with_features = [with_feature_set(not_features = ["dbg"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmtd.lib"])], + with_features = [with_feature_set(features = ["dbg"])], + ), + ], + ) + + dynamic_link_msvcrt_feature = feature( + name = "dynamic_link_msvcrt", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MD"])], + with_features = [with_feature_set(not_features = ["dbg", "static_link_msvcrt"])], + ), + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MDd"])], + with_features = [with_feature_set(features = ["dbg"], not_features = ["static_link_msvcrt"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrt.lib"])], + with_features = [with_feature_set(not_features = ["dbg", "static_link_msvcrt"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrtd.lib"])], + with_features = [with_feature_set(features = ["dbg"], not_features = ["static_link_msvcrt"])], + ), + ], + ) + + dbg_feature = feature( + name = "dbg", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Od", "/Z7"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = [ctx.attr.dbg_mode_debug_flag, "/INCREMENTAL:NO"], + ), + ], + ), + ], + implies = ["generate_pdb_file"], + ) + + opt_feature = feature( + name = "opt", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/O2"])], + ), + ], + implies = ["frame_pointer"], + ) + + supports_interface_shared_libraries_feature = feature( + name = "supports_interface_shared_libraries", + enabled = True, + ) + + user_link_flags_feature = feature( + name = "user_link_flags", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["%{user_link_flags}"], + iterate_over = "user_link_flags", + expand_if_available = "user_link_flags", + ), + ], + ), + ], + ) + + default_compile_flags_feature = feature( + name = "default_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = [ + "/DCOMPILER_MSVC", + "/DNOMINMAX", + "/D_WIN32_WINNT=0x0601", + "/D_CRT_SECURE_NO_DEPRECATE", + "/D_CRT_SECURE_NO_WARNINGS", + "/bigobj", + "/Zm500", + "/EHsc", + "/wd4351", + "/wd4291", + "/wd4250", + "/wd4996", + ], + ), + ], + ), + ], + ) + + msvc_compile_env_feature = feature( + name = "msvc_compile_env", + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ], + env_entries = [env_entry(key = "INCLUDE", value = ctx.attr.msvc_env_include)], + ), + ], + ) + + preprocessor_defines_feature = feature( + name = "preprocessor_defines", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group( + flags = ["/D%{preprocessor_defines}"], + iterate_over = "preprocessor_defines", + ), + ], + ), + ], + ) + + generate_pdb_file_feature = feature( + name = "generate_pdb_file", + ) + + output_execpath_flags_feature = feature( + name = "output_execpath_flags", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["/OUT:%{output_execpath}"], + expand_if_available = "output_execpath", + ), + ], + ), + ], + ) + + disable_assertions_feature = feature( + name = "disable_assertions", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/DNDEBUG"])], + with_features = [with_feature_set(features = ["opt"])], + ), + ], + ) + + has_configured_linker_path_feature = feature(name = "has_configured_linker_path") + + supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True) + + no_stripping_feature = feature(name = "no_stripping") + + linker_param_file_feature = feature( + name = "linker_param_file", + flag_sets = [ + flag_set( + actions = all_link_actions + + [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + flags = ["@%{linker_param_file}"], + expand_if_available = "linker_param_file", + ), + ], + ), + ], + ) + + ignore_noisy_warnings_feature = feature( + name = "ignore_noisy_warnings", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = [flag_group(flags = ["/ignore:4221"])], + ), + ], + ) + + no_legacy_features_feature = feature(name = "no_legacy_features") + + parse_showincludes_feature = feature( + name = "parse_showincludes", + enabled = ctx.attr.supports_parse_showincludes, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_header_parsing, + ], + flag_groups = [flag_group(flags = ["/showIncludes"])], + ), + ], + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_header_parsing, + ], + # Force English (and thus a consistent locale) output so that Bazel can parse + # the /showIncludes output without having to guess the encoding. + env_entries = [env_entry(key = "VSLANG", value = "1033")], + ), + ], + ) + + # MSVC does not emit .d files. + no_dotd_file_feature = feature( + name = "no_dotd_file", + enabled = True, + ) + + treat_warnings_as_errors_feature = feature( + name = "treat_warnings_as_errors", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile] + all_link_actions, + flag_groups = [flag_group(flags = ["/WX"])], + ), + ], + ) + + windows_export_all_symbols_feature = feature(name = "windows_export_all_symbols") + + no_windows_export_all_symbols_feature = feature(name = "no_windows_export_all_symbols") + + include_paths_feature = feature( + name = "include_paths", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group( + flags = ["/I%{quote_include_paths}"], + iterate_over = "quote_include_paths", + ), + flag_group( + flags = ["/I%{include_paths}"], + iterate_over = "include_paths", + ), + flag_group( + flags = ["/I%{system_include_paths}"], + iterate_over = "system_include_paths", + ), + ], + ), + ], + ) + + external_include_paths_feature = feature( + name = "external_include_paths", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.clif_match, + ACTION_NAMES.objc_compile, + ACTION_NAMES.objcpp_compile, + ], + flag_groups = [ + flag_group( + flags = ["/external:I", "%{external_include_paths}"], + iterate_over = "external_include_paths", + expand_if_available = "external_include_paths", + ), + ], + ), + ], + ) + + linkstamps_feature = feature( + name = "linkstamps", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["%{linkstamp_paths}"], + iterate_over = "linkstamp_paths", + expand_if_available = "linkstamp_paths", + ), + ], + ), + ], + ) + + targets_windows_feature = feature( + name = "targets_windows", + enabled = True, + implies = ["copy_dynamic_libraries_to_binary"], + ) + + linker_subsystem_flag_feature = feature( + name = "linker_subsystem_flag", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/SUBSYSTEM:CONSOLE"])], + ), + ], + ) + + frame_pointer_feature = feature( + name = "frame_pointer", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Oy-"])], + ), + ], + ) + + compiler_output_flags_feature = feature( + name = "compiler_output_flags", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.assemble], + flag_groups = [ + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fo%{output_file}", "/Zi"], + expand_if_available = "output_file", + expand_if_not_available = "output_assembly_file", + ), + ], + expand_if_not_available = "output_preprocess_file", + ), + ], + ), + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fo%{output_file}"], + expand_if_not_available = "output_preprocess_file", + ), + ], + expand_if_available = "output_file", + expand_if_not_available = "output_assembly_file", + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fa%{output_file}"], + expand_if_available = "output_assembly_file", + ), + ], + expand_if_available = "output_file", + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["/P", "/Fi%{output_file}"], + expand_if_available = "output_preprocess_file", + ), + ], + expand_if_available = "output_file", + ), + ], + ), + ], + ) + + nologo_feature = feature( + name = "nologo", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_static_library, + ], + flag_groups = [flag_group(flags = ["/nologo"])], + ), + ], + ) + + smaller_binary_feature = feature( + name = "smaller_binary", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Gy", "/Gw"])], + with_features = [with_feature_set(features = ["opt"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/OPT:ICF", "/OPT:REF"])], + with_features = [with_feature_set(features = ["opt"])], + ), + ], + ) + + compiler_input_flags_feature = feature( + name = "compiler_input_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["/c", "%{source_file}"], + expand_if_available = "source_file", + ), + ], + ), + ], + ) + + def_file_feature = feature( + name = "def_file", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["/DEF:%{def_file_path}", "/ignore:4070"], + expand_if_available = "def_file_path", + ), + ], + ), + ], + ) + + msvc_env_feature = feature( + name = "msvc_env", + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_static_library, + ], + env_entries = [ + env_entry(key = "PATH", value = ctx.attr.msvc_env_path), + env_entry(key = "TMP", value = ctx.attr.msvc_env_tmp), + env_entry(key = "TEMP", value = ctx.attr.msvc_env_tmp), + ], + ), + ], + implies = ["msvc_compile_env", "msvc_link_env"], + ) + features = [ + no_legacy_features_feature, + nologo_feature, + has_configured_linker_path_feature, + no_stripping_feature, + targets_windows_feature, + copy_dynamic_libraries_to_binary_feature, + default_compile_flags_feature, + msvc_env_feature, + msvc_compile_env_feature, + msvc_link_env_feature, + include_paths_feature, + external_include_paths_feature, + preprocessor_defines_feature, + parse_showincludes_feature, + no_dotd_file_feature, + generate_pdb_file_feature, + shared_flag_feature, + linkstamps_feature, + output_execpath_flags_feature, + archiver_flags_feature, + input_param_flags_feature, + linker_subsystem_flag_feature, + user_link_flags_feature, + default_link_flags_feature, + linker_param_file_feature, + static_link_msvcrt_feature, + dynamic_link_msvcrt_feature, + dbg_feature, + fastbuild_feature, + opt_feature, + frame_pointer_feature, + disable_assertions_feature, + determinism_feature, + treat_warnings_as_errors_feature, + smaller_binary_feature, + ignore_noisy_warnings_feature, + user_compile_flags_feature, + sysroot_feature, + unfiltered_compile_flags_feature, + archive_param_file_feature, + compiler_param_file_feature, + compiler_output_flags_feature, + compiler_input_flags_feature, + def_file_feature, + windows_export_all_symbols_feature, + no_windows_export_all_symbols_feature, + supports_dynamic_linker_feature, + supports_interface_shared_libraries_feature, + ] + else: + targets_windows_feature = feature( + name = "targets_windows", + implies = ["copy_dynamic_libraries_to_binary"], + enabled = True, + ) + + copy_dynamic_libraries_to_binary_feature = feature(name = "copy_dynamic_libraries_to_binary") + + gcc_env_feature = feature( + name = "gcc_env", + enabled = True, + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_static_library, + ], + env_entries = [ + env_entry(key = "PATH", value = ctx.attr.tool_bin_path), + ], + ), + ], + ) + + default_compile_flags_feature = feature( + name = "default_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = [flag_group(flags = ["-std=gnu++14"])], + ), + ], + ) + + default_link_flags_feature = feature( + name = "default_link_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["-lstdc++"])], + ), + ], + ) + + supports_dynamic_linker_feature = feature( + name = "supports_dynamic_linker", + enabled = True, + ) + + dbg_feature = feature( + name = "dbg", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["-g", "-Og"])], + ), + ], + ) + + opt_feature = feature( + name = "opt", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = [ + "-g0", + "-O3", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["-Wl,--gc-sections"])], + ), + ], + ) + + if ctx.attr.cpu == "x64_windows" and ctx.attr.compiler == "mingw-gcc": + archive_param_file_feature = feature( + name = "archive_param_file", + enabled = True, + ) + + compiler_param_file_feature = feature( + name = "compiler_param_file", + ) + + features = [ + targets_windows_feature, + copy_dynamic_libraries_to_binary_feature, + gcc_env_feature, + default_compile_flags_feature, + archive_param_file_feature, + compiler_param_file_feature, + default_link_flags_feature, + supports_dynamic_linker_feature, + dbg_feature, + opt_feature, + ] + else: + supports_pic_feature = feature( + name = "supports_pic", + enabled = True, + ) + + sysroot_feature = feature( + name = "sysroot", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = ["--sysroot=%{sysroot}"], + expand_if_available = "sysroot", + ), + ], + ), + ], + ) + + fdo_optimize_feature = feature( + name = "fdo_optimize", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "-fprofile-use=%{fdo_profile_path}", + "-fprofile-correction", + ], + expand_if_available = "fdo_profile_path", + ), + ], + ), + ], + provides = ["profile"], + ) + + treat_warnings_as_errors_feature = feature( + name = "treat_warnings_as_errors", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["-Werror"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["-Wl,-fatal-warnings"])], + ), + ], + ) + + user_compile_flags_feature = feature( + name = "user_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = ["%{user_compile_flags}"], + iterate_over = "user_compile_flags", + expand_if_available = "user_compile_flags", + ), + ], + ), + ], + ) + + features = [ + targets_windows_feature, + copy_dynamic_libraries_to_binary_feature, + gcc_env_feature, + supports_pic_feature, + default_compile_flags_feature, + default_link_flags_feature, + fdo_optimize_feature, + supports_dynamic_linker_feature, + dbg_feature, + opt_feature, + user_compile_flags_feature, + treat_warnings_as_errors_feature, + sysroot_feature, + ] + + tool_paths = [ + tool_path(name = name, path = path) + for name, path in ctx.attr.tool_paths.items() + ] + + return cc_common.create_cc_toolchain_config_info( + ctx = ctx, + features = features, + action_configs = action_configs, + artifact_name_patterns = artifact_name_patterns, + cxx_builtin_include_directories = ctx.attr.cxx_builtin_include_directories, + toolchain_identifier = ctx.attr.toolchain_identifier, + host_system_name = ctx.attr.host_system_name, + target_system_name = ctx.attr.target_system_name, + target_cpu = ctx.attr.cpu, + target_libc = ctx.attr.target_libc, + compiler = ctx.attr.compiler, + abi_version = ctx.attr.abi_version, + abi_libc_version = ctx.attr.abi_libc_version, + tool_paths = tool_paths, + ) + +cc_toolchain_config = rule( + implementation = _impl, + attrs = { + "cpu": attr.string(mandatory = True), + "compiler": attr.string(), + "toolchain_identifier": attr.string(), + "host_system_name": attr.string(), + "target_system_name": attr.string(), + "target_libc": attr.string(), + "abi_version": attr.string(), + "abi_libc_version": attr.string(), + "tool_paths": attr.string_dict(), + "cxx_builtin_include_directories": attr.string_list(), + "archiver_flags": attr.string_list(default = []), + "default_link_flags": attr.string_list(default = []), + "msvc_env_tmp": attr.string(default = "msvc_not_found"), + "msvc_env_path": attr.string(default = "msvc_not_found"), + "msvc_env_include": attr.string(default = "msvc_not_found"), + "msvc_env_lib": attr.string(default = "msvc_not_found"), + "msvc_cl_path": attr.string(default = "vc_installation_error.bat"), + "msvc_ml_path": attr.string(default = "vc_installation_error.bat"), + "msvc_link_path": attr.string(default = "vc_installation_error.bat"), + "msvc_lib_path": attr.string(default = "vc_installation_error.bat"), + "dbg_mode_debug_flag": attr.string(), + "fastbuild_mode_debug_flag": attr.string(), + "tool_bin_path": attr.string(default = "not_found"), + "supports_parse_showincludes": attr.bool(), + }, + provides = [CcToolchainConfigInfo], +) diff --git a/tensorflow/tools/toolchains/win2022/BUILD b/tensorflow/tools/toolchains/win2022/BUILD new file mode 100644 index 00000000000000..82434f82ddbdd3 --- /dev/null +++ b/tensorflow/tools/toolchains/win2022/BUILD @@ -0,0 +1,37 @@ +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +java_runtime( + name = "windows_jdk8", + srcs = [], + java_home = "C:/openjdk", +) + +# Register a Windows 2022 (Clang) platform. +# Note that while this does support RBE, the current pool size is tiny, +# and this platform is meant to be used as a non-RBE one, for now. +platform( + name = "windows_ltsc2022_clang", + constraint_values = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + "@bazel_tools//tools/cpp:clang-cl", + ], + remote_execution_properties = """ + properties:{ + name: "container-image" + value: "docker://gcr.io/tensorflow-testing/tf-win2022@sha256:915cb093630432c38b028f56bd31116a5559ebbc688d427b6092d86828ae03bc" + } + properties:{ + name: "OSFamily" + value: "Windows" + } + properties:{ + name: "Pool" value: "win2022" + } + properties:{ + name: "dockerNetwork" value: "off" + } + """, +) diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index a12236b670377e..0e48711de1bbe8 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -40,6 +40,7 @@ load("//third_party/jpeg:workspace.bzl", jpeg = "repo") load("//third_party/kissfft:workspace.bzl", kissfft = "repo") load("//third_party/libprotobuf_mutator:workspace.bzl", libprotobuf_mutator = "repo") load("//third_party/llvm:setup.bzl", "llvm_setup") +load("//third_party/nanobind:workspace.bzl", nanobind = "repo") load("//third_party/nasm:workspace.bzl", nasm = "repo") load("//third_party/opencl_headers:workspace.bzl", opencl_headers = "repo") load("//third_party/pasta:workspace.bzl", pasta = "repo") @@ -47,6 +48,7 @@ load("//third_party/py:python_configure.bzl", "python_configure") load("//third_party/py/ml_dtypes:workspace.bzl", ml_dtypes = "repo") load("//third_party/pybind11_abseil:workspace.bzl", pybind11_abseil = "repo") load("//third_party/pybind11_bazel:workspace.bzl", pybind11_bazel = "repo") +load("//third_party/robin_map:workspace.bzl", robin_map = "repo") load("//third_party/ruy:workspace.bzl", ruy = "repo") load("//third_party/shardy:workspace.bzl", shardy = "repo") load("//third_party/sobol_data:workspace.bzl", sobol_data = "repo") @@ -78,11 +80,13 @@ def _initialize_third_party(): kissfft() libprotobuf_mutator() ml_dtypes() + nanobind() nasm() opencl_headers() pasta() pybind11_abseil() pybind11_bazel() + robin_map() ruy() shardy() sobol_data() @@ -150,18 +154,18 @@ def _tf_repositories(): # LINT.IfChange tf_http_archive( name = "XNNPACK", - sha256 = "3306f4178c8594b689165d385e644f03a3154c3be044f6ae36dd170fbf182cf5", - strip_prefix = "XNNPACK-983d013300f19fd3f4e33220b6401408e97a8d12", - urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/983d013300f19fd3f4e33220b6401408e97a8d12.zip"), + sha256 = "2a33eb922e6a4b55dfe9332ac61c8d4d128ae8f9e24e873e756a474e983d50a1", + strip_prefix = "XNNPACK-02764b305b430aec42c3df85ba32b9a3f8d6e3d4", + urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/02764b305b430aec42c3df85ba32b9a3f8d6e3d4.zip"), ) # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) # XNNPack dependency. tf_http_archive( name = "KleidiAI", - sha256 = "ad37707084a6d4ff41be10cbe8540c75bea057ba79d0de6c367c1bfac6ba0852", - strip_prefix = "kleidiai-40a926833857fb64786e02f97703e42b1537cb57", - urls = tf_mirror_urls("https://gitlab.arm.com/kleidi/kleidiai/-/archive/40a926833857fb64786e02f97703e42b1537cb57/kleidiai-40a926833857fb64786e02f97703e42b1537cb57.zip"), + sha256 = "8ba8cdb9f945941174d34d10eb4ad158ad1cbc1aef259de5ad992b0bbe85861f", + strip_prefix = "kleidiai-7e8c4baf953227fa447a2f345e5d6491a504aa56", + urls = tf_mirror_urls("https://gitlab.arm.com/kleidi/kleidiai/-/archive/7e8c4baf953227fa447a2f345e5d6491a504aa56/kleidiai-7e8c4baf953227fa447a2f345e5d6491a504aa56.zip"), ) tf_http_archive( @@ -229,6 +233,10 @@ def _tf_repositories(): "//third_party/mkl_dnn:onednn_acl_fp32_bf16_reorder.patch", "//third_party/mkl_dnn:onednn_acl_bf16_capability_detection_for_ubuntu20.04.patch", "//third_party/mkl_dnn:onednn_acl_indirect_conv.patch", + "//third_party/mkl_dnn:onednn_acl_allow_blocked_weight_format_for_matmul_primitive.patch", + "//third_party/mkl_dnn:onednn_acl_fix_segfault_during_postop_execute.patch", + "//third_party/mkl_dnn:onednn_acl_add_bf16_platform_support_check.patch", + "//third_party/mkl_dnn:onednn_acl_add_sbgemm_matmul_primitive_definition.patch", ], sha256 = "2f76b407ef8893cca71340f88cd800019a1f14f8ac1bbdbb89a84be1370b52e3", strip_prefix = "oneDNN-3.2.1", @@ -635,14 +643,6 @@ def _tf_repositories(): ) # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/neon2sse.cmake) - tf_http_archive( - name = "double_conversion", - sha256 = "3dbcdf186ad092a8b71228a5962009b5c96abde9a315257a3452eb988414ea3b", - strip_prefix = "double-conversion-3.2.0", - system_build_file = "//third_party/systemlibs:double_conversion.BUILD", - urls = tf_mirror_urls("https://github.com/google/double-conversion/archive/v3.2.0.tar.gz"), - ) - tf_http_archive( name = "tflite_mobilenet_float", build_file = "//third_party:tflite_mobilenet_float.BUILD", diff --git a/third_party/flatbuffers/build_defs.bzl b/third_party/flatbuffers/build_defs.bzl index 8f4aaa7a646781..364163ee70a1d4 100644 --- a/third_party/flatbuffers/build_defs.bzl +++ b/third_party/flatbuffers/build_defs.bzl @@ -415,6 +415,7 @@ def flatbuffer_py_library( name, srcs, deps = [], + visibility = None, include_paths = []): """A py_library with the generated reader/writers for the given schema. @@ -465,6 +466,7 @@ def flatbuffer_py_library( deps = deps + [ "@flatbuffers//:runtime_py", ], + visibility = visibility, ) def flatbuffer_java_library( diff --git a/third_party/gpus/crosstool/BUILD.rocm.tpl b/third_party/gpus/crosstool/BUILD.rocm.tpl index 6c1523e3e46929..36a50e5f40d058 100644 --- a/third_party/gpus/crosstool/BUILD.rocm.tpl +++ b/third_party/gpus/crosstool/BUILD.rocm.tpl @@ -110,7 +110,7 @@ filegroup( ) filegroup( - name = "crosstool_wrapper_driver_is_not_gcc", - srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"], + name = "crosstool_wrapper_driver_is_not_gcc", + srcs = [":clang/bin/crosstool_wrapper_driver_is_not_gcc"], + data = ["@local_config_rocm//rocm:all_files"], ) - diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index 5d708a68f03715..d9e40de81a9b07 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -24,7 +24,7 @@ import pipes # Template values set by rocm_configure.bzl. CPU_COMPILER = ('%{cpu_compiler}') -USE_CLANG = ('%{compiler}' == 'clang') +USE_CLANG = ('%{compiler_is_clang}' == 'True') HOST_COMPILER_PATH = ('%{host_compiler_path}') HIPCC_PATH = '%{hipcc_path}' @@ -192,6 +192,7 @@ def InvokeHipcc(argv, log=False): hipccopts += defines hipccopts += std_options hipccopts += m_options + hipccopts += ' --rocm-path="%{rocm_path}" ' if depfiles: # Generate the dependency file diff --git a/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl index 510235d801de4e..d8f125fa3d3253 100644 --- a/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -21,12 +25,14 @@ cc_library( name = "cublas", visibility = ["//visibility:public"], %{comment}deps = [":cublas_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/cublas/lib"), ) cc_library( name = "cublasLt", visibility = ["//visibility:public"], %{comment}deps = [":cublasLt_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/cublas/lib"), ) cc_library( diff --git a/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl index 04d2de148c78c0..fabb310001cd39 100644 --- a/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -36,6 +40,7 @@ cc_library( %{comment}}) + [ %{comment}":cudart_shared_library", %{comment}], + %{comment}linkopts = cuda_rpath_flags("nvidia/cuda_runtime/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl index 165c5b1579e73f..c3701a6241243d 100644 --- a/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -58,6 +62,7 @@ cc_library( %{comment}"@cuda_nvrtc//:nvrtc", %{comment}":cudnn_main", %{comment}], + %{comment}linkopts = cuda_rpath_flags("nvidia/cudnn/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl index 7f36054a51bb5b..4e8bcbd84e0327 100644 --- a/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -65,6 +69,7 @@ cc_library( %{comment}"@cuda_nvrtc//:nvrtc", %{comment}":cudnn_main", %{comment}], + %{comment}linkopts = cuda_rpath_flags("nvidia/cudnn/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl index 48ccb0ea3cd197..2e55a742d54967 100644 --- a/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -13,6 +17,7 @@ cc_import( cc_library( name = "cufft", %{comment}deps = [":cufft_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/cufft/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl index 3991b486195bc5..16d6991b584154 100644 --- a/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl @@ -1,5 +1,10 @@ licenses(["restricted"]) # NVIDIA proprietary license load("@local_config_cuda//cuda:build_defs.bzl", "if_version_equal_or_greater_than") +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) + exports_files([ "version.txt", ]) @@ -13,6 +18,7 @@ cc_import( cc_library( name = "cupti", %{comment}deps = [":cupti_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/cuda_cupti/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl index 50e5a8f18a96fd..746503fcf22229 100644 --- a/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -13,6 +17,7 @@ cc_import( cc_library( name = "curand", %{comment}deps = [":curand_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/curand/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl index 943a08ebeb96e1..30bacf07eebda2 100644 --- a/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -19,6 +23,7 @@ cc_import( cc_library( name = "cusolver", %{comment}deps = [":cusolver_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/cusolver/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl index 46b24366ce1c04..b7765ab22508dc 100644 --- a/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -14,6 +18,7 @@ cc_import( cc_library( name = "cusparse", %{comment}deps = [":cusparse_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/cusparse/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl index 0494008e7924f3..5be8d6ef2408ba 100644 --- a/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -13,6 +17,7 @@ cc_import( cc_library( name = "nvjitlink", %{comment}deps = [":nvjitlink_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/nvjitlink/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl index de18489b455b79..fea4c5d7ce7ed5 100644 --- a/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl @@ -1,4 +1,9 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) + %{multiline_comment} cc_import( name = "nvrtc_main", @@ -16,5 +21,6 @@ cc_library( %{comment}":nvrtc_main", %{comment}":nvrtc_builtins", %{comment}], + %{comment}linkopts = cuda_rpath_flags("nvidia/cuda_nvrtc/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/gpus/rocm/BUILD.tpl b/third_party/gpus/rocm/BUILD.tpl index aa3688e335df37..7ebf2773eb48b1 100644 --- a/third_party/gpus/rocm/BUILD.tpl +++ b/third_party/gpus/rocm/BUILD.tpl @@ -1,8 +1,22 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") +load("@local_config_rocm//rocm:build_defs.bzl", "rocm_version_number", "select_threshold") licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like -package(default_visibility = ["//visibility:public"]) +package(default_visibility = ["//visibility:private"]) + +bool_flag( + name = "use_rocm_hermetic_rpath", + build_setting_default = False, +) + +config_setting( + name = "build_hermetic", + flag_values = { + ":use_rocm_hermetic_rpath": "True", + }, +) config_setting( name = "using_hipcc", @@ -12,171 +26,434 @@ config_setting( ) cc_library( - name = "rocm_headers", + name = "config", hdrs = [ - "rocm/rocm_config.h", - %{rocm_headers} + "rocm_config/rocm_config.h", ], + include_prefix = "rocm", + strip_include_prefix = "rocm_config", +) + +cc_library( + name = "config_hermetic", + hdrs = [ + "rocm_config_hermetic/rocm_config.h", + ], + include_prefix = "rocm", + strip_include_prefix = "rocm_config_hermetic", +) + +cc_library( + name = "rocm_config", + visibility = ["//visibility:public"], + deps = select({ + ":build_hermetic": [ + ":config_hermetic", + ], + "//conditions:default": [ + "config", + ], + }), +) + +cc_library( + name = "rocm_headers", + hdrs = glob([ + "%{rocm_root}/include/**", + "%{rocm_root}/lib/llvm/lib/**/*.h", + ]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", - "rocm/include/rocrand", - "rocm/include/roctracer", + "%{rocm_root}/include", + "%{rocm_root}/include/rocrand", + "%{rocm_root}/include/roctracer", ], + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], + deps = [ + ":rocm_rpath", + ], ) cc_library( - name = "hip", - srcs = ["rocm/lib/%{hip_lib}"], - data = ["rocm/lib/%{hip_lib}"], + name = "rocm", + visibility = ["//visibility:public"], + deps = [ + ":hip", + ":hipblas", + ":hipblaslt", + ":hiprand", + ":hipsolver", + ":hipsparse", + ":hsa_rocr", + ":miopen", + ":rocblas", + ":rocm_config", + ":rocprofiler_register", + ":rocsolver", + ":roctracer", + ":rocsparse", + ] + select_threshold( + above_or_eq = [":hipfft"], + below = [":rocfft"], + threshold = 40100, + value = rocm_version_number(), + ), +) + +cc_library( + name = "hsa_rocr", + srcs = glob(["%{rocm_root}/lib/libhsa-runtime*.so*"]), + hdrs = glob(["%{rocm_root}/include/hsa/**"]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", ], linkstatic = 1, + strip_include_prefix = "%{rocm_root}", + deps = [":rocm_config"], +) + +cc_library( + name = "rocm_rpath", + linkopts = select({ + ":build_hermetic": [ + "-Wl,-rpath=%{rocm_toolkit_path}/lib", + ], + "//conditions:default": [ + "-Wl,-rpath=/opt/rocm/lib", + ], + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "hip", visibility = ["//visibility:public"], + deps = [ + ":rocm_hip", + ":rocm_rpath", + ], +) + +cc_library( + name = "rocm_hip", + srcs = glob(["%{rocm_root}/lib/libamdhip*.so*"]), + hdrs = glob(["%{rocm_root}/include/hip/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", + ], + strip_include_prefix = "%{rocm_root}", + deps = [ + ":amd_comgr", + ":hsa_rocr", + ":rocm_config", + ":rocm_smi", + ":rocprofiler_register", + ":system_libs", + ], ) cc_library( name = "rocblas", - srcs = ["rocm/lib/%{rocblas_lib}"], - data = ["rocm/lib/%{rocblas_lib}"], + hdrs = glob(["%{rocm_root}/include/rocblas/**"]), + data = glob([ + "%{rocm_root}/lib/librocblas*.so*", + "%{rocm_root}/lib/rocblas/**", + ]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", ], - linkstatic = 1, + # workaround to bring tensile files to the same fs layout as expected in the lib + # rocblas assumes that tensile files are located in ../roblas/libraries directory + linkopts = ["-Wl,-rpath=local_config_rocm/rocm/rocm_dis/lib"], + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "%{hipfft_or_rocfft}", - srcs = ["rocm/lib/%{hipfft_or_rocfft_lib}"], - data = ["rocm/lib/%{hipfft_or_rocfft_lib}"], + name = "rocfft", + srcs = glob(["%{rocm_root}/lib/librocfft*.so*"]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", ], linkstatic = 1, visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "hiprand", - srcs = ["rocm/lib/%{hiprand_lib}"], - data = ["rocm/lib/%{hiprand_lib}"], + name = "hipfft", + srcs = glob(["%{rocm_root}/lib/libhipfft*.so*"]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", - "rocm/include/rocrand", + "%{rocm_root}/include", ], linkstatic = 1, - visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "miopen", - srcs = ["rocm/lib/%{miopen_lib}"], - data = ["rocm/lib/%{miopen_lib}"], + name = "hiprand", + srcs = glob(["%{rocm_root}/lib/libhiprand*.so*"]), + hdrs = glob(["%{rocm_root}/include/hiprand/**"]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", + "%{rocm_root}/include/rocrand", ], linkstatic = 1, + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "rccl", - srcs = ["rocm/lib/%{rccl_lib}"], - data = ["rocm/lib/%{rccl_lib}"], + name = "miopen", + hdrs = glob(["%{rocm_root}/include/rccl/**"]), + data = glob([ + "%{rocm_root}/lib/libMIOpen*.so*", + "%{rocm_root}/share/miopen/**", + ]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", ], - linkstatic = 1, + # workaround to bring miopen db files to the same fs layout as expected in the lib + # rocblas assumes that miopen db files are located in ../share/miopen/db directory + linkopts = ["-Wl,-rpath=local_config_rocm/rocm/rocm_dis/lib"], + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "rocm", - visibility = ["//visibility:public"], - deps = [ - ":rocm_headers", - ":hip", - ":rocblas", - ":hipblas", - ":%{hipfft_or_rocfft}", - ":hiprand", - ":miopen", - ":hipsparse", - ":roctracer", - ":rocsolver", - ":hipsolver", + name = "rccl", + srcs = glob(["%{rocm_root}/lib/librccl*.so*"]), + hdrs = glob(["%{rocm_root}/include/rccl/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", ], + linkstatic = 1, + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) bzl_library( name = "build_defs_bzl", srcs = ["build_defs.bzl"], + visibility = ["//visibility:public"], ) cc_library( name = "rocprim", srcs = [ - "rocm/include/hipcub/hipcub_version.hpp", - "rocm/include/rocprim/rocprim_version.hpp", + "%{rocm_root}/include/hipcub/hipcub_version.hpp", + "%{rocm_root}/include/rocprim/rocprim_version.hpp", ], hdrs = glob([ - "rocm/include/hipcub/**", - "rocm/include/rocprim/**", + "%{rocm_root}/include/hipcub/**", + "%{rocm_root}/include/rocprim/**", ]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include/hipcub", - "rocm/include/rocprim", + "%{rocm_root}/include/hipcub", + "%{rocm_root}/include/rocprim", ], + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], deps = [ - "@local_config_rocm//rocm:rocm_headers", + ":rocm_config", + ":rocm_headers", ], ) cc_library( name = "hipsparse", - srcs = ["rocm/lib/%{hipsparse_lib}"], - data = ["rocm/lib/%{hipsparse_lib}"], + srcs = glob(["%{rocm_root}/lib/libhipsparse*.so*"]), + hdrs = glob(["%{rocm_root}/include/hipsparse/**"]), + data = glob(["%{rocm_root}/lib/libhipsparse*.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( name = "roctracer", - data = ["rocm/lib/%{roctracer_lib}"], + hdrs = glob(["%{rocm_root}/include/roctracer/**"]), + data = glob(["%{rocm_root}/lib/libroctracer*.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( name = "rocsolver", - srcs = ["rocm/lib/%{rocsolver_lib}"], - data = ["rocm/lib/%{rocsolver_lib}"], + srcs = glob(["%{rocm_root}/lib/librocsolver*.so*"]), + hdrs = glob(["%{rocm_root}/include/rocsolver/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], +) + +cc_library( + name = "rocsparse", + srcs = glob(["%{rocm_root}/lib/librocsparse*.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( name = "hipsolver", - srcs = ["rocm/lib/%{hipsolver_lib}"], - data = ["rocm/lib/%{hipsolver_lib}"], + srcs = glob(["%{rocm_root}/lib/libhipsolver*.so*"]), + hdrs = glob(["%{rocm_root}/include/hipsolver/**"]), + data = glob(["%{rocm_root}/lib/libhipsolver*.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( name = "hipblas", - srcs = ["rocm/lib/%{hipblas_lib}"], - data = ["rocm/lib/%{hipblas_lib}"], + srcs = glob(["%{rocm_root}/lib/libhipblas.so*"]), + hdrs = glob(["%{rocm_root}/include/hipblas/**"]), + data = glob(["%{rocm_root}/lib/libhipblas.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], +) + +cc_library( + name = "hipblaslt", + hdrs = glob(["%{rocm_root}/include/hipblaslt/**"]), + data = glob([ + "%{rocm_root}/lib/hipblaslt/**", + "%{rocm_root}/lib/libhipblaslt.so*", + ]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + # workaround to bring tensile files to the same fs layout as expected in the lib + # hibplatslt assumes that tensile files are located in ../hipblaslt/libraries directory + linkopts = ["-Wl,-rpath=local_config_rocm/rocm/rocm_dis/lib"], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], +) + +cc_library( + name = "rocrand", + srcs = glob(["%{rocm_root}/lib/librocrand*.so*"]), + hdrs = glob(["%{rocm_root}/include/rocrand/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], +) + +cc_library( + name = "rocprofiler_register", + srcs = glob([ + "%{rocm_root}/lib/librocprofiler-register.so*", + ]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", + ], + strip_include_prefix = "%{rocm_root}", + deps = [":rocm_config"], +) + +cc_library( + name = "amd_comgr", + srcs = glob([ + "%{rocm_root}/lib/libamd_comgr.so*", + ]), + hdrs = glob(["%{rocm_root}/include/amd_comgr/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", + ], + strip_include_prefix = "%{rocm_root}", + deps = [":rocm_config"], +) + +cc_library( + name = "rocm_smi", + srcs = glob([ + "%{rocm_root}/lib/librocm_smi64.so*", + "%{rocm_root}/lib/libroam.so*", + ]), + hdrs = glob([ + "%{rocm_root}/include/oam/**", + "%{rocm_root}/include/rocm_smi/**", + ]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", + ], + strip_include_prefix = "%{rocm_root}", + deps = [":rocm_config"], +) + +cc_library( + name = "system_libs", + srcs = glob([ + "rocm_dist/usr/lib/**/libelf.so*", + "rocm_dist/usr/lib/**/libdrm.so*", + "rocm_dist/usr/lib/**/libnuma.so*", + "rocm_dist/usr/lib/**/libdrm_amdgpu.so*", + ]), + data = glob([ + "rocm_dist/usr/**", + ]), ) filegroup( name = "rocm_root", srcs = [ - "rocm/bin/clang-offload-bundler", + "%{rocm_root}/bin/clang-offload-bundler", ], + visibility = ["//visibility:public"], ) -%{copy_rules} +filegroup( + name = "all_files", + srcs = glob(["%{rocm_root}/**"]), + visibility = ["//visibility:public"], +) diff --git a/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/gpus/rocm/build_defs.bzl.tpl index 83a7e9dababf38..d327083e4dc8ea 100644 --- a/third_party/gpus/rocm/build_defs.bzl.tpl +++ b/third_party/gpus/rocm/build_defs.bzl.tpl @@ -11,6 +11,8 @@ def if_rocm(if_true, if_false = []): "//conditions:default": if_false }) +def select_threshold(value, above_or_eq, threshold, below): + return below if value < threshold else above_or_eq def rocm_default_copts(): """Default options for all ROCm compilations.""" diff --git a/third_party/gpus/rocm/rocm_redist.bzl b/third_party/gpus/rocm/rocm_redist.bzl new file mode 100644 index 00000000000000..ca64cc8ec9b61b --- /dev/null +++ b/third_party/gpus/rocm/rocm_redist.bzl @@ -0,0 +1,18 @@ +load( + "@local_tsl//third_party/gpus/rocm:rocm_redist_ubuntu_20_04.bzl", + "rocm_redist_ubuntu_20_04", +) +load( + "@local_tsl//third_party/gpus/rocm:rocm_redist_ubuntu_22_04.bzl", + "rocm_redist_ubuntu_22_04", +) +load( + "@local_tsl//third_party/gpus/rocm:rocm_redist_ubuntu_24_04.bzl", + "rocm_redist_ubuntu_24_04", +) + +rocm_redist = { + "ubuntu_20.04": rocm_redist_ubuntu_20_04, + "ubuntu_22.04": rocm_redist_ubuntu_22_04, + "ubuntu_24.04": rocm_redist_ubuntu_24_04, +} diff --git a/third_party/gpus/rocm/rocm_redist_ubuntu_20_04.bzl b/third_party/gpus/rocm/rocm_redist_ubuntu_20_04.bzl new file mode 100644 index 00000000000000..ecae2197563b33 --- /dev/null +++ b/third_party/gpus/rocm/rocm_redist_ubuntu_20_04.bzl @@ -0,0 +1,183 @@ +rocm_redist_ubuntu_20_04 = { + "6.2.0": { + "archives": [ + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/c/comgr6.2.0/comgr6.2.0_2.8.0.60200-66~20.04_amd64.deb", + sha256 = "fabf4a831f21b5248932e08654149bc215da2a816613ad8d05b805d4e226171a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-runtime-amd6.2.0/hip-runtime-amd6.2.0_6.2.41133.60200-66~20.04_amd64.deb", + sha256 = "215fae8759742bc048699feaacd6256a3ac2138771b69731dab7779325bb1b41", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev6.2.0/hip-dev6.2.0_6.2.41133.60200-66~20.04_amd64.deb", + sha256 = "e901d66275b3b520ee73250caa4a1836be142823083528b4db6cc31a18bfb94d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas6.2.0/hipblas6.2.0_2.2.0.60200-66~20.04_amd64.deb", + sha256 = "f8a20128b5c26198bd9ecec894f8a4c74fa28ee668e4ef1bf73d0c3edff8c144", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas-dev6.2.0/hipblas-dev6.2.0_2.2.0.60200-66~20.04_amd64.deb", + sha256 = "ab3ee54b33eba013fbf3d9aefe64b54e1918b9fb72790ca0b57fb391cb662cf0", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcc6.2.0/hipcc6.2.0_1.1.1.60200-66~20.04_amd64.deb", + sha256 = "a68123c046b8c913705262014463a8a30768167a1b68a78d8455deaf85a802d7", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcub-dev6.2.0/hipcub-dev6.2.0_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "c71fab59f62ad9d4b60aa4217f4db42c6996d83d5ad7ba29e127cc13bda59afc", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft6.2.0/hipfft6.2.0_1.0.14.60200-66~20.04_amd64.deb", + sha256 = "25887526ea2e955d4c0afa4749f8db55a49e399a349d43ccf66e0ad99ff78b2a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft-dev6.2.0/hipfft-dev6.2.0_1.0.14.60200-66~20.04_amd64.deb", + sha256 = "3cfec840c79c6bce4e83bf6e056e241cc13ff572352b040a952c7642b61d45aa", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver6.2.0/hipsolver6.2.0_2.2.0.60200-66~20.04_amd64.deb", + sha256 = "cb56dd79ff52eaddfed379831023484d9ec32b9538bc3d02ee34c328457cd20e", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver-dev6.2.0/hipsolver-dev6.2.0_2.2.0.60200-66~20.04_amd64.deb", + sha256 = "1e968f9405c8b90fbb58dff09d8bab08cf31c8386880fff95e1cb8932320bc37", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse6.2.0/hipsparse6.2.0_3.1.1.60200-66~20.04_amd64.deb", + sha256 = "f08ba25b6b950754b5a2bb64c125a01b9f44280f227ff19eeb78e188f0b17320", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse-dev6.2.0/hipsparse-dev6.2.0_3.1.1.60200-66~20.04_amd64.deb", + sha256 = "e9464369619bbea7299ac83e17b3cbbabdeb16e6d4da116400532e7737332b65", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand6.2.0/hiprand6.2.0_2.11.0.60200-66~20.04_amd64.deb", + sha256 = "2efed49be9413e08e91b3fb67736644bb0e8809fc673d310a0abab65b69eacad", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand-dev6.2.0/hiprand-dev6.2.0_2.11.0.60200-66~20.04_amd64.deb", + sha256 = "19564fb2f9616860234aa8bd69cca324a1a3ec33476581ec57200a1dac1d4dcb", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hsa-rocr6.2.0/hsa-rocr6.2.0_1.14.0.60200-66~20.04_amd64.deb", + sha256 = "e4940a5d47e9e39d603f18936e7921c603fd8dde0e359e0be796f9c1cdacd431", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip6.2.0/miopen-hip6.2.0_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "638a28c5407c3af7d16e1b0179b7494b0aeb36c314114af148b1bcd52e883db1", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip-dev/miopen-hip-dev_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "77c9d26c4f0053b71fb86f7a6b489655e27053f9605efca3a16344ccf286e313", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl6.2.0/rccl6.2.0_2.20.5.60200-66~20.04_amd64.deb", + sha256 = "2b3ce1ca2e58e891963f26d4bd31ae45894480483f691d371f269e698f75f8eb", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl-dev6.2.0/rccl-dev6.2.0_2.20.5.60200-66~20.04_amd64.deb", + sha256 = "0dedbffa5bb272d656086a9586e3705551345945f35f4f6be6dc8a27b63127a9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas6.2.0/rocblas6.2.0_4.2.0.60200-66~20.04_amd64.deb", + sha256 = "6e5b3caeadf592367f8638db67a70b8dd9231a8257dc2012a9c46e2c5974fff5", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas-dev/rocblas-dev_4.2.0.60200-66~20.04_amd64.deb", + sha256 = "eaefe5a7d75ef61314b83af5bb85d8e652a730deaa58e1d600b1e9c2e673673c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft6.2.0/rocfft6.2.0_1.0.28.60200-66~20.04_amd64.deb", + sha256 = "b2bfe29ab688781bad5bc067ee682658085e22caaf09b18278f2f4b9905081d3", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft-dev6.2.0/rocfft-dev6.2.0_1.0.28.60200-66~20.04_amd64.deb", + sha256 = "e94d50fd6f24d70649ce046dbfe4dda2587d1d82892d4c126a4c3e91d1570071", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-core/rocm-core_6.2.0.60200-66~20.04_amd64.deb", + sha256 = "0e16c9fc58fc904542be4dad63bb2ff34268b5c13957c432e91ec0e4fd149c82", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-hip-libraries/rocm-hip-libraries_6.2.0.60200-66~20.04_amd64.deb", + sha256 = "14f47d79b508eb259bfe4e0e5f360edb5721b908caf3bb981a4eee4181783be9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev/hip-dev_6.2.41133.60200-66~20.04_amd64.deb", + sha256 = "97e6e77eaea56de6cc4ea2c525dd8b9a587546eb99c782c7af46cdc5363b99bf", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-device-libs6.2.0/rocm-device-libs6.2.0_1.0.0.60200-66~20.04_amd64.deb", + sha256 = "ae055b579d319e1a779783ba774f119fb0e1a731d058a03b36dc5c15214d210a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocminfo6.2.0/rocminfo6.2.0_1.0.0.60200-66~20.04_amd64.deb", + sha256 = "3bcf3dc22dbede7da70299cde1484776827808b967d371441f6cf6d3fe8af30d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm6.2.0/rocm-llvm6.2.0_18.0.0.24292.60200-66~20.04_amd64.deb", + sha256 = "ce17d2b85407b9539e0feda513fd360a48ebfd971c19af122dda21d60448c9fc", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm-dev/rocm-llvm-dev_18.0.0.24292.60200-66~20.04_amd64.deb", + sha256 = "322ca8425c3a8f2ec17c551bad606b96d957b0c1eea07196dd66ac9f15460ed5", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-smi-lib6.2.0/rocm-smi-lib6.2.0_7.3.0.60200-66~20.04_amd64.deb", + sha256 = "1bbdb32d21dbc12bf9a736f6ca8726df9673e4401465d2b9b537c47b358b67f1", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprim-dev6.2.0/rocprim-dev6.2.0_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "e74e1907eb90a692344626e881cb88eeed5565ac3b487eb94ad4ac02ffd838ed", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprofiler-register6.2.0/rocprofiler-register6.2.0_0.4.0.60200-66~20.04_amd64.deb", + sha256 = "4be88c5010c2cf0223c1dd7dc9d4a430fc54ee401ca093de2dcca60dabea763a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocrand-dev/rocrand-dev_3.1.0.60200-66~20.04_amd64.deb", + sha256 = "ddd0ac44b08470dfc128d6f6d2598a9728879f5a78bc5290645baebf22433b63", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer6.2.0/roctracer6.2.0_4.1.60200.60200-66~20.04_amd64.deb", + sha256 = "b94cdf230b372ebcaf97085cf67f01ef7977f814280fdaf1886797f39899ef41", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer-dev6.2.0/roctracer-dev6.2.0_4.1.60200.60200-66~20.04_amd64.deb", + sha256 = "9a85b57eea3790432eae06421081b3e59d3c9841d59646364ecd174f9ed4821a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver6.2.0/rocsolver6.2.0_3.26.0.60200-66~20.04_amd64.deb", + sha256 = "87dcd34a9b50f46161ecdb7781ab03c2b311fb7e13aa167c4a9c5e3bcf24b473", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver-dev6.2.0/rocsolver-dev6.2.0_3.26.0.60200-66~20.04_amd64.deb", + sha256 = "21e4aa1957e7bc5d293a418a983d9b3c3917fb78eb79d3d4d55a253b9bae7743", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsparse6.2.0/rocsparse6.2.0_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "dacc13278f2be1cd847fca30ce409dcf95749df5f1a27635bc6dbd61be488d14", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm2_2.4.101-2_amd64.deb", + sha256 = "4cd2e10f9486456a2782487f8bfd39f330f35a4d5bd6d693412b9e4ca2a6acbd", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm-amdgpu1_2.4.101-2_amd64.deb", + sha256 = "d4567a30f7d68b4dcf794f8677b96e89083693c94e88279fecf577ceba8b9774", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libelf1_0.176-1.1build1_amd64.deb", + sha256 = "78a8761227efc04a1e37527f2f33ba608c6fb5d6c911616346ada5d7b9b72ee3", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libnuma1_2.0.12-1_amd64.deb", + sha256 = "0b1edf08cf9befecd21fe94e298ac25e476f87fd876ddd4adf42ef713449e637", + ), + ], + "rocm_root": "opt/rocm-6.2.0", + }, +} diff --git a/third_party/gpus/rocm/rocm_redist_ubuntu_22_04.bzl b/third_party/gpus/rocm/rocm_redist_ubuntu_22_04.bzl new file mode 100644 index 00000000000000..88dca226f795b7 --- /dev/null +++ b/third_party/gpus/rocm/rocm_redist_ubuntu_22_04.bzl @@ -0,0 +1,183 @@ +rocm_redist_ubuntu_22_04 = { + "6.2.0": { + "archives": [ + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/c/comgr6.2.0/comgr6.2.0_2.8.0.60200-66~22.04_amd64.deb", + sha256 = "bc5d620e4e0db3746fc6b2279e463f618681f1f95ba973e40b687cef50ca2489", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-runtime-amd6.2.0/hip-runtime-amd6.2.0_6.2.41133.60200-66~22.04_amd64.deb", + sha256 = "38e9670bedc7bbdc0b9f38c7a0fe90f73ef80f161cbf63c98d30e422438ce2c5", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev6.2.0/hip-dev6.2.0_6.2.41133.60200-66~22.04_amd64.deb", + sha256 = "c66cc8c19b57cab740710811457f02a16e24cff761e5c99c3640f63ceefe8281", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas6.2.0/hipblas6.2.0_2.2.0.60200-66~22.04_amd64.deb", + sha256 = "fbd647e1b13e7aa2c14c9581f9102c069ddab9ecb47a4b226d433ec37b19e92d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas-dev6.2.0/hipblas-dev6.2.0_2.2.0.60200-66~22.04_amd64.deb", + sha256 = "885cf3f3a52ebde9caadf6348a6cda28fd15e3bc52bab0c90b587d72b29ff7ef", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcc6.2.0/hipcc6.2.0_1.1.1.60200-66~22.04_amd64.deb", + sha256 = "468026fa8eb70121f0c545557a926ddc41228cef9457b4a00d8fc3a36b04310f", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcub-dev6.2.0/hipcub-dev6.2.0_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "c2c7d2ec5a8a31837c0addfc619ee67a374ea967cc6d43900472005489f62722", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft6.2.0/hipfft6.2.0_1.0.14.60200-66~22.04_amd64.deb", + sha256 = "6e649430cc5e247bbd052dff2d681b6bf0ef09d0bc3446a4911f4ab4cd317140", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft-dev6.2.0/hipfft-dev6.2.0_1.0.14.60200-66~22.04_amd64.deb", + sha256 = "389b0c83a39adbeeec442adde3fedba2820ed948179a4a0df03d67560501cd97", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver6.2.0/hipsolver6.2.0_2.2.0.60200-66~22.04_amd64.deb", + sha256 = "adf9aad1fc062445e34cdddbeca80db9c02f4c5f258e01c45e2a6222d15cb66d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver-dev6.2.0/hipsolver-dev6.2.0_2.2.0.60200-66~22.04_amd64.deb", + sha256 = "cb46dfbff3943a3167f6173fc381d744eb966a3451bcff49458c696888ec452c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse6.2.0/hipsparse6.2.0_3.1.1.60200-66~22.04_amd64.deb", + sha256 = "8c7a216aeef6ceeb3881d3e443a89a0f5c15a17deb5926cba4b787554c8fab87", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse-dev6.2.0/hipsparse-dev6.2.0_3.1.1.60200-66~22.04_amd64.deb", + sha256 = "501cad72df5f09572f99c11eebbb1eff49afb6ca8c91bcf4966f81068177a95d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand6.2.0/hiprand6.2.0_2.11.0.60200-66~22.04_amd64.deb", + sha256 = "b20c86be57698a944f91048699d0fbde5253bea28ba9d4035ce1de1d3c20f9ac", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand-dev6.2.0/hiprand-dev6.2.0_2.11.0.60200-66~22.04_amd64.deb", + sha256 = "9dab6f44b92b6020e183777f6f07219d68de5d10cad7538c7ddcae0192aa3e33", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hsa-rocr6.2.0/hsa-rocr6.2.0_1.14.0.60200-66~22.04_amd64.deb", + sha256 = "62d280204d8ff642b464dab03fc344442df6dc5f04e152da20604e8050303c41", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip6.2.0/miopen-hip6.2.0_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "6c2aa042067e51d5b70a264ca83c92ffaa6e81d00d08b55986917da860e66d85", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip-dev/miopen-hip-dev_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "f3452b2bd9c2869c550c7f963cca65fb35a37183ad4a56d96e05c69adb2f1d04", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl6.2.0/rccl6.2.0_2.20.5.60200-66~22.04_amd64.deb", + sha256 = "f3205c0a7d736f457ee2262988260e8dc4c495fa74a394ff73a9dfe002aff335", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl-dev6.2.0/rccl-dev6.2.0_2.20.5.60200-66~22.04_amd64.deb", + sha256 = "953a248cd44f403e5423185918166bfa29a009519c3d7b5b5a8e067fdf672602", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas6.2.0/rocblas6.2.0_4.2.0.60200-66~22.04_amd64.deb", + sha256 = "c306ca3e59b851ebb35872e09e5598adf2e2ebb736c1b200ff4ee204fe262f7e", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas-dev/rocblas-dev_4.2.0.60200-66~22.04_amd64.deb", + sha256 = "115d0e9ec1b93bf7cba5fa1e3de1428f0d999d931c2dd495e4cdad22b5078936", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft6.2.0/rocfft6.2.0_1.0.28.60200-66~22.04_amd64.deb", + sha256 = "0d40fc9aa1da617cd8864258cd1259a0e7444ea0da446297d154b5b3422393af", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft-dev6.2.0/rocfft-dev6.2.0_1.0.28.60200-66~22.04_amd64.deb", + sha256 = "8c1e72cf1c165e20960b0c2f3c499900a809d59340d14a0acff95c543c7087f2", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-core/rocm-core_6.2.0.60200-66~22.04_amd64.deb", + sha256 = "22c80c1a704f4ce7d6a49a8b41acd64f3ed0513cd7f5570a0664a10df5858334", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-hip-libraries/rocm-hip-libraries_6.2.0.60200-66~22.04_amd64.deb", + sha256 = "9c2ff1dc100e342969bd51a7cd4918048c8b25579de709efde56425d969cd50f", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev/hip-dev_6.2.41133.60200-66~22.04_amd64.deb", + sha256 = "1101f3edb9dbc9f4914d7f26b5569ec9bde076d52d4125c98d22a99dd730ab51", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-device-libs6.2.0/rocm-device-libs6.2.0_1.0.0.60200-66~22.04_amd64.deb", + sha256 = "d5b660df350130e0ab04ddf3e36dd442bde27ae9cbb8e5f12c047b0d3cb05463", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocminfo6.2.0/rocminfo6.2.0_1.0.0.60200-66~22.04_amd64.deb", + sha256 = "0d06a84ac53d388089b7b8c80133f60c1eea5bfd85155ecc113efb206a747c25", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm6.2.0/rocm-llvm6.2.0_18.0.0.24292.60200-66~22.04_amd64.deb", + sha256 = "4a29539480a7e4b27991ccf533a35526dd3994a457fa84e4c960192c2fa05b46", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm-dev/rocm-llvm-dev_18.0.0.24292.60200-66~22.04_amd64.deb", + sha256 = "febb8614cedd98f13ba0624072ffdd13b9a6dc3431380a17a0eaf87583627890", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprim-dev6.2.0/rocprim-dev6.2.0_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "3d859bb735ff8bf1962ce680e9257dcc574ab36224f50069f833fa19c6d7e69d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-smi-lib6.2.0/rocm-smi-lib6.2.0_7.3.0.60200-66~22.04_amd64.deb", + sha256 = "ffd4e064e8a1d52b9e72114e8a1d51c78004a960f1d923448af8ed07a1b6f30b", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprofiler-register6.2.0/rocprofiler-register6.2.0_0.4.0.60200-66~22.04_amd64.deb", + sha256 = "66df78d8c5e2d1a0ae43cd4a5e41cf75ec120c870a0bbd7da18a2ba4dec42f9c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocrand-dev/rocrand-dev_3.1.0.60200-66~22.04_amd64.deb", + sha256 = "317c16a6e0b0b456153437406dd92225e17dbd454fc1304b0c3fef5fbfc69bc2", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer6.2.0/roctracer6.2.0_4.1.60200.60200-66~22.04_amd64.deb", + sha256 = "9ddf8835f1e94d5004b4c466091c8110cb72e11eda545d0de395395832076c0a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer-dev6.2.0/roctracer-dev6.2.0_4.1.60200.60200-66~22.04_amd64.deb", + sha256 = "9a9ed0c66d3a9d9ff50f1fc3a9e9105bb8b1a6d93c1f856682625dfb68ab627f", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver6.2.0/rocsolver6.2.0_3.26.0.60200-66~22.04_amd64.deb", + sha256 = "5b86bf7b33a3ffa7098878f27d1b119aada69ebb02bd121b47209559c32703be", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver-dev6.2.0/rocsolver-dev6.2.0_3.26.0.60200-66~22.04_amd64.deb", + sha256 = "4573f99191fbe3a2afab84fdf5a05e024bd230ca7866d7eba71a5f2560a3a0bf", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsparse6.2.0/rocsparse6.2.0_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "4fbc91db9085ecd80a5e051bba56863ae33b22516d727ab3fef15fb500187222", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm2_2.4.110-1ubuntu1_amd64.deb", + sha256 = "e5ea68db36b31aab442c790e1c78ecdf53646c16b0cd83db15966632ba04152c", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm-amdgpu1_2.4.110-1ubuntu1_amd64.deb", + sha256 = "ae1f0d77668d7275d085ba820206ba91e90833dd1a02b8e251af0c73aa119ba3", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libelf1_0.186-1build1_amd64.deb", + sha256 = "8effc4d7a0cc341bcf6cb11af0134f3defa6292376ecfdfc697a9b228606345c", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libnuma1_2.0.14-3ubuntu2_amd64.deb", + sha256 = "0721c89001fbbd1ada23e89da5d60e762763c1a7b3dc814a2e9a518480a8043d", + ), + ], + "rocm_root": "opt/rocm-6.2.0", + }, +} diff --git a/third_party/gpus/rocm/rocm_redist_ubuntu_24_04.bzl b/third_party/gpus/rocm/rocm_redist_ubuntu_24_04.bzl new file mode 100644 index 00000000000000..da9ef00998f936 --- /dev/null +++ b/third_party/gpus/rocm/rocm_redist_ubuntu_24_04.bzl @@ -0,0 +1,187 @@ +rocm_redist_ubuntu_24_04 = { + "6.2.0": { + "archives": [ + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/c/comgr6.2.0/comgr6.2.0_2.8.0.60200-66~24.04_amd64.deb", + sha256 = "7e1ff2d9f2435f5b9db9aa952bb57d1a878a8aa7d96bda61361c107b7e1428e3", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev6.2.0/hip-dev6.2.0_6.2.41133.60200-66~24.04_amd64.deb", + sha256 = "5e6601ada30432ee0dab0473585bdf1fa7c398f0c655538d48eba9c44e6dc77a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas6.2.0/hipblas6.2.0_2.2.0.60200-66~24.04_amd64.deb", + sha256 = "7ff8f6308c744c71008959b17ab6338de1c6fd3e4581dd94271e6eca9fdc4c13", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas-dev6.2.0/hipblas-dev6.2.0_2.2.0.60200-66~24.04_amd64.deb", + sha256 = "e9f71e71db600d72dcb2b61e64b965b6c60d47bd4bb699e8abec85edb260b819", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblaslt6.2.0/hipblaslt6.2.0_0.8.0.60200-66~24.04_amd64.deb", + sha256 = "e5dfd8ba9e49f919a96c102d3a652e8ef0c4d1a63b3f3909c856d40b1745e2a9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblaslt-dev6.2.0/hipblaslt-dev6.2.0_0.8.0.60200-66~24.04_amd64.deb", + sha256 = "639bd47010035ee6719425510be33d2f54483004a909dfa4c64f853d7394a22f", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcc6.2.0/hipcc6.2.0_1.1.1.60200-66~24.04_amd64.deb", + sha256 = "c2782a98633e4400f46ba732605e56b2821366db60ec06d88db0615e4d1acf3c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcub-dev6.2.0/hipcub-dev6.2.0_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "48fec4d06aef3159db4117125b728242a1eeb480ea3d55d3901d945d4b883694", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft6.2.0/hipfft6.2.0_1.0.14.60200-66~24.04_amd64.deb", + sha256 = "8dd73cdbd4f0563f4a0481304771e4cbcac5905eea1f2d8ef41f922cdf9aba85", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft-dev6.2.0/hipfft-dev6.2.0_1.0.14.60200-66~24.04_amd64.deb", + sha256 = "e3c0a4ebda8d3aacd44b19c6872f23222513be0a5c04f793605088d9183f1be4", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver6.2.0/hipsolver6.2.0_2.2.0.60200-66~24.04_amd64.deb", + sha256 = "adbba9ffcf8b5e4202efbe45924d87520bf4100ec5464bd0ba3beb61cb535c6c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver-dev6.2.0/hipsolver-dev6.2.0_2.2.0.60200-66~24.04_amd64.deb", + sha256 = "01d3dd6195111808b40a5837d3e51d8c27c4700b4bd8bb2d901e39d0474fd98a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse6.2.0/hipsparse6.2.0_3.1.1.60200-66~24.04_amd64.deb", + sha256 = "2ba33a96388cd3edd7b5b8b261fe99cbd569894f4d7db291fc0dd0ff5d7c67ce", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse-dev6.2.0/hipsparse-dev6.2.0_3.1.1.60200-66~24.04_amd64.deb", + sha256 = "6a767f493a722e2d4260a9bc23cf9db66fd275a094b395c768e305f60d6b4fe9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand6.2.0/hiprand6.2.0_2.11.0.60200-66~24.04_amd64.deb", + sha256 = "82f182134b415080ba4a12fd7993b6099ee9b9e549c72bfebee24c8486704078", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand-dev6.2.0/hiprand-dev6.2.0_2.11.0.60200-66~24.04_amd64.deb", + sha256 = "011d5c28f45cd9d756e0cf6ea6a3d37eabd98a3381ffd961c772ab92a37e4ee8", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hsa-rocr6.2.0/hsa-rocr6.2.0_1.14.0.60200-66~24.04_amd64.deb", + sha256 = "fa04f707debb75087ea2bf5e327602034eaa3a6900421f2cf32ad5f5f1c887b9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip6.2.0/miopen-hip6.2.0_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "2dbf6d126d0de6930e0cd94d0e525e07d3019d90bd7256f3151a7f1fbc2250af", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip-dev/miopen-hip-dev_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "df5fdd2218e4d380b133ba402f3734fbe0589d9cdd8618a101b71b968909b4ba", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl6.2.0/rccl6.2.0_2.20.5.60200-66~24.04_amd64.deb", + sha256 = "4d7efa4ee6aa2bf69b0aab449cc1d01c25ca65814e1b3cb07f6b59fa8b1608b8", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl-dev6.2.0/rccl-dev6.2.0_2.20.5.60200-66~24.04_amd64.deb", + sha256 = "4ab4f880344e04d61b6fa746be5c4bdc2841409fb6987ee61e39c6420b4eca42", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas6.2.0/rocblas6.2.0_4.2.0.60200-66~24.04_amd64.deb", + sha256 = "521c87ce396c6ce10076cc641b6035451fd68ddb36a684c5a9c9538dfc831ade", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas-dev/rocblas-dev_4.2.0.60200-66~24.04_amd64.deb", + sha256 = "00f135ce2ae47c35085ef06248ff7d5ce8c12fd0d5b82e7bd77b1dbc0ce7058e", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft6.2.0/rocfft6.2.0_1.0.28.60200-66~24.04_amd64.deb", + sha256 = "40c936452e84bfec87236f08de5a9d3f232c397a3305b6143c26697ed56ceda1", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft-dev6.2.0/rocfft-dev6.2.0_1.0.28.60200-66~24.04_amd64.deb", + sha256 = "eb3904263b396d46799eeea1081d8e8d1a551a890432a803364db2d013849f92", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-core/rocm-core_6.2.0.60200-66~24.04_amd64.deb", + sha256 = "af5fcbe8dc2b6cbec30e2d39d30736e8a47a0b9d0ca2be7f179f2947f9c98245", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-hip-libraries/rocm-hip-libraries_6.2.0.60200-66~24.04_amd64.deb", + sha256 = "228f07a3caefc41f6efd5345eb9d3630f1db769f9b4abd1313cbcb32d077ce53", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev/hip-dev_6.2.41133.60200-66~24.04_amd64.deb", + sha256 = "cda72054d2011dbb062e75386766d928fd8905c15c88685c3ef87fc963bd88ad", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-device-libs6.2.0/rocm-device-libs6.2.0_1.0.0.60200-66~24.04_amd64.deb", + sha256 = "298544f717dfb236b9257b19a0ab81abaaa770128976d4abfdea546cd32d8b02", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocminfo6.2.0/rocminfo6.2.0_1.0.0.60200-66~24.04_amd64.deb", + sha256 = "8e78ed8e480b55a496153b150acb22bab39c3bb8cf1e62f9aff7eaf75a3a3a92", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm6.2.0/rocm-llvm6.2.0_18.0.0.24292.60200-66~24.04_amd64.deb", + sha256 = "72c388eae7c0f54151b46fbd8fa6e26f1ca81e2b8b415c43411a156b3f25b6e7", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm-dev/rocm-llvm-dev_18.0.0.24292.60200-66~24.04_amd64.deb", + sha256 = "3e85a859c5dafa82a9a57dda096d566b821217bacfac995f7cc45ed460b68999", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-smi-lib6.2.0/rocm-smi-lib6.2.0_7.3.0.60200-66~24.04_amd64.deb", + sha256 = "c094e3022c73fca2aa6c8bb435f93550109531de37fe8de5fbf6cfe1f047b645", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprim-dev6.2.0/rocprim-dev6.2.0_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "6c832e2feb0885fbe481245825c76a466921b294f530eb0d0da70a44cfe6e608", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprofiler-register6.2.0/rocprofiler-register6.2.0_0.4.0.60200-66~24.04_amd64.deb", + sha256 = "d198d010fedfbe51d3fd19444e2848d430e08f91d19a5b2661b94ac6d1135863", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocrand-dev/rocrand-dev_3.1.0.60200-66~24.04_amd64.deb", + sha256 = "2a2a95185ce0e54df226474b2f5cfcdc9e5ede5a6d88a8a70c2635ea2237abba", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer6.2.0/roctracer6.2.0_4.1.60200.60200-66~24.04_amd64.deb", + sha256 = "2f2fb6f8d06ace89131934c833b0ea359335a4b45aeec1559b293d7bc14b1d1d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer-dev6.2.0/roctracer-dev6.2.0_4.1.60200.60200-66~24.04_amd64.deb", + sha256 = "c6c781ee87c459aed32e943b389137f98ecd402fb83a3d1c98de9a76abadc3a3", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver6.2.0/rocsolver6.2.0_3.26.0.60200-66~24.04_amd64.deb", + sha256 = "5e4b3e38556f0826e5322971635a49a72283d60862ccc4d28efd11c8fb955b47", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver-dev6.2.0/rocsolver-dev6.2.0_3.26.0.60200-66~24.04_amd64.deb", + sha256 = "5bb6ae92a25f33488f2ee5f123ac4f67ad130e18e4949161715451509be3b89d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsparse6.2.0/rocsparse6.2.0_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "1867833a569fbf3f87b82c81bc47f5d62085ea40f12d1cb33475c1f2dec89bc4", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm2_2.4.120-2build1_amd64.deb", + sha256 = "f5fb4e7ce17921cc466fb7911abf91495ffb181b36772f68e2e82cb621703112", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm-amdgpu1_2.4.120-2build1_amd64.deb", + sha256 = "e149d4daea33f58853b8013fd6c24888429ce7716a4b26d1a1f45181b5a4e73e", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libelf1t64_0.190-1.1build4_amd64.deb", + sha256 = "b277e52769302778bd052376ac6687b52954b6605dd5f781bff8631e3504d58f", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libnuma1_2.0.18-1build1_amd64.deb", + sha256 = "508daa855e99959acaa945e6a89d218e0be6b5727fd28773580942ff37cf5805", + ), + ], + "rocm_root": "opt/rocm-6.2.0", + }, +} diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 39eec6edd4fdc5..7748cbe1d4223e 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -12,6 +12,10 @@ * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets. """ +load( + "//third_party/gpus/rocm:rocm_redist.bzl", + "rocm_redist", +) load( "//third_party/remote_config:common.bzl", "config_repo_label", @@ -33,15 +37,12 @@ load( load( ":cuda_configure.bzl", "enable_cuda", - "make_copy_dir_rule", - "make_copy_files_rule", ) load( ":sycl_configure.bzl", "enable_sycl", ) - _GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" _GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX" _CLANG_COMPILER_PATH = "CLANG_COMPILER_PATH" @@ -49,6 +50,9 @@ _TF_SYSROOT = "TF_SYSROOT" _ROCM_TOOLKIT_PATH = "ROCM_PATH" _TF_ROCM_AMDGPU_TARGETS = "TF_ROCM_AMDGPU_TARGETS" _TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO" +_DISTRIBUTION_PATH = "rocm/rocm_dist" +_OS = "OS" +_ROCM_VERSION = "ROCM_VERSION" _DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm" @@ -207,29 +211,8 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin): """ inc_dirs = [] - # Add HSA headers (needs to match $HSA_PATH) - inc_dirs.append(rocm_config.rocm_toolkit_path + "/hsa/include") - - # Add HIP headers (needs to match $HIP_PATH) - inc_dirs.append(rocm_config.rocm_toolkit_path + "/hip/include") - if int(rocm_config.rocm_version_number) >= 50200: - inc_dirs.append(rocm_config.rocm_toolkit_path + "/include") - inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/hip") - inc_dirs.append(rocm_config.rocm_paths["ROCPRIM"] + "/include/rocprim") - inc_dirs.append(rocm_config.rocm_paths["ROCSOLVER"] + "/include/rocsolver") - inc_dirs.append(rocm_config.rocm_paths["ROCBLAS"] + "/include/rocblas") - - # Add HIP-Clang headers (realpath relative to compiler binary) - inc_dirs.append(rocm_config.llvm_path + "/lib/clang/12.0.0/include") - inc_dirs.append(rocm_config.llvm_path + "/lib/clang/13.0.0/include") - inc_dirs.append(rocm_config.llvm_path + "/lib/clang/14.0.0/include") - inc_dirs.append(rocm_config.llvm_path + "/lib/clang/15.0.0/include") - inc_dirs.append(rocm_config.llvm_path + "/lib/clang/16.0.0/include") - inc_dirs.append(rocm_config.llvm_path + "/lib/clang/17.0.0/include/") - inc_dirs.append(rocm_config.llvm_path + "/lib/clang/17/include") - inc_dirs.append(rocm_config.llvm_path + "/lib/clang/18/include") - inc_dirs.append(rocm_config.llvm_path + "/lib/clang/19/include") - rocm_toolkit_path = realpath(repository_ctx, rocm_config.rocm_toolkit_path, bash_bin) + # Add full paths + rocm_toolkit_path = str(repository_ctx.path(rocm_config.rocm_toolkit_path)) inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/8.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/9.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/10.0.0/include") @@ -381,7 +364,7 @@ def _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin): return libs -def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, bash_bin): +def _find_libs(repository_ctx, rocm_config, bash_bin): """Returns the ROCm libraries on the system. Args: @@ -397,7 +380,6 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, bash_bin): for name, path in [ ("amdhip64", rocm_config.rocm_toolkit_path), ("rocblas", rocm_config.rocm_paths["ROCBLAS"]), - (hipfft_or_rocfft, rocm_config.rocm_paths[hipfft_or_rocfft.upper()]), ("hiprand", rocm_config.rocm_paths["HIPRAND"]), ("MIOpen", rocm_config.rocm_paths["MIOPEN"]), ("rccl", rocm_config.rocm_paths["RCCL"]), @@ -415,17 +397,17 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, bash_bin): libs_paths.append(("hipblaslt", _rocm_lib_paths(repository_ctx, "hipblaslt", rocm_config.rocm_paths["HIPBLASLT"]), True)) return _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin) -def find_rocm_config(repository_ctx): +def find_rocm_config(repository_ctx, rocm_path): """Returns ROCm config dictionary from running find_rocm_config.py""" python_bin = get_python_bin(repository_ctx) - exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_rocm_config]) + exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_rocm_config], env_vars = {"ROCM_PATH": rocm_path}) if exec_result.return_code: auto_configure_fail("Failed to run find_rocm_config.py: %s" % err_out(exec_result)) # Parse the dict from stdout. return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()]) -def _get_rocm_config(repository_ctx, bash_bin): +def _get_rocm_config(repository_ctx, bash_bin, rocm_path, install_path): """Detects and returns information about the ROCm installation on the system. Args: @@ -440,7 +422,7 @@ def _get_rocm_config(repository_ctx, bash_bin): miopen_version_number: The version of MIOpen on the system. hipruntime_version_number: The version of HIP Runtime on the system. """ - config = find_rocm_config(repository_ctx) + config = find_rocm_config(repository_ctx, rocm_path) rocm_toolkit_path = config["rocm_toolkit_path"] rocm_version_number = config["rocm_version_number"] miopen_version_number = config["miopen_version_number"] @@ -470,21 +452,21 @@ def _get_rocm_config(repository_ctx, bash_bin): # Check if the environment variable which specifies the path to the rocm component is set and that # the rocm component is not already installed in the rocm_toolkit_path component_path = get_host_environ(repository_ctx, component + "_PATH") - if component_path==None: + if component_path == None: rocm_paths[component] = rocm_toolkit_path else: rocm_paths[component] = component_path rocm_paths["MIOPEN"] = get_host_environ(repository_ctx, "MIOPEN_PATH") - if rocm_paths["MIOPEN"]==None: + if rocm_paths["MIOPEN"] == None: # For ROCm 5.2 and above, find MIOpen and RCCL in the main rocm lib path rocm_paths["MIOPEN"] = rocm_toolkit_path + "/miopen" if int(rocm_version_number) < 50200 else rocm_toolkit_path rocm_paths["RCCL"] = get_host_environ(repository_ctx, "RCCL_PATH") - if rocm_paths["RCCL"]==None: + if rocm_paths["RCCL"] == None: rocm_paths["RCCL"] = rocm_toolkit_path + "/rccl" if int(rocm_version_number) < 50200 else rocm_toolkit_path llvm_path = get_host_environ(repository_ctx, "LLVM_PATH") - if llvm_path==None: + if llvm_path == None: llvm_path = rocm_toolkit_path + "/llvm" return struct( amdgpu_targets = _amdgpu_targets(repository_ctx, rocm_paths["ROCMINFO"], bash_bin), @@ -494,6 +476,7 @@ def _get_rocm_config(repository_ctx, bash_bin): hipruntime_version_number = hipruntime_version_number, rocm_paths = rocm_paths, llvm_path = llvm_path, + install_path = install_path, ) def _tpl_path(repository_ctx, labelname): @@ -557,15 +540,12 @@ def _create_dummy_repository(repository_ctx): "%{hipblas_lib}": _lib_name("hipblas"), "%{miopen_lib}": _lib_name("miopen"), "%{rccl_lib}": _lib_name("rccl"), - "%{hipfft_or_rocfft}": "hipfft", - "%{hipfft_or_rocfft_lib}": _lib_name("hipfft"), "%{hiprand_lib}": _lib_name("hiprand"), "%{hipsparse_lib}": _lib_name("hipsparse"), "%{roctracer_lib}": _lib_name("roctracer64"), "%{rocsolver_lib}": _lib_name("rocsolver"), "%{hipsolver_lib}": _lib_name("hipsolver"), "%{hipblaslt_lib}": _lib_name("hipblaslt"), - "%{copy_rules}": "", "%{rocm_headers}": "", }, ) @@ -583,7 +563,7 @@ def _create_dummy_repository(repository_ctx): "%{rocm_toolkit_path}": _DEFAULT_ROCM_TOOLKIT_PATH, "%{hipblaslt_flag}": "0", }, - "rocm/rocm/rocm_config.h", + "rocm/rocm_config/rocm_config.h", ) # If rocm_configure is not configured to build with GPU support, and the user @@ -635,6 +615,53 @@ def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets): amdgpu_target for amdgpu_target in amdgpu_targets] return str(amdgpu_target_flags) +def _get_file_name(url): + last_slash_index = url.rfind("/") + return url[last_slash_index + 1:] + +def _download_package(repository_ctx, archive): + file_name = _get_file_name(archive.url) + tmp_dir = "tmp" + repository_ctx.file(tmp_dir + "/.idx") # create tmp dir + + repository_ctx.report_progress("Downloading and extracting {}, expected hash is {}".format(archive.url, archive.sha256)) # buildifier: disable=print + repository_ctx.download_and_extract( + url = archive.url, + output = tmp_dir if archive.url.endswith(".deb") else _DISTRIBUTION_PATH, + sha256 = archive.sha256, + ) + + all_files = repository_ctx.path(tmp_dir).readdir() + + matched_files = [f for f in all_files if _get_file_name(str(f)).startswith("data.")] + for f in matched_files: + repository_ctx.extract(f, _DISTRIBUTION_PATH) + + repository_ctx.delete(tmp_dir) + repository_ctx.delete(file_name) + +def _remove_root_dir(path, root_dir): + if path.startswith(root_dir + "/"): + return path[len(root_dir) + 1:] + return path + +def _setup_rocm_distro_dir(repository_ctx): + """Sets up the rocm hermetic installation directory to be used in hermetic build""" + bash_bin = get_bash_bin(repository_ctx) + os = repository_ctx.os.environ.get(_OS) + rocm_version = repository_ctx.os.environ.get(_ROCM_VERSION) + if os and rocm_version: + redist = rocm_redist[os][rocm_version] + repository_ctx.file("rocm/.index") + for archive in redist["archives"]: + _download_package(repository_ctx, archive) + return _get_rocm_config(repository_ctx, bash_bin, "{}/{}".format(_DISTRIBUTION_PATH, redist["rocm_root"]), "/{}".format(redist["rocm_root"])) + else: + rocm_path = repository_ctx.os.environ.get(_ROCM_TOOLKIT_PATH, _DEFAULT_ROCM_TOOLKIT_PATH) + repository_ctx.report_progress("Using local rocm installation {}".format(rocm_path)) # buildifier: disable=print + repository_ctx.symlink(rocm_path, _DISTRIBUTION_PATH) + return _get_rocm_config(repository_ctx, bash_bin, _DISTRIBUTION_PATH, _DEFAULT_ROCM_TOOLKIT_PATH) + def _create_local_rocm_repository(repository_ctx): """Creates the repository containing files set up to build with ROCm.""" @@ -647,73 +674,23 @@ def _create_local_rocm_repository(repository_ctx): "rocm:rocm_config.h", ]} - bash_bin = get_bash_bin(repository_ctx) - rocm_config = _get_rocm_config(repository_ctx, bash_bin) - - # For ROCm 4.1 and above use hipfft, older ROCm versions use rocfft + rocm_config = _setup_rocm_distro_dir(repository_ctx) rocm_version_number = int(rocm_config.rocm_version_number) - hipfft_or_rocfft = "rocfft" if rocm_version_number < 40100 else "hipfft" # Copy header and library files to execroot. # rocm_toolkit_path - rocm_toolkit_path = rocm_config.rocm_toolkit_path - copy_rules = [ - make_copy_dir_rule( - repository_ctx, - name = "rocm-include", - src_dir = rocm_toolkit_path + "/include", - out_dir = "rocm/include", - ), - ] - - rocm_components_include = "" - - # install all the rocm component include directories that aren't in the rocm_toolkit_path and haven't - # already been installed to the local rocm repo - for component_label in rocm_config.rocm_paths: - component_name = component_label.lower().replace("_","-") - component_toolkit_include_path = rocm_config.rocm_toolkit_path + "/include/" + component_name - toolkit_include_exists = files_exist(repository_ctx, [component_toolkit_include_path], bash_bin) - component_include_path = rocm_config.rocm_paths[component_label] + "/include/" + component_name - if not toolkit_include_exists[0] and repository_ctx.path(component_include_path).exists: - rocm_components_include = rocm_components_include + '":' + component_name + '-include",\n' - copy_rules.append( - make_copy_dir_rule( - repository_ctx, - name = component_name + "-include", - src_dir = component_include_path, - out_dir = "rocm/include/" + component_name, - ), - ) - - rocm_libs = _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, bash_bin) + rocm_toolkit_path = _remove_root_dir(rocm_config.rocm_toolkit_path, "rocm") + bash_bin = get_bash_bin(repository_ctx) + rocm_libs = _find_libs(repository_ctx, rocm_config, bash_bin) rocm_lib_srcs = [] rocm_lib_outs = [] for lib in rocm_libs.values(): if lib: rocm_lib_srcs.append(lib.path) rocm_lib_outs.append("rocm/lib/" + lib.file_name) - copy_rules.append(make_copy_files_rule( - repository_ctx, - name = "rocm-lib", - srcs = rocm_lib_srcs, - outs = rocm_lib_outs, - )) clang_offload_bundler_path = rocm_config.llvm_path + "/bin/clang-offload-bundler" - # copy files mentioned in third_party/gpus/rocm/BUILD - copy_rules.append(make_copy_files_rule( - repository_ctx, - name = "rocm-bin", - srcs = [ - clang_offload_bundler_path, - ], - outs = [ - "rocm/bin/" + "clang-offload-bundler", - ], - )) - have_hipblaslt = "1" if rocm_libs["hipblaslt"] != None else "0" # Set up BUILD file for rocm/ @@ -735,18 +712,8 @@ def _create_local_rocm_repository(repository_ctx): ) repository_dict = { - "%{hip_lib}": rocm_libs["amdhip64"].file_name, - "%{rocblas_lib}": rocm_libs["rocblas"].file_name, - "%{hipfft_or_rocfft}": hipfft_or_rocfft, - "%{hipfft_or_rocfft_lib}": rocm_libs[hipfft_or_rocfft].file_name, - "%{hiprand_lib}": rocm_libs["hiprand"].file_name, - "%{miopen_lib}": rocm_libs["MIOpen"].file_name, - "%{rccl_lib}": rocm_libs["rccl"].file_name, - "%{hipsparse_lib}": rocm_libs["hipsparse"].file_name, - "%{roctracer_lib}": rocm_libs["roctracer64"].file_name, - "%{rocsolver_lib}": rocm_libs["rocsolver"].file_name, - "%{copy_rules}": "\n".join(copy_rules), - "%{rocm_headers}": ('":rocm-include",\n' + rocm_components_include), + "%{rocm_root}": rocm_toolkit_path, + "%{rocm_toolkit_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path)), } is_rocm_clang = _use_rocm_clang(repository_ctx) @@ -766,7 +733,6 @@ def _create_local_rocm_repository(repository_ctx): ) # Set up crosstool/ - cc = find_cc(repository_ctx, is_rocm_clang) host_compiler_includes = get_cxx_inc_directories( repository_ctx, @@ -829,6 +795,7 @@ def _create_local_rocm_repository(repository_ctx): repository_ctx.template( "crosstool/cc_toolchain_config.bzl", tpl_paths["crosstool:hipcc_cc_toolchain_config.bzl"], + rocm_defines, ) repository_ctx.template( @@ -836,12 +803,13 @@ def _create_local_rocm_repository(repository_ctx): tpl_paths["crosstool:clang/bin/crosstool_wrapper_driver_rocm"], { "%{cpu_compiler}": str(cc), - "%{compiler}": "clang" if is_rocm_clang else "unknown", - "%{hipcc_path}": rocm_config.rocm_toolkit_path + "/bin/hipcc", + "%{compiler_is_clang}": "True" if is_rocm_clang else "False", + "%{hipcc_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path + "/bin/hipcc")), "%{hipcc_env}": _hipcc_env(repository_ctx), - "%{rocr_runtime_path}": rocm_config.rocm_paths["HSA"] + "/lib", + "%{rocm_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path)), + "%{rocr_runtime_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path + "/lib")), "%{rocr_runtime_library}": "hsa-runtime64", - "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", + "%{hip_runtime_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path + "/lib")), "%{hip_runtime_library}": "amdhip64", "%{rccl_runtime_path}": rocm_config.rocm_paths["RCCL"] + "/lib", "%{rocblas_runtime_path}": rocm_config.rocm_paths["ROCBLAS"] + "/lib", @@ -857,13 +825,32 @@ def _create_local_rocm_repository(repository_ctx): # Set up rocm_config.h, which is used by # tensorflow/compiler/xla/stream_executor/dso_loader.cc. repository_ctx.template( - "rocm/rocm/rocm_config.h", + "rocm/rocm_config/rocm_config.h", + tpl_paths["rocm:rocm_config.h"], + { + "%{rocm_amdgpu_targets}": ",".join( + ["\"%s\"" % c for c in rocm_config.amdgpu_targets], + ), + "%{rocm_toolkit_path}": rocm_config.install_path, + "%{rocm_version_number}": rocm_config.rocm_version_number, + "%{miopen_version_number}": rocm_config.miopen_version_number, + "%{hipruntime_version_number}": rocm_config.hipruntime_version_number, + "%{hipblaslt_flag}": have_hipblaslt, + "%{hip_soversion_number}": "6" if int(rocm_config.rocm_version_number) >= 60000 else "5", + "%{rocblas_soversion_number}": "4" if int(rocm_config.rocm_version_number) >= 60000 else "3", + }, + ) + + # Set up rocm_config.h, which is used by + # tensorflow/compiler/xla/stream_executor/dso_loader.cc. + repository_ctx.template( + "rocm/rocm_config_hermetic/rocm_config.h", tpl_paths["rocm:rocm_config.h"], { "%{rocm_amdgpu_targets}": ",".join( ["\"%s\"" % c for c in rocm_config.amdgpu_targets], ), - "%{rocm_toolkit_path}": rocm_config.rocm_toolkit_path, + "%{rocm_toolkit_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path)), "%{rocm_version_number}": rocm_config.rocm_version_number, "%{miopen_version_number}": rocm_config.miopen_version_number, "%{hipruntime_version_number}": rocm_config.hipruntime_version_number, @@ -940,6 +927,8 @@ _ENVIRONS = [ _ROCM_TOOLKIT_PATH, _TF_ROCM_AMDGPU_TARGETS, "CLANG_COMPILER_PATH", + _OS, + _ROCM_VERSION, ] remote_rocm_configure = repository_rule( diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 4f8ac49c4524db..c14fe64d0b0902 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,42 +1,12 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst ---- a/clang/docs/ReleaseNotes.rst -+++ b/clang/docs/ReleaseNotes.rst -@@ -796,7 +796,6 @@ - - Fixed an assertion failure caused by mangled names with invalid identifiers. (#GH112205) - - Fixed an incorrect lambda scope of generic lambdas that caused Clang to crash when computing potential lambda - captures at the end of a full expression. (#GH115931) --- Clang no longer rejects deleting a pointer of incomplete enumeration type. (#GH99278) +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/Hexagon/isel/isel-tfrrp.ll b/llvm/test/CodeGen/Hexagon/isel/isel-tfrrp.ll +--- a/llvm/test/CodeGen/Hexagon/isel/isel-tfrrp.ll ++++ b/llvm/test/CodeGen/Hexagon/isel/isel-tfrrp.ll +@@ -2,6 +2,7 @@ + ; The constant 0 is generated by a transfer immediate instruction. - Bug Fixes to AST Handling - ^^^^^^^^^^^^^^^^^^^^^^^^^ -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp ---- a/clang/lib/Sema/SemaExprCXX.cpp -+++ b/clang/lib/Sema/SemaExprCXX.cpp -@@ -3747,8 +3747,7 @@ - } else if (!Pointee->isDependentType()) { - // FIXME: This can result in errors if the definition was imported from a - // module but is hidden. -- if (!Pointee->isStructureOrClassType() || -- !RequireCompleteType(StartLoc, Pointee, -+ if (!RequireCompleteType(StartLoc, Pointee, - LangOpts.CPlusPlus26 - ? diag::err_delete_incomplete - : diag::warn_delete_incomplete, -diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/new-delete.cpp b/clang/test/SemaCXX/new-delete.cpp ---- a/clang/test/SemaCXX/new-delete.cpp -+++ b/clang/test/SemaCXX/new-delete.cpp -@@ -540,13 +540,6 @@ - void f(A *x) { delete x; } // expected-warning {{delete called on 'PR10504::A' that is abstract but has non-virtual destructor}} - } + ; RUN: llc -march=hexagon -debug-only=isel 2>&1 < %s - | FileCheck %s ++; REQUIRES: asserts --#if __cplusplus >= 201103L --enum GH99278_1 { -- zero = decltype(delete static_cast(nullptr), 0){} -- // expected-warning@-1 {{expression with side effects has no effect in an unevaluated context}} --}; --#endif -- - struct PlacementArg {}; - inline void *operator new[](size_t, const PlacementArg &) throw () { - return 0; + ; CHECK: [[R0:%[0-9]+]]:intregs = A2_tfrsi 0 + ; CHECK-NEXT: predregs = C2_tfrrp killed [[R0]]:intregs diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 7c3347b7a73784..c35f4e43aec473 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "1d95825d4d168a17a4f27401dec3f2977a59a70e" - LLVM_SHA256 = "d3276c678b616c0d820fe14a3404b43591f4e1bc75b6bed2782e0776e0c9b401" + LLVM_COMMIT = "35e76b6a4fc74e64bd6c91e5b9b9eb6a03aa802e" + LLVM_SHA256 = "bf4e52c430ff8eb2b055a4abcbd70468d2e6ea7f277e472575e92903bd7d8981" tf_http_archive( name = name, diff --git a/third_party/mkl_dnn/mkldnn_acl.BUILD b/third_party/mkl_dnn/mkldnn_acl.BUILD index 868a2972a44861..56686b95fbefef 100644 --- a/third_party/mkl_dnn/mkldnn_acl.BUILD +++ b/third_party/mkl_dnn/mkldnn_acl.BUILD @@ -167,6 +167,7 @@ cc_library( "include/**/*", "include/*", "src/common/*.hpp", + "src/common/**/*.h", "src/cpu/**/*.hpp", "src/cpu/*.hpp", "src/cpu/aarch64/xbyak_aarch64/**/*.h", diff --git a/third_party/mkl_dnn/onednn_acl_add_bf16_platform_support_check.patch b/third_party/mkl_dnn/onednn_acl_add_bf16_platform_support_check.patch new file mode 100644 index 00000000000000..42dd262323b577 --- /dev/null +++ b/third_party/mkl_dnn/onednn_acl_add_bf16_platform_support_check.patch @@ -0,0 +1,31 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +diff --git a/src/cpu/platform.cpp b/src/cpu/platform.cpp +index 65b887ea21..eabdb827bd 100644 +--- a/src/cpu/platform.cpp ++++ b/src/cpu/platform.cpp +@@ -117,6 +117,8 @@ bool has_data_type_support(data_type_t data_type) { + #if defined(USE_CBLAS) && defined(BLAS_HAS_SBGEMM) && defined(__MMA__) + return true; + #endif ++#elif DNNL_AARCH64_USE_ACL ++ return arm_compute::CPUInfo::get().has_bf16(); + #else + return false; + #endif +-- +2.34.1 + diff --git a/third_party/mkl_dnn/onednn_acl_add_sbgemm_matmul_primitive_definition.patch b/third_party/mkl_dnn/onednn_acl_add_sbgemm_matmul_primitive_definition.patch new file mode 100644 index 00000000000000..779608a68058d2 --- /dev/null +++ b/third_party/mkl_dnn/onednn_acl_add_sbgemm_matmul_primitive_definition.patch @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +diff --git a/src/cpu/aarch64/matmul/acl_matmul.hpp b/src/cpu/aarch64/matmul/acl_matmul.hpp +index ab13efb9b2..ec261e156d 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul.hpp ++++ b/src/cpu/aarch64/matmul/acl_matmul.hpp +@@ -78,11 +78,21 @@ struct acl_matmul_t : public primitive_t { + = utils::everyone_is(data_type::f16, src_md()->data_type, + weights_md()->data_type, dst_md()->data_type) + && platform::has_data_type_support(data_type::f16); ++ const bool is_fp32_bf16_ok ++ = (utils::everyone_is(data_type::f32, src_md()->data_type, ++ dst_md()->data_type, desc()->accum_data_type) ++ && platform::has_data_type_support(data_type::f32) ++ && utils::everyone_is( ++ data_type::bf16, weights_md()->data_type) ++ && platform::has_data_type_support( ++ data_type::bf16)); ++ + const bool is_weights_md_format_ok + = utils::one_of(weights_format_kind_received, + format_kind::any, format_kind::blocked); + bool ok = is_dense_data() +- && utils::one_of(true, is_fp32_ok, is_fp16_ok) ++ && utils::one_of( ++ true, is_fp32_ok, is_fp16_ok, is_fp32_bf16_ok) + && !has_zero_dim_memory() && is_weights_md_format_ok + && set_default_formats() + && attr()->has_default_values( +-- +2.34.1 diff --git a/third_party/mkl_dnn/onednn_acl_allow_blocked_weight_format_for_matmul_primitive.patch b/third_party/mkl_dnn/onednn_acl_allow_blocked_weight_format_for_matmul_primitive.patch new file mode 100644 index 00000000000000..ec2cb97f5131ba --- /dev/null +++ b/third_party/mkl_dnn/onednn_acl_allow_blocked_weight_format_for_matmul_primitive.patch @@ -0,0 +1,100 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +diff --git a/src/cpu/aarch64/matmul/acl_matmul.hpp b/src/cpu/aarch64/matmul/acl_matmul.hpp +index 451cc78d52..ab13efb9b2 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul.hpp ++++ b/src/cpu/aarch64/matmul/acl_matmul.hpp +@@ -67,6 +67,8 @@ struct acl_matmul_t : public primitive_t { + + status_t init(engine_t *engine) { + using smask_t = primitive_attr_t::skip_mask_t; ++ const format_kind_t weights_format_kind_received ++ = weights_md_.format_kind; + const bool is_fp32_ok + = utils::everyone_is(data_type::f32, src_md()->data_type, + weights_md()->data_type, dst_md()->data_type, +@@ -76,18 +78,20 @@ struct acl_matmul_t : public primitive_t { + = utils::everyone_is(data_type::f16, src_md()->data_type, + weights_md()->data_type, dst_md()->data_type) + && platform::has_data_type_support(data_type::f16); ++ const bool is_weights_md_format_ok ++ = utils::one_of(weights_format_kind_received, ++ format_kind::any, format_kind::blocked); + bool ok = is_dense_data() + && utils::one_of(true, is_fp32_ok, is_fp16_ok) +- && !has_zero_dim_memory() +- && weights_md_.format_kind == format_kind::any ++ && !has_zero_dim_memory() && is_weights_md_format_ok + && set_default_formats() + && attr()->has_default_values( + smask_t::oscale | smask_t::post_ops) + && attr_oscale_ok() && !has_runtime_dims_or_strides(); + if (!ok) return status::unimplemented; + +- CHECK(acl_matmul_utils::init_conf_matmul( +- amp_, src_md_, weights_md_, dst_md_, *desc(), *attr())); ++ CHECK(acl_matmul_utils::init_conf_matmul(amp_, src_md_, weights_md_, ++ dst_md_, *desc(), *attr(), weights_format_kind_received)); + + arm_compute::ActivationLayerInfo act_info; + CHECK(post_ops.init(engine, attr_.post_ops_, dst_md_, act_info)); +diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp +index a314d96384..027f915a8a 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp ++++ b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp +@@ -27,7 +27,8 @@ namespace acl_matmul_utils { + + status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, + memory_desc_t &wei_md, memory_desc_t &dst_md, const matmul_desc_t &md, +- const primitive_attr_t &attr) { ++ const primitive_attr_t &attr, ++ format_kind_t weights_format_kind_received) { + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper wei_d(&wei_md); +@@ -128,9 +129,16 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, + for (dim_t i = K_dim - 1; i >= 0; --i) + batch_dims.push_back(i); + ++ const memory_desc_t weights_md_received = wei_md; + acl_utils::reorder_to_weight_format(amp.wei_tensor_info, wei_md, + expected_weight_format, K_dim, N_dim, {}, batch_dims); + ++ ACL_CHECK_SUPPORT((weights_format_kind_received == format_kind::blocked) ++ && !(dnnl_memory_desc_equal(&weights_md_received, &wei_md)), ++ "specified blocked format not supported by ACL, use " ++ "format_kind_t::any to find a supported blocked format for " ++ "your platform"); ++ + return status::success; + } + +diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp +index 67bb2e78eb..5ba4241abc 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp ++++ b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp +@@ -52,7 +52,8 @@ namespace acl_matmul_utils { + + status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, + memory_desc_t &wei_md, memory_desc_t &dst_md, const matmul_desc_t &md, +- const primitive_attr_t &attr); ++ const primitive_attr_t &attr, ++ format_kind_t weights_format_kind_received); + + } // namespace acl_matmul_utils + +-- +2.34.1 diff --git a/third_party/mkl_dnn/onednn_acl_fix_segfault_during_postop_execute.patch b/third_party/mkl_dnn/onednn_acl_fix_segfault_during_postop_execute.patch new file mode 100644 index 00000000000000..39f7e74345e08b --- /dev/null +++ b/third_party/mkl_dnn/onednn_acl_fix_segfault_during_postop_execute.patch @@ -0,0 +1,96 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +diff --git a/src/cpu/aarch64/acl_post_ops.cpp b/src/cpu/aarch64/acl_post_ops.cpp +index ea4bb200ec..3eb53b81bd 100644 +--- a/src/cpu/aarch64/acl_post_ops.cpp ++++ b/src/cpu/aarch64/acl_post_ops.cpp +@@ -24,7 +24,7 @@ namespace aarch64 { + + status_t acl_post_ops_t::execute(const exec_ctx_t &ctx, void *src_orig) const { + +- int post_op_index = 0; ++ int post_op_index = post_op_start_index_; + + // As these are post ops, this src will also be our dst. If we have a sum + // post op, the src/dst will start off in a temporary, then change to +diff --git a/src/cpu/aarch64/acl_post_ops.hpp b/src/cpu/aarch64/acl_post_ops.hpp +index 7b59ad71d3..ceaa95b73a 100644 +--- a/src/cpu/aarch64/acl_post_ops.hpp ++++ b/src/cpu/aarch64/acl_post_ops.hpp +@@ -32,7 +32,9 @@ struct acl_post_ops_t { + // init the acl_post_ops_t. Note that this function modifies the passed in + // post ops by setting the preferred memory formats + status_t init(engine_t *engine, post_ops_t &post_ops, +- const memory_desc_t &dst_md) { ++ const memory_desc_t &dst_md, int post_op_start_index = 0) { ++ ++ post_op_start_index_ = post_op_start_index; + + CHECK(post_ops.set_default_formats(&dst_md)); + dst_data_type = dst_md.data_type; +@@ -41,7 +43,7 @@ struct acl_post_ops_t { + sum_index = -1; + post_op_primitives = {}; + +- for (int i = 0; i < post_ops.len(); i++) { ++ for (int i = post_op_start_index; i < post_ops.len(); i++) { + auto &po = post_ops.entry_[i]; + + if (po.is_sum()) { +@@ -135,7 +137,8 @@ struct acl_post_ops_t { + // formats + status_t init(engine_t *engine, post_ops_t &base_post_ops, + const memory_desc_t &dst_md, +- arm_compute::ActivationLayerInfo &act_info_to_fuse) { ++ arm_compute::ActivationLayerInfo &act_info_to_fuse, ++ int post_op_start_index = 0) { + + CHECK(base_post_ops.set_default_formats(&dst_md)); + dst_data_type = dst_md.data_type; +@@ -149,18 +152,11 @@ struct acl_post_ops_t { + "eltwise post op scale must be 1 (no scale)"); + CHECK(acl_utils::convert_to_acl_act(first_po, act_info_to_fuse)); + +- // Copy all but the first, because it has been fused +- post_ops_t post_ops; +- for (int idx = 1; idx < base_post_ops.len(); ++idx) { +- // Construct empty entry then copy, so that we can check for failure +- post_ops.entry_.emplace_back(); +- post_ops.entry_.back().copy_from(base_post_ops.entry_[idx]); +- } +- return init(engine, post_ops, dst_md); +- ++ // post_op_start_index + 1 to skip the fused eltwise ++ return init(engine, base_post_ops, dst_md, post_op_start_index + 1); + } else { + // Nothing to fuse, just copy all post ops +- return init(engine, base_post_ops, dst_md); ++ return init(engine, base_post_ops, dst_md, post_op_start_index); + } + } + +@@ -179,6 +175,9 @@ struct acl_post_ops_t { + private: + // Index of the sum post op if there is one, < 0 means no sum + int sum_index = -1; ++ // Index of the first post op this primitive executes. This is typically the ++ // number of post ops which were fused. ++ int post_op_start_index_ = 0; + data_type_t dst_data_type; + // Vector of primitives used to execute the post ops. They are constructed + // in init to be either acl_binary_t (for sum, add, sub, div, mul, min and +-- +2.34.1 diff --git a/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl b/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl index 61d7809bcdaad1..51e7c35200fd34 100644 --- a/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl +++ b/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -14,6 +18,7 @@ cc_import( cc_library( name = "nccl", %{comment}deps = [":nccl_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/nccl/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/py/BUILD b/third_party/py/BUILD index 7250861f26bfa2..661e8950c4dc2d 100644 --- a/third_party/py/BUILD +++ b/third_party/py/BUILD @@ -53,22 +53,8 @@ config_setting( }, ) -# Flag indicating if the target requires manylinux compliance verification. -bool_flag( - name = "verify_manylinux", - # TODO(ybaturina): Enable the flag by default when hermetic C++ toolchain is ready. - build_setting_default = False, +filegroup( + name = "manylinux_compliance_test", + srcs = ["manylinux_compliance_test.py"], visibility = ["//visibility:public"], ) - -py_binary( - name = "verify_manylinux_compliance", - srcs = [ - "verify_manylinux_compliance.py", - ], - main = "verify_manylinux_compliance.py", - visibility = ["//visibility:public"], - deps = [ - "@pypi_auditwheel//:pkg", - ], -) diff --git a/third_party/py/ml_dtypes/workspace.bzl b/third_party/py/ml_dtypes/workspace.bzl index 29a551da8d0017..962fb487c2d2f4 100644 --- a/third_party/py/ml_dtypes/workspace.bzl +++ b/third_party/py/ml_dtypes/workspace.bzl @@ -7,8 +7,8 @@ float8 varieties, and int4. load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - ML_DTYPES_COMMIT = "c12281a501469d553483eb4d68065826b9c2fcb5" - ML_DTYPES_SHA256 = "cee11c4bed5147bece9e385a88c20887344ad9b89b3acb09bf3d7c9c21fb9715" + ML_DTYPES_COMMIT = "0fa5313b65efe848c5968a15dd37dd220cc29567" + ML_DTYPES_SHA256 = "69c562bb961a21d92357c7709430553c226caac75a751c0aa52955ca14ce8641" tf_http_archive( name = "ml_dtypes", build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD", diff --git a/third_party/py/python_init_rules.bzl b/third_party/py/python_init_rules.bzl index 79bc343aae489e..796ae3d92d999f 100644 --- a/third_party/py/python_init_rules.bzl +++ b/third_party/py/python_init_rules.bzl @@ -8,4 +8,6 @@ def python_init_rules(): sha256 = "62ddebb766b4d6ddf1712f753dac5740bea072646f630eb9982caa09ad8a7687", strip_prefix = "rules_python-0.39.0", url = "https://github.com/bazelbuild/rules_python/releases/download/0.39.0/rules_python-0.39.0.tar.gz", + patch_args = ["-p1"], + patches = [Label("//third_party/py:rules_python.patch")], ) diff --git a/third_party/py/rules_python.patch b/third_party/py/rules_python.patch new file mode 100644 index 00000000000000..ef7ff2fc6f8e52 --- /dev/null +++ b/third_party/py/rules_python.patch @@ -0,0 +1,39 @@ +diff --git a/python/private/pypi/deps.bzl b/python/private/pypi/deps.bzl +index 8949ed4a..8d0ab0e7 100644 +--- a/python/private/pypi/deps.bzl ++++ b/python/private/pypi/deps.bzl +@@ -51,8 +51,8 @@ _RULE_DEPS = [ + ), + ( + "pypi__packaging", +- "https://files.pythonhosted.org/packages/49/df/1fceb2f8900f8639e278b056416d49134fb8d84c5942ffaa01ad34782422/packaging-24.0-py3-none-any.whl", +- "2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5", ++ "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", ++ "09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", + ), + ( + "pypi__pep517", +@@ -61,8 +61,8 @@ _RULE_DEPS = [ + ), + ( + "pypi__pip", +- "https://files.pythonhosted.org/packages/8a/6a/19e9fe04fca059ccf770861c7d5721ab4c2aebc539889e97c7977528a53b/pip-24.0-py3-none-any.whl", +- "ba0d021a166865d2265246961bec0152ff124de910c5cc39f1156ce3fa7c69dc", ++ "https://files.pythonhosted.org/packages/ef/7d/500c9ad20238fcfcb4cb9243eede163594d7020ce87bd9610c9e02771876/pip-24.3.1-py3-none-any.whl", ++ "3790624780082365f47549d032f3770eeb2b1e8bd1f7b2e02dace1afa361b4ed", + ), + ( + "pypi__pip_tools", +diff --git a/python/private/pypi/evaluate_markers.bzl b/python/private/pypi/evaluate_markers.bzl +index c805fd7a..e57e6138 100644 +--- a/python/private/pypi/evaluate_markers.bzl ++++ b/python/private/pypi/evaluate_markers.bzl +@@ -20,7 +20,7 @@ load(":pypi_repo_utils.bzl", "pypi_repo_utils") + SRCS = [ + # When the version, or any of the files in `packaging` package changes, + # this file will change as well. +- Label("@pypi__packaging//:packaging-24.0.dist-info/RECORD"), ++ Label("@pypi__packaging//:packaging-24.2.dist-info/RECORD"), + Label("//python/private/pypi/requirements_parser:resolve_target_platforms.py"), + Label("//python/private/pypi/whl_installer:platform.py"), + ] \ No newline at end of file diff --git a/third_party/remote_config/common.bzl b/third_party/remote_config/common.bzl index 57fb6fcf7aca9a..c70c0ba5b51db6 100644 --- a/third_party/remote_config/common.bzl +++ b/third_party/remote_config/common.bzl @@ -212,7 +212,8 @@ def execute( cmdline, error_msg = None, error_details = None, - allow_failure = False): + allow_failure = False, + env_vars = {}): """Executes an arbitrary shell command. Args: @@ -222,10 +223,11 @@ def execute( error_details: string, details about the error or steps to fix it allow_failure: bool, if True, an empty stdout result or output to stderr is fine, otherwise either of these is an error + env_vars: environment variables Returns: The result of repository_ctx.execute(cmdline) """ - result = raw_exec(repository_ctx, cmdline) + result = raw_exec(repository_ctx, cmdline, env_vars) if (result.stderr or not result.stdout) and not allow_failure: fail( "\n".join([ @@ -236,7 +238,7 @@ def execute( ) return result -def raw_exec(repository_ctx, cmdline): +def raw_exec(repository_ctx, cmdline, env_vars = {}): """Executes a command via repository_ctx.execute() and returns the result. This method is useful for debugging purposes. For example, to print all @@ -245,11 +247,12 @@ def raw_exec(repository_ctx, cmdline): Args: repository_ctx: the repository_ctx cmdline: the list of args + env_vars: environment variables Returns: The 'exec_result' of repository_ctx.execute(). """ - return repository_ctx.execute(cmdline) + return repository_ctx.execute(cmdline, environment = env_vars) def files_exist(repository_ctx, paths, bash_bin = None): """Checks which files in paths exists. diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index 6d102a47289fe0..5675d833f11002 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,117 +1,32 @@ diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index dfa4b78..4f8ac49 100644 +index 509398d..c14fe64 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1,57 +1,42 @@ +@@ -1 +1,12 @@ Auto generated patch. Do not edit or delete it, even if empty. --diff -ruN --strip-trailing-cr a/llvm/lib/ExecutionEngine/Orc/ExecutionUtils.cpp b/llvm/lib/ExecutionEngine/Orc/ExecutionUtils.cpp ----- a/llvm/lib/ExecutionEngine/Orc/ExecutionUtils.cpp --+++ b/llvm/lib/ExecutionEngine/Orc/ExecutionUtils.cpp --@@ -573,7 +573,6 @@ -- // Create __imp_ symbol -- jitlink::Symbol &Ptr = -- jitlink::x86_64::createAnonymousPointer(*G, Sec, &Target); --- auto name = getImpPrefix() + *KV.first; -- Ptr.setName(G->intern((Twine(getImpPrefix()) + *KV.first).str())); -- Ptr.setLinkage(jitlink::Linkage::Strong); -- Ptr.setScope(jitlink::Scope::Default); --diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/bolt/BUILD.bazel b/utils/bazel/llvm-project-overlay/bolt/BUILD.bazel ----- a/utils/bazel/llvm-project-overlay/bolt/BUILD.bazel --+++ b/utils/bazel/llvm-project-overlay/bolt/BUILD.bazel --@@ -285,6 +285,7 @@ -- "//llvm:MCParser", -- "//llvm:Object", -- "//llvm:ObjectYAML", --+ "//llvm:OrcShared", -- "//llvm:Support", -- "//llvm:TargetParser", -- "//llvm:config", --diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ----- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel --+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel --@@ -1442,7 +1442,10 @@ -- hdrs = glob(["src/__support/time/*.h"]), -- deps = [ -- ":__support_common", --+ ":__support_error_or", -- ":hdr_time_macros", --+ ":types_clockid_t", --+ ":types_struct_timespec", -- ":types_time_t", -- ], -- ) --@@ -1486,6 +1489,8 @@ -- ":__support_common", -- ":__support_error_or", -- ":__support_osutil_vdso", --+ ":types_clockid_t", --+ ":types_struct_timespec", -- ], -- ) -+diff -ruN --strip-trailing-cr a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst -+--- a/clang/docs/ReleaseNotes.rst -++++ b/clang/docs/ReleaseNotes.rst -+@@ -796,7 +796,6 @@ -+ - Fixed an assertion failure caused by mangled names with invalid identifiers. (#GH112205) -+ - Fixed an incorrect lambda scope of generic lambdas that caused Clang to crash when computing potential lambda -+ captures at the end of a full expression. (#GH115931) -+-- Clang no longer rejects deleting a pointer of incomplete enumeration type. (#GH99278) - --diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel ----- a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel --+++ b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel --@@ -2800,6 +2800,7 @@ -- ":MC", -- ":MCDisassembler", -- ":Object", --+ ":OrcShared", -- ":OrcTargetProcess", -- ":Passes", -- ":Support", -+ Bug Fixes to AST Handling -+ ^^^^^^^^^^^^^^^^^^^^^^^^^ -+diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp -+--- a/clang/lib/Sema/SemaExprCXX.cpp -++++ b/clang/lib/Sema/SemaExprCXX.cpp -+@@ -3747,8 +3747,7 @@ -+ } else if (!Pointee->isDependentType()) { -+ // FIXME: This can result in errors if the definition was imported from a -+ // module but is hidden. -+- if (!Pointee->isStructureOrClassType() || -+- !RequireCompleteType(StartLoc, Pointee, -++ if (!RequireCompleteType(StartLoc, Pointee, -+ LangOpts.CPlusPlus26 -+ ? diag::err_delete_incomplete -+ : diag::warn_delete_incomplete, -+diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/new-delete.cpp b/clang/test/SemaCXX/new-delete.cpp -+--- a/clang/test/SemaCXX/new-delete.cpp -++++ b/clang/test/SemaCXX/new-delete.cpp -+@@ -540,13 +540,6 @@ -+ void f(A *x) { delete x; } // expected-warning {{delete called on 'PR10504::A' that is abstract but has non-virtual destructor}} -+ } ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/Hexagon/isel/isel-tfrrp.ll b/llvm/test/CodeGen/Hexagon/isel/isel-tfrrp.ll ++--- a/llvm/test/CodeGen/Hexagon/isel/isel-tfrrp.ll +++++ b/llvm/test/CodeGen/Hexagon/isel/isel-tfrrp.ll ++@@ -2,6 +2,7 @@ ++ ; The constant 0 is generated by a transfer immediate instruction. + -+-#if __cplusplus >= 201103L -+-enum GH99278_1 { -+- zero = decltype(delete static_cast(nullptr), 0){} -+- // expected-warning@-1 {{expression with side effects has no effect in an unevaluated context}} -+-}; -+-#endif -+- -+ struct PlacementArg {}; -+ inline void *operator new[](size_t, const PlacementArg &) throw () { -+ return 0; ++ ; RUN: llc -march=hexagon -debug-only=isel 2>&1 < %s - | FileCheck %s +++; REQUIRES: asserts ++ ++ ; CHECK: [[R0:%[0-9]+]]:intregs = A2_tfrsi 0 ++ ; CHECK-NEXT: predregs = C2_tfrrp killed [[R0]]:intregs diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index e60a1c8..7c3347b 100644 +index 02401a7..c35f4e4 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "2ccf7ed277df28651b94bbee9fccefdf22fb074f" -- LLVM_SHA256 = "ca68a54dcd12c0dde32732a90899bf57e0f3f96fc43d8d1124d95a5eae627508" -+ LLVM_COMMIT = "1d95825d4d168a17a4f27401dec3f2977a59a70e" -+ LLVM_SHA256 = "d3276c678b616c0d820fe14a3404b43591f4e1bc75b6bed2782e0776e0c9b401" +- LLVM_COMMIT = "a531800344dc54e9c197a13b22e013f919f3f5e1" +- LLVM_SHA256 = "74a873f8d4c677d192e9bfade095af3363c76b0fb23c5f6260121d74322744bc" ++ LLVM_COMMIT = "35e76b6a4fc74e64bd6c91e5b9b9eb6a03aa802e" ++ LLVM_SHA256 = "bf4e52c430ff8eb2b055a4abcbd70468d2e6ea7f277e472575e92903bd7d8981" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index a2396e5007c48e..a8f7e817753eae 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "cdc7e854703cecf8dcd16db45b92b7be005c4f60" - SHARDY_SHA256 = "13f4f2d5cf241f97ba098ba5683fe066cf075f62cfdcba6287ba3b225a78e40e" + SHARDY_COMMIT = "2ca9cd74b9f9fc5851d0b19c4cc07b1cfc35f0e3" + SHARDY_SHA256 = "502353ad1b00303cab5141ac3a85f4bb6ef61340679353cf79a5d6d1b58139dd" tf_http_archive( name = "shardy", diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 8b137891791fe9..071bba3084c74b 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1 +1,154 @@ +diff --ruN a/stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py b/stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py +--- stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py ++++ stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py +@@ -71,8 +71,15 @@ + + output_file = os.path.relpath( + os.path.normpath( +- os.path.join(os.path.dirname(__file__), "..", "..", "stablehlo", +- "transforms", output_filename)), ++ os.path.join( ++ os.path.dirname(__file__), ++ "..", ++ "..", ++ "stablehlo", ++ "transforms", ++ output_filename, ++ ) ++ ), + os.getcwd(), + ) + +@@ -105,7 +112,8 @@ + func = getattr(fa.algorithms, fname, None) + if func is None: + warnings.warn( +- f"{fa.algorithms.__name__} does not define {fname}. Skipping.") ++ f"{fa.algorithms.__name__} does not define {fname}. Skipping." ++ ) + continue + ctx = fa.Context(paths=[fa.algorithms], + parameters=dict(rewrite_keep_integer_literals=True)) +@@ -116,14 +124,15 @@ + sources[-1] += src + source = "\n\n".join(sources) + "\n" + +- if chloname.startswith('StableHLO_'): ++ if chloname.startswith("StableHLO_"): + # an ugly hack to fix the definition of stablehlo complex math + # functions. TODO(pearu): add the corresponding feature to + # functional_algorithms stablehlo printer +- NameOp = chloname.split('_', 1)[1] ++ NameOp = chloname.split("_", 1)[1] + source = source.replace( +- f'def : Pat<({chloname}', +- f'def {NameOp}_ComplexElementType_ComplexMathExpander : Pat<({chloname}' ++ f"def : Pat<({chloname}", ++ f"def {NameOp}_ComplexElementType_ComplexMathExpander :" ++ f" Pat<({chloname}", + ) + + if os.path.isfile(output_file): +diff --ruN a/stablehlo/build_tools/math/generate_tests.py b/stablehlo/build_tools/math/generate_tests.py +--- stablehlo/build_tools/math/generate_tests.py ++++ stablehlo/build_tools/math/generate_tests.py +@@ -64,10 +64,12 @@ + dict(name="acosh", mpmath_name="arccosh"), + dict(name="atanh", mpmath_name="arctanh"), + dict(name="square", mpmath_name="square"), +- dict(name="log_plus_one", +- mpmath_name="log1p", +- namespace="stablehlo", +- passes="--stablehlo-complex-math-expander"), ++ dict( ++ name="log_plus_one", ++ mpmath_name="log1p", ++ namespace="stablehlo", ++ passes="--stablehlo-complex-math-expander", ++ ), + ] + + +@@ -138,13 +140,16 @@ + params = fa.utils.function_validation_parameters(opname, dtype) + max_ulp_difference = op.get( + "max_ulp_difference", +- params.get("max_valid_ulp_count", default_max_ulp_difference)) ++ params.get("max_valid_ulp_count", default_max_ulp_difference), ++ ) + + nmp = fa.utils.numpy_with_mpmath( + extra_prec_multiplier=op.get( + "extra_prec_multiplier", +- params.get("extra_prec_multiplier", +- default_extra_prec_multiplier)), ++ params.get( ++ "extra_prec_multiplier", default_extra_prec_multiplier ++ ), ++ ), + flush_subnormals=flush_subnormals, + ) + +@@ -208,8 +213,10 @@ + continue + + f = open(fname, "w") +- f.write(f"// RUN: stablehlo-opt {passes} %s |" +- " stablehlo-translate --interpret\n") ++ f.write( ++ f"// RUN: stablehlo-opt {passes} %s |" ++ " stablehlo-translate --interpret\n" ++ ) + f.write( + "// This file is generated, see build_tools/math/README.md for more" + " information.\n") +diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp b/stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp +--- stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp ++++ stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp +@@ -107,6 +107,8 @@ + + LinalgTypeConverter::LinalgTypeConverter() : RemoveSignTypeConverter() { + addArgumentMaterialization(scalarToTensor); ++ addSourceMaterialization(scalarToTensor); ++ addTargetMaterialization(scalarToTensor); + } + + } // namespace mlir::stablehlo +diff --ruN a/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir b/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir +--- stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir ++++ stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir +@@ -440,7 +440,6 @@ + } + + // ----- +- + + // CHECK-LABEL: func.func @asinh_f64( + // CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +@@ -2788,7 +2787,6 @@ + + // ----- + +- + // CHECK-LABEL: @sinh_f32 + // CHECK-SAME: (%[[X:.*]]: tensor) + func.func @sinh_f32(%x : tensor) -> tensor { +@@ -3891,6 +3889,8 @@ + return + } + ++// ----- ++ + // CHECK-LABEL: @square_complex_f32( + // CHECK-SAME: %[[VAL_0:.*]]: tensor>) -> tensor> { + // CHECK: %[[VAL_1:.*]] = stablehlo.real %[[VAL_0]] : (tensor>) -> tensor +@@ -3916,6 +3916,8 @@ + func.return %result : tensor> + } + ++// ----- ++ + // CHECK-LABEL: @square_f32( + // CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { + // CHECK: %[[VAL_1:.*]] = stablehlo.multiply %[[VAL_0]], %[[VAL_0]] : tensor diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 20badb638791f8..dfae5f53d44715 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "b3d3cacde8994df313297e68713ed74c2ca279ee" - STABLEHLO_SHA256 = "8bb81d7f60f19493b1edfc916adcfe1f9d1deeaf77c9ca7a896e05861505817d" + STABLEHLO_COMMIT = "38bb2f9bf63b714e8a49fe34a478139058ee1660" + STABLEHLO_SHA256 = "74feb9f9f34eb4dd0b11404371af58f7a5a5ded177d38b01b53174ce757a3a61" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/systemlibs/double_conversion.BUILD b/third_party/systemlibs/double_conversion.BUILD deleted file mode 100644 index 568460181ae0bc..00000000000000 --- a/third_party/systemlibs/double_conversion.BUILD +++ /dev/null @@ -1,12 +0,0 @@ -licenses(["notice"]) - -filegroup( - name = "LICENSE", - visibility = ["//visibility:public"], -) - -cc_library( - name = "double-conversion", - linkopts = ["-ldouble-conversion"], - visibility = ["//visibility:public"], -) diff --git a/third_party/systemlibs/grpc.bazel.generate_cc.bzl b/third_party/systemlibs/grpc.bazel.generate_cc.bzl index c659ca16366b7a..aa5d18eaa9a488 100644 --- a/third_party/systemlibs/grpc.bazel.generate_cc.bzl +++ b/third_party/systemlibs/grpc.bazel.generate_cc.bzl @@ -11,6 +11,7 @@ load( "get_proto_root", "proto_path_to_generated_filename", ) +load("@rules_proto//proto:defs.bzl", "ProtoInfo") _GRPC_PROTO_HEADER_FMT = "{}.grpc.pb.h" _GRPC_PROTO_SRC_FMT = "{}.grpc.pb.cc" diff --git a/third_party/systemlibs/grpc.bazel.protobuf.bzl b/third_party/systemlibs/grpc.bazel.protobuf.bzl index 3eca97dc2311fb..cfb124ce43b1ef 100644 --- a/third_party/systemlibs/grpc.bazel.protobuf.bzl +++ b/third_party/systemlibs/grpc.bazel.protobuf.bzl @@ -1,5 +1,7 @@ """Utility functions for generating protobuf code.""" +load("@rules_proto//proto:defs.bzl", "ProtoInfo") + _PROTO_EXTENSION = ".proto" _VIRTUAL_IMPORTS = "/_virtual_imports/" diff --git a/third_party/systemlibs/syslibs_configure.bzl b/third_party/systemlibs/syslibs_configure.bzl index f2fc22480f4989..3c734e475f412b 100644 --- a/third_party/systemlibs/syslibs_configure.bzl +++ b/third_party/systemlibs/syslibs_configure.bzl @@ -21,7 +21,6 @@ VALID_LIBS = [ "curl", "cython", "dill_archive", - "double_conversion", "flatbuffers", "functools32_archive", "gast_archive", diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 5b0519e9359709..30aeaeed284332 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "d02348ca01f8dbe413b11394dd913aa69002a378" - TFRT_SHA256 = "0548608af9f64645e68b8eb922fded98d014408685f26e6f4ab5f635c0140e48" + TFRT_COMMIT = "c6ecd4a29d5052301238120206d6aaa287a4cdc0" + TFRT_SHA256 = "653cef57364a4f716be6565cbd20a499d1ccb9c1b6530b2f75cd4460bee81e89" tf_http_archive( name = "tf_runtime", diff --git a/third_party/triton/llvm_integration/cl704999069.patch b/third_party/triton/llvm_integration/cl704999069.patch new file mode 100644 index 00000000000000..95dd8fe8292fed --- /dev/null +++ b/third_party/triton/llvm_integration/cl704999069.patch @@ -0,0 +1,21 @@ + +--- a/lib/Dialect/Triton/Transforms/Combine.td 2024-12-05 23:53:31.000000000 -0800 ++++ b/lib/Dialect/Triton/Transforms/Combine.td 2024-12-11 00:38:55.000000000 -0800 +@@ -17,7 +17,7 @@ + [(Constraint> $c), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; + def CombineDotAddFPattern : Pat< +- (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath, $denorm), ++ (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), +@@ -29,7 +29,7 @@ + [(Constraint> $c), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; + def CombineDotAddFRevPattern : Pat< +- (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath, $denorm), ++ (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), diff --git a/third_party/triton/temporary/const_signature_fixes.patch b/third_party/triton/temporary/const_signature_fixes.patch new file mode 100644 index 00000000000000..26c3d8014e953f --- /dev/null +++ b/third_party/triton/temporary/const_signature_fixes.patch @@ -0,0 +1,92 @@ +diff --git a/third_party/f2reduce/f2reduce.cpp b/third_party/f2reduce/f2reduce.cpp +--- a/third_party/f2reduce/f2reduce.cpp ++++ b/third_party/f2reduce/f2reduce.cpp +@@ -470,8 +470,8 @@ namespace f2reduce { + + void inplace_rref_strided(uint64_t *matrix, uint64_t rows, uint64_t cols, uint64_t stride) { + +- if (rows <= 1) { +- // If the matrix has 0 or 1 rows, it must already be in RREF: ++ if (rows <= 1 || cols <= 1) { ++ // If the matrix has 0 or 1 rows or columns, it must already be in RREF: + return; + } + +diff --git a/third_party/nvidia/backend/cuda_utils.cc b/third_party/nvidia/backend/cuda_utils.cc +--- a/third_party/nvidia/backend/cuda_utils.cc ++++ b/third_party/nvidia/backend/cuda_utils.cc +@@ -276,8 +276,10 @@ const ExtractionInfo kExtractionInfos[]{ + ExtractionInfo::build({"'u64'"}), + ExtractionInfo::build({"'fp16'", "'bf16'", "'fp32'", "'f32'"}), + ExtractionInfo::build({"'fp64'"}), ++ // Note: types are e.g. '*fp32', so no closing quote is intentional. + ExtractionInfo::build({"'*"}, extractPointer), +- ExtractionInfo{{"None"}, 0, nullptr}, // Represent constexprs as None ++ ExtractionInfo{ ++ {"None", "'none'"}, 0, nullptr}, // Represent constexprs as None + }; + + // Finds an extractor that supports a given type_repr in the extractor list. +diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py +--- a/third_party/nvidia/backend/driver.py ++++ b/third_party/nvidia/backend/driver.py +@@ -92,7 +92,22 @@ def ty_to_cpp(ty): + }[ty] + + +-def make_launcher(constants : dict[int, str], signature : dict[int, any]) -> Callable[..., None]: ++def flatten_tuples(xs): ++ """Recursively flattens tuple elements in xs.""" ++ for x in xs: ++ if isinstance(x, tuple): ++ yield from flatten_tuples(x) ++ else: ++ yield x ++ ++ ++def make_launcher(constants : dict[int, str], signature : dict[int, any], ids : dict[str, tuple]) -> Callable[..., None]: ++ ++ signature = {k: v for k, v in signature.items() if v != 'constexpr'} ++ signature = ','.join(signature.values()).replace('[', '').replace(']', '') ++ signature = list(filter(bool, signature.split(','))) ++ signature = {i: s for i, s in enumerate(signature)} ++ + # We seem to have 3 categories of arguments: + # 1. arguments listed in signature + # 2. arguments listed in constants +@@ -103,8 +118,8 @@ def make_launcher(constants : dict[int, + # category (3). The generic C++ launcher currently does not do that, so we + # are doing it in the python wrapper. + signature_metadata = cuda_utils.build_signature_metadata( +- ty if arg_id not in constants else None +- for arg_id, ty in signature.items()) ++ ty for ty in signature.values()) ++ + def wrapper(grid_dim_x: int, grid_dim_y: int, grid_dim_z: int, + stream: int, kernel: int, global_scratch: any, + packed_metadata: tuple[int, int, int, int, int, int], +@@ -115,18 +130,18 @@ def make_launcher(constants : dict[int, + cuda_utils.launch(grid_dim_x, grid_dim_y, grid_dim_z, stream, kernel, + packed_metadata, hook_args, launch_enter_hook, + launch_exit_hook, signature_metadata, global_scratch, +- args) ++ flatten_tuples(args)) + return wrapper + + + class CudaLauncher(object): + + def __init__(self, src, metadata): +- constants = getattr(src, "constants", dict()) +- cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i +- constants = {cst_key(key): value for key, value in constants.items()} +- signature = {cst_key(key): value for key, value in src.signature.items()} +- self.launch = make_launcher(constants, signature) ++ ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} ++ constants = src.constants if hasattr(src, "constants") else dict() ++ constants = {idx: value for idx, value in constants.items()} ++ signature = {idx: value for idx, value in src.signature.items()} ++ self.launch = make_launcher(constants, signature, ids) + self.global_scratch_size = metadata.global_scratch_size + self.global_scratch_align = metadata.global_scratch_align + diff --git a/third_party/triton/temporary/numpy_type_promotion.patch b/third_party/triton/temporary/numpy_type_promotion.patch new file mode 100644 index 00000000000000..e41638db8fcaf8 --- /dev/null +++ b/third_party/triton/temporary/numpy_type_promotion.patch @@ -0,0 +1,12 @@ +--- a/python/test/unit/language/test_core.py ++++ b/python/test/unit/language/test_core.py +@@ -363,8 +363,7 @@ def _test_binary(dtype_x, dtype_y, expr, + # We remove any explicit casting + pattern = r'\.astype\(np\.\w+\)' + scalar_expr = expr if numpy_expr is None else re.sub(pattern, '', numpy_expr) +- with promotion_numpy_2_0(): +- z_ref = eval(scalar_expr) ++ z_ref = eval(scalar_expr) + else: + z_ref = eval(expr if numpy_expr is None else numpy_expr) + diff --git a/third_party/triton/temporary/revert_67ea999.patch b/third_party/triton/temporary/revert_67ea999.patch new file mode 100644 index 00000000000000..22239930a1005c --- /dev/null +++ b/third_party/triton/temporary/revert_67ea999.patch @@ -0,0 +1,556 @@ +This patch is reverting https://github.com/triton-lang/triton/commit/67ea999935f4511a535a25bdecb27e79e3c3af41 +which breaks //learning/deepmind/jax/triton/ops:attention_test_gpu_a100 +The patch is very intrusive due to how big the change is, so it should be prioritized for removal. +This is tracked in b/385090655. + +diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h +--- a/include/triton/Tools/LinearLayout.h ++++ b/include/triton/Tools/LinearLayout.h +@@ -681,6 +681,13 @@ public: + // (i.e. every input bit affects the output). + llvm::MapVector getFreeVariableMasks() const; + ++ // Increase an input dimension without affecting the output dimension. The ++ // added free variables are mapped to 0, ensuring that the new input ++ // dimensions correspond directly to the existing output space. The function ++ // errors out if `newInDimSize` is less than the current size or the new size ++ // is not a power of 2. ++ LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const; ++ + std::string toString() const; + + friend bool operator==(LinearLayout lhs, LinearLayout rhs); +diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp +--- a/lib/Analysis/Utility.cpp ++++ b/lib/Analysis/Utility.cpp +@@ -683,8 +683,42 @@ std::optional minimalCvtLa + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + StringAttr kBlock = StringAttr::get(ctx, "block"); +- +- auto comp = dstLayout->invertAndCompose(*srcLayout); ++ auto numSrcRegs = srcLayout->getInDimSize(kRegister); ++ auto numDstRegs = dstLayout->getInDimSize(kRegister); ++ // The `invertAndCompose` function will generate a layout that is injective ++ // by assigning new output dimensions to free variables. For instance, ++ // consider a scenario where `srcLayout` has a free variable in the lane ++ // dimension, while `dstLayout` has two free variables in the lane ++ // dimension and also a larger number of registers. ++ // The injective form of `srcLayout` will add only a single additional row ++ // to the transformation matrix, whereas the injective form of `dstLayout` ++ // will add two additional rows. This discrepancy causes misleading results ++ // because the matrices end up with a different number of rows. ++ // ++ // Take `dstLayout ⋅ srcLayout^-1` as an example: ++ // ++ // - `injective(dstLayout)`: [n, m] → [n + 2, m] ++ // - `injective(srcLayout)`: [n, m] → [n + 1, m] ++ // - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1] ++ // - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n + ++ // 1] → [n + 2, n + 1] ++ // ++ // Here, the `(n + 1)`-th row added by `dstLayout` represents the free ++ // variable in registers, and the `(n + 2)`-th row represents the free ++ // variable in lanes. However, the `(n + 1)`-th row added by `srcLayout` ++ // represents the free variable in lanes. As a result, the `(n + 1)`-th row ++ // in two layouts do not correspond to the same free variable. ++ // ++ // To address this issue, we pad the free variables in `srcLayout` and ++ // `dstLayout` to ensure they have the same number of registers. This ++ // guarantees that the resulting matrices have the same number of rows, ++ // ensuring consistency in the composition process. ++ auto numRegs = std::max(numSrcRegs, numDstRegs); ++ auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs); ++ auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs); ++ // comp describes the layout function to create dst from src. ++ LinearLayout comp = ++ dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs); + // We try to quotient by the largest subspace first + auto dims = SmallVector{"block", "warp", "lane", "register"}; + for (auto dim : dims) { +diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +--- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp ++++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +@@ -315,10 +315,14 @@ struct ConvertLayoutOpUsingLinearLayouts + // TODO(Keren): implement warp shuffle instead of using the general + // approach that uses shared memory + return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); +- } else if (llvm::is_contained(dims, kRegister)) { ++ } else if (llvm::is_contained(dims, kRegister) || ++ dstLayout.getInDimSize(kRegister) != ++ srcLayout.getInDimSize(kRegister)) { + // Case 4. Transfer between values in the same thread, in which case we + // simply reorder the elements of adaptor.getSrc(). +- return transferWithinThread(op, *conversion, adaptor, rewriter); ++ return transferWithinThread( ++ op, dstLayout.getFreeVariableMasks()[kRegister], ++ dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter); + } else { + // Cast 5. The two layouts are equivalent. We should probably remove + // these in RemoveLayoutConversion. +@@ -328,8 +332,8 @@ struct ConvertLayoutOpUsingLinearLayouts + } + + LogicalResult +- transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion, +- OpAdaptor adaptor, ++ transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs, ++ const LinearLayout &conversion, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); +@@ -339,9 +343,16 @@ struct ConvertLayoutOpUsingLinearLayouts + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); +- SmallVector outVals(conversion.getInDimSize(kRegister)); +- for (int i = 0; i < outVals.size(); i++) { +- auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; ++ SmallVector outVals(numRegs); ++ for (int i = 0; i < numRegs; i++) { ++ // Remove free masks from the register index ++ // For example, if idx = 0b00111, and masks = 0b00100, then we get ++ // 0b00011. It means that register 7 (0b111) has the same value as ++ // register 3 (0b011). ++ auto idx = i & (~regMasks); ++ auto srcIdx = conversion.hasInDim(kRegister) ++ ? conversion.apply({{kRegister, idx}}).begin()->second ++ : idx; + outVals[i] = inVals[srcIdx]; + } + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, +diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp +--- a/lib/Tools/LinearLayout.cpp ++++ b/lib/Tools/LinearLayout.cpp +@@ -112,6 +112,30 @@ std::unique_ptr getMatrix(co + return m; + } + ++// Get a matrix for `layout` with its codomain expanded so it's injective, i.e. ++// each input element maps to a unique output element. We do this by finding ++// columns that are equal to 0 and adding a new row with a 1 in that column. ++std::tuple, int /*numRows*/, int /*numCols*/> ++getInjectiveMat(const LinearLayout &layout) { ++ int numRows = layout.getTotalOutDimSizeLog2(); ++ int numCols = layout.getTotalInDimSizeLog2(); ++ std::unique_ptr mat = getMatrix(layout); ++ ++ // Bits of mat or-reduced along the columns (so there's just one row). ++ uint64_t colBits = 0; ++ for (int r = 0; r < numRows; r++) { ++ colBits |= mat[r]; ++ } ++ auto expanded = std::unique_ptr(new uint64_t[numRows + numCols]); ++ std::memcpy(expanded.get(), mat.get(), numRows * sizeof(uint64_t)); ++ for (int c = 0; c < numCols; c++) { ++ if ((colBits & (1 << c)) == 0) { ++ expanded[numRows++] = (1 << c); ++ } ++ } ++ return std::make_tuple(std::move(expanded), numRows, numCols); ++} ++ + // Compute the rank of the matrix formed by taking the bases for the given + // outDim as columns. In other words, finds the number of linearly-independent + // bases for this output dimension. +@@ -780,179 +804,118 @@ LinearLayout LinearLayout::compose(const + compositionIsSurjective); + } + +-namespace { +-std::unique_ptr concatMatrices(const LinearLayout &A, +- const LinearLayout &B) { +- // In plain words, "convert_layout does not change the shape of a tensor" +- assert(A.getTotalOutDimSizeLog2() == B.getTotalOutDimSizeLog2() && +- "Matrices must have the same number of output dimensions"); +- int numRows = A.getTotalOutDimSizeLog2(); +- int numColsA = A.getTotalInDimSizeLog2(); +- +- // rref expects the lower bits to be the lower indices of the matrix +- auto concat = getMatrix(A); +- auto BMat = getMatrix(B); +- for (int r = 0; r < numRows; r++) { +- concat[r] |= BMat[r] << numColsA; ++LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const { ++ assertDimsEqualIgnoringOrder(getOutDimNames(), outer.getOutDimNames()); ++ for (StringAttr outDim : getOutDimNames()) { ++ assert(getOutDimSize(outDim) <= outer.getOutDimSize(outDim)); + } +- return concat; +-} ++ assert(outer.isSurjective()); + +-LinearLayout lstsq(const LinearLayout &A, const LinearLayout &B) { +- // Solve the least square system AX = B for A = outer, B = *this +- // and return the least square solution X of minimal norm +- // A and B may not be surjective, but we assume that Im(B) \subset Im(A) +- // Sketch of the algorithm: +- // https://github.com/triton-lang/triton/pull/5309#discussion_r1869084111 +- int numRows = A.getTotalOutDimSizeLog2(); +- int numColsA = A.getTotalInDimSizeLog2(); +- int numColsB = B.getTotalInDimSizeLog2(); +- int numCols = numColsA + numColsB; +- std::unique_ptr combinedMat = concatMatrices(A, B); +- f2reduce::inplace_rref_strided(combinedMat.get(), numRows, numCols, ++ // Make both `this` and `outer` injective. We need to do this on the ++ // `outer` layout because we can't invert a non-injective function. We ++ // choose to do so on the `this` layout as well. The rest of the comment ++ // explains why we make that choice. ++ // ++ // Recall from the header that C = A.invertAndCompose(B) just means that ++ // A(x) = B(C(x)). ++ // ++ // Sometimes we may have a choice of multiple values for a particular ++ // C(x). For example, if A(1) = B(0) = B(1) = 0, then C(1) can be either 0 ++ // or 1. ++ // ++ // We want to choose C such that C(x) != 0 where possible. For example, ++ // suppose we are transferring from registers to registers and we have the ++ // following layouts. ++ // ++ // A(thread=1, block=0) = 1 ++ // A(thread=2, block=0) = 2 ++ // A(thread=0, block=1) = 0 ++ // ++ // B(thread=1, block=0) = 2 ++ // B(thread=2, block=0) = 1 ++ // B(thread=0, block=1) = 0 ++ // ++ // Notice that A and B both have the same data in each of their two ++ // blocks. So if we want to transfer from A to B, we don't need to cross ++ // blocks, which is expensive. We want A.invertAndCompose(B) to reflect ++ // that choice. ++ // ++ // Let A' be A with the last line changed to "=4", and similarly for B'. ++ // When transferring from A' to B', we can't cross blocks even if we wanted ++ // to, because the two blocks now have different data. But also, any ++ // mapping of thread+block from A' to B' is also valid for mapping from A ++ // to B. ++ // ++ // Thus making A and B injective encodes our desire not to cross blocks, ++ // or more generally our desire that C(x) != 0 where possible. ++ auto [matThis, numRowsThis, numColsThis] = getInjectiveMat(*this); ++ auto [matOuter, numRowsOuter, numColsOuter] = getInjectiveMat( ++ outer.transposeOuts(llvm::to_vector(this->getOutDimNames()))); ++ ++ // Concatenate `matOuter` and `matThis` horizontally (i.e. `matThis` ++ // is to the right of `matOuter`). ++ int combinedNumRows = std::max(numRowsThis, numRowsOuter); ++ int combinedNumCols = numColsThis + numColsOuter; ++ assert(combinedNumCols <= 64 && "Can't handle huge layouts"); ++ ++ std::unique_ptr m(new uint64_t[combinedNumRows]()); ++ for (int r = 0; r < numRowsOuter; r++) { ++ m[r] = matOuter[r]; ++ } ++ for (int r = 0; r < numRowsThis; r++) { ++ m[r] |= matThis[r] << numColsOuter; ++ } ++ ++ // Perform Gaussian elimination on `m`. Because `outer` was modified to ++ // be bijective, the first half of the matrix should be the identity ++ // matrix. The remaining half are the bases for the combined ++ // transformation. ++ // ++ // `stride` is specified in number of 64-bit words per row, and we pack ++ // our matrix so that there's only one uint64_t per row. ++ f2reduce::inplace_rref_strided(m.get(), combinedNumRows, combinedNumCols, + /*stride=*/1); + +- // Compute the pivot columns +- // Since A and B have the same image, each row will either have a pivot +- // or will be all zeros +- SmallVector pivotCols; +- for (int r = 0; r < numRows; r++) { +- auto row = combinedMat[r]; +- if (row == 0) { +- continue; ++ // Check that the first half of the matrix is indeed the identity. ++ for (int r = 0; r < std::min(numRowsOuter, numColsOuter); r++) { ++ for (int c = 0; c < std::min(numColsOuter, numRowsOuter); c++) { ++ if (((m[r] >> c) & 1) != (r == c ? 1 : 0)) { ++ llvm::report_fatal_error("First half of the matrix was not the " ++ "identity, bug in invertAndCompose"); ++ } + } +- int c = __builtin_ctzll(row); +- assert(c < numColsA && "Precondition broken. Im(B) not contained in Im(A)"); +- assert(pivotCols.empty() || +- pivotCols.back() < c && "Pivot columns are not in increasing order"); +- pivotCols.push_back(c); +- } +- +- // Extract A^{-1}B and complete the matrix using zeros +- std::unique_ptr retMat(new uint64_t[numColsA]()); +- int j = 0; +- for (int r = 0; r < numColsA; r++) { +- auto isPivot = j < pivotCols.size() && pivotCols[j] == r; +- retMat[r] = isPivot ? combinedMat[j++] >> numColsA : 0; + } + + // We need names for the in/out dim of the flattened layout we're going to + // read off from `m`. These could be anything, doesn't matter. +- StringAttr inDim1D = *A.getInDimNames().begin(); +- StringAttr outDim1D = *A.getOutDimNames().begin(); ++ StringAttr inDim1D = *getInDimNames().begin(); ++ StringAttr outDim1D = *getOutDimNames().begin(); + + // Read off the new bases. These are for a flattened 1D -> 1D +- LinearLayout::BasesT retBases; +- auto &bs = retBases[inDim1D]; +- for (int c = 0; c < numColsB; c++) { ++ // transformation from `this`'s in-dims to `outer`'s in-dims. ++ BasesT newBases; ++ auto &bs = newBases[inDim1D]; ++ for (int c = 0; c < numColsThis; c++) { + int32_t basis = 0; +- for (int r = 0; r < numColsA; r++) { +- basis |= (retMat[r] >> c & 1) << r; ++ for (int r = 0; r < numRowsOuter; r++) { ++ basis |= (m[r] >> (numColsOuter + c) & 1) << r; + } + bs.push_back({basis}); + } + +- LinearLayout retFlattened(std::move(retBases), +- {{outDim1D, A.getTotalInDimSize()}}, ++ LinearLayout flatComposed(std::move(newBases), ++ {{outDim1D, outer.getTotalInDimSize()}}, + /*requireSurjective=*/false); + + SmallVector> retInDims; + SmallVector> retOutDims; +- for (StringAttr dim : B.getInDimNames()) { +- retInDims.push_back({dim, B.getInDimSize(dim)}); +- } +- for (StringAttr dim : A.getInDimNames()) { +- retOutDims.push_back({dim, A.getInDimSize(dim)}); +- } +- return retFlattened.reshapeIns(retInDims).reshapeOuts(retOutDims); +-} +- +-} // namespace +- +-LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const { +- // TODO(Lezcano) Make friend and perhaps rename to `convertFrom` or `lstsq` +- // For this, we need to implement our LLVM lowerings by inverting the "outer" +- // layout, and then iterating over the elements from the "this" layout and +- // fetching the corresponding element from the "outer" layout. This exercises +- // the broadcasting that we incentivise via choosing the minimum norm solution +- // in lstsq. +- +- // The order of dims does not matter. We choose to transpose outer +- auto outDims = llvm::to_vector(getOutDimNames()); +- assertDimsEqualIgnoringOrder(outDims, outer.getOutDimNames()); +- const auto &B = *this; +- const auto A = outer.transposeOuts(outDims); +- for (auto dim : outDims) { +- assert(A.getOutDimSize(dim) == B.getOutDimSize(dim) && +- "Convert layout does not change the shape of a tensor"); ++ for (StringAttr dim : getInDimNames()) { ++ retInDims.push_back({dim, getInDimSize(dim)}); + } +- +- // We'll write A^{-1} to mean the inverse or the pseudo-inverse of A +- // We are computing A^{-1}B so A must be surjective so that +- // it has a left inverse. +- assert(A.isSurjective()); +- +- // Broadcasting heuristic +- // Imagine we have two layouts with `warps = [[0, 0],  [0, 0]]` +- // (broadcasting) on both layouts. We could map any warp to any warp in the +- // conversion. Now, we want to map them as the identity map, to mark that +- // nothing needs to be done there (`lstsq` would map all the warps to the +- // zero warp, minimum norm solution). The heuristic here is as follows: +- // - If a dimension is the same for both layouts, we want to map it as the +- // identity +- // Equivalently, we don't add it to the conversion +- // - Otherwise, we just call lstsq (i.e. map all the equivalent elements +- // to the same input element) to take advantage of broadcasting in shared +- // memory and avoid saving repeated elements in shared memory +- SmallVector identityDims; +- for (auto dim : A.getInDimNames()) { +- if (B.hasInDim(dim) && +- A.sublayout(dim, outDims) == B.sublayout(dim, outDims)) { +- identityDims.push_back(dim); +- } +- } +- SmallVector ANonIdentityInDims; +- SmallVector BNonIdentityInDims; +- for (auto dim : A.getInDimNames()) { +- if (!llvm::is_contained(identityDims, dim)) { +- ANonIdentityInDims.push_back(dim); +- } ++ for (StringAttr dim : outer.getInDimNames()) { ++ retOutDims.push_back({dim, outer.getInDimSize(dim)}); + } +- for (auto dim : B.getInDimNames()) { +- if (!llvm::is_contained(identityDims, dim)) { +- BNonIdentityInDims.push_back(dim); +- } +- } +- +- auto AReduced = A.sublayout(ANonIdentityInDims, outDims); +- auto BReduced = B.sublayout(BNonIdentityInDims, outDims); +- +- // If one is empty, the other must be empty as well +- assert((AReduced == LinearLayout::empty()) == +- (BReduced == LinearLayout::empty())); +- bool isEmpty = AReduced == LinearLayout::empty(); +- +- auto ret = isEmpty ? LinearLayout::empty() : lstsq(AReduced, BReduced); +- +- // TODO(Lezcano): We should return the reduced layout instead of re-adding the +- // identity maps. With this, we'll be able to kill `minimalCvtLayout` +- +- // Add the identity maps for the dimensions that are the same for both layouts +- for (auto dim : identityDims) { +- ret *= LinearLayout::identity1D(A.getInDimSize(dim), dim, dim); +- } +- +- // Reshape the result +- SmallVector> inDimsA; +- SmallVector> inDimsB; +- for (auto dim : A.getInDimNames()) { +- inDimsA.push_back({dim, A.getInDimSize(dim)}); +- } +- for (auto dim : B.getInDimNames()) { +- inDimsB.push_back({dim, B.getInDimSize(dim)}); +- } +- ret = ret.reshapeIns(inDimsB).reshapeOuts(inDimsA); +- return ret; ++ return flatComposed.reshapeIns(retInDims).reshapeOuts(retOutDims); + } + + llvm::MapVector +@@ -1041,6 +1004,21 @@ bool LinearLayout::equalIgnoringOutDimSi + return true; + } + ++LinearLayout LinearLayout::resize(StringAttr inDim, ++ int32_t newInDimSize) const { ++ BasesT bases = getBases(); ++ assert(bases.contains(inDim) && "inDim not in layout"); ++ assert(llvm::isPowerOf2_32(newInDimSize) && ++ "newInDimSize must be a power of 2"); ++ assert(newInDimSize >= getInDimSize(inDim) && ++ "newInDimSize must be >= old size"); ++ auto numFreeVariables = llvm::Log2_32(newInDimSize) - getInDimSizeLog2(inDim); ++ for (int i = 0; i < numFreeVariables; i++) { ++ bases[inDim].push_back(std::vector(getNumOutDims(), 0)); ++ } ++ return LinearLayout(std::move(bases), llvm::to_vector(getOutDimNames())); ++} ++ + std::string LinearLayout::toString() const { + // Start with a newline because we print out a bulleted list; it doesn't + // make sense for the first line of this list to be on the same line as +diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir +--- a/test/Conversion/tritongpu_to_llvm.mlir ++++ b/test/Conversion/tritongpu_to_llvm.mlir +@@ -1698,7 +1698,8 @@ module attributes {"ttg.target" = "cuda: + // CHECK-LABEL: convert_single_element + // CHECK-NOT: llvm.store + // CHECK-NOT: llvm.load +- // CHECK: llvm.return ++ // CHECK: llvm.insertvalue ++ // CHECK: llvm.extractvalue + tt.func public @convert_single_element() attributes {noinline = false} { + %cst = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked1> + %0 = ttg.convert_layout %cst : tensor<1xf32, #blocked1> -> tensor<1xf32, #blocked> +diff --git a/unittest/Tools/LinearLayoutTest.cpp b/unittest/Tools/LinearLayoutTest.cpp +--- a/unittest/Tools/LinearLayoutTest.cpp ++++ b/unittest/Tools/LinearLayoutTest.cpp +@@ -410,6 +410,26 @@ TEST_F(LinearLayoutTest, InvertAndCompos + EXPECT_EQ(composition.compose(l2), l1); + } + ++TEST_F(LinearLayoutTest, InvertAndCompose_SmallerResult) { ++ // The domain of l2 is [0,16), but the codomain of the result is only [0,8), ++ // because there's no value v in the codomain of l1 such that l2^-1(v) >= 8. ++ LinearLayout l1({{S("in1"), {{1}, {2}, {4}}}}, {S("out")}); ++ LinearLayout l2({{S("in2"), {{4}, {1}, {2}, {8}}}}, {S("out")}); ++ // Pseudo-inverse of l2 is ++ // ++ // out(1) = 2 ++ // out(2) = 4 ++ // out(4) = 1 ++ // out(8) = 8 ++ // ++ // Composing with l1 gives back l2^-1 without the out(8) entry. ++ LinearLayout composition = l1.invertAndCompose(l2); ++ EXPECT_EQ(composition, ++ LinearLayout({{S("in1"), {{2}, {4}, {1}}}}, {{S("in2"), 16}}, ++ /*requireSurjective=*/false)); ++ EXPECT_TRUE(composition.compose(l2).equalIgnoringOutDimSizes(l1)); ++} ++ + TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastedInDim) { + LinearLayout l1({{S("in1"), {{2}, {1}, {4}}}, {S("in2"), {{0}}}}, {S("out")}); + LinearLayout l2({{S("in"), {{4}, {1}, {2}}}}, {S("out")}); +@@ -494,10 +514,8 @@ TEST_F(LinearLayoutTest, InvertAndCompos + LinearLayout l1({{S("in1"), {{1}, {2}, {4}}}, {S("in2"), {{0}}}}, {S("out")}); + LinearLayout l2({{S("in3"), {{1}, {2}, {4}}}, {S("in4"), {{0}}}}, {S("out")}); + LinearLayout c = l1.invertAndCompose(l2); +- EXPECT_EQ(c, LinearLayout( +- {{S("in1"), {{1, 0}, {2, 0}, {4, 0}}}, {S("in2"), {{0, 0}}}}, +- {{S("in3"), 8}, {S("in4"), 2}}, +- /*requireSurjective=*/false)); ++ EXPECT_EQ(c, LinearLayout::identity1D(8, S("in1"), S("in3")) * ++ LinearLayout::identity1D(2, S("in2"), S("in4"))); + EXPECT_EQ(c.compose(l2), + l1.transposeOuts(llvm::to_vector(l2.getOutDimNames()))); + } +@@ -507,9 +525,8 @@ TEST_F(LinearLayoutTest, InvertAndCompos + LinearLayout b({{S("in3"), {{2}, {1}}}, {S("in4"), {{0}}}}, {S("out")}); + LinearLayout c = a.invertAndCompose(b); + EXPECT_EQ(c, +- LinearLayout({{S("in1"), {{2, 0}, {1, 0}}}, {S("in2"), {{0, 0}}}}, +- {{S("in3"), 4}, {S("in4"), 2}}, +- /*requireSurjective=*/false)); ++ LinearLayout({{S("in1"), {{2, 0}, {1, 0}}}, {S("in2"), {{0, 1}}}}, ++ {S("in3"), S("in4")})); + EXPECT_EQ(c.compose(b), a.transposeOuts(llvm::to_vector(b.getOutDimNames()))); + } + +@@ -729,6 +746,40 @@ TEST_F(LinearLayoutTest, QuotientIdentit + ASSERT_TRUE(quotientLayout.has_value()); + ASSERT_TRUE(quotientLayout->quotient({S("dim2")}).has_value()); + } ++ ++TEST_F(LinearLayoutTest, Resize) { ++ auto init = LinearLayout( ++ { ++ {S("in0"), {{0, 1}, {0, 2}}}, ++ {S("in1"), {{1, 0}, {2, 0}}}, ++ {S("in2"), {}}, ++ }, ++ {S("dim0"), S("dim1")}); ++ EXPECT_EQ(init.resize(S("in0"), 8), ++ LinearLayout( ++ { ++ {S("in0"), {{0, 1}, {0, 2}, {0, 0}}}, ++ {S("in1"), {{1, 0}, {2, 0}}}, ++ {S("in2"), {}}, ++ }, ++ {S("dim0"), S("dim1")})); ++ EXPECT_EQ(init.resize(S("in0"), 4), LinearLayout( ++ { ++ {S("in0"), {{0, 1}, {0, 2}}}, ++ {S("in1"), {{1, 0}, {2, 0}}}, ++ {S("in2"), {}}, ++ }, ++ {S("dim0"), S("dim1")})); ++ EXPECT_EQ(init.resize(S("in1"), 8), ++ LinearLayout( ++ { ++ {S("in0"), {{0, 1}, {0, 2}}}, ++ {S("in1"), {{1, 0}, {2, 0}, {0, 0}}}, ++ {S("in2"), {}}, ++ }, ++ {S("dim0"), S("dim1")})); ++} ++ + } // anonymous namespace + } // namespace mlir::triton + diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index 4fa55269e3323c..0348fe0cbb87f7 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -14,5 +14,6 @@ those to this list. """ temporary_patch_list = [ + "//third_party/triton:temporary/numpy_type_promotion.patch", # Add new patches just above this line ] diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index 5de24fa70a5b75..2b93e2bababa45 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton/xla_extensions:series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl702724623" - TRITON_SHA256 = "7348c9fcc01f24d97daf71b9757b9065a36fedfe05a5fbe1ea79b603b89a65b9" + TRITON_COMMIT = "cl706678601" + TRITON_SHA256 = "904377c36458ef842e6fa2daa8e55f4fe0d235f08cce3011c5b33b50f4ffe93a" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/third_party/triton/xla_extensions/series.bzl b/third_party/triton/xla_extensions/series.bzl index 0e0291d7def6d5..9a12588aae7bcc 100644 --- a/third_party/triton/xla_extensions/series.bzl +++ b/third_party/triton/xla_extensions/series.bzl @@ -8,5 +8,6 @@ IMPORTANT: This is a temporary hack while we are figuring out the proper way to extensions_files_patch_list = [ "//third_party/triton:xla_extensions/sparse_wgmma_op.patch", # Sparsity internal patch + "//third_party/triton:xla_extensions/sparse_fenceinsertion_pass.patch", # Sparse internal patch # Add new patches just above this line ] diff --git a/third_party/triton/xla_extensions/sparse_fenceinsertion_pass.patch b/third_party/triton/xla_extensions/sparse_fenceinsertion_pass.patch new file mode 100644 index 00000000000000..d9a1a25fe2d1f9 --- /dev/null +++ b/third_party/triton/xla_extensions/sparse_fenceinsertion_pass.patch @@ -0,0 +1,13 @@ +# Tracked in b/377699102 +--- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp 2024-12-05 23:53:31.000000000 -0800 ++++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp 2024-12-19 07:03:31.000000000 -0800 +@@ -44,7 +44,8 @@ + return; + ModuleOp mod = getOperation(); + mod.walk([&](Operation *op) { +- if (!op->hasTrait()) ++ if (!isa(op) && ++ op->getName().getStringRef() != "triton_xla.sparse_dot") + return WalkResult::advance(); + OpBuilder builder(op); + auto a = op->getOperand(0); diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index 342f35280adf36..48618aa1acb6c2 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -451,30 +451,36 @@ build:avx_linux --copt=-mavx build:avx_linux --host_copt=-mavx build:avx_win --copt=/arch:AVX +# TODO(belitskiy): Remove once Win2019 is gone. # Use Clang-cl compiler on Windows +build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl +build:win_clang --extra_execution_platforms=//tools/toolchains/win:x64_windows-clang-cl +build:win_clang --host_platform=//tools/toolchains/win:x64_windows-clang-cl build:win_clang --copt=/clang:-Weverything build:win_clang --host_copt=/clang:-Weverything -build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl -build:win_clang --extra_execution_platforms=//tensorflow/tools/toolchains/win:x64_windows-clang-cl -build:win_clang --host_platform=//tensorflow/tools/toolchains/win:x64_windows-clang-cl build:win_clang --compiler=clang-cl build:win_clang --linkopt=/FORCE:MULTIPLE build:win_clang --host_linkopt=/FORCE:MULTIPLE test:win_clang --linkopt=/FORCE:MULTIPLE test:win_clang --host_linkopt=/FORCE:MULTIPLE - -# Same config as above but for XLA, which has different toolchain paths -build:win_clang_xla --copt=/clang:-Weverything -build:win_clang_xla --host_copt=/clang:-Weverything -build:win_clang_xla --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl -build:win_clang_xla --extra_execution_platforms=//tools/toolchains/win:x64_windows-clang-cl -build:win_clang_xla --host_platform=//tools/toolchains/win:x64_windows-clang-cl -build:win_clang_xla --compiler=clang-cl -build:win_clang_xla --linkopt=/FORCE:MULTIPLE -build:win_clang_xla --host_linkopt=/FORCE:MULTIPLE -test:win_clang_xla --action_env=PATHEXT=.COM;.EXE;.BAT;.CMD;.VBS;.VBE;.JS;.JSE;.WSF;.WSH;.MSC;.PY;.PYW -test:win_clang_xla --linkopt=/FORCE:MULTIPLE -test:win_clang_xla --host_linkopt=/FORCE:MULTIPLE +test:win_clang --action_env=PATHEXT=.COM;.EXE;.BAT;.CMD;.VBS;.VBE;.JS;.JSE;.WSF;.WSH;.MSC;.PY;.PYW + +# build:windows_x86_cpu --extra_toolchains="//tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl" +# build:windows_x86_cpu --extra_execution_platforms="//tools/toolchains/win2022:windows_ltsc2022_clang" +# build:windows_x86_cpu --host_platform="//tools/toolchains/win2022:windows_ltsc2022_clang" +build:windows_x86_cpu --crosstool_top="//tools/toolchains/win2022/20241118:toolchain" +build:windows_x86_cpu --extra_toolchains="//tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl" +build:windows_x86_cpu --extra_execution_platforms="//tools/toolchains/win2022:windows_ltsc2022_clang" +build:windows_x86_cpu --host_platform="//tools/toolchains/win2022:windows_ltsc2022_clang" +build:windows_x86_cpu --platforms="//tools/toolchains/win2022:windows_ltsc2022_clang" +build:windows_x86_cpu --copt=/clang:-Weverything +build:windows_x86_cpu --host_copt=/clang:-Weverything +build:windows_x86_cpu --compiler=clang-cl +build:windows_x86_cpu --linkopt=/FORCE:MULTIPLE +build:windows_x86_cpu --host_linkopt=/FORCE:MULTIPLE +test:windows_x86_cpu --linkopt=/FORCE:MULTIPLE +test:windows_x86_cpu --host_linkopt=/FORCE:MULTIPLE +test:windows_x86_cpu --action_env=PATHEXT=.COM;.EXE;.BAT;.CMD;.VBS;.VBE;.JS;.JSE;.WSF;.WSH;.MSC;.PY;.PYW # Options to build TensorFlow 1.x or 2.x. # TODO(kanglan): Change v2's define to default behavior @@ -533,9 +539,9 @@ build:rbe_linux_cpu --crosstool_top="@local_config_cuda//crosstool:toolchain" build:rbe_linux_cpu --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64" build:rbe_linux_cpu --repo_env=CC="/usr/lib/llvm-18/bin/clang" build:rbe_linux_cpu --repo_env=TF_SYSROOT="/dt9" -build:rbe_linux_cpu --extra_execution_platforms="@sigbuild-r2.17-clang_config_platform//:platform" -build:rbe_linux_cpu --host_platform="@sigbuild-r2.17-clang_config_platform//:platform" -build:rbe_linux_cpu --platforms="@sigbuild-r2.17-clang_config_platform//:platform" +build:rbe_linux_cpu --extra_execution_platforms="@ml_build_config_platform//:platform" +build:rbe_linux_cpu --host_platform="@ml_build_config_platform//:platform" +build:rbe_linux_cpu --platforms="@ml_build_config_platform//:platform" # This is needed for all Clang17 builds but must not be present in GCC builds. build:rbe_linux_cpu --copt=-Wno-error=unused-command-line-argument # This was added in clang-16 by https://reviews.llvm.org/D133574. @@ -578,11 +584,11 @@ build:rbe_win_base --nobuild_python_zip build:rbe_win_base --define=override_eigen_strong_inline=true build:rbe_win_clang --config=rbe_win_base -build:rbe_win_clang --crosstool_top="//tensorflow/tools/toolchains/win/20240424:toolchain" -build:rbe_win_clang --extra_toolchains="//tensorflow/tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl" -build:rbe_win_clang --extra_execution_platforms="//tensorflow/tools/toolchains/win:x64_windows-clang-cl" -build:rbe_win_clang --host_platform="//tensorflow/tools/toolchains/win:x64_windows-clang-cl" -build:rbe_win_clang --platforms="//tensorflow/tools/toolchains/win:x64_windows-clang-cl" +build:rbe_win_clang --crosstool_top="//tools/toolchains/win/20240424:toolchain" +build:rbe_win_clang --extra_toolchains="//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl" +build:rbe_win_clang --extra_execution_platforms="//tools/toolchains/win:x64_windows-clang-cl" +build:rbe_win_clang --host_platform="//tools/toolchains/win:x64_windows-clang-cl" +build:rbe_win_clang --platforms="//tools/toolchains/win:x64_windows-clang-cl" build:rbe_win_clang --compiler=clang-cl build:rbe_win_clang --linkopt=/FORCE:MULTIPLE build:rbe_win_clang --host_linkopt=/FORCE:MULTIPLE @@ -746,48 +752,54 @@ build:tf_public_macos_cache_push --config=tf_public_macos_cache --remote_upload_ # LIBTENSORFLOW TESTS are for building Libtensorflow archives. These are CUDA/CPU-agnostic. test:linux_libtensorflow_test --config=cuda_wheel -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_package:libtensorflow.tar.gz //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz //tensorflow/java:libtensorflow.jar //tensorflow/java:libtensorflow-src.jar //tensorflow/tools/lib_package:libtensorflow_proto.zip +build:windows_libtensorflow_build --config=cuda_wheel --config=windows_x86_cpu -- //:LICENSE //tensorflow:tensorflow.dll //tensorflow:tensorflow_dll_import_lib //tensorflow/tools/lib_package:clicenses_generate //tensorflow/java:tensorflow_jni.dll //tensorflow/tools/lib_package:jnilicenses_generate # PYTHON TESTS run a suite of Python tests intended for verifying that the Python wheel # will work properly. These are usually run Nightly or upon Release. # CPU WHEEL -test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cpu_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/tools/pip_package:import_api_packages_test +test:linux_cpu_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... # CUDA WHEEL -test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/tools/pip_package:import_api_packages_test +test:linux_cuda_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_gpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... # ARM64 WHEEL -test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/tools/pip_package:import_api_packages_test +test:linux_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS ARM64 WHEEL -test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/tools/pip_package:import_api_packages_test +test:macos_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # MACOS X86 WHEEL -test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_x86_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/tools/pip_package:import_api_packages_test +test:macos_x86_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +# WINDOWS X86 WHEEL +test:windows_x86_cpu_wheel_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test +test:windows_x86_cpu_wheel_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test +test:windows_x86_cpu_wheel_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600" +test:windows_x86_cpu_wheel_test --build_tests_only --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. # LINUX CPU PYCPP: -test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... # LINUX CUDA PYCPP: -test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 -test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_gpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... # LINUX ARM64 PYCPP # In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on @@ -798,35 +810,35 @@ test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflo # do not run them. By prefixing the configs with "build", we can run both # `bazel build` and `bazel test` commands with the same config as test configs # inherit from build. -build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? -build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/python/tools:aot_compiled_test +build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/python/tools:aot_compiled_test # CROSS-COMPILE ARM64 PYCPP build:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test # Tests that fail only when cross-compiled build:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP -test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS X86 PYCPP # These are defined as build configs so that we can run a build only job. See # the note under "ARM64 PYCPP" for more details. -build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test build:macos_x86_pycpp_test_filters --keep_going --test_lang_filters=cc,py --test_size_filters=small,medium -build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... +build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... # CROSS-COMPILE MACOS X86 PYCPP build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test # WINDOWS X86-64 CPU PYCPP build:windows_x86_cpu_pycpp_test_build_opts --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions --dynamic_mode=off build:windows_x86_cpu_pycpp_test_build_opts_debug --config=windows_x86_cpu_pycpp_test_build_opts --linkopt=/demangle:no --host_linkopt=/demangle:no --linkopt=/errorlimit:0 --host_linkopt=/errorlimit:0 -test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test -test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-tf_tosa,-oss_excluded,-gpu,-tpu,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-tf_tosa,-oss_excluded,-benchmark-test test:windows_x86_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600" test:windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_build_opts --build_tests_only test:windows_x86_cpu_pycpp_test --config=windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/... @@ -840,64 +852,31 @@ test:windows_x86_cpu_pycpp_test --config=windows_x86_cpu_pycpp_test_opts --confi # seems it is this way because these flags are old and predate the distinction # between host and execution platform. build:cross_compile_base --host_cpu=k8 -build:cross_compile_base --host_crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite -build:cross_compile_base --extra_execution_platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_x86_64 - -# XLA related settings for cross-compiled build. Certain paths are -# different in the XLA repo. -build:cross_compile_base_xla --host_cpu=k8 -build:cross_compile_base_xla --host_crosstool_top=//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite -build:cross_compile_base_xla --extra_execution_platforms=//tools/toolchains/cross_compile/config:linux_x86_64 +build:cross_compile_base --host_crosstool_top=//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +build:cross_compile_base --extra_execution_platforms=//tools/toolchains/cross_compile/config:linux_x86_64 build:rbe_cross_compile_base --config=rbe_base build:rbe_cross_compile_base --remote_instance_name=projects/tensorflow-testing/instances/default_instance -# XLA depends on some local Python headers that are configured as Genrule. They -# are present on the local host machine but not on the remote execution machine, -# leading to build failures. To resolve the issue, the following line is added -# to make sure all Genrule targets are excuted locally. -build:rbe_cross_compile_base_xla --config=rbe_cross_compile_base -build:rbe_cross_compile_base_xla --strategy=Genrule=standalone - -# Due to the above strategy, all Genrule commands are executed locally, but the -# following actions invoke tools (E.g `flatc`, `llvm-tblgen`, etc.) that are -# only executabe on the RBE (x86) machine, so the strategy_regexp options are -# added to override and run the actions using remote strategy. -build:rbe_cross_compile_base_xla --strategy_regexp='Generating code from table.*=remote' -build:rbe_cross_compile_base_xla --strategy_regexp='Generating flatbuffer files.*=remote' -build:rbe_cross_compile_base_xla --strategy_regexp='Executing genrule @llvm-project.*=remote' - # Test-related settings below this point # We cannot run cross-compiled tests on the remote Linux x86 VMs so we need to # force all tests to run locally on the Aarch64 host. test:rbe_cross_compile_base --strategy=TestRunner=local --build_tests_only test:rbe_cross_compile_base --verbose_failures=true --local_test_jobs=HOST_CPUS --test_output=errors -test:rbe_cross_compile_base_xla --config=rbe_cross_compile_base - # START LINUX AARCH64 CROSS-COMPILE CONFIGS build:cross_compile_linux_arm64 --config=cross_compile_base # Set the target CPU to Aarch64 -build:cross_compile_linux_arm64 --platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_aarch64 +build:cross_compile_linux_arm64 --platforms=//tools/toolchains/cross_compile/config:linux_aarch64 build:cross_compile_linux_arm64 --cpu=aarch64 -build:cross_compile_linux_arm64 --crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite - -# XLA uses different paths for platforms and crosstool_top. -build:cross_compile_linux_arm64_xla --config=cross_compile_base_xla -build:cross_compile_linux_arm64_xla --platforms=//tools/toolchains/cross_compile/config:linux_aarch64 -build:cross_compile_linux_arm64_xla --crosstool_top=//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +build:cross_compile_linux_arm64 --crosstool_top=//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite # RBE cross-compile configs for Linux Aarch64 build:rbe_cross_compile_linux_arm64 --config=cross_compile_linux_arm64 build:rbe_cross_compile_linux_arm64 --config=rbe_cross_compile_base test:rbe_cross_compile_linux_arm64 --config=rbe_cross_compile_base -# RBE cross-compile configs for XLA Linux Aarch64 -build:rbe_cross_compile_linux_arm64_xla --config=cross_compile_linux_arm64_xla -build:rbe_cross_compile_linux_arm64_xla --config=rbe_cross_compile_base_xla -test:rbe_cross_compile_linux_arm64_xla --config=rbe_cross_compile_base_xla - # END LINUX AARCH64 CROSS-COMPILE CONFIGS # START MACOS CROSS-COMPILE CONFIGS @@ -907,16 +886,16 @@ build:cross_compile_macos_x86 --config=nonccl build:cross_compile_macos_x86 --action_env MACOSX_DEPLOYMENT_TARGET=10.15 # Set the target CPU to Darwin x86 -build:cross_compile_macos_x86 --platforms=//tensorflow/tools/toolchains/cross_compile/config:darwin_x86_64 +build:cross_compile_macos_x86 --platforms=//tools/toolchains/cross_compile/config:darwin_x86_64 build:cross_compile_macos_x86 --cpu=darwin -build:cross_compile_macos_x86 --crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite +build:cross_compile_macos_x86 --crosstool_top=//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite # When RBE cross-compiling for macOS, we need to explicitly register the # toolchain. Otherwise, oddly, RBE complains that a "docker container must be # specified". -build:cross_compile_macos_x86 --extra_toolchains=//tensorflow/tools/toolchains/cross_compile/config:macos-x86-cross-compile-cc-toolchain +build:cross_compile_macos_x86 --extra_toolchains=//tools/toolchains/cross_compile/config:macos-x86-cross-compile-cc-toolchain # Map --platforms=darwin_x86_64 to --cpu=darwin and vice-versa to make selects() # and transistions that use these flags work. -build:cross_compile_macos_x86 --platform_mappings=tensorflow/tools/toolchains/cross_compile/config/platform_mappings +build:cross_compile_macos_x86 --platform_mappings=tools/toolchains/cross_compile/config/platform_mappings # RBE cross-compile configs for Darwin x86 build:rbe_cross_compile_macos_x86 --config=cross_compile_macos_x86 --remote_download_minimal diff --git a/third_party/xla/.github/workflows/autorun_ci.py b/third_party/xla/.github/workflows/autorun_ci.py new file mode 100644 index 00000000000000..8221fdcd90cfb5 --- /dev/null +++ b/third_party/xla/.github/workflows/autorun_ci.py @@ -0,0 +1,43 @@ +# Copyright 2024 The OpenXLA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Autoruns CI for OpenXLA org members with membership set to public.""" +import logging +import os + +import github_api + +_OPENXLA_ORG_ID = 107584881 # https://api.github.com/orgs/107584881 + + +def main(): + username = os.getenv("PR_AUTHOR_USERNAME") + pr_number = os.getenv("PR_NUMBER") + api = github_api.GitHubAPI(os.getenv("GH_TOKEN")) + + orgs = api.get_user_orgs(username) + logging.info("Found public organizations for user %s: %s", username, orgs) + + if _OPENXLA_ORG_ID in {org["id"] for org in orgs}: + logging.info( + "Found OpenXLA org in public memberships, so adding kokoro:force-run" + " label." + ) + api.add_issue_labels("openxla/xla", pr_number, ["kokoro:force-run"]) + + +if __name__ == "__main__": + logging.basicConfig() + logging.getLogger().setLevel(logging.INFO) + main() diff --git a/third_party/xla/.github/workflows/autorun_ci.yml b/third_party/xla/.github/workflows/autorun_ci.yml new file mode 100644 index 00000000000000..92ebd74e75797f --- /dev/null +++ b/third_party/xla/.github/workflows/autorun_ci.yml @@ -0,0 +1,38 @@ +# Copyright 2024 The OpenXLA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +name: Autorun CI for OpenXLA Public Members +permissions: + pull-requests: write +on: + pull_request_target: + branches: ["main"] + +jobs: + autorun-ci: + runs-on: ubuntu-22.04 + defaults: + run: + shell: bash + env: + GH_TOKEN: ${{ github.token }} + PR_NUMBER: ${{ github.event.number }} + PR_AUTHOR_USERNAME: ${{ github.event.pull_request.user.login }} + timeout-minutes: 6 + if: github.event.sender.type == 'User' + steps: + - name: "Checking out repository" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: "Autorun CI for public OpenXLA org members" + run: python3 .github/workflows/autorun_ci.py diff --git a/third_party/xla/.github/workflows/buildifier.yml b/third_party/xla/.github/workflows/buildifier.yml index e82ee7b489ca33..6a5b11ca49d36c 100644 --- a/third_party/xla/.github/workflows/buildifier.yml +++ b/third_party/xla/.github/workflows/buildifier.yml @@ -38,4 +38,4 @@ jobs: - name: "Install buildifier" run: parallel --ungroup --retries 3 --delay 15 --nonall -- go install github.com/bazelbuild/buildtools/buildifier@433ea85 # 6.4.0 - name: "Run buildifier" - run: buildifier --lint=warn --warnings=-out-of-order-load $(find xla/ -type f -name "BUILD" -or -name "*bzl" | grep -v /tsl/) + run: buildifier --lint=warn --warnings=-out-of-order-load -r xla/ diff --git a/third_party/xla/.github/workflows/cpu_benchmarks.yml b/third_party/xla/.github/workflows/cpu_benchmarks.yml new file mode 100644 index 00000000000000..d69bc9f3b8cc30 --- /dev/null +++ b/third_party/xla/.github/workflows/cpu_benchmarks.yml @@ -0,0 +1,93 @@ +# Copyright 2024 The OpenXLA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +name: Benchmarks + +on: + workflow_dispatch: # Allows manual triggering + schedule: + - cron: '0 */6 * * *' # Run every 6 hours (at minute 0 of hours 0, 6, 12, 18) + push: + branches: + - main + +jobs: + benchmark: + runs-on: ubuntu-24.04 + steps: + - name: Checkout OpenXLA + uses: actions/checkout@v3 + with: + repository: 'openxla/xla' + path: openxla + + - name: Print machine specs + run: | + lscpu + free -h # Memory information + df -h # Disk space information + uname -a # Kernel information + + - name: Build run_hlo_module + working-directory: openxla + run: bazelisk build -c opt --dynamic_mode=off xla/tools:run_hlo_module + + - name: Run HLO Module Benchmarks + working-directory: openxla + continue-on-error: true + run: | + for file in xla/tests/fuzz/*.hlo; do + filename=$(basename "$file") + # Skip expected failed hlo files. + if [[ "$filename" == "rand_000060.hlo" || "$filename" == "rand_000067.hlo" || "$filename" == "rand_000072.hlo" ]]; then + echo "Skipping benchmark on $file" + continue + fi + echo "Running benchmark on $file" + ./bazel-bin/xla/tools/run_hlo_module --input_format=hlo --platform=CPU "$file" + done + + - name: Create results directory + working-directory: openxla + run: mkdir results + + - name: Build CPU Benchmarks + working-directory: openxla + run: bazelisk build -c opt --dynamic_mode=off //xla/service/cpu/benchmarks:* + + - name: Run CPU benchmarks + working-directory: openxla + continue-on-error: true + run: | + find ./bazel-bin/xla/service/cpu/benchmarks/ -maxdepth 1 -type f -executable -name "*_test" -print0 | while IFS= read -r -d $'\0' benchmark; do + benchmark_name=$(basename "$benchmark" | sed 's/_test$//') + echo "Running benchmark: $benchmark_name" + + # Run the benchmark with default parameters. + $benchmark --benchmark_filter=".*" + $benchmark --benchmark_filter=".*" > "results/$benchmark_name.log" 2>&1 + + # Check the exit code of the benchmark + if [ $? -ne 0 ]; then + echo "Error: Benchmark '$benchmark_name' failed. Check the log file: results/$benchmark_name.log" + else + echo "Benchmark '$benchmark_name' completed successfully." + fi + done + + - name: Upload Results + uses: actions/upload-artifact@v4 + with: + name: cpu-xla-benchmarks + path: openxla/results diff --git a/third_party/xla/.github/workflows/github_api.py b/third_party/xla/.github/workflows/github_api.py index 57a8125d64539a..b178f048016b5f 100644 --- a/third_party/xla/.github/workflows/github_api.py +++ b/third_party/xla/.github/workflows/github_api.py @@ -120,3 +120,41 @@ def set_issue_status( """ endpoint = f"repos/{repo}/issues/{issue_number}" return self._make_request("POST", endpoint, status=status) + + def add_issue_labels( + self, repo: str, issue_number: int, labels: list[str] + ) -> requests.Response: + """Adds labels to an issue (or PR). + + https://docs.github.com/en/actions/managing-issues-and-pull-requests/adding-labels-to-issues + + Arguments: + repo: a string of the form `owner/repo_name`, e.g. openxla/xla + issue_number: the issue (or PR) to set the status of + labels: the labels to add to the issue + + Returns: + a requests.Response object containing the response from the API. + + Raises: + requests.exceptions.HTTPError + """ + endpoint = f"repos/{repo}/issues/{issue_number}/labels" + return self._make_request("POST", endpoint, labels=labels) + + def get_user_orgs(self, username: str) -> requests.Response: + """Gets all public org memberships for a user. + + https://docs.github.com/en/rest/orgs/orgs?apiVersion=2022-11-28#list-organizations-for-a-user + + Arguments: + username: The user's GitHub username as a string. + + Returns: + a requests.Response object containing the response from the API. + + Raises: + requests.exceptions.HTTPError + """ + endpoint = f"users/{username}/orgs" + return self._make_request("GET", endpoint, username=username) diff --git a/third_party/xla/build_tools/ci/build.py b/third_party/xla/build_tools/ci/build.py index 4cdaf4bbdff8c9..8bea3850d9edf3 100755 --- a/third_party/xla/build_tools/ci/build.py +++ b/third_party/xla/build_tools/ci/build.py @@ -304,7 +304,7 @@ def nvidia_gpu_build_with_compute_capability( type_=BuildType.CPU_ARM64, repo="openxla/xla", image_url=_ML_BUILD_ARM64_IMAGE, - configs=("warnings", "rbe_cross_compile_linux_arm64_xla", "nonccl"), + configs=("warnings", "rbe_cross_compile_linux_arm64", "nonccl"), target_patterns=_XLA_DEFAULT_TARGET_PATTERNS, options={**_DEFAULT_BAZEL_OPTIONS, "build_tests_only": True}, build_tag_filters=cpu_arm_tag_filter, diff --git a/third_party/xla/build_tools/ci/golden_commands.txt b/third_party/xla/build_tools/ci/golden_commands.txt index bc82a9e3b3d837..6be00b06a062d6 100644 --- a/third_party/xla/build_tools/ci/golden_commands.txt +++ b/third_party/xla/build_tools/ci/golden_commands.txt @@ -2,8 +2,8 @@ $KOKORO_ARTIFACTS_DIR/github/xla/.kokoro/generate_index_html.sh index.html docker pull us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest docker run --detach --name=xla_ci --rm --interactive --tty --volume=./github:/github --workdir=/github/xla us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest bash -docker exec xla_ci parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,-gpu,-requires-gpu-nvidia,-requires-gpu-amd,-not_run:arm --test_tag_filters=-no_oss,-gpu,-requires-gpu-nvidia,-requires-gpu-amd,-not_run:arm --config=warnings --config=rbe_cross_compile_linux_arm64_xla --config=nonccl --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --build_tests_only --nobuild -- //xla/... //build_tools/... @local_tsl//tsl/... -docker exec xla_ci bazel test --build_tag_filters=-no_oss,-gpu,-requires-gpu-nvidia,-requires-gpu-amd,-not_run:arm --test_tag_filters=-no_oss,-gpu,-requires-gpu-nvidia,-requires-gpu-amd,-not_run:arm --config=warnings --config=rbe_cross_compile_linux_arm64_xla --config=nonccl --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --build_tests_only -- //xla/... //build_tools/... @local_tsl//tsl/... +docker exec xla_ci parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,-gpu,-requires-gpu-nvidia,-requires-gpu-amd,-not_run:arm --test_tag_filters=-no_oss,-gpu,-requires-gpu-nvidia,-requires-gpu-amd,-not_run:arm --config=warnings --config=rbe_cross_compile_linux_arm64 --config=nonccl --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --build_tests_only --nobuild -- //xla/... //build_tools/... @local_tsl//tsl/... +docker exec xla_ci bazel test --build_tag_filters=-no_oss,-gpu,-requires-gpu-nvidia,-requires-gpu-amd,-not_run:arm --test_tag_filters=-no_oss,-gpu,-requires-gpu-nvidia,-requires-gpu-amd,-not_run:arm --config=warnings --config=rbe_cross_compile_linux_arm64 --config=nonccl --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --build_tests_only -- //xla/... //build_tools/... @local_tsl//tsl/... docker exec xla_ci bazel analyze-profile profile.json.gz docker stop xla_ci # END BuildType.CPU_ARM64 diff --git a/third_party/xla/build_tools/lint/tags.py b/third_party/xla/build_tools/lint/tags.py index 3e3a4680323d98..195257c8e8e7fe 100644 --- a/third_party/xla/build_tools/lint/tags.py +++ b/third_party/xla/build_tools/lint/tags.py @@ -83,13 +83,14 @@ "xla_gpu_h100": "Runs on an h100.", "xla_gpu_b100": "Runs on an b100.", # Below tags are consumed by `xla_test`. - "test_xla_cpu_thunks": ( - "Internally, `xla_test` sets `--xla_cpu_use_thunk_runtime`. Unused on" - " OpenXLA CI." + "test_xla_cpu_no_thunks": ( + "Internally, `xla_test` sets `--xla_cpu_use_thunk_runtime` to false." + " Unused on OpenXLA CI." ), - "test_hlo_pjrt_runner": ( - "Internally adds the appropriate" - " `xla/tests:pjrt_$BACKEND_client_registry`. Unused on OpenXLA CI." + "test_migrated_to_hlo_runner_pjrt": ( + "Adds the appropriate `xla/tests:pjrt_$BACKEND_client_registry` to the" + " annotated `xla_test` target. Adding this tag does not synthesize" + " additional targets." ), "multi_gpu": "Used by `xla_test` to signal that multiple GPUs are needed.", "multi_gpu_h100": ( diff --git a/third_party/xla/build_tools/rocm/run_xla.sh b/third_party/xla/build_tools/rocm/run_xla.sh index 140c6a9c1f0088..2ed5dc2d317acc 100755 --- a/third_party/xla/build_tools/rocm/run_xla.sh +++ b/third_party/xla/build_tools/rocm/run_xla.sh @@ -56,6 +56,7 @@ TAGS_FILTER="${TAGS_FILTER},${UNSUPPORTED_GPU_TAGS// /,}" bazel \ test \ + --define xnn_enable_avxvnniint8=false --define xnn_enable_avx512fp16=false \ --config=rocm \ --build_tag_filters=${TAGS_FILTER} \ --test_tag_filters=${TAGS_FILTER} \ diff --git a/third_party/xla/docs/_toc.yaml b/third_party/xla/docs/_toc.yaml index 8028d8f1a7d69a..48d49d0f6451f9 100644 --- a/third_party/xla/docs/_toc.yaml +++ b/third_party/xla/docs/_toc.yaml @@ -35,6 +35,8 @@ toc: path: /xla/persisted_autotuning - title: Shapes and layout path: /xla/shapes + - title: Testing HLO passes + path: /xla/test_hlo_passes - title: Tiled layout path: /xla/tiled_layout - title: Using LSP autocompletion diff --git a/third_party/xla/docs/contributing.md b/third_party/xla/docs/contributing.md index 8d17d519ae0d9e..7c9ff58d47afc1 100644 --- a/third_party/xla/docs/contributing.md +++ b/third_party/xla/docs/contributing.md @@ -40,9 +40,9 @@ This project follows ### Developer Guide -For a guide on how to setup a development environment for OpenXLA, including getting -code, building it, running tests and submitting changes, please refer to the -[Developer guide](docs/developer_guide.md). +For a guide on how to setup a development environment for OpenXLA, including +getting code, building it, running tests and submitting changes, please refer to +the [Developer guide](./developer_guide.md). ### Code standards diff --git a/third_party/xla/docs/custom_call.md b/third_party/xla/docs/custom_call.md index 1bd39c0e070405..840e284c7eab19 100644 --- a/third_party/xla/docs/custom_call.md +++ b/third_party/xla/docs/custom_call.md @@ -33,7 +33,7 @@ end to end examples of integrating custom calls and XLA FFI with JAX. XLA FFI binding is a compile-time specification of the custom call signature: custom call arguments, attributes and their types, and additional parameters passed via the execution context (i.e., gpu stream for GPU backend). XLA FFI -finding can be bound to any C++ callable (function pointer, lambda, etc.) with +binding can be bound to any C++ callable (function pointer, lambda, etc.) with compatible `operator()` signature. Constructed handler decodes XLA FFI call frame (defined by the stable C API), type check all parameters, and forward decoded results to the user-defined callback. diff --git a/third_party/xla/docs/gpu_architecture.md b/third_party/xla/docs/gpu_architecture.md new file mode 100644 index 00000000000000..295b206ae21353 --- /dev/null +++ b/third_party/xla/docs/gpu_architecture.md @@ -0,0 +1,253 @@ +# XLA:GPU Architecture Overview + +# Introduction + +XLA is a hardware- and framework- domain-specific compiler for linear algebra, +offering best-in-class performance. JAX, TF, Pytorch and others use XLA by +converting the user input to +[StableHLO](https://github.com/openxla/stablehlo/tree/main) (“high-level +operation”: a set of \~100 statically shaped instructions like addition, +subtraction, matmul, etc) operation set, from which XLA produces optimized code +for a variety of backends: + +![](./images/xla_hardware.png) + +During the execution, the frameworks invoke the +[PJRT runtime](https://opensource.googleblog.com/2023/05/pjrt-simplifying-ml-hardware-and-framework-integration.html) +API, which lets the frameworks perform the operation “populate the specified +buffers using a given StableHLO program on a specific device”. + +# XLA:GPU Pipeline + +XLA:GPU uses a combination of “native” (PTX, via LLVM) emitters and TritonIR +emitters to generate high-performance GPU kernels (blue color indicates 3P +components): + +![](./images/gpu_pipeline.png) + +## Running Example: JAX + +To illustrate the pipeline, let’s start with a running example in JAX, which +computes a matmul combined with multiplication by a constant and negation: + +``` +def f(a, b): +    return -((a @ b) * 0.125) +``` + +We can inspect the HLO generated by the function: + +``` +M = 1024 +K = 512 +N = 2048 +key = jax.random.PRNGKey(1701) +a = jax.random.randint(key, (M, K), dtype=jax.numpy.int8, minval=0, maxval=255) +b = jax.random.normal(key, (K, N), dtype=jax.dtypes.bfloat16) + +print(jax.xla_computation(f)(a, b).as_hlo_text()) +``` + +which generates: + +``` +HloModule xla_computation_f, entry_computation_layout={(s8[1024,512]{1,0}, bf16[512,2048]{1,0})->(bf16[1024,2048]{1,0})} + +ENTRY main.10 { +  Arg_0.1 = s8[1024,512]{1,0} parameter(0) +  convert.5 = bf16[1024,512]{1,0} convert(Arg_0.1) +  Arg_1.2 = bf16[512,2048]{1,0} parameter(1) +  dot.6 = bf16[1024,2048]{1,0} dot(convert.5, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +  constant.3 = bf16[] constant(0.125) +  broadcast.4 = bf16[1024,2048]{1,0} broadcast(constant.3), dimensions={} +  multiply.7 = bf16[1024,2048]{1,0} multiply(dot.6, broadcast.4) +  ROOT negate.8 = bf16[1024,2048]{1,0} negate(multiply.7) +} +``` + +We can visualize the input HLO computation as well, using +`jax.xla_computation(f)(a, b).as_hlo_dot_graph()`: + +![](./images/lowered_hlo.png) + +## Optimizations on HLO: Key Components + +A number of notable optimization passes happen on HLO, as HLO->HLO rewrites. + +### SPMD Partitioner + +The XLA SPMD partitioner, as described in the GSPMD +[publication](https://arxiv.org/pdf/2105.04663.pdf%C3%AF%C2%BC%E2%80%B0%C3%A3%E2%82%AC%E2%80%B9%C3%A5%E2%80%99%C5%92), +consumes HLO with sharding annotations (produced e.g. by `jax.pjit`), and +produces a sharded HLO which can then run on a number of hosts and devices. +Apart from partitioning, the SPMD attempts to optimize HLO for an optimal +execution schedule, +[overlapping](https://dl.acm.org/doi/pdf/10.1145/3567955.3567959) computation +and communication between the nodes. + +#### Example + +Consider starting from a simple JAX program sharded across two devices: + +``` +# Defines a mesh with two axes called ‘x’ and ‘y’, +# sharded across two devices: first and second CPU. +with jax.sharding.Mesh( +      [['cpu:0', 'cpu:1']], ('x', 'y')): + +    @pjit +    def f(a, b): +        out = -((a @ b) * 0.125) +        # Shard output matrix access across ‘x’ +        # and ‘y’ respectively. Generates ‘Sharding’ +        # custom call. +        out = with_sharding_constraint( +          out, jax.lax.PartitionSpec('x', 'y')) +        return out + +# Random inputs to call our function. +a = jax.random.randint(key, (1024, 512), jnp.int8) +b = jax.random.normal(key, (512, 2048), jnp.float32) + +print(f.lower(a, b).compiler_ir()) +``` + +Visualizing it, the sharding annotations are presented as custom calls: + +![](./images/annotated_module.png) + +To check how the SPMD partitioner expands the custom call, we can look at HLO +after optimizations: + +``` +print(f.lower(np.ones((8, 8)).compile().as_text()) +``` + +Which generates HLO with a collective: + +![](./images/partitioned_module.png) + +### Layout Assignment + +HLO decouples logical shape and physical layout (how tensors are laid out in +memory). For example, a matrix `f32[32, 64]` can be represented either in +row-major or column-major order, represented as `{1,0}` or `{0,1}` respectively. +In general, layout is represented as a part of shape, showing a permutation over +the rank indicating physical layout in memory. + +For each operation present in the HLO, the Layout Assignment pass chooses an +optimal layout (e.g. NHWC for a convolution on Ampere). For example, an +`int8xint8->int32`  matmul operation prefers `{0,1}` layout for the RHS of the +computation. Similarly, “transposes” inserted by the user are ignored, and +encoded as a layout change. + +The layouts are then propagated through the graph, and conflicts between layouts +or at graph endpoints are materialized as `copy` operations, which perform the +physical transposition. For example, starting from the graph + +![](./images/pre_layout_module.png) + +Running the layout assignment we see the following layouts and `copy` operation +inserted: + +![](./images/layout_assigned_module.png) + +### Fusion + +Fusion is XLA’s single most important optimization, which groups multiple +operations (e.g. addition into exponentiation into matmul) to a single kernel. +Since many GPU workloads tend to be memory-bound, fusion dramatically speeds up +the execution by avoiding the writing of intermediate tensors to HBM and then +reading them back, and instead passes them around in either registers or shared +memory. + +Fused HLO instructions are blocked together in a single fusion computation, +which establishes the following invariants: + +- No intermediate storage inside the fusion is materialized in HBM (it has to + be all passed through either registers or shared memory). + +- A fusion is always compiled to exactly one GPU kernel + +## HLO Optimizations on Running Example + +We can inspect the post-optimization HLO using `jax.jit(f).lower(a, +b).compile().as_text()`, and verify that a single fusion got generated: + +``` +HloModule jit_f, is_scheduled=true, entry_computation_layout={(s8[3,2]{1,0}, bf16[2,3]{1,0})->bf16[3,3]{1,0}}, allow_spmd_sharding_propagation_to_output={true} + +%triton_gemm_dot.6_computation (parameter_0: s8[3,2], parameter_1: bf16[2,3]) -> bf16[3,3] { +  %parameter_0 = s8[3,2]{1,0} parameter(0) +  %convert.0 = bf16[3,2]{1,0} convert(s8[3,2]{1,0} %parameter_0) +  %parameter_1 = bf16[2,3]{1,0} parameter(1) +  %dot.0 = bf16[3,3]{1,0} dot(bf16[3,2]{1,0} %convert.0, bf16[2,3]{1,0} %parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +  %convert.1 = f32[3,3]{1,0} convert(bf16[3,3]{1,0} %dot.0) +  %constant_0 = bf16[] constant(0.125) +  %broadcast.0 = bf16[3,3]{1,0} broadcast(bf16[] %constant_0), dimensions={} +  %convert.2 = f32[3,3]{1,0} convert(bf16[3,3]{1,0} %broadcast.0) +  %multiply.0 = f32[3,3]{1,0} multiply(f32[3,3]{1,0} %convert.1, f32[3,3]{1,0} %convert.2) +  %negate.0 = f32[3,3]{1,0} negate(f32[3,3]{1,0} %multiply.0) +  ROOT %convert.6 = bf16[3,3]{1,0} convert(f32[3,3]{1,0} %negate.0) +} + +ENTRY %main.9 (Arg_0.1: s8[3,2], Arg_1.2: bf16[2,3]) -> bf16[3,3] { +  %Arg_1.2 = bf16[2,3]{1,0} parameter(1), sharding={replicated} +  %Arg_0.1 = s8[3,2]{1,0} parameter(0), sharding={replicated} +  ROOT %triton_gemm_dot.6 = bf16[3,3]{1,0} fusion(s8[3,2]{1,0} %Arg_0.1, bf16[2,3]{1,0} %Arg_1.2), kind=kCustom, calls=%triton_gemm_dot.6_computation, backend_config={"kind":"__triton_gemm","triton_gemm_config":{"block_m":"64","block_n":"64","block_k":"64","split_k":"1","num_stages":"2","num_warps":"4"}} +} +``` + +Note that the fusion `backend_config` tells us that Triton will be used as a +code generation strategy, and it specifies the chosen tiling. + +We can also visualize the resulting module: + +![](./images/fused_module.png) + +## Buffer Assignment and Scheduling + +A buffer assignment pass takes into account the shape information, and aims to +produce an optimal buffer allocation for the program, minimizing the amount of +intermediate memory consumed. Unlike TF or PyTorch immediate-mode (non-compiled) +execution, where the memory allocator does not know the graph in advance, the +XLA scheduler can “look into the future” and produce an optimal computation +schedule. + +## Compiler Backend: Codegen and Library Selection + +For every HLO instruction in the computation, XLA chooses whether to run it +using a library linked into a runtime, or to codegen it to PTX. + +### Library Selection + +For many common operations, XLA:GPU uses fast-performance libraries from NVIDIA, +such as cuBLAS, cuDNN, and NCCL. The libraries have an advantage of verified +fast performance, but often preclude complex fusion opportunities. + +### Direct code generation + +The XLA:GPU backend generates high-performance LLVM IR directly for a number of +operations (reductions, transposes, etc). + +### Triton code generation + +For more advanced fusions which include matrix multiplication or softmax, +XLA:GPU uses [Triton](https://github.com/openai/triton) as a code-generation +layer. HLO Fusions are converted to TritonIR (an MLIR dialect which serves as an +input to Triton), selects tiling parameters and invokes Triton for PTX +generation: + +![](./images/triton_opt_pipeline.png) + +We have observed the resulting code to perform very well on Ampere, at +near-roofline performance with properly tuned tile sizes. + +## Runtime + +XLA Runtime converts the resulting sequence of CUDA kernel calls and library +invocations into a RuntimeIR (an MLIR dialect in XLA), on which CUDA graph +extraction is performed. CUDA graph is still work in progress, only some nodes +are currently supported. Once CUDA graph boundaries are extracted, RuntimeIR is +compiled via LLVM to a CPU executable, which can then be stored or transferred +for Ahead-Of-Time compilation. diff --git a/third_party/xla/docs/images/annotated_module.png b/third_party/xla/docs/images/annotated_module.png new file mode 100644 index 00000000000000..ba013533c11058 Binary files /dev/null and b/third_party/xla/docs/images/annotated_module.png differ diff --git a/third_party/xla/docs/images/fused_module.png b/third_party/xla/docs/images/fused_module.png new file mode 100644 index 00000000000000..044e477babad26 Binary files /dev/null and b/third_party/xla/docs/images/fused_module.png differ diff --git a/third_party/xla/docs/images/gpu_pipeline.png b/third_party/xla/docs/images/gpu_pipeline.png new file mode 100644 index 00000000000000..38ac530c9bb204 Binary files /dev/null and b/third_party/xla/docs/images/gpu_pipeline.png differ diff --git a/third_party/xla/docs/images/layout_assigned_module.png b/third_party/xla/docs/images/layout_assigned_module.png new file mode 100644 index 00000000000000..8a32c1d34d12b4 Binary files /dev/null and b/third_party/xla/docs/images/layout_assigned_module.png differ diff --git a/third_party/xla/docs/images/lowered_hlo.png b/third_party/xla/docs/images/lowered_hlo.png new file mode 100644 index 00000000000000..fa8a79918b4e3f Binary files /dev/null and b/third_party/xla/docs/images/lowered_hlo.png differ diff --git a/third_party/xla/docs/images/partitioned_module.png b/third_party/xla/docs/images/partitioned_module.png new file mode 100644 index 00000000000000..3b60284aecaeff Binary files /dev/null and b/third_party/xla/docs/images/partitioned_module.png differ diff --git a/third_party/xla/docs/images/pre_layout_module.png b/third_party/xla/docs/images/pre_layout_module.png new file mode 100644 index 00000000000000..0558c9ff0ac188 Binary files /dev/null and b/third_party/xla/docs/images/pre_layout_module.png differ diff --git a/third_party/xla/docs/images/triton_opt_pipeline.png b/third_party/xla/docs/images/triton_opt_pipeline.png new file mode 100644 index 00000000000000..2391094bad0c02 Binary files /dev/null and b/third_party/xla/docs/images/triton_opt_pipeline.png differ diff --git a/third_party/xla/docs/images/xla_hardware.png b/third_party/xla/docs/images/xla_hardware.png new file mode 100644 index 00000000000000..56f47de9fa7a36 Binary files /dev/null and b/third_party/xla/docs/images/xla_hardware.png differ diff --git a/third_party/xla/docs/indexing.md b/third_party/xla/docs/indexing.md index 29fe34895771da..fb524a9f42d2b7 100644 --- a/third_party/xla/docs/indexing.md +++ b/third_party/xla/docs/indexing.md @@ -300,7 +300,8 @@ d1 in [0, 29] ``` ### [Gather](https://openxla.org/xla/operation_semantics#gather) -Only the simplified gather is supported. See [gather_simplifier].(https://github.com/openxla/xla/blob/main/xla/hlo/transforms/simplifiers/gather_simplifier.h). + +Only the simplified gather is supported. See [gather_simplifier.h](https://github.com/openxla/xla/blob/main/xla/hlo/transforms/simplifiers/gather_simplifier.h). ```c++ operand = f32[33,76,70] parameter(0) @@ -326,10 +327,7 @@ rt0 in [0, 26], rt1 in [0, 68] ``` -Note that now we have **s** on the right side for the input-to-output mapping. -Those are the symbols that represent runtime values. For example, in this -particular case for every element of the output with indices `d0, d1, d2, d3` we -extract elements (d0, 0) and (d0, 1) from `indices` tensor. +Note that now we have **rt** symbols that represent runtime values. The output to input map for `indices`: @@ -342,10 +340,10 @@ d2 in [0, 7], d3 in [0, 3], s0 in [0, 1] ``` + The range variable `s0` shows that we need the entire row (d0, *) of the `indices` tensor to compute an element of the output. - ### [Transpose](https://openxla.org/xla/operation_semantics#transpose) Indexing map for transpose is a permutation of input/output dimensions. diff --git a/third_party/xla/docs/operation_semantics.md b/third_party/xla/docs/operation_semantics.md index 2a2f24f77c3b03..11704cbe2d73cc 100644 --- a/third_party/xla/docs/operation_semantics.md +++ b/third_party/xla/docs/operation_semantics.md @@ -180,9 +180,9 @@ AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4); ![](images/ops_alltoall.png) In this example, there are 4 cores participating in the Alltoall. On each core, -the operand is split into 4 parts along dimension 0, so each part has shape +the operand is split into 4 parts along dimension 1, so each part has shape f32[4,4]. The 4 parts are scattered to all cores. Then each core concatenates -the received parts along dimension 1, in the order of core 0-4. So the output on +the received parts along dimension 0, in the order of core 0-4. So the output on each core has shape f32[16,4]. ## BatchNormGrad diff --git a/third_party/xla/docs/test_hlo_passes.md b/third_party/xla/docs/test_hlo_passes.md new file mode 100644 index 00000000000000..c7ddc089997005 --- /dev/null +++ b/third_party/xla/docs/test_hlo_passes.md @@ -0,0 +1,73 @@ +# Writing unit tests for HLO passes + +There are different ways to write unit test for HLO passes. This page describes +the preferred method to ensure consistency and readability. + +## `FileCheck` with `CHECK` lines interleaved + +Most HLO passes can be tested using +[`FileCheck`](https://llvm.org/docs/CommandGuide/FileCheck.html) tests. +Interleave `CHECK` lines in input HLO module texts, and make sure to use `// +CHECK` instead of `; CHECK` uniformly as the `FileCheck` delimiter. + +For example, you can re-write the +[`fusion cc_test` for a `priotity_fusion` pass](https://github.com/openxla/xla/blob/fe30942a406659bff75399a2a10585bbd1287e07/xla/service/gpu/transforms/priority_fusion_test.cc#L133-L149) +as follows: + +``` +TEST_F(PriorityFusionTest, FuseBroadcastIntoBitcastConsumers) { + absl::string_view kHlo = R"( + HloModule test_module + + // CHECK: ENTRY main + ENTRY main { + // CHECK-NEXT: %[[PARAM:.*]] = f32[96]{0} parameter(0) + param_0 = f32[96]{0} parameter(0) + broadcast = f32[8,96,128,7]{3,2,1,0} broadcast(param_0), dimensions={1} + bitcast.6079.2 = f32[8,24,4,128,7]{4,3,2,1,0} bitcast(broadcast) + // CHECK-NEXT: ROOT %{{.*}} fusion(%[[PARAM]]) {{.*}} + ROOT transpose.1990.2 = f32[8,24,128,7,4]{4,3,2,1,0} transpose(bitcast.6079.2), dimensions={0,1,3,4,2} + } + )"; + RunAndFilecheckHloRewrite(kHlo, std::move(priority_fusion_)); +} +``` + +Note: Currently, the codebase has some tests where input HLO module and expected +module are written separately. Inlining the `CHECK` lines is the preferred +method for future tests. It enables better readability and a similar signature +as MLIR based tests +[like in `stablehlo_aggressive_folder.mlir`](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir#L31-L39). + +## `LIT` runner and `hlo-opt` + +Where feasible, use [`LIT`](https://llvm.org/docs/CommandGuide/lit.html) runner +and `hlo-opt`, and place `CHECK` lines locally next to the input IR they +correspond to. Again, make sure to use `// CHECK` instead of `; CHECK` as the +delimiter. + +For example, some +[GPU tests](https://github.com/openxla/xla/tree/main/xla/service/gpu/tests) can +be written as follows: + +``` +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK-%{PTX} %s + +HloModule Test, is_scheduled=true +fused_computation { + param_0 = f32[100,200]{1,0} parameter(0) + ROOT b.1 = f32[200,100]{1,0} transpose(f32[100,200]{1,0} param_0), dimensions={1,0} +} +ENTRY main { + a = f32[100, 200]{1,0} parameter(0) + // CHECK-PTX: call void @llvm.nvvm.barrier0 + // CHECK-GCN: call void @llvm.amdgcn.s.barrier + ROOT wrapped_b = f32[200,100]{1,0} fusion(f32[100,200]{1,0} a), kind=kInput, calls=fused_computation +} +``` + +## (Don't) Graph traversal + +Refrain from writing tests that travel leaf nodes of the result graph and match +with expected op. These tests are tedious to write, difficult to quickly read, +and more difficult to debug and fix. Use one of the above options instead. diff --git a/third_party/xla/docs/tf2xla/index.md b/third_party/xla/docs/tf2xla/index.md index edde1f7de62374..6cb58700e80993 100644 --- a/third_party/xla/docs/tf2xla/index.md +++ b/third_party/xla/docs/tf2xla/index.md @@ -143,11 +143,6 @@ experimental. For a detailed usage example see the [auto-clustering tutorial colab](./tutorials/autoclustering_xla.ipynb). -### AOT (Ahead-of-time) compilation for CPU with `tfcompile` - -You can also use a standalone [`tfcompile`](./tfcompile.md) tool, which converts -TensorFlow graph into executable code (for x86-64 CPU only). - ## Inspect compiled programs XLA provides introspection facilities which let you inspect the generated diff --git a/third_party/xla/docs/tf2xla/tfcompile.md b/third_party/xla/docs/tf2xla/tfcompile.md deleted file mode 100644 index 5d60a4e90a9acb..00000000000000 --- a/third_party/xla/docs/tf2xla/tfcompile.md +++ /dev/null @@ -1,279 +0,0 @@ -# Using AOT compilation - -## What is tfcompile? - -`tfcompile` is a standalone tool that ahead-of-time (AOT) compiles TensorFlow -graphs into executable code. It can reduce total binary size, and also avoid -some runtime overheads. A typical use-case of `tfcompile` is to compile an -inference graph into executable code for mobile devices. - -The TensorFlow graph is normally executed by the TensorFlow runtime. This incurs -some runtime overhead for execution of each node in the graph. This also leads -to a larger total binary size, since the code for the TensorFlow runtime needs -to be available, in addition to the graph itself. The executable code produced -by `tfcompile` does not use the TensorFlow runtime, and only has dependencies on -kernels that are actually used in the computation. - -The compiler is built on top of the XLA framework. The code bridging TensorFlow -to the XLA framework resides under -[tensorflow/compiler](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/). - -## What does tfcompile do? - -`tfcompile` takes a subgraph, identified by the TensorFlow concepts of -feeds and fetches, and generates a function that implements that subgraph. -The `feeds` are the input arguments for the function, and the `fetches` are the -output arguments for the function. All inputs must be fully specified by the -feeds; the resulting pruned subgraph cannot contain Placeholder or Variable -nodes. It is common to specify all Placeholders and Variables as feeds, which -ensures the resulting subgraph no longer contains these nodes. The generated -function is packaged as a `cc_library`, with a header file exporting the -function signature, and an object file containing the implementation. The user -writes code to invoke the generated function as appropriate. - -## Using tfcompile - -This section details high level steps for generating an executable binary with -`tfcompile` from a TensorFlow subgraph. The steps are: - -* Step 1: Configure the subgraph to compile -* Step 2: Use the `tf_library` build macro to compile the subgraph -* Step 3: Write code to invoke the subgraph -* Step 4: Create the final binary - -### Step 1: Configure the subgraph to compile - -Identify the feeds and fetches that correspond to the input and output -arguments for the generated function. Then configure the `feeds` and `fetches` -in a [`tensorflow.tf2xla.Config`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/tf2xla.proto) -proto. - -```textproto -# Each feed is a positional input argument for the generated function. The order -# of each entry matches the order of each input argument. Here “x_hold” and “y_hold” -# refer to the names of placeholder nodes defined in the graph. -feed { - id { node_name: "x_hold" } - shape { - dim { size: 2 } - dim { size: 3 } - } -} -feed { - id { node_name: "y_hold" } - shape { - dim { size: 3 } - dim { size: 2 } - } -} - -# Each fetch is a positional output argument for the generated function. The order -# of each entry matches the order of each output argument. Here “x_y_prod” -# refers to the name of a matmul node defined in the graph. -fetch { - id { node_name: "x_y_prod" } -} -``` - -### Step 2: Use tf_library build macro to compile the subgraph - -This step converts the graph into a `cc_library` using the `tf_library` build -macro. The `cc_library` consists of an object file containing the code generated -from the graph, along with a header file that gives access to the generated -code. `tf_library` utilizes `tfcompile` to compile the TensorFlow graph into -executable code. - -```build -load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") - -# Use the tf_library macro to compile your graph into executable code. -tf_library( - # name is used to generate the following underlying build rules: - # : cc_library packaging the generated header and object files - # _test : cc_test containing a simple test and benchmark - # _benchmark : cc_binary containing a stand-alone benchmark with minimal deps; - # can be run on a mobile device - name = "test_graph_tfmatmul", - # cpp_class specifies the name of the generated C++ class, with namespaces allowed. - # The class will be generated in the given namespace(s), or if no namespaces are - # given, within the global namespace. - cpp_class = "foo::bar::MatMulComp", - # graph is the input GraphDef proto, by default expected in binary format. To - # use the text format instead, just use the ‘.pbtxt’ suffix. A subgraph will be - # created from this input graph, with feeds as inputs and fetches as outputs. - # No Placeholder or Variable ops may exist in this subgraph. - graph = "test_graph_tfmatmul.pb", - # config is the input Config proto, by default expected in binary format. To - # use the text format instead, use the ‘.pbtxt’ suffix. This is where the - # feeds and fetches were specified above, in the previous step. - config = "test_graph_tfmatmul.config.pbtxt", -) -``` - -> To generate the GraphDef proto (test_graph_tfmatmul.pb) for this example, run -> [make_test_graphs.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/aot/tests/make_test_graphs.py) -> and specify the output location with the --out_dir flag. - -Typical graphs contain [`Variables`](https://www.tensorflow.org/guide/variables) -representing the weights that are learned via training, but `tfcompile` cannot -compile a subgraph that contain `Variables`. The -[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py) -tool converts variables into constants, using values stored in a checkpoint -file. As a convenience, the `tf_library` macro supports the `freeze_checkpoint` -argument, which runs the tool. For more examples see -[tensorflow/compiler/aot/tests/BUILD](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/aot/tests/BUILD). - -> Constants that show up in the compiled subgraph are compiled directly into the -> generated code. To pass the constants into the generated function, rather than -> having them compiled-in, simply pass them in as feeds. - -For details on the `tf_library` build macro, see -[tfcompile.bzl](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/aot/tfcompile.bzl). - -For details on the underlying `tfcompile` tool, see -[tfcompile_main.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/aot/tfcompile_main.cc). - -### Step 3: Write code to invoke the subgraph - -This step uses the header file (`test_graph_tfmatmul.h`) generated by the -`tf_library` build macro in the previous step to invoke the generated code. The -header file is located in the `bazel-bin` directory corresponding to the -build package, and is named based on the name attribute set on the `tf_library` -build macro. For example, the header generated for `test_graph_tfmatmul` would -be `test_graph_tfmatmul.h`. Below is an abbreviated version of what is -generated. The generated file, in `bazel-bin`, contains additional useful -comments. - -```c++ -namespace foo { -namespace bar { - -// MatMulComp represents a computation previously specified in a -// TensorFlow graph, now compiled into executable code. -class MatMulComp { - public: - // AllocMode controls the buffer allocation mode. - enum class AllocMode { - ARGS_RESULTS_AND_TEMPS, // Allocate arg, result and temp buffers - RESULTS_AND_TEMPS_ONLY, // Only allocate result and temp buffers - }; - - MatMulComp(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS); - ~MatMulComp(); - - // Runs the computation, with inputs read from arg buffers, and outputs - // written to result buffers. Returns true on success and false on failure. - bool Run(); - - // Arg methods for managing input buffers. Buffers are in row-major order. - // There is a set of methods for each positional argument. - void** args(); - - void set_arg0_data(float* data); - float* arg0_data(); - float& arg0(size_t dim0, size_t dim1); - - void set_arg1_data(float* data); - float* arg1_data(); - float& arg1(size_t dim0, size_t dim1); - - // Result methods for managing output buffers. Buffers are in row-major order. - // Must only be called after a successful Run call. There is a set of methods - // for each positional result. - void** results(); - - - float* result0_data(); - float& result0(size_t dim0, size_t dim1); -}; - -} // end namespace bar -} // end namespace foo -``` - -The generated C++ class is called `MatMulComp` in the `foo::bar` namespace, -because that was the `cpp_class` specified in the `tf_library` macro. All -generated classes have a similar API, with the only difference being the methods -to handle arg and result buffers. Those methods differ based on the number and -types of the buffers, which were specified by the `feed` and `fetch` arguments -to the `tf_library` macro. - -There are three types of buffers managed within the generated class: `args` -representing the inputs, `results` representing the outputs, and `temps` -representing temporary buffers used internally to perform the computation. By -default, each instance of the generated class allocates and manages all of these -buffers for you. The `AllocMode` constructor argument may be used to change this -behavior. All buffers are aligned to 64-byte boundaries. - -The generated C++ class is just a wrapper around the low-level code generated by -XLA. - -Example of invoking the generated function based on -[`tfcompile_test.cc`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/aot/tests/tfcompile_test.cc): - -```c++ -#define EIGEN_USE_THREADS -#define EIGEN_USE_CUSTOM_THREAD_POOL - -#include -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "third_party/tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" // generated - -int main(int argc, char** argv) { - Eigen::ThreadPool tp(2); // Size the thread pool as appropriate. - Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); - - - foo::bar::MatMulComp matmul; - matmul.set_thread_pool(&device); - - // Set up args and run the computation. - const float args[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; - std::copy(args + 0, args + 6, matmul.arg0_data()); - std::copy(args + 6, args + 12, matmul.arg1_data()); - matmul.Run(); - - // Check result - if (matmul.result0(0, 0) == 58) { - std::cout << "Success" << std::endl; - } else { - std::cout << "Failed. Expected value 58 at 0,0. Got:" - << matmul.result0(0, 0) << std::endl; - } - - return 0; -} -``` - -### Step 4: Create the final binary - -This step combines the library generated by `tf_library` in step 2 and the code -written in step 3 to create a final binary. Below is an example `bazel` BUILD -file. - -```build -# Example of linking your binary -# Also see //tensorflow/compiler/aot/tests/BUILD -load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") - -# The same tf_library call from step 2 above. -tf_library( - name = "test_graph_tfmatmul", - ... -) - -# The executable code generated by tf_library can then be linked into your code. -cc_binary( - name = "my_binary", - srcs = [ - "my_code.cc", # include test_graph_tfmatmul.h to access the generated header - ], - deps = [ - ":test_graph_tfmatmul", # link in the generated object file - "//third_party/eigen3", - ], - linkopts = [ - "-lpthread", - ] -) -``` diff --git a/third_party/xla/opensource_only.files b/third_party/xla/opensource_only.files index ec12dc189805fb..d3572bee439d60 100644 --- a/third_party/xla/opensource_only.files +++ b/third_party/xla/opensource_only.files @@ -1,4 +1,5 @@ compiler/xla/backends/cpu/nanort/package_groups.bzl: +compiler/xla/backends/cpu/package_groups.bzl: compiler/xla/internal/package_groups.bzl: compiler/xla/mlir_hlo/WORKSPACE: compiler/xla/package_groups.bzl: @@ -56,5 +57,7 @@ tools/toolchains/win/20240424/BUILD: tools/toolchains/win/BUILD: tools/toolchains/win/bazel_211/BUILD: tools/toolchains/win/tf_win_05022023/BUILD: +tools/toolchains/win2022/20241118/BUILD: +tools/toolchains/win2022/BUILD: tools/toolchains/win_1803/py38/BUILD: tools/toolchains/win_1803/py39/BUILD: diff --git a/third_party/xla/third_party/py/BUILD b/third_party/xla/third_party/py/BUILD index 7250861f26bfa2..661e8950c4dc2d 100644 --- a/third_party/xla/third_party/py/BUILD +++ b/third_party/xla/third_party/py/BUILD @@ -53,22 +53,8 @@ config_setting( }, ) -# Flag indicating if the target requires manylinux compliance verification. -bool_flag( - name = "verify_manylinux", - # TODO(ybaturina): Enable the flag by default when hermetic C++ toolchain is ready. - build_setting_default = False, +filegroup( + name = "manylinux_compliance_test", + srcs = ["manylinux_compliance_test.py"], visibility = ["//visibility:public"], ) - -py_binary( - name = "verify_manylinux_compliance", - srcs = [ - "verify_manylinux_compliance.py", - ], - main = "verify_manylinux_compliance.py", - visibility = ["//visibility:public"], - deps = [ - "@pypi_auditwheel//:pkg", - ], -) diff --git a/third_party/xla/third_party/py/ml_dtypes/workspace.bzl b/third_party/xla/third_party/py/ml_dtypes/workspace.bzl index 29a551da8d0017..962fb487c2d2f4 100644 --- a/third_party/xla/third_party/py/ml_dtypes/workspace.bzl +++ b/third_party/xla/third_party/py/ml_dtypes/workspace.bzl @@ -7,8 +7,8 @@ float8 varieties, and int4. load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - ML_DTYPES_COMMIT = "c12281a501469d553483eb4d68065826b9c2fcb5" - ML_DTYPES_SHA256 = "cee11c4bed5147bece9e385a88c20887344ad9b89b3acb09bf3d7c9c21fb9715" + ML_DTYPES_COMMIT = "0fa5313b65efe848c5968a15dd37dd220cc29567" + ML_DTYPES_SHA256 = "69c562bb961a21d92357c7709430553c226caac75a751c0aa52955ca14ce8641" tf_http_archive( name = "ml_dtypes", build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD", diff --git a/third_party/xla/third_party/py/python_init_rules.bzl b/third_party/xla/third_party/py/python_init_rules.bzl index 79bc343aae489e..796ae3d92d999f 100644 --- a/third_party/xla/third_party/py/python_init_rules.bzl +++ b/third_party/xla/third_party/py/python_init_rules.bzl @@ -8,4 +8,6 @@ def python_init_rules(): sha256 = "62ddebb766b4d6ddf1712f753dac5740bea072646f630eb9982caa09ad8a7687", strip_prefix = "rules_python-0.39.0", url = "https://github.com/bazelbuild/rules_python/releases/download/0.39.0/rules_python-0.39.0.tar.gz", + patch_args = ["-p1"], + patches = [Label("//third_party/py:rules_python.patch")], ) diff --git a/third_party/xla/third_party/py/rules_python.patch b/third_party/xla/third_party/py/rules_python.patch new file mode 100644 index 00000000000000..ef7ff2fc6f8e52 --- /dev/null +++ b/third_party/xla/third_party/py/rules_python.patch @@ -0,0 +1,39 @@ +diff --git a/python/private/pypi/deps.bzl b/python/private/pypi/deps.bzl +index 8949ed4a..8d0ab0e7 100644 +--- a/python/private/pypi/deps.bzl ++++ b/python/private/pypi/deps.bzl +@@ -51,8 +51,8 @@ _RULE_DEPS = [ + ), + ( + "pypi__packaging", +- "https://files.pythonhosted.org/packages/49/df/1fceb2f8900f8639e278b056416d49134fb8d84c5942ffaa01ad34782422/packaging-24.0-py3-none-any.whl", +- "2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5", ++ "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", ++ "09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", + ), + ( + "pypi__pep517", +@@ -61,8 +61,8 @@ _RULE_DEPS = [ + ), + ( + "pypi__pip", +- "https://files.pythonhosted.org/packages/8a/6a/19e9fe04fca059ccf770861c7d5721ab4c2aebc539889e97c7977528a53b/pip-24.0-py3-none-any.whl", +- "ba0d021a166865d2265246961bec0152ff124de910c5cc39f1156ce3fa7c69dc", ++ "https://files.pythonhosted.org/packages/ef/7d/500c9ad20238fcfcb4cb9243eede163594d7020ce87bd9610c9e02771876/pip-24.3.1-py3-none-any.whl", ++ "3790624780082365f47549d032f3770eeb2b1e8bd1f7b2e02dace1afa361b4ed", + ), + ( + "pypi__pip_tools", +diff --git a/python/private/pypi/evaluate_markers.bzl b/python/private/pypi/evaluate_markers.bzl +index c805fd7a..e57e6138 100644 +--- a/python/private/pypi/evaluate_markers.bzl ++++ b/python/private/pypi/evaluate_markers.bzl +@@ -20,7 +20,7 @@ load(":pypi_repo_utils.bzl", "pypi_repo_utils") + SRCS = [ + # When the version, or any of the files in `packaging` package changes, + # this file will change as well. +- Label("@pypi__packaging//:packaging-24.0.dist-info/RECORD"), ++ Label("@pypi__packaging//:packaging-24.2.dist-info/RECORD"), + Label("//python/private/pypi/requirements_parser:resolve_target_platforms.py"), + Label("//python/private/pypi/whl_installer:platform.py"), + ] \ No newline at end of file diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index 6d102a47289fe0..5675d833f11002 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,117 +1,32 @@ diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index dfa4b78..4f8ac49 100644 +index 509398d..c14fe64 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch -@@ -1,57 +1,42 @@ +@@ -1 +1,12 @@ Auto generated patch. Do not edit or delete it, even if empty. --diff -ruN --strip-trailing-cr a/llvm/lib/ExecutionEngine/Orc/ExecutionUtils.cpp b/llvm/lib/ExecutionEngine/Orc/ExecutionUtils.cpp ----- a/llvm/lib/ExecutionEngine/Orc/ExecutionUtils.cpp --+++ b/llvm/lib/ExecutionEngine/Orc/ExecutionUtils.cpp --@@ -573,7 +573,6 @@ -- // Create __imp_ symbol -- jitlink::Symbol &Ptr = -- jitlink::x86_64::createAnonymousPointer(*G, Sec, &Target); --- auto name = getImpPrefix() + *KV.first; -- Ptr.setName(G->intern((Twine(getImpPrefix()) + *KV.first).str())); -- Ptr.setLinkage(jitlink::Linkage::Strong); -- Ptr.setScope(jitlink::Scope::Default); --diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/bolt/BUILD.bazel b/utils/bazel/llvm-project-overlay/bolt/BUILD.bazel ----- a/utils/bazel/llvm-project-overlay/bolt/BUILD.bazel --+++ b/utils/bazel/llvm-project-overlay/bolt/BUILD.bazel --@@ -285,6 +285,7 @@ -- "//llvm:MCParser", -- "//llvm:Object", -- "//llvm:ObjectYAML", --+ "//llvm:OrcShared", -- "//llvm:Support", -- "//llvm:TargetParser", -- "//llvm:config", --diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ----- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel --+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel --@@ -1442,7 +1442,10 @@ -- hdrs = glob(["src/__support/time/*.h"]), -- deps = [ -- ":__support_common", --+ ":__support_error_or", -- ":hdr_time_macros", --+ ":types_clockid_t", --+ ":types_struct_timespec", -- ":types_time_t", -- ], -- ) --@@ -1486,6 +1489,8 @@ -- ":__support_common", -- ":__support_error_or", -- ":__support_osutil_vdso", --+ ":types_clockid_t", --+ ":types_struct_timespec", -- ], -- ) -+diff -ruN --strip-trailing-cr a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst -+--- a/clang/docs/ReleaseNotes.rst -++++ b/clang/docs/ReleaseNotes.rst -+@@ -796,7 +796,6 @@ -+ - Fixed an assertion failure caused by mangled names with invalid identifiers. (#GH112205) -+ - Fixed an incorrect lambda scope of generic lambdas that caused Clang to crash when computing potential lambda -+ captures at the end of a full expression. (#GH115931) -+-- Clang no longer rejects deleting a pointer of incomplete enumeration type. (#GH99278) - --diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel ----- a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel --+++ b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel --@@ -2800,6 +2800,7 @@ -- ":MC", -- ":MCDisassembler", -- ":Object", --+ ":OrcShared", -- ":OrcTargetProcess", -- ":Passes", -- ":Support", -+ Bug Fixes to AST Handling -+ ^^^^^^^^^^^^^^^^^^^^^^^^^ -+diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp -+--- a/clang/lib/Sema/SemaExprCXX.cpp -++++ b/clang/lib/Sema/SemaExprCXX.cpp -+@@ -3747,8 +3747,7 @@ -+ } else if (!Pointee->isDependentType()) { -+ // FIXME: This can result in errors if the definition was imported from a -+ // module but is hidden. -+- if (!Pointee->isStructureOrClassType() || -+- !RequireCompleteType(StartLoc, Pointee, -++ if (!RequireCompleteType(StartLoc, Pointee, -+ LangOpts.CPlusPlus26 -+ ? diag::err_delete_incomplete -+ : diag::warn_delete_incomplete, -+diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/new-delete.cpp b/clang/test/SemaCXX/new-delete.cpp -+--- a/clang/test/SemaCXX/new-delete.cpp -++++ b/clang/test/SemaCXX/new-delete.cpp -+@@ -540,13 +540,6 @@ -+ void f(A *x) { delete x; } // expected-warning {{delete called on 'PR10504::A' that is abstract but has non-virtual destructor}} -+ } ++diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/Hexagon/isel/isel-tfrrp.ll b/llvm/test/CodeGen/Hexagon/isel/isel-tfrrp.ll ++--- a/llvm/test/CodeGen/Hexagon/isel/isel-tfrrp.ll +++++ b/llvm/test/CodeGen/Hexagon/isel/isel-tfrrp.ll ++@@ -2,6 +2,7 @@ ++ ; The constant 0 is generated by a transfer immediate instruction. + -+-#if __cplusplus >= 201103L -+-enum GH99278_1 { -+- zero = decltype(delete static_cast(nullptr), 0){} -+- // expected-warning@-1 {{expression with side effects has no effect in an unevaluated context}} -+-}; -+-#endif -+- -+ struct PlacementArg {}; -+ inline void *operator new[](size_t, const PlacementArg &) throw () { -+ return 0; ++ ; RUN: llc -march=hexagon -debug-only=isel 2>&1 < %s - | FileCheck %s +++; REQUIRES: asserts ++ ++ ; CHECK: [[R0:%[0-9]+]]:intregs = A2_tfrsi 0 ++ ; CHECK-NEXT: predregs = C2_tfrrp killed [[R0]]:intregs diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index e60a1c8..7c3347b 100644 +index 02401a7..c35f4e4 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "2ccf7ed277df28651b94bbee9fccefdf22fb074f" -- LLVM_SHA256 = "ca68a54dcd12c0dde32732a90899bf57e0f3f96fc43d8d1124d95a5eae627508" -+ LLVM_COMMIT = "1d95825d4d168a17a4f27401dec3f2977a59a70e" -+ LLVM_SHA256 = "d3276c678b616c0d820fe14a3404b43591f4e1bc75b6bed2782e0776e0c9b401" +- LLVM_COMMIT = "a531800344dc54e9c197a13b22e013f919f3f5e1" +- LLVM_SHA256 = "74a873f8d4c677d192e9bfade095af3363c76b0fb23c5f6260121d74322744bc" ++ LLVM_COMMIT = "35e76b6a4fc74e64bd6c91e5b9b9eb6a03aa802e" ++ LLVM_SHA256 = "bf4e52c430ff8eb2b055a4abcbd70468d2e6ea7f277e472575e92903bd7d8981" tf_http_archive( name = name, diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index a2396e5007c48e..a8f7e817753eae 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "cdc7e854703cecf8dcd16db45b92b7be005c4f60" - SHARDY_SHA256 = "13f4f2d5cf241f97ba098ba5683fe066cf075f62cfdcba6287ba3b225a78e40e" + SHARDY_COMMIT = "2ca9cd74b9f9fc5851d0b19c4cc07b1cfc35f0e3" + SHARDY_SHA256 = "502353ad1b00303cab5141ac3a85f4bb6ef61340679353cf79a5d6d1b58139dd" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 8b137891791fe9..071bba3084c74b 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -1 +1,154 @@ +diff --ruN a/stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py b/stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py +--- stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py ++++ stablehlo/build_tools/math/generate_ChloDecompositionPatternsMath.py +@@ -71,8 +71,15 @@ + + output_file = os.path.relpath( + os.path.normpath( +- os.path.join(os.path.dirname(__file__), "..", "..", "stablehlo", +- "transforms", output_filename)), ++ os.path.join( ++ os.path.dirname(__file__), ++ "..", ++ "..", ++ "stablehlo", ++ "transforms", ++ output_filename, ++ ) ++ ), + os.getcwd(), + ) + +@@ -105,7 +112,8 @@ + func = getattr(fa.algorithms, fname, None) + if func is None: + warnings.warn( +- f"{fa.algorithms.__name__} does not define {fname}. Skipping.") ++ f"{fa.algorithms.__name__} does not define {fname}. Skipping." ++ ) + continue + ctx = fa.Context(paths=[fa.algorithms], + parameters=dict(rewrite_keep_integer_literals=True)) +@@ -116,14 +124,15 @@ + sources[-1] += src + source = "\n\n".join(sources) + "\n" + +- if chloname.startswith('StableHLO_'): ++ if chloname.startswith("StableHLO_"): + # an ugly hack to fix the definition of stablehlo complex math + # functions. TODO(pearu): add the corresponding feature to + # functional_algorithms stablehlo printer +- NameOp = chloname.split('_', 1)[1] ++ NameOp = chloname.split("_", 1)[1] + source = source.replace( +- f'def : Pat<({chloname}', +- f'def {NameOp}_ComplexElementType_ComplexMathExpander : Pat<({chloname}' ++ f"def : Pat<({chloname}", ++ f"def {NameOp}_ComplexElementType_ComplexMathExpander :" ++ f" Pat<({chloname}", + ) + + if os.path.isfile(output_file): +diff --ruN a/stablehlo/build_tools/math/generate_tests.py b/stablehlo/build_tools/math/generate_tests.py +--- stablehlo/build_tools/math/generate_tests.py ++++ stablehlo/build_tools/math/generate_tests.py +@@ -64,10 +64,12 @@ + dict(name="acosh", mpmath_name="arccosh"), + dict(name="atanh", mpmath_name="arctanh"), + dict(name="square", mpmath_name="square"), +- dict(name="log_plus_one", +- mpmath_name="log1p", +- namespace="stablehlo", +- passes="--stablehlo-complex-math-expander"), ++ dict( ++ name="log_plus_one", ++ mpmath_name="log1p", ++ namespace="stablehlo", ++ passes="--stablehlo-complex-math-expander", ++ ), + ] + + +@@ -138,13 +140,16 @@ + params = fa.utils.function_validation_parameters(opname, dtype) + max_ulp_difference = op.get( + "max_ulp_difference", +- params.get("max_valid_ulp_count", default_max_ulp_difference)) ++ params.get("max_valid_ulp_count", default_max_ulp_difference), ++ ) + + nmp = fa.utils.numpy_with_mpmath( + extra_prec_multiplier=op.get( + "extra_prec_multiplier", +- params.get("extra_prec_multiplier", +- default_extra_prec_multiplier)), ++ params.get( ++ "extra_prec_multiplier", default_extra_prec_multiplier ++ ), ++ ), + flush_subnormals=flush_subnormals, + ) + +@@ -208,8 +213,10 @@ + continue + + f = open(fname, "w") +- f.write(f"// RUN: stablehlo-opt {passes} %s |" +- " stablehlo-translate --interpret\n") ++ f.write( ++ f"// RUN: stablehlo-opt {passes} %s |" ++ " stablehlo-translate --interpret\n" ++ ) + f.write( + "// This file is generated, see build_tools/math/README.md for more" + " information.\n") +diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp b/stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp +--- stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp ++++ stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp +@@ -107,6 +107,8 @@ + + LinalgTypeConverter::LinalgTypeConverter() : RemoveSignTypeConverter() { + addArgumentMaterialization(scalarToTensor); ++ addSourceMaterialization(scalarToTensor); ++ addTargetMaterialization(scalarToTensor); + } + + } // namespace mlir::stablehlo +diff --ruN a/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir b/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir +--- stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir ++++ stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir +@@ -440,7 +440,6 @@ + } + + // ----- +- + + // CHECK-LABEL: func.func @asinh_f64( + // CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +@@ -2788,7 +2787,6 @@ + + // ----- + +- + // CHECK-LABEL: @sinh_f32 + // CHECK-SAME: (%[[X:.*]]: tensor) + func.func @sinh_f32(%x : tensor) -> tensor { +@@ -3891,6 +3889,8 @@ + return + } + ++// ----- ++ + // CHECK-LABEL: @square_complex_f32( + // CHECK-SAME: %[[VAL_0:.*]]: tensor>) -> tensor> { + // CHECK: %[[VAL_1:.*]] = stablehlo.real %[[VAL_0]] : (tensor>) -> tensor +@@ -3916,6 +3916,8 @@ + func.return %result : tensor> + } + ++// ----- ++ + // CHECK-LABEL: @square_f32( + // CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { + // CHECK: %[[VAL_1:.*]] = stablehlo.multiply %[[VAL_0]], %[[VAL_0]] : tensor diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index 20badb638791f8..dfae5f53d44715 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "b3d3cacde8994df313297e68713ed74c2ca279ee" - STABLEHLO_SHA256 = "8bb81d7f60f19493b1edfc916adcfe1f9d1deeaf77c9ca7a896e05861505817d" + STABLEHLO_COMMIT = "38bb2f9bf63b714e8a49fe34a478139058ee1660" + STABLEHLO_SHA256 = "74feb9f9f34eb4dd0b11404371af58f7a5a5ded177d38b01b53174ce757a3a61" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/triton/llvm_integration/cl704999069.patch b/third_party/xla/third_party/triton/llvm_integration/cl704999069.patch new file mode 100644 index 00000000000000..95dd8fe8292fed --- /dev/null +++ b/third_party/xla/third_party/triton/llvm_integration/cl704999069.patch @@ -0,0 +1,21 @@ + +--- a/lib/Dialect/Triton/Transforms/Combine.td 2024-12-05 23:53:31.000000000 -0800 ++++ b/lib/Dialect/Triton/Transforms/Combine.td 2024-12-11 00:38:55.000000000 -0800 +@@ -17,7 +17,7 @@ + [(Constraint> $c), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; + def CombineDotAddFPattern : Pat< +- (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath, $denorm), ++ (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), +@@ -29,7 +29,7 @@ + [(Constraint> $c), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; + def CombineDotAddFRevPattern : Pat< +- (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath, $denorm), ++ (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), diff --git a/third_party/xla/third_party/triton/temporary/const_signature_fixes.patch b/third_party/xla/third_party/triton/temporary/const_signature_fixes.patch new file mode 100644 index 00000000000000..26c3d8014e953f --- /dev/null +++ b/third_party/xla/third_party/triton/temporary/const_signature_fixes.patch @@ -0,0 +1,92 @@ +diff --git a/third_party/f2reduce/f2reduce.cpp b/third_party/f2reduce/f2reduce.cpp +--- a/third_party/f2reduce/f2reduce.cpp ++++ b/third_party/f2reduce/f2reduce.cpp +@@ -470,8 +470,8 @@ namespace f2reduce { + + void inplace_rref_strided(uint64_t *matrix, uint64_t rows, uint64_t cols, uint64_t stride) { + +- if (rows <= 1) { +- // If the matrix has 0 or 1 rows, it must already be in RREF: ++ if (rows <= 1 || cols <= 1) { ++ // If the matrix has 0 or 1 rows or columns, it must already be in RREF: + return; + } + +diff --git a/third_party/nvidia/backend/cuda_utils.cc b/third_party/nvidia/backend/cuda_utils.cc +--- a/third_party/nvidia/backend/cuda_utils.cc ++++ b/third_party/nvidia/backend/cuda_utils.cc +@@ -276,8 +276,10 @@ const ExtractionInfo kExtractionInfos[]{ + ExtractionInfo::build({"'u64'"}), + ExtractionInfo::build({"'fp16'", "'bf16'", "'fp32'", "'f32'"}), + ExtractionInfo::build({"'fp64'"}), ++ // Note: types are e.g. '*fp32', so no closing quote is intentional. + ExtractionInfo::build({"'*"}, extractPointer), +- ExtractionInfo{{"None"}, 0, nullptr}, // Represent constexprs as None ++ ExtractionInfo{ ++ {"None", "'none'"}, 0, nullptr}, // Represent constexprs as None + }; + + // Finds an extractor that supports a given type_repr in the extractor list. +diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py +--- a/third_party/nvidia/backend/driver.py ++++ b/third_party/nvidia/backend/driver.py +@@ -92,7 +92,22 @@ def ty_to_cpp(ty): + }[ty] + + +-def make_launcher(constants : dict[int, str], signature : dict[int, any]) -> Callable[..., None]: ++def flatten_tuples(xs): ++ """Recursively flattens tuple elements in xs.""" ++ for x in xs: ++ if isinstance(x, tuple): ++ yield from flatten_tuples(x) ++ else: ++ yield x ++ ++ ++def make_launcher(constants : dict[int, str], signature : dict[int, any], ids : dict[str, tuple]) -> Callable[..., None]: ++ ++ signature = {k: v for k, v in signature.items() if v != 'constexpr'} ++ signature = ','.join(signature.values()).replace('[', '').replace(']', '') ++ signature = list(filter(bool, signature.split(','))) ++ signature = {i: s for i, s in enumerate(signature)} ++ + # We seem to have 3 categories of arguments: + # 1. arguments listed in signature + # 2. arguments listed in constants +@@ -103,8 +118,8 @@ def make_launcher(constants : dict[int, + # category (3). The generic C++ launcher currently does not do that, so we + # are doing it in the python wrapper. + signature_metadata = cuda_utils.build_signature_metadata( +- ty if arg_id not in constants else None +- for arg_id, ty in signature.items()) ++ ty for ty in signature.values()) ++ + def wrapper(grid_dim_x: int, grid_dim_y: int, grid_dim_z: int, + stream: int, kernel: int, global_scratch: any, + packed_metadata: tuple[int, int, int, int, int, int], +@@ -115,18 +130,18 @@ def make_launcher(constants : dict[int, + cuda_utils.launch(grid_dim_x, grid_dim_y, grid_dim_z, stream, kernel, + packed_metadata, hook_args, launch_enter_hook, + launch_exit_hook, signature_metadata, global_scratch, +- args) ++ flatten_tuples(args)) + return wrapper + + + class CudaLauncher(object): + + def __init__(self, src, metadata): +- constants = getattr(src, "constants", dict()) +- cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i +- constants = {cst_key(key): value for key, value in constants.items()} +- signature = {cst_key(key): value for key, value in src.signature.items()} +- self.launch = make_launcher(constants, signature) ++ ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} ++ constants = src.constants if hasattr(src, "constants") else dict() ++ constants = {idx: value for idx, value in constants.items()} ++ signature = {idx: value for idx, value in src.signature.items()} ++ self.launch = make_launcher(constants, signature, ids) + self.global_scratch_size = metadata.global_scratch_size + self.global_scratch_align = metadata.global_scratch_align + diff --git a/third_party/xla/third_party/triton/temporary/numpy_type_promotion.patch b/third_party/xla/third_party/triton/temporary/numpy_type_promotion.patch new file mode 100644 index 00000000000000..e41638db8fcaf8 --- /dev/null +++ b/third_party/xla/third_party/triton/temporary/numpy_type_promotion.patch @@ -0,0 +1,12 @@ +--- a/python/test/unit/language/test_core.py ++++ b/python/test/unit/language/test_core.py +@@ -363,8 +363,7 @@ def _test_binary(dtype_x, dtype_y, expr, + # We remove any explicit casting + pattern = r'\.astype\(np\.\w+\)' + scalar_expr = expr if numpy_expr is None else re.sub(pattern, '', numpy_expr) +- with promotion_numpy_2_0(): +- z_ref = eval(scalar_expr) ++ z_ref = eval(scalar_expr) + else: + z_ref = eval(expr if numpy_expr is None else numpy_expr) + diff --git a/third_party/xla/third_party/triton/temporary/revert_67ea999.patch b/third_party/xla/third_party/triton/temporary/revert_67ea999.patch new file mode 100644 index 00000000000000..22239930a1005c --- /dev/null +++ b/third_party/xla/third_party/triton/temporary/revert_67ea999.patch @@ -0,0 +1,556 @@ +This patch is reverting https://github.com/triton-lang/triton/commit/67ea999935f4511a535a25bdecb27e79e3c3af41 +which breaks //learning/deepmind/jax/triton/ops:attention_test_gpu_a100 +The patch is very intrusive due to how big the change is, so it should be prioritized for removal. +This is tracked in b/385090655. + +diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h +--- a/include/triton/Tools/LinearLayout.h ++++ b/include/triton/Tools/LinearLayout.h +@@ -681,6 +681,13 @@ public: + // (i.e. every input bit affects the output). + llvm::MapVector getFreeVariableMasks() const; + ++ // Increase an input dimension without affecting the output dimension. The ++ // added free variables are mapped to 0, ensuring that the new input ++ // dimensions correspond directly to the existing output space. The function ++ // errors out if `newInDimSize` is less than the current size or the new size ++ // is not a power of 2. ++ LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const; ++ + std::string toString() const; + + friend bool operator==(LinearLayout lhs, LinearLayout rhs); +diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp +--- a/lib/Analysis/Utility.cpp ++++ b/lib/Analysis/Utility.cpp +@@ -683,8 +683,42 @@ std::optional minimalCvtLa + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + StringAttr kBlock = StringAttr::get(ctx, "block"); +- +- auto comp = dstLayout->invertAndCompose(*srcLayout); ++ auto numSrcRegs = srcLayout->getInDimSize(kRegister); ++ auto numDstRegs = dstLayout->getInDimSize(kRegister); ++ // The `invertAndCompose` function will generate a layout that is injective ++ // by assigning new output dimensions to free variables. For instance, ++ // consider a scenario where `srcLayout` has a free variable in the lane ++ // dimension, while `dstLayout` has two free variables in the lane ++ // dimension and also a larger number of registers. ++ // The injective form of `srcLayout` will add only a single additional row ++ // to the transformation matrix, whereas the injective form of `dstLayout` ++ // will add two additional rows. This discrepancy causes misleading results ++ // because the matrices end up with a different number of rows. ++ // ++ // Take `dstLayout ⋅ srcLayout^-1` as an example: ++ // ++ // - `injective(dstLayout)`: [n, m] → [n + 2, m] ++ // - `injective(srcLayout)`: [n, m] → [n + 1, m] ++ // - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1] ++ // - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n + ++ // 1] → [n + 2, n + 1] ++ // ++ // Here, the `(n + 1)`-th row added by `dstLayout` represents the free ++ // variable in registers, and the `(n + 2)`-th row represents the free ++ // variable in lanes. However, the `(n + 1)`-th row added by `srcLayout` ++ // represents the free variable in lanes. As a result, the `(n + 1)`-th row ++ // in two layouts do not correspond to the same free variable. ++ // ++ // To address this issue, we pad the free variables in `srcLayout` and ++ // `dstLayout` to ensure they have the same number of registers. This ++ // guarantees that the resulting matrices have the same number of rows, ++ // ensuring consistency in the composition process. ++ auto numRegs = std::max(numSrcRegs, numDstRegs); ++ auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs); ++ auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs); ++ // comp describes the layout function to create dst from src. ++ LinearLayout comp = ++ dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs); + // We try to quotient by the largest subspace first + auto dims = SmallVector{"block", "warp", "lane", "register"}; + for (auto dim : dims) { +diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +--- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp ++++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +@@ -315,10 +315,14 @@ struct ConvertLayoutOpUsingLinearLayouts + // TODO(Keren): implement warp shuffle instead of using the general + // approach that uses shared memory + return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); +- } else if (llvm::is_contained(dims, kRegister)) { ++ } else if (llvm::is_contained(dims, kRegister) || ++ dstLayout.getInDimSize(kRegister) != ++ srcLayout.getInDimSize(kRegister)) { + // Case 4. Transfer between values in the same thread, in which case we + // simply reorder the elements of adaptor.getSrc(). +- return transferWithinThread(op, *conversion, adaptor, rewriter); ++ return transferWithinThread( ++ op, dstLayout.getFreeVariableMasks()[kRegister], ++ dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter); + } else { + // Cast 5. The two layouts are equivalent. We should probably remove + // these in RemoveLayoutConversion. +@@ -328,8 +332,8 @@ struct ConvertLayoutOpUsingLinearLayouts + } + + LogicalResult +- transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion, +- OpAdaptor adaptor, ++ transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs, ++ const LinearLayout &conversion, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); +@@ -339,9 +343,16 @@ struct ConvertLayoutOpUsingLinearLayouts + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); +- SmallVector outVals(conversion.getInDimSize(kRegister)); +- for (int i = 0; i < outVals.size(); i++) { +- auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; ++ SmallVector outVals(numRegs); ++ for (int i = 0; i < numRegs; i++) { ++ // Remove free masks from the register index ++ // For example, if idx = 0b00111, and masks = 0b00100, then we get ++ // 0b00011. It means that register 7 (0b111) has the same value as ++ // register 3 (0b011). ++ auto idx = i & (~regMasks); ++ auto srcIdx = conversion.hasInDim(kRegister) ++ ? conversion.apply({{kRegister, idx}}).begin()->second ++ : idx; + outVals[i] = inVals[srcIdx]; + } + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, +diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp +--- a/lib/Tools/LinearLayout.cpp ++++ b/lib/Tools/LinearLayout.cpp +@@ -112,6 +112,30 @@ std::unique_ptr getMatrix(co + return m; + } + ++// Get a matrix for `layout` with its codomain expanded so it's injective, i.e. ++// each input element maps to a unique output element. We do this by finding ++// columns that are equal to 0 and adding a new row with a 1 in that column. ++std::tuple, int /*numRows*/, int /*numCols*/> ++getInjectiveMat(const LinearLayout &layout) { ++ int numRows = layout.getTotalOutDimSizeLog2(); ++ int numCols = layout.getTotalInDimSizeLog2(); ++ std::unique_ptr mat = getMatrix(layout); ++ ++ // Bits of mat or-reduced along the columns (so there's just one row). ++ uint64_t colBits = 0; ++ for (int r = 0; r < numRows; r++) { ++ colBits |= mat[r]; ++ } ++ auto expanded = std::unique_ptr(new uint64_t[numRows + numCols]); ++ std::memcpy(expanded.get(), mat.get(), numRows * sizeof(uint64_t)); ++ for (int c = 0; c < numCols; c++) { ++ if ((colBits & (1 << c)) == 0) { ++ expanded[numRows++] = (1 << c); ++ } ++ } ++ return std::make_tuple(std::move(expanded), numRows, numCols); ++} ++ + // Compute the rank of the matrix formed by taking the bases for the given + // outDim as columns. In other words, finds the number of linearly-independent + // bases for this output dimension. +@@ -780,179 +804,118 @@ LinearLayout LinearLayout::compose(const + compositionIsSurjective); + } + +-namespace { +-std::unique_ptr concatMatrices(const LinearLayout &A, +- const LinearLayout &B) { +- // In plain words, "convert_layout does not change the shape of a tensor" +- assert(A.getTotalOutDimSizeLog2() == B.getTotalOutDimSizeLog2() && +- "Matrices must have the same number of output dimensions"); +- int numRows = A.getTotalOutDimSizeLog2(); +- int numColsA = A.getTotalInDimSizeLog2(); +- +- // rref expects the lower bits to be the lower indices of the matrix +- auto concat = getMatrix(A); +- auto BMat = getMatrix(B); +- for (int r = 0; r < numRows; r++) { +- concat[r] |= BMat[r] << numColsA; ++LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const { ++ assertDimsEqualIgnoringOrder(getOutDimNames(), outer.getOutDimNames()); ++ for (StringAttr outDim : getOutDimNames()) { ++ assert(getOutDimSize(outDim) <= outer.getOutDimSize(outDim)); + } +- return concat; +-} ++ assert(outer.isSurjective()); + +-LinearLayout lstsq(const LinearLayout &A, const LinearLayout &B) { +- // Solve the least square system AX = B for A = outer, B = *this +- // and return the least square solution X of minimal norm +- // A and B may not be surjective, but we assume that Im(B) \subset Im(A) +- // Sketch of the algorithm: +- // https://github.com/triton-lang/triton/pull/5309#discussion_r1869084111 +- int numRows = A.getTotalOutDimSizeLog2(); +- int numColsA = A.getTotalInDimSizeLog2(); +- int numColsB = B.getTotalInDimSizeLog2(); +- int numCols = numColsA + numColsB; +- std::unique_ptr combinedMat = concatMatrices(A, B); +- f2reduce::inplace_rref_strided(combinedMat.get(), numRows, numCols, ++ // Make both `this` and `outer` injective. We need to do this on the ++ // `outer` layout because we can't invert a non-injective function. We ++ // choose to do so on the `this` layout as well. The rest of the comment ++ // explains why we make that choice. ++ // ++ // Recall from the header that C = A.invertAndCompose(B) just means that ++ // A(x) = B(C(x)). ++ // ++ // Sometimes we may have a choice of multiple values for a particular ++ // C(x). For example, if A(1) = B(0) = B(1) = 0, then C(1) can be either 0 ++ // or 1. ++ // ++ // We want to choose C such that C(x) != 0 where possible. For example, ++ // suppose we are transferring from registers to registers and we have the ++ // following layouts. ++ // ++ // A(thread=1, block=0) = 1 ++ // A(thread=2, block=0) = 2 ++ // A(thread=0, block=1) = 0 ++ // ++ // B(thread=1, block=0) = 2 ++ // B(thread=2, block=0) = 1 ++ // B(thread=0, block=1) = 0 ++ // ++ // Notice that A and B both have the same data in each of their two ++ // blocks. So if we want to transfer from A to B, we don't need to cross ++ // blocks, which is expensive. We want A.invertAndCompose(B) to reflect ++ // that choice. ++ // ++ // Let A' be A with the last line changed to "=4", and similarly for B'. ++ // When transferring from A' to B', we can't cross blocks even if we wanted ++ // to, because the two blocks now have different data. But also, any ++ // mapping of thread+block from A' to B' is also valid for mapping from A ++ // to B. ++ // ++ // Thus making A and B injective encodes our desire not to cross blocks, ++ // or more generally our desire that C(x) != 0 where possible. ++ auto [matThis, numRowsThis, numColsThis] = getInjectiveMat(*this); ++ auto [matOuter, numRowsOuter, numColsOuter] = getInjectiveMat( ++ outer.transposeOuts(llvm::to_vector(this->getOutDimNames()))); ++ ++ // Concatenate `matOuter` and `matThis` horizontally (i.e. `matThis` ++ // is to the right of `matOuter`). ++ int combinedNumRows = std::max(numRowsThis, numRowsOuter); ++ int combinedNumCols = numColsThis + numColsOuter; ++ assert(combinedNumCols <= 64 && "Can't handle huge layouts"); ++ ++ std::unique_ptr m(new uint64_t[combinedNumRows]()); ++ for (int r = 0; r < numRowsOuter; r++) { ++ m[r] = matOuter[r]; ++ } ++ for (int r = 0; r < numRowsThis; r++) { ++ m[r] |= matThis[r] << numColsOuter; ++ } ++ ++ // Perform Gaussian elimination on `m`. Because `outer` was modified to ++ // be bijective, the first half of the matrix should be the identity ++ // matrix. The remaining half are the bases for the combined ++ // transformation. ++ // ++ // `stride` is specified in number of 64-bit words per row, and we pack ++ // our matrix so that there's only one uint64_t per row. ++ f2reduce::inplace_rref_strided(m.get(), combinedNumRows, combinedNumCols, + /*stride=*/1); + +- // Compute the pivot columns +- // Since A and B have the same image, each row will either have a pivot +- // or will be all zeros +- SmallVector pivotCols; +- for (int r = 0; r < numRows; r++) { +- auto row = combinedMat[r]; +- if (row == 0) { +- continue; ++ // Check that the first half of the matrix is indeed the identity. ++ for (int r = 0; r < std::min(numRowsOuter, numColsOuter); r++) { ++ for (int c = 0; c < std::min(numColsOuter, numRowsOuter); c++) { ++ if (((m[r] >> c) & 1) != (r == c ? 1 : 0)) { ++ llvm::report_fatal_error("First half of the matrix was not the " ++ "identity, bug in invertAndCompose"); ++ } + } +- int c = __builtin_ctzll(row); +- assert(c < numColsA && "Precondition broken. Im(B) not contained in Im(A)"); +- assert(pivotCols.empty() || +- pivotCols.back() < c && "Pivot columns are not in increasing order"); +- pivotCols.push_back(c); +- } +- +- // Extract A^{-1}B and complete the matrix using zeros +- std::unique_ptr retMat(new uint64_t[numColsA]()); +- int j = 0; +- for (int r = 0; r < numColsA; r++) { +- auto isPivot = j < pivotCols.size() && pivotCols[j] == r; +- retMat[r] = isPivot ? combinedMat[j++] >> numColsA : 0; + } + + // We need names for the in/out dim of the flattened layout we're going to + // read off from `m`. These could be anything, doesn't matter. +- StringAttr inDim1D = *A.getInDimNames().begin(); +- StringAttr outDim1D = *A.getOutDimNames().begin(); ++ StringAttr inDim1D = *getInDimNames().begin(); ++ StringAttr outDim1D = *getOutDimNames().begin(); + + // Read off the new bases. These are for a flattened 1D -> 1D +- LinearLayout::BasesT retBases; +- auto &bs = retBases[inDim1D]; +- for (int c = 0; c < numColsB; c++) { ++ // transformation from `this`'s in-dims to `outer`'s in-dims. ++ BasesT newBases; ++ auto &bs = newBases[inDim1D]; ++ for (int c = 0; c < numColsThis; c++) { + int32_t basis = 0; +- for (int r = 0; r < numColsA; r++) { +- basis |= (retMat[r] >> c & 1) << r; ++ for (int r = 0; r < numRowsOuter; r++) { ++ basis |= (m[r] >> (numColsOuter + c) & 1) << r; + } + bs.push_back({basis}); + } + +- LinearLayout retFlattened(std::move(retBases), +- {{outDim1D, A.getTotalInDimSize()}}, ++ LinearLayout flatComposed(std::move(newBases), ++ {{outDim1D, outer.getTotalInDimSize()}}, + /*requireSurjective=*/false); + + SmallVector> retInDims; + SmallVector> retOutDims; +- for (StringAttr dim : B.getInDimNames()) { +- retInDims.push_back({dim, B.getInDimSize(dim)}); +- } +- for (StringAttr dim : A.getInDimNames()) { +- retOutDims.push_back({dim, A.getInDimSize(dim)}); +- } +- return retFlattened.reshapeIns(retInDims).reshapeOuts(retOutDims); +-} +- +-} // namespace +- +-LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const { +- // TODO(Lezcano) Make friend and perhaps rename to `convertFrom` or `lstsq` +- // For this, we need to implement our LLVM lowerings by inverting the "outer" +- // layout, and then iterating over the elements from the "this" layout and +- // fetching the corresponding element from the "outer" layout. This exercises +- // the broadcasting that we incentivise via choosing the minimum norm solution +- // in lstsq. +- +- // The order of dims does not matter. We choose to transpose outer +- auto outDims = llvm::to_vector(getOutDimNames()); +- assertDimsEqualIgnoringOrder(outDims, outer.getOutDimNames()); +- const auto &B = *this; +- const auto A = outer.transposeOuts(outDims); +- for (auto dim : outDims) { +- assert(A.getOutDimSize(dim) == B.getOutDimSize(dim) && +- "Convert layout does not change the shape of a tensor"); ++ for (StringAttr dim : getInDimNames()) { ++ retInDims.push_back({dim, getInDimSize(dim)}); + } +- +- // We'll write A^{-1} to mean the inverse or the pseudo-inverse of A +- // We are computing A^{-1}B so A must be surjective so that +- // it has a left inverse. +- assert(A.isSurjective()); +- +- // Broadcasting heuristic +- // Imagine we have two layouts with `warps = [[0, 0],  [0, 0]]` +- // (broadcasting) on both layouts. We could map any warp to any warp in the +- // conversion. Now, we want to map them as the identity map, to mark that +- // nothing needs to be done there (`lstsq` would map all the warps to the +- // zero warp, minimum norm solution). The heuristic here is as follows: +- // - If a dimension is the same for both layouts, we want to map it as the +- // identity +- // Equivalently, we don't add it to the conversion +- // - Otherwise, we just call lstsq (i.e. map all the equivalent elements +- // to the same input element) to take advantage of broadcasting in shared +- // memory and avoid saving repeated elements in shared memory +- SmallVector identityDims; +- for (auto dim : A.getInDimNames()) { +- if (B.hasInDim(dim) && +- A.sublayout(dim, outDims) == B.sublayout(dim, outDims)) { +- identityDims.push_back(dim); +- } +- } +- SmallVector ANonIdentityInDims; +- SmallVector BNonIdentityInDims; +- for (auto dim : A.getInDimNames()) { +- if (!llvm::is_contained(identityDims, dim)) { +- ANonIdentityInDims.push_back(dim); +- } ++ for (StringAttr dim : outer.getInDimNames()) { ++ retOutDims.push_back({dim, outer.getInDimSize(dim)}); + } +- for (auto dim : B.getInDimNames()) { +- if (!llvm::is_contained(identityDims, dim)) { +- BNonIdentityInDims.push_back(dim); +- } +- } +- +- auto AReduced = A.sublayout(ANonIdentityInDims, outDims); +- auto BReduced = B.sublayout(BNonIdentityInDims, outDims); +- +- // If one is empty, the other must be empty as well +- assert((AReduced == LinearLayout::empty()) == +- (BReduced == LinearLayout::empty())); +- bool isEmpty = AReduced == LinearLayout::empty(); +- +- auto ret = isEmpty ? LinearLayout::empty() : lstsq(AReduced, BReduced); +- +- // TODO(Lezcano): We should return the reduced layout instead of re-adding the +- // identity maps. With this, we'll be able to kill `minimalCvtLayout` +- +- // Add the identity maps for the dimensions that are the same for both layouts +- for (auto dim : identityDims) { +- ret *= LinearLayout::identity1D(A.getInDimSize(dim), dim, dim); +- } +- +- // Reshape the result +- SmallVector> inDimsA; +- SmallVector> inDimsB; +- for (auto dim : A.getInDimNames()) { +- inDimsA.push_back({dim, A.getInDimSize(dim)}); +- } +- for (auto dim : B.getInDimNames()) { +- inDimsB.push_back({dim, B.getInDimSize(dim)}); +- } +- ret = ret.reshapeIns(inDimsB).reshapeOuts(inDimsA); +- return ret; ++ return flatComposed.reshapeIns(retInDims).reshapeOuts(retOutDims); + } + + llvm::MapVector +@@ -1041,6 +1004,21 @@ bool LinearLayout::equalIgnoringOutDimSi + return true; + } + ++LinearLayout LinearLayout::resize(StringAttr inDim, ++ int32_t newInDimSize) const { ++ BasesT bases = getBases(); ++ assert(bases.contains(inDim) && "inDim not in layout"); ++ assert(llvm::isPowerOf2_32(newInDimSize) && ++ "newInDimSize must be a power of 2"); ++ assert(newInDimSize >= getInDimSize(inDim) && ++ "newInDimSize must be >= old size"); ++ auto numFreeVariables = llvm::Log2_32(newInDimSize) - getInDimSizeLog2(inDim); ++ for (int i = 0; i < numFreeVariables; i++) { ++ bases[inDim].push_back(std::vector(getNumOutDims(), 0)); ++ } ++ return LinearLayout(std::move(bases), llvm::to_vector(getOutDimNames())); ++} ++ + std::string LinearLayout::toString() const { + // Start with a newline because we print out a bulleted list; it doesn't + // make sense for the first line of this list to be on the same line as +diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir +--- a/test/Conversion/tritongpu_to_llvm.mlir ++++ b/test/Conversion/tritongpu_to_llvm.mlir +@@ -1698,7 +1698,8 @@ module attributes {"ttg.target" = "cuda: + // CHECK-LABEL: convert_single_element + // CHECK-NOT: llvm.store + // CHECK-NOT: llvm.load +- // CHECK: llvm.return ++ // CHECK: llvm.insertvalue ++ // CHECK: llvm.extractvalue + tt.func public @convert_single_element() attributes {noinline = false} { + %cst = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked1> + %0 = ttg.convert_layout %cst : tensor<1xf32, #blocked1> -> tensor<1xf32, #blocked> +diff --git a/unittest/Tools/LinearLayoutTest.cpp b/unittest/Tools/LinearLayoutTest.cpp +--- a/unittest/Tools/LinearLayoutTest.cpp ++++ b/unittest/Tools/LinearLayoutTest.cpp +@@ -410,6 +410,26 @@ TEST_F(LinearLayoutTest, InvertAndCompos + EXPECT_EQ(composition.compose(l2), l1); + } + ++TEST_F(LinearLayoutTest, InvertAndCompose_SmallerResult) { ++ // The domain of l2 is [0,16), but the codomain of the result is only [0,8), ++ // because there's no value v in the codomain of l1 such that l2^-1(v) >= 8. ++ LinearLayout l1({{S("in1"), {{1}, {2}, {4}}}}, {S("out")}); ++ LinearLayout l2({{S("in2"), {{4}, {1}, {2}, {8}}}}, {S("out")}); ++ // Pseudo-inverse of l2 is ++ // ++ // out(1) = 2 ++ // out(2) = 4 ++ // out(4) = 1 ++ // out(8) = 8 ++ // ++ // Composing with l1 gives back l2^-1 without the out(8) entry. ++ LinearLayout composition = l1.invertAndCompose(l2); ++ EXPECT_EQ(composition, ++ LinearLayout({{S("in1"), {{2}, {4}, {1}}}}, {{S("in2"), 16}}, ++ /*requireSurjective=*/false)); ++ EXPECT_TRUE(composition.compose(l2).equalIgnoringOutDimSizes(l1)); ++} ++ + TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastedInDim) { + LinearLayout l1({{S("in1"), {{2}, {1}, {4}}}, {S("in2"), {{0}}}}, {S("out")}); + LinearLayout l2({{S("in"), {{4}, {1}, {2}}}}, {S("out")}); +@@ -494,10 +514,8 @@ TEST_F(LinearLayoutTest, InvertAndCompos + LinearLayout l1({{S("in1"), {{1}, {2}, {4}}}, {S("in2"), {{0}}}}, {S("out")}); + LinearLayout l2({{S("in3"), {{1}, {2}, {4}}}, {S("in4"), {{0}}}}, {S("out")}); + LinearLayout c = l1.invertAndCompose(l2); +- EXPECT_EQ(c, LinearLayout( +- {{S("in1"), {{1, 0}, {2, 0}, {4, 0}}}, {S("in2"), {{0, 0}}}}, +- {{S("in3"), 8}, {S("in4"), 2}}, +- /*requireSurjective=*/false)); ++ EXPECT_EQ(c, LinearLayout::identity1D(8, S("in1"), S("in3")) * ++ LinearLayout::identity1D(2, S("in2"), S("in4"))); + EXPECT_EQ(c.compose(l2), + l1.transposeOuts(llvm::to_vector(l2.getOutDimNames()))); + } +@@ -507,9 +525,8 @@ TEST_F(LinearLayoutTest, InvertAndCompos + LinearLayout b({{S("in3"), {{2}, {1}}}, {S("in4"), {{0}}}}, {S("out")}); + LinearLayout c = a.invertAndCompose(b); + EXPECT_EQ(c, +- LinearLayout({{S("in1"), {{2, 0}, {1, 0}}}, {S("in2"), {{0, 0}}}}, +- {{S("in3"), 4}, {S("in4"), 2}}, +- /*requireSurjective=*/false)); ++ LinearLayout({{S("in1"), {{2, 0}, {1, 0}}}, {S("in2"), {{0, 1}}}}, ++ {S("in3"), S("in4")})); + EXPECT_EQ(c.compose(b), a.transposeOuts(llvm::to_vector(b.getOutDimNames()))); + } + +@@ -729,6 +746,40 @@ TEST_F(LinearLayoutTest, QuotientIdentit + ASSERT_TRUE(quotientLayout.has_value()); + ASSERT_TRUE(quotientLayout->quotient({S("dim2")}).has_value()); + } ++ ++TEST_F(LinearLayoutTest, Resize) { ++ auto init = LinearLayout( ++ { ++ {S("in0"), {{0, 1}, {0, 2}}}, ++ {S("in1"), {{1, 0}, {2, 0}}}, ++ {S("in2"), {}}, ++ }, ++ {S("dim0"), S("dim1")}); ++ EXPECT_EQ(init.resize(S("in0"), 8), ++ LinearLayout( ++ { ++ {S("in0"), {{0, 1}, {0, 2}, {0, 0}}}, ++ {S("in1"), {{1, 0}, {2, 0}}}, ++ {S("in2"), {}}, ++ }, ++ {S("dim0"), S("dim1")})); ++ EXPECT_EQ(init.resize(S("in0"), 4), LinearLayout( ++ { ++ {S("in0"), {{0, 1}, {0, 2}}}, ++ {S("in1"), {{1, 0}, {2, 0}}}, ++ {S("in2"), {}}, ++ }, ++ {S("dim0"), S("dim1")})); ++ EXPECT_EQ(init.resize(S("in1"), 8), ++ LinearLayout( ++ { ++ {S("in0"), {{0, 1}, {0, 2}}}, ++ {S("in1"), {{1, 0}, {2, 0}, {0, 0}}}, ++ {S("in2"), {}}, ++ }, ++ {S("dim0"), S("dim1")})); ++} ++ + } // anonymous namespace + } // namespace mlir::triton + diff --git a/third_party/xla/third_party/triton/temporary/series.bzl b/third_party/xla/third_party/triton/temporary/series.bzl index 4fa55269e3323c..0348fe0cbb87f7 100644 --- a/third_party/xla/third_party/triton/temporary/series.bzl +++ b/third_party/xla/third_party/triton/temporary/series.bzl @@ -14,5 +14,6 @@ those to this list. """ temporary_patch_list = [ + "//third_party/triton:temporary/numpy_type_promotion.patch", # Add new patches just above this line ] diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl index 5de24fa70a5b75..2b93e2bababa45 100644 --- a/third_party/xla/third_party/triton/workspace.bzl +++ b/third_party/xla/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton/xla_extensions:series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl702724623" - TRITON_SHA256 = "7348c9fcc01f24d97daf71b9757b9065a36fedfe05a5fbe1ea79b603b89a65b9" + TRITON_COMMIT = "cl706678601" + TRITON_SHA256 = "904377c36458ef842e6fa2daa8e55f4fe0d235f08cce3011c5b33b50f4ffe93a" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/third_party/xla/third_party/triton/xla_extensions/series.bzl b/third_party/xla/third_party/triton/xla_extensions/series.bzl index 0e0291d7def6d5..9a12588aae7bcc 100644 --- a/third_party/xla/third_party/triton/xla_extensions/series.bzl +++ b/third_party/xla/third_party/triton/xla_extensions/series.bzl @@ -8,5 +8,6 @@ IMPORTANT: This is a temporary hack while we are figuring out the proper way to extensions_files_patch_list = [ "//third_party/triton:xla_extensions/sparse_wgmma_op.patch", # Sparsity internal patch + "//third_party/triton:xla_extensions/sparse_fenceinsertion_pass.patch", # Sparse internal patch # Add new patches just above this line ] diff --git a/third_party/xla/third_party/triton/xla_extensions/sparse_fenceinsertion_pass.patch b/third_party/xla/third_party/triton/xla_extensions/sparse_fenceinsertion_pass.patch new file mode 100644 index 00000000000000..d9a1a25fe2d1f9 --- /dev/null +++ b/third_party/xla/third_party/triton/xla_extensions/sparse_fenceinsertion_pass.patch @@ -0,0 +1,13 @@ +# Tracked in b/377699102 +--- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp 2024-12-05 23:53:31.000000000 -0800 ++++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp 2024-12-19 07:03:31.000000000 -0800 +@@ -44,7 +44,8 @@ + return; + ModuleOp mod = getOperation(); + mod.walk([&](Operation *op) { +- if (!op->hasTrait()) ++ if (!isa(op) && ++ op->getName().getStringRef() != "triton_xla.sparse_dot") + return WalkResult::advance(); + OpBuilder builder(op); + auto a = op->getOperand(0); diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index 342f35280adf36..04fb49a09186a8 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -451,30 +451,36 @@ build:avx_linux --copt=-mavx build:avx_linux --host_copt=-mavx build:avx_win --copt=/arch:AVX +# TODO(belitskiy): Remove once Win2019 is gone. # Use Clang-cl compiler on Windows -build:win_clang --copt=/clang:-Weverything -build:win_clang --host_copt=/clang:-Weverything build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl build:win_clang --extra_execution_platforms=//tensorflow/tools/toolchains/win:x64_windows-clang-cl build:win_clang --host_platform=//tensorflow/tools/toolchains/win:x64_windows-clang-cl +build:win_clang --copt=/clang:-Weverything +build:win_clang --host_copt=/clang:-Weverything build:win_clang --compiler=clang-cl build:win_clang --linkopt=/FORCE:MULTIPLE build:win_clang --host_linkopt=/FORCE:MULTIPLE test:win_clang --linkopt=/FORCE:MULTIPLE test:win_clang --host_linkopt=/FORCE:MULTIPLE - -# Same config as above but for XLA, which has different toolchain paths -build:win_clang_xla --copt=/clang:-Weverything -build:win_clang_xla --host_copt=/clang:-Weverything -build:win_clang_xla --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl -build:win_clang_xla --extra_execution_platforms=//tools/toolchains/win:x64_windows-clang-cl -build:win_clang_xla --host_platform=//tools/toolchains/win:x64_windows-clang-cl -build:win_clang_xla --compiler=clang-cl -build:win_clang_xla --linkopt=/FORCE:MULTIPLE -build:win_clang_xla --host_linkopt=/FORCE:MULTIPLE -test:win_clang_xla --action_env=PATHEXT=.COM;.EXE;.BAT;.CMD;.VBS;.VBE;.JS;.JSE;.WSF;.WSH;.MSC;.PY;.PYW -test:win_clang_xla --linkopt=/FORCE:MULTIPLE -test:win_clang_xla --host_linkopt=/FORCE:MULTIPLE +test:win_clang --action_env=PATHEXT=.COM;.EXE;.BAT;.CMD;.VBS;.VBE;.JS;.JSE;.WSF;.WSH;.MSC;.PY;.PYW + +# build:windows_x86_cpu --extra_toolchains="//tensorflow/tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl" +# build:windows_x86_cpu --extra_execution_platforms="//tensorflow/tools/toolchains/win2022:windows_ltsc2022_clang" +# build:windows_x86_cpu --host_platform="//tensorflow/tools/toolchains/win2022:windows_ltsc2022_clang" +build:windows_x86_cpu --crosstool_top="//tensorflow/tools/toolchains/win2022/20241118:toolchain" +build:windows_x86_cpu --extra_toolchains="//tensorflow/tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl" +build:windows_x86_cpu --extra_execution_platforms="//tensorflow/tools/toolchains/win2022:windows_ltsc2022_clang" +build:windows_x86_cpu --host_platform="//tensorflow/tools/toolchains/win2022:windows_ltsc2022_clang" +build:windows_x86_cpu --platforms="//tensorflow/tools/toolchains/win2022:windows_ltsc2022_clang" +build:windows_x86_cpu --copt=/clang:-Weverything +build:windows_x86_cpu --host_copt=/clang:-Weverything +build:windows_x86_cpu --compiler=clang-cl +build:windows_x86_cpu --linkopt=/FORCE:MULTIPLE +build:windows_x86_cpu --host_linkopt=/FORCE:MULTIPLE +test:windows_x86_cpu --linkopt=/FORCE:MULTIPLE +test:windows_x86_cpu --host_linkopt=/FORCE:MULTIPLE +test:windows_x86_cpu --action_env=PATHEXT=.COM;.EXE;.BAT;.CMD;.VBS;.VBE;.JS;.JSE;.WSF;.WSH;.MSC;.PY;.PYW # Options to build TensorFlow 1.x or 2.x. # TODO(kanglan): Change v2's define to default behavior @@ -533,9 +539,9 @@ build:rbe_linux_cpu --crosstool_top="@local_config_cuda//crosstool:toolchain" build:rbe_linux_cpu --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64" build:rbe_linux_cpu --repo_env=CC="/usr/lib/llvm-18/bin/clang" build:rbe_linux_cpu --repo_env=TF_SYSROOT="/dt9" -build:rbe_linux_cpu --extra_execution_platforms="@sigbuild-r2.17-clang_config_platform//:platform" -build:rbe_linux_cpu --host_platform="@sigbuild-r2.17-clang_config_platform//:platform" -build:rbe_linux_cpu --platforms="@sigbuild-r2.17-clang_config_platform//:platform" +build:rbe_linux_cpu --extra_execution_platforms="@ml_build_config_platform//:platform" +build:rbe_linux_cpu --host_platform="@ml_build_config_platform//:platform" +build:rbe_linux_cpu --platforms="@ml_build_config_platform//:platform" # This is needed for all Clang17 builds but must not be present in GCC builds. build:rbe_linux_cpu --copt=-Wno-error=unused-command-line-argument # This was added in clang-16 by https://reviews.llvm.org/D133574. @@ -746,48 +752,54 @@ build:tf_public_macos_cache_push --config=tf_public_macos_cache --remote_upload_ # LIBTENSORFLOW TESTS are for building Libtensorflow archives. These are CUDA/CPU-agnostic. test:linux_libtensorflow_test --config=cuda_wheel -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_package:libtensorflow.tar.gz //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz //tensorflow/java:libtensorflow.jar //tensorflow/java:libtensorflow-src.jar //tensorflow/tools/lib_package:libtensorflow_proto.zip +build:windows_libtensorflow_build --config=cuda_wheel --config=windows_x86_cpu -- //:LICENSE //tensorflow:tensorflow.dll //tensorflow:tensorflow_dll_import_lib //tensorflow/tools/lib_package:clicenses_generate //tensorflow/java:tensorflow_jni.dll //tensorflow/tools/lib_package:jnilicenses_generate # PYTHON TESTS run a suite of Python tests intended for verifying that the Python wheel # will work properly. These are usually run Nightly or upon Release. # CPU WHEEL -test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cpu_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/tools/pip_package:import_api_packages_test +test:linux_cpu_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA WHEEL -test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/tools/pip_package:import_api_packages_test +test:linux_cuda_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_gpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL -test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/tools/pip_package:import_api_packages_test +test:linux_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS ARM64 WHEEL -test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/tools/pip_package:import_api_packages_test +test:macos_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # MACOS X86 WHEEL -test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_x86_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/tools/pip_package:import_api_packages_test +test:macos_x86_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +# WINDOWS X86 WHEEL +test:windows_x86_cpu_wheel_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test +test:windows_x86_cpu_wheel_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test +test:windows_x86_cpu_wheel_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600" +test:windows_x86_cpu_wheel_test --build_tests_only --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. # LINUX CPU PYCPP: -test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # LINUX CUDA PYCPP: -test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 -test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_gpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # LINUX ARM64 PYCPP # In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on @@ -798,35 +810,35 @@ test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflo # do not run them. By prefixing the configs with "build", we can run both # `bazel build` and `bazel test` commands with the same config as test configs # inherit from build. -build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? -build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/python/tools:aot_compiled_test +build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/python/tools:aot_compiled_test # CROSS-COMPILE ARM64 PYCPP build:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test # Tests that fail only when cross-compiled build:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP -test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium -test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS X86 PYCPP # These are defined as build configs so that we can run a build only job. See # the note under "ARM64 PYCPP" for more details. -build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test build:macos_x86_pycpp_test_filters --keep_going --test_lang_filters=cc,py --test_size_filters=small,medium -build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... +build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... # CROSS-COMPILE MACOS X86 PYCPP build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test # WINDOWS X86-64 CPU PYCPP build:windows_x86_cpu_pycpp_test_build_opts --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions --dynamic_mode=off build:windows_x86_cpu_pycpp_test_build_opts_debug --config=windows_x86_cpu_pycpp_test_build_opts --linkopt=/demangle:no --host_linkopt=/demangle:no --linkopt=/errorlimit:0 --host_linkopt=/errorlimit:0 -test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test -test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-tf_tosa,-oss_excluded,-gpu,-tpu,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-tf_tosa,-oss_excluded,-benchmark-test test:windows_x86_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600" test:windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_build_opts --build_tests_only test:windows_x86_cpu_pycpp_test --config=windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/... @@ -843,38 +855,15 @@ build:cross_compile_base --host_cpu=k8 build:cross_compile_base --host_crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite build:cross_compile_base --extra_execution_platforms=//tensorflow/tools/toolchains/cross_compile/config:linux_x86_64 -# XLA related settings for cross-compiled build. Certain paths are -# different in the XLA repo. -build:cross_compile_base_xla --host_cpu=k8 -build:cross_compile_base_xla --host_crosstool_top=//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite -build:cross_compile_base_xla --extra_execution_platforms=//tools/toolchains/cross_compile/config:linux_x86_64 - build:rbe_cross_compile_base --config=rbe_base build:rbe_cross_compile_base --remote_instance_name=projects/tensorflow-testing/instances/default_instance -# XLA depends on some local Python headers that are configured as Genrule. They -# are present on the local host machine but not on the remote execution machine, -# leading to build failures. To resolve the issue, the following line is added -# to make sure all Genrule targets are excuted locally. -build:rbe_cross_compile_base_xla --config=rbe_cross_compile_base -build:rbe_cross_compile_base_xla --strategy=Genrule=standalone - -# Due to the above strategy, all Genrule commands are executed locally, but the -# following actions invoke tools (E.g `flatc`, `llvm-tblgen`, etc.) that are -# only executabe on the RBE (x86) machine, so the strategy_regexp options are -# added to override and run the actions using remote strategy. -build:rbe_cross_compile_base_xla --strategy_regexp='Generating code from table.*=remote' -build:rbe_cross_compile_base_xla --strategy_regexp='Generating flatbuffer files.*=remote' -build:rbe_cross_compile_base_xla --strategy_regexp='Executing genrule @llvm-project.*=remote' - # Test-related settings below this point # We cannot run cross-compiled tests on the remote Linux x86 VMs so we need to # force all tests to run locally on the Aarch64 host. test:rbe_cross_compile_base --strategy=TestRunner=local --build_tests_only test:rbe_cross_compile_base --verbose_failures=true --local_test_jobs=HOST_CPUS --test_output=errors -test:rbe_cross_compile_base_xla --config=rbe_cross_compile_base - # START LINUX AARCH64 CROSS-COMPILE CONFIGS build:cross_compile_linux_arm64 --config=cross_compile_base @@ -883,21 +872,11 @@ build:cross_compile_linux_arm64 --platforms=//tensorflow/tools/toolchains/cross_ build:cross_compile_linux_arm64 --cpu=aarch64 build:cross_compile_linux_arm64 --crosstool_top=//tensorflow/tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite -# XLA uses different paths for platforms and crosstool_top. -build:cross_compile_linux_arm64_xla --config=cross_compile_base_xla -build:cross_compile_linux_arm64_xla --platforms=//tools/toolchains/cross_compile/config:linux_aarch64 -build:cross_compile_linux_arm64_xla --crosstool_top=//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite - # RBE cross-compile configs for Linux Aarch64 build:rbe_cross_compile_linux_arm64 --config=cross_compile_linux_arm64 build:rbe_cross_compile_linux_arm64 --config=rbe_cross_compile_base test:rbe_cross_compile_linux_arm64 --config=rbe_cross_compile_base -# RBE cross-compile configs for XLA Linux Aarch64 -build:rbe_cross_compile_linux_arm64_xla --config=cross_compile_linux_arm64_xla -build:rbe_cross_compile_linux_arm64_xla --config=rbe_cross_compile_base_xla -test:rbe_cross_compile_linux_arm64_xla --config=rbe_cross_compile_base_xla - # END LINUX AARCH64 CROSS-COMPILE CONFIGS # START MACOS CROSS-COMPILE CONFIGS diff --git a/third_party/xla/third_party/tsl/opensource_only.files b/third_party/xla/third_party/tsl/opensource_only.files index 9ad725817356c5..ad43bd44d8ef37 100644 --- a/third_party/xla/third_party/tsl/opensource_only.files +++ b/third_party/xla/third_party/tsl/opensource_only.files @@ -99,17 +99,18 @@ third_party/nvtx/LICENSE: third_party/protobuf/BUILD: third_party/py/BUILD.tpl: third_party/py/BUILD: +third_party/py/manylinux_compliance_test.py: third_party/py/ml_dtypes/BUILD: third_party/py/ml_dtypes/LICENSE: third_party/py/numpy/BUILD: third_party/py/py_import.bzl: +third_party/py/py_manylinux_compliance_test.bzl: third_party/py/python_configure.bzl: third_party/py/python_init_pip.bzl: third_party/py/python_init_repositories.bzl: third_party/py/python_init_rules.bzl: third_party/py/python_init_toolchains.bzl: third_party/py/python_repo.bzl: -third_party/py/verify_manylinux_compliance.py: third_party/pybind11.BUILD: third_party/pybind11_bazel/BUILD: third_party/python_runtime/BUILD: @@ -131,7 +132,6 @@ third_party/systemlibs/boringssl.BUILD: third_party/systemlibs/build_defs.bzl.tpl: third_party/systemlibs/curl.BUILD: third_party/systemlibs/cython.BUILD: -third_party/systemlibs/double_conversion.BUILD: third_party/systemlibs/gif.BUILD: third_party/systemlibs/google_cloud_cpp.BUILD: third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD: @@ -176,5 +176,7 @@ tools/toolchains/win/20240424/BUILD: tools/toolchains/win/BUILD: tools/toolchains/win/bazel_211/BUILD: tools/toolchains/win/tf_win_05022023/BUILD: +tools/toolchains/win2022/20241118/BUILD: +tools/toolchains/win2022/BUILD: tools/toolchains/win_1803/py38/BUILD: tools/toolchains/win_1803/py39/BUILD: diff --git a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/BUILD.rocm.tpl b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/BUILD.rocm.tpl index 03a9dde83cfddc..ac3082fbcb3055 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/BUILD.rocm.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/BUILD.rocm.tpl @@ -111,7 +111,7 @@ filegroup( ) filegroup( - name = "crosstool_wrapper_driver_is_not_gcc", - srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"], + name = "crosstool_wrapper_driver_is_not_gcc", + srcs = [":clang/bin/crosstool_wrapper_driver_is_not_gcc"], + data = ["@local_config_rocm//rocm:all_files"], ) - diff --git a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index 98bcb800c60e87..e75ea610a02403 100755 --- a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -24,7 +24,7 @@ import pipes # Template values set by rocm_configure.bzl. CPU_COMPILER = ('%{cpu_compiler}') -USE_CLANG = ('%{compiler}' == 'clang') +USE_CLANG = ('%{compiler_is_clang}' == 'True') HOST_COMPILER_PATH = ('%{host_compiler_path}') HIPCC_PATH = '%{hipcc_path}' @@ -186,6 +186,7 @@ def InvokeHipcc(argv, log=False): hipccopts += defines hipccopts += std_options hipccopts += m_options + hipccopts += ' --rocm-path="%{rocm_path}" ' if depfiles: # Generate the dependency file diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl index 510235d801de4e..d8f125fa3d3253 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -21,12 +25,14 @@ cc_library( name = "cublas", visibility = ["//visibility:public"], %{comment}deps = [":cublas_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/cublas/lib"), ) cc_library( name = "cublasLt", visibility = ["//visibility:public"], %{comment}deps = [":cublasLt_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/cublas/lib"), ) cc_library( diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl index 04d2de148c78c0..fabb310001cd39 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -36,6 +40,7 @@ cc_library( %{comment}}) + [ %{comment}":cudart_shared_library", %{comment}], + %{comment}linkopts = cuda_rpath_flags("nvidia/cuda_runtime/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl index 165c5b1579e73f..c3701a6241243d 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -58,6 +62,7 @@ cc_library( %{comment}"@cuda_nvrtc//:nvrtc", %{comment}":cudnn_main", %{comment}], + %{comment}linkopts = cuda_rpath_flags("nvidia/cudnn/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl index 7f36054a51bb5b..4e8bcbd84e0327 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -65,6 +69,7 @@ cc_library( %{comment}"@cuda_nvrtc//:nvrtc", %{comment}":cudnn_main", %{comment}], + %{comment}linkopts = cuda_rpath_flags("nvidia/cudnn/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl index 48ccb0ea3cd197..2e55a742d54967 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -13,6 +17,7 @@ cc_import( cc_library( name = "cufft", %{comment}deps = [":cufft_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/cufft/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl index 3991b486195bc5..16d6991b584154 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl @@ -1,5 +1,10 @@ licenses(["restricted"]) # NVIDIA proprietary license load("@local_config_cuda//cuda:build_defs.bzl", "if_version_equal_or_greater_than") +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) + exports_files([ "version.txt", ]) @@ -13,6 +18,7 @@ cc_import( cc_library( name = "cupti", %{comment}deps = [":cupti_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/cuda_cupti/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl index 50e5a8f18a96fd..746503fcf22229 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -13,6 +17,7 @@ cc_import( cc_library( name = "curand", %{comment}deps = [":curand_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/curand/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl index 943a08ebeb96e1..30bacf07eebda2 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -19,6 +23,7 @@ cc_import( cc_library( name = "cusolver", %{comment}deps = [":cusolver_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/cusolver/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl index 46b24366ce1c04..b7765ab22508dc 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -14,6 +18,7 @@ cc_import( cc_library( name = "cusparse", %{comment}deps = [":cusparse_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/cusparse/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl index 0494008e7924f3..5be8d6ef2408ba 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -13,6 +17,7 @@ cc_import( cc_library( name = "nvjitlink", %{comment}deps = [":nvjitlink_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/nvjitlink/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl index de18489b455b79..fea4c5d7ce7ed5 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl @@ -1,4 +1,9 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) + %{multiline_comment} cc_import( name = "nvrtc_main", @@ -16,5 +21,6 @@ cc_library( %{comment}":nvrtc_main", %{comment}":nvrtc_builtins", %{comment}], + %{comment}linkopts = cuda_rpath_flags("nvidia/cuda_nvrtc/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/rocm/BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/rocm/BUILD.tpl index aa3688e335df37..7ebf2773eb48b1 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/rocm/BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/rocm/BUILD.tpl @@ -1,8 +1,22 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") +load("@local_config_rocm//rocm:build_defs.bzl", "rocm_version_number", "select_threshold") licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like -package(default_visibility = ["//visibility:public"]) +package(default_visibility = ["//visibility:private"]) + +bool_flag( + name = "use_rocm_hermetic_rpath", + build_setting_default = False, +) + +config_setting( + name = "build_hermetic", + flag_values = { + ":use_rocm_hermetic_rpath": "True", + }, +) config_setting( name = "using_hipcc", @@ -12,171 +26,434 @@ config_setting( ) cc_library( - name = "rocm_headers", + name = "config", hdrs = [ - "rocm/rocm_config.h", - %{rocm_headers} + "rocm_config/rocm_config.h", ], + include_prefix = "rocm", + strip_include_prefix = "rocm_config", +) + +cc_library( + name = "config_hermetic", + hdrs = [ + "rocm_config_hermetic/rocm_config.h", + ], + include_prefix = "rocm", + strip_include_prefix = "rocm_config_hermetic", +) + +cc_library( + name = "rocm_config", + visibility = ["//visibility:public"], + deps = select({ + ":build_hermetic": [ + ":config_hermetic", + ], + "//conditions:default": [ + "config", + ], + }), +) + +cc_library( + name = "rocm_headers", + hdrs = glob([ + "%{rocm_root}/include/**", + "%{rocm_root}/lib/llvm/lib/**/*.h", + ]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", - "rocm/include/rocrand", - "rocm/include/roctracer", + "%{rocm_root}/include", + "%{rocm_root}/include/rocrand", + "%{rocm_root}/include/roctracer", ], + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], + deps = [ + ":rocm_rpath", + ], ) cc_library( - name = "hip", - srcs = ["rocm/lib/%{hip_lib}"], - data = ["rocm/lib/%{hip_lib}"], + name = "rocm", + visibility = ["//visibility:public"], + deps = [ + ":hip", + ":hipblas", + ":hipblaslt", + ":hiprand", + ":hipsolver", + ":hipsparse", + ":hsa_rocr", + ":miopen", + ":rocblas", + ":rocm_config", + ":rocprofiler_register", + ":rocsolver", + ":roctracer", + ":rocsparse", + ] + select_threshold( + above_or_eq = [":hipfft"], + below = [":rocfft"], + threshold = 40100, + value = rocm_version_number(), + ), +) + +cc_library( + name = "hsa_rocr", + srcs = glob(["%{rocm_root}/lib/libhsa-runtime*.so*"]), + hdrs = glob(["%{rocm_root}/include/hsa/**"]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", ], linkstatic = 1, + strip_include_prefix = "%{rocm_root}", + deps = [":rocm_config"], +) + +cc_library( + name = "rocm_rpath", + linkopts = select({ + ":build_hermetic": [ + "-Wl,-rpath=%{rocm_toolkit_path}/lib", + ], + "//conditions:default": [ + "-Wl,-rpath=/opt/rocm/lib", + ], + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "hip", visibility = ["//visibility:public"], + deps = [ + ":rocm_hip", + ":rocm_rpath", + ], +) + +cc_library( + name = "rocm_hip", + srcs = glob(["%{rocm_root}/lib/libamdhip*.so*"]), + hdrs = glob(["%{rocm_root}/include/hip/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", + ], + strip_include_prefix = "%{rocm_root}", + deps = [ + ":amd_comgr", + ":hsa_rocr", + ":rocm_config", + ":rocm_smi", + ":rocprofiler_register", + ":system_libs", + ], ) cc_library( name = "rocblas", - srcs = ["rocm/lib/%{rocblas_lib}"], - data = ["rocm/lib/%{rocblas_lib}"], + hdrs = glob(["%{rocm_root}/include/rocblas/**"]), + data = glob([ + "%{rocm_root}/lib/librocblas*.so*", + "%{rocm_root}/lib/rocblas/**", + ]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", ], - linkstatic = 1, + # workaround to bring tensile files to the same fs layout as expected in the lib + # rocblas assumes that tensile files are located in ../roblas/libraries directory + linkopts = ["-Wl,-rpath=local_config_rocm/rocm/rocm_dis/lib"], + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "%{hipfft_or_rocfft}", - srcs = ["rocm/lib/%{hipfft_or_rocfft_lib}"], - data = ["rocm/lib/%{hipfft_or_rocfft_lib}"], + name = "rocfft", + srcs = glob(["%{rocm_root}/lib/librocfft*.so*"]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", ], linkstatic = 1, visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "hiprand", - srcs = ["rocm/lib/%{hiprand_lib}"], - data = ["rocm/lib/%{hiprand_lib}"], + name = "hipfft", + srcs = glob(["%{rocm_root}/lib/libhipfft*.so*"]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", - "rocm/include/rocrand", + "%{rocm_root}/include", ], linkstatic = 1, - visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "miopen", - srcs = ["rocm/lib/%{miopen_lib}"], - data = ["rocm/lib/%{miopen_lib}"], + name = "hiprand", + srcs = glob(["%{rocm_root}/lib/libhiprand*.so*"]), + hdrs = glob(["%{rocm_root}/include/hiprand/**"]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", + "%{rocm_root}/include/rocrand", ], linkstatic = 1, + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "rccl", - srcs = ["rocm/lib/%{rccl_lib}"], - data = ["rocm/lib/%{rccl_lib}"], + name = "miopen", + hdrs = glob(["%{rocm_root}/include/rccl/**"]), + data = glob([ + "%{rocm_root}/lib/libMIOpen*.so*", + "%{rocm_root}/share/miopen/**", + ]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", ], - linkstatic = 1, + # workaround to bring miopen db files to the same fs layout as expected in the lib + # rocblas assumes that miopen db files are located in ../share/miopen/db directory + linkopts = ["-Wl,-rpath=local_config_rocm/rocm/rocm_dis/lib"], + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "rocm", - visibility = ["//visibility:public"], - deps = [ - ":rocm_headers", - ":hip", - ":rocblas", - ":hipblas", - ":%{hipfft_or_rocfft}", - ":hiprand", - ":miopen", - ":hipsparse", - ":roctracer", - ":rocsolver", - ":hipsolver", + name = "rccl", + srcs = glob(["%{rocm_root}/lib/librccl*.so*"]), + hdrs = glob(["%{rocm_root}/include/rccl/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", ], + linkstatic = 1, + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) bzl_library( name = "build_defs_bzl", srcs = ["build_defs.bzl"], + visibility = ["//visibility:public"], ) cc_library( name = "rocprim", srcs = [ - "rocm/include/hipcub/hipcub_version.hpp", - "rocm/include/rocprim/rocprim_version.hpp", + "%{rocm_root}/include/hipcub/hipcub_version.hpp", + "%{rocm_root}/include/rocprim/rocprim_version.hpp", ], hdrs = glob([ - "rocm/include/hipcub/**", - "rocm/include/rocprim/**", + "%{rocm_root}/include/hipcub/**", + "%{rocm_root}/include/rocprim/**", ]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include/hipcub", - "rocm/include/rocprim", + "%{rocm_root}/include/hipcub", + "%{rocm_root}/include/rocprim", ], + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], deps = [ - "@local_config_rocm//rocm:rocm_headers", + ":rocm_config", + ":rocm_headers", ], ) cc_library( name = "hipsparse", - srcs = ["rocm/lib/%{hipsparse_lib}"], - data = ["rocm/lib/%{hipsparse_lib}"], + srcs = glob(["%{rocm_root}/lib/libhipsparse*.so*"]), + hdrs = glob(["%{rocm_root}/include/hipsparse/**"]), + data = glob(["%{rocm_root}/lib/libhipsparse*.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( name = "roctracer", - data = ["rocm/lib/%{roctracer_lib}"], + hdrs = glob(["%{rocm_root}/include/roctracer/**"]), + data = glob(["%{rocm_root}/lib/libroctracer*.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( name = "rocsolver", - srcs = ["rocm/lib/%{rocsolver_lib}"], - data = ["rocm/lib/%{rocsolver_lib}"], + srcs = glob(["%{rocm_root}/lib/librocsolver*.so*"]), + hdrs = glob(["%{rocm_root}/include/rocsolver/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], +) + +cc_library( + name = "rocsparse", + srcs = glob(["%{rocm_root}/lib/librocsparse*.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( name = "hipsolver", - srcs = ["rocm/lib/%{hipsolver_lib}"], - data = ["rocm/lib/%{hipsolver_lib}"], + srcs = glob(["%{rocm_root}/lib/libhipsolver*.so*"]), + hdrs = glob(["%{rocm_root}/include/hipsolver/**"]), + data = glob(["%{rocm_root}/lib/libhipsolver*.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( name = "hipblas", - srcs = ["rocm/lib/%{hipblas_lib}"], - data = ["rocm/lib/%{hipblas_lib}"], + srcs = glob(["%{rocm_root}/lib/libhipblas.so*"]), + hdrs = glob(["%{rocm_root}/include/hipblas/**"]), + data = glob(["%{rocm_root}/lib/libhipblas.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], +) + +cc_library( + name = "hipblaslt", + hdrs = glob(["%{rocm_root}/include/hipblaslt/**"]), + data = glob([ + "%{rocm_root}/lib/hipblaslt/**", + "%{rocm_root}/lib/libhipblaslt.so*", + ]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + # workaround to bring tensile files to the same fs layout as expected in the lib + # hibplatslt assumes that tensile files are located in ../hipblaslt/libraries directory + linkopts = ["-Wl,-rpath=local_config_rocm/rocm/rocm_dis/lib"], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], +) + +cc_library( + name = "rocrand", + srcs = glob(["%{rocm_root}/lib/librocrand*.so*"]), + hdrs = glob(["%{rocm_root}/include/rocrand/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], +) + +cc_library( + name = "rocprofiler_register", + srcs = glob([ + "%{rocm_root}/lib/librocprofiler-register.so*", + ]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", + ], + strip_include_prefix = "%{rocm_root}", + deps = [":rocm_config"], +) + +cc_library( + name = "amd_comgr", + srcs = glob([ + "%{rocm_root}/lib/libamd_comgr.so*", + ]), + hdrs = glob(["%{rocm_root}/include/amd_comgr/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", + ], + strip_include_prefix = "%{rocm_root}", + deps = [":rocm_config"], +) + +cc_library( + name = "rocm_smi", + srcs = glob([ + "%{rocm_root}/lib/librocm_smi64.so*", + "%{rocm_root}/lib/libroam.so*", + ]), + hdrs = glob([ + "%{rocm_root}/include/oam/**", + "%{rocm_root}/include/rocm_smi/**", + ]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", + ], + strip_include_prefix = "%{rocm_root}", + deps = [":rocm_config"], +) + +cc_library( + name = "system_libs", + srcs = glob([ + "rocm_dist/usr/lib/**/libelf.so*", + "rocm_dist/usr/lib/**/libdrm.so*", + "rocm_dist/usr/lib/**/libnuma.so*", + "rocm_dist/usr/lib/**/libdrm_amdgpu.so*", + ]), + data = glob([ + "rocm_dist/usr/**", + ]), ) filegroup( name = "rocm_root", srcs = [ - "rocm/bin/clang-offload-bundler", + "%{rocm_root}/bin/clang-offload-bundler", ], + visibility = ["//visibility:public"], ) -%{copy_rules} +filegroup( + name = "all_files", + srcs = glob(["%{rocm_root}/**"]), + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/xla/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl index 83a7e9dababf38..d327083e4dc8ea 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl @@ -11,6 +11,8 @@ def if_rocm(if_true, if_false = []): "//conditions:default": if_false }) +def select_threshold(value, above_or_eq, threshold, below): + return below if value < threshold else above_or_eq def rocm_default_copts(): """Default options for all ROCm compilations.""" diff --git a/third_party/xla/third_party/tsl/third_party/gpus/rocm/rocm_redist.bzl b/third_party/xla/third_party/tsl/third_party/gpus/rocm/rocm_redist.bzl new file mode 100644 index 00000000000000..ca64cc8ec9b61b --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/rocm/rocm_redist.bzl @@ -0,0 +1,18 @@ +load( + "@local_tsl//third_party/gpus/rocm:rocm_redist_ubuntu_20_04.bzl", + "rocm_redist_ubuntu_20_04", +) +load( + "@local_tsl//third_party/gpus/rocm:rocm_redist_ubuntu_22_04.bzl", + "rocm_redist_ubuntu_22_04", +) +load( + "@local_tsl//third_party/gpus/rocm:rocm_redist_ubuntu_24_04.bzl", + "rocm_redist_ubuntu_24_04", +) + +rocm_redist = { + "ubuntu_20.04": rocm_redist_ubuntu_20_04, + "ubuntu_22.04": rocm_redist_ubuntu_22_04, + "ubuntu_24.04": rocm_redist_ubuntu_24_04, +} diff --git a/third_party/xla/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_20_04.bzl b/third_party/xla/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_20_04.bzl new file mode 100644 index 00000000000000..ecae2197563b33 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_20_04.bzl @@ -0,0 +1,183 @@ +rocm_redist_ubuntu_20_04 = { + "6.2.0": { + "archives": [ + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/c/comgr6.2.0/comgr6.2.0_2.8.0.60200-66~20.04_amd64.deb", + sha256 = "fabf4a831f21b5248932e08654149bc215da2a816613ad8d05b805d4e226171a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-runtime-amd6.2.0/hip-runtime-amd6.2.0_6.2.41133.60200-66~20.04_amd64.deb", + sha256 = "215fae8759742bc048699feaacd6256a3ac2138771b69731dab7779325bb1b41", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev6.2.0/hip-dev6.2.0_6.2.41133.60200-66~20.04_amd64.deb", + sha256 = "e901d66275b3b520ee73250caa4a1836be142823083528b4db6cc31a18bfb94d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas6.2.0/hipblas6.2.0_2.2.0.60200-66~20.04_amd64.deb", + sha256 = "f8a20128b5c26198bd9ecec894f8a4c74fa28ee668e4ef1bf73d0c3edff8c144", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas-dev6.2.0/hipblas-dev6.2.0_2.2.0.60200-66~20.04_amd64.deb", + sha256 = "ab3ee54b33eba013fbf3d9aefe64b54e1918b9fb72790ca0b57fb391cb662cf0", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcc6.2.0/hipcc6.2.0_1.1.1.60200-66~20.04_amd64.deb", + sha256 = "a68123c046b8c913705262014463a8a30768167a1b68a78d8455deaf85a802d7", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcub-dev6.2.0/hipcub-dev6.2.0_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "c71fab59f62ad9d4b60aa4217f4db42c6996d83d5ad7ba29e127cc13bda59afc", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft6.2.0/hipfft6.2.0_1.0.14.60200-66~20.04_amd64.deb", + sha256 = "25887526ea2e955d4c0afa4749f8db55a49e399a349d43ccf66e0ad99ff78b2a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft-dev6.2.0/hipfft-dev6.2.0_1.0.14.60200-66~20.04_amd64.deb", + sha256 = "3cfec840c79c6bce4e83bf6e056e241cc13ff572352b040a952c7642b61d45aa", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver6.2.0/hipsolver6.2.0_2.2.0.60200-66~20.04_amd64.deb", + sha256 = "cb56dd79ff52eaddfed379831023484d9ec32b9538bc3d02ee34c328457cd20e", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver-dev6.2.0/hipsolver-dev6.2.0_2.2.0.60200-66~20.04_amd64.deb", + sha256 = "1e968f9405c8b90fbb58dff09d8bab08cf31c8386880fff95e1cb8932320bc37", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse6.2.0/hipsparse6.2.0_3.1.1.60200-66~20.04_amd64.deb", + sha256 = "f08ba25b6b950754b5a2bb64c125a01b9f44280f227ff19eeb78e188f0b17320", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse-dev6.2.0/hipsparse-dev6.2.0_3.1.1.60200-66~20.04_amd64.deb", + sha256 = "e9464369619bbea7299ac83e17b3cbbabdeb16e6d4da116400532e7737332b65", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand6.2.0/hiprand6.2.0_2.11.0.60200-66~20.04_amd64.deb", + sha256 = "2efed49be9413e08e91b3fb67736644bb0e8809fc673d310a0abab65b69eacad", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand-dev6.2.0/hiprand-dev6.2.0_2.11.0.60200-66~20.04_amd64.deb", + sha256 = "19564fb2f9616860234aa8bd69cca324a1a3ec33476581ec57200a1dac1d4dcb", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hsa-rocr6.2.0/hsa-rocr6.2.0_1.14.0.60200-66~20.04_amd64.deb", + sha256 = "e4940a5d47e9e39d603f18936e7921c603fd8dde0e359e0be796f9c1cdacd431", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip6.2.0/miopen-hip6.2.0_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "638a28c5407c3af7d16e1b0179b7494b0aeb36c314114af148b1bcd52e883db1", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip-dev/miopen-hip-dev_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "77c9d26c4f0053b71fb86f7a6b489655e27053f9605efca3a16344ccf286e313", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl6.2.0/rccl6.2.0_2.20.5.60200-66~20.04_amd64.deb", + sha256 = "2b3ce1ca2e58e891963f26d4bd31ae45894480483f691d371f269e698f75f8eb", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl-dev6.2.0/rccl-dev6.2.0_2.20.5.60200-66~20.04_amd64.deb", + sha256 = "0dedbffa5bb272d656086a9586e3705551345945f35f4f6be6dc8a27b63127a9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas6.2.0/rocblas6.2.0_4.2.0.60200-66~20.04_amd64.deb", + sha256 = "6e5b3caeadf592367f8638db67a70b8dd9231a8257dc2012a9c46e2c5974fff5", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas-dev/rocblas-dev_4.2.0.60200-66~20.04_amd64.deb", + sha256 = "eaefe5a7d75ef61314b83af5bb85d8e652a730deaa58e1d600b1e9c2e673673c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft6.2.0/rocfft6.2.0_1.0.28.60200-66~20.04_amd64.deb", + sha256 = "b2bfe29ab688781bad5bc067ee682658085e22caaf09b18278f2f4b9905081d3", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft-dev6.2.0/rocfft-dev6.2.0_1.0.28.60200-66~20.04_amd64.deb", + sha256 = "e94d50fd6f24d70649ce046dbfe4dda2587d1d82892d4c126a4c3e91d1570071", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-core/rocm-core_6.2.0.60200-66~20.04_amd64.deb", + sha256 = "0e16c9fc58fc904542be4dad63bb2ff34268b5c13957c432e91ec0e4fd149c82", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-hip-libraries/rocm-hip-libraries_6.2.0.60200-66~20.04_amd64.deb", + sha256 = "14f47d79b508eb259bfe4e0e5f360edb5721b908caf3bb981a4eee4181783be9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev/hip-dev_6.2.41133.60200-66~20.04_amd64.deb", + sha256 = "97e6e77eaea56de6cc4ea2c525dd8b9a587546eb99c782c7af46cdc5363b99bf", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-device-libs6.2.0/rocm-device-libs6.2.0_1.0.0.60200-66~20.04_amd64.deb", + sha256 = "ae055b579d319e1a779783ba774f119fb0e1a731d058a03b36dc5c15214d210a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocminfo6.2.0/rocminfo6.2.0_1.0.0.60200-66~20.04_amd64.deb", + sha256 = "3bcf3dc22dbede7da70299cde1484776827808b967d371441f6cf6d3fe8af30d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm6.2.0/rocm-llvm6.2.0_18.0.0.24292.60200-66~20.04_amd64.deb", + sha256 = "ce17d2b85407b9539e0feda513fd360a48ebfd971c19af122dda21d60448c9fc", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm-dev/rocm-llvm-dev_18.0.0.24292.60200-66~20.04_amd64.deb", + sha256 = "322ca8425c3a8f2ec17c551bad606b96d957b0c1eea07196dd66ac9f15460ed5", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-smi-lib6.2.0/rocm-smi-lib6.2.0_7.3.0.60200-66~20.04_amd64.deb", + sha256 = "1bbdb32d21dbc12bf9a736f6ca8726df9673e4401465d2b9b537c47b358b67f1", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprim-dev6.2.0/rocprim-dev6.2.0_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "e74e1907eb90a692344626e881cb88eeed5565ac3b487eb94ad4ac02ffd838ed", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprofiler-register6.2.0/rocprofiler-register6.2.0_0.4.0.60200-66~20.04_amd64.deb", + sha256 = "4be88c5010c2cf0223c1dd7dc9d4a430fc54ee401ca093de2dcca60dabea763a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocrand-dev/rocrand-dev_3.1.0.60200-66~20.04_amd64.deb", + sha256 = "ddd0ac44b08470dfc128d6f6d2598a9728879f5a78bc5290645baebf22433b63", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer6.2.0/roctracer6.2.0_4.1.60200.60200-66~20.04_amd64.deb", + sha256 = "b94cdf230b372ebcaf97085cf67f01ef7977f814280fdaf1886797f39899ef41", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer-dev6.2.0/roctracer-dev6.2.0_4.1.60200.60200-66~20.04_amd64.deb", + sha256 = "9a85b57eea3790432eae06421081b3e59d3c9841d59646364ecd174f9ed4821a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver6.2.0/rocsolver6.2.0_3.26.0.60200-66~20.04_amd64.deb", + sha256 = "87dcd34a9b50f46161ecdb7781ab03c2b311fb7e13aa167c4a9c5e3bcf24b473", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver-dev6.2.0/rocsolver-dev6.2.0_3.26.0.60200-66~20.04_amd64.deb", + sha256 = "21e4aa1957e7bc5d293a418a983d9b3c3917fb78eb79d3d4d55a253b9bae7743", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsparse6.2.0/rocsparse6.2.0_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "dacc13278f2be1cd847fca30ce409dcf95749df5f1a27635bc6dbd61be488d14", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm2_2.4.101-2_amd64.deb", + sha256 = "4cd2e10f9486456a2782487f8bfd39f330f35a4d5bd6d693412b9e4ca2a6acbd", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm-amdgpu1_2.4.101-2_amd64.deb", + sha256 = "d4567a30f7d68b4dcf794f8677b96e89083693c94e88279fecf577ceba8b9774", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libelf1_0.176-1.1build1_amd64.deb", + sha256 = "78a8761227efc04a1e37527f2f33ba608c6fb5d6c911616346ada5d7b9b72ee3", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libnuma1_2.0.12-1_amd64.deb", + sha256 = "0b1edf08cf9befecd21fe94e298ac25e476f87fd876ddd4adf42ef713449e637", + ), + ], + "rocm_root": "opt/rocm-6.2.0", + }, +} diff --git a/third_party/xla/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_22_04.bzl b/third_party/xla/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_22_04.bzl new file mode 100644 index 00000000000000..88dca226f795b7 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_22_04.bzl @@ -0,0 +1,183 @@ +rocm_redist_ubuntu_22_04 = { + "6.2.0": { + "archives": [ + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/c/comgr6.2.0/comgr6.2.0_2.8.0.60200-66~22.04_amd64.deb", + sha256 = "bc5d620e4e0db3746fc6b2279e463f618681f1f95ba973e40b687cef50ca2489", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-runtime-amd6.2.0/hip-runtime-amd6.2.0_6.2.41133.60200-66~22.04_amd64.deb", + sha256 = "38e9670bedc7bbdc0b9f38c7a0fe90f73ef80f161cbf63c98d30e422438ce2c5", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev6.2.0/hip-dev6.2.0_6.2.41133.60200-66~22.04_amd64.deb", + sha256 = "c66cc8c19b57cab740710811457f02a16e24cff761e5c99c3640f63ceefe8281", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas6.2.0/hipblas6.2.0_2.2.0.60200-66~22.04_amd64.deb", + sha256 = "fbd647e1b13e7aa2c14c9581f9102c069ddab9ecb47a4b226d433ec37b19e92d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas-dev6.2.0/hipblas-dev6.2.0_2.2.0.60200-66~22.04_amd64.deb", + sha256 = "885cf3f3a52ebde9caadf6348a6cda28fd15e3bc52bab0c90b587d72b29ff7ef", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcc6.2.0/hipcc6.2.0_1.1.1.60200-66~22.04_amd64.deb", + sha256 = "468026fa8eb70121f0c545557a926ddc41228cef9457b4a00d8fc3a36b04310f", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcub-dev6.2.0/hipcub-dev6.2.0_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "c2c7d2ec5a8a31837c0addfc619ee67a374ea967cc6d43900472005489f62722", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft6.2.0/hipfft6.2.0_1.0.14.60200-66~22.04_amd64.deb", + sha256 = "6e649430cc5e247bbd052dff2d681b6bf0ef09d0bc3446a4911f4ab4cd317140", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft-dev6.2.0/hipfft-dev6.2.0_1.0.14.60200-66~22.04_amd64.deb", + sha256 = "389b0c83a39adbeeec442adde3fedba2820ed948179a4a0df03d67560501cd97", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver6.2.0/hipsolver6.2.0_2.2.0.60200-66~22.04_amd64.deb", + sha256 = "adf9aad1fc062445e34cdddbeca80db9c02f4c5f258e01c45e2a6222d15cb66d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver-dev6.2.0/hipsolver-dev6.2.0_2.2.0.60200-66~22.04_amd64.deb", + sha256 = "cb46dfbff3943a3167f6173fc381d744eb966a3451bcff49458c696888ec452c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse6.2.0/hipsparse6.2.0_3.1.1.60200-66~22.04_amd64.deb", + sha256 = "8c7a216aeef6ceeb3881d3e443a89a0f5c15a17deb5926cba4b787554c8fab87", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse-dev6.2.0/hipsparse-dev6.2.0_3.1.1.60200-66~22.04_amd64.deb", + sha256 = "501cad72df5f09572f99c11eebbb1eff49afb6ca8c91bcf4966f81068177a95d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand6.2.0/hiprand6.2.0_2.11.0.60200-66~22.04_amd64.deb", + sha256 = "b20c86be57698a944f91048699d0fbde5253bea28ba9d4035ce1de1d3c20f9ac", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand-dev6.2.0/hiprand-dev6.2.0_2.11.0.60200-66~22.04_amd64.deb", + sha256 = "9dab6f44b92b6020e183777f6f07219d68de5d10cad7538c7ddcae0192aa3e33", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hsa-rocr6.2.0/hsa-rocr6.2.0_1.14.0.60200-66~22.04_amd64.deb", + sha256 = "62d280204d8ff642b464dab03fc344442df6dc5f04e152da20604e8050303c41", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip6.2.0/miopen-hip6.2.0_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "6c2aa042067e51d5b70a264ca83c92ffaa6e81d00d08b55986917da860e66d85", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip-dev/miopen-hip-dev_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "f3452b2bd9c2869c550c7f963cca65fb35a37183ad4a56d96e05c69adb2f1d04", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl6.2.0/rccl6.2.0_2.20.5.60200-66~22.04_amd64.deb", + sha256 = "f3205c0a7d736f457ee2262988260e8dc4c495fa74a394ff73a9dfe002aff335", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl-dev6.2.0/rccl-dev6.2.0_2.20.5.60200-66~22.04_amd64.deb", + sha256 = "953a248cd44f403e5423185918166bfa29a009519c3d7b5b5a8e067fdf672602", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas6.2.0/rocblas6.2.0_4.2.0.60200-66~22.04_amd64.deb", + sha256 = "c306ca3e59b851ebb35872e09e5598adf2e2ebb736c1b200ff4ee204fe262f7e", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas-dev/rocblas-dev_4.2.0.60200-66~22.04_amd64.deb", + sha256 = "115d0e9ec1b93bf7cba5fa1e3de1428f0d999d931c2dd495e4cdad22b5078936", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft6.2.0/rocfft6.2.0_1.0.28.60200-66~22.04_amd64.deb", + sha256 = "0d40fc9aa1da617cd8864258cd1259a0e7444ea0da446297d154b5b3422393af", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft-dev6.2.0/rocfft-dev6.2.0_1.0.28.60200-66~22.04_amd64.deb", + sha256 = "8c1e72cf1c165e20960b0c2f3c499900a809d59340d14a0acff95c543c7087f2", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-core/rocm-core_6.2.0.60200-66~22.04_amd64.deb", + sha256 = "22c80c1a704f4ce7d6a49a8b41acd64f3ed0513cd7f5570a0664a10df5858334", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-hip-libraries/rocm-hip-libraries_6.2.0.60200-66~22.04_amd64.deb", + sha256 = "9c2ff1dc100e342969bd51a7cd4918048c8b25579de709efde56425d969cd50f", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev/hip-dev_6.2.41133.60200-66~22.04_amd64.deb", + sha256 = "1101f3edb9dbc9f4914d7f26b5569ec9bde076d52d4125c98d22a99dd730ab51", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-device-libs6.2.0/rocm-device-libs6.2.0_1.0.0.60200-66~22.04_amd64.deb", + sha256 = "d5b660df350130e0ab04ddf3e36dd442bde27ae9cbb8e5f12c047b0d3cb05463", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocminfo6.2.0/rocminfo6.2.0_1.0.0.60200-66~22.04_amd64.deb", + sha256 = "0d06a84ac53d388089b7b8c80133f60c1eea5bfd85155ecc113efb206a747c25", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm6.2.0/rocm-llvm6.2.0_18.0.0.24292.60200-66~22.04_amd64.deb", + sha256 = "4a29539480a7e4b27991ccf533a35526dd3994a457fa84e4c960192c2fa05b46", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm-dev/rocm-llvm-dev_18.0.0.24292.60200-66~22.04_amd64.deb", + sha256 = "febb8614cedd98f13ba0624072ffdd13b9a6dc3431380a17a0eaf87583627890", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprim-dev6.2.0/rocprim-dev6.2.0_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "3d859bb735ff8bf1962ce680e9257dcc574ab36224f50069f833fa19c6d7e69d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-smi-lib6.2.0/rocm-smi-lib6.2.0_7.3.0.60200-66~22.04_amd64.deb", + sha256 = "ffd4e064e8a1d52b9e72114e8a1d51c78004a960f1d923448af8ed07a1b6f30b", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprofiler-register6.2.0/rocprofiler-register6.2.0_0.4.0.60200-66~22.04_amd64.deb", + sha256 = "66df78d8c5e2d1a0ae43cd4a5e41cf75ec120c870a0bbd7da18a2ba4dec42f9c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocrand-dev/rocrand-dev_3.1.0.60200-66~22.04_amd64.deb", + sha256 = "317c16a6e0b0b456153437406dd92225e17dbd454fc1304b0c3fef5fbfc69bc2", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer6.2.0/roctracer6.2.0_4.1.60200.60200-66~22.04_amd64.deb", + sha256 = "9ddf8835f1e94d5004b4c466091c8110cb72e11eda545d0de395395832076c0a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer-dev6.2.0/roctracer-dev6.2.0_4.1.60200.60200-66~22.04_amd64.deb", + sha256 = "9a9ed0c66d3a9d9ff50f1fc3a9e9105bb8b1a6d93c1f856682625dfb68ab627f", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver6.2.0/rocsolver6.2.0_3.26.0.60200-66~22.04_amd64.deb", + sha256 = "5b86bf7b33a3ffa7098878f27d1b119aada69ebb02bd121b47209559c32703be", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver-dev6.2.0/rocsolver-dev6.2.0_3.26.0.60200-66~22.04_amd64.deb", + sha256 = "4573f99191fbe3a2afab84fdf5a05e024bd230ca7866d7eba71a5f2560a3a0bf", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsparse6.2.0/rocsparse6.2.0_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "4fbc91db9085ecd80a5e051bba56863ae33b22516d727ab3fef15fb500187222", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm2_2.4.110-1ubuntu1_amd64.deb", + sha256 = "e5ea68db36b31aab442c790e1c78ecdf53646c16b0cd83db15966632ba04152c", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm-amdgpu1_2.4.110-1ubuntu1_amd64.deb", + sha256 = "ae1f0d77668d7275d085ba820206ba91e90833dd1a02b8e251af0c73aa119ba3", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libelf1_0.186-1build1_amd64.deb", + sha256 = "8effc4d7a0cc341bcf6cb11af0134f3defa6292376ecfdfc697a9b228606345c", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libnuma1_2.0.14-3ubuntu2_amd64.deb", + sha256 = "0721c89001fbbd1ada23e89da5d60e762763c1a7b3dc814a2e9a518480a8043d", + ), + ], + "rocm_root": "opt/rocm-6.2.0", + }, +} diff --git a/third_party/xla/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_24_04.bzl b/third_party/xla/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_24_04.bzl new file mode 100644 index 00000000000000..da9ef00998f936 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_24_04.bzl @@ -0,0 +1,187 @@ +rocm_redist_ubuntu_24_04 = { + "6.2.0": { + "archives": [ + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/c/comgr6.2.0/comgr6.2.0_2.8.0.60200-66~24.04_amd64.deb", + sha256 = "7e1ff2d9f2435f5b9db9aa952bb57d1a878a8aa7d96bda61361c107b7e1428e3", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev6.2.0/hip-dev6.2.0_6.2.41133.60200-66~24.04_amd64.deb", + sha256 = "5e6601ada30432ee0dab0473585bdf1fa7c398f0c655538d48eba9c44e6dc77a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas6.2.0/hipblas6.2.0_2.2.0.60200-66~24.04_amd64.deb", + sha256 = "7ff8f6308c744c71008959b17ab6338de1c6fd3e4581dd94271e6eca9fdc4c13", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas-dev6.2.0/hipblas-dev6.2.0_2.2.0.60200-66~24.04_amd64.deb", + sha256 = "e9f71e71db600d72dcb2b61e64b965b6c60d47bd4bb699e8abec85edb260b819", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblaslt6.2.0/hipblaslt6.2.0_0.8.0.60200-66~24.04_amd64.deb", + sha256 = "e5dfd8ba9e49f919a96c102d3a652e8ef0c4d1a63b3f3909c856d40b1745e2a9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblaslt-dev6.2.0/hipblaslt-dev6.2.0_0.8.0.60200-66~24.04_amd64.deb", + sha256 = "639bd47010035ee6719425510be33d2f54483004a909dfa4c64f853d7394a22f", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcc6.2.0/hipcc6.2.0_1.1.1.60200-66~24.04_amd64.deb", + sha256 = "c2782a98633e4400f46ba732605e56b2821366db60ec06d88db0615e4d1acf3c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcub-dev6.2.0/hipcub-dev6.2.0_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "48fec4d06aef3159db4117125b728242a1eeb480ea3d55d3901d945d4b883694", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft6.2.0/hipfft6.2.0_1.0.14.60200-66~24.04_amd64.deb", + sha256 = "8dd73cdbd4f0563f4a0481304771e4cbcac5905eea1f2d8ef41f922cdf9aba85", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft-dev6.2.0/hipfft-dev6.2.0_1.0.14.60200-66~24.04_amd64.deb", + sha256 = "e3c0a4ebda8d3aacd44b19c6872f23222513be0a5c04f793605088d9183f1be4", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver6.2.0/hipsolver6.2.0_2.2.0.60200-66~24.04_amd64.deb", + sha256 = "adbba9ffcf8b5e4202efbe45924d87520bf4100ec5464bd0ba3beb61cb535c6c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver-dev6.2.0/hipsolver-dev6.2.0_2.2.0.60200-66~24.04_amd64.deb", + sha256 = "01d3dd6195111808b40a5837d3e51d8c27c4700b4bd8bb2d901e39d0474fd98a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse6.2.0/hipsparse6.2.0_3.1.1.60200-66~24.04_amd64.deb", + sha256 = "2ba33a96388cd3edd7b5b8b261fe99cbd569894f4d7db291fc0dd0ff5d7c67ce", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse-dev6.2.0/hipsparse-dev6.2.0_3.1.1.60200-66~24.04_amd64.deb", + sha256 = "6a767f493a722e2d4260a9bc23cf9db66fd275a094b395c768e305f60d6b4fe9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand6.2.0/hiprand6.2.0_2.11.0.60200-66~24.04_amd64.deb", + sha256 = "82f182134b415080ba4a12fd7993b6099ee9b9e549c72bfebee24c8486704078", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand-dev6.2.0/hiprand-dev6.2.0_2.11.0.60200-66~24.04_amd64.deb", + sha256 = "011d5c28f45cd9d756e0cf6ea6a3d37eabd98a3381ffd961c772ab92a37e4ee8", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hsa-rocr6.2.0/hsa-rocr6.2.0_1.14.0.60200-66~24.04_amd64.deb", + sha256 = "fa04f707debb75087ea2bf5e327602034eaa3a6900421f2cf32ad5f5f1c887b9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip6.2.0/miopen-hip6.2.0_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "2dbf6d126d0de6930e0cd94d0e525e07d3019d90bd7256f3151a7f1fbc2250af", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip-dev/miopen-hip-dev_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "df5fdd2218e4d380b133ba402f3734fbe0589d9cdd8618a101b71b968909b4ba", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl6.2.0/rccl6.2.0_2.20.5.60200-66~24.04_amd64.deb", + sha256 = "4d7efa4ee6aa2bf69b0aab449cc1d01c25ca65814e1b3cb07f6b59fa8b1608b8", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl-dev6.2.0/rccl-dev6.2.0_2.20.5.60200-66~24.04_amd64.deb", + sha256 = "4ab4f880344e04d61b6fa746be5c4bdc2841409fb6987ee61e39c6420b4eca42", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas6.2.0/rocblas6.2.0_4.2.0.60200-66~24.04_amd64.deb", + sha256 = "521c87ce396c6ce10076cc641b6035451fd68ddb36a684c5a9c9538dfc831ade", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas-dev/rocblas-dev_4.2.0.60200-66~24.04_amd64.deb", + sha256 = "00f135ce2ae47c35085ef06248ff7d5ce8c12fd0d5b82e7bd77b1dbc0ce7058e", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft6.2.0/rocfft6.2.0_1.0.28.60200-66~24.04_amd64.deb", + sha256 = "40c936452e84bfec87236f08de5a9d3f232c397a3305b6143c26697ed56ceda1", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft-dev6.2.0/rocfft-dev6.2.0_1.0.28.60200-66~24.04_amd64.deb", + sha256 = "eb3904263b396d46799eeea1081d8e8d1a551a890432a803364db2d013849f92", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-core/rocm-core_6.2.0.60200-66~24.04_amd64.deb", + sha256 = "af5fcbe8dc2b6cbec30e2d39d30736e8a47a0b9d0ca2be7f179f2947f9c98245", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-hip-libraries/rocm-hip-libraries_6.2.0.60200-66~24.04_amd64.deb", + sha256 = "228f07a3caefc41f6efd5345eb9d3630f1db769f9b4abd1313cbcb32d077ce53", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev/hip-dev_6.2.41133.60200-66~24.04_amd64.deb", + sha256 = "cda72054d2011dbb062e75386766d928fd8905c15c88685c3ef87fc963bd88ad", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-device-libs6.2.0/rocm-device-libs6.2.0_1.0.0.60200-66~24.04_amd64.deb", + sha256 = "298544f717dfb236b9257b19a0ab81abaaa770128976d4abfdea546cd32d8b02", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocminfo6.2.0/rocminfo6.2.0_1.0.0.60200-66~24.04_amd64.deb", + sha256 = "8e78ed8e480b55a496153b150acb22bab39c3bb8cf1e62f9aff7eaf75a3a3a92", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm6.2.0/rocm-llvm6.2.0_18.0.0.24292.60200-66~24.04_amd64.deb", + sha256 = "72c388eae7c0f54151b46fbd8fa6e26f1ca81e2b8b415c43411a156b3f25b6e7", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm-dev/rocm-llvm-dev_18.0.0.24292.60200-66~24.04_amd64.deb", + sha256 = "3e85a859c5dafa82a9a57dda096d566b821217bacfac995f7cc45ed460b68999", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-smi-lib6.2.0/rocm-smi-lib6.2.0_7.3.0.60200-66~24.04_amd64.deb", + sha256 = "c094e3022c73fca2aa6c8bb435f93550109531de37fe8de5fbf6cfe1f047b645", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprim-dev6.2.0/rocprim-dev6.2.0_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "6c832e2feb0885fbe481245825c76a466921b294f530eb0d0da70a44cfe6e608", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprofiler-register6.2.0/rocprofiler-register6.2.0_0.4.0.60200-66~24.04_amd64.deb", + sha256 = "d198d010fedfbe51d3fd19444e2848d430e08f91d19a5b2661b94ac6d1135863", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocrand-dev/rocrand-dev_3.1.0.60200-66~24.04_amd64.deb", + sha256 = "2a2a95185ce0e54df226474b2f5cfcdc9e5ede5a6d88a8a70c2635ea2237abba", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer6.2.0/roctracer6.2.0_4.1.60200.60200-66~24.04_amd64.deb", + sha256 = "2f2fb6f8d06ace89131934c833b0ea359335a4b45aeec1559b293d7bc14b1d1d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer-dev6.2.0/roctracer-dev6.2.0_4.1.60200.60200-66~24.04_amd64.deb", + sha256 = "c6c781ee87c459aed32e943b389137f98ecd402fb83a3d1c98de9a76abadc3a3", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver6.2.0/rocsolver6.2.0_3.26.0.60200-66~24.04_amd64.deb", + sha256 = "5e4b3e38556f0826e5322971635a49a72283d60862ccc4d28efd11c8fb955b47", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver-dev6.2.0/rocsolver-dev6.2.0_3.26.0.60200-66~24.04_amd64.deb", + sha256 = "5bb6ae92a25f33488f2ee5f123ac4f67ad130e18e4949161715451509be3b89d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsparse6.2.0/rocsparse6.2.0_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "1867833a569fbf3f87b82c81bc47f5d62085ea40f12d1cb33475c1f2dec89bc4", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm2_2.4.120-2build1_amd64.deb", + sha256 = "f5fb4e7ce17921cc466fb7911abf91495ffb181b36772f68e2e82cb621703112", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm-amdgpu1_2.4.120-2build1_amd64.deb", + sha256 = "e149d4daea33f58853b8013fd6c24888429ce7716a4b26d1a1f45181b5a4e73e", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libelf1t64_0.190-1.1build4_amd64.deb", + sha256 = "b277e52769302778bd052376ac6687b52954b6605dd5f781bff8631e3504d58f", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libnuma1_2.0.18-1build1_amd64.deb", + sha256 = "508daa855e99959acaa945e6a89d218e0be6b5727fd28773580942ff37cf5805", + ), + ], + "rocm_root": "opt/rocm-6.2.0", + }, +} diff --git a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl index f315d6f60f750f..b61324179ca597 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl @@ -12,6 +12,10 @@ * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets. """ +load( + "//third_party/gpus/rocm:rocm_redist.bzl", + "rocm_redist", +) load( "//third_party/remote_config:common.bzl", "config_repo_label", @@ -33,8 +37,6 @@ load( load( ":cuda_configure.bzl", "enable_cuda", - "make_copy_dir_rule", - "make_copy_files_rule", ) load( ":sycl_configure.bzl", @@ -48,6 +50,9 @@ _TF_SYSROOT = "TF_SYSROOT" _ROCM_TOOLKIT_PATH = "ROCM_PATH" _TF_ROCM_AMDGPU_TARGETS = "TF_ROCM_AMDGPU_TARGETS" _TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO" +_DISTRIBUTION_PATH = "rocm/rocm_dist" +_OS = "OS" +_ROCM_VERSION = "ROCM_VERSION" _DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm" @@ -203,20 +208,8 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin): """ inc_dirs = [] - # Add HSA headers (needs to match $HSA_PATH) - inc_dirs.append(rocm_config.rocm_toolkit_path + "/hsa/include") - - # Add HIP headers (needs to match $HIP_PATH) - inc_dirs.append(rocm_config.rocm_toolkit_path + "/hip/include") - if int(rocm_config.rocm_version_number) >= 50200: - inc_dirs.append(rocm_config.rocm_toolkit_path + "/include") - inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/hip") - inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/rocprim") - inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/rocsolver") - inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/rocblas") - - # Add HIP-Clang headers (realpath relative to compiler binary) - rocm_toolkit_path = realpath(repository_ctx, rocm_config.rocm_toolkit_path, bash_bin) + # Add full paths + rocm_toolkit_path = str(repository_ctx.path(rocm_config.rocm_toolkit_path)) inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/8.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/9.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/10.0.0/include") @@ -367,7 +360,7 @@ def _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin): return libs -def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_path, bash_bin): +def _find_libs(repository_ctx, rocm_config, miopen_path, rccl_path, bash_bin): """Returns the ROCm libraries on the system. Args: @@ -383,7 +376,6 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_ for name, path in [ ("amdhip64", rocm_config.rocm_toolkit_path), ("rocblas", rocm_config.rocm_toolkit_path), - (hipfft_or_rocfft, rocm_config.rocm_toolkit_path), ("hiprand", rocm_config.rocm_toolkit_path), ("MIOpen", miopen_path), ("rccl", rccl_path), @@ -401,17 +393,17 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_ libs_paths.append(("hipblaslt", _rocm_lib_paths(repository_ctx, "hipblaslt", rocm_config.rocm_toolkit_path), True)) return _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin) -def find_rocm_config(repository_ctx): +def find_rocm_config(repository_ctx, rocm_path): """Returns ROCm config dictionary from running find_rocm_config.py""" python_bin = get_python_bin(repository_ctx) - exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_rocm_config]) + exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_rocm_config], env_vars = {"ROCM_PATH": rocm_path}) if exec_result.return_code: auto_configure_fail("Failed to run find_rocm_config.py: %s" % err_out(exec_result)) # Parse the dict from stdout. return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()]) -def _get_rocm_config(repository_ctx, bash_bin): +def _get_rocm_config(repository_ctx, bash_bin, rocm_path, install_path): """Detects and returns information about the ROCm installation on the system. Args: @@ -426,7 +418,7 @@ def _get_rocm_config(repository_ctx, bash_bin): miopen_version_number: The version of MIOpen on the system. hipruntime_version_number: The version of HIP Runtime on the system. """ - config = find_rocm_config(repository_ctx) + config = find_rocm_config(repository_ctx, rocm_path) rocm_toolkit_path = config["rocm_toolkit_path"] rocm_version_number = config["rocm_version_number"] miopen_version_number = config["miopen_version_number"] @@ -437,6 +429,7 @@ def _get_rocm_config(repository_ctx, bash_bin): rocm_version_number = rocm_version_number, miopen_version_number = miopen_version_number, hipruntime_version_number = hipruntime_version_number, + install_path = install_path, ) def _tpl_path(repository_ctx, labelname): @@ -500,15 +493,12 @@ def _create_dummy_repository(repository_ctx): "%{hipblas_lib}": _lib_name("hipblas"), "%{miopen_lib}": _lib_name("miopen"), "%{rccl_lib}": _lib_name("rccl"), - "%{hipfft_or_rocfft}": "hipfft", - "%{hipfft_or_rocfft_lib}": _lib_name("hipfft"), "%{hiprand_lib}": _lib_name("hiprand"), "%{hipsparse_lib}": _lib_name("hipsparse"), "%{roctracer_lib}": _lib_name("roctracer64"), "%{rocsolver_lib}": _lib_name("rocsolver"), "%{hipsolver_lib}": _lib_name("hipsolver"), "%{hipblaslt_lib}": _lib_name("hipblaslt"), - "%{copy_rules}": "", "%{rocm_headers}": "", }, ) @@ -526,7 +516,7 @@ def _create_dummy_repository(repository_ctx): "%{rocm_toolkit_path}": _DEFAULT_ROCM_TOOLKIT_PATH, "%{hipblaslt_flag}": "0", }, - "rocm/rocm/rocm_config.h", + "rocm/rocm_config/rocm_config.h", ) # If rocm_configure is not configured to build with GPU support, and the user @@ -578,6 +568,53 @@ def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets): amdgpu_target for amdgpu_target in amdgpu_targets] return str(amdgpu_target_flags) +def _get_file_name(url): + last_slash_index = url.rfind("/") + return url[last_slash_index + 1:] + +def _download_package(repository_ctx, archive): + file_name = _get_file_name(archive.url) + tmp_dir = "tmp" + repository_ctx.file(tmp_dir + "/.idx") # create tmp dir + + repository_ctx.report_progress("Downloading and extracting {}, expected hash is {}".format(archive.url, archive.sha256)) # buildifier: disable=print + repository_ctx.download_and_extract( + url = archive.url, + output = tmp_dir if archive.url.endswith(".deb") else _DISTRIBUTION_PATH, + sha256 = archive.sha256, + ) + + all_files = repository_ctx.path(tmp_dir).readdir() + + matched_files = [f for f in all_files if _get_file_name(str(f)).startswith("data.")] + for f in matched_files: + repository_ctx.extract(f, _DISTRIBUTION_PATH) + + repository_ctx.delete(tmp_dir) + repository_ctx.delete(file_name) + +def _remove_root_dir(path, root_dir): + if path.startswith(root_dir + "/"): + return path[len(root_dir) + 1:] + return path + +def _setup_rocm_distro_dir(repository_ctx): + """Sets up the rocm hermetic installation directory to be used in hermetic build""" + bash_bin = get_bash_bin(repository_ctx) + os = repository_ctx.os.environ.get(_OS) + rocm_version = repository_ctx.os.environ.get(_ROCM_VERSION) + if os and rocm_version: + redist = rocm_redist[os][rocm_version] + repository_ctx.file("rocm/.index") + for archive in redist["archives"]: + _download_package(repository_ctx, archive) + return _get_rocm_config(repository_ctx, bash_bin, "{}/{}".format(_DISTRIBUTION_PATH, redist["rocm_root"]), "/{}".format(redist["rocm_root"])) + else: + rocm_path = repository_ctx.os.environ.get(_ROCM_TOOLKIT_PATH, _DEFAULT_ROCM_TOOLKIT_PATH) + repository_ctx.report_progress("Using local rocm installation {}".format(rocm_path)) # buildifier: disable=print + repository_ctx.symlink(rocm_path, _DISTRIBUTION_PATH) + return _get_rocm_config(repository_ctx, bash_bin, _DISTRIBUTION_PATH, _DEFAULT_ROCM_TOOLKIT_PATH) + def _create_local_rocm_repository(repository_ctx): """Creates the repository containing files set up to build with ROCm.""" @@ -590,12 +627,8 @@ def _create_local_rocm_repository(repository_ctx): "rocm:rocm_config.h", ]} - bash_bin = get_bash_bin(repository_ctx) - rocm_config = _get_rocm_config(repository_ctx, bash_bin) - - # For ROCm 4.1 and above use hipfft, older ROCm versions use rocfft + rocm_config = _setup_rocm_distro_dir(repository_ctx) rocm_version_number = int(rocm_config.rocm_version_number) - hipfft_or_rocfft = "rocfft" if rocm_version_number < 40100 else "hipfft" # For ROCm 5.2 and above, find MIOpen and RCCL in the main rocm lib path miopen_path = rocm_config.rocm_toolkit_path + "/miopen" if rocm_version_number < 50200 else rocm_config.rocm_toolkit_path @@ -603,75 +636,19 @@ def _create_local_rocm_repository(repository_ctx): # Copy header and library files to execroot. # rocm_toolkit_path - rocm_toolkit_path = rocm_config.rocm_toolkit_path - copy_rules = [ - make_copy_dir_rule( - repository_ctx, - name = "rocm-include", - src_dir = rocm_toolkit_path + "/include", - out_dir = "rocm/include", - ), - ] - - # explicitly copy (into the local_config_rocm repo) the $ROCM_PATH/hiprand/include and - # $ROCM_PATH/rocrand/include dirs, only once the softlink to them in $ROCM_PATH/include - # dir has been removed. This removal will happen in a near-future ROCm release. - hiprand_include = "" - hiprand_include_softlink = rocm_config.rocm_toolkit_path + "/include/hiprand" - softlink_exists = files_exist(repository_ctx, [hiprand_include_softlink], bash_bin) - if not softlink_exists[0]: - hiprand_include = '":hiprand-include",\n' - copy_rules.append( - make_copy_dir_rule( - repository_ctx, - name = "hiprand-include", - src_dir = rocm_toolkit_path + "/hiprand/include", - out_dir = "rocm/include/hiprand", - ), - ) - - rocrand_include = "" - rocrand_include_softlink = rocm_config.rocm_toolkit_path + "/include/rocrand" - softlink_exists = files_exist(repository_ctx, [rocrand_include_softlink], bash_bin) - if not softlink_exists[0]: - rocrand_include = '":rocrand-include",\n' - copy_rules.append( - make_copy_dir_rule( - repository_ctx, - name = "rocrand-include", - src_dir = rocm_toolkit_path + "/rocrand/include", - out_dir = "rocm/include/rocrand", - ), - ) + rocm_toolkit_path = _remove_root_dir(rocm_config.rocm_toolkit_path, "rocm") - rocm_libs = _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_path, bash_bin) + bash_bin = get_bash_bin(repository_ctx) + rocm_libs = _find_libs(repository_ctx, rocm_config, miopen_path, rccl_path, bash_bin) rocm_lib_srcs = [] rocm_lib_outs = [] for lib in rocm_libs.values(): if lib: rocm_lib_srcs.append(lib.path) rocm_lib_outs.append("rocm/lib/" + lib.file_name) - copy_rules.append(make_copy_files_rule( - repository_ctx, - name = "rocm-lib", - srcs = rocm_lib_srcs, - outs = rocm_lib_outs, - )) clang_offload_bundler_path = rocm_toolkit_path + "/llvm/bin/clang-offload-bundler" - # copy files mentioned in third_party/gpus/rocm/BUILD - copy_rules.append(make_copy_files_rule( - repository_ctx, - name = "rocm-bin", - srcs = [ - clang_offload_bundler_path, - ], - outs = [ - "rocm/bin/" + "clang-offload-bundler", - ], - )) - have_hipblaslt = "1" if rocm_libs["hipblaslt"] != None else "0" # Set up BUILD file for rocm/ @@ -693,20 +670,8 @@ def _create_local_rocm_repository(repository_ctx): ) repository_dict = { - "%{hip_lib}": rocm_libs["amdhip64"].file_name, - "%{rocblas_lib}": rocm_libs["rocblas"].file_name, - "%{hipfft_or_rocfft}": hipfft_or_rocfft, - "%{hipfft_or_rocfft_lib}": rocm_libs[hipfft_or_rocfft].file_name, - "%{hiprand_lib}": rocm_libs["hiprand"].file_name, - "%{miopen_lib}": rocm_libs["MIOpen"].file_name, - "%{rccl_lib}": rocm_libs["rccl"].file_name, - "%{hipsparse_lib}": rocm_libs["hipsparse"].file_name, - "%{roctracer_lib}": rocm_libs["roctracer64"].file_name, - "%{rocsolver_lib}": rocm_libs["rocsolver"].file_name, - "%{copy_rules}": "\n".join(copy_rules), - "%{rocm_headers}": ('":rocm-include",\n' + - hiprand_include + - rocrand_include), + "%{rocm_root}": rocm_toolkit_path, + "%{rocm_toolkit_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path)), } is_rocm_clang = _use_rocm_clang(repository_ctx) @@ -726,7 +691,6 @@ def _create_local_rocm_repository(repository_ctx): ) # Set up crosstool/ - cc = find_cc(repository_ctx, is_rocm_clang) host_compiler_includes = get_cxx_inc_directories( repository_ctx, @@ -785,6 +749,7 @@ def _create_local_rocm_repository(repository_ctx): repository_ctx.template( "crosstool/cc_toolchain_config.bzl", tpl_paths["crosstool:hipcc_cc_toolchain_config.bzl"], + rocm_defines, ) repository_ctx.template( @@ -794,10 +759,12 @@ def _create_local_rocm_repository(repository_ctx): "%{cpu_compiler}": str(cc), "%{compiler}": "clang" if is_rocm_clang else "unknown", "%{hipcc_path}": rocm_config.rocm_toolkit_path + "/bin/hipcc", + "%{compiler_is_clang}": "True" if is_rocm_clang else "False", "%{hipcc_env}": _hipcc_env(repository_ctx), - "%{rocr_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", + "%{rocm_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path)), + "%{rocr_runtime_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path + "/lib")), "%{rocr_runtime_library}": "hsa-runtime64", - "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", + "%{hip_runtime_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path + "/lib")), "%{hip_runtime_library}": "amdhip64", "%{crosstool_verbose}": _crosstool_verbose(repository_ctx), "%{gcc_host_compiler_path}": str(cc), @@ -807,13 +774,32 @@ def _create_local_rocm_repository(repository_ctx): # Set up rocm_config.h, which is used by # tensorflow/compiler/xla/stream_executor/dso_loader.cc. repository_ctx.template( - "rocm/rocm/rocm_config.h", + "rocm/rocm_config/rocm_config.h", + tpl_paths["rocm:rocm_config.h"], + { + "%{rocm_amdgpu_targets}": ",".join( + ["\"%s\"" % c for c in rocm_config.amdgpu_targets], + ), + "%{rocm_toolkit_path}": rocm_config.install_path, + "%{rocm_version_number}": rocm_config.rocm_version_number, + "%{miopen_version_number}": rocm_config.miopen_version_number, + "%{hipruntime_version_number}": rocm_config.hipruntime_version_number, + "%{hipblaslt_flag}": have_hipblaslt, + "%{hip_soversion_number}": "6" if int(rocm_config.rocm_version_number) >= 60000 else "5", + "%{rocblas_soversion_number}": "4" if int(rocm_config.rocm_version_number) >= 60000 else "3", + }, + ) + + # Set up rocm_config.h, which is used by + # tensorflow/compiler/xla/stream_executor/dso_loader.cc. + repository_ctx.template( + "rocm/rocm_config_hermetic/rocm_config.h", tpl_paths["rocm:rocm_config.h"], { "%{rocm_amdgpu_targets}": ",".join( ["\"%s\"" % c for c in rocm_config.amdgpu_targets], ), - "%{rocm_toolkit_path}": rocm_config.rocm_toolkit_path, + "%{rocm_toolkit_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path)), "%{rocm_version_number}": rocm_config.rocm_version_number, "%{miopen_version_number}": rocm_config.miopen_version_number, "%{hipruntime_version_number}": rocm_config.hipruntime_version_number, @@ -889,6 +875,8 @@ _ENVIRONS = [ "TF_NEED_CUDA", # Needed by the `if_gpu_is_configured` macro _ROCM_TOOLKIT_PATH, _TF_ROCM_AMDGPU_TARGETS, + _OS, + _ROCM_VERSION, ] remote_rocm_configure = repository_rule( diff --git a/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_acl.BUILD b/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_acl.BUILD index 868a2972a44861..56686b95fbefef 100644 --- a/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_acl.BUILD +++ b/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_acl.BUILD @@ -167,6 +167,7 @@ cc_library( "include/**/*", "include/*", "src/common/*.hpp", + "src/common/**/*.h", "src/cpu/**/*.hpp", "src/cpu/*.hpp", "src/cpu/aarch64/xbyak_aarch64/**/*.h", diff --git a/third_party/xla/third_party/tsl/third_party/mkl_dnn/onednn_acl_add_bf16_platform_support_check.patch b/third_party/xla/third_party/tsl/third_party/mkl_dnn/onednn_acl_add_bf16_platform_support_check.patch new file mode 100644 index 00000000000000..42dd262323b577 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/mkl_dnn/onednn_acl_add_bf16_platform_support_check.patch @@ -0,0 +1,31 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +diff --git a/src/cpu/platform.cpp b/src/cpu/platform.cpp +index 65b887ea21..eabdb827bd 100644 +--- a/src/cpu/platform.cpp ++++ b/src/cpu/platform.cpp +@@ -117,6 +117,8 @@ bool has_data_type_support(data_type_t data_type) { + #if defined(USE_CBLAS) && defined(BLAS_HAS_SBGEMM) && defined(__MMA__) + return true; + #endif ++#elif DNNL_AARCH64_USE_ACL ++ return arm_compute::CPUInfo::get().has_bf16(); + #else + return false; + #endif +-- +2.34.1 + diff --git a/third_party/xla/third_party/tsl/third_party/mkl_dnn/onednn_acl_add_sbgemm_matmul_primitive_definition.patch b/third_party/xla/third_party/tsl/third_party/mkl_dnn/onednn_acl_add_sbgemm_matmul_primitive_definition.patch new file mode 100644 index 00000000000000..779608a68058d2 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/mkl_dnn/onednn_acl_add_sbgemm_matmul_primitive_definition.patch @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +diff --git a/src/cpu/aarch64/matmul/acl_matmul.hpp b/src/cpu/aarch64/matmul/acl_matmul.hpp +index ab13efb9b2..ec261e156d 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul.hpp ++++ b/src/cpu/aarch64/matmul/acl_matmul.hpp +@@ -78,11 +78,21 @@ struct acl_matmul_t : public primitive_t { + = utils::everyone_is(data_type::f16, src_md()->data_type, + weights_md()->data_type, dst_md()->data_type) + && platform::has_data_type_support(data_type::f16); ++ const bool is_fp32_bf16_ok ++ = (utils::everyone_is(data_type::f32, src_md()->data_type, ++ dst_md()->data_type, desc()->accum_data_type) ++ && platform::has_data_type_support(data_type::f32) ++ && utils::everyone_is( ++ data_type::bf16, weights_md()->data_type) ++ && platform::has_data_type_support( ++ data_type::bf16)); ++ + const bool is_weights_md_format_ok + = utils::one_of(weights_format_kind_received, + format_kind::any, format_kind::blocked); + bool ok = is_dense_data() +- && utils::one_of(true, is_fp32_ok, is_fp16_ok) ++ && utils::one_of( ++ true, is_fp32_ok, is_fp16_ok, is_fp32_bf16_ok) + && !has_zero_dim_memory() && is_weights_md_format_ok + && set_default_formats() + && attr()->has_default_values( +-- +2.34.1 diff --git a/third_party/xla/third_party/tsl/third_party/mkl_dnn/onednn_acl_allow_blocked_weight_format_for_matmul_primitive.patch b/third_party/xla/third_party/tsl/third_party/mkl_dnn/onednn_acl_allow_blocked_weight_format_for_matmul_primitive.patch new file mode 100644 index 00000000000000..ec2cb97f5131ba --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/mkl_dnn/onednn_acl_allow_blocked_weight_format_for_matmul_primitive.patch @@ -0,0 +1,100 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +diff --git a/src/cpu/aarch64/matmul/acl_matmul.hpp b/src/cpu/aarch64/matmul/acl_matmul.hpp +index 451cc78d52..ab13efb9b2 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul.hpp ++++ b/src/cpu/aarch64/matmul/acl_matmul.hpp +@@ -67,6 +67,8 @@ struct acl_matmul_t : public primitive_t { + + status_t init(engine_t *engine) { + using smask_t = primitive_attr_t::skip_mask_t; ++ const format_kind_t weights_format_kind_received ++ = weights_md_.format_kind; + const bool is_fp32_ok + = utils::everyone_is(data_type::f32, src_md()->data_type, + weights_md()->data_type, dst_md()->data_type, +@@ -76,18 +78,20 @@ struct acl_matmul_t : public primitive_t { + = utils::everyone_is(data_type::f16, src_md()->data_type, + weights_md()->data_type, dst_md()->data_type) + && platform::has_data_type_support(data_type::f16); ++ const bool is_weights_md_format_ok ++ = utils::one_of(weights_format_kind_received, ++ format_kind::any, format_kind::blocked); + bool ok = is_dense_data() + && utils::one_of(true, is_fp32_ok, is_fp16_ok) +- && !has_zero_dim_memory() +- && weights_md_.format_kind == format_kind::any ++ && !has_zero_dim_memory() && is_weights_md_format_ok + && set_default_formats() + && attr()->has_default_values( + smask_t::oscale | smask_t::post_ops) + && attr_oscale_ok() && !has_runtime_dims_or_strides(); + if (!ok) return status::unimplemented; + +- CHECK(acl_matmul_utils::init_conf_matmul( +- amp_, src_md_, weights_md_, dst_md_, *desc(), *attr())); ++ CHECK(acl_matmul_utils::init_conf_matmul(amp_, src_md_, weights_md_, ++ dst_md_, *desc(), *attr(), weights_format_kind_received)); + + arm_compute::ActivationLayerInfo act_info; + CHECK(post_ops.init(engine, attr_.post_ops_, dst_md_, act_info)); +diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp +index a314d96384..027f915a8a 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp ++++ b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp +@@ -27,7 +27,8 @@ namespace acl_matmul_utils { + + status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, + memory_desc_t &wei_md, memory_desc_t &dst_md, const matmul_desc_t &md, +- const primitive_attr_t &attr) { ++ const primitive_attr_t &attr, ++ format_kind_t weights_format_kind_received) { + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper wei_d(&wei_md); +@@ -128,9 +129,16 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, + for (dim_t i = K_dim - 1; i >= 0; --i) + batch_dims.push_back(i); + ++ const memory_desc_t weights_md_received = wei_md; + acl_utils::reorder_to_weight_format(amp.wei_tensor_info, wei_md, + expected_weight_format, K_dim, N_dim, {}, batch_dims); + ++ ACL_CHECK_SUPPORT((weights_format_kind_received == format_kind::blocked) ++ && !(dnnl_memory_desc_equal(&weights_md_received, &wei_md)), ++ "specified blocked format not supported by ACL, use " ++ "format_kind_t::any to find a supported blocked format for " ++ "your platform"); ++ + return status::success; + } + +diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp +index 67bb2e78eb..5ba4241abc 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp ++++ b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp +@@ -52,7 +52,8 @@ namespace acl_matmul_utils { + + status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, + memory_desc_t &wei_md, memory_desc_t &dst_md, const matmul_desc_t &md, +- const primitive_attr_t &attr); ++ const primitive_attr_t &attr, ++ format_kind_t weights_format_kind_received); + + } // namespace acl_matmul_utils + +-- +2.34.1 diff --git a/third_party/xla/third_party/tsl/third_party/mkl_dnn/onednn_acl_fix_segfault_during_postop_execute.patch b/third_party/xla/third_party/tsl/third_party/mkl_dnn/onednn_acl_fix_segfault_during_postop_execute.patch new file mode 100644 index 00000000000000..39f7e74345e08b --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/mkl_dnn/onednn_acl_fix_segfault_during_postop_execute.patch @@ -0,0 +1,96 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +diff --git a/src/cpu/aarch64/acl_post_ops.cpp b/src/cpu/aarch64/acl_post_ops.cpp +index ea4bb200ec..3eb53b81bd 100644 +--- a/src/cpu/aarch64/acl_post_ops.cpp ++++ b/src/cpu/aarch64/acl_post_ops.cpp +@@ -24,7 +24,7 @@ namespace aarch64 { + + status_t acl_post_ops_t::execute(const exec_ctx_t &ctx, void *src_orig) const { + +- int post_op_index = 0; ++ int post_op_index = post_op_start_index_; + + // As these are post ops, this src will also be our dst. If we have a sum + // post op, the src/dst will start off in a temporary, then change to +diff --git a/src/cpu/aarch64/acl_post_ops.hpp b/src/cpu/aarch64/acl_post_ops.hpp +index 7b59ad71d3..ceaa95b73a 100644 +--- a/src/cpu/aarch64/acl_post_ops.hpp ++++ b/src/cpu/aarch64/acl_post_ops.hpp +@@ -32,7 +32,9 @@ struct acl_post_ops_t { + // init the acl_post_ops_t. Note that this function modifies the passed in + // post ops by setting the preferred memory formats + status_t init(engine_t *engine, post_ops_t &post_ops, +- const memory_desc_t &dst_md) { ++ const memory_desc_t &dst_md, int post_op_start_index = 0) { ++ ++ post_op_start_index_ = post_op_start_index; + + CHECK(post_ops.set_default_formats(&dst_md)); + dst_data_type = dst_md.data_type; +@@ -41,7 +43,7 @@ struct acl_post_ops_t { + sum_index = -1; + post_op_primitives = {}; + +- for (int i = 0; i < post_ops.len(); i++) { ++ for (int i = post_op_start_index; i < post_ops.len(); i++) { + auto &po = post_ops.entry_[i]; + + if (po.is_sum()) { +@@ -135,7 +137,8 @@ struct acl_post_ops_t { + // formats + status_t init(engine_t *engine, post_ops_t &base_post_ops, + const memory_desc_t &dst_md, +- arm_compute::ActivationLayerInfo &act_info_to_fuse) { ++ arm_compute::ActivationLayerInfo &act_info_to_fuse, ++ int post_op_start_index = 0) { + + CHECK(base_post_ops.set_default_formats(&dst_md)); + dst_data_type = dst_md.data_type; +@@ -149,18 +152,11 @@ struct acl_post_ops_t { + "eltwise post op scale must be 1 (no scale)"); + CHECK(acl_utils::convert_to_acl_act(first_po, act_info_to_fuse)); + +- // Copy all but the first, because it has been fused +- post_ops_t post_ops; +- for (int idx = 1; idx < base_post_ops.len(); ++idx) { +- // Construct empty entry then copy, so that we can check for failure +- post_ops.entry_.emplace_back(); +- post_ops.entry_.back().copy_from(base_post_ops.entry_[idx]); +- } +- return init(engine, post_ops, dst_md); +- ++ // post_op_start_index + 1 to skip the fused eltwise ++ return init(engine, base_post_ops, dst_md, post_op_start_index + 1); + } else { + // Nothing to fuse, just copy all post ops +- return init(engine, base_post_ops, dst_md); ++ return init(engine, base_post_ops, dst_md, post_op_start_index); + } + } + +@@ -179,6 +175,9 @@ struct acl_post_ops_t { + private: + // Index of the sum post op if there is one, < 0 means no sum + int sum_index = -1; ++ // Index of the first post op this primitive executes. This is typically the ++ // number of post ops which were fused. ++ int post_op_start_index_ = 0; + data_type_t dst_data_type; + // Vector of primitives used to execute the post ops. They are constructed + // in init to be either acl_binary_t (for sum, add, sub, div, mul, min and +-- +2.34.1 diff --git a/third_party/xla/third_party/tsl/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl index 61d7809bcdaad1..51e7c35200fd34 100644 --- a/third_party/xla/third_party/tsl/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl @@ -1,4 +1,8 @@ licenses(["restricted"]) # NVIDIA proprietary license +load( + "@local_xla//xla/tsl/platform/default:cuda_build_defs.bzl", + "cuda_rpath_flags", +) exports_files([ "version.txt", @@ -14,6 +18,7 @@ cc_import( cc_library( name = "nccl", %{comment}deps = [":nccl_shared_library"], + %{comment}linkopts = cuda_rpath_flags("nvidia/nccl/lib"), visibility = ["//visibility:public"], ) diff --git a/third_party/xla/third_party/tsl/third_party/py/BUILD b/third_party/xla/third_party/tsl/third_party/py/BUILD index 7250861f26bfa2..661e8950c4dc2d 100644 --- a/third_party/xla/third_party/tsl/third_party/py/BUILD +++ b/third_party/xla/third_party/tsl/third_party/py/BUILD @@ -53,22 +53,8 @@ config_setting( }, ) -# Flag indicating if the target requires manylinux compliance verification. -bool_flag( - name = "verify_manylinux", - # TODO(ybaturina): Enable the flag by default when hermetic C++ toolchain is ready. - build_setting_default = False, +filegroup( + name = "manylinux_compliance_test", + srcs = ["manylinux_compliance_test.py"], visibility = ["//visibility:public"], ) - -py_binary( - name = "verify_manylinux_compliance", - srcs = [ - "verify_manylinux_compliance.py", - ], - main = "verify_manylinux_compliance.py", - visibility = ["//visibility:public"], - deps = [ - "@pypi_auditwheel//:pkg", - ], -) diff --git a/third_party/xla/third_party/tsl/third_party/py/verify_manylinux_compliance.py b/third_party/xla/third_party/tsl/third_party/py/manylinux_compliance_test.py similarity index 65% rename from third_party/xla/third_party/tsl/third_party/py/verify_manylinux_compliance.py rename to third_party/xla/third_party/tsl/third_party/py/manylinux_compliance_test.py index 5afbae839abff6..734892d5469ebf 100644 --- a/third_party/xla/third_party/tsl/third_party/py/verify_manylinux_compliance.py +++ b/third_party/xla/third_party/tsl/third_party/py/manylinux_compliance_test.py @@ -1,40 +1,44 @@ -# Copyright 2024 The Tensorflow Authors. +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# https://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tool to verify wheel manylinux compliance.""" +# ============================================================================== import argparse import io +import platform import re import sys from auditwheel import main_show -def parse_args() -> argparse.Namespace: +def parse_args(): """Arguments parser.""" parser = argparse.ArgumentParser( - description="Helper for auditwheel", fromfile_prefix_chars="@" + description="Helper for manylinux compliance verification", + fromfile_prefix_chars="@", ) parser.add_argument( - "--wheel_path", required=True, help="Path of the wheel, mandatory" + "--wheel-path", required=True, help="Path of the wheel, mandatory" ) parser.add_argument( - "--compliance-tag", help="ManyLinux compliance tag", required=False + "--aarch64-compliance-tag", + required=True, + help="ManyLinux compliance tag for aarch64", ) parser.add_argument( - "--auditwheel-show-log-path", - help="Path to file with auditwheel show results, mandatory", + "--x86_64-compliance-tag", required=True, + help="ManyLinux compliance tag for x86_64", ) return parser.parse_args() @@ -70,39 +74,37 @@ def get_auditwheel_output(wheel_path: str) -> None: def verify_manylinux_compliance( auditwheel_log: str, compliance_tag: str, - auditwheel_show_log_path: str, ) -> None: """Verify manylinux compliance. Args: auditwheel_log: "auditwheel show" execution results compliance_tag: manyLinux compliance tag - auditwheel_show_log_path: path to file with auditwheel show results Raises: RuntimeError: if the wheel is not manyLinux compliant. """ - with open(auditwheel_show_log_path, "w") as auditwheel_show_log: - auditwheel_show_log.write(auditwheel_log) - if not compliance_tag: - return regex = 'following platform tag: "{}"'.format(compliance_tag) if not re.search(regex, auditwheel_log): raise RuntimeError( - ( - "The wheel is not compliant with tag {tag}." - + " If you want to disable this check, please provide" - + " `--@local_tsl//third_party/py:verify_manylinux=false`." - + "\n{result}" - ).format(tag=compliance_tag, result=auditwheel_log) + ("The wheel is not compliant with the tag {tag}.\n{result}").format( + tag=compliance_tag, result=auditwheel_log + ) ) -if __name__ == "__main__": - args = parse_args() +def test_manylinux_compliance(args): + machine_type = platform.uname().machine + if machine_type == "x86_64": + compliance_tag = args.x86_64_compliance_tag + else: + compliance_tag = args.aarch64_compliance_tag auditwheel_output = get_auditwheel_output(args.wheel_path) verify_manylinux_compliance( auditwheel_output, - args.compliance_tag, - args.auditwheel_show_log_path, + compliance_tag, ) + + +if __name__ == "__main__": + test_manylinux_compliance(parse_args()) diff --git a/third_party/xla/third_party/tsl/third_party/py/ml_dtypes/workspace.bzl b/third_party/xla/third_party/tsl/third_party/py/ml_dtypes/workspace.bzl index 29a551da8d0017..962fb487c2d2f4 100644 --- a/third_party/xla/third_party/tsl/third_party/py/ml_dtypes/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/py/ml_dtypes/workspace.bzl @@ -7,8 +7,8 @@ float8 varieties, and int4. load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - ML_DTYPES_COMMIT = "c12281a501469d553483eb4d68065826b9c2fcb5" - ML_DTYPES_SHA256 = "cee11c4bed5147bece9e385a88c20887344ad9b89b3acb09bf3d7c9c21fb9715" + ML_DTYPES_COMMIT = "0fa5313b65efe848c5968a15dd37dd220cc29567" + ML_DTYPES_SHA256 = "69c562bb961a21d92357c7709430553c226caac75a751c0aa52955ca14ce8641" tf_http_archive( name = "ml_dtypes", build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD", diff --git a/third_party/xla/third_party/tsl/third_party/py/py_import.bzl b/third_party/xla/third_party/tsl/third_party/py/py_import.bzl index b00ca49418423d..38a1ae1da7c325 100644 --- a/third_party/xla/third_party/tsl/third_party/py/py_import.bzl +++ b/third_party/xla/third_party/tsl/third_party/py/py_import.bzl @@ -2,31 +2,23 @@ def _unpacked_wheel_impl(ctx): output_dir = ctx.actions.declare_directory(ctx.label.name) - libs = [] - for dep in ctx.attr.cc_deps: - linker_inputs = dep[CcInfo].linking_context.linker_inputs.to_list() - for linker_input in linker_inputs: - if linker_input.libraries and linker_input.libraries[0].dynamic_library: - lib = linker_input.libraries[0].dynamic_library - libs.append(lib) - wheel = None - for w in ctx.files.wheel_rule_outputs: - if w.basename.endswith(".whl"): - wheel = w - break + wheel = ctx.file.wheel script = """ {zipper} x {wheel} -d {output} - for lib in {libs}; do - cp $lib {output}/tensorflow + for wheel_dep in {wheel_deps}; do + {zipper} x $wheel_dep -d {output} done """.format( zipper = ctx.executable.zipper.path, wheel = wheel.path, output = output_dir.path, - libs = " ".join(["'%s'" % lib.path for lib in libs]), + wheel_deps = " ".join([ + "'%s'" % wheel_dep.path + for wheel_dep in ctx.files.wheel_deps + ]), ) ctx.actions.run_shell( - inputs = ctx.files.wheel_rule_outputs + libs, + inputs = ctx.files.wheel + ctx.files.wheel_deps, command = script, outputs = [output_dir], tools = [ctx.executable.zipper], @@ -39,22 +31,26 @@ def _unpacked_wheel_impl(ctx): _unpacked_wheel = rule( implementation = _unpacked_wheel_impl, attrs = { - "wheel_rule_outputs": attr.label(mandatory = True, allow_files = True), + "wheel": attr.label(mandatory = True, allow_single_file = True), "zipper": attr.label( default = Label("@bazel_tools//tools/zip:zipper"), cfg = "exec", executable = True, ), - "cc_deps": attr.label_list(providers = [CcInfo]), + "wheel_deps": attr.label_list(allow_files = True), }, ) -def py_import(name, wheel, deps = [], cc_deps = []): +def py_import( + name, + wheel, + deps = [], + wheel_deps = []): unpacked_wheel_name = name + "_unpacked_wheel" _unpacked_wheel( name = unpacked_wheel_name, - wheel_rule_outputs = wheel, - cc_deps = cc_deps, + wheel = wheel, + wheel_deps = wheel_deps, ) native.py_library( name = name, @@ -68,6 +64,6 @@ def py_import(name, wheel, deps = [], cc_deps = []): Args: wheel: wheel file to unpack. deps: dependencies of the py_library. - cc_deps: dependencies that will be copied in the folder - with the unpacked wheel content. + wheel_deps: additional wheels to unpack. These wheels will be unpacked in the + same folder as the wheel. """ # buildifier: disable=no-effect diff --git a/third_party/xla/third_party/tsl/third_party/py/py_manylinux_compliance_test.bzl b/third_party/xla/third_party/tsl/third_party/py/py_manylinux_compliance_test.bzl new file mode 100644 index 00000000000000..e0a7e822507650 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/py/py_manylinux_compliance_test.bzl @@ -0,0 +1,25 @@ +""" Macros for manylinux compliance verification test. """ + +load("@rules_python//python:py_test.bzl", "py_test") + +def verify_manylinux_compliance_test( + name, + wheel, + aarch64_compliance_tag, + x86_64_compliance_tag, + test_tags = []): + py_test( + name = name, + srcs = [Label("//third_party/py:manylinux_compliance_test")], + data = [ + wheel, + ], + deps = ["@pypi_auditwheel//:pkg"], + args = [ + "--wheel-path=$(location {})".format(wheel), + "--aarch64-compliance-tag={}".format(aarch64_compliance_tag), + "--x86_64-compliance-tag={}".format(x86_64_compliance_tag), + ], + main = "manylinux_compliance_test.py", + tags = ["manual"] + test_tags, + ) diff --git a/third_party/xla/third_party/tsl/third_party/py/python_init_rules.bzl b/third_party/xla/third_party/tsl/third_party/py/python_init_rules.bzl index 79bc343aae489e..796ae3d92d999f 100644 --- a/third_party/xla/third_party/tsl/third_party/py/python_init_rules.bzl +++ b/third_party/xla/third_party/tsl/third_party/py/python_init_rules.bzl @@ -8,4 +8,6 @@ def python_init_rules(): sha256 = "62ddebb766b4d6ddf1712f753dac5740bea072646f630eb9982caa09ad8a7687", strip_prefix = "rules_python-0.39.0", url = "https://github.com/bazelbuild/rules_python/releases/download/0.39.0/rules_python-0.39.0.tar.gz", + patch_args = ["-p1"], + patches = [Label("//third_party/py:rules_python.patch")], ) diff --git a/third_party/xla/third_party/tsl/third_party/py/rules_python.patch b/third_party/xla/third_party/tsl/third_party/py/rules_python.patch new file mode 100644 index 00000000000000..ef7ff2fc6f8e52 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/py/rules_python.patch @@ -0,0 +1,39 @@ +diff --git a/python/private/pypi/deps.bzl b/python/private/pypi/deps.bzl +index 8949ed4a..8d0ab0e7 100644 +--- a/python/private/pypi/deps.bzl ++++ b/python/private/pypi/deps.bzl +@@ -51,8 +51,8 @@ _RULE_DEPS = [ + ), + ( + "pypi__packaging", +- "https://files.pythonhosted.org/packages/49/df/1fceb2f8900f8639e278b056416d49134fb8d84c5942ffaa01ad34782422/packaging-24.0-py3-none-any.whl", +- "2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5", ++ "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", ++ "09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", + ), + ( + "pypi__pep517", +@@ -61,8 +61,8 @@ _RULE_DEPS = [ + ), + ( + "pypi__pip", +- "https://files.pythonhosted.org/packages/8a/6a/19e9fe04fca059ccf770861c7d5721ab4c2aebc539889e97c7977528a53b/pip-24.0-py3-none-any.whl", +- "ba0d021a166865d2265246961bec0152ff124de910c5cc39f1156ce3fa7c69dc", ++ "https://files.pythonhosted.org/packages/ef/7d/500c9ad20238fcfcb4cb9243eede163594d7020ce87bd9610c9e02771876/pip-24.3.1-py3-none-any.whl", ++ "3790624780082365f47549d032f3770eeb2b1e8bd1f7b2e02dace1afa361b4ed", + ), + ( + "pypi__pip_tools", +diff --git a/python/private/pypi/evaluate_markers.bzl b/python/private/pypi/evaluate_markers.bzl +index c805fd7a..e57e6138 100644 +--- a/python/private/pypi/evaluate_markers.bzl ++++ b/python/private/pypi/evaluate_markers.bzl +@@ -20,7 +20,7 @@ load(":pypi_repo_utils.bzl", "pypi_repo_utils") + SRCS = [ + # When the version, or any of the files in `packaging` package changes, + # this file will change as well. +- Label("@pypi__packaging//:packaging-24.0.dist-info/RECORD"), ++ Label("@pypi__packaging//:packaging-24.2.dist-info/RECORD"), + Label("//python/private/pypi/requirements_parser:resolve_target_platforms.py"), + Label("//python/private/pypi/whl_installer:platform.py"), + ] \ No newline at end of file diff --git a/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pybind_extension.py.tpl b/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pybind_extension.py.tpl index 98428b51486efd..fb225d16aa0d8e 100644 --- a/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pybind_extension.py.tpl +++ b/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pybind_extension.py.tpl @@ -1,49 +1,30 @@ -import os -import re - - -def __calc_import_path(): - module_name = os.path.basename(__file__)[:-3] - outer_module_name = "" # template_val - for var in ["PYWRAP_TARGET", "TEST_TARGET"]: - path = __find_pywrap_module_by_target_label(os.environ.get(var)) - if path: - return "%s.%s%s" % (path, outer_module_name, module_name) - - for var in ["RUNFILES_MANIFEST_FILE", "RUNFILES_DIR"]: - path = __find_pywrap_module_by_runfiles_env(os.environ.get(var)) - if path: - return "%s.%s%s" % (path, outer_module_name, module_name) - - raise RuntimeError("Could not detect original test/binary location") - - -def __find_pywrap_module_by_target_label(target_label): - if target_label: - return target_label.split("//", 1)[1].split(":")[0].replace("/", ".") - return None - - -def __find_pywrap_module_by_runfiles_env(runfiles_env_var): - pattern = re.compile( - r"bazel-out/.*/bin/(?P[\w/]*)/(?P\w+)(\.exe)?\.runfiles" - ) - if runfiles_env_var: - match = pattern.search(runfiles_env_var) - return match.group("pkg").replace("/", ".") - return None - - def __update_globals(pywrap_m): if hasattr(pywrap_m, '__all__'): all_names = pywrap_m.__all__ else: all_names = [name for name in dir(pywrap_m) if not name.startswith('_')] - extra_names = [] # template_val + extra_names = [] # template_val all_names.extend(extra_names) globals().update({name: getattr(pywrap_m, name) for name in all_names}) -__pywrap_m = __import__(__calc_import_path(), fromlist=["*"]) -__update_globals(__pywrap_m) +def __try_import(): + imports_paths = [] # template_val + exceptions = [] + last_exception = None + for import_path in imports_paths: + try: + pywrap_m = __import__(import_path, fromlist=["*"]) + __update_globals(pywrap_m) + return + except ImportError as e: + exceptions.append(str(e)) + last_exception = e + pass + + raise RuntimeError(f""" +Could not import original test/binary location, import paths tried: {imports_paths}. +Previous exceptions: {exceptions}""", last_exception) + +__try_import() diff --git a/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pywrap.default.bzl b/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pywrap.default.bzl index 7aa60b07dd3329..ea1b0bb39e50b7 100644 --- a/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pywrap.default.bzl +++ b/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pywrap.default.bzl @@ -18,7 +18,6 @@ def pybind_extension( win_def_file = None, # original testonly = None, # original compatible_with = None, # original - outer_module_name = "", # deprecate additional_exported_symbols = [], data = None, # original # Garbage parameters, exist only to maingain backward compatibility for @@ -26,12 +25,6 @@ def pybind_extension( # To patch top-level deps lists in sophisticated cases pywrap_ignored_deps_filter = ["@pybind11", "@pybind11//:pybind11"], - pywrap_private_deps_filter = [ - "@pybind11_abseil//pybind11_abseil:absl_casters", - "@pybind11_abseil//pybind11_abseil:import_status_module", - "@pybind11_abseil//pybind11_abseil:status_casters", - "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", - ], pytype_srcs = None, # alias for data hdrs = [], # merge into sources pytype_deps = None, # ignore? @@ -54,7 +47,6 @@ def pybind_extension( pytype_deps, ] - private_deps_filter_dict = {k: None for k in pywrap_private_deps_filter} ignored_deps_filter_dict = {k: None for k in pywrap_ignored_deps_filter} actual_srcs = srcs + hdrs @@ -68,13 +60,10 @@ def pybind_extension( actual_private_deps = [] actual_default_deps = ["@pybind11//:pybind11"] - if type(deps) == list: + if not deps or type(deps) == list: for dep in deps: if dep in ignored_deps_filter_dict: continue - if dep in private_deps_filter_dict: - actual_private_deps.append(dep) - continue actual_deps.append(dep) else: actual_deps = deps @@ -84,12 +73,10 @@ def pybind_extension( name = name, deps = actual_deps, srcs = actual_srcs, - private_deps = actual_private_deps, visibility = visibility, win_def_file = win_def_file, testonly = testonly, compatible_with = compatible_with, - outer_module_name = outer_module_name, additional_exported_symbols = additional_exported_symbols, data = actual_data, default_deps = actual_default_deps, diff --git a/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pywrap.impl.bzl b/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pywrap.impl.bzl index c80b88ea7f76d4..3597758c95f5a5 100644 --- a/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pywrap.impl.bzl +++ b/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pywrap.impl.bzl @@ -3,11 +3,11 @@ load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain", "use_c PywrapInfo = provider( fields = { "cc_info": "Wrapped CcInfo", - "private_deps": "Libraries to link only to individual pywrap libraries, but not in commmon library", "owner": "Owner's label", + "common_lib_packages": "Packages in which to search for common pywrap library", "py_stub": "Pybind Python stub used to resolve cross-package references", - "outer_module_name": "Outer module name for deduping libraries with the same name", "cc_only": "True if this PywrapInfo represents cc-only library (no PyIni_)", + "starlark_only": "", }, ) @@ -19,22 +19,23 @@ CollectedPywrapInfo = provider( PywrapFilters = provider( fields = { - "py_cc_linker_inputs": "", - "cc_linker_inputs": "", - "pywrap_private_linker_inputs": "", + "pywrap_lib_filter": "", + "common_lib_filters": "", }, ) def pywrap_library( name, deps, - py_cc_deps_filter = [], - cc_deps_filter = [], - linkopts = [], - py_cc_linkopts = [], + starlark_only_deps = [], + pywrap_lib_filter = None, + pywrap_lib_exclusion_filter = None, + common_lib_filters = {}, + common_lib_version_scripts = {}, + common_lib_linkopts = {}, win_def_file = None, - py_cc_win_def_file = None, pywrap_count = None, + starlark_only_pywrap_count = 0, extra_deps = ["@pybind11//:pybind11"], visibility = None, testonly = None, @@ -43,6 +44,9 @@ def pywrap_library( # targets directly, so actual pywrap_count should just be equal to number # of deps. actual_pywrap_count = len(deps) if pywrap_count == None else pywrap_count + if starlark_only_deps: + starlark_only_pywrap_count = len(starlark_only_deps) + actual_deps = deps + starlark_only_deps # 1) Create common libraries cc-only (C API) and py-specific (parts reused # by different pywrap libraries but dependin on Python symbols). @@ -51,81 +55,84 @@ def pywrap_library( info_collector_name = "_%s_info_collector" % name collected_pywrap_infos( name = info_collector_name, - deps = deps, + deps = actual_deps, pywrap_count = actual_pywrap_count, + starlark_only_pywrap_count = starlark_only_pywrap_count, ) linker_input_filters_name = "_%s_linker_input_filters" % name - _linker_input_filters( - name = linker_input_filters_name, - dep = ":%s" % info_collector_name, - py_cc_deps_filter = py_cc_deps_filter, - cc_deps_filter = cc_deps_filter, - ) - # _internal binary - common_split_name = "_%s_split" % name - _pywrap_split_library( - name = common_split_name, - mode = "cc_common", - dep = ":%s" % info_collector_name, - linker_input_filters = "%s" % linker_input_filters_name, - testonly = testonly, - compatible_with = compatible_with, - ) + cur_pkg = native.package_name() + cur_pkg = cur_pkg + "/" if native.package_name() else cur_pkg + starlark_only_filter_full_name = None + if starlark_only_pywrap_count > 0: + starlark_only_filter_full_name = "%s%s__starlark_only_common" % (cur_pkg, name) - common_cc_binary_name = "%s_internal" % name - common_import_name = _construct_common_binary( - common_cc_binary_name, - [":%s" % common_split_name], - linkopts, - testonly, - compatible_with, - win_def_file, - None, - ) - - # _py_internal binary - py_common_split_name = "_%s_py_split" % name - _pywrap_split_library( - name = py_common_split_name, - mode = "py_common", + _linker_input_filters( + name = linker_input_filters_name, dep = ":%s" % info_collector_name, - linker_input_filters = "%s" % linker_input_filters_name, - testonly = testonly, - compatible_with = compatible_with, + pywrap_lib_filter = pywrap_lib_filter, + pywrap_lib_exclusion_filter = pywrap_lib_exclusion_filter, + common_lib_filters = {v: k for k, v in common_lib_filters.items()}, + starlark_only_filter_name = starlark_only_filter_full_name, ) - common_py_cc_binary_name = "%s_py_internal" % name - common_py_import_name = _construct_common_binary( - common_py_cc_binary_name, - [ - ":%s" % py_common_split_name, - ":%s" % common_import_name, - "@pybind11//:pybind11", - ], - py_cc_linkopts, - testonly, - compatible_with, - py_cc_win_def_file, - None, - ) - - common_deps = extra_deps + [ - ":%s" % common_import_name, - ":%s" % common_py_import_name, - ] - binaries_data = [ - ":%s" % common_cc_binary_name, - ":%s" % common_py_cc_binary_name, - ] + common_deps = [] + starlark_only_common_deps = [] + binaries_data = {} + starlark_only_binaries_data = {} + internal_binaries = [] + + common_lib_full_names = [] + common_lib_full_names.extend(common_lib_filters.keys()) + common_lib_full_names.append("%s%s_common" % (cur_pkg, name)) + if starlark_only_filter_full_name: + common_lib_full_names.append(starlark_only_filter_full_name) + + for common_lib_full_name in common_lib_full_names: + common_lib_pkg, common_lib_name = _get_common_lib_package_and_name( + common_lib_full_name, + ) + common_split_name = "_%s_split" % common_lib_name + _pywrap_common_split_library( + name = common_split_name, + dep = ":%s" % info_collector_name, + common_lib_full_name = common_lib_full_name, + linker_input_filters = "%s" % linker_input_filters_name, + testonly = testonly, + compatible_with = compatible_with, + ) + ver_script = common_lib_version_scripts.get(common_lib_full_name, None) + linkopts = common_lib_linkopts.get(common_lib_full_name, []) + + common_cc_binary_name = "%s" % common_lib_name + common_import_name = _construct_common_binary( + common_cc_binary_name, + [":%s" % common_split_name] + common_deps, + linkopts, + testonly, + compatible_with, + win_def_file, + None, + binaries_data.values(), + common_lib_pkg, + ver_script, + ) + actual_binaries_data = binaries_data + actual_common_deps = common_deps + if common_lib_full_name == starlark_only_filter_full_name: + actual_binaries_data = starlark_only_binaries_data + actual_common_deps = starlark_only_common_deps + internal_binaries.append(":%s" % common_cc_binary_name) + actual_binaries_data[":%s" % common_cc_binary_name] = common_lib_pkg + actual_common_deps.append(":%s" % common_import_name) # 2) Create individual super-thin pywrap libraries, which depend on the # common one. The individual libraries must link in statically only the # object file with Python Extension's init function PyInit_ # shared_objects = [] - for pywrap_index in range(0, actual_pywrap_count): + for pywrap_index in range(0, actual_pywrap_count + starlark_only_pywrap_count): dep_name = "_%s_%s" % (name, pywrap_index) shared_object_name = "%s_shared_object" % dep_name win_def_name = "%s_win_def" % dep_name @@ -133,7 +140,6 @@ def pywrap_library( _pywrap_split_library( name = pywrap_name, - mode = "pywrap", dep = ":%s" % info_collector_name, linker_input_filters = "%s" % linker_input_filters_name, pywrap_index = pywrap_index, @@ -149,10 +155,14 @@ def pywrap_library( compatible_with = compatible_with, ) + actual_common_deps = common_deps + if pywrap_index >= actual_pywrap_count: + actual_common_deps = common_deps + starlark_only_common_deps + native.cc_binary( name = shared_object_name, srcs = [], - deps = [":%s" % pywrap_name] + common_deps, + deps = [":%s" % pywrap_name] + actual_common_deps, linkshared = True, linkstatic = True, win_def_file = ":%s" % win_def_name, @@ -165,38 +175,42 @@ def pywrap_library( # attribute in a py_library, which is the final and only public artifact of # this macro # - pywrap_binaries_name = "%s_internal_binaries" % name + pywrap_binaries_name = "%s_common_binaries" % name + wheel_locations_json_name = ":%s_wheel_locations.json" % pywrap_binaries_name _pywrap_binaries( name = pywrap_binaries_name, collected_pywraps = ":%s" % info_collector_name, deps = shared_objects, + common_binaries = binaries_data, + starlark_only_common_binaries = starlark_only_binaries_data, extension = select({ "@bazel_tools//src/conditions:windows": ".pyd", "//conditions:default": ".so", }), - wheel_locations_json = ":%s_wheel_locations.json" % pywrap_binaries_name, + wheel_locations_json = wheel_locations_json_name, testonly = testonly, compatible_with = compatible_with, ) + internal_binaries.append(":%s" % pywrap_binaries_name) + internal_binaries.append(wheel_locations_json_name) - binaries_data.append("%s" % pywrap_binaries_name) - binaries_data.extend([shared_objects[0]]) + all_binaries_data = list(binaries_data.keys()) + all_binaries_data.extend(starlark_only_binaries_data.keys()) + all_binaries_data.append(":%s" % pywrap_binaries_name) + all_binaries_data.extend([shared_objects[-1]]) native.py_library( name = name, srcs = [":%s" % info_collector_name], - data = binaries_data, + data = all_binaries_data, testonly = testonly, compatible_with = compatible_with, visibility = visibility, ) - # For debugging purposes only native.filegroup( - name = "_%s_all_binaries" % name, - srcs = binaries_data, - testonly = testonly, - compatible_with = compatible_with, + name = name + "_all_binaries", + srcs = internal_binaries, ) def _construct_common_binary( @@ -206,13 +220,24 @@ def _construct_common_binary( testonly, compatible_with, win_def_file, - local_defines): + local_defines, + dependency_common_lib_packages, + dependent_common_lib_package, + version_script): + actual_linkopts = _construct_linkopt_soname(name) + _construct_linkopt_rpaths( + dependency_common_lib_packages, + dependent_common_lib_package, + ) + _construct_linkopt_version_script(version_script) + native.cc_binary( name = name, - deps = deps, + deps = deps + ([version_script] if version_script else []), linkstatic = True, linkshared = True, - linkopts = linkopts, + linkopts = linkopts + select({ + "@bazel_tools//src/conditions:windows": [], + "//conditions:default": actual_linkopts, + }), testonly = testonly, compatible_with = compatible_with, win_def_file = win_def_file, @@ -232,7 +257,8 @@ def _construct_common_binary( native.cc_import( name = import_name, shared_library = ":%s" % name, - interface_library = ":%s" % if_lib_name, + # TODO: put it back to fix Windows + # interface_library = ":%s" % if_lib_name, testonly = testonly, compatible_with = compatible_with, ) @@ -241,57 +267,32 @@ def _construct_common_binary( def _pywrap_split_library_impl(ctx): pywrap_index = ctx.attr.pywrap_index - pywrap_infos = ctx.attr.dep[CollectedPywrapInfo].pywrap_infos.to_list() + pw_list = ctx.attr.dep[CollectedPywrapInfo].pywrap_infos.to_list() + pw = pw_list[pywrap_index] + linker_inputs = pw.cc_info.linking_context.linker_inputs.to_list() + li = linker_inputs[0] + user_link_flags = li.user_link_flags + split_linker_inputs = [] private_linker_inputs = [] - - mode = ctx.attr.mode - filters = ctx.attr.linker_input_filters[PywrapFilters] - py_cc_linker_inputs = filters.py_cc_linker_inputs - - if mode == "pywrap": - pw = pywrap_infos[pywrap_index] - - # print("%s matches %s" % (str(pw.owner), ctx.label)) - if not pw.cc_only: - li = pw.cc_info.linking_context.linker_inputs.to_list()[0] - split_linker_inputs.append(li) - private_linker_inputs = [ - depset(direct = filters.pywrap_private_linker_inputs[pywrap_index].keys()), - ] - else: - for i in range(0, len(pywrap_infos)): - pw = pywrap_infos[i] - pw_private_linker_inputs = filters.pywrap_private_linker_inputs[i] - pw_lis = pw.cc_info.linking_context.linker_inputs.to_list()[1:] - for li in pw_lis: - if li in pw_private_linker_inputs: - continue - if li in filters.py_cc_linker_inputs: - if mode == "py_common": - split_linker_inputs.append(li) - elif mode == "cc_common": - split_linker_inputs.append(li) - - dependency_libraries = _construct_dependency_libraries( + if not pw.cc_only: + split_linker_inputs.append(li) + pywrap_lib_filter = ctx.attr.linker_input_filters[PywrapFilters].pywrap_lib_filter + private_lis = [] + for li in linker_inputs[1:]: + if li in pywrap_lib_filter: + private_lis.append(li) + private_linker_inputs = [ + depset(direct = private_lis), + ] + + return _construct_split_library_cc_info( ctx, split_linker_inputs, + user_link_flags, + private_linker_inputs, ) - linker_input = cc_common.create_linker_input( - owner = ctx.label, - libraries = depset(direct = dependency_libraries), - ) - - linking_context = cc_common.create_linking_context( - linker_inputs = depset( - direct = [linker_input], - transitive = private_linker_inputs, - ), - ) - - return [CcInfo(linking_context = linking_context)] - _pywrap_split_library = rule( attrs = { "dep": attr.label( @@ -305,9 +306,53 @@ _pywrap_split_library = rule( mandatory = True, ), "pywrap_index": attr.int(mandatory = False, default = -1), - "mode": attr.string( + "_cc_toolchain": attr.label( + default = "@bazel_tools//tools/cpp:current_cc_toolchain", + ), + }, + fragments = ["cpp"], + toolchains = use_cpp_toolchain(), + implementation = _pywrap_split_library_impl, +) + +def _pywrap_common_split_library_impl(ctx): + pywrap_infos = ctx.attr.dep[CollectedPywrapInfo].pywrap_infos.to_list() + split_linker_inputs = [] + + filters = ctx.attr.linker_input_filters[PywrapFilters] + + libs_to_exclude = {} + libs_to_include = {} + include_all_not_excluded = False + + if ctx.attr.common_lib_full_name not in filters.common_lib_filters: + for common_lib_filter in filters.common_lib_filters.values(): + libs_to_exclude.update(common_lib_filter) + include_all_not_excluded = True + else: + libs_to_include = filters.common_lib_filters[ctx.attr.common_lib_full_name] + + for pw in pywrap_infos: + pw_lis = pw.cc_info.linking_context.linker_inputs.to_list()[1:] + for li in pw_lis: + if li in libs_to_exclude: + continue + if include_all_not_excluded or (li in libs_to_include): + split_linker_inputs.append(li) + + return _construct_split_library_cc_info(ctx, split_linker_inputs, [], []) + +_pywrap_common_split_library = rule( + attrs = { + "dep": attr.label( + allow_files = False, + providers = [CollectedPywrapInfo], + ), + "common_lib_full_name": attr.string(mandatory = True), + "linker_input_filters": attr.label( + allow_files = False, + providers = [PywrapFilters], mandatory = True, - values = ["pywrap", "cc_common", "py_common"], ), "_cc_toolchain": attr.label( default = "@bazel_tools//tools/cpp:current_cc_toolchain", @@ -315,9 +360,34 @@ _pywrap_split_library = rule( }, fragments = ["cpp"], toolchains = use_cpp_toolchain(), - implementation = _pywrap_split_library_impl, + implementation = _pywrap_common_split_library_impl, ) +def _construct_split_library_cc_info( + ctx, + split_linker_inputs, + user_link_flags, + private_linker_inputs): + dependency_libraries = _construct_dependency_libraries( + ctx, + split_linker_inputs, + ) + + linker_input = cc_common.create_linker_input( + owner = ctx.label, + libraries = depset(direct = dependency_libraries), + user_link_flags = depset(direct = user_link_flags), + ) + + linking_context = cc_common.create_linking_context( + linker_inputs = depset( + direct = [linker_input], + transitive = private_linker_inputs, + ), + ) + + return [CcInfo(linking_context = linking_context)] + def _construct_dependency_libraries(ctx, split_linker_inputs): cc_toolchain = find_cpp_toolchain(ctx) feature_configuration = cc_common.configure_features( @@ -345,32 +415,46 @@ def _construct_dependency_libraries(ctx, split_linker_inputs): return dependency_libraries def _linker_input_filters_impl(ctx): - py_cc_linker_inputs = {} - for py_cc_dep in ctx.attr.py_cc_deps_filter: - for li in py_cc_dep[CcInfo].linking_context.linker_inputs.to_list()[:1]: - py_cc_linker_inputs[li] = li.owner - - cc_linker_inputs = {} - for cc_dep in ctx.attr.cc_deps_filter: - for li in cc_dep[CcInfo].linking_context.linker_inputs.to_list()[:1]: - cc_linker_inputs[li] = li.owner + pywrap_lib_exclusion_filter = {} + pywrap_lib_filter = {} + visited_filters = {} + if ctx.attr.pywrap_lib_exclusion_filter: + for li in ctx.attr.pywrap_lib_exclusion_filter[CcInfo].linking_context.linker_inputs.to_list(): + pywrap_lib_exclusion_filter[li] = li.owner + + if ctx.attr.pywrap_lib_filter: + for li in ctx.attr.pywrap_lib_filter[CcInfo].linking_context.linker_inputs.to_list(): + if li not in pywrap_lib_exclusion_filter: + pywrap_lib_filter[li] = li.owner + + common_lib_filters = {k: {} for k in ctx.attr.common_lib_filters.values()} + + for filter, name in ctx.attr.common_lib_filters.items(): + filter_li = filter[CcInfo].linking_context.linker_inputs.to_list() + for li in filter_li: + if li not in visited_filters: + common_lib_filters[name][li] = li.owner + visited_filters[li] = li.owner pywrap_infos = ctx.attr.dep[CollectedPywrapInfo].pywrap_infos.to_list() - pywrap_private_linker_inputs = [] + starlark_only_filter = {} - for pw in pywrap_infos: - private_linker_inputs = {} + if ctx.attr.starlark_only_filter_name: + for pw in pywrap_infos: + if pw.starlark_only: + for li in pw.cc_info.linking_context.linker_inputs.to_list()[1:]: + starlark_only_filter[li] = li.owner - for private_dep in pw.private_deps: - for priv_li in private_dep[CcInfo].linking_context.linker_inputs.to_list(): - if (priv_li not in py_cc_linker_inputs) and (priv_li not in cc_linker_inputs): - private_linker_inputs[priv_li] = priv_li.owner - pywrap_private_linker_inputs.append(private_linker_inputs) + for pw in pywrap_infos: + if not pw.starlark_only: + for li in pw.cc_info.linking_context.linker_inputs.to_list()[1:]: + starlark_only_filter.pop(li, None) + common_lib_filters[ctx.attr.starlark_only_filter_name] = starlark_only_filter return [ PywrapFilters( - py_cc_linker_inputs = py_cc_linker_inputs, - pywrap_private_linker_inputs = pywrap_private_linker_inputs, + pywrap_lib_filter = pywrap_lib_filter, + common_lib_filters = common_lib_filters, ), ] @@ -380,43 +464,43 @@ _linker_input_filters = rule( allow_files = False, providers = [CollectedPywrapInfo], ), - "py_cc_deps_filter": attr.label_list( + "pywrap_lib_filter": attr.label( allow_files = False, providers = [CcInfo], mandatory = False, - default = [], ), - "cc_deps_filter": attr.label_list( + "pywrap_lib_exclusion_filter": attr.label( allow_files = False, providers = [CcInfo], mandatory = False, - default = [], ), + "common_lib_filters": attr.label_keyed_string_dict( + allow_files = False, + providers = [CcInfo], + mandatory = False, + default = {}, + ), + "starlark_only_filter_name": attr.string(mandatory = False), }, implementation = _linker_input_filters_impl, ) -def pywrap_common_library(name, dep): +def pywrap_common_library(name, dep, filter_name = None): native.alias( name = name, - actual = "%s_internal_import" % dep, + actual = "%s_import" % (filter_name if filter_name else dep + "_common"), ) -def pywrap_py_common_library(name, dep): +def pywrap_binaries(name, dep, **kwargs): native.alias( name = name, - actual = "%s_py_internal_import" % dep, + actual = "%s_all_binaries" % dep, + **kwargs ) - -def pywrap_binaries(name, dep): - native.filegroup( - name = name, - srcs = [ - "%s_internal_binaries_wheel_locations.json" % dep, - "%s_internal_binaries" % dep, - "%s_py_internal" % dep, - "%s_internal" % dep, - ], + native.alias( + name = name + ".json", + actual = "%s_common_binaries_wheel_locations.json" % dep, + **kwargs ) def _generated_win_def_file_impl(ctx): @@ -456,20 +540,20 @@ def pybind_extension( name, deps, srcs = [], - private_deps = [], + common_lib_packages = [], visibility = None, win_def_file = None, testonly = None, compatible_with = None, - outer_module_name = "", additional_exported_symbols = [], default_deps = ["@pybind11//:pybind11"], + linkopts = [], + starlark_only = False, **kwargs): cc_library_name = "_%s_cc_library" % name - native.cc_library( name = cc_library_name, - deps = deps + private_deps + default_deps, + deps = deps + default_deps, srcs = srcs, linkstatic = True, alwayslink = True, @@ -477,6 +561,13 @@ def pybind_extension( testonly = testonly, compatible_with = compatible_with, local_defines = ["PROTOBUF_USE_DLLS", "ABSL_CONSUME_DLL"], + linkopts = linkopts + select({ + "@bazel_tools//src/conditions:windows": [], + "//conditions:default": _construct_linkopt_rpaths( + common_lib_packages + [native.package_name()], + native.package_name(), + ), + }), **kwargs ) @@ -486,15 +577,16 @@ def pybind_extension( deps = ["%s" % cc_library_name], testonly = testonly, compatible_with = compatible_with, + common_lib_packages = common_lib_packages, visibility = visibility, ) else: _pywrap_info_wrapper( name = name, deps = ["%s" % cc_library_name], - private_deps = private_deps, - outer_module_name = outer_module_name, + common_lib_packages = common_lib_packages, additional_exported_symbols = additional_exported_symbols, + starlark_only = starlark_only, testonly = testonly, compatible_with = compatible_with, visibility = visibility, @@ -502,21 +594,26 @@ def pybind_extension( def _pywrap_info_wrapper_impl(ctx): #the attribute is called deps not dep to match aspect's attr_aspects - if len(ctx.attr.deps) != 1: fail("deps attribute must contain exactly one dependency") py_stub = ctx.actions.declare_file("%s.py" % ctx.attr.name) substitutions = {} - outer_module_name = ctx.attr.outer_module_name - if outer_module_name: - val = 'outer_module_name = "%s."' % outer_module_name - substitutions['outer_module_name = "" # template_val'] = val additional_exported_symbols = ctx.attr.additional_exported_symbols + + py_pkgs = [] + for pkg in ctx.attr.common_lib_packages: + if pkg: + py_pkgs.append(pkg.replace("/", ".") + "." + ctx.attr.name) + + if py_pkgs: + val = "imports_paths = %s # template_val" % py_pkgs + substitutions["imports_paths = [] # template_val"] = val + if additional_exported_symbols: val = "extra_names = %s # template_val" % additional_exported_symbols - substitutions["extra_names = [] # template_val"] = val + substitutions["extra_names = [] # template_val"] = val ctx.actions.expand_template( template = ctx.file.py_stub_src, @@ -528,19 +625,18 @@ def _pywrap_info_wrapper_impl(ctx): PyInfo(transitive_sources = depset()), PywrapInfo( cc_info = ctx.attr.deps[0][CcInfo], - private_deps = ctx.attr.private_deps, owner = ctx.label, + common_lib_packages = ctx.attr.common_lib_packages, py_stub = py_stub, - outer_module_name = outer_module_name, cc_only = False, + starlark_only = ctx.attr.starlark_only, ), ] _pywrap_info_wrapper = rule( attrs = { "deps": attr.label_list(providers = [CcInfo]), - "private_deps": attr.label_list(providers = [CcInfo]), - "outer_module_name": attr.string(mandatory = False, default = ""), + "common_lib_packages": attr.string_list(default = []), "py_stub_src": attr.label( allow_single_file = True, default = Label("//third_party/py/rules_pywrap:pybind_extension.py.tpl"), @@ -549,6 +645,7 @@ _pywrap_info_wrapper = rule( mandatory = False, default = [], ), + "starlark_only": attr.bool(mandatory = False, default = False), }, implementation = _pywrap_info_wrapper_impl, ) @@ -559,17 +656,18 @@ def _cc_only_pywrap_info_wrapper_impl(ctx): PyInfo(transitive_sources = depset()), PywrapInfo( cc_info = wrapped_dep[CcInfo], - private_deps = [], owner = ctx.label, + common_lib_packages = ctx.attr.common_lib_packages, py_stub = None, - outer_module_name = None, cc_only = True, + starlark_only = False, ), ] _cc_only_pywrap_info_wrapper = rule( attrs = { "deps": attr.label_list(providers = [CcInfo]), + "common_lib_packages": attr.string_list(default = []), }, implementation = _cc_only_pywrap_info_wrapper_impl, ) @@ -602,42 +700,67 @@ _pywrap_info_collector_aspect = aspect( ) def _collected_pywrap_infos_impl(ctx): - pywrap_infos = [] + pywrap_depsets = [] for dep in ctx.attr.deps: if CollectedPywrapInfo in dep: - pywrap_infos.append(dep[CollectedPywrapInfo].pywrap_infos) + pywrap_depsets.append(dep[CollectedPywrapInfo].pywrap_infos) - rv = CollectedPywrapInfo( + all_pywraps = CollectedPywrapInfo( pywrap_infos = depset( - transitive = pywrap_infos, + transitive = pywrap_depsets, order = "topological", ), ) - pywraps = rv.pywrap_infos.to_list() - if ctx.attr.pywrap_count != len(pywraps): - found_pywraps = "\n ".join([str(pw.owner) for pw in pywraps]) + pywraps = [] + sl_only_pywraps = [] + py_stubs = [] + + for pw in all_pywraps.pywrap_infos.to_list(): + if pw.starlark_only: + sl_only_pywraps.append(pw) + else: + pywraps.append(pw) + if pw.py_stub: + py_stubs.append(pw.py_stub) + + pw_count = ctx.attr.pywrap_count + sl_pw_count = ctx.attr.starlark_only_pywrap_count + + if pw_count != len(pywraps) or sl_pw_count != len(sl_only_pywraps): + found_pws = "\n ".join([str(pw.owner) for pw in pywraps]) + found_sl_pws = "\n ".join([str(pw.owner) for pw in sl_only_pywraps]) fail(""" Number of actual pywrap libraries does not match expected pywrap_count. - Expected pywrap_count: {expected_pywrap_count} - Actual pywrap_count: {actual_pywra_count} - Actual pywrap libraries in the transitive closure of {label}: - {found_pywraps} + Expected regular pywrap_count: {expected_pywrap_count} + Actual regular pywrap_count: {actual_pywrap_count} + Expected starlark-only pywrap_count: {expected_starlark_only_pywrap_count} + Actual starlark-only pywrap_count: {starlark_only_pywrap_count} + Actual regualar pywrap libraries in the transitive closure of {label}: + {found_pws} + Actual starlark-only pywrap libraries in the transitive closure of {label}: + {found_sl_pws} """.format( - expected_pywrap_count = ctx.attr.pywrap_count, - actual_pywra_count = len(pywraps), + expected_pywrap_count = pw_count, + expected_starlark_only_pywrap_count = sl_pw_count, + actual_pywrap_count = len(pywraps), + starlark_only_pywrap_count = len(sl_only_pywraps), label = ctx.label, - found_pywraps = found_pywraps, + found_pws = found_pws, + found_sl_pws = found_sl_pws, )) - py_stubs = [] - for pw in pywraps: - if pw.py_stub: - py_stubs.append(pw.py_stub) + categorized_pywraps = CollectedPywrapInfo( + pywrap_infos = depset( + direct = pywraps, + transitive = [depset(sl_only_pywraps)], + order = "topological", + ), + ) return [ DefaultInfo(files = depset(direct = py_stubs)), - rv, + categorized_pywraps, ] collected_pywrap_infos = rule( @@ -647,6 +770,7 @@ collected_pywrap_infos = rule( providers = [PyInfo], ), "pywrap_count": attr.int(mandatory = True, default = 1), + "starlark_only_pywrap_count": attr.int(mandatory = True, default = 0), }, implementation = _collected_pywrap_infos_impl, ) @@ -671,8 +795,6 @@ def _pywrap_binaries_impl(ctx): pywrap_info = pywrap_infos[i] original_binary = original_binaries[i] subfolder = "" - if pywrap_info.outer_module_name: - subfolder = pywrap_info.outer_module_name + "/" final_binary_name = "%s%s%s" % (subfolder, pywrap_info.owner.name, extension) final_binary = ctx.actions.declare_file(final_binary_name) original_binary_file = original_binary.files.to_list()[0] @@ -686,23 +808,47 @@ def _pywrap_binaries_impl(ctx): ) original_to_final_binaries.append( - " '{original}' => '{final}'".format( + " '{original}' => '{final}'{starlark_only}".format( original = original_binary_file.path, final = final_binary.path, + starlark_only = " (excluded from wheel)" if pywrap_info.starlark_only else "", ), ) final_binaries.append(final_binary) - final_binary_location = "{root}{new_package}/{basename}".format( - root = final_binary.path.split(final_binary.short_path, 1)[0], - new_package = pywrap_info.owner.package, - basename = final_binary.basename, - ) + final_binary_location = "" + if not pywrap_info.cc_only and not pywrap_info.starlark_only: + final_binary_location = _construct_final_binary_location( + final_binary, + pywrap_info.owner.package, + ) + wheel_locations[final_binary.path] = final_binary_location if pywrap_info.py_stub: wheel_locations[pywrap_info.py_stub.path] = "" + for common_binary, common_binary_pkg in ctx.attr.common_binaries.items(): + final_binary = common_binary.files.to_list()[0] + final_binary_location = _construct_final_binary_location( + final_binary, + common_binary_pkg, + ) + original_to_final_binaries.append( + " common lib => '{}'".format( + final_binary.path, + ), + ) + wheel_locations[final_binary.path] = final_binary_location + for starlark_only_common_binary in ctx.attr.starlark_only_common_binaries: + final_binary = starlark_only_common_binary.files.to_list()[0] + original_to_final_binaries.append( + " common lib => '{}' (excluded from wheel)".format( + final_binary.path, + ), + ) + wheel_locations[final_binary.path] = "" + ctx.actions.write( output = ctx.outputs.wheel_locations_json, content = str(wheel_locations), @@ -715,9 +861,24 @@ def _pywrap_binaries_impl(ctx): return [DefaultInfo(files = depset(direct = final_binaries))] +def _construct_final_binary_location(final_binary, new_package): + return "{root}{new_package}/{basename}".format( + root = final_binary.path.split(final_binary.short_path, 1)[0], + new_package = new_package, + basename = final_binary.basename, + ) + _pywrap_binaries = rule( attrs = { "deps": attr.label_list(mandatory = True, allow_files = False), + "common_binaries": attr.label_keyed_string_dict( + allow_files = False, + mandatory = True, + ), + "starlark_only_common_binaries": attr.label_keyed_string_dict( + allow_files = False, + mandatory = True, + ), "collected_pywraps": attr.label(mandatory = True, allow_files = False), "extension": attr.string(default = ".so"), "wheel_locations_json": attr.output(mandatory = True), @@ -756,3 +917,43 @@ stripped_cc_info = rule( }, implementation = _stripped_cc_info_impl, ) + +def _get_common_lib_package_and_name(common_lib_full_name): + if "/" in common_lib_full_name: + return common_lib_full_name.rsplit("/", 1) + return "", common_lib_full_name + +def _construct_linkopt_soname(name): + soname = name.rsplit("/", 1)[1] if "/" in name else name + soname = soname if name.startswith("lib") else ("lib%s" % soname) + if ".so" not in name: + soname += ".so" + return ["-Wl,-soname,%s" % soname] + +def _construct_linkopt_rpaths(dependency_lib_packages, dependent_lib_package): + linkopts = {} + for dependency_lib_package in dependency_lib_packages: + origin_pkg = _construct_rpath(dependency_lib_package, dependent_lib_package) + linkopts["-rpath,'$$ORIGIN/%s'" % origin_pkg] = True + return ["-Wl," + ",".join(linkopts.keys())] if linkopts else [] + +def _construct_rpath(dependency_lib_package, dependent_lib_package): + dependency_pkg_components = dependency_lib_package.split("/") + dependent_pkg_comonents = dependent_lib_package.split("/") + min_len = min(len(dependency_pkg_components), len(dependent_pkg_comonents)) + common_prefix_i = 0 + for i in range(0, min_len): + if dependency_pkg_components[i] == dependent_pkg_comonents[i]: + common_prefix_i = i + 1 + else: + break + + levels_up = "../" * (len(dependent_pkg_comonents) - common_prefix_i) + remaining_pkg = "/".join(dependency_pkg_components[common_prefix_i:]) + + return levels_up + remaining_pkg + +def _construct_linkopt_version_script(version_script): + if not version_script: + return [] + return ["-Wl,--version-script,$(location {})".format(version_script)] diff --git a/third_party/xla/third_party/tsl/third_party/remote_config/common.bzl b/third_party/xla/third_party/tsl/third_party/remote_config/common.bzl index 57fb6fcf7aca9a..c70c0ba5b51db6 100644 --- a/third_party/xla/third_party/tsl/third_party/remote_config/common.bzl +++ b/third_party/xla/third_party/tsl/third_party/remote_config/common.bzl @@ -212,7 +212,8 @@ def execute( cmdline, error_msg = None, error_details = None, - allow_failure = False): + allow_failure = False, + env_vars = {}): """Executes an arbitrary shell command. Args: @@ -222,10 +223,11 @@ def execute( error_details: string, details about the error or steps to fix it allow_failure: bool, if True, an empty stdout result or output to stderr is fine, otherwise either of these is an error + env_vars: environment variables Returns: The result of repository_ctx.execute(cmdline) """ - result = raw_exec(repository_ctx, cmdline) + result = raw_exec(repository_ctx, cmdline, env_vars) if (result.stderr or not result.stdout) and not allow_failure: fail( "\n".join([ @@ -236,7 +238,7 @@ def execute( ) return result -def raw_exec(repository_ctx, cmdline): +def raw_exec(repository_ctx, cmdline, env_vars = {}): """Executes a command via repository_ctx.execute() and returns the result. This method is useful for debugging purposes. For example, to print all @@ -245,11 +247,12 @@ def raw_exec(repository_ctx, cmdline): Args: repository_ctx: the repository_ctx cmdline: the list of args + env_vars: environment variables Returns: The 'exec_result' of repository_ctx.execute(). """ - return repository_ctx.execute(cmdline) + return repository_ctx.execute(cmdline, environment = env_vars) def files_exist(repository_ctx, paths, bash_bin = None): """Checks which files in paths exists. diff --git a/third_party/xla/third_party/tsl/third_party/systemlibs/double_conversion.BUILD b/third_party/xla/third_party/tsl/third_party/systemlibs/double_conversion.BUILD deleted file mode 100644 index 568460181ae0bc..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/systemlibs/double_conversion.BUILD +++ /dev/null @@ -1,12 +0,0 @@ -licenses(["notice"]) - -filegroup( - name = "LICENSE", - visibility = ["//visibility:public"], -) - -cc_library( - name = "double-conversion", - linkopts = ["-ldouble-conversion"], - visibility = ["//visibility:public"], -) diff --git a/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.generate_cc.bzl b/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.generate_cc.bzl index c659ca16366b7a..aa5d18eaa9a488 100644 --- a/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.generate_cc.bzl +++ b/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.generate_cc.bzl @@ -11,6 +11,7 @@ load( "get_proto_root", "proto_path_to_generated_filename", ) +load("@rules_proto//proto:defs.bzl", "ProtoInfo") _GRPC_PROTO_HEADER_FMT = "{}.grpc.pb.h" _GRPC_PROTO_SRC_FMT = "{}.grpc.pb.cc" diff --git a/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.protobuf.bzl b/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.protobuf.bzl index 3eca97dc2311fb..cfb124ce43b1ef 100644 --- a/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.protobuf.bzl +++ b/third_party/xla/third_party/tsl/third_party/systemlibs/grpc.bazel.protobuf.bzl @@ -1,5 +1,7 @@ """Utility functions for generating protobuf code.""" +load("@rules_proto//proto:defs.bzl", "ProtoInfo") + _PROTO_EXTENSION = ".proto" _VIRTUAL_IMPORTS = "/_virtual_imports/" diff --git a/third_party/xla/third_party/tsl/third_party/systemlibs/syslibs_configure.bzl b/third_party/xla/third_party/tsl/third_party/systemlibs/syslibs_configure.bzl index f2fc22480f4989..3c734e475f412b 100644 --- a/third_party/xla/third_party/tsl/third_party/systemlibs/syslibs_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/systemlibs/syslibs_configure.bzl @@ -21,7 +21,6 @@ VALID_LIBS = [ "curl", "cython", "dill_archive", - "double_conversion", "flatbuffers", "functools32_archive", "gast_archive", diff --git a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl index 83f52d9af9970a..0079e66d203915 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl +++ b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl @@ -1,6 +1,6 @@ """Configurations of RBE builds used with remote config.""" -load("//tools/toolchains/remote_config:rbe_config.bzl", "sigbuild_tf_configs", "tensorflow_local_config", "tensorflow_rbe_config", "tensorflow_rbe_win_config") +load("//tools/toolchains/remote_config:rbe_config.bzl", "ml_build_rbe_config", "sigbuild_tf_configs", "tensorflow_local_config", "tensorflow_rbe_config", "tensorflow_rbe_win_config") def initialize_rbe_configs(): tensorflow_local_config( @@ -47,6 +47,11 @@ def initialize_rbe_configs(): python_bin_path = "C:/Python37/python.exe", ) + # The `ml-build-rbe` image is identical to the `ml-build` image except for the base image. + # The `ml-build`'s base image is a standard `ubuntu22.04` image. + # The `ml-build-rbe`'s base image is `nvidia/cuda:12.3.2-base-ubuntu22.04` which has nvidia driver installed. + ml_build_rbe_config("docker://us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-rbe@sha256:aaeb29799463729092c05f5ac8393113b3bb5d1ecf085f9f1f2016e3a1ece11c") + # TF-Version-Specific SIG Build RBE Configs. The crosstool generated from these # configs are python-version-independent because they only care about the # tooling paths; the container mapping is useful only so that TF RBE users diff --git a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl index 280b8d914283dd..dbfafdfb08c180 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl +++ b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl @@ -92,10 +92,24 @@ def _tensorflow_local_config(name): platform_constraint = "@%s_config_platform//:platform_constraint" % name, ) +def _ml_build_rbe_config(container_image): + exec_properties = { + "container-image": container_image, + "Pool": "default", + } + + remote_platform_configure( + name = "ml_build_config_platform", + platform = "linux", + platform_exec_properties = exec_properties, + ) + tensorflow_rbe_config = _tensorflow_rbe_config tensorflow_rbe_win_config = _tensorflow_rbe_win_config tensorflow_local_config = _tensorflow_local_config +ml_build_rbe_config = _ml_build_rbe_config +# TODO(b/369382309): Remove this once ml_build_rbe_config is used everywhere. # Streamlined platform configuration for the SIG Build containers. # See //tensorflow/tools/tf_sig_build_dockerfiles # These containers do not support ROCm and all have CUDA. diff --git a/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/BUILD b/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/BUILD index 93b3c90aff81d9..db4cf0eac92066 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/BUILD +++ b/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/BUILD @@ -20,24 +20,6 @@ load(":windows_cc_toolchain_config.bzl", "cc_toolchain_config") package(default_visibility = ["//visibility:public"]) -cc_library(name = "empty_lib") - -# Label flag for extra libraries to be linked into every binary. -# TODO(bazel-team): Support passing flag multiple times to build a list. -label_flag( - name = "link_extra_libs", - build_setting_default = ":empty_lib", -) - -# The final extra library to be linked into every binary target. This collects -# the above flag, but may also include more libraries depending on config. -cc_library( - name = "link_extra_lib", - deps = [ - ":link_extra_libs", - ], -) - cc_library( name = "malloc", ) @@ -228,7 +210,8 @@ cc_toolchain_config( compiler = "msvc-cl", cpu = "x64_windows", cxx_builtin_include_directories = [ - "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include", "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", @@ -240,24 +223,24 @@ cc_toolchain_config( default_link_flags = ["/MACHINE:X64"], fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", host_system_name = "local", - msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/cl.exe", - msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", - msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64", - msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\lib\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", msvc_env_tmp = "C:\\Users\\ContainerAdministrator\\AppData\\Local\\Temp", - msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/lib.exe", - msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/link.exe", - msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/ml64.exe", + msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/lib.exe", + msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/link.exe", + msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/ml64.exe", supports_parse_showincludes = True, target_libc = "msvcrt", target_system_name = "local", tool_paths = { - "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/lib.exe", - "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/ml64.exe", - "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/cl.exe", - "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/cl.exe", + "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/lib.exe", + "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/ml64.exe", + "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", + "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", "gcov": "wrapper/bin/msvc_nop.bat", - "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/link.exe", + "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/link.exe", "nm": "wrapper/bin/msvc_nop.bat", "objcopy": "wrapper/bin/msvc_nop.bat", "objdump": "wrapper/bin/msvc_nop.bat", @@ -303,7 +286,8 @@ cc_toolchain_config( compiler = "msvc-cl", cpu = "x64_windows", cxx_builtin_include_directories = [ - "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include", "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", @@ -315,24 +299,24 @@ cc_toolchain_config( default_link_flags = ["/MACHINE:X86"], fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", host_system_name = "local", - msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/cl.exe", - msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", - msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\lib\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x86", - msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\bin\\HostX64\\x86;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\lib\\x86;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x86", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x86;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", msvc_env_tmp = "C:\\Users\\ContainerAdministrator\\AppData\\Local\\Temp", - msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/lib.exe", - msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/link.exe", - msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/ml.exe", + msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/lib.exe", + msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/link.exe", + msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/ml.exe", supports_parse_showincludes = True, target_libc = "msvcrt", target_system_name = "local", tool_paths = { - "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/lib.exe", - "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/ml.exe", - "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/cl.exe", - "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/cl.exe", + "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/lib.exe", + "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/ml.exe", + "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", + "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", "gcov": "wrapper/bin/msvc_nop.bat", - "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/link.exe", + "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/link.exe", "nm": "wrapper/bin/msvc_nop.bat", "objcopy": "wrapper/bin/msvc_nop.bat", "objdump": "wrapper/bin/msvc_nop.bat", @@ -511,7 +495,8 @@ cc_toolchain_config( compiler = "clang-cl", cpu = "x64_windows", cxx_builtin_include_directories = [ - "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include", "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", @@ -521,13 +506,16 @@ cc_toolchain_config( "C:\\tools\\LLVM\\lib\\clang\\18\\include", ], dbg_mode_debug_flag = "/DEBUG", - default_link_flags = ["/MACHINE:X64"], + default_link_flags = [ + "/MACHINE:X64", + "/DEFAULTLIB:clang_rt.builtins-x86_64.lib", + ], fastbuild_mode_debug_flag = "/DEBUG", host_system_name = "local", msvc_cl_path = "C:/tools/LLVM/bin/clang-cl.exe", - msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt;C:\\tools\\LLVM\\lib\\clang\\18\\include", - msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64;C:\\tools\\LLVM\\lib\\clang\\18\\lib\\windows", - msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt;C:\\tools\\LLVM\\lib\\clang\\18\\include", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\lib\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64;C:\\tools\\LLVM\\lib\\clang\\18\\lib\\windows", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", msvc_env_tmp = "C:\\Users\\ContainerAdministrator\\AppData\\Local\\Temp", msvc_lib_path = "C:/tools/LLVM/bin/llvm-lib.exe", msvc_link_path = "C:/tools/LLVM/bin/lld-link.exe", diff --git a/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/builtin_include_directory_paths_clangcl b/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/builtin_include_directory_paths_clangcl index 0a1fb6e0df84ce..f440b6083d71fb 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/builtin_include_directory_paths_clangcl +++ b/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/builtin_include_directory_paths_clangcl @@ -3,3 +3,5 @@ that clang-cl reported. This file is a dependency of every compilation action an changes to it will be reflected in the action cache key. When some of these paths change, Bazel will make sure to rerun the action, even though none of declared action inputs or the action commandline changes. + + diff --git a/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/builtin_include_directory_paths_msvc b/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/builtin_include_directory_paths_msvc index 55ba44f761e2c1..1380bc62e15b60 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/builtin_include_directory_paths_msvc +++ b/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/builtin_include_directory_paths_msvc @@ -4,3 +4,4 @@ changes to it will be reflected in the action cache key. When some of these paths change, Bazel will make sure to rerun the action, even though none of declared action inputs or the action commandline changes. + diff --git a/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/toolchain_image_info b/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/toolchain_image_info index 807a14bebbdb44..ffa6a8e33c7933 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/toolchain_image_info +++ b/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/toolchain_image_info @@ -1,2 +1,2 @@ REPOSITORY TAG DIGEST IMAGE ID CREATED SIZE -gcr.io/tensorflow-testing/tf-win2019-docker-staging latest sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc b601adb43430 8 minutes ago 20.4GB \ No newline at end of file +gcr.io/tensorflow-testing/tf-win2019-rbe latest sha256:d3577d20dea75966faf7fd03479c71462441937df5694259109c2ee1d002a3dd b601adb43430 8 minutes ago 20.4GB \ No newline at end of file diff --git a/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/windows_cc_toolchain_config.bzl b/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/windows_cc_toolchain_config.bzl index 6d8e8af6d50e4a..03ff9b6b30078d 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/windows_cc_toolchain_config.bzl +++ b/third_party/xla/third_party/tsl/tools/toolchains/win/20240424/windows_cc_toolchain_config.bzl @@ -375,7 +375,6 @@ def _impl(ctx): compiler_param_file_feature = feature( name = "compiler_param_file", - enabled = True, ) copy_dynamic_libraries_to_binary_feature = feature( diff --git a/third_party/xla/third_party/tsl/tools/toolchains/win/BUILD b/third_party/xla/third_party/tsl/tools/toolchains/win/BUILD index 55ae6fb22b81f6..258ca032ecd1ea 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/win/BUILD +++ b/third_party/xla/third_party/tsl/tools/toolchains/win/BUILD @@ -17,7 +17,7 @@ platform( remote_execution_properties = """ properties:{ name: "container-image" - value: "docker://gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc" + value: "docker://gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:d3577d20dea75966faf7fd03479c71462441937df5694259109c2ee1d002a3dd" } properties:{ name: "OSFamily" @@ -43,7 +43,7 @@ platform( remote_execution_properties = """ properties:{ name: "container-image" - value: "docker://gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc" + value: "docker://gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:d3577d20dea75966faf7fd03479c71462441937df5694259109c2ee1d002a3dd" } properties:{ name: "OSFamily" diff --git a/third_party/xla/third_party/tsl/tools/toolchains/win2022/20241118/BUILD b/third_party/xla/third_party/tsl/tools/toolchains/win2022/20241118/BUILD new file mode 100644 index 00000000000000..7d1ac7d0dfa1f2 --- /dev/null +++ b/third_party/xla/third_party/tsl/tools/toolchains/win2022/20241118/BUILD @@ -0,0 +1,647 @@ +# Copyright 2018 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This becomes the BUILD file for @local_config_cc// under Windows. + +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_toolchain", "cc_toolchain_suite") +load(":armeabi_cc_toolchain_config.bzl", "armeabi_cc_toolchain_config") +load(":windows_cc_toolchain_config.bzl", "cc_toolchain_config") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "malloc", +) + +filegroup( + name = "empty", + srcs = [], +) + +filegroup( + name = "mingw_compiler_files", + srcs = [":builtin_include_directory_paths_mingw"], +) + +filegroup( + name = "clangcl_compiler_files", + srcs = [":builtin_include_directory_paths_clangcl"], +) + +filegroup( + name = "msvc_compiler_files", + srcs = [":builtin_include_directory_paths_msvc"], +) + +# Hardcoded toolchain, legacy behaviour. +cc_toolchain_suite( + name = "toolchain", + toolchains = { + "armeabi-v7a|compiler": ":cc-compiler-armeabi-v7a", + "x64_windows|msvc-cl": ":cc-compiler-x64_windows", + "x64_x86_windows|msvc-cl": ":cc-compiler-x64_x86_windows", + "x64_arm_windows|msvc-cl": ":cc-compiler-x64_arm_windows", + "x64_arm64_windows|msvc-cl": ":cc-compiler-arm64_windows", + "arm64_windows|msvc-cl": ":cc-compiler-arm64_windows", + "x64_windows|msys-gcc": ":cc-compiler-x64_windows_msys", + "x64_windows|mingw-gcc": ":cc-compiler-x64_windows_mingw", + "x64_windows|clang-cl": ":cc-compiler-x64_windows-clang-cl", + "x64_windows_msys": ":cc-compiler-x64_windows_msys", + "x64_windows": ":cc-compiler-x64_windows", + "x64_x86_windows": ":cc-compiler-x64_x86_windows", + "x64_arm_windows": ":cc-compiler-x64_arm_windows", + "x64_arm64_windows": ":cc-compiler-arm64_windows", + "arm64_windows": ":cc-compiler-arm64_windows", + "x64_arm64_windows|clang-cl": ":cc-compiler-arm64_windows-clang-cl", + "arm64_windows|clang-cl": ":cc-compiler-arm64_windows-clang-cl", + "armeabi-v7a": ":cc-compiler-armeabi-v7a", + }, +) + +cc_toolchain( + name = "cc-compiler-x64_windows_msys", + all_files = ":empty", + ar_files = ":empty", + as_files = ":mingw_compiler_files", + compiler_files = ":mingw_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":msys_x64", + toolchain_identifier = "msys_x64", +) + +cc_toolchain_config( + name = "msys_x64", + abi_libc_version = "local", + abi_version = "local", + compiler = "msys-gcc", + cpu = "x64_windows", + cxx_builtin_include_directories = [ + "c:/tools/msys64/usr/", + ], + dbg_mode_debug_flag = "/DEBUG:FULL", + fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", + host_system_name = "local", + target_libc = "msys", + target_system_name = "local", + tool_bin_path = "c:/tools/msys64/usr/bin", + tool_paths = { + "ar": "c:/tools/msys64/usr/bin/ar", + "cpp": "c:/tools/msys64/usr/bin/cpp", + "dwp": "c:/tools/msys64/usr/bin/dwp", + "gcc": "c:/tools/msys64/usr/bin/gcc", + "gcov": "c:/tools/msys64/usr/bin/gcov", + "ld": "c:/tools/msys64/usr/bin/ld", + "nm": "c:/tools/msys64/usr/bin/nm", + "objcopy": "c:/tools/msys64/usr/bin/objcopy", + "objdump": "c:/tools/msys64/usr/bin/objdump", + "strip": "c:/tools/msys64/usr/bin/strip", + }, +) + +toolchain( + name = "cc-toolchain-x64_windows_msys", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + "@bazel_tools//tools/cpp:msys", + ], + target_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_windows_msys", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-x64_windows_mingw", + all_files = ":empty", + ar_files = ":empty", + as_files = ":mingw_compiler_files", + compiler_files = ":mingw_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 0, + toolchain_config = ":msys_x64_mingw", + toolchain_identifier = "msys_x64_mingw", +) + +cc_toolchain_config( + name = "msys_x64_mingw", + abi_libc_version = "local", + abi_version = "local", + compiler = "mingw-gcc", + cpu = "x64_windows", + cxx_builtin_include_directories = [ + "c:/tools/msys64/mingw64/", + ], + dbg_mode_debug_flag = "/DEBUG:FULL", + fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", + host_system_name = "local", + target_libc = "mingw", + target_system_name = "local", + tool_bin_path = "c:/tools/msys64/mingw64/bin", + tool_paths = { + "ar": "c:/tools/msys64/mingw64/bin/ar", + "cpp": "c:/tools/msys64/mingw64/bin/cpp", + "dwp": "c:/tools/msys64/mingw64/bin/dwp", + "gcc": "c:/tools/msys64/mingw64/bin/gcc", + "gcov": "c:/tools/msys64/mingw64/bin/gcov", + "ld": "c:/tools/msys64/mingw64/bin/ld", + "nm": "c:/tools/msys64/mingw64/bin/nm", + "objcopy": "c:/tools/msys64/mingw64/bin/objcopy", + "objdump": "c:/tools/msys64/mingw64/bin/objdump", + "strip": "c:/tools/msys64/mingw64/bin/strip", + }, +) + +toolchain( + name = "cc-toolchain-x64_windows_mingw", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + "@bazel_tools//tools/cpp:mingw", + ], + target_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_windows_mingw", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-x64_windows", + all_files = ":empty", + ar_files = ":empty", + as_files = ":msvc_compiler_files", + compiler_files = ":msvc_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":msvc_x64", + toolchain_identifier = "msvc_x64", +) + +cc_toolchain_config( + name = "msvc_x64", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:X64"], + compiler = "msvc-cl", + cpu = "x64_windows", + cxx_builtin_include_directories = [ + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", + "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + ], + dbg_mode_debug_flag = "/DEBUG:FULL", + default_link_flags = ["/MACHINE:X64"], + fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", + host_system_name = "local", + msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_env_tmp = "C:\\TMP", + msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/lib.exe", + msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/link.exe", + msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/ml64.exe", + supports_parse_showincludes = True, + target_libc = "msvcrt", + target_system_name = "local", + tool_paths = { + "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/lib.exe", + "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/ml64.exe", + "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", + "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/link.exe", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "msvc_x64", +) + +toolchain( + name = "cc-toolchain-x64_windows", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + target_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_windows", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-x64_x86_windows", + all_files = ":empty", + ar_files = ":empty", + as_files = ":msvc_compiler_files", + compiler_files = ":msvc_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":msvc_x64_x86", + toolchain_identifier = "msvc_x64_x86", +) + +cc_toolchain_config( + name = "msvc_x64_x86", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:X86"], + compiler = "msvc-cl", + cpu = "x64_windows", + cxx_builtin_include_directories = [ + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", + "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + ], + dbg_mode_debug_flag = "/DEBUG:FULL", + default_link_flags = ["/MACHINE:X86"], + fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", + host_system_name = "local", + msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x86", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x86;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_env_tmp = "C:\\TMP", + msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/lib.exe", + msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/link.exe", + msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/ml.exe", + supports_parse_showincludes = True, + target_libc = "msvcrt", + target_system_name = "local", + tool_paths = { + "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/lib.exe", + "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/ml.exe", + "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", + "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/link.exe", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "msvc_x64_x86", +) + +toolchain( + name = "cc-toolchain-x64_x86_windows", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + target_compatible_with = [ + "@platforms//cpu:x86_32", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_x86_windows", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-x64_arm_windows", + all_files = ":empty", + ar_files = ":empty", + as_files = ":msvc_compiler_files", + compiler_files = ":msvc_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":msvc_x64_arm", + toolchain_identifier = "msvc_x64_arm", +) + +cc_toolchain_config( + name = "msvc_x64_arm", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:ARM"], + compiler = "msvc-cl", + cpu = "x64_windows", + cxx_builtin_include_directories = [], + dbg_mode_debug_flag = "/DEBUG", + default_link_flags = ["/MACHINE:ARM"], + fastbuild_mode_debug_flag = "/DEBUG", + host_system_name = "local", + msvc_cl_path = "vc_installation_error_arm.bat", + msvc_env_include = "msvc_not_found", + msvc_env_lib = "msvc_not_found", + msvc_env_path = "msvc_not_found", + msvc_env_tmp = "msvc_not_found", + msvc_lib_path = "vc_installation_error_arm.bat", + msvc_link_path = "vc_installation_error_arm.bat", + msvc_ml_path = "vc_installation_error_arm.bat", + supports_parse_showincludes = False, + target_libc = "msvcrt", + target_system_name = "local", + tool_paths = { + "ar": "vc_installation_error_arm.bat", + "ml": "vc_installation_error_arm.bat", + "cpp": "vc_installation_error_arm.bat", + "gcc": "vc_installation_error_arm.bat", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "vc_installation_error_arm.bat", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "msvc_x64_arm", +) + +toolchain( + name = "cc-toolchain-x64_arm_windows", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + target_compatible_with = [ + "@platforms//cpu:arm", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_arm_windows", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-arm64_windows", + all_files = ":empty", + ar_files = ":empty", + as_files = ":msvc_compiler_files", + compiler_files = ":msvc_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":msvc_arm64", + toolchain_identifier = "msvc_arm64", +) + +cc_toolchain_config( + name = "msvc_arm64", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:ARM64"], + compiler = "msvc-cl", + cpu = "x64_windows", + cxx_builtin_include_directories = [], + dbg_mode_debug_flag = "/DEBUG", + default_link_flags = ["/MACHINE:ARM64"], + fastbuild_mode_debug_flag = "/DEBUG", + host_system_name = "local", + msvc_cl_path = "vc_installation_error_arm64.bat", + msvc_env_include = "msvc_not_found", + msvc_env_lib = "msvc_not_found", + msvc_env_path = "msvc_not_found", + msvc_env_tmp = "msvc_not_found", + msvc_lib_path = "vc_installation_error_arm64.bat", + msvc_link_path = "vc_installation_error_arm64.bat", + msvc_ml_path = "vc_installation_error_arm64.bat", + supports_parse_showincludes = False, + target_libc = "msvcrt", + target_system_name = "local", + tool_paths = { + "ar": "vc_installation_error_arm64.bat", + "ml": "vc_installation_error_arm64.bat", + "cpp": "vc_installation_error_arm64.bat", + "gcc": "vc_installation_error_arm64.bat", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "vc_installation_error_arm64.bat", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "msvc_arm64", +) + +toolchain( + name = "cc-toolchain-arm64_windows", + exec_compatible_with = [ + "@platforms//os:windows", + ], + target_compatible_with = [ + "@platforms//cpu:arm64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-arm64_windows", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-x64_windows-clang-cl", + all_files = ":empty", + ar_files = ":empty", + as_files = ":clangcl_compiler_files", + compiler_files = ":clangcl_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":clang_cl_x64", + toolchain_identifier = "clang_cl_x64", +) + +cc_toolchain_config( + name = "clang_cl_x64", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:X64"], + compiler = "clang-cl", + cpu = "x64_windows", + cxx_builtin_include_directories = [ + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", + "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + "C:\\tools\\LLVM\\lib\\clang\\18\\include", + ], + dbg_mode_debug_flag = "/DEBUG", + default_link_flags = [ + "/MACHINE:X64", + "/DEFAULTLIB:clang_rt.builtins-x86_64.lib", + ], + fastbuild_mode_debug_flag = "/DEBUG", + host_system_name = "local", + msvc_cl_path = "C:/tools/LLVM/bin/clang-cl.exe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt;C:\\tools\\LLVM\\lib\\clang\\18\\include", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64;C:\\tools\\LLVM\\lib\\clang\\18\\lib\\windows", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_env_tmp = "C:\\TMP", + msvc_lib_path = "C:/tools/LLVM/bin/llvm-lib.exe", + msvc_link_path = "C:/tools/LLVM/bin/lld-link.exe", + msvc_ml_path = "C:/tools/LLVM/bin/clang-cl.exe", + supports_parse_showincludes = True, + target_libc = "msvcrt", + target_system_name = "local", + tool_paths = { + "ar": "C:/tools/LLVM/bin/llvm-lib.exe", + "ml": "C:/tools/LLVM/bin/clang-cl.exe", + "cpp": "C:/tools/LLVM/bin/clang-cl.exe", + "gcc": "C:/tools/LLVM/bin/clang-cl.exe", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "C:/tools/LLVM/bin/lld-link.exe", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "clang_cl_x64", +) + +toolchain( + name = "cc-toolchain-x64_windows-clang-cl", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + "@bazel_tools//tools/cpp:clang-cl", + ], + target_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_windows-clang-cl", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-arm64_windows-clang-cl", + all_files = ":empty", + ar_files = ":empty", + as_files = ":clangcl_compiler_files", + compiler_files = ":clangcl_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":clang_cl_arm64", + toolchain_identifier = "clang_cl_arm64", +) + +cc_toolchain_config( + name = "clang_cl_arm64", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:ARM64"], + compiler = "clang-cl", + cpu = "arm64_windows", + cxx_builtin_include_directories = [], + dbg_mode_debug_flag = "/DEBUG", + default_link_flags = ["/MACHINE:ARM64"], + fastbuild_mode_debug_flag = "/DEBUG", + host_system_name = "local", + msvc_cl_path = "vc_installation_error_arm64.bat", + msvc_env_include = "clang_cl_not_found", + msvc_env_lib = "clang_cl_not_found", + msvc_env_path = "clang_cl_not_found", + msvc_env_tmp = "clang_cl_not_found", + msvc_lib_path = "vc_installation_error_arm64.bat", + msvc_link_path = "vc_installation_error_arm64.bat", + msvc_ml_path = "vc_installation_error_arm64.bat", + supports_parse_showincludes = False, + target_libc = "msvcrt", + target_system_name = "aarch64-pc-windows-msvc", + tool_paths = { + "ar": "vc_installation_error_arm64.bat", + "ml": "vc_installation_error_arm64.bat", + "cpp": "vc_installation_error_arm64.bat", + "gcc": "vc_installation_error_arm64.bat", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "vc_installation_error_arm64.bat", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "clang_cl_arm64", +) + +toolchain( + name = "cc-toolchain-arm64_windows-clang-cl", + exec_compatible_with = [ + "@platforms//os:windows", + "@bazel_tools//tools/cpp:clang-cl", + ], + target_compatible_with = [ + "@platforms//cpu:arm64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-arm64_windows-clang-cl", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-armeabi-v7a", + all_files = ":empty", + ar_files = ":empty", + as_files = ":empty", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":stub_armeabi-v7a", + toolchain_identifier = "stub_armeabi-v7a", +) + +armeabi_cc_toolchain_config(name = "stub_armeabi-v7a") + +toolchain( + name = "cc-toolchain-armeabi-v7a", + exec_compatible_with = [ + ], + target_compatible_with = [ + "@platforms//cpu:armv7", + "@platforms//os:android", + ], + toolchain = ":cc-compiler-armeabi-v7a", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) diff --git a/third_party/xla/third_party/tsl/tools/toolchains/win2022/20241118/armeabi_cc_toolchain_config.bzl b/third_party/xla/third_party/tsl/tools/toolchains/win2022/20241118/armeabi_cc_toolchain_config.bzl new file mode 100644 index 00000000000000..72ef48ae6d6dfc --- /dev/null +++ b/third_party/xla/third_party/tsl/tools/toolchains/win2022/20241118/armeabi_cc_toolchain_config.bzl @@ -0,0 +1,82 @@ +# Copyright 2019 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Starlark cc_toolchain configuration rule""" + +load( + "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl", + "feature", + "tool_path", +) + +def _impl(ctx): + toolchain_identifier = "stub_armeabi-v7a" + host_system_name = "armeabi-v7a" + target_system_name = "armeabi-v7a" + target_cpu = "armeabi-v7a" + target_libc = "armeabi-v7a" + compiler = "compiler" + abi_version = "armeabi-v7a" + abi_libc_version = "armeabi-v7a" + cc_target_os = None + builtin_sysroot = None + action_configs = [] + + supports_pic_feature = feature(name = "supports_pic", enabled = True) + supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True) + features = [supports_dynamic_linker_feature, supports_pic_feature] + + cxx_builtin_include_directories = [] + artifact_name_patterns = [] + make_variables = [] + + tool_paths = [ + tool_path(name = "ar", path = "/bin/false"), + tool_path(name = "cpp", path = "/bin/false"), + tool_path(name = "dwp", path = "/bin/false"), + tool_path(name = "gcc", path = "/bin/false"), + tool_path(name = "gcov", path = "/bin/false"), + tool_path(name = "ld", path = "/bin/false"), + tool_path(name = "llvm-profdata", path = "/bin/false"), + tool_path(name = "nm", path = "/bin/false"), + tool_path(name = "objcopy", path = "/bin/false"), + tool_path(name = "objdump", path = "/bin/false"), + tool_path(name = "strip", path = "/bin/false"), + ] + + return cc_common.create_cc_toolchain_config_info( + ctx = ctx, + features = features, + action_configs = action_configs, + artifact_name_patterns = artifact_name_patterns, + cxx_builtin_include_directories = cxx_builtin_include_directories, + toolchain_identifier = toolchain_identifier, + host_system_name = host_system_name, + target_system_name = target_system_name, + target_cpu = target_cpu, + target_libc = target_libc, + compiler = compiler, + abi_version = abi_version, + abi_libc_version = abi_libc_version, + tool_paths = tool_paths, + make_variables = make_variables, + builtin_sysroot = builtin_sysroot, + cc_target_os = cc_target_os, + ) + +armeabi_cc_toolchain_config = rule( + implementation = _impl, + attrs = {}, + provides = [CcToolchainConfigInfo], +) diff --git a/third_party/xla/third_party/tsl/tools/toolchains/win2022/20241118/builtin_include_directory_paths_clangcl b/third_party/xla/third_party/tsl/tools/toolchains/win2022/20241118/builtin_include_directory_paths_clangcl new file mode 100644 index 00000000000000..f440b6083d71fb --- /dev/null +++ b/third_party/xla/third_party/tsl/tools/toolchains/win2022/20241118/builtin_include_directory_paths_clangcl @@ -0,0 +1,7 @@ +This file is generated by cc_configure and contains builtin include directories +that clang-cl reported. This file is a dependency of every compilation action and +changes to it will be reflected in the action cache key. When some of these +paths change, Bazel will make sure to rerun the action, even though none of +declared action inputs or the action commandline changes. + + diff --git a/third_party/xla/third_party/tsl/tools/toolchains/win2022/20241118/builtin_include_directory_paths_msvc b/third_party/xla/third_party/tsl/tools/toolchains/win2022/20241118/builtin_include_directory_paths_msvc new file mode 100644 index 00000000000000..1380bc62e15b60 --- /dev/null +++ b/third_party/xla/third_party/tsl/tools/toolchains/win2022/20241118/builtin_include_directory_paths_msvc @@ -0,0 +1,7 @@ +This file is generated by cc_configure and contains builtin include directories +that msvc reported. This file is a dependency of every compilation action and +changes to it will be reflected in the action cache key. When some of these +paths change, Bazel will make sure to rerun the action, even though none of +declared action inputs or the action commandline changes. + + diff --git a/third_party/xla/third_party/tsl/tools/toolchains/win2022/20241118/windows_cc_toolchain_config.bzl b/third_party/xla/third_party/tsl/tools/toolchains/win2022/20241118/windows_cc_toolchain_config.bzl new file mode 100644 index 00000000000000..03ff9b6b30078d --- /dev/null +++ b/third_party/xla/third_party/tsl/tools/toolchains/win2022/20241118/windows_cc_toolchain_config.bzl @@ -0,0 +1,1442 @@ +# Copyright 2019 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Starlark cc_toolchain configuration rule for Windows""" + +load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES") +load( + "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl", + "action_config", + "artifact_name_pattern", + "env_entry", + "env_set", + "feature", + "flag_group", + "flag_set", + "tool", + "tool_path", + "variable_with_value", + "with_feature_set", +) + +all_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.clif_match, + ACTION_NAMES.lto_backend, +] + +all_cpp_compile_actions = [ + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.clif_match, +] + +preprocessor_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.clif_match, +] + +codegen_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, +] + +all_link_actions = [ + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, +] + +def _use_msvc_toolchain(ctx): + return ctx.attr.cpu in ["x64_windows", "arm64_windows"] and (ctx.attr.compiler == "msvc-cl" or ctx.attr.compiler == "clang-cl") + +def _impl(ctx): + if _use_msvc_toolchain(ctx): + artifact_name_patterns = [ + artifact_name_pattern( + category_name = "object_file", + prefix = "", + extension = ".obj", + ), + artifact_name_pattern( + category_name = "static_library", + prefix = "", + extension = ".lib", + ), + artifact_name_pattern( + category_name = "alwayslink_static_library", + prefix = "", + extension = ".lo.lib", + ), + artifact_name_pattern( + category_name = "executable", + prefix = "", + extension = ".exe", + ), + artifact_name_pattern( + category_name = "dynamic_library", + prefix = "", + extension = ".dll", + ), + artifact_name_pattern( + category_name = "interface_library", + prefix = "", + extension = ".if.lib", + ), + ] + else: + artifact_name_patterns = [ + artifact_name_pattern( + category_name = "executable", + prefix = "", + extension = ".exe", + ), + ] + + if _use_msvc_toolchain(ctx): + cpp_link_nodeps_dynamic_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_nodeps_dynamic_library, + implies = [ + "nologo", + "shared_flag", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + "has_configured_linker_path", + "def_file", + ], + tools = [tool(path = ctx.attr.msvc_link_path)], + ) + + cpp_link_static_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_static_library, + implies = [ + "nologo", + "archiver_flags", + "input_param_flags", + "linker_param_file", + "msvc_env", + ], + tools = [tool(path = ctx.attr.msvc_lib_path)], + ) + + assemble_action = action_config( + action_name = ACTION_NAMES.assemble, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "sysroot", + ], + tools = [tool(path = ctx.attr.msvc_ml_path)], + ) + + preprocess_assemble_action = action_config( + action_name = ACTION_NAMES.preprocess_assemble, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "sysroot", + ], + tools = [tool(path = ctx.attr.msvc_ml_path)], + ) + + c_compile_action = action_config( + action_name = ACTION_NAMES.c_compile, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "user_compile_flags", + "sysroot", + ], + tools = [tool(path = ctx.attr.msvc_cl_path)], + ) + + linkstamp_compile_action = action_config( + action_name = ACTION_NAMES.linkstamp_compile, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "default_compile_flags", + "nologo", + "msvc_env", + "user_compile_flags", + "sysroot", + "unfiltered_compile_flags", + ], + tools = [tool(path = ctx.attr.msvc_cl_path)], + ) + + cpp_compile_action = action_config( + action_name = ACTION_NAMES.cpp_compile, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "user_compile_flags", + "sysroot", + ], + tools = [tool(path = ctx.attr.msvc_cl_path)], + ) + + cpp_link_executable_action = action_config( + action_name = ACTION_NAMES.cpp_link_executable, + implies = [ + "nologo", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + ], + tools = [tool(path = ctx.attr.msvc_link_path)], + ) + + cpp_link_dynamic_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_dynamic_library, + implies = [ + "nologo", + "shared_flag", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + "has_configured_linker_path", + "def_file", + ], + tools = [tool(path = ctx.attr.msvc_link_path)], + ) + + action_configs = [ + assemble_action, + preprocess_assemble_action, + c_compile_action, + linkstamp_compile_action, + cpp_compile_action, + cpp_link_executable_action, + cpp_link_dynamic_library_action, + cpp_link_nodeps_dynamic_library_action, + cpp_link_static_library_action, + ] + else: + action_configs = [] + + if _use_msvc_toolchain(ctx): + msvc_link_env_feature = feature( + name = "msvc_link_env", + env_sets = [ + env_set( + actions = all_link_actions + + [ACTION_NAMES.cpp_link_static_library], + env_entries = [env_entry(key = "LIB", value = ctx.attr.msvc_env_lib)], + ), + ], + ) + + shared_flag_feature = feature( + name = "shared_flag", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [flag_group(flags = ["/DLL"])], + ), + ], + ) + + determinism_feature = feature( + name = "determinism", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "/wd4117", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + ] + (["-Wno-builtin-macro-redefined"] if ctx.attr.compiler == "clang-cl" else []), + ), + ], + ), + ], + ) + + sysroot_feature = feature( + name = "sysroot", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = ["--sysroot=%{sysroot}"], + iterate_over = "sysroot", + expand_if_available = "sysroot", + ), + ], + ), + ], + ) + + unfiltered_compile_flags_feature = feature( + name = "unfiltered_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["%{unfiltered_compile_flags}"], + iterate_over = "unfiltered_compile_flags", + expand_if_available = "unfiltered_compile_flags", + ), + ], + ), + ], + ) + + archive_param_file_feature = feature( + name = "archive_param_file", + enabled = True, + ) + + compiler_param_file_feature = feature( + name = "compiler_param_file", + ) + + copy_dynamic_libraries_to_binary_feature = feature( + name = "copy_dynamic_libraries_to_binary", + ) + + input_param_flags_feature = feature( + name = "input_param_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = ["/IMPLIB:%{interface_library_output_path}"], + expand_if_available = "interface_library_output_path", + ), + ], + ), + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["%{libopts}"], + iterate_over = "libopts", + expand_if_available = "libopts", + ), + ], + ), + flag_set( + actions = all_link_actions + + [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + iterate_over = "libraries_to_link", + flag_groups = [ + flag_group( + iterate_over = "libraries_to_link.object_files", + flag_groups = [flag_group(flags = ["%{libraries_to_link.object_files}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file_group", + ), + ), + flag_group( + flag_groups = [flag_group(flags = ["%{libraries_to_link.name}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file", + ), + ), + flag_group( + flag_groups = [flag_group(flags = ["%{libraries_to_link.name}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "interface_library", + ), + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["%{libraries_to_link.name}"], + expand_if_false = "libraries_to_link.is_whole_archive", + ), + flag_group( + flags = ["/WHOLEARCHIVE:%{libraries_to_link.name}"], + expand_if_true = "libraries_to_link.is_whole_archive", + ), + ], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "static_library", + ), + ), + ], + expand_if_available = "libraries_to_link", + ), + ], + ), + ], + ) + + fastbuild_feature = feature( + name = "fastbuild", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Od", "/Z7"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = [ctx.attr.fastbuild_mode_debug_flag, "/INCREMENTAL:NO"], + ), + ], + ), + ], + implies = ["generate_pdb_file"], + ) + + user_compile_flags_feature = feature( + name = "user_compile_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["%{user_compile_flags}"], + iterate_over = "user_compile_flags", + expand_if_available = "user_compile_flags", + ), + ], + ), + ], + ) + + archiver_flags_feature = feature( + name = "archiver_flags", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + flags = ["/OUT:%{output_execpath}"], + expand_if_available = "output_execpath", + ), + flag_group( + flags = ctx.attr.archiver_flags, + ), + ], + ), + ], + ) + + default_link_flags_feature = feature( + name = "default_link_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ctx.attr.default_link_flags)], + ), + ], + ) + + static_link_msvcrt_feature = feature( + name = "static_link_msvcrt", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MT"])], + with_features = [with_feature_set(not_features = ["dbg"])], + ), + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MTd"])], + with_features = [with_feature_set(features = ["dbg"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmt.lib"])], + with_features = [with_feature_set(not_features = ["dbg"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmtd.lib"])], + with_features = [with_feature_set(features = ["dbg"])], + ), + ], + ) + + dynamic_link_msvcrt_feature = feature( + name = "dynamic_link_msvcrt", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MD"])], + with_features = [with_feature_set(not_features = ["dbg", "static_link_msvcrt"])], + ), + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MDd"])], + with_features = [with_feature_set(features = ["dbg"], not_features = ["static_link_msvcrt"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrt.lib"])], + with_features = [with_feature_set(not_features = ["dbg", "static_link_msvcrt"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrtd.lib"])], + with_features = [with_feature_set(features = ["dbg"], not_features = ["static_link_msvcrt"])], + ), + ], + ) + + dbg_feature = feature( + name = "dbg", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Od", "/Z7"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = [ctx.attr.dbg_mode_debug_flag, "/INCREMENTAL:NO"], + ), + ], + ), + ], + implies = ["generate_pdb_file"], + ) + + opt_feature = feature( + name = "opt", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/O2"])], + ), + ], + implies = ["frame_pointer"], + ) + + supports_interface_shared_libraries_feature = feature( + name = "supports_interface_shared_libraries", + enabled = True, + ) + + user_link_flags_feature = feature( + name = "user_link_flags", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["%{user_link_flags}"], + iterate_over = "user_link_flags", + expand_if_available = "user_link_flags", + ), + ], + ), + ], + ) + + default_compile_flags_feature = feature( + name = "default_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = [ + "/DCOMPILER_MSVC", + "/DNOMINMAX", + "/D_WIN32_WINNT=0x0601", + "/D_CRT_SECURE_NO_DEPRECATE", + "/D_CRT_SECURE_NO_WARNINGS", + "/bigobj", + "/Zm500", + "/EHsc", + "/wd4351", + "/wd4291", + "/wd4250", + "/wd4996", + ], + ), + ], + ), + ], + ) + + msvc_compile_env_feature = feature( + name = "msvc_compile_env", + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ], + env_entries = [env_entry(key = "INCLUDE", value = ctx.attr.msvc_env_include)], + ), + ], + ) + + preprocessor_defines_feature = feature( + name = "preprocessor_defines", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group( + flags = ["/D%{preprocessor_defines}"], + iterate_over = "preprocessor_defines", + ), + ], + ), + ], + ) + + generate_pdb_file_feature = feature( + name = "generate_pdb_file", + ) + + output_execpath_flags_feature = feature( + name = "output_execpath_flags", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["/OUT:%{output_execpath}"], + expand_if_available = "output_execpath", + ), + ], + ), + ], + ) + + disable_assertions_feature = feature( + name = "disable_assertions", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/DNDEBUG"])], + with_features = [with_feature_set(features = ["opt"])], + ), + ], + ) + + has_configured_linker_path_feature = feature(name = "has_configured_linker_path") + + supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True) + + no_stripping_feature = feature(name = "no_stripping") + + linker_param_file_feature = feature( + name = "linker_param_file", + flag_sets = [ + flag_set( + actions = all_link_actions + + [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + flags = ["@%{linker_param_file}"], + expand_if_available = "linker_param_file", + ), + ], + ), + ], + ) + + ignore_noisy_warnings_feature = feature( + name = "ignore_noisy_warnings", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = [flag_group(flags = ["/ignore:4221"])], + ), + ], + ) + + no_legacy_features_feature = feature(name = "no_legacy_features") + + parse_showincludes_feature = feature( + name = "parse_showincludes", + enabled = ctx.attr.supports_parse_showincludes, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_header_parsing, + ], + flag_groups = [flag_group(flags = ["/showIncludes"])], + ), + ], + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_header_parsing, + ], + # Force English (and thus a consistent locale) output so that Bazel can parse + # the /showIncludes output without having to guess the encoding. + env_entries = [env_entry(key = "VSLANG", value = "1033")], + ), + ], + ) + + # MSVC does not emit .d files. + no_dotd_file_feature = feature( + name = "no_dotd_file", + enabled = True, + ) + + treat_warnings_as_errors_feature = feature( + name = "treat_warnings_as_errors", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile] + all_link_actions, + flag_groups = [flag_group(flags = ["/WX"])], + ), + ], + ) + + windows_export_all_symbols_feature = feature(name = "windows_export_all_symbols") + + no_windows_export_all_symbols_feature = feature(name = "no_windows_export_all_symbols") + + include_paths_feature = feature( + name = "include_paths", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group( + flags = ["/I%{quote_include_paths}"], + iterate_over = "quote_include_paths", + ), + flag_group( + flags = ["/I%{include_paths}"], + iterate_over = "include_paths", + ), + flag_group( + flags = ["/I%{system_include_paths}"], + iterate_over = "system_include_paths", + ), + ], + ), + ], + ) + + external_include_paths_feature = feature( + name = "external_include_paths", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.clif_match, + ACTION_NAMES.objc_compile, + ACTION_NAMES.objcpp_compile, + ], + flag_groups = [ + flag_group( + flags = ["/external:I", "%{external_include_paths}"], + iterate_over = "external_include_paths", + expand_if_available = "external_include_paths", + ), + ], + ), + ], + ) + + linkstamps_feature = feature( + name = "linkstamps", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["%{linkstamp_paths}"], + iterate_over = "linkstamp_paths", + expand_if_available = "linkstamp_paths", + ), + ], + ), + ], + ) + + targets_windows_feature = feature( + name = "targets_windows", + enabled = True, + implies = ["copy_dynamic_libraries_to_binary"], + ) + + linker_subsystem_flag_feature = feature( + name = "linker_subsystem_flag", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/SUBSYSTEM:CONSOLE"])], + ), + ], + ) + + frame_pointer_feature = feature( + name = "frame_pointer", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Oy-"])], + ), + ], + ) + + compiler_output_flags_feature = feature( + name = "compiler_output_flags", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.assemble], + flag_groups = [ + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fo%{output_file}", "/Zi"], + expand_if_available = "output_file", + expand_if_not_available = "output_assembly_file", + ), + ], + expand_if_not_available = "output_preprocess_file", + ), + ], + ), + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fo%{output_file}"], + expand_if_not_available = "output_preprocess_file", + ), + ], + expand_if_available = "output_file", + expand_if_not_available = "output_assembly_file", + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fa%{output_file}"], + expand_if_available = "output_assembly_file", + ), + ], + expand_if_available = "output_file", + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["/P", "/Fi%{output_file}"], + expand_if_available = "output_preprocess_file", + ), + ], + expand_if_available = "output_file", + ), + ], + ), + ], + ) + + nologo_feature = feature( + name = "nologo", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_static_library, + ], + flag_groups = [flag_group(flags = ["/nologo"])], + ), + ], + ) + + smaller_binary_feature = feature( + name = "smaller_binary", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Gy", "/Gw"])], + with_features = [with_feature_set(features = ["opt"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/OPT:ICF", "/OPT:REF"])], + with_features = [with_feature_set(features = ["opt"])], + ), + ], + ) + + compiler_input_flags_feature = feature( + name = "compiler_input_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["/c", "%{source_file}"], + expand_if_available = "source_file", + ), + ], + ), + ], + ) + + def_file_feature = feature( + name = "def_file", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["/DEF:%{def_file_path}", "/ignore:4070"], + expand_if_available = "def_file_path", + ), + ], + ), + ], + ) + + msvc_env_feature = feature( + name = "msvc_env", + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_static_library, + ], + env_entries = [ + env_entry(key = "PATH", value = ctx.attr.msvc_env_path), + env_entry(key = "TMP", value = ctx.attr.msvc_env_tmp), + env_entry(key = "TEMP", value = ctx.attr.msvc_env_tmp), + ], + ), + ], + implies = ["msvc_compile_env", "msvc_link_env"], + ) + features = [ + no_legacy_features_feature, + nologo_feature, + has_configured_linker_path_feature, + no_stripping_feature, + targets_windows_feature, + copy_dynamic_libraries_to_binary_feature, + default_compile_flags_feature, + msvc_env_feature, + msvc_compile_env_feature, + msvc_link_env_feature, + include_paths_feature, + external_include_paths_feature, + preprocessor_defines_feature, + parse_showincludes_feature, + no_dotd_file_feature, + generate_pdb_file_feature, + shared_flag_feature, + linkstamps_feature, + output_execpath_flags_feature, + archiver_flags_feature, + input_param_flags_feature, + linker_subsystem_flag_feature, + user_link_flags_feature, + default_link_flags_feature, + linker_param_file_feature, + static_link_msvcrt_feature, + dynamic_link_msvcrt_feature, + dbg_feature, + fastbuild_feature, + opt_feature, + frame_pointer_feature, + disable_assertions_feature, + determinism_feature, + treat_warnings_as_errors_feature, + smaller_binary_feature, + ignore_noisy_warnings_feature, + user_compile_flags_feature, + sysroot_feature, + unfiltered_compile_flags_feature, + archive_param_file_feature, + compiler_param_file_feature, + compiler_output_flags_feature, + compiler_input_flags_feature, + def_file_feature, + windows_export_all_symbols_feature, + no_windows_export_all_symbols_feature, + supports_dynamic_linker_feature, + supports_interface_shared_libraries_feature, + ] + else: + targets_windows_feature = feature( + name = "targets_windows", + implies = ["copy_dynamic_libraries_to_binary"], + enabled = True, + ) + + copy_dynamic_libraries_to_binary_feature = feature(name = "copy_dynamic_libraries_to_binary") + + gcc_env_feature = feature( + name = "gcc_env", + enabled = True, + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_static_library, + ], + env_entries = [ + env_entry(key = "PATH", value = ctx.attr.tool_bin_path), + ], + ), + ], + ) + + default_compile_flags_feature = feature( + name = "default_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = [flag_group(flags = ["-std=gnu++14"])], + ), + ], + ) + + default_link_flags_feature = feature( + name = "default_link_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["-lstdc++"])], + ), + ], + ) + + supports_dynamic_linker_feature = feature( + name = "supports_dynamic_linker", + enabled = True, + ) + + dbg_feature = feature( + name = "dbg", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["-g", "-Og"])], + ), + ], + ) + + opt_feature = feature( + name = "opt", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = [ + "-g0", + "-O3", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["-Wl,--gc-sections"])], + ), + ], + ) + + if ctx.attr.cpu == "x64_windows" and ctx.attr.compiler == "mingw-gcc": + archive_param_file_feature = feature( + name = "archive_param_file", + enabled = True, + ) + + compiler_param_file_feature = feature( + name = "compiler_param_file", + ) + + features = [ + targets_windows_feature, + copy_dynamic_libraries_to_binary_feature, + gcc_env_feature, + default_compile_flags_feature, + archive_param_file_feature, + compiler_param_file_feature, + default_link_flags_feature, + supports_dynamic_linker_feature, + dbg_feature, + opt_feature, + ] + else: + supports_pic_feature = feature( + name = "supports_pic", + enabled = True, + ) + + sysroot_feature = feature( + name = "sysroot", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = ["--sysroot=%{sysroot}"], + expand_if_available = "sysroot", + ), + ], + ), + ], + ) + + fdo_optimize_feature = feature( + name = "fdo_optimize", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "-fprofile-use=%{fdo_profile_path}", + "-fprofile-correction", + ], + expand_if_available = "fdo_profile_path", + ), + ], + ), + ], + provides = ["profile"], + ) + + treat_warnings_as_errors_feature = feature( + name = "treat_warnings_as_errors", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["-Werror"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["-Wl,-fatal-warnings"])], + ), + ], + ) + + user_compile_flags_feature = feature( + name = "user_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = ["%{user_compile_flags}"], + iterate_over = "user_compile_flags", + expand_if_available = "user_compile_flags", + ), + ], + ), + ], + ) + + features = [ + targets_windows_feature, + copy_dynamic_libraries_to_binary_feature, + gcc_env_feature, + supports_pic_feature, + default_compile_flags_feature, + default_link_flags_feature, + fdo_optimize_feature, + supports_dynamic_linker_feature, + dbg_feature, + opt_feature, + user_compile_flags_feature, + treat_warnings_as_errors_feature, + sysroot_feature, + ] + + tool_paths = [ + tool_path(name = name, path = path) + for name, path in ctx.attr.tool_paths.items() + ] + + return cc_common.create_cc_toolchain_config_info( + ctx = ctx, + features = features, + action_configs = action_configs, + artifact_name_patterns = artifact_name_patterns, + cxx_builtin_include_directories = ctx.attr.cxx_builtin_include_directories, + toolchain_identifier = ctx.attr.toolchain_identifier, + host_system_name = ctx.attr.host_system_name, + target_system_name = ctx.attr.target_system_name, + target_cpu = ctx.attr.cpu, + target_libc = ctx.attr.target_libc, + compiler = ctx.attr.compiler, + abi_version = ctx.attr.abi_version, + abi_libc_version = ctx.attr.abi_libc_version, + tool_paths = tool_paths, + ) + +cc_toolchain_config = rule( + implementation = _impl, + attrs = { + "cpu": attr.string(mandatory = True), + "compiler": attr.string(), + "toolchain_identifier": attr.string(), + "host_system_name": attr.string(), + "target_system_name": attr.string(), + "target_libc": attr.string(), + "abi_version": attr.string(), + "abi_libc_version": attr.string(), + "tool_paths": attr.string_dict(), + "cxx_builtin_include_directories": attr.string_list(), + "archiver_flags": attr.string_list(default = []), + "default_link_flags": attr.string_list(default = []), + "msvc_env_tmp": attr.string(default = "msvc_not_found"), + "msvc_env_path": attr.string(default = "msvc_not_found"), + "msvc_env_include": attr.string(default = "msvc_not_found"), + "msvc_env_lib": attr.string(default = "msvc_not_found"), + "msvc_cl_path": attr.string(default = "vc_installation_error.bat"), + "msvc_ml_path": attr.string(default = "vc_installation_error.bat"), + "msvc_link_path": attr.string(default = "vc_installation_error.bat"), + "msvc_lib_path": attr.string(default = "vc_installation_error.bat"), + "dbg_mode_debug_flag": attr.string(), + "fastbuild_mode_debug_flag": attr.string(), + "tool_bin_path": attr.string(default = "not_found"), + "supports_parse_showincludes": attr.bool(), + }, + provides = [CcToolchainConfigInfo], +) diff --git a/third_party/xla/third_party/tsl/tools/toolchains/win2022/BUILD b/third_party/xla/third_party/tsl/tools/toolchains/win2022/BUILD new file mode 100644 index 00000000000000..82434f82ddbdd3 --- /dev/null +++ b/third_party/xla/third_party/tsl/tools/toolchains/win2022/BUILD @@ -0,0 +1,37 @@ +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +java_runtime( + name = "windows_jdk8", + srcs = [], + java_home = "C:/openjdk", +) + +# Register a Windows 2022 (Clang) platform. +# Note that while this does support RBE, the current pool size is tiny, +# and this platform is meant to be used as a non-RBE one, for now. +platform( + name = "windows_ltsc2022_clang", + constraint_values = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + "@bazel_tools//tools/cpp:clang-cl", + ], + remote_execution_properties = """ + properties:{ + name: "container-image" + value: "docker://gcr.io/tensorflow-testing/tf-win2022@sha256:915cb093630432c38b028f56bd31116a5559ebbc688d427b6092d86828ae03bc" + } + properties:{ + name: "OSFamily" + value: "Windows" + } + properties:{ + name: "Pool" value: "win2022" + } + properties:{ + name: "dockerNetwork" value: "off" + } + """, +) diff --git a/third_party/xla/third_party/tsl/tsl/platform/BUILD b/third_party/xla/third_party/tsl/tsl/platform/BUILD index 618192a8888479..9744f7abe27aa3 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/BUILD @@ -16,7 +16,6 @@ load( "tf_error_logging_deps", "tf_fingerprint_deps", "tf_google_mobile_srcs_no_runtime", - "tf_logging_deps", "tf_platform_deps", "tf_protobuf_compiler_deps", "tf_resource_deps", @@ -26,7 +25,6 @@ load( "tsl_grpc_credentials_deps", "tsl_protobuf_deps", ) -load("@local_xla//xla/tsl/platform:build_config_root.bzl", "if_static") load( "@local_xla//xla/tsl/platform:rules_cc.bzl", "cc_library", @@ -45,12 +43,12 @@ cc_library( srcs = ["base64.cc"], hdrs = ["base64.h"], deps = [ - ":errors", - ":macros", - ":status", ":stringpiece", - ":types", "@com_google_absl//absl/status", + "@local_xla//xla/tsl/platform:errors", + "@local_xla//xla/tsl/platform:macros", + "@local_xla//xla/tsl/platform:status", + "@local_xla//xla/tsl/platform:types", ], ) @@ -59,8 +57,8 @@ cc_library( hdrs = ["blocking_counter.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":logging", ":mutex", + "@local_xla//xla/tsl/platform:logging", ], ) @@ -78,7 +76,7 @@ cc_library( ":byte_order", ":stringpiece", ":tstring", - ":types", + "@local_xla//xla/tsl/platform:types", ], ) @@ -88,8 +86,8 @@ tsl_cc_test( srcs = ["cpu_info_test.cc"], deps = [ ":platform_port", - ":test", - ":test_main", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -108,8 +106,8 @@ tsl_cc_test( ], deps = [ ":criticality", - ":test", - ":test_main", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -118,9 +116,9 @@ cc_library( srcs = ["denormal.cc"], hdrs = ["denormal.h"], deps = [ - ":macros", ":platform", ":platform_port", + "@local_xla//xla/tsl/platform:macros", ], ) @@ -130,8 +128,8 @@ tsl_cc_test( srcs = ["denormal_test.cc"], deps = [ ":denormal", - ":test", - ":test_main", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -143,57 +141,40 @@ cc_library( "file_system_helper.h", "threadpool.h", ], - deps = tf_windows_aware_platform_deps("env") + if_static([":env_impl"]), + deps = [ + "@local_xla//xla/tsl/platform:env", + ], ) cc_library( name = "threadpool_async_executor", hdrs = ["threadpool_async_executor.h"], deps = [ - ":env", - "@local_xla//xla/tsl/concurrency:async_value", - ], -) - -tsl_cc_test( - name = "threadpool_async_executor_test", - srcs = ["threadpool_async_executor_test.cc"], - deps = [ - ":env", - ":env_impl", - ":test", - ":test_main", - ":threadpool_async_executor", - "@com_google_absl//absl/synchronization", + "@local_xla//xla/tsl/platform:threadpool_async_executor", ], ) cc_library( name = "env_impl", - deps = tf_windows_aware_platform_deps("env_impl"), + deps = [ + "@local_xla//xla/tsl/platform:env_impl", + ], ) cc_library( name = "env_time", compatible_with = get_compatible_with_portable(), textual_hdrs = ["env_time.h"], - deps = tf_windows_aware_platform_deps("env_time"), + deps = [ + "@local_xla//xla/tsl/platform:env_time", + ], ) cc_library( name = "errors", - srcs = ["errors.cc"], hdrs = ["errors.h"], deps = [ - ":logging", - ":macros", - ":status", - ":str_util", - ":strcat", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:cord", + "@local_xla//xla/tsl/platform:errors", ], ) @@ -223,13 +204,14 @@ cc_library( srcs = ["numbers.cc"], hdrs = ["numbers.h"], deps = [ - ":logging", - ":macros", ":str_util", ":stringpiece", ":stringprintf", - ":types", - "@double_conversion//:double-conversion", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/platform:logging", + "@local_xla//xla/tsl/platform:macros", + "@local_xla//xla/tsl/platform:types", ], ) @@ -238,14 +220,14 @@ cc_library( srcs = ["path.cc"], hdrs = ["path.h"], deps = [ - ":logging", ":mutex", ":scanner", ":str_util", ":strcat", ":stringpiece", - ":types", "@com_google_absl//absl/algorithm:container", + "@local_xla//xla/tsl/platform:logging", + "@local_xla//xla/tsl/platform:types", ], alwayslink = True, ) @@ -259,7 +241,7 @@ cc_library( hdrs = ["protobuf.h"], deps = [ ":platform", - ":types", + "@local_xla//xla/tsl/platform:types", ] + tsl_protobuf_deps(), ) @@ -288,55 +270,26 @@ cc_library( cc_library( name = "status", - srcs = ["status.cc"], hdrs = ["status.h"], deps = [ - ":logging", - ":macros", - ":mutex", - ":platform", - ":stack_frame", - ":stacktrace", - ":str_util", - ":strcat", - ":stringprintf", - ":types", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:cord", - "@com_google_absl//absl/types:optional", - "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", - ] + tf_platform_deps("status"), + "@local_xla//xla/tsl/platform:status", + ], ) cc_library( name = "status_to_from_proto", - srcs = [ - "status_to_from_proto.cc", - ], hdrs = ["status_to_from_proto.h"], deps = [ - ":status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:cord", - "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", - "@local_xla//xla/tsl/protobuf:status_proto_cc", - ] + tf_platform_deps("status"), + "@local_xla//xla/tsl/platform:status_to_from_proto", + ], ) cc_library( name = "status_matchers", testonly = 1, - srcs = ["status_matchers.cc"], hdrs = ["status_matchers.h"], deps = [ - ":status", - ":statusor", - ":test", - "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/platform:status_matchers", ], ) @@ -344,17 +297,8 @@ cc_library( name = "statusor", hdrs = ["statusor.h"], deps = [ - ":errors", - ":logging", - ":macros", - ":platform", - ":status", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - ] + tf_platform_deps("statusor"), + "@local_xla//xla/tsl/platform:statusor", + ], ) cc_library( @@ -366,17 +310,10 @@ cc_library( cc_library( name = "test", testonly = True, - srcs = ["test.cc"], compatible_with = get_compatible_with_portable(), textual_hdrs = ["test.h"], deps = [ - ":logging", - ":macros", - ":net", - ":path", - ":platform", - ":types", - "@com_google_googletest//:gtest", + "@local_xla//xla/tsl/platform:test", ], ) @@ -386,8 +323,7 @@ cc_library( hdrs = ["test_benchmark.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":platform", - "@com_google_benchmark//:benchmark", + "@local_xla//xla/tsl/platform:test_benchmark", ], ) @@ -552,13 +488,10 @@ filegroup( "denormal.cc", "denormal.h", "dynamic_annotations.h", - "env.cc", "env.h", "env_time.h", - "errors.cc", "errors.h", "file_statistics.h", - "file_system.cc", "file_system.h", "file_system_helper.h", "hash.cc", @@ -591,7 +524,6 @@ filegroup( "setround.h", "snappy.h", "stacktrace.h", - "status.cc", "status.h", "statusor.h", "str_util.cc", @@ -602,7 +534,6 @@ filegroup( "stringprintf.cc", "stringprintf.h", "thread_annotations.h", - "threadpool.cc", "threadpool.h", "threadpool_interface.h", "tracing.h", @@ -610,7 +541,6 @@ filegroup( ] + select({ "@local_xla//xla/tsl:fuchsia": tf_google_mobile_srcs_no_runtime(), "//conditions:default": [ - "file_system_helper.cc", "tracing.cc", "@local_xla//xla/tsl/platform/default:mobile_srcs_no_runtime", ], @@ -672,13 +602,11 @@ exports_files( "criticality.h", "cuda_root_path.h", "demangle.h", - "env.cc", "env.h", "env_time.h", "error_logging.h", "file_system.cc", "file_system.h", - "file_system_helper.cc", "file_system_helper.h", "grpc_credentials.h", "host_info.h", @@ -811,6 +739,9 @@ cc_library( name = "macros", hdrs = ["macros.h"], compatible_with = get_compatible_with_portable(), + deps = [ + "@local_xla//xla/tsl/platform:macros", + ], ) filegroup( @@ -871,7 +802,7 @@ cc_library( hdrs = ["random.h"], deps = [ ":mutex", - ":types", + "@local_xla//xla/tsl/platform:types", ], ) @@ -881,10 +812,10 @@ cc_library( srcs = ["resource_loader.cc"], textual_hdrs = ["resource_loader.h"], deps = [ - ":logging", ":path", ":platform", - ":test", + "@local_xla//xla/tsl/platform:logging", + "@local_xla//xla/tsl/platform:test", ], ) @@ -924,11 +855,11 @@ tsl_cc_test( ], tags = ["no_windows"], deps = [ - ":logging", ":stacktrace", ":stacktrace_handler", - ":test", - ":test_main", + "@local_xla//xla/tsl/platform:logging", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -937,12 +868,12 @@ cc_library( srcs = ["str_util.cc"], hdrs = ["str_util.h"], deps = [ - ":logging", - ":macros", ":stringpiece", - ":types", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/platform:logging", + "@local_xla//xla/tsl/platform:macros", + "@local_xla//xla/tsl/platform:types", ], ) @@ -951,12 +882,12 @@ cc_library( srcs = ["strcat.cc"], hdrs = ["strcat.h"], deps = [ - ":logging", - ":macros", ":numbers", ":stringpiece", - ":types", "@com_google_absl//absl/meta:type_traits", + "@local_xla//xla/tsl/platform:logging", + "@local_xla//xla/tsl/platform:macros", + "@local_xla//xla/tsl/platform:types", ], ) @@ -984,8 +915,8 @@ cc_library( srcs = ["stringprintf.cc"], hdrs = ["stringprintf.h"], deps = [ - ":macros", - ":types", + "@local_xla//xla/tsl/platform:macros", + "@local_xla//xla/tsl/platform:types", ], ) @@ -1003,9 +934,7 @@ cc_library( hdrs = ["threadpool_interface.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":mutex", - ":types", - "@eigen_archive//:eigen3", + "@local_xla//xla/tsl/platform:threadpool_interface", ], ) @@ -1014,11 +943,8 @@ cc_library( hdrs = ["types.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":bfloat16", - ":ml_dtypes", - ":platform", - ":tstring", - ] + tf_platform_deps("types"), + "@local_xla//xla/tsl/platform:types", + ], ) cc_library( @@ -1030,14 +956,14 @@ cc_library( deps = [ ":byte_order", ":fingerprint", - ":macros", ":net", ":platform", ":platform_port", ":platform_strings", ":stacktrace_handler", ":stringpiece", - ":threadpool_interface", + "@local_xla//xla/tsl/platform:macros", + "@local_xla//xla/tsl/platform:threadpool_interface", ], ) @@ -1058,6 +984,7 @@ cc_library( deps = [ "@ml_dtypes//:float8", "@ml_dtypes//:intn", + "@ml_dtypes//:mxfloat", ], ) @@ -1076,7 +1003,9 @@ cc_library( visibility = [ "//visibility:public", ], - deps = tf_logging_deps(), + deps = [ + "@local_xla//xla/tsl/platform:logging", + ], ) cc_library( @@ -1114,10 +1043,10 @@ cc_library( hdrs = ["hash.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":macros", ":raw_coding", ":stringpiece", - ":types", + "@local_xla//xla/tsl/platform:macros", + "@local_xla//xla/tsl/platform:types", ], ) @@ -1133,7 +1062,7 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ ":byte_order", - ":types", + "@local_xla//xla/tsl/platform:types", ], ) @@ -1163,8 +1092,8 @@ cc_library( srcs = ["setround.cc"], hdrs = ["setround.h"], deps = [ - ":logging", - ":macros", + "@local_xla//xla/tsl/platform:logging", + "@local_xla//xla/tsl/platform:macros", ], ) @@ -1185,10 +1114,10 @@ tsl_cc_test( ], tags = ["no_windows"], deps = [ - ":logging", ":stacktrace", - ":test", - ":test_main", + "@local_xla//xla/tsl/platform:logging", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -1203,7 +1132,7 @@ cc_library( name = "file_statistics", hdrs = ["file_statistics.h"], deps = [ - ":types", + "@local_xla//xla/tsl/platform:file_statistics", ], ) @@ -1214,7 +1143,7 @@ cc_library( deps = [ ":platform", ":stringpiece", - ":types", + "@local_xla//xla/tsl/platform:types", ] + tf_fingerprint_deps(), ) @@ -1224,9 +1153,9 @@ tsl_cc_test( srcs = ["fingerprint_test.cc"], deps = [ ":fingerprint", - ":test", - ":test_main", - ":types", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", + "@local_xla//xla/tsl/platform:types", ], ) @@ -1243,9 +1172,9 @@ cc_library( srcs = ["scanner.cc"], hdrs = ["scanner.h"], deps = [ - ":macros", ":str_util", ":stringpiece", + "@local_xla//xla/tsl/platform:macros", ], ) @@ -1269,9 +1198,9 @@ tsl_cc_test( size = "small", srcs = ["ctstring_test.cc"], deps = [ - ":test", - ":test_main", ":tstring", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -1281,10 +1210,10 @@ tsl_cc_test( srcs = ["hash_test.cc"], deps = [ ":hash", - ":logging", - ":test", - ":test_benchmark", - ":test_main", + "@local_xla//xla/tsl/platform:logging", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_benchmark", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -1293,12 +1222,12 @@ tsl_cc_test( size = "small", srcs = ["path_test.cc"], deps = [ - ":env", - ":env_impl", ":path", ":stringpiece", - ":test", - ":test_main", + "@local_xla//xla/tsl/platform:env", + "@local_xla//xla/tsl/platform:env_impl", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -1307,9 +1236,9 @@ tsl_cc_test( srcs = ["random_test.cc"], deps = [ ":random", - ":test", - ":test_main", - ":types", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", + "@local_xla//xla/tsl/platform:types", ], ) @@ -1321,81 +1250,21 @@ tsl_cc_test( ":cord", ":platform", ":stringpiece", - ":test", - ":test_main", ":tstring", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) cc_library( name = "test_main", testonly = 1, - srcs = ["test_main.cc"], - copts = tsl_copts(), - linkopts = select({ - "@local_xla//xla/tsl:windows": [], - "//conditions:default": ["-lm"], - }), deps = [ - ":platform", - ":stacktrace_handler", - ":test", - ":test_benchmark", - "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/platform:test_main", ], alwayslink = 1, ) -tsl_cc_test( - name = "status_test", - size = "small", - srcs = ["status_test.cc"], - deps = [ - ":errors", - ":stack_frame", - ":status", - ":status_matchers", - ":status_to_from_proto", - ":test", - ":test_main", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:cord", - "@com_google_absl//absl/strings:str_format", - "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", - "@local_xla//xla/tsl/protobuf:status_proto_cc", - ], -) - -tsl_cc_test( - name = "statusor_test", - size = "small", - srcs = ["statusor_test.cc"], - deps = [ - ":errors", - ":macros", - ":statusor", - ":test", - ":test_benchmark", - ":test_main", - "@com_google_absl//absl/base:config", - ], -) - -tsl_cc_test( - name = "status_matchers_test", - size = "small", - srcs = ["status_matchers_test.cc"], - deps = [ - ":errors", - ":status", - ":status_matchers", - ":statusor", - ":test", - ":test_main", - "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", - ], -) - cc_library( name = "notification", hdrs = ["notification.h"], @@ -1411,7 +1280,7 @@ cc_library( hdrs = ["threadpool_options.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":threadpool_interface", + "@local_xla//xla/tsl/platform:threadpool_options", ], ) @@ -1429,13 +1298,13 @@ tsl_cc_test( srcs = ["unbounded_work_queue_test.cc"], deps = [ ":blocking_counter", - ":env", - ":env_impl", ":random", - ":test", - ":test_main", ":unbounded_work_queue", "@com_google_absl//absl/memory", + "@local_xla//xla/tsl/platform:env", + "@local_xla//xla/tsl/platform:env_impl", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -1450,7 +1319,7 @@ cc_library( name = "load_library", textual_hdrs = ["load_library.h"], deps = [ - ":status", + "@local_xla//xla/tsl/platform:status", ] + tf_windows_aware_platform_deps("load_library"), ) @@ -1459,7 +1328,7 @@ cc_library( srcs = ["abi.cc"], hdrs = ["abi.h"], deps = [ - ":types", + "@local_xla//xla/tsl/platform:types", ], ) @@ -1467,9 +1336,9 @@ cc_library( name = "refcount", hdrs = ["refcount.h"], deps = [ - ":logging", ":mutex", ":thread_annotations", + "@local_xla//xla/tsl/platform:logging", ], ) @@ -1477,19 +1346,7 @@ cc_library( name = "null_file_system", hdrs = ["null_file_system.h"], deps = [ - ":env", - ], -) - -tsl_cc_test( - name = "errors_test", - size = "small", - srcs = ["errors_test.cc"], - deps = [ - ":errors", - ":test", - ":test_main", - "@com_google_absl//absl/status", + "@local_xla//xla/tsl/platform:env", ], ) @@ -1502,8 +1359,8 @@ tsl_cc_test( deps = [ ":intrusive_ptr", ":refcount", - ":test", - ":test_main", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -1518,9 +1375,10 @@ tsl_cc_test( "notap", ], deps = [ - ":logging", - ":test", - ":test_main", + ":platform_port", + "@local_xla//xla/tsl/platform:logging", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -1531,8 +1389,8 @@ tsl_cc_test( tags = ["noclang"], deps = [ ":setround", - ":test", - ":test_main", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -1543,11 +1401,11 @@ tsl_cc_test( "refcount_test.cc", ], deps = [ - ":env", - ":env_impl", ":refcount", - ":test", - ":test_main", + "@local_xla//xla/tsl/platform:env", + "@local_xla//xla/tsl/platform:env_impl", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -1558,29 +1416,9 @@ tsl_cc_test( "integral_types_test.cc", ], deps = [ - ":test", - ":test_main", - ":types", - ], -) - -tsl_cc_test( - name = "logging_test", - size = "small", - srcs = [ - "logging_test.cc", - ], - deps = [ - ":logging", - ":path", - ":stacktrace_handler", - ":statusor", - ":test", - "@com_google_absl//absl/base:log_severity", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", + "@local_xla//xla/tsl/platform:types", ], ) @@ -1591,10 +1429,6 @@ tsl_cc_test( "mutex_test.cc", ], deps = [ - ":env", - ":env_impl", - ":env_time", - ":logging", ":mutex", ":net", ":platform_port", @@ -1603,9 +1437,13 @@ tsl_cc_test( ":strcat", ":stringpiece", ":stringprintf", - ":test", - ":test_main", - ":types", + "@local_xla//xla/tsl/platform:env", + "@local_xla//xla/tsl/platform:env_impl", + "@local_xla//xla/tsl/platform:env_time", + "@local_xla//xla/tsl/platform:logging", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", + "@local_xla//xla/tsl/platform:types", ], ) @@ -1616,10 +1454,10 @@ tsl_cc_test( "net_test.cc", ], deps = [ - ":logging", ":net", - ":test", - ":test_main", + "@local_xla//xla/tsl/platform:logging", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -1633,12 +1471,12 @@ tsl_cc_test( "notap", #TODO(b/245510532) : disabled due to flakiness. ], deps = [ - ":env", - ":env_impl", ":mutex", ":platform_port", - ":test", - ":test_main", + "@local_xla//xla/tsl/platform:env", + "@local_xla//xla/tsl/platform:env_impl", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -1650,8 +1488,8 @@ tsl_cc_test( ], deps = [ ":scanner", - ":test", - ":test_main", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -1663,8 +1501,8 @@ tsl_cc_test( ], deps = [ ":str_util", - ":test", - ":test_main", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -1677,10 +1515,10 @@ tsl_cc_test( deps = [ ":strcat", ":stringprintf", - ":test", - ":test_main", - ":types", "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", + "@local_xla//xla/tsl/platform:types", ], ) @@ -1692,8 +1530,8 @@ tsl_cc_test( ], deps = [ ":stringpiece", - ":test", - ":test_main", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -1705,8 +1543,8 @@ tsl_cc_test( ], deps = [ ":stringprintf", - ":test", - ":test_main", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -1718,8 +1556,10 @@ tsl_cc_test( ], deps = [ ":numbers", - ":test", - ":test_main", + "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", + "@local_xla//xla/tsl/platform:types", ], ) @@ -1733,12 +1573,12 @@ cc_library( ], copts = tsl_copts(), deps = [ - ":env", - ":errors", - ":logging", ":random", - ":status", "@com_google_absl//absl/time", + "@local_xla//xla/tsl/platform:env", + "@local_xla//xla/tsl/platform:errors", + "@local_xla//xla/tsl/platform:logging", + "@local_xla//xla/tsl/platform:status", ], ) @@ -1749,11 +1589,11 @@ cc_library( ], copts = tsl_copts(), deps = [ - ":env", - ":errors", ":random", ":retrying_utils", - ":status", + "@local_xla//xla/tsl/platform:env", + "@local_xla//xla/tsl/platform:errors", + "@local_xla//xla/tsl/platform:status", ], ) @@ -1762,12 +1602,12 @@ tsl_cc_test( size = "small", srcs = ["retrying_file_system_test.cc"], deps = [ - ":env_impl", ":retrying_file_system", ":str_util", - ":test", - ":test_main", "@local_xla//xla/tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/platform:env_impl", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -1776,14 +1616,14 @@ tsl_cc_test( size = "small", srcs = ["retrying_utils_test.cc"], deps = [ - ":env", - ":env_impl", - ":errors", ":retrying_utils", ":str_util", - ":test", - ":test_main", "@com_google_absl//absl/time", "@local_xla//xla/tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/platform:env", + "@local_xla//xla/tsl/platform:env_impl", + "@local_xla//xla/tsl/platform:errors", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/abi.cc b/third_party/xla/third_party/tsl/tsl/platform/abi.cc index 8e886535d45039..9e969f31249c65 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/abi.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/abi.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tsl/platform/abi.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" #if defined(_MSC_VER) #include diff --git a/third_party/xla/third_party/tsl/tsl/platform/abi.h b/third_party/xla/third_party/tsl/tsl/platform/abi.h index b7106a0d7203a3..20f2fbf063ea38 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/abi.h +++ b/third_party/xla/third_party/tsl/tsl/platform/abi.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace port { diff --git a/third_party/xla/third_party/tsl/tsl/platform/abi_test.cc b/third_party/xla/third_party/tsl/tsl/platform/abi_test.cc index ff4fef46e7ae6d..02fe441ef41565 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/abi_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/abi_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/base64.cc b/third_party/xla/third_party/tsl/tsl/platform/base64.cc index 6ea29ad399d0ad..7c21d29c930327 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/base64.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/base64.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include "absl/status/status.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" namespace tsl { namespace { diff --git a/third_party/xla/third_party/tsl/tsl/platform/base64.h b/third_party/xla/third_party/tsl/tsl/platform/base64.h index 2b8e204629bd59..08867207f6e76e 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/base64.h +++ b/third_party/xla/third_party/tsl/tsl/platform/base64.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "absl/status/status.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" #include "tsl/platform/stringpiece.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/blocking_counter.h b/third_party/xla/third_party/tsl/tsl/platform/blocking_counter.h index c085e4d66af54e..e46fc7591ba3ac 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/blocking_counter.h +++ b/third_party/xla/third_party/tsl/tsl/platform/blocking_counter.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/mutex.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/coding.cc b/third_party/xla/third_party/tsl/tsl/platform/coding.cc index f7d1cc387fc7b9..4f2be2f722f443 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/coding.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/coding.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tsl/platform/coding.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/byte_order.h" #include "tsl/platform/stringpiece.h" #include "tsl/platform/tstring.h" -#include "tsl/platform/types.h" namespace tsl { namespace core { diff --git a/third_party/xla/third_party/tsl/tsl/platform/coding.h b/third_party/xla/third_party/tsl/tsl/platform/coding.h index 5947b2ed3b4d5e..b8153c18de45fd 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/coding.h +++ b/third_party/xla/third_party/tsl/tsl/platform/coding.h @@ -21,9 +21,9 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_CODING_H_ #define TENSORFLOW_TSL_PLATFORM_CODING_H_ +#include "xla/tsl/platform/types.h" #include "tsl/platform/stringpiece.h" #include "tsl/platform/tstring.h" -#include "tsl/platform/types.h" namespace tsl { namespace core { diff --git a/third_party/xla/third_party/tsl/tsl/platform/cpu_info.cc b/third_party/xla/third_party/tsl/tsl/platform/cpu_info.cc index 1de5eb8031623d..5ed6c7ff4c0ade 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cpu_info.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cpu_info.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tsl/platform/cpu_info.h" #include "absl/base/call_once.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/platform.h" -#include "tsl/platform/types.h" #if defined(PLATFORM_IS_X86) #include // NOLINT #endif diff --git a/third_party/xla/third_party/tsl/tsl/platform/cpu_info_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cpu_info_test.cc index dbef5a57f47397..e4757931cf6cb4 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cpu_info_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cpu_info_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tsl/platform/cpu_info.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/criticality_test.cc b/third_party/xla/third_party/tsl/tsl/platform/criticality_test.cc index c3cf04f04cd540..1812fa4df444c7 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/criticality_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/criticality_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tsl/platform/criticality.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace criticality { diff --git a/third_party/xla/third_party/tsl/tsl/platform/ctstring_test.cc b/third_party/xla/third_party/tsl/tsl/platform/ctstring_test.cc index 040881eccc847c..61f126a976d4cb 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/ctstring_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/ctstring_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "xla/tsl/platform/test.h" #include "tsl/platform/ctstring_internal.h" -#include "tsl/platform/test.h" static const char kLongString[] = "abcdefghij" diff --git a/third_party/xla/third_party/tsl/tsl/platform/demangle.h b/third_party/xla/third_party/tsl/tsl/platform/demangle.h index 95f07ff0ce1bcc..4b7576f8dc4f31 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/demangle.h +++ b/third_party/xla/third_party/tsl/tsl/platform/demangle.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_DEMANGLE_H_ #define TENSORFLOW_TSL_PLATFORM_DEMANGLE_H_ -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace port { diff --git a/third_party/xla/third_party/tsl/tsl/platform/denormal.h b/third_party/xla/third_party/tsl/tsl/platform/denormal.h index 5b13ab1b0d752c..05e52d3ceae4f7 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/denormal.h +++ b/third_party/xla/third_party/tsl/tsl/platform/denormal.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_DENORMAL_H_ #define TENSORFLOW_TSL_PLATFORM_DENORMAL_H_ -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" namespace tsl { namespace port { diff --git a/third_party/xla/third_party/tsl/tsl/platform/denormal_test.cc b/third_party/xla/third_party/tsl/tsl/platform/denormal_test.cc index 74102f7ab451ba..0b682c002bc5cc 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/denormal_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/denormal_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace port { diff --git a/third_party/xla/third_party/tsl/tsl/platform/env.h b/third_party/xla/third_party/tsl/tsl/platform/env.h index f814e39339ecc8..806cbb1c9860bb 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/env.h +++ b/third_party/xla/third_party/tsl/tsl/platform/env.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,720 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_ENV_H_ #define TENSORFLOW_TSL_PLATFORM_ENV_H_ -#include - -#include -#include -#include -#include -#include - -#include "absl/functional/any_invocable.h" -#include "tsl/platform/env_time.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/file_system.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/mutex.h" -#include "tsl/platform/numa.h" -#include "tsl/platform/platform.h" -#include "tsl/platform/protobuf.h" -#include "tsl/platform/status.h" -#include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" - -// Delete leaked Windows definitions. -#ifdef PLATFORM_WINDOWS -#undef CopyFile -#undef DeleteFile -#endif - -namespace tsl { - -class Thread; -struct ThreadOptions; - -/// \brief An interface used by the tensorflow implementation to -/// access operating system functionality like the filesystem etc. -/// -/// Callers may wish to provide a custom Env object to get fine grain -/// control. -/// -/// All Env implementations of file-system modifying functionality are safe -/// for concurrent access from multiple threads without any external -/// synchronization, *however*, Envs and their underlying file systems are -/// global objects, and therefore, if any thread modifies options, the modified -/// options take effect process-wide. The SetOption functions themselves are -/// also *not* thread safe. -class Env { - public: - Env(); - virtual ~Env() = default; - - /// \brief Returns a default environment suitable for the current operating - /// system. - /// - /// Sophisticated users may wish to provide their own Env - /// implementation instead of relying on this default environment. - /// - /// The result of Default() belongs to this library and must never be deleted. - static Env* Default(); - - /// \brief Returns the FileSystem object to handle operations on the file - /// specified by 'fname'. The FileSystem object is used as the implementation - /// for the file system related (non-virtual) functions that follow. - /// Returned FileSystem object is still owned by the Env object and will - // (might) be destroyed when the environment is destroyed. - virtual absl::Status GetFileSystemForFile(const std::string& fname, - FileSystem** result); - - /// \brief Returns the file system schemes registered for this Env. - virtual absl::Status GetRegisteredFileSystemSchemes( - std::vector* schemes); - - /// \brief Register a file system for a scheme. - virtual absl::Status RegisterFileSystem(const std::string& scheme, - FileSystemRegistry::Factory factory); - - /// \brief Register a modular file system for a scheme. - /// - /// Same as `RegisterFileSystem` but for filesystems provided by plugins. - /// - /// TODO(b/139060984): After all filesystems are converted, make this be the - /// canonical registration function. - virtual absl::Status RegisterFileSystem( - const std::string& scheme, std::unique_ptr filesystem); - - absl::Status SetOption(const std::string& scheme, const std::string& key, - const std::string& value); - - absl::Status SetOption(const std::string& scheme, const std::string& key, - const std::vector& values); - - absl::Status SetOption(const std::string& scheme, const std::string& key, - const std::vector& values); - - absl::Status SetOption(const std::string& scheme, const std::string& key, - const std::vector& values); - - /// \brief Flush filesystem caches for all registered filesystems. - absl::Status FlushFileSystemCaches(); - - /// \brief Creates a brand new random access read-only file with the - /// specified name. - - /// On success, stores a pointer to the new file in - /// *result and returns OK. On failure stores NULL in *result and - /// returns non-OK. If the file does not exist, returns a non-OK - /// status. - /// - /// The returned file may be concurrently accessed by multiple threads. - /// - /// The ownership of the returned RandomAccessFile is passed to the caller - /// and the object should be deleted when is not used. The file object - /// shouldn't live longer than the Env object. - absl::Status NewRandomAccessFile(const std::string& fname, - std::unique_ptr* result); - - absl::Status NewRandomAccessFile(const std::string& fname, - TransactionToken* token, - std::unique_ptr* result) { - // We duplicate these methods due to Google internal coding style prevents - // virtual functions with default arguments. See PR #41615. - return absl::OkStatus(); - } - - /// \brief Creates an object that writes to a new file with the specified - /// name. - /// - /// Deletes any existing file with the same name and creates a - /// new file. On success, stores a pointer to the new file in - /// *result and returns OK. On failure stores NULL in *result and - /// returns non-OK. - /// - /// The returned file will only be accessed by one thread at a time. - /// - /// The ownership of the returned WritableFile is passed to the caller - /// and the object should be deleted when is not used. The file object - /// shouldn't live longer than the Env object. - absl::Status NewWritableFile(const std::string& fname, - std::unique_ptr* result); - - absl::Status NewWritableFile(const std::string& fname, - TransactionToken* token, - std::unique_ptr* result) { - return absl::OkStatus(); - } - - /// \brief Creates an object that either appends to an existing file, or - /// writes to a new file (if the file does not exist to begin with). - /// - /// On success, stores a pointer to the new file in *result and - /// returns OK. On failure stores NULL in *result and returns - /// non-OK. - /// - /// The returned file will only be accessed by one thread at a time. - /// - /// The ownership of the returned WritableFile is passed to the caller - /// and the object should be deleted when is not used. The file object - /// shouldn't live longer than the Env object. - absl::Status NewAppendableFile(const std::string& fname, - std::unique_ptr* result); - - absl::Status NewAppendableFile(const std::string& fname, - TransactionToken* token, - std::unique_ptr* result) { - return absl::OkStatus(); - } - /// \brief Creates a readonly region of memory with the file context. - /// - /// On success, it returns a pointer to read-only memory region - /// from the content of file fname. The ownership of the region is passed to - /// the caller. On failure stores nullptr in *result and returns non-OK. - /// - /// The returned memory region can be accessed from many threads in parallel. - /// - /// The ownership of the returned ReadOnlyMemoryRegion is passed to the caller - /// and the object should be deleted when is not used. The memory region - /// object shouldn't live longer than the Env object. - absl::Status NewReadOnlyMemoryRegionFromFile( - const std::string& fname, std::unique_ptr* result); - - absl::Status NewReadOnlyMemoryRegionFromFile( - const std::string& fname, TransactionToken* token, - std::unique_ptr* result) { - return absl::OkStatus(); - } - - /// Returns OK if the named path exists and NOT_FOUND otherwise. - absl::Status FileExists(const std::string& fname); - - absl::Status FileExists(const std::string& fname, TransactionToken* token) { - return absl::OkStatus(); - } - - /// Returns true if all the listed files exist, false otherwise. - /// if status is not null, populate the vector with a detailed status - /// for each file. - bool FilesExist(const std::vector& files, - std::vector* status); - - bool FilesExist(const std::vector& files, TransactionToken* token, - std::vector* status) { - return true; - } - - /// \brief Stores in *result the names of the children of the specified - /// directory. The names are relative to "dir". - /// - /// Original contents of *results are dropped. - absl::Status GetChildren(const std::string& dir, std::vector* result); - - absl::Status GetChildren(const std::string& dir, TransactionToken* token, - std::vector* result) { - return absl::OkStatus(); - } - - /// \brief Returns true if the path matches the given pattern. The wildcards - /// allowed in pattern are described in FileSystem::GetMatchingPaths. - virtual bool MatchPath(const std::string& path, - const std::string& pattern) = 0; - - /// \brief Given a pattern, stores in *results the set of paths that matches - /// that pattern. *results is cleared. - /// - /// More details about `pattern` in FileSystem::GetMatchingPaths. - virtual absl::Status GetMatchingPaths(const std::string& pattern, - std::vector* results); - - absl::Status GetMatchingPaths(const std::string& pattern, - TransactionToken* token, - std::vector* results) { - return absl::OkStatus(); - } - - /// Deletes the named file. - absl::Status DeleteFile(const std::string& fname); - - absl::Status DeleteFile(const std::string& fname, TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Deletes the specified directory and all subdirectories and files - /// underneath it. This is accomplished by traversing the directory tree - /// rooted at dirname and deleting entries as they are encountered. - /// - /// If dirname itself is not readable or does not exist, *undeleted_dir_count - /// is set to 1, *undeleted_file_count is set to 0 and an appropriate status - /// (e.g. NOT_FOUND) is returned. - /// - /// If dirname and all its descendants were successfully deleted, TF_OK is - /// returned and both error counters are set to zero. - /// - /// Otherwise, while traversing the tree, undeleted_file_count and - /// undeleted_dir_count are updated if an entry of the corresponding type - /// could not be deleted. The returned error status represents the reason that - /// any one of these entries could not be deleted. - /// - /// REQUIRES: undeleted_files, undeleted_dirs to be not null. - /// - /// Typical return codes: - /// * OK - dirname exists and we were able to delete everything underneath. - /// * NOT_FOUND - dirname doesn't exist - /// * PERMISSION_DENIED - dirname or some descendant is not writable - /// * UNIMPLEMENTED - Some underlying functions (like Delete) are not - /// implemented - absl::Status DeleteRecursively(const std::string& dirname, - int64_t* undeleted_files, - int64_t* undeleted_dirs); - - absl::Status DeleteRecursively(const std::string& dirname, - TransactionToken* token, - int64_t* undeleted_files, - int64_t* undeleted_dirs) { - return absl::OkStatus(); - } - - /// \brief Creates the specified directory and all the necessary - /// subdirectories. Typical return codes. - /// * OK - successfully created the directory and sub directories, even if - /// they were already created. - /// * PERMISSION_DENIED - dirname or some subdirectory is not writable. - absl::Status RecursivelyCreateDir(const std::string& dirname); - - absl::Status RecursivelyCreateDir(const std::string& dirname, - TransactionToken* token) { - return absl::OkStatus(); - } - /// \brief Creates the specified directory. Typical return codes - /// * OK - successfully created the directory. - /// * ALREADY_EXISTS - directory already exists. - /// * PERMISSION_DENIED - dirname is not writable. - absl::Status CreateDir(const std::string& dirname); - - absl::Status CreateDir(const std::string& dirname, TransactionToken* token) { - return absl::OkStatus(); - } - - /// Deletes the specified directory. - absl::Status DeleteDir(const std::string& dirname); - - absl::Status DeleteDir(const std::string& dirname, TransactionToken* token) { - return absl::OkStatus(); - } - - /// Obtains statistics for the given path. - absl::Status Stat(const std::string& fname, FileStatistics* stat); - - absl::Status Stat(const std::string& fname, TransactionToken* token, - FileStatistics* stat) { - return absl::OkStatus(); - } - - /// \brief Returns whether the given path is a directory or not. - /// Typical return codes (not guaranteed exhaustive): - /// * OK - The path exists and is a directory. - /// * FAILED_PRECONDITION - The path exists and is not a directory. - /// * NOT_FOUND - The path entry does not exist. - /// * PERMISSION_DENIED - Insufficient permissions. - /// * UNIMPLEMENTED - The file factory doesn't support directories. - absl::Status IsDirectory(const std::string& fname); - - /// \brief Returns whether the given path is on a file system - /// that has atomic move capabilities. This can be used - /// to determine if there needs to be a temp location to safely write objects. - /// The second boolean argument has_atomic_move contains this information. - /// - /// Returns one of the following status codes (not guaranteed exhaustive): - /// * OK - The path is on a recognized file system, - /// so has_atomic_move holds the above information. - /// * UNIMPLEMENTED - The file system of the path hasn't been implemented in - /// TF - absl::Status HasAtomicMove(const std::string& path, bool* has_atomic_move); - - /// Returns whether the give path is on a file system - /// that has ability to create a new temp file. This can be used - /// to determine if there needs to be a temp location to safely write objects. - /// If this returns false, TensorFlow will write directly to output files - /// instead of creating a temporary file and swapping it in. This may mean - /// that incomplete writes are visible to consumers. - absl::Status CanCreateTempFile(const std::string& fname, - bool* can_create_temp_file); - - /// Stores the size of `fname` in `*file_size`. - absl::Status GetFileSize(const std::string& fname, uint64* file_size); - - absl::Status GetFileSize(const std::string& fname, TransactionToken* token, - uint64* file_size) { - return absl::OkStatus(); - } - - /// \brief Renames file src to target. If target already exists, it will be - /// replaced. - absl::Status RenameFile(const std::string& src, const std::string& target); - - absl::Status RenameFile(const std::string& src, const std::string& target, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Copy the src to target. - absl::Status CopyFile(const std::string& src, const std::string& target); - - absl::Status CopyFile(const std::string& src, const std::string& target, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief starts a new transaction on the filesystem that handles filename - absl::Status StartTransaction(const std::string& filename, - TransactionToken** token) { - *token = nullptr; - return absl::OkStatus(); - } - - /// \brief Adds `path` to transaction in `token` if token belongs to - /// filesystem that handles the path. - absl::Status AddToTransaction(const std::string& path, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Get token for `path` or start a new transaction and add `path` to - /// it. - absl::Status GetTokenOrStartTransaction(const std::string& path, - TransactionToken** token) { - *token = nullptr; - return absl::OkStatus(); - } - - /// \brief Returns the transaction for `path` or nullptr in `token` - absl::Status GetTransactionForPath(const std::string& path, - TransactionToken** token) { - *token = nullptr; - return absl::OkStatus(); - } - - /// \brief Finalizes the transaction - absl::Status EndTransaction(TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Returns the absolute path of the current executable. It resolves - /// symlinks if there is any. - std::string GetExecutablePath(); - - /// Creates a local unique temporary file name. Returns true if success. - bool LocalTempFilename(std::string* filename); - - /// Creates a local unique file name that starts with |prefix| and ends with - /// |suffix|. Returns true if success. - bool CreateUniqueFileName(std::string* prefix, const std::string& suffix); - - /// \brief Return the runfiles directory if running under bazel. Returns - /// the directory the executable is located in if not running under bazel. - virtual std::string GetRunfilesDir() = 0; - - // TODO(jeff,sanjay): Add back thread/thread-pool support if needed. - // TODO(jeff,sanjay): if needed, tighten spec so relative to epoch, or - // provide a routine to get the absolute time. - - /// \brief Returns the number of nano-seconds since the Unix epoch. - virtual uint64 NowNanos() const { return EnvTime::NowNanos(); } - - /// \brief Returns the number of micro-seconds since the Unix epoch. - virtual uint64 NowMicros() const { return EnvTime::NowMicros(); } - - /// \brief Returns the number of seconds since the Unix epoch. - virtual uint64 NowSeconds() const { return EnvTime::NowSeconds(); } - - /// Sleeps/delays the thread for the prescribed number of micro-seconds. - virtual void SleepForMicroseconds(int64_t micros) = 0; - - /// Returns the process ID of the calling process. - int32 GetProcessId(); - - /// \brief Returns a new thread that is running fn() and is identified - /// (for debugging/performance-analysis) by "name". - /// - /// Caller takes ownership of the result and must delete it eventually - /// (the deletion will block until fn() stops running). - virtual Thread* StartThread( - const ThreadOptions& thread_options, const std::string& name, - absl::AnyInvocable fn) TF_MUST_USE_RESULT = 0; - - // Returns the thread id of calling thread. - // Posix: Returns pthread id which is only guaranteed to be unique within a - // process. - // Windows: Returns thread id which is unique. - virtual int32 GetCurrentThreadId() = 0; - - // Copies current thread name to "name". Returns true if success. - virtual bool GetCurrentThreadName(std::string* name) = 0; - - // \brief Schedules the given closure on a thread-pool. - // - // NOTE(mrry): This closure may block. - virtual void SchedClosure(absl::AnyInvocable closure) = 0; - - // \brief Schedules the given closure on a thread-pool after the given number - // of microseconds. - // - // NOTE(mrry): This closure must not block. - virtual void SchedClosureAfter(int64_t micros, - absl::AnyInvocable closure) = 0; - - // \brief Load a dynamic library. - // - // Pass "library_filename" to a platform-specific mechanism for dynamically - // loading a library. The rules for determining the exact location of the - // library are platform-specific and are not documented here. - // - // On success, returns a handle to the library in "*handle" and returns - // OK from the function. - // Otherwise returns nullptr in "*handle" and an error status from the - // function. - virtual absl::Status LoadDynamicLibrary(const char* library_filename, - void** handle) = 0; - - // \brief Get a pointer to a symbol from a dynamic library. - // - // "handle" should be a pointer returned from a previous call to LoadLibrary. - // On success, store a pointer to the located symbol in "*symbol" and return - // OK from the function. Otherwise, returns nullptr in "*symbol" and an error - // status from the function. - virtual absl::Status GetSymbolFromLibrary(void* handle, - const char* symbol_name, - void** symbol) = 0; - - // \brief build the name of dynamic library. - // - // "name" should be name of the library. - // "version" should be the version of the library or NULL - // returns the name that LoadLibrary() can use - virtual std::string FormatLibraryFileName(const std::string& name, - const std::string& version) = 0; - - // Returns a possible list of local temporary directories. - virtual void GetLocalTempDirectories(std::vector* list) = 0; - - private: - std::unique_ptr file_system_registry_; - Env(const Env&) = delete; - void operator=(const Env&) = delete; -}; - -/// \brief An implementation of Env that forwards all calls to another Env. -/// -/// May be useful to clients who wish to override just part of the -/// functionality of another Env. -class EnvWrapper : public Env { - public: - /// Initializes an EnvWrapper that delegates all calls to *t - explicit EnvWrapper(Env* t) : target_(t) {} - ~EnvWrapper() override; - - /// Returns the target to which this Env forwards all calls - Env* target() const { return target_; } - - absl::Status GetFileSystemForFile(const std::string& fname, - FileSystem** result) override { - return target_->GetFileSystemForFile(fname, result); - } - - absl::Status GetRegisteredFileSystemSchemes( - std::vector* schemes) override { - return target_->GetRegisteredFileSystemSchemes(schemes); - } - - absl::Status RegisterFileSystem( - const std::string& scheme, FileSystemRegistry::Factory factory) override { - return target_->RegisterFileSystem(scheme, factory); - } - - bool MatchPath(const std::string& path, const std::string& pattern) override { - return target_->MatchPath(path, pattern); - } - - uint64 NowMicros() const override { return target_->NowMicros(); } - void SleepForMicroseconds(int64_t micros) override { - target_->SleepForMicroseconds(micros); - } - Thread* StartThread(const ThreadOptions& thread_options, - const std::string& name, - absl::AnyInvocable fn) override { - return target_->StartThread(thread_options, name, std::move(fn)); - } - int32 GetCurrentThreadId() override { return target_->GetCurrentThreadId(); } - bool GetCurrentThreadName(std::string* name) override { - return target_->GetCurrentThreadName(name); - } - void SchedClosure(absl::AnyInvocable closure) override { - target_->SchedClosure(std::move(closure)); - } - void SchedClosureAfter(int64_t micros, - absl::AnyInvocable closure) override { - target_->SchedClosureAfter(micros, std::move(closure)); - } - absl::Status LoadDynamicLibrary(const char* library_filename, - void** handle) override { - return target_->LoadDynamicLibrary(library_filename, handle); - } - absl::Status GetSymbolFromLibrary(void* handle, const char* symbol_name, - void** symbol) override { - return target_->GetSymbolFromLibrary(handle, symbol_name, symbol); - } - std::string FormatLibraryFileName(const std::string& name, - const std::string& version) override { - return target_->FormatLibraryFileName(name, version); - } - - std::string GetRunfilesDir() override { return target_->GetRunfilesDir(); } - - private: - void GetLocalTempDirectories(std::vector* list) override { - target_->GetLocalTempDirectories(list); - } - - Env* target_; -}; - -/// Represents a thread used to run a TSL function. -class Thread { - public: - Thread() {} - - /// Blocks until the thread of control stops running. - virtual ~Thread(); - - private: - Thread(const Thread&) = delete; - void operator=(const Thread&) = delete; -}; - -/// \brief Cross-platform setenv. -/// -/// Since setenv() is not available on windows, we provide an -/// alternative with platform specific implementations here. -int setenv(const char* name, const char* value, int overwrite); - -/// Cross-platform unsetenv. -int unsetenv(const char* name); - -/// \brief Options to configure a Thread. -/// -/// Note that the options are all hints, and the -/// underlying implementation may choose to ignore it. -struct ThreadOptions { - /// Thread stack size to use (in bytes). - size_t stack_size = 0; // 0: use system default value - /// Guard area size to use near thread stacks to use (in bytes) - size_t guard_size = 0; // 0: use system default value - int numa_node = port::kNUMANoAffinity; -}; - -/// A utility routine: copy contents of `src` in file system `src_fs` -/// to `target` in file system `target_fs`. -absl::Status FileSystemCopyFile(FileSystem* src_fs, const std::string& src, - FileSystem* target_fs, - const std::string& target); - -/// A utility routine: reads contents of named file into `*data` -absl::Status ReadFileToString(Env* env, const std::string& fname, - std::string* data); - -/// A utility routine: write contents of `data` to file named `fname` -/// (overwriting existing contents, if any). -absl::Status WriteStringToFile(Env* env, const std::string& fname, - const absl::string_view& data); - -/// Write binary representation of "proto" to the named file. -absl::Status WriteBinaryProto(Env* env, const std::string& fname, - const protobuf::MessageLite& proto); - -/// Reads contents of named file and parse as binary encoded proto data -/// and store into `*proto`. -absl::Status ReadBinaryProto(Env* env, const std::string& fname, - protobuf::MessageLite* proto); - -/// Write the text representation of "proto" to the named file. -inline absl::Status WriteTextProto(Env* /* env */, - const std::string& /* fname */, - const protobuf::MessageLite& /* proto */) { - return errors::Unimplemented("Can't write text protos with protolite."); -} -absl::Status WriteTextProto(Env* env, const std::string& fname, - const protobuf::Message& proto); - -/// Read contents of named file and parse as text encoded proto data -/// and store into `*proto`. -inline absl::Status ReadTextProto(Env* /* env */, - const std::string& /* fname */, - protobuf::MessageLite* /* proto */) { - return errors::Unimplemented("Can't parse text protos with protolite."); -} -absl::Status ReadTextProto(Env* env, const std::string& fname, - protobuf::Message* proto); - -/// Read contents of named file and parse as either text or binary encoded proto -/// data and store into `*proto`. -absl::Status ReadTextOrBinaryProto(Env* env, const std::string& fname, - protobuf::Message* proto); -absl::Status ReadTextOrBinaryProto(Env* env, const std::string& fname, - protobuf::MessageLite* proto); - -// START_SKIP_DOXYGEN - -// The following approach to register filesystems is deprecated and will be -// replaced with modular filesystem plugins registration. -// TODO(b/139060984): After all filesystems are converted, remove this. -namespace register_file_system { - -template -struct Register { - Register(Env* env, const std::string& scheme, bool try_modular_filesystems) { - // TODO(yongtang): Remove legacy file system registration for hdfs/s3/gcs - // after TF 2.6+. - if (try_modular_filesystems) { - const char* env_value = getenv("TF_USE_MODULAR_FILESYSTEM"); - string load_plugin = env_value ? absl::AsciiStrToLower(env_value) : ""; - if (load_plugin == "true" || load_plugin == "1") { - // We don't register the static filesystem and wait for SIG IO one - LOG(WARNING) << "Using modular file system for '" << scheme << "'." - << " Please switch to tensorflow-io" - << " (https://github.com/tensorflow/io) for file system" - << " support of '" << scheme << "'."; - return; - } - // If the envvar is missing or not "true"/"1", then fall back to legacy - // implementation to be backwards compatible. - } - // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object! - env->RegisterFileSystem(scheme, []() -> FileSystem* { return new Factory; }) - .IgnoreError(); - } -}; - -} // namespace register_file_system - -// END_SKIP_DOXYGEN - -} // namespace tsl - -// Register a FileSystem implementation for a scheme. Files with names that have -// "scheme://" prefixes are routed to use this implementation. -#define REGISTER_FILE_SYSTEM_ENV(env, scheme, factory, modular) \ - REGISTER_FILE_SYSTEM_UNIQ_HELPER(__COUNTER__, env, scheme, factory, modular) -#define REGISTER_FILE_SYSTEM_UNIQ_HELPER(ctr, env, scheme, factory, modular) \ - REGISTER_FILE_SYSTEM_UNIQ(ctr, env, scheme, factory, modular) -#define REGISTER_FILE_SYSTEM_UNIQ(ctr, env, scheme, factory, modular) \ - static ::tsl::register_file_system::Register register_ff##ctr \ - TF_ATTRIBUTE_UNUSED = \ - ::tsl::register_file_system::Register(env, scheme, modular) - -#define REGISTER_FILE_SYSTEM(scheme, factory) \ - REGISTER_FILE_SYSTEM_ENV(::tsl::Env::Default(), scheme, factory, false); - -#define REGISTER_LEGACY_FILE_SYSTEM(scheme, factory) \ - REGISTER_FILE_SYSTEM_ENV(::tsl::Env::Default(), scheme, factory, true); +#include "xla/tsl/platform/env.h" #endif // TENSORFLOW_TSL_PLATFORM_ENV_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/env_time.h b/third_party/xla/third_party/tsl/tsl/platform/env_time.h index 2ec888069ead32..eaadae805294a0 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/env_time.h +++ b/third_party/xla/third_party/tsl/tsl/platform/env_time.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,54 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #ifndef TENSORFLOW_TSL_PLATFORM_ENV_TIME_H_ #define TENSORFLOW_TSL_PLATFORM_ENV_TIME_H_ -#include - -#include "tsl/platform/types.h" - -namespace tsl { - -/// \brief An interface used by the tsl implementation to -/// access timer related operations. -class EnvTime { - public: - static constexpr uint64 kMicrosToPicos = 1000ULL * 1000ULL; - static constexpr uint64 kMicrosToNanos = 1000ULL; - static constexpr uint64 kMillisToMicros = 1000ULL; - static constexpr uint64 kMillisToNanos = 1000ULL * 1000ULL; - static constexpr uint64 kNanosToPicos = 1000ULL; - static constexpr uint64 kSecondsToMillis = 1000ULL; - static constexpr uint64 kSecondsToMicros = 1000ULL * 1000ULL; - static constexpr uint64 kSecondsToNanos = 1000ULL * 1000ULL * 1000ULL; - - EnvTime() = default; - virtual ~EnvTime() = default; - - /// \brief Returns the number of nano-seconds since the Unix epoch. - static uint64 NowNanos(); - - /// \brief Returns the number of micro-seconds since the Unix epoch. - static uint64 NowMicros() { return NowNanos() / kMicrosToNanos; } - - /// \brief Returns the number of seconds since the Unix epoch. - static uint64 NowSeconds() { return NowNanos() / kSecondsToNanos; } - - /// \brief A version of NowNanos() that may be overridden by a subclass. - virtual uint64 GetOverridableNowNanos() const { return NowNanos(); } - - /// \brief A version of NowMicros() that may be overridden by a subclass. - virtual uint64 GetOverridableNowMicros() const { - return GetOverridableNowNanos() / kMicrosToNanos; - } - - /// \brief A version of NowSeconds() that may be overridden by a subclass. - virtual uint64 GetOverridableNowSeconds() const { - return GetOverridableNowNanos() / kSecondsToNanos; - } -}; - -} // namespace tsl +#include "xla/tsl/platform/env_time.h" #endif // TENSORFLOW_TSL_PLATFORM_ENV_TIME_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/errors.h b/third_party/xla/third_party/tsl/tsl/platform/errors.h index 9be69959661e8a..0c28bd4188db21 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/errors.h +++ b/third_party/xla/third_party/tsl/tsl/platform/errors.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,631 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_ERRORS_H_ #define TENSORFLOW_TSL_PLATFORM_ERRORS_H_ -#include -#include -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/status/status.h" -#include "absl/strings/cord.h" -#include "absl/strings/str_join.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/status.h" -#include "tsl/platform/str_util.h" -#include "tsl/platform/strcat.h" - -namespace tsl { -namespace error { -// NOLINTBEGIN(misc-unused-using-decls) -// TODO(aminim): figure out the protobuf migration story. -using tensorflow::error::ABORTED; -using tensorflow::error::ALREADY_EXISTS; -using tensorflow::error::CANCELLED; -using tensorflow::error::Code; -using tensorflow::error::DATA_LOSS; -using tensorflow::error::DEADLINE_EXCEEDED; -using tensorflow::error::FAILED_PRECONDITION; -using tensorflow::error::INTERNAL; -using tensorflow::error::INVALID_ARGUMENT; -using tensorflow::error::NOT_FOUND; -using tensorflow::error::OK; -using tensorflow::error::OUT_OF_RANGE; -using tensorflow::error::PERMISSION_DENIED; -using tensorflow::error::RESOURCE_EXHAUSTED; -using tensorflow::error::UNAUTHENTICATED; -using tensorflow::error::UNAVAILABLE; -using tensorflow::error::UNIMPLEMENTED; -using tensorflow::error::UNKNOWN; -// NOLINTEND(misc-unused-using-decls) -} // namespace error - -namespace errors { - -namespace internal { - -// The DECLARE_ERROR macro below only supports types that can be converted -// into StrCat's AlphaNum. For the other types we rely on a slower path -// through std::stringstream. To add support of a new type, it is enough to -// make sure there is an operator<<() for it: -// -// std::ostream& operator<<(std::ostream& os, const MyType& foo) { -// os << foo.ToString(); -// return os; -// } -// Eventually absl::strings will have native support for this and we will be -// able to completely remove PrepareForStrCat(). -template -typename std::enable_if::value, - std::string>::type -PrepareForStrCat(const T& t) { - std::stringstream ss; - ss << t; - return ss.str(); -} -inline const strings::AlphaNum& PrepareForStrCat(const strings::AlphaNum& a) { - return a; -} - -} // namespace internal - -// Maps UNIX errors into a Status. -absl::Status IOError(const string& context, int err_number); - -// Returns all payloads from a Status as a key-value map. -inline std::unordered_map GetPayloads( - const absl::Status& status) { - std::unordered_map payloads; - status.ForEachPayload( - [&payloads](absl::string_view key, const absl::Cord& value) { - payloads[std::string(key)] = std::string(value); - }); - return payloads; -} - -// Inserts all given payloads into the given status. Will overwrite existing -// payloads if they exist with the same key. -inline void InsertPayloads( - absl::Status& status, - const std::unordered_map& payloads) { - for (const auto& payload : payloads) { - status.SetPayload(payload.first, absl::Cord(payload.second)); - } -} - -// Copies all payloads from one Status to another. Will overwrite existing -// payloads in the destination if they exist with the same key. -inline void CopyPayloads(const absl::Status& from, absl::Status& to) { - from.ForEachPayload([&to](absl::string_view key, const absl::Cord& value) { - to.SetPayload(key, value); - }); -} - -#if defined(PLATFORM_GOOGLE) -// Creates a new status with the given code, message and payloads. -inline absl::Status Create( - absl::StatusCode code, absl::string_view message, - const std::unordered_map& payloads, - absl::SourceLocation loc = absl::SourceLocation::current()) { - absl::Status status(code, message, loc); - InsertPayloads(status, payloads); - return status; -} -// Returns a new Status, replacing its message with the given. -inline absl::Status CreateWithUpdatedMessage(const absl::Status& status, - absl::string_view message) { - auto locations = status.GetSourceLocations(); - auto initial_loc = - locations.empty() ? absl::SourceLocation::current() : locations[0]; - absl::Status new_status = Create(static_cast(status.code()), - message, GetPayloads(status), initial_loc); - if (locations.size() > 1) { - for (auto loc : locations.subspan(1)) { - new_status.AddSourceLocation(loc); - } - } - return new_status; -} - -#else -inline ::absl::Status Create( - absl::StatusCode code, ::tsl::StringPiece message, - const std::unordered_map& payloads) { - Status status(code, message); - InsertPayloads(status, payloads); - return status; -} -// Returns a new Status, replacing its message with the given. -inline ::tsl::Status CreateWithUpdatedMessage(const ::tsl::Status& status, - ::tsl::StringPiece message) { - return Create(static_cast(status.code()), message, - GetPayloads(status)); -} -#endif - -// Append some context to an error message. Each time we append -// context put it on a new line, since it is possible for there -// to be several layers of additional context. -template -void AppendToMessage(absl::Status* status, Args... args) { - auto new_status = CreateWithUpdatedMessage( - *status, ::tsl::strings::StrCat(status->message(), "\n\t", args...)); - CopyPayloads(*status, new_status); - *status = std::move(new_status); -} - -// For propagating errors when calling a function. -#define TF_RETURN_IF_ERROR(...) \ - do { \ - ::absl::Status _status = (__VA_ARGS__); \ - if (TF_PREDICT_FALSE(!_status.ok())) { \ - MAYBE_ADD_SOURCE_LOCATION(_status) \ - return _status; \ - } \ - } while (0) - -#define TF_RETURN_WITH_CONTEXT_IF_ERROR(expr, ...) \ - do { \ - ::tsl::Status _status = (expr); \ - if (TF_PREDICT_FALSE(!_status.ok())) { \ - ::tsl::errors::AppendToMessage(&_status, __VA_ARGS__); \ - return _status; \ - } \ - } while (0) - -// Convenience functions for generating and using error status. -// Example usage: -// status.Update(errors::InvalidArgument("The ", foo, " isn't right.")); -// if (errors::IsInvalidArgument(status)) { ... } -// switch (status.code()) { case error::INVALID_ARGUMENT: ... } - -// CANCELLED -template -absl::Status Cancelled(Args... args) { - return absl::Status(absl::StatusCode::kCancelled, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status CancelledWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kCancelled, message, payloads); -} - -// InvalidArgument -template -absl::Status InvalidArgument(Args... args) { - return absl::Status(absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} - -#if defined(PLATFORM_GOOGLE) -// Specialized overloads to capture source location for up to three arguments. -template -::absl::Status InvalidArgument( - Arg1 arg1, Arg2 arg2, Arg3 arg3, Arg4 arg4, - absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2), - ::tsl::errors::internal::PrepareForStrCat(arg3), - ::tsl::errors::internal::PrepareForStrCat(arg4)), - loc); -} -template -::absl::Status InvalidArgument( - Arg1 arg1, Arg2 arg2, Arg3 arg3, - absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2), - ::tsl::errors::internal::PrepareForStrCat(arg3)), - loc); -} -template -::absl::Status InvalidArgument( - Arg1 arg1, Arg2 arg2, - absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2)), - loc); -} -template -::absl::Status InvalidArgument( - Arg1 arg1, absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1)), - loc); -} -template -::absl::Status InvalidArgumentWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads, - absl::SourceLocation loc = absl::SourceLocation::current()) { - return errors::Create(absl::StatusCode::kInvalidArgument, message, payloads, - loc); -} -#else -template -::absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2, Arg3 arg3) { - return ::absl::Status( - absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2), - ::tsl::errors::internal::PrepareForStrCat(arg3))); -} -template -::absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2) { - return ::absl::Status( - absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2))); -} -template -::absl::Status InvalidArgument(Arg1 arg1) { - return ::absl::Status( - absl::StatusCode::kInvalidArgument, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1))); -} -template -::absl::Status InvalidArgumentWithPayloads( - const ::tsl::StringPiece& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kInvalidArgument, message, payloads); -} -#endif - -// NotFound -template -absl::Status NotFound(Args... args) { - return absl::Status(absl::StatusCode::kNotFound, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -#if defined(PLATFORM_GOOGLE) -// Specialized overloads to capture source location for up to three arguments. -template -::absl::Status NotFound( - Arg1 arg1, Arg2 arg2, Arg3 arg3, - absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kNotFound, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2), - ::tsl::errors::internal::PrepareForStrCat(arg3)), - loc); -} -template -::absl::Status NotFound( - Arg1 arg1, Arg2 arg2, - absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kNotFound, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2)), - loc); -} -template -::absl::Status NotFound( - Arg1 arg1, absl::SourceLocation loc = absl::SourceLocation::current()) { - return absl::Status( - absl::StatusCode::kNotFound, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1)), - loc); -} -template -::absl::Status NotFoundWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads, - absl::SourceLocation loc = absl::SourceLocation::current()) { - return errors::Create(absl::StatusCode::kNotFound, message, payloads, loc); -} -#else -template -::absl::Status NotFound(Arg1 arg1, Arg2 arg2, Arg3 arg3) { - return ::absl::Status( - absl::StatusCode::kNotFound, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2), - ::tsl::errors::internal::PrepareForStrCat(arg3))); -} -template -::absl::Status NotFound(Arg1 arg1, Arg2 arg2) { - return ::absl::Status( - absl::StatusCode::kNotFound, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), - ::tsl::errors::internal::PrepareForStrCat(arg2))); -} -template -::absl::Status NotFound(Arg1 arg1) { - return ::absl::Status( - absl::StatusCode::kNotFound, - ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1))); -} -template -::absl::Status NotFoundWithPayloads( - const ::tsl::StringPiece& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kNotFound, message, payloads); -} -#endif - -// AlreadyExists -template -absl::Status AlreadyExists(Args... args) { - return absl::Status(absl::StatusCode::kAlreadyExists, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status AlreadyExistsWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kAlreadyExists, message, payloads); -} - -// ResourceExhausted -template -absl::Status ResourceExhausted(Args... args) { - return absl::Status(absl::StatusCode::kResourceExhausted, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status ResourceExhaustedWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kResourceExhausted, message, - payloads); -} - -// Unavailable -template -absl::Status Unavailable(Args... args) { - return absl::Status(absl::StatusCode::kUnavailable, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status UnavailableWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kUnavailable, message, payloads); -} - -// FailedPrecondition -template -absl::Status FailedPrecondition(Args... args) { - return absl::Status(absl::StatusCode::kFailedPrecondition, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status FailedPreconditionWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kFailedPrecondition, message, - payloads); -} - -// OutOfRange -template -absl::Status OutOfRange(Args... args) { - return absl::Status(absl::StatusCode::kOutOfRange, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status OutOfRangeWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kOutOfRange, message, payloads); -} - -// Unimplemented -template -absl::Status Unimplemented(Args... args) { - return absl::Status(absl::StatusCode::kUnimplemented, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status UnimplementedWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kUnimplemented, message, payloads); -} - -// Internal -template -absl::Status Internal(Args... args) { - return absl::Status(absl::StatusCode::kInternal, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status InternalWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kInternal, message, payloads); -} - -// Aborted -template -absl::Status Aborted(Args... args) { - return absl::Status(absl::StatusCode::kAborted, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status AbortedWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kAborted, message, payloads); -} - -// DeadlineExceeded -template -absl::Status DeadlineExceeded(Args... args) { - return absl::Status(absl::StatusCode::kDeadlineExceeded, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status DeadlineExceededWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kDeadlineExceeded, message, payloads); -} - -// DataLoss -template -absl::Status DataLoss(Args... args) { - return absl::Status(absl::StatusCode::kDataLoss, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status DataLossWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kDataLoss, message, payloads); -} - -// Unknown -template -absl::Status Unknown(Args... args) { - return absl::Status(absl::StatusCode::kUnknown, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status UnknownPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kUnknown, message, payloads); -} -// PermissionDenied -template -absl::Status PermissionDenied(Args... args) { - return absl::Status(absl::StatusCode::kPermissionDenied, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status PermissionDeniedWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kPermissionDenied, message, payloads); -} - -// Unauthenticated -template -absl::Status Unauthenticated(Args... args) { - return absl::Status(absl::StatusCode::kUnauthenticated, - ::tsl::strings::StrCat( - ::tsl::errors::internal::PrepareForStrCat(args)...)); -} -template -absl::Status UnauthenticatedWithPayloads( - const absl::string_view& message, - const std::unordered_map& payloads) { - return errors::Create(absl::StatusCode::kUnauthenticated, message, payloads); -} - -bool IsAborted(const absl::Status& status); -bool IsAlreadyExists(const absl::Status& status); -bool IsCancelled(const absl::Status& status); -bool IsDataLoss(const absl::Status& status); -bool IsDeadlineExceeded(const absl::Status& status); -bool IsFailedPrecondition(const absl::Status& status); -bool IsInternal(const absl::Status& status); -bool IsInvalidArgument(const absl::Status& status); -bool IsNotFound(const absl::Status& status); -bool IsOutOfRange(const absl::Status& status); -bool IsPermissionDenied(const absl::Status& status); -bool IsResourceExhausted(const absl::Status& status); -bool IsUnauthenticated(const absl::Status& status); -bool IsUnavailable(const absl::Status& status); -bool IsUnimplemented(const absl::Status& status); -bool IsUnknown(const absl::Status& status); - -// Produces a formatted string pattern from the name which can uniquely identify -// this node upstream to produce an informative error message. The pattern -// followed is: {{node }} -// Note: The pattern below determines the regex _NODEDEF_NAME_RE in the file -// tensorflow/python/client/session.py -// LINT.IfChange -inline std::string FormatNodeNameForError(absl::string_view name) { - return strings::StrCat("{{node ", name, "}}"); -} -// LINT.ThenChange(//tensorflow/python/client/session.py) -template -std::string FormatNodeNamesForError(const T& names) { - return absl::StrJoin( - names, ", ", [](std::string* output, absl::string_view s) { - ::tsl::strings::StrAppend(output, FormatNodeNameForError(s)); - }); -} -// LINT.IfChange -inline std::string FormatColocationNodeForError(absl::string_view name) { - return strings::StrCat("{{colocation_node ", name, "}}"); -} -// LINT.ThenChange(//tensorflow/python/framework/error_interpolation.py) -template >> -std::string FormatColocationNodeForError(const T& names) { - return absl::StrJoin( - names, ", ", [](std::string* output, absl::string_view s) { - ::tsl::strings::StrAppend(output, FormatColocationNodeForError(s)); - }); -} - -inline std::string FormatFunctionForError(absl::string_view name) { - return strings::StrCat("{{function_node ", name, "}}"); -} - -inline absl::Status ReplaceErrorFromNonCommunicationOps( - const absl::Status s, absl::string_view op_name) { - assert(::tsl::errors::IsUnavailable(s)); - return absl::Status( - absl::StatusCode::kInternal, - strings::StrCat( - s.message(), "\nExecuting non-communication op <", op_name, - "> originally returned UnavailableError, and was replaced by " - "InternalError to avoid invoking TF network error handling logic.")); -} - -template -std::string FormatOriginalNodeLocationForError(const T& node_names, - const T& func_names) { - std::vector error_message; - for (int i = 0; i != node_names.size(); ++i) { - if (i != 0) { - error_message.push_back(", "); - } - if (i < func_names.size()) { - error_message.push_back(FormatFunctionForError(func_names[i])); - } - error_message.push_back(FormatNodeNameForError(node_names[i])); - } - return absl::StrJoin(error_message, ""); -} - -// The CanonicalCode() for non-errors. -using ::tsl::error::OK; // NOLINT - -} // namespace errors -} // namespace tsl +#include "xla/tsl/platform/errors.h" #endif // TENSORFLOW_TSL_PLATFORM_ERRORS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/file_statistics.h b/third_party/xla/third_party/tsl/tsl/platform/file_statistics.h index ebe50be46ae811..07bf908edbaf22 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/file_statistics.h +++ b/third_party/xla/third_party/tsl/tsl/platform/file_statistics.h @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,24 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_FILE_STATISTICS_H_ #define TENSORFLOW_TSL_PLATFORM_FILE_STATISTICS_H_ -#include "tsl/platform/types.h" - -namespace tsl { - -struct FileStatistics { - // The length of the file or -1 if finding file length is not supported. - int64_t length = -1; - // The last modified time in nanoseconds. - int64_t mtime_nsec = 0; - // True if the file is a directory, otherwise false. - bool is_directory = false; - - FileStatistics() {} - FileStatistics(int64_t length, int64_t mtime_nsec, bool is_directory) - : length(length), mtime_nsec(mtime_nsec), is_directory(is_directory) {} - ~FileStatistics() {} -}; - -} // namespace tsl +#include "xla/tsl/platform/file_statistics.h" #endif // TENSORFLOW_TSL_PLATFORM_FILE_STATISTICS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/file_system.h b/third_party/xla/third_party/tsl/tsl/platform/file_system.h index 8b48788261368e..8d55471a5766f2 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/file_system.h +++ b/third_party/xla/third_party/tsl/tsl/platform/file_system.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,921 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_FILE_SYSTEM_H_ #define TENSORFLOW_TSL_PLATFORM_FILE_SYSTEM_H_ -#include - -#include -#include -#include -#include -#include -#include - -#include "tsl/platform/cord.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/file_statistics.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/platform.h" -#include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" - -#ifdef PLATFORM_WINDOWS -#undef DeleteFile -#undef CopyFile -#undef TranslateName -#endif - -namespace tsl { - -class FileAcl; -class RandomAccessFile; -class ReadOnlyMemoryRegion; -class WritableFile; - -class FileSystem; -struct TransactionToken { - FileSystem* owner; - void* token; -}; - -/// A generic interface for accessing a file system. Implementations -/// of custom filesystem adapters must implement this interface, -/// RandomAccessFile, WritableFile, and ReadOnlyMemoryRegion classes. -class FileSystem { - public: - /// \brief Creates a brand new random access read-only file with the - /// specified name. - /// - /// On success, stores a pointer to the new file in - /// *result and returns OK. On failure stores NULL in *result and - /// returns non-OK. If the file does not exist, returns a non-OK - /// status. - /// - /// The returned file may be concurrently accessed by multiple threads. - /// - /// The ownership of the returned RandomAccessFile is passed to the caller - /// and the object should be deleted when is not used. - virtual absl::Status NewRandomAccessFile( - const std::string& fname, std::unique_ptr* result) { - return NewRandomAccessFile(fname, nullptr, result); - } - - virtual absl::Status NewRandomAccessFile( - const std::string& fname, TransactionToken* token, - std::unique_ptr* result) { - // We duplicate these methods due to Google internal coding style prevents - // virtual functions with default arguments. See PR #41615. - return absl::OkStatus(); - } - - /// \brief Creates an object that writes to a new file with the specified - /// name. - /// - /// Deletes any existing file with the same name and creates a - /// new file. On success, stores a pointer to the new file in - /// *result and returns OK. On failure stores NULL in *result and - /// returns non-OK. - /// - /// The returned file will only be accessed by one thread at a time. - /// - /// The ownership of the returned WritableFile is passed to the caller - /// and the object should be deleted when is not used. - virtual absl::Status NewWritableFile(const std::string& fname, - std::unique_ptr* result) { - return NewWritableFile(fname, nullptr, result); - } - - virtual absl::Status NewWritableFile(const std::string& fname, - TransactionToken* token, - std::unique_ptr* result) { - return absl::OkStatus(); - } - - /// \brief Creates an object that either appends to an existing file, or - /// writes to a new file (if the file does not exist to begin with). - /// - /// On success, stores a pointer to the new file in *result and - /// returns OK. On failure stores NULL in *result and returns - /// non-OK. - /// - /// The returned file will only be accessed by one thread at a time. - /// - /// The ownership of the returned WritableFile is passed to the caller - /// and the object should be deleted when is not used. - virtual absl::Status NewAppendableFile( - const std::string& fname, std::unique_ptr* result) { - return NewAppendableFile(fname, nullptr, result); - } - - virtual absl::Status NewAppendableFile( - const std::string& fname, TransactionToken* token, - std::unique_ptr* result) { - return absl::OkStatus(); - } - - /// \brief Creates a readonly region of memory with the file context. - /// - /// On success, it returns a pointer to read-only memory region - /// from the content of file fname. The ownership of the region is passed to - /// the caller. On failure stores nullptr in *result and returns non-OK. - /// - /// The returned memory region can be accessed from many threads in parallel. - /// - /// The ownership of the returned ReadOnlyMemoryRegion is passed to the caller - /// and the object should be deleted when is not used. - virtual absl::Status NewReadOnlyMemoryRegionFromFile( - const std::string& fname, std::unique_ptr* result) { - return NewReadOnlyMemoryRegionFromFile(fname, nullptr, result); - } - - virtual absl::Status NewReadOnlyMemoryRegionFromFile( - const std::string& fname, TransactionToken* token, - std::unique_ptr* result) { - return absl::OkStatus(); - } - - /// Returns OK if the named path exists and NOT_FOUND otherwise. - virtual absl::Status FileExists(const std::string& fname) { - return FileExists(fname, nullptr); - } - - virtual absl::Status FileExists(const std::string& fname, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// Returns true if all the listed files exist, false otherwise. - /// if status is not null, populate the vector with a detailed status - /// for each file. - virtual bool FilesExist(const std::vector& files, - std::vector* status) { - return FilesExist(files, nullptr, status); - } - - virtual bool FilesExist(const std::vector& files, - TransactionToken* token, - std::vector* status); - - /// \brief Returns the immediate children in the given directory. - /// - /// The returned paths are relative to 'dir'. - virtual absl::Status GetChildren(const std::string& dir, - std::vector* result) { - return GetChildren(dir, nullptr, result); - } - - virtual absl::Status GetChildren(const std::string& dir, - TransactionToken* token, - std::vector* result) { - return absl::OkStatus(); - } - - /// \brief Given a pattern, stores in *results the set of paths that matches - /// that pattern. *results is cleared. - /// - /// pattern must match all of a name, not just a substring. - /// - /// pattern: { term } - /// term: - /// '*': matches any sequence of non-'/' characters - /// '?': matches a single non-'/' character - /// '[' [ '^' ] { match-list } ']': - /// matches any single character (not) on the list - /// c: matches character c (c != '*', '?', '\\', '[') - /// '\\' c: matches character c - /// character-range: - /// c: matches character c (c != '\\', '-', ']') - /// '\\' c: matches character c - /// lo '-' hi: matches character c for lo <= c <= hi - /// - /// Typical return codes: - /// * OK - no errors - /// * UNIMPLEMENTED - Some underlying functions (like GetChildren) are not - /// implemented - virtual absl::Status GetMatchingPaths(const std::string& pattern, - std::vector* results) { - return GetMatchingPaths(pattern, nullptr, results); - } - - virtual absl::Status GetMatchingPaths(const std::string& pattern, - TransactionToken* token, - std::vector* results) { - return absl::OkStatus(); - } - - /// \brief Checks if the given filename matches the pattern. - /// - /// This function provides the equivalent of posix fnmatch, however it is - /// implemented without fnmatch to ensure that this can be used for cloud - /// filesystems on windows. For windows filesystems, it uses PathMatchSpec. - virtual bool Match(const std::string& filename, const std::string& pattern); - - /// \brief Obtains statistics for the given path. - virtual absl::Status Stat(const std::string& fname, FileStatistics* stat) { - return Stat(fname, nullptr, stat); - } - - virtual absl::Status Stat(const std::string& fname, TransactionToken* token, - FileStatistics* stat) { - return absl::OkStatus(); - } - - /// \brief Deletes the named file. - virtual absl::Status DeleteFile(const std::string& fname) { - return DeleteFile(fname, nullptr); - } - - virtual absl::Status DeleteFile(const std::string& fname, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Creates the specified directory. - /// Typical return codes: - /// * OK - successfully created the directory. - /// * ALREADY_EXISTS - directory with name dirname already exists. - /// * PERMISSION_DENIED - dirname is not writable. - virtual absl::Status CreateDir(const std::string& dirname) { - return CreateDir(dirname, nullptr); - } - - virtual absl::Status CreateDir(const std::string& dirname, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Creates the specified directory and all the necessary - /// subdirectories. - /// Typical return codes: - /// * OK - successfully created the directory and sub directories, even if - /// they were already created. - /// * PERMISSION_DENIED - dirname or some subdirectory is not writable. - virtual absl::Status RecursivelyCreateDir(const std::string& dirname) { - return RecursivelyCreateDir(dirname, nullptr); - } - - virtual absl::Status RecursivelyCreateDir(const std::string& dirname, - TransactionToken* token); - - /// \brief Deletes the specified directory. - virtual absl::Status DeleteDir(const std::string& dirname) { - return DeleteDir(dirname, nullptr); - } - - virtual absl::Status DeleteDir(const std::string& dirname, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Deletes the specified directory and all subdirectories and files - /// underneath it. This is accomplished by traversing the directory tree - /// rooted at dirname and deleting entries as they are encountered. - /// - /// If dirname itself is not readable or does not exist, *undeleted_dir_count - /// is set to 1, *undeleted_file_count is set to 0 and an appropriate status - /// (e.g. NOT_FOUND) is returned. - /// - /// If dirname and all its descendants were successfully deleted, TF_OK is - /// returned and both error counters are set to zero. - /// - /// Otherwise, while traversing the tree, undeleted_file_count and - /// undeleted_dir_count are updated if an entry of the corresponding type - /// could not be deleted. The returned error status represents the reason that - /// any one of these entries could not be deleted. - /// - /// REQUIRES: undeleted_files, undeleted_dirs to be not null. - /// - /// Typical return codes: - /// * OK - dirname exists and we were able to delete everything underneath. - /// * NOT_FOUND - dirname doesn't exist - /// * PERMISSION_DENIED - dirname or some descendant is not writable - /// * UNIMPLEMENTED - Some underlying functions (like Delete) are not - /// implemented - virtual absl::Status DeleteRecursively(const std::string& dirname, - int64_t* undeleted_files, - int64_t* undeleted_dirs) { - return DeleteRecursively(dirname, nullptr, undeleted_files, undeleted_dirs); - } - - virtual absl::Status DeleteRecursively(const std::string& dirname, - TransactionToken* token, - int64_t* undeleted_files, - int64_t* undeleted_dirs); - - /// \brief Stores the size of `fname` in `*file_size`. - virtual absl::Status GetFileSize(const std::string& fname, - uint64* file_size) { - return GetFileSize(fname, nullptr, file_size); - } - - virtual absl::Status GetFileSize(const std::string& fname, - TransactionToken* token, uint64* file_size) { - return absl::OkStatus(); - } - - /// \brief Overwrites the target if it exists. - virtual absl::Status RenameFile(const std::string& src, - const std::string& target) { - return RenameFile(src, target, nullptr); - } - - virtual absl::Status RenameFile(const std::string& src, - const std::string& target, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Copy the src to target. - virtual absl::Status CopyFile(const std::string& src, - const std::string& target) { - return CopyFile(src, target, nullptr); - } - - virtual absl::Status CopyFile(const std::string& src, - const std::string& target, - TransactionToken* token); - - /// \brief Translate an URI to a filename for the FileSystem implementation. - /// - /// The implementation in this class cleans up the path, removing - /// duplicate /'s, resolving .. and removing trailing '/'. - /// This respects relative vs. absolute paths, but does not - /// invoke any system calls (getcwd(2)) in order to resolve relative - /// paths with respect to the actual working directory. That is, this is - /// purely string manipulation, completely independent of process state. - virtual std::string TranslateName(const std::string& name) const; - - /// \brief Returns whether the given path is a directory or not. - /// - /// Typical return codes (not guaranteed exhaustive): - /// * OK - The path exists and is a directory. - /// * FAILED_PRECONDITION - The path exists and is not a directory. - /// * NOT_FOUND - The path entry does not exist. - /// * PERMISSION_DENIED - Insufficient permissions. - /// * UNIMPLEMENTED - The file factory doesn't support directories. - virtual absl::Status IsDirectory(const std::string& fname) { - return IsDirectory(fname, nullptr); - } - - virtual absl::Status IsDirectory(const std::string& fname, - TransactionToken* token); - - /// \brief Returns whether the given path is on a file system - /// that has atomic move capabilities. This can be used - /// to determine if there needs to be a temp location to safely write objects. - /// The second boolean argument has_atomic_move contains this information. - /// - /// Returns one of the following status codes (not guaranteed exhaustive): - /// * OK - The path is on a recognized file system, - /// so has_atomic_move holds the above information. - /// * UNIMPLEMENTED - The file system of the path hasn't been implemented in - /// TF - virtual absl::Status HasAtomicMove(const std::string& path, - bool* has_atomic_move); - - /// Returns whether the give path is on a file system - /// that has ability to create a new temp file. This can be used - /// to determine if there needs to be a temp location to safely write objects. - /// If the file system cannot create a temp file, it's possibile that - /// uncomplete result may appear in the given file. - virtual absl::Status CanCreateTempFile(const std::string& fname, - bool* can_create_temp_file); - - /// \brief Flushes any cached filesystem objects from memory. - virtual void FlushCaches() { FlushCaches(nullptr); } - - virtual void FlushCaches(TransactionToken* token); - - /// \brief The separator this filesystem uses. - /// - /// This is implemented as a part of the filesystem, because even on windows, - /// a user may need access to filesystems with '/' separators, such as cloud - /// filesystems. - virtual char Separator() const; - - /// \brief Split a path to its basename and dirname. - /// - /// Helper function for Basename and Dirname. - std::pair SplitPath( - absl::string_view uri) const; - - /// \brief returns the final file name in the given path. - /// - /// Returns the part of the path after the final "/". If there is no - /// "/" in the path, the result is the same as the input. - virtual absl::string_view Basename(absl::string_view path) const; - - /// \brief Returns the part of the path before the final "/". - /// - /// If there is a single leading "/" in the path, the result will be the - /// leading "/". If there is no "/" in the path, the result is the empty - /// prefix of the input. - absl::string_view Dirname(absl::string_view path) const; - - /// \brief Returns the part of the basename of path after the final ".". - /// - /// If there is no "." in the basename, the result is empty. - absl::string_view Extension(absl::string_view path) const; - - /// \brief Clean duplicate and trailing, "/"s, and resolve ".." and ".". - /// - /// NOTE: This respects relative vs. absolute paths, but does not - /// invoke any system calls (getcwd(2)) in order to resolve relative - /// paths with respect to the actual working directory. That is, this is - /// purely string manipulation, completely independent of process state. - std::string CleanPath(absl::string_view path) const; - - /// \brief Creates a URI from a scheme, host, and path. - /// - /// If the scheme is empty, we just return the path. - std::string CreateURI(absl::string_view scheme, absl::string_view host, - absl::string_view path) const; - - /// \brief Return true if path is absolute. - bool IsAbsolutePath(absl::string_view path) const; - -#ifndef SWIG // variadic templates - /// \brief Join multiple paths together. - /// - /// This function also removes the unnecessary path separators. - /// For example: - /// - /// Arguments | JoinPath - /// ---------------------------+---------- - /// '/foo', 'bar' | /foo/bar - /// '/foo/', 'bar' | /foo/bar - /// '/foo', '/bar' | /foo/bar - /// - /// Usage: - /// string path = io::JoinPath("/mydir", filename); - /// string path = io::JoinPath(FLAGS_test_srcdir, filename); - /// string path = io::JoinPath("/full", "path", "to", "filename"); - template - std::string JoinPath(const T&... args) { - return JoinPathImpl({args...}); - } -#endif /* SWIG */ - - std::string JoinPathImpl(std::initializer_list paths); - - /// \brief Populates the scheme, host, and path from a URI. - /// - /// scheme, host, and path are guaranteed by this function to point into the - /// contents of uri, even if empty. - /// - /// Corner cases: - /// - If the URI is invalid, scheme and host are set to empty strings and the - /// passed string is assumed to be a path - /// - If the URI omits the path (e.g. file://host), then the path is left - /// empty. - void ParseURI(absl::string_view remaining, absl::string_view* scheme, - absl::string_view* host, absl::string_view* path) const; - - // Transaction related API - - /// \brief Starts a new transaction - virtual absl::Status StartTransaction(TransactionToken** token) { - *token = nullptr; - return absl::OkStatus(); - } - - /// \brief Adds `path` to transaction in `token` - virtual absl::Status AddToTransaction(const std::string& path, - TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Ends transaction - virtual absl::Status EndTransaction(TransactionToken* token) { - return absl::OkStatus(); - } - - /// \brief Get token for `path` or start a new transaction and add `path` to - /// it. - virtual absl::Status GetTokenOrStartTransaction(const std::string& path, - TransactionToken** token) { - *token = nullptr; - return absl::OkStatus(); - } - - /// \brief Return transaction for `path` or nullptr in `token` - virtual absl::Status GetTransactionForPath(const std::string& path, - TransactionToken** token) { - *token = nullptr; - return absl::OkStatus(); - } - - /// \brief Decode transaction to human readable string. - virtual std::string DecodeTransaction(const TransactionToken* token); - - /// \brief Set File System Configuration Options - virtual absl::Status SetOption(const string& key, const string& value) { - return errors::Unimplemented("SetOption"); - } - - /// \brief Set File System Configuration Option - virtual absl::Status SetOption(const std::string& name, - const std::vector& values) { - return errors::Unimplemented("SetOption"); - } - - /// \brief Set File System Configuration Option - virtual absl::Status SetOption(const std::string& name, - const std::vector& values) { - return errors::Unimplemented("SetOption"); - } - - /// \brief Set File System Configuration Option - virtual absl::Status SetOption(const std::string& name, - const std::vector& values) { - return errors::Unimplemented("SetOption"); - } - - /// \brief Set File System ACL checker. - /// - /// No checks are enforced if a FileAcl is never set. - virtual absl::Status SetFileAcl(std::shared_ptr file_acl) { - return errors::Unimplemented("SetFileAcl"); - } - - FileSystem() {} - - virtual ~FileSystem() = default; -}; -/// This macro adds forwarding methods from FileSystem class to -/// used class since name hiding will prevent these to be accessed from -/// derived classes and would require all use locations to migrate to -/// Transactional API. This is an interim solution until ModularFileSystem class -/// becomes a singleton. -// TODO(sami): Remove this macro when filesystem plugins migration is complete. -#define TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT \ - using FileSystem::NewRandomAccessFile; \ - using FileSystem::NewWritableFile; \ - using FileSystem::NewAppendableFile; \ - using FileSystem::NewReadOnlyMemoryRegionFromFile; \ - using FileSystem::FileExists; \ - using FileSystem::GetChildren; \ - using FileSystem::GetMatchingPaths; \ - using FileSystem::Stat; \ - using FileSystem::DeleteFile; \ - using FileSystem::RecursivelyCreateDir; \ - using FileSystem::DeleteDir; \ - using FileSystem::DeleteRecursively; \ - using FileSystem::GetFileSize; \ - using FileSystem::RenameFile; \ - using FileSystem::CopyFile; \ - using FileSystem::IsDirectory; \ - using FileSystem::FlushCaches - -/// A Wrapper class for Transactional FileSystem support. -/// This provides means to make use of the transactions with minimal code change -/// Any operations that are done through this interface will be through the -/// transaction created at the time of construction of this instance. -/// See FileSystem documentation for method descriptions. -/// This class simply forwards all calls to wrapped filesystem either with given -/// transaction token or with token used in its construction. This allows doing -/// transactional filesystem access with minimal code change. -class WrappedFileSystem : public FileSystem { - public: - TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT; - - absl::Status NewRandomAccessFile( - const std::string& fname, TransactionToken* token, - std::unique_ptr* result) override { - return fs_->NewRandomAccessFile(fname, (token ? token : token_), result); - } - - absl::Status NewWritableFile(const std::string& fname, - TransactionToken* token, - std::unique_ptr* result) override { - return fs_->NewWritableFile(fname, (token ? token : token_), result); - } - - absl::Status NewAppendableFile( - const std::string& fname, TransactionToken* token, - std::unique_ptr* result) override { - return fs_->NewAppendableFile(fname, (token ? token : token_), result); - } - - absl::Status NewReadOnlyMemoryRegionFromFile( - const std::string& fname, TransactionToken* token, - std::unique_ptr* result) override { - return fs_->NewReadOnlyMemoryRegionFromFile(fname, (token ? token : token_), - result); - } - - absl::Status FileExists(const std::string& fname, - TransactionToken* token) override { - return fs_->FileExists(fname, (token ? token : token_)); - } - - bool FilesExist(const std::vector& files, TransactionToken* token, - std::vector* status) override { - return fs_->FilesExist(files, (token ? token : token_), status); - } - - absl::Status GetChildren(const std::string& dir, TransactionToken* token, - std::vector* result) override { - return fs_->GetChildren(dir, (token ? token : token_), result); - } - - absl::Status GetMatchingPaths(const std::string& pattern, - TransactionToken* token, - std::vector* results) override { - return fs_->GetMatchingPaths(pattern, (token ? token : token_), results); - } - - bool Match(const std::string& filename, const std::string& pattern) override { - return fs_->Match(filename, pattern); - } - - absl::Status Stat(const std::string& fname, TransactionToken* token, - FileStatistics* stat) override { - return fs_->Stat(fname, (token ? token : token_), stat); - } - - absl::Status DeleteFile(const std::string& fname, - TransactionToken* token) override { - return fs_->DeleteFile(fname, (token ? token : token_)); - } - - absl::Status CreateDir(const std::string& dirname, - TransactionToken* token) override { - return fs_->CreateDir(dirname, (token ? token : token_)); - } - - absl::Status RecursivelyCreateDir(const std::string& dirname, - TransactionToken* token) override { - return fs_->RecursivelyCreateDir(dirname, (token ? token : token_)); - } - - absl::Status DeleteDir(const std::string& dirname, - TransactionToken* token) override { - return fs_->DeleteDir(dirname, (token ? token : token_)); - } - - absl::Status DeleteRecursively(const std::string& dirname, - TransactionToken* token, - int64_t* undeleted_files, - int64_t* undeleted_dirs) override { - return fs_->DeleteRecursively(dirname, (token ? token : token_), - undeleted_files, undeleted_dirs); - } - - absl::Status GetFileSize(const std::string& fname, TransactionToken* token, - uint64* file_size) override { - return fs_->GetFileSize(fname, (token ? token : token_), file_size); - } - - absl::Status RenameFile(const std::string& src, const std::string& target, - TransactionToken* token) override { - return fs_->RenameFile(src, target, (token ? token : token_)); - } - - absl::Status CopyFile(const std::string& src, const std::string& target, - TransactionToken* token) override { - return fs_->CopyFile(src, target, (token ? token : token_)); - } - - std::string TranslateName(const std::string& name) const override { - return fs_->TranslateName(name); - } - - absl::Status IsDirectory(const std::string& fname, - TransactionToken* token) override { - return fs_->IsDirectory(fname, (token ? token : token_)); - } - - absl::Status HasAtomicMove(const std::string& path, - bool* has_atomic_move) override { - return fs_->HasAtomicMove(path, has_atomic_move); - } - - void FlushCaches(TransactionToken* token) override { - return fs_->FlushCaches((token ? token : token_)); - } - - char Separator() const override { return fs_->Separator(); } - - absl::string_view Basename(absl::string_view path) const override { - return fs_->Basename(path); - } - - absl::Status StartTransaction(TransactionToken** token) override { - return fs_->StartTransaction(token); - } - - absl::Status AddToTransaction(const std::string& path, - TransactionToken* token) override { - return fs_->AddToTransaction(path, (token ? token : token_)); - } - - absl::Status EndTransaction(TransactionToken* token) override { - return fs_->EndTransaction(token); - } - - absl::Status GetTransactionForPath(const std::string& path, - TransactionToken** token) override { - return fs_->GetTransactionForPath(path, token); - } - - absl::Status GetTokenOrStartTransaction(const std::string& path, - TransactionToken** token) override { - return fs_->GetTokenOrStartTransaction(path, token); - } - - std::string DecodeTransaction(const TransactionToken* token) override { - return fs_->DecodeTransaction((token ? token : token_)); - } - - WrappedFileSystem(FileSystem* file_system, TransactionToken* token) - : fs_(file_system), token_(token) {} - - ~WrappedFileSystem() override = default; - - private: - FileSystem* fs_; - TransactionToken* token_; -}; - -/// A file abstraction for randomly reading the contents of a file. -class RandomAccessFile { - public: - RandomAccessFile() {} - virtual ~RandomAccessFile() = default; - - /// \brief Returns the name of the file. - /// - /// This is an optional operation that may not be implemented by every - /// filesystem. - virtual absl::Status Name(absl::string_view* result) const { - return errors::Unimplemented("This filesystem does not support Name()"); - } - - /// \brief Reads up to `n` bytes from the file starting at `offset`. - /// - /// `scratch[0..n-1]` may be written by this routine. Sets `*result` - /// to the data that was read (including if fewer than `n` bytes were - /// successfully read). May set `*result` to point at data in - /// `scratch[0..n-1]`, so `scratch[0..n-1]` must be live when - /// `*result` is used. - /// - /// On OK returned status: `n` bytes have been stored in `*result`. - /// On non-OK returned status: `[0..n]` bytes have been stored in `*result`. - /// - /// Returns `OUT_OF_RANGE` if fewer than n bytes were stored in `*result` - /// because of EOF. - /// - /// Safe for concurrent use by multiple threads. - virtual absl::Status Read(uint64 offset, size_t n, absl::string_view* result, - char* scratch) const = 0; - -#if defined(TF_CORD_SUPPORT) - /// \brief Read up to `n` bytes from the file starting at `offset`. - virtual absl::Status Read(uint64 offset, size_t n, absl::Cord* cord) const { - return errors::Unimplemented( - "Read(uint64, size_t, absl::Cord*) is not " - "implemented"); - } -#endif - - private: - RandomAccessFile(const RandomAccessFile&) = delete; - void operator=(const RandomAccessFile&) = delete; -}; - -/// \brief A file abstraction for sequential writing. -/// -/// The implementation must provide buffering since callers may append -/// small fragments at a time to the file. -class WritableFile { - public: - WritableFile() {} - virtual ~WritableFile() = default; - - /// \brief Append 'data' to the file. - virtual absl::Status Append(absl::string_view data) = 0; - -#if defined(TF_CORD_SUPPORT) - // \brief Append 'data' to the file. - virtual absl::Status Append(const absl::Cord& cord) { - for (absl::string_view chunk : cord.Chunks()) { - TF_RETURN_IF_ERROR(Append(chunk)); - } - return absl::OkStatus(); - } -#endif - - /// \brief Close the file. - /// - /// Flush() and de-allocate resources associated with this file - /// - /// Typical return codes (not guaranteed to be exhaustive): - /// * OK - /// * Other codes, as returned from Flush() - virtual absl::Status Close() = 0; - - /// \brief Flushes the file and optionally syncs contents to filesystem. - /// - /// This should flush any local buffers whose contents have not been - /// delivered to the filesystem. - /// - /// If the process terminates after a successful flush, the contents - /// may still be persisted, since the underlying filesystem may - /// eventually flush the contents. If the OS or machine crashes - /// after a successful flush, the contents may or may not be - /// persisted, depending on the implementation. - virtual absl::Status Flush() = 0; - - // \brief Returns the name of the file. - /// - /// This is an optional operation that may not be implemented by every - /// filesystem. - virtual absl::Status Name(absl::string_view* result) const { - return errors::Unimplemented("This filesystem does not support Name()"); - } - - /// \brief Syncs contents of file to filesystem. - /// - /// This waits for confirmation from the filesystem that the contents - /// of the file have been persisted to the filesystem; if the OS - /// or machine crashes after a successful Sync, the contents should - /// be properly saved. - virtual absl::Status Sync() = 0; - - /// \brief Retrieves the current write position in the file, or -1 on - /// error. - /// - /// This is an optional operation, subclasses may choose to return - /// errors::Unimplemented. - virtual absl::Status Tell(int64_t* position) { - *position = -1; - return errors::Unimplemented("This filesystem does not support Tell()"); - } - - private: - WritableFile(const WritableFile&) = delete; - void operator=(const WritableFile&) = delete; -}; - -/// \brief A readonly memmapped file abstraction. -/// -/// The implementation must guarantee that all memory is accessible when the -/// object exists, independently from the Env that created it. -class ReadOnlyMemoryRegion { - public: - ReadOnlyMemoryRegion() {} - virtual ~ReadOnlyMemoryRegion() = default; - - /// \brief Returns a pointer to the memory region. - virtual const void* data() = 0; - - /// \brief Returns the length of the memory region in bytes. - virtual uint64 length() = 0; -}; - -/// \brief A registry for file system implementations. -/// -/// Filenames are specified as an URI, which is of the form -/// [scheme://]. -/// File system implementations are registered using the REGISTER_FILE_SYSTEM -/// macro, providing the 'scheme' as the key. -/// -/// There are two `Register` methods: one using `Factory` for legacy filesystems -/// (deprecated mechanism of subclassing `FileSystem` and using -/// `REGISTER_FILE_SYSTEM` macro), and one using `std::unique_ptr` -/// for the new modular approach. -/// -/// Note that the new API expects a pointer to `ModularFileSystem` but this is -/// not checked as there should be exactly one caller to the API and doing the -/// check results in a circular dependency between `BUILD` targets. -/// -/// Plan is to completely remove the filesystem registration from `Env` and -/// incorporate it into `ModularFileSystem` class (which will be renamed to be -/// the only `FileSystem` class and marked as `final`). But this will happen at -/// a later time, after we convert all filesystems to the new API. -/// -/// TODO(b/139060984): After all filesystems are converted, remove old -/// registration and update comment. -class FileSystemRegistry { - public: - typedef std::function Factory; - - virtual ~FileSystemRegistry() = default; - virtual absl::Status Register(const std::string& scheme, Factory factory) = 0; - virtual absl::Status Register(const std::string& scheme, - std::unique_ptr filesystem) = 0; - virtual FileSystem* Lookup(const std::string& scheme) = 0; - virtual absl::Status GetRegisteredFileSystemSchemes( - std::vector* schemes) = 0; -}; - -/// \brief An abstraction for enforcing ACL checks in FileSystem. -class FileAcl { - public: - virtual absl::Status CheckAccess(std::string_view path) = 0; - virtual ~FileAcl() = default; -}; - -} // namespace tsl +#include "xla/tsl/platform/file_system.h" #endif // TENSORFLOW_TSL_PLATFORM_FILE_SYSTEM_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/file_system_helper.h b/third_party/xla/third_party/tsl/tsl/platform/file_system_helper.h index e9e7df6aa68907..49a0bd1c2a8f82 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/file_system_helper.h +++ b/third_party/xla/third_party/tsl/tsl/platform/file_system_helper.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,49 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_FILE_SYSTEM_HELPER_H_ #define TENSORFLOW_TSL_PLATFORM_FILE_SYSTEM_HELPER_H_ -#include -#include - -#include "tsl/platform/env.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" - -namespace tsl { - -class FileSystem; -class Env; - -namespace internal { - -// Given a pattern, stores in 'results' the set of paths (in the given file -// system) that match that pattern. -// -// This helper may be used by implementations of FileSystem::GetMatchingPaths() -// in order to provide parallel scanning of subdirectories (except on iOS). -// -// Arguments: -// fs: may not be null and will be used to identify directories and list -// their contents. -// env: may not be null and will be used to check if a match has been found. -// pattern: see FileSystem::GetMatchingPaths() for details. -// results: will be cleared and may not be null. -// -// Returns an error status if any call to 'fs' failed. -absl::Status GetMatchingPaths(FileSystem* fs, Env* env, const string& pattern, - std::vector* results); - -// Given a file path, determines whether the file exists. This helper simplifies -// the use of Env::FileExists. -// -// Arguments: -// env: may not be null. -// fname: the file path to look up -// -// Returns true if the file exists, false if it does not exist, or an error -// Status. -absl::StatusOr FileExists(Env* env, const string& fname); - -} // namespace internal -} // namespace tsl +#include "xla/tsl/platform/file_system_helper.h" #endif // TENSORFLOW_TSL_PLATFORM_FILE_SYSTEM_HELPER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/fingerprint.h b/third_party/xla/third_party/tsl/tsl/platform/fingerprint.h index b5be7200332e41..33d2b707092d6f 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/fingerprint.h +++ b/third_party/xla/third_party/tsl/tsl/platform/fingerprint.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_FINGERPRINT_H_ #define TENSORFLOW_TSL_PLATFORM_FINGERPRINT_H_ +#include "xla/tsl/platform/types.h" #include "tsl/platform/platform.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" #if TSL_IS_IN_OSS #define USE_OSS_FARMHASH diff --git a/third_party/xla/third_party/tsl/tsl/platform/fingerprint_test.cc b/third_party/xla/third_party/tsl/tsl/platform/fingerprint_test.cc index 7cbdceb685cc06..2a40d863f78d66 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/fingerprint_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/fingerprint_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace { diff --git a/third_party/xla/third_party/tsl/tsl/platform/hash.cc b/third_party/xla/third_party/tsl/tsl/platform/hash.cc index a9d3bd65d403a5..325aa93b088c9c 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/hash.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/hash.cc @@ -17,9 +17,9 @@ limitations under the License. #include -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/raw_coding.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/hash.h b/third_party/xla/third_party/tsl/tsl/platform/hash.h index 2e18b440a263d3..174b233c2d3b25 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/hash.h +++ b/third_party/xla/third_party/tsl/tsl/platform/hash.h @@ -24,8 +24,8 @@ limitations under the License. #include #include +#include "xla/tsl/platform/types.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/hash_test.cc b/third_party/xla/third_party/tsl/tsl/platform/hash_test.cc index 7b4752e729107c..010ccde8374694 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/hash_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/hash_test.cc @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tsl/platform/hash.h" + #include #include #include -#include "tsl/platform/hash.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/host_info.h b/third_party/xla/third_party/tsl/tsl/platform/host_info.h index 630f9424525e04..687045c02c1a6e 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/host_info.h +++ b/third_party/xla/third_party/tsl/tsl/platform/host_info.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace port { diff --git a/third_party/xla/third_party/tsl/tsl/platform/human_readable_json.h b/third_party/xla/third_party/tsl/tsl/platform/human_readable_json.h index ae7b9ee7fc4b38..3fedff0630e964 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/human_readable_json.h +++ b/third_party/xla/third_party/tsl/tsl/platform/human_readable_json.h @@ -20,8 +20,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/protobuf.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/integral_types_test.cc b/third_party/xla/third_party/tsl/tsl/platform/integral_types_test.cc index 0ce3c497a067f5..80655dbee9407d 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/integral_types_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/integral_types_test.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace { diff --git a/third_party/xla/third_party/tsl/tsl/platform/intrusive_ptr_test.cc b/third_party/xla/third_party/tsl/tsl/platform/intrusive_ptr_test.cc index ff7a28de648554..6257729e28cffa 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/intrusive_ptr_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/intrusive_ptr_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tsl/platform/intrusive_ptr.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/refcount.h" -#include "tsl/platform/test.h" namespace tsl { namespace core { diff --git a/third_party/xla/third_party/tsl/tsl/platform/logging.h b/third_party/xla/third_party/tsl/tsl/platform/logging.h index 93939888230464..193cb9b5118f5d 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/logging.h +++ b/third_party/xla/third_party/tsl/tsl/platform/logging.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,14 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_LOGGING_H_ #define TENSORFLOW_TSL_PLATFORM_LOGGING_H_ -#include "tsl/platform/platform.h" - -#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) || \ - defined(PLATFORM_GOOGLE_IOS) || defined(GOOGLE_LOGGING) || \ - defined(__EMSCRIPTEN__) || defined(PLATFORM_CHROMIUMOS) -#include "xla/tsl/platform/google/logging.h" // IWYU pragma: export -#else -#include "xla/tsl/platform/default/logging.h" // IWYU pragma: export -#endif +#include "xla/tsl/platform/logging.h" #endif // TENSORFLOW_TSL_PLATFORM_LOGGING_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/macros.h b/third_party/xla/third_party/tsl/tsl/platform/macros.h index cb91c4ff64e847..960d7ed2e2accf 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/macros.h +++ b/third_party/xla/third_party/tsl/tsl/platform/macros.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,147 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_MACROS_H_ #define TENSORFLOW_TSL_PLATFORM_MACROS_H_ -// Compiler attributes -#if (defined(__GNUC__) || defined(__APPLE__)) && !defined(SWIG) -// Compiler supports GCC-style attributes -#define TF_ATTRIBUTE_NORETURN __attribute__((noreturn)) -#define TF_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline)) -#define TF_ATTRIBUTE_NOINLINE __attribute__((noinline)) -#define TF_ATTRIBUTE_UNUSED __attribute__((unused)) -#define TF_ATTRIBUTE_COLD __attribute__((cold)) -#define TF_ATTRIBUTE_WEAK __attribute__((weak)) -#define TF_PACKED __attribute__((packed)) -#define TF_MUST_USE_RESULT __attribute__((warn_unused_result)) -#define TF_PRINTF_ATTRIBUTE(string_index, first_to_check) \ - __attribute__((__format__(__printf__, string_index, first_to_check))) -#define TF_SCANF_ATTRIBUTE(string_index, first_to_check) \ - __attribute__((__format__(__scanf__, string_index, first_to_check))) -#elif defined(_MSC_VER) -// Non-GCC equivalents -#define TF_ATTRIBUTE_NORETURN __declspec(noreturn) -#define TF_ATTRIBUTE_ALWAYS_INLINE __forceinline -#define TF_ATTRIBUTE_NOINLINE -#define TF_ATTRIBUTE_UNUSED -#define TF_ATTRIBUTE_COLD -#define TF_ATTRIBUTE_WEAK -#define TF_MUST_USE_RESULT -#define TF_PACKED -#define TF_PRINTF_ATTRIBUTE(string_index, first_to_check) -#define TF_SCANF_ATTRIBUTE(string_index, first_to_check) -#else -// Non-GCC equivalents -#define TF_ATTRIBUTE_NORETURN -#define TF_ATTRIBUTE_ALWAYS_INLINE -#define TF_ATTRIBUTE_NOINLINE -#define TF_ATTRIBUTE_UNUSED -#define TF_ATTRIBUTE_COLD -#define TF_ATTRIBUTE_WEAK -#define TF_MUST_USE_RESULT -#define TF_PACKED -#define TF_PRINTF_ATTRIBUTE(string_index, first_to_check) -#define TF_SCANF_ATTRIBUTE(string_index, first_to_check) -#endif - -// Control visibility outside .so -#if defined(_WIN32) -#ifdef TF_COMPILE_LIBRARY -#define TF_EXPORT __declspec(dllexport) -#else -#define TF_EXPORT __declspec(dllimport) -#endif // TF_COMPILE_LIBRARY -#else -#define TF_EXPORT __attribute__((visibility("default"))) -#endif // _WIN32 - -#ifdef __has_builtin -#define TF_HAS_BUILTIN(x) __has_builtin(x) -#else -#define TF_HAS_BUILTIN(x) 0 -#endif - -// C++11-style attributes (N2761) -#if defined(__has_cpp_attribute) -// Safely checks if an attribute is supported. Equivalent to -// ABSL_HAVE_CPP_ATTRIBUTE. -#define TF_HAS_CPP_ATTRIBUTE(n) __has_cpp_attribute(n) -#else -#define TF_HAS_CPP_ATTRIBUTE(n) 0 -#endif - -// [[clang::annotate("x")]] allows attaching custom strings (e.g. "x") to -// declarations (variables, functions, fields, etc.) for use by tools. They are -// represented in the Clang AST (as AnnotateAttr nodes) and in LLVM IR, but not -// in final output. -#if TF_HAS_CPP_ATTRIBUTE(clang::annotate) -#define TF_ATTRIBUTE_ANNOTATE(str) [[clang::annotate(str)]] -#else -#define TF_ATTRIBUTE_ANNOTATE(str) -#endif - -// A variable declaration annotated with the `TF_CONST_INIT` attribute will -// not compile (on supported platforms) unless the variable has a constant -// initializer. -#if TF_HAS_CPP_ATTRIBUTE(clang::require_constant_initialization) -#define TF_CONST_INIT [[clang::require_constant_initialization]] -#else -#define TF_CONST_INIT -#endif - -// Compilers can be told that a certain branch is not likely to be taken -// (for instance, a CHECK failure), and use that information in static -// analysis. Giving it this information can help it optimize for the -// common case in the absence of better information (ie. -// -fprofile-arcs). -#if TF_HAS_BUILTIN(__builtin_expect) || (defined(__GNUC__) && __GNUC__ >= 3) -#define TF_PREDICT_FALSE(x) (__builtin_expect(x, 0)) -#define TF_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) -#else -#define TF_PREDICT_FALSE(x) (x) -#define TF_PREDICT_TRUE(x) (x) -#endif - -// DEPRECATED: directly use the macro implementation instead. -// A macro to disallow the copy constructor and operator= functions -// This is usually placed in the private: declarations for a class. -#define TF_DISALLOW_COPY_AND_ASSIGN(TypeName) \ - TypeName(const TypeName&) = delete; \ - void operator=(const TypeName&) = delete - -// The TF_ARRAYSIZE(arr) macro returns the # of elements in an array arr. -// -// The expression TF_ARRAYSIZE(a) is a compile-time constant of type -// size_t. -#define TF_ARRAYSIZE(a) \ - ((sizeof(a) / sizeof(*(a))) / \ - static_cast(!(sizeof(a) % sizeof(*(a))))) - -#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L || \ - (defined(_MSC_VER) && _MSC_VER >= 1900) -// Define this to 1 if the code is compiled in C++11 mode; leave it -// undefined otherwise. Do NOT define it to 0 -- that causes -// '#ifdef LANG_CXX11' to behave differently from '#if LANG_CXX11'. -#define LANG_CXX11 1 -#endif - -#if defined(__clang__) && defined(LANG_CXX11) && defined(__has_warning) -#if __has_feature(cxx_attributes) && __has_warning("-Wimplicit-fallthrough") -#define TF_FALLTHROUGH_INTENDED [[clang::fallthrough]] // NOLINT -#endif -#endif - -#ifndef TF_FALLTHROUGH_INTENDED -#define TF_FALLTHROUGH_INTENDED \ - do { \ - } while (0) -#endif - -namespace tsl { -namespace internal { -template -void remove_unused_variable_compiler_warning(const T&){}; -} // namespace internal -} // namespace tsl -#define TF_UNUSED_VARIABLE(x) \ - tensorflow::internal::remove_unused_variable_compiler_warning(x) +#include "xla/tsl/platform/macros.h" #endif // TENSORFLOW_TSL_PLATFORM_MACROS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/mem.h b/third_party/xla/third_party/tsl/tsl/platform/mem.h index 6d0dc803e93b80..bc975ae17643b9 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/mem.h +++ b/third_party/xla/third_party/tsl/tsl/platform/mem.h @@ -17,9 +17,9 @@ limitations under the License. #define TENSORFLOW_TSL_PLATFORM_MEM_H_ // TODO(cwhipkey): remove this when callers use annotations directly. +#include "xla/tsl/platform/types.h" #include "tsl/platform/dynamic_annotations.h" #include "tsl/platform/platform.h" -#include "tsl/platform/types.h" namespace tsl { namespace port { diff --git a/third_party/xla/third_party/tsl/tsl/platform/ml_dtypes.h b/third_party/xla/third_party/tsl/tsl/platform/ml_dtypes.h index 89a40bd891e106..a03fa02447f3c6 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/ml_dtypes.h +++ b/third_party/xla/third_party/tsl/tsl/platform/ml_dtypes.h @@ -18,8 +18,10 @@ limitations under the License. #include "ml_dtypes/include/float8.h" // from @ml_dtypes #include "ml_dtypes/include/intn.h" // from @ml_dtypes +#include "ml_dtypes/include/mxfloat.h" // from @ml_dtypes namespace tsl { +using float4_e2m1fn = ::ml_dtypes::float4_e2m1fn; using float8_e3m4 = ::ml_dtypes::float8_e3m4; using float8_e4m3 = ::ml_dtypes::float8_e4m3; using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn; @@ -27,7 +29,10 @@ using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz; using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz; using float8_e5m2 = ::ml_dtypes::float8_e5m2; using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz; +using float8_e8m0fnu = ::ml_dtypes::float8_e8m0fnu; +using int1 = ::ml_dtypes::int1; +using uint1 = ::ml_dtypes::uint1; using int2 = ::ml_dtypes::int2; using uint2 = ::ml_dtypes::uint2; using int4 = ::ml_dtypes::int4; diff --git a/third_party/xla/third_party/tsl/tsl/platform/mutex_test.cc b/third_party/xla/third_party/tsl/tsl/platform/mutex_test.cc index b5444ae721eaff..58c46c2b4a2327 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/mutex_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/mutex_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tsl/platform/mutex.h" -#include "tsl/platform/test.h" -#include "tsl/platform/threadpool.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/threadpool.h" namespace tsl { namespace { diff --git a/third_party/xla/third_party/tsl/tsl/platform/net_test.cc b/third_party/xla/third_party/tsl/tsl/platform/net_test.cc index 2d39042df2ea93..d99c7cb3952777 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/net_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/net_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tsl/platform/net.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace internal { diff --git a/third_party/xla/third_party/tsl/tsl/platform/null_file_system.h b/third_party/xla/third_party/tsl/tsl/platform/null_file_system.h index c04d2c1f0d6056..8c88298589b066 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/null_file_system.h +++ b/third_party/xla/third_party/tsl/tsl/platform/null_file_system.h @@ -20,9 +20,9 @@ limitations under the License. #include #include -#include "tsl/platform/env.h" -#include "tsl/platform/file_system.h" -#include "tsl/platform/file_system_helper.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/file_system.h" +#include "xla/tsl/platform/file_system_helper.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/numa.h b/third_party/xla/third_party/tsl/tsl/platform/numa.h index 997d03d4974382..12a65894a0cc9d 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/numa.h +++ b/third_party/xla/third_party/tsl/tsl/platform/numa.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_NUMA_H_ #define TENSORFLOW_TSL_PLATFORM_NUMA_H_ +#include "xla/tsl/platform/types.h" #include "tsl/platform/platform.h" -#include "tsl/platform/types.h" namespace tsl { namespace port { diff --git a/third_party/xla/third_party/tsl/tsl/platform/numa_test.cc b/third_party/xla/third_party/tsl/tsl/platform/numa_test.cc index d01a5d76a0a873..047053b1924e34 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/numa_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/numa_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tsl/platform/numa.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace internal { diff --git a/third_party/xla/third_party/tsl/tsl/platform/numbers.cc b/third_party/xla/third_party/tsl/tsl/platform/numbers.cc index 7239e6fff7a51d..54609b06f010de 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/numbers.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/numbers.cc @@ -20,18 +20,20 @@ limitations under the License. #include #include -#include +#include #include #include #include +#include +#include // NOLINT #include -#include "double-conversion/double-conversion.h" -#include "tsl/platform/str_util.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/stringprintf.h" -#include "tsl/platform/types.h" namespace tsl { @@ -114,17 +116,6 @@ T locale_independent_strtonum(const char* str, const char** endptr) { return result; } -static inline const double_conversion::StringToDoubleConverter& -StringToFloatConverter() { - static const double_conversion::StringToDoubleConverter converter( - double_conversion::StringToDoubleConverter::ALLOW_LEADING_SPACES | - double_conversion::StringToDoubleConverter::ALLOW_HEX | - double_conversion::StringToDoubleConverter::ALLOW_TRAILING_SPACES | - double_conversion::StringToDoubleConverter::ALLOW_CASE_INSENSIBILITY, - 0., 0., "inf", "nan"); - return converter; -} - } // namespace namespace strings { @@ -219,154 +210,6 @@ size_t DoubleToBuffer(double value, char* buffer) { return snprintf_result; } -namespace { -char SafeFirstChar(absl::string_view str) { - if (str.empty()) return '\0'; - return str[0]; -} -void SkipSpaces(absl::string_view* str) { - while (isspace(SafeFirstChar(*str))) str->remove_prefix(1); -} -} // namespace - -bool safe_strto64(absl::string_view str, int64_t* value) { - SkipSpaces(&str); - - int64_t vlimit = kint64max; - int sign = 1; - if (absl::ConsumePrefix(&str, "-")) { - sign = -1; - // Different limit for positive and negative integers. - vlimit = kint64min; - } - - if (!isdigit(SafeFirstChar(str))) return false; - - int64_t result = 0; - if (sign == 1) { - do { - int digit = SafeFirstChar(str) - '0'; - if ((vlimit - digit) / 10 < result) { - return false; - } - result = result * 10 + digit; - str.remove_prefix(1); - } while (isdigit(SafeFirstChar(str))); - } else { - do { - int digit = SafeFirstChar(str) - '0'; - if ((vlimit + digit) / 10 > result) { - return false; - } - result = result * 10 - digit; - str.remove_prefix(1); - } while (isdigit(SafeFirstChar(str))); - } - - SkipSpaces(&str); - if (!str.empty()) return false; - - *value = result; - return true; -} - -bool safe_strtou64(absl::string_view str, uint64_t* value) { - SkipSpaces(&str); - if (!isdigit(SafeFirstChar(str))) return false; - - uint64_t result = 0; - do { - int digit = SafeFirstChar(str) - '0'; - if ((kuint64max - digit) / 10 < result) { - return false; - } - result = result * 10 + digit; - str.remove_prefix(1); - } while (isdigit(SafeFirstChar(str))); - - SkipSpaces(&str); - if (!str.empty()) return false; - - *value = result; - return true; -} - -bool safe_strto32(absl::string_view str, int32_t* value) { - SkipSpaces(&str); - - int64_t vmax = kint32max; - int sign = 1; - if (absl::ConsumePrefix(&str, "-")) { - sign = -1; - // Different max for positive and negative integers. - ++vmax; - } - - if (!isdigit(SafeFirstChar(str))) return false; - - int64_t result = 0; - do { - result = result * 10 + SafeFirstChar(str) - '0'; - if (result > vmax) { - return false; - } - str.remove_prefix(1); - } while (isdigit(SafeFirstChar(str))); - - SkipSpaces(&str); - - if (!str.empty()) return false; - - *value = static_cast(result * sign); - return true; -} - -bool safe_strtou32(absl::string_view str, uint32_t* value) { - SkipSpaces(&str); - if (!isdigit(SafeFirstChar(str))) return false; - - int64_t result = 0; - do { - result = result * 10 + SafeFirstChar(str) - '0'; - if (result > kuint32max) { - return false; - } - str.remove_prefix(1); - } while (isdigit(SafeFirstChar(str))); - - SkipSpaces(&str); - if (!str.empty()) return false; - - *value = static_cast(result); - return true; -} - -bool safe_strtof(absl::string_view str, float* value) { - int processed_characters_count = -1; - auto len = str.size(); - - // If string length exceeds buffer size or int max, fail. - if (len >= kFastToBufferSize) return false; - if (len > std::numeric_limits::max()) return false; - - *value = StringToFloatConverter().StringToFloat( - str.data(), static_cast(len), &processed_characters_count); - return processed_characters_count > 0; -} - -bool safe_strtod(absl::string_view str, double* value) { - int processed_characters_count = -1; - auto len = str.size(); - - // If string length exceeds buffer size or int max, fail. - if (len >= kFastToBufferSize) return false; - if (len > std::numeric_limits::max()) return false; - - *value = StringToFloatConverter().StringToDouble( - str.data(), static_cast(len), &processed_characters_count); - return processed_characters_count > 0; -} - size_t FloatToBuffer(float value, char* buffer) { // FLT_DIG is 6 for IEEE-754 floats, which are used on almost all // platforms these days. Just in case some system exists where FLT_DIG @@ -390,7 +233,7 @@ size_t FloatToBuffer(float value, char* buffer) { DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); float parsed_value; - if (!safe_strtof(buffer, &parsed_value) || parsed_value != value) { + if (!absl::SimpleAtof(buffer, &parsed_value) || parsed_value != value) { snprintf_result = snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG + 3, value); @@ -401,51 +244,21 @@ size_t FloatToBuffer(float value, char* buffer) { } std::string FpToString(Fprint fp) { - char buf[17]; - snprintf(buf, sizeof(buf), "%016llx", static_cast(fp)); - return std::string(buf); + return absl::StrCat(absl::Hex(fp, absl::kZeroPad16)); } -bool StringToFp(const std::string& s, Fprint* fp) { - char junk; - uint64_t result; - if (sscanf(s.c_str(), "%" SCNx64 "%c", &result, &junk) == 1) { - *fp = result; - return true; - } else { +bool HexStringToUint64(absl::string_view s, uint64_t* result) { + auto end_ptr = s.data() + s.size(); + uint64_t parsed_result; + auto [ptr, ec] = + std::from_chars(s.data(), end_ptr, parsed_result, /*base=*/16); + if (ec != std::errc{}) { return false; } -} - -absl::string_view Uint64ToHexString(uint64_t v, char* buf) { - static const char* hexdigits = "0123456789abcdef"; - const int num_byte = 16; - buf[num_byte] = '\0'; - for (int i = num_byte - 1; i >= 0; i--) { - buf[i] = hexdigits[v & 0xf]; - v >>= 4; - } - return absl::string_view(buf, num_byte); -} - -bool HexStringToUint64(const absl::string_view& s, uint64_t* result) { - uint64_t v = 0; - if (s.empty()) { + if (ptr != end_ptr) { return false; } - for (size_t i = 0; i < s.size(); i++) { - char c = s[i]; - if (c >= '0' && c <= '9') { - v = (v << 4) + (c - '0'); - } else if (c >= 'a' && c <= 'f') { - v = (v << 4) + 10 + (c - 'a'); - } else if (c >= 'A' && c <= 'F') { - v = (v << 4) + 10 + (c - 'A'); - } else { - return false; - } - } - *result = v; + *result = parsed_result; return true; } diff --git a/third_party/xla/third_party/tsl/tsl/platform/numbers.h b/third_party/xla/third_party/tsl/tsl/platform/numbers.h index 0d62f425361927..0f4dc84e2fa18e 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/numbers.h +++ b/third_party/xla/third_party/tsl/tsl/platform/numbers.h @@ -16,11 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_NUMBERS_H_ #define TENSORFLOW_TSL_PLATFORM_NUMBERS_H_ +#include #include #include +#include "absl/base/macros.h" +#include "absl/strings/numbers.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" namespace tsl { namespace strings { @@ -46,7 +49,7 @@ namespace strings { // Int64, UInt64, Int, Uint: 22 bytes // Time: 30 bytes // Use kFastToBufferSize rather than hardcoding constants. -static const int kFastToBufferSize = 32; +inline constexpr int kFastToBufferSize = 32; // ---------------------------------------------------------------------- // FastInt32ToBufferLeft() @@ -77,75 +80,83 @@ size_t FloatToBuffer(float value, char* buffer); // Convert a 64-bit fingerprint value to an ASCII representation. std::string FpToString(Fprint fp); -// Attempt to parse a fingerprint in the form encoded by FpToString. If -// successful, stores the fingerprint in *fp and returns true. Otherwise, -// returns false. -bool StringToFp(const std::string& s, Fprint* fp); - -// Convert a 64-bit fingerprint value to an ASCII representation that -// is terminated by a '\0'. -// Buf must point to an array of at least kFastToBufferSize characters -absl::string_view Uint64ToHexString(uint64_t v, char* buf); - -// Attempt to parse a uint64 in the form encoded by FastUint64ToHexString. If -// successful, stores the value in *v and returns true. Otherwise, -// returns false. -bool HexStringToUint64(const absl::string_view& s, uint64_t* result); +// Attempt to parse a `uint64_t` in the form encoded by +// `absl::StrCat(absl::Hex(*result))`. If successful, stores the value in +// `result` and returns true. Otherwise, returns false. +bool HexStringToUint64(absl::string_view s, uint64_t* result); // Convert strings to 32bit integer values. // Leading and trailing spaces are allowed. // Return false with overflow or invalid input. -bool safe_strto32(absl::string_view str, int32_t* value); +ABSL_DEPRECATE_AND_INLINE() +inline bool safe_strto32(absl::string_view str, int32_t* value) { + return absl::SimpleAtoi(str, value); +} // Convert strings to unsigned 32bit integer values. // Leading and trailing spaces are allowed. // Return false with overflow or invalid input. -bool safe_strtou32(absl::string_view str, uint32_t* value); +ABSL_DEPRECATE_AND_INLINE() +inline bool safe_strtou32(absl::string_view str, uint32_t* value) { + return absl::SimpleAtoi(str, value); +} // Convert strings to 64bit integer values. // Leading and trailing spaces are allowed. // Return false with overflow or invalid input. -bool safe_strto64(absl::string_view str, int64_t* value); +ABSL_DEPRECATE_AND_INLINE() +inline bool safe_strto64(absl::string_view str, int64_t* value) { + return absl::SimpleAtoi(str, value); +} // Convert strings to unsigned 64bit integer values. // Leading and trailing spaces are allowed. // Return false with overflow or invalid input. -bool safe_strtou64(absl::string_view str, uint64_t* value); +ABSL_DEPRECATE_AND_INLINE() +inline bool safe_strtou64(absl::string_view str, uint64_t* value) { + return absl::SimpleAtoi(str, value); +} // Convert strings to floating point values. // Leading and trailing spaces are allowed. // Values may be rounded on over- and underflow. // Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`. -bool safe_strtof(absl::string_view str, float* value); +ABSL_DEPRECATE_AND_INLINE() +inline bool safe_strtof(absl::string_view str, float* value) { + return absl::SimpleAtof(str, value); +} // Convert strings to double precision floating point values. // Leading and trailing spaces are allowed. // Values may be rounded on over- and underflow. // Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`. -bool safe_strtod(absl::string_view str, double* value); +ABSL_DEPRECATE_AND_INLINE() +inline bool safe_strtod(absl::string_view str, double* value) { + return absl::SimpleAtod(str, value); +} inline bool ProtoParseNumeric(absl::string_view s, int32_t* value) { - return safe_strto32(s, value); + return absl::SimpleAtoi(s, value); } inline bool ProtoParseNumeric(absl::string_view s, uint32_t* value) { - return safe_strtou32(s, value); + return absl::SimpleAtoi(s, value); } inline bool ProtoParseNumeric(absl::string_view s, int64_t* value) { - return safe_strto64(s, value); + return absl::SimpleAtoi(s, value); } inline bool ProtoParseNumeric(absl::string_view s, uint64_t* value) { - return safe_strtou64(s, value); + return absl::SimpleAtoi(s, value); } inline bool ProtoParseNumeric(absl::string_view s, float* value) { - return safe_strtof(s, value); + return absl::SimpleAtof(s, value); } inline bool ProtoParseNumeric(absl::string_view s, double* value) { - return safe_strtod(s, value); + return absl::SimpleAtod(s, value); } // Convert strings to number of type T. diff --git a/third_party/xla/third_party/tsl/tsl/platform/numbers_test.cc b/third_party/xla/third_party/tsl/tsl/platform/numbers_test.cc index 0ce574e597dea9..2c90bee0c5256c 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/numbers_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/numbers_test.cc @@ -18,7 +18,9 @@ limitations under the License. #include #include -#include "tsl/platform/test.h" +#include "absl/strings/str_cat.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace strings { @@ -26,29 +28,28 @@ namespace strings { // NOTE: most of the routines in numbers.h are tested indirectly through // strcat_test.cc in this directory. -// Test StrCat of ints and longs of various sizes and signdedness. +// Test StrCat of ints and longs of various sizes and signedness. TEST(FpToString, Ints) { for (int s = 0; s < 64; s++) { for (int delta = -1; delta <= 1; delta++) { uint64 fp = (1ull << s) + delta; string s = FpToString(fp); uint64 fp2; - EXPECT_TRUE(StringToFp(s, &fp2)); + EXPECT_TRUE(HexStringToUint64(s, &fp2)); EXPECT_EQ(fp, fp2); } } Fprint dummy; - EXPECT_FALSE(StringToFp("", &dummy)); - EXPECT_FALSE(StringToFp("xyz", &dummy)); - EXPECT_FALSE(StringToFp("0000000000000000xyz", &dummy)); + EXPECT_FALSE(HexStringToUint64("", &dummy)); + EXPECT_FALSE(HexStringToUint64("xyz", &dummy)); + EXPECT_FALSE(HexStringToUint64("0000000000000000xyz", &dummy)); } TEST(Uint64ToHexString, Ints) { for (int s = 0; s < 64; s++) { for (int delta = -1; delta <= 1; delta++) { uint64 fp = (1ull << s) + delta; - char buf[kFastToBufferSize]; - absl::string_view s = Uint64ToHexString(fp, buf); + std::string s = absl::StrCat(absl::Hex(fp, absl::kZeroPad16)); uint64 fp2; EXPECT_TRUE(HexStringToUint64(s, &fp2)); EXPECT_EQ(fp, fp2) << s; @@ -121,262 +122,262 @@ TEST(HumanReadableElapsedTime, Basic) { TEST(safe_strto32, Int32s) { int32 result; - EXPECT_EQ(true, safe_strto32("1", &result)); + EXPECT_EQ(true, absl::SimpleAtoi("1", &result)); EXPECT_EQ(1, result); - EXPECT_EQ(true, safe_strto32("123", &result)); + EXPECT_EQ(true, absl::SimpleAtoi("123", &result)); EXPECT_EQ(123, result); - EXPECT_EQ(true, safe_strto32(" -123 ", &result)); + EXPECT_EQ(true, absl::SimpleAtoi(" -123 ", &result)); EXPECT_EQ(-123, result); - EXPECT_EQ(true, safe_strto32("2147483647", &result)); + EXPECT_EQ(true, absl::SimpleAtoi("2147483647", &result)); EXPECT_EQ(2147483647, result); - EXPECT_EQ(true, safe_strto32("-2147483648", &result)); + EXPECT_EQ(true, absl::SimpleAtoi("-2147483648", &result)); EXPECT_EQ(-2147483648, result); // Invalid argument - EXPECT_EQ(false, safe_strto32(" 132as ", &result)); - EXPECT_EQ(false, safe_strto32(" 132.2 ", &result)); - EXPECT_EQ(false, safe_strto32(" -", &result)); - EXPECT_EQ(false, safe_strto32("", &result)); - EXPECT_EQ(false, safe_strto32(" ", &result)); - EXPECT_EQ(false, safe_strto32("123 a", &result)); + EXPECT_EQ(false, absl::SimpleAtoi(" 132as ", &result)); + EXPECT_EQ(false, absl::SimpleAtoi(" 132.2 ", &result)); + EXPECT_EQ(false, absl::SimpleAtoi(" -", &result)); + EXPECT_EQ(false, absl::SimpleAtoi("", &result)); + EXPECT_EQ(false, absl::SimpleAtoi(" ", &result)); + EXPECT_EQ(false, absl::SimpleAtoi("123 a", &result)); // Overflow - EXPECT_EQ(false, safe_strto32("2147483648", &result)); - EXPECT_EQ(false, safe_strto32("-2147483649", &result)); + EXPECT_EQ(false, absl::SimpleAtoi("2147483648", &result)); + EXPECT_EQ(false, absl::SimpleAtoi("-2147483649", &result)); // Check that the StringPiece's length is respected. - EXPECT_EQ(true, safe_strto32(absl::string_view("123", 1), &result)); + EXPECT_EQ(true, absl::SimpleAtoi(absl::string_view("123", 1), &result)); EXPECT_EQ(1, result); - EXPECT_EQ(true, safe_strto32(absl::string_view(" -123", 4), &result)); + EXPECT_EQ(true, absl::SimpleAtoi(absl::string_view(" -123", 4), &result)); EXPECT_EQ(-12, result); - EXPECT_EQ(false, safe_strto32(absl::string_view(nullptr, 0), &result)); + EXPECT_EQ(false, absl::SimpleAtoi(absl::string_view(nullptr, 0), &result)); } TEST(safe_strtou32, UInt32s) { uint32 result; - EXPECT_TRUE(safe_strtou32("0", &result)); + EXPECT_TRUE(absl::SimpleAtoi("0", &result)); EXPECT_EQ(0, result); - EXPECT_TRUE(safe_strtou32("1", &result)); + EXPECT_TRUE(absl::SimpleAtoi("1", &result)); EXPECT_EQ(1, result); - EXPECT_TRUE(safe_strtou32("123", &result)); + EXPECT_TRUE(absl::SimpleAtoi("123", &result)); EXPECT_EQ(123, result); - EXPECT_TRUE(safe_strtou32("4294967295", &result)); + EXPECT_TRUE(absl::SimpleAtoi("4294967295", &result)); EXPECT_EQ(4294967295, result); // Invalid argument - EXPECT_FALSE(safe_strtou32(" 132as ", &result)); - EXPECT_FALSE(safe_strtou32(" 132.2 ", &result)); - EXPECT_FALSE(safe_strtou32(" -", &result)); - EXPECT_FALSE(safe_strtou32("", &result)); - EXPECT_FALSE(safe_strtou32(" ", &result)); - EXPECT_FALSE(safe_strtou32("123 a", &result)); - EXPECT_FALSE(safe_strtou32("123 456", &result)); + EXPECT_FALSE(absl::SimpleAtoi(" 132as ", &result)); + EXPECT_FALSE(absl::SimpleAtoi(" 132.2 ", &result)); + EXPECT_FALSE(absl::SimpleAtoi(" -", &result)); + EXPECT_FALSE(absl::SimpleAtoi("", &result)); + EXPECT_FALSE(absl::SimpleAtoi(" ", &result)); + EXPECT_FALSE(absl::SimpleAtoi("123 a", &result)); + EXPECT_FALSE(absl::SimpleAtoi("123 456", &result)); // Overflow - EXPECT_FALSE(safe_strtou32("4294967296", &result)); - EXPECT_FALSE(safe_strtou32("-1", &result)); + EXPECT_FALSE(absl::SimpleAtoi("4294967296", &result)); + EXPECT_FALSE(absl::SimpleAtoi("-1", &result)); // Check that the StringPiece's length is respected. - EXPECT_TRUE(safe_strtou32(absl::string_view("123", 1), &result)); + EXPECT_TRUE(absl::SimpleAtoi(absl::string_view("123", 1), &result)); EXPECT_EQ(1, result); - EXPECT_TRUE(safe_strtou32(absl::string_view(" 123", 3), &result)); + EXPECT_TRUE(absl::SimpleAtoi(absl::string_view(" 123", 3), &result)); EXPECT_EQ(12, result); - EXPECT_FALSE(safe_strtou32(absl::string_view(nullptr, 0), &result)); + EXPECT_FALSE(absl::SimpleAtoi(absl::string_view(nullptr, 0), &result)); } TEST(safe_strto64, Int64s) { int64 result; - EXPECT_EQ(true, safe_strto64("1", &result)); + EXPECT_EQ(true, absl::SimpleAtoi("1", &result)); EXPECT_EQ(1, result); - EXPECT_EQ(true, safe_strto64("123", &result)); + EXPECT_EQ(true, absl::SimpleAtoi("123", &result)); EXPECT_EQ(123, result); - EXPECT_EQ(true, safe_strto64(" -123 ", &result)); + EXPECT_EQ(true, absl::SimpleAtoi(" -123 ", &result)); EXPECT_EQ(-123, result); - EXPECT_EQ(true, safe_strto64("9223372036854775807", &result)); + EXPECT_EQ(true, absl::SimpleAtoi("9223372036854775807", &result)); EXPECT_EQ(9223372036854775807, result); - EXPECT_EQ(true, safe_strto64("-9223372036854775808", &result)); + EXPECT_EQ(true, absl::SimpleAtoi("-9223372036854775808", &result)); // kint64min == -9223372036854775808 // Use -9223372036854775808 directly results in out of range error EXPECT_EQ(kint64min, result); // Invalid argument - EXPECT_EQ(false, safe_strto64(" 132as ", &result)); - EXPECT_EQ(false, safe_strto64(" 132.2 ", &result)); - EXPECT_EQ(false, safe_strto64(" -", &result)); - EXPECT_EQ(false, safe_strto64("", &result)); - EXPECT_EQ(false, safe_strto64(" ", &result)); - EXPECT_EQ(false, safe_strto64("123 a", &result)); + EXPECT_EQ(false, absl::SimpleAtoi(" 132as ", &result)); + EXPECT_EQ(false, absl::SimpleAtoi(" 132.2 ", &result)); + EXPECT_EQ(false, absl::SimpleAtoi(" -", &result)); + EXPECT_EQ(false, absl::SimpleAtoi("", &result)); + EXPECT_EQ(false, absl::SimpleAtoi(" ", &result)); + EXPECT_EQ(false, absl::SimpleAtoi("123 a", &result)); // Overflow - EXPECT_EQ(false, safe_strto64("9223372036854775808", &result)); - EXPECT_EQ(false, safe_strto64("-9223372036854775809", &result)); + EXPECT_EQ(false, absl::SimpleAtoi("9223372036854775808", &result)); + EXPECT_EQ(false, absl::SimpleAtoi("-9223372036854775809", &result)); // Check that the StringPiece's length is respected. - EXPECT_EQ(true, safe_strto64(absl::string_view("123", 1), &result)); + EXPECT_EQ(true, absl::SimpleAtoi(absl::string_view("123", 1), &result)); EXPECT_EQ(1, result); - EXPECT_EQ(true, safe_strto64(absl::string_view(" -123", 4), &result)); + EXPECT_EQ(true, absl::SimpleAtoi(absl::string_view(" -123", 4), &result)); EXPECT_EQ(-12, result); - EXPECT_EQ(false, safe_strto64(absl::string_view(nullptr, 0), &result)); + EXPECT_EQ(false, absl::SimpleAtoi(absl::string_view(nullptr, 0), &result)); } TEST(safe_strtou64, UInt64s) { uint64 result; - EXPECT_TRUE(safe_strtou64("0", &result)); + EXPECT_TRUE(absl::SimpleAtoi("0", &result)); EXPECT_EQ(0, result); - EXPECT_TRUE(safe_strtou64("1", &result)); + EXPECT_TRUE(absl::SimpleAtoi("1", &result)); EXPECT_EQ(1, result); - EXPECT_TRUE(safe_strtou64("123", &result)); + EXPECT_TRUE(absl::SimpleAtoi("123", &result)); EXPECT_EQ(123, result); - EXPECT_TRUE(safe_strtou64(" 345 ", &result)); + EXPECT_TRUE(absl::SimpleAtoi(" 345 ", &result)); EXPECT_EQ(345, result); - EXPECT_TRUE(safe_strtou64("18446744073709551615", &result)); + EXPECT_TRUE(absl::SimpleAtoi("18446744073709551615", &result)); EXPECT_EQ(18446744073709551615UL, result); // Invalid argument - EXPECT_FALSE(safe_strtou64(" 132.2 ", &result)); - EXPECT_FALSE(safe_strtou64(" 132.2 ", &result)); - EXPECT_FALSE(safe_strtou64(" -", &result)); - EXPECT_FALSE(safe_strtou64("", &result)); - EXPECT_FALSE(safe_strtou64(" ", &result)); - EXPECT_FALSE(safe_strtou64("123 a", &result)); - EXPECT_FALSE(safe_strtou64("123 456", &result)); + EXPECT_FALSE(absl::SimpleAtoi(" 132.2 ", &result)); + EXPECT_FALSE(absl::SimpleAtoi(" 132.2 ", &result)); + EXPECT_FALSE(absl::SimpleAtoi(" -", &result)); + EXPECT_FALSE(absl::SimpleAtoi("", &result)); + EXPECT_FALSE(absl::SimpleAtoi(" ", &result)); + EXPECT_FALSE(absl::SimpleAtoi("123 a", &result)); + EXPECT_FALSE(absl::SimpleAtoi("123 456", &result)); // Overflow - EXPECT_FALSE(safe_strtou64("18446744073709551616", &result)); - EXPECT_FALSE(safe_strtou64("-1", &result)); + EXPECT_FALSE(absl::SimpleAtoi("18446744073709551616", &result)); + EXPECT_FALSE(absl::SimpleAtoi("-1", &result)); // Check that the StringPiece's length is respected. - EXPECT_TRUE(safe_strtou64(absl::string_view("123", 1), &result)); + EXPECT_TRUE(absl::SimpleAtoi(absl::string_view("123", 1), &result)); EXPECT_EQ(1, result); - EXPECT_TRUE(safe_strtou64(absl::string_view(" 123", 3), &result)); + EXPECT_TRUE(absl::SimpleAtoi(absl::string_view(" 123", 3), &result)); EXPECT_EQ(12, result); - EXPECT_FALSE(safe_strtou64(absl::string_view(nullptr, 0), &result)); + EXPECT_FALSE(absl::SimpleAtoi(absl::string_view(nullptr, 0), &result)); } TEST(safe_strtof, Float) { float result = 0; - EXPECT_TRUE(safe_strtof("0.123456", &result)); + EXPECT_TRUE(absl::SimpleAtof("0.123456", &result)); EXPECT_EQ(0.123456f, result); - EXPECT_FALSE(safe_strtof("0.12345abc", &result)); + EXPECT_FALSE(absl::SimpleAtof("0.12345abc", &result)); // Overflow to infinity, underflow to 0. - EXPECT_TRUE(safe_strtof("1e39", &result)); + EXPECT_TRUE(absl::SimpleAtof("1e39", &result)); EXPECT_EQ(std::numeric_limits::infinity(), result); - EXPECT_TRUE(safe_strtof("-1e39", &result)); + EXPECT_TRUE(absl::SimpleAtof("-1e39", &result)); EXPECT_EQ(-std::numeric_limits::infinity(), result); - EXPECT_TRUE(safe_strtof("1e-50", &result)); + EXPECT_TRUE(absl::SimpleAtof("1e-50", &result)); EXPECT_EQ(0, result); - EXPECT_TRUE(safe_strtof("0xF", &result)); + EXPECT_TRUE(absl::SimpleAtof("0xF", &result)); EXPECT_EQ(0xF, result); - EXPECT_TRUE(safe_strtof("-0x2A", &result)); + EXPECT_TRUE(absl::SimpleAtof("-0x2A", &result)); EXPECT_EQ(-42.0f, result); - EXPECT_TRUE(safe_strtof(" -0x2", &result)); + EXPECT_TRUE(absl::SimpleAtof(" -0x2", &result)); EXPECT_EQ(-2.0f, result); - EXPECT_TRUE(safe_strtof("8 \t", &result)); + EXPECT_TRUE(absl::SimpleAtof("8 \t", &result)); EXPECT_EQ(8.0f, result); - EXPECT_TRUE(safe_strtof("\t20.0\t ", &result)); + EXPECT_TRUE(absl::SimpleAtof("\t20.0\t ", &result)); EXPECT_EQ(20.0f, result); - EXPECT_FALSE(safe_strtof("-infinity is awesome", &result)); + EXPECT_FALSE(absl::SimpleAtof("-infinity is awesome", &result)); // Make sure we exit cleanly if the string is too long char test_str[2 * kFastToBufferSize]; for (int i = 0; i < 2 * kFastToBufferSize; ++i) test_str[i] = 'a'; test_str[kFastToBufferSize + 1] = '\0'; - EXPECT_FALSE(safe_strtof(test_str, &result)); + EXPECT_FALSE(absl::SimpleAtof(test_str, &result)); - EXPECT_TRUE(safe_strtof("-inf", &result)); + EXPECT_TRUE(absl::SimpleAtof("-inf", &result)); EXPECT_EQ(-std::numeric_limits::infinity(), result); - EXPECT_TRUE(safe_strtof("+inf", &result)); + EXPECT_TRUE(absl::SimpleAtof("+inf", &result)); EXPECT_EQ(std::numeric_limits::infinity(), result); - EXPECT_TRUE(safe_strtof("InF", &result)); + EXPECT_TRUE(absl::SimpleAtof("InF", &result)); EXPECT_EQ(std::numeric_limits::infinity(), result); - EXPECT_TRUE(safe_strtof("-INF", &result)); + EXPECT_TRUE(absl::SimpleAtof("-INF", &result)); EXPECT_EQ(-std::numeric_limits::infinity(), result); - EXPECT_TRUE(safe_strtof("nan", &result)); + EXPECT_TRUE(absl::SimpleAtof("nan", &result)); EXPECT_TRUE(std::isnan(result)); - EXPECT_TRUE(safe_strtof("-nan", &result)); + EXPECT_TRUE(absl::SimpleAtof("-nan", &result)); EXPECT_TRUE(std::isnan(result)); - EXPECT_TRUE(safe_strtof("-NaN", &result)); + EXPECT_TRUE(absl::SimpleAtof("-NaN", &result)); EXPECT_TRUE(std::isnan(result)); - EXPECT_TRUE(safe_strtof("+NAN", &result)); + EXPECT_TRUE(absl::SimpleAtof("+NAN", &result)); EXPECT_TRUE(std::isnan(result)); } TEST(safe_strtod, Double) { double result = 0; - EXPECT_TRUE(safe_strtod("0.1234567890123", &result)); + EXPECT_TRUE(absl::SimpleAtod("0.1234567890123", &result)); EXPECT_EQ(0.1234567890123, result); - EXPECT_FALSE(safe_strtod("0.1234567890123abc", &result)); + EXPECT_FALSE(absl::SimpleAtod("0.1234567890123abc", &result)); // Make sure we exit cleanly if the string is too long char test_str[2 * kFastToBufferSize]; for (int i = 0; i < 2 * kFastToBufferSize; ++i) test_str[i] = 'a'; test_str[kFastToBufferSize + 1] = '\0'; - EXPECT_FALSE(safe_strtod(test_str, &result)); + EXPECT_FALSE(absl::SimpleAtod(test_str, &result)); // Overflow to infinity, underflow to 0. - EXPECT_TRUE(safe_strtod("1e310", &result)); + EXPECT_TRUE(absl::SimpleAtod("1e310", &result)); EXPECT_EQ(std::numeric_limits::infinity(), result); - EXPECT_TRUE(safe_strtod("-1e310", &result)); + EXPECT_TRUE(absl::SimpleAtod("-1e310", &result)); EXPECT_EQ(-std::numeric_limits::infinity(), result); - EXPECT_TRUE(safe_strtod("1e-325", &result)); + EXPECT_TRUE(absl::SimpleAtod("1e-325", &result)); EXPECT_EQ(0, result); - EXPECT_TRUE(safe_strtod(" -0x1c", &result)); + EXPECT_TRUE(absl::SimpleAtod(" -0x1c", &result)); EXPECT_EQ(-28.0, result); - EXPECT_TRUE(safe_strtod("50 \t", &result)); + EXPECT_TRUE(absl::SimpleAtod("50 \t", &result)); EXPECT_EQ(50.0, result); - EXPECT_TRUE(safe_strtod("\t82.0\t ", &result)); + EXPECT_TRUE(absl::SimpleAtod("\t82.0\t ", &result)); EXPECT_EQ(82.0, result); - EXPECT_FALSE(safe_strtod("infinity", &result)); + EXPECT_TRUE(absl::SimpleAtod("infinity", &result)); - EXPECT_TRUE(safe_strtod("-inf", &result)); + EXPECT_TRUE(absl::SimpleAtod("-inf", &result)); EXPECT_EQ(-std::numeric_limits::infinity(), result); - EXPECT_TRUE(safe_strtod("+inf", &result)); + EXPECT_TRUE(absl::SimpleAtod("+inf", &result)); EXPECT_EQ(std::numeric_limits::infinity(), result); - EXPECT_TRUE(safe_strtod("InF", &result)); + EXPECT_TRUE(absl::SimpleAtod("InF", &result)); EXPECT_EQ(std::numeric_limits::infinity(), result); - EXPECT_TRUE(safe_strtod("-INF", &result)); + EXPECT_TRUE(absl::SimpleAtod("-INF", &result)); EXPECT_EQ(-std::numeric_limits::infinity(), result); - EXPECT_TRUE(safe_strtod("nan", &result)); + EXPECT_TRUE(absl::SimpleAtod("nan", &result)); EXPECT_TRUE(std::isnan(result)); - EXPECT_TRUE(safe_strtod("-nan", &result)); + EXPECT_TRUE(absl::SimpleAtod("-nan", &result)); EXPECT_TRUE(std::isnan(result)); - EXPECT_TRUE(safe_strtod("-NaN", &result)); + EXPECT_TRUE(absl::SimpleAtod("-NaN", &result)); EXPECT_TRUE(std::isnan(result)); - EXPECT_TRUE(safe_strtod("+NAN", &result)); + EXPECT_TRUE(absl::SimpleAtod("+NAN", &result)); EXPECT_TRUE(std::isnan(result)); } diff --git a/third_party/xla/third_party/tsl/tsl/platform/path.cc b/third_party/xla/third_party/tsl/tsl/platform/path.cc index 1d808f122eee76..a099b7a7384a68 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/path.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/path.cc @@ -30,13 +30,13 @@ limitations under the License. #include #include "absl/algorithm/container.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/mutex.h" #include "tsl/platform/scanner.h" #include "tsl/platform/str_util.h" #include "tsl/platform/strcat.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" namespace tsl { namespace io { diff --git a/third_party/xla/third_party/tsl/tsl/platform/path.h b/third_party/xla/third_party/tsl/tsl/platform/path.h index dd5567a3792e6c..bf9537c0ee8fed 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/path.h +++ b/third_party/xla/third_party/tsl/tsl/platform/path.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "xla/tsl/platform/types.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" namespace tsl { namespace io { diff --git a/third_party/xla/third_party/tsl/tsl/platform/path_test.cc b/third_party/xla/third_party/tsl/tsl/platform/path_test.cc index ec43b631cf61cb..f644b0742ab1e2 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/path_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/path_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/test.h" namespace tsl { namespace io { diff --git a/third_party/xla/third_party/tsl/tsl/platform/port_test.cc b/third_party/xla/third_party/tsl/tsl/platform/port_test.cc index ba4dac4220d4b7..d238fec664ab51 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/port_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/port_test.cc @@ -15,12 +15,12 @@ limitations under the License. #include +#include "xla/tsl/platform/env_time.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/threadpool.h" #include "tsl/platform/cpu_info.h" -#include "tsl/platform/env_time.h" #include "tsl/platform/mem.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/test.h" -#include "tsl/platform/threadpool.h" namespace tsl { namespace port { diff --git a/third_party/xla/third_party/tsl/tsl/platform/protobuf.h b/third_party/xla/third_party/tsl/tsl/platform/protobuf.h index d35ccc79c0ed7f..a4525babba4bdd 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/protobuf.h +++ b/third_party/xla/third_party/tsl/tsl/platform/protobuf.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "xla/tsl/platform/types.h" #include "tsl/platform/platform.h" -#include "tsl/platform/types.h" // Import whatever namespace protobuf comes from into the // ::tsl::protobuf namespace. diff --git a/third_party/xla/third_party/tsl/tsl/platform/ram_file_system.h b/third_party/xla/third_party/tsl/tsl/platform/ram_file_system.h index 861b0666648266..626239e9af1657 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/ram_file_system.h +++ b/third_party/xla/third_party/tsl/tsl/platform/ram_file_system.h @@ -29,11 +29,11 @@ limitations under the License. #include #include "absl/strings/match.h" -#include "tsl/platform/env.h" -#include "tsl/platform/file_system.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/file_system.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/mutex.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" #ifdef PLATFORM_WINDOWS #undef DeleteFile diff --git a/third_party/xla/third_party/tsl/tsl/platform/random.cc b/third_party/xla/third_party/tsl/tsl/platform/random.cc index d7b05ab1e387a0..5d76de9a45424c 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/random.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/random.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "xla/tsl/platform/types.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/types.h" namespace tsl { namespace random { diff --git a/third_party/xla/third_party/tsl/tsl/platform/random.h b/third_party/xla/third_party/tsl/tsl/platform/random.h index 7e385387cf54f9..680520d08a4264 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/random.h +++ b/third_party/xla/third_party/tsl/tsl/platform/random.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_RANDOM_H_ #define TENSORFLOW_TSL_PLATFORM_RANDOM_H_ -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace random { diff --git a/third_party/xla/third_party/tsl/tsl/platform/random_test.cc b/third_party/xla/third_party/tsl/tsl/platform/random_test.cc index 7a6e7a7fea09ad..2ca4e32fc08aff 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/random_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/random_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace random { diff --git a/third_party/xla/third_party/tsl/tsl/platform/raw_coding.h b/third_party/xla/third_party/tsl/tsl/platform/raw_coding.h index f12c1d18ef7895..efa959af261d2e 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/raw_coding.h +++ b/third_party/xla/third_party/tsl/tsl/platform/raw_coding.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "xla/tsl/platform/types.h" #include "tsl/platform/byte_order.h" -#include "tsl/platform/types.h" namespace tsl { namespace core { diff --git a/third_party/xla/third_party/tsl/tsl/platform/refcount.h b/third_party/xla/third_party/tsl/tsl/platform/refcount.h index c3461c615a3064..5af30791b39800 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/refcount.h +++ b/third_party/xla/third_party/tsl/tsl/platform/refcount.h @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/mutex.h" #include "tsl/platform/thread_annotations.h" diff --git a/third_party/xla/third_party/tsl/tsl/platform/refcount_test.cc b/third_party/xla/third_party/tsl/tsl/platform/refcount_test.cc index 0f6036fcadec69..0cf6dc49683237 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/refcount_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/refcount_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tsl/platform/refcount.h" -#include "tsl/platform/env.h" -#include "tsl/platform/test.h" -#include "tsl/platform/threadpool.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/threadpool.h" namespace tsl { namespace core { diff --git a/third_party/xla/third_party/tsl/tsl/platform/resource_loader.cc b/third_party/xla/third_party/tsl/tsl/platform/resource_loader.cc index 97f3b0e212aa08..cff9ac257c53ad 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/resource_loader.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/resource_loader.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/path.h" #include "tsl/platform/platform.h" -#include "tsl/platform/test.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system.h b/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system.h index 1eb8da393d3eb5..f915e5d471bcc9 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system.h +++ b/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system.h @@ -20,12 +20,12 @@ limitations under the License. #include #include -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/file_system.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/file_system.h" +#include "xla/tsl/platform/status.h" #include "tsl/platform/random.h" #include "tsl/platform/retrying_utils.h" -#include "tsl/platform/status.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system_test.cc b/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system_test.cc index 33792c8ecfd293..4a856256712a31 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/str_util.h" -#include "tsl/platform/test.h" namespace tsl { namespace { diff --git a/third_party/xla/third_party/tsl/tsl/platform/retrying_utils.cc b/third_party/xla/third_party/tsl/tsl/platform/retrying_utils.cc index 14459e93b61ef3..a42cc83dd9788b 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/retrying_utils.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/retrying_utils.cc @@ -20,10 +20,10 @@ limitations under the License. #include #include "absl/time/time.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/file_system.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/file_system.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/random.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/retrying_utils.h b/third_party/xla/third_party/tsl/tsl/platform/retrying_utils.h index 470b6a8f183412..5b1e802c420877 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/retrying_utils.h +++ b/third_party/xla/third_party/tsl/tsl/platform/retrying_utils.h @@ -18,7 +18,7 @@ limitations under the License. #include #include "absl/time/time.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/retrying_utils_test.cc b/third_party/xla/third_party/tsl/tsl/platform/retrying_utils_test.cc index 00241685d00d5d..65707a651a7ea4 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/retrying_utils_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/retrying_utils_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include "absl/time/time.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/str_util.h" -#include "tsl/platform/test.h" namespace tsl { namespace { diff --git a/third_party/xla/third_party/tsl/tsl/platform/rocm_rocdl_path.h b/third_party/xla/third_party/tsl/tsl/platform/rocm_rocdl_path.h index 7432a6566d717a..7134df4932c575 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/rocm_rocdl_path.h +++ b/third_party/xla/third_party/tsl/tsl/platform/rocm_rocdl_path.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_ROCM_ROCDL_PATH_H_ #define TENSORFLOW_TSL_PLATFORM_ROCM_ROCDL_PATH_H_ -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/scanner.h b/third_party/xla/third_party/tsl/tsl/platform/scanner.h index d8be6caade08c3..4eb70b8244bc71 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/scanner.h +++ b/third_party/xla/third_party/tsl/tsl/platform/scanner.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" #include "tsl/platform/str_util.h" #include "tsl/platform/stringpiece.h" diff --git a/third_party/xla/third_party/tsl/tsl/platform/scanner_test.cc b/third_party/xla/third_party/tsl/tsl/platform/scanner_test.cc index 36681fa0496ff5..dead6fb18937dc 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/scanner_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/scanner_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tsl/platform/scanner.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace strings { diff --git a/third_party/xla/third_party/tsl/tsl/platform/setround.cc b/third_party/xla/third_party/tsl/tsl/platform/setround.cc index 0001031cf67bdd..27008762714629 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/setround.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/setround.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tsl/platform/setround.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { namespace port { diff --git a/third_party/xla/third_party/tsl/tsl/platform/setround.h b/third_party/xla/third_party/tsl/tsl/platform/setround.h index adfc3fd2ee29fa..503bda014819e7 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/setround.h +++ b/third_party/xla/third_party/tsl/tsl/platform/setround.h @@ -27,7 +27,7 @@ limitations under the License. #include // NOLINT #endif -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" namespace tsl { namespace port { diff --git a/third_party/xla/third_party/tsl/tsl/platform/setround_test.cc b/third_party/xla/third_party/tsl/tsl/platform/setround_test.cc index 5f19a8067ce9a5..6bbe24e8500868 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/setround_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/setround_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" // LLVM does not support . Disable these tests when building with it. // See b/35384639 for more information. diff --git a/third_party/xla/third_party/tsl/tsl/platform/snappy.h b/third_party/xla/third_party/tsl/tsl/platform/snappy.h index 151b4a9bce74df..d2acb88796350a 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/snappy.h +++ b/third_party/xla/third_party/tsl/tsl/platform/snappy.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_SNAPPY_H_ #define TENSORFLOW_TSL_PLATFORM_SNAPPY_H_ -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" #if !defined(PLATFORM_WINDOWS) #include diff --git a/third_party/xla/third_party/tsl/tsl/platform/stacktrace_handler_test.cc b/third_party/xla/third_party/tsl/tsl/platform/stacktrace_handler_test.cc index 6d9cc5fd722061..71d45ad44b5ebc 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/stacktrace_handler_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/stacktrace_handler_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include -#include "tsl/platform/logging.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace { diff --git a/third_party/xla/third_party/tsl/tsl/platform/stacktrace_test.cc b/third_party/xla/third_party/tsl/tsl/platform/stacktrace_test.cc index 3b23165e51080a..2c91527fbf5107 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/stacktrace_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/stacktrace_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tsl/platform/logging.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace { diff --git a/third_party/xla/third_party/tsl/tsl/platform/status.h b/third_party/xla/third_party/tsl/tsl/platform/status.h index 61238a13f5c883..fdd9343ac610f1 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status.h +++ b/third_party/xla/third_party/tsl/tsl/platform/status.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,211 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_STATUS_H_ #define TENSORFLOW_TSL_PLATFORM_STATUS_H_ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/base/macros.h" -#include "absl/functional/function_ref.h" -#include "absl/status/status.h" -#include "absl/strings/cord.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "xla/tsl/protobuf/error_codes.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/platform.h" -#include "tsl/platform/stack_frame.h" -#include "tsl/platform/types.h" - -// Include appropriate platform-dependent parts of status. -#if defined(PLATFORM_GOOGLE) -#include "xla/tsl/platform/google/status.h" // IWYU pragma: export -#else -#include "xla/tsl/platform/default/status.h" // IWYU pragma: export -#endif - -// TODO: b/323943471 - This macro should eventually be provided by Abseil. -#ifndef ABSL_DEPRECATE_AND_INLINE -#define ABSL_DEPRECATE_AND_INLINE() -#endif - -namespace tsl { - -// Since April 2023, tensorflow::Status is an alias to absl::Status. The first -// TF release including this change will be TF 2.14 (the latest release in -// April 2023 is 2.13). -// At the same time `tsl::errors::Code` aliases `absl::StatusCode`. -// -// Here is a set of correspondences: -// - Use `absl::OkStatus()` instead of `tsl::OkStatus()`. -typedef absl::Status Status ABSL_DEPRECATE_AND_INLINE(); - -namespace errors { -typedef absl::StatusCode Code ABSL_DEPRECATE_AND_INLINE(); -} // namespace errors -namespace error { -typedef ::tensorflow::error::Code Code; -} // namespace error -} // namespace tsl - -// Transparent comparison between tensorflow::error::Code protobuf enum and -// absl::Status. -// -// The longer term objective is to delete these when we have done the transition -// to absl::Status. -namespace tensorflow::error { -inline bool operator==(const ::tensorflow::error::Code& c1, - const absl::StatusCode& c2) { - return static_cast(c1) == static_cast(c2); -} - -inline bool operator!=(const ::tensorflow::error::Code& c1, - const absl::StatusCode& c2) { - return static_cast(c1) != static_cast(c2); -} -} // namespace tensorflow::error - -namespace absl { -inline bool operator==(const ::absl::StatusCode& c1, - const ::tensorflow::error::Code& c2) { - return static_cast(c1) == static_cast(c2); -} - -inline bool operator!=(const ::absl::StatusCode& c1, - const ::tensorflow::error::Code& c2) { - return static_cast(c1) != static_cast(c2); -} -} // namespace absl - -namespace tsl { - -// OkStatus() -// -// Returns an OK status, equivalent to a default constructed instance. Prefer -// usage of `OkStatus()` when constructing such an OK status. -ABSL_DEPRECATE_AND_INLINE() inline absl::Status OkStatus() { - return absl::OkStatus(); -}; - -ABSL_DEPRECATE_AND_INLINE() -inline absl::Status FromAbslStatus(const absl::Status& s) { return s; } -ABSL_DEPRECATE_AND_INLINE() -inline absl::Status ToAbslStatus(const ::absl::Status& s) { return s; } - -// Given `Status.message()` does not guarantee to be always backed by a -// null-terminated string, we have this utility function when it's needed for -// the Tensorflow C-API. -// A more robust API would be to get both a `char*` of the beginning of the -// string, plus the size (see e.g. `XlaCustomCallStatusSetFailure`). -// NB: This Windows-only implementation is exists only to avoid a linker error. -// Remove if this is resolved. -#ifdef _WIN32 -const char* NullTerminatedMessage(const absl::Status& status); -#else -ABSL_DEPRECATE_AND_INLINE() -inline const char* NullTerminatedMessage(const absl::Status& status) { - return absl::StatusMessageAsCStr(status); -} -#endif - -// TODO(b/197552541) Move this namespace to errors.h. -namespace errors { - -void SetStackTrace(absl::Status& status, std::vector stack_trace); - -std::vector GetStackTrace(const absl::Status& status); -} // namespace errors - -// Helper class to manage multiple child status values. -class StatusGroup { - public: - StatusGroup(); - // Constructor to form a StatusGroup from any N set of Status arguments. - // Usage: StatusGroup({status_a, status_b, status_c}); - StatusGroup(std::initializer_list statuses); - - // Utility function to mark a Status as derived. By marking derived status, - // Derived status messages are ignored when reporting errors to end users. - static absl::Status MakeDerived(const absl::Status& s); - static bool IsDerived(const absl::Status& s); - - // Enable warning and error log collection for appending to the aggregated - // status. This function may be called more than once. - static void ConfigureLogHistory(); - - // Returns merged payloads of all statuses. In case multiple statuses have the - // same payload key, non-derived statuses have priority over derived ones, - // otherwise one payload value will be chosen in an unspecified but - // deterministic order. - // NOTE: The payload marking derived statuses as derived will not be returned. - std::unordered_map GetPayloads() const; - - // Return a merged status with combined child status messages with a summary. - absl::Status as_summary_status() const; - // Return a merged status with combined child status messages with - // concatenation. - absl::Status as_concatenated_status() const; - - bool ok() const { return ok_; } - - // Augment this group with the child status `status`. - void Update(const absl::Status& status); - - // Attach recent warning and error log messages - void AttachLogMessages(); - bool HasLogMessages() const { return !recent_logs_.empty(); } - - private: - bool ok_ = true; - size_t num_ok_ = 0; - - // Maintain a sorted collection of statuses. - struct CompareStatus { - bool operator()(const absl::Status& a, const absl::Status& b) const { - return a.ToString() > b.ToString(); - } - }; - // Using std::set instead of absl::btree_set to keep size for certain - // dependent libraries under the limit. - std::set derived_; - std::set non_derived_; - - std::vector recent_logs_; // recent warning and error logs -}; - -typedef std::function StatusCallback; - -extern ::tsl::string* TfCheckOpHelperOutOfLine(const absl::Status& v, - const char* msg); - -inline ::tsl::string* TfCheckOpHelper(absl::Status v, const char* msg) { - if (v.ok()) return nullptr; - return TfCheckOpHelperOutOfLine(v, msg); -} - -#define TF_DO_CHECK_OK(val, level) \ - while (auto* _result = ::tsl::TfCheckOpHelper(val, #val)) \ - LOG(level) << *(_result) - -#define TF_CHECK_OK(val) TF_DO_CHECK_OK(val, FATAL) -#define TF_QCHECK_OK(val) TF_DO_CHECK_OK(val, QFATAL) - -// DEBUG only version of TF_CHECK_OK. Compiler still parses 'val' even in opt -// mode. -#ifndef NDEBUG -#define TF_DCHECK_OK(val) TF_CHECK_OK(val) -#else -#define TF_DCHECK_OK(val) \ - while (false && (::tsl::OkStatus() == (val))) LOG(FATAL) -#endif - -} // namespace tsl +#include "xla/tsl/platform/status.h" #endif // TENSORFLOW_TSL_PLATFORM_STATUS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/status_matchers.h b/third_party/xla/third_party/tsl/tsl/platform/status_matchers.h index e7e12c269d28e0..e9a55986087a0e 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status_matchers.h +++ b/third_party/xla/third_party/tsl/tsl/platform/status_matchers.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,332 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #ifndef TENSORFLOW_TSL_PLATFORM_STATUS_MATCHERS_H_ #define TENSORFLOW_TSL_PLATFORM_STATUS_MATCHERS_H_ -#include -#include -#include - -#include "xla/tsl/protobuf/error_codes.pb.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" - -// Defines the following utilities: -// -// =============== -// IsOkAndHolds(m) -// =============== -// -// This matcher matches a StatusOr value whose status is OK and whose inner -// value matches matcher m. Example: -// -// using ::tsl::testing::IsOkAndHolds; -// using ::testing::HasSubstr; -// ... -// StatusOr status_or_message("Hello, world"); -// EXPECT_THAT(status_or_message, IsOkAndHolds("Hello, world"))); -// EXPECT_THAT(status_or_message, IsOkAndHolds(HasSubstr("Hello,"))); -// -// =============================== -// StatusIs(status_code_matcher, -// error_message_matcher) -// =============================== -// -// This matcher matches a Status or StatusOr if the following are true: -// -// - the status's code() matches status_code_matcher, and -// - the status's error_message() matches error_message_matcher. -// -// Example: -// -// using ::tsl::testing::StatusIs; -// using ::testing::HasSubstr; -// using ::testing::MatchesRegex; -// using ::testing::Ne; -// using ::testing::_; -// StatusOr GetMessage(int id); -// ... -// -// // The status code must be CANCELLED; the error message can be anything. -// EXPECT_THAT(GetName(42), -// StatusIs(tsl::error::CANCELLED, _)); -// -// // The status code can be anything; the error message must match the regex. -// EXPECT_THAT(GetName(43), -// StatusIs(_, MatchesRegex("server.*time-out"))); -// -// // The status code should not be CANCELLED; the error message can be -// // anything with "Cancelled" in it. -// EXPECT_THAT(GetName(44), -// StatusIs(Ne(tsl::error::CANCELLED), -// HasSubstr("Cancelled")))); -// -// ============================= -// StatusIs(status_code_matcher) -// ============================= -// -// This is a shorthand for -// StatusIs(status_code_matcher, ::testing::_) -// -// In other words, it's like the two-argument StatusIs(), except that it ignores -// error messages. -// -// ====== -// IsOk() -// ====== -// -// Matches a Status or StatusOr whose status value is OK. -// Equivalent to 'StatusIs(error::OK)'. -// -// Example: -// ... -// StatusOr message("Hello, world"); -// EXPECT_THAT(message, IsOk()); -// Status status = OkStatus(); -// EXPECT_THAT(status, IsOk()); - -namespace tsl { - -inline void PrintTo(const tsl::error::Code code, std::ostream* os) { - *os << Code_Name(code); -} - -template -void PrintTo(const StatusOr& status_or, std::ostream* os) { - *os << ::testing::PrintToString(status_or.status()); - if (status_or.ok()) { - *os << ": " << ::testing::PrintToString(status_or.value()); - } -} - -namespace testing { -namespace internal_status { - -inline const absl::Status& GetStatus(const absl::Status& status) { - return status; -} - -template -inline const absl::Status& GetStatus(const StatusOr& status) { - return status.status(); -} - -//////////////////////////////////////////////////////////// -// Implementation of IsOkAndHolds(). -// -// Monomorphic implementation of matcher IsOkAndHolds(m). StatusOrType is a -// reference to StatusOr. -template -class IsOkAndHoldsMatcherImpl - : public ::testing::MatcherInterface { - public: - typedef - typename std::remove_reference::type::value_type value_type; - - template - explicit IsOkAndHoldsMatcherImpl(InnerMatcher&& inner_matcher) - : inner_matcher_(::testing::SafeMatcherCast( - std::forward(inner_matcher))) {} - - void DescribeTo(std::ostream* os) const override { - *os << "is OK and has a value that "; - inner_matcher_.DescribeTo(os); - } - - void DescribeNegationTo(std::ostream* os) const override { - *os << "isn't OK or has a value that "; - inner_matcher_.DescribeNegationTo(os); - } - - bool MatchAndExplain( - StatusOrType actual_value, - ::testing::MatchResultListener* result_listener) const override { - if (!actual_value.ok()) { - *result_listener << "which has status " << actual_value.status(); - return false; - } - - ::testing::StringMatchResultListener inner_listener; - const bool matches = - inner_matcher_.MatchAndExplain(*actual_value, &inner_listener); - const std::string inner_explanation = inner_listener.str(); - if (!inner_explanation.empty()) { - *result_listener << "which contains value " - << ::testing::PrintToString(*actual_value) << ", " - << inner_explanation; - } - return matches; - } - - private: - const ::testing::Matcher inner_matcher_; -}; - -// Implements IsOkAndHolds(m) as a polymorphic matcher. -template -class IsOkAndHoldsMatcher { - public: - explicit IsOkAndHoldsMatcher(InnerMatcher inner_matcher) - : inner_matcher_(std::move(inner_matcher)) {} - - // Converts this polymorphic matcher to a monomorphic matcher of the given - // type. StatusOrType can be either StatusOr or a reference to StatusOr. - template - operator ::testing::Matcher() const { // NOLINT - return ::testing::Matcher( - new IsOkAndHoldsMatcherImpl(inner_matcher_)); - } - - private: - const InnerMatcher inner_matcher_; -}; - -//////////////////////////////////////////////////////////// -// Implementation of StatusIs(). -// -// StatusIs() is a polymorphic matcher. This class is the common -// implementation of it shared by all types T where StatusIs() can be used as -// a Matcher. - -class StatusIsMatcherCommonImpl { - public: - StatusIsMatcherCommonImpl( - ::testing::Matcher code_matcher, - ::testing::Matcher message_matcher) - : code_matcher_(std::move(code_matcher)), - message_matcher_(std::move(message_matcher)) {} - - void DescribeTo(std::ostream* os) const; - - void DescribeNegationTo(std::ostream* os) const; - - bool MatchAndExplain(const absl::Status& status, - ::testing::MatchResultListener* result_listener) const; - - private: - const ::testing::Matcher code_matcher_; - const ::testing::Matcher message_matcher_; -}; - -// Monomorphic implementation of matcher StatusIs() for a given type T. T can -// be Status, StatusOr<>, or a reference to either of them. -template -class MonoStatusIsMatcherImpl : public ::testing::MatcherInterface { - public: - explicit MonoStatusIsMatcherImpl(StatusIsMatcherCommonImpl common_impl) - : common_impl_(std::move(common_impl)) {} - - void DescribeTo(std::ostream* os) const override { - common_impl_.DescribeTo(os); - } - - void DescribeNegationTo(std::ostream* os) const override { - common_impl_.DescribeNegationTo(os); - } - - bool MatchAndExplain( - T actual_value, - ::testing::MatchResultListener* result_listener) const override { - return common_impl_.MatchAndExplain(GetStatus(actual_value), - result_listener); - } - - private: - StatusIsMatcherCommonImpl common_impl_; -}; - -// Implements StatusIs() as a polymorphic matcher. -class StatusIsMatcher { - public: - StatusIsMatcher(::testing::Matcher code_matcher, - ::testing::Matcher message_matcher) - : common_impl_( - ::testing::MatcherCast(code_matcher), - ::testing::MatcherCast(message_matcher)) {} - - // Converts this polymorphic matcher to a monomorphic matcher of the given - // type. T can be StatusOr<>, Status, or a reference to either of them. - template - operator ::testing::Matcher() const { // NOLINT - return ::testing::MakeMatcher(new MonoStatusIsMatcherImpl(common_impl_)); - } - - private: - const StatusIsMatcherCommonImpl common_impl_; -}; - -// Monomorphic implementation of matcher IsOk() for a given type T. -// T can be Status, StatusOr<>, or a reference to either of them. -template -class MonoIsOkMatcherImpl : public ::testing::MatcherInterface { - public: - void DescribeTo(std::ostream* os) const override { *os << "is OK"; } - void DescribeNegationTo(std::ostream* os) const override { - *os << "is not OK"; - } - bool MatchAndExplain(T actual_value, - ::testing::MatchResultListener*) const override { - return GetStatus(actual_value).ok(); - } -}; - -// Implements IsOk() as a polymorphic matcher. -class IsOkMatcher { - public: - template - operator ::testing::Matcher() const { // NOLINT - return ::testing::Matcher(new MonoIsOkMatcherImpl()); - } -}; -} // namespace internal_status - -// Returns a matcher that matches a StatusOr<> whose status is OK and whose -// value matches the inner matcher. -template -internal_status::IsOkAndHoldsMatcher::type> -IsOkAndHolds(InnerMatcher&& inner_matcher) { - return internal_status::IsOkAndHoldsMatcher< - typename std::decay::type>( - std::forward(inner_matcher)); -} - -// Returns a matcher that matches a Status or StatusOr<> whose status code -// matches code_matcher, and whose error message matches message_matcher. -template -internal_status::StatusIsMatcher StatusIs(CodeMatcher code_matcher, - MessageMatcher message_matcher) { - return internal_status::StatusIsMatcher(std::move(code_matcher), - std::move(message_matcher)); -} -// Remove this specialization when tensorflow::Status is absl::Status -template -internal_status::StatusIsMatcher StatusIs(tensorflow::error::Code code_matcher, - MessageMatcher message_matcher) { - return internal_status::StatusIsMatcher( - static_cast(code_matcher), std::move(message_matcher)); -} - -// Returns a matcher that matches a Status or StatusOr<> whose status code -// matches code_matcher. -template -internal_status::StatusIsMatcher StatusIs(CodeMatcher code_matcher) { - return StatusIs(std::move(code_matcher), ::testing::_); -} -// Remove this specialization when tensorflow::Status is absl::Status -template <> -inline internal_status::StatusIsMatcher StatusIs( - tensorflow::error::Code code_matcher) { - return StatusIs(static_cast(code_matcher), ::testing::_); -} - -// Returns a matcher that matches a Status or StatusOr<> which is OK. -inline internal_status::IsOkMatcher IsOk() { - return internal_status::IsOkMatcher(); -} - -} // namespace testing -} // namespace tsl +#include "xla/tsl/platform/status_matchers.h" #endif // TENSORFLOW_TSL_PLATFORM_STATUS_MATCHERS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.h b/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.h index 021e002ae4041d..89b0de80337619 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.h +++ b/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,32 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #ifndef TENSORFLOW_TSL_PLATFORM_STATUS_TO_FROM_PROTO_H_ #define TENSORFLOW_TSL_PLATFORM_STATUS_TO_FROM_PROTO_H_ -#include "xla/tsl/protobuf/status.pb.h" -#include "tsl/platform/status.h" - -namespace tsl { - -// TODO(b/250921378): Merge this file with `status.h` once we figure out how to -// fix the following error with the MacOS build: -// -// ImportError: -// dlopen(/org_tensorflow/tensorflow/python/platform/_pywrap_tf2.so, 2): -// Symbol not found: tensorflow11StatusProtoC1EPN6protobuf5ArenaEb - -// Converts a `Status` to a `StatusProto`. -tensorflow::StatusProto StatusToProto(const absl::Status& s); - -#if defined(PLATFORM_GOOGLE) -// Constructs a `Status` from a `StatusProto`. -absl::Status StatusFromProto( - const tensorflow::StatusProto& proto, - absl::SourceLocation loc = absl::SourceLocation::current()); -#else -Status StatusFromProto(const tensorflow::StatusProto& proto); -#endif -} // namespace tsl +#include "xla/tsl/platform/status_to_from_proto.h" #endif // TENSORFLOW_TSL_PLATFORM_STATUS_TO_FROM_PROTO_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/statusor.h b/third_party/xla/third_party/tsl/tsl/platform/statusor.h index ac27ede3133850..c4e6da3721d76d 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/statusor.h +++ b/third_party/xla/third_party/tsl/tsl/platform/statusor.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,99 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// StatusOr is the union of a Status object and a T object. StatusOr models -// the concept of an object that is either a value, or an error Status -// explaining why such a value is not present. To this end, StatusOr does not -// allow its Status value to be Status::OK. -// -// The primary use-case for StatusOr is as the return value of a -// function which may fail. -// -// Example client usage for a StatusOr, where T is not a pointer: -// -// StatusOr result = DoBigCalculationThatCouldFail(); -// if (result.ok()) { -// float answer = result.value(); -// printf("Big calculation yielded: %f", answer); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example client usage for a StatusOr: -// -// StatusOr result = FooFactory::MakeNewFoo(arg); -// if (result.ok()) { -// std::unique_ptr foo(result.value()); -// foo->DoSomethingCool(); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example client usage for a StatusOr>: -// -// StatusOr> result = FooFactory::MakeNewFoo(arg); -// if (result.ok()) { -// std::unique_ptr foo = std::move(result.value()); -// foo->DoSomethingCool(); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example factory implementation returning StatusOr: -// -// StatusOr FooFactory::MakeNewFoo(int arg) { -// if (arg <= 0) { -// return tsl::InvalidArgument("Arg must be positive"); -// } else { -// return new Foo(arg); -// } -// } -// -// Note that the assignment operators require that destroying the currently -// stored value cannot invalidate the argument; in other words, the argument -// cannot be an alias for the current value, or anything owned by the current -// value. #ifndef TENSORFLOW_TSL_PLATFORM_STATUSOR_H_ #define TENSORFLOW_TSL_PLATFORM_STATUSOR_H_ -#include "absl/base/attributes.h" -#include "absl/base/macros.h" -#include "absl/status/statusor.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/platform.h" -#include "tsl/platform/status.h" - -// Include appropriate platform-dependent `TF_ASSIGN_OR_RETURN`. -#if defined(PLATFORM_GOOGLE) -#include "xla/tsl/platform/google/statusor.h" // IWYU pragma: export -#else -#include "xla/tsl/platform/default/statusor.h" // IWYU pragma: export -#endif - -// TODO: b/323943471 - This macro should eventually be provided by Abseil. -#ifndef ABSL_DEPRECATE_AND_INLINE -#define ABSL_DEPRECATE_AND_INLINE() -#endif - -namespace tsl { - -template -using StatusOr ABSL_DEPRECATE_AND_INLINE() = absl::StatusOr; - -} // namespace tsl - -#define TF_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ - TF_ASSERT_OK_AND_ASSIGN_IMPL( \ - TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \ - rexpr); - -#define TF_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \ - auto statusor = (rexpr); \ - ASSERT_TRUE(statusor.status().ok()) << statusor.status(); \ - lhs = std::move(statusor).value() - -#define TF_STATUS_MACROS_CONCAT_NAME(x, y) TF_STATUS_MACROS_CONCAT_IMPL(x, y) -#define TF_STATUS_MACROS_CONCAT_IMPL(x, y) x##y +#include "xla/tsl/platform/statusor.h" #endif // TENSORFLOW_TSL_PLATFORM_STATUSOR_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/str_util.cc b/third_party/xla/third_party/tsl/tsl/platform/str_util.cc index 19dfb640cb375e..f22bc6f0c45e3a 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/str_util.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/str_util.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "absl/strings/ascii.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/stringpiece.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/str_util.h b/third_party/xla/third_party/tsl/tsl/platform/str_util.h index 685583faeb9670..ff7c4cd64e5484 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/str_util.h +++ b/third_party/xla/third_party/tsl/tsl/platform/str_util.h @@ -27,9 +27,9 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/strip.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" // TODO: b/323943471 - This macro should eventually be provided by Abseil. #ifndef ABSL_DEPRECATE_AND_INLINE diff --git a/third_party/xla/third_party/tsl/tsl/platform/str_util_test.cc b/third_party/xla/third_party/tsl/tsl/platform/str_util_test.cc index ce52193109f721..607d7d1bbdf0c7 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/str_util_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/str_util_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/strcat.cc b/third_party/xla/third_party/tsl/tsl/platform/strcat.cc index afa4fd5e2630fe..0259c4bb4c0204 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/strcat.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/strcat.cc @@ -23,7 +23,7 @@ limitations under the License. #include #include "absl/meta/type_traits.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { namespace strings { diff --git a/third_party/xla/third_party/tsl/tsl/platform/strcat.h b/third_party/xla/third_party/tsl/tsl/platform/strcat.h index d552a8a8977baf..dfea869466c0a0 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/strcat.h +++ b/third_party/xla/third_party/tsl/tsl/platform/strcat.h @@ -22,10 +22,10 @@ limitations under the License. #include -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/numbers.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" // The AlphaNum type was designed to be used as the parameter type for StrCat(). // Any routine accepting either a string or a number may accept it. diff --git a/third_party/xla/third_party/tsl/tsl/platform/strcat_test.cc b/third_party/xla/third_party/tsl/tsl/platform/strcat_test.cc index d62fdb60361e9a..d98359458dd540 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/strcat_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/strcat_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/stringprintf.h" -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" #ifdef _MSC_VER // ssize_t is not a standard C++ type. diff --git a/third_party/xla/third_party/tsl/tsl/platform/stringpiece_test.cc b/third_party/xla/third_party/tsl/tsl/platform/stringpiece_test.cc index b7a46ed5d7b149..f50c1275eba845 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/stringpiece_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/stringpiece_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/stringprintf.h b/third_party/xla/third_party/tsl/tsl/platform/stringprintf.h index 92bc6fc771967e..6e1268dfa352dc 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/stringprintf.h +++ b/third_party/xla/third_party/tsl/tsl/platform/stringprintf.h @@ -26,8 +26,8 @@ limitations under the License. #include -#include "tsl/platform/macros.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace strings { diff --git a/third_party/xla/third_party/tsl/tsl/platform/stringprintf_test.cc b/third_party/xla/third_party/tsl/tsl/platform/stringprintf_test.cc index 6421002a041aa1..94cfd688f9f386 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/stringprintf_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/stringprintf_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace strings { diff --git a/third_party/xla/third_party/tsl/tsl/platform/test.h b/third_party/xla/third_party/tsl/tsl/platform/test.h index 77591d8c04143e..31ca87536ac34f 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/test.h +++ b/third_party/xla/third_party/tsl/tsl/platform/test.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,71 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_TEST_H_ #define TENSORFLOW_TSL_PLATFORM_TEST_H_ -#include -#include -#include - -#include // IWYU pragma: export -#include "tsl/platform/macros.h" -#include "tsl/platform/platform.h" -#include "tsl/platform/types.h" - -// Includes gmock.h and enables the use of gmock matchers in tensorflow tests. -// -// Test including this header can use the macros EXPECT_THAT(...) and -// ASSERT_THAT(...) in combination with gmock matchers. -// Example: -// std::vector vec = Foo(); -// EXPECT_THAT(vec, ::testing::ElementsAre(1,2,3)); -// EXPECT_THAT(vec, ::testing::UnorderedElementsAre(2,3,1)); -// -// For more details on gmock matchers see: -// https://github.com/google/googletest/blob/master/googlemock/docs/CheatSheet.md#matchers -// -// The advantages of using gmock matchers instead of self defined matchers are -// better error messages, more maintainable tests and more test coverage. -#if !defined(PLATFORM_GOOGLE) && !defined(PLATFORM_GOOGLE_ANDROID) && \ - !defined(PLATFORM_CHROMIUMOS) -#include -#include // IWYU pragma: export -#include // IWYU pragma: export -#endif -#include // IWYU pragma: export - -namespace tsl { -namespace testing { - -// Return a temporary directory suitable for temporary testing files. -// -// Where possible, consider using Env::LocalTempFilename over this function. -std::string TmpDir(); - -// Returns the path to TensorFlow in the directory containing data -// dependencies. -// -// A better alternative would be making use if -// tensorflow/tsl/platform/resource_loader.h:GetDataDependencyFilepath. That -// function should do the right thing both within and outside of tests allowing -// avoiding test specific APIs. -std::string TensorFlowSrcRoot(); - -// Returns the path to XLA in the directory containing data -// dependencies. -std::string XlaSrcRoot(); - -// Returns the path to TSL in the directory containing data -// dependencies. -std::string TslSrcRoot(); - -// Return a random number generator seed to use in randomized tests. -// Returns the same value for the lifetime of the process. -int RandomSeed(); - -// Returns an unused port number, for use in multi-process testing. -// NOTE: This function is not thread-safe. -int PickUnusedPortOrDie(); - -} // namespace testing -} // namespace tsl +#include "xla/tsl/platform/test.h" #endif // TENSORFLOW_TSL_PLATFORM_TEST_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/test_benchmark.h b/third_party/xla/third_party/tsl/tsl/platform/test_benchmark.h index d1ce3cdac3514a..6772a5f12ec9e1 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/test_benchmark.h +++ b/third_party/xla/third_party/tsl/tsl/platform/test_benchmark.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,36 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Simple benchmarking facility. #ifndef TENSORFLOW_TSL_PLATFORM_TEST_BENCHMARK_H_ #define TENSORFLOW_TSL_PLATFORM_TEST_BENCHMARK_H_ -#include "benchmark/benchmark.h" // from @com_google_benchmark // IWYU pragma: export -#include "tsl/platform/platform.h" - -// FIXME(vyng): Remove this. -// Background: During the benchmark-migration projects, all benchmarks were made -// to use "testing::benchmark::" prefix because that is what the internal -// Google benchmark library use. -namespace testing { -namespace benchmark { -using ::benchmark::State; // NOLINT -} // namespace benchmark -} // namespace testing - -namespace tsl { -namespace testing { - -inline void RunBenchmarks() { benchmark::RunSpecifiedBenchmarks(); } -inline void InitializeBenchmarks(int* argc, char** argv) { - benchmark::Initialize(argc, argv); -} - -template -void DoNotOptimize(const T& var) { - ::benchmark::DoNotOptimize(var); -} -} // namespace testing -} // namespace tsl +#include "xla/tsl/platform/test_benchmark.h" #endif // TENSORFLOW_TSL_PLATFORM_TEST_BENCHMARK_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/threadpool.h b/third_party/xla/third_party/tsl/tsl/platform/threadpool.h index df650f6eccfd4c..3ab00c4d498b2b 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/threadpool.h +++ b/third_party/xla/third_party/tsl/tsl/platform/threadpool.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,230 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_THREADPOOL_H_ #define TENSORFLOW_TSL_PLATFORM_THREADPOOL_H_ -#include -#include - -#include "absl/types/optional.h" -#include "tsl/platform/env.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/threadpool_interface.h" -#include "tsl/platform/types.h" - -namespace Eigen { -class Allocator; -class ThreadPoolInterface; -struct ThreadPoolDevice; - -template -class ThreadPoolTempl; -} // namespace Eigen - -namespace tsl { -namespace thread { - -struct EigenEnvironment; - -class ThreadPool { - public: - // Scheduling strategies for ParallelFor. The strategy governs how the given - // units of work are distributed among the available threads in the - // threadpool. - enum class SchedulingStrategy { - // The Adaptive scheduling strategy adaptively chooses the shard sizes based - // on the cost of each unit of work, and the cost model of the underlying - // threadpool device. - // - // The 'cost_per_unit' is an estimate of the number of CPU cycles (or - // nanoseconds if not CPU-bound) to complete a unit of work. Overestimating - // creates too many shards and CPU time will be dominated by per-shard - // overhead, such as Context creation. Underestimating may not fully make - // use of the specified parallelism, and may also cause inefficiencies due - // to load balancing issues and stragglers. - kAdaptive, - // The Fixed Block Size scheduling strategy shards the given units of work - // into shards of fixed size. In case the total number of units is not - // evenly divisible by 'block_size', at most one of the shards may be of - // smaller size. The exact number of shards may be found by a call to - // NumShardsUsedByFixedBlockSizeScheduling. - // - // Each shard may be executed on a different thread in parallel, depending - // on the number of threads available in the pool. Note that when there - // aren't enough threads in the pool to achieve full parallelism, function - // calls will be automatically queued. - kFixedBlockSize - }; - - // Contains additional parameters for either the Adaptive or the Fixed Block - // Size scheduling strategy. - class SchedulingParams { - public: - explicit SchedulingParams(SchedulingStrategy strategy, - absl::optional cost_per_unit, - absl::optional block_size) - : strategy_(strategy), - cost_per_unit_(cost_per_unit), - block_size_(block_size) {} - - SchedulingStrategy strategy() const { return strategy_; } - absl::optional cost_per_unit() const { return cost_per_unit_; } - absl::optional block_size() const { return block_size_; } - - private: - // The underlying Scheduling Strategy for which this instance contains - // additional parameters. - SchedulingStrategy strategy_; - - // The estimated cost per unit of work in number of CPU cycles (or - // nanoseconds if not CPU-bound). Only applicable for Adaptive scheduling - // strategy. - absl::optional cost_per_unit_; - - // The block size of each shard. Only applicable for Fixed Block Size - // scheduling strategy. - absl::optional block_size_; - }; - - // Constructs a pool that contains "num_threads" threads with specified - // "name". env->StartThread() is used to create individual threads with the - // given ThreadOptions. If "low_latency_hint" is true the thread pool - // implementation may use it as a hint that lower latency is preferred at the - // cost of higher CPU usage, e.g. by letting one or more idle threads spin - // wait. Conversely, if the threadpool is used to schedule high-latency - // operations like I/O the hint should be set to false. - // - // REQUIRES: num_threads > 0 - ThreadPool(Env* env, const ThreadOptions& thread_options, - const std::string& name, int num_threads, bool low_latency_hint, - Eigen::Allocator* allocator = nullptr); - - // Constructs a pool for low-latency ops that contains "num_threads" threads - // with specified "name". env->StartThread() is used to create individual - // threads. - // REQUIRES: num_threads > 0 - ThreadPool(Env* env, const std::string& name, int num_threads); - - // Constructs a pool for low-latency ops that contains "num_threads" threads - // with specified "name". env->StartThread() is used to create individual - // threads with the given ThreadOptions. - // REQUIRES: num_threads > 0 - ThreadPool(Env* env, const ThreadOptions& thread_options, - const std::string& name, int num_threads); - - // Constructs a pool that wraps around the thread::ThreadPoolInterface - // instance provided by the caller. Caller retains ownership of - // `user_threadpool` and must ensure its lifetime is longer than the - // ThreadPool instance. - explicit ThreadPool(thread::ThreadPoolInterface* user_threadpool); - - // Waits until all scheduled work has finished and then destroy the - // set of threads. - ~ThreadPool(); - - // Schedules fn() for execution in the pool of threads. - void Schedule(std::function fn); - - void SetStealPartitions( - const std::vector>& partitions); - - void ScheduleWithHint(std::function fn, int start, int limit); - - // Returns the number of shards used by ParallelForFixedBlockSizeScheduling - // with these parameters. - int NumShardsUsedByFixedBlockSizeScheduling(const int64_t total, - const int64_t block_size); - - // Returns the number of threads spawned by calling TransformRangeConcurrently - // with these parameters. - // Deprecated. Use NumShardsUsedByFixedBlockSizeScheduling. - int NumShardsUsedByTransformRangeConcurrently(const int64_t block_size, - const int64_t total); - - // ParallelFor shards the "total" units of work assuming each unit of work - // having roughly "cost_per_unit" cost, in cycles. Each unit of work is - // indexed 0, 1, ..., total - 1. Each shard contains 1 or more units of work - // and the total cost of each shard is roughly the same. - // - // "cost_per_unit" is an estimate of the number of CPU cycles (or nanoseconds - // if not CPU-bound) to complete a unit of work. Overestimating creates too - // many shards and CPU time will be dominated by per-shard overhead, such as - // Context creation. Underestimating may not fully make use of the specified - // parallelism, and may also cause inefficiencies due to load balancing - // issues and stragglers. - void ParallelFor(int64_t total, int64_t cost_per_unit, - const std::function& fn); - - // Similar to ParallelFor above, but takes the specified scheduling strategy - // into account. - void ParallelFor(int64_t total, const SchedulingParams& scheduling_params, - const std::function& fn); - - // Same as ParallelFor with Fixed Block Size scheduling strategy. - // Deprecated. Prefer ParallelFor with a SchedulingStrategy argument. - void TransformRangeConcurrently( - const int64_t block_size, const int64_t total, - const std::function& fn); - - // Shards the "total" units of work. For more details, see "ParallelFor". - // - // The function is passed a thread_id between 0 and NumThreads() *inclusive*. - // This is because some work can happen on the caller thread while the threads - // in the pool are also being used. - // - // The caller can allocate NumThreads() + 1 separate buffers for each thread. - // Each thread can safely write to the buffer given by its id without - // synchronization. However, the worker fn may be called multiple times - // sequentially with the same id. - // - // At most NumThreads() unique ids will actually be used, and only a few may - // be used for small workloads. If each buffer is expensive, the buffers - // should be stored in an array initially filled with null, and a buffer - // should be allocated by fn the first time that the id is used. - void ParallelForWithWorkerId( - int64_t total, int64_t cost_per_unit, - const std::function& fn); - - // Similar to ParallelForWithWorkerId above, but takes the specified - // scheduling strategy into account. - void ParallelForWithWorkerId( - int64_t total, const SchedulingParams& scheduling_params, - const std::function& fn); - - // Returns the number of threads in the pool. - int NumThreads() const; - - // Returns current thread id between 0 and NumThreads() - 1, if called from a - // thread in the pool. Returns -1 otherwise. - int CurrentThreadId() const; - - // If ThreadPool implementation is compatible with Eigen::ThreadPoolInterface, - // returns a non-null pointer. The caller does not own the object the returned - // pointer points to, and should not attempt to delete. - Eigen::ThreadPoolInterface* AsEigenThreadPool() const; - - private: - // Divides the work represented by the range [0, total) into k shards. - // Calls fn(i*block_size, (i+1)*block_size) from the ith shard (0 <= i < k). - // Each shard may be executed on a different thread in parallel, depending on - // the number of threads available in the pool. - // When (i+1)*block_size > total, fn(i*block_size, total) is called instead. - // Here, k = NumShardsUsedByFixedBlockSizeScheduling(total, block_size). - // Requires 0 < block_size <= total. - void ParallelForFixedBlockSizeScheduling( - const int64_t total, const int64_t block_size, - const std::function& fn); - - // underlying_threadpool_ is the user_threadpool if user_threadpool is - // provided in the constructor. Otherwise it is the eigen_threadpool_. - Eigen::ThreadPoolInterface* underlying_threadpool_; - // eigen_threadpool_ is instantiated and owned by thread::ThreadPool if - // user_threadpool is not in the constructor. - std::unique_ptr> eigen_threadpool_; - std::unique_ptr threadpool_device_; - ThreadPool(const ThreadPool&) = delete; - void operator=(const ThreadPool&) = delete; -}; - -} // namespace thread -} // namespace tsl +#include "xla/tsl/platform/threadpool.h" #endif // TENSORFLOW_TSL_PLATFORM_THREADPOOL_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/threadpool_async_executor.h b/third_party/xla/third_party/tsl/tsl/platform/threadpool_async_executor.h index 59f14aab13234b..deadc951116856 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/threadpool_async_executor.h +++ b/third_party/xla/third_party/tsl/tsl/platform/threadpool_async_executor.h @@ -16,35 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_THREADPOOL_ASYNC_EXECUTOR_H_ #define TENSORFLOW_TSL_PLATFORM_THREADPOOL_ASYNC_EXECUTOR_H_ -#include - -#include "xla/tsl/concurrency/async_value.h" -#include "tsl/platform/threadpool.h" - -namespace tsl::thread { - -// An adaptor for a ThreadPool that converts it into the AsyncValue:Executor. -// -// AsncValue::Executor task is a move-only absl::AnyInvocable, and ThreadPool -// expects a copyable std::function. This class adapts the two and makes sure -// that the task is deleted when it's done executing. -class ThreadPoolAsyncExecutor : public AsyncValue::Executor { - public: - explicit ThreadPoolAsyncExecutor(ThreadPool* thread_pool) - : thread_pool_(thread_pool) {} - - void Execute(Task task) final { - auto* task_ptr = new Task(std::move(task)); - thread_pool_->Schedule([task_ptr] { - (*task_ptr)(); - delete task_ptr; - }); - } - - private: - ThreadPool* thread_pool_; -}; - -} // namespace tsl::thread +#include "xla/tsl/platform/threadpool_async_executor.h" #endif // TENSORFLOW_TSL_PLATFORM_THREADPOOL_ASYNC_EXECUTOR_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/threadpool_interface.h b/third_party/xla/third_party/tsl/tsl/platform/threadpool_interface.h index 0dac04d5e7293d..930d8bcd26b7f8 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/threadpool_interface.h +++ b/third_party/xla/third_party/tsl/tsl/platform/threadpool_interface.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,16 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_THREADPOOL_INTERFACE_H_ #define TENSORFLOW_TSL_PLATFORM_THREADPOOL_INTERFACE_H_ -#include "unsupported/Eigen/CXX11/ThreadPool" // from @eigen_archive -#include "tsl/platform/mutex.h" -#include "tsl/platform/types.h" - -namespace tsl { -namespace thread { - -class ThreadPoolInterface : public Eigen::ThreadPoolInterface {}; - -} // namespace thread -} // namespace tsl +#include "xla/tsl/platform/threadpool_interface.h" #endif // TENSORFLOW_TSL_PLATFORM_THREADPOOL_INTERFACE_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/threadpool_options.h b/third_party/xla/third_party/tsl/tsl/platform/threadpool_options.h index 21c74fbaa5727f..ea884edfc380c8 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/threadpool_options.h +++ b/third_party/xla/third_party/tsl/tsl/platform/threadpool_options.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,20 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_THREADPOOL_OPTIONS_H_ #define TENSORFLOW_TSL_PLATFORM_THREADPOOL_OPTIONS_H_ -#include "tsl/platform/threadpool_interface.h" - -namespace tsl { -namespace thread { - -struct ThreadPoolOptions { - // If not null, use this threadpool to schedule inter-op operation - thread::ThreadPoolInterface* inter_op_threadpool = nullptr; - - // If not null, use this threadpool to schedule intra-op operation - thread::ThreadPoolInterface* intra_op_threadpool = nullptr; -}; - -} // namespace thread -} // namespace tsl +#include "xla/tsl/platform/threadpool_options.h" #endif // TENSORFLOW_TSL_PLATFORM_THREADPOOL_OPTIONS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/tracing.h b/third_party/xla/third_party/tsl/tsl/platform/tracing.h index 8541c2bf77feb9..07a725f2203106 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/tracing.h +++ b/third_party/xla/third_party/tsl/tsl/platform/tracing.h @@ -20,10 +20,10 @@ limitations under the License. #include -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/platform.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" namespace tsl { namespace tracing { diff --git a/third_party/xla/third_party/tsl/tsl/platform/tstring_test.cc b/third_party/xla/third_party/tsl/tsl/platform/tstring_test.cc index 78263471b61073..859f8676846e38 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/tstring_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/tstring_test.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include +#include "xla/tsl/platform/test.h" #include "tsl/platform/cord.h" #include "tsl/platform/platform.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/test.h" using ::tsl::tstring; diff --git a/third_party/xla/third_party/tsl/tsl/platform/types.h b/third_party/xla/third_party/tsl/tsl/platform/types.h index 1768d57bb7e2c6..90aa7993f7dbbc 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/types.h +++ b/third_party/xla/third_party/tsl/tsl/platform/types.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,59 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_TYPES_H_ #define TENSORFLOW_TSL_PLATFORM_TYPES_H_ -#include - -#include "tsl/platform/bfloat16.h" -#include "tsl/platform/ml_dtypes.h" // IWYU pragma: export -#include "tsl/platform/platform.h" -#include "tsl/platform/tstring.h" - -// Include appropriate platform-dependent implementations -#if defined(PLATFORM_GOOGLE) || defined(GOOGLE_INTEGRAL_TYPES) -#include "xla/tsl/platform/google/integral_types.h" // IWYU pragma: export -#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \ - defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_POSIX_IOS) || \ - defined(PLATFORM_GOOGLE_IOS) || defined(PLATFORM_WINDOWS) -#include "xla/tsl/platform/default/integral_types.h" // IWYU pragma: export -#else -#error Define the appropriate PLATFORM_ macro for this platform -#endif - -namespace tsl { - -// Alias tsl::string to std::string. -using std::string; - -static const uint4 kuint4max = static_cast(0x0F); -static const uint8 kuint8max = static_cast(0xFF); -static const uint16 kuint16max = static_cast(0xFFFF); -static const uint32 kuint32max = static_cast(0xFFFFFFFF); -static const uint64 kuint64max = static_cast(0xFFFFFFFFFFFFFFFFull); -static const int8_t kint8min = static_cast(~0x7F); -static const int8_t kint8max = static_cast(0x7F); -static const int4 kint4min = static_cast(0x08); -static const int4 kint4max = static_cast(0x07); -static const int16_t kint16min = static_cast(~0x7FFF); -static const int16_t kint16max = static_cast(0x7FFF); -static const int32_t kint32min = static_cast(~0x7FFFFFFF); -static const int32_t kint32max = static_cast(0x7FFFFFFF); -static const int64_t kint64min = static_cast(~0x7FFFFFFFFFFFFFFFll); -static const int64_t kint64max = static_cast(0x7FFFFFFFFFFFFFFFll); - -// A typedef for a uint64 used as a short fingerprint. -using Fprint = uint64; - -} // namespace tsl - -// Alias namespace ::stream_executor as ::tensorflow::se. -namespace stream_executor {} -namespace tensorflow { -namespace se = ::stream_executor; -} // namespace tensorflow - -#if defined(PLATFORM_WINDOWS) -#include -typedef std::ptrdiff_t ssize_t; -#endif +#include "xla/tsl/platform/types.h" #endif // TENSORFLOW_TSL_PLATFORM_TYPES_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/unbounded_work_queue_test.cc b/third_party/xla/third_party/tsl/tsl/platform/unbounded_work_queue_test.cc index 1efdf5a3842487..ce703010d4536c 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/unbounded_work_queue_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/unbounded_work_queue_test.cc @@ -16,10 +16,10 @@ limitations under the License. #include "tsl/platform/unbounded_work_queue.h" #include "absl/memory/memory.h" -#include "tsl/platform/random.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/blocking_counter.h" -#include "tsl/platform/env.h" -#include "tsl/platform/test.h" +#include "tsl/platform/random.h" namespace tsl { namespace { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD index 87039192e615ec..3cf573c3addd6f 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD @@ -62,11 +62,11 @@ cc_library( ]), deps = [ ":profiler_interface", - "//tsl/platform:errors", - "//tsl/platform:logging", - "//tsl/platform:status", "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/status", + "@local_xla//xla/tsl/platform:errors", + "@local_xla//xla/tsl/platform:logging", + "@local_xla//xla/tsl/platform:status", ], ) @@ -114,13 +114,13 @@ tsl_cc_test( ":profiler_factory", ":profiler_factory_impl", ":profiler_interface", - "//tsl/platform:macros", - "//tsl/platform:test", - "//tsl/platform:test_main", "//tsl/profiler/protobuf:profiler_options_proto_cc", "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@local_xla//xla/tsl/platform:macros", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -134,8 +134,8 @@ cc_library( "@local_xla//xla/tsl/profiler:xla_profiler_backends", ]), deps = [ - "//tsl/platform:status", "//tsl/profiler/protobuf:xplane_proto_cc", + "@local_xla//xla/tsl/platform:status", ], ) @@ -149,11 +149,11 @@ cc_library( "@local_xla//xla/tsl/profiler:xla_internal", ]), deps = [ - "//tsl/platform:errors", - "//tsl/platform:macros", - "//tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", + "@local_xla//xla/tsl/platform:errors", + "@local_xla//xla/tsl/platform:macros", + "@local_xla//xla/tsl/platform:statusor", "@local_xla//xla/tsl/util:env_var", ], ) @@ -163,9 +163,9 @@ tsl_cc_test( srcs = ["profiler_lock_test.cc"], deps = [ ":profiler_lock", - "//tsl/platform:test", - "//tsl/platform:test_main", "@com_google_absl//absl/status:statusor", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -175,14 +175,14 @@ cc_library( visibility = internal_visibility(["@local_xla//xla/tsl:internal"]), deps = [ "//tsl/platform", - "//tsl/platform:errors", "//tsl/platform:mutex", - "//tsl/platform:status", "//tsl/platform:thread_annotations", - "//tsl/platform:types", "//tsl/profiler/protobuf:profiler_options_proto_cc", "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/status", + "@local_xla//xla/tsl/platform:errors", + "@local_xla//xla/tsl/platform:status", + "@local_xla//xla/tsl/platform:types", ] + if_not_android([ ":profiler_interface", ":profiler_lock", @@ -203,15 +203,15 @@ cc_library( "@local_xla//xla/tsl/profiler:internal", ]), deps = [ - "//tsl/platform:errors", - "//tsl/platform:logging", "//tsl/platform:mutex", "//tsl/platform:thread_annotations", - "//tsl/platform:types", "//tsl/profiler/protobuf:profiler_options_proto_cc", "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@local_xla//xla/tsl/platform:errors", + "@local_xla//xla/tsl/platform:logging", + "@local_xla//xla/tsl/platform:types", ] + if_not_android([ ":profiler_collection", ":profiler_factory", @@ -219,7 +219,7 @@ cc_library( ":profiler_lock", "//tsl/platform", "//tsl/platform:platform_port", - "//tsl/platform:status", + "@local_xla//xla/tsl/platform:status", "@local_xla//xla/tsl/profiler/convert:post_process_single_host_xplane", "@local_xla//xla/tsl/profiler/utils:time_utils", ]), @@ -231,10 +231,10 @@ cc_library( hdrs = ["traceme_encode.h"], visibility = ["//visibility:public"], deps = [ - "//tsl/platform:logging", - "//tsl/platform:macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/platform:logging", + "@local_xla//xla/tsl/platform:macros", ], ) @@ -244,11 +244,11 @@ tsl_cc_test( deps = [ ":traceme_encode", "//tsl/platform", - "//tsl/platform:test", - "//tsl/platform:test_benchmark", - "//tsl/platform:test_main", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_benchmark", + "@local_xla//xla/tsl/platform:test_main", ], ) @@ -265,10 +265,10 @@ cc_library( deps = [ ":traceme_encode", "//tsl/platform", - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:types", "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/platform:logging", + "@local_xla//xla/tsl/platform:macros", + "@local_xla//xla/tsl/platform:types", "@local_xla//xla/tsl/profiler/utils:no_init", ] + if_not_android([ "@local_xla//xla/tsl/profiler/backends/cpu:traceme_recorder", @@ -321,9 +321,9 @@ cc_library( deps = [ ":nvtx_utils", "//tsl/platform", - "//tsl/platform:macros", - "//tsl/platform:types", "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/platform:macros", + "@local_xla//xla/tsl/platform:types", ] + if_not_android([ "@local_xla//xla/tsl/profiler/backends/cpu:annotation_stack", ]), @@ -335,10 +335,10 @@ tsl_cc_test( srcs = ["scoped_annotation_test.cc"], deps = [ ":scoped_annotation", - "//tsl/platform:test", - "//tsl/platform:test_benchmark", - "//tsl/platform:test_main", "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/platform:test", + "@local_xla//xla/tsl/platform:test_benchmark", + "@local_xla//xla/tsl/platform:test_main", "@local_xla//xla/tsl/profiler/backends/cpu:annotation_stack", "@local_xla//xla/tsl/profiler/backends/cpu:annotation_stack_impl", ], @@ -352,9 +352,9 @@ cc_library( ":context_types_hdrs", ":traceme", ":traceme_encode", - "//tsl/platform:types", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@local_xla//xla/tsl/platform:types", ], ) @@ -369,9 +369,9 @@ cc_library( ]), deps = [ ":profiler_interface", - "//tsl/platform:status", "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/status", + "@local_xla//xla/tsl/platform:status", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/connected_traceme.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/connected_traceme.h index d026a197da756c..422e8271ee4fc3 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/connected_traceme.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/connected_traceme.h @@ -21,7 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" #include "tsl/profiler/lib/context_types.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_collection.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_collection.h index c3bede9af47c8d..e2b9fd3ef979db 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_collection.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_collection.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "absl/status/status.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/xplane.pb.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_controller.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_controller.cc index 55fc42706dfea5..d9c58717cdd801 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_controller.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_controller.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/xplane.pb.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_controller.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_controller.h index ed88f8ec26b561..cc0334e9daf338 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_controller.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_controller.h @@ -18,7 +18,7 @@ limitations under the License. #include #include "absl/status/status.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/xplane.pb.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_factory_test.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_factory_test.cc index a1188b9fa5563d..84eda47a56fa68 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_factory_test.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_factory_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/status.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/test.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_interface.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_interface.h index c949a50f463cbb..2b0b712425bbcc 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_interface.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_interface.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PROFILER_LIB_PROFILER_INTERFACE_H_ #define TENSORFLOW_TSL_PROFILER_LIB_PROFILER_INTERFACE_H_ -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.cc index d32ea96fd2bf69..b226bd23925fec 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/macros.h" #include "xla/tsl/util/env_var.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/macros.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h index ef303663b3d142..719ed8f2452ba1 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h @@ -19,7 +19,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock_test.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock_test.cc index 2ddc56fb0b9a8d..f3e63bff6af66e 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock_test.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "absl/status/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.cc index 2932415dceae2e..dc312efb24b655 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.cc @@ -20,8 +20,8 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/status.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/mutex.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.h index b503f428ff30d5..f65ff7c36ab59d 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.h @@ -20,11 +20,11 @@ limitations under the License. #include #include "absl/status/status.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/mutex.h" #include "tsl/platform/platform.h" -#include "tsl/platform/status.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h index 92898c4ebc3834..d39536401e7adb 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h @@ -22,7 +22,7 @@ limitations under the License. #include #include -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" #include "tsl/platform/platform.h" // IWYU pragma: keep #include "tsl/profiler/lib/nvtx_utils.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_test.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_test.cc index bcfe9356150862..9aa61fd4983e11 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_test.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/tsl/profiler/backends/cpu/annotation_stack.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h index 517f4a0f8b669b..566dfef0a876bb 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h @@ -24,9 +24,9 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" #include "xla/tsl/profiler/utils/no_init.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" #include "tsl/profiler/lib/traceme_encode.h" // IWYU pragma: export #if !defined(IS_MOBILE_PLATFORM) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme_encode.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme_encode.h index 76c5f301e7d703..69f12dd0825e36 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme_encode.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme_encode.h @@ -24,8 +24,8 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme_encode_test.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme_encode_test.cc index 4827bee4d820b6..f8dc39196b650d 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme_encode_test.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme_encode_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" #include "tsl/platform/platform.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/workspace2.bzl b/third_party/xla/third_party/tsl/workspace2.bzl index 5cffcd00d255bd..f3a237dd7f70ca 100644 --- a/third_party/xla/third_party/tsl/workspace2.bzl +++ b/third_party/xla/third_party/tsl/workspace2.bzl @@ -111,9 +111,9 @@ def _tf_repositories(): # LINT.IfChange tf_http_archive( name = "XNNPACK", - sha256 = "ca3a5316b8161214f8f22a578fb638f1fccd0585eee40301363ffd026310379a", - strip_prefix = "XNNPACK-a50369c0fdd15f0f35b1a91c964644327a88d480", - urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/a50369c0fdd15f0f35b1a91c964644327a88d480.zip"), + sha256 = "3306f4178c8594b689165d385e644f03a3154c3be044f6ae36dd170fbf182cf5", + strip_prefix = "XNNPACK-983d013300f19fd3f4e33220b6401408e97a8d12", + urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/983d013300f19fd3f4e33220b6401408e97a8d12.zip"), ) # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) @@ -126,16 +126,16 @@ def _tf_repositories(): tf_http_archive( name = "pthreadpool", - sha256 = "b96413b10dd8edaa4f6c0a60c6cf5ef55eebeef78164d5d69294c8173457f0ec", - strip_prefix = "pthreadpool-b8374f80e42010941bda6c85b0e3f1a1bd77a1e0", - urls = tf_mirror_urls("https://github.com/Maratyszcza/pthreadpool/archive/b8374f80e42010941bda6c85b0e3f1a1bd77a1e0.zip"), + sha256 = "a4cf06de57bfdf8d7b537c61f1c3071bce74e57524fe053e0bbd2332feca7f95", + strip_prefix = "pthreadpool-4fe0e1e183925bf8cfa6aae24237e724a96479b8", + urls = tf_mirror_urls("https://github.com/Maratyszcza/pthreadpool/archive/4fe0e1e183925bf8cfa6aae24237e724a96479b8.zip"), ) tf_http_archive( name = "cpuinfo", - strip_prefix = "cpuinfo-5e63739504f0f8e18e941bd63b2d6d42536c7d90", - sha256 = "18eca9bc8d9c4ce5496d0d2be9f456d55cbbb5f0639a551ce9c8bac2e84d85fe", - urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/5e63739504f0f8e18e941bd63b2d6d42536c7d90.tar.gz"), + sha256 = "52e0ffd7998d8cb3a927d8a6e1145763744d866d2be09c4eccea27fc157b6bb0", + strip_prefix = "cpuinfo-cebb0933058d7f181c979afd50601dc311e1bf8c", + urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/cebb0933058d7f181c979afd50601dc311e1bf8c.zip"), ) tf_http_archive( @@ -163,6 +163,11 @@ def _tf_repositories(): "//third_party/mkl_dnn:onednn_acl_thread_local_scheduler.patch", "//third_party/mkl_dnn:onednn_acl_fp32_bf16_reorder.patch", "//third_party/mkl_dnn:onednn_acl_bf16_capability_detection_for_ubuntu20.04.patch", + "//third_party/mkl_dnn:onednn_acl_indirect_conv.patch", + "//third_party/mkl_dnn:onednn_acl_allow_blocked_weight_format_for_matmul_primitive.patch", + "//third_party/mkl_dnn:onednn_acl_fix_segfault_during_postop_execute.patch", + "//third_party/mkl_dnn:onednn_acl_add_bf16_platform_support_check.patch", + "//third_party/mkl_dnn:onednn_acl_add_sbgemm_matmul_primitive_definition.patch", ], sha256 = "2f76b407ef8893cca71340f88cd800019a1f14f8ac1bbdbb89a84be1370b52e3", strip_prefix = "oneDNN-3.2.1", @@ -475,14 +480,6 @@ def _tf_repositories(): urls = tf_mirror_urls("https://github.com/cython/cython/archive/3.0.3.tar.gz"), ) - tf_http_archive( - name = "double_conversion", - sha256 = "3dbcdf186ad092a8b71228a5962009b5c96abde9a315257a3452eb988414ea3b", - strip_prefix = "double-conversion-3.2.0", - system_build_file = "//third_party/systemlibs:double_conversion.BUILD", - urls = tf_mirror_urls("https://github.com/google/double-conversion/archive/v3.2.0.tar.gz"), - ) - tf_http_archive( name = "build_bazel_rules_android", sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806", diff --git a/third_party/xla/tools/toolchains/remote_config/configs.bzl b/third_party/xla/tools/toolchains/remote_config/configs.bzl index 83f52d9af9970a..0079e66d203915 100644 --- a/third_party/xla/tools/toolchains/remote_config/configs.bzl +++ b/third_party/xla/tools/toolchains/remote_config/configs.bzl @@ -1,6 +1,6 @@ """Configurations of RBE builds used with remote config.""" -load("//tools/toolchains/remote_config:rbe_config.bzl", "sigbuild_tf_configs", "tensorflow_local_config", "tensorflow_rbe_config", "tensorflow_rbe_win_config") +load("//tools/toolchains/remote_config:rbe_config.bzl", "ml_build_rbe_config", "sigbuild_tf_configs", "tensorflow_local_config", "tensorflow_rbe_config", "tensorflow_rbe_win_config") def initialize_rbe_configs(): tensorflow_local_config( @@ -47,6 +47,11 @@ def initialize_rbe_configs(): python_bin_path = "C:/Python37/python.exe", ) + # The `ml-build-rbe` image is identical to the `ml-build` image except for the base image. + # The `ml-build`'s base image is a standard `ubuntu22.04` image. + # The `ml-build-rbe`'s base image is `nvidia/cuda:12.3.2-base-ubuntu22.04` which has nvidia driver installed. + ml_build_rbe_config("docker://us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-rbe@sha256:aaeb29799463729092c05f5ac8393113b3bb5d1ecf085f9f1f2016e3a1ece11c") + # TF-Version-Specific SIG Build RBE Configs. The crosstool generated from these # configs are python-version-independent because they only care about the # tooling paths; the container mapping is useful only so that TF RBE users diff --git a/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl b/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl index 280b8d914283dd..dbfafdfb08c180 100644 --- a/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl +++ b/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl @@ -92,10 +92,24 @@ def _tensorflow_local_config(name): platform_constraint = "@%s_config_platform//:platform_constraint" % name, ) +def _ml_build_rbe_config(container_image): + exec_properties = { + "container-image": container_image, + "Pool": "default", + } + + remote_platform_configure( + name = "ml_build_config_platform", + platform = "linux", + platform_exec_properties = exec_properties, + ) + tensorflow_rbe_config = _tensorflow_rbe_config tensorflow_rbe_win_config = _tensorflow_rbe_win_config tensorflow_local_config = _tensorflow_local_config +ml_build_rbe_config = _ml_build_rbe_config +# TODO(b/369382309): Remove this once ml_build_rbe_config is used everywhere. # Streamlined platform configuration for the SIG Build containers. # See //tensorflow/tools/tf_sig_build_dockerfiles # These containers do not support ROCm and all have CUDA. diff --git a/third_party/xla/tools/toolchains/win/20240424/BUILD b/third_party/xla/tools/toolchains/win/20240424/BUILD index 93b3c90aff81d9..db4cf0eac92066 100644 --- a/third_party/xla/tools/toolchains/win/20240424/BUILD +++ b/third_party/xla/tools/toolchains/win/20240424/BUILD @@ -20,24 +20,6 @@ load(":windows_cc_toolchain_config.bzl", "cc_toolchain_config") package(default_visibility = ["//visibility:public"]) -cc_library(name = "empty_lib") - -# Label flag for extra libraries to be linked into every binary. -# TODO(bazel-team): Support passing flag multiple times to build a list. -label_flag( - name = "link_extra_libs", - build_setting_default = ":empty_lib", -) - -# The final extra library to be linked into every binary target. This collects -# the above flag, but may also include more libraries depending on config. -cc_library( - name = "link_extra_lib", - deps = [ - ":link_extra_libs", - ], -) - cc_library( name = "malloc", ) @@ -228,7 +210,8 @@ cc_toolchain_config( compiler = "msvc-cl", cpu = "x64_windows", cxx_builtin_include_directories = [ - "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include", "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", @@ -240,24 +223,24 @@ cc_toolchain_config( default_link_flags = ["/MACHINE:X64"], fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", host_system_name = "local", - msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/cl.exe", - msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", - msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64", - msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\lib\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", msvc_env_tmp = "C:\\Users\\ContainerAdministrator\\AppData\\Local\\Temp", - msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/lib.exe", - msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/link.exe", - msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/ml64.exe", + msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/lib.exe", + msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/link.exe", + msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/ml64.exe", supports_parse_showincludes = True, target_libc = "msvcrt", target_system_name = "local", tool_paths = { - "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/lib.exe", - "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/ml64.exe", - "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/cl.exe", - "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/cl.exe", + "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/lib.exe", + "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/ml64.exe", + "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", + "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", "gcov": "wrapper/bin/msvc_nop.bat", - "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x64/link.exe", + "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/link.exe", "nm": "wrapper/bin/msvc_nop.bat", "objcopy": "wrapper/bin/msvc_nop.bat", "objdump": "wrapper/bin/msvc_nop.bat", @@ -303,7 +286,8 @@ cc_toolchain_config( compiler = "msvc-cl", cpu = "x64_windows", cxx_builtin_include_directories = [ - "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include", "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", @@ -315,24 +299,24 @@ cc_toolchain_config( default_link_flags = ["/MACHINE:X86"], fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", host_system_name = "local", - msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/cl.exe", - msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", - msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\lib\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x86", - msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\bin\\HostX64\\x86;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\lib\\x86;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x86", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x86;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", msvc_env_tmp = "C:\\Users\\ContainerAdministrator\\AppData\\Local\\Temp", - msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/lib.exe", - msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/link.exe", - msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/ml.exe", + msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/lib.exe", + msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/link.exe", + msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/ml.exe", supports_parse_showincludes = True, target_libc = "msvcrt", target_system_name = "local", tool_paths = { - "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/lib.exe", - "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/ml.exe", - "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/cl.exe", - "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/cl.exe", + "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/lib.exe", + "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/ml.exe", + "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", + "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", "gcov": "wrapper/bin/msvc_nop.bat", - "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.39.33519/bin/HostX64/x86/link.exe", + "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/link.exe", "nm": "wrapper/bin/msvc_nop.bat", "objcopy": "wrapper/bin/msvc_nop.bat", "objdump": "wrapper/bin/msvc_nop.bat", @@ -511,7 +495,8 @@ cc_toolchain_config( compiler = "clang-cl", cpu = "x64_windows", cxx_builtin_include_directories = [ - "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include", "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", @@ -521,13 +506,16 @@ cc_toolchain_config( "C:\\tools\\LLVM\\lib\\clang\\18\\include", ], dbg_mode_debug_flag = "/DEBUG", - default_link_flags = ["/MACHINE:X64"], + default_link_flags = [ + "/MACHINE:X64", + "/DEFAULTLIB:clang_rt.builtins-x86_64.lib", + ], fastbuild_mode_debug_flag = "/DEBUG", host_system_name = "local", msvc_cl_path = "C:/tools/LLVM/bin/clang-cl.exe", - msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt;C:\\tools\\LLVM\\lib\\clang\\18\\include", - msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64;C:\\tools\\LLVM\\lib\\clang\\18\\lib\\windows", - msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.39.33519\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt;C:\\tools\\LLVM\\lib\\clang\\18\\include", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\ATLMFC\\lib\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64;C:\\tools\\LLVM\\lib\\clang\\18\\lib\\windows", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", msvc_env_tmp = "C:\\Users\\ContainerAdministrator\\AppData\\Local\\Temp", msvc_lib_path = "C:/tools/LLVM/bin/llvm-lib.exe", msvc_link_path = "C:/tools/LLVM/bin/lld-link.exe", diff --git a/third_party/xla/tools/toolchains/win/20240424/builtin_include_directory_paths_clangcl b/third_party/xla/tools/toolchains/win/20240424/builtin_include_directory_paths_clangcl index 0a1fb6e0df84ce..f440b6083d71fb 100644 --- a/third_party/xla/tools/toolchains/win/20240424/builtin_include_directory_paths_clangcl +++ b/third_party/xla/tools/toolchains/win/20240424/builtin_include_directory_paths_clangcl @@ -3,3 +3,5 @@ that clang-cl reported. This file is a dependency of every compilation action an changes to it will be reflected in the action cache key. When some of these paths change, Bazel will make sure to rerun the action, even though none of declared action inputs or the action commandline changes. + + diff --git a/third_party/xla/tools/toolchains/win/20240424/builtin_include_directory_paths_msvc b/third_party/xla/tools/toolchains/win/20240424/builtin_include_directory_paths_msvc index 55ba44f761e2c1..1380bc62e15b60 100644 --- a/third_party/xla/tools/toolchains/win/20240424/builtin_include_directory_paths_msvc +++ b/third_party/xla/tools/toolchains/win/20240424/builtin_include_directory_paths_msvc @@ -4,3 +4,4 @@ changes to it will be reflected in the action cache key. When some of these paths change, Bazel will make sure to rerun the action, even though none of declared action inputs or the action commandline changes. + diff --git a/third_party/xla/tools/toolchains/win/20240424/toolchain_image_info b/third_party/xla/tools/toolchains/win/20240424/toolchain_image_info index 807a14bebbdb44..ffa6a8e33c7933 100644 --- a/third_party/xla/tools/toolchains/win/20240424/toolchain_image_info +++ b/third_party/xla/tools/toolchains/win/20240424/toolchain_image_info @@ -1,2 +1,2 @@ REPOSITORY TAG DIGEST IMAGE ID CREATED SIZE -gcr.io/tensorflow-testing/tf-win2019-docker-staging latest sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc b601adb43430 8 minutes ago 20.4GB \ No newline at end of file +gcr.io/tensorflow-testing/tf-win2019-rbe latest sha256:d3577d20dea75966faf7fd03479c71462441937df5694259109c2ee1d002a3dd b601adb43430 8 minutes ago 20.4GB \ No newline at end of file diff --git a/third_party/xla/tools/toolchains/win/20240424/windows_cc_toolchain_config.bzl b/third_party/xla/tools/toolchains/win/20240424/windows_cc_toolchain_config.bzl index 6d8e8af6d50e4a..03ff9b6b30078d 100644 --- a/third_party/xla/tools/toolchains/win/20240424/windows_cc_toolchain_config.bzl +++ b/third_party/xla/tools/toolchains/win/20240424/windows_cc_toolchain_config.bzl @@ -375,7 +375,6 @@ def _impl(ctx): compiler_param_file_feature = feature( name = "compiler_param_file", - enabled = True, ) copy_dynamic_libraries_to_binary_feature = feature( diff --git a/third_party/xla/tools/toolchains/win/BUILD b/third_party/xla/tools/toolchains/win/BUILD index 55ae6fb22b81f6..258ca032ecd1ea 100644 --- a/third_party/xla/tools/toolchains/win/BUILD +++ b/third_party/xla/tools/toolchains/win/BUILD @@ -17,7 +17,7 @@ platform( remote_execution_properties = """ properties:{ name: "container-image" - value: "docker://gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc" + value: "docker://gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:d3577d20dea75966faf7fd03479c71462441937df5694259109c2ee1d002a3dd" } properties:{ name: "OSFamily" @@ -43,7 +43,7 @@ platform( remote_execution_properties = """ properties:{ name: "container-image" - value: "docker://gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc" + value: "docker://gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:d3577d20dea75966faf7fd03479c71462441937df5694259109c2ee1d002a3dd" } properties:{ name: "OSFamily" diff --git a/third_party/xla/tools/toolchains/win2022/20241118/BUILD b/third_party/xla/tools/toolchains/win2022/20241118/BUILD new file mode 100644 index 00000000000000..7d1ac7d0dfa1f2 --- /dev/null +++ b/third_party/xla/tools/toolchains/win2022/20241118/BUILD @@ -0,0 +1,647 @@ +# Copyright 2018 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This becomes the BUILD file for @local_config_cc// under Windows. + +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_toolchain", "cc_toolchain_suite") +load(":armeabi_cc_toolchain_config.bzl", "armeabi_cc_toolchain_config") +load(":windows_cc_toolchain_config.bzl", "cc_toolchain_config") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "malloc", +) + +filegroup( + name = "empty", + srcs = [], +) + +filegroup( + name = "mingw_compiler_files", + srcs = [":builtin_include_directory_paths_mingw"], +) + +filegroup( + name = "clangcl_compiler_files", + srcs = [":builtin_include_directory_paths_clangcl"], +) + +filegroup( + name = "msvc_compiler_files", + srcs = [":builtin_include_directory_paths_msvc"], +) + +# Hardcoded toolchain, legacy behaviour. +cc_toolchain_suite( + name = "toolchain", + toolchains = { + "armeabi-v7a|compiler": ":cc-compiler-armeabi-v7a", + "x64_windows|msvc-cl": ":cc-compiler-x64_windows", + "x64_x86_windows|msvc-cl": ":cc-compiler-x64_x86_windows", + "x64_arm_windows|msvc-cl": ":cc-compiler-x64_arm_windows", + "x64_arm64_windows|msvc-cl": ":cc-compiler-arm64_windows", + "arm64_windows|msvc-cl": ":cc-compiler-arm64_windows", + "x64_windows|msys-gcc": ":cc-compiler-x64_windows_msys", + "x64_windows|mingw-gcc": ":cc-compiler-x64_windows_mingw", + "x64_windows|clang-cl": ":cc-compiler-x64_windows-clang-cl", + "x64_windows_msys": ":cc-compiler-x64_windows_msys", + "x64_windows": ":cc-compiler-x64_windows", + "x64_x86_windows": ":cc-compiler-x64_x86_windows", + "x64_arm_windows": ":cc-compiler-x64_arm_windows", + "x64_arm64_windows": ":cc-compiler-arm64_windows", + "arm64_windows": ":cc-compiler-arm64_windows", + "x64_arm64_windows|clang-cl": ":cc-compiler-arm64_windows-clang-cl", + "arm64_windows|clang-cl": ":cc-compiler-arm64_windows-clang-cl", + "armeabi-v7a": ":cc-compiler-armeabi-v7a", + }, +) + +cc_toolchain( + name = "cc-compiler-x64_windows_msys", + all_files = ":empty", + ar_files = ":empty", + as_files = ":mingw_compiler_files", + compiler_files = ":mingw_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":msys_x64", + toolchain_identifier = "msys_x64", +) + +cc_toolchain_config( + name = "msys_x64", + abi_libc_version = "local", + abi_version = "local", + compiler = "msys-gcc", + cpu = "x64_windows", + cxx_builtin_include_directories = [ + "c:/tools/msys64/usr/", + ], + dbg_mode_debug_flag = "/DEBUG:FULL", + fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", + host_system_name = "local", + target_libc = "msys", + target_system_name = "local", + tool_bin_path = "c:/tools/msys64/usr/bin", + tool_paths = { + "ar": "c:/tools/msys64/usr/bin/ar", + "cpp": "c:/tools/msys64/usr/bin/cpp", + "dwp": "c:/tools/msys64/usr/bin/dwp", + "gcc": "c:/tools/msys64/usr/bin/gcc", + "gcov": "c:/tools/msys64/usr/bin/gcov", + "ld": "c:/tools/msys64/usr/bin/ld", + "nm": "c:/tools/msys64/usr/bin/nm", + "objcopy": "c:/tools/msys64/usr/bin/objcopy", + "objdump": "c:/tools/msys64/usr/bin/objdump", + "strip": "c:/tools/msys64/usr/bin/strip", + }, +) + +toolchain( + name = "cc-toolchain-x64_windows_msys", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + "@bazel_tools//tools/cpp:msys", + ], + target_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_windows_msys", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-x64_windows_mingw", + all_files = ":empty", + ar_files = ":empty", + as_files = ":mingw_compiler_files", + compiler_files = ":mingw_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 0, + toolchain_config = ":msys_x64_mingw", + toolchain_identifier = "msys_x64_mingw", +) + +cc_toolchain_config( + name = "msys_x64_mingw", + abi_libc_version = "local", + abi_version = "local", + compiler = "mingw-gcc", + cpu = "x64_windows", + cxx_builtin_include_directories = [ + "c:/tools/msys64/mingw64/", + ], + dbg_mode_debug_flag = "/DEBUG:FULL", + fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", + host_system_name = "local", + target_libc = "mingw", + target_system_name = "local", + tool_bin_path = "c:/tools/msys64/mingw64/bin", + tool_paths = { + "ar": "c:/tools/msys64/mingw64/bin/ar", + "cpp": "c:/tools/msys64/mingw64/bin/cpp", + "dwp": "c:/tools/msys64/mingw64/bin/dwp", + "gcc": "c:/tools/msys64/mingw64/bin/gcc", + "gcov": "c:/tools/msys64/mingw64/bin/gcov", + "ld": "c:/tools/msys64/mingw64/bin/ld", + "nm": "c:/tools/msys64/mingw64/bin/nm", + "objcopy": "c:/tools/msys64/mingw64/bin/objcopy", + "objdump": "c:/tools/msys64/mingw64/bin/objdump", + "strip": "c:/tools/msys64/mingw64/bin/strip", + }, +) + +toolchain( + name = "cc-toolchain-x64_windows_mingw", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + "@bazel_tools//tools/cpp:mingw", + ], + target_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_windows_mingw", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-x64_windows", + all_files = ":empty", + ar_files = ":empty", + as_files = ":msvc_compiler_files", + compiler_files = ":msvc_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":msvc_x64", + toolchain_identifier = "msvc_x64", +) + +cc_toolchain_config( + name = "msvc_x64", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:X64"], + compiler = "msvc-cl", + cpu = "x64_windows", + cxx_builtin_include_directories = [ + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", + "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + ], + dbg_mode_debug_flag = "/DEBUG:FULL", + default_link_flags = ["/MACHINE:X64"], + fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", + host_system_name = "local", + msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_env_tmp = "C:\\TMP", + msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/lib.exe", + msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/link.exe", + msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/ml64.exe", + supports_parse_showincludes = True, + target_libc = "msvcrt", + target_system_name = "local", + tool_paths = { + "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/lib.exe", + "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/ml64.exe", + "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", + "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/cl.exe", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x64/link.exe", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "msvc_x64", +) + +toolchain( + name = "cc-toolchain-x64_windows", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + target_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_windows", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-x64_x86_windows", + all_files = ":empty", + ar_files = ":empty", + as_files = ":msvc_compiler_files", + compiler_files = ":msvc_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":msvc_x64_x86", + toolchain_identifier = "msvc_x64_x86", +) + +cc_toolchain_config( + name = "msvc_x64_x86", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:X86"], + compiler = "msvc-cl", + cpu = "x64_windows", + cxx_builtin_include_directories = [ + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", + "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + ], + dbg_mode_debug_flag = "/DEBUG:FULL", + default_link_flags = ["/MACHINE:X86"], + fastbuild_mode_debug_flag = "/DEBUG:FASTLINK", + host_system_name = "local", + msvc_cl_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x86;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x86", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x86;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_env_tmp = "C:\\TMP", + msvc_lib_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/lib.exe", + msvc_link_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/link.exe", + msvc_ml_path = "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/ml.exe", + supports_parse_showincludes = True, + target_libc = "msvcrt", + target_system_name = "local", + tool_paths = { + "ar": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/lib.exe", + "ml": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/ml.exe", + "cpp": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", + "gcc": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/cl.exe", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/MSVC/14.42.34433/bin/HostX64/x86/link.exe", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "msvc_x64_x86", +) + +toolchain( + name = "cc-toolchain-x64_x86_windows", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + target_compatible_with = [ + "@platforms//cpu:x86_32", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_x86_windows", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-x64_arm_windows", + all_files = ":empty", + ar_files = ":empty", + as_files = ":msvc_compiler_files", + compiler_files = ":msvc_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":msvc_x64_arm", + toolchain_identifier = "msvc_x64_arm", +) + +cc_toolchain_config( + name = "msvc_x64_arm", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:ARM"], + compiler = "msvc-cl", + cpu = "x64_windows", + cxx_builtin_include_directories = [], + dbg_mode_debug_flag = "/DEBUG", + default_link_flags = ["/MACHINE:ARM"], + fastbuild_mode_debug_flag = "/DEBUG", + host_system_name = "local", + msvc_cl_path = "vc_installation_error_arm.bat", + msvc_env_include = "msvc_not_found", + msvc_env_lib = "msvc_not_found", + msvc_env_path = "msvc_not_found", + msvc_env_tmp = "msvc_not_found", + msvc_lib_path = "vc_installation_error_arm.bat", + msvc_link_path = "vc_installation_error_arm.bat", + msvc_ml_path = "vc_installation_error_arm.bat", + supports_parse_showincludes = False, + target_libc = "msvcrt", + target_system_name = "local", + tool_paths = { + "ar": "vc_installation_error_arm.bat", + "ml": "vc_installation_error_arm.bat", + "cpp": "vc_installation_error_arm.bat", + "gcc": "vc_installation_error_arm.bat", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "vc_installation_error_arm.bat", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "msvc_x64_arm", +) + +toolchain( + name = "cc-toolchain-x64_arm_windows", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + target_compatible_with = [ + "@platforms//cpu:arm", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_arm_windows", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-arm64_windows", + all_files = ":empty", + ar_files = ":empty", + as_files = ":msvc_compiler_files", + compiler_files = ":msvc_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":msvc_arm64", + toolchain_identifier = "msvc_arm64", +) + +cc_toolchain_config( + name = "msvc_arm64", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:ARM64"], + compiler = "msvc-cl", + cpu = "x64_windows", + cxx_builtin_include_directories = [], + dbg_mode_debug_flag = "/DEBUG", + default_link_flags = ["/MACHINE:ARM64"], + fastbuild_mode_debug_flag = "/DEBUG", + host_system_name = "local", + msvc_cl_path = "vc_installation_error_arm64.bat", + msvc_env_include = "msvc_not_found", + msvc_env_lib = "msvc_not_found", + msvc_env_path = "msvc_not_found", + msvc_env_tmp = "msvc_not_found", + msvc_lib_path = "vc_installation_error_arm64.bat", + msvc_link_path = "vc_installation_error_arm64.bat", + msvc_ml_path = "vc_installation_error_arm64.bat", + supports_parse_showincludes = False, + target_libc = "msvcrt", + target_system_name = "local", + tool_paths = { + "ar": "vc_installation_error_arm64.bat", + "ml": "vc_installation_error_arm64.bat", + "cpp": "vc_installation_error_arm64.bat", + "gcc": "vc_installation_error_arm64.bat", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "vc_installation_error_arm64.bat", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "msvc_arm64", +) + +toolchain( + name = "cc-toolchain-arm64_windows", + exec_compatible_with = [ + "@platforms//os:windows", + ], + target_compatible_with = [ + "@platforms//cpu:arm64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-arm64_windows", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-x64_windows-clang-cl", + all_files = ":empty", + ar_files = ":empty", + as_files = ":clangcl_compiler_files", + compiler_files = ":clangcl_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":clang_cl_x64", + toolchain_identifier = "clang_cl_x64", +) + +cc_toolchain_config( + name = "clang_cl_x64", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:X64"], + compiler = "clang-cl", + cpu = "x64_windows", + cxx_builtin_include_directories = [ + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include", + "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include", + "C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt", + "C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt", + "C:\\tools\\LLVM\\lib\\clang\\18\\include", + ], + dbg_mode_debug_flag = "/DEBUG", + default_link_flags = [ + "/MACHINE:X64", + "/DEFAULTLIB:clang_rt.builtins-x86_64.lib", + ], + fastbuild_mode_debug_flag = "/DEBUG", + host_system_name = "local", + msvc_cl_path = "C:/tools/LLVM/bin/clang-cl.exe", + msvc_env_include = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\include;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Auxiliary\\VS\\include;C:\\Program Files (x86)\\Windows Kits\\10\\include\\10.0.22621.0\\ucrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\um;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\shared;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\winrt;C:\\Program Files (x86)\\Windows Kits\\10\\\\include\\10.0.22621.0\\\\cppwinrt;C:\\tools\\LLVM\\lib\\clang\\18\\include", + msvc_env_lib = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\lib\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\lib\\10.0.22621.0\\ucrt\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\\\lib\\10.0.22621.0\\\\um\\x64;C:\\tools\\LLVM\\lib\\clang\\18\\lib\\windows", + msvc_env_path = "C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX64\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\VCPackages;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TestWindow;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\CommonExtensions\\Microsoft\\TeamFoundation\\Team Explorer;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\MSBuild\\Current\\bin\\Roslyn;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\10.0.22621.0\\\\x64;C:\\Program Files (x86)\\Windows Kits\\10\\bin\\\\x64;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\\\MSBuild\\Current\\Bin\\amd64;C:\\Windows\\Microsoft.NET\\Framework64\\v4.0.30319;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\;;C:\\Windows\\system32;C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\IDE\\VC\\Linux\\bin\\ConnectionManagerExe", + msvc_env_tmp = "C:\\TMP", + msvc_lib_path = "C:/tools/LLVM/bin/llvm-lib.exe", + msvc_link_path = "C:/tools/LLVM/bin/lld-link.exe", + msvc_ml_path = "C:/tools/LLVM/bin/clang-cl.exe", + supports_parse_showincludes = True, + target_libc = "msvcrt", + target_system_name = "local", + tool_paths = { + "ar": "C:/tools/LLVM/bin/llvm-lib.exe", + "ml": "C:/tools/LLVM/bin/clang-cl.exe", + "cpp": "C:/tools/LLVM/bin/clang-cl.exe", + "gcc": "C:/tools/LLVM/bin/clang-cl.exe", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "C:/tools/LLVM/bin/lld-link.exe", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "clang_cl_x64", +) + +toolchain( + name = "cc-toolchain-x64_windows-clang-cl", + exec_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + "@bazel_tools//tools/cpp:clang-cl", + ], + target_compatible_with = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-x64_windows-clang-cl", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-arm64_windows-clang-cl", + all_files = ":empty", + ar_files = ":empty", + as_files = ":clangcl_compiler_files", + compiler_files = ":clangcl_compiler_files", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":clang_cl_arm64", + toolchain_identifier = "clang_cl_arm64", +) + +cc_toolchain_config( + name = "clang_cl_arm64", + abi_libc_version = "local", + abi_version = "local", + archiver_flags = ["/MACHINE:ARM64"], + compiler = "clang-cl", + cpu = "arm64_windows", + cxx_builtin_include_directories = [], + dbg_mode_debug_flag = "/DEBUG", + default_link_flags = ["/MACHINE:ARM64"], + fastbuild_mode_debug_flag = "/DEBUG", + host_system_name = "local", + msvc_cl_path = "vc_installation_error_arm64.bat", + msvc_env_include = "clang_cl_not_found", + msvc_env_lib = "clang_cl_not_found", + msvc_env_path = "clang_cl_not_found", + msvc_env_tmp = "clang_cl_not_found", + msvc_lib_path = "vc_installation_error_arm64.bat", + msvc_link_path = "vc_installation_error_arm64.bat", + msvc_ml_path = "vc_installation_error_arm64.bat", + supports_parse_showincludes = False, + target_libc = "msvcrt", + target_system_name = "aarch64-pc-windows-msvc", + tool_paths = { + "ar": "vc_installation_error_arm64.bat", + "ml": "vc_installation_error_arm64.bat", + "cpp": "vc_installation_error_arm64.bat", + "gcc": "vc_installation_error_arm64.bat", + "gcov": "wrapper/bin/msvc_nop.bat", + "ld": "vc_installation_error_arm64.bat", + "nm": "wrapper/bin/msvc_nop.bat", + "objcopy": "wrapper/bin/msvc_nop.bat", + "objdump": "wrapper/bin/msvc_nop.bat", + "strip": "wrapper/bin/msvc_nop.bat", + }, + toolchain_identifier = "clang_cl_arm64", +) + +toolchain( + name = "cc-toolchain-arm64_windows-clang-cl", + exec_compatible_with = [ + "@platforms//os:windows", + "@bazel_tools//tools/cpp:clang-cl", + ], + target_compatible_with = [ + "@platforms//cpu:arm64", + "@platforms//os:windows", + ], + toolchain = ":cc-compiler-arm64_windows-clang-cl", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain( + name = "cc-compiler-armeabi-v7a", + all_files = ":empty", + ar_files = ":empty", + as_files = ":empty", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":empty", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 1, + toolchain_config = ":stub_armeabi-v7a", + toolchain_identifier = "stub_armeabi-v7a", +) + +armeabi_cc_toolchain_config(name = "stub_armeabi-v7a") + +toolchain( + name = "cc-toolchain-armeabi-v7a", + exec_compatible_with = [ + ], + target_compatible_with = [ + "@platforms//cpu:armv7", + "@platforms//os:android", + ], + toolchain = ":cc-compiler-armeabi-v7a", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) diff --git a/third_party/xla/tools/toolchains/win2022/20241118/armeabi_cc_toolchain_config.bzl b/third_party/xla/tools/toolchains/win2022/20241118/armeabi_cc_toolchain_config.bzl new file mode 100644 index 00000000000000..72ef48ae6d6dfc --- /dev/null +++ b/third_party/xla/tools/toolchains/win2022/20241118/armeabi_cc_toolchain_config.bzl @@ -0,0 +1,82 @@ +# Copyright 2019 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Starlark cc_toolchain configuration rule""" + +load( + "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl", + "feature", + "tool_path", +) + +def _impl(ctx): + toolchain_identifier = "stub_armeabi-v7a" + host_system_name = "armeabi-v7a" + target_system_name = "armeabi-v7a" + target_cpu = "armeabi-v7a" + target_libc = "armeabi-v7a" + compiler = "compiler" + abi_version = "armeabi-v7a" + abi_libc_version = "armeabi-v7a" + cc_target_os = None + builtin_sysroot = None + action_configs = [] + + supports_pic_feature = feature(name = "supports_pic", enabled = True) + supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True) + features = [supports_dynamic_linker_feature, supports_pic_feature] + + cxx_builtin_include_directories = [] + artifact_name_patterns = [] + make_variables = [] + + tool_paths = [ + tool_path(name = "ar", path = "/bin/false"), + tool_path(name = "cpp", path = "/bin/false"), + tool_path(name = "dwp", path = "/bin/false"), + tool_path(name = "gcc", path = "/bin/false"), + tool_path(name = "gcov", path = "/bin/false"), + tool_path(name = "ld", path = "/bin/false"), + tool_path(name = "llvm-profdata", path = "/bin/false"), + tool_path(name = "nm", path = "/bin/false"), + tool_path(name = "objcopy", path = "/bin/false"), + tool_path(name = "objdump", path = "/bin/false"), + tool_path(name = "strip", path = "/bin/false"), + ] + + return cc_common.create_cc_toolchain_config_info( + ctx = ctx, + features = features, + action_configs = action_configs, + artifact_name_patterns = artifact_name_patterns, + cxx_builtin_include_directories = cxx_builtin_include_directories, + toolchain_identifier = toolchain_identifier, + host_system_name = host_system_name, + target_system_name = target_system_name, + target_cpu = target_cpu, + target_libc = target_libc, + compiler = compiler, + abi_version = abi_version, + abi_libc_version = abi_libc_version, + tool_paths = tool_paths, + make_variables = make_variables, + builtin_sysroot = builtin_sysroot, + cc_target_os = cc_target_os, + ) + +armeabi_cc_toolchain_config = rule( + implementation = _impl, + attrs = {}, + provides = [CcToolchainConfigInfo], +) diff --git a/third_party/xla/tools/toolchains/win2022/20241118/builtin_include_directory_paths_clangcl b/third_party/xla/tools/toolchains/win2022/20241118/builtin_include_directory_paths_clangcl new file mode 100644 index 00000000000000..f440b6083d71fb --- /dev/null +++ b/third_party/xla/tools/toolchains/win2022/20241118/builtin_include_directory_paths_clangcl @@ -0,0 +1,7 @@ +This file is generated by cc_configure and contains builtin include directories +that clang-cl reported. This file is a dependency of every compilation action and +changes to it will be reflected in the action cache key. When some of these +paths change, Bazel will make sure to rerun the action, even though none of +declared action inputs or the action commandline changes. + + diff --git a/third_party/xla/tools/toolchains/win2022/20241118/builtin_include_directory_paths_msvc b/third_party/xla/tools/toolchains/win2022/20241118/builtin_include_directory_paths_msvc new file mode 100644 index 00000000000000..1380bc62e15b60 --- /dev/null +++ b/third_party/xla/tools/toolchains/win2022/20241118/builtin_include_directory_paths_msvc @@ -0,0 +1,7 @@ +This file is generated by cc_configure and contains builtin include directories +that msvc reported. This file is a dependency of every compilation action and +changes to it will be reflected in the action cache key. When some of these +paths change, Bazel will make sure to rerun the action, even though none of +declared action inputs or the action commandline changes. + + diff --git a/third_party/xla/tools/toolchains/win2022/20241118/windows_cc_toolchain_config.bzl b/third_party/xla/tools/toolchains/win2022/20241118/windows_cc_toolchain_config.bzl new file mode 100644 index 00000000000000..03ff9b6b30078d --- /dev/null +++ b/third_party/xla/tools/toolchains/win2022/20241118/windows_cc_toolchain_config.bzl @@ -0,0 +1,1442 @@ +# Copyright 2019 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Starlark cc_toolchain configuration rule for Windows""" + +load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES") +load( + "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl", + "action_config", + "artifact_name_pattern", + "env_entry", + "env_set", + "feature", + "flag_group", + "flag_set", + "tool", + "tool_path", + "variable_with_value", + "with_feature_set", +) + +all_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.clif_match, + ACTION_NAMES.lto_backend, +] + +all_cpp_compile_actions = [ + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.clif_match, +] + +preprocessor_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.clif_match, +] + +codegen_compile_actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, +] + +all_link_actions = [ + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, +] + +def _use_msvc_toolchain(ctx): + return ctx.attr.cpu in ["x64_windows", "arm64_windows"] and (ctx.attr.compiler == "msvc-cl" or ctx.attr.compiler == "clang-cl") + +def _impl(ctx): + if _use_msvc_toolchain(ctx): + artifact_name_patterns = [ + artifact_name_pattern( + category_name = "object_file", + prefix = "", + extension = ".obj", + ), + artifact_name_pattern( + category_name = "static_library", + prefix = "", + extension = ".lib", + ), + artifact_name_pattern( + category_name = "alwayslink_static_library", + prefix = "", + extension = ".lo.lib", + ), + artifact_name_pattern( + category_name = "executable", + prefix = "", + extension = ".exe", + ), + artifact_name_pattern( + category_name = "dynamic_library", + prefix = "", + extension = ".dll", + ), + artifact_name_pattern( + category_name = "interface_library", + prefix = "", + extension = ".if.lib", + ), + ] + else: + artifact_name_patterns = [ + artifact_name_pattern( + category_name = "executable", + prefix = "", + extension = ".exe", + ), + ] + + if _use_msvc_toolchain(ctx): + cpp_link_nodeps_dynamic_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_nodeps_dynamic_library, + implies = [ + "nologo", + "shared_flag", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + "has_configured_linker_path", + "def_file", + ], + tools = [tool(path = ctx.attr.msvc_link_path)], + ) + + cpp_link_static_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_static_library, + implies = [ + "nologo", + "archiver_flags", + "input_param_flags", + "linker_param_file", + "msvc_env", + ], + tools = [tool(path = ctx.attr.msvc_lib_path)], + ) + + assemble_action = action_config( + action_name = ACTION_NAMES.assemble, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "sysroot", + ], + tools = [tool(path = ctx.attr.msvc_ml_path)], + ) + + preprocess_assemble_action = action_config( + action_name = ACTION_NAMES.preprocess_assemble, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "sysroot", + ], + tools = [tool(path = ctx.attr.msvc_ml_path)], + ) + + c_compile_action = action_config( + action_name = ACTION_NAMES.c_compile, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "user_compile_flags", + "sysroot", + ], + tools = [tool(path = ctx.attr.msvc_cl_path)], + ) + + linkstamp_compile_action = action_config( + action_name = ACTION_NAMES.linkstamp_compile, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "default_compile_flags", + "nologo", + "msvc_env", + "user_compile_flags", + "sysroot", + "unfiltered_compile_flags", + ], + tools = [tool(path = ctx.attr.msvc_cl_path)], + ) + + cpp_compile_action = action_config( + action_name = ACTION_NAMES.cpp_compile, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "user_compile_flags", + "sysroot", + ], + tools = [tool(path = ctx.attr.msvc_cl_path)], + ) + + cpp_link_executable_action = action_config( + action_name = ACTION_NAMES.cpp_link_executable, + implies = [ + "nologo", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + ], + tools = [tool(path = ctx.attr.msvc_link_path)], + ) + + cpp_link_dynamic_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_dynamic_library, + implies = [ + "nologo", + "shared_flag", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + "has_configured_linker_path", + "def_file", + ], + tools = [tool(path = ctx.attr.msvc_link_path)], + ) + + action_configs = [ + assemble_action, + preprocess_assemble_action, + c_compile_action, + linkstamp_compile_action, + cpp_compile_action, + cpp_link_executable_action, + cpp_link_dynamic_library_action, + cpp_link_nodeps_dynamic_library_action, + cpp_link_static_library_action, + ] + else: + action_configs = [] + + if _use_msvc_toolchain(ctx): + msvc_link_env_feature = feature( + name = "msvc_link_env", + env_sets = [ + env_set( + actions = all_link_actions + + [ACTION_NAMES.cpp_link_static_library], + env_entries = [env_entry(key = "LIB", value = ctx.attr.msvc_env_lib)], + ), + ], + ) + + shared_flag_feature = feature( + name = "shared_flag", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [flag_group(flags = ["/DLL"])], + ), + ], + ) + + determinism_feature = feature( + name = "determinism", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "/wd4117", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + ] + (["-Wno-builtin-macro-redefined"] if ctx.attr.compiler == "clang-cl" else []), + ), + ], + ), + ], + ) + + sysroot_feature = feature( + name = "sysroot", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = ["--sysroot=%{sysroot}"], + iterate_over = "sysroot", + expand_if_available = "sysroot", + ), + ], + ), + ], + ) + + unfiltered_compile_flags_feature = feature( + name = "unfiltered_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["%{unfiltered_compile_flags}"], + iterate_over = "unfiltered_compile_flags", + expand_if_available = "unfiltered_compile_flags", + ), + ], + ), + ], + ) + + archive_param_file_feature = feature( + name = "archive_param_file", + enabled = True, + ) + + compiler_param_file_feature = feature( + name = "compiler_param_file", + ) + + copy_dynamic_libraries_to_binary_feature = feature( + name = "copy_dynamic_libraries_to_binary", + ) + + input_param_flags_feature = feature( + name = "input_param_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = ["/IMPLIB:%{interface_library_output_path}"], + expand_if_available = "interface_library_output_path", + ), + ], + ), + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["%{libopts}"], + iterate_over = "libopts", + expand_if_available = "libopts", + ), + ], + ), + flag_set( + actions = all_link_actions + + [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + iterate_over = "libraries_to_link", + flag_groups = [ + flag_group( + iterate_over = "libraries_to_link.object_files", + flag_groups = [flag_group(flags = ["%{libraries_to_link.object_files}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file_group", + ), + ), + flag_group( + flag_groups = [flag_group(flags = ["%{libraries_to_link.name}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file", + ), + ), + flag_group( + flag_groups = [flag_group(flags = ["%{libraries_to_link.name}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "interface_library", + ), + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["%{libraries_to_link.name}"], + expand_if_false = "libraries_to_link.is_whole_archive", + ), + flag_group( + flags = ["/WHOLEARCHIVE:%{libraries_to_link.name}"], + expand_if_true = "libraries_to_link.is_whole_archive", + ), + ], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "static_library", + ), + ), + ], + expand_if_available = "libraries_to_link", + ), + ], + ), + ], + ) + + fastbuild_feature = feature( + name = "fastbuild", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Od", "/Z7"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = [ctx.attr.fastbuild_mode_debug_flag, "/INCREMENTAL:NO"], + ), + ], + ), + ], + implies = ["generate_pdb_file"], + ) + + user_compile_flags_feature = feature( + name = "user_compile_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["%{user_compile_flags}"], + iterate_over = "user_compile_flags", + expand_if_available = "user_compile_flags", + ), + ], + ), + ], + ) + + archiver_flags_feature = feature( + name = "archiver_flags", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + flags = ["/OUT:%{output_execpath}"], + expand_if_available = "output_execpath", + ), + flag_group( + flags = ctx.attr.archiver_flags, + ), + ], + ), + ], + ) + + default_link_flags_feature = feature( + name = "default_link_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ctx.attr.default_link_flags)], + ), + ], + ) + + static_link_msvcrt_feature = feature( + name = "static_link_msvcrt", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MT"])], + with_features = [with_feature_set(not_features = ["dbg"])], + ), + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MTd"])], + with_features = [with_feature_set(features = ["dbg"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmt.lib"])], + with_features = [with_feature_set(not_features = ["dbg"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmtd.lib"])], + with_features = [with_feature_set(features = ["dbg"])], + ), + ], + ) + + dynamic_link_msvcrt_feature = feature( + name = "dynamic_link_msvcrt", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MD"])], + with_features = [with_feature_set(not_features = ["dbg", "static_link_msvcrt"])], + ), + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MDd"])], + with_features = [with_feature_set(features = ["dbg"], not_features = ["static_link_msvcrt"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrt.lib"])], + with_features = [with_feature_set(not_features = ["dbg", "static_link_msvcrt"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrtd.lib"])], + with_features = [with_feature_set(features = ["dbg"], not_features = ["static_link_msvcrt"])], + ), + ], + ) + + dbg_feature = feature( + name = "dbg", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Od", "/Z7"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = [ctx.attr.dbg_mode_debug_flag, "/INCREMENTAL:NO"], + ), + ], + ), + ], + implies = ["generate_pdb_file"], + ) + + opt_feature = feature( + name = "opt", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/O2"])], + ), + ], + implies = ["frame_pointer"], + ) + + supports_interface_shared_libraries_feature = feature( + name = "supports_interface_shared_libraries", + enabled = True, + ) + + user_link_flags_feature = feature( + name = "user_link_flags", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["%{user_link_flags}"], + iterate_over = "user_link_flags", + expand_if_available = "user_link_flags", + ), + ], + ), + ], + ) + + default_compile_flags_feature = feature( + name = "default_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = [ + "/DCOMPILER_MSVC", + "/DNOMINMAX", + "/D_WIN32_WINNT=0x0601", + "/D_CRT_SECURE_NO_DEPRECATE", + "/D_CRT_SECURE_NO_WARNINGS", + "/bigobj", + "/Zm500", + "/EHsc", + "/wd4351", + "/wd4291", + "/wd4250", + "/wd4996", + ], + ), + ], + ), + ], + ) + + msvc_compile_env_feature = feature( + name = "msvc_compile_env", + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ], + env_entries = [env_entry(key = "INCLUDE", value = ctx.attr.msvc_env_include)], + ), + ], + ) + + preprocessor_defines_feature = feature( + name = "preprocessor_defines", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group( + flags = ["/D%{preprocessor_defines}"], + iterate_over = "preprocessor_defines", + ), + ], + ), + ], + ) + + generate_pdb_file_feature = feature( + name = "generate_pdb_file", + ) + + output_execpath_flags_feature = feature( + name = "output_execpath_flags", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["/OUT:%{output_execpath}"], + expand_if_available = "output_execpath", + ), + ], + ), + ], + ) + + disable_assertions_feature = feature( + name = "disable_assertions", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/DNDEBUG"])], + with_features = [with_feature_set(features = ["opt"])], + ), + ], + ) + + has_configured_linker_path_feature = feature(name = "has_configured_linker_path") + + supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True) + + no_stripping_feature = feature(name = "no_stripping") + + linker_param_file_feature = feature( + name = "linker_param_file", + flag_sets = [ + flag_set( + actions = all_link_actions + + [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + flags = ["@%{linker_param_file}"], + expand_if_available = "linker_param_file", + ), + ], + ), + ], + ) + + ignore_noisy_warnings_feature = feature( + name = "ignore_noisy_warnings", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = [flag_group(flags = ["/ignore:4221"])], + ), + ], + ) + + no_legacy_features_feature = feature(name = "no_legacy_features") + + parse_showincludes_feature = feature( + name = "parse_showincludes", + enabled = ctx.attr.supports_parse_showincludes, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_header_parsing, + ], + flag_groups = [flag_group(flags = ["/showIncludes"])], + ), + ], + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_header_parsing, + ], + # Force English (and thus a consistent locale) output so that Bazel can parse + # the /showIncludes output without having to guess the encoding. + env_entries = [env_entry(key = "VSLANG", value = "1033")], + ), + ], + ) + + # MSVC does not emit .d files. + no_dotd_file_feature = feature( + name = "no_dotd_file", + enabled = True, + ) + + treat_warnings_as_errors_feature = feature( + name = "treat_warnings_as_errors", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile] + all_link_actions, + flag_groups = [flag_group(flags = ["/WX"])], + ), + ], + ) + + windows_export_all_symbols_feature = feature(name = "windows_export_all_symbols") + + no_windows_export_all_symbols_feature = feature(name = "no_windows_export_all_symbols") + + include_paths_feature = feature( + name = "include_paths", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group( + flags = ["/I%{quote_include_paths}"], + iterate_over = "quote_include_paths", + ), + flag_group( + flags = ["/I%{include_paths}"], + iterate_over = "include_paths", + ), + flag_group( + flags = ["/I%{system_include_paths}"], + iterate_over = "system_include_paths", + ), + ], + ), + ], + ) + + external_include_paths_feature = feature( + name = "external_include_paths", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.clif_match, + ACTION_NAMES.objc_compile, + ACTION_NAMES.objcpp_compile, + ], + flag_groups = [ + flag_group( + flags = ["/external:I", "%{external_include_paths}"], + iterate_over = "external_include_paths", + expand_if_available = "external_include_paths", + ), + ], + ), + ], + ) + + linkstamps_feature = feature( + name = "linkstamps", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["%{linkstamp_paths}"], + iterate_over = "linkstamp_paths", + expand_if_available = "linkstamp_paths", + ), + ], + ), + ], + ) + + targets_windows_feature = feature( + name = "targets_windows", + enabled = True, + implies = ["copy_dynamic_libraries_to_binary"], + ) + + linker_subsystem_flag_feature = feature( + name = "linker_subsystem_flag", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/SUBSYSTEM:CONSOLE"])], + ), + ], + ) + + frame_pointer_feature = feature( + name = "frame_pointer", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Oy-"])], + ), + ], + ) + + compiler_output_flags_feature = feature( + name = "compiler_output_flags", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.assemble], + flag_groups = [ + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fo%{output_file}", "/Zi"], + expand_if_available = "output_file", + expand_if_not_available = "output_assembly_file", + ), + ], + expand_if_not_available = "output_preprocess_file", + ), + ], + ), + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fo%{output_file}"], + expand_if_not_available = "output_preprocess_file", + ), + ], + expand_if_available = "output_file", + expand_if_not_available = "output_assembly_file", + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fa%{output_file}"], + expand_if_available = "output_assembly_file", + ), + ], + expand_if_available = "output_file", + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["/P", "/Fi%{output_file}"], + expand_if_available = "output_preprocess_file", + ), + ], + expand_if_available = "output_file", + ), + ], + ), + ], + ) + + nologo_feature = feature( + name = "nologo", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_static_library, + ], + flag_groups = [flag_group(flags = ["/nologo"])], + ), + ], + ) + + smaller_binary_feature = feature( + name = "smaller_binary", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Gy", "/Gw"])], + with_features = [with_feature_set(features = ["opt"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/OPT:ICF", "/OPT:REF"])], + with_features = [with_feature_set(features = ["opt"])], + ), + ], + ) + + compiler_input_flags_feature = feature( + name = "compiler_input_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["/c", "%{source_file}"], + expand_if_available = "source_file", + ), + ], + ), + ], + ) + + def_file_feature = feature( + name = "def_file", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["/DEF:%{def_file_path}", "/ignore:4070"], + expand_if_available = "def_file_path", + ), + ], + ), + ], + ) + + msvc_env_feature = feature( + name = "msvc_env", + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_static_library, + ], + env_entries = [ + env_entry(key = "PATH", value = ctx.attr.msvc_env_path), + env_entry(key = "TMP", value = ctx.attr.msvc_env_tmp), + env_entry(key = "TEMP", value = ctx.attr.msvc_env_tmp), + ], + ), + ], + implies = ["msvc_compile_env", "msvc_link_env"], + ) + features = [ + no_legacy_features_feature, + nologo_feature, + has_configured_linker_path_feature, + no_stripping_feature, + targets_windows_feature, + copy_dynamic_libraries_to_binary_feature, + default_compile_flags_feature, + msvc_env_feature, + msvc_compile_env_feature, + msvc_link_env_feature, + include_paths_feature, + external_include_paths_feature, + preprocessor_defines_feature, + parse_showincludes_feature, + no_dotd_file_feature, + generate_pdb_file_feature, + shared_flag_feature, + linkstamps_feature, + output_execpath_flags_feature, + archiver_flags_feature, + input_param_flags_feature, + linker_subsystem_flag_feature, + user_link_flags_feature, + default_link_flags_feature, + linker_param_file_feature, + static_link_msvcrt_feature, + dynamic_link_msvcrt_feature, + dbg_feature, + fastbuild_feature, + opt_feature, + frame_pointer_feature, + disable_assertions_feature, + determinism_feature, + treat_warnings_as_errors_feature, + smaller_binary_feature, + ignore_noisy_warnings_feature, + user_compile_flags_feature, + sysroot_feature, + unfiltered_compile_flags_feature, + archive_param_file_feature, + compiler_param_file_feature, + compiler_output_flags_feature, + compiler_input_flags_feature, + def_file_feature, + windows_export_all_symbols_feature, + no_windows_export_all_symbols_feature, + supports_dynamic_linker_feature, + supports_interface_shared_libraries_feature, + ] + else: + targets_windows_feature = feature( + name = "targets_windows", + implies = ["copy_dynamic_libraries_to_binary"], + enabled = True, + ) + + copy_dynamic_libraries_to_binary_feature = feature(name = "copy_dynamic_libraries_to_binary") + + gcc_env_feature = feature( + name = "gcc_env", + enabled = True, + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_static_library, + ], + env_entries = [ + env_entry(key = "PATH", value = ctx.attr.tool_bin_path), + ], + ), + ], + ) + + default_compile_flags_feature = feature( + name = "default_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = [flag_group(flags = ["-std=gnu++14"])], + ), + ], + ) + + default_link_flags_feature = feature( + name = "default_link_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["-lstdc++"])], + ), + ], + ) + + supports_dynamic_linker_feature = feature( + name = "supports_dynamic_linker", + enabled = True, + ) + + dbg_feature = feature( + name = "dbg", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["-g", "-Og"])], + ), + ], + ) + + opt_feature = feature( + name = "opt", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = [ + "-g0", + "-O3", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", + ])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["-Wl,--gc-sections"])], + ), + ], + ) + + if ctx.attr.cpu == "x64_windows" and ctx.attr.compiler == "mingw-gcc": + archive_param_file_feature = feature( + name = "archive_param_file", + enabled = True, + ) + + compiler_param_file_feature = feature( + name = "compiler_param_file", + ) + + features = [ + targets_windows_feature, + copy_dynamic_libraries_to_binary_feature, + gcc_env_feature, + default_compile_flags_feature, + archive_param_file_feature, + compiler_param_file_feature, + default_link_flags_feature, + supports_dynamic_linker_feature, + dbg_feature, + opt_feature, + ] + else: + supports_pic_feature = feature( + name = "supports_pic", + enabled = True, + ) + + sysroot_feature = feature( + name = "sysroot", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = ["--sysroot=%{sysroot}"], + expand_if_available = "sysroot", + ), + ], + ), + ], + ) + + fdo_optimize_feature = feature( + name = "fdo_optimize", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "-fprofile-use=%{fdo_profile_path}", + "-fprofile-correction", + ], + expand_if_available = "fdo_profile_path", + ), + ], + ), + ], + provides = ["profile"], + ) + + treat_warnings_as_errors_feature = feature( + name = "treat_warnings_as_errors", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["-Werror"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["-Wl,-fatal-warnings"])], + ), + ], + ) + + user_compile_flags_feature = feature( + name = "user_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = ["%{user_compile_flags}"], + iterate_over = "user_compile_flags", + expand_if_available = "user_compile_flags", + ), + ], + ), + ], + ) + + features = [ + targets_windows_feature, + copy_dynamic_libraries_to_binary_feature, + gcc_env_feature, + supports_pic_feature, + default_compile_flags_feature, + default_link_flags_feature, + fdo_optimize_feature, + supports_dynamic_linker_feature, + dbg_feature, + opt_feature, + user_compile_flags_feature, + treat_warnings_as_errors_feature, + sysroot_feature, + ] + + tool_paths = [ + tool_path(name = name, path = path) + for name, path in ctx.attr.tool_paths.items() + ] + + return cc_common.create_cc_toolchain_config_info( + ctx = ctx, + features = features, + action_configs = action_configs, + artifact_name_patterns = artifact_name_patterns, + cxx_builtin_include_directories = ctx.attr.cxx_builtin_include_directories, + toolchain_identifier = ctx.attr.toolchain_identifier, + host_system_name = ctx.attr.host_system_name, + target_system_name = ctx.attr.target_system_name, + target_cpu = ctx.attr.cpu, + target_libc = ctx.attr.target_libc, + compiler = ctx.attr.compiler, + abi_version = ctx.attr.abi_version, + abi_libc_version = ctx.attr.abi_libc_version, + tool_paths = tool_paths, + ) + +cc_toolchain_config = rule( + implementation = _impl, + attrs = { + "cpu": attr.string(mandatory = True), + "compiler": attr.string(), + "toolchain_identifier": attr.string(), + "host_system_name": attr.string(), + "target_system_name": attr.string(), + "target_libc": attr.string(), + "abi_version": attr.string(), + "abi_libc_version": attr.string(), + "tool_paths": attr.string_dict(), + "cxx_builtin_include_directories": attr.string_list(), + "archiver_flags": attr.string_list(default = []), + "default_link_flags": attr.string_list(default = []), + "msvc_env_tmp": attr.string(default = "msvc_not_found"), + "msvc_env_path": attr.string(default = "msvc_not_found"), + "msvc_env_include": attr.string(default = "msvc_not_found"), + "msvc_env_lib": attr.string(default = "msvc_not_found"), + "msvc_cl_path": attr.string(default = "vc_installation_error.bat"), + "msvc_ml_path": attr.string(default = "vc_installation_error.bat"), + "msvc_link_path": attr.string(default = "vc_installation_error.bat"), + "msvc_lib_path": attr.string(default = "vc_installation_error.bat"), + "dbg_mode_debug_flag": attr.string(), + "fastbuild_mode_debug_flag": attr.string(), + "tool_bin_path": attr.string(default = "not_found"), + "supports_parse_showincludes": attr.bool(), + }, + provides = [CcToolchainConfigInfo], +) diff --git a/third_party/xla/tools/toolchains/win2022/BUILD b/third_party/xla/tools/toolchains/win2022/BUILD new file mode 100644 index 00000000000000..82434f82ddbdd3 --- /dev/null +++ b/third_party/xla/tools/toolchains/win2022/BUILD @@ -0,0 +1,37 @@ +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +java_runtime( + name = "windows_jdk8", + srcs = [], + java_home = "C:/openjdk", +) + +# Register a Windows 2022 (Clang) platform. +# Note that while this does support RBE, the current pool size is tiny, +# and this platform is meant to be used as a non-RBE one, for now. +platform( + name = "windows_ltsc2022_clang", + constraint_values = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + "@bazel_tools//tools/cpp:clang-cl", + ], + remote_execution_properties = """ + properties:{ + name: "container-image" + value: "docker://gcr.io/tensorflow-testing/tf-win2022@sha256:915cb093630432c38b028f56bd31116a5559ebbc688d427b6092d86828ae03bc" + } + properties:{ + name: "OSFamily" + value: "Windows" + } + properties:{ + name: "Pool" value: "win2022" + } + properties:{ + name: "dockerNetwork" value: "off" + } + """, +) diff --git a/third_party/xla/warnings.bazelrc b/third_party/xla/warnings.bazelrc index 259d5d0e624a43..5afe21f7c56bb4 100644 --- a/third_party/xla/warnings.bazelrc +++ b/third_party/xla/warnings.bazelrc @@ -36,6 +36,7 @@ build:warnings --copt=-Wno-deprecated-enum-compare-conditional build:warnings --copt=-Wno-deprecated-enum-float-conversion build:warnings --copt=-Wno-deprecated-this-capture build:warnings --copt=-Wno-return-type-c-linkage +build:warnings --copt=-Wno-nullability-completeness build:warnings --copt=-Wno-bitfield-constant-conversion build:warnings --copt=-Wno-bitwise-instead-of-logical build:warnings --copt=-Wno-comment diff --git a/third_party/xla/workspace2.bzl b/third_party/xla/workspace2.bzl index 86033f5239be12..e594e123c29100 100644 --- a/third_party/xla/workspace2.bzl +++ b/third_party/xla/workspace2.bzl @@ -42,12 +42,42 @@ def _tf_repositories(): # curl -L | sha256sum # and update the sha256 with the result. + # LINT.IfChange tf_http_archive( name = "XNNPACK", sha256 = "3306f4178c8594b689165d385e644f03a3154c3be044f6ae36dd170fbf182cf5", strip_prefix = "XNNPACK-983d013300f19fd3f4e33220b6401408e97a8d12", urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/983d013300f19fd3f4e33220b6401408e97a8d12.zip"), ) + # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) + + tf_http_archive( + name = "KleidiAI", + sha256 = "ad37707084a6d4ff41be10cbe8540c75bea057ba79d0de6c367c1bfac6ba0852", + strip_prefix = "kleidiai-40a926833857fb64786e02f97703e42b1537cb57", + urls = tf_mirror_urls("https://gitlab.arm.com/kleidi/kleidiai/-/archive/40a926833857fb64786e02f97703e42b1537cb57/kleidiai-40a926833857fb64786e02f97703e42b1537cb57.zip"), + ) + + tf_http_archive( + name = "FXdiv", + sha256 = "3d7b0e9c4c658a84376a1086126be02f9b7f753caa95e009d9ac38d11da444db", + strip_prefix = "FXdiv-63058eff77e11aa15bf531df5dd34395ec3017c8", + urls = tf_mirror_urls("https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip"), + ) + + tf_http_archive( + name = "cpuinfo", + sha256 = "52e0ffd7998d8cb3a927d8a6e1145763744d866d2be09c4eccea27fc157b6bb0", + strip_prefix = "cpuinfo-cebb0933058d7f181c979afd50601dc311e1bf8c", + urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/cebb0933058d7f181c979afd50601dc311e1bf8c.zip"), + ) + + tf_http_archive( + name = "pthreadpool", + sha256 = "a4cf06de57bfdf8d7b537c61f1c3071bce74e57524fe053e0bbd2332feca7f95", + strip_prefix = "pthreadpool-4fe0e1e183925bf8cfa6aae24237e724a96479b8", + urls = tf_mirror_urls("https://github.com/Maratyszcza/pthreadpool/archive/4fe0e1e183925bf8cfa6aae24237e724a96479b8.zip"), + ) tf_http_archive( name = "jsoncpp_git", @@ -61,9 +91,9 @@ def _tf_repositories(): name = "cudnn_frontend_archive", build_file = "//third_party:cudnn_frontend.BUILD", patch_file = ["//third_party:cudnn_frontend_header_fix.patch"], - sha256 = "5f77784dc3ccbca7aca5ea0b5a6e31b95aa85023c5942d22be5fa8dd6c339d81", - strip_prefix = "cudnn-frontend-1.8.0", - urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.8.0.zip"), + sha256 = "7be8afebc693f0ef75bbc673ce5c1cf422673e84ea7d53e488201756c046496e", + strip_prefix = "cudnn-frontend-1.9.0", + urls = tf_mirror_urls("https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.9.0.zip"), ) tf_http_archive( diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 39abcda6320730..bac1d681c21388 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -1,7 +1,7 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("//third_party/compute_library:build_defs.bzl", "if_enable_acl") # copybara:uncomment load("@rules_python//python:proto.bzl", "py_proto_library") -load("//third_party/compute_library:build_defs.bzl", "if_enable_acl") load("//xla:package_groups.bzl", "xla_package_groups") load("//xla:xla.bzl", "xla_bzl_library", "xla_cc_test", "xla_py_proto_library") load("//xla/tsl:tsl.bzl", "if_google", "internal_visibility") @@ -90,7 +90,7 @@ xla_cc_test( srcs = ["bit_cast_test.cc"], deps = [ ":bit_cast", - ":test", + "//xla/hlo/testlib:test", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:bfloat16", "@local_tsl//tsl/platform:test_main", @@ -127,9 +127,9 @@ xla_cc_test( srcs = ["comparison_util_test.cc"], deps = [ ":comparison_util", - ":test", ":types", ":xla_data_proto_cc", + "//xla/hlo/testlib:test", "@local_tsl//tsl/platform:test_main", ], ) @@ -157,7 +157,7 @@ xla_cc_test( srcs = ["ef57_test.cc"], deps = [ ":ef57", - ":test", + "//xla/hlo/testlib:test", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:log_streamer", "@com_google_absl//absl/random", @@ -200,12 +200,9 @@ cc_library( name = "test", testonly = 1, hdrs = ["test.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/testlib:test instead.", visibility = internal_visibility([":friends"]), - deps = [ - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform", - "@local_tsl//tsl/platform:test", - ], + deps = ["//xla/hlo/testlib:test"], ) cc_library( @@ -226,8 +223,9 @@ xla_cc_test( srcs = ["types_test.cc"], visibility = ["//visibility:private"], deps = [ - ":test", ":types", + "//xla/hlo/testlib:test", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:test_main", ], ) @@ -240,6 +238,8 @@ cc_library( deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:log_severity", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -257,10 +257,11 @@ xla_cc_test( srcs = ["status_macros_test.cc"], deps = [ ":status_macros", - ":test", - ":test_helpers", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", @@ -284,8 +285,8 @@ xla_cc_test( deps = [ ":bit_cast", ":fp_util", - ":test", ":util", + "//xla/hlo/testlib:test", "@com_google_absl//absl/base", "@com_google_absl//absl/numeric:bits", "@com_google_googletest//:gtest_main", @@ -344,11 +345,15 @@ xla_cc_test( name = "util_test", srcs = ["util_test.cc"], deps = [ - ":test", ":types", ":util", + "//xla/hlo/testlib:test", + "@com_google_absl//absl/base:log_severity", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test_main", @@ -374,7 +379,7 @@ xla_cc_test( srcs = ["permutation_util_test.cc"], deps = [ ":permutation_util", - ":test", + "//xla/hlo/testlib:test", "@local_tsl//tsl/platform:test_main", ], ) @@ -401,8 +406,8 @@ xla_cc_test( name = "iterator_util_test", srcs = ["iterator_util_test.cc"], deps = [ - ":test", ":util", + "//xla/hlo/testlib:test", "@local_tsl//tsl/platform:test_main", ], ) @@ -477,9 +482,10 @@ xla_cc_test( srcs = ["shape_test.cc"], deps = [ ":shape_util", - ":test", ":xla_data_proto_cc", + "//xla/hlo/testlib:test", "@com_google_absl//absl/hash:hash_testing", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], @@ -490,13 +496,14 @@ xla_cc_test( srcs = ["shape_util_test.cc"], deps = [ ":shape_util", - ":test", ":util", ":xla_data_proto_cc", + "//xla/hlo/testlib:test", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:test_benchmark", @@ -509,9 +516,9 @@ xla_cc_test( srcs = ["primitive_util_test.cc"], deps = [ ":shape_util", - ":test", - ":test_helpers", ":xla_data_proto_cc", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], @@ -522,9 +529,10 @@ xla_cc_test( srcs = ["layout_util_test.cc"], deps = [ ":shape_util", - ":test", - ":test_helpers", ":xla_data_proto_cc", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", + "@com_google_absl//absl/log", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -538,8 +546,8 @@ xla_cc_test( srcs = ["layout_test.cc"], deps = [ ":shape_util", - ":test", ":xla_data_proto_cc", + "//xla/hlo/testlib:test", "@local_tsl//tsl/platform:test_main", ], ) @@ -549,8 +557,8 @@ xla_cc_test( srcs = ["index_util_test.cc"], deps = [ ":shape_util", - ":test", ":xla_data_proto_cc", + "//xla/hlo/testlib:test", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:test_main", ], @@ -607,10 +615,10 @@ xla_cc_test( ":literal_util", ":shape_tree", ":shape_util", - ":test", ":types", ":util", ":xla_data_proto_cc", + "//xla/hlo/testlib:test", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/base", "@com_google_absl//absl/hash", @@ -629,6 +637,33 @@ xla_cc_test( ], ) +cc_library( + name = "literal_pool", + srcs = ["literal_pool.cc"], + hdrs = ["literal_pool.h"], + visibility = ["//visibility:public"], + deps = [ + ":literal", + ":shape_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "literal_pool_test", + srcs = ["literal_pool_test.cc"], + deps = [ + ":literal", + ":literal_pool", + ":literal_util", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "literal_util", srcs = ["literal_util.cc"], @@ -671,8 +706,8 @@ xla_cc_test( ":error_spec", ":literal_comparison", ":literal_util", - ":test_helpers", ":xla_data_proto_cc", + "//xla/hlo/testlib:test_helpers", "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:ml_dtypes", @@ -749,7 +784,7 @@ xla_cc_test( srcs = ["array_test.cc"], deps = [ ":array", - ":test", + "//xla/hlo/testlib:test", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", @@ -775,7 +810,7 @@ xla_cc_test( srcs = ["array2d_test.cc"], deps = [ ":array2d", - ":test", + "//xla/hlo/testlib:test", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test_main", @@ -798,8 +833,8 @@ xla_cc_test( srcs = ["array3d_test.cc"], deps = [ ":array3d", - ":test", ":types", + "//xla/hlo/testlib:test", "@local_tsl//tsl/platform:test_main", ], ) @@ -824,7 +859,7 @@ xla_cc_test( deps = [ ":array2d", ":array4d", - ":test", + "//xla/hlo/testlib:test", "@com_google_absl//absl/log", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", @@ -871,12 +906,9 @@ cc_library( name = "test_helpers", testonly = 1, hdrs = ["test_helpers.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/testlib:test_helpers instead.", visibility = internal_visibility([":friends"]), - deps = [ - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:test", - ], + deps = ["//xla/hlo/testlib:test_helpers"], ) cc_library( @@ -910,10 +942,11 @@ xla_cc_test( deps = [ ":literal", ":shape_util", - ":test", ":text_literal_reader", ":types", ":xla_data_proto_cc", + "//xla/hlo/testlib:test", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test_main", ], @@ -944,11 +977,12 @@ xla_cc_test( deps = [ ":literal", ":literal_util", - ":test", - ":test_helpers", ":text_literal_writer", ":types", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test_main", ], @@ -964,6 +998,7 @@ cc_library( "//xla/tsl/lib/gtl:iterator_range", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", @@ -979,8 +1014,9 @@ xla_cc_test( deps = [ ":shape_tree", ":shape_util", - ":test", ":xla_data_proto_cc", + "//xla/hlo/testlib:test", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], @@ -995,6 +1031,7 @@ cc_library( ":printer", ":shape_util", ":util", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", @@ -1011,6 +1048,7 @@ cc_library( ":xla_data_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", @@ -1021,9 +1059,10 @@ xla_cc_test( name = "window_util_test", srcs = ["window_util_test.cc"], deps = [ - ":test", ":window_util", ":xla_data_proto_cc", + "//xla/hlo/testlib:test", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:test_main", ], ) @@ -1069,10 +1108,11 @@ xla_cc_test( ":literal", ":literal_util", ":reference_util", - ":test", ":xla_data_proto_cc", "//xla/hlo/builder:padding", + "//xla/hlo/testlib:test", "//xla/tests:literal_test_util", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:test_main", ], ) @@ -1173,28 +1213,6 @@ xla_cc_test( ], ) -cc_library( - name = "refcounting_hash_map", - hdrs = ["refcounting_hash_map.h"], - deps = [ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", - ], -) - -xla_cc_test( - name = "refcounting_hash_map_test", - srcs = ["refcounting_hash_map_test.cc"], - deps = [ - ":refcounting_hash_map", - ":test", - "@local_tsl//tsl/platform:test_main", - ], -) - cc_library( name = "union_find", hdrs = ["union_find.h"], @@ -1270,8 +1288,8 @@ xla_cc_test( ":autotune_result_wrapper", ":autotune_results_proto_cc", ":autotuning_proto_cc", - ":test", - ":test_helpers", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], diff --git a/third_party/xla/xla/array.h b/third_party/xla/xla/array.h index 1d28388c563117..0bec1540e95f48 100644 --- a/third_party/xla/xla/array.h +++ b/third_party/xla/xla/array.h @@ -26,6 +26,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -442,16 +443,18 @@ class Array { bool operator!=(const Array& other) const { return !(*this == other); } // Performs the equivalent of a slice operation on this array. + // When `out_of_bounds_value` is specified, the out of bounds accesses are ok + // and the slice is initialized to the given value. Array Slice(absl::Span starts, absl::Span limits, - bool out_of_bounds_ok = false) const { + std::optional out_of_bounds_value = std::nullopt) const { CHECK_EQ(starts.size(), num_dimensions()); CHECK_EQ(limits.size(), num_dimensions()); OwnedBuffer sizes(starts.size()); for (int64_t i = 0; i < starts.size(); ++i) { CHECK_GE(starts[i], 0); - if (!out_of_bounds_ok) { + if (!out_of_bounds_value.has_value()) { CHECK_LE(limits[i], dim(i)); } sizes[i] = limits[i] - starts[i]; @@ -460,11 +463,10 @@ class Array { if (result.num_elements() == 0) { return result; } - // Initializes the slice to the first value if out of bounds access are ok. - if (out_of_bounds_ok) { - CHECK_GT(num_elements(), 0); + // Initializes the slice to the given value if out of bounds access are ok. + if (out_of_bounds_value.has_value()) { for (int64_t i = 0; i < result.num_elements(); ++i) { - result.values_[i] = values_[0]; + result.values_[i] = out_of_bounds_value.value(); } } diff --git a/third_party/xla/xla/array2d_test.cc b/third_party/xla/xla/array2d_test.cc index 921da30256fa3d..055a6e77420819 100644 --- a/third_party/xla/xla/array2d_test.cc +++ b/third_party/xla/xla/array2d_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "Eigen/Core" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" #include "tsl/platform/ml_dtypes.h" namespace xla { diff --git a/third_party/xla/xla/array3d_test.cc b/third_party/xla/xla/array3d_test.cc index 334d733266b41b..3ed4d7b2a7532f 100644 --- a/third_party/xla/xla/array3d_test.cc +++ b/third_party/xla/xla/array3d_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" #include "xla/types.h" namespace xla { diff --git a/third_party/xla/xla/array4d_test.cc b/third_party/xla/xla/array4d_test.cc index 1deb1bc81f3c7e..7d8bcb7c6930ad 100644 --- a/third_party/xla/xla/array4d_test.cc +++ b/third_party/xla/xla/array4d_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "absl/types/span.h" #include "Eigen/Core" #include "xla/array2d.h" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" namespace xla { namespace { diff --git a/third_party/xla/xla/array_test.cc b/third_party/xla/xla/array_test.cc index bf79aa98f40491..a20223d746c729 100644 --- a/third_party/xla/xla/array_test.cc +++ b/third_party/xla/xla/array_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "Eigen/Core" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" namespace xla { namespace { diff --git a/third_party/xla/xla/autotune_result_wrapper_test.cc b/third_party/xla/xla/autotune_result_wrapper_test.cc index 848024d7e4343e..8259a15c715cbe 100644 --- a/third_party/xla/xla/autotune_result_wrapper_test.cc +++ b/third_party/xla/xla/autotune_result_wrapper_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "xla/autotune_results.pb.h" #include "xla/autotuning.pb.h" -#include "xla/test.h" -#include "xla/test_helpers.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/backends/cpu/BUILD b/third_party/xla/xla/backends/cpu/BUILD index 80150ef859a8e0..9e8a4b8b2232c3 100644 --- a/third_party/xla/xla/backends/cpu/BUILD +++ b/third_party/xla/xla/backends/cpu/BUILD @@ -1,3 +1,4 @@ +load("//xla/backends/cpu:package_groups.bzl", "xla_cpu_backend_access") load("//xla/tsl:tsl.bzl", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "filegroup") load("//xla/tsl/platform:rules_cc.bzl", "cc_library") @@ -8,6 +9,8 @@ package( licenses = ["notice"], ) +xla_cpu_backend_access() + package_group( name = "friends", includes = [ @@ -26,3 +29,23 @@ cc_library( hdrs = ["alignment.h"], deps = ["@eigen_archive//:eigen3"], ) + +cc_library( + name = "xnn_emitter", + srcs = ["xnn_emitter.cc"], + hdrs = ["xnn_emitter.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/backends/cpu/runtime/xnnpack:xnn_interop", + "//xla/hlo/ir:hlo", + "//xla/tsl/platform:logging", + "@XNNPACK", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:statusor", + ], +) diff --git a/third_party/xla/xla/backends/cpu/codegen/BUILD b/third_party/xla/xla/backends/cpu/codegen/BUILD index f495a6357dbf12..1027191c80578e 100644 --- a/third_party/xla/xla/backends/cpu/codegen/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/BUILD @@ -10,7 +10,7 @@ load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], + default_visibility = ["//xla/backends/cpu:xla_backend_cpu_internal_access"], licenses = ["notice"], ) @@ -26,12 +26,13 @@ cc_library( srcs = ["contiguous_section_memory_manager.cc"], hdrs = ["contiguous_section_memory_manager.h"], deps = [ - "//xla:util", - "@llvm-project//llvm:Core", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@llvm-project//llvm:ExecutionEngine", - "@llvm-project//llvm:OrcJIT", "@llvm-project//llvm:Support", - "@local_tsl//tsl/platform:logging", + # TODO(basioli): This dependency increases the binary size significantly. + # Consider reducing the dependency size, or use something alternative. + "//xla:util", ], ) @@ -82,15 +83,18 @@ cc_library( srcs = ["jit_compiler.cc"], hdrs = ["jit_compiler.h"], deps = [ + ":compiled_function_library", ":contiguous_section_memory_manager", ":cpu_features", ":ir_compiler", "//xla:util", "//xla/backends/cpu/runtime:function_library", "//xla/service/cpu:orc_jit_memory_mapper", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -132,6 +136,7 @@ xla_cc_test( "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//llvm:AsmParser", "@llvm-project//llvm:Core", @@ -156,6 +161,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:math_ops", + "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Analysis", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", @@ -199,3 +205,209 @@ cc_library( "@llvm-project//llvm:Support", ], ) + +cc_library( + name = "kernel_api_ir_builder", + srcs = ["kernel_api_ir_builder.cc"], + hdrs = ["kernel_api_ir_builder.h"], + deps = [ + "//xla:cpu_function_runtime", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/service/llvm_ir:ir_array", + "//xla/service/llvm_ir:llvm_util", + "//xla/tsl/platform:errors", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "kernel_api_ir_builder_test", + srcs = ["kernel_api_ir_builder_test.cc"], + deps = [ + ":kernel_api_ir_builder", + "//xla:cpu_function_runtime", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_ordering", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:filecheck", + "//xla/service:buffer_assignment", + "//xla/service:hlo_module_config", + "//xla/service:logical_buffer", + "//xla/service/llvm_ir:ir_array", + "//xla/service/llvm_ir:llvm_util", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:ir_headers", + ], +) + +cc_library( + name = "llvm_ir_kernel_spec", + srcs = ["llvm_ir_kernel_spec.cc"], + hdrs = ["llvm_ir_kernel_spec.h"], + deps = [ + "//xla/codegen:kernel_spec", + "//xla/codegen:llvm_ir_kernel_source", + "//xla/service:buffer_assignment", + "//xla/stream_executor:launch_dim", + ], +) + +cc_library( + name = "elemental_kernel_emitter", + srcs = ["elemental_kernel_emitter.cc"], + hdrs = ["elemental_kernel_emitter.h"], + deps = [ + ":kernel_api_ir_builder", + ":llvm_ir_kernel_spec", + ":target_machine_features", + "//xla:shape_util", + "//xla:util", + "//xla/codegen:kernel_emitter", + "//xla/codegen:kernel_spec", + "//xla/codegen:llvm_ir_kernel_source", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service:elemental_ir_emitter", + "//xla/service:hlo_module_config", + "//xla/service/cpu:backend_config_proto_cc", + "//xla/service/cpu:elemental_ir_emitter", + "//xla/service/cpu:ir_emitter", + "//xla/service/cpu:parallel_loop_emitter", + "//xla/service/cpu:shape_partition", + "//xla/service/llvm_ir:ir_array", + "//xla/service/llvm_ir:llvm_util", + "//xla/service/llvm_ir:loop_emitter", + "//xla/stream_executor:launch_dim", + "//xla/tsl/platform:errors", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:JITLink", + "@llvm-project//llvm:ir_headers", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "elemental_kernel_emitter_test", + srcs = ["elemental_kernel_emitter_test.cc"], + deps = [ + ":elemental_kernel_emitter", + ":llvm_ir_kernel_spec", + "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_ordering", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:filecheck", + "//xla/service:buffer_assignment", + "//xla/service:logical_buffer", + "//xla/service/cpu:target_machine_features_stub", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:ir_headers", + "@local_tsl//tsl/platform:casts", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "compiled_function_library", + srcs = ["compiled_function_library.cc"], + hdrs = ["compiled_function_library.h"], + deps = [ + "//xla/backends/cpu/runtime:function_library", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//llvm:ExecutionEngine", + "@llvm-project//llvm:OrcJIT", + ], +) + +cc_library( + name = "object_loader", + srcs = ["object_loader.cc"], + hdrs = ["object_loader.h"], + deps = [ + ":compiled_function_library", + ":contiguous_section_memory_manager", + "//xla/backends/cpu/runtime:function_library", + "//xla/service/cpu:orc_jit_memory_mapper", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:JITLink", + "@llvm-project//llvm:OrcShared", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + ], +) + +xla_cc_test( + name = "object_loader_test", + srcs = ["object_loader_test.cc"], + deps = [ + ":ir_compiler", + ":jit_compiler", + ":object_loader", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/backends/cpu/runtime:function_library", + "//xla/service:cpu_plugin", + "//xla/service/cpu:executable_proto_cc", + "//xla/service/llvm_ir:llvm_util", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:AsmParser", + "@llvm-project//llvm:JITLink", + "@llvm-project//llvm:Object", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:ir_headers", + "@local_tsl//tsl/platform:statusor", + ], +) diff --git a/third_party/xla/xla/backends/cpu/codegen/compiled_function_library.cc b/third_party/xla/xla/backends/cpu/codegen/compiled_function_library.cc new file mode 100644 index 00000000000000..7f111e5a3566b2 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/compiled_function_library.cc @@ -0,0 +1,68 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/codegen/compiled_function_library.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "xla/backends/cpu/runtime/function_library.h" + +namespace xla::cpu { + +CompiledFunctionLibrary::CompiledFunctionLibrary( + std::unique_ptr execution_session, + std::unique_ptr object_layer, + absl::flat_hash_map symbols_map) + : execution_session_(std::move(execution_session)), + object_layer_(std::move(object_layer)), + symbols_map_(std::move(symbols_map)) { + DCHECK(execution_session_) << "Execution session must not be null"; +} + +CompiledFunctionLibrary::~CompiledFunctionLibrary() { + if (execution_session_) { + if (auto err = execution_session_->endSession()) { + execution_session_->reportError(std::move(err)); + } + } +} + +absl::StatusOr CompiledFunctionLibrary::ResolveFunction( + TypeId type_id, absl::string_view name) { + if (auto it = symbols_map_.find(name); it != symbols_map_.end()) { + if (it->second.type_id != type_id) { + return absl::Status( + absl::StatusCode::kInternal, + absl::StrFormat("Symbol %s has type id %d, expected %d", name, + it->second.type_id.value(), type_id.value())); + } + return it->second.ptr; + } + return absl::Status(absl::StatusCode::kNotFound, + absl::StrFormat("Function %s not found (type id: %d)", + name, type_id.value())); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/codegen/compiled_function_library.h b/third_party/xla/xla/backends/cpu/codegen/compiled_function_library.h new file mode 100644 index 00000000000000..b91100a66dd10c --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/compiled_function_library.h @@ -0,0 +1,68 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_CODEGEN_COMPILED_FUNCTION_LIBRARY_H_ +#define XLA_BACKENDS_CPU_CODEGEN_COMPILED_FUNCTION_LIBRARY_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "xla/backends/cpu/runtime/function_library.h" + +namespace xla::cpu { + +// A CompiledFunctionLibrary is a FunctionLibrary that resolves function names +// to compiled functions using LLVM's ORC JIT. +class CompiledFunctionLibrary : public FunctionLibrary { + public: + struct ResolvedSymbol { + TypeId type_id; + void* ptr; + }; + + // Constructs a new CompiledFunctionLibrary. + // + // `execution_session` is the LLVM ORC execution session to use. + // `object_layer` is the LLVM ORC object linking layer with preloaded object + // files. + // `symbols_map` is a map from symbol names to resolved symbols. + CompiledFunctionLibrary( + std::unique_ptr execution_session, + std::unique_ptr object_layer, + absl::flat_hash_map symbols_map); + + ~CompiledFunctionLibrary() final; + + // Resolves the function with the given name and type ID. + absl::StatusOr ResolveFunction(TypeId type_id, + absl::string_view name) final; + + private: + std::unique_ptr execution_session_; + // Owns resources required for the execution session. + std::unique_ptr object_layer_; + // Caches the resolved symbols so we don't have to look them up every time a + // function is resolved. + absl::flat_hash_map symbols_map_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_CODEGEN_COMPILED_FUNCTION_LIBRARY_H_ diff --git a/third_party/xla/xla/backends/cpu/codegen/contiguous_section_memory_manager.cc b/third_party/xla/xla/backends/cpu/codegen/contiguous_section_memory_manager.cc index f30fa63be52ad9..ae15857de011c1 100644 --- a/third_party/xla/xla/backends/cpu/codegen/contiguous_section_memory_manager.cc +++ b/third_party/xla/xla/backends/cpu/codegen/contiguous_section_memory_manager.cc @@ -20,12 +20,13 @@ limitations under the License. #include #include // NOLINT +#include "absl/log/check.h" +#include "absl/log/log.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/Support/Alignment.h" #include "llvm/Support/Memory.h" #include "llvm/Support/Process.h" #include "xla/util.h" -#include "tsl/platform/logging.h" namespace xla::cpu { namespace { diff --git a/third_party/xla/xla/backends/cpu/codegen/cpu_features.cc b/third_party/xla/xla/backends/cpu/codegen/cpu_features.cc index 88829db2fc5ce5..6697676c583cc6 100644 --- a/third_party/xla/xla/backends/cpu/codegen/cpu_features.cc +++ b/third_party/xla/xla/backends/cpu/codegen/cpu_features.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include "absl/algorithm/container.h" @@ -36,7 +35,7 @@ namespace xla::cpu { using tsl::port::CPUFeature; // Returns the earliest CPU generation that supports the instruction set. -std::string_view CpuTargetFromMaxFeature(CPUFeature max_feature) { +absl::string_view CpuTargetFromMaxFeature(CPUFeature max_feature) { switch (max_feature) { case CPUFeature::SSE4_2: return "nehalem"; @@ -60,7 +59,7 @@ std::string_view CpuTargetFromMaxFeature(CPUFeature max_feature) { } } -std::optional CpuFeatureFromString(std::string_view cpu_feature) { +std::optional CpuFeatureFromString(absl::string_view cpu_feature) { if (cpu_feature.empty()) return std::nullopt; // Non-exhaustive list of CPU features. (Only the ones we care about.) @@ -90,7 +89,7 @@ std::optional CpuFeatureFromString(std::string_view cpu_feature) { // switch statement is the most readable way to express the logic. // // NOLINTNEXTLINE(readability-function-cognitive-complexity) -bool ShouldEnableCpuFeature(std::string_view feature, CPUFeature max_feature) { +bool ShouldEnableCpuFeature(absl::string_view feature, CPUFeature max_feature) { // x86 CPUs have backward compatibility so newer CPUs have all features of // older CPUs. We go through switch cases from oldest features to newest. // - Each case looks for features that are introduced in the next diff --git a/third_party/xla/xla/backends/cpu/codegen/cpu_features.h b/third_party/xla/xla/backends/cpu/codegen/cpu_features.h index 5d0053c4093f96..c98ed1b4d37610 100644 --- a/third_party/xla/xla/backends/cpu/codegen/cpu_features.h +++ b/third_party/xla/xla/backends/cpu/codegen/cpu_features.h @@ -19,25 +19,25 @@ limitations under the License. #include #include #include -#include #include #include "absl/base/attributes.h" +#include "absl/strings/string_view.h" #include "tsl/platform/cpu_info.h" namespace xla::cpu { // Returns the earliest CPU generation that supports the instruction set. -std::string_view CpuTargetFromMaxFeature(tsl::port::CPUFeature max_feature); +absl::string_view CpuTargetFromMaxFeature(tsl::port::CPUFeature max_feature); // Converts a string representation of a CPU feature to a CPUFeature enum. // Returns std::nullopt if the string is not a valid CPU feature. std::optional CpuFeatureFromString( - std::string_view cpu_feature); + absl::string_view cpu_feature); // Returns true if `feature` can be enabled given the maximum allowed CPU // feature `max_feature`. -bool ShouldEnableCpuFeature(std::string_view feature, +bool ShouldEnableCpuFeature(absl::string_view feature, tsl::port::CPUFeature max_feature); struct DetectedMachineAttributes { diff --git a/third_party/xla/xla/backends/cpu/codegen/elemental_kernel_emitter.cc b/third_party/xla/xla/backends/cpu/codegen/elemental_kernel_emitter.cc new file mode 100644 index 00000000000000..19c15a903c8cea --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/elemental_kernel_emitter.cc @@ -0,0 +1,380 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/codegen/elemental_kernel_emitter.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Value.h" +#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h" +#include "xla/backends/cpu/codegen/llvm_ir_kernel_spec.h" +#include "xla/backends/cpu/codegen/target_machine_features.h" +#include "xla/codegen/kernel_spec.h" +#include "xla/codegen/llvm_ir_kernel_source.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/elemental_ir_emitter.h" +#include "xla/service/cpu/ir_emitter.h" +#include "xla/service/cpu/parallel_loop_emitter.h" +#include "xla/service/cpu/shape_partition.h" +#include "xla/service/elemental_ir_emitter.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/llvm_ir/ir_array.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/service/llvm_ir/loop_emitter.h" +#include "xla/shape.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/tsl/platform/errors.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" + +namespace xla::cpu { + +namespace { + +KernelApiIrBuilder::Options KernelApiIrBuilderOptionsFromHloModuleConfig( + const HloModule* hlo_module) { + if (hlo_module == nullptr) { + return {true, 256}; + } + + const HloModuleConfig& config = hlo_module->config(); + return KernelApiIrBuilder::Options{ + config.debug_options().xla_llvm_enable_invariant_load_metadata(), + config.debug_options().xla_cpu_prefer_vector_width()}; +} + +struct ParallelConfig { + std::vector outer_dimension_partitions; +}; + +// Parallel partition bounds for parallelized outer dimensions: +// vector<[i64 lower_bound, i64 upper_bound]> +using ParallelPartitionBounds = + std::vector>; + +std::optional GetParallelConfig(const HloInstruction* instr) { + // Check if the instruction is marked for parallel execution. + auto backend_config = instr->backend_config(); + if (!backend_config.ok() || + backend_config->outer_dimension_partitions().empty()) { + return std::nullopt; + } + + ParallelConfig config; + config.outer_dimension_partitions.assign( + backend_config->outer_dimension_partitions().begin(), + backend_config->outer_dimension_partitions().end()); + + return config; +} + +ParallelPartitionBounds EmitParallelPartitionBounds( + llvm::IRBuilderBase& b, + const KernelApiIrBuilder::KernelPrototype& kernel_prototype, + const ParallelConfig& parallel_config, const Shape& shape, + absl::string_view name) { + ShapePartitionIterator it(shape, parallel_config.outer_dimension_partitions); + + size_t num_parallel_dimensions = + parallel_config.outer_dimension_partitions.size(); + + // Create a constant array of all partition bounds. We will be indexing into + // this array using block and thread dimension indices passed in a call frame. + // + // Type: [#partitions x [#outer_dimensions x [lower_bound, upper_bound]]] + // + llvm::ArrayType* dim_bounds_ty = llvm::ArrayType::get(b.getInt64Ty(), 2); + llvm::ArrayType* partition_bounds_ty = + llvm::ArrayType::get(dim_bounds_ty, num_parallel_dimensions); + llvm::ArrayType* parallel_bounds_ty = + llvm::ArrayType::get(partition_bounds_ty, it.GetTotalPartitionCount()); + + // Build a nested array of partition bounds from shape partition iterator. + std::vector partition_bounds; + for (int64_t i = 0; i < it.GetTotalPartitionCount(); ++i) { + std::vector dim_counts; + for (auto [lower, size] : it.GetPartition(i)) { + dim_counts.push_back(llvm::ConstantArray::get( + dim_bounds_ty, {b.getInt64(lower), b.getInt64(lower + size)})); + } + partition_bounds.push_back( + llvm::ConstantArray::get(partition_bounds_ty, dim_counts)); + } + + llvm::Constant* parallel_bounds = + llvm::ConstantArray::get(parallel_bounds_ty, partition_bounds); + + llvm::Module* module = b.GetInsertBlock()->getParent()->getParent(); + llvm::GlobalVariable* parallel_bounds_global = new llvm::GlobalVariable( + /*M=*/*module, + /*Ty=*/parallel_bounds_ty, + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/parallel_bounds, + /*Name=*/absl::StrCat(name, "_parallel_bounds")); + + // Construct IR to load bounds for all parallel dimensions. + ParallelPartitionBounds bounds; + for (size_t i = 0; i < num_parallel_dimensions; ++i) { + llvm::Value* partition = kernel_prototype.thread_id.x; + llvm::Value* parallel_dim = b.getInt32(i); + + llvm::Value* lower_gep = b.CreateInBoundsGEP( + parallel_bounds_ty, parallel_bounds_global, + {b.getInt32(0), partition, parallel_dim, b.getInt32(0)}, + absl::StrCat("lo_dim_", i, "_gep")); + + llvm::Value* upper_gep = b.CreateInBoundsGEP( + parallel_bounds_ty, parallel_bounds_global, + {b.getInt32(0), partition, parallel_dim, b.getInt32(1)}, + absl::StrCat("up_dim_", i, "_gep")); + + bounds.emplace_back( + b.CreateLoad(b.getInt64Ty(), lower_gep, absl::StrCat("lo_dim_", i)), + b.CreateLoad(b.getInt64Ty(), upper_gep, absl::StrCat("up_dim_", i))); + } + + return bounds; +} + +// Implementation detail for ComputationsTransitivelyContainCustomCall, which +// recursively checks whether a computation contains a custom call. +bool RecursivelyCheckForCustomCall( + const HloComputation& computation, + absl::flat_hash_map& custom_call_map) { + bool contains_custom_call = computation.IsCustomCallComputation(); + + for (const HloInstruction* instruction : computation.instructions()) { + for (const HloComputation* nested_computation : + instruction->called_computations()) { + if (const auto itr = custom_call_map.find(nested_computation); + itr != custom_call_map.end()) { + return itr->second; + } + contains_custom_call |= + RecursivelyCheckForCustomCall(*nested_computation, custom_call_map); + } + } + + custom_call_map[&computation] = contains_custom_call; + return contains_custom_call; +} + +// For each called computation in operation, determines whether that computation +// calls a custom-call function, either directly or indirectly (e.g. because it +// calls another computation that does). +absl::flat_hash_map +ComputationsTransitivelyContainCustomCall(const HloInstruction* instr) { + absl::flat_hash_map custom_call_map; + + for (const HloComputation* computation : instr->called_computations()) { + RecursivelyCheckForCustomCall(*computation, custom_call_map); + } + + return custom_call_map; +} + +} // namespace + +ElementalKernelEmitter::ElementalKernelEmitter(const HloInstruction* instr) + : ElementalKernelEmitter(instr, nullptr, nullptr) {} + +ElementalKernelEmitter::ElementalKernelEmitter( + const HloInstruction* instr, const BufferAssignment* buffer_assignment, + const TargetMachineFeatures* target_machine) + : instr_(instr), + buffer_assignment_(buffer_assignment), + target_machine_(target_machine), + context_(std::make_unique()), + kernel_api_ir_builder_( + *context_.getContext(), + KernelApiIrBuilderOptionsFromHloModuleConfig(instr_->GetModule())) {} + +absl::StatusOr> +ElementalKernelEmitter::EmitKernelSpec() { + VLOG(2) << "Emit elemental host kernel: " << instr_->name(); + + llvm::LLVMContext& ctx = *context_.getContext(); + + // A module identifier (prefix) for emitted LLVM modules. + // (Module must be prefixed with this to ensure the cpu_compiler gives correct + // name to the dumped IR file) + static constexpr absl::string_view kXlaModuleIdentifier = "__compute_module"; + auto module = std::make_unique( + absl::StrCat(kXlaModuleIdentifier, "_", instr_->name(), + "_elemental_kernel_module"), + ctx); + + TF_ASSIGN_OR_RETURN(KernelApiIrBuilder::KernelPrototype kernel_prototype, + kernel_api_ir_builder_.EmitKernelPrototype( + *module, instr_, buffer_assignment_, "_kernel")); + + llvm::IRBuilder<> ir_builder(ctx); + ir_builder.SetInsertPoint( + kernel_prototype.function->getEntryBlock().getTerminator()); + + TF_ASSIGN_OR_RETURN( + CpuElementalIrEmitter::ThreadLocalCallCallback thread_local_call_fn, + ThreadLocalCallbackFactory(ir_builder, *module)); + + CpuElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; + for (int64_t i = 0; i < instr_->operand_count(); ++i) { + const HloInstruction* operand = instr_->operand(i); + operand_to_generator[operand] = [&, i](const llvm_ir::IrArray::Index& idx) { + return kernel_prototype.arguments[i].EmitReadArrayElement(idx, + &ir_builder); + }; + } + + const HloModule* hlo_module = instr_->GetModule(); + bool enable_fast_min_max = + hlo_module + ? hlo_module->config().debug_options().xla_cpu_enable_fast_min_max() + : true; + CpuElementalIrEmitter elemental_ir_emitter(module.get(), &ir_builder, + std::move(thread_local_call_fn), + true, enable_fast_min_max); + + llvm_ir::ElementGenerator element_generator = + elemental_ir_emitter.MakeElementGenerator(instr_, operand_to_generator); + + TF_ASSIGN_OR_RETURN(se::ThreadDim thread_dims, + EmitElementalLoops(ir_builder, instr_, kernel_prototype, + element_generator)); + + auto source = std::make_unique( + context_, std::move(module), + std::string(kernel_prototype.function->getName())); + + // TODO(willfroom): what do we do with buffer allocations? + // The same data should be in buffer_uses? + std::vector buffer_allocations; + + return std::make_unique( + thread_dims, std::move(buffer_allocations), + std::move(kernel_prototype.buffer_uses), std::move(source)); +} + +absl::StatusOr ElementalKernelEmitter::EmitElementalLoops( + llvm::IRBuilderBase& b, const HloInstruction* instr, + const KernelApiIrBuilder::KernelPrototype& kernel_prototype, + const llvm_ir::ElementGenerator& element_generator) { + // We can emit loops for instruction with multiple results only if it is a + // fusion, reduce or reduce window. + bool multiple_results = kernel_prototype.results.size() > 1; + bool support_multiple_results = instr->opcode() == HloOpcode::kFusion || + instr->opcode() == HloOpcode::kReduce || + instr->opcode() == HloOpcode::kReduceWindow; + + auto parallel_config = GetParallelConfig(instr); + bool has_parallel_config = parallel_config.has_value(); + + if (multiple_results && !support_multiple_results) { + return Internal( + "Multi-output host kernels are not supported for %s instruction", + HloOpcodeString(instr->opcode())); + } + + // TODO(ezhulenev): Support multiple results for parallel loops. + if (multiple_results) { + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter(element_generator, kernel_prototype.results, &b) + .EmitLoop(llvm_ir::IrName(instr))); + return se::ThreadDim(); + } + + const llvm_ir::IrArray& result = kernel_prototype.results.front(); + + // Emit a loop for a single parallel partition with dynamic bounds computed + // from thread index. + if (has_parallel_config) { + ParallelPartitionBounds parallel_bounds = EmitParallelPartitionBounds( + b, kernel_prototype, *parallel_config, instr->shape(), instr->name()); + TF_RETURN_IF_ERROR( + ParallelLoopEmitter(element_generator, result, ¶llel_bounds, &b) + .EmitLoop(llvm_ir::IrName(instr))); + return se::ThreadDim(ShapePartitionAssigner::GetTotalPartitionCount( + parallel_config->outer_dimension_partitions)); + } + + // Emit a whole loop for the instruction. + TF_RETURN_IF_ERROR(llvm_ir::LoopEmitter(element_generator, result, &b) + .EmitLoop(llvm_ir::IrName(instr))); + return se::ThreadDim(); +} + +absl::StatusOr +ElementalKernelEmitter::ThreadLocalCallbackFactory(llvm::IRBuilderBase& builder, + llvm::Module& module) const { + const HloModule* hlo_module = instr_->GetModule(); + if (hlo_module == nullptr) { + return nullptr; + } + + auto ir_emitter = std::make_unique( + nullptr, *hlo_module, *buffer_assignment_, &module, + /*instruction_to_profile_idx=*/ + absl::flat_hash_map{}, + /*computation_to_profile_idx=*/ + absl::flat_hash_map{}, + ComputationsTransitivelyContainCustomCall(instr_), target_machine_, + /*emit_code_for_msan=*/false); + IrEmitter::IRBuilderGuard builder_guard = ir_emitter->WithBuilder(builder); + + TF_RETURN_IF_ERROR(ir_emitter->EmitSmallConstantGlobals()); + + if (instr_->has_to_apply()) { + HloComputation* nested_computation = instr_->to_apply(); + bool is_reducer = instr_->opcode() == HloOpcode::kReduce || + instr_->opcode() == HloOpcode::kReduceWindow; + TF_RETURN_IF_ERROR(ir_emitter->EmitNestedComputation( + *nested_computation, llvm_ir::IrName(nested_computation->name()), + is_reducer)); + } + + return [ir_emitter = std::move(ir_emitter), &builder]( + const HloComputation& callee, + absl::Span parameters, absl::string_view name, + bool is_reducer) { + IrEmitter::IRBuilderGuard builder_guard = ir_emitter->WithBuilder(builder); + return ir_emitter->EmitThreadLocalCall(callee, parameters, name, is_reducer, + /*in_compute_function=*/false); + }; +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/codegen/elemental_kernel_emitter.h b/third_party/xla/xla/backends/cpu/codegen/elemental_kernel_emitter.h new file mode 100644 index 00000000000000..337cfce6a7b2b8 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/elemental_kernel_emitter.h @@ -0,0 +1,76 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_CODEGEN_ELEMENTAL_KERNEL_EMITTER_H_ +#define XLA_BACKENDS_CPU_CODEGEN_ELEMENTAL_KERNEL_EMITTER_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h" +#include "xla/backends/cpu/codegen/target_machine_features.h" +#include "xla/codegen/kernel_emitter.h" +#include "xla/codegen/kernel_spec.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/elemental_ir_emitter.h" +#include "xla/service/llvm_ir/loop_emitter.h" +#include "xla/stream_executor/launch_dim.h" + +namespace xla::cpu { + +class ElementalKernelEmitter final : public KernelEmitter { + public: + explicit ElementalKernelEmitter(const HloInstruction* instr); + + ElementalKernelEmitter(const HloInstruction* instr, + const BufferAssignment* buffer_assignment, + const TargetMachineFeatures* target_machine); + + absl::StatusOr> EmitKernelSpec() override; + + private: + // Emits LLVM IR using elemental loop emitter and the given element generator. + // If the instruction is parallelized, it will emit a parallel loop partition + // and return the requested number of execution threads. + absl::StatusOr EmitElementalLoops( + llvm::IRBuilderBase& b, const HloInstruction* instr, + const KernelApiIrBuilder::KernelPrototype& kernel_prototype, + const llvm_ir::ElementGenerator& element_generator); + + // Create a thread local call callback, can be empty if no IrEmitter is + // registered. + absl::StatusOr + ThreadLocalCallbackFactory(llvm::IRBuilderBase& builder, + llvm::Module& module) const; + + private: + const HloInstruction* instr_; + + const BufferAssignment* buffer_assignment_ = nullptr; + const TargetMachineFeatures* target_machine_ = nullptr; + + llvm::orc::ThreadSafeContext context_; + + KernelApiIrBuilder kernel_api_ir_builder_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_CODEGEN_ELEMENTAL_KERNEL_EMITTER_H_ diff --git a/third_party/xla/xla/backends/cpu/codegen/elemental_kernel_emitter_test.cc b/third_party/xla/xla/backends/cpu/codegen/elemental_kernel_emitter_test.cc new file mode 100644 index 00000000000000..809d5d93437091 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/elemental_kernel_emitter_test.cc @@ -0,0 +1,130 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/codegen/elemental_kernel_emitter.h" + +#include +#include + +#include +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Type.h" +#include "xla/backends/cpu/codegen/llvm_ir_kernel_spec.h" +#include "xla/hlo/analysis/hlo_ordering.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/target_machine_features_stub.h" +#include "xla/service/logical_buffer.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/casts.h" +#include "tsl/platform/statusor.h" + +namespace xla::cpu { + +class ElementalKernelEmitterTest : public HloTestBase { + public: + ElementalKernelEmitterTest() + : target_machine_features_([](int64_t size) { return 1; }) {} + + absl::StatusOr> EmitKernelSpec( + const HloInstruction* instr, const BufferAssignment* buffer_assignment) { + ElementalKernelEmitter emitter(instr, buffer_assignment, + &target_machine_features_); + + TF_ASSIGN_OR_RETURN(auto kernel_spec, emitter.EmitKernelSpec()); + + return absl::WrapUnique( + tsl::down_cast(kernel_spec.release())); + } + + absl::StatusOr> RunBufferAssignment( + const HloModule& hlo) { + return BufferAssigner::Run( + &hlo, std::make_unique(&hlo), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return /*alignment=*/1; }); + } + + private: + TargetMachineFeaturesStub target_machine_features_; +}; + +namespace { + +TEST_F(ElementalKernelEmitterTest, EmitElementalKernel) { + const char* hlo_text = R"( + HloModule m + ENTRY main { + p0 = f32[2,2] parameter(0) + ROOT convert = s32[2,2] convert(p0) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(auto buffer_assignement, RunBufferAssignment(*hlo)); + TF_ASSERT_OK_AND_ASSIGN( + auto spec, EmitKernelSpec(hlo->entry_computation()->root_instruction(), + buffer_assignement.get())); + + ASSERT_TRUE(*RunFileCheck(spec->kernel_source().ToString(), R"( + CHECK: define ptr @convert_kernel(ptr %0) #0 { + CHECK: fptosi float {{.*}} to i32 + CHECK: } + )")); +} + +TEST_F(ElementalKernelEmitterTest, EmitParallelKernel) { + llvm::LLVMContext context; + auto module = std::make_unique("test", context); + + const char* hlo_text = R"( + HloModule m + ENTRY main { + p0 = f32[1,2,1,16384,256] parameter(0) + ROOT convert = s32[1,2,1,16384,256] convert(p0), + backend_config={"outer_dimension_partitions":["1","2","1","4"]} + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(auto buffer_assignement, RunBufferAssignment(*hlo)); + TF_ASSERT_OK_AND_ASSIGN( + auto spec, EmitKernelSpec(hlo->entry_computation()->root_instruction(), + buffer_assignement.get())); + + ASSERT_TRUE(*RunFileCheck(spec->kernel_source().ToString(), R"( + CHECK: @convert_parallel_bounds = private constant [8 x [4 x [2 x i64]]] + + CHECK: define ptr @convert_kernel(ptr %0) #0 { + CHECK: %lo_dim_0_gep = getelementptr{{.*}} i32 0, i64 %tid_x, i32 0, i32 0 + CHECK: %up_dim_0_gep = getelementptr{{.*}} i32 0, i64 %tid_x, i32 0, i32 1 + CHECK: %lo_dim_1_gep = getelementptr{{.*}} i32 0, i64 %tid_x, i32 1, i32 0 + CHECK: %up_dim_1_gep = getelementptr{{.*}} i32 0, i64 %tid_x, i32 1, i32 1 + CHECK: %lo_dim_2_gep = getelementptr{{.*}} i32 0, i64 %tid_x, i32 2, i32 0 + CHECK: %up_dim_2_gep = getelementptr{{.*}} i32 0, i64 %tid_x, i32 2, i32 1 + CHECK: %lo_dim_3_gep = getelementptr{{.*}} i32 0, i64 %tid_x, i32 3, i32 0 + CHECK: %up_dim_3_gep = getelementptr{{.*}} i32 0, i64 %tid_x, i32 3, i32 1 + CHECK: fptosi float {{.*}} to i32 + CHECK: } + )")); +} + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/codegen/ir/tests/BUILD b/third_party/xla/xla/backends/cpu/codegen/ir/tests/BUILD index c33df92b01cc32..10228fcc460af8 100644 --- a/third_party/xla/xla/backends/cpu/codegen/ir/tests/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/ir/tests/BUILD @@ -10,7 +10,7 @@ lit_test_suite( srcs = glob(["*.mlir"]), cfg = "//xla:lit.cfg.py", tools = [ - "//xla/backends/cpu/codegen/tools:xla_cpu_opt", + "//xla/codegen/tools:emitters_opt", "@llvm-project//llvm:FileCheck", ], ) diff --git a/third_party/xla/xla/backends/cpu/codegen/ir/tests/ops.mlir b/third_party/xla/xla/backends/cpu/codegen/ir/tests/ops.mlir index 0e7faa0a235242..0cc695d10d3471 100644 --- a/third_party/xla/xla/backends/cpu/codegen/ir/tests/ops.mlir +++ b/third_party/xla/xla/backends/cpu/codegen/ir/tests/ops.mlir @@ -1,4 +1,4 @@ -// RUN: xla_cpu_opt %s --split-input-file | FileCheck %s +// RUN: emitters_opt %s --split-input-file | FileCheck %s func.func @load(%arg0: !xla_cpu.call_frame) -> tensor<32x32xf32> { %0 = xla_cpu.load %arg0, 0 : tensor<32x32xf32> diff --git a/third_party/xla/xla/backends/cpu/codegen/ir/tests/types.mlir b/third_party/xla/xla/backends/cpu/codegen/ir/tests/types.mlir index 504dfa29976c33..67c73db0d1ae53 100644 --- a/third_party/xla/xla/backends/cpu/codegen/ir/tests/types.mlir +++ b/third_party/xla/xla/backends/cpu/codegen/ir/tests/types.mlir @@ -1,4 +1,4 @@ -// RUN: xla_cpu_opt %s | FileCheck %s +// RUN: emitters_opt %s | FileCheck %s func.func @call_frame_arg(%arg0: !xla_cpu.call_frame) { return diff --git a/third_party/xla/xla/backends/cpu/codegen/ir_compiler.cc b/third_party/xla/xla/backends/cpu/codegen/ir_compiler.cc index 6812d8b6ef1203..2f746a1fa1947a 100644 --- a/third_party/xla/xla/backends/cpu/codegen/ir_compiler.cc +++ b/third_party/xla/xla/backends/cpu/codegen/ir_compiler.cc @@ -185,7 +185,7 @@ llvm::Expected> IrCompiler::operator()( llvm::Expected> obj_file = llvm::object::ObjectFile::createObjectFile(*mc_memory_buffer); if (obj_file) { - hooks_.post_codegen(*obj_file.get()); + hooks_.post_codegen(module, *obj_file.get()); } else { LOG(WARNING) << "Could not convert memory buffer to object file"; } diff --git a/third_party/xla/xla/backends/cpu/codegen/ir_compiler.h b/third_party/xla/xla/backends/cpu/codegen/ir_compiler.h index 9be22a78eff78f..9c6678bd9196f3 100644 --- a/third_party/xla/xla/backends/cpu/codegen/ir_compiler.h +++ b/third_party/xla/xla/backends/cpu/codegen/ir_compiler.h @@ -68,7 +68,8 @@ class IrCompiler : public llvm::orc::IRCompileLayer::IRCompiler { struct CompilationHooks { std::function pre_optimization; std::function post_optimization; - std::function post_codegen; + std::function + post_codegen; }; IrCompiler(TargetMachineBuilder target_machine_builder, Options options, diff --git a/third_party/xla/xla/backends/cpu/codegen/jit_compiler.cc b/third_party/xla/xla/backends/cpu/codegen/jit_compiler.cc index 03dcfad9033b8f..e91e89a0007ff1 100644 --- a/third_party/xla/xla/backends/cpu/codegen/jit_compiler.cc +++ b/third_party/xla/xla/backends/cpu/codegen/jit_compiler.cc @@ -20,14 +20,17 @@ limitations under the License. #include #include #include -#include #include +#include "absl/base/call_once.h" +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "llvm/ADT/SmallVector.h" @@ -47,6 +50,7 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" #include "llvm/TargetParser/Host.h" +#include "xla/backends/cpu/codegen/compiled_function_library.h" #include "xla/backends/cpu/codegen/contiguous_section_memory_manager.h" #include "xla/backends/cpu/codegen/cpu_features.h" #include "xla/backends/cpu/codegen/ir_compiler.h" @@ -54,7 +58,6 @@ limitations under the License. #include "xla/service/cpu/orc_jit_memory_mapper.h" #include "xla/util.h" #include "tsl/platform/cpu_info.h" -#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" @@ -64,6 +67,14 @@ namespace xla::cpu { using tsl::profiler::TraceMe; using tsl::profiler::TraceMeEncode; +// Initialize LLVM the first time `JitCompiler` is created. +static void InitializeLLVMTarget() { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); +} + +absl::once_flag initialize_llvm_flag; + absl::StatusOr> JitCompiler::InferTargetMachine( const llvm::TargetOptions& target_options, llvm::CodeGenOptLevel opt_level, @@ -76,10 +87,11 @@ JitCompiler::InferTargetMachine( // If `max_cpu_feature` is newer than the host CPU, we should keep the host // CPU name, e.g., we don't want to set the target CPU to Skylake when we are // on a Broadwell host. - std::string_view cpu = result.num_filtered_features - ? CpuTargetFromMaxFeature(*max_cpu_feature) - : std::string_view(llvm::sys::getHostCPUName()); + absl::string_view cpu = result.num_filtered_features + ? CpuTargetFromMaxFeature(*max_cpu_feature) + : absl::string_view(llvm::sys::getHostCPUName()); + absl::call_once(initialize_llvm_flag, InitializeLLVMTarget); std::unique_ptr target_machine( llvm::EngineBuilder() .setTargetOptions(target_options) @@ -107,13 +119,7 @@ IrCompiler::TargetMachineBuilder JitCompiler::InferTargetMachineBuilder( absl::StatusOr JitCompiler::Create( llvm::TargetOptions target_options, Options options, TaskRunner task_runner) { - // Initialize LLVM the first time `JitCompiler` is created. - static bool llvm_initialized = [] { - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); - return true; - }(); - CHECK(llvm_initialized) << "LLVM must be initialized"; + absl::call_once(initialize_llvm_flag, InitializeLLVMTarget); // Infer target machine from the current host CPU. IrCompiler::TargetMachineBuilder target_machine_builder = @@ -257,7 +263,7 @@ absl::StatusOr> JitCompiler::Compile( // Mangle symbol names for the target machine data layout. llvm::DataLayout data_layout = target_machine_->createDataLayout(); - auto mangle = [&](std::string_view name) { + auto mangle = [&](absl::string_view name) { llvm::SmallVector mangled; llvm::Mangler::getNameWithPrefix(mangled, name, data_layout); return std::string(mangled.begin(), mangled.end()); @@ -334,43 +340,14 @@ void JitCompiler::TaskDispatcher::dispatch( absl::MutexLock lock(&mu_); --num_dispatched_tasks_; - cv_.SignalAll(); }); } void JitCompiler::TaskDispatcher::shutdown() { - absl::MutexLock lock(&mu_); - while (num_dispatched_tasks_ > 0) { - cv_.Wait(&mu_); - } -} - -JitCompiler::CompiledFunctionLibrary::CompiledFunctionLibrary( - std::unique_ptr execution_session, - std::unique_ptr object_layer, - absl::flat_hash_map symbols_map) - : execution_session_(std::move(execution_session)), - object_layer_(std::move(object_layer)), - symbols_map_(std::move(symbols_map)) { - DCHECK(execution_session_) << "Execution session must not be null"; -} - -JitCompiler::CompiledFunctionLibrary::~CompiledFunctionLibrary() { - if (auto err = execution_session_->endSession()) { - execution_session_->reportError(std::move(err)); - } -} - -absl::StatusOr JitCompiler::CompiledFunctionLibrary::ResolveFunction( - TypeId type_id, std::string_view name) { - if (auto it = symbols_map_.find(name); it != symbols_map_.end()) { - if (it->second.type_id != type_id) { - return Internal("Symbol %s has type id %d, expected %d", name, - it->second.type_id.value(), type_id.value()); - } - return it->second.ptr; - } - return NotFound("Function %s not found (type id: %d)", name, type_id.value()); + auto all_tasks_finished = [this]() ABSL_SHARED_LOCKS_REQUIRED(mu_) { + return num_dispatched_tasks_ == 0; + }; + absl::MutexLock lock(&mu_, absl::Condition(&all_tasks_finished)); } } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/codegen/jit_compiler.h b/third_party/xla/xla/backends/cpu/codegen/jit_compiler.h index 771e65380780e9..e98a999ddeb52c 100644 --- a/third_party/xla/xla/backends/cpu/codegen/jit_compiler.h +++ b/third_party/xla/xla/backends/cpu/codegen/jit_compiler.h @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/base/thread_annotations.h" @@ -157,34 +156,9 @@ class JitCompiler { TaskRunner task_runner_; absl::Mutex mu_; - absl::CondVar cv_; size_t num_dispatched_tasks_ ABSL_GUARDED_BY(mu_) = 0; }; - // Function library constructed from the set of jit-compiled symbols. - class CompiledFunctionLibrary : public FunctionLibrary { - public: - struct ResolvedSymbol { - TypeId type_id; - void* ptr; - }; - - CompiledFunctionLibrary( - std::unique_ptr execution_session, - std::unique_ptr object_layer, - absl::flat_hash_map symbols_map); - - ~CompiledFunctionLibrary() final; - - absl::StatusOr ResolveFunction(TypeId type_id, - std::string_view name) final; - - private: - std::unique_ptr execution_session_; - std::unique_ptr object_layer_; - absl::flat_hash_map symbols_map_; - }; - JitCompiler(IrCompiler::TargetMachineBuilder target_machine_builder, std::shared_ptr target_machine, TaskDispatcher* task_dispatcher, diff --git a/third_party/xla/xla/backends/cpu/codegen/jit_compiler_test.cc b/third_party/xla/xla/backends/cpu/codegen/jit_compiler_test.cc index 94ee288e8bc75f..0df61106f09320 100644 --- a/third_party/xla/xla/backends/cpu/codegen/jit_compiler_test.cc +++ b/third_party/xla/xla/backends/cpu/codegen/jit_compiler_test.cc @@ -20,12 +20,12 @@ limitations under the License. #include #include #include -#include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/AsmParser/Parser.h" #include "llvm/ExecutionEngine/JITSymbol.h" @@ -36,7 +36,6 @@ limitations under the License. #include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h" #include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/Support/CodeGen.h" #include "llvm/Support/Error.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Target/TargetMachine.h" @@ -63,8 +62,8 @@ static absl::StatusOr> Compile( // Parses the LLVM IR into a ThreadSafeModule. static absl::StatusOr ParseModule( - llvm::orc::ThreadSafeContext& context, std::string_view ir, - std::string_view name) { + llvm::orc::ThreadSafeContext& context, absl::string_view ir, + absl::string_view name) { llvm::SMDiagnostic diagnostic; llvm::MemoryBufferRef ir_buffer(ir, name); @@ -97,7 +96,7 @@ TEST(JitCompilerTest, Compile) { JitCompiler::Create(llvm::TargetOptions(), std::move(options), std::move(task_runner))); - constexpr std::string_view add_in_place_ir = R"( + constexpr absl::string_view add_in_place_ir = R"( define void @AddInplace(ptr %arg) { %v0 = load float, ptr %arg %v1 = fadd float %v0, %v0 @@ -105,7 +104,7 @@ TEST(JitCompilerTest, Compile) { ret void })"; - constexpr std::string_view mul_in_place_ir = R"( + constexpr absl::string_view mul_in_place_ir = R"( define void @MulInplace(ptr %arg) { %v0 = load float, ptr %arg %v1 = fmul float %v0, %v0 @@ -113,7 +112,7 @@ TEST(JitCompilerTest, Compile) { ret void })"; - auto add_module = [&](std::string_view ir, std::string_view name, + auto add_module = [&](absl::string_view ir, absl::string_view name, size_t dylib_index) -> absl::Status { TF_ASSIGN_OR_RETURN(llvm::orc::ThreadSafeModule tsm, ParseModule(tsc, ir, name)); @@ -189,7 +188,7 @@ TEST(JitCompilerTest, ExternalDefinitionGenerator) { JitCompiler::Create(llvm::TargetOptions(), std::move(options), /*task_runner=*/nullptr)); - constexpr std::string_view call_external_fn_ir = R"( + constexpr absl::string_view call_external_fn_ir = R"( declare void @__external_fn(ptr %arg) define void @CallExternalFn(ptr %arg) { diff --git a/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.cc b/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.cc new file mode 100644 index 00000000000000..d6b49888faad72 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.cc @@ -0,0 +1,518 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h" + +#include +#include +#include +#include +#include + +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/CallingConv.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/CodeGen.h" +#include "xla/cpu_function_runtime.h" +#include "xla/runtime/buffer_use.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/llvm_ir/ir_array.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" + +namespace xla::cpu { + +namespace { + +class MemoryDependencyAnalyzer { + public: + MemoryDependencyAnalyzer( + llvm::LLVMContext& context, absl::string_view name, + absl::Span results) + : context_(context), mb_(context) { + // Create an alias domain for the host kernel function. + llvm::MDNode* domain = mb_.createAliasScopeDomain( + absl::StrFormat("XLA host kernel %s AA domain", name)); + + result_slices_.reserve(results.size()); + for (const KernelApiIrBuilder::KernelParameter& result : results) { + result_slices_.insert(result.slice); + + // Skip result buffers that are aliased with entry parameters as we don't + // know if they can alias with any other buffers. + if (result.slice.allocation()->is_parameter_aliased_with_output()) { + continue; + } + alias_scopes_[result.slice] = mb_.createAliasScope( + absl::StrFormat("result slice: %s", result.slice.ToString()), domain); + } + } + + // Returns alias scope for the given buffer slice. + llvm::MDNode* GetAliasScope(BufferAllocation::Slice slice) { + if (slice.allocation() == nullptr) { + return nullptr; + } + + auto it = alias_scopes_.find(slice); + return it == alias_scopes_.end() ? nullptr + : llvm::MDNode::get(context_, it->second); + }; + + // Construct !noalias metadata for buffer slice. + llvm::MDNode* GetNoAlias(BufferAllocation::Slice slice) { + llvm::SmallVector scopes; + for (const auto& [alias_slice, alias_scope] : alias_scopes_) { + if (!slice.OverlapsWith(alias_slice)) { + scopes.push_back(alias_scope); + } + } + return scopes.empty() ? nullptr : llvm::MDNode::get(context_, scopes); + }; + + bool ResultContainsSlice(BufferAllocation::Slice slice) { + if (slice.allocation() == nullptr) { + return false; + } + return result_slices_.contains(slice); + } + + private: + llvm::LLVMContext& context_; + llvm::MDBuilder mb_; + + absl::btree_map alias_scopes_; + absl::flat_hash_set result_slices_; +}; + +// Following struct types correspond to HostKernel C API. +// See: xla/backends/cpu/runtime/kernel_c_api.h + +llvm::StructType* Dim3StructTy(llvm::LLVMContext& ctx, absl::string_view name) { + llvm::IntegerType* i64 = llvm::IntegerType::getInt64Ty(ctx); + return llvm::StructType::create(name, i64, i64, i64); +} + +llvm::StructType* KernelThreadDimTy(llvm::LLVMContext& ctx) { + return Dim3StructTy(ctx, "XLA_CPU_KernelThreadDim"); +} + +llvm::StructType* KernelThreadTy(llvm::LLVMContext& ctx) { + return Dim3StructTy(ctx, "XLA_CPU_KernelThread"); +} + +llvm::StructType* KernelArgTy(llvm::LLVMContext& ctx) { + llvm::PointerType* ptr = llvm::PointerType::getUnqual(ctx); + llvm::IntegerType* i64 = llvm::IntegerType::getInt64Ty(ctx); + return llvm::StructType::create("XLA_CPU_KernelArg", ptr, i64); +} + +llvm::StructType* KernelCallFrameTy(llvm::LLVMContext& ctx) { + llvm::PointerType* ptr = llvm::PointerType::getUnqual(ctx); + llvm::IntegerType* i64 = llvm::IntegerType::getInt64Ty(ctx); + return llvm::StructType::create("XLA_CPU_KernelCallFrame", ptr, ptr, i64, + ptr); +} + +llvm::FunctionType* KernelFunctionTy(llvm::LLVMContext& ctx) { + return llvm::FunctionType::get(llvm::PointerType::getUnqual(ctx), + llvm::PointerType::getUnqual(ctx), + /*isVarArg=*/false); +} + +// Check that all kernel arguments are coming from non-overlapping slices. It +// is fine to pass same slice as different arguments. This property is not +// used anywhere during the codegen, it acts mostly as a sanity check for +// the buffer assignment. In the future we might emit better aliasing metadata +// based on this property. +absl::Status VerifyKernelArgumentsNonOverlapping( + absl::Span arguments) { + for (size_t i = 0; i < arguments.size(); ++i) { + for (size_t j = i + 1; j < arguments.size(); ++j) { + const KernelApiIrBuilder::KernelParameter& a = arguments[i]; + const KernelApiIrBuilder::KernelParameter& b = arguments[j]; + + if (a.slice != b.slice && a.slice.OverlapsWith(b.slice)) { + return Internal( + "Kernel arguments must not overlap: result #%d (%s) overlaps " + "with result #%d (%s)", + i, a.slice.ToString(), j, b.slice.ToString()); + } + } + } + + return absl::OkStatus(); +} + +// Check that all kernel results are unique and coming from non-overlapping +// slices. We rely on this property to create LLVM `!alias.scope` for each +// kernel result buffer and to construct `!noalias` metadata for arguments. +absl::Status VerifyKernelResultsNonOverlapping( + absl::Span results) { + for (size_t i = 0; i < results.size(); ++i) { + for (size_t j = i + 1; j < results.size(); ++j) { + const KernelApiIrBuilder::KernelParameter& a = results[i]; + const KernelApiIrBuilder::KernelParameter& b = results[j]; + + if (a.slice.OverlapsWith(b.slice)) { + return Internal( + "Kernel results must not overlap: result #%d (%s) overlaps " + "with result #%d (%s)", + i, a.slice.ToString(), j, b.slice.ToString()); + } + } + } + + return absl::OkStatus(); +} + +// Check that results do not overlap with arguments, or if they do, they must +// be the same as one of the arguments, which can happen for inplace kernels. +absl::Status VerifyKernelResultsNonOverlappingWithArguments( + absl::Span arguments, + absl::Span results) { + for (size_t i = 0; i < results.size(); ++i) { + for (size_t j = 0; j < arguments.size(); ++j) { + const KernelApiIrBuilder::KernelParameter& result = results[i]; + const KernelApiIrBuilder::KernelParameter& argument = arguments[j]; + + if (result.slice.OverlapsWith(argument.slice) && + result.slice != argument.slice) { + return Internal( + "Kernel results must not partially overlap with arguments: result " + "#%d (%s) overlaps with argument #%d (%s)", + i, result.slice.ToString(), j, argument.slice.ToString()); + } + } + } + + return absl::OkStatus(); +} + +absl::Status VerifyKernelParameters( + absl::Span arguments, + absl::Span results) { + // IMPORTANT: Buffer slice non-overlapping property checked below does not + // necessarily mean that the buffers do not alias. Parameter allocations + // might have different index but at run time might be backed by the same + // memory (or aliased memory). We conservatively do not emit noalias metadata + // for buffers coming from parameter allocations. + + TF_RETURN_IF_ERROR(VerifyKernelArgumentsNonOverlapping(arguments)); + TF_RETURN_IF_ERROR(VerifyKernelResultsNonOverlapping(results)); + TF_RETURN_IF_ERROR( + VerifyKernelResultsNonOverlappingWithArguments(arguments, results)); + + return absl::OkStatus(); +} + +absl::StatusOr GetUniqueSlice( + const BufferAssignment* buffer_assignment, + const HloInstruction* instruction, const ShapeIndex& index) { + if (buffer_assignment == nullptr) { + return BufferAllocation::Slice{}; + } + + return buffer_assignment->GetUniqueSlice(instruction, index); +} + +absl::StatusOr> +GetKernelArgumentsParameters(const HloInstruction* instruction, + const BufferAssignment* buffer_assignment) { + std::vector arguments; + + for (HloInstruction* operand : instruction->operands()) { + for (auto& indexed : ShapeUtil::GetLeafShapes(operand->shape())) { + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice slice, + GetUniqueSlice(buffer_assignment, operand, indexed.index)); + arguments.push_back( + KernelApiIrBuilder::KernelParameter{indexed.shape, slice}); + } + } + return arguments; +} + +absl::StatusOr> +GetKernelResultsParameters(const HloInstruction* instruction, + const BufferAssignment* buffer_assignment) { + std::vector results; + for (auto& indexed : ShapeUtil::GetLeafShapes(instruction->shape())) { + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice slice, + GetUniqueSlice(buffer_assignment, instruction, indexed.index)); + results.push_back( + KernelApiIrBuilder::KernelParameter{indexed.shape, slice}); + } + return results; +} + +} // namespace + +KernelApiIrBuilder::KernelApiIrBuilder(llvm::LLVMContext& context, + Options options) + : context_(context), options_(std::move(options)) { + thread_dim_ty_ = KernelThreadDimTy(context_); + thread_ty_ = KernelThreadTy(context_); + arg_ty_ = KernelArgTy(context_); + call_frame_ty_ = KernelCallFrameTy(context_); + kernel_function_ty_ = KernelFunctionTy(context_); +} + +auto KernelApiIrBuilder::EmitKernelPrototype( + llvm::Module& module, const HloInstruction* instr, + const BufferAssignment* buffer_assignment, absl::string_view suffix) + -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(std::vector arguments, + GetKernelArgumentsParameters(instr, buffer_assignment)); + TF_ASSIGN_OR_RETURN(std::vector results, + GetKernelResultsParameters(instr, buffer_assignment)); + + bool compute_alias_metadata = buffer_assignment != nullptr; + return EmitKernelPrototype(module, absl::StrCat(instr->name(), suffix), + arguments, results, compute_alias_metadata); +} + +auto KernelApiIrBuilder::EmitKernelPrototype( + llvm::Module& module, absl::string_view name, + absl::Span arguments, + absl::Span results, bool compute_alias_metadata) + -> absl::StatusOr { + CHECK(&module.getContext() == &context_) << "Module context mismatch"; + + VLOG(3) << "Emit kernel prototype: " << name + << ", #arguments=" << arguments.size() + << ", #results=" << results.size(); + for (const KernelParameter& argument : arguments) { + VLOG(3) << " argument: " << argument.shape.ToString(true) << " in " + << argument.slice.ToString(); + } + for (const KernelParameter& result : results) { + VLOG(3) << " result: " << result.shape.ToString(true) << " in " + << result.slice.ToString(); + } + + if (compute_alias_metadata) { + TF_RETURN_IF_ERROR(VerifyKernelParameters(arguments, results)); + } + + MemoryDependencyAnalyzer memory_dependency_analyzer( + context_, name, + compute_alias_metadata ? results : absl::Span{}); + + llvm::IRBuilder<> b(context_); + + // Create a kernel function with HostKernel API. + llvm::Function* function = EmitKernelFunction(module, name); + + // Create an entry basic block and set insert point to the end of it. + b.SetInsertPoint(llvm::BasicBlock::Create(context_, "", function)); + + llvm::Value* call_frame = function->getArg(0); + // Build thread coordinates from the call frame. + KernelApiIrBuilder::ThreadDims kernel_thread_dims = + EmitKernelThreadDims(b, call_frame); + KernelApiIrBuilder::ThreadId kernel_thread = EmitKernelThread(b, call_frame); + + int64_t idx = 0; + + // A set of invariant (read-only) buffer indices, feeded in the loop array in + // the next section. + absl::flat_hash_set invariant_arguments; + + // IrArrays for the parameters. + std::vector ir_arguments; + for (int64_t i = 0; i < arguments.size(); ++i) { + const KernelParameter& argument = arguments[i]; + auto ir_argument = EmitKernelArgument(b, call_frame, idx++, argument.shape); + if (auto* noalias = memory_dependency_analyzer.GetNoAlias(argument.slice)) { + ir_argument.AddNoaliasMetadata(noalias); + } + + // If a buffer slice is not a part of result set, then it must be invariant + // (read-only). + if (!memory_dependency_analyzer.ResultContainsSlice(argument.slice)) { + ir_argument.MarkInvariantOverWholeProgram(&context_); + invariant_arguments.insert(i); + } + + ir_arguments.push_back(std::move(ir_argument)); + } + + // IrArrays for the results. + std::vector ir_results; + for (const KernelParameter& result : results) { + auto ir_result = EmitKernelArgument(b, call_frame, idx++, result.shape); + if (auto* noalias = memory_dependency_analyzer.GetNoAlias(result.slice)) { + ir_result.AddNoaliasMetadata(noalias); + } + if (auto* alias_scope = + memory_dependency_analyzer.GetAliasScope(result.slice)) { + ir_result.AddAliasScopeMetadata(alias_scope); + } + ir_results.push_back(std::move(ir_result)); + } + + // Return null pointer to signal success as we do not support error handling + // in the compiled host kernel. + llvm::BasicBlock* return_block = + llvm::BasicBlock::Create(context_, "return", function); + + b.CreateBr(return_block); + + b.SetInsertPoint(return_block); + b.CreateRet( + llvm::ConstantPointerNull::get(llvm::PointerType::getUnqual(context_))); + + absl::InlinedVector buffer_uses; + if (compute_alias_metadata) { + for (const KernelParameter& argument : arguments) { + buffer_uses.push_back(BufferUse::Read(argument.slice)); + } + for (const KernelParameter& result : results) { + buffer_uses.push_back(BufferUse::Write(result.slice)); + } + } + + return KernelPrototype{function, + return_block, + kernel_thread_dims, + kernel_thread, + std::move(ir_arguments), + std::move(ir_results), + std::move(invariant_arguments), + std::move(buffer_uses)}; +} + +auto KernelApiIrBuilder::EmitKernelThreadDims(llvm::IRBuilderBase& builder, + llvm::Value* call_frame) + -> ThreadDims { + llvm::Value* td_gep = + builder.CreateStructGEP(call_frame_ty_, call_frame, 0, "tdims_gep"); + llvm::Value* tdims = builder.CreateLoad(builder.getPtrTy(), td_gep, "tdims"); + llvm::Value* x_gep = + builder.CreateStructGEP(thread_dim_ty_, tdims, 0, "tdim_x_gep"); + llvm::Value* y_gep = + builder.CreateStructGEP(thread_dim_ty_, tdims, 1, "tdim_y_gep"); + llvm::Value* z_gep = + builder.CreateStructGEP(thread_dim_ty_, tdims, 2, "tdim_z_gep"); + + return {builder.CreateLoad(builder.getInt64Ty(), x_gep, "tdim_x"), + builder.CreateLoad(builder.getInt64Ty(), y_gep, "tdim_y"), + builder.CreateLoad(builder.getInt64Ty(), z_gep, "tdim_z")}; +} + +auto KernelApiIrBuilder::EmitKernelThread(llvm::IRBuilderBase& builder, + llvm::Value* call_frame) -> ThreadId { + llvm::Value* t_gep = + builder.CreateStructGEP(call_frame_ty_, call_frame, 1, "tid_gep"); + llvm::LoadInst* tids = builder.CreateLoad(builder.getPtrTy(), t_gep, "tids"); + llvm::Value* x_gep = + builder.CreateStructGEP(thread_ty_, tids, 0, "tid_x_gep"); + llvm::Value* y_gep = + builder.CreateStructGEP(thread_ty_, tids, 1, "tid_y_gep"); + llvm::Value* z_gep = + builder.CreateStructGEP(thread_ty_, tids, 2, "tid_z_gep"); + + return {builder.CreateLoad(builder.getInt64Ty(), x_gep, "tid_x"), + builder.CreateLoad(builder.getInt64Ty(), y_gep, "tid_y"), + builder.CreateLoad(builder.getInt64Ty(), z_gep, "tid_z")}; +} + +llvm_ir::IrArray KernelApiIrBuilder::EmitKernelArgument( + llvm::IRBuilderBase& builder, llvm::Value* call_frame, int64_t index, + const Shape& shape) { + llvm::LLVMContext& ctx = builder.getContext(); + + llvm::Type* ptr = llvm::PointerType::get(ctx, 0); + std::string name = absl::StrCat("arg", index); + + llvm::Value* args_gep = + builder.CreateStructGEP(call_frame_ty_, call_frame, 3, "args_gep"); + llvm::LoadInst* args = builder.CreateLoad(ptr, args_gep, "args"); + llvm::Value* data_gep = + builder.CreateConstGEP2_32(arg_ty_, args, index, 0, name + "_gep"); + llvm::LoadInst* data = builder.CreateLoad(ptr, data_gep, name); + + // All buffers passed to host kernels are expected to be properly aligned, + // emit metadata to allow LLVM to use that information for optimization. + llvm_ir::SetAlignmentMetadataForLoad(data, cpu_function_runtime::MinAlign()); + + // All buffers pointers passed to host kernels are expected to be + // dereferenceable. + llvm_ir::SetDereferenceableMetadataForLoad(data, + ShapeUtil::ByteSizeOf(shape)); + + // All buffers pointers passed to host kernels are expected to be invariant + // over the whole program. Note the metadata is attached only to loading + // buffer pointers, not to loading actual buffers. + if (options_.enable_invariant_load_metadata) { + data->setMetadata(llvm::LLVMContext::MD_invariant_load, + llvm::MDNode::get(data->getContext(), /*MDs=*/{})); + } + + return llvm_ir::IrArray(data, llvm_ir::ShapeToIrType(shape, ctx), shape); +} + +llvm::Function* KernelApiIrBuilder::EmitKernelFunction(llvm::Module& module, + absl::string_view name) { + llvm::Function* function = llvm::Function::Create( + kernel_function_ty_, llvm::GlobalValue::ExternalLinkage, name, module); + + // We use external linkage because we'll be resolving this function from the + // XLA runtime. + function->setCallingConv(llvm::CallingConv::C); + + // Generate unwind information so that GDB can crawl through the stack frames + // created by the JIT compiled code. + function->setUWTableKind(llvm::UWTableKind::Default); + + // Set prefer-vector-width attribute to allow LLVM to use wider vector + // registers (by default LLVM uses at most 256-bit registers). + function->addFnAttr("prefer-vector-width", + absl::StrCat(options_.prefer_vector_width)); + + // Always keep a frame pointer for the host kernel so we can see them in all + // performance profiling tools. + function->addFnAttr("frame-pointer", "all"); + + return function; +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.h b/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.h new file mode 100644 index 00000000000000..06b193ab9c6e09 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.h @@ -0,0 +1,133 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_CODEGEN_KERNEL_API_IR_BUILDER_H_ +#define XLA_BACKENDS_CPU_CODEGEN_KERNEL_API_IR_BUILDER_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Value.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/runtime/buffer_use.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/llvm_ir/ir_array.h" +#include "xla/shape.h" + +namespace xla::cpu { + +class KernelApiIrBuilder { + public: + struct Options { + bool enable_invariant_load_metadata; + int32_t prefer_vector_width; + }; + + // Thread dimensions of the kernel invocation. + struct ThreadDims { + llvm::Value* x; + llvm::Value* y; + llvm::Value* z; + }; + + // Thread coordinates of the kernel invocation. + struct ThreadId { + llvm::Value* x; + llvm::Value* y; + llvm::Value* z; + }; + + // Kernel parameter (argument or result buffer) passed to a kernel function. + // We rely on buffer allocation slice information to infer buffer aliasing + // scopes for LLVM codegen. + struct KernelParameter { + Shape shape; + BufferAllocation::Slice slice; + }; + + // A kernel function prototype with all the LLVM values that might be needed + // to emit the actual kernel body. + struct KernelPrototype { + llvm::Function* function; + llvm::BasicBlock* return_block; + + // LLVM values identifying kernel invocation thread coordinates. + ThreadDims thread_dims; + ThreadId thread_id; + + // LLVM values corresponding to the kernel arguments and results arrays. All + // tuples are flattened as we do not have any tuples at run time and only + // read and write data from/to leaf arrays. + std::vector arguments; + std::vector results; + + // Set containing all invariant (read-only) buffers indices. A buffer is + // read-only if it is not aliased with any result. + absl::flat_hash_set invariant_arguments; + + // the set of buffer uses for this kernel, can be empty if buffer + // was not provided. + absl::InlinedVector buffer_uses; + }; + + KernelApiIrBuilder(llvm::LLVMContext& context_, Options options); + + // Emits a kernel prototype for the given HLO instruction. + // buffer_assignment may be null, in which case we will not compute alias + // metadata. + absl::StatusOr EmitKernelPrototype( + llvm::Module& module, const HloInstruction* instr, + const BufferAssignment* buffer_assignment, absl::string_view suffix = ""); + + absl::StatusOr EmitKernelPrototype( + llvm::Module& module, absl::string_view name, + absl::Span arguments, + absl::Span results, + bool compute_alias_metadata = true); + + private: + ThreadDims EmitKernelThreadDims(llvm::IRBuilderBase& builder, + llvm::Value* call_frame); + ThreadId EmitKernelThread(llvm::IRBuilderBase& builder, + llvm::Value* call_frame); + llvm_ir::IrArray EmitKernelArgument(llvm::IRBuilderBase& builder, + llvm::Value* call_frame, int64_t index, + const Shape& shape); + llvm::Function* EmitKernelFunction(llvm::Module& module, + absl::string_view name); + + private: + llvm::LLVMContext& context_; + + Options options_; + + llvm::StructType* thread_dim_ty_; + llvm::StructType* thread_ty_; + llvm::StructType* arg_ty_; + llvm::StructType* call_frame_ty_; + llvm::FunctionType* kernel_function_ty_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_CODEGEN_KERNEL_API_IR_BUILDER_H_ diff --git a/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder_test.cc b/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder_test.cc new file mode 100644 index 00000000000000..04b25ec25c5fa7 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder_test.cc @@ -0,0 +1,298 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h" + +#include +#include +#include + +#include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Type.h" +#include "xla/cpu_function_runtime.h" +#include "xla/hlo/analysis/hlo_ordering.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/llvm_ir/ir_array.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/service/logical_buffer.h" +#include "xla/shape_util.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +class KernelApiIrBuilderTest : public HloTestBase { + public: + KernelApiIrBuilderTest() + : module_("KernelApiIrBuilderTest", context_), + kernel_api_ir_builder_(context_, + KernelApiIrBuilder::Options{true, 256}) {} + + llvm::IRBuilder<> getBuilder() { return llvm::IRBuilder<>(context_); } + + auto EmitKernelPrototype(const HloInstruction* instr, + const BufferAssignment* buffer_assignment) { + return kernel_api_ir_builder_.EmitKernelPrototype(module_, instr, + buffer_assignment); + } + + auto EmitKernelPrototype( + absl::string_view name, + absl::Span arguments, + absl::Span results) { + return kernel_api_ir_builder_.EmitKernelPrototype(module_, name, arguments, + results); + } + + absl::StatusOr> RunBufferAssignment( + const HloModule& hlo) { + return BufferAssigner::Run( + &hlo, std::make_unique(&hlo), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return /*alignment=*/1; }); + } + + llvm::LLVMContext& context() { return context_; } + std::string DumpToString() { return llvm_ir::DumpToString(&module_); } + + private: + llvm::LLVMContext context_; + llvm::Module module_; + KernelApiIrBuilder kernel_api_ir_builder_; +}; + +namespace { + +TEST_F(KernelApiIrBuilderTest, BuildKernelPrototype) { + auto hlo = std::make_unique("test", HloModuleConfig()); + + auto shape = ShapeUtil::MakeShape(PrimitiveType::F32, {4, 2}); + + BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); + BufferAllocation::Slice arg0(&alloc, /*offset=*/0, /*size=*/256); + BufferAllocation::Slice arg1(&alloc, /*offset=*/256, /*size=*/256); + BufferAllocation::Slice res0(&alloc, /*offset=*/512, /*size=*/256); + BufferAllocation::Slice res1(&alloc, /*offset=*/768, /*size=*/256); + + std::vector arguments = {{shape, arg0}, + {shape, arg1}}; + std::vector results = {{shape, res0}, + {shape, res1}}; + + TF_ASSERT_OK_AND_ASSIGN(auto prototype, + EmitKernelPrototype("test", arguments, results)); + llvm::IRBuilder<> builder = getBuilder(); + builder.SetInsertPoint(prototype.function->getEntryBlock().getTerminator()); + + auto* zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context()), 0); + llvm_ir::IrArray::Index index(zero, shape, &builder); + + // Emit loads from arguments and results buffers to test alias scope metadata. + EXPECT_NE(prototype.arguments[0].EmitReadArrayElement(index, &builder), + nullptr); + EXPECT_NE(prototype.arguments[1].EmitReadArrayElement(index, &builder), + nullptr); + EXPECT_NE(prototype.results[0].EmitReadArrayElement(index, &builder), + nullptr); + EXPECT_NE(prototype.results[1].EmitReadArrayElement(index, &builder), + nullptr); + + // clang-format off + ASSERT_TRUE(*RunFileCheck(DumpToString(), + absl::StrCat(R"( + CHECK: define ptr @test(ptr %0) #0 { + + CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 0 + CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThreadDim, {{.*}} i32 0 + CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThreadDim, {{.*}} i32 1 + CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThreadDim, {{.*}} i32 2 + CHECK: load i64 + CHECK: load i64 + CHECK: load i64 + + CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 1 + CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThread, {{.*}} i32 0 + CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThread, {{.*}} i32 1 + CHECK: getelementptr inbounds nuw %XLA_CPU_KernelThread, {{.*}} i32 2 + CHECK: load i64 + CHECK: load i64 + CHECK: load i64 + + CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 3 + CHECK: load ptr + CHECK: getelementptr %XLA_CPU_KernelArg, {{.*}} i32 0, i32 0 + CHECK: %[[ARG0:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0:.+]], !dereferenceable ![[DEREF_BYTES:.+]], !align ![[ALIGNMENT:.+]] + + CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 3 + CHECK: load ptr + CHECK: getelementptr %XLA_CPU_KernelArg, {{.*}} i32 1, i32 0 + CHECK: %[[ARG1:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]] + + CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 3 + CHECK: load ptr + CHECK: getelementptr %XLA_CPU_KernelArg, {{.*}} i32 2, i32 0 + CHECK: %[[ARG2:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]] + + CHECK-NEXT: getelementptr inbounds nuw %XLA_CPU_KernelCallFrame, {{.*}} i32 3 + CHECK: load ptr + CHECK: getelementptr %XLA_CPU_KernelArg, {{.*}} i32 3, i32 0 + CHECK: %[[ARG3:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]] + + CHECK-NEXT: %[[PTR0:.+]] = getelementptr inbounds float, ptr %[[ARG0]] + CHECK: load float, ptr %[[PTR0]], align 4, + CHECK-SAME: !invariant.load ![[SCOPE0]], + CHECK-SAME: !noalias ![[SCOPE1:.+]] + + CHECK-NEXT: %[[PTR1:.+]] = getelementptr inbounds float, ptr %[[ARG1]] + CHECK: load float, ptr %[[PTR1]], align 4, + CHECK-SAME: !invariant.load ![[SCOPE0]], + CHECK-SAME: !noalias ![[SCOPE1]] + + CHECK-NEXT: %[[PTR2:.+]] = getelementptr inbounds float, ptr %[[ARG2]] + CHECK: load float, ptr %[[PTR2]], align 4, !alias.scope ![[SCOPE2:.+]], + CHECK: !noalias ![[SCOPE3:.+]] + + CHECK-NEXT: %[[PTR3:.+]] = getelementptr inbounds float, ptr %[[ARG3]] + CHECK: load float, ptr %[[PTR3]], align 4, !alias.scope ![[SCOPE3]], + CHECK: !noalias ![[SCOPE2]] + + CHECK: ret ptr null + CHECK: } + + #0 = { uwtable "frame-pointer"="all" "prefer-vector-width"="256" } + CHECK-DAG: ![[ALIGNMENT]] = !{i64 )", cpu_function_runtime::MinAlign(), R"(} + CHECK-DAG: ![[SCOPE0]] = !{} + CHECK-DAG: ![[SCOPE1]] = !{![[RES0:.+]], ![[RES1:.+]]} + CHECK-DAG: ![[SCOPE2]] = !{![[RES0]]} + CHECK-DAG: ![[SCOPE3]] = !{![[RES1]]} + CHECK-DAG: ![[RES0]] = !{!"{{.*}}, offset:512, {{.*}}", ![[DOMAIN:.+]]} + CHECK-DAG: ![[RES1]] = !{!"{{.*}}, offset:768, {{.*}}", ![[DOMAIN]]} + CHECK-DAG: ![[DOMAIN]] = !{!"XLA host kernel test AA domain"} + )"))); + // clang-format on + + // Match for dereferenceable metadata in separate check, because depending on + // the alignment value, it may be the same scope as align, and may be a + // separate one. It's impossible to match both these cases in one FileCheck. + ASSERT_TRUE(*RunFileCheck(DumpToString(), R"( + CHECK: {{.+}} = load ptr, {{.*}}, !dereferenceable ![[DEREF_BYTES:.+]], + CHECK: ![[DEREF_BYTES]] = !{i64 32} + )")); +} + +TEST_F(KernelApiIrBuilderTest, AllInvariantBuffers) { + llvm::LLVMContext context; + auto module = std::make_unique("test", context); + + const char* hlo_text = R"( + HloModule m + ENTRY main { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT add.0 = f32[2,2] add(p0, p1) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(auto buffer_assignement, RunBufferAssignment(*hlo)); + TF_ASSERT_OK_AND_ASSIGN( + KernelApiIrBuilder::KernelPrototype prototype, + EmitKernelPrototype(hlo->entry_computation()->root_instruction(), + buffer_assignement.get())); + + ASSERT_EQ(prototype.invariant_arguments.size(), 2); +} + +TEST_F(KernelApiIrBuilderTest, InvariantBufferPassedTwice) { + llvm::LLVMContext context; + auto module = std::make_unique("test", context); + + const char* hlo_text = R"( + HloModule m + ENTRY main { + p0 = f32[2,2] parameter(0) + ROOT add.0 = f32[2,2] add(p0, p0) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(auto buffer_assignement, RunBufferAssignment(*hlo)); + TF_ASSERT_OK_AND_ASSIGN( + KernelApiIrBuilder::KernelPrototype prototype, + EmitKernelPrototype(hlo->entry_computation()->root_instruction(), + buffer_assignement.get())); + + // Invariant buffers contains indices of both arguments, even though it is the + // same buffer slice. + ASSERT_EQ(prototype.invariant_arguments.size(), 2); +} + +TEST_F(KernelApiIrBuilderTest, NoInvariantBuffers) { + llvm::LLVMContext context; + auto module = std::make_unique("test", context); + + const char* hlo_text = R"( + HloModule m, input_output_alias={ {}: (0, {}, must-alias) } + ENTRY main { + p0 = f32[2,2] parameter(0) + ROOT add.0 = f32[2,2] add(p0, p0) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(auto buffer_assignement, RunBufferAssignment(*hlo)); + TF_ASSERT_OK_AND_ASSIGN( + KernelApiIrBuilder::KernelPrototype prototype, + EmitKernelPrototype(hlo->entry_computation()->root_instruction(), + buffer_assignement.get())); + + ASSERT_EQ(prototype.invariant_arguments.size(), 0); +} + +TEST_F(KernelApiIrBuilderTest, MixedBuffers) { + llvm::LLVMContext context; + auto module = std::make_unique("test", context); + + const char* hlo_text = R"( + HloModule m, input_output_alias={ {}: (1, {}, must-alias) } + ENTRY main { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT add.0 = f32[2,2] add(p0, p1) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(auto buffer_assignement, RunBufferAssignment(*hlo)); + TF_ASSERT_OK_AND_ASSIGN( + KernelApiIrBuilder::KernelPrototype prototype, + EmitKernelPrototype(hlo->entry_computation()->root_instruction(), + buffer_assignement.get())); + + // The first argument is invariant, the second is not because it's aliased to + // the output. + EXPECT_EQ(prototype.invariant_arguments.size(), 1); + EXPECT_TRUE(prototype.invariant_arguments.contains(0)); +} + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_spec.cc b/third_party/xla/xla/backends/cpu/codegen/llvm_ir_kernel_spec.cc similarity index 95% rename from third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_spec.cc rename to third_party/xla/xla/backends/cpu/codegen/llvm_ir_kernel_spec.cc index b54637f87f5d63..482f002a6fb2fc 100644 --- a/third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_spec.cc +++ b/third_party/xla/xla/backends/cpu/codegen/llvm_ir_kernel_spec.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/backends/cpu/testlib/llvm_ir_kernel_spec.h" +#include "xla/backends/cpu/codegen/llvm_ir_kernel_spec.h" #include #include diff --git a/third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_spec.h b/third_party/xla/xla/backends/cpu/codegen/llvm_ir_kernel_spec.h similarity index 86% rename from third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_spec.h rename to third_party/xla/xla/backends/cpu/codegen/llvm_ir_kernel_spec.h index 1bd97d52f6a141..cedd7e6db4f1bc 100644 --- a/third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_spec.h +++ b/third_party/xla/xla/backends/cpu/codegen/llvm_ir_kernel_spec.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_BACKENDS_CPU_TESTLIB_LLVM_IR_KERNEL_SPEC_H_ -#define XLA_BACKENDS_CPU_TESTLIB_LLVM_IR_KERNEL_SPEC_H_ +#ifndef XLA_BACKENDS_CPU_CODEGEN_LLVM_IR_KERNEL_SPEC_H_ +#define XLA_BACKENDS_CPU_CODEGEN_LLVM_IR_KERNEL_SPEC_H_ #include #include @@ -36,7 +36,7 @@ class LlvmIrKernelSpec final : public xla::KernelSpec { std::unique_ptr kernel_source); LlvmIrKernelSpec(LlvmIrKernelSpec&& other) = default; - LlvmIrKernelSpec& operator=(LlvmIrKernelSpec&& other) = default; + LlvmIrKernelSpec& operator=(LlvmIrKernelSpec&& other) noexcept = default; LlvmIrKernelSource& kernel_source() override { return *kernel_source_; } @@ -47,4 +47,4 @@ class LlvmIrKernelSpec final : public xla::KernelSpec { } // namespace xla::cpu -#endif // XLA_BACKENDS_CPU_TESTLIB_LLVM_IR_KERNEL_SPEC_H_ +#endif // XLA_BACKENDS_CPU_CODEGEN_LLVM_IR_KERNEL_SPEC_H_ diff --git a/third_party/xla/xla/backends/cpu/codegen/object_loader.cc b/third_party/xla/xla/backends/cpu/codegen/object_loader.cc new file mode 100644 index 00000000000000..ca70110d1e188f --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/object_loader.cc @@ -0,0 +1,174 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/codegen/object_loader.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Mangler.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/MemoryBuffer.h" +#include "xla/backends/cpu/codegen/compiled_function_library.h" +#include "xla/backends/cpu/codegen/contiguous_section_memory_manager.h" +#include "xla/backends/cpu/runtime/function_library.h" +#include "xla/service/cpu/orc_jit_memory_mapper.h" + +namespace xla::cpu { + +static std::unique_ptr +CreateObjectLinkingLayer(llvm::orc::ExecutionSession& execution_session) { + return std::make_unique( + execution_session, [] { + return std::make_unique( + orc_jit_memory_mapper::GetInstance()); + }); +} + +ObjectLoader::ObjectLoader(size_t num_dylibs) +/*: target_machine_(std::move(target_machine))*/ { + // LLVM execution session that holds jit-compiled functions. + execution_session_ = std::make_unique( + std::make_unique( + /*SSP=*/nullptr, /*D=*/nullptr)); + + execution_session_->setErrorReporter([](llvm::Error err) { + LOG(ERROR) << "LLVM compilation error: " << llvm::toString(std::move(err)); + }); + + // Create at least one dynamic library for the given jit compiler. + dylibs_.resize(std::max(1, num_dylibs)); + for (size_t i = 0; i < dylibs_.size(); ++i) { + dylibs_[i] = &execution_session_->createBareJITDylib( + absl::StrCat("")); + // TODO using target machine might bring some deps we don't need. + // as a first attempt fully remove it, consider pruning the reqs + // if (definition_generator) { + // dylibs_[i]->addGenerator(definition_generator(target_machine_.get())); + // } + } + + object_layer_ = CreateObjectLinkingLayer(*execution_session_); +} + +absl::Status ObjectLoader::AddObjFile(const std::string& obj_file, + const std::string& memory_buffer_name, + size_t dylib_index) { + if (dylib_index >= dylibs_.size()) { + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Invalid dylib index %d (num dylibs: %d))", dylib_index, + dylibs_.size())); + } + + llvm::StringRef data(obj_file.data(), obj_file.size()); + + auto obj_file_mem_buffer = + llvm::MemoryBuffer::getMemBuffer(data, memory_buffer_name); + + if (!obj_file_mem_buffer) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Failed to create memory buffer"); + } + + llvm::orc::JITDylib* dylib = dylibs_[dylib_index]; + if (auto err = object_layer_->add(*dylib, std::move(obj_file_mem_buffer))) { + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Failed to add object file to dylib %d: %s", + dylib_index, llvm::toString(std::move(err)))); + } + + return absl::OkStatus(); +} + +absl::StatusOr> ObjectLoader::Load( + absl::Span symbols, const llvm::DataLayout& data_layout) && { + // Mangle symbol names for the target machine data layout. + auto mangle = [&](absl::string_view name) { + llvm::SmallVector mangled; + llvm::Mangler::getNameWithPrefix(mangled, name, data_layout); + return std::string(mangled.begin(), mangled.end()); + }; + + // Build a symbol lookup set. + llvm::orc::SymbolLookupSet lookup_set; + for (const auto& symbol : symbols) { + VLOG(5) << absl::StreamFormat(" - look up symbol: %s", symbol.name); + lookup_set.add(execution_session_->intern(mangle(symbol.name))); + } + + // Build a search order for the dynamic libraries. + llvm::orc::JITDylibSearchOrder search_order(dylibs_.size()); + for (size_t i = 0; i < dylibs_.size(); ++i) { + search_order[i] = std::make_pair( + dylibs_[i], llvm::orc::JITDylibLookupFlags::MatchExportedSymbolsOnly); + } + + // Look up all requested symbols in the execution session. + auto symbol_map = execution_session_->lookup(std::move(search_order), + std::move(lookup_set)); + + if (auto err = symbol_map.takeError()) { + return absl::Status(absl::StatusCode::kInternal, + absl::StrFormat("%s", llvm::toString(std::move(err)))); + } + + // Resolve type-erased symbol pointers from the symbol map. + using ResolvedSymbol = CompiledFunctionLibrary::ResolvedSymbol; + absl::flat_hash_map resolved_map; + + for (const auto& symbol : symbols) { + auto symbol_name = execution_session_->intern(mangle(symbol.name)); + llvm::orc::ExecutorSymbolDef symbol_def = symbol_map->at(symbol_name); + llvm::orc::ExecutorAddr symbol_addr = symbol_def.getAddress(); + void* ptr = reinterpret_cast(symbol_addr.getValue()); + resolved_map[symbol.name] = ResolvedSymbol{symbol.type_id, ptr}; + } + + return std::make_unique( + std::move(execution_session_), std::move(object_layer_), + std::move(resolved_map)); +} + +ObjectLoader::~ObjectLoader() { + if (execution_session_) { + if (auto err = execution_session_->endSession()) { + execution_session_->reportError(std::move(err)); + } + } +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/codegen/object_loader.h b/third_party/xla/xla/backends/cpu/codegen/object_loader.h new file mode 100644 index 00000000000000..00739eca9f9bf6 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/object_loader.h @@ -0,0 +1,79 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_CODEGEN_OBJECT_LOADER_H_ +#define XLA_BACKENDS_CPU_CODEGEN_OBJECT_LOADER_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/IR/DataLayout.h" +#include "xla/backends/cpu/runtime/function_library.h" + +namespace xla::cpu { + +class ObjectLoader { + public: + using Symbol = FunctionLibrary::Symbol; + + explicit ObjectLoader(size_t num_dylibs); + + absl::Status AddObjFile(const std::string& obj_file, + const std::string& memory_buffer_name, + size_t dylib_index = 0); + + absl::StatusOr> Load( + absl::Span symbols, const llvm::DataLayout& data_layout) &&; + + llvm::orc::RTDyldObjectLinkingLayer* object_layer() { + return object_layer_.get(); + } + + llvm::orc::ExecutionSession* execution_session() { + return execution_session_.get(); + } + + absl::StatusOr dylib(size_t dylib_index) { + if (dylib_index >= dylibs_.size()) { + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Invalid dylib index %d (num dylibs: %d))", + dylib_index, dylibs_.size())); + } + return dylibs_[dylib_index]; + } + + ~ObjectLoader(); + + private: + std::unique_ptr object_layer_; + std::unique_ptr execution_session_; + + // Non-owning pointers to dynamic libraries created for the execution session. + std::vector dylibs_; + + // std::shared_ptr target_machine_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_CODEGEN_OBJECT_LOADER_H_ diff --git a/third_party/xla/xla/backends/cpu/codegen/object_loader_test.cc b/third_party/xla/xla/backends/cpu/codegen/object_loader_test.cc new file mode 100644 index 00000000000000..35bec67e6324aa --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/object_loader_test.cc @@ -0,0 +1,161 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/codegen/object_loader.h" + +#include +#include +#include +#include +#include +#include + +#include +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Object/ObjectFile.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include "xla/backends/cpu/codegen/ir_compiler.h" +#include "xla/backends/cpu/codegen/jit_compiler.h" +#include "xla/backends/cpu/runtime/function_library.h" +#include "xla/service/cpu/executable.pb.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { +namespace { + +// Parses the LLVM IR into a ThreadSafeModule. +static absl::StatusOr ParseModule( + llvm::orc::ThreadSafeContext& context, absl::string_view ir, + absl::string_view name) { + llvm::SMDiagnostic diagnostic; + llvm::MemoryBufferRef ir_buffer(ir, name); + + auto m = llvm::parseAssembly(ir_buffer, diagnostic, *context.getContext()); + if (m == nullptr) { + return Internal("Failed to parse LLVM IR: %s", + diagnostic.getMessage().str()); + } + + return llvm::orc::ThreadSafeModule(std::move(m), context); +} + +static absl::StatusOr> Compile( + JitCompiler compiler, absl::Span symbols) { + return std::move(compiler).Compile(symbols); +}; + +TEST(ObjectLoader, Load) { + constexpr size_t kNumDyLibs = 1; + auto context = std::make_unique(); + llvm::orc::ThreadSafeContext tsc(std::move(context)); + + std::vector object_files; + auto object_files_saver = + [&object_files](const llvm::Module& /*module*/, + const llvm::object::ObjectFile& object_file) -> void { + object_files.emplace_back(object_file.getData().data(), + object_file.getData().size()); + }; + + JitCompiler::Options options; + options.num_dylibs = kNumDyLibs; + options.ir_compiler_hooks.post_codegen = object_files_saver; + + TF_ASSERT_OK_AND_ASSIGN( + auto compiler, + JitCompiler::Create(llvm::TargetOptions(), std::move(options))); + + constexpr absl::string_view add_in_place_ir = R"( + define void @AddInplace(ptr %arg) { + %v0 = load float, ptr %arg + %v1 = fadd float %v0, %v0 + store float %v1, ptr %arg + ret void + })"; + + auto add_module = [&](absl::string_view ir, absl::string_view name, + size_t dylib_index) -> absl::Status { + TF_ASSIGN_OR_RETURN(llvm::orc::ThreadSafeModule tsm, + ParseModule(tsc, ir, name)); + TF_RETURN_IF_ERROR(compiler.AddModule(std::move(tsm), dylib_index)); + return absl::OkStatus(); + }; + + TF_ASSERT_OK(add_module(add_in_place_ir, "AddInplace", 0)); + + using ScalarFn = void(float*); + std::vector symbols = { + FunctionLibrary::Sym("AddInplace")}; + + llvm::DataLayout data_layout = compiler.target_machine()->createDataLayout(); + TF_ASSERT_OK_AND_ASSIGN(auto function_library_compiled, + Compile(std::move(compiler), symbols)); + + TF_ASSERT_OK_AND_ASSIGN( + ScalarFn * add_in_place_compiled, + function_library_compiled->ResolveFunction("AddInplace")); + + EXPECT_NE(add_in_place_compiled, nullptr); + + auto object_loader(std::make_unique(/*num_dylibs=*/kNumDyLibs)); + { + size_t obj_file_index = 0; + for (auto& obj_file : object_files) { + llvm::StringRef data(obj_file.data(), obj_file.size()); + TF_ASSERT_OK(object_loader->AddObjFile( + obj_file, absl::StrCat("loaded_obj_file_", obj_file_index++))); + } + } + + TF_ASSERT_OK_AND_ASSIGN(auto loaded_function_library, + std::move(*object_loader).Load(symbols, data_layout)); + + TF_ASSERT_OK_AND_ASSIGN( + ScalarFn * loaded_add_in_place, + loaded_function_library->ResolveFunction("AddInplace")); + + EXPECT_NE(loaded_add_in_place, nullptr); + + constexpr float kInputValue = 1.0f; + constexpr float kExpectedOutput = kInputValue + kInputValue; + + float compiled_function_input = kInputValue; + add_in_place_compiled(&compiled_function_input); + EXPECT_EQ(compiled_function_input, kExpectedOutput); + + float loaded_function_input = 1.0f; + loaded_add_in_place(&loaded_function_input); + EXPECT_EQ(loaded_function_input, compiled_function_input); +} + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/codegen/polynomial_approximations.cc b/third_party/xla/xla/backends/cpu/codegen/polynomial_approximations.cc index e1b6caf4f6c7c8..df7274d63e391f 100644 --- a/third_party/xla/xla/backends/cpu/codegen/polynomial_approximations.cc +++ b/third_party/xla/xla/backends/cpu/codegen/polynomial_approximations.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include #include +#include "absl/strings/string_view.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -84,7 +84,7 @@ void RemoveFunctionFromUsedList(llvm::Module* module, llvm::Function* fn) { // vector_width f32s, and that fn_body_generator generates a function body with // the same inputs/outputs as fn_name. void RewriteCalls( - llvm::Module* module, std::string_view fn_name, + llvm::Module* module, absl::string_view fn_name, std::function fn_body_generator, @@ -399,15 +399,15 @@ llvm::Value* GenerateVF32Log(llvm::IRBuilderBase* b, llvm::Value* input, } } // namespace -static constexpr std::string_view kTanhV4F32Sym = "__xla_cpu_TanhV4F32"; -static constexpr std::string_view kTanhV8F32Sym = "__xla_cpu_TanhV8F32"; -static constexpr std::string_view kTanhV16F32Sym = "__xla_cpu_TanhV16F32"; -static constexpr std::string_view kExpV4F32Sym = "__xla_cpu_ExpV4F32"; -static constexpr std::string_view kExpV8F32Sym = "__xla_cpu_ExpV8F32"; -static constexpr std::string_view kExpV16F32Sym = "__xla_cpu_ExpV16F32"; -static constexpr std::string_view kLogV4F32Sym = "__xla_cpu_LogV4F32AVX"; -static constexpr std::string_view kLogV8F32Sym = "__xla_cpu_LogV8F32AVX"; -static constexpr std::string_view kLogV16F32Sym = "__xla_cpu_LogV16F32AVX"; +static constexpr absl::string_view kTanhV4F32Sym = "__xla_cpu_TanhV4F32"; +static constexpr absl::string_view kTanhV8F32Sym = "__xla_cpu_TanhV8F32"; +static constexpr absl::string_view kTanhV16F32Sym = "__xla_cpu_TanhV16F32"; +static constexpr absl::string_view kExpV4F32Sym = "__xla_cpu_ExpV4F32"; +static constexpr absl::string_view kExpV8F32Sym = "__xla_cpu_ExpV8F32"; +static constexpr absl::string_view kExpV16F32Sym = "__xla_cpu_ExpV16F32"; +static constexpr absl::string_view kLogV4F32Sym = "__xla_cpu_LogV4F32AVX"; +static constexpr absl::string_view kLogV8F32Sym = "__xla_cpu_LogV8F32AVX"; +static constexpr absl::string_view kLogV16F32Sym = "__xla_cpu_LogV16F32AVX"; std::vector PolynomialApproximationsVectorization() { return std::vector{ diff --git a/third_party/xla/xla/backends/cpu/codegen/target_machine_features.h b/third_party/xla/xla/backends/cpu/codegen/target_machine_features.h index 5148ef1af1c020..e47acef5569a8e 100644 --- a/third_party/xla/xla/backends/cpu/codegen/target_machine_features.h +++ b/third_party/xla/xla/backends/cpu/codegen/target_machine_features.h @@ -38,6 +38,9 @@ class TargetMachineFeatures { explicit TargetMachineFeatures(llvm::TargetMachine* target_machine); virtual ~TargetMachineFeatures() = default; + TargetMachineFeatures(TargetMachineFeatures&&) = default; + TargetMachineFeatures& operator=(TargetMachineFeatures&&) = default; + // Return the vectorization factor, which is the number of bytes of data // explicitly vectorized routines will try to process at once. virtual int32_t vectorization_factor_in_bytes() const; diff --git a/third_party/xla/xla/backends/cpu/codegen/tools/BUILD b/third_party/xla/xla/backends/cpu/codegen/tools/BUILD deleted file mode 100644 index cfc8a5a33f6b41..00000000000000 --- a/third_party/xla/xla/backends/cpu/codegen/tools/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//xla:xla.bzl", "xla_cc_binary") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - licenses = ["notice"], -) - -xla_cc_binary( - name = "xla_cpu_opt", - srcs = ["xla_cpu_opt.cc"], - visibility = ["//xla/backends/cpu/codegen:__subpackages__"], - deps = [ - "//xla/backends/cpu/codegen/ir:xla_cpu", - "//xla/backends/cpu/codegen/transforms:passes", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncExtensions", - "@llvm-project//mlir:MlirOptLib", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", - ], -) diff --git a/third_party/xla/xla/backends/cpu/codegen/tools/xla_cpu_opt.cc b/third_party/xla/xla/backends/cpu/codegen/tools/xla_cpu_opt.cc deleted file mode 100644 index 109b4d5489526f..00000000000000 --- a/third_party/xla/xla/backends/cpu/codegen/tools/xla_cpu_opt.cc +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/Func/Extensions/AllExtensions.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Tools/mlir-opt/MlirOptMain.h" -#include "mlir/Transforms/Passes.h" -#include "xla/backends/cpu/codegen/ir/xla_cpu_dialect.h" -#include "xla/backends/cpu/codegen/transforms/passes.h" - -int main(int argc, char** argv) { - mlir::DialectRegistry registry; - registry.insert(); - - // Register builtin MLIR passes. - mlir::func::registerAllExtensions(registry); - mlir::registerCanonicalizerPass(); - mlir::registerCSEPass(); - - // Register XLA:CPU passes. - xla::cpu::registerXlaCpuTransformsPasses(); - - return mlir::failed( - MlirOptMain(argc, argv, "XLA:CPU Pass Driver\n", registry)); -} diff --git a/third_party/xla/xla/backends/cpu/codegen/transforms/tests/BUILD b/third_party/xla/xla/backends/cpu/codegen/transforms/tests/BUILD index c33df92b01cc32..10228fcc460af8 100644 --- a/third_party/xla/xla/backends/cpu/codegen/transforms/tests/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/transforms/tests/BUILD @@ -10,7 +10,7 @@ lit_test_suite( srcs = glob(["*.mlir"]), cfg = "//xla:lit.cfg.py", tools = [ - "//xla/backends/cpu/codegen/tools:xla_cpu_opt", + "//xla/codegen/tools:emitters_opt", "@llvm-project//llvm:FileCheck", ], ) diff --git a/third_party/xla/xla/backends/cpu/codegen/transforms/tests/lower_trivial.mlir b/third_party/xla/xla/backends/cpu/codegen/transforms/tests/lower_trivial.mlir index a7d4a117f0f005..363620bf8b645a 100644 --- a/third_party/xla/xla/backends/cpu/codegen/transforms/tests/lower_trivial.mlir +++ b/third_party/xla/xla/backends/cpu/codegen/transforms/tests/lower_trivial.mlir @@ -1,4 +1,4 @@ -// RUN: xla_cpu_opt %s --xla-cpu-lower-trivial | FileCheck %s +// RUN: emitters_opt %s --xla-cpu-lower-trivial | FileCheck %s func.func @call_frame_arg(%arg0: !xla_cpu.call_frame) { %0 = xla_cpu.load %arg0, 0 : tensor<32x32xf32> diff --git a/third_party/xla/xla/backends/cpu/codegen/vector_ir_builder.cc b/third_party/xla/xla/backends/cpu/codegen/vector_ir_builder.cc index 35dae9ec77de13..e68b0055b6228a 100644 --- a/third_party/xla/xla/backends/cpu/codegen/vector_ir_builder.cc +++ b/third_party/xla/xla/backends/cpu/codegen/vector_ir_builder.cc @@ -53,8 +53,8 @@ VectorIrBuilder::VectorIrBuilder(PrimitiveType primitive_type, primitive_type_(primitive_type), b_(b), name_(std::move(name)) { - scalar_type_ = llvm_ir::PrimitiveTypeToIrType( - primitive_type, b_->GetInsertBlock()->getModule()); + scalar_type_ = + llvm_ir::PrimitiveTypeToIrType(primitive_type, b_->getContext()); scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_); vector_type_ = llvm::VectorType::get(scalar_type_, vector_size, false); vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_); diff --git a/third_party/xla/xla/backends/cpu/collectives/BUILD b/third_party/xla/xla/backends/cpu/collectives/BUILD new file mode 100644 index 00000000000000..2608bb2eed8c3d --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/BUILD @@ -0,0 +1,383 @@ +load("//xla:xla.bzl", "xla_cc_test") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "cpu_clique_key", + srcs = ["cpu_clique_key.cc"], + hdrs = ["cpu_clique_key.h"], + deps = [ + "//xla/core/collectives:clique_key", + "//xla/service:global_device_id", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:casts", + ], +) + +cc_library( + name = "cpu_clique", + srcs = ["cpu_clique.cc"], + hdrs = ["cpu_clique.h"], + deps = [ + ":cpu_clique_key", + "//xla/core/collectives:clique", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/tsl/platform:logging", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "cpu_cliques", + srcs = ["cpu_cliques.cc"], + hdrs = ["cpu_cliques.h"], + deps = [ + ":cpu_clique", + ":cpu_clique_key", + ":cpu_collectives", + "//xla:util", + "//xla/core/collectives:clique", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "cpu_collectives", + srcs = ["cpu_collectives.cc"], + hdrs = ["cpu_collectives.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/core/collectives", + "//xla/core/collectives:clique_id", + "//xla/core/collectives:collectives_registry", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/service:collective_ops_utils", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:casts", + ], +) + +# TODO(b/380457503): Restrict visibility to private. +cc_library( + name = "in_process_collectives", + srcs = ["in_process_collectives.cc"], + hdrs = ["in_process_collectives.h"], + deps = [ + ":cpu_collectives", + ":in_process_communicator", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/core/collectives:clique_id", + "//xla/core/collectives:clique_key", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/stream_executor:device_memory", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + ], +) + +# TODO(b/380457503): Restrict visibility to private. +cc_library( + name = "in_process_communicator", + srcs = ["in_process_communicator.cc"], + hdrs = ["in_process_communicator.h"], + deps = [ + ":cpu_collectives", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/service:rendezvous", + "//xla/stream_executor:device_memory", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + ], +) + +# TODO(b/380457503): Restrict visibility to private. +cc_library( + name = "gloo_kv_store", + srcs = ["gloo_kv_store.cc"], + hdrs = ["gloo_kv_store.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = [ + "//xla/pjrt/cpu:legacy_cpu_internal_users", + ], + deps = [ + "//xla/pjrt:status_casters", + "//xla/pjrt/distributed:key_value_store_interface", + "@com_google_absl//absl/time", + "@gloo", + ], +) + +# TODO(b/380457503): Restrict visibility to private. +cc_library( + name = "gloo_collectives", + srcs = ["gloo_collectives.cc"], + hdrs = ["gloo_collectives.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/backends/cpu/collectives:gloo_communicator", + "//xla/core/collectives:clique_id", + "//xla/core/collectives:clique_key", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/stream_executor:device_memory", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@gloo", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "gloo_collectives_test", + srcs = ["gloo_collectives_test.cc"], + linkstatic = True, + deps = [ + ":gloo_collectives", + ":gloo_kv_store", + "//xla:executable_run_options", + "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:cpu_clique_key", + "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/core/collectives:clique_id", + "//xla/core/collectives:clique_key", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/pjrt/distributed:in_memory_key_value_store", + "//xla/pjrt/distributed:key_value_store_interface", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/stream_executor:device_memory", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + ] + select({ + # Gloo's transport_tcp is not available on MacOS + "//xla/tsl:macos": [ + "@gloo//:transport_uv", + ], + "//conditions:default": [ + "@gloo//:transport_tcp", + ], + }), +) + +# TODO(b/380457503): Restrict visibility to private. +cc_library( + name = "gloo_communicator", + srcs = ["gloo_communicator.cc"], + hdrs = ["gloo_communicator.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":cpu_collectives", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/stream_executor:device_memory", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@gloo", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + +cc_library( + name = "mpi_collectives", + srcs = ["mpi_collectives.cc"], + hdrs = ["mpi_collectives.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + # copybara:uncomment_begin(google-only) + # "-Ithird_party/openmpi/ompi/include", + # copybara:uncomment_end + ], + features = ["-use_header_modules"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/backends/cpu/collectives:mpi_communicator", + "//xla/core/collectives:clique_id", + "//xla/core/collectives:clique_key", + "//xla/core/collectives:communicator", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@mpitrampoline", + ], +) + +# TODO(b/380457503): Restrict visibility to private. +cc_library( + name = "mpi_communicator", + srcs = ["mpi_communicator.cc"], + hdrs = ["mpi_communicator.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + # copybara:uncomment_begin(google-only) + # "-Ithird_party/openmpi/ompi/include", + # copybara:uncomment_end + ], + features = ["-use_header_modules"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/stream_executor:device_memory", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + "@mpitrampoline", + ], +) diff --git a/third_party/xla/xla/backends/cpu/collectives/cpu_clique.cc b/third_party/xla/xla/backends/cpu/collectives/cpu_clique.cc new file mode 100644 index 00000000000000..a81dd80392f9f1 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/cpu_clique.cc @@ -0,0 +1,59 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/cpu_clique.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/core/collectives/clique.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/tsl/platform/logging.h" + +namespace xla::cpu { + +CpuClique::CpuClique(CpuCliqueKey key) : Clique({}), key_(std::move(key)) {} + +std::string CpuClique::DebugString() const { + std::string out = + absl::StrFormat("key: %s; size: %d; communicators: ", key_.ToString(), + num_communicators()); + int32_t cnt = 0; + ForEachComm([&](RankId rank, Communicator* comm) { + if (cnt++) absl::StrAppend(&out, ", "); + absl::StrAppendFormat(&out, "[rank=%d, comm=%s]", rank.value(), + comm->ToString()); + }); + return out; +} + +absl::Status CpuClique::HealthCheck() const { + absl::Status health_check = absl::OkStatus(); + ForEachComm([&health_check](RankId rank, Communicator* comm) { + if (auto s = comm->HealthCheck(); !s.ok()) { + LOG(ERROR) << "CPU communicator error (rank " << rank << "): " << s; + if (health_check.ok()) health_check = std::move(s); // return first error + } + }); + return health_check; +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/collectives/cpu_clique.h b/third_party/xla/xla/backends/cpu/collectives/cpu_clique.h new file mode 100644 index 00000000000000..e1ff3025a955b0 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/cpu_clique.h @@ -0,0 +1,42 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_H_ + +#include + +#include "absl/status/status.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/core/collectives/clique.h" + +namespace xla::cpu { + +// A group of CPU communicators making up a clique. +class CpuClique final : public Clique { + public: + explicit CpuClique(CpuCliqueKey key); + + absl::Status HealthCheck() const final; + + std::string DebugString() const final; + + private: + CpuCliqueKey key_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_H_ diff --git a/third_party/xla/xla/backends/cpu/collectives/cpu_clique_key.cc b/third_party/xla/xla/backends/cpu/collectives/cpu_clique_key.cc new file mode 100644 index 00000000000000..b66c844d4983ed --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/cpu_clique_key.cc @@ -0,0 +1,59 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/cpu_clique_key.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_format.h" +#include "xla/core/collectives/clique_key.h" +#include "xla/service/global_device_id.h" +#include "tsl/platform/casts.h" + +namespace xla::cpu { + +bool CpuCliqueKey::IsSubsetOf(const CliqueKey& other) const { + auto* other_cpu = tsl::down_cast(&other); + if (other_cpu == nullptr) return false; + + return absl::c_all_of(devices(), [&](GlobalDeviceId id) { + return absl::c_linear_search(other_cpu->devices(), id); + }); +} + +std::string CpuCliqueKey::ToString() const { + return absl::StrFormat("devices=[%s]", GlobalDeviceIdsToString(devices())); +} + +void CpuCliqueKey::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), devices()); +} + +bool operator==(const CpuCliqueKey& a, const CpuCliqueKey& b) { + return a.devices() == b.devices(); +} + +bool operator<(const CpuCliqueKey& a, const CpuCliqueKey& b) { + return a.devices() < b.devices(); +} + +bool operator>(const CpuCliqueKey& a, const CpuCliqueKey& b) { + return a.devices() > b.devices(); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/collectives/cpu_clique_key.h b/third_party/xla/xla/backends/cpu/collectives/cpu_clique_key.h new file mode 100644 index 00000000000000..30b257c1a0d0c0 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/cpu_clique_key.h @@ -0,0 +1,44 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_KEY_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_KEY_H_ + +#include + +#include "absl/hash/hash.h" +#include "xla/core/collectives/clique_key.h" + +namespace xla::cpu { + +// Clique key for identifying a particular CPU collectives clique. +class CpuCliqueKey final : public CliqueKey { + public: + using CliqueKey::CliqueKey; + + bool IsSubsetOf(const CliqueKey& other) const final; + std::string ToString() const final; + + friend bool operator==(const CpuCliqueKey& a, const CpuCliqueKey& b); + friend bool operator<(const CpuCliqueKey& a, const CpuCliqueKey& b); + friend bool operator>(const CpuCliqueKey& a, const CpuCliqueKey& b); + + private: + void HashValue(absl::HashState state) const final; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_KEY_H_ diff --git a/third_party/xla/xla/backends/cpu/collectives/cpu_cliques.cc b/third_party/xla/xla/backends/cpu/collectives/cpu_cliques.cc new file mode 100644 index 00000000000000..c52b400e4b5797 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/cpu_cliques.cc @@ -0,0 +1,121 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/cpu_cliques.h" + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/node_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "xla/backends/cpu/collectives/cpu_clique.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace xla::cpu { + +//===----------------------------------------------------------------------===// +// ProcessCpuCliques +//===----------------------------------------------------------------------===// + +namespace { + +// CpuClique is not thread-safe, so we wrap it in a thread-safe container as we +// create new communicators lazily and potentially from multiple threads. +struct ThreadSafeClique { + explicit ThreadSafeClique(CpuCliqueKey key) : clique(key) {} + + absl::Mutex mu; + CpuClique clique ABSL_GUARDED_BY(mu); +}; + +// Container for initialized and ready to use CPU cliques. In contrast to GPU +// cliques, CPU cliques are not lockable, and we create communicators lazily +// when needed. +struct ProcessCpuCliques { + absl::Mutex mu; + absl::node_hash_map map ABSL_GUARDED_BY(mu); +}; +} // namespace + +// Returns process-local CPU cliques. +static ProcessCpuCliques& GetProcessCpuCliques() { + static auto* cliques = new ProcessCpuCliques; + return *cliques; +} + +//===----------------------------------------------------------------------===// + +// TODO(b/380457503): Consider switching to a lockable CPU clique model similar +// to GPU cliques, and creating all communicators upfront. +absl::StatusOr AcquireCommunicator( + CpuCollectives* collectives, const CpuCliqueKey& clique_key, RankId rank) { + VLOG(3) << "Acquire communicator for clique key " << clique_key.ToString() + << " and rank " << rank; + + ProcessCpuCliques& cliques = GetProcessCpuCliques(); + + // Synchronize access to the process cliques. + ThreadSafeClique& thread_safe_clique = [&]() -> ThreadSafeClique& { + absl::MutexLock lock(&cliques.mu); + auto [it, emplaced] = cliques.map.try_emplace(clique_key, clique_key); + return it->second; + }(); + + // Check if we already have a communicator for this rank. + std::optional comm = [&]() -> std::optional { + absl::MutexLock lock(&thread_safe_clique.mu); + return thread_safe_clique.clique.comm(rank); + }(); + + if (comm.has_value()) return *comm; + + VLOG(3) << "Create a new communicator for clique key " + << clique_key.ToString() << " and rank " << rank; + + // Create a new communicator and add it to the clique. + CpuCollectives::DeviceRank device_rank(/*device=*/nullptr, rank); + CpuCollectives::Config config; + + TF_ASSIGN_OR_RETURN(std::vector> communicators, + collectives->CreateCommunicators(clique_key, std::nullopt, + {device_rank}, config)); + + // We expect to create communicators lazily on at a time. + if (communicators.size() != 1) { + return Internal( + "Expected to create a single communicator for a clique key %s and rank " + "%d, but got %d", + clique_key.ToString(), rank.value(), communicators.size()); + } + + absl::MutexLock lock(&thread_safe_clique.mu); + TF_RETURN_IF_ERROR(thread_safe_clique.clique.AddComm( + rank, std::move(communicators.front()))); + + return *thread_safe_clique.clique.comm(rank); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/collectives/cpu_cliques.h b/third_party/xla/xla/backends/cpu/collectives/cpu_cliques.h new file mode 100644 index 00000000000000..b42774619fe4b2 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/cpu_cliques.h @@ -0,0 +1,33 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUES_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUES_H_ + +#include "absl/status/statusor.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" + +namespace xla::cpu { + +// Returns a communicator for a given clique key and rank. +absl::StatusOr AcquireCommunicator( + CpuCollectives* collectives, const CpuCliqueKey& clique_key, RankId rank); + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUES_H_ diff --git a/third_party/xla/xla/backends/cpu/collectives/cpu_collectives.cc b/third_party/xla/xla/backends/cpu/collectives/cpu_collectives.cc new file mode 100644 index 00000000000000..1500eef4eb8c8a --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/cpu_collectives.cc @@ -0,0 +1,63 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/cpu_collectives.h" + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "xla/core/collectives/collectives.h" +#include "xla/core/collectives/collectives_registry.h" +#include "xla/core/collectives/communicator.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/util.h" +#include "tsl/platform/casts.h" + +namespace xla::cpu { + +CpuCollectives* CpuCollectives::Default() { + absl::StatusOr collectives = + CollectivesRegistry::Default("host"); + CHECK_OK(collectives) << "Failed to get CPU collectives"; // Crash OK + + if (auto* cpu_collectives = tsl::down_cast(*collectives)) { + return cpu_collectives; + } + + LOG(FATAL) << "Unsupported collectives implementation for CPU"; +} + +absl::StatusOr CpuCollectives::TryCast( + const Collectives::Device* device) { + if (auto* cpu_device = tsl::down_cast(device)) { + return cpu_device; + } + return InvalidArgument("Collectives device is not a CPU device"); +} + +absl::StatusOr CpuCollectives::TryCast( + const Communicator::Executor* executor) { + if (auto* cpu_executor = tsl::down_cast(executor)) { + return cpu_executor; + } + return InvalidArgument("Collectives executor is not a CPU executor"); +} + +CpuCollectives::Executor::Executor(RendezvousKey rendezvous_key, + absl::Duration timeout) + : rendezvous_key_(rendezvous_key), timeout_(timeout) {} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/collectives/cpu_collectives.h b/third_party/xla/xla/backends/cpu/collectives/cpu_collectives.h new file mode 100644 index 00000000000000..330b35f52146d1 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/cpu_collectives.h @@ -0,0 +1,83 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_COLLECTIVES_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_CPU_COLLECTIVES_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/core/collectives/clique_id.h" +#include "xla/core/collectives/collectives.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +// XLA:CPU extension of the Collectives interface with CPU-specific APIs. +class CpuCollectives : public Collectives { + public: + // Returns the default collectives implementation for CPU backend. + static CpuCollectives* Default(); + + class Device : public Collectives::Device { + public: + Device() = default; + }; + + // Executor allows CPU collectives clients to pass additional information to + // the collectives implementation. + class Executor : public Communicator::Executor { + public: + Executor(RendezvousKey rendezvous_key, absl::Duration timeout); + + const RendezvousKey& rendezvous_key() const { return rendezvous_key_; } + const absl::Duration& timeout() const { return timeout_; } + + private: + RendezvousKey rendezvous_key_; + absl::Duration timeout_; + }; + + absl::StatusOr CreateUniqueCliqueId() const final { + return Unimplemented("CPU collectives do not support clique ids"); + } + + absl::StatusOr>> SplitCommunicators( + absl::Span comms, int32_t color, + absl::Span keys, const Config& config) final { + return Unimplemented( + "CPU collectives do not support communicator splitting"); + } + + // Tries to cast a Collectives::Device to a CpuCollectives::Device. + static absl::StatusOr TryCast( + const Collectives::Device* device); + + // Tries to cast a Communicator::Executor to a CpuCollectives::Executor. + static absl::StatusOr TryCast( + const Communicator::Executor* executor); +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_CPU_COLLECTIVES_H_ diff --git a/third_party/xla/xla/backends/cpu/collectives/gloo_collectives.cc b/third_party/xla/xla/backends/cpu/collectives/gloo_collectives.cc new file mode 100644 index 00000000000000..eb8705b81fd5f8 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/gloo_collectives.cc @@ -0,0 +1,86 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/gloo_collectives.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "gloo/context.h" +#include "gloo/rendezvous/context.h" +#include "gloo/rendezvous/prefix_store.h" +#include "gloo/rendezvous/store.h" +#include "gloo/transport/device.h" +#include "xla/backends/cpu/collectives/gloo_communicator.h" +#include "xla/core/collectives/clique_id.h" +#include "xla/core/collectives/clique_key.h" +#include "xla/core/collectives/communicator.h" +#include "xla/service/global_device_id.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +GlooCollectives::GlooCollectives( + std::unique_ptr store, + std::shared_ptr device) + : store_(std::move(store)), device_(std::move(device)) {} + +GlooCollectives::~GlooCollectives() = default; + +absl::StatusOr>> +GlooCollectives::CreateCommunicators(const CliqueKey& clique_key, + const std::optional& clique_id, + absl::Span ranks, + const Config& config) { + std::vector> communicators; + for (auto& device_rank : ranks) { + size_t rank = device_rank.rank.value(); + + auto gloo_context = std::make_shared( + rank, clique_key.num_devices()); + auto prefix_store = gloo::rendezvous::PrefixStore( + absl::StrCat("gloo/", + absl::StrJoin(clique_key.devices(), ",", + [](std::string* out, GlobalDeviceId id) { + absl::StrAppend(out, id.value()); + })), + *store_); + + try { + gloo_context->connectFullMesh(prefix_store, device_); + } catch (std::exception& e) { + return absl::UnknownError( + absl::StrCat("Gloo context initialization failed: ", e.what())); + } + + communicators.push_back(std::make_unique( + std::move(gloo_context), rank, clique_key.num_devices())); + } + + return communicators; +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/collectives/gloo_collectives.h b/third_party/xla/xla/backends/cpu/collectives/gloo_collectives.h new file mode 100644 index 00000000000000..9b52a05ea5e342 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/gloo_collectives.h @@ -0,0 +1,56 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_GLOO_COLLECTIVES_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_GLOO_COLLECTIVES_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "gloo/context.h" +#include "gloo/rendezvous/store.h" +#include "gloo/transport/device.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/clique_id.h" +#include "xla/core/collectives/clique_key.h" +#include "xla/core/collectives/communicator.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +class GlooCollectives : public CpuCollectives { + public: + GlooCollectives(std::unique_ptr store, + std::shared_ptr device); + ~GlooCollectives() override; + + absl::StatusOr>> + CreateCommunicators(const CliqueKey& clique_key, + const std::optional& clique_id, + absl::Span ranks, + const Config& config) final; + + private: + std::unique_ptr store_; + std::shared_ptr device_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_GLOO_COLLECTIVES_H_ diff --git a/third_party/xla/xla/pjrt/cpu/gloo_collectives_test.cc b/third_party/xla/xla/backends/cpu/collectives/gloo_collectives_test.cc similarity index 75% rename from third_party/xla/xla/pjrt/cpu/gloo_collectives_test.cc rename to third_party/xla/xla/backends/cpu/collectives/gloo_collectives_test.cc index b8bb7810dd3909..c4a9009e73c884 100644 --- a/third_party/xla/xla/pjrt/cpu/gloo_collectives_test.cc +++ b/third_party/xla/xla/backends/cpu/collectives/gloo_collectives_test.cc @@ -13,38 +13,45 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/pjrt/cpu/gloo_collectives.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" #include #include #include #include +#include +#include #include #include "absl/status/statusor.h" #include "absl/time/time.h" #include "absl/types/span.h" -#if defined(__linux__) -#include "gloo/transport/tcp/attr.h" -#include "gloo/transport/tcp/device.h" -#elif defined(__APPLE__) -#include "gloo/transport/uv/device.h" -#endif // defined(__linux__) +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/backends/cpu/collectives/gloo_kv_store.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/executable_run_options.h" -#include "xla/pjrt/cpu/gloo_kv_store.h" #include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/service/global_device_id.h" +#include "xla/stream_executor/device_memory.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/threadpool.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" -#include "tsl/platform/threadpool.h" + +#if defined(__linux__) +#include "gloo/transport/tcp/attr.h" +#include "gloo/transport/tcp/device.h" +#elif defined(__APPLE__) +#include "gloo/transport/uv/device.h" +#endif // defined(__linux__) namespace xla::cpu { @@ -56,7 +63,7 @@ constexpr int kNumParticipants = 2; constexpr size_t kBufferSize = 256; constexpr absl::Duration kTimeout = absl::Seconds(5); -absl::StatusOr> GetCommunicator( +absl::StatusOr> GetCommunicator( size_t kNumParticipants, absl::Span global_devices, const std::shared_ptr& kv_store, int rank) { auto collectives = std::make_shared( @@ -66,7 +73,16 @@ absl::StatusOr> GetCommunicator( #elif defined(__APPLE__) gloo::transport::uv::CreateDevice(gloo::transport::uv::attr())); #endif // defined(__linux__) - return collectives->GetCommunicator(global_devices, rank); + + CpuCliqueKey clique_key(global_devices); + CpuCollectives::DeviceRank device_rank(nullptr, RankId(rank)); + + TF_ASSIGN_OR_RETURN( + auto communicators, + collectives->CreateCommunicators(clique_key, std::nullopt, {device_rank}, + CpuCollectives::Config())); + + return std::move(communicators[0]); } RendezvousKey MakeRendezvousKey(std::vector global_devices) { @@ -77,6 +93,12 @@ RendezvousKey MakeRendezvousKey(std::vector global_devices) { // TODO(cobley) - add tests for other collectives. +template +static se::DeviceMemoryBase AsDeviceMemory(const std::vector& data) { + return se::DeviceMemoryBase(const_cast(data.data()), + data.size() * sizeof(T)); +} + absl::StatusOr> AllReduce( const std::shared_ptr& kv_store, const std::vector& input_buffer, @@ -87,9 +109,10 @@ absl::StatusOr> AllReduce( auto communicator, GetCommunicator(kNumParticipants, global_devices, kv_store, rank)); + CpuCollectives::Executor executor(rendezvous_key, kTimeout); TF_RETURN_IF_ERROR(communicator->AllReduce( - rendezvous_key, xla::ReductionKind::SUM, xla::PrimitiveType::U8, - kBufferSize, input_buffer.data(), output_buffer.data(), kTimeout)); + AsDeviceMemory(input_buffer), AsDeviceMemory(output_buffer), + xla::PrimitiveType::U8, kBufferSize, xla::ReductionKind::SUM, executor)); return output_buffer; } diff --git a/third_party/xla/xla/pjrt/cpu/gloo_collectives.cc b/third_party/xla/xla/backends/cpu/collectives/gloo_communicator.cc similarity index 58% rename from third_party/xla/xla/pjrt/cpu/gloo_collectives.cc rename to third_party/xla/xla/backends/cpu/collectives/gloo_communicator.cc index bfe17be6f2ad90..e5e19aa3a1cfed 100644 --- a/third_party/xla/xla/pjrt/cpu/gloo_collectives.cc +++ b/third_party/xla/xla/backends/cpu/collectives/gloo_communicator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/pjrt/cpu/gloo_collectives.h" +#include "xla/backends/cpu/collectives/gloo_communicator.h" #include #include @@ -22,16 +22,12 @@ limitations under the License. #include #include #include -#include -#include #include #include +#include "absl/log/log.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "absl/types/span.h" @@ -41,39 +37,40 @@ limitations under the License. #include "gloo/context.h" #include "gloo/math.h" #include "gloo/reduce_scatter.h" -#include "gloo/rendezvous/context.h" -#include "gloo/rendezvous/prefix_store.h" -#include "gloo/rendezvous/store.h" #include "gloo/transport/device.h" #include "gloo/transport/unbound_buffer.h" #include "gloo/types.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/rank_id.h" #include "xla/primitive_util.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" -#include "xla/service/global_device_id.h" #include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" namespace xla::cpu { -GlooCollectivesCommunicator::GlooCollectivesCommunicator( - std::shared_ptr context) - : context_(std::move(context)) {} -GlooCollectivesCommunicator::~GlooCollectivesCommunicator() = default; +GlooCommunicator::GlooCommunicator(std::shared_ptr context, + size_t rank, size_t num_ranks) + : context_(std::move(context)), rank_(rank), num_ranks_(num_ranks) {} + +GlooCommunicator::~GlooCommunicator() = default; template static absl::Status SetAllReduceOptions(ReductionKind reduction_kind, - const void* input_buffer, - void* output_buffer, + se::DeviceMemoryBase input_buffer, + se::DeviceMemoryBase output_buffer, size_t num_elements, gloo::AllreduceOptions& options) { - options.setInput(reinterpret_cast(const_cast(input_buffer)), - num_elements); - options.setOutput(reinterpret_cast(const_cast(output_buffer)), - num_elements); + options.setInput( + reinterpret_cast(const_cast(input_buffer.opaque())), + num_elements); + options.setOutput( + reinterpret_cast(const_cast(output_buffer.opaque())), + num_elements); using ReductionFn = void (*)(void*, const void*, const void*, size_t); @@ -104,76 +101,79 @@ static absl::Status SetAllReduceOptions(ReductionKind reduction_kind, return absl::OkStatus(); } -absl::Status GlooCollectivesCommunicator::AllReduce( - const RendezvousKey& key, ReductionKind reduction_kind, - PrimitiveType element_type, size_t num_elements, const void* input_buffer, - void* output_buffer, absl::Duration timeout) { +absl::Status GlooCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + gloo::AllreduceOptions options(context_); // TODO(phawkins): how to do tags? // options.setTag(tag); - switch (element_type) { + switch (dtype) { case S8: TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, input_buffer, output_buffer, num_elements, options)); + reduction_kind, send_buffer, recv_buffer, count, options)); break; case PRED: case U8: TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, input_buffer, output_buffer, num_elements, options)); + reduction_kind, send_buffer, recv_buffer, count, options)); break; case S16: TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, input_buffer, output_buffer, num_elements, options)); + reduction_kind, send_buffer, recv_buffer, count, options)); break; case U16: TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, input_buffer, output_buffer, num_elements, options)); + reduction_kind, send_buffer, recv_buffer, count, options)); break; case S32: TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, input_buffer, output_buffer, num_elements, options)); + reduction_kind, send_buffer, recv_buffer, count, options)); break; case U32: TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, input_buffer, output_buffer, num_elements, options)); + reduction_kind, send_buffer, recv_buffer, count, options)); break; case S64: TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, input_buffer, output_buffer, num_elements, options)); + reduction_kind, send_buffer, recv_buffer, count, options)); break; case U64: TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, input_buffer, output_buffer, num_elements, options)); + reduction_kind, send_buffer, recv_buffer, count, options)); break; case F16: TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, input_buffer, output_buffer, num_elements, options)); + reduction_kind, send_buffer, recv_buffer, count, options)); break; case BF16: TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, input_buffer, output_buffer, num_elements, options)); + reduction_kind, send_buffer, recv_buffer, count, options)); break; case F32: TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, input_buffer, output_buffer, num_elements, options)); + reduction_kind, send_buffer, recv_buffer, count, options)); break; case F64: TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, input_buffer, output_buffer, num_elements, options)); + reduction_kind, send_buffer, recv_buffer, count, options)); break; case C64: TF_RETURN_IF_ERROR(SetAllReduceOptions>( - reduction_kind, input_buffer, output_buffer, num_elements, options)); + reduction_kind, send_buffer, recv_buffer, count, options)); break; case C128: TF_RETURN_IF_ERROR(SetAllReduceOptions>( - reduction_kind, input_buffer, output_buffer, num_elements, options)); + reduction_kind, send_buffer, recv_buffer, count, options)); break; default: return absl::InvalidArgumentError("Unknown datatype in allreduce"); } options.setAlgorithm(gloo::AllreduceOptions::Algorithm::RING); - options.setTimeout(absl::ToChronoMilliseconds(timeout)); + options.setTimeout(absl::ToChronoMilliseconds(cpu_executor->timeout())); try { gloo::allreduce(options); @@ -186,38 +186,42 @@ absl::Status GlooCollectivesCommunicator::AllReduce( static constexpr uint8_t kCollectivePermuteSlotPrefix = 0x40; -absl::Status GlooCollectivesCommunicator::CollectivePermute( - const RendezvousKey& key, size_t num_bytes, std::optional source_rank, - absl::Span target_ranks, const void* input_buffer, - void* output_buffer, absl::Duration timeout) { +absl::Status GlooCommunicator::CollectivePermute( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, std::optional source_rank, + absl::Span target_ranks, const Executor& executor) { uint32_t tag = 0; // TODO(phawkins): come up with better tags. const auto slot = gloo::Slot::build(kCollectivePermuteSlotPrefix, tag); + + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + size_t num_bytes = count * primitive_util::ByteWidth(dtype); + try { std::unique_ptr in; std::unique_ptr out; - for (int target : target_ranks) { + for (RankId target : target_ranks) { if (target != context_->rank) { - VLOG(1) << "send from " << context_->rank << " to " << target; + VLOG(1) << "send from " << context_->rank << " to " << target.value(); if (!in) { - in = context_->createUnboundBuffer(const_cast(input_buffer), - num_bytes); + in = context_->createUnboundBuffer(send_buffer.opaque(), num_bytes); } - in->send(target, slot); + in->send(target.value(), slot); } } if (source_rank) { if (*source_rank == context_->rank) { - std::memcpy(output_buffer, input_buffer, num_bytes); + std::memcpy(recv_buffer.opaque(), send_buffer.opaque(), num_bytes); } else { - VLOG(1) << "recv at " << context_->rank << " from " << *source_rank; - out = context_->createUnboundBuffer(output_buffer, num_bytes); - out->recv(*source_rank, slot); + VLOG(1) << "recv at " << context_->rank << " from " + << source_rank->value(); + out = context_->createUnboundBuffer(recv_buffer.opaque(), num_bytes); + out->recv(source_rank->value(), slot); } } else { - std::memset(output_buffer, 0, num_bytes); + std::memset(recv_buffer.opaque(), 0, num_bytes); } VLOG(1) << "wait for send at " << context_->rank; - auto deadline = absl::ToChronoTime(absl::Now() + timeout); + auto deadline = absl::ToChronoTime(absl::Now() + cpu_executor->timeout()); if (in) { in->waitSend(deadline); } @@ -233,10 +237,10 @@ absl::Status GlooCollectivesCommunicator::CollectivePermute( return absl::OkStatus(); } -absl::Status GlooCollectivesCommunicator::AllToAll( - const RendezvousKey& key, size_t chunk_bytes, - absl::Span input_buffers, - absl::Span output_buffers, absl::Duration timeout) { +absl::Status GlooCommunicator::AllToAll( + absl::Span send_buffers, + absl::Span recv_buffers, PrimitiveType dtype, + size_t count, const Executor& executor) { // We can't use Gloo's all-to-all implementation directly because it assumes // that the inputs and outputs are contiguous. No big deal; it's just built // on top of send/recv and we can do the same as it. @@ -244,8 +248,11 @@ absl::Status GlooCollectivesCommunicator::AllToAll( int my_rank = context_->rank; int world_size = context_->size; - TF_RET_CHECK(world_size == input_buffers.size()); - TF_RET_CHECK(world_size == output_buffers.size()); + TF_RET_CHECK(world_size == send_buffers.size()); + TF_RET_CHECK(world_size == recv_buffers.size()); + + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); try { const auto slot = gloo::Slot::build(gloo::kAlltoallSlotPrefix, tag); @@ -256,8 +263,9 @@ absl::Status GlooCollectivesCommunicator::AllToAll( for (size_t i = 0; i < world_size; ++i) { if (i != my_rank) { ins[i] = context_->createUnboundBuffer( - const_cast(input_buffers[i]), chunk_bytes); - outs[i] = context_->createUnboundBuffer(output_buffers[i], chunk_bytes); + const_cast(send_buffers[i].opaque()), chunk_bytes); + outs[i] = context_->createUnboundBuffer( + const_cast(recv_buffers[i].opaque()), chunk_bytes); } } @@ -268,9 +276,10 @@ absl::Status GlooCollectivesCommunicator::AllToAll( outs[recv_rank]->recv(recv_rank, slot); } - std::memcpy(output_buffers[my_rank], input_buffers[my_rank], chunk_bytes); + std::memcpy(const_cast(recv_buffers[my_rank].opaque()), + send_buffers[my_rank].opaque(), chunk_bytes); - auto deadline = absl::ToChronoTime(absl::Now() + timeout); + auto deadline = absl::ToChronoTime(absl::Now() + cpu_executor->timeout()); for (int i = 0; i < world_size; i++) { if (i != my_rank) { ins[i]->waitSend(deadline); @@ -284,19 +293,20 @@ absl::Status GlooCollectivesCommunicator::AllToAll( return absl::OkStatus(); } -absl::Status GlooCollectivesCommunicator::AllGather(const RendezvousKey& key, - size_t chunk_bytes, - const void* input_buffer, - void* output_buffer, - absl::Duration timeout) { +absl::Status GlooCommunicator::AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + const Executor& executor) { uint32_t tag = 0; // TODO(phawkins): use better tags. + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); + gloo::AllgatherOptions options(context_); options.setTag(tag); - options.setTimeout(absl::ToChronoMilliseconds(timeout)); - options.setInput(reinterpret_cast(const_cast(input_buffer)), - chunk_bytes); - options.setOutput(reinterpret_cast(output_buffer), + options.setTimeout(absl::ToChronoMilliseconds(cpu_executor->timeout())); + options.setInput(reinterpret_cast(send_buffer.opaque()), chunk_bytes); + options.setOutput(reinterpret_cast(recv_buffer.opaque()), chunk_bytes * context_->size); try { @@ -357,122 +367,77 @@ absl::Status ReduceScatterHelper(std::shared_ptr context, return absl::OkStatus(); } -absl::Status GlooCollectivesCommunicator::ReduceScatter( - const RendezvousKey& key, ReductionKind reduction_kind, - PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, - void* output_buffer, absl::Duration timeout) { - size_t chunk_bytes = chunk_elems * primitive_util::ByteWidth(element_type); +absl::Status GlooCommunicator::ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) { + size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); std::unique_ptr temp(new char[chunk_bytes * context_->size]); - std::memcpy(temp.get(), input_buffer, chunk_bytes * context_->size); - switch (element_type) { + std::memcpy(temp.get(), send_buffer.opaque(), chunk_bytes * context_->size); + switch (dtype) { case S8: TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), chunk_elems)); + temp.get(), count)); break; case PRED: case U8: TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), chunk_elems)); + temp.get(), count)); break; case S16: TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), chunk_elems)); + temp.get(), count)); break; case U16: - TF_RETURN_IF_ERROR(ReduceScatterHelper( - context_, reduction_kind, temp.get(), chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); break; case S32: TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), chunk_elems)); + temp.get(), count)); break; case U32: - TF_RETURN_IF_ERROR(ReduceScatterHelper( - context_, reduction_kind, temp.get(), chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); break; case S64: TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), chunk_elems)); + temp.get(), count)); break; case U64: - TF_RETURN_IF_ERROR(ReduceScatterHelper( - context_, reduction_kind, temp.get(), chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); break; case BF16: - TF_RETURN_IF_ERROR(ReduceScatterHelper( - context_, reduction_kind, temp.get(), chunk_elems)); + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); break; case F16: TF_RETURN_IF_ERROR(ReduceScatterHelper( - context_, reduction_kind, temp.get(), chunk_elems)); + context_, reduction_kind, temp.get(), count)); break; case F32: TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), chunk_elems)); + temp.get(), count)); break; case F64: TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), chunk_elems)); + temp.get(), count)); break; case C64: TF_RETURN_IF_ERROR(ReduceScatterHelper>( - context_, reduction_kind, temp.get(), chunk_elems)); + context_, reduction_kind, temp.get(), count)); break; case C128: TF_RETURN_IF_ERROR(ReduceScatterHelper>( - context_, reduction_kind, temp.get(), chunk_elems)); + context_, reduction_kind, temp.get(), count)); break; default: return absl::InvalidArgumentError("Unknown datatype in reducescatter"); } - std::memcpy(output_buffer, temp.get(), chunk_bytes); + std::memcpy(recv_buffer.opaque(), temp.get(), chunk_bytes); return absl::OkStatus(); } -GlooCollectives::GlooCollectives( - std::unique_ptr store, - std::shared_ptr device) - : store_(std::move(store)), device_(std::move(device)) {} - -GlooCollectives::~GlooCollectives() = default; - -absl::StatusOr> -GlooCollectives::GetCommunicator( - absl::Span global_devices, int rank) { - Context* context; - { - absl::MutexLock lock(&mu_); - auto& context_ref = contexts_[std::make_tuple( - std::vector(global_devices.begin(), - global_devices.end()), - rank)]; - if (!context_ref) { - context_ref = std::make_unique(); - } - context = context_ref.get(); - } - absl::MutexLock context_lock(&context->mu); - if (context->communicator) { - return context->communicator; - } - auto gloo_context = - std::make_shared(rank, global_devices.size()); - auto prefix_store = gloo::rendezvous::PrefixStore( - absl::StrCat("gloo/", - absl::StrJoin(global_devices, ",", - [](std::string* out, GlobalDeviceId id) { - absl::StrAppend(out, id.value()); - })), - *store_); - try { - gloo_context->connectFullMesh(prefix_store, device_); - } catch (std::exception& e) { - return absl::UnknownError( - absl::StrCat("Gloo context initialization failed: ", e.what())); - } - context->communicator = - std::make_shared(std::move(gloo_context)); - return context->communicator; -} - } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/collectives/gloo_communicator.h b/third_party/xla/xla/backends/cpu/collectives/gloo_communicator.h new file mode 100644 index 00000000000000..234716da759340 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/gloo_communicator.h @@ -0,0 +1,103 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_GLOO_COMMUNICATOR_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_GLOO_COMMUNICATOR_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "gloo/context.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +// XLA communicator implemented using Gloo communication library. +class GlooCommunicator : public Communicator { + public: + GlooCommunicator(std::shared_ptr context, size_t rank, + size_t num_ranks); + ~GlooCommunicator() override; + + absl::Status AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + std::optional source_rank, + absl::Span target_ranks, + const Executor& executor) override; + + absl::Status AllToAll(absl::Span send_buffers, + absl::Span recv_buffers, + PrimitiveType dtype, size_t count, + const Executor& executor) override; + + absl::Status AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, const Executor& executor) override; + + absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Broadcast is not implemented"); + } + + absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Send is not implemented"); + } + + absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Recv is not implemented"); + } + + absl::StatusOr NumRanks() const override { return num_ranks_; } + + std::string ToString() const override { + return absl::StrCat("GlooCommunicator [rank: ", rank_, + " num_ranks: ", num_ranks_, "]"); + } + + private: + std::shared_ptr context_; + size_t rank_; + size_t num_ranks_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_GLOO_COMMUNICATOR_H_ diff --git a/third_party/xla/xla/pjrt/cpu/gloo_kv_store.cc b/third_party/xla/xla/backends/cpu/collectives/gloo_kv_store.cc similarity index 93% rename from third_party/xla/xla/pjrt/cpu/gloo_kv_store.cc rename to third_party/xla/xla/backends/cpu/collectives/gloo_kv_store.cc index 7feb80b3435c00..bba2b7a6451f30 100644 --- a/third_party/xla/xla/pjrt/cpu/gloo_kv_store.cc +++ b/third_party/xla/xla/backends/cpu/collectives/gloo_kv_store.cc @@ -13,13 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/pjrt/cpu/gloo_kv_store.h" +#include "xla/backends/cpu/collectives/gloo_kv_store.h" #include // NOLINT #include #include #include -#include #include #include @@ -39,7 +38,8 @@ GlooKeyValueStore::~GlooKeyValueStore() = default; void GlooKeyValueStore::set(const std::string& key, const std::vector& data) { - ThrowIfError(kv_store_->Set(key, std::string_view(data.data(), data.size()))); + ThrowIfError( + kv_store_->Set(key, absl::string_view(data.data(), data.size()))); } std::vector GlooKeyValueStore::get(const std::string& key) { diff --git a/third_party/xla/xla/pjrt/cpu/gloo_kv_store.h b/third_party/xla/xla/backends/cpu/collectives/gloo_kv_store.h similarity index 90% rename from third_party/xla/xla/pjrt/cpu/gloo_kv_store.h rename to third_party/xla/xla/backends/cpu/collectives/gloo_kv_store.h index 2872168372c0d6..1cba490ba5ce65 100644 --- a/third_party/xla/xla/pjrt/cpu/gloo_kv_store.h +++ b/third_party/xla/xla/backends/cpu/collectives/gloo_kv_store.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_PJRT_CPU_GLOO_KV_STORE_H_ -#define XLA_PJRT_CPU_GLOO_KV_STORE_H_ +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_GLOO_KV_STORE_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_GLOO_KV_STORE_H_ #include // NOLINT #include @@ -49,4 +49,4 @@ class GlooKeyValueStore : public ::gloo::rendezvous::Store { } // namespace xla::cpu -#endif // XLA_PJRT_CPU_GLOO_KV_STORE_H_ +#endif // XLA_BACKENDS_CPU_COLLECTIVES_GLOO_KV_STORE_H_ diff --git a/third_party/xla/xla/backends/cpu/collectives/in_process_collectives.cc b/third_party/xla/xla/backends/cpu/collectives/in_process_collectives.cc new file mode 100644 index 00000000000000..0fb139c5d38e07 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/in_process_collectives.cc @@ -0,0 +1,50 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/in_process_collectives.h" + +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/backends/cpu/collectives/in_process_communicator.h" +#include "xla/core/collectives/clique_id.h" +#include "xla/core/collectives/clique_key.h" +#include "xla/core/collectives/communicator.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +absl::StatusOr>> +InProcessCollectives::CreateCommunicators( + const CliqueKey& clique_key, const std::optional& clique_id, + absl::Span ranks, const Config& config) { + std::vector> communicators; + communicators.reserve(ranks.size()); + + for (auto& device_rank : ranks) { + size_t rank = device_rank.rank.value(); + communicators.push_back(std::make_unique( + rank, clique_key.num_devices())); + } + + return communicators; +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/collectives/in_process_collectives.h b/third_party/xla/xla/backends/cpu/collectives/in_process_collectives.h new file mode 100644 index 00000000000000..9d3150a469aca3 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/in_process_collectives.h @@ -0,0 +1,48 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_IN_PROCESS_COLLECTIVES_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_IN_PROCESS_COLLECTIVES_H_ + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/backends/cpu/collectives/in_process_communicator.h" +#include "xla/core/collectives/clique_id.h" +#include "xla/core/collectives/clique_key.h" +#include "xla/core/collectives/communicator.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +class InProcessCollectives : public CpuCollectives { + public: + absl::StatusOr>> + CreateCommunicators(const CliqueKey& clique_key, + const std::optional& clique_id, + absl::Span ranks, + const Config& config) final; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_IN_PROCESS_COLLECTIVES_H_ diff --git a/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.cc b/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.cc new file mode 100644 index 00000000000000..2d4dc88f9ef27a --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.cc @@ -0,0 +1,420 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/in_process_communicator.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/primitive_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/rendezvous.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { +namespace { + +template +static bool ByRank(const Participant* a, const Participant* b) { + return a->rank < b->rank; +} + +template +T GetInitialValue(ReductionKind reduction_kind) { + switch (reduction_kind) { + case ReductionKind::SUM: + return static_cast(0); + case ReductionKind::PRODUCT: + return static_cast(1); + case ReductionKind::MIN: + return std::numeric_limits::has_infinity + ? std::numeric_limits::infinity() + : std::numeric_limits::max(); + case ReductionKind::MAX: + return std::numeric_limits::has_infinity + ? -std::numeric_limits::infinity() + : std::numeric_limits::lowest(); + } +} + +// We cannot use static_assert(false), because the C++ standard (prior to +// CWG2518) does not allow the statement discarded by a constexpr if to +// be ill-formed for every possible specialization. +// See https://en.cppreference.com/w/cpp/language/if#Constexpr_if +template +constexpr bool always_false_v = false; + +template +void ReduceHelper(absl::Span acc, absl::Span inputs) { + // TODO(penporn): make sure this gets vectorized. + if constexpr (reduction_kind == ReductionKind::SUM) { + for (size_t j = 0; j < inputs.size(); ++j) { + for (size_t i = 0; i < acc.size(); ++i) { + acc[i] += inputs[j][i]; + } + } + } else if constexpr (reduction_kind == ReductionKind::PRODUCT) { + for (size_t j = 0; j < inputs.size(); ++j) { + for (size_t i = 0; i < acc.size(); ++i) { + acc[i] *= inputs[j][i]; + } + } + } else if constexpr (reduction_kind == ReductionKind::MIN) { + for (size_t j = 0; j < inputs.size(); ++j) { + for (size_t i = 0; i < acc.size(); ++i) { + acc[i] = std::min(acc[i], inputs[j][i]); + } + } + } else if constexpr (reduction_kind == ReductionKind::MAX) { + for (size_t j = 0; j < inputs.size(); ++j) { + for (size_t i = 0; i < acc.size(); ++i) { + acc[i] = std::max(acc[i], inputs[j][i]); + } + } + } else { + static_assert(always_false_v, "Unsupported reduction kind"); + } +} + +template +absl::Status ReduceScatter(ReductionKind reduction_kind, + absl::Span inputs, void* output, + int64_t num_elems) { + using T = primitive_util::NativeTypeOf; + T initial_value = GetInitialValue(reduction_kind); + + absl::Span out_chunk = + absl::MakeSpan(reinterpret_cast(output), num_elems); + for (int64_t i = 0; i < num_elems; ++i) { + out_chunk[i] = initial_value; + } + + absl::Span input_chunks( + reinterpret_cast(inputs.data()), inputs.size()); + switch (reduction_kind) { + case ReductionKind::SUM: + ReduceHelper(out_chunk, input_chunks); + break; + case ReductionKind::PRODUCT: + ReduceHelper(out_chunk, input_chunks); + break; + case ReductionKind::MIN: + if constexpr (!is_complex_v) { + ReduceHelper(out_chunk, input_chunks); + } else { + return absl::InvalidArgumentError( + "Min reductions not supported for complex types"); + } + break; + case ReductionKind::MAX: + if constexpr (!is_complex_v) { + ReduceHelper(out_chunk, input_chunks); + } else { + return absl::InvalidArgumentError( + "Max reductions not supported for complex types"); + } + break; + } + + return absl::OkStatus(); +} + +//===----------------------------------------------------------------------===// +// AllReduce +//===----------------------------------------------------------------------===// + +struct AllReduceParticipant { + size_t rank; + se::DeviceMemoryBase src; + se::DeviceMemoryBase dest; +}; + +static absl::Status AllReduceOp( + PrimitiveType primitive_type, size_t count, ReductionKind reduction_kind, + absl::Span participants) { + absl::c_sort(participants, ByRank); + + if (!primitive_util::IsArrayType(primitive_type)) { + return Unimplemented( + "Unexpected datatype: %s", + primitive_util::LowercasePrimitiveTypeName(primitive_type)); + } + + // Collect reduction inputs from all participants. + std::vector inputs(participants.size()); + for (auto* participant : participants) { + inputs[participant->rank] = participant->src.opaque(); + } + + // Reduce all inputs into the destination buffer at rank 0. + void* output = participants[0]->dest.opaque(); + + TF_RETURN_IF_ERROR(primitive_util::ArrayTypeSwitch( + [&](const auto type_tag) { + return ReduceScatter(reduction_kind, inputs, output, count); + }, + primitive_type)); + + // Copy all-reduced output to all other participants. + for (size_t i = 1; i < participants.size(); ++i) { + std::memcpy(participants[i]->dest.opaque(), participants[0]->dest.opaque(), + count * primitive_util::ByteWidth(primitive_type)); + } + + return absl::OkStatus(); +} + +//===----------------------------------------------------------------------===// +// ReduceScatter +//===----------------------------------------------------------------------===// + +struct ReduceScatterParticipant { + size_t rank; + se::DeviceMemoryBase src; + se::DeviceMemoryBase dest; +}; + +static absl::Status ReduceScatterOp( + PrimitiveType primitive_type, size_t count, ReductionKind reduction_kind, + absl::Span participants) { + absl::c_sort(participants, ByRank); + + if (!primitive_util::IsArrayType(primitive_type)) { + return Unimplemented( + "Unexpected datatype: %s", + primitive_util::LowercasePrimitiveTypeName(primitive_type)); + } + + size_t num_participants = participants.size(); + size_t num_bytes = count * primitive_util::ByteWidth(primitive_type); + + for (size_t i = 0; i < num_participants; ++i) { + size_t offset = i * num_bytes; + + // Collect reduction inputs from all participants. + std::vector inputs(num_participants); + for (size_t j = 0; j < num_participants; ++j) { + std::byte* src = static_cast(participants[j]->src.opaque()); + inputs[j] = src + offset; + } + + // Reduce all inputs into the destination buffer. + void* output = participants[i]->dest.opaque(); + + TF_RETURN_IF_ERROR(primitive_util::ArrayTypeSwitch( + [&](const auto type_tag) { + return ReduceScatter(reduction_kind, inputs, output, count); + }, + primitive_type)); + } + + return absl::OkStatus(); +} + +//===----------------------------------------------------------------------===// +// AllGather +//===----------------------------------------------------------------------===// + +struct AllGatherParticipant { + size_t rank; + se::DeviceMemoryBase src; + se::DeviceMemoryBase dest; +}; + +static absl::Status AllGatherOp( + size_t num_bytes, absl::Span participants) { + absl::c_sort(participants, ByRank); + + size_t num_participants = participants.size(); + + for (size_t i = 0; i < num_participants; ++i) { + for (size_t j = 0; j < num_participants; ++j) { + std::byte* dest = static_cast(participants[i]->dest.opaque()); + size_t offset = j * num_bytes; + std::memcpy(dest + offset, participants[j]->src.opaque(), num_bytes); + } + } + + return absl::OkStatus(); +} + +//===----------------------------------------------------------------------===// +// AllToAll +//===----------------------------------------------------------------------===// + +struct AllToAllParticipant { + size_t rank; + + std::vector src; + std::vector dest; +}; + +static absl::Status AllToAllOp( + size_t num_bytes, absl::Span participants) { + absl::c_sort(participants, ByRank); + + size_t num_participants = participants.size(); + + for (size_t i = 0; i < num_participants; ++i) { + for (size_t j = 0; j < num_participants; ++j) { + std::memcpy(participants[j]->dest[i].opaque(), + participants[i]->src[j].opaque(), num_bytes); + } + } + + return absl::OkStatus(); +} + +//===----------------------------------------------------------------------===// +// CollectivePermute +//===----------------------------------------------------------------------===// + +struct CollectivePermuteParticipant { + size_t rank; + std::optional src_rank; + + se::DeviceMemoryBase src; + se::DeviceMemoryBase dest; +}; + +static absl::Status CollectivePermuteOp( + size_t num_bytes, + absl::Span participants) { + absl::c_sort(participants, ByRank); + + for (const CollectivePermuteParticipant* participant : participants) { + void* dest = participant->dest.opaque(); + + if (participant->src_rank) { + size_t src_rank = participant->src_rank->value(); + std::memcpy(dest, participants.at(src_rank)->src.opaque(), num_bytes); + } else { + std::memset(dest, 0, num_bytes); + } + } + return absl::OkStatus(); +} + +} // namespace + +//===----------------------------------------------------------------------===// + +InProcessCommunicator::InProcessCommunicator(size_t rank, size_t num_ranks) + : rank_(rank), num_ranks_(num_ranks) {} + +absl::Status InProcessCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + const RendezvousKey& key = cpu_executor->rendezvous_key(); + + std::string name = absl::StrCat("all reduce ", key.ToString()); + AllReduceParticipant partiticipant{rank_, send_buffer, recv_buffer}; + + return Rendezvous( + name, key, partiticipant, key.num_local_participants, + std::bind(AllReduceOp, dtype, count, reduction_kind, + std::placeholders::_1)); +} + +absl::Status InProcessCommunicator::CollectivePermute( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, std::optional source_rank, + absl::Span target_ranks, const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + const RendezvousKey& key = cpu_executor->rendezvous_key(); + + std::string name = absl::StrCat("collective permute ", key.ToString()); + CollectivePermuteParticipant partiticipant{rank_, source_rank, send_buffer, + recv_buffer}; + + size_t num_bytes = count * primitive_util::ByteWidth(dtype); + return Rendezvous( + name, key, partiticipant, key.num_local_participants, + std::bind(CollectivePermuteOp, num_bytes, std::placeholders::_1)); +} + +absl::Status InProcessCommunicator::AllToAll( + absl::Span send_buffers, + absl::Span recv_buffers, PrimitiveType dtype, + size_t count, const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + const RendezvousKey& key = cpu_executor->rendezvous_key(); + + std::string name = absl::StrCat("all to all ", key.ToString()); + AllToAllParticipant partiticipant{rank_, + {send_buffers.begin(), send_buffers.end()}, + {recv_buffers.begin(), recv_buffers.end()}}; + + size_t num_bytes = count * primitive_util::ByteWidth(dtype); + return Rendezvous( + name, key, partiticipant, key.num_local_participants, + std::bind(AllToAllOp, num_bytes, std::placeholders::_1)); +} + +absl::Status InProcessCommunicator::AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + const RendezvousKey& key = cpu_executor->rendezvous_key(); + + std::string name = absl::StrCat("all gather ", key.ToString()); + AllGatherParticipant partiticipant{rank_, send_buffer, recv_buffer}; + + size_t num_bytes = count * primitive_util::ByteWidth(dtype); + return Rendezvous( + name, key, partiticipant, key.num_local_participants, + std::bind(AllGatherOp, num_bytes, std::placeholders::_1)); +} + +absl::Status InProcessCommunicator::ReduceScatter( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, ReductionKind reduction_kind, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + const RendezvousKey& key = cpu_executor->rendezvous_key(); + + std::string name = absl::StrCat("reduce scatter ", key.ToString()); + ReduceScatterParticipant partiticipant{rank_, send_buffer, recv_buffer}; + + return Rendezvous( + name, key, partiticipant, key.num_local_participants, + std::bind(ReduceScatterOp, dtype, count, reduction_kind, + std::placeholders::_1)); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.h b/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.h new file mode 100644 index 00000000000000..f4366c858f6608 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/in_process_communicator.h @@ -0,0 +1,99 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_IN_PROCESS_COMMUNICATOR_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_IN_PROCESS_COMMUNICATOR_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +// XLA communicator that implements collective operations using shared memory +// and works only within a single process. +class InProcessCommunicator : public Communicator { + public: + InProcessCommunicator(size_t rank, size_t num_ranks); + + absl::Status AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + std::optional source_rank, + absl::Span target_ranks, + const Executor& executor) override; + + absl::Status AllToAll(absl::Span send_buffers, + absl::Span recv_buffers, + PrimitiveType dtype, size_t count, + const Executor& executor) override; + + absl::Status AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, const Executor& executor) override; + + absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Broadcast is not implemented"); + } + + absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Send is not implemented"); + } + + absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Recv is not implemented"); + } + + absl::StatusOr NumRanks() const override { return num_ranks_; } + + std::string ToString() const override { + return absl::StrCat("InProcessCommunicator [rank: ", rank_, + " num_ranks: ", num_ranks_, "]"); + } + + private: + size_t rank_; + size_t num_ranks_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_IN_PROCESS_COMMUNICATOR_H_ diff --git a/third_party/xla/xla/backends/cpu/collectives/mpi_collectives.cc b/third_party/xla/xla/backends/cpu/collectives/mpi_collectives.cc new file mode 100644 index 00000000000000..c368ed986289f3 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/mpi_collectives.cc @@ -0,0 +1,77 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/mpi_collectives.h" + +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "mpi.h" +#include "xla/backends/cpu/collectives/mpi_communicator.h" +#include "xla/core/collectives/clique_id.h" +#include "xla/core/collectives/clique_key.h" +#include "xla/core/collectives/communicator.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +void MpiCollectives::Init() { + int provided; + MPI_Init_thread(nullptr, nullptr, MPI_THREAD_FUNNELED, &provided); + MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank_); + MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size_); + VLOG(1) << "MPI rank=" << mpi_world_rank_ << " size=" << mpi_world_size_; +} + +void MpiCollectives::Finalize() { MPI_Finalize(); } + +absl::StatusOr>> +MpiCollectives::CreateCommunicators(const CliqueKey& clique_key, + const std::optional& clique_id, + absl::Span ranks, + const Config& config) { + int flag; + MPI_Is_thread_main(&flag); + if (!flag) { + return absl::UnknownError( + "MPI: Communicator requested from a thread that is not " + "the one MPI was initialized from. Multiple " + "threads/devices per process are not yet supported."); + } + + std::vector> communicators; + for (auto& device_rank : ranks) { + size_t rank = device_rank.rank.value(); + int color; + int key = 0; + if (clique_key.num_devices() > 0) { + color = static_cast(clique_key.devices().at(0).value()); + key = rank; + } else { + color = MPI_UNDEFINED; + } + communicators.push_back(std::make_unique(color, key)); + } + + return communicators; +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/collectives/mpi_collectives.h b/third_party/xla/xla/backends/cpu/collectives/mpi_collectives.h new file mode 100644 index 00000000000000..702cb05fa4faf3 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/mpi_collectives.h @@ -0,0 +1,66 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_MPI_COLLECTIVES_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_MPI_COLLECTIVES_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/clique_id.h" +#include "xla/core/collectives/clique_key.h" +#include "xla/core/collectives/communicator.h" +#include "xla/service/global_device_id.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +class MpiCollectives : public CpuCollectives { + public: + /* + The user has to explicitly call Init() and Finalize() before and + after use. + For example, using the Python client, this can be achieved with: + + collectives = xla_client._xla.make_mpi_collectives() + collectives.Init() + atexit.register(collectives.Finalize) + */ + void Init(); + void Finalize(); + + absl::StatusOr>> + CreateCommunicators(const CliqueKey& clique_key, + const std::optional& clique_id, + absl::Span ranks, + const Config& config) final; + + private: + absl::Status ExchangeGlobalDeviceIds( + absl::Span global_devices, int rank); + + int mpi_world_rank_; + int mpi_world_size_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_MPI_COLLECTIVES_H_ diff --git a/third_party/xla/xla/pjrt/cpu/mpi_collectives.cc b/third_party/xla/xla/backends/cpu/collectives/mpi_communicator.cc similarity index 52% rename from third_party/xla/xla/pjrt/cpu/mpi_collectives.cc rename to third_party/xla/xla/backends/cpu/collectives/mpi_communicator.cc index d2c93fd75450f5..0062593da75407 100644 --- a/third_party/xla/xla/pjrt/cpu/mpi_collectives.cc +++ b/third_party/xla/xla/backends/cpu/collectives/mpi_communicator.cc @@ -13,37 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/pjrt/cpu/mpi_collectives.h" +#include "xla/backends/cpu/collectives/mpi_communicator.h" -#include #include -#include #include -#include -#include #include -#include -#include -#include #include -#include "mpi.h" // NOLINT #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" #include "absl/types/span.h" +#include "mpi.h" +#include "xla/core/collectives/rank_id.h" #include "xla/primitive_util.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" -#include "xla/service/global_device_id.h" #include "xla/status_macros.h" -#include "xla/types.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla::cpu { @@ -125,57 +115,58 @@ static absl::Status MpiErrorToAbslStatus(int error) { return absl::OkStatus(); } -MpiCollectivesCommunicator::MpiCollectivesCommunicator(int color, int key) { +MpiCommunicator::MpiCommunicator(int color, int key) { MPI_Comm_split(MPI_COMM_WORLD, color, key, &comm_); MPI_Comm_rank(comm_, &mpi_rank_); MPI_Comm_size(comm_, &mpi_size_); } -MpiCollectivesCommunicator::~MpiCollectivesCommunicator() { - MPI_Comm_free(&comm_); -}; +MpiCommunicator::~MpiCommunicator() { MPI_Comm_free(&comm_); }; -absl::Status MpiCollectivesCommunicator::AllReduce( - const RendezvousKey& key, ReductionKind reduction_kind, - PrimitiveType element_type, size_t num_elements, const void* input_buffer, - void* output_buffer, absl::Duration timeout) { - TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(element_type)); +absl::Status MpiCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); - return MpiErrorToAbslStatus(MPI_Allreduce(input_buffer, output_buffer, - num_elements, type, op, comm_)); + return MpiErrorToAbslStatus(MPI_Allreduce( + send_buffer.opaque(), recv_buffer.opaque(), count, type, op, comm_)); } -absl::Status MpiCollectivesCommunicator::CollectivePermute( - const RendezvousKey& key, size_t num_bytes, std::optional source_rank, - absl::Span target_ranks, const void* input_buffer, - void* output_buffer, absl::Duration timeout) { +absl::Status MpiCommunicator::CollectivePermute( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, std::optional source_rank, + absl::Span target_ranks, const Executor& executor) { int tag = 0; // TODO come up with better tags. const int rank = mpi_rank_; std::vector requests; + size_t num_bytes = count * primitive_util::ByteWidth(dtype); + if (source_rank) { - if (*source_rank == rank) { - std::memcpy(output_buffer, input_buffer, num_bytes); + if (source_rank->value() == rank) { + std::memcpy(recv_buffer.opaque(), send_buffer.opaque(), num_bytes); } else { - VLOG(1) << "recv at " << rank << " from " << *source_rank; + VLOG(1) << "recv at " << rank << " from " << source_rank->value(); requests.emplace_back(); TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( - MPI_Irecv(output_buffer, num_bytes, MPI_BYTE, *source_rank, tag, - comm_, &requests.back()))); + MPI_Irecv(recv_buffer.opaque(), num_bytes, MPI_BYTE, + source_rank->value(), tag, comm_, &requests.back()))); } } else { - std::memset(output_buffer, 0, num_bytes); + std::memset(recv_buffer.opaque(), 0, num_bytes); } - for (int target : target_ranks) { + for (RankId target : target_ranks) { if (target != rank) { - VLOG(1) << "send from " << rank << " to " << target; + VLOG(1) << "send from " << rank << " to " << target.value(); requests.emplace_back(); TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( - MPI_Isend(input_buffer, num_bytes, MPI_BYTE, target, tag, comm_, - &requests.back()))); + MPI_Isend(send_buffer.opaque(), num_bytes, MPI_BYTE, target.value(), + tag, comm_, &requests.back()))); } } @@ -187,18 +178,28 @@ absl::Status MpiCollectivesCommunicator::CollectivePermute( return absl::OkStatus(); } -absl::Status MpiCollectivesCommunicator::AllToAll( - const RendezvousKey& key, size_t chunk_bytes, - absl::Span input_buffers, - absl::Span output_buffers, absl::Duration timeout) { +absl::Status MpiCommunicator::AllToAll( + absl::Span send_buffers, + absl::Span recv_buffers, PrimitiveType dtype, + size_t count, const Executor& executor) { // We can't use MPI_Alltoall directly because it assumes that the inputs and // outputs are contiguous. Therefore here we implement it using MPI_Sendrecv. int tag = 0; // TODO use better tags. const int rank = mpi_rank_; const int size = mpi_size_; - TF_RET_CHECK(size == input_buffers.size()); - TF_RET_CHECK(size == output_buffers.size()); + TF_RET_CHECK(size == send_buffers.size()); + TF_RET_CHECK(size == recv_buffers.size()); + + size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); + + std::vector input_buffers; + std::vector output_buffers; + + for (int i = 0; i < size; i++) { + input_buffers.push_back(const_cast(send_buffers[i].opaque())); + output_buffers.push_back(const_cast(recv_buffers[i].opaque())); + } std::memcpy(output_buffers[rank], input_buffers[rank], chunk_bytes); @@ -214,70 +215,28 @@ absl::Status MpiCollectivesCommunicator::AllToAll( return absl::OkStatus(); } -absl::Status MpiCollectivesCommunicator::AllGather(const RendezvousKey& key, - size_t chunk_bytes, - const void* input_buffer, - void* output_buffer, - absl::Duration timeout) { - return MpiErrorToAbslStatus(MPI_Allgather(input_buffer, chunk_bytes, MPI_BYTE, - output_buffer, chunk_bytes, - MPI_BYTE, comm_)); +absl::Status MpiCommunicator::AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); + return MpiErrorToAbslStatus(MPI_Allgather(send_buffer.opaque(), count, type, + recv_buffer.opaque(), count, type, + comm_)); } -absl::Status MpiCollectivesCommunicator::ReduceScatter( - const RendezvousKey& key, ReductionKind reduction_kind, - PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, - void* output_buffer, absl::Duration timeout) { +absl::Status MpiCommunicator::ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) { const int size = mpi_size_; - std::vector recvcounts(size, chunk_elems); - TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(element_type)); + std::vector recvcounts(size, count); + TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); - return MpiErrorToAbslStatus(MPI_Reduce_scatter( - input_buffer, output_buffer, recvcounts.data(), type, op, comm_)); -} - -void MpiCollectives::Init() { - int provided; - MPI_Init_thread(NULL, NULL, MPI_THREAD_FUNNELED, &provided); - MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank_); - MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size_); - VLOG(1) << "MPI rank=" << mpi_world_rank_ << " size=" << mpi_world_size_; -} - -void MpiCollectives::Finalize() { - contexts_.clear(); - MPI_Finalize(); -} - -absl::StatusOr> -MpiCollectives::GetCommunicator(absl::Span global_devices, - int rank) { - int flag; - MPI_Is_thread_main(&flag); - if (!flag) { - return absl::UnknownError( - absl::StrCat("MPI: Communicator requested from a thread that is not " - "the one MPI was initialized from. Multiple " - "threads/devices per process are not yet supported.")); - } - - auto& context = contexts_[std::make_tuple( - std::vector(global_devices.begin(), global_devices.end()), - rank)]; - if (context) { - return context; - } - - int color; - int key = 0; - if (global_devices.size() > 0) { - color = static_cast(global_devices.at(0).value()); - key = rank; - } else { - color = MPI_UNDEFINED; - } - context = std::make_shared(color, key); - return context; + return MpiErrorToAbslStatus( + MPI_Reduce_scatter(send_buffer.opaque(), recv_buffer.opaque(), + recvcounts.data(), type, op, comm_)); } } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/collectives/mpi_communicator.h b/third_party/xla/xla/backends/cpu/collectives/mpi_communicator.h new file mode 100644 index 00000000000000..cfed534b66bd51 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/collectives/mpi_communicator.h @@ -0,0 +1,98 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_MPI_COMMUNICATOR_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_MPI_COMMUNICATOR_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "mpi.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +class MpiCommunicator : public Communicator { + public: + explicit MpiCommunicator(int color, int key); + ~MpiCommunicator() override; + + absl::Status AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + std::optional source_rank, + absl::Span target_ranks, + const Executor& executor) override; + + absl::Status AllToAll(absl::Span send_buffers, + absl::Span recv_buffers, + PrimitiveType dtype, size_t count, + const Executor& executor) override; + absl::Status AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, const Executor& executor) override; + absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Broadcast is not implemented"); + } + + absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Send is not implemented"); + } + + absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Recv is not implemented"); + } + + absl::StatusOr NumRanks() const override { return mpi_size_; } + + std::string ToString() const override { + return absl::StrCat("MpiCommunicator [rank: ", mpi_rank_, + " num_ranks: ", mpi_size_, "]"); + } + + private: + MPI_Comm comm_; + int mpi_rank_; + int mpi_size_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_MPI_COMMUNICATOR_H_ diff --git a/third_party/xla/xla/backends/cpu/nanort/BUILD b/third_party/xla/xla/backends/cpu/nanort/BUILD index 6fbc3573e13054..58730eadef70c1 100644 --- a/third_party/xla/xla/backends/cpu/nanort/BUILD +++ b/third_party/xla/xla/backends/cpu/nanort/BUILD @@ -1,7 +1,6 @@ load("//xla:xla.bzl", "xla_cc_test") load("//xla/backends/cpu/nanort:package_groups.bzl", "xla_cpu_nanort_packages") load("//xla/tsl:tsl.bzl", "internal_visibility") -load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -59,6 +58,7 @@ xla_cc_test( "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -105,3 +105,74 @@ cc_library( "@local_tsl//tsl/profiler/lib:traceme_encode", ], ) + +cc_library( + name = "ifrt_client", + srcs = ["ifrt_client.cc"], + hdrs = ["ifrt_client.h"], + deps = [ + ":nanort_client", + ":nanort_executable", + "//xla:shape_util", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/backends/cpu:alignment", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_sharding", + "//xla/pjrt:mlir_to_hlo", + "//xla/pjrt:pjrt_compiler", + "//xla/pjrt:pjrt_executable", + "//xla/pjrt:pjrt_layout", + "//xla/pjrt:utils", + "//xla/python/ifrt", + "//xla/python/ifrt:attribute_map", + "//xla/python/ifrt/hlo:hlo_program", + "//xla/python/pjrt_ifrt:pjrt_dtype", + "//xla/python/pjrt_ifrt:xla_ifrt", + "//xla/service:hlo_module_config", + "//xla/tsl/concurrency:async_value", + "//xla/tsl/concurrency:ref_count", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:fingerprint", + ], +) + +cc_library( + name = "register_nanort_for_ifrt_tests", + testonly = True, + srcs = ["register_nanort_for_ifrt_tests.cc"], + deps = [ + ":ifrt_client", + "//xla/python/ifrt:test_util", + ], + alwayslink = True, +) + +xla_cc_test( + name = "ifrt_client_test", + srcs = ["ifrt_client_test.cc"], + deps = [ + ":register_nanort_for_ifrt_tests", + "//xla/python/ifrt:array_impl_test_lib", + "//xla/python/ifrt:client_impl_test_lib", + "//xla/python/ifrt:test_util", + "//xla/python/ifrt:tuple_impl_test_lib", + "//xla/python/pjrt_ifrt:xla_executable_impl_test_lib", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:test_main", + ], +) diff --git a/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc b/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc new file mode 100644 index 00000000000000..cf4365656b72f4 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc @@ -0,0 +1,1420 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/nanort/ifrt_client.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/call_once.h" +#include "absl/base/nullability.h" +#include "absl/container/btree_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/backends/cpu/alignment.h" +#include "xla/backends/cpu/nanort/nanort_executable.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/layout.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/utils.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/index.h" +#include "xla/python/ifrt/index_domain.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/ifrt/remap_plan.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/topology.h" +#include "xla/python/ifrt/tuple.h" +#include "xla/python/ifrt/value.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/fingerprint.h" + +namespace xla::cpu { +namespace { + +static const char kMemoryKind[] = ""; + +// Returns a Future that is immediately ready with the given status. This is +// mostly useful because everything NanoRT does is immediately ready. +ifrt::Future<> Ready(absl::Status status = absl::OkStatus()) { + return ifrt::Future<>(std::move(status)); +} + +// Base class for all value types. This class doesn't participate in the llvm +// RTTI hierarchy (you can't dynamically cast to it), rather it just +// implements some virtual methods that have the same implementation for all +// NanoRT value types. +template +class NanoValue : public llvm::RTTIExtends { + public: + explicit NanoValue(NanoIfrtClient* client) : client_(client) {} + + ifrt::Client* client() const override { return client_; } + + // Called by subclasses to get access to client() without having to cast. + NanoIfrtClient* nano_client() const { return client_; } + + // All nano values are immediately ready. + ifrt::Future<> GetReadyFuture() const override { return Ready(); } + + // Subclasses must still implement Delete(). + ifrt::Future<> Delete() override = 0; + bool IsDeleted() const override = 0; + + // Helper that returns an error if this value is accessed after it has been + // deleted. Meant to be called with TF_RETURN_IF_ERROR at the top of + // relevant methods. + absl::Status ValidateNotDeleted() const { + if (IsDeleted()) { + return absl::FailedPreconditionError("Tried to access a deleted value."); + } + return absl::OkStatus(); + } + + private: + NanoIfrtClient* client_; +}; + +// Array implementation. +// +// This class always holds a continuous buffer of memory, if a sharding is +// provided, it will be disassembled as needed to satisfy caller expectations. +// +// See ShardedNanoArray for the case where the array is constructed from +// multiple existing shards. +class NanoArray final : public NanoValue { + public: + // A pointer to the underlying buffer. We use a shared_ptr because for some + // operations (like disassembly) we can just alias the memory, but we still + // need to support deletion of the NanoArray that created the buffer. + using DataPtr = std::shared_ptr; + + NanoArray(NanoIfrtClient* client, ifrt::DType dtype, ifrt::Shape shape, + DataPtr data, std::shared_ptr sharding) + : NanoValue(client), + dtype_(std::move(dtype)), + shape_(std::move(shape)), + data_(std::move(data)), + sharding_(std::move(sharding)) {} + + // Allocates a new array of the given type and shape. + static absl::StatusOr> Allocate( + NanoIfrtClient* client, ifrt::DType dtype, ifrt::Shape shape, + std::shared_ptr sharding) { + TF_RET_CHECK(dtype.byte_size().has_value()); + TF_ASSIGN_OR_RETURN( + DataPtr data_ptr, + AllocateData(dtype.byte_size().value() * shape.num_elements())); + return tsl::TakeRef(new NanoArray(client, dtype, shape, std::move(data_ptr), + std::move(sharding))); + } + + // Creates an array from a host buffer. The buffer will be used directly + // without a copy if the copy semantics allow it and the layout is row major + // and dense. + static absl::StatusOr> FromBuffer( + NanoIfrtClient* client, void* data, ifrt::DType dtype, ifrt::Shape shape, + std::shared_ptr sharding, + std::optional> byte_strides, bool make_copy, + std::function on_done_with_host_buffer) { + auto size = dtype.byte_size().value_or(0) * shape.num_elements(); + TF_RET_CHECK(size > 0); + DataPtr data_ptr; + if (!on_done_with_host_buffer) { + on_done_with_host_buffer = [] {}; + } + bool layout_compatible = LayoutCompatible(dtype, shape, byte_strides); + bool aligned = reinterpret_cast(data) % Align() == 0; + + if (!layout_compatible || !aligned) { + // Input is not aligned, or has a weird layout, so we need to copy it. + make_copy = true; + } + + if (make_copy) { + TF_ASSIGN_OR_RETURN(data_ptr, AllocateData(size)); + if (layout_compatible) { + // Input has a compatible layout, so we can just do a memcpy. + memcpy(data_ptr.get(), data, size); + } else { + // Input has an incompatible layout, so we need to copy it with an + // appropriate stride. + TF_ASSIGN_OR_RETURN(auto dense_strides, DenseByteStrides(dtype, shape)); + TF_RETURN_IF_ERROR(CopyWithByteStrides( + reinterpret_cast(data_ptr.get()), dense_strides, + reinterpret_cast(data), + byte_strides.value_or(dense_strides), shape.dims(), + dtype.byte_size().value())); + } + // We're done with the input buffer, so we can allow the caller to clean + // it up. + on_done_with_host_buffer(); + } else { + // We're allowed to keep the input buffer, and it's dense and row major, + // so we can just use it directly. + data_ptr = DataPtr(data, [done = std::move(on_done_with_host_buffer)]( + void* ptr) { done(); }); + } + TF_RET_CHECK(data_ptr != nullptr); + return tsl::TakeRef(new NanoArray(client, dtype, shape, std::move(data_ptr), + std::move(sharding))); + } + + const DataPtr& data() const { return data_; } + + // Copies a sub-array of the given size from src to dst. The dst array must + // already be allocated and of the correct type and shape. Values outside of + // the specified sub-array of dst will be left untouched. + // + // This is mostly intended to support sharding and assembling. + static absl::Status CopySubArray(NanoArray& dst, + absl::Span dst_loc, + NanoArray& src, + absl::Span src_loc, + absl::Span size) { + // Make sure the arrays are the same type and the type is supported. + TF_RET_CHECK(dst.dtype() == src.dtype()); + TF_RET_CHECK(dst.dtype().byte_size().has_value()); + + // Make sure all the dims are compatible. + TF_RET_CHECK(dst.shape().dims().size() == size.size()); + TF_RET_CHECK(src.shape().dims().size() == size.size()); + TF_RET_CHECK(dst.shape().dims().size() == size.size()); + TF_RET_CHECK(dst_loc.size() == size.size()); + TF_RET_CHECK(src_loc.size() == size.size()); + + // Make sure what we're copying is within the bounds of the arrays. + for (size_t i = 0; i < size.size(); ++i) { + TF_RET_CHECK(dst_loc[i] + size[i] <= dst.shape().dims()[i]); + TF_RET_CHECK(src_loc[i] + size[i] <= src.shape().dims()[i]); + } + + int64_t element_size = dst.dtype().byte_size().value(); + + // Returns the size of a row in bytes for the given shape. + auto row_size = [=](absl::Span shape) { + if (shape.empty()) return element_size; // Scalar. + return shape.back() * element_size; + }; + + // Since this is always row major, we can do one memcpy per row, and rows + // will always be evenly spaces within the arrays. + int64_t src_row_stride = row_size(src.shape().dims()); + int64_t dst_row_stride = row_size(dst.shape().dims()); + int64_t copy_row_size = row_size(size); + + // How many rows do we have to copy? + int64_t copy_num_rows = 1; + for (int64_t i = 0; i + 1 < size.size(); ++i) { + copy_num_rows *= size[i]; + } + + // Returns a pointer to the given position in the array. + auto get_row_ptr = [&](NanoArray& array, + absl::Span position) -> std::byte* { + size_t offset = 0; + size_t stride = 1; + for (int i = position.size() - 1; i >= 0; --i) { + offset += stride * position[i]; + stride *= array.shape().dims()[i]; + } + offset *= element_size; + return static_cast(array.data().get()) + offset; + }; + + // Get the pointers to the start of the rows we're copying. + std::byte* dst_row_start = get_row_ptr(dst, dst_loc); + std::byte* src_row_start = get_row_ptr(src, src_loc); + + // Copy the rows. + for (int64_t i = 0; i < copy_num_rows; ++i) { + memcpy(dst_row_start, src_row_start, copy_row_size); + dst_row_start += dst_row_stride; + src_row_start += src_row_stride; + } + return absl::OkStatus(); + } + + absl::StatusOr>> Disassemble() { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + if (sharding().IsFullyReplicated()) { + if (sharding().devices()->size() == 1) { + // Only one device and one shard, so we can just return a reference to + // this array. + return std::vector>{tsl::FormRef(this)}; + } + + // If the array is fully replicated and there are multiple "devices", we + // need to make one "copy" per device. + std::vector> shards; + shards.reserve(sharding().devices()->size()); + for (auto* device : sharding().devices()->devices()) { + auto one_device_sharding = ifrt::SingleDeviceSharding::Create( + device, sharding().memory_kind()); + shards.push_back( + tsl::TakeRef(new NanoArray(nano_client(), dtype_, shape_, data_, + std::move(one_device_sharding)))); + } + return shards; + } + + // The array is sharded, copy the appropriate sub-arrays. + TF_ASSIGN_OR_RETURN(auto index_domains, sharding().IndexDomains(shape())); + TF_RET_CHECK(index_domains.size() == sharding().devices()->size()); + std::vector> shards; + shards.reserve(index_domains.size()); + for (int i = 0; i < index_domains.size(); ++i) { + const auto& index_domain = index_domains[i]; + auto* device = sharding().devices()->devices()[i]; + auto one_device_sharding = + ifrt::SingleDeviceSharding::Create(device, sharding().memory_kind()); + TF_ASSIGN_OR_RETURN( + auto shard, + NanoArray::Allocate(nano_client(), dtype(), index_domain.shape(), + std::move(one_device_sharding))); + TF_RETURN_IF_ERROR(NanoArray::CopySubArray( + // To the origin of this shard. + *shard, ifrt::Index::Zeros(shape().dims().size()).elements(), + // From the assembled array. + *this, index_domain.origin().elements(), + // The in the shape of this shard. + index_domain.shape().dims())); + shards.push_back(std::move(shard)); + } + return shards; + } + + NanoRtExecutable::Argument AsArgument() { + return NanoRtExecutable::Argument( + reinterpret_cast(data_.get()), + dtype_.byte_size().value() * shape_.num_elements()); + } + + NanoRtExecutable::Result AsResult() { + return NanoRtExecutable::Result( + reinterpret_cast(data_.get()), + dtype_.byte_size().value() * shape_.num_elements()); + } + + std::string DebugString() const override { + return absl::StrCat("NanoArray(", dtype_.DebugString(), ", ", + shape_.DebugString(), ", @", + reinterpret_cast(data_.get()), ")"); + } + + ifrt::Future<> Delete() override { + data_ = nullptr; + return Ready(); + } + + bool IsDeleted() const override { return data_ == nullptr; } + + ifrt::DType dtype() const override { return dtype_; } + + const ifrt::Shape& shape() const override { return shape_; } + + const ifrt::Sharding& sharding() const override { return *sharding_; } + + absl::Nonnull> shared_ptr_sharding() + const override { + return sharding_; + } + + absl::StatusOr> layout() const override { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + return std::make_shared(xla::Layout(shape().dims())); + } + + absl::StatusOr>> + DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics semantics) override { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + TF_ASSIGN_OR_RETURN(auto shards, Disassemble()); + return std::vector>(shards.begin(), shards.end()); + } + + absl::StatusOr>> + DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics array_copy_semantics, + ifrt::SingleDeviceShardSemantics single_device_shard_semantics) override { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + return DisassembleIntoSingleDeviceArrays(array_copy_semantics); + } + + absl::StatusOr> FullyReplicatedShard( + ifrt::ArrayCopySemantics semantics) override { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + return tsl::FormRef(this); + } + + ifrt::Future<> CopyToHostBuffer( + void* data, std::optional> byte_strides, + ifrt::ArrayCopySemantics semantics) override { + // Run everything in a lambda so we can use error macros and convert to a + // future once. + return Ready([&] { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + TF_ASSIGN_OR_RETURN(xla::PrimitiveType xla_dtype, + ifrt::ToPrimitiveType(dtype())); + if (!byte_strides.has_value() || + xla::HasMajorToMinorLayout(xla_dtype, shape().dims(), + *byte_strides)) { + memcpy(data, data_.get(), + dtype().byte_size().value() * shape().num_elements()); + } else { + TF_ASSIGN_OR_RETURN(auto in_strides, + DenseByteStrides(dtype(), shape())); + TF_RETURN_IF_ERROR(CopyWithByteStrides( + reinterpret_cast(data), *byte_strides, + reinterpret_cast(data_.get()), in_strides, + shape().dims(), dtype().byte_size().value())); + } + return absl::OkStatus(); + }()); + } + + static char ID; // NOLINT + + private: + // Returns true if the given data type, shape, and strides are compatible + // with NanoArray (we can either use this memory directly or memcpy it into + // our own memory). + static bool LayoutCompatible( + ifrt::DType dtype, const ifrt::Shape& shape, + std::optional> byte_strides) { + if (!dtype.byte_size().has_value()) { + return false; + } + auto xla_dtype = ifrt::ToPrimitiveType(dtype); + if (!xla_dtype.ok()) { + return false; + } + if (!byte_strides.has_value()) { + return true; + } + return xla::HasMajorToMinorLayout(*xla_dtype, shape.dims(), *byte_strides); + } + + // Returns the byte strides for a dense array with the given type and shape. + static absl::StatusOr> DenseByteStrides( + ifrt::DType dtype, ifrt::Shape shape) { + TF_ASSIGN_OR_RETURN(xla::PrimitiveType xla_dtype, + ifrt::ToPrimitiveType(dtype)); + auto xla_shape = xla::ShapeUtil::MakeShape(xla_dtype, shape.dims()); + auto strides = xla::ShapeUtil::ByteStrides(xla_shape); + if (!strides.has_value()) { + return absl::InvalidArgumentError(absl::StrCat( + "Couldn't compute byte strides for shape:", xla_shape.ToString())); + } + return std::move(*strides); + } + + // Allocates an aligned buffer of the given size. + static absl::StatusOr AllocateData(size_t size) { + DataPtr data_ptr(aligned_alloc(Align(), std::max(size, Align())), + [](void* ptr) { free(ptr); }); + if (data_ptr == nullptr) { + return absl::InternalError(absl::StrCat( + "Failed to allocate memory for NanoArray. Errno: ", strerror(errno))); + } + return data_ptr; + } + + // Copies data between two buffers that represent the same shape but have + // different byte strides. This is a recursive method that peels back dims + // until we get to a scalar, which isn't very efficient but the common case + // is expected to be a row major array without padding. + static absl::Status CopyWithByteStrides( + std::byte* dst, absl::Span dst_byte_strides, + const std::byte* src, absl::Span src_byte_strides, + absl::Span dims, int64_t elem_size) { + TF_RET_CHECK(dims.size() == dst_byte_strides.size()); + TF_RET_CHECK(dims.size() == src_byte_strides.size()); + // Scalar. Just copy it. + if (dims.empty()) { + memcpy(dst, src, elem_size); + return absl::OkStatus(); + } + // Peel back dims recursively until we get to a scalar. + for (int64_t i = 0; i < dims[0]; ++i) { + TF_RETURN_IF_ERROR(CopyWithByteStrides(dst, dst_byte_strides.subspan(1), + src, src_byte_strides.subspan(1), + dims.subspan(1), elem_size)); + dst += dst_byte_strides[0]; + src += src_byte_strides[0]; + } + return absl::OkStatus(); + } + + ifrt::DType dtype_; + ifrt::Shape shape_; + DataPtr data_; + std::shared_ptr sharding_; +}; + +char NanoArray::ID = 'A'; // NOLINT + +// Sharded array implementation. Represents an array that should be assembled +// from multiple arrays, but we aren't sure how to assemble it yet. +class ShardedNanoArray final : public NanoValue { + public: + // Creates an array from the given shards. Note that if we can assemble the + // array using the given sharding, this method will return a NanoArray. + static absl::StatusOr> FromShards( + NanoIfrtClient* client, ifrt::Shape shape, + std::shared_ptr sharding, + std::vector> shards) { + if (shards.empty()) { + return absl::InvalidArgumentError( + "Can't create a sharded array with no shards."); + } + xla::ifrt::DType dtype = shards[0]->dtype(); + + auto array = tsl::TakeRef(new ShardedNanoArray( + client, dtype, shape, sharding, std::move(shards))); + + // Try to eagerly assemble the array. Sometimes this cannot be done + // because arrays are loaded with a simple per device sharding and we + // won't know how to assemble it until the program is run. + if (auto dense_array = array->Assemble(sharding); dense_array.ok()) { + return dense_array; + } + + // If we can't assemble the array, we'll just return the sharded array. It + // will be assembled at execution time when we know the actual sharding. + return array; + } + + const std::vector>& shards() { return shards_; } + + // Assembles the array using the given sharding to prepare it as an input to + // execution. If this array has already been assembled using the given + // sharding, this method will return the cached result. This optimizes a + // common case where a checkpoint is loaded with an unknown sharding, but + // then we find the real sharding when the program is run. + absl::StatusOr> AssembleForExecution( + std::shared_ptr sharding) { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + absl::call_once(assemble_once_, [this, sharding]() { + assemble_result_ = Assemble(sharding); + }); + TF_RETURN_IF_ERROR(assemble_result_.status()); + if (assemble_result_.value()->shared_ptr_sharding() != sharding) { + // Bleh... We cached the wrong sharding somehow. This means one sharded + // array was an input to two different programs with different + // shardings, this should be unlikely. + return Assemble(sharding); + } + return assemble_result_; + } + + ifrt::Future<> Delete() override { + // Sharded arrays are never borrowed like dense arrays are, so we can just + // clear the shards and let them be destroyed. + shards_.clear(); + assemble_result_ = absl::Status(absl::StatusCode::kUnavailable, ""); + return Ready(); + } + + bool IsDeleted() const override { return shards_.empty(); } + + std::string DebugString() const override { + auto result = + absl::StrCat("ShardedNanoArray(", dtype_.DebugString(), ", ", + shape_.DebugString(), ", ", sharding_->DebugString()); + for (const auto& shard : shards_) { + absl::StrAppend(&result, ", ", shard->DebugString()); + } + absl::StrAppend(&result, ")"); + return result; + } + + ifrt::DType dtype() const override { return dtype_; } + + const ifrt::Shape& shape() const override { return shape_; } + + const ifrt::Sharding& sharding() const override { return *sharding_; } + + absl::Nonnull> shared_ptr_sharding() + const override { + return sharding_; + } + + absl::StatusOr> layout() const override { + return std::make_shared(xla::Layout(shape().dims())); + } + + absl::StatusOr>> + DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics semantics) override { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + return std::vector>(shards_.begin(), shards_.end()); + } + + absl::StatusOr>> + DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics array_copy_semantics, + ifrt::SingleDeviceShardSemantics single_device_shard_semantics) override { + return DisassembleIntoSingleDeviceArrays(array_copy_semantics); + } + + absl::StatusOr> FullyReplicatedShard( + ifrt::ArrayCopySemantics semantics) override { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + return tsl::FormRef(this); + } + + ifrt::Future<> CopyToHostBuffer( + void* data, std::optional> byte_strides, + ifrt::ArrayCopySemantics semantics) override { + return Ready( + absl::InternalError("Cannot copy sharded array to host buffer.")); + } + + static char ID; // NOLINT + + private: + ShardedNanoArray(NanoIfrtClient* client, ifrt::DType dtype, ifrt::Shape shape, + std::shared_ptr sharding, + std::vector> shards) + : NanoValue(client), + dtype_(std::move(dtype)), + shape_(std::move(shape)), + sharding_(std::move(sharding)), + shards_(std::move(shards)) {} + + absl::StatusOr> Assemble( + std::shared_ptr sharding) { + TF_ASSIGN_OR_RETURN(auto index_domains, sharding->IndexDomains(shape())); + if (index_domains.size() != shards_.size()) { + return absl::FailedPreconditionError( + absl::StrCat("Number of index domains ", index_domains.size(), + " not equal to number of arrays ", shards_.size())); + } + + for (int i = 0; i < index_domains.size(); ++i) { + if (index_domains[i].shape() != shards_[i]->shape()) { + return absl::FailedPreconditionError(absl::StrCat( + "Index domain ", index_domains[i].shape().DebugString(), + " not equal to array shape ", shards_[i]->shape().DebugString())); + } + } + + // If the sharding is replicated in any way, this comparator will dedupe + // arrays that have the same logical destination. + struct IndexDomainCmp { + bool operator()(const ifrt::IndexDomain& a, + const ifrt::IndexDomain& b) const { + return std::lexicographical_compare( + a.origin().elements().begin(), a.origin().elements().end(), + b.origin().elements().begin(), b.origin().elements().end()); + } + }; + + // Index the arrays by where we are copying them to. Note that this will + // implicitly filter out replicated shards since they will have the same + // destination in the assembled array. + absl::btree_map + index_domain_device_arrays; + for (int i = 0; i < index_domains.size(); ++i) { + index_domain_device_arrays[index_domains[i]] = shards_[i].get(); + } + + TF_ASSIGN_OR_RETURN(auto result, NanoArray::Allocate(nano_client(), dtype(), + shape(), sharding)); + + // Copy the shards into the final array. + auto shard_origin = ifrt::Index::Zeros(shards_[0]->shape().dims().size()); + for (const auto& [index_domain, shard] : index_domain_device_arrays) { + TF_RETURN_IF_ERROR(NanoArray::CopySubArray( + *result, index_domain.origin().elements(), *shard, + shard_origin.elements(), shard->shape().dims())); + } + + return result; + } + + ifrt::DType dtype_; + ifrt::Shape shape_; + std::shared_ptr sharding_; + std::vector> shards_; + + absl::once_flag assemble_once_; + absl::StatusOr> assemble_result_; +}; + +char ShardedNanoArray::ID = 'A'; // NOLINT + +// Tuple implementation. +class NanoTuple final : public NanoValue { + public: + explicit NanoTuple(NanoIfrtClient* client, + absl::Span> values) + : NanoValue(client), + values_(values.begin(), values.end()) {} + + ifrt::Future<> Delete() override { + for (auto& value : values_) { + value->Delete(); + } + values_.clear(); + deleted_ = true; + return Ready(); + } + + bool IsDeleted() const override { + for (auto& value : values_) { + if (value->IsDeleted()) { + return true; + } + } + return deleted_; + } + + // Returns the arity of the tuple. + int Arity() override { return values_.size(); } + + // Unpacks the tuple into its constituent pieces. + absl::Status Unpack( + absl::Span> values) override { + TF_RETURN_IF_ERROR(ValidateNotDeleted()); + if (values.size() != values_.size()) { + return absl::InvalidArgumentError( + absl::StrCat("Tuple arity mismatch: expected ", values_.size(), + ", got ", values.size())); + } + for (int i = 0; i < values_.size(); ++i) { + values[i] = values_[i]; + } + return absl::OkStatus(); + } + + std::string DebugString() const override { + std::string result = "NanoTuple("; + for (const auto& value : values_) { + absl::StrAppend(&result, value->DebugString(), ", "); + } + absl::StrAppend(&result, ")"); + return result; + } + + static char ID; // NOLINT + + private: + bool deleted_ = false; + std::vector> values_; +}; + +char NanoTuple::ID = 'T'; // NOLINT + +// Executable implementation. +class NanoExecutable final + : public llvm::RTTIExtends { + public: + // Creates a NanoExecutable from an ifrt::Program. + static absl::StatusOr> Create( + NanoIfrtClient* client, std::unique_ptr program) { + auto* xla_program = llvm::dyn_cast(program.get()); + if (xla_program == nullptr) { + return absl::InvalidArgumentError("NanoRT requires an HloProgram"); + } + XlaComputation computation; + TF_RETURN_IF_ERROR(MlirToXlaComputation(xla_program->mlir_module, + computation, false, true, false)); + TF_ASSIGN_OR_RETURN(auto nano_executable, + client->nano_client()->Compile(computation)); + + if (computation.proto().computations().size() != 1) { + return absl::InvalidArgumentError( + absl::StrCat("NanoRT only supports single-computation programs, got ", + computation.proto().computations().size())); + } + + TF_ASSIGN_OR_RETURN(auto program_shape, computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(auto proto_input_shardings, + GetInputShardings(program_shape, computation)); + TF_ASSIGN_OR_RETURN(auto proto_output_shardings, + GetOutputShardings(program_shape, computation)); + auto input_shardings = + IfrtShardingsFromProto(client, proto_input_shardings); + auto output_shardings = + IfrtShardingsFromProto(client, proto_output_shardings); + + return absl::WrapUnique(new NanoExecutable( + client, std::move(computation), std::move(program_shape), + std::move(nano_executable), std::move(input_shardings), + std::move(output_shardings))); + } + + ifrt::Client* client() const override { return client_; } + + absl::string_view name() const override { return program_.name(); } + + absl::StatusOr Execute( + absl::Span> args, + const ExecuteOptions& options, + std::optional> devices) override { + if (args.size() != input_shardings_.size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Number of arguments ", args.size(), + " is not what executable expects ", input_shardings_.size())); + } + + // Convert the ifrt arrays to nano arrays. 'tmp' holds any arrays that had + // to be assembled. + std::vector> tmp; + TF_ASSIGN_OR_RETURN(auto nano_args, + NanoArgumentsFromIfrtArguments(args, tmp)); + + TF_ASSIGN_OR_RETURN(auto result_arrays, AllocateResults()); + std::vector nano_results; + nano_results.reserve(result_arrays.size()); + for (auto& result_array : result_arrays) { + nano_results.push_back( + llvm::dyn_cast(result_array.get())->AsResult()); + } + + auto event = executable_->Execute(nano_args, nano_results, + NanoRtExecutable::PreallocatedTemp{}); + + // TODO(jsoyke): Consider making this non-blocking if we ever use this + // interface for models that require threading, or if we want to delay + // execution until we know where the outputs will be stored. + tsl::BlockUntilReady(event); + + if (event.IsError()) return event.GetError(); + if (!event.IsConcrete()) { + return absl::InternalError("NanoRT result is not concrete."); + } + + ExecuteResult result; + if (options.fill_status) { + result.status = Ready(); + } + result.outputs = std::move(result_arrays); + return result; + } + + // Returns a fingerprint of this executable. + absl::StatusOr> Fingerprint() const override { + return absl::UnimplementedError("Fingerprint is not implemented."); + } + + absl::StatusOr Serialize() const override { + return absl::UnimplementedError("Serialize is not implemented."); + } + + ifrt::Future<> GetReadyFuture() const override { return Ready(); } + + int num_devices() const override { return 1; } + + int64_t SizeOfGeneratedCodeInBytes() const override { return 0; } + + absl::StatusOr GetCompiledMemoryStats() const override { + return absl::UnimplementedError( + "GetCompiledMemoryStats is not implemented."); + } + + std::optional> GetParameterShardings() + const override { + auto shardings = GetInputShardings(program_shape_, program_); + if (!shardings.ok()) return std::nullopt; + return *shardings; + } + + std::optional> GetOutputShardings() const override { + auto shardings = GetOutputShardings(program_shape_, program_); + if (!shardings.ok()) return std::nullopt; + return *shardings; + } + + absl::StatusOr>> + GetParameterLayouts() const override { + std::vector> layouts; + layouts.reserve(program_shape_.parameters().size()); + for (const auto& shape : program_shape_.parameters()) { + layouts.push_back( + std::make_shared(xla::Layout(shape.dimensions()))); + } + return layouts; + } + + absl::StatusOr>> + GetOutputLayouts() const override { + const auto& result_shape = program_shape_.result(); + const auto result_shapes = + result_shape.IsTuple() + ? absl::MakeConstSpan(result_shape.tuple_shapes()) + : absl::MakeConstSpan(&result_shape, 1); + std::vector> layouts; + layouts.reserve(result_shapes.size()); + for (const auto& shape : result_shapes) { + layouts.push_back( + std::make_shared(xla::Layout(shape.dimensions()))); + } + return layouts; + } + + absl::StatusOr>> GetHloModules() + const override { + std::vector> hlo_modules(1); + TF_ASSIGN_OR_RETURN( + hlo_modules[0], + HloModule::CreateFromProto(program_.proto(), HloModuleConfig())); + return hlo_modules; + } + + absl::StatusOr>> + GetOutputMemoryKinds() const override { + std::vector> memory_kinds; + memory_kinds.reserve(output_shardings_.size()); + for (const auto& _ : output_shardings_) { + memory_kinds.push_back({kMemoryKind}); + } + return memory_kinds; + } + + absl::StatusOr GetCostAnalysis() const override { + return absl::UnimplementedError("GetCostAnalysis is not implemented."); + } + + ifrt::Future<> Delete() override { + client_ = nullptr; + program_ = {}; + program_shape_ = {}; + executable_.reset(); + input_shardings_.clear(); + output_shardings_.clear(); + return Ready(); + } + + bool IsDeleted() const override { return executable_ == nullptr; } + + absl::Span addressable_devices() const override { + return client_->addressable_devices(); + } + + static char ID; // NOLINT + + private: + NanoExecutable(NanoIfrtClient* client, XlaComputation program, + ProgramShape program_shape, + std::unique_ptr executable, + std::vector> input_shardings, + std::vector> output_shardings) + : client_(client), + program_(std::move(program)), + program_shape_(std::move(program_shape)), + executable_(std::move(executable)), + input_shardings_(std::move(input_shardings)), + output_shardings_(std::move(output_shardings)) {} + + // Converts an OpSharding proto (from an HLO Instruction) to an ifrt + // sharding. + static std::vector> IfrtShardingsFromProto( + NanoIfrtClient* client, absl::Span shardings) { + std::vector> result; + result.reserve(shardings.size()); + for (const auto& sharding : shardings) { + if (sharding.type() == OpSharding::REPLICATED || + sharding.type() == OpSharding::MAXIMAL) { + result.push_back(client->default_sharding()); + continue; + } + int num_tiles = 1; + for (const auto dim : sharding.tile_assignment_dimensions()) { + num_tiles *= dim; + } + // Repeat the device for each tile. We only have one device anyway so + // just used the first. + auto device_list = ifrt::BasicDeviceList::Create( + ifrt::BasicDeviceList::Devices(num_tiles, client->devices()[0])); + auto xla_sharding = *HloSharding::FromProto(sharding); + result.push_back(ifrt::HloSharding::Create( + std::move(device_list), client->devices()[0]->Memories()[0]->Kind(), + std::move(xla_sharding))); + } + return result; + } + + static absl::StatusOr> GetInputShardings( + const ProgramShape& program_shape, const XlaComputation& computation) { + std::vector shardings(program_shape.parameters().size()); + for (const auto& instruction : + computation.proto().computations(0).instructions()) { + if (instruction.opcode() == "parameter" && instruction.has_sharding()) { + if (instruction.parameter_number() >= shardings.size()) { + return absl::InvalidArgumentError( + absl::StrCat("Parameter number ", instruction.parameter_number(), + " is out of range for program with ", + program_shape.parameters().size(), " parameters.")); + } + shardings[instruction.parameter_number()] = instruction.sharding(); + } + } + return shardings; + } + + static absl::StatusOr> GetOutputShardings( + const ProgramShape& program_shape, const XlaComputation& computation) { + const auto& result_shape = program_shape.result(); + + int output_id = computation.proto().computations(0).root_id(); + + std::vector shardings( + (result_shape.IsTuple() ? result_shape.tuple_shapes().size() : 1)); + + for (const auto& instruction : + computation.proto().computations(0).instructions()) { + // We found a sharded output instruction. + if (instruction.id() == output_id && instruction.has_sharding()) { + if (result_shape.IsTuple()) { + TF_RET_CHECK(instruction.sharding().tuple_shardings().size() == + result_shape.tuple_shapes().size()); + for (int i = 0; i < instruction.sharding().tuple_shardings().size(); + ++i) { + shardings[i] = instruction.sharding().tuple_shardings()[i]; + } + } else { + shardings[0] = instruction.sharding(); + } + } + } + return shardings; + } + + // Allocates the results for the program. + absl::StatusOr>> AllocateResults() { + const auto& result_shape = program_shape_.result(); + const auto result_shapes = + result_shape.IsTuple() + ? absl::MakeConstSpan(result_shape.tuple_shapes()) + : absl::MakeConstSpan(&result_shape, 1); + TF_RET_CHECK(result_shapes.size() == output_shardings_.size()); + + std::vector> result_arrays; + result_arrays.reserve(result_shapes.size()); + + for (int i = 0; i < result_shapes.size(); ++i) { + TF_ASSIGN_OR_RETURN(auto ifrt_type, + ifrt::ToDType(result_shapes[i].element_type())); + ifrt::Shape ifrt_shape(result_shapes[i].dimensions()); + TF_ASSIGN_OR_RETURN(auto array, + NanoArray::Allocate(client_, ifrt_type, ifrt_shape, + output_shardings_[i])); + result_arrays.push_back(std::move(array)); + } + return result_arrays; + } + + // Converts the ifrt arrays to nano arguments. 'tmp' holds any arrays that + // had to be assembled. + absl::StatusOr> + NanoArgumentsFromIfrtArguments( + absl::Span> args, + std::vector>& tmp) { + std::vector nano_args; + nano_args.reserve(args.size()); + + for (int i = 0; i < args.size(); ++i) { + auto* nano_array = llvm::dyn_cast_or_null(args[i].get()); + if (nano_array == nullptr) { + // The input isn't a nano array, so it must be a sharded array. + auto* sharded_array = + llvm::dyn_cast_or_null(args[i].get()); + if (sharded_array == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Argument is not a NanoArray or ShardedNanoArray: ", + args[i]->DebugString())); + } + TF_ASSIGN_OR_RETURN( + auto dense_array, + sharded_array->AssembleForExecution(input_shardings_[i])); + nano_array = dense_array.get(); + tmp.push_back(std::move(dense_array)); + } + nano_args.push_back(nano_array->AsArgument()); + } + + return nano_args; + } + + NanoIfrtClient* client_; + XlaComputation program_; + ProgramShape program_shape_; + std::unique_ptr executable_; + std::vector> input_shardings_; + std::vector> output_shardings_; +}; + +char NanoExecutable::ID = 'E'; // NOLINT + +// Compiler implementation. +class NanoCompiler final + : public llvm::RTTIExtends { + public: + explicit NanoCompiler(NanoIfrtClient* client) : client_(client) {} + + absl::StatusOr> Compile( + std::unique_ptr program, + std::unique_ptr options) override { + return NanoExecutable::Create(client_, std::move(program)); + } + + absl::StatusOr> Compile( + std::unique_ptr program, const ifrt::Topology& topology, + std::unique_ptr options) override { + return absl::UnimplementedError("Partial compilation is not implemented."); + } + + absl::StatusOr> + DeserializeLoadedExecutable( + absl::string_view serialized, + std::unique_ptr options) override { + return absl::UnimplementedError( + "DeserializeLoadedExecutable is not implemented."); + } + static char ID; // NOLINT + + private: + NanoIfrtClient* client_; +}; + +char NanoCompiler::ID = 'C'; // NOLINT + +// Memory implementation. There is only one address space so this doesn't do +// much. +class NanoMemory final : public llvm::RTTIExtends { + public: + explicit NanoMemory(NanoIfrtClient* client) : client_(client) {} + + ifrt::MemoryId Id() const override { return ifrt::MemoryId(0); } + + const ifrt::MemoryKind& Kind() const override { + static ifrt::MemoryKind mem_kind(kMemoryKind); + return mem_kind; + } + + absl::string_view ToString() const override { return "NanoRT CPU Memory"; } + absl::string_view DebugString() const override { return ToString(); } + absl::Span Devices() const override { + return client_->devices(); + } + + static char ID; // NOLINT + + private: + NanoMemory() = default; + + NanoIfrtClient* client_; +}; + +char NanoMemory::ID = 'M'; // NOLINT + +// Device implementation. There is only one device so this doesn't do much. +class NanoDevice final : public llvm::RTTIExtends { + public: + NanoDevice(NanoIfrtClient* client, ifrt::Memory* memory) + : client_(client), memory_(memory) {} + + ifrt::Client* client() const override { return client_; } + + ifrt::DeviceId Id() const override { return ifrt::DeviceId(0); } + + const ifrt::AttributeMap& Attributes() const override { + static auto attributes = new ifrt::AttributeMap({}); + return *attributes; + } + + absl::string_view Kind() const override { return "cpu"; } + + absl::string_view ToString() const override { return "NanoRT CPU"; } + + absl::string_view DebugString() const override { return ToString(); } + + absl::StatusOr DefaultMemory() const override { + return memory_; + } + + absl::Span Memories() const override { + return absl::MakeConstSpan(&memory_, 1); + } + + bool IsAddressable() const override { return true; } + + int ProcessIndex() const override { return 0; } + + static char ID; // NOLINT + + private: + NanoIfrtClient* client_; + ifrt::Memory* memory_; +}; + +char NanoDevice::ID = 'D'; // NOLINT + +} // namespace + +NanoIfrtClient::~NanoIfrtClient() = default; + +std::shared_ptr NanoIfrtClient::Create() { + return CreateWithDevices(1); +} + +std::shared_ptr NanoIfrtClient::CreateWithDevices( + int num_devices) { + return std::shared_ptr(new NanoIfrtClient(num_devices)); +} + +std::shared_ptr NanoIfrtClient::default_sharding() const { + return ifrt::SingleDeviceSharding::Create(device_.get(), ifrt::MemoryKind{}); +} + +absl::StatusOr> +NanoIfrtClient::MakeArrayFromHostBuffer( + const void* data, ifrt::DType dtype, ifrt::Shape shape, + std::optional> byte_strides, + absl::Nonnull> sharding, + HostBufferSemantics semantics, + std::function on_done_with_host_buffer) { + bool make_copy = false; + switch (semantics) { + case HostBufferSemantics::kImmutableUntilTransferCompletes: + case HostBufferSemantics::kImmutableOnlyDuringCall: + make_copy = true; + break; + case HostBufferSemantics::kImmutableZeroCopy: + case HostBufferSemantics::kMutableZeroCopy: + make_copy = false; + break; + } + return NanoArray::FromBuffer(this, const_cast(data), dtype, shape, + std::move(sharding), byte_strides, make_copy, + on_done_with_host_buffer); +} + +absl::StatusOr> +NanoIfrtClient::AssembleArrayFromSingleDeviceArrays( + ifrt::Shape shape, + absl::Nonnull> sharding, + absl::Span> arrays, + ifrt::ArrayCopySemantics semantics) { + std::vector> nano_arrays; + nano_arrays.reserve(arrays.size()); + for (const auto& array : arrays) { + auto* nano_array = llvm::dyn_cast_or_null(array.get()); + if (nano_array == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Array is not a NanoArray: ", array->DebugString())); + } + nano_arrays.push_back(tsl::FormRef(nano_array)); + } + return ShardedNanoArray::FromShards(this, shape, sharding, + std::move(nano_arrays)); +} + +absl::StatusOr> +NanoIfrtClient::AssembleArrayFromSingleDeviceArrays( + ifrt::Shape shape, + absl::Nonnull> sharding, + absl::Span> arrays, + ifrt::ArrayCopySemantics array_copy_semantics, + ifrt::SingleDeviceShardSemantics single_device_shard_semantics) { + return AssembleArrayFromSingleDeviceArrays(shape, sharding, arrays, + array_copy_semantics); +} + +absl::StatusOr>> +NanoIfrtClient::CopyArrays( + absl::Span> arrays, + std::optional> devices, + std::optional memory_kind, + ifrt::ArrayCopySemantics semantics) { + std::vector> result; + result.reserve(arrays.size()); + for (const auto& array : arrays) { + tsl::RCReference copy; + TF_ASSIGN_OR_RETURN(auto sharding, array->sharding().WithDeviceAssignment( + devices, memory_kind)); + if (auto nano_array = llvm::dyn_cast_or_null(array.get())) { + copy = tsl::TakeRef(new NanoArray(this, nano_array->dtype(), + nano_array->shape(), nano_array->data(), + std::move(sharding))); + } else if (auto sharded_nano_array = + llvm::dyn_cast_or_null(array.get())) { + std::vector> shards_copy; + shards_copy.reserve(sharded_nano_array->shards().size()); + for (const auto& shard : sharded_nano_array->shards()) { + shards_copy.push_back(tsl::TakeRef( + new NanoArray(this, shard->dtype(), shard->shape(), shard->data(), + shard->shared_ptr_sharding()))); + } + TF_ASSIGN_OR_RETURN( + copy, ShardedNanoArray::FromShards(this, sharded_nano_array->shape(), + std::move(sharding), + std::move(shards_copy))); + } else { + return absl::InvalidArgumentError( + absl::StrCat("Array is not a NanoArray or ShardedNanoArray: ", + array->DebugString())); + } + TF_RET_CHECK(copy != nullptr); + result.push_back(copy); + } + return result; +} + +absl::StatusOr>> +NanoIfrtClient::RemapArrays( + const ifrt::RemapPlan& plan, + absl::Span> arrays, + ifrt::ArrayCopySemantics semantics) { + return absl::UnimplementedError("RemapArrays is not implemented."); +} + +ifrt::Future<> NanoIfrtClient::GetReadyFuture( + absl::Span> values) { + return Ready(); +} + +absl::StatusOr> NanoIfrtClient::MakeTuple( + absl::Span> values) { + return tsl::MakeRef(this, std::move(values)); +} + +absl::string_view NanoIfrtClient::runtime_type() const { return "nano"; } + +absl::string_view NanoIfrtClient::platform_name() const { + return xla::CpuName(); +} + +absl::string_view NanoIfrtClient::platform_version() const { + return xla::CpuName(); +} + +ifrt::PlatformId NanoIfrtClient::platform_id() const { + return tsl::Fingerprint64(platform_name()); +} + +const ifrt::AttributeMap& NanoIfrtClient::Attributes() const { + static auto attributes = new ifrt::AttributeMap({}); + return *attributes; +} + +int NanoIfrtClient::device_count() const { return devices_.size(); } + +int NanoIfrtClient::addressable_device_count() const { return device_count(); } + +absl::Span NanoIfrtClient::devices() const { + return devices_; +} + +absl::Span NanoIfrtClient::addressable_devices() const { + return devices(); +} + +int NanoIfrtClient::process_index() const { return 0; } + +absl::Span NanoIfrtClient::GetAllDevices() const { + return devices(); +} + +absl::StatusOr +NanoIfrtClient::GetDefaultDeviceAssignment(int num_replicas, + int num_partitions) const { + return ifrt::DeviceAssignment(1, 1); +} + +absl::StatusOr NanoIfrtClient::LookupDevice( + ifrt::DeviceId device_id) const { + return LookupAddressableDevice(device_id.value()); +} + +absl::StatusOr NanoIfrtClient::LookupAddressableDevice( + int local_hardware_id) const { + return device_.get(); +} + +ifrt::Compiler* NanoIfrtClient::GetDefaultCompiler() { return compiler_.get(); } + +absl::StatusOr> +NanoIfrtClient::GetTopologyForDevices( + const tsl::RCReference& devices) const { + return absl::UnimplementedError("GetTopologyForDevices is not implemented."); +} + +absl::StatusOr> +NanoIfrtClient::GetDefaultLayout(ifrt::DType dtype, + absl::Span dims, + ifrt::Device* device, + xla::ifrt::MemoryKind memory_kind) const { + return std::make_shared(xla::Layout(dims)); +} + +NanoIfrtClient::NanoIfrtClient(int32_t num_devices) + : compiler_(std::make_unique(this)), + memory_(std::make_unique(this)), + device_(std::make_unique(this, memory_.get())), + default_sharding_( + ifrt::SingleDeviceSharding::Create(device_.get(), memory_->Kind())), + devices_(num_devices, device_.get()) {} + +char NanoIfrtClient::ID = 'N'; // NOLINT + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/nanort/ifrt_client.h b/third_party/xla/xla/backends/cpu/nanort/ifrt_client.h new file mode 100644 index 00000000000000..96530d62bdb1bf --- /dev/null +++ b/third_party/xla/xla/backends/cpu/nanort/ifrt_client.h @@ -0,0 +1,197 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_NANORT_IFRT_CLIENT_H_ +#define XLA_BACKENDS_CPU_NANORT_IFRT_CLIENT_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "xla/backends/cpu/nanort/nanort_client.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/remap_plan.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/topology.h" +#include "xla/python/ifrt/tuple.h" +#include "xla/python/ifrt/value.h" +#include "xla/tsl/concurrency/ref_count.h" + +namespace xla::cpu { + +// NanoIfrtClient is a thin wrapper around NanoRtClient that implements the +// ifrt::Client interface. +// +// Unlike NanoRtClient, this class will honor sharding annotations in XLA +// programs, mostly to satisfy IFRT callers. The sharding will be undone as soon +// as possible and reused (either when the sharded arrays is assembled or when +// it is first accessed by an executable). Even so, this client will have much +// better performance with unsharded inputs. +// +// Note: Array remapping is currently unimplemented. +// +// Note: We may add support for callers to access the underlying executables and +// buffers directly in the future, this would allow the "load path" that +// initializes programs and variables to be reused while still getting the +// performance wins of NanoRt at execution time. +class NanoIfrtClient : public llvm::RTTIExtends { + public: + ~NanoIfrtClient() override; + + // Creates a client with a single device. Typically this is how this client + // should be used. + static std::shared_ptr Create(); + + // Creates a client with the given number of devices, this is provided for + // testing and to allow the client to be used in applications that expect + // programs to be sharded. + static std::shared_ptr CreateWithDevices(int32_t num_devices); + + // Returns a single device sharding. Generally callers should prefer to use + // this when possible for optimal performance. + std::shared_ptr default_sharding() const; + + // Returns the underlying NanoRtClient. + NanoRtClient* nano_client() { return &client_; } + + using HostBufferSemantics = xla::ifrt::Client::HostBufferSemantics; + + // Creates an array from a host buffer. The buffer will be used directly + // without a copy if the copy semantics allow it and the layout is row major + // and dense. + absl::StatusOr> MakeArrayFromHostBuffer( + const void* data, ifrt::DType dtype, ifrt::Shape shape, + std::optional> byte_strides, + absl::Nonnull> sharding, + HostBufferSemantics semantics, + std::function on_done_with_host_buffer) override; + + // Assembles a sharded array from a list of single device arrays. If the + // provided sharding is specific enough to assemble a dense array, this method + // will actually return an assembled array that pretends it is sharded. + // + // Otherwise we will produce an assembled array on demand when it is first + // accessed by an XLA program. + absl::StatusOr> + AssembleArrayFromSingleDeviceArrays( + ifrt::Shape shape, + absl::Nonnull> sharding, + absl::Span> arrays, + ifrt::ArrayCopySemantics semantics) override; + absl::StatusOr> + AssembleArrayFromSingleDeviceArrays( + ifrt::Shape shape, + absl::Nonnull> sharding, + absl::Span> arrays, + ifrt::ArrayCopySemantics array_copy_semantics, + ifrt::SingleDeviceShardSemantics single_device_shard_semantics) override; + + absl::StatusOr>> CopyArrays( + absl::Span> arrays, + std::optional> devices, + std::optional memory_kind, + ifrt::ArrayCopySemantics semantics) override; + + absl::StatusOr>> RemapArrays( + const ifrt::RemapPlan& plan, + absl::Span> arrays, + ifrt::ArrayCopySemantics semantics) override; + + ifrt::Future<> GetReadyFuture( + absl::Span> values) override; + + absl::StatusOr> MakeTuple( + absl::Span> values) override; + + absl::string_view runtime_type() const override; + + absl::string_view platform_name() const override; + absl::string_view platform_version() const override; + ifrt::PlatformId platform_id() const override; + + const ifrt::AttributeMap& Attributes() const override; + + int device_count() const override; + int addressable_device_count() const override; + absl::Span devices() const override; + absl::Span addressable_devices() const override; + int process_index() const override; + + absl::Span GetAllDevices() const override; + + absl::StatusOr GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const override; + absl::StatusOr LookupDevice( + ifrt::DeviceId device_id) const override; + absl::StatusOr LookupAddressableDevice( + int local_hardware_id) const override; + + ifrt::Compiler* GetDefaultCompiler() override; + + absl::StatusOr> GetTopologyForDevices( + const tsl::RCReference& devices) const override; + + absl::StatusOr> GetDefaultLayout( + ifrt::DType dtype, absl::Span dims, ifrt::Device* device, + xla::ifrt::MemoryKind memory_kind) const override; + + static char ID; // NOLINT + + private: + explicit NanoIfrtClient(int32_t num_devices); + + // The underlying NanoRtClient. + NanoRtClient client_; + + // The compiler, memory, and device objects. See cc file for implementation + // details. + std::unique_ptr compiler_; + std::unique_ptr memory_; + std::unique_ptr device_; + + // The default sharding for this client. When this sharding is used it + // typically means that we can use an array's contents directly. + std::shared_ptr default_sharding_; + + // Some of the ifrt::Client methods return a span of devices, so we need to + // keep storage for them here. Note that this may repeat the device_ pointer + // multiple times if this client is configured with multiple devices. This is + // mostly to make IFRT callers that expect sharded programs to run on multiple + // devices happy. This has the unusual property that we have multiple devices + // but a single device_id, but this seems to work fine and most documentation + // warns that devices may be repeated within a device list or sharding. + std::vector devices_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_NANORT_IFRT_CLIENT_H_ diff --git a/third_party/xla/xla/backends/cpu/nanort/ifrt_client_test.cc b/third_party/xla/xla/backends/cpu/nanort/ifrt_client_test.cc new file mode 100644 index 00000000000000..efe24079a9016a --- /dev/null +++ b/third_party/xla/xla/backends/cpu/nanort/ifrt_client_test.cc @@ -0,0 +1,34 @@ +/* Copyright 2023 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#include "absl/strings/string_view.h" +#include "xla/python/ifrt/test_util.h" + +// For now, all of the tests we run are provided by IFRT, they use +// NanoIfrtClient via the "register_nanort_for_ifrt_tests" target, which can +// also be used to run NanoIfrtClient in other tests. see the BUILD file for the +// list. We need a main function to filter out one test that doesn't seem worth +// supporting. + +int main(int argc, char** argv) { + // This test expects copies to multiple devices to fail, but we only have one + // device and it doesn't seem worth pretending that we have more. + static constexpr absl::string_view kFilter = + "-ArrayImplTest.CopyMixedSourceDevices"; + xla::ifrt::test_util::SetTestFilterIfNotUserSpecified(kFilter); + + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/third_party/xla/xla/backends/cpu/nanort/nanort_client_test.cc b/third_party/xla/xla/backends/cpu/nanort/nanort_client_test.cc index 992a8b51137847..50b4d521de81cb 100644 --- a/third_party/xla/xla/backends/cpu/nanort/nanort_client_test.cc +++ b/third_party/xla/xla/backends/cpu/nanort/nanort_client_test.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include #include -#include #include #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/backends/cpu/nanort/nanort_executable.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" @@ -46,7 +46,7 @@ using Arguments = absl::InlinedVector; using Results = absl::InlinedVector; TEST(NanoRtClientTest, CompileAndRunScalarComputation) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule add ENTRY e { @@ -80,7 +80,7 @@ TEST(NanoRtClientTest, CompileAndRunScalarComputation) { } TEST(NanoRtClientTest, CompileAndRunTupledComputation) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule add_and_mul ENTRY e { @@ -119,7 +119,7 @@ TEST(NanoRtClientTest, CompileAndRunTupledComputation) { } TEST(NanoRtClientTest, CompileAndRunConstantComputation) { - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule cst ENTRY e { @@ -149,7 +149,7 @@ TEST(NanoRtClientTest, CompileAndRunConstantComputation) { } TEST(NanoRtClientTest, CompileAndRunConditionalComputation) { - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule conditional %add (x: f32[]) -> f32[] { diff --git a/third_party/xla/xla/backends/cpu/nanort/register_nanort_for_ifrt_tests.cc b/third_party/xla/xla/backends/cpu/nanort/register_nanort_for_ifrt_tests.cc new file mode 100644 index 00000000000000..b804c257f79be5 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/nanort/register_nanort_for_ifrt_tests.cc @@ -0,0 +1,29 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/nanort/ifrt_client.h" +#include "xla/python/ifrt/test_util.h" + +namespace xla::cpu { +namespace { + +// Link this in to use the NanoIfrtClient as the default IFRT client for tests. +// IFRT tests expect the client to have multiple devices. +const bool kUnused = (ifrt::test_util::RegisterClientFactory( + [] { return NanoIfrtClient::CreateWithDevices(4); }), + true); + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/package_groups.bzl b/third_party/xla/xla/backends/cpu/package_groups.bzl new file mode 100644 index 00000000000000..c5a3ffb5c88435 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/package_groups.bzl @@ -0,0 +1,8 @@ +"""Package groups for XLA:CPU backend internal access.""" + +# Integrations should use PJRT as the API to access XLA. +def xla_cpu_backend_access(name = "xla_cpu_backend_access"): + native.package_group( + name = "xla_backend_cpu_internal_access", + packages = ["//..."], + ) diff --git a/third_party/xla/xla/backends/cpu/runtime/BUILD b/third_party/xla/xla/backends/cpu/runtime/BUILD index c48abc6020e125..279e0163918008 100644 --- a/third_party/xla/xla/backends/cpu/runtime/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/BUILD @@ -55,14 +55,15 @@ xla_cc_test( srcs = ["buffer_allocations_test.cc"], deps = [ ":buffer_allocations", + ":thunk_testlib", + "//xla:literal_util", "//xla/service:buffer_assignment", - "//xla/service:maybe_owning_device_memory", "//xla/stream_executor:device_memory", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -126,7 +127,6 @@ cc_library( hdrs = ["function_library.h"], deps = [ ":kernel_c_api", - "//xla:util", "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", @@ -144,21 +144,23 @@ cc_library( ":resource_use", "//xla:executable_run_options", "//xla:util", + "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/backends/cpu/collectives:in_process_collectives", + "//xla/core/collectives", "//xla/ffi:execution_context", "//xla/runtime:buffer_use", "//xla/service:global_device_id", - "//xla/service/cpu:collectives_interface", "//xla/service/cpu:cpu_executable_run_options", "//xla/service/cpu:cpu_runtime", - "//xla/service/cpu:in_process_collectives", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/profiler/lib:traceme", "@local_tsl//tsl/profiler/lib:traceme_encode", ], @@ -167,13 +169,19 @@ cc_library( cc_library( name = "thunk_testlib", testonly = 1, + srcs = ["thunk_testlib.cc"], hdrs = ["thunk_testlib.h"], deps = [ + ":buffer_allocations", ":resource_use", ":thunk", + "//xla:literal", "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", ], ) @@ -183,13 +191,13 @@ xla_cc_test( deps = [ ":thunk", "//xla:executable_run_options", - "//xla/service/cpu:collectives_interface", + "//xla/backends/cpu/collectives:cpu_collectives", "//xla/service/cpu:cpu_executable_run_options", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -201,13 +209,17 @@ cc_library( deps = [ ":resource_use", ":thunk", + "//xla:util", "//xla/runtime:buffer_use", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:env", + "//xla/tsl/platform:logging", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -215,7 +227,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -229,24 +241,26 @@ xla_cc_test( ":thread_pool_task_runner", ":thunk", ":thunk_executor", + ":thunk_testlib", + "//xla:literal", + "//xla:literal_util", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", - "//xla/service:maybe_owning_device_memory", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/platform:test_main", ], ) @@ -301,10 +315,10 @@ xla_cc_test( "//xla/service:maybe_owning_device_memory", "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -319,10 +333,10 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:cpu_collectives", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", - "//xla/service/cpu:collectives_interface", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", @@ -420,17 +434,19 @@ xla_cc_test( ":buffer_allocations", ":convolution_thunk", ":thunk", + ":thunk_testlib", + "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla/service:buffer_assignment", - "//xla/service:maybe_owning_device_memory", - "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -445,13 +461,16 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:cpu_collectives", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", - "//xla/service/cpu:collectives_interface", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -476,11 +495,13 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:cpu_collectives", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", - "//xla/service/cpu:collectives_interface", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", @@ -488,8 +509,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", @@ -507,10 +526,10 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:cpu_collectives", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", - "//xla/service/cpu:collectives_interface", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", @@ -538,11 +557,12 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/core/collectives:rank_id", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", - "//xla/service/cpu:collectives_interface", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", @@ -571,14 +591,21 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:cpu_clique_key", + "//xla/backends/cpu/collectives:cpu_cliques", + "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", "//xla/service:global_device_id", - "//xla/service/cpu:collectives_interface", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", @@ -588,9 +615,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", ], ) @@ -628,14 +652,15 @@ xla_cc_test( ":buffer_allocations", ":copy_thunk", ":thunk", + ":thunk_testlib", + "//xla:literal_util", "//xla:shape_util", "//xla/service:buffer_assignment", - "//xla/service:maybe_owning_device_memory", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", ], ) @@ -678,6 +703,38 @@ cc_library( ], ) +cc_library( + name = "dot_lib", + srcs = ["dot_lib.cc"], + hdrs = ["dot_lib.h"], + deps = [ + ":thunk", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/stream_executor:device_memory", + "//xla/tsl/concurrency:async_value", + "//xla/tsl/framework/contraction:eigen_contraction_kernel", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", + ], +) + cc_library( name = "dot_thunk", srcs = [ @@ -691,6 +748,7 @@ cc_library( ], hdrs = ["dot_thunk.h"], deps = [ + ":dot_lib", ":thunk", "//xla:shape_util", "//xla:types", @@ -701,6 +759,7 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "//xla/tsl/framework/contraction:eigen_contraction_kernel", + "//xla/tsl/platform:logging", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", @@ -792,16 +851,16 @@ xla_cc_test( ":buffer_allocations", ":logical_id_thunk", ":thunk", + ":thunk_testlib", "//xla:executable_run_options", + "//xla:literal_util", "//xla/service:buffer_assignment", - "//xla/service:maybe_owning_device_memory", - "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -862,26 +921,27 @@ cc_library( ":kernel_c_api", ":thunk", "//xla:util", + "//xla/backends/cpu/codegen:llvm_ir_kernel_spec", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/stream_executor:device_memory", "//xla/stream_executor:launch_dim", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", ], @@ -896,17 +956,17 @@ xla_cc_test( ":kernel_c_api", ":kernel_thunk", ":thunk", + ":thunk_testlib", + "//xla:literal_util", "//xla/service:buffer_assignment", - "//xla/service:maybe_owning_device_memory", - "//xla/stream_executor:device_memory", "//xla/stream_executor:launch_dim", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -971,6 +1031,9 @@ cc_library( "//xla/service:buffer_assignment", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -984,9 +1047,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -999,17 +1059,21 @@ xla_cc_test( ":function_library", ":sort_thunk", ":thunk", + ":thunk_testlib", + "//xla:literal", + "//xla:literal_util", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/service:buffer_assignment", - "//xla/service:maybe_owning_device_memory", - "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/platform:test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", ], ) @@ -1045,16 +1109,16 @@ xla_cc_test( ":thunk", ":thunk_testlib", ":while_thunk", + "//xla:literal_util", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", - "//xla/service:maybe_owning_device_memory", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:env", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.cc index fa55bbc48dbffc..82847710d0b75f 100644 --- a/third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.cc @@ -24,11 +24,11 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/collective_thunk.h" #include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -76,12 +76,14 @@ tsl::AsyncValueRef AllGatherThunk::Execute( return ExecuteWithCommunicator( params.collective_params, - [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + [&](const RendezvousKey& key, Communicator& comm) { + CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); + for (int32_t i = 0; i < data.source.size(); ++i) { const Shape& shape = source_shape(i); TF_RETURN_IF_ERROR(comm.AllGather( - key, ShapeUtil::ByteSizeOf(shape), data.source[i].opaque(), - data.destination[i].opaque(), DefaultCollectiveTimeout())); + data.source[i], data.destination[i], shape.element_type(), + ShapeUtil::ElementsIn(shape), executor)); } return absl::OkStatus(); }); diff --git a/third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.cc index a5d9d283867c2d..9c6ac2ead41620 100644 --- a/third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.cc @@ -21,23 +21,23 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" -#include "absl/types/span.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/collective_thunk.h" #include "xla/backends/cpu/runtime/thunk.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/errors.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" @@ -101,13 +101,13 @@ tsl::AsyncValueRef AllReduceThunk::Execute( return ExecuteWithCommunicator( params.collective_params, - [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + [&](const RendezvousKey& key, Communicator& comm) { + CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); for (int32_t i = 0; i < data.source.size(); ++i) { const Shape& shape = destination_shape(i); TF_RETURN_IF_ERROR(comm.AllReduce( - key, reduction_kind_, shape.element_type(), - ShapeUtil::ElementsIn(shape), data.source[i].opaque(), - data.destination[i].opaque(), DefaultCollectiveTimeout())); + data.source[i], data.destination[i], shape.element_type(), + ShapeUtil::ElementsIn(shape), reduction_kind_, executor)); } return absl::OkStatus(); }); diff --git a/third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.cc index 8badd0c4e7e232..b97ff3409deecc 100644 --- a/third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.cc @@ -23,16 +23,16 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/collective_thunk.h" #include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" @@ -75,24 +75,13 @@ tsl::AsyncValueRef AllToAllThunk::Execute( return ExecuteWithCommunicator( params.collective_params, - [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + [&](const RendezvousKey& key, Communicator& comm) { + CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); const Shape& shape = destination_shape(0); - absl::InlinedVector input_buffers; - input_buffers.reserve(data.source.size()); - for (int i = 0; i < data.source.size(); ++i) { - input_buffers.push_back(data.source[i].opaque()); - } - - absl::InlinedVector output_buffers; - output_buffers.reserve(data.destination.size()); - for (int i = 0; i < data.destination.size(); ++i) { - output_buffers.push_back(data.destination[i].opaque()); - } - - TF_RETURN_IF_ERROR(comm.AllToAll(key, ShapeUtil::ByteSizeOf(shape), - input_buffers, output_buffers, - DefaultCollectiveTimeout())); + TF_RETURN_IF_ERROR( + comm.AllToAll(data.source, data.destination, shape.element_type(), + ShapeUtil::ElementsIn(shape), executor)); return absl::OkStatus(); }); diff --git a/third_party/xla/xla/backends/cpu/runtime/buffer_allocations_test.cc b/third_party/xla/xla/backends/cpu/runtime/buffer_allocations_test.cc index c92be6205ac910..bcaa241e89136b 100644 --- a/third_party/xla/xla/backends/cpu/runtime/buffer_allocations_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/buffer_allocations_test.cc @@ -15,58 +15,48 @@ limitations under the License. #include "xla/backends/cpu/runtime/buffer_allocations.h" -#include -#include - +#include "xla/backends/cpu/runtime/thunk_testlib.h" +#include "xla/literal_util.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/maybe_owning_device_memory.h" #include "xla/stream_executor/device_memory.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla::cpu { namespace { TEST(BufferAllocationsTest, GetDeviceAddress) { - std::vector buffers; - std::vector data = {1.0, 2.0, 3.0, 4.0}; - - size_t size_in_bytes = data.size() * sizeof(float); - buffers.emplace_back(se::DeviceMemoryBase(data.data(), size_in_bytes)); + auto data = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0}); - BufferAllocations allocations(buffers); + BufferAllocation alloc = CreateBufferAllocation(0, data); + BufferAllocation::Slice slice = CreateBufferAllocationSlice( + alloc, /*offset=*/2 * sizeof(float), /*size=*/sizeof(float)); - BufferAllocation alloc(0, size_in_bytes, 0); - BufferAllocation::Slice slice(&alloc, /*offset=*/2 * sizeof(float), - /*size=*/sizeof(float)); + BufferAllocations allocations = CreateBufferAllocations(data); TF_ASSERT_OK_AND_ASSIGN(se::DeviceMemoryBase alloc_mem, allocations.GetDeviceAddress(0)); - EXPECT_EQ(alloc_mem.opaque(), &data[0]); + EXPECT_EQ(alloc_mem.opaque(), &data.data()[0]); TF_ASSERT_OK_AND_ASSIGN(se::DeviceMemoryBase slice_mem, allocations.GetDeviceAddress(slice)); - EXPECT_EQ(slice_mem.opaque(), &data[2]); + EXPECT_EQ(slice_mem.opaque(), &data.data()[2]); } TEST(BufferAllocationsTest, GetDeviceAddressUnchecked) { - std::vector buffers; - std::vector data = {1.0, 2.0, 3.0, 4.0}; - - size_t size_in_bytes = data.size() * sizeof(float); - buffers.emplace_back(se::DeviceMemoryBase(data.data(), size_in_bytes)); + auto data = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0}); - BufferAllocations allocations(buffers); + BufferAllocation alloc = CreateBufferAllocation(0, data); + BufferAllocation::Slice slice = CreateBufferAllocationSlice( + alloc, /*offset=*/2 * sizeof(float), /*size=*/sizeof(float)); - BufferAllocation alloc(0, size_in_bytes, 0); - BufferAllocation::Slice slice(&alloc, /*offset=*/2 * sizeof(float), - /*size=*/sizeof(float)); + BufferAllocations allocations = CreateBufferAllocations(data); se::DeviceMemoryBase alloc_mem = allocations.GetDeviceAddressUnchecked(0); - EXPECT_EQ(alloc_mem.opaque(), &data[0]); + EXPECT_EQ(alloc_mem.opaque(), &data.data()[0]); se::DeviceMemoryBase slice_mem = allocations.GetDeviceAddressUnchecked(slice); - EXPECT_EQ(slice_mem.opaque(), &data[2]); + EXPECT_EQ(slice_mem.opaque(), &data.data()[2]); } } // namespace diff --git a/third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.cc index a830c0f7fd4ea1..3e46d388a5f671 100644 --- a/third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -25,15 +26,17 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/collective_thunk.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/core/collectives/rank_id.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" @@ -83,12 +86,12 @@ CollectivePermuteThunk::Execute(const ExecuteParams& params) { : logical_id.replica_id; // Find replicas that we will communicate with. - std::optional source_replica_id; - std::vector copy_to; + std::optional source_replica_id; + std::vector copy_to; for (auto& [from, to] : source_target_pairs_) { if (from == logical_device_id) { - copy_to.push_back(to); + copy_to.push_back(RankId(to)); } if (to == logical_device_id) { TF_RET_CHECK(!source_replica_id.has_value()) @@ -98,6 +101,10 @@ CollectivePermuteThunk::Execute(const ExecuteParams& params) { } } + auto rank_fmt = [](std::string* out, RankId rank) { + absl::StrAppend(out, rank.value()); + }; + VLOG(3) << absl::StreamFormat( "CollectivePermute: #source_buffers=%d, #destination_buffers=%d, " "source_target_pairs=[%s], logical_device_id=%d (%s), " @@ -106,7 +113,8 @@ CollectivePermuteThunk::Execute(const ExecuteParams& params) { absl::StrJoin(source_target_pairs_, ", ", absl::PairFormatter("->")), logical_device_id, op_params().has_channel_id ? "computation id" : "replica id", - source_replica_id.value_or(-1), absl::StrJoin(copy_to, ",")); + source_replica_id.value_or(RankId(-1)).value(), + absl::StrJoin(copy_to, ",", rank_fmt)); for (int i = 0; i < data.source.size(); ++i) { VLOG(3) << absl::StreamFormat( @@ -122,13 +130,15 @@ CollectivePermuteThunk::Execute(const ExecuteParams& params) { return ExecuteWithCommunicator( params.collective_params, - [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + [&](const RendezvousKey& key, Communicator& comm) { + CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); + for (int32_t i = 0; i < data.source.size(); ++i) { const Shape& shape = source_shape(i); TF_RETURN_IF_ERROR(comm.CollectivePermute( - key, ShapeUtil::ByteSizeOf(shape), source_replica_id, copy_to, - data.source[i].opaque(), data.destination[i].opaque(), - DefaultCollectiveTimeout())); + data.source[i], data.destination[i], shape.element_type(), + ShapeUtil::ElementsIn(shape), source_replica_id, copy_to, + executor)); } return absl::OkStatus(); }); diff --git a/third_party/xla/xla/backends/cpu/runtime/collective_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/collective_thunk.cc index 4bebdd09cd31c1..35a6f72fb9671d 100644 --- a/third_party/xla/xla/backends/cpu/runtime/collective_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/collective_thunk.cc @@ -32,23 +32,27 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/backends/cpu/collectives/cpu_cliques.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/resource_use.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/service/global_device_id.h" #include "xla/shape.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" namespace xla::cpu { @@ -172,7 +176,7 @@ CollectiveThunk::ExecuteWithCommunicator( TF_RET_CHECK(params) << "Collective parameters are not set for collective operation"; - CollectivesInterface* collectives = params->collectives; + CpuCollectives* collectives = params->collectives; TF_RET_CHECK(collectives) << "Collectives interface is not set for collective operation"; @@ -183,8 +187,10 @@ CollectiveThunk::ExecuteWithCommunicator( VLOG(3) << absl::StreamFormat(" rank=%d, key=%s", rank, key.ToString()); - TF_ASSIGN_OR_RETURN(std::shared_ptr communicator, - collectives->GetCommunicator(key.global_devices, rank)); + CpuCliqueKey clique_key(key.global_devices); + TF_ASSIGN_OR_RETURN( + Communicator * communicator, + AcquireCommunicator(collectives, clique_key, RankId(rank))); TF_RETURN_IF_ERROR(callback(key, *communicator)); diff --git a/third_party/xla/xla/backends/cpu/runtime/collective_thunk.h b/third_party/xla/xla/backends/cpu/runtime/collective_thunk.h index 8efc767838806d..e226f7ab3834b6 100644 --- a/third_party/xla/xla/backends/cpu/runtime/collective_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/collective_thunk.h @@ -31,7 +31,6 @@ limitations under the License. #include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/service/global_device_id.h" #include "xla/shape.h" #include "xla/stream_executor/device_memory.h" @@ -86,8 +85,8 @@ class CollectiveThunk : public Thunk { protected: // Callback for collective thunk implementations. - using Callback = absl::AnyInvocable; + using Callback = absl::AnyInvocable; static bool IsDataTypeSupportedByCollectiveReduce(PrimitiveType datatype); diff --git a/third_party/xla/xla/backends/cpu/runtime/conditional_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/conditional_thunk_test.cc index a5222a8de6bb3d..589273b87977d7 100644 --- a/third_party/xla/xla/backends/cpu/runtime/conditional_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/conditional_thunk_test.cc @@ -25,8 +25,8 @@ limitations under the License. #include "xla/backends/cpu/runtime/thunk_testlib.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla::cpu { namespace { diff --git a/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_internal.h b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_internal.h index 0b78a1cffb26fc..fa8cbdab6eff3c 100644 --- a/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_internal.h +++ b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_internal.h @@ -17,9 +17,10 @@ limitations under the License. #define XLA_BACKENDS_CPU_RUNTIME_CONVOLUTION_THUNK_INTERNAL_H_ #include +#include #include #include -#include +#include #include "xla/backends/cpu/runtime/concurrency.h" #include "xla/tsl/framework/convolution/eigen_spatial_convolutions.h" // IWYU pragma: keep @@ -32,17 +33,20 @@ limitations under the License. namespace xla::cpu::internal { +constexpr auto kMaxConvMatrixSize = static_cast(8) << 30; // 8 GiB + // Returns in 'out_data' (assumes to be zero-initialized) image patch in storage -// order (width, height, depth), constructed from patches in 'col_data', which -// is required to be in storage order (in_width * in_height, filter_width, -// filter_height, in_depth). Based on TF implementation by Yangqing Jia (jiayq). +// order (width, height, depth), constructed from patches in 'conv_matrix', +// which is required to be in storage order (in_width * in_height, filter_width, +// filter_height, out_depth). +// Based on TF implementation by Yangqing Jia (jiayq). // TODO(adambanas): The original implementation implicitly rotates the kernel by // 180 degrees, but to be backwards compatible, we cannot do that in XLA. This -// results in counterintuitive operations on col_data, which is also 15-20% +// results in counterintuitive operations on conv_matrix, which is also 15-20% // slower. Try alternative approaches (e.g. rotate kernel before matrix // multiplication in the calling function). template -void Pack2DPatches(const T* col_data, const int depth, const int height, +void Pack2DPatches(const T* conv_matrix, const int depth, const int height, const int width, const int filter_h, const int filter_w, const int pad_top, const int pad_bottom, const int pad_left, const int pad_right, const int stride_h, const int stride_w, @@ -55,7 +59,7 @@ void Pack2DPatches(const T* col_data, const int depth, const int height, const int filter_spatial_size = filter_h * filter_w; int w_patch_begin = pad_left - filter_w + 1; - col_data += depth * (filter_spatial_size - 1); + conv_matrix += depth * (filter_spatial_size - 1); for (int w = 0; w < w_patches_number; ++w) { int h_patch_begin = pad_top - filter_h + 1; for (int h = 0; h < h_patches_number; ++h) { @@ -73,17 +77,17 @@ void Pack2DPatches(const T* col_data, const int depth, const int height, // in the output buffer, at all depths if (iw >= 0 && iw < width && ih >= 0 && ih < height) { for (int i = 0; i < depth; ++i) { - out_im_patch_data[i] += col_data[i]; + out_im_patch_data[i] += conv_matrix[i]; } } out_im_patch_data += depth; - col_data -= depth; + conv_matrix -= depth; } // Jump over remaining number of depth. out_im_patch_data += depth * (height - filter_h); } - col_data += 2 * depth * filter_spatial_size; + conv_matrix += 2 * depth * filter_spatial_size; h_patch_begin += stride_h; } w_patch_begin += stride_w; @@ -96,7 +100,7 @@ void Pack2DPatches(const T* col_data, const int depth, const int height, // Explore these alternatives. // TODO(adambanas): Add support for feature group count. template -void EigenTransposedConv2D( +bool EigenTransposedConv2D( const EigenDevice& device, ScalarType* out, ScalarType* lhs, ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, Eigen::Index input_y, Eigen::Index input_channels, Eigen::Index kernel_x, @@ -106,17 +110,18 @@ void EigenTransposedConv2D( Eigen::Index padding_y_before, Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, Eigen::Index lhs_y_dilation, std::function done_callback, bool use_thunk_runtime) { - // TODO(adambanas): Current custom conv algorithm doesn't support both - // multiple input channels and multiple output channels (i.e. kernel_filters) - // at the same time. - CHECK(input_channels == 1 || kernel_filters == 1); - - typedef Eigen::TensorMap, - Eigen::Unaligned> - TensorMap; - typedef Eigen::TensorMap, - Eigen::Aligned> - ConstTensorMap; + // Grouped convolutions are not supported yet. + CHECK(kernel_channels == input_channels); + + using TensorMap2D = + Eigen::TensorMap, + Eigen::Unaligned>; + using ConstTensorMap3D = + Eigen::TensorMap, + Eigen::Aligned>; + using ConstTensorMap2D = + Eigen::TensorMap, + Eigen::Aligned>; // Total spatial dimensions. const int input_image_size = input_x * input_y; @@ -124,10 +129,19 @@ void EigenTransposedConv2D( // Kernel dimensions per input channel. const int kernel_total_size = kernel_x * kernel_y * kernel_filters; - // Intermediate buffer - std::vector col_buffer; - col_buffer.resize(input_batch * input_image_size * kernel_total_size); - ScalarType* col_buffer_data = col_buffer.data(); + // Intermediate buffer (convolution matrix) + const size_t buffer_size = input_batch * input_image_size * kernel_total_size; + if (buffer_size * sizeof(ScalarType) > kMaxConvMatrixSize) { + LOG(WARNING) + << "Falling back to generic convolution implementation, because custom " + "transposed convolution algorithm needs too much memory (" + << buffer_size * sizeof(ScalarType) + << " bytes, exceeding the threshold of " << kMaxConvMatrixSize + << " bytes)."; + return false; + } + auto conv_matrix = std::make_unique(buffer_size); + ScalarType* conv_matrix_data = conv_matrix.get(); // Initialize output to zero. ScalarType* out_data = out; @@ -135,17 +149,17 @@ void EigenTransposedConv2D( out_data + input_batch * output_image_size * kernel_filters, ScalarType(0.0f)); - // Initialize contraction dims (we need to transpose 'B' below). - Eigen::array, 1> contract_dims; - contract_dims[0].first = 1; - contract_dims[0].second = 1; + // Initialize contraction dims (we need to transpose 'B' below, the dimension + // we need to contract is 'kernel_channels'). + Eigen::array, 1> contract_dims = { + Eigen::IndexPair(1, 1)}; - // Compute intermediate results (convolution matrix) into col_buffer. - TensorMap C(col_buffer_data, input_batch * input_image_size, - kernel_total_size); + // Compute intermediate results (convolution matrix) into conv_matrix. + TensorMap2D C(conv_matrix_data, input_batch * input_image_size, + kernel_total_size); - ConstTensorMap A(lhs, input_batch * input_image_size, input_channels); - ConstTensorMap B(rhs, kernel_total_size, input_channels); + ConstTensorMap2D A(lhs, input_batch * input_image_size, input_channels); + ConstTensorMap3D B(rhs, kernel_x * kernel_y, kernel_channels, kernel_filters); // Use concurrent execution if we have a thread pool device. constexpr bool use_thread_pool = @@ -162,24 +176,22 @@ void EigenTransposedConv2D( const int output_offset = output_image_size * kernel_filters; // Pack the calculated patches into the output buffer. - // NOTE: The ownership of the col_buffer is transferred to the lambda without - // data copy or reallocation. Thanks to that, col_buffer_data pointer remains - // valid, and that is important because 'C' matrix is referencing it. We need - // to make sure this lambda is never copied, otherwise col_buffer won't - // contain contraction results at the time lambda is called. - auto pack_patches = [=, col_buffer = std::move(col_buffer)]() { + // NOTE: The ownership of the conv_matrix is transferred to the lambda without + // data copy or reallocation. Thanks to that, conv_matrix_data pointer remains + // valid, and that is important because 'C' matrix is referencing it. + auto pack_patches = [=, conv_matrix = std::move(conv_matrix)]() { // Using local pointers to buffers, because lambda is not mutable. - const ScalarType* col_buffer_data = col_buffer.data(); + const ScalarType* conv_matrix_data = conv_matrix.get(); ScalarType* local_out_data = out_data; // TODO(adambanas): Run this part in parallel. for (int image_id = 0; image_id < input_batch; ++image_id) { Pack2DPatches( - col_buffer_data, kernel_filters, output_y, output_x, kernel_y, + conv_matrix_data, kernel_filters, output_y, output_x, kernel_y, kernel_x, padding_y_before, padding_y_after, padding_x_before, padding_x_after, lhs_y_dilation, lhs_x_dilation, local_out_data); - col_buffer_data += input_offset; + conv_matrix_data += input_offset; local_out_data += output_offset; } @@ -190,24 +202,34 @@ void EigenTransposedConv2D( } }; + // Molds the output of the contraction into the shape expected by packing + // algorithm: + // - the minor dimension (dims[1]): the patch values to be packed; contiguous + // in memory + // - the major dimension (dims[0]): everything else + Eigen::DSizes post_contract_dims; + post_contract_dims[0] = input_batch * input_image_size; + post_contract_dims[1] = kernel_total_size; + if (done_callback) { // Schedule the work in the thread pool and return. - C.device(device, std::move(pack_patches)) = A.contract(B, contract_dims); + C.device(device, std::move(pack_patches)) = + A.contract(B, contract_dims).reshape(post_contract_dims); } else { // Run synchronously in the current thread. - C.device(device) = A.contract(B, contract_dims); + C.device(device) = A.contract(B, contract_dims).reshape(post_contract_dims); pack_patches(); } + return true; } inline bool CanUseCustomTransposedConv( - Eigen::Index input_channels, Eigen::Index kernel_filters, Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index lhs_x_dilation, Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count) { return (lhs_x_dilation > 1 || lhs_y_dilation > 1) && rhs_x_dilation == 1 && - rhs_y_dilation == 1 && (input_channels == 1 || kernel_filters == 1) && - feature_group_count == 1 && x_stride == 1 && y_stride == 1; + rhs_y_dilation == 1 && feature_group_count == 1 && x_stride == 1 && + y_stride == 1; } // Algorithm that works for all types of 2D convolutions. Even though it works @@ -361,23 +383,25 @@ void EigenConv2D(const EigenDevice& device, ScalarType* out, ScalarType* lhs, Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count, std::function done_callback, bool use_thunk_runtime) { - if (CanUseCustomTransposedConv(input_channels, kernel_filters, x_stride, - y_stride, lhs_x_dilation, lhs_y_dilation, - rhs_x_dilation, rhs_y_dilation, + if (CanUseCustomTransposedConv(x_stride, y_stride, lhs_x_dilation, + lhs_y_dilation, rhs_x_dilation, rhs_y_dilation, feature_group_count)) { - EigenTransposedConv2D( - device, out, lhs, rhs, input_batch, input_x, input_y, input_channels, - kernel_x, kernel_y, kernel_channels, kernel_filters, output_x, output_y, - padding_x_before, padding_x_after, padding_y_before, padding_y_after, - lhs_x_dilation, lhs_y_dilation, done_callback, use_thunk_runtime); - } else { - EigenGenericConv2D( - device, out, lhs, rhs, input_batch, input_x, input_y, input_channels, - kernel_x, kernel_y, kernel_channels, kernel_filters, output_x, output_y, - x_stride, y_stride, padding_x_before, padding_x_after, padding_y_before, - padding_y_after, lhs_x_dilation, lhs_y_dilation, rhs_x_dilation, - rhs_y_dilation, feature_group_count, done_callback, use_thunk_runtime); + if (EigenTransposedConv2D( + device, out, lhs, rhs, input_batch, input_x, input_y, + input_channels, kernel_x, kernel_y, kernel_channels, kernel_filters, + output_x, output_y, padding_x_before, padding_x_after, + padding_y_before, padding_y_after, lhs_x_dilation, lhs_y_dilation, + done_callback, use_thunk_runtime)) { + return; + } + // Transposed convolution failed, fallback to generic implementation. } + EigenGenericConv2D( + device, out, lhs, rhs, input_batch, input_x, input_y, input_channels, + kernel_x, kernel_y, kernel_channels, kernel_filters, output_x, output_y, + x_stride, y_stride, padding_x_before, padding_x_after, padding_y_before, + padding_y_after, lhs_x_dilation, lhs_y_dilation, rhs_x_dilation, + rhs_y_dilation, feature_group_count, done_callback, use_thunk_runtime); } template diff --git a/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_test.cc index 20a75d1f97ebcc..ce8142444ebe5d 100644 --- a/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_test.cc @@ -15,27 +15,25 @@ limitations under the License. #include "xla/backends/cpu/runtime/convolution_thunk.h" -#include #include #include -#include +#include #include #include #include "absl/algorithm/container.h" #include "absl/status/status.h" +#include "absl/types/span.h" #include "Eigen/Core" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/thunk.h" -#include "xla/primitive_util.h" +#include "xla/backends/cpu/runtime/thunk_testlib.h" +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/maybe_owning_device_memory.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla::cpu { namespace { @@ -102,23 +100,6 @@ std::vector MakeDataVector(const std::vector& dims) { return std::vector(size, ElementType(0.0)); } -template -std::vector MakeBuffers( - const std::vector& input, - const std::vector& kernel, - const std::vector& output) { - std::vector buffers; - size_t input_size_in_bytes = input.size() * sizeof(ElementType); - buffers.emplace_back(se::DeviceMemoryBase(input.data(), input_size_in_bytes)); - size_t kernel_size_in_bytes = kernel.size() * sizeof(ElementType); - buffers.emplace_back( - se::DeviceMemoryBase(kernel.data(), kernel_size_in_bytes)); - size_t output_size_in_bytes = output.size() * sizeof(ElementType); - buffers.emplace_back( - se::DeviceMemoryBase(output.data(), output_size_in_bytes)); - return buffers; -} - ConvolutionThunk::Options MakeConvolutionOptions() { ConvolutionThunk::Options options; options.multi_threaded = false; @@ -175,107 +156,80 @@ Window MakeWindow(int convolution_rank) { template class ConvolutionThunkBuilder { public: - // Set convolution options. If not called before Build(), default options are - // used. - void SetOptions(ConvolutionThunk::Options options) { - options_ = std::move(options); - } - - // Constructor that lets the user specify the convolution dimensions. - auto Build(ConvolutionDimensions dims = ConvolutionDimensions()) { - // Data dimensions. - auto input_dims = MakeInputDims(dims); - auto kernel_dims = MakeKernelDims(dims); - auto output_dims = MakeOutputDims(dims); + ConvolutionThunkBuilder(ConvolutionThunkBuilder&&) = delete; + ConvolutionThunkBuilder& operator=(ConvolutionThunkBuilder&&) = delete; - return Build(input_dims, kernel_dims, output_dims); - } + explicit ConvolutionThunkBuilder( + ConvolutionDimensions dims = ConvolutionDimensions()) + : ConvolutionThunkBuilder(MakeInputDims(dims), MakeKernelDims(dims), + MakeOutputDims(dims)) {} - // Constructor that lets the user specify each data dimension separately. - auto Build(const std::vector& input_dims, - const std::vector& kernel_dims, - const std::vector& output_dims) { + ConvolutionThunkBuilder(absl::Span input_dims, + absl::Span kernel_dims, + absl::Span output_dims) { // Convolution rank inferred from the input dimensions. int convolution_rank = input_dims.size() - 2; + // Convolution parameters. + dnums_ = MakeConvolutionDimensionNumbers(convolution_rank); + window_ = MakeWindow(convolution_rank); + // Actual data. - input_ = MakeDataVector(input_dims); - kernel_ = MakeDataVector(kernel_dims); - output_ = MakeDataVector(output_dims); - - // Buffers. - size_t input_size_in_bytes = input_.size() * sizeof(ElementType); - buffers_.emplace_back( - se::DeviceMemoryBase(input_.data(), input_size_in_bytes)); - size_t kernel_size_in_bytes = kernel_.size() * sizeof(ElementType); - buffers_.emplace_back( - se::DeviceMemoryBase(kernel_.data(), kernel_size_in_bytes)); - size_t output_size_in_bytes = output_.size() * sizeof(ElementType); - buffers_.emplace_back( - se::DeviceMemoryBase(output_.data(), output_size_in_bytes)); - - // Buffer allocations. - allocations_ = std::make_unique(buffers_); - - input_alloc_ = - std::make_unique(0, input_size_in_bytes, 0); - kernel_alloc_ = - std::make_unique(1, kernel_size_in_bytes, 0); - output_alloc_ = - std::make_unique(2, output_size_in_bytes, 0); - - BufferAllocation::Slice input_slice(input_alloc_.get(), 0, - input_size_in_bytes); - BufferAllocation::Slice kernel_slice(kernel_alloc_.get(), 0, - kernel_size_in_bytes); - BufferAllocation::Slice output_slice(output_alloc_.get(), 0, - output_size_in_bytes); - - // Shapes. - auto primitive_type = primitive_util::NativeToPrimitiveType(); - Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_dims); - Shape kernel_shape = ShapeUtil::MakeShape(primitive_type, kernel_dims); - Shape output_shape = ShapeUtil::MakeShape(primitive_type, output_dims); + input_ = LiteralUtil::CreateFull(input_dims, ElementType(0.0)); + kernel_ = LiteralUtil::CreateFull(kernel_dims, ElementType(0.0)); + output_ = LiteralUtil::CreateFull(output_dims, ElementType(0.0)); - // Convolution parameters. - auto dnums = MakeConvolutionDimensionNumbers(convolution_rank); - auto window = MakeWindow(convolution_rank); + input_alloc_ = CreateBufferAllocation(0, input_); + kernel_alloc_ = CreateBufferAllocation(1, kernel_); + output_alloc_ = CreateBufferAllocation(2, output_); + } - // Create thunk. - return ConvolutionThunk::Create( - {"convolution"}, options_, std::move(input_slice), input_shape, - std::move(kernel_slice), kernel_shape, std::move(output_slice), - output_shape, dnums, window, - /*feature_group_count=*/1); + // Set convolution options. If not called before Build(), default options are + // used. + void SetOptions(ConvolutionThunk::Options options) { + options_ = std::move(options); } - // Get execution parameters for the last created thunk. - auto GetExecutionParams() { - return Thunk::ExecuteParams{nullptr, allocations_.get()}; + BufferAllocations GetAllocations() { + return CreateBufferAllocations(input_, kernel_, output_); + } + + auto Build() { + auto [input_slice, kernel_slice, output_slice] = + CreateBufferAllocationSlice(*input_alloc_, *kernel_alloc_, + *output_alloc_); + return ConvolutionThunk::Create( + {"convolution"}, options_, input_slice, input_.shape(), kernel_slice, + kernel_.shape(), output_slice, output_.shape(), dnums_, window_, + /*feature_group_count=*/1); } private: - std::vector input_; - std::vector kernel_; - std::vector output_; - std::vector buffers_; - ConvolutionThunk::Options options_ = MakeConvolutionOptions(); + ConvolutionDimensionNumbers dnums_; + Window window_; + + Literal input_; + Literal kernel_; + Literal output_; - // Unique pointers, because they are created only when needed. - std::unique_ptr allocations_; - std::unique_ptr input_alloc_; - std::unique_ptr kernel_alloc_; - std::unique_ptr output_alloc_; + std::optional input_alloc_; + std::optional kernel_alloc_; + std::optional output_alloc_; + + ConvolutionThunk::Options options_ = MakeConvolutionOptions(); }; template void SuccessfulConvolution(int convolution_rank) { - ConvolutionThunkBuilder builder; - TF_ASSERT_OK_AND_ASSIGN( - auto thunk, builder.Build(ConvolutionDimensions(convolution_rank))); + ConvolutionThunkBuilder builder( + ConvolutionDimensions{convolution_rank}); + TF_ASSERT_OK_AND_ASSIGN(auto thunk, builder.Build()); + BufferAllocations allocations = builder.GetAllocations(); // Execute thunk and wait for completion. - Thunk::ExecuteParams params = builder.GetExecutionParams(); + Thunk::ExecuteParams params; + params.buffer_allocations = &allocations; + auto execute_event = thunk->Execute(params); tsl::BlockUntilReady(execute_event); @@ -308,10 +262,10 @@ TEST(ConvolutionThunkTest, CreationErrorOnUnsupportedType) { } TEST(ConvolutionThunkTest, CreationErrorOnTooHighConvolutionRank) { - ConvolutionThunkBuilder builder; + ConvolutionThunkBuilder builder( + ConvolutionDimensions(/*convolution_rank=*/4)); - auto status_or_thunk = - builder.Build(ConvolutionDimensions(/*convolution_rank=*/4)); + auto status_or_thunk = builder.Build(); EXPECT_EQ(status_or_thunk.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status_or_thunk.status().message(), @@ -319,10 +273,10 @@ TEST(ConvolutionThunkTest, CreationErrorOnTooHighConvolutionRank) { } TEST(ConvolutionThunkTest, CreationErrorOnTooLowConvolutionRank) { - ConvolutionThunkBuilder builder; + ConvolutionThunkBuilder builder( + ConvolutionDimensions(/*convolution_rank=*/0)); - auto status_or_thunk = - builder.Build(ConvolutionDimensions(/*convolution_rank=*/0)); + auto status_or_thunk = builder.Build(); EXPECT_EQ(status_or_thunk.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status_or_thunk.status().message(), @@ -330,8 +284,6 @@ TEST(ConvolutionThunkTest, CreationErrorOnTooLowConvolutionRank) { } TEST(ConvolutionThunkTest, CreationErrorOnMismatchedKernelBufferRank) { - ConvolutionThunkBuilder builder; - ConvolutionDimensions dims_2d(/*convolution_rank=*/2); auto input_dims = MakeInputDims(dims_2d); auto output_dims = MakeOutputDims(dims_2d); @@ -340,7 +292,9 @@ TEST(ConvolutionThunkTest, CreationErrorOnMismatchedKernelBufferRank) { ConvolutionDimensions dims_3d(/*convolution_rank=*/3); auto kernel_dims = MakeKernelDims(dims_3d); - auto status_or_thunk = builder.Build(input_dims, kernel_dims, output_dims); + ConvolutionThunkBuilder builder(input_dims, kernel_dims, output_dims); + + auto status_or_thunk = builder.Build(); EXPECT_EQ(status_or_thunk.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status_or_thunk.status().message(), @@ -349,8 +303,6 @@ TEST(ConvolutionThunkTest, CreationErrorOnMismatchedKernelBufferRank) { } TEST(ConvolutionThunkTest, CreationErrorOnMismatchedOutputBufferRank) { - ConvolutionThunkBuilder builder; - ConvolutionDimensions dims_2d(/*convolution_rank=*/2); auto input_dims = MakeInputDims(dims_2d); auto kernel_dims = MakeKernelDims(dims_2d); @@ -359,7 +311,9 @@ TEST(ConvolutionThunkTest, CreationErrorOnMismatchedOutputBufferRank) { ConvolutionDimensions dims_3d(/*convolution_rank=*/3); auto output_dims = MakeOutputDims(dims_3d); - auto status_or_thunk = builder.Build(input_dims, kernel_dims, output_dims); + ConvolutionThunkBuilder builder(input_dims, kernel_dims, output_dims); + auto status_or_thunk = builder.Build(); + EXPECT_EQ(status_or_thunk.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status_or_thunk.status().message(), @@ -368,8 +322,6 @@ TEST(ConvolutionThunkTest, CreationErrorOnMismatchedOutputBufferRank) { } TEST(ConvolutionThunkTest, CreationErrorOnBatchSizeMismatch) { - ConvolutionThunkBuilder builder; - ConvolutionDimensions dims; dims.batch_size = 1; auto input_dims = MakeInputDims(dims); @@ -379,7 +331,9 @@ TEST(ConvolutionThunkTest, CreationErrorOnBatchSizeMismatch) { dims.batch_size = 2; auto output_dims = MakeOutputDims(dims); - auto status_or_thunk = builder.Build(input_dims, kernel_dims, output_dims); + ConvolutionThunkBuilder builder(input_dims, kernel_dims, output_dims); + auto status_or_thunk = builder.Build(); + EXPECT_EQ(status_or_thunk.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status_or_thunk.status().message(), @@ -388,8 +342,6 @@ TEST(ConvolutionThunkTest, CreationErrorOnBatchSizeMismatch) { } TEST(ConvolutionThunkTest, CreationErrorOnOutputChannelsMismatch) { - ConvolutionThunkBuilder builder; - ConvolutionDimensions dims; dims.output_channels = 3; auto input_dims = MakeInputDims(dims); @@ -399,7 +351,9 @@ TEST(ConvolutionThunkTest, CreationErrorOnOutputChannelsMismatch) { dims.output_channels = 4; auto output_dims = MakeOutputDims(dims); - auto status_or_thunk = builder.Build(input_dims, kernel_dims, output_dims); + ConvolutionThunkBuilder builder(input_dims, kernel_dims, output_dims); + auto status_or_thunk = builder.Build(); + EXPECT_EQ(status_or_thunk.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT( @@ -411,15 +365,19 @@ TEST(ConvolutionThunkTest, CreationErrorOnOutputChannelsMismatch) { TEST(ConvolutionThunkTest, ExecutionErrorOnMissingThreadPoolInMultiThreadedMode) { ConvolutionThunkBuilder builder; + auto options = MakeConvolutionOptions(); options.multi_threaded = true; builder.SetOptions(options); - TF_ASSERT_OK_AND_ASSIGN(auto thunk, builder.Build(ConvolutionDimensions())); + TF_ASSERT_OK_AND_ASSIGN(auto thunk, builder.Build()); + BufferAllocations allocations = builder.GetAllocations(); // Execute thunk and wait for completion. - Thunk::ExecuteParams params = builder.GetExecutionParams(); + Thunk::ExecuteParams params; params.intra_op_threadpool = nullptr; + params.buffer_allocations = &allocations; + auto execute_event = thunk->Execute(params); tsl::BlockUntilReady(execute_event); diff --git a/third_party/xla/xla/backends/cpu/runtime/copy_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/copy_thunk_test.cc index 8a8e4fb4debd27..ea7592a1c781ac 100644 --- a/third_party/xla/xla/backends/cpu/runtime/copy_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/copy_thunk_test.cc @@ -15,36 +15,32 @@ limitations under the License. #include "xla/backends/cpu/runtime/copy_thunk.h" -#include -#include - #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_testlib.h" #include "xla/layout_util.h" +#include "xla/literal_util.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/maybe_owning_device_memory.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla::cpu { namespace { TEST(CopyThunkTest, CopyEmptyShape) { - std::vector buffers; - buffers.emplace_back(se::DeviceMemoryBase(nullptr, 0)); - buffers.emplace_back(se::DeviceMemoryBase(nullptr, 0)); - - BufferAllocations allocations(buffers); + auto src = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto dst = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); - BufferAllocation src_alloc(/*index=*/0, /*size=*/100, /*color=*/0); - BufferAllocation dst_alloc(/*index=*/1, /*size=*/100, /*color=*/0); + BufferAllocations allocations = CreateBufferAllocations(src, dst); + auto [src_alloc, dst_alloc] = CreateBufferAllocation(src, dst); - BufferAllocation::Slice src_slice(&src_alloc, 0, 0); - BufferAllocation::Slice dst_slice(&dst_alloc, 0, 0); + BufferAllocation::Slice src_slice = + CreateBufferAllocationSlice(src_alloc, 0, 0); + BufferAllocation::Slice dst_slice = + CreateBufferAllocationSlice(src_alloc, 0, 0); Shape shape = ShapeUtil::MakeShape(F32, {0, 2}); @@ -60,27 +56,18 @@ TEST(CopyThunkTest, CopyEmptyShape) { } TEST(CopyThunkTest, CopySameShape) { - std::vector buffers; - std::vector src = {1.0, 2.0, 3.0, 4.0}; - std::vector dst(4, 0.0); - - size_t size_in_bytes = src.size() * sizeof(float); - buffers.emplace_back(se::DeviceMemoryBase(src.data(), size_in_bytes)); - buffers.emplace_back(se::DeviceMemoryBase(dst.data(), size_in_bytes)); + auto src = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto dst = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); - BufferAllocations allocations(buffers); + BufferAllocations allocations = CreateBufferAllocations(src, dst); - BufferAllocation src_alloc(/*index=*/0, size_in_bytes, /*color=*/0); - BufferAllocation dst_alloc(/*index=*/1, size_in_bytes, /*color=*/0); - - BufferAllocation::Slice src_slice(&src_alloc, 0, size_in_bytes); - BufferAllocation::Slice dst_slice(&dst_alloc, 0, size_in_bytes); - - Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + auto [src_alloc, dst_alloc] = CreateBufferAllocation(src, dst); + auto [src_slice, dst_slice] = + CreateBufferAllocationSlice(src_alloc, dst_alloc); TF_ASSERT_OK_AND_ASSIGN( - auto thunk, - CopyThunk::Create({"copy"}, src_slice, shape, dst_slice, shape)); + auto thunk, CopyThunk::Create({"copy"}, src_slice, src.shape(), dst_slice, + dst.shape())); Thunk::ExecuteParams params = {nullptr, &allocations}; @@ -92,29 +79,21 @@ TEST(CopyThunkTest, CopySameShape) { } TEST(CopyThunkTest, CopyTransposed) { - std::vector buffers; - std::vector src = {1.0, 2.0, 3.0, 4.0}; - std::vector dst(4, 0.0); - - size_t size_in_bytes = src.size() * sizeof(float); - buffers.emplace_back(se::DeviceMemoryBase(src.data(), size_in_bytes)); - buffers.emplace_back(se::DeviceMemoryBase(dst.data(), size_in_bytes)); - - BufferAllocations allocations(buffers); + auto src = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto dst = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); - BufferAllocation src_alloc(/*index=*/0, size_in_bytes, /*color=*/0); - BufferAllocation dst_alloc(/*index=*/1, size_in_bytes, /*color=*/0); + BufferAllocations allocations = CreateBufferAllocations(src, dst); - BufferAllocation::Slice src_slice(&src_alloc, 0, size_in_bytes); - BufferAllocation::Slice dst_slice(&dst_alloc, 0, size_in_bytes); + auto [src_alloc, dst_alloc] = CreateBufferAllocation(src, dst); + auto [src_slice, dst_slice] = + CreateBufferAllocationSlice(src_alloc, dst_alloc); - Shape src_shape = ShapeUtil::MakeShape(F32, {2, 2}); - *src_shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - Shape dst_shape = ShapeUtil::MakeShape(F32, {2, 2}); + Shape transposed_shape = src.shape(); + *transposed_shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); TF_ASSERT_OK_AND_ASSIGN( - auto thunk, - CopyThunk::Create({"copy"}, src_slice, src_shape, dst_slice, dst_shape)); + auto thunk, CopyThunk::Create({"copy"}, src_slice, transposed_shape, + dst_slice, dst.shape())); Thunk::ExecuteParams params = {nullptr, &allocations}; @@ -122,30 +101,29 @@ TEST(CopyThunkTest, CopyTransposed) { tsl::BlockUntilReady(execute_event); ASSERT_FALSE(execute_event.IsError()); - std::vector expected = {1.0, 3.0, 2.0, 4.0}; - EXPECT_EQ(expected, dst); + EXPECT_EQ(dst, LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}})); } TEST(CopyThunkTest, CopyTransposedEmptyShape) { - std::vector buffers; - buffers.emplace_back(se::DeviceMemoryBase(nullptr, 0)); - buffers.emplace_back(se::DeviceMemoryBase(nullptr, 0)); + auto src = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto dst = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); - BufferAllocations allocations(buffers); + BufferAllocations allocations = CreateBufferAllocations(src, dst); + auto [src_alloc, dst_alloc] = CreateBufferAllocation(src, dst); - BufferAllocation src_alloc(/*index=*/0, /*size=*/100, /*color=*/0); - BufferAllocation dst_alloc(/*index=*/1, /*size=*/100, /*color=*/0); + BufferAllocation::Slice src_slice = + CreateBufferAllocationSlice(src_alloc, 0, 0); + BufferAllocation::Slice dst_slice = + CreateBufferAllocationSlice(src_alloc, 0, 0); - BufferAllocation::Slice src_slice(&src_alloc, 0, 0); - BufferAllocation::Slice dst_slice(&dst_alloc, 0, 0); + Shape shape = ShapeUtil::MakeShape(F32, {0, 2}); - Shape src_shape = ShapeUtil::MakeShape(F32, {0, 2}); - *src_shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - Shape dst_shape = ShapeUtil::MakeShape(F32, {0, 2}); + Shape transposed_shape = shape; + *transposed_shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); TF_ASSERT_OK_AND_ASSIGN( - auto thunk, - CopyThunk::Create({"copy"}, src_slice, src_shape, dst_slice, dst_shape)); + auto thunk, CopyThunk::Create({"copy"}, src_slice, transposed_shape, + dst_slice, shape)); Thunk::ExecuteParams params = {nullptr, &allocations}; diff --git a/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc index 8f693a1e3c5378..974a77522ac77d 100644 --- a/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc @@ -132,6 +132,12 @@ absl::StatusOr BuildCallFrameForTypedFFI( // memory addresses will be updated at runtime. for (int i = 0; i < op_buffers.arguments_buffers.size(); ++i) { auto& shape = op_buffers.arguments_shapes[i]; + + if (shape.IsToken()) { + builder.AddTokenArg(); + continue; + } + auto elements = absl::c_accumulate(shape.dimensions(), 1ULL, std::multiplies()); auto dtype_bytes = primitive_util::ByteWidth(shape.element_type()); @@ -144,6 +150,12 @@ absl::StatusOr BuildCallFrameForTypedFFI( // memory addresses will be updated at runtime. for (int i = 0; i < op_buffers.results_buffers.size(); ++i) { auto& shape = op_buffers.results_shapes[i]; + + if (shape.IsToken()) { + builder.AddTokenRet(); + continue; + } + auto elements = absl::c_accumulate(shape.dimensions(), 1ULL, std::multiplies()); auto dtype_bytes = primitive_util::ByteWidth(shape.element_type()); diff --git a/third_party/xla/xla/backends/cpu/runtime/dot_lib.cc b/third_party/xla/xla/backends/cpu/runtime/dot_lib.cc new file mode 100644 index 00000000000000..05aaca671a474a --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/dot_lib.cc @@ -0,0 +1,144 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/dot_lib.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "xla/layout_util.h" +#include "xla/runtime/buffer_use.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/util.h" + +namespace xla::cpu { + +absl::InlinedVector DotBufferUses(const DotSlices& slices) { + return {BufferUse::Read(slices.lhs_buffer), + BufferUse::Read(slices.rhs_buffer), + BufferUse::Write(slices.out_buffer)}; +} + +absl::StatusOr GetDotShape(const DotDimensionNumbers& dot_dimensions, + const Shape& lhs_shape, + const Shape& rhs_shape, + const Shape& out_shape) { + // All shapes must be in dim0-major layout. + if (!LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()) || + !LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()) || + !LayoutUtil::IsMonotonicWithDim0Major(out_shape.layout())) { + return InvalidArgument( + "DotThunk requires all operands and outputs to be in " + "dim0-major layout: lhs_shape=[%s], rhs_shape=[%s], out_shape=[%s]", + lhs_shape.ToString(true), rhs_shape.ToString(true), + out_shape.ToString(true)); + } + + // Batch dimensions must be contiguous and start at 0. + std::vector batch_dims(dot_dimensions.lhs_batch_dimensions().size()); + absl::c_iota(batch_dims, 0); + + if (!absl::c_equal(dot_dimensions.lhs_batch_dimensions(), batch_dims) || + !absl::c_equal(dot_dimensions.rhs_batch_dimensions(), batch_dims)) { + return InvalidArgument( + "Batch dimensions must be contiguous and start at 0: " + "lhs_batch_dims=[%s], rhs_batch_dims=[%s]", + absl::StrJoin(dot_dimensions.lhs_batch_dimensions(), ","), + absl::StrJoin(dot_dimensions.rhs_batch_dimensions(), ",")); + } + + int64_t num_batch_dims = batch_dims.size(); + int64_t batch_size = + std::accumulate(out_shape.dimensions().begin(), + out_shape.dimensions().begin() + num_batch_dims, 1LL, + std::multiplies()); + + Shape lhs_matmul_shape = ShapeUtil::DeleteDimensions(batch_dims, lhs_shape); + Shape rhs_matmul_shape = ShapeUtil::DeleteDimensions(batch_dims, rhs_shape); + Shape out_matmul_shape = ShapeUtil::DeleteDimensions(batch_dims, out_shape); + + // Check that matmul shapes are rank 2 or less and can be represented as + // Eigen 2D contraction. + if (lhs_matmul_shape.rank() > 2 || rhs_matmul_shape.rank() > 2 || + out_matmul_shape.rank() > 2) { + return InvalidArgument( + "MatMul shape must be rank 2 or less: lhs=%s, rhs=%s, out=%s", + lhs_matmul_shape.ToString(true), rhs_matmul_shape.ToString(true), + out_matmul_shape.ToString(true)); + } + + return DotShape{ + batch_size, + std::move(lhs_matmul_shape), + std::move(rhs_matmul_shape), + std::move(out_matmul_shape), + }; +} + +absl::StatusOr GetDotCanonicalDims( + const DotDimensionNumbers& dot_dimensions, const DotShape& dot_shape) { + // Copy from the original dot dimension numbers. + absl::InlinedVector lhs_contracting_dims; + absl::InlinedVector rhs_contracting_dims; + + lhs_contracting_dims.assign( + dot_dimensions.lhs_contracting_dimensions().begin(), + dot_dimensions.lhs_contracting_dimensions().end()); + rhs_contracting_dims.assign( + dot_dimensions.rhs_contracting_dimensions().begin(), + dot_dimensions.rhs_contracting_dimensions().end()); + + // Adjust contracting dimensions for leading batch dimensions. + for (int64_t& dim : lhs_contracting_dims) + dim -= dot_dimensions.lhs_batch_dimensions_size(); + for (int64_t& dim : rhs_contracting_dims) + dim -= dot_dimensions.rhs_batch_dimensions_size(); + + // Non-contracting dots should never make it here. + TF_RET_CHECK(lhs_contracting_dims.size() == 1); + TF_RET_CHECK(rhs_contracting_dims.size() == 1); + TF_RET_CHECK(lhs_contracting_dims[0] < 2); + TF_RET_CHECK(rhs_contracting_dims[0] < 2); + + auto is_column_major = [](const Shape& shape) { + return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0; + }; + + return DotCanonicalDims{ + /*m=*/dot_shape.lhs_matmul_shape.rank() <= 1 + ? int64_t{1} + : dot_shape.lhs_matmul_shape.dimensions(1 - lhs_contracting_dims[0]), + /*k=*/dot_shape.lhs_matmul_shape.dimensions(lhs_contracting_dims[0]), + /*n=*/dot_shape.rhs_matmul_shape.rank() <= 1 + ? int64_t{1} + : dot_shape.rhs_matmul_shape.dimensions(1 - rhs_contracting_dims[0]), + /*lhs_column_major=*/is_column_major(dot_shape.lhs_matmul_shape), + /*lhs_canonical=*/dot_shape.lhs_matmul_shape.rank() <= 1 || + lhs_contracting_dims[0] == 1, + /*rhs_column_major=*/is_column_major(dot_shape.rhs_matmul_shape), + /*rhs_canonical=*/rhs_contracting_dims[0] == 0}; +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/runtime/dot_lib.h b/third_party/xla/xla/backends/cpu/runtime/dot_lib.h new file mode 100644 index 00000000000000..e913f56c9f0bc8 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/dot_lib.h @@ -0,0 +1,96 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_RUNTIME_DOT_LIB_H_ +#define XLA_BACKENDS_CPU_RUNTIME_DOT_LIB_H_ + +#include + +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "xla/runtime/buffer_use.h" +#include "xla/service/buffer_assignment.h" +#include "xla/shape.h" + +namespace xla::cpu { + +// Allocation slices of the dot operation. +struct DotSlices { + BufferAllocation::Slice lhs_buffer; + Shape lhs_shape; + + BufferAllocation::Slice rhs_buffer; + Shape rhs_shape; + + BufferAllocation::Slice out_buffer; + Shape out_shape; +}; + +// Shape of the batched dot operation supported by the XLA:CPU runtime. +struct DotShape { + // Product of batch dimensions. + int64_t batch_size; + + // Shapes of the non-batch matrix-multiplication for the dot operation + Shape lhs_matmul_shape; + Shape rhs_matmul_shape; + Shape out_matmul_shape; +}; + +// Dot operation is implemented as a matrix-matrix multiply (row-major x +// rowm-major or col-major x col-major). For batched dot operations, it is +// implemented as multiple matrix multiplications repeated for each batch +// element. +struct DotCanonicalDims { + // The number of rows in the LHS. + int64_t m; + + // The number of columns in the LHS, which also must be equal to the + // number of rows in the RHS. + int64_t k; + + // The number of columns in the RHS. + int64_t n; + + // True if the LHS matrix is column major. + bool lhs_column_major; + + // True if the LHS contraction dimension is 1. + bool lhs_canonical; + + // True if the RHS matrix is column major. + bool rhs_column_major; + + // True if the RHS contraction dimension is 0. + bool rhs_canonical; +}; + +// Returns buffer uses of the dot operation. +absl::InlinedVector DotBufferUses(const DotSlices& slices); + +// Verifies dot dimensions and shapes and returns the shape of the dot operation +// in a form that is convenient for the runtime implementation. +absl::StatusOr GetDotShape(const DotDimensionNumbers& dot_dimensions, + const Shape& lhs_shape, + const Shape& rhs_shape, + const Shape& out_shape); + +// Get canonical dot dimensions for the given dot shape. +absl::StatusOr GetDotCanonicalDims( + const DotDimensionNumbers& dot_dimensions, const DotShape& dot_shape); + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_RUNTIME_DOT_LIB_H_ diff --git a/third_party/xla/xla/backends/cpu/runtime/dot_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/dot_thunk.cc index 3b0d81ff346429..00bcec6a2df83c 100644 --- a/third_party/xla/xla/backends/cpu/runtime/dot_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk.cc @@ -17,197 +17,75 @@ limitations under the License. #include #include -#include #include -#include #include -#include -#include "absl/algorithm/container.h" #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/dot_lib.h" #include "xla/backends/cpu/runtime/thunk.h" #include "xla/layout_util.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" #include "xla/shape.h" -#include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/logging.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { -namespace { - -// Dot operation is implemented as a matrix-matrix multiply (row-major x -// rowm-major or col-major x col-major). For batched dot operations, it is -// implemented as multiple matrix multiplications repeated for each batch -// element. -// -// We rely on col-major Eigen contraction and figure out how to represent dot -// operation as a contraction based on the dot dimension numbers. -struct MatMulDims { - // The number of rows in the LHS. - int64_t m; - - // The number of columns in the LHS, which also must be equal to the - // number of rows in the RHS. - int64_t k; - - // The number of columns in the RHS. - int64_t n; - - // True if the LHS matrix is column major. - bool lhs_column_major; - - // True if the LHS contraction dimension is 1. - bool lhs_canonical; - - // True if the RHS matrix is column major. - bool rhs_column_major; - - // True if the RHS contraction dimension is 0. - bool rhs_canonical; -}; - -} // namespace - -static MatMulDims GetMatMulDims( - const Shape& lhs_shape, absl::Span lhs_contracting_dims, - const Shape& rhs_shape, absl::Span rhs_contracting_dims) { - // Non-contracting dots should never make it here. - CHECK_EQ(lhs_contracting_dims.size(), 1); - CHECK_EQ(rhs_contracting_dims.size(), 1); - CHECK_LT(lhs_contracting_dims[0], 2); - CHECK_LT(rhs_contracting_dims[0], 2); - - auto is_column_major = [](const Shape& shape) { - return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0; - }; - - return MatMulDims{ - /*m=*/lhs_shape.rank() <= 1 - ? 1LL - : lhs_shape.dimensions(1LL - lhs_contracting_dims[0]), - /*k=*/lhs_shape.dimensions(lhs_contracting_dims[0]), - /*n=*/rhs_shape.rank() <= 1 - ? 1LL - : rhs_shape.dimensions(1LL - rhs_contracting_dims[0]), - /*lhs_column_major=*/is_column_major(lhs_shape), - /*lhs_canonical=*/lhs_shape.rank() <= 1 || lhs_contracting_dims[0] == 1, - /*rhs_column_major=*/is_column_major(rhs_shape), - /*rhs_canonical=*/rhs_contracting_dims[0] == 0}; -} absl::StatusOr> DotThunk::Create( Info info, DotDimensionNumbers dot_dimensions, BufferAllocation::Slice lhs_buffer, Shape lhs_shape, BufferAllocation::Slice rhs_buffer, Shape rhs_shape, BufferAllocation::Slice out_buffer, Shape out_shape) { - // All shapes must be in dim0-major layout. - if (!LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()) || - !LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()) || - !LayoutUtil::IsMonotonicWithDim0Major(out_shape.layout())) { - return InvalidArgument( - "DotThunk requires all operands and outputs to be in " - "dim0-major layout: lhs_shape=[%s], rhs_shape=[%s], out_shape=[%s]", - lhs_shape.ToString(true), rhs_shape.ToString(true), - out_shape.ToString(true)); - } + TF_ASSIGN_OR_RETURN(DotShape dot_shape, GetDotShape(dot_dimensions, lhs_shape, + rhs_shape, out_shape)); - // Batch dimensions must be contiguous and start at 0. - std::vector batch_dims(dot_dimensions.lhs_batch_dimensions().size()); - absl::c_iota(batch_dims, 0); - - if (!absl::c_equal(dot_dimensions.lhs_batch_dimensions(), batch_dims) || - !absl::c_equal(dot_dimensions.rhs_batch_dimensions(), batch_dims)) { - return InvalidArgument( - "Batch dimensions must be contiguous and start at 0: " - "lhs_batch_dims=[%s], rhs_batch_dims=[%s]", - absl::StrJoin(dot_dimensions.lhs_batch_dimensions(), ","), - absl::StrJoin(dot_dimensions.rhs_batch_dimensions(), ",")); - } + TF_ASSIGN_OR_RETURN(DotCanonicalDims dot_canonical_dims, + GetDotCanonicalDims(dot_dimensions, dot_shape)); - int64_t num_batch_dims = batch_dims.size(); - int64_t batch_size = - std::accumulate(out_shape.dimensions().begin(), - out_shape.dimensions().begin() + num_batch_dims, 1LL, - std::multiplies()); - - Shape lhs_matmul_shape = ShapeUtil::DeleteDimensions(batch_dims, lhs_shape); - Shape rhs_matmul_shape = ShapeUtil::DeleteDimensions(batch_dims, rhs_shape); - Shape out_matmul_shape = ShapeUtil::DeleteDimensions(batch_dims, out_shape); - - // Check that matmul shapes are rank 2 or less and can be represented as - // Eigen 2D contraction. - if (lhs_matmul_shape.rank() > 2 || rhs_matmul_shape.rank() > 2 || - out_matmul_shape.rank() > 2) { - return InvalidArgument( - "MatMul shape must be rank 2 or less: lhs=%s, rhs=%s, out=%s", - lhs_matmul_shape.ToString(true), rhs_matmul_shape.ToString(true), - out_matmul_shape.ToString(true)); - } + DotSlices dot_slices{lhs_buffer, std::move(lhs_shape), + rhs_buffer, std::move(rhs_shape), + out_buffer, std::move(out_shape)}; - return absl::WrapUnique(new DotThunk( - info, std::move(dot_dimensions), lhs_buffer, std::move(lhs_shape), - rhs_buffer, std::move(rhs_shape), out_buffer, std::move(out_shape), - batch_size, std::move(lhs_matmul_shape), std::move(rhs_matmul_shape), - std::move(out_matmul_shape))); + return absl::WrapUnique( + new DotThunk(info, std::move(dot_dimensions), std::move(dot_slices), + std::move(dot_shape), std::move(dot_canonical_dims))); } DotThunk::DotThunk(Info info, DotDimensionNumbers dot_dimensions, - BufferAllocation::Slice lhs_buffer, Shape lhs_shape, - BufferAllocation::Slice rhs_buffer, Shape rhs_shape, - BufferAllocation::Slice out_buffer, Shape out_shape, - int64_t batch_size, Shape lhs_matmul_shape, - Shape rhs_matmul_shape, Shape out_matmul_shape) + DotSlices dot_slices, DotShape dot_shape, + DotCanonicalDims dot_canonical_dims) : Thunk(Kind::kDot, info), - dot_dimensions_(dot_dimensions), - lhs_buffer_(lhs_buffer), - lhs_shape_(lhs_shape), - rhs_buffer_(rhs_buffer), - rhs_shape_(rhs_shape), - out_buffer_(out_buffer), - out_shape_(out_shape), - batch_size_(batch_size), - lhs_matmul_shape_(lhs_matmul_shape), - rhs_matmul_shape_(rhs_matmul_shape), - out_matmul_shape_(out_matmul_shape) { - // Copy from the original dot dimension numbers. - lhs_matmul_contracting_dims_.assign( - dot_dimensions_.lhs_contracting_dimensions().begin(), - dot_dimensions_.lhs_contracting_dimensions().end()); - rhs_matmul_contracting_dims_.assign( - dot_dimensions_.rhs_contracting_dimensions().begin(), - dot_dimensions_.rhs_contracting_dimensions().end()); - - // Adjust contracting dimensions for leading batch dimensions. - for (int64_t& dim : lhs_matmul_contracting_dims_) - dim -= dot_dimensions_.lhs_batch_dimensions_size(); - for (int64_t& dim : rhs_matmul_contracting_dims_) - dim -= dot_dimensions_.rhs_batch_dimensions_size(); -} + dot_dimensions_(std::move(dot_dimensions)), + dot_slices_(std::move(dot_slices)), + dot_shape_(std::move(dot_shape)), + dot_canonical_dims_(std::move(dot_canonical_dims)) {} tsl::AsyncValueRef DotThunk::Execute( const ExecuteParams& params) { tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs_data, - params.buffer_allocations->GetDeviceAddress(lhs_buffer_)); + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase lhs_data, + params.buffer_allocations->GetDeviceAddress(dot_slices_.lhs_buffer)); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs_data, - params.buffer_allocations->GetDeviceAddress(rhs_buffer_)); + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase rhs_data, + params.buffer_allocations->GetDeviceAddress(dot_slices_.rhs_buffer)); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase out_data, - params.buffer_allocations->GetDeviceAddress(out_buffer_)); + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase out_data, + params.buffer_allocations->GetDeviceAddress(dot_slices_.out_buffer)); VLOG(3) << absl::StreamFormat( "Dot operation: lhs_batch_dims=[%s], rhs_batch_dims=[%s], " @@ -217,31 +95,28 @@ tsl::AsyncValueRef DotThunk::Execute( absl::StrJoin(dot_dimensions_.lhs_contracting_dimensions(), ","), absl::StrJoin(dot_dimensions_.rhs_contracting_dimensions(), ",")); - VLOG(3) << absl::StreamFormat(" lhs: %s in slice %s (%p)", - lhs_shape_.ToString(true), - lhs_buffer_.ToString(), lhs_data.opaque()); - VLOG(3) << absl::StreamFormat(" rhs: %s in slice %s (%p)", - rhs_shape_.ToString(true), - rhs_buffer_.ToString(), rhs_data.opaque()); - VLOG(3) << absl::StreamFormat(" out: %s in slice %s (%p)", - out_shape_.ToString(true), - out_buffer_.ToString(), out_data.opaque()); - VLOG(3) << absl::StreamFormat( - " matmul shape: batch_size=%d, lhs=%s, rhs=%s, out=%s", batch_size_, - lhs_matmul_shape_.ToString(true), rhs_matmul_shape_.ToString(true), - out_matmul_shape_.ToString(true)); + " lhs: %s in slice %s (%p)", dot_slices_.lhs_shape.ToString(true), + dot_slices_.lhs_buffer.ToString(), lhs_data.opaque()); + VLOG(3) << absl::StreamFormat( + " rhs: %s in slice %s (%p)", dot_slices_.rhs_shape.ToString(true), + dot_slices_.rhs_buffer.ToString(), rhs_data.opaque()); + VLOG(3) << absl::StreamFormat( + " out: %s in slice %s (%p)", dot_slices_.out_shape.ToString(true), + dot_slices_.out_buffer.ToString(), out_data.opaque()); - MatMulDims matmul_dims = - GetMatMulDims(lhs_matmul_shape_, lhs_matmul_contracting_dims_, - rhs_matmul_shape_, rhs_matmul_contracting_dims_); + VLOG(3) << absl::StreamFormat( + " matmul shape: batch_size=%d, lhs=%s, rhs=%s, out=%s", + dot_shape_.batch_size, dot_shape_.lhs_matmul_shape.ToString(true), + dot_shape_.rhs_matmul_shape.ToString(true), + dot_shape_.out_matmul_shape.ToString(true)); VLOG(3) << absl::StreamFormat( " matmul dims: m=%d, k=%d, n=%d, lhs_column_major=%v, lhs_canonical=%v, " "rhs_column_major=%v, rhs_canonical=%v", - matmul_dims.m, matmul_dims.k, matmul_dims.n, matmul_dims.lhs_column_major, - matmul_dims.lhs_canonical, matmul_dims.rhs_column_major, - matmul_dims.rhs_canonical); + dot_canonical_dims_.m, dot_canonical_dims_.k, dot_canonical_dims_.n, + dot_canonical_dims_.lhs_column_major, dot_canonical_dims_.lhs_canonical, + dot_canonical_dims_.rhs_column_major, dot_canonical_dims_.rhs_canonical); if (params.intra_op_threadpool == nullptr) { return InvalidArgument("Intra-op threadpool must be provided for DotThunk"); @@ -262,36 +137,41 @@ tsl::AsyncValueRef DotThunk::Execute( void* lhs = lhs_data.opaque(); void* rhs = rhs_data.opaque(); - bool transpose_lhs = !matmul_dims.lhs_canonical; - bool transpose_rhs = !matmul_dims.rhs_canonical; + int64_t m = dot_canonical_dims_.m; + int64_t n = dot_canonical_dims_.n; + int64_t k = dot_canonical_dims_.k; + + bool transpose_lhs = !dot_canonical_dims_.lhs_canonical; + bool transpose_rhs = !dot_canonical_dims_.rhs_canonical; - CHECK_EQ(matmul_dims.lhs_column_major, matmul_dims.rhs_column_major); - if (!matmul_dims.lhs_column_major) { - std::swap(matmul_dims.m, matmul_dims.n); + CHECK_EQ(dot_canonical_dims_.lhs_column_major, + dot_canonical_dims_.rhs_column_major); + if (!dot_canonical_dims_.lhs_column_major) { + std::swap(m, n); std::swap(lhs, rhs); std::swap(transpose_lhs, transpose_rhs); } - PrimitiveType element_type = lhs_matmul_shape_.element_type(); + PrimitiveType element_type = dot_shape_.lhs_matmul_shape.element_type(); int64_t byte_width = primitive_util::ByteWidth(element_type); - int64_t lhs_stride = matmul_dims.m * matmul_dims.k * byte_width; - int64_t rhs_stride = matmul_dims.k * matmul_dims.n * byte_width; - int64_t out_stride = matmul_dims.m * matmul_dims.n * byte_width; + int64_t lhs_stride = m * k * byte_width; + int64_t rhs_stride = k * n * byte_width; + int64_t out_stride = m * n * byte_width; auto batch_ptr = [&](void* ptr, int64_t stride, int64_t index) -> void* { return static_cast(ptr) + stride * index; }; - tsl::CountDownAsyncValueRef state(batch_size_); + tsl::CountDownAsyncValueRef state(dot_shape_.batch_size); auto dispatch = [&](auto type_tag) { - for (int64_t i = 0; i < batch_size_; ++i) { + for (int64_t i = 0; i < dot_shape_.batch_size; ++i) { TypedMatMul( params.intra_op_threadpool, batch_ptr(out, out_stride, i), - batch_ptr(lhs, lhs_stride, i), batch_ptr(rhs, rhs_stride, i), - matmul_dims.m, matmul_dims.n, matmul_dims.k, transpose_lhs, - transpose_rhs, [state]() mutable { state.CountDown(); }); + batch_ptr(lhs, lhs_stride, i), batch_ptr(rhs, rhs_stride, i), m, n, k, + transpose_lhs, transpose_rhs, + [state]() mutable { state.CountDown(); }); } }; diff --git a/third_party/xla/xla/backends/cpu/runtime/dot_thunk.h b/third_party/xla/xla/backends/cpu/runtime/dot_thunk.h index 61bcb8194e1150..15b5b97fd33c22 100644 --- a/third_party/xla/xla/backends/cpu/runtime/dot_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_BACKENDS_CPU_RUNTIME_DOT_THUNK_H_ #define XLA_BACKENDS_CPU_RUNTIME_DOT_THUNK_H_ +#include "xla/backends/cpu/runtime/dot_lib.h" #define EIGEN_USE_THREADS #include @@ -30,7 +31,6 @@ limitations under the License. #include "Eigen/Core" #include "unsupported/Eigen/CXX11/Tensor" #include "xla/backends/cpu/runtime/thunk.h" -#include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -48,18 +48,11 @@ class DotThunk final : public Thunk { tsl::AsyncValueRef Execute(const ExecuteParams& params) final; - BufferUses buffer_uses() const final { - return {BufferUse::Read(lhs_buffer_), BufferUse::Read(rhs_buffer_), - BufferUse::Write(out_buffer_)}; - } + BufferUses buffer_uses() const final { return DotBufferUses(dot_slices_); } private: - DotThunk(Info info, DotDimensionNumbers dot_dimensions, - BufferAllocation::Slice lhs_buffer, Shape lhs_shape, - BufferAllocation::Slice rhs_buffer, Shape rhs_shape, - BufferAllocation::Slice out_buffer, Shape out_shape, - int64_t batch_size, Shape lhs_matmul_shape, Shape rhs_matmul_shape, - Shape out_matmul_shape); + DotThunk(Info info, DotDimensionNumbers dot_dimensions, DotSlices dot_slices, + DotShape dot_shape, DotCanonicalDims dot_canonical_dims); using DoneCallback = absl::AnyInvocable; @@ -77,23 +70,9 @@ class DotThunk final : public Thunk { DoneCallback done); DotDimensionNumbers dot_dimensions_; - - BufferAllocation::Slice lhs_buffer_; - Shape lhs_shape_; - - BufferAllocation::Slice rhs_buffer_; - Shape rhs_shape_; - - BufferAllocation::Slice out_buffer_; - Shape out_shape_; - - // Product of batch dimensions. - int64_t batch_size_; - - // Shapes of the non-batch matrix-multiplication for the dot operation - Shape lhs_matmul_shape_; - Shape rhs_matmul_shape_; - Shape out_matmul_shape_; + DotSlices dot_slices_; + DotShape dot_shape_; + DotCanonicalDims dot_canonical_dims_; // Contracting dimensions of the LHS and RHS matmul shapes. absl::InlinedVector lhs_matmul_contracting_dims_; diff --git a/third_party/xla/xla/backends/cpu/runtime/function_library.h b/third_party/xla/xla/backends/cpu/runtime/function_library.h index 68c92f26936b85..76e213c0296faf 100644 --- a/third_party/xla/xla/backends/cpu/runtime/function_library.h +++ b/third_party/xla/xla/backends/cpu/runtime/function_library.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include @@ -69,7 +68,7 @@ class FunctionLibrary { } template >* = nullptr> - absl::StatusOr ResolveFunction(std::string_view name) { + absl::StatusOr ResolveFunction(absl::string_view name) { TF_ASSIGN_OR_RETURN(void* ptr, ResolveFunction(GetTypeId(), name)); return reinterpret_cast(ptr); } @@ -79,7 +78,7 @@ class FunctionLibrary { // id. Implementation might choose not to verify the type id and then it is up // to the caller to ensure the resolved function is of the correct type. virtual absl::StatusOr ResolveFunction(TypeId type_id, - std::string_view name) = 0; + absl::string_view name) = 0; private: // Returns a type id for a given function type. diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel.cc b/third_party/xla/xla/backends/cpu/runtime/kernel.cc index c554667e152f65..ae5bf9be3dd3b3 100644 --- a/third_party/xla/xla/backends/cpu/runtime/kernel.cc +++ b/third_party/xla/xla/backends/cpu/runtime/kernel.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include @@ -205,13 +206,25 @@ void KernelExecuteState::CallSync(uint64_t task_index) { void KernelExecuteState::CallAsync(uint64_t start_index, uint64_t end_index) { CHECK_LT(start_index, end_index) << "Invalid task index range"; // Crash OK - while (end_index - start_index > 1) { - uint64_t mid_index = (start_index + end_index) / 2; - task_runner_([self = this, mid_index, end_index] { - self->CallAsync(mid_index, end_index); - }); - end_index = mid_index; + + auto dispatch = [&](auto index_type) { + using Index = decltype(index_type); + while (end_index - start_index > 1) { + uint64_t mid_index = (start_index + end_index) / 2; + task_runner_([self = this, mid = Index(mid_index), + end = Index(end_index)] { self->CallAsync(mid, end); }); + end_index = mid_index; + } + }; + + // If the number of tasks is small, we can use uint16_t to index them and hit + // small object optimization in the std::function and avoid a heap allocation. + if (ABSL_PREDICT_TRUE(end_index <= std::numeric_limits::max())) { + dispatch(uint16_t{}); + } else { + dispatch(uint64_t{}); } + CallSync(start_index); } diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc index 0b591a0b5855b6..2578dc1b7c85ac 100644 --- a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc @@ -15,8 +15,6 @@ limitations under the License. #include "xla/backends/cpu/runtime/kernel_thunk.h" -#define EIGEN_USE_THREADS - #include #include #include @@ -24,18 +22,22 @@ limitations under the License. #include #include #include +#include #include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/base/call_once.h" #include "absl/base/optimization.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/memory/memory.h" #include "absl/numeric/bits.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" -#include "unsupported/Eigen/CXX11/Tensor" +#include "xla/backends/cpu/codegen/llvm_ir_kernel_spec.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/function_library.h" #include "xla/backends/cpu/runtime/kernel.h" @@ -46,12 +48,14 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/launch_dim.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/errors.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" +#define EIGEN_USE_THREADS +#include "unsupported/Eigen/CXX11/Tensor" + namespace xla::cpu { namespace internal { @@ -111,8 +115,9 @@ template KernelThunk::KernelThunk( Info info, absl::Span arguments_buffers, absl::Span results_buffers, - absl::flat_hash_set invariant_arguments, std::string kernel_name, - se::ThreadDim thread_dim, std::optional min_alignment) + std::optional> invariant_arguments, + std::string kernel_name, se::ThreadDim thread_dim, + std::optional min_alignment) : Thunk(Kind::kKernel, std::move(info)), invariant_arguments_(std::move(invariant_arguments)), num_kernel_args_(arguments_buffers.size() + results_buffers.size()), @@ -200,7 +205,9 @@ KernelThunk::ExecuteInternal( // TODO(abanas): Check also for overlapping buffers. TF_RETURN_IF_ERROR( CheckBufferAlignment(info(), min_alignment_.value_or(0), kernel_args)); - TF_RETURN_IF_ERROR(CheckInvariantBuffersMemory(kernel_args)); + if (invariant_arguments_.has_value()) { + TF_RETURN_IF_ERROR(CheckInvariantBuffersMemory(kernel_args)); + } } // TODO(ezhulenev): Kernel ptr should be loaded as a part of Thunk @@ -252,9 +259,10 @@ template absl::Status KernelThunk::CheckInvariantBuffersMemory( const KernelArgs& kernel_args) const { + CHECK(invariant_arguments_.has_value()); // Crash OK if (ABSL_PREDICT_FALSE(VLOG_IS_ON(10))) { VLOG(10) << "Verify invariant buffers: "; - for (auto index : invariant_arguments_) { + for (auto index : *invariant_arguments_) { VLOG(10) << absl::StreamFormat(" invariant arg id: %d", index); } } @@ -267,7 +275,7 @@ KernelThunk::CheckInvariantBuffersMemory( // Verify all argument buffers. for (int64_t i = 0; i < arguments.size(); ++i) { const XLA_CPU_KernelArg& argument = arguments[i]; - if (invariant_arguments_.contains(i)) { + if (invariant_arguments_->contains(i)) { // This argument should be read only, i.e. not one of the results. if (Contains(results, argument)) { return Internal("Mismatch in invariant buffers metadata"); @@ -308,7 +316,7 @@ absl::StatusOr> KernelThunk::Create( absl::Span arguments_buffers, absl::Span results_buffers, std::string kernel_name, se::ThreadDim thread_dim, - absl::flat_hash_set invariant_arguments, + std::optional> invariant_arguments, std::optional min_alignment) { if (min_alignment.has_value() && !absl::has_single_bit(*min_alignment)) { return Internal("Host kernel %s minimum alignment %d is not a power of 2", @@ -350,4 +358,25 @@ absl::StatusOr> KernelThunk::Create( thread_dim, min_alignment)); } +absl::StatusOr> KernelThunk::Create( + Thunk::Info info, std::unique_ptr kernel_spec, + std::optional min_alignment) { + std::vector arguments_buffers; + std::vector results_buffers; + + for (const BufferUse& buffer_use : kernel_spec->buffer_uses()) { + if (buffer_use.access() == BufferUse::kRead) { + arguments_buffers.push_back(buffer_use.slice()); + } else { + results_buffers.push_back(buffer_use.slice()); + } + } + + const std::string& kernel_name = kernel_spec->kernel_source().kernel_name(); + + return Create(std::move(info), arguments_buffers, results_buffers, + std::move(kernel_name), kernel_spec->thread_dim(), std::nullopt, + min_alignment); +} + } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.h b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.h index 4e11b4ad2e1996..173f44420719ab 100644 --- a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.h @@ -33,6 +33,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/cpu/codegen/llvm_ir_kernel_spec.h" #include "xla/backends/cpu/runtime/kernel.h" #include "xla/backends/cpu/runtime/kernel_c_api.h" #include "xla/backends/cpu/runtime/thunk.h" @@ -95,7 +96,7 @@ class KernelThunk : public Thunk { KernelThunk(Info info, absl::Span arguments_buffers, absl::Span results_buffers, - absl::flat_hash_set invariant_arguments, + std::optional> invariant_arguments, std::string kernel_name, se::ThreadDim thread_dim, std::optional min_alignment); @@ -105,7 +106,7 @@ class KernelThunk : public Thunk { ResultsBuffers results_buffers_; // A set of invariant arguments (their indices). - absl::flat_hash_set invariant_arguments_; + std::optional> invariant_arguments_; size_t num_kernel_args_; @@ -155,9 +156,13 @@ class KernelThunk final : public internal::KernelThunk<> { absl::Span arguments_buffers, absl::Span results_buffers, std::string kernel_name, se::ThreadDim thread_dim, - absl::flat_hash_set invariant_arguments, + std::optional> invariant_arguments, std::optional min_alignment = std::nullopt); + static absl::StatusOr> Create( + Thunk::Info info, std::unique_ptr kernel_spec, + std::optional min_alignment); + tsl::AsyncValueRef Execute( const Thunk::ExecuteParams& params) final; }; diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk_test.cc index eed4eec1ce90db..2cdd55e9ecdcd1 100644 --- a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk_test.cc @@ -15,25 +15,23 @@ limitations under the License. #include "xla/backends/cpu/runtime/kernel_thunk.h" -#include #include -#include -#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" +#include "absl/strings/string_view.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/function_library.h" #include "xla/backends/cpu/runtime/kernel_c_api.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_testlib.h" +#include "xla/literal_util.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/maybe_owning_device_memory.h" -#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/launch_dim.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla::cpu { namespace { @@ -41,7 +39,7 @@ namespace { class AddF32HostKernel : public FunctionLibrary { public: absl::StatusOr ResolveFunction(TypeId type_id, - std::string_view name) final { + absl::string_view name) final { auto kernel = +[](const XLA_CPU_KernelCallFrame* call_frame) { const XLA_CPU_KernelArg& in = call_frame->args[0]; const XLA_CPU_KernelArg& out = call_frame->args[1]; @@ -67,26 +65,18 @@ TEST(KernelThunkTest, CheckAlignment) { } TEST(KernelThunkTest, AddF32) { - std::vector buffers; - std::vector in = {1.0, 2.0, 3.0, 4.0}; - std::vector out(4, 0.0); + auto in = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto out = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); - size_t size_in_bytes = in.size() * sizeof(float); - buffers.emplace_back(se::DeviceMemoryBase(in.data(), size_in_bytes)); - buffers.emplace_back(se::DeviceMemoryBase(out.data(), size_in_bytes)); + BufferAllocations allocations = CreateBufferAllocations(in, out); - BufferAllocations allocations(buffers); - - BufferAllocation in_alloc(0, size_in_bytes, 0); - BufferAllocation out_alloc(1, size_in_bytes, 0); - - BufferAllocation::Slice in_slice(&in_alloc, 0, size_in_bytes); - BufferAllocation::Slice out_slice(&out_alloc, 0, size_in_bytes); + auto [in_alloc, out_alloc] = CreateBufferAllocation(in, out); + auto [in_slice, out_slice] = CreateBufferAllocationSlice(in_alloc, out_alloc); TF_ASSERT_OK_AND_ASSIGN( auto thunk, KernelThunk::Create({"add_f32"}, {in_slice}, {out_slice}, "add_f32", - se::ThreadDim(4), /*invariant_arguments=*/{0})); + se::ThreadDim(4), /*invariant_arguments=*/{{0}})); AddF32HostKernel host_kernels; Thunk::ExecuteParams params = {&host_kernels, &allocations}; @@ -95,25 +85,21 @@ TEST(KernelThunkTest, AddF32) { tsl::BlockUntilReady(execute_event); ASSERT_FALSE(execute_event.IsError()) << execute_event.GetError(); - std::vector expected = {2.0, 4.0, 6.0, 8.0}; - EXPECT_EQ(out, expected); + EXPECT_EQ(out, LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}})); } TEST(KernelThunkTest, AddF32Inline) { - std::vector buffers; - std::vector in_out = {1.0, 2.0, 3.0, 4.0}; + auto in_out = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - size_t size_in_bytes = in_out.size() * sizeof(float); - buffers.emplace_back(se::DeviceMemoryBase(in_out.data(), size_in_bytes)); + BufferAllocations allocations = CreateBufferAllocations(in_out); - BufferAllocations allocations(buffers); - BufferAllocation in_out_alloc(0, size_in_bytes, 0); - BufferAllocation::Slice in_out_slice(&in_out_alloc, 0, size_in_bytes); + BufferAllocation alloc = CreateBufferAllocation(0, in_out); + BufferAllocation::Slice slice = CreateBufferAllocationSlice(alloc); TF_ASSERT_OK_AND_ASSIGN( - auto thunk, KernelThunk::Create( - {"add_f32"}, {in_out_slice}, {in_out_slice}, "add_f32", - se::ThreadDim(4), /*invariant_arguments=*/{})); + auto thunk, + KernelThunk::Create({"add_f32"}, {slice}, {slice}, "add_f32", + se::ThreadDim(4), /*invariant_arguments=*/{{}})); AddF32HostKernel host_kernels; Thunk::ExecuteParams params = {&host_kernels, &allocations}; @@ -122,8 +108,7 @@ TEST(KernelThunkTest, AddF32Inline) { tsl::BlockUntilReady(execute_event); ASSERT_FALSE(execute_event.IsError()); - std::vector expected = {2.0, 4.0, 6.0, 8.0}; - EXPECT_EQ(in_out, expected); + EXPECT_EQ(in_out, LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}})); } TEST(KernelThunkInvariantBuffersTest, MissingBufferSlice) { @@ -131,27 +116,19 @@ TEST(KernelThunkInvariantBuffersTest, MissingBufferSlice) { GTEST_SKIP() << "Invariant buffers check is disabled in optimized build."; #endif - std::vector buffers; - std::vector in = {1.0, 2.0, 3.0, 4.0}; - std::vector out(4, 0.0); - - size_t size_in_bytes = in.size() * sizeof(float); - buffers.emplace_back(se::DeviceMemoryBase(in.data(), size_in_bytes)); - buffers.emplace_back(se::DeviceMemoryBase(out.data(), size_in_bytes)); - - BufferAllocations allocations(buffers); + auto in = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto out = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); - BufferAllocation in_alloc(0, size_in_bytes, 0); - BufferAllocation out_alloc(1, size_in_bytes, 0); + BufferAllocations allocations = CreateBufferAllocations(in, out); - BufferAllocation::Slice in_slice(&in_alloc, 0, size_in_bytes); - BufferAllocation::Slice out_slice(&out_alloc, 0, size_in_bytes); + auto [in_alloc, out_alloc] = CreateBufferAllocation(in, out); + auto [in_slice, out_slice] = CreateBufferAllocationSlice(in_alloc, out_alloc); // Invariant buffer set is incorrect - should include in_slice, but is empty. TF_ASSERT_OK_AND_ASSIGN( auto thunk, KernelThunk::Create({"add_f32"}, {in_slice}, {out_slice}, "add_f32", - se::ThreadDim(4), /*invariant_arguments=*/{})); + se::ThreadDim(4), /*invariant_arguments=*/{{}})); AddF32HostKernel host_kernels; Thunk::ExecuteParams params = {&host_kernels, &allocations}; @@ -171,22 +148,18 @@ TEST(KernelThunkInvariantBuffersTest, ExtraInputOutputBufferSlice) { GTEST_SKIP() << "Invariant buffers check is disabled in optimized build."; #endif - std::vector buffers; - std::vector in_out = {1.0, 2.0, 3.0, 4.0}; + auto in_out = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + BufferAllocations allocations = CreateBufferAllocations(in_out); - size_t size_in_bytes = in_out.size() * sizeof(float); - buffers.emplace_back(se::DeviceMemoryBase(in_out.data(), size_in_bytes)); - - BufferAllocations allocations(buffers); - BufferAllocation in_out_alloc(0, size_in_bytes, 0); - BufferAllocation::Slice in_out_slice(&in_out_alloc, 0, size_in_bytes); + BufferAllocation alloc = CreateBufferAllocation(0, in_out); + BufferAllocation::Slice slice = CreateBufferAllocationSlice(alloc); // Invariant buffer set is incorrect - should be empty, but contains input // buffer that's not invariant. TF_ASSERT_OK_AND_ASSIGN( - auto thunk, KernelThunk::Create( - {"add_f32"}, {in_out_slice}, {in_out_slice}, "add_f32", - se::ThreadDim(4), /*invariant_arguments=*/{0})); + auto thunk, + KernelThunk::Create({"add_f32"}, {slice}, {slice}, "add_f32", + se::ThreadDim(4), /*invariant_arguments=*/{{0}})); AddF32HostKernel host_kernels; Thunk::ExecuteParams params = {&host_kernels, &allocations}; @@ -209,31 +182,21 @@ TEST(KernelThunkInvariantBuffersTest, GTEST_SKIP() << "Invariant buffers check is disabled in optimized build."; #endif - // We've got only one memory section - std::vector buffers; - std::vector in_out = {1.0, 2.0, 3.0, 4.0}; - - // We've got two buffer slices with different indexes, but both pointing to - // the same memory section. - size_t size_in_bytes = in_out.size() * sizeof(float); - buffers.emplace_back(se::DeviceMemoryBase(in_out.data(), size_in_bytes)); - buffers.emplace_back(se::DeviceMemoryBase(in_out.data(), size_in_bytes)); - - BufferAllocations allocations(buffers); - - BufferAllocation in_0_alloc(0, size_in_bytes, 0); - BufferAllocation in_1_alloc(1, size_in_bytes, 0); + // We've got only one literal, but two buffer slices that point to the same + // memory region. + auto data = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + BufferAllocations allocations = CreateBufferAllocations(data, data); - BufferAllocation::Slice in_0_slice(&in_0_alloc, 0, size_in_bytes); - BufferAllocation::Slice in_1_slice(&in_1_alloc, 0, size_in_bytes); + auto [alloc_0, alloc_1] = CreateBufferAllocation(data, data); + auto [slice_0, slice_1] = CreateBufferAllocationSlice(alloc_0, alloc_1); - // Invariant buffer set is incorrect. in_1_slice is not aliased to any output, - // but it points to the same memory section as in_0_slice (which is not - // invariant, because is aliased with the output). + // Invariant buffer set is incorrect. slice_1 is not aliased to any output, + // but it points to the same memory region as slice_0 (which is not + // invariant, because it is aliased with the output). TF_ASSERT_OK_AND_ASSIGN( - auto thunk, KernelThunk::Create({"add_f32"}, {in_0_slice, in_1_slice}, - {in_0_slice}, "add_f32", se::ThreadDim(4), - /*invariant_arguments=*/{1})); + auto thunk, KernelThunk::Create({"add_f32"}, {slice_0, slice_1}, + {slice_0}, "add_f32", se::ThreadDim(4), + /*invariant_arguments=*/{{1}})); AddF32HostKernel host_kernels; Thunk::ExecuteParams params = {&host_kernels, &allocations}; diff --git a/third_party/xla/xla/backends/cpu/runtime/logical_id_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/logical_id_thunk_test.cc index c8dd0a60782fed..6bf1a404469163 100644 --- a/third_party/xla/xla/backends/cpu/runtime/logical_id_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/logical_id_thunk_test.cc @@ -24,13 +24,13 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_testlib.h" #include "xla/executable_run_options.h" +#include "xla/literal_util.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/maybe_owning_device_memory.h" -#include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla::cpu { namespace { @@ -52,19 +52,15 @@ absl::StatusOr CreateDeviceAssignment( } TEST(LogicalIdThunkTest, GetReplicaId) { - std::vector dst(1, std::numeric_limits::min()); + auto dst = LiteralUtil::CreateR0(std::numeric_limits::min()); - std::vector buffers; - buffers.emplace_back(se::DeviceMemoryBase(dst.data(), sizeof(int32_t))); - - BufferAllocation alloc(/*index=*/0, /*size=*/sizeof(int32_t), /*color=*/0); - BufferAllocation::Slice id_slice(&alloc, /*offset=*/0, - /*size=*/sizeof(int32_t)); + BufferAllocation alloc = CreateBufferAllocation(0, dst); + BufferAllocation::Slice id_slice = CreateBufferAllocationSlice(alloc); std::string name(Thunk::KindToString(Thunk::Kind::kReplicaId)); TF_ASSERT_OK_AND_ASSIGN(auto thunk, ReplicaIdThunk::Create({name}, id_slice)); - BufferAllocations allocations(buffers); + BufferAllocations allocations = CreateBufferAllocations(dst); TF_ASSERT_OK_AND_ASSIGN(DeviceAssignment device_assn, CreateDeviceAssignment({{0, 1}})); @@ -83,25 +79,20 @@ TEST(LogicalIdThunkTest, GetReplicaId) { tsl::BlockUntilReady(execute_event); ASSERT_FALSE(execute_event.IsError()); - EXPECT_EQ(dst[0], 0); + EXPECT_EQ(dst, LiteralUtil::CreateR0(0)); } TEST(LogicalIdThunkTest, GetPartitionId) { - std::vector dst(2, std::numeric_limits::min()); - - std::vector buffers; - static constexpr auto kDataSize = 2 * sizeof(int32_t); - buffers.emplace_back(se::DeviceMemoryBase(dst.data(), kDataSize)); + auto dst = LiteralUtil::CreateR0(std::numeric_limits::min()); - BufferAllocation alloc(/*index=*/0, /*size=*/kDataSize, /*color=*/0); - BufferAllocation::Slice id_slice(&alloc, /*offset=*/sizeof(int32_t), - /*size=*/sizeof(int32_t)); + BufferAllocation alloc = CreateBufferAllocation(0, dst); + BufferAllocation::Slice id_slice = CreateBufferAllocationSlice(alloc); std::string name(Thunk::KindToString(Thunk::Kind::kPartitionId)); TF_ASSERT_OK_AND_ASSIGN(auto thunk, PartitionIdThunk::Create({name}, id_slice)); - BufferAllocations allocations(buffers); + BufferAllocations allocations = CreateBufferAllocations(dst); TF_ASSERT_OK_AND_ASSIGN(DeviceAssignment device_assn, CreateDeviceAssignment({{0}, {1}})); @@ -120,8 +111,7 @@ TEST(LogicalIdThunkTest, GetPartitionId) { tsl::BlockUntilReady(execute_event); ASSERT_FALSE(execute_event.IsError()); - EXPECT_EQ(dst[0], std::numeric_limits::min()); - EXPECT_EQ(dst[1], 0); + EXPECT_EQ(dst, LiteralUtil::CreateR0(0)); } } // namespace diff --git a/third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.cc index 920aa3dc545b19..570621d6c970eb 100644 --- a/third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.cc @@ -24,12 +24,12 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/collective_thunk.h" #include "xla/backends/cpu/runtime/thunk.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -89,14 +89,16 @@ ReduceScatterThunk::Execute(const ExecuteParams& params) { return ExecuteWithCommunicator( params.collective_params, - [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + [&](const RendezvousKey& key, Communicator& comm) { + CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); + for (int32_t i = 0; i < data.source.size(); ++i) { const Shape& shape = destination_shape(i); TF_RETURN_IF_ERROR(comm.ReduceScatter( - key, reduction_kind_, shape.element_type(), - ShapeUtil::ElementsIn(shape), data.source[i].opaque(), - data.destination[i].opaque(), DefaultCollectiveTimeout())); + data.source[i], data.destination[i], shape.element_type(), + ShapeUtil::ElementsIn(shape), reduction_kind_, executor)); } + return absl::OkStatus(); }); } diff --git a/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc index 3b3c5381883257..96534db43b1345 100644 --- a/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc @@ -31,6 +31,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/base/attributes.h" #include "absl/base/call_once.h" #include "absl/base/dynamic_annotations.h" #include "absl/base/optimization.h" @@ -51,10 +52,10 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { @@ -138,6 +139,76 @@ namespace { // The size of the largest element we support (std::complex). static constexpr size_t kMaxElementSize = 16; +// Type erased storage suitable for storing any primitive type. +using ValueStorage = std::array; + +// Pointers to the input arrays together with their primitive sizes. +template +class Inputs { + public: + Inputs(std::array ptrs, + std::array primitive_sizes) { + for (size_t i = 0; i < n; ++i) { + ptrs_and_primitive_sizes_[i] = {ptrs[i], primitive_sizes[i]}; + } + } + + // Accessing arrays with `operator[]` has zero overheads, so we don't need to + // use pointers to data in contrast to `DInputs` below. + + std::byte* ptr(size_t i, size_t offset) const { + DCHECK_LT(i, n) << "Input index out of bounds"; + auto& [ptr, primitive_size] = ptrs_and_primitive_sizes_[i]; + return ptr + offset * primitive_size; + } + + size_t primitive_size(size_t i) const { + return ptrs_and_primitive_sizes_[i].second; + } + + private: + // Pointers into the input buffers and each input's primitive size. Keep + // pointers and primitives sizes next to each other to avoid cache misses + // on a hot path. + std::array, n> ptrs_and_primitive_sizes_; +}; + +class DInputs { + public: + DInputs(std::vector ptrs, std::vector primitive_sizes) + : n_(ptrs.size()), ptrs_and_primitive_sizes_(ptrs.size()) { + DCHECK_EQ(ptrs.size(), primitive_sizes.size()); + for (size_t i = 0; i < ptrs.size(); ++i) { + ptrs_and_primitive_sizes_[i] = {ptrs[i], primitive_sizes[i]}; + } + } + + size_t n() const { return n_; } + + // Accessing vectors with `operator[]` is significantly slower than using a + // pointer to data because of libc++ hardening which checks for OOB access on + // every call. We know that we are not going to access out of bounds, so we + // use a pointer to data instead. + + std::byte* ptr(size_t i, size_t offset) const { + DCHECK_LT(i, n_) << "Input index out of bounds"; + auto& [ptr, primitive_size] = ptrs_and_primitive_sizes_.data()[i]; + return ptr + offset * primitive_size; + } + + size_t primitive_size(size_t i) const { + return ptrs_and_primitive_sizes_.data()[i].second; + } + + private: + size_t n_; // number of sorted inputs + + // Pointers into the input buffers and each input's primitive size. Keep + // pointers and primitives sizes next to each other to avoid cache misses + // on a hot path. + std::vector> ptrs_and_primitive_sizes_; +}; + // Forward declare reference type defined below. template struct Ref; @@ -148,125 +219,212 @@ template struct Value { Value(const Ref& ref); // NOLINT - const void* compared_value(size_t i) const { return value[i].data(); } + void FillComparedValues(const void** __restrict compared_values) const; - // Use properly aligned byte array to store primitive values. - using ValueStorage = std::array; - alignas(alignof(std::max_align_t)) std::array value; - std::array value_sizes; + std::array values; }; struct DValue { DValue(const DRef& ref); // NOLINT - const void* compared_value(size_t i) const { return value[i].data(); } + void FillComparedValues(const void** __restrict compared_values) const; - // Use properly aligned byte array to store primitive values. - using ValueStorage = std::array; - std::vector value; - std::vector value_sizes; - size_t n; + std::vector values; }; // Reference to values stored in the input buffers. template struct Ref { - Ref(std::array ptr, std::array ptr_sizes) - : ptr(ptr), ptr_sizes(ptr_sizes) {} + Ref(const Inputs* inputs, size_t offset) + : inputs(inputs), offset(offset) {} Ref& operator=(const Value& value); Ref& operator=(const Ref& other); - const void* compared_value(size_t i) const { return ptr[i]; } + void FillComparedValues(const void** __restrict compared_values) const; + + std::byte* ptr(size_t i) const { return inputs->ptr(i, offset); } + size_t primitive_size(size_t i) const { return inputs->primitive_size(i); } - std::array ptr; - std::array ptr_sizes; + const Inputs* inputs; + size_t offset; }; struct DRef { - DRef(std::vector ptr, std::vector ptr_sizes) - : ptr(ptr), ptr_sizes(ptr_sizes), n(ptr.size()) {} + DRef(const DInputs* inputs, size_t offset) : inputs(inputs), offset(offset) {} DRef& operator=(const DValue& value); DRef& operator=(const DRef& other); - const void* compared_value(size_t i) const { return ptr[i]; } + void FillComparedValues(const void** __restrict compared_values) const; - std::vector ptr; - std::vector ptr_sizes; - const size_t n; + size_t n() const { return inputs->n(); } + std::byte* ptr(size_t i) const { return inputs->ptr(i, offset); } + size_t primitive_size(size_t i) const { return inputs->primitive_size(i); } + + const DInputs* inputs; + size_t offset; }; +// We know that we can only copy up to 16 bytes for the largest element type +// and can specialize `std::memcpy` to allow LLVM to inline it with statically +// known sizes. +static ABSL_ATTRIBUTE_ALWAYS_INLINE void Memcpy(void* __restrict dest, + const void* __restrict src, + size_t n) { + switch (n) { + case 1: + std::memcpy(dest, src, 1); + break; + case 2: + std::memcpy(dest, src, 2); + break; + case 4: + std::memcpy(dest, src, 4); + break; + case 8: + std::memcpy(dest, src, 8); + break; + case 16: + std::memcpy(dest, src, 16); + break; + default: + LOG(FATAL) << "Unsupported memcpy size: " << n; + } +} + +// Specialize swap for statically known sizes to avoid going through the same +// switch statement multiple times. +static ABSL_ATTRIBUTE_ALWAYS_INLINE void Swap(void* __restrict a, + void* __restrict b, size_t n) { + std::array tmp; + switch (n) { + case 1: + std::memcpy(tmp.data(), a, 1); + std::memcpy(a, b, 1); + std::memcpy(b, tmp.data(), 1); + break; + case 2: + std::memcpy(tmp.data(), a, 2); + std::memcpy(a, b, 2); + std::memcpy(b, tmp.data(), 2); + break; + case 4: + std::memcpy(tmp.data(), a, 4); + std::memcpy(a, b, 4); + std::memcpy(b, tmp.data(), 4); + break; + case 8: + std::memcpy(tmp.data(), a, 8); + std::memcpy(a, b, 8); + std::memcpy(b, tmp.data(), 8); + break; + case 16: + std::memcpy(tmp.data(), a, 16); + std::memcpy(a, b, 16); + std::memcpy(b, tmp.data(), 16); + break; + default: + LOG(FATAL) << "Unsupported swap size: " << n; + } +} + template -Value::Value(const Ref& ref) : value_sizes(ref.ptr_sizes) { +ABSL_ATTRIBUTE_ALWAYS_INLINE Value::Value(const Ref& ref) { for (size_t i = 0; i < n; ++i) { - std::memcpy(value[i].data(), ref.ptr[i], ref.ptr_sizes[i]); + Memcpy(values[i].data(), ref.ptr(i), ref.primitive_size(i)); } } -DValue::DValue(const DRef& ref) - : value_sizes(ref.ptr_sizes), n(ref.ptr.size()) { - value.reserve(n); - for (size_t i = 0; i < n; ++i) { - value.emplace_back(); - std::memcpy(value[i].data(), ref.ptr[i], ref.ptr_sizes[i]); +template +ABSL_ATTRIBUTE_ALWAYS_INLINE void Value::FillComparedValues( + const void** __restrict compared_values) const { + for (const ValueStorage& value : values) { + *compared_values = value.data(); + compared_values += 2; + } +} + +ABSL_ATTRIBUTE_ALWAYS_INLINE DValue::DValue(const DRef& ref) : values(ref.n()) { + for (size_t i = 0, end = ref.n(); i < end; ++i) { + Memcpy(values.data()[i].data(), ref.ptr(i), ref.primitive_size(i)); + } +} + +ABSL_ATTRIBUTE_ALWAYS_INLINE void DValue::FillComparedValues( + const void** __restrict compared_values) const { +#pragma unroll 8 + for (const ValueStorage& value : values) { + *compared_values = value.data(); + compared_values += 2; } } template -Ref& Ref::operator=(const Value& value) { - DCHECK(ptr_sizes == value.value_sizes); +ABSL_ATTRIBUTE_ALWAYS_INLINE Ref& Ref::operator=(const Value& value) { for (size_t i = 0; i < n; ++i) { - std::memcpy(ptr[i], value.value[i].data(), value.value_sizes[i]); + Memcpy(ptr(i), value.values.data()[i].data(), primitive_size(i)); } return *this; } -DRef& DRef::operator=(const DValue& value) { - DCHECK(ptr_sizes == value.value_sizes); +template +ABSL_ATTRIBUTE_ALWAYS_INLINE Ref& Ref::operator=(const Ref& other) { for (size_t i = 0; i < n; ++i) { - std::memcpy(ptr[i], value.value[i].data(), value.value_sizes[i]); + DCHECK_EQ(primitive_size(i), other.primitive_size(i)); + Memcpy(ptr(i), other.ptr(i), primitive_size(i)); } return *this; } template -Ref& Ref::operator=(const Ref& other) { - DCHECK(ptr_sizes == other.ptr_sizes); +ABSL_ATTRIBUTE_ALWAYS_INLINE void Ref::FillComparedValues( + const void** __restrict compared_values) const { for (size_t i = 0; i < n; ++i) { - std::memcpy(ptr[i], other.ptr[i], other.ptr_sizes[i]); + *compared_values = ptr(i); + compared_values += 2; + } +} + +ABSL_ATTRIBUTE_ALWAYS_INLINE DRef& DRef::operator=(const DValue& value) { + for (size_t i = 0, end = n(); i < end; ++i) { + Memcpy(ptr(i), value.values.data()[i].data(), primitive_size(i)); } return *this; } -DRef& DRef::operator=(const DRef& other) { - DCHECK(ptr_sizes == other.ptr_sizes); - const size_t n = other.ptr.size(); - for (size_t i = 0; i < n; ++i) { - std::memcpy(ptr[i], other.ptr[i], other.ptr_sizes[i]); +ABSL_ATTRIBUTE_ALWAYS_INLINE DRef& DRef::operator=(const DRef& other) { + for (size_t i = 0, end = n(); i < end; ++i) { + DCHECK_EQ(primitive_size(i), other.primitive_size(i)); + Memcpy(ptr(i), other.ptr(i), primitive_size(i)); } return *this; } +ABSL_ATTRIBUTE_ALWAYS_INLINE void DRef::FillComparedValues( + const void** __restrict compared_values) const { +#pragma unroll 8 + for (size_t i = 0, end = n(); i < end; ++i) { + *compared_values = ptr(i); + compared_values += 2; + } +} + // Swap function required by `std::sort` and `std::stable_sort` implementations. template -void swap(const Ref& lhs, const Ref& rhs) { +ABSL_ATTRIBUTE_ALWAYS_INLINE void swap(const Ref& lhs, const Ref& rhs) { for (size_t i = 0; i < n; ++i) { - std::array tmp; - std::memcpy(tmp.data(), lhs.ptr[i], lhs.ptr_sizes[i]); - std::memcpy(lhs.ptr[i], rhs.ptr[i], rhs.ptr_sizes[i]); - std::memcpy(rhs.ptr[i], tmp.data(), lhs.ptr_sizes[i]); + DCHECK_EQ(lhs.primitive_size(i), rhs.primitive_size(i)); + size_t primitive_size = lhs.primitive_size(i); + Swap(lhs.ptr(i), rhs.ptr(i), primitive_size); } } -void swap(const DRef& lhs, const DRef& rhs) { - DCHECK(lhs.ptr_sizes == rhs.ptr_sizes); - const size_t n = lhs.ptr.size(); - for (size_t i = 0; i < n; ++i) { - std::array tmp; - std::memcpy(tmp.data(), lhs.ptr[i], lhs.ptr_sizes[i]); - std::memcpy(lhs.ptr[i], rhs.ptr[i], rhs.ptr_sizes[i]); - std::memcpy(rhs.ptr[i], tmp.data(), lhs.ptr_sizes[i]); +ABSL_ATTRIBUTE_ALWAYS_INLINE void swap(const DRef& lhs, const DRef& rhs) { + for (size_t i = 0, end = lhs.n(); i < end; ++i) { + DCHECK_EQ(lhs.primitive_size(i), rhs.primitive_size(i)); + size_t primitive_size = lhs.primitive_size(i); + Swap(lhs.ptr(i), rhs.ptr(i), primitive_size); } } @@ -277,51 +435,42 @@ struct Ptr { Ptr() = default; - Ptr(std::array ptr, std::array ptr_sizes) - : ptr(ptr), ptr_sizes(ptr_sizes) {} + explicit Ptr(const Inputs* inputs, size_t offset = 0) + : inputs(inputs), offset(offset) {} - Ref operator*() const { return Ref{ptr, ptr_sizes}; } + Ref operator*() const { return Ref{inputs, offset}; } Ptr& operator+=(difference_type diff) { - for (size_t i = 0; i < n; ++i) ptr[i] += diff * ptr_sizes[i]; + offset += diff; return *this; } Ptr& operator-=(difference_type diff) { - for (size_t i = 0; i < n; ++i) ptr[i] -= diff * ptr_sizes[i]; + offset -= diff; return *this; } Ptr operator+(difference_type diff) const { - std::array upd; - for (size_t i = 0; i < n; ++i) upd[i] = ptr[i] + diff * ptr_sizes[i]; - return Ptr{upd, ptr_sizes}; + return Ptr(inputs, offset + diff); } Ptr operator-(difference_type diff) const { - std::array upd; - for (size_t i = 0; i < n; ++i) upd[i] = ptr[i] - diff * ptr_sizes[i]; - return Ptr{upd, ptr_sizes}; + return Ptr(inputs, offset - diff); } - // In all comparison operators defined below we use only the ptr at index 0, - // because we know that all pointers change together and this is an - // implementation detail of sort iterator. - difference_type operator-(const Ptr& rhs) const { - DCHECK(ptr_sizes == rhs.ptr_sizes); - return (ptr[0] - rhs.ptr[0]) / ptr_sizes[0]; + return offset - rhs.offset; } - bool operator==(const Ptr& rhs) const { return ptr[0] == rhs.ptr[0]; } - bool operator!=(const Ptr& rhs) const { return ptr[0] != rhs.ptr[0]; } - bool operator>(const Ptr& rhs) const { return ptr[0] > rhs.ptr[0]; } - bool operator<(const Ptr& rhs) const { return ptr[0] < rhs.ptr[0]; } - bool operator>=(const Ptr& rhs) const { return ptr[0] >= rhs.ptr[0]; } - bool operator<=(const Ptr& rhs) const { return ptr[0] <= rhs.ptr[0]; } + bool operator==(const Ptr& rhs) const { return offset == rhs.offset; } + bool operator!=(const Ptr& rhs) const { return offset != rhs.offset; } + bool operator>(const Ptr& rhs) const { return offset > rhs.offset; } + bool operator<(const Ptr& rhs) const { return offset < rhs.offset; } + bool operator>=(const Ptr& rhs) const { return offset >= rhs.offset; } + bool operator<=(const Ptr& rhs) const { return offset <= rhs.offset; } - std::array ptr; // pointers into the input buffers - std::array ptr_sizes; // pointers sizes in bytes + const Inputs* inputs; // pointer to the input arrays + size_t offset; // offset into the inputs arrays }; struct DPtr { @@ -329,52 +478,42 @@ struct DPtr { DPtr() = default; - DPtr(std::vector ptr, std::vector ptr_sizes) - : ptr(ptr), ptr_sizes(ptr_sizes), n(ptr.size()) {} + explicit DPtr(const DInputs* inputs, size_t offset = 0) + : inputs(inputs), offset(offset) {} - DRef operator*() const { return DRef{ptr, ptr_sizes}; } + DRef operator*() const { return DRef{inputs, offset}; } DPtr& operator+=(difference_type diff) { - for (size_t i = 0; i < n; ++i) ptr[i] += diff * ptr_sizes[i]; + offset += diff; return *this; } DPtr& operator-=(difference_type diff) { - for (size_t i = 0; i < n; ++i) ptr[i] -= diff * ptr_sizes[i]; + offset -= diff; return *this; } DPtr operator+(difference_type diff) const { - std::vector upd(n); - for (size_t i = 0; i < n; ++i) upd[i] = ptr[i] + diff * ptr_sizes[i]; - return DPtr{upd, ptr_sizes}; + return DPtr(inputs, offset + diff); } DPtr operator-(difference_type diff) const { - std::vector upd(n); - for (size_t i = 0; i < n; ++i) upd[i] = ptr[i] - diff * ptr_sizes[i]; - return DPtr{upd, ptr_sizes}; + return DPtr(inputs, offset - diff); } - // In all comparison operators defined below we use only the ptr at index 0, - // because we know that all pointers change together and this is an - // implementation detail of sort iterator. - difference_type operator-(const DPtr& rhs) const { - DCHECK(ptr_sizes == rhs.ptr_sizes); - return (ptr[0] - rhs.ptr[0]) / ptr_sizes[0]; + return offset - rhs.offset; } - bool operator==(const DPtr& rhs) const { return ptr[0] == rhs.ptr[0]; } - bool operator!=(const DPtr& rhs) const { return ptr[0] != rhs.ptr[0]; } - bool operator>(const DPtr& rhs) const { return ptr[0] > rhs.ptr[0]; } - bool operator<(const DPtr& rhs) const { return ptr[0] < rhs.ptr[0]; } - bool operator>=(const DPtr& rhs) const { return ptr[0] >= rhs.ptr[0]; } - bool operator<=(const DPtr& rhs) const { return ptr[0] <= rhs.ptr[0]; } + bool operator==(const DPtr& rhs) const { return offset == rhs.offset; } + bool operator!=(const DPtr& rhs) const { return offset != rhs.offset; } + bool operator>(const DPtr& rhs) const { return offset > rhs.offset; } + bool operator<(const DPtr& rhs) const { return offset < rhs.offset; } + bool operator>=(const DPtr& rhs) const { return offset >= rhs.offset; } + bool operator<=(const DPtr& rhs) const { return offset <= rhs.offset; } - std::vector ptr; // pointers into the input buffers - std::vector ptr_sizes; // pointers sizes in bytes - size_t n; + const DInputs* inputs; // pointer to the input arrays + size_t offset; // offset into the inputs arrays }; // We rely on `std::sort` and `std::stable_sort` to sort the raw data. We sort @@ -393,7 +532,7 @@ class SortIterator { SortIterator() = default; SortIterator(pointer ptr, difference_type stride) - : ptr_(ptr), stride_(stride) {} + : ptr_(std::move(ptr)), stride_(stride) {} SortIterator(const SortIterator& other) = default; SortIterator& operator=(const SortIterator& other) = default; @@ -401,6 +540,7 @@ class SortIterator { SortIterator& operator=(SortIterator&& other) = default; reference operator*() const { return *ptr_; } + reference operator[](difference_type diff) const { return *(*this + diff); } difference_type operator-(const SortIterator& rhs) const { return (ptr_ - rhs.ptr_) / stride_; @@ -538,27 +678,26 @@ static void SortInplace(const SortDims& sort_dims, int64_t offset, absl::Span data, absl::Span shapes, bool is_stable, SortThunk::LessThan* less_than) { - std::array ptr; - std::array ptr_sizes; + std::array ptrs; + std::array primitive_sizes; for (size_t i = 0; i < n; ++i) { std::byte* base = reinterpret_cast(data[i].opaque()); - ptr_sizes[i] = primitive_util::ByteWidth(shapes[i].element_type()); - ptr[i] = base + offset * ptr_sizes[i]; + primitive_sizes[i] = primitive_util::ByteWidth(shapes[i].element_type()); + ptrs[i] = base + offset * primitive_sizes[i]; } + Inputs inputs(ptrs, primitive_sizes); + auto compare = [&](const auto& a, const auto& b) { - std::array data; - for (size_t i = 0, j = 0; i < n; i += 1, j += 2) { - data[j] = a.compared_value(i); - data[j + 1] = b.compared_value(i); - } - return (*less_than)(data.data()); + std::array values; + a.FillComparedValues(&values[0]); + b.FillComparedValues(&values[1]); + return (*less_than)(values.data()); }; SortIterator, Ref, Ptr> begin( - Ptr(ptr, ptr_sizes), - /*stride=*/sort_dims.inner_dim_size); + Ptr(&inputs), /*stride=*/sort_dims.inner_dim_size); if (is_stable) { std::stable_sort(begin, begin + sort_dims.sort_dim_size, compare); } else { @@ -570,25 +709,28 @@ static void DSortInplace(const SortDims& sort_dims, int64_t offset, absl::Span data, absl::Span shapes, bool is_stable, SortThunk::LessThan* less_than, size_t n) { - std::vector ptr(n); - std::vector ptr_sizes(n); + std::vector ptrs(n); + std::vector primitive_sizes(n); for (size_t i = 0; i < n; ++i) { std::byte* base = reinterpret_cast(data[i].opaque()); - ptr_sizes[i] = primitive_util::ByteWidth(shapes[i].element_type()); - ptr[i] = base + offset * ptr_sizes[i]; + primitive_sizes[i] = primitive_util::ByteWidth(shapes[i].element_type()); + ptrs[i] = base + offset * primitive_sizes[i]; } - auto compare = [&](const auto& a, const auto& b) { - std::vector data(2 * n); - for (size_t i = 0, j = 0; i < n; i += 1, j += 2) { - data[j] = a.compared_value(i); - data[j + 1] = b.compared_value(i); - } - return (*less_than)(data.data()); + DInputs inputs(std::move(ptrs), std::move(primitive_sizes)); + + // Allocate scratch space for sorted values outside of the lambda to avoid + // allocating it on every call to `compare`. + std::vector values(2 * n); + + auto compare = [&, values = values.data()](const auto& a, const auto& b) { + a.FillComparedValues(&values[0]); + b.FillComparedValues(&values[1]); + return (*less_than)(values); }; - SortIterator begin(DPtr(ptr, ptr_sizes), + SortIterator begin(DPtr(&inputs), /*stride=*/sort_dims.inner_dim_size); if (is_stable) { std::stable_sort(begin, begin + sort_dims.sort_dim_size, compare); @@ -638,10 +780,8 @@ static absl::Status SortInplace( type); }; - // use "sort" for statically known number of sorted inputs (expected to be + // Use "sort" for statically known number of sorted inputs (expected to be // faster) and "dsort" for dynamically known number of sorted inputs. - // for 100 elements stable sort is 1.5 times faster than stable dsort. - // for 100 elements unstable sort is 2.47 times faster than unstable dsort. switch (data.size()) { case 1: DCHECK_EQ(shapes.size(), 1); @@ -696,33 +836,6 @@ static absl::Status SortInplace( case 16: sort(std::integral_constant{}); break; - case 17: - sort(std::integral_constant{}); - break; - case 18: - sort(std::integral_constant{}); - break; - case 19: - sort(std::integral_constant{}); - break; - case 20: - sort(std::integral_constant{}); - break; - case 21: - sort(std::integral_constant{}); - break; - case 22: - sort(std::integral_constant{}); - break; - case 23: - sort(std::integral_constant{}); - break; - case 24: - sort(std::integral_constant{}); - break; - case 25: - sort(std::integral_constant{}); - break; default: dsort(data.size()); break; diff --git a/third_party/xla/xla/backends/cpu/runtime/sort_thunk.h b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.h index c73ad534db2aad..6d32ab1ac3c5f6 100644 --- a/third_party/xla/xla/backends/cpu/runtime/sort_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef XLA_BACKENDS_CPU_RUNTIME_SORT_THUNK_H_ #define XLA_BACKENDS_CPU_RUNTIME_SORT_THUNK_H_ -#include #include #include #include @@ -24,10 +23,8 @@ limitations under the License. #include #include "absl/base/call_once.h" -#include "absl/base/thread_annotations.h" #include "absl/functional/any_invocable.h" #include "absl/status/statusor.h" -#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" diff --git a/third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc index 98d1eea03703c8..797847a42c8bc5 100644 --- a/third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc @@ -16,37 +16,38 @@ limitations under the License. #include "xla/backends/cpu/runtime/sort_thunk.h" #include -#include -#include #include #include -#include -#include -#include +#include #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/function_library.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_testlib.h" #include "xla/layout.h" #include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/maybe_owning_device_memory.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" +#include "xla/xla_data.pb.h" namespace xla::cpu { namespace { class SortThunkTest : public testing::TestWithParam {}; +// Sorts the data using only the first input (that must be float!). static bool LessThan(const void** data) { auto* lhs = reinterpret_cast(data[0]); auto* rhs = reinterpret_cast(data[1]); @@ -55,39 +56,29 @@ static bool LessThan(const void** data) { class LessThanComparator : public FunctionLibrary { public: - static void LessThanWrapper(bool* result, const void*, const void** data, - const void*, const void*, const void*) { - *result = LessThan(data); - } - absl::StatusOr ResolveFunction(TypeId type_id, - std::string_view name) final { + absl::string_view name) final { DCHECK_EQ(name, "less_than"); return reinterpret_cast(LessThanWrapper); } + + private: + static void LessThanWrapper(bool* result, const void*, const void** data, + const void*, const void*, const void*) { + *result = LessThan(data); + } }; TEST_P(SortThunkTest, DescendingSortPlainArray) { bool is_stable = GetParam(); - const int data_size = 10000; - - std::vector buffers; - std::vector data(data_size); - std::default_random_engine gen; - std::uniform_real_distribution distribution(0.0, 1000.0); + TF_ASSERT_OK_AND_ASSIGN(auto data, + LiteralUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(F32, {10000}), 1.0f, 0.1f)); - for (int i = 0; i < data_size; i++) { - data[i] = distribution(gen); - } - - const size_t size_in_bytes = data_size * sizeof(float); - buffers.emplace_back(se::DeviceMemoryBase(data.data(), size_in_bytes)); - - const BufferAllocations allocations(buffers); - const BufferAllocation alloc(0, size_in_bytes, 0); - const BufferAllocation::Slice slice0(&alloc, 0, size_in_bytes); - const Shape data_shape = ShapeUtil::MakeShape(F32, {data_size}); + BufferAllocations allocations = CreateBufferAllocations(data); + BufferAllocation alloc = CreateBufferAllocation(0, data); + BufferAllocation::Slice slice = CreateBufferAllocationSlice(alloc); // The comparator function is not used in the plain array sort when the sort // direction is specified and data types are supported. @@ -95,7 +86,7 @@ TEST_P(SortThunkTest, DescendingSortPlainArray) { // Use sort direction to activate the most efficient sorting function. TF_ASSERT_OK_AND_ASSIGN( - auto thunk, SortThunk::Create({"sort"}, {{slice0, data_shape}}, + auto thunk, SortThunk::Create({"sort"}, {{slice, data.shape()}}, /*dimension=*/0, is_stable, fake_less_than, SortThunk::SortDirection::kDescending)); @@ -106,37 +97,27 @@ TEST_P(SortThunkTest, DescendingSortPlainArray) { tsl::BlockUntilReady(execute_event); ASSERT_FALSE(execute_event.IsError()); - EXPECT_TRUE( - std::is_sorted(data.cbegin(), data.cend(), std::greater())); + EXPECT_TRUE(std::is_sorted(data.data().begin(), + data.data().end(), std::greater())); } TEST_P(SortThunkTest, Sort1D) { bool is_stable = GetParam(); - std::vector buffers; - std::vector data = {2.0, 4.0, 1.0, 3.0}; - std::vector indices = {0, 1, 2, 3}; - - size_t size_in_bytes = data.size() * sizeof(float); - buffers.emplace_back(se::DeviceMemoryBase(data.data(), size_in_bytes)); - buffers.emplace_back(se::DeviceMemoryBase(indices.data(), size_in_bytes)); - - BufferAllocations allocations(buffers); - - BufferAllocation alloc0(0, size_in_bytes, 0); - BufferAllocation alloc1(1, size_in_bytes, 0); + auto data = LiteralUtil::CreateR1({2.0, 4.0, 1.0, 3.0}); + auto indices = LiteralUtil::CreateR1({0, 1, 2, 3}); - BufferAllocation::Slice slice0(&alloc0, 0, size_in_bytes); - BufferAllocation::Slice slice1(&alloc1, 0, size_in_bytes); + BufferAllocations allocations = CreateBufferAllocations(data, indices); - Shape data_shape = ShapeUtil::MakeShape(F32, {4}); - Shape indices_shape = ShapeUtil::MakeShape(S32, {4}); + auto [alloc0, alloc1] = CreateBufferAllocation(data, indices); + auto [slice0, slice1] = CreateBufferAllocationSlice(alloc0, alloc1); TF_ASSERT_OK_AND_ASSIGN( - auto thunk, SortThunk::Create( - {"sort"}, {{slice0, data_shape}, {slice1, indices_shape}}, - /*dimension=*/0, is_stable, LessThan, - SortThunk::SortDirection::kAscending)); + auto thunk, + SortThunk::Create({"sort"}, + {{slice0, data.shape()}, {slice1, indices.shape()}}, + /*dimension=*/0, is_stable, LessThan, + SortThunk::SortDirection::kAscending)); Thunk::ExecuteParams params; params.buffer_allocations = &allocations; @@ -145,68 +126,42 @@ TEST_P(SortThunkTest, Sort1D) { tsl::BlockUntilReady(execute_event); ASSERT_FALSE(execute_event.IsError()); - std::vector expected_data = {1.0, 2.0, 3.0, 4.0}; - std::vector expected_indices = {2, 0, 3, 1}; - - EXPECT_EQ(data, expected_data); - EXPECT_EQ(indices, expected_indices); + EXPECT_EQ(data, LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0})); + EXPECT_EQ(indices, LiteralUtil::CreateR1({2, 0, 3, 1})); } -TEST_P(SortThunkTest, DynamicSort1D) { +TEST_P(SortThunkTest, Sort1DDynamicNumInputs) { bool is_stable = GetParam(); - // 33 empty slices + 2 slices with data = 35 slices - // This amount of slices will call the dynamic sort implementation. - constexpr int num_of_empty_slices = 33; - constexpr int total_num_of_slices = num_of_empty_slices + 2; - - // size of each of 33 data buffers - constexpr int data_size = 31; - - // values range will be [5.0, 35.0] - constexpr float starting_value = 5.0f; - - std::array data{ - 17.0f, 16.0f, 5.0f, 10.0f, 30.0f, 8.0f, 9.0f, 21.0f, - 14.0f, 32.0f, 29.0f, 28.0f, 19.0f, 12.0f, 25.0f, 22.0f, - 18.0f, 35.0f, 34.0f, 23.0f, 7.0f, 13.0f, 26.0f, 33.0f, - 15.0f, 24.0f, 20.0f, 31.0f, 6.0f, 27.0f, 11.0f}; - std::array indices; - std::iota(indices.begin(), indices.end(), 0); - - // This is a container for the rest of the buffers. - std::array empty; - - const size_t data_size_in_bytes = data.size() * sizeof(float); - const size_t ind_size_in_bytes = indices.size() * sizeof(int32_t); - const size_t empty_size_in_bytes = empty.size() * sizeof(uint32_t); - - const BufferAllocation alloc0(0, data_size_in_bytes, 0); - const BufferAllocation alloc1(1, ind_size_in_bytes, 0); - const BufferAllocation rest(2, empty_size_in_bytes, 0); - - const BufferAllocation::Slice slice0(&alloc0, 0, data_size_in_bytes); - const BufferAllocation::Slice slice1(&alloc1, 0, ind_size_in_bytes); - - const Shape data_shape = ShapeUtil::MakeShape(F32, {data_size}); - const Shape indices_shape = ShapeUtil::MakeShape(S32, {data_size}); - const Shape rest_shape = ShapeUtil::MakeShape(U32, {data_size}); - - std::vector buffers; - buffers.emplace_back(se::DeviceMemoryBase(data.data(), data_size_in_bytes)); - buffers.emplace_back(se::DeviceMemoryBase(indices.data(), ind_size_in_bytes)); - buffers.emplace_back(se::DeviceMemoryBase(empty.data(), empty_size_in_bytes)); - - BufferAllocations allocations(buffers); - - std::array inputs{ - {{slice0, data_shape}, {slice1, indices_shape}}}; - for (int i = 0; i < num_of_empty_slices; ++i) { - constexpr size_t empty_slice_in_bytes = data_size * sizeof(uint32_t); - inputs[i + 2].slice = BufferAllocation::Slice( - &rest, i * empty_slice_in_bytes, empty_slice_in_bytes); - inputs[i + 2].shape = rest_shape; - } + Literal data = LiteralUtil::CreateR1( + {17.0f, 16.0f, 5.0f, 10.0f, 30.0f, 8.0f, 9.0f, 21.0f, + 14.0f, 32.0f, 29.0f, 28.0f, 19.0f, 12.0f, 25.0f, 22.0f, + 18.0f, 35.0f, 34.0f, 23.0f, 7.0f, 13.0f, 26.0f, 33.0f, + 15.0f, 24.0f, 20.0f, 31.0f, 6.0f, 27.0f, 11.0f}); + + Literal indices = LiteralUtil::CreateR1( + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}); + + // We use dummy data to create large number of input to trigger the dynamic + // sort implementation, but we don't use it for sorting. + TF_ASSERT_OK_AND_ASSIGN( + Literal dummy_data, + LiteralUtil::CreateRandomLiteral(data.shape(), 1.0f, 0.1f)); + + BufferAllocations allocations = + CreateBufferAllocations(data, indices, dummy_data); + + auto [data_alloc, indices_alloc, dummy_alloc] = + CreateBufferAllocation(data, indices, dummy_data); + auto [data_slice, indices_slice, dummy_slice] = + CreateBufferAllocationSlice(data_alloc, indices_alloc, dummy_alloc); + + // We use only first input for sorting, the rest of the inputs are shuffled + // according to the values in the `data` literal. + std::vector inputs = {{data_slice, data.shape()}, + {indices_slice, indices.shape()}}; + inputs.resize(40, {dummy_slice, dummy_data.shape()}); TF_ASSERT_OK_AND_ASSIGN( auto thunk, SortThunk::Create({"sort"}, inputs, @@ -220,11 +175,15 @@ TEST_P(SortThunkTest, DynamicSort1D) { tsl::BlockUntilReady(execute_event); ASSERT_FALSE(execute_event.IsError()); - std::array expected_data; - std::iota(expected_data.begin(), expected_data.end(), starting_value); - const std::array expected_indices{ - 2, 28, 20, 5, 6, 3, 30, 13, 21, 8, 24, 1, 0, 16, 12, 26, - 7, 15, 19, 25, 14, 22, 29, 11, 10, 4, 27, 9, 23, 18, 17}; + auto expected_data = LiteralUtil::CreateR1( + {5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, + 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, + 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f}); + + auto expected_indices = LiteralUtil::CreateR1( + {2, 28, 20, 5, 6, 3, 30, 13, 21, 8, 24, 1, 0, 16, 12, 26, + 7, 15, 19, 25, 14, 22, 29, 11, 10, 4, 27, 9, 23, 18, 17}); EXPECT_EQ(data, expected_data); EXPECT_EQ(indices, expected_indices); @@ -233,30 +192,19 @@ TEST_P(SortThunkTest, DynamicSort1D) { TEST_P(SortThunkTest, Sort2D) { bool is_stable = GetParam(); - std::vector buffers; - std::vector data = {2.0, 4.0, 1.0, 3.0}; - std::vector indices = {0, 1, 2, 3}; - - size_t size_in_bytes = data.size() * sizeof(float); - buffers.emplace_back(se::DeviceMemoryBase(data.data(), size_in_bytes)); - buffers.emplace_back(se::DeviceMemoryBase(indices.data(), size_in_bytes)); - - BufferAllocations allocations(buffers); - - BufferAllocation alloc0(0, size_in_bytes, 0); - BufferAllocation alloc1(1, size_in_bytes, 0); + auto data = LiteralUtil::CreateR2({{2.0, 4.0}, {1.0, 3.0}}); + auto indices = LiteralUtil::CreateR2({{0, 1}, {2, 3}}); - BufferAllocation::Slice slice0(&alloc0, 0, size_in_bytes); - BufferAllocation::Slice slice1(&alloc1, 0, size_in_bytes); + BufferAllocations allocations = CreateBufferAllocations(data, indices); - Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); - Shape indices_shape = ShapeUtil::MakeShape(S32, {2, 2}); + auto [alloc0, alloc1] = CreateBufferAllocation(data, indices); + auto [slice0, slice1] = CreateBufferAllocationSlice(alloc0, alloc1); // Sort along the dimension `0`. TF_ASSERT_OK_AND_ASSIGN( auto sort_dim0, SortThunk::Create({"sort"}, - {{slice0, data_shape}, {slice1, indices_shape}}, + {{slice0, data.shape()}, {slice1, indices.shape()}}, /*dimension=*/0, is_stable, "less_than", SortThunk::SortDirection::kAscending)); @@ -270,20 +218,17 @@ TEST_P(SortThunkTest, Sort2D) { tsl::BlockUntilReady(execute_event0); ASSERT_FALSE(execute_event0.IsError()); - std::vector expected_data = {1.0, 3.0, 2.0, 4.0}; - std::vector expected_indices = {2, 3, 0, 1}; - - EXPECT_EQ(data, expected_data); - EXPECT_EQ(indices, expected_indices); + EXPECT_EQ(data, LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}})); + EXPECT_EQ(indices, LiteralUtil::CreateR2({{2, 3}, {0, 1}})); // Reset data and indices to make it unsorted along the dimension `1`. - data = {4.0, 3.0, 2.0, 1.0}; - indices = {0, 1, 2, 3}; + data = LiteralUtil::CreateR2({{4.0, 3.0}, {2.0, 1.0}}); + indices = LiteralUtil::CreateR2({{0, 1}, {2, 3}}); TF_ASSERT_OK_AND_ASSIGN( auto sort_dim1, SortThunk::Create({"sort"}, - {{slice0, data_shape}, {slice1, indices_shape}}, + {{slice0, data.shape()}, {slice1, indices.shape()}}, /*dimension=*/1, /*is_stable=*/false, "less_than", SortThunk::SortDirection::kAscending)); @@ -292,36 +237,25 @@ TEST_P(SortThunkTest, Sort2D) { tsl::BlockUntilReady(execute_event1); ASSERT_FALSE(execute_event1.IsError()); - expected_data = {3.0, 4.0, 1.0, 2.0}; - expected_indices = {1, 0, 3, 2}; - - EXPECT_EQ(data, expected_data); - EXPECT_EQ(indices, expected_indices); + EXPECT_EQ(data, LiteralUtil::CreateR2({{3.0, 4.0}, {1.0, 2.0}})); + EXPECT_EQ(indices, LiteralUtil::CreateR2({{1, 0}, {3, 2}})); } TEST_P(SortThunkTest, Sort2DWithLayout) { bool is_stable = GetParam(); - std::vector buffers; - std::vector data = {4.0, 3.0, 2.0, 1.0}; - std::vector indices = {0, 1, 2, 3}; - - size_t size_in_bytes = data.size() * sizeof(float); - buffers.emplace_back(se::DeviceMemoryBase(data.data(), size_in_bytes)); - buffers.emplace_back(se::DeviceMemoryBase(indices.data(), size_in_bytes)); - - BufferAllocations allocations(buffers); + auto data = LiteralUtil::CreateR2({{4.0, 3.0}, {2.0, 1.0}}); + auto indices = LiteralUtil::CreateR2({{0, 1}, {2, 3}}); - BufferAllocation alloc0(0, size_in_bytes, 0); - BufferAllocation alloc1(1, size_in_bytes, 0); + BufferAllocations allocations = CreateBufferAllocations(data, indices); - BufferAllocation::Slice slice0(&alloc0, 0, size_in_bytes); - BufferAllocation::Slice slice1(&alloc1, 0, size_in_bytes); + auto [alloc0, alloc1] = CreateBufferAllocation(data, indices); + auto [slice0, slice1] = CreateBufferAllocationSlice(alloc0, alloc1); - Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + Shape data_shape = data.shape(); *data_shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - Shape indices_shape = ShapeUtil::MakeShape(S32, {2, 2}); + Shape indices_shape = indices.shape(); *indices_shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); // Sort along the dimension `0`. @@ -342,15 +276,12 @@ TEST_P(SortThunkTest, Sort2DWithLayout) { tsl::BlockUntilReady(execute_event0); ASSERT_FALSE(execute_event0.IsError()); - std::vector expected_data = {3.0, 4.0, 1.0, 2.0}; - std::vector expected_indices = {1, 0, 3, 2}; - - EXPECT_EQ(data, expected_data); - EXPECT_EQ(indices, expected_indices); + EXPECT_EQ(data, LiteralUtil::CreateR2({{3.0, 4.0}, {1.0, 2.0}})); + EXPECT_EQ(indices, LiteralUtil::CreateR2({{1, 0}, {3, 2}})); // Reset data and indices to make it unsorted along the dimension `1`. - data = {2.0, 4.0, 1.0, 3.0}; - indices = {0, 1, 2, 3}; + data = LiteralUtil::CreateR2({{2.0, 4.0}, {1.0, 3.0}}); + indices = LiteralUtil::CreateR2({{0, 1}, {2, 3}}); TF_ASSERT_OK_AND_ASSIGN( auto sort_dim1, @@ -364,173 +295,83 @@ TEST_P(SortThunkTest, Sort2DWithLayout) { tsl::BlockUntilReady(execute_event1); ASSERT_FALSE(execute_event1.IsError()); - expected_data = {1.0, 3.0, 2.0, 4.0}; - expected_indices = {2, 3, 0, 1}; - - EXPECT_EQ(data, expected_data); - EXPECT_EQ(indices, expected_indices); + EXPECT_EQ(data, LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}})); + EXPECT_EQ(indices, LiteralUtil::CreateR2({{2, 3}, {0, 1}})); } -void BM_DynamicSort1D(::testing::benchmark::State& state, bool is_stable) { - const int total_num_of_slices = state.range(0); - const int num_of_empty_slices = total_num_of_slices - 2; - - // size of each of data buffers - constexpr int data_size = 31; - - const std::array data{ - 17.0f, 16.0f, 5.0f, 10.0f, 30.0f, 8.0f, 9.0f, 21.0f, - 14.0f, 32.0f, 29.0f, 28.0f, 19.0f, 12.0f, 25.0f, 22.0f, - 18.0f, 35.0f, 34.0f, 23.0f, 7.0f, 13.0f, 26.0f, 33.0f, - 15.0f, 24.0f, 20.0f, 31.0f, 6.0f, 27.0f, 11.0f}; - std::array indices; - std::iota(indices.begin(), indices.end(), 0); - - // This is the container for the rest of the buffers. - std::vector empty(data_size * num_of_empty_slices); - - const size_t data_size_in_bytes = data.size() * sizeof(float); - const size_t ind_size_in_bytes = indices.size() * sizeof(int32_t); - const size_t empty_size_in_bytes = empty.size() * sizeof(uint32_t); - - const BufferAllocation alloc0(0, data_size_in_bytes, 0); - const BufferAllocation alloc1(1, ind_size_in_bytes, 0); - const BufferAllocation rest(2, empty_size_in_bytes, 0); - - const BufferAllocation::Slice slice0(&alloc0, 0, data_size_in_bytes); - const BufferAllocation::Slice slice1(&alloc1, 0, ind_size_in_bytes); - - const Shape data_shape = ShapeUtil::MakeShape(F32, {data_size}); - const Shape indices_shape = ShapeUtil::MakeShape(S32, {data_size}); - const Shape rest_shape = ShapeUtil::MakeShape(U32, {data_size}); - - for (auto s : state) { - // Pause timing to avoid counting the time spent in the setup. - state.PauseTiming(); - auto data_clone(data); - auto indices_clone(indices); - - std::vector buffers; - buffers.emplace_back( - se::DeviceMemoryBase(data_clone.data(), data_size_in_bytes)); - buffers.emplace_back( - se::DeviceMemoryBase(indices_clone.data(), ind_size_in_bytes)); - buffers.emplace_back( - se::DeviceMemoryBase(empty.data(), empty_size_in_bytes)); - - BufferAllocations allocations(buffers); - - std::vector inputs(total_num_of_slices); - inputs[0] = {slice0, data_shape}; - inputs[1] = {slice1, indices_shape}; - for (int i = 0; i < num_of_empty_slices; ++i) { - constexpr size_t empty_slice_in_bytes = data_size * sizeof(uint32_t); - inputs[i + 2].slice = BufferAllocation::Slice( - &rest, i * empty_slice_in_bytes, empty_slice_in_bytes); - inputs[i + 2].shape = rest_shape; - } - - Thunk::ExecuteParams params; - params.buffer_allocations = &allocations; +INSTANTIATE_TEST_SUITE_P(SortThunk, SortThunkTest, testing::Bool(), + testing::PrintToStringParamName()); - state.ResumeTiming(); - TF_ASSERT_OK_AND_ASSIGN( - auto thunk, SortThunk::Create({"sort"}, inputs, - /*dimension=*/0, is_stable, LessThan, - SortThunk::SortDirection::kAscending)); +//===----------------------------------------------------------------------===// +// Performance benchmarks below. +//===----------------------------------------------------------------------===// - auto execute_event = thunk->Execute(params); - tsl::BlockUntilReady(execute_event); - ASSERT_FALSE(execute_event.IsError()); - } -} +void BM_Sort1D(benchmark::State& state) { + int64_t input_size = state.range(0); + int64_t num_inputs = state.range(1); + bool is_stable = state.range(2); + bool sort_ascending = state.range(3); -void BM_SortPlainArray(::testing::benchmark::State& state, bool is_stable) { - const int data_size = state.range(0); + CHECK_GE(num_inputs, 1) << "Number of inputs must be at least 1"; // Crash OK - std::vector data(data_size); + auto data = LiteralUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(F32, {input_size}), 1.0f, 1.0f); + CHECK_OK(data) << "Failed to create random literal"; // Crash OK - std::default_random_engine gen; - std::uniform_real_distribution distribution(0.0, 1000.0); + // We use dummy data to create additional inputs, but we don't use it for + // sorting and simply shuffle it according to the values in the first input. + auto dummy_data = + LiteralUtil::CreateRandomLiteral(data->shape(), 1.f, 1.f); + CHECK_OK(dummy_data) << "Failed to create random literal"; // Crash OK - for (int i = 0; i < data_size; i++) { - data[i] = distribution(gen); - } + // Use sort direction to activate the most efficient sorting function, or fall + // back on the comparator functor. + std::optional direction; + if (sort_ascending) direction = SortThunk::SortDirection::kAscending; - const size_t size_in_bytes = data_size * sizeof(float); - const BufferAllocation alloc(0, size_in_bytes, 0); - const BufferAllocation::Slice slice0(&alloc, 0, size_in_bytes); - const Shape data_shape = ShapeUtil::MakeShape(F32, {data_size}); + auto [alloc, dummy_alloc] = CreateBufferAllocation(*data, *dummy_data); + auto [slice, dummy_slice] = CreateBufferAllocationSlice(alloc, dummy_alloc); for (auto s : state) { - state.PauseTiming(); - auto data_clone(data); - std::vector buffer; - buffer.emplace_back(se::DeviceMemoryBase(data_clone.data(), size_in_bytes)); + // Clone the data to avoid sorting already sorted data. + Literal data_copy = data->Clone(); + BufferAllocations allocations = + CreateBufferAllocations(data_copy, *dummy_data); - const BufferAllocations allocations(buffer); + std::vector inputs = {{slice, data_copy.shape()}}; + inputs.resize(num_inputs, {dummy_slice, dummy_data->shape()}); Thunk::ExecuteParams params; params.buffer_allocations = &allocations; - // The comparator function is not used in the plain array sort when the sort - // direction is specified and data types are supported. - auto fake_less_than = [](const void** data) { return false; }; + auto thunk = + SortThunk::Create({"sort"}, inputs, + /*dimension=*/0, is_stable, LessThan, direction); + CHECK_OK(thunk) << "Failed to create sort thunk"; // Crash OK - state.ResumeTiming(); - // Use sort direction to activate the most efficient sorting function. - TF_ASSERT_OK_AND_ASSIGN( - auto thunk, - SortThunk::Create({"sort"}, {{slice0, data_shape}}, - /*dimension=*/0, is_stable, fake_less_than, - SortThunk::SortDirection::kAscending)); - - auto execute_event = thunk->Execute(params); + auto execute_event = (*thunk)->Execute(params); tsl::BlockUntilReady(execute_event); - ASSERT_FALSE(execute_event.IsError()); + CHECK(execute_event.IsConcrete()); } } -void BM_StableDynamicSort1D(::testing::benchmark::State& state) { - BM_DynamicSort1D(state, /*is_stable=*/true); -} - -void BM_UnstableDynamicSort1D(::testing::benchmark::State& state) { - BM_DynamicSort1D(state, /*is_stable=*/false); -} - -void BM_StableSortPlainArray(::testing::benchmark::State& state) { - BM_SortPlainArray(state, /*is_stable=*/true); -} - -void BM_UnstableSortPlainArray(::testing::benchmark::State& state) { - BM_SortPlainArray(state, /*is_stable=*/false); -} - -BENCHMARK(BM_StableDynamicSort1D) - ->MeasureProcessCPUTime() - ->Arg(35) - ->Arg(50) - ->Arg(100); - -BENCHMARK(BM_UnstableDynamicSort1D) - ->MeasureProcessCPUTime() - ->Arg(35) - ->Arg(50) - ->Arg(100); - -BENCHMARK(BM_StableSortPlainArray) +BENCHMARK(BM_Sort1D) ->MeasureProcessCPUTime() - ->Arg(10000) - ->Arg(100000); - -BENCHMARK(BM_UnstableSortPlainArray) - ->MeasureProcessCPUTime() - ->Arg(10000) - ->Arg(100000); - -INSTANTIATE_TEST_SUITE_P(SortThunk, SortThunkTest, testing::Bool(), - testing::PrintToStringParamName()); + ->ArgNames({"input_size", "num_inputs", "is_stable", "sort_ascending"}) + // Sort using ascending directions. + ->Args({1000, 1, false, true}) + ->Args({1000, 2, false, true}) + ->Args({1000, 4, false, true}) + ->Args({1000, 8, false, true}) + ->Args({1000, 16, false, true}) + ->Args({1000, 32, false, true}) + // Sort using LessThan comparator. + ->Args({1000, 1, false, false}) + ->Args({1000, 2, false, false}) + ->Args({1000, 4, false, false}) + ->Args({1000, 8, false, false}) + ->Args({1000, 16, false, false}) + ->Args({1000, 32, false, false}); } // namespace } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk.cc b/third_party/xla/xla/backends/cpu/runtime/thunk.cc index 1b56a0194014ee..96cf954095f20c 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk.cc @@ -20,24 +20,24 @@ limitations under the License. #include #include #include -#include #include +#include "absl/strings/string_view.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/backends/cpu/collectives/in_process_collectives.h" #include "xla/executable_run_options.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/cpu_executable_run_options.h" -#include "xla/service/cpu/in_process_collectives.h" #include "xla/service/global_device_id.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" namespace xla::cpu { -std::string_view Thunk::KindToString(Kind kind) { +absl::string_view Thunk::KindToString(Kind kind) { switch (kind) { case Kind::kAllGather: return "all-gather"; @@ -81,6 +81,8 @@ std::string_view Thunk::KindToString(Kind kind) { return "topk"; case Kind::kWhile: return "while"; + case Kind::kXnnFusion: + return "xnn-fusion"; } } Thunk::Thunk(Kind kind, Info info) @@ -100,15 +102,14 @@ Thunk::CollectiveExecuteParams::Create( // Default implementation of a collectives interface that can execute // collective operations within the same process. - static CollectivesInterface* in_process_collectives = - new runtime::InProcessCollectives(); + static CpuCollectives* in_process_collectives = new InProcessCollectives(); // If CPU executable run options are set, use the collectives interface // provided by the executable run options if it is set. Otherwise, use the // in-process collectives interface. const CpuExecutableRunOptions* cpu_run_options = run_options->cpu_executable_run_options(); - CollectivesInterface* collectives = + CpuCollectives* collectives = cpu_run_options && cpu_run_options->collectives() ? cpu_run_options->collectives() : in_process_collectives; @@ -120,8 +121,7 @@ Thunk::CollectiveExecuteParams::Create( Thunk::CollectiveExecuteParams::CollectiveExecuteParams( RunId run_id, int64_t local_device_ordinal, GlobalDeviceId global_device_id, - const DeviceAssignment* device_assignment, - CollectivesInterface* collectives) + const DeviceAssignment* device_assignment, CpuCollectives* collectives) : run_id(run_id), local_device_ordinal(local_device_ordinal), global_device_id(global_device_id), diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk.h b/third_party/xla/xla/backends/cpu/runtime/thunk.h index bdb145c64df65b..2c86db92517745 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk.h @@ -22,28 +22,26 @@ limitations under the License. #include #include #include -#include #include #include #include #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/function_library.h" -#include "xla/backends/cpu/runtime/kernel_c_api.h" #include "xla/backends/cpu/runtime/resource_use.h" #include "xla/executable_run_options.h" #include "xla/ffi/execution_context.h" #include "xla/runtime/buffer_use.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/service/global_device_id.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" -#include "xla/util.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" namespace Eigen { struct ThreadPoolDevice; @@ -89,6 +87,7 @@ class Thunk { kSort, kTopK, kWhile, + kXnnFusion, }; struct Info { @@ -133,7 +132,7 @@ class Thunk { Kind kind() const { return kind_; } const Info& info() const { return info_; } - static std::string_view KindToString(Kind kind); + static absl::string_view KindToString(Kind kind); // Returns the list of buffers used by a thunk. Thunk executor relies on this // information to execute thunks concurrently and to avoid data races. @@ -164,13 +163,13 @@ class Thunk { GlobalDeviceId global_device_id; const DeviceAssignment* device_assignment = nullptr; - CollectivesInterface* collectives = nullptr; + CpuCollectives* collectives = nullptr; private: CollectiveExecuteParams(RunId run_id, int64_t local_device_ordinal, GlobalDeviceId global_device_id, const DeviceAssignment* device_assignment, - CollectivesInterface* collectives); + CpuCollectives* collectives); }; //===--------------------------------------------------------------------===// diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc index d330f2116e14d2..97625473b44200 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc @@ -15,17 +15,23 @@ limitations under the License. #include "xla/backends/cpu/runtime/thunk_executor.h" +#include + #include #include #include +#include +#include #include #include +#include #include #include #include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/base/optimization.h" +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" @@ -36,17 +42,38 @@ limitations under the License. #include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/util.h" +#include "tsl/platform/numbers.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { +// If XLA:CPU compiled with `-DXLA_CPU_USE_BLOCKING_THUNK_EXECUTOR` we'll run +// all thunks sequentially and block on the completion of all thunks, which is +// helpful for debugging and gives more readable Xprof traces. +// +// WARNING: This option is UNSAFE and can lead to deadlocks. It should be used +// only for debugging purposes. +static constexpr bool UseBlockingThunkExecutor() { +#if defined(XLA_CPU_USE_BLOCKING_THUNK_EXECUTOR) + return true; +#else + return false; +#endif // XLA_CPU_USE_BLOCKING_THUNK_EXECUTOR +} + ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, + NodesEdges nodes_in_edges, + NodesEdges nodes_out_edges, std::vector nodes_defs, const ThunkExecutor::Options& options) : thunk_sequence_(std::move(thunk_sequence)), options_(options), num_thunks_(thunk_sequence_.size()), + nodes_in_edges_(std::move(nodes_in_edges)), + nodes_out_edges_(std::move(nodes_out_edges)), nodes_defs_(std::move(nodes_defs)), is_sequential_(true) { for (NodeId i = 0; i < nodes_defs_.size(); ++i) { @@ -61,9 +88,6 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, } } - // Erase redundant edges between nodes. - int64_t num_erased_edges = RunTransitiveReductionAndUpdatePriorities(); - // Check if constructed execution DAG is sequential: every node depends on the // completion of the previous node. for (NodeId i = 1; i < nodes_defs_.size() && is_sequential_; ++i) { @@ -84,11 +108,15 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, is_sequential_ |= thunk_sequence_.size() <= options.execute_sequential_num_thunks_threshold; + // Force sequential execution if we are running in blocking mode as it makes + // Xprof traces easier to read. + is_sequential_ |= UseBlockingThunkExecutor(); + VLOG(2) << absl::StreamFormat( "Constructed ThunkExecutor with %d nodes: #source_nodes=%d " - "#sink_nodes=%d, #erased_edges=%d, is_sequential=%v, small_buffers=%v", - nodes_defs_.size(), source_.size(), sink_.size(), num_erased_edges, - is_sequential_, small_buffers); + "#sink_nodes=%d, is_sequential=%v, small_buffers=%v", + nodes_defs_.size(), source_.size(), sink_.size(), is_sequential_, + small_buffers); // Sanity check that all vectors are empty or all vectors are non-empty. DCHECK((!source_.empty() && !sink_.empty() && !thunk_sequence_.empty()) || @@ -97,7 +125,13 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, absl::StatusOr ThunkExecutor::Create( ThunkSequence thunk_sequence, const ThunkExecutor::Options& options) { - std::vector defs(thunk_sequence.size()); + // Make sure that thunk sequence size fits into NodeId. + if (thunk_sequence.size() > std::numeric_limits::max()) { + return Internal("Can't create ThunkExecutor with more than %d thunks", + std::numeric_limits::max()); + } + + std::vector builders(thunk_sequence.size()); std::vector buffer_rwsets(thunk_sequence.size()); std::vector resource_rwsets(thunk_sequence.size()); @@ -108,7 +142,7 @@ absl::StatusOr ThunkExecutor::Create( // most recent updates that touch the whole buffer slice. for (NodeId i = 0; i < thunk_sequence.size(); ++i) { - defs[i].id = i; + builders[i].id = i; Thunk& thunk = *thunk_sequence[i]; buffer_rwsets[i].AddAll(thunk.buffer_uses()); @@ -118,24 +152,33 @@ absl::StatusOr ThunkExecutor::Create( // Check if node `i` must be executed after node `j`. if (buffer_rwsets[j].HasConflicts(buffer_rwsets[i]) || resource_rwsets[j].HasConflicts(resource_rwsets[i])) { - defs[j].out_edges.push_back(i); - defs[i].in_edges.push_back(j); + builders[j].out_edges.push_back(i); + builders[i].in_edges.push_back(j); } } } // Verify that both in-edges and out-edges are sorted in ascending order as we // use this property later. - for (NodeId i = 0; i < defs.size(); ++i) { - DCHECK(absl::c_is_sorted(defs[i].out_edges)); - DCHECK(absl::c_is_sorted(defs[i].in_edges)); + for (NodeId i = 0; i < builders.size(); ++i) { + DCHECK(absl::c_is_sorted(builders[i].out_edges)); + DCHECK(absl::c_is_sorted(builders[i].in_edges)); } - return ThunkExecutor(std::move(thunk_sequence), std::move(defs), options); + // Erase redundant edges between nodes. + int64_t num_erased_edges = + RunTransitiveReductionAndUpdatePriorities(absl::MakeSpan(builders)); + VLOG(5) << absl::StreamFormat( + "Transitive reduction erased %d edges from the nodes graph", + num_erased_edges); + + auto [in_edges, out_edges, nodes_defs] = CreateNodeDefs(std::move(builders)); + return ThunkExecutor(std::move(thunk_sequence), std::move(in_edges), + std::move(out_edges), std::move(nodes_defs), options); } ThunkExecutor::ExecuteState::Node::Node(const NodeDef& node_def) - : counter(node_def.in_edges.size()), out_edges(&node_def.out_edges) {} + : counter(node_def.in_edges.size()), out_edges(node_def.out_edges) {} ThunkExecutor::ExecuteState::ExecuteState(ThunkExecutor* executor, Thunk::TaskRunner* runner) @@ -204,12 +247,39 @@ tsl::AsyncValueRef ThunkExecutor::Execute( return execute_event; } +// We deliberately opt-out from the cognitive complexity check, as this +// function is on a hot path, any any attempt to split it leads to measurable +// regressions in microbenchmarks. tsl::AsyncValueRef +// NOLINTNEXTLINE(readability-function-cognitive-complexity) ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) { + if constexpr (UseBlockingThunkExecutor()) { + VLOG(2) << absl::StreamFormat( + "ThunkExecutor::ExecuteSequential: execute %d thunks in blocking mode", + num_thunks_); + } + for (auto it = thunk_sequence_.begin(); it != thunk_sequence_.end(); ++it) { + // Record thunk execution start time in blocking mode. + uint64_t start_us; + if constexpr (UseBlockingThunkExecutor()) { + start_us = tsl::Env::Default()->NowMicros(); + } + Thunk& thunk = **it; auto execute_event = thunk.Execute(params); + // Log thunk execution time in blocking mode. + if constexpr (UseBlockingThunkExecutor()) { + tsl::BlockUntilReady(execute_event); + VLOG(2) << absl::StreamFormat( + " thunk[%d] took %s (op_name: %s)", + std::distance(thunk_sequence_.begin(), it), + tsl::strings::HumanReadableElapsedTime( + (tsl::Env::Default()->NowMicros() - start_us) / 1000000.0), + thunk.info().op_name); + } + // Fast path for thunks executed inline and returned OkExecuteEvent. if (ABSL_PREDICT_TRUE(thunk.IsOkExecuteEvent(execute_event))) { continue; @@ -296,7 +366,11 @@ void ThunkExecutor::ResumeExecuteSequential( event.SetStateConcrete(); } +// We deliberately opt-out from the cognitive complexity check, as this +// function is on a hot path, any any attempt to split it leads to measurable +// regressions in microbenchmarks. template +// NOLINTNEXTLINE(readability-function-cognitive-complexity) void ThunkExecutor::Execute(ExecuteState* state, const Thunk::ExecuteParams& params, ReadyQueue ready_queue, @@ -431,10 +505,10 @@ void ThunkExecutor::ProcessOutEdges( // Load `is_sink` before dropping node counters because otherwise it might // race with NodeDef destructor. - bool is_sink = node.out_edges->empty(); + bool is_sink = node.out_edges.empty(); // Append ready nodes to the back of the ready queue. - for (NodeId out_edge : *node.out_edges) { + for (NodeId out_edge : node.out_edges) { ExecuteState::Node& out_node = state->node(out_edge); int64_t cnt = out_node.counter.fetch_sub(1, std::memory_order_release); @@ -467,10 +541,50 @@ void ThunkExecutor::ProcessOutEdges( } } +std::tuple> +ThunkExecutor::CreateNodeDefs(std::vector builders) { + // Find how many in-edges and out-edges we have in total. + size_t num_in_edges = 0, num_out_edges = 0; + for (const NodeDefBuilder& b : builders) { + num_in_edges += b.in_edges.size(); + num_out_edges += b.out_edges.size(); + } + + NodesEdges nodes_in_edges; + NodesEdges nodes_out_edges; + std::vector nodes_defs; + + // Reserve memory to avoid re-allocation and dangling spans into freed memory. + nodes_in_edges.reserve(num_in_edges); + nodes_out_edges.reserve(num_out_edges); + nodes_defs.reserve(builders.size()); + + for (const NodeDefBuilder& b : builders) { + size_t num_in_edges = b.in_edges.size(); + size_t num_out_edges = b.out_edges.size(); + + auto inserted_in_edges = nodes_in_edges.insert( + nodes_in_edges.end(), b.in_edges.begin(), b.in_edges.end()); + auto inserted_out_edges = nodes_out_edges.insert( + nodes_out_edges.end(), b.out_edges.begin(), b.out_edges.end()); + + nodes_defs.push_back(NodeDef{ + b.id, b.priority, + num_in_edges ? absl::MakeConstSpan(&*inserted_in_edges, num_in_edges) + : absl::Span(), + num_out_edges ? absl::MakeConstSpan(&*inserted_out_edges, num_out_edges) + : absl::Span()}); + } + + return std::make_tuple(std::move(nodes_in_edges), std::move(nodes_out_edges), + std::move(nodes_defs)); +} + // Erases edge from `from` node to `to` node if it exists. We rely on the fact // that out and in-edges are sorted and use binary search on a critical path. -static int64_t EraseEdge(ThunkExecutor::NodeDef& from, - ThunkExecutor::NodeDef& to) { +static int64_t EraseEdge(ThunkExecutor::NodeDefBuilder& from, + ThunkExecutor::NodeDefBuilder& to) { DCHECK_NE(from.id, to.id) << "Nodes must be different"; DCHECK_LT(from.id, to.id) << "Nodes must be ordered"; @@ -514,7 +628,8 @@ static int64_t EraseEdge(ThunkExecutor::NodeDef& from, return 1; } -int64_t ThunkExecutor::RunTransitiveReductionAndUpdatePriorities() { +int64_t ThunkExecutor::RunTransitiveReductionAndUpdatePriorities( + absl::Span builders) { int64_t num_erased_edges = 0; // Keep workspace for DFS traversal between iterations. @@ -531,17 +646,17 @@ int64_t ThunkExecutor::RunTransitiveReductionAndUpdatePriorities() { // For each node we do a DFS traversal and delete redundant edges that // connect source node with the node reachable via DFS. We do traversal in // reverse order as we end up traversing fewer edges this way. - for (int64_t i = nodes_defs_.size() - 1; i >= 0; --i) { - NodeDef& source_node = nodes_defs_[i]; + for (int64_t i = builders.size() - 1; i >= 0; --i) { + NodeDefBuilder& source_node = builders[i]; // Clear DFS workspace from previous iteration. stack.clear(); - visited.assign(nodes_defs_.size(), false); + visited.assign(builders.size(), false); // Initialize stack with nodes reachable via immediate out nodes. We mark // immediate out nodes as visited to correctly compute node priority below. for (int64_t out_id : source_node.out_edges) { - NodeDef& out_node = nodes_defs_[out_id]; + NodeDefBuilder& out_node = builders[out_id]; visited[out_id] = true; for (int64_t start_id : out_node.out_edges) add_to_stack(start_id); } @@ -551,7 +666,7 @@ int64_t ThunkExecutor::RunTransitiveReductionAndUpdatePriorities() { int64_t node_id = stack.back(); stack.pop_back(); - NodeDef& node = nodes_defs_[node_id]; + NodeDefBuilder& node = builders[node_id]; num_erased_edges += EraseEdge(source_node, node); for (int64_t out_id : node.out_edges) add_to_stack(out_id); diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h index 54b4a4be2ac0c6..aaae96c906cc92 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h @@ -23,12 +23,14 @@ limitations under the License. #include #include #include +#include #include #include #include "absl/base/thread_annotations.h" #include "absl/container/fixed_array.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" @@ -69,7 +71,7 @@ class ThunkExecutor { using Options = internal::ThunkExecutorOptions; // Nodes identified by their index in the captured ThunkSequence. - using NodeId = int64_t; + using NodeId = int32_t; static constexpr NodeId kInvalidNodeId = std::numeric_limits::min(); @@ -79,8 +81,22 @@ class ThunkExecutor { static absl::StatusOr Create( ThunkSequence thunk_sequence, const Options& options = Options()); + // We store all `in_edges` and `out_edges` referenced by the `NodeDef` inside + // large vectors to optimize for data locality on a hot path. + using NodesEdges = std::vector; + // NodeDef defines an execution order for all thunks in a sequence. struct NodeDef { + NodeId id = kInvalidNodeId; + int64_t priority = 0; + absl::Span in_edges; + absl::Span out_edges; + }; + + // A NodeDef builder to collect all in-edges and out-edges before constructing + // a NodeDef. We use it at ThunkExecutor creation time when we don't know how + // many in-edges and out-edges we have in total. + struct NodeDefBuilder { NodeId id = kInvalidNodeId; int64_t priority = 0; std::vector in_edges; @@ -177,7 +193,7 @@ class ThunkExecutor { explicit Node(const NodeDef& node_def); alignas(kAtomicAlignment) std::atomic counter; - const std::vector* out_edges; + absl::Span out_edges; }; static_assert(std::is_trivially_destructible_v, @@ -189,7 +205,10 @@ class ThunkExecutor { ExecuteState(ThunkExecutor* executor, Thunk::TaskRunner* runner); - Node& node(NodeId id) { return *reinterpret_cast(&nodes[id]); } + Node& node(NodeId id) { + DCHECK_LT(id, nodes.size()) << "Node id is out of bounds"; + return *reinterpret_cast(&nodes.data()[id]); + } ThunkExecutor* executor; Thunk::TaskRunner* runner; @@ -208,7 +227,8 @@ class ThunkExecutor { absl::Status abort_status ABSL_GUARDED_BY(abort_mutex); }; - ThunkExecutor(ThunkSequence thunk_sequence, std::vector nodes_defs, + ThunkExecutor(ThunkSequence thunk_sequence, NodesEdges nodes_in_edges, + NodesEdges nodes_out_edges, std::vector nodes_defs, const Options& options); // Executes thunks sequentially starting from the first thunk in the sequence. @@ -240,17 +260,25 @@ class ThunkExecutor { tsl::AsyncValuePtr node_event, ExecuteState::Node& node, ReadyQueue& ready_queue); - // Runs a transitive reduction on the NodeDef graph to remove redundant edges, - // and updates nodes priorities. Returns the number of removed edges. + // Converts a vector of NodeDefBuilder to a tuple of NodesEdges and a vector + // of NodeDef. + static std::tuple> + CreateNodeDefs(std::vector builders); + + // Runs a transitive reduction on the NodeDefBuilder graph to remove redundant + // edges, and updates nodes priorities. Returns the number of removed edges. // // See: https://en.wikipedia.org/wiki/Transitive_reduction - int64_t RunTransitiveReductionAndUpdatePriorities(); + static int64_t RunTransitiveReductionAndUpdatePriorities( + absl::Span builders); ThunkSequence thunk_sequence_; Options options_; int64_t num_thunks_; + NodesEdges nodes_in_edges_; // `in_edges` referenced by `nodes_defs_` + NodesEdges nodes_out_edges_; // `out_edges` referenced by `nodes_defs_` std::vector nodes_defs_; std::vector source_; diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc index 511456e2adf762..dd315236916dd1 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc @@ -36,18 +36,20 @@ limitations under the License. #include "xla/backends/cpu/runtime/resource_use.h" #include "xla/backends/cpu/runtime/thread_pool_task_runner.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_testlib.h" +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/maybe_owning_device_memory.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" -#include "tsl/platform/threadpool.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/threadpool.h" #define EIGEN_USE_THREADS @@ -94,17 +96,6 @@ auto MakeTaskRunnerFrom(Runner&& runner, WorkerId&& worker_id) { std::forward(worker_id)); } -template -std::vector AsDeviceMemory( - absl::Span* const> data) { - std::vector buffers; - for (auto& vec : data) { - buffers.emplace_back( - se::DeviceMemoryBase(vec->data(), vec->size() * sizeof(T))); - } - return buffers; -} - // A test-only thunk for verifying thunk executor implementation: // // dst += src (for all srcs and dsts slices) @@ -483,10 +474,9 @@ TEST(ThunkExecutorTest, Execute) { ThunkExecutor executor, ThunkExecutor::Create(std::move(sequence), OptionsForTest())); - std::vector data(20, 1); // shared src and dst allocation - - auto buffers = AsDeviceMemory({&data}); - BufferAllocations allocations(buffers); + // Shared src and dst allocation. + auto data = LiteralUtil::CreateFull({20}, int32_t{1}); + BufferAllocations allocations = CreateBufferAllocations(data); auto task_runner = MakeTaskRunnerFrom( [&](Thunk::Task task) { @@ -507,9 +497,10 @@ TEST(ThunkExecutorTest, Execute) { ASSERT_TRUE(execute_event.IsConcrete()); EXPECT_THAT(trace, ElementsAre("", "b", "a", "c")); - EXPECT_THAT(data, ElementsAre(2, 2, 2, 2, 2, // slice0 - 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, // slice2 - 2, 2, 2, 2, 2)); // slice1 + EXPECT_EQ(data, LiteralUtil::CreateR1({2, 2, 2, 2, 2, // slice0 + 4, 4, 4, 4, 4, // slice2 + 4, 4, 4, 4, 4, // ... + 2, 2, 2, 2, 2})); // slice1 } //===----------------------------------------------------------------------===// @@ -572,10 +563,8 @@ TEST(ThunkExecutorTest, ExecuteOnCorrectThreadPool) { ThunkExecutor executor, ThunkExecutor::Create(std::move(sequence), OptionsForTest())); - std::vector data(60, 1); // shared src and dst allocation - - auto buffers = AsDeviceMemory({&data}); - BufferAllocations allocations(buffers); + auto data = LiteralUtil::CreateFull({60}, uint8_t{1}); + BufferAllocations allocations = CreateBufferAllocations(data); // Task runner must be used only when ThunkExecutor detects that it runs on a // wrong thread and has to jump into the task runner. @@ -609,17 +598,27 @@ TEST(ThunkExecutorTest, ExecuteOnCorrectThreadPool) { enum class SharedResourceUse { kNo, kAll, kRandom }; struct GeneratedThunkSequence { + explicit GeneratedThunkSequence(int64_t num_elements) + : src(LiteralUtil::CreateFull({num_elements}, int32_t{1})), + dst(LiteralUtil::CreateFull({num_elements}, int32_t{0})), + expected(LiteralUtil::CreateFull({num_elements}, int32_t{0})), + src_alloc(CreateBufferAllocation(0, src)), + dst_alloc(CreateBufferAllocation(1, dst)), + expected_shared_resource_value(0), + expected_literals({&src, &expected}), + literals({&src, &dst}) {} + + Literal src; + Literal dst; + Literal expected; + BufferAllocation src_alloc; BufferAllocation dst_alloc; - std::vector src; - std::vector dst; - std::vector expected; - int32_t expected_shared_resource_value; - std::vector expected_buffers; - std::vector buffers; + std::vector expected_literals; + std::vector literals; ThunkSequence sequence; }; @@ -628,18 +627,8 @@ static absl::StatusOr> GenerateThunkSequence(size_t num_elements, size_t num_thunks, SharedResourceUse shared_resource_use, bool inject_errors) { - auto g = std::make_unique(GeneratedThunkSequence{ - BufferAllocation(/*index=*/0, num_elements * sizeof(int32_t), 0), - BufferAllocation(/*index=*/1, num_elements * sizeof(int32_t), 0), - /*src=*/std::vector(num_elements, 1), - /*dst=*/std::vector(num_elements, 0), - /*expected=*/std::vector(num_elements, 0), - /*expected_shared_resource_value=*/0, - }); - + auto g = std::make_unique(num_elements); g->sequence.reserve(num_thunks); - g->expected_buffers = AsDeviceMemory({&g->src, &g->expected}); - g->buffers = AsDeviceMemory({&g->src, &g->dst}); std::minstd_rand0 engine; @@ -661,7 +650,8 @@ GenerateThunkSequence(size_t num_elements, size_t num_thunks, BufferAllocation::Slice dst = random_slice(&g->dst_alloc); // Pre-compute expected result while building the thunk sequence. - BufferAllocations allocations(g->expected_buffers); + BufferAllocations allocations = + CreateBufferAllocations(absl::MakeSpan(g->expected_literals)); TF_RETURN_IF_ERROR(AddI32Thunk::Execute(&allocations, src, dst)); bool use_resource = [&] { @@ -747,7 +737,8 @@ TEST_P(ThunkExecutorStressTest, Execute) { ThunkExecutor executor, ThunkExecutor::Create(std::move(g->sequence), executor_options)); - BufferAllocations allocations(g->buffers); + BufferAllocations allocations = + CreateBufferAllocations(absl::MakeSpan(g->literals)); Thunk::ExecuteParams params = {nullptr, &allocations, nullptr, device(), task_runner()}; @@ -886,7 +877,8 @@ static void BM_SequentialThunkExecutor(benchmark::State& state) { auto e = ThunkExecutor::Create(std::move(g->sequence), OptionsForTest()).value(); - BufferAllocations allocations(g->buffers); + BufferAllocations allocations = + CreateBufferAllocations(absl::MakeSpan(g->literals)); Thunk::ExecuteParams params = {nullptr, &allocations}; for (auto _ : state) { @@ -901,16 +893,15 @@ static void BM_SyncThunkExecutor(benchmark::State& state) { auto g = GenerateThunkSequence(/*num_elements=*/1024, num_thunks, /*shared_resource_use=*/SharedResourceUse::kNo, - /*inject_errors=*/false) - .value(); - auto e = - ThunkExecutor::Create(std::move(g->sequence), OptionsForTest()).value(); + /*inject_errors=*/false); + auto e = ThunkExecutor::Create(std::move((*g)->sequence), OptionsForTest()); - BufferAllocations allocations(g->buffers); + BufferAllocations allocations = + CreateBufferAllocations(absl::MakeSpan((*g)->literals)); Thunk::ExecuteParams params = {nullptr, &allocations}; for (auto _ : state) { - auto execute_event = e.Execute(params); + auto execute_event = e->Execute(params); tsl::BlockUntilReady(execute_event); CHECK(execute_event.IsConcrete()); } @@ -925,19 +916,18 @@ static void BM_AsyncThunkExecutor(benchmark::State& state) { auto g = GenerateThunkSequence(/*num_elements=*/1024, num_thunks, /*shared_resource_use=*/SharedResourceUse::kNo, - /*inject_errors=*/false) - .value(); - auto e = - ThunkExecutor::Create(std::move(g->sequence), OptionsForTest()).value(); + /*inject_errors=*/false); + auto e = ThunkExecutor::Create(std::move((*g)->sequence), OptionsForTest()); - BufferAllocations allocations(g->buffers); + BufferAllocations allocations = + CreateBufferAllocations(absl::MakeSpan((*g)->literals)); ThreadPoolTaskRunner task_runner(thread_pool.AsEigenThreadPool()); Thunk::ExecuteParams params = {nullptr, &allocations, nullptr, &device, &task_runner}; for (auto _ : state) { - auto execute_event = e.Execute(params); + auto execute_event = e->Execute(params); tsl::BlockUntilReady(execute_event); CHECK(execute_event.IsConcrete()); } diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_test.cc index 1b0dd200a864c8..d8bc5faafdbaa7 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_test.cc @@ -17,11 +17,11 @@ limitations under the License. #include +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/executable_run_options.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/cpu_executable_run_options.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla::cpu { namespace { @@ -93,13 +93,12 @@ TEST(ThunkTest, CollectiveExecuteParams) { // Test forwarding collectives interface from CpuExecutableRunOptions. CpuExecutableRunOptions cpu_run_options; cpu_run_options.set_collectives( - reinterpret_cast(0x12345678)); + reinterpret_cast(0x12345678)); run_options.set_cpu_executable_run_options(&cpu_run_options); TF_ASSERT_OK_AND_ASSIGN(params, Thunk::CollectiveExecuteParams::Create(&run_options)); - EXPECT_EQ(params.collectives, - reinterpret_cast(0x12345678)); + EXPECT_EQ(params.collectives, reinterpret_cast(0x12345678)); } } // namespace diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_testlib.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_testlib.cc new file mode 100644 index 00000000000000..96fc5f68115e4e --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_testlib.cc @@ -0,0 +1,57 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/thunk_testlib.h" + +#include +#include +#include + +#include "absl/types/span.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/literal.h" +#include "xla/service/buffer_assignment.h" +#include "xla/stream_executor/device_memory.h" + +namespace xla::cpu { + +BufferAllocation CreateBufferAllocation(size_t index, const Literal& literal) { + size_t size_in_bytes = literal.size_bytes(); + return BufferAllocation(index, size_in_bytes, 0); +} + +BufferAllocation::Slice CreateBufferAllocationSlice( + const BufferAllocation& allocation) { + return CreateBufferAllocationSlice(allocation, 0, allocation.size()); +} + +BufferAllocation::Slice CreateBufferAllocationSlice( + const BufferAllocation& allocation, int64_t offset, int64_t size) { + return BufferAllocation::Slice(&allocation, offset, size); +} + +BufferAllocations CreateBufferAllocations(absl::Span literals) { + std::vector buffers; + buffers.reserve(literals.size()); + + for (auto* literal : literals) { + size_t size_in_bytes = literal->size_bytes(); + buffers.emplace_back(literal->untyped_data(), size_in_bytes); + } + + return BufferAllocations(buffers); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_testlib.h b/third_party/xla/xla/backends/cpu/runtime/thunk_testlib.h index 4da0650efee7c4..9476184750c552 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_testlib.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_testlib.h @@ -16,14 +16,78 @@ limitations under the License. #ifndef XLA_BACKENDS_CPU_RUNTIME_THUNK_TESTLIB_H_ #define XLA_BACKENDS_CPU_RUNTIME_THUNK_TESTLIB_H_ +#include +#include +#include +#include +#include +#include + #include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/resource_use.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/literal.h" #include "xla/runtime/buffer_use.h" +#include "xla/service/buffer_assignment.h" #include "xla/tsl/concurrency/async_value_ref.h" namespace xla::cpu { +//===----------------------------------------------------------------------===// +// A set of helper functions to create buffer allocations from Literals. +//===----------------------------------------------------------------------===// + +// Creates a BufferAllocation with given index from a literal. +BufferAllocation CreateBufferAllocation(size_t index, const Literal& literal); + +// Creates an array of BufferAllocations from a variadic pack of literals. +template < + typename... Literals, + std::enable_if_t...>>* = + nullptr> +std::array CreateBufferAllocation( + Literals&... literals) { + size_t index = 0; + return {CreateBufferAllocation(index++, literals)...}; +} + +// Creates a BufferAllocation::Slice that covers the entire allocation. +BufferAllocation::Slice CreateBufferAllocationSlice( + const BufferAllocation& allocation); + +// Creates a BufferAllocation::Slice that covers a subrange of the allocation. +BufferAllocation::Slice CreateBufferAllocationSlice( + const BufferAllocation& allocation, int64_t offset, int64_t size); + +// Creates an array of BufferAllocation::Slice from a pack of allocations. Each +// slice covers the entire corresponding allocation. +template ...>>* = nullptr> +std::array +CreateBufferAllocationSlice(const BufferAllocations&... allocations) { + return {CreateBufferAllocationSlice(allocations)...}; +} + +// Creates a BufferAllocations from a span of literals. +BufferAllocations CreateBufferAllocations(absl::Span literals); + +// Creates a BufferAllocations from a variadic pack of literals. +template < + typename... Literals, + std::enable_if_t...>>* = + nullptr> +BufferAllocations CreateBufferAllocations(Literals&... literals) { + std::vector literals_ptrs = {&literals...}; + return CreateBufferAllocations(absl::MakeSpan(literals_ptrs)); +} + +//===----------------------------------------------------------------------===// +// A library of test-only thunks. +//===----------------------------------------------------------------------===// + // A test-only thunk to create a Thunk with a specific buffer use. class BufferUseThunk : public Thunk { public: diff --git a/third_party/xla/xla/backends/cpu/runtime/while_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/while_thunk_test.cc index d4b874a72b380f..0a78fff7818792 100644 --- a/third_party/xla/xla/backends/cpu/runtime/while_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/while_thunk_test.cc @@ -26,15 +26,15 @@ limitations under the License. #include "xla/backends/cpu/runtime/resource_use.h" #include "xla/backends/cpu/runtime/thunk.h" #include "xla/backends/cpu/runtime/thunk_testlib.h" +#include "xla/literal_util.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/maybe_owning_device_memory.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/env.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" -#include "tsl/platform/threadpool.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/threadpool.h" #define EIGEN_USE_THREADS @@ -161,27 +161,21 @@ class BodyThunk : public Thunk { TEST(WhileThunkTest, NonBlockingExecute) { static constexpr size_t kNumIterations = 100; - BufferAllocation pred_alloc(0, sizeof(char), 0); - BufferAllocation cnt_alloc(1, sizeof(int32_t), 0); + auto pred = LiteralUtil::CreateR0(false); + auto counter = LiteralUtil::CreateR0(0); - BufferAllocation::Slice pred_slice(&pred_alloc, 0, sizeof(char)); - BufferAllocation::Slice cnt_slice(&cnt_alloc, 0, sizeof(int32_t)); + BufferAllocations allocations = CreateBufferAllocations(pred, counter); - std::vector buffers; - std::vector predicate = {false}; - std::vector counter = {0}; - - buffers.emplace_back(se::DeviceMemoryBase(predicate.data(), sizeof(char))); - buffers.emplace_back(se::DeviceMemoryBase(counter.data(), sizeof(int32_t))); - - BufferAllocations allocations(buffers); + auto [pred_alloc, counter_alloc] = CreateBufferAllocation(pred, counter); + auto [pred_slice, counter_slice] = + CreateBufferAllocationSlice(pred_alloc, counter_alloc); ThunkSequence cond_sequence; cond_sequence.push_back( std::make_unique(kNumIterations, pred_slice)); ThunkSequence body_sequence; - body_sequence.push_back(std::make_unique(cnt_slice)); + body_sequence.push_back(std::make_unique(counter_slice)); TF_ASSERT_OK_AND_ASSIGN( auto thunk, @@ -200,26 +194,20 @@ TEST(WhileThunkTest, NonBlockingExecute) { tsl::BlockUntilReady(execute_event); ASSERT_FALSE(execute_event.IsError()); - EXPECT_EQ(counter[0], kNumIterations); + EXPECT_EQ(counter, LiteralUtil::CreateR0(kNumIterations)); } TEST(WhileThunkTest, NonBlockingExecuteWithTripCount) { static constexpr size_t kNumIterations = 100; - BufferAllocation pred_alloc(0, sizeof(char), 0); - BufferAllocation cnt_alloc(1, sizeof(int32_t), 0); - - BufferAllocation::Slice pred_slice(&pred_alloc, 0, sizeof(char)); - BufferAllocation::Slice cnt_slice(&cnt_alloc, 0, sizeof(int32_t)); - - std::vector buffers; - std::vector predicate = {false}; - std::vector counter = {0}; + auto pred = LiteralUtil::CreateR0(false); + auto counter = LiteralUtil::CreateR0(0); - buffers.emplace_back(se::DeviceMemoryBase(predicate.data(), sizeof(char))); - buffers.emplace_back(se::DeviceMemoryBase(counter.data(), sizeof(int32_t))); + BufferAllocations allocations = CreateBufferAllocations(pred, counter); - BufferAllocations allocations(buffers); + auto [pred_alloc, counter_alloc] = CreateBufferAllocation(pred, counter); + auto [pred_slice, counter_slice] = + CreateBufferAllocationSlice(pred_alloc, counter_alloc); // We pass empty cond sequence, because we know the trip count, and check that // predicate value is ignored (it is initialized to false) and body executed @@ -227,7 +215,7 @@ TEST(WhileThunkTest, NonBlockingExecuteWithTripCount) { ThunkSequence cond_sequence; ThunkSequence body_sequence; - body_sequence.push_back(std::make_unique(cnt_slice)); + body_sequence.push_back(std::make_unique(counter_slice)); TF_ASSERT_OK_AND_ASSIGN( auto thunk, WhileThunk::Create( @@ -246,7 +234,7 @@ TEST(WhileThunkTest, NonBlockingExecuteWithTripCount) { tsl::BlockUntilReady(execute_event); ASSERT_FALSE(execute_event.IsError()); - EXPECT_EQ(counter[0], kNumIterations); + EXPECT_EQ(counter, LiteralUtil::CreateR0(kNumIterations)); } } // namespace diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/BUILD b/third_party/xla/xla/backends/cpu/runtime/xnnpack/BUILD index f63eb4d9377791..8b65dedcb6eaac 100644 --- a/third_party/xla/xla/backends/cpu/runtime/xnnpack/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/BUILD @@ -1,3 +1,4 @@ +load("//xla:xla.bzl", "xla_cc_test") load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( @@ -13,10 +14,227 @@ package_group( ], ) +cc_library( + name = "object_pool", + hdrs = ["object_pool.h"], + deps = [ + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "object_pool_test", + srcs = ["object_pool_test.cc"], + deps = [ + ":object_pool", + "//xla/tsl/platform:env", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "parallel_loop_runner", + srcs = ["parallel_loop_runner.cc"], + hdrs = ["parallel_loop_runner.h"], + deps = [ + "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/math:math_util", + "//xla/tsl/platform:logging", + "@com_google_absl//absl/base:core_headers", + "@eigen_archive//:eigen3", + ], +) + +xla_cc_test( + name = "parallel_loop_runner_test", + srcs = ["parallel_loop_runner_test.cc"], + deps = [ + ":parallel_loop_runner", + "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:env", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", + ], +) + cc_library( name = "xnn_interop", + srcs = ["xnn_interop.cc"], hdrs = ["xnn_interop.h"], deps = [ + "//xla:util", + "@XNNPACK", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:logging", + ], +) + +cc_library( + name = "xnn_threadpool", + srcs = ["xnn_threadpool.cc"], + hdrs = ["xnn_threadpool.h"], + # copybara:uncomment_begin(google-only) + # local_defines = select({ + # "@pthreadpool:pthreadpool_header_only_explicit_true": [ + # "XLA_CPU_USE_CUSTOM_PTHREADPOOL", + # ], + # "//conditions:default": [], + # }), + # copybara:uncomment_end + deps = [ + ":parallel_loop_runner", + "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:env", + "//xla/tsl/platform:logging", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:platform_port", + "@pthreadpool", + ], +) + +xla_cc_test( + name = "xnn_threadpool_test", + srcs = ["xnn_threadpool_test.cc"], + deps = [ + ":parallel_loop_runner", + ":xnn_threadpool", + "//xla/tsl/concurrency:async_value", + "@XNNPACK", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/synchronization", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", + "@pthreadpool", + ], +) + +cc_library( + name = "xnn_dot_thunk", + srcs = ["xnn_dot_thunk.cc"], + hdrs = ["xnn_dot_thunk.h"], + deps = [ + ":xnn_fusion_thunk", + ":xnn_interop", + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/backends/cpu/runtime:dot_lib", + "//xla/backends/cpu/runtime:thunk", + "//xla/service:buffer_assignment", + "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "@XNNPACK", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", + ], +) + +xla_cc_test( + name = "xnn_dot_thunk_test", + srcs = ["xnn_dot_thunk_test.cc"], + deps = [ + ":xnn_dot_thunk", + "//xla:executable_run_options", + "//xla:literal_util", + "//xla:shape_util", + "//xla/backends/cpu/runtime:buffer_allocations", + "//xla/backends/cpu/runtime:thunk", + "//xla/backends/cpu/runtime:thunk_testlib", + "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "xnn_fusion_thunk", + srcs = ["xnn_fusion_thunk.cc"], + hdrs = ["xnn_fusion_thunk.h"], + deps = [ + ":object_pool", + ":parallel_loop_runner", + ":xnn_interop", + ":xnn_threadpool", + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/backends/cpu/runtime:dot_lib", + "//xla/backends/cpu/runtime:thunk", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/stream_executor:device_memory", + "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "@XNNPACK", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", + "@pthreadpool", + ], +) + +xla_cc_test( + name = "xnn_fusion_thunk_test", + srcs = ["xnn_fusion_thunk_test.cc"], + deps = [ + ":xnn_fusion_thunk", + ":xnn_interop", + "//xla:executable_run_options", + "//xla:literal_util", + "//xla:shape_util", + "//xla/backends/cpu/runtime:buffer_allocations", + "//xla/backends/cpu/runtime:thunk", + "//xla/backends/cpu/runtime:thunk_testlib", + "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@XNNPACK", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/object_pool.h b/third_party/xla/xla/backends/cpu/runtime/xnnpack/object_pool.h new file mode 100644 index 00000000000000..32313c2d04487e --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/object_pool.h @@ -0,0 +1,140 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_RUNTIME_XNNPACK_OBJECT_POOL_H_ +#define XLA_BACKENDS_CPU_RUNTIME_XNNPACK_OBJECT_POOL_H_ + +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "tsl/platform/statusor.h" + +namespace xla::cpu { + +// A non-blocking pool of objects of type `T`. Objects in the pool are created +// lazily when needed by calling the user-provided `builder` function. +// +// This object pool is intended to be used on a critical path and optimized for +// zero-allocation in steady state. +template +class ObjectPool { + struct Entry { + T object; + std::atomic next; + }; + + public: + explicit ObjectPool(absl::AnyInvocable(Args...)> builder); + ~ObjectPool(); + + class BorrowedObject { + public: + ~BorrowedObject(); + + T& operator*() { return entry_->object; } + T* operator->() { return &entry_->object; } + + BorrowedObject(BorrowedObject&&) = default; + BorrowedObject& operator=(BorrowedObject&&) = default; + + private: + friend class ObjectPool; + + BorrowedObject(ObjectPool* parent, std::unique_ptr entry); + + ObjectPool* parent_; + std::unique_ptr entry_; + }; + + absl::StatusOr GetOrCreate(Args... args); + + size_t num_created() const { return num_created_.load(); } + + private: + absl::StatusOr> CreateEntry(Args... args); + std::unique_ptr PopEntry(); + void PushEntry(std::unique_ptr entry); + + absl::AnyInvocable(Args...)> builder_; + std::atomic head_; + std::atomic num_created_; +}; + +template +ObjectPool::ObjectPool( + absl::AnyInvocable(Args...)> builder) + : builder_(std::move(builder)), head_(nullptr), num_created_(0) {} + +template +ObjectPool::~ObjectPool() { + while (Entry* entry = head_.load()) { + head_.store(entry->next); + delete entry; + } +} + +template +auto ObjectPool::CreateEntry(Args... args) + -> absl::StatusOr> { + auto entry = std::make_unique(); + TF_ASSIGN_OR_RETURN(entry->object, builder_(std::forward(args)...)); + entry->next = nullptr; + num_created_.fetch_add(1); + return entry; +} + +template +auto ObjectPool::PopEntry() -> std::unique_ptr { + Entry* head = head_.load(); + while (head && !head_.compare_exchange_weak(head, head->next)) { + } + return std::unique_ptr(head); +} + +template +void ObjectPool::PushEntry(std::unique_ptr entry) { + Entry* head = head_.load(); + Entry* new_head = entry.release(); + do { + new_head->next = head; + } while (!head_.compare_exchange_weak(head, new_head)); +} + +template +ObjectPool::BorrowedObject::BorrowedObject( + ObjectPool* parent, std::unique_ptr entry) + : parent_(parent), entry_(std::move(entry)) {} + +template +ObjectPool::BorrowedObject::~BorrowedObject() { + if (parent_ && entry_) parent_->PushEntry(std::move(entry_)); +} + +template +auto ObjectPool::GetOrCreate(Args... args) + -> absl::StatusOr { + if (std::unique_ptr entry = PopEntry()) { + return BorrowedObject(this, std::move(entry)); + } + TF_ASSIGN_OR_RETURN(auto entry, CreateEntry(std::forward(args)...)); + return BorrowedObject(this, std::move(entry)); +} + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_RUNTIME_XNNPACK_OBJECT_POOL_H_ diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/object_pool_test.cc b/third_party/xla/xla/backends/cpu/runtime/xnnpack/object_pool_test.cc new file mode 100644 index 00000000000000..bdad63e68621d5 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/object_pool_test.cc @@ -0,0 +1,103 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/xnnpack/object_pool.h" + +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/synchronization/blocking_counter.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/threadpool.h" + +namespace xla::cpu { +namespace { + +using IntPool = ObjectPool>; + +TEST(ObjectPoolTest, GetOrCreate) { + int32_t counter = 0; + IntPool pool([&]() -> absl::StatusOr> { + return std::make_unique(counter++); + }); + + TF_ASSERT_OK_AND_ASSIGN(auto obj0, pool.GetOrCreate()); + ASSERT_EQ(**obj0, 0); + + TF_ASSERT_OK_AND_ASSIGN(auto obj1, pool.GetOrCreate()); + ASSERT_EQ(**obj1, 1); + + auto destroy = [](IntPool::BorrowedObject obj) {}; + destroy(std::move(obj0)); + destroy(std::move(obj1)); + + TF_ASSERT_OK_AND_ASSIGN(auto obj2, pool.GetOrCreate()); + ASSERT_EQ(**obj2, 1); + ASSERT_EQ(counter, 2); +} + +TEST(ObjectPoolTest, GetOrCreateUnderContention) { + tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8); + + std::atomic counter = 0; + IntPool pool([&]() -> absl::StatusOr> { + return std::make_unique(counter++); + }); + + size_t num_tasks = 10; + absl::BlockingCounter blocking_counter(num_tasks); + + for (int32_t t = 0; t < num_tasks; ++t) { + threads.Schedule([&] { + for (int32_t i = 0; i < 100; ++i) { + TF_ASSERT_OK_AND_ASSIGN(auto obj, pool.GetOrCreate()); + ASSERT_GE(**obj, 0); + } + blocking_counter.DecrementCount(); + }); + } + + blocking_counter.Wait(); + + // We should create at most one object for each thread in the pool. + EXPECT_LE(counter, 8); +} + +//===----------------------------------------------------------------------===// +// Performance benchmarks. +//===----------------------------------------------------------------------===// + +static void BM_GetOrCreate(benchmark::State& state) { + IntPool pool([cnt = 0]() mutable -> absl::StatusOr> { + return std::make_unique(cnt++); + }); + + for (auto _ : state) { + auto obj = pool.GetOrCreate(); + benchmark::DoNotOptimize(obj); + } +} + +BENCHMARK(BM_GetOrCreate); + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc b/third_party/xla/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc new file mode 100644 index 00000000000000..f3a23b04861437 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc @@ -0,0 +1,416 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/optimization.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/concurrency/chain.h" +#include "xla/tsl/lib/math/math_util.h" +#include "xla/tsl/platform/logging.h" + +#define EIGEN_USE_THREADS +#include "unsupported/Eigen/CXX11/Tensor" + +namespace xla::cpu { + +using Task = std::function; + +// Returns non-reference-counted async value ref in constructed state. +// +// Returned async value is a per-process singleton stored in a storage with a +// static duration, and can be safely compared using pointer equality. +static tsl::AsyncValueRef OkDoneEventSingleton() { + static tsl::AsyncValueOwningRef* singleton = [] { + auto* storage = new tsl::internal::AsyncValueStorage(); + return new tsl::AsyncValueOwningRef( + tsl::MakeAvailableAsyncValueRef(*storage)); + }(); + return singleton->AsRef(); +} + +ParallelLoopRunner::ParallelLoopRunner(const Eigen::ThreadPoolDevice* device) + : done_event_(OkDoneEventSingleton()), device_(device) {} + +tsl::AsyncValueRef ParallelLoopRunner::ResetDoneEvent() { + auto done_event = std::move(done_event_); + done_event_ = OkDoneEventSingleton(); + return done_event; +} + +size_t ParallelLoopRunner::num_threads() const { + return device_.load()->numThreadsInPool(); +} + +tsl::AsyncValueRef ParallelLoopRunner::TakeDoneEvent( + ParallelLoopRunner&& runner) { + return std::move(runner.done_event_); +} + +ParallelLoopRunner::ParallelTaskConfig +ParallelLoopRunner::ComputeParallelTaskConfig(size_t num_tasks) const { + // We limit the number of parallel tasks per thread to avoid excessive task + // scheduling overheads at run time. + static constexpr size_t kMaxTasksPerThread = 4; + + size_t parallel_task_size = + tsl::MathUtil::CeilOfRatio(num_tasks, kMaxTasksPerThread * num_threads()); + size_t num_parallel_tasks = + tsl::MathUtil::CeilOfRatio(num_tasks, parallel_task_size); + + return {num_tasks, parallel_task_size, num_parallel_tasks}; +} + +template +static void Parallelize(ParallelizeContext* ctx, Index start_index, + Index end_index) { + CHECK_LT(start_index, end_index) << "Invalid task index range"; // Crash OK + + // Recursively split the task into two halves and schedule the right half into + // the thread pool. + while (end_index - start_index > 1) { + Index mid_index = (start_index + end_index) / 2; + ctx->device->enqueueNoNotification([ctx, mid_index, end_index] { + Parallelize(ctx, mid_index, end_index); + }); + end_index = mid_index; + } + + // Execute the `start_index` task in the caller thread. + ctx->parallel_task(start_index); + + // If count down is completed, delete the context. + if (ctx->count_down.CountDown()) { + delete ctx; + } +} + +template +void ParallelLoopRunner::Parallelize( + tsl::CountDownAsyncValueRef count_down, size_t start_index, + size_t end_index, ParallelTask&& parallel_task) { + CHECK_LT(start_index, end_index) << "Invalid task index range"; // Crash OK + + struct ParallelizeContext { + ParallelizeContext(tsl::CountDownAsyncValueRef count_down, + const Eigen::ThreadPoolDevice* device, + ParallelTask&& parallel_task) + : count_down(std::move(count_down)), + device(device), + parallel_task(std::forward(parallel_task)) {} + + tsl::CountDownAsyncValueRef count_down; + const Eigen::ThreadPoolDevice* device; + ParallelTask parallel_task; + }; + + auto ctx = std::make_unique( + std::move(count_down), device_, + std::forward(parallel_task)); + + // We try to use uint16_t for index type because it enables small buffer + // optimization in the constructed `std::function` tasks. + if (ABSL_PREDICT_TRUE(end_index <= std::numeric_limits::max())) { + xla::cpu::Parallelize(ctx.release(), start_index, end_index); + } else { + xla::cpu::Parallelize(ctx.release(), start_index, end_index); + } +} + +template +void ParallelLoopRunner::ScheduleOne(Task&& task) { + auto event = tsl::MakeConstructedAsyncValueRef(); + done_event_.AndThen([event, task = std::forward(task)] { + task(); + event.SetStateConcrete(); + }); + done_event_ = std::move(event); +} + +template +void ParallelLoopRunner::ScheduleAll(size_t num_tasks, + ParallelTask&& parallel_task) { + tsl::CountDownAsyncValueRef count_down(num_tasks); + auto count_down_done = count_down.AsRef(); + + done_event_.AndThen([this, num_tasks, count_down = std::move(count_down), + parallel_task = + std::forward(parallel_task)] { + Parallelize(std::move(count_down), 0, num_tasks, std::move(parallel_task)); + }); + done_event_ = std::move(count_down_done); +} + +namespace { + +// Multidimensional index types for the parallel loop runner tasks. We launch +// tasks using one-dimensional `task_index` and convert it into a +// multidimensional index type depending on the loop type. + +struct Task1DTile1DIndex { + size_t offset; + size_t extent; +}; + +struct Task2DTile1DIndex { + size_t i; + size_t offset_j; + size_t extent_j; +}; + +struct Task3DTile2DIndex { + size_t i; + size_t offset_j; + size_t offset_k; + size_t extent_j; + size_t extent_k; +}; + +} // namespace + +auto ParallelLoopRunner::ParallelTaskConfig::ParallelTaskRange( + size_t parallel_task_index) const -> TaskRange { + size_t begin = parallel_task_index * parallel_task_size; + size_t end = std::min(num_tasks, begin + parallel_task_size); + return {begin, end}; +} + +static Task1DTile1DIndex Delinearize(size_t task_index, size_t range, + size_t tile) { + size_t offset = task_index * tile; + size_t extent = std::min(range - offset, tile); + return {offset, extent}; +} + +static size_t NumTasks(size_t range_i, size_t range_j, size_t tile_j) { + size_t num_tile_j_tasks = tsl::MathUtil::CeilOfRatio(range_j, tile_j); + size_t num_tasks = range_i * num_tile_j_tasks; + DCHECK_GT(num_tasks, 0) << "Expected at least one tile task"; + return num_tasks; +} + +static Task2DTile1DIndex Delinearize(size_t task_index, size_t range_i, + size_t range_j, size_t tile_j) { + size_t num_tile_j_tasks = tsl::MathUtil::CeilOfRatio(range_j, tile_j); + DCHECK_GT(num_tile_j_tasks, 0) << "Expected at least one tile j task"; + + // Compute task indices along the `i` and `j` dimensions. + size_t task_i = task_index / num_tile_j_tasks; + size_t task_j = task_index % num_tile_j_tasks; + + // Convert task index into the offset and extent along the `j` dimension. + size_t offset_j = task_j * tile_j; + size_t extent_j = std::min(range_j - offset_j, tile_j); + + return {task_i, offset_j, extent_j}; +} + +static size_t NumTasks(size_t range_i, size_t range_j, size_t range_k, + size_t tile_j, size_t tile_k) { + size_t num_tile_j_tasks = tsl::MathUtil::CeilOfRatio(range_j, tile_j); + size_t num_tile_k_tasks = tsl::MathUtil::CeilOfRatio(range_k, tile_k); + size_t num_tasks = range_i * num_tile_j_tasks * num_tile_k_tasks; + DCHECK_GT(num_tasks, 0) << "Expected at least one tile task"; + return num_tasks; +} + +static Task3DTile2DIndex Delinearize(size_t task_index, size_t range_i, + size_t range_j, size_t range_k, + size_t tile_j, size_t tile_k) { + size_t num_tile_j_tasks = tsl::MathUtil::CeilOfRatio(range_j, tile_j); + size_t num_tile_k_tasks = tsl::MathUtil::CeilOfRatio(range_k, tile_k); + size_t num_tile_tasks = num_tile_j_tasks * num_tile_k_tasks; + + DCHECK_GT(num_tile_j_tasks, 0) << "Expected at least one tile j task"; + DCHECK_GT(num_tile_k_tasks, 0) << "Expected at least one tile k task"; + + // Compute task indices along the `i`, `j` and `k` dimensions. + size_t task_i = task_index / num_tile_tasks; + task_index %= num_tile_tasks; + + size_t task_j = task_index / num_tile_k_tasks; + task_index %= num_tile_k_tasks; + + size_t task_k = task_index; + + // Convert task indices into the offset and extent along the `j` and `k` + // dimensions. + size_t offset_j = task_j * tile_j; + size_t offset_k = task_k * tile_k; + size_t extent_j = std::min(range_j - offset_j, tile_j); + size_t extent_k = std::min(range_k - offset_k, tile_k); + + return {task_i, offset_j, offset_k, extent_j, extent_k}; +} + +// In the `Parallelize` implementations below: +// +// (1) If done event is already available, execute the task immediately in the +// caller thread. In this case we don't need to overwrite the done event, +// because the existing one will correctly represent the state of the +// parallel loop runner (all scheduled loops are ready). +// +// (2) If done event is not available, we have to overwrite it with a new one +// that will be set to concrete state after the task is executed. + +void ParallelLoopRunner::Parallelize(size_t range, Task1D task) { + DCHECK(done_event_) << "Parallel loop runner is in moved-from state"; + DCHECK_GT(range, 0) << "Expected at least one task"; + + // Fast path for the degenerate parallel loop with single task. + if (ABSL_PREDICT_TRUE(range == 1)) { + // Execute task in the caller thread if done event is already available. + if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) { + task(0); + return; + } + + // Schedule task when done event becomes available. + ScheduleOne([task = std::move(task)] { task(0); }); + return; + } + + // Schedule `parallel_config.num_parallel_tasks` into the underlying thread + // pool when done event becomes available. + auto parallel_config = ComputeParallelTaskConfig(range); + auto parallel_task = [parallel_config, + task = std::move(task)](size_t parallel_task_index) { + auto [begin, end] = parallel_config.ParallelTaskRange(parallel_task_index); + for (size_t i = begin; i < end; ++i) task(i); + }; + + ScheduleAll(parallel_config.num_parallel_tasks, std::move(parallel_task)); +} + +void ParallelLoopRunner::Parallelize(size_t range, size_t tile, + Task1DTile1D task) { + DCHECK(done_event_) << "Parallel loop runner is in moved-from state"; + + size_t num_tasks = tsl::MathUtil::CeilOfRatio(range, tile); + DCHECK_GT(num_tasks, 0) << "Expected at least one task"; + + // Fast path for the degenerate parallel loop with single task. + if (ABSL_PREDICT_TRUE(num_tasks == 1)) { + DCHECK_EQ(range, tile) << "Expected range to be equal to tile"; + + // Execute task in the caller thread if done event is already available. + if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) { + task(0, range); + return; + } + + // Schedule task when done event becomes available. + ScheduleOne([range, task = std::move(task)] { task(0, range); }); + return; + } + + // Schedule `parallel_config.num_parallel_tasks` into the underlying thread + // pool when done event becomes available. + auto parallel_config = ComputeParallelTaskConfig(num_tasks); + auto parallel_task = [range, tile, parallel_config, + task = std::move(task)](size_t parallel_task_index) { + auto [begin, end] = parallel_config.ParallelTaskRange(parallel_task_index); + for (size_t i = begin; i < end; ++i) { + auto x = Delinearize(i, range, tile); + task(x.offset, x.extent); + } + }; + + ScheduleAll(parallel_config.num_parallel_tasks, std::move(parallel_task)); +} + +void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j, + size_t tile_j, Task2DTile1D task) { + DCHECK(done_event_) << "Parallel loop runner is in moved-from state"; + size_t num_tasks = NumTasks(range_i, range_j, tile_j); + + // Fast path for the degenerate parallel loop with single task. + if (ABSL_PREDICT_TRUE(num_tasks == 1)) { + DCHECK_EQ(range_j, tile_j) << "Expected range to be equal to tile"; + + // Execute task in the caller thread if done event is already available. + if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) { + task(0, 0, range_j); + return; + } + + // Schedule task when done event becomes available. + ScheduleOne([range_j, task = std::move(task)] { task(0, 0, range_j); }); + return; + } + + // Schedule `parallel_config.num_parallel_tasks` into the underlying thread + // pool when done event becomes available. + auto parallel_config = ComputeParallelTaskConfig(num_tasks); + auto parallel_task = [range_i, range_j, tile_j, parallel_config, + task = std::move(task)](size_t parallel_task_index) { + auto [begin, end] = parallel_config.ParallelTaskRange(parallel_task_index); + for (size_t i = begin; i < end; ++i) { + auto x = Delinearize(i, range_i, range_j, tile_j); + task(x.i, x.offset_j, x.extent_j); + } + }; + + ScheduleAll(parallel_config.num_parallel_tasks, std::move(parallel_task)); +} + +void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j, + size_t range_k, size_t tile_j, + size_t tile_k, Task3DTile2D task) { + DCHECK(done_event_) << "Parallel loop runner is in moved-from state"; + size_t num_tasks = NumTasks(range_i, range_j, range_k, tile_j, tile_k); + + // Fast path for the degenerate parallel loop with single task. + if (ABSL_PREDICT_TRUE(num_tasks == 1)) { + DCHECK_EQ(range_j, tile_j) << "Expected range to be equal to tile"; + DCHECK_EQ(range_k, tile_k) << "Expected range to be equal to tile"; + + // Execute task in the caller thread if done event is already available. + if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) { + task(0, 0, 0, range_j, range_k); + return; + } + + // Schedule task when done event becomes available. + ScheduleOne([range_j, range_k, task = std::move(task)] { + task(0, 0, 0, range_j, range_k); + }); + return; + } + + // Schedule `parallel_config.num_parallel_tasks` into the underlying thread + // pool when done event becomes available. + auto parallel_config = ComputeParallelTaskConfig(num_tasks); + auto parallel_task = [range_i, range_j, range_k, tile_j, tile_k, + parallel_config, + task = std::move(task)](size_t parallel_task_index) { + auto [begin, end] = parallel_config.ParallelTaskRange(parallel_task_index); + for (size_t i = begin; i < end; ++i) { + auto x = Delinearize(i, range_i, range_j, range_k, tile_j, tile_k); + task(x.i, x.offset_j, x.offset_k, x.extent_j, x.extent_k); + } + }; + + ScheduleAll(parallel_config.num_parallel_tasks, std::move(parallel_task)); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h b/third_party/xla/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h new file mode 100644 index 00000000000000..361378a6084d76 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h @@ -0,0 +1,161 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_RUNTIME_XNNPACK_PARALLEL_LOOP_RUNNER_H_ +#define XLA_BACKENDS_CPU_RUNTIME_XNNPACK_PARALLEL_LOOP_RUNNER_H_ + +#include +#include +#include + +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/concurrency/chain.h" + +namespace Eigen { +struct ThreadPoolDevice; +} // namespace Eigen + +namespace xla::cpu { + +// Parallel loop runner uses underlying Eigen ThreadPoolDevice to execute +// parallel loops providing implicit synchronization: the next parallel loop +// starts execution only after all tasks from the previous loop are completed. +// +// Scheduled parallel loops execute asynchronously without blocking the caller +// thread. It is the user's responsibility to ensure that all values captured by +// the task are valid until the task is completed. +// +// Parallel loop runner is an implementation of the `pthreadpool` API adaptor +// for XLA:CPU runtime. +// +// WARNING: ParallelLoopRunner is not thread-safe, and must be externally +// synchronized by the user. +class ParallelLoopRunner { + public: + explicit ParallelLoopRunner(const Eigen::ThreadPoolDevice* device); + + // Takes ownership of the runner and returns a done event. After the done + // event is transferred to the caller, it is illegal to schedule more parallel + // loops on the moved-from runner. + static tsl::AsyncValueRef TakeDoneEvent( + ParallelLoopRunner&& runner); + + using Task1D = std::function; + + using Task1DTile1D = std::function; + + using Task2DTile1D = + std::function; + + using Task3DTile2D = + std::function; + + // This function implements a parallel version of a following loop: + // + // for (size_t i = 0; i < range; i++) + // task(i); + void Parallelize(size_t range, Task1D task); + + // This function implements a parallel version of a following loop: + // + // for (size_t i = 0; i < range; i += tile) + // task(i, std::min(range - i, tile)); + void Parallelize(size_t range, size_t tile, Task1DTile1D task); + + // This function implements a parallel version of a following loop: + // + // for (size_t i = 0; i < range_i; i++) + // for (size_t j = 0; j < range_j; j += tile_j) + // task(i, j, min(range_j - j, tile_j)); + void Parallelize(size_t range_i, size_t range_j, size_t tile_j, + Task2DTile1D task); + + // This function implements a parallel version of a following loop: + // + // for (size_t i = 0; i < range_i; i++) + // for (size_t j = 0; j < range_j; j += tile_j) + // for (size_t k = 0; k < range_k; k += tile_k) + // task(i, j, k, min(range_j - j, tile_j), min(range_k - k, tile_k)); + void Parallelize(size_t range_i, size_t range_j, size_t range_k, + size_t tile_j, size_t tile_k, Task3DTile2D task); + + // Resets the parallel loop runner `done_event` and returns the previous one + // to the caller. + tsl::AsyncValueRef ResetDoneEvent(); + + tsl::AsyncValueRef done_event() const { return done_event_; } + + const Eigen::ThreadPoolDevice* device() const { return device_; } + void set_device(const Eigen::ThreadPoolDevice* device) { device_ = device; } + + size_t num_threads() const; + + private: + // When parallelizing loops, we split the loop iteration space of `num_tasks` + // size into `num_parallel_tasks` parallel tasks, each of which processes + // `parallel_task_size` original tasks sequentially on a single thread. We do + // this to avoid excessive task scheduling overheads at run time. + struct ParallelTaskConfig { + struct TaskRange { + size_t begin; + size_t end; + }; + + TaskRange ParallelTaskRange(size_t parallel_task_index) const; + + size_t num_tasks; + size_t parallel_task_size; + size_t num_parallel_tasks; + }; + + ParallelTaskConfig ComputeParallelTaskConfig(size_t num_tasks) const; + + // Schedules tasks in the [start_index, end_index) range into the Eigen thread + // pool using recursive work splitting. Executes the `start_index` task in the + // caller thread. + template + void Parallelize(tsl::CountDownAsyncValueRef count_down, + size_t start_index, size_t end_index, + ParallelTask&& parallel_task); + + // Schedules `task` as the AndThen callback of the `done_event_`. Updates + // `done_event_` to the new completion event. + template + void ScheduleOne(Task&& task); + + // Schedules `num_tasks` invocation of the `parallel_task` into the Eigen + // thread pool when the `done_event_` becomes available. Updates `done_event_` + // to the new completion event. + template + void ScheduleAll(size_t num_tasks, ParallelTask&& parallel_task); + + // Async value that signals completion of the last scheduled parallel loop. + tsl::AsyncValueRef done_event_; + + // We keep a pointer to the Eigen thread pool device as an atomic variable + // because we might update it between concurrent runs of XNNPACK operations + // and non-atomic access to the `device_` pointer might lead to a data race. + // + // In practice PjRt CPU client owns the intra-op thread pool and passes it to + // XLA via Thunk::ExecuteParams, and PjRt client might have multiple thread + // pools for different NUMA nodes, and we have to be able to switch between + // them from run to run. + std::atomic device_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_RUNTIME_XNNPACK_PARALLEL_LOOP_RUNNER_H_ diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner_test.cc b/third_party/xla/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner_test.cc new file mode 100644 index 00000000000000..59dbf031c1eb27 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner_test.cc @@ -0,0 +1,205 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/cleanup/cleanup.h" +#include "absl/types/span.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/threadpool.h" + +#define EIGEN_USE_THREADS +#include "unsupported/Eigen/CXX11/Tensor" + +namespace xla::cpu { +namespace { + +TEST(ParallelLoopRunnerTest, Parallelize1D) { + tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8); + Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(), + threads.NumThreads()); + ParallelLoopRunner runner(&device); + + constexpr int32_t d0 = 128; + + auto* data = new int32_t[d0](); + auto cleanup = absl::Cleanup([&]() { delete[] data; }); + + auto increment = [&](size_t offset) { data[offset] += 1; }; + + runner.Parallelize(d0, increment); + runner.Parallelize(d0, increment); + runner.Parallelize(d0, increment); + runner.Parallelize(d0, increment); + runner.Parallelize(d0, increment); + + tsl::BlockUntilReady(ParallelLoopRunner::TakeDoneEvent(std::move(runner))); + ASSERT_TRUE(absl::c_all_of(absl::MakeSpan(&data[0], d0), + [](int32_t value) { return value == 5; })); +} + +TEST(ParallelLoopRunnerTest, Parallelize1DTile1D) { + tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8); + Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(), + threads.NumThreads()); + ParallelLoopRunner runner(&device); + + constexpr int32_t d0 = 128; + + auto* data = new int32_t[d0](); + auto cleanup = absl::Cleanup([&]() { delete[] data; }); + + auto increment = [&](size_t offset, size_t extent) { + for (size_t i = offset; i < offset + extent; ++i) { + data[i] += 1; + } + }; + + runner.Parallelize(d0, 1, increment); + runner.Parallelize(d0, 2, increment); + runner.Parallelize(d0, 3, increment); + runner.Parallelize(d0, 4, increment); + runner.Parallelize(d0, 5, increment); + + tsl::BlockUntilReady(ParallelLoopRunner::TakeDoneEvent(std::move(runner))); + ASSERT_TRUE(absl::c_all_of(absl::MakeSpan(&data[0], d0), + [](int32_t value) { return value == 5; })); +} + +TEST(ParallelLoopRunnerTest, Parallelize2DTile1D) { + tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8); + Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(), + threads.NumThreads()); + ParallelLoopRunner runner(&device); + + constexpr int32_t d0 = 4; + constexpr int32_t d1 = 39; + + auto* data = new int32_t[d0][d1](); + auto cleanup = absl::Cleanup([&]() { delete[] data; }); + + auto increment = [&](size_t i, size_t offset_j, size_t extent_j) { + for (size_t j = offset_j; j < offset_j + extent_j; ++j) { + data[i][j] += 1; + } + }; + + runner.Parallelize(d0, d1, 1, increment); + runner.Parallelize(d0, d1, 2, increment); + runner.Parallelize(d0, d1, 3, increment); + runner.Parallelize(d0, d1, 4, increment); + runner.Parallelize(d0, d1, 5, increment); + + tsl::BlockUntilReady(ParallelLoopRunner::TakeDoneEvent(std::move(runner))); + ASSERT_TRUE(absl::c_all_of(absl::MakeSpan(&data[0][0], d0 * d1), + [](int32_t value) { return value == 5; })); +} + +TEST(ParallelLoopRunnerTest, Parallelize3DTile2D) { + tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8); + Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(), + threads.NumThreads()); + ParallelLoopRunner runner(&device); + + constexpr int32_t d0 = 4; + constexpr int32_t d1 = 39; + constexpr int32_t d2 = 63; + + auto* data = new int32_t[d0][d1][d2](); + auto cleanup = absl::Cleanup([&]() { delete[] data; }); + + auto increment = [&](size_t i, size_t offset_j, size_t offset_k, + size_t extent_j, size_t extent_k) { + for (size_t j = offset_j; j < offset_j + extent_j; ++j) { + for (size_t k = offset_k; k < offset_k + extent_k; ++k) { + data[i][j][k] += 1; + } + } + }; + + runner.Parallelize(d0, d1, d2, 1, 5, increment); + runner.Parallelize(d0, d1, d2, 2, 4, increment); + runner.Parallelize(d0, d1, d2, 3, 4, increment); + runner.Parallelize(d0, d1, d2, 4, 3, increment); + runner.Parallelize(d0, d1, d2, 5, 1, increment); + + tsl::BlockUntilReady(ParallelLoopRunner::TakeDoneEvent(std::move(runner))); + ASSERT_TRUE(absl::c_all_of(absl::MakeSpan(&data[0][0][0], d0 * d1 * d2), + [](int32_t value) { return value == 5; })); +} + +//===----------------------------------------------------------------------===// +// Performance benchmarks. +//===----------------------------------------------------------------------===// + +static void BM_SingleTask1DLoop(benchmark::State& state) { + tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8); + Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(), + threads.NumThreads()); + ParallelLoopRunner runner(&device); + + for (auto _ : state) { + runner.Parallelize(1, 1, [](size_t, size_t) {}); + tsl::BlockUntilReady(runner.done_event()); + } +} + +BENCHMARK(BM_SingleTask1DLoop); + +static void BM_Parallelize2DTile1D(benchmark::State& state) { + tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8); + Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(), + threads.NumThreads()); + ParallelLoopRunner runner(&device); + + size_t range = 4; + size_t tile = 1; + + for (auto _ : state) { + runner.Parallelize(range, range, tile, [](size_t, size_t, size_t) {}); + tsl::BlockUntilReady(runner.done_event()); + } +} + +BENCHMARK(BM_Parallelize2DTile1D); + +static void BM_Parallelize3DTile2D(benchmark::State& state) { + tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8); + Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(), + threads.NumThreads()); + ParallelLoopRunner runner(&device); + + size_t range = 4; + size_t tile = 1; + + for (auto _ : state) { + runner.Parallelize(range, range, range, tile, tile, + [](size_t, size_t, size_t, size_t, size_t) {}); + tsl::BlockUntilReady(runner.done_event()); + } +} + +BENCHMARK(BM_Parallelize3DTile2D); + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.cc new file mode 100644 index 00000000000000..92d32d86e2461c --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.cc @@ -0,0 +1,183 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "xnnpack.h" +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "xla/backends/cpu/runtime/dot_lib.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.h" +#include "xla/backends/cpu/runtime/xnnpack/xnn_interop.h" +#include "xla/service/buffer_assignment.h" +#include "xla/shape.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla::cpu { + +absl::StatusOr XnnDotThunk::BuildDotSubgraph( + absl::Span arguments, absl::Span results) { + xnn_subgraph_t subgraph = nullptr; + XNN_RETURN_IF_ERROR(xnn_create_subgraph(/*external_value_ids=*/3, + /*flags=*/0, &subgraph)); + + uint32_t lhs_id = XNN_INVALID_VALUE_ID; + uint32_t rhs_id = XNN_INVALID_VALUE_ID; + uint32_t out_id = XNN_INVALID_VALUE_ID; + + auto dims = [](absl::Span dims) -> std::vector { + return {dims.begin(), dims.end()}; + }; + + std::vector lhs_dims = dims(dot_slices_.lhs_shape.dimensions()); + std::vector rhs_dims = dims(dot_slices_.rhs_shape.dimensions()); + std::vector out_dims = dims(dot_slices_.out_shape.dimensions()); + + XNN_RETURN_IF_ERROR(xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, lhs_dims.size(), lhs_dims.data(), nullptr, + /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &lhs_id)); + + XNN_RETURN_IF_ERROR(xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, rhs_dims.size(), rhs_dims.data(), nullptr, + /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_INPUT, &rhs_id)); + + XNN_RETURN_IF_ERROR(xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, out_dims.size(), out_dims.data(), nullptr, + /*external_id=*/2, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &out_id)); + + XNN_RETURN_IF_ERROR(xnn_define_batch_matrix_multiply( + subgraph, lhs_id, rhs_id, out_id, + /*flags=*/dot_canonical_dims_.rhs_canonical ? 0 : XNN_FLAG_TRANSPOSE_B)); + + return subgraph; +} + +absl::StatusOr XnnDotThunk::IsSupported( + const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape, + const Shape& rhs_shape, const Shape& out_shape) { + // TODO(ezhulenev): Support other element types. + if (lhs_shape.element_type() != F32 || rhs_shape.element_type() != F32 || + out_shape.element_type() != F32) { + return false; + } + + TF_ASSIGN_OR_RETURN(DotShape dot_shape, GetDotShape(dot_dimensions, lhs_shape, + rhs_shape, out_shape)); + + TF_ASSIGN_OR_RETURN(DotCanonicalDims dot_canonical_dims, + GetDotCanonicalDims(dot_dimensions, dot_shape)); + + // XNNPACK does not support transposing LHS or col-major layouts. + return dot_canonical_dims.lhs_canonical && + !dot_canonical_dims.lhs_column_major && + !dot_canonical_dims.rhs_column_major; +} + +absl::StatusOr> XnnDotThunk::Create( + Info info, DotDimensionNumbers dot_dimensions, + BufferAllocation::Slice lhs_buffer, Shape lhs_shape, + BufferAllocation::Slice rhs_buffer, Shape rhs_shape, + BufferAllocation::Slice out_buffer, Shape out_shape) { + TF_RETURN_IF_ERROR(InitializeXnnPack()); + + TF_ASSIGN_OR_RETURN(DotShape dot_shape, GetDotShape(dot_dimensions, lhs_shape, + rhs_shape, out_shape)); + + TF_ASSIGN_OR_RETURN(DotCanonicalDims dot_canonical_dims, + GetDotCanonicalDims(dot_dimensions, dot_shape)); + + DotSlices dot_slices{lhs_buffer, std::move(lhs_shape), + rhs_buffer, std::move(rhs_shape), + out_buffer, std::move(out_shape)}; + + return absl::WrapUnique( + new XnnDotThunk(info, std::move(dot_dimensions), std::move(dot_slices), + std::move(dot_shape), std::move(dot_canonical_dims))); +} + +static std::vector DotArguments( + const DotSlices& slices) { + return {XnnFusionThunk::Argument{slices.lhs_buffer, slices.lhs_shape}, + XnnFusionThunk::Argument{slices.rhs_buffer, slices.rhs_shape}}; +} + +static std::vector DotResults(const DotSlices& slices) { + return {XnnFusionThunk::Result{slices.out_buffer, slices.out_shape}}; +} + +XnnDotThunk::XnnDotThunk(Info info, DotDimensionNumbers dot_dimensions, + DotSlices dot_slices, DotShape dot_shape, + DotCanonicalDims dot_canonical_dims) + : XnnFusionThunk(std::move(info), DotArguments(dot_slices), + DotResults(dot_slices), + std::bind(&XnnDotThunk::BuildDotSubgraph, this, + std::placeholders::_1, std::placeholders::_2)), + dot_dimensions_(std::move(dot_dimensions)), + dot_slices_(std::move(dot_slices)), + dot_shape_(std::move(dot_shape)), + dot_canonical_dims_(std::move(dot_canonical_dims)) {} + +std::string XnnDotThunk::fusion_kind() const { return "dot"; } + +std::string XnnDotThunk::fusion_description() const { + return absl::StrFormat( + "lhs_batch_dims=[%s], rhs_batch_dims=[%s], " + "lhs_contract_dims=[%s], rhs_contract_dims=[%s]", + absl::StrJoin(dot_dimensions_.lhs_batch_dimensions(), ","), + absl::StrJoin(dot_dimensions_.rhs_batch_dimensions(), ","), + absl::StrJoin(dot_dimensions_.lhs_contracting_dimensions(), ","), + absl::StrJoin(dot_dimensions_.rhs_contracting_dimensions(), ",")); +} + +std::vector XnnDotThunk::fusion_details() const { + return { + absl::StrFormat(" matmul shape: batch_size=%d, lhs=%s, rhs=%s, out=%s", + dot_shape_.batch_size, + dot_shape_.lhs_matmul_shape.ToString(true), + dot_shape_.rhs_matmul_shape.ToString(true), + dot_shape_.out_matmul_shape.ToString(true)), + absl::StrFormat(" matmul dims: m=%d, k=%d, n=%d, lhs_column_major=%v, " + "lhs_canonical=%v rhs_column_major=%v, rhs_canonical=%v", + dot_canonical_dims_.m, dot_canonical_dims_.k, + dot_canonical_dims_.n, + dot_canonical_dims_.lhs_column_major, + dot_canonical_dims_.lhs_canonical, + dot_canonical_dims_.rhs_column_major, + dot_canonical_dims_.rhs_canonical), + }; +} + +std::string XnnDotThunk::argument_name(size_t index) const { + return index == 0 ? "lhs" : "rhs"; +} + +std::string XnnDotThunk::result_name(size_t index) const { return "out"; } + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.h b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.h new file mode 100644 index 00000000000000..b3ae7e88b5e69e --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.h @@ -0,0 +1,75 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_RUNTIME_XNNPACK_XNN_DOT_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_XNNPACK_XNN_DOT_THUNK_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/backends/cpu/runtime/dot_lib.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.h" +#include "xla/service/buffer_assignment.h" +#include "xla/shape.h" + +namespace xla::cpu { + +// Dot operation implemented on top of XNNPACK. +class XnnDotThunk final : public XnnFusionThunk { + public: + // Returns true if the dot operation is supported by XNNPACK. Returns an error + // if the dot operation shape is invalid. + static absl::StatusOr IsSupported( + const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape, + const Shape& rhs_shape, const Shape& out_shape); + + static absl::StatusOr> Create( + Info info, DotDimensionNumbers dot_dimensions, + BufferAllocation::Slice lhs_buffer, Shape lhs_shape, + BufferAllocation::Slice rhs_buffer, Shape rhs_shape, + BufferAllocation::Slice out_buffer, Shape out_shape); + + protected: + std::string fusion_kind() const final; + std::string fusion_description() const final; + + bool has_fusion_details() const final { return true; } + std::vector fusion_details() const final; + + std::string argument_name(size_t index) const final; + std::string result_name(size_t index) const final; + + private: + XnnDotThunk(Info info, DotDimensionNumbers dot_dimensions, + DotSlices dot_slices, DotShape dot_shape, + DotCanonicalDims dot_canonical_dims); + + absl::StatusOr BuildDotSubgraph( + absl::Span arguments, absl::Span results); + + DotDimensionNumbers dot_dimensions_; + DotSlices dot_slices_; + DotShape dot_shape_; + DotCanonicalDims dot_canonical_dims_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_RUNTIME_XNNPACK_XNN_DOT_THUNK_H_ diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk_test.cc new file mode 100644 index 00000000000000..b811e2566612d0 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk_test.cc @@ -0,0 +1,64 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.h" + +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_testlib.h" +#include "xla/literal_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +TEST(XnnDotThunkTest, SimpleDot) { + auto lhs = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto rhs = LiteralUtil::CreateR2({{4.0, 3.0}, {2.0, 1.0}}); + auto out = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); + + BufferAllocations allocations = CreateBufferAllocations(lhs, rhs, out); + + auto [lhs_alloc, rhs_alloc, out_alloc] = + CreateBufferAllocation(lhs, rhs, out); + auto [lhs_slice, rhs_slice, out_slice] = + CreateBufferAllocationSlice(lhs_alloc, rhs_alloc, out_alloc); + + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + + DotDimensionNumbers dot_dimensions; + dot_dimensions.add_lhs_contracting_dimensions(1); + dot_dimensions.add_rhs_contracting_dimensions(0); + + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, XnnDotThunk::Create({"dot"}, dot_dimensions, lhs_slice, shape, + rhs_slice, shape, out_slice, shape)); + + Thunk::ExecuteParams params; + params.buffer_allocations = &allocations; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()) << execute_event.GetError(); + + EXPECT_EQ(out, LiteralUtil::CreateR2({{8.0, 5.0}, {20.0, 13.0}})); +} + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.cc new file mode 100644 index 00000000000000..e88ac6b530a6ca --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.cc @@ -0,0 +1,239 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.h" + +#include +#include +#include +#include +#include +#include + +#include "xnnpack.h" +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "pthreadpool.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h" +#include "xla/backends/cpu/runtime/xnnpack/xnn_interop.h" +#include "xla/backends/cpu/runtime/xnnpack/xnn_threadpool.h" +#include "xla/runtime/buffer_use.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla::cpu { + +// XNNPACK runtime instantiated for the fusion operation. +struct XnnFusionThunk::XnnRuntime { + XnnRuntime() = default; + ~XnnRuntime() { Destroy(); } + + XnnRuntime(XnnRuntime&&); + XnnRuntime& operator=(XnnRuntime&&); + + tsl::AsyncValueRef Invoke( + const Eigen::ThreadPoolDevice* device, + absl::Span arguments, + absl::Span results); + + void Destroy(); + + std::unique_ptr runner; + pthreadpool_t threadpool = nullptr; + + xnn_subgraph_t subgraph = nullptr; + xnn_workspace_t workspace = nullptr; + xnn_runtime_t runtime = nullptr; +}; + +XnnFusionThunk::XnnRuntime::XnnRuntime(XnnRuntime&& other) { + *this = std::move(other); +} + +auto XnnFusionThunk::XnnRuntime::operator=(XnnRuntime&& other) -> XnnRuntime& { + Destroy(); + + threadpool = other.threadpool; + subgraph = other.subgraph; + workspace = other.workspace; + runtime = other.runtime; + + other.threadpool = nullptr; + other.subgraph = nullptr; + other.workspace = nullptr; + other.runtime = nullptr; + + runner = std::move(other.runner); + return *this; +} + +tsl::AsyncValueRef +XnnFusionThunk::XnnRuntime::Invoke(const Eigen::ThreadPoolDevice* device, + absl::Span arguments, + absl::Span results) { + // Create external values for all arguments and results. + absl::InlinedVector external_values; + external_values.reserve(arguments.size() + results.size()); + + // External tensor id for arguments and results. + uint32_t id = 0; + + for (auto& argument : arguments) { + external_values.push_back(xnn_external_value{id++, argument.opaque()}); + } + + for (auto& result : results) { + external_values.push_back(xnn_external_value{id++, result.opaque()}); + } + + XNN_RETURN_IF_ERROR(xnn_setup_runtime_v2(runtime, external_values.size(), + external_values.data())); + + runner->set_device(device); + XNN_RETURN_IF_ERROR(xnn_invoke_runtime(runtime)); + return runner->ResetDoneEvent(); +} + +void XnnFusionThunk::XnnRuntime::Destroy() { + if (runtime != nullptr) XNN_LOG_IF_ERROR(xnn_delete_runtime(runtime)); + if (subgraph != nullptr) XNN_LOG_IF_ERROR(xnn_delete_subgraph(subgraph)); + if (workspace != nullptr) XNN_LOG_IF_ERROR(xnn_release_workspace(workspace)); + + bool owned_threadpool = threadpool != nullptr && IsCustomPthreadpoolEnabled(); + if (owned_threadpool) pthreadpool_destroy(threadpool); +} + +absl::StatusOr XnnFusionThunk::CreateXnnRuntime( + const Eigen::ThreadPoolDevice* device) { + bool use_custom_threadpool = device && IsCustomPthreadpoolEnabled(); + VLOG(3) << absl::StreamFormat( + "Create XNN runtime for `%s` operation: num_created=%d, " + "use_custom_threadpool=%v", + info().op_name, xnn_runtime_pool_.num_created(), use_custom_threadpool); + + XnnRuntime runtime; + + // Construct XNNPACK subgraph using user-provided builder function. + TF_ASSIGN_OR_RETURN(runtime.subgraph, builder_(arguments_, results_)); + + // If XLA is compiled with custom pthreadpool, use it in XNNPACK runtime, + // otherwise we'll run all XNNPACK operations in the default pthreadpool. + runtime.runner = std::make_unique(device); + if (use_custom_threadpool) { + runtime.threadpool = CreateCustomPthreadpool(runtime.runner.get()); + } else { + runtime.threadpool = DefaultPthreadpool(); + } + + XNN_RETURN_IF_ERROR(xnn_create_workspace(&runtime.workspace)); + + XNN_RETURN_IF_ERROR( + xnn_create_runtime_v4(runtime.subgraph, nullptr, runtime.workspace, + runtime.threadpool, 0, &runtime.runtime)); + + XNN_RETURN_IF_ERROR(xnn_reshape_runtime(runtime.runtime)); + + return {std::move(runtime)}; +} + +absl::StatusOr> XnnFusionThunk::Create( + Info info, std::vector arguments, std::vector results, + Builder builder) { + TF_RETURN_IF_ERROR(InitializeXnnPack()); + + return absl::WrapUnique( + new XnnFusionThunk(std::move(info), std::move(arguments), + std::move(results), std::move(builder))); +} + +XnnFusionThunk::XnnFusionThunk(Info info, std::vector arguments, + std::vector results, Builder builder) + : Thunk(Kind::kXnnFusion, std::move(info)), + arguments_(std::move(arguments)), + results_(std::move(results)), + builder_(std::move(builder)), + xnn_runtime_pool_(std::bind(&XnnFusionThunk::CreateXnnRuntime, this, + std::placeholders::_1)) {} + +XnnFusionThunk::~XnnFusionThunk() = default; + +XnnFusionThunk::BufferUses XnnFusionThunk::buffer_uses() const { + BufferUses buffer_uses; + for (const Argument& argument : arguments_) { + buffer_uses.push_back(BufferUse::Read(argument.slice)); + } + for (const Result& result : results_) { + buffer_uses.push_back(BufferUse::Write(result.slice)); + } + return buffer_uses; +} + +tsl::AsyncValueRef XnnFusionThunk::Execute( + const ExecuteParams& params) { + VLOG(3) << absl::StreamFormat("XNN %s `%s`: %s", fusion_kind(), + info().op_name, fusion_description()); + + if (VLOG_IS_ON(3) && has_fusion_details()) { + for (auto& detail : fusion_details()) VLOG(3) << detail; + } + + // Resolve device memory for arguments. + absl::InlinedVector arguments_buffers; + arguments_buffers.resize(arguments_.size()); + for (size_t i = 0; i < arguments_.size(); ++i) { + Argument& argument = arguments_[i]; + + TF_ASSIGN_OR_RETURN( + arguments_buffers[i], + params.buffer_allocations->GetDeviceAddress(argument.slice)); + + VLOG(3) << absl::StreamFormat(" %s: %s in slice %s (%p)", argument_name(i), + argument.shape.ToString(true), + argument.slice.ToString(), + arguments_buffers[i].opaque()); + } + + // Resolve device memory for results. + absl::InlinedVector results_buffers; + results_buffers.resize(results_.size()); + for (size_t i = 0; i < results_.size(); ++i) { + Result& result = results_[i]; + + TF_ASSIGN_OR_RETURN( + results_buffers[i], + params.buffer_allocations->GetDeviceAddress(results_[i].slice)); + + VLOG(3) << absl::StreamFormat(" %s: %s in slice %s (%p)", result_name(i), + result.shape.ToString(true), + result.slice.ToString(), + results_buffers[i].opaque()); + } + + TF_ASSIGN_OR_RETURN( + auto runtime, xnn_runtime_pool_.GetOrCreate(params.intra_op_threadpool)); + + return runtime->Invoke(params.intra_op_threadpool, + absl::MakeSpan(arguments_buffers), + absl::MakeSpan(results_buffers)); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.h b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.h new file mode 100644 index 00000000000000..1653bb2bc609f1 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.h @@ -0,0 +1,105 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_RUNTIME_XNNPACK_XNN_FUSION_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_XNNPACK_XNN_FUSION_THUNK_H_ + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/xnnpack/object_pool.h" +#include "xla/service/buffer_assignment.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/async_value_ref.h" + +// Forward declare XNNPACK types. +typedef struct xnn_subgraph* xnn_subgraph_t; // NOLINT + +namespace xla::cpu { + +// XNN fusion thunk encapsulates XNNPACK subgraph contructed from an XLA fusion +// operation, where each HLO op has a corresponding XNNPACK operator. +class XnnFusionThunk : public Thunk { + public: + ~XnnFusionThunk() override; + + struct Argument { + BufferAllocation::Slice slice; + Shape shape; + }; + + struct Result { + BufferAllocation::Slice slice; + Shape shape; + }; + + // Builder function constructs XNNPACK subgraph for the fusion operation. + using Builder = absl::AnyInvocable( + absl::Span arguments, absl::Span results)>; + + static absl::StatusOr> Create( + Info info, std::vector arguments, std::vector results, + Builder builder); + + tsl::AsyncValueRef Execute(const ExecuteParams& params) final; + + BufferUses buffer_uses() const final; + + protected: + XnnFusionThunk(Info info, std::vector arguments, + std::vector results, Builder builder); + + // Extension points for subclasses to customize the logging behavior. + virtual std::string fusion_kind() const { return "fusion"; } + virtual std::string fusion_description() const { return ""; } + + virtual bool has_fusion_details() const { return false; } + virtual std::vector fusion_details() const { return {}; } + + virtual std::string argument_name(size_t index) const { + return absl::StrCat("arg #", index); + } + + virtual std::string result_name(size_t index) const { + return absl::StrCat("res #", index); + } + + private: + // XNNPACK runtime instantiated for the fusion operation. + struct XnnRuntime; + + absl::StatusOr CreateXnnRuntime( + const Eigen::ThreadPoolDevice* device); + + std::vector arguments_; + std::vector results_; + Builder builder_; + + // XLA:CPU executable can be called concurrently from multiple threads, + // and we need to keep a pool of XNNPACK runtimes to avoid data races. + ObjectPool xnn_runtime_pool_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_RUNTIME_XNNPACK_XNN_FUSION_THUNK_H_ diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk_test.cc new file mode 100644 index 00000000000000..2ee61b734ba72c --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.h" + +#include +#include +#include +#include + +#include "xnnpack.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_testlib.h" +#include "xla/backends/cpu/runtime/xnnpack/xnn_interop.h" +#include "xla/literal_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +static absl::StatusOr CreateBinaryAdd( + absl::Span arguments, + absl::Span results) { + xnn_subgraph_t subgraph = nullptr; + XNN_RETURN_IF_ERROR(xnn_create_subgraph(/*external_value_ids=*/3, + /*flags=*/0, &subgraph)); + + auto dims = [](absl::Span dims) -> std::vector { + return {dims.begin(), dims.end()}; + }; + + uint32_t lhs_id = XNN_INVALID_VALUE_ID; + uint32_t rhs_id = XNN_INVALID_VALUE_ID; + uint32_t out_id = XNN_INVALID_VALUE_ID; + + std::vector lhs_dims = dims(arguments[0].shape.dimensions()); + std::vector rhs_dims = dims(arguments[1].shape.dimensions()); + std::vector out_dims = dims(results[0].shape.dimensions()); + + XNN_RETURN_IF_ERROR(xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, lhs_dims.size(), lhs_dims.data(), nullptr, + /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &lhs_id)); + + XNN_RETURN_IF_ERROR(xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, rhs_dims.size(), rhs_dims.data(), nullptr, + /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_INPUT, &rhs_id)); + + XNN_RETURN_IF_ERROR(xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, out_dims.size(), out_dims.data(), nullptr, + /*external_id=*/2, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &out_id)); + + xnn_binary_params params = {-std::numeric_limits::infinity(), + std::numeric_limits::infinity()}; + + XNN_RETURN_IF_ERROR(xnn_define_binary(subgraph, xnn_binary_add, ¶ms, + lhs_id, rhs_id, out_id, /*flags=*/0)); + + return subgraph; +} + +TEST(XnnFusionThunkTest, ElementwiseAdd) { + auto lhs = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0}); + auto rhs = LiteralUtil::CreateR1({4.0, 3.0, 2.0, 1.0}); + auto out = LiteralUtil::CreateR1({0.0, 0.0, 0.0, 0.0}); + + BufferAllocations allocations = CreateBufferAllocations(lhs, rhs, out); + + auto [lhs_alloc, rhs_alloc, out_alloc] = + CreateBufferAllocation(lhs, rhs, out); + auto [lhs_slice, rhs_slice, out_slice] = + CreateBufferAllocationSlice(lhs_alloc, rhs_alloc, out_alloc); + + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + + XnnFusionThunk::Argument lhs_arg = {lhs_slice, shape}; + XnnFusionThunk::Argument rhs_arg = {rhs_slice, shape}; + XnnFusionThunk::Result out_res = {out_slice, shape}; + + TF_ASSERT_OK_AND_ASSIGN(auto thunk, + XnnFusionThunk::Create({"fusion"}, {lhs_arg, rhs_arg}, + {out_res}, &CreateBinaryAdd)); + + Thunk::ExecuteParams params; + params.buffer_allocations = &allocations; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()) << execute_event.GetError(); + + EXPECT_EQ(out, LiteralUtil::CreateR1({5.0, 5.0, 5.0, 5.0})); +} + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_interop.cc b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_interop.cc new file mode 100644 index 00000000000000..65e255654818bd --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_interop.cc @@ -0,0 +1,32 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/xnnpack/xnn_interop.h" + +#include "xnnpack.h" +#include "absl/status/status.h" +#include "xla/util.h" + +namespace xla::cpu { + +absl::Status InitializeXnnPack() { + static xnn_status status = xnn_initialize(/*allocator=*/nullptr); + if (status != xnn_status_success) { + return Internal("XNNPACK initialization failed"); + } + return absl::OkStatus(); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_interop.h b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_interop.h index 52e655d2eeaf37..47f6aa3d29402a 100644 --- a/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_interop.h +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_interop.h @@ -16,8 +16,65 @@ limitations under the License. #ifndef XLA_BACKENDS_CPU_RUNTIME_XNNPACK_XNN_INTEROP_H_ #define XLA_BACKENDS_CPU_RUNTIME_XNNPACK_XNN_INTEROP_H_ -#include "xnnpack.h" // IWYU pragma: keep +#include "xnnpack.h" +#include "absl/base/optimization.h" +#include "absl/status/status.h" +#include "xla/util.h" +#include "tsl/platform/logging.h" -namespace xla::cpu {} +namespace xla::cpu { + +#define XNN_RETURN_IF_ERROR(expr) \ + do { \ + absl::Status s = XnnStatusToStatus(expr); \ + if (!s.ok()) { \ + return s; \ + } \ + } while (0) + +#define XNN_LOG_IF_ERROR(expr) \ + do { \ + absl::Status s = XnnStatusToStatus(expr); \ + if (!s.ok()) { \ + LOG(ERROR) << "XNNPACK operation failed: " << s; \ + } \ + } while (0) + +// Statically initializes XNNPACK for the current process. +absl::Status InitializeXnnPack(); + +// Converts XNNPACK status to absl::Status. +inline absl::Status XnnStatusToStatus(xnn_status status) { + if (ABSL_PREDICT_TRUE(status == xnn_status_success)) { + return absl::OkStatus(); + } + + auto error_message = [](xnn_status status) { + switch (status) { + case xnn_status_success: + return ""; + case xnn_status_uninitialized: + return "uninitialized"; + case xnn_status_invalid_parameter: + return "invalid parameter"; + case xnn_status_invalid_state: + return "invalid state"; + case xnn_status_unsupported_parameter: + return "unsupported parameter"; + case xnn_status_unsupported_hardware: + return "unsupported hardware"; + case xnn_status_out_of_memory: + return "out of memory"; + case xnn_status_reallocation_required: + return "reallocation required"; + case xnn_status_deprecated: + return "deprecated"; + } + }; + + return Internal("XNNPACK operation failed: %s", error_message(status)); +} + +} // namespace xla::cpu #endif // XLA_BACKENDS_CPU_RUNTIME_XNNPACK_XNN_INTEROP_H_ diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_threadpool.cc b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_threadpool.cc new file mode 100644 index 00000000000000..49d03eba57e130 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_threadpool.cc @@ -0,0 +1,470 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/xnnpack/xnn_threadpool.h" + +#include +#include +#include +#include + +#include "absl/base/call_once.h" +#include "absl/base/optimization.h" +#include "pthreadpool.h" +#include "xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/threadpool.h" +#include "tsl/platform/cpu_info.h" + +#define EIGEN_USE_THREADS +#include "unsupported/Eigen/CXX11/Tensor" + +// `pthreadpool` API implementation on top of ParallelLoopRunner. +// +// When building with `pthreadpool_header_only` config, `pthreadpool` becomes a +// header-only library, and we implement the API on top of ParallelLoopRunner. +// +// At link time `pthreadpool` symbols resolved to our own implementation. This +// is a temporary hack around the fact that it's impossible to customize +// `pthreadpool` implementation at run time. The downsize is that it's +// impossible to have two `pthreadpool` implementations linked into the same +// binary. +// +// WARNING: This is under construction and implements only the subset of the API +// surface which is needed by XNNPACK uses inside XLA. + +namespace xla::cpu { + +bool IsCustomPthreadpoolEnabled() { +#if defined(XLA_CPU_USE_CUSTOM_PTHREADPOOL) + return true; +#else + return false; +#endif // XLA_CPU_USE_CUSTOM_PTHREADPOOL +} + +// Default XLA:CPU pthreadpool initialized once per process. +static absl::once_flag pthreadpool_init; +static pthreadpool_t default_pthreadpool; + +pthreadpool_t DefaultPthreadpool() { + if (IsCustomPthreadpoolEnabled()) { + LOG(WARNING) << "Default pthreadpool is not supported when build with " + "`--define pthreadpool_header_only=true`"; + return nullptr; + } + + absl::call_once(pthreadpool_init, []() { + default_pthreadpool = pthreadpool_create(tsl::port::MaxParallelism()); + }); + + return default_pthreadpool; +} + +namespace { + +class Pthreadpool { + public: + virtual ~Pthreadpool() = default; + virtual ParallelLoopRunner* runner() = 0; +}; + +// Wraps user-provided parallel loop runner into the custom pthreadpool. +class WrappedParallelLoopRunner : public Pthreadpool { + public: + explicit WrappedParallelLoopRunner(ParallelLoopRunner* runner) + : runner_(runner) {} + ParallelLoopRunner* runner() final { return runner_; } + + private: + ParallelLoopRunner* runner_; +}; + +// Wraps newly created thread pool into the custom pthreadpool. +class OwnedParallelLoopRunner : public Pthreadpool { + public: + explicit OwnedParallelLoopRunner(size_t threads_count) + : thread_pool_(tsl::Env::Default(), "xnn_threadpool", threads_count), + device_(thread_pool_.AsEigenThreadPool(), threads_count), + runner_(&device_) {} + + ParallelLoopRunner* runner() final { return &runner_; } + + private: + tsl::thread::ThreadPool thread_pool_; + Eigen::ThreadPoolDevice device_; + ParallelLoopRunner runner_; +}; + +} // namespace + +pthreadpool_t CreateCustomPthreadpool(ParallelLoopRunner* runner) { + if (IsCustomPthreadpoolEnabled()) { + return reinterpret_cast( + std::make_unique(runner).release()); + } + LOG(FATAL) << "To use custom pthreadpool, build with " + "`--define pthreadpool_header_only=true`"; +} + +static pthreadpool_t CreateCustomPthreadpool(size_t threads_count) { // NOLINT + if (IsCustomPthreadpoolEnabled()) { + return reinterpret_cast( + std::make_unique(threads_count).release()); + } + LOG(FATAL) << "To use custom pthreadpool, build with " + "`--define pthreadpool_header_only=true`"; +} + +static Pthreadpool* Cast(pthreadpool_t threadpool) { + return reinterpret_cast(threadpool); +} + +xla::cpu::ParallelLoopRunner* GetParallelLoopRunner(pthreadpool_t threadpool) { + return IsCustomPthreadpoolEnabled() ? Cast(threadpool)->runner() : nullptr; +} + +//===----------------------------------------------------------------------===// +// C++ implementation of the subset of `pthreadpool` C API. +//===----------------------------------------------------------------------===// + +static void DestroyCustomPthreadpool(pthreadpool_t threadpool) { // NOLINT + if (ABSL_PREDICT_FALSE(threadpool == nullptr)) { + return; + } + + tsl::BlockUntilReady(Cast(threadpool)->runner()->done_event()); + delete Cast(threadpool); +} + +static size_t GetThreadsCount(pthreadpool_t threadpool) { // NOLINT + if (ABSL_PREDICT_FALSE(threadpool == nullptr)) { + return 0; + } + + return Cast(threadpool)->runner()->num_threads(); +} + +static void Parallelize1D( // NOLINT + pthreadpool_t threadpool, pthreadpool_task_1d_t function, void* context, + size_t range, uint32_t flags) { + if (ABSL_PREDICT_FALSE(threadpool == nullptr)) { + for (size_t i = 0; i < range; ++i) { + function(context, i); + } + return; + } + + ParallelLoopRunner::Task1D task = [function, context](size_t offset) { + (*function)(context, offset); + }; + Cast(threadpool)->runner()->Parallelize(range, task); +} + +static void Parallelize1DTile1D( // NOLINT + pthreadpool_t threadpool, pthreadpool_task_1d_tile_1d_t function, + void* context, size_t range, size_t tile, uint32_t flags) { + if (ABSL_PREDICT_FALSE(threadpool == nullptr)) { + for (size_t i = 0; i < range; i += tile) { + function(context, i, std::min(range - i, tile)); + } + return; + } + + ParallelLoopRunner::Task1DTile1D task = [function, context](size_t offset, + size_t extent) { + (*function)(context, offset, extent); + }; + Cast(threadpool)->runner()->Parallelize(range, tile, task); +} + +static void Parallelize2DTile1D(pthreadpool_t threadpool, // NOLINT + pthreadpool_task_2d_tile_1d_t function, + void* context, size_t range_i, size_t range_j, + size_t tile_j, uint32_t flags) { + if (ABSL_PREDICT_FALSE(threadpool == nullptr)) { + for (size_t i = 0; i < range_i; i++) { + for (size_t j = 0; j < range_j; j += tile_j) { + function(context, i, j, std::min(range_j - j, tile_j)); + } + } + return; + } + + ParallelLoopRunner::Task2DTile1D task = + [function, context](size_t offset_i, size_t offset_j, size_t extent_j) { + (*function)(context, offset_i, offset_j, extent_j); + }; + Cast(threadpool)->runner()->Parallelize(range_i, range_j, tile_j, task); +} + +static void Parallelize3DTile2D(pthreadpool_t threadpool, // NOLINT + pthreadpool_task_3d_tile_2d_t function, + void* context, size_t range_i, size_t range_j, + size_t range_k, size_t tile_j, size_t tile_k, + uint32_t flags) { + if (ABSL_PREDICT_FALSE(threadpool == nullptr)) { + for (size_t i = 0; i < range_i; i++) { + for (size_t j = 0; j < range_j; j += tile_j) { + for (size_t k = 0; k < range_k; k += tile_k) { + function(context, i, j, k, std::min(range_j - j, tile_j), + std::min(range_k - k, tile_k)); + } + } + } + return; + } + + ParallelLoopRunner::Task3DTile2D task = + [function, context](size_t offset_i, size_t offset_j, size_t offset_k, + size_t extent_j, size_t extent_k) { + (*function)(context, offset_i, offset_j, offset_k, extent_j, extent_k); + }; + Cast(threadpool) + ->runner() + ->Parallelize(range_i, range_j, range_k, tile_j, tile_k, task); +} + +} // namespace xla::cpu + +#if defined(XLA_CPU_USE_CUSTOM_PTHREADPOOL) + +extern "C" pthreadpool_t pthreadpool_create(size_t threads_count) { + return xla::cpu::CreateCustomPthreadpool(threads_count); +} + +extern "C" void pthreadpool_destroy(pthreadpool_t threadpool) { + xla::cpu::DestroyCustomPthreadpool(threadpool); +} + +extern "C" size_t pthreadpool_get_threads_count(pthreadpool_t threadpool) { + return xla::cpu::GetThreadsCount(threadpool); +} + +extern "C" void pthreadpool_parallelize_1d(pthreadpool_t threadpool, + pthreadpool_task_1d_t function, + void* context, size_t range, + uint32_t flags) { + xla::cpu::Parallelize1D(threadpool, function, context, range, flags); +} + +extern "C" void pthreadpool_parallelize_1d_with_thread( + pthreadpool_t threadpool, pthreadpool_task_1d_with_thread_t function, + void* context, size_t range, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_1d_with_uarch( + pthreadpool_t threadpool, pthreadpool_task_1d_with_id_t function, + void* context, uint32_t default_uarch_index, uint32_t max_uarch_index, + size_t range, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_1d_tile_1d( + pthreadpool_t threadpool, pthreadpool_task_1d_tile_1d_t function, + void* context, size_t range, size_t tile, uint32_t flags) { + xla::cpu::Parallelize1DTile1D(threadpool, function, context, range, tile, + flags); +} + +extern "C" void pthreadpool_parallelize_2d(pthreadpool_t threadpool, + pthreadpool_task_2d_t function, + void* context, size_t range_i, + size_t range_j, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_2d_with_thread( + pthreadpool_t threadpool, pthreadpool_task_2d_with_thread_t function, + void* context, size_t range_i, size_t range_j, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_2d_tile_1d( + pthreadpool_t threadpool, pthreadpool_task_2d_tile_1d_t function, + void* context, size_t range_i, size_t range_j, size_t tile_j, + uint32_t flags) { + xla::cpu::Parallelize2DTile1D(threadpool, function, context, range_i, range_j, + tile_j, flags); +} + +extern "C" void pthreadpool_parallelize_2d_tile_1d_with_uarch( + pthreadpool_t threadpool, pthreadpool_task_2d_tile_1d_with_id_t function, + void* context, uint32_t default_uarch_index, uint32_t max_uarch_index, + size_t range_i, size_t range_j, size_t tile_j, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_2d_tile_1d_with_uarch_with_thread( + pthreadpool_t threadpool, + pthreadpool_task_2d_tile_1d_with_id_with_thread_t function, void* context, + uint32_t default_uarch_index, uint32_t max_uarch_index, size_t range_i, + size_t range_j, size_t tile_j, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_2d_tile_2d( + pthreadpool_t threadpool, pthreadpool_task_2d_tile_2d_t function, + void* context, size_t range_i, size_t range_j, size_t tile_i, size_t tile_j, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_2d_tile_2d_with_uarch( + pthreadpool_t threadpool, pthreadpool_task_2d_tile_2d_with_id_t function, + void* context, uint32_t default_uarch_index, uint32_t max_uarch_index, + size_t range_i, size_t range_j, size_t tile_i, size_t tile_j, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_3d(pthreadpool_t threadpool, + pthreadpool_task_3d_t function, + void* context, size_t range_i, + size_t range_j, size_t range_k, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_3d_tile_1d( + pthreadpool_t threadpool, pthreadpool_task_3d_tile_1d_t function, + void* context, size_t range_i, size_t range_j, size_t range_k, + size_t tile_k, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_3d_tile_1d_with_thread( + pthreadpool_t threadpool, + pthreadpool_task_3d_tile_1d_with_thread_t function, void* context, + size_t range_i, size_t range_j, size_t range_k, size_t tile_k, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_3d_tile_1d_with_uarch( + pthreadpool_t threadpool, pthreadpool_task_3d_tile_1d_with_id_t function, + void* context, uint32_t default_uarch_index, uint32_t max_uarch_index, + size_t range_i, size_t range_j, size_t range_k, size_t tile_k, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_3d_tile_1d_with_uarch_with_thread( + pthreadpool_t threadpool, + pthreadpool_task_3d_tile_1d_with_id_with_thread_t function, void* context, + uint32_t default_uarch_index, uint32_t max_uarch_index, size_t range_i, + size_t range_j, size_t range_k, size_t tile_k, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_3d_tile_2d( + pthreadpool_t threadpool, pthreadpool_task_3d_tile_2d_t function, + void* context, size_t range_i, size_t range_j, size_t range_k, + size_t tile_j, size_t tile_k, uint32_t flags) { + xla::cpu::Parallelize3DTile2D(threadpool, function, context, range_i, range_j, + range_k, tile_j, tile_k, flags); +} + +extern "C" void pthreadpool_parallelize_3d_tile_2d_with_uarch( + pthreadpool_t threadpool, pthreadpool_task_3d_tile_2d_with_id_t function, + void* context, uint32_t default_uarch_index, uint32_t max_uarch_index, + size_t range_i, size_t range_j, size_t range_k, size_t tile_j, + size_t tile_k, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_4d(pthreadpool_t threadpool, + pthreadpool_task_4d_t function, + void* context, size_t range_i, + size_t range_j, size_t range_k, + size_t range_l, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_4d_tile_1d( + pthreadpool_t threadpool, pthreadpool_task_4d_tile_1d_t function, + void* context, size_t range_i, size_t range_j, size_t range_k, + size_t range_l, size_t tile_l, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_4d_tile_2d( + pthreadpool_t threadpool, pthreadpool_task_4d_tile_2d_t function, + void* context, size_t range_i, size_t range_j, size_t range_k, + size_t range_l, size_t tile_k, size_t tile_l, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_4d_tile_2d_with_uarch( + pthreadpool_t threadpool, pthreadpool_task_4d_tile_2d_with_id_t function, + void* context, uint32_t default_uarch_index, uint32_t max_uarch_index, + size_t range_i, size_t range_j, size_t range_k, size_t range_l, + size_t tile_k, size_t tile_l, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_5d(pthreadpool_t threadpool, + pthreadpool_task_5d_t function, + void* context, size_t range_i, + size_t range_j, size_t range_k, + size_t range_l, size_t range_m, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_5d_tile_1d( + pthreadpool_t threadpool, pthreadpool_task_5d_tile_1d_t function, + void* context, size_t range_i, size_t range_j, size_t range_k, + size_t range_l, size_t range_m, size_t tile_m, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_5d_tile_2d( + pthreadpool_t threadpool, pthreadpool_task_5d_tile_2d_t function, + void* context, size_t range_i, size_t range_j, size_t range_k, + size_t range_l, size_t range_m, size_t tile_l, size_t tile_m, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_6d(pthreadpool_t threadpool, + pthreadpool_task_6d_t function, + void* context, size_t range_i, + size_t range_j, size_t range_k, + size_t range_l, size_t range_m, + size_t range_n, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_6d_tile_1d( + pthreadpool_t threadpool, pthreadpool_task_6d_tile_1d_t function, + void* context, size_t range_i, size_t range_j, size_t range_k, + size_t range_l, size_t range_m, size_t range_n, size_t tile_n, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_6d_tile_2d( + pthreadpool_t threadpool, pthreadpool_task_6d_tile_2d_t function, + void* context, size_t range_i, size_t range_j, size_t range_k, + size_t range_l, size_t range_m, size_t range_n, size_t tile_m, + size_t tile_n, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +#endif // XLA_CPU_USE_CUSTOM_PTHREADPOOL diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_threadpool.h b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_threadpool.h new file mode 100644 index 00000000000000..4afe664bba8cd6 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_threadpool.h @@ -0,0 +1,42 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_RUNTIME_XNNPACK_XNN_THREADPOOL_H_ +#define XLA_BACKENDS_CPU_RUNTIME_XNNPACK_XNN_THREADPOOL_H_ + +#include "pthreadpool.h" +#include "xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h" + +namespace xla::cpu { + +// Returns true if the custom pthreadpool is enabled. +bool IsCustomPthreadpoolEnabled(); + +// Returns the default per-process pthreadpool. If custom `pthreadpool` is +// enabled, it will return nullptr. +pthreadpool_t DefaultPthreadpool(); + +// Creates a `pthreadpool` that uses the given `runner` to execute work. If +// custom `pthreadpool` is disabled, it will kill the process. +pthreadpool_t CreateCustomPthreadpool(xla::cpu::ParallelLoopRunner* runner); + +// Returns the parallel loop runner associated with the given `pthreadpool`. If +// the `pthreadpool` is not associated with a parallel loop runner, returns +// nullptr. +xla::cpu::ParallelLoopRunner* GetParallelLoopRunner(pthreadpool_t threadpool); + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_RUNTIME_XNNPACK_XNN_THREADPOOL_H_ diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_threadpool_test.cc b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_threadpool_test.cc new file mode 100644 index 00000000000000..7cdf1dd1cb91a0 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/xnn_threadpool_test.cc @@ -0,0 +1,238 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/xnnpack/xnn_threadpool.h" + +#include +#include +#include +#include + +#include "xnnpack.h" +#include "absl/algorithm/container.h" +#include "pthreadpool.h" +#include "xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +static xnn_status CreateBinaryOpsSubgraph(xnn_subgraph_t subgraph, + std::vector dims) { + uint32_t lhs_id = XNN_INVALID_VALUE_ID; + uint32_t rhs_id = XNN_INVALID_VALUE_ID; + uint32_t out0_id = XNN_INVALID_VALUE_ID; + uint32_t out1_id = XNN_INVALID_VALUE_ID; + + if (auto s = xnn_define_tensor_value(subgraph, xnn_datatype_fp32, dims.size(), + dims.data(), nullptr, /*external_id=*/0, + XNN_VALUE_FLAG_EXTERNAL_INPUT, &lhs_id); + s != xnn_status_success) { + return s; + } + + if (auto s = xnn_define_tensor_value(subgraph, xnn_datatype_fp32, dims.size(), + dims.data(), nullptr, /*external_id=*/1, + XNN_VALUE_FLAG_EXTERNAL_INPUT, &rhs_id); + s != xnn_status_success) { + return s; + } + + if (auto s = xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, dims.size(), dims.data(), nullptr, + /*external_id=*/2, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &out0_id); + s != xnn_status_success) { + return s; + } + + if (auto s = xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, dims.size(), dims.data(), nullptr, + /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &out1_id); + s != xnn_status_success) { + return s; + } + + xnn_binary_params params = {-std::numeric_limits::infinity(), + std::numeric_limits::infinity()}; + + if (auto s = xnn_define_binary(subgraph, xnn_binary_add, ¶ms, lhs_id, + rhs_id, out0_id, /*flags=*/0); + s != xnn_status_success) { + return s; + } + + if (auto s = xnn_define_binary(subgraph, xnn_binary_multiply, ¶ms, lhs_id, + rhs_id, out1_id, /*flags=*/0); + s != xnn_status_success) { + return s; + } + + return xnn_status_success; +} + +static xnn_status CreateDotSubgraph(xnn_subgraph_t subgraph, size_t m, size_t n, + size_t k) { + uint32_t lhs_id = XNN_INVALID_VALUE_ID; + uint32_t rhs_id = XNN_INVALID_VALUE_ID; + uint32_t out_id = XNN_INVALID_VALUE_ID; + + std::vector lhs_dims = {m, k}; + std::vector rhs_dims = {k, n}; + std::vector out_dims = {m, n}; + + if (auto s = xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, lhs_dims.size(), lhs_dims.data(), + nullptr, /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &lhs_id); + s != xnn_status_success) { + return s; + } + + if (auto s = xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, rhs_dims.size(), rhs_dims.data(), + nullptr, /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_INPUT, &rhs_id); + s != xnn_status_success) { + return s; + } + + if (auto s = xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, out_dims.size(), out_dims.data(), + nullptr, + /*external_id=*/2, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &out_id); + s != xnn_status_success) { + return s; + } + + if (auto s = + xnn_define_batch_matrix_multiply(subgraph, lhs_id, rhs_id, out_id, + /*flags=*/0); + s != xnn_status_success) { + return s; + } + + return xnn_status_success; +} + +TEST(XnnThreadPoolTest, Binary) { + pthreadpool_t threadpool = pthreadpool_create(8); + ASSERT_NE(threadpool, nullptr); + + ASSERT_EQ(xnn_initialize(/*allocator=*/nullptr), xnn_status_success); + + xnn_workspace_t workspace = nullptr; + ASSERT_EQ(xnn_create_workspace(&workspace), xnn_status_success); + + xnn_subgraph_t subgraph = nullptr; + + ASSERT_EQ( + xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph), + xnn_status_success); + + size_t d0 = 1024; + CreateBinaryOpsSubgraph(subgraph, {d0, d0}); + + std::vector lhs(d0 * d0, 2.0f); + std::vector rhs(d0 * d0, 3.0f); + std::vector out0(d0 * d0, 0.0f); + std::vector out1(d0 * d0, 0.0f); + + xnn_runtime_t runtime = nullptr; + ASSERT_EQ(xnn_create_runtime_v4(subgraph, nullptr, workspace, threadpool, 0, + &runtime), + xnn_status_success); + + std::vector external_values = { + xnn_external_value{0, lhs.data()}, + xnn_external_value{1, rhs.data()}, + xnn_external_value{2, out0.data()}, + xnn_external_value{3, out1.data()}, + }; + + ASSERT_EQ(xnn_reshape_runtime(runtime), xnn_status_success); + ASSERT_EQ(xnn_setup_runtime_v2(runtime, 4, external_values.data()), + xnn_status_success); + + ASSERT_EQ(xnn_invoke_runtime(runtime), xnn_status_success); + + if (ParallelLoopRunner* runner = GetParallelLoopRunner(threadpool)) { + tsl::BlockUntilReady(runner->done_event()); + ASSERT_TRUE(runner->done_event().IsConcrete()); + } + + ASSERT_TRUE(absl::c_all_of(out0, [](float v) { return v == 5.0f; })); + ASSERT_TRUE(absl::c_all_of(out1, [](float v) { return v == 6.0f; })); + + ASSERT_EQ(xnn_delete_runtime(runtime), xnn_status_success); + ASSERT_EQ(xnn_delete_subgraph(subgraph), xnn_status_success); + ASSERT_EQ(xnn_release_workspace(workspace), xnn_status_success); + + pthreadpool_destroy(threadpool); +} + +TEST(XnnThreadPoolTest, Dot) { + pthreadpool_t threadpool = pthreadpool_create(8); + ASSERT_NE(threadpool, nullptr); + + ASSERT_EQ(xnn_initialize(/*allocator=*/nullptr), xnn_status_success); + + xnn_workspace_t workspace = nullptr; + ASSERT_EQ(xnn_create_workspace(&workspace), xnn_status_success); + + xnn_subgraph_t subgraph = nullptr; + + ASSERT_EQ( + xnn_create_subgraph(/*external_value_ids=*/3, /*flags=*/0, &subgraph), + xnn_status_success); + + size_t m = 256, k = 256, n = 256; + CreateDotSubgraph(subgraph, m, k, n); + + std::vector lhs(m * k, 1.0f); + std::vector rhs(k * n, 1.0f); + std::vector out(m * n, 0.0f); + + xnn_runtime_t runtime = nullptr; + ASSERT_EQ(xnn_create_runtime_v4(subgraph, nullptr, workspace, threadpool, 0, + &runtime), + xnn_status_success); + + std::vector external_values = { + xnn_external_value{0, lhs.data()}, + xnn_external_value{1, rhs.data()}, + xnn_external_value{2, out.data()}, + }; + + ASSERT_EQ(xnn_reshape_runtime(runtime), xnn_status_success); + ASSERT_EQ(xnn_setup_runtime_v2(runtime, 3, external_values.data()), + xnn_status_success); + + ASSERT_EQ(xnn_invoke_runtime(runtime), xnn_status_success); + + if (ParallelLoopRunner* runner = GetParallelLoopRunner(threadpool)) { + tsl::BlockUntilReady(runner->done_event()); + ASSERT_TRUE(runner->done_event().IsConcrete()); + } + + ASSERT_TRUE(absl::c_all_of(out, [&](float v) { return v == k; })); + + ASSERT_EQ(xnn_delete_runtime(runtime), xnn_status_success); + ASSERT_EQ(xnn_delete_subgraph(subgraph), xnn_status_success); + ASSERT_EQ(xnn_release_workspace(workspace), xnn_status_success); + + pthreadpool_destroy(threadpool); +} + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/testlib/BUILD b/third_party/xla/xla/backends/cpu/testlib/BUILD index 668fce96125486..66e1b0da302d94 100644 --- a/third_party/xla/xla/backends/cpu/testlib/BUILD +++ b/third_party/xla/xla/backends/cpu/testlib/BUILD @@ -23,14 +23,16 @@ cc_library( srcs = ["kernel_runner.cc"], hdrs = ["kernel_runner.h"], deps = [ - ":llvm_ir_kernel_spec", "//xla/backends/cpu/codegen:jit_compiler", + "//xla/backends/cpu/codegen:llvm_ir_kernel_spec", "//xla/backends/cpu/runtime:function_library", "//xla/backends/cpu/runtime:kernel", "//xla/backends/cpu/runtime:kernel_c_api", "//xla/codegen:kernel_spec", "//xla/codegen:llvm_ir_kernel_source", "//xla/codegen/testlib:kernel_runner", + "//xla/service/cpu:runtime_symbol_generator", + "//xla/tsl/platform:errors", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -64,12 +66,12 @@ xla_cc_test( cc_library( name = "llvm_ir_kernel_emitter", - testonly = 1, + testonly = 1, # TODO(willfroom): Move to runtime(?) & plug into ir_emitter2 once the interface is stable. srcs = ["llvm_ir_kernel_emitter.cc"], hdrs = ["llvm_ir_kernel_emitter.h"], deps = [ - ":llvm_ir_kernel_spec", "//xla:util", + "//xla/backends/cpu/codegen:llvm_ir_kernel_spec", "//xla/codegen:kernel_emitter", "//xla/codegen:kernel_spec", "//xla/codegen:llvm_ir_kernel_source", @@ -77,6 +79,7 @@ cc_library( "//xla/service:buffer_assignment", "//xla/stream_executor:launch_dim", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//llvm:AsmParser", "@llvm-project//llvm:Core", @@ -84,47 +87,44 @@ cc_library( ], ) -cc_library( - name = "llvm_ir_kernel_spec", - testonly = 1, - srcs = ["llvm_ir_kernel_spec.cc"], - hdrs = ["llvm_ir_kernel_spec.h"], - deps = [ - "//xla/codegen:kernel_spec", - "//xla/codegen:llvm_ir_kernel_source", - "//xla/service:buffer_assignment", - "//xla/stream_executor:launch_dim", - ], -) - tsl_pybind_extension( - name = "kernel_runner_extention", + name = "_extension", testonly = 1, - srcs = ["kernel_runner_extention.cc"], - visibility = ["//visibility:private"], # the extention should always be linked via kernel_runner_pylib + srcs = ["kernel_runner_extension.cc"], + visibility = ["//visibility:private"], # the extension should always be linked via testlib deps = [ ":kernel_runner", ":llvm_ir_kernel_emitter", - ":llvm_ir_kernel_spec", - # placeholder for index annotation deps + # placeholder for index annotation deps # buildcleaner: keep "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@nanobind", "@local_config_python//:python_headers", # buildcleaner: keep + "//xla/backends/cpu/codegen:elemental_kernel_emitter", + "//xla/backends/cpu/codegen:jit_compiler", + "//xla/backends/cpu/codegen:llvm_ir_kernel_spec", + "//xla/backends/cpu/codegen:target_machine_features", + "//xla/codegen:kernel_emitter", "//xla/codegen:kernel_spec", + "//xla/codegen/testlib:kernel_runner", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service/cpu:cpu_compiler_pure", "//xla/stream_executor:launch_dim", ], ) pytype_strict_library( - name = "kernel_runner_pylib", + name = "testlib", testonly = 1, - srcs = ["kernel_runner.py"], + srcs = [ + "__init__.py", + ], srcs_version = "PY3", deps = [ - ":kernel_runner_extention", - "//xla/codegen/testlib:kernel_runner_pylib", + ":_extension", + "//xla/codegen/testlib", # buildcleaner: keep ], ) @@ -154,9 +154,28 @@ py_strict_test( "no_oss", ], deps = [ - ":kernel_runner_pylib", + ":testlib", + "//third_party/py/numpy", + "//xla/codegen/testlib", + "@absl_py//absl/testing:absltest", + ], +) + +py_strict_test( + name = "elemental_kernel_emitter_test", + srcs = ["elemental_kernel_emitter_test.py"], + main = "elemental_kernel_emitter_test.py", + python_version = "PY3", + srcs_version = "PY3", + tags = [ + "no_oss", + ], + deps = [ + ":testlib", "//third_party/py/numpy", - "//xla/codegen/testlib:kernel_runner_pylib", + "//xla/codegen/testlib", + "//xla/python:xla_extension", "@absl_py//absl/testing:absltest", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/third_party/xla/xla/backends/cpu/testlib/__init__.py b/third_party/xla/xla/backends/cpu/testlib/__init__.py index e69de29bb2d1d6..74881ff0f44ce3 100644 --- a/third_party/xla/xla/backends/cpu/testlib/__init__.py +++ b/third_party/xla/xla/backends/cpu/testlib/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Public API for cpu codegen testlib.""" + +from xla.backends.cpu.testlib import _extension + +# go/keep-sorted start +ElementalKernelEmitter = _extension.ElementalKernelEmitter +HloCompiler = _extension.HloCompiler +JitCompiler = _extension.JitCompiler +KernelRunner = _extension.KernelRunner +LlvmIrKernelEmitter = _extension.LlvmIrKernelEmitter +LlvmIrKernelSpec = _extension.LlvmIrKernelSpec +TargetMachineFeatures = _extension.TargetMachineFeatures +# go/keep-sorted end diff --git a/third_party/xla/xla/backends/cpu/testlib/elemental_kernel_emitter_test.py b/third_party/xla/xla/backends/cpu/testlib/elemental_kernel_emitter_test.py new file mode 100644 index 00000000000000..fd24142d2ae916 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/testlib/elemental_kernel_emitter_test.py @@ -0,0 +1,383 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from collections.abc import Callable, Sequence +import dataclasses +import itertools +import math + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np + +from xla.backends.cpu import testlib as testlib_cpu +from xla.codegen import testlib as testlib_base +from xla.codegen.testlib import utilities as testlib_utilities +from xla.python import xla_extension + +HloOpcode = testlib_base.HloOpcode +create_literal = testlib_base.utilities.create_literal_from_np +HloInstruction = testlib_base.HloInstruction +ComparisonDirection = testlib_base.ComparisonDirection +_inf = float("inf") + + +def create_input( + value_range: tuple[float, float], + shape: Sequence[int], + dtype: np.dtype, + shuffle: bool = False, +) -> np.ndarray: + size = np.prod(shape) if shape else 1 + result = np.linspace( + value_range[0], value_range[1], size, dtype=dtype + ).reshape(shape) + + if shuffle and (np.ndim(result) != 0): + np.random.shuffle(result) + + return result + + +def np_erf(x): + return np.vectorize(math.erf, otypes=[x.dtype])(x) + + +@dataclasses.dataclass(frozen=True) +class ElementalHloOpcodeDef: + op: HloOpcode + np_op: Callable[[np.ndarray, ...], np.ndarray] + input_ranges: tuple[float, float] = (-1.0, 1.0) + decimal_precision: int = 6 + + # For simple unpacking + def __iter__(self): + return iter( + (self.op, self.np_op, self.input_ranges, self.decimal_precision) + ) + + def __repr__(self): + return f"{self.op.name}({self.input_ranges})" + + +@parameterized.product( + op_def=[ + ElementalHloOpcodeDef(HloOpcode.sine, np.sin), + ElementalHloOpcodeDef(HloOpcode.cosine, np.cos), + ElementalHloOpcodeDef(HloOpcode.tan, np.tan), + ElementalHloOpcodeDef(HloOpcode.exponential, np.exp), + ElementalHloOpcodeDef(HloOpcode.log, np.log, (0.01, 10.0)), + ElementalHloOpcodeDef(HloOpcode.log_plus_one, np.log1p), + ElementalHloOpcodeDef(HloOpcode.sqrt, np.sqrt), + ElementalHloOpcodeDef( + HloOpcode.rsqrt, lambda x: np.reciprocal(np.sqrt(x)) + ), + ElementalHloOpcodeDef(HloOpcode.cbrt, np.cbrt), + ElementalHloOpcodeDef(HloOpcode.power, np.pow), + ElementalHloOpcodeDef(HloOpcode.add, np.add), + ElementalHloOpcodeDef(HloOpcode.subtract, np.subtract), + ElementalHloOpcodeDef(HloOpcode.multiply, np.multiply), + ElementalHloOpcodeDef(HloOpcode.divide, np.divide), + ElementalHloOpcodeDef(HloOpcode.maximum, np.maximum), + ElementalHloOpcodeDef(HloOpcode.minimum, np.minimum), + ElementalHloOpcodeDef(HloOpcode.sign, np.sign), + ElementalHloOpcodeDef(HloOpcode.negate, np.negative), + ElementalHloOpcodeDef(HloOpcode.is_finite, np.isfinite, (-_inf, _inf)), + ElementalHloOpcodeDef(HloOpcode.ceil, np.ceil, (-10.0, 10.0)), + ElementalHloOpcodeDef(HloOpcode.floor, np.floor, (-5.0, 5.0)), + ElementalHloOpcodeDef(HloOpcode.tanh, np.tanh), + ElementalHloOpcodeDef(HloOpcode.atan2, np.arctan2), + ElementalHloOpcodeDef(HloOpcode.erf, np_erf), + ElementalHloOpcodeDef(HloOpcode.exponential_minus_one, np.expm1), + # TODO(willfroom): Update to use better inputs for the following. + ElementalHloOpcodeDef(HloOpcode.clamp, np.clip), + # TODO(willfroom): Add complex ops. + # ElementalHloOpcodeDef(HloOpcode.complex, np.complex), + # ElementalHloOpcodeDef(HloOpcode.real, np.real), + # ElementalHloOpcodeDef(HloOpcode.imag, np.imag), + # TODO(willfroom): go through ElementalIrEmitter interface and ensure + # that all ops are implemented. + # ... + ], + shape=[(4,), (4, 3), (4, 3, 10)], + dtype=[np.dtype(np.float32), np.dtype(np.float64)], +) +class ElementalKernelRunnerTest(absltest.TestCase): + + def id(self): + return self._test_params_reprs.get(self._testMethodName, "") + + def test_elemental_kernel_emitter( + self, + op_def: ElementalHloOpcodeDef, + shape: tuple[int, ...], + dtype: np.dtype, + ): + + [op, np_op, input_ranges, decimal_precision] = op_def + + num_inputs = testlib_utilities.opcode_arity(op) + self.assertIsNotNone(num_inputs) + + np_inputs = [ + create_input(input_ranges, shape, dtype) for _ in range(num_inputs) + ] + input_literals = [create_literal(input_array) for input_array in np_inputs] + + expected_output = np_op(*np_inputs) + output_literal = create_literal( + np.ndarray(shape, dtype=expected_output.dtype) + ) + + hlo_parameters = [ + HloInstruction.create_parameter(idx, literal.shape(), f"input_{idx}") + for [idx, literal] in enumerate(input_literals) + ] + + hlo_op = HloInstruction.create_variadic( + output_literal.shape(), op, hlo_parameters + ) + + emitter = testlib_cpu.ElementalKernelEmitter(hlo_op) + kernel_spec = emitter.emit_kernel_spec() + self.assertIsNotNone(kernel_spec) + + # kernel_spec is consumed by the runner, so we need to save the IR string + # before passing it to the runner. + ir_string = str(kernel_spec.kernel_source()) + + runner = testlib_cpu.KernelRunner.create(kernel_spec) + + runner.call(list(itertools.chain(input_literals, [output_literal]))) + np.testing.assert_array_almost_equal( + np.asarray(output_literal), + expected_output, + decimal=decimal_precision, + err_msg=ir_string, + ) + + +@parameterized.product( + op_def=[ + (ComparisonDirection.kEq, np.equal), + (ComparisonDirection.kNe, np.not_equal), + (ComparisonDirection.kGe, np.greater_equal), + (ComparisonDirection.kGt, np.greater), + (ComparisonDirection.kLe, np.less_equal), + (ComparisonDirection.kLt, np.less), + ], + shape=[(4,), (4, 3), (4, 3, 10)], + dtype=[ + np.dtype(np.uint8), + np.dtype(np.uint16), + np.dtype(np.uint32), + np.dtype(np.uint64), + np.dtype(np.int8), + np.dtype(np.int16), + np.dtype(np.int32), + np.dtype(np.int64), + np.dtype(np.float16), + np.dtype(np.float32), + np.dtype(np.float64), + ], +) +class ElementalComparisonKernelRunnerTest(absltest.TestCase): + + def test_elemental_comparision_kernel_emitter(self, op_def, shape, dtype): + [direction, np_op] = op_def + + is_unsigned = np.issubdtype(dtype, np.unsignedinteger) + value_range = (0.0, 20.0) if is_unsigned else (-10.0, 10.0) + lhs_np = create_input(value_range, shape, dtype, shuffle=True) + rhs_np = create_input(value_range, shape, dtype, shuffle=True) + + lhs_literal = create_literal(lhs_np) + rhs_literal = create_literal(rhs_np) + + output_literal = create_literal(np.ndarray(shape, dtype=np.bool)) + + lhs_param = HloInstruction.create_parameter(0, lhs_literal.shape(), "lhs") + rhs_param = HloInstruction.create_parameter(1, rhs_literal.shape(), "rhs") + + hlo_op = HloInstruction.create_compare( + output_literal.shape(), lhs_param, rhs_param, direction + ) + + emitter = testlib_cpu.ElementalKernelEmitter(hlo_op) + + runner = testlib_cpu.KernelRunner.create(emitter.emit_kernel_spec()) + + runner.call([lhs_literal, rhs_literal, output_literal]) + np.testing.assert_equal( + np.asarray(output_literal), + np_op(lhs_np, rhs_np), + ) + + +@parameterized.product( + input_dimensions=[(4,), (4, 3), (4, 3, 10)], + dtype=[ + np.dtype(np.uint8), + np.dtype(np.uint16), + np.dtype(np.uint32), + np.dtype(np.uint64), + np.dtype(np.int8), + np.dtype(np.int16), + np.dtype(np.int32), + np.dtype(np.int64), + np.dtype(np.float16), + np.dtype(np.float32), + np.dtype(np.float64), + ], +) +class HloModuleKernelRunnerTest(absltest.TestCase): + + def id(self): + return self._test_params_reprs.get(self._testMethodName, "") + + def test_map(self, input_dimensions, dtype): + scalar_shape = xla_extension.Shape.scalar_shape(dtype) + shape = xla_extension.Shape.array_shape(dtype, input_dimensions) + + # Please note the double curly braces is to escape the python string + # formatting. + hlo = """ + HloModule test_map + + double {{ + a = {scalar_shape} parameter(0) + b = {scalar_shape} constant(2) + ROOT doubled = {scalar_shape} multiply(a, b) + }} + + ENTRY main {{ + a = {shape} parameter(0) + ROOT mapped = {shape} map(a), to_apply=double + }} + """.format(scalar_shape=scalar_shape, shape=shape) + + hlo_compiler = testlib_cpu.HloCompiler() + hlo_module = testlib_base.HloModule.parse_from_string(hlo) + hlo_module.set_schedule(hlo_compiler.create_hlo_schedule(hlo_module)) + buffer_assignment = hlo_compiler.create_buffer_assignment(hlo_module) + + jit_compiler = testlib_cpu.JitCompiler() + + emitter = testlib_cpu.ElementalKernelEmitter( + hlo_module.get_root_instruction(), + buffer_assignment, + jit_compiler.get_target_machine(), + ) + + input_np = create_input([0, 10], input_dimensions, dtype, shuffle=True) + + input_literal = create_literal(input_np) + + output_literal = xla_extension.Literal(shape) + + runner = testlib_cpu.KernelRunner.create( + emitter.emit_kernel_spec(), jit_compiler + ) + + runner.call([input_literal, output_literal]) + + np.testing.assert_equal( + np.asarray(output_literal), + input_np * 2, + ) + + def test_reduce(self, input_dimensions, dtype): + # Iterate over all combinations of reduce dimensions. + for reduce_dimensions in itertools.chain.from_iterable( + itertools.combinations(range(len(input_dimensions)), r) + for r in range(1, len(input_dimensions)) + ): + scalar_shape = xla_extension.Shape.scalar_shape(dtype) + input_shape = xla_extension.Shape.array_shape(dtype, input_dimensions) + + output_dimensions = [ + dim + for idx, dim in enumerate(input_dimensions) + if idx not in reduce_dimensions + ] + # Result can overflow in int8 (which results in undefined behavior), + # so we use int16 instead. + output_dtype = np.dtype(np.int16) if (dtype == np.int8) else dtype + output_shape = xla_extension.Shape.array_shape( + output_dtype, output_dimensions + ) + + # Please note the double curly braces is to escape the python string + # formatting. + hlo = """ + HloModule test_reduce + + add_method {{ + a = {scalar_shape} parameter(0) + b = {scalar_shape} parameter(1) + ROOT add = {scalar_shape} add(a, b) + }} + + ENTRY main {{ + array = {input_shape} parameter(0) + initial_value = {scalar_shape} parameter(1) + ROOT reduced = {output_shape} reduce(array, initial_value), + dimensions={{{reduce_dimensions}}}, to_apply=add_method + }} + """.format( + scalar_shape=scalar_shape, + input_shape=input_shape, + reduce_dimensions=",".join(map(str, reduce_dimensions)), + output_shape=output_shape, + ) + + hlo_compiler = testlib_cpu.HloCompiler() + hlo_module = testlib_base.HloModule.parse_from_string(hlo) + hlo_module.set_schedule(hlo_compiler.create_hlo_schedule(hlo_module)) + buffer_assignment = hlo_compiler.create_buffer_assignment(hlo_module) + + jit_compiler = testlib_cpu.JitCompiler() + + emitter = testlib_cpu.ElementalKernelEmitter( + hlo_module.get_root_instruction(), + buffer_assignment, + jit_compiler.get_target_machine(), + ) + + input_np = create_input([0, 10], input_dimensions, dtype) + input_literal = create_literal(input_np) + + initial_value_np = create_input([0, 10], (), dtype) + initial_value_literal = create_literal(initial_value_np) + + output_literal = xla_extension.Literal(output_shape) + + runner = testlib_cpu.KernelRunner.create( + emitter.emit_kernel_spec(), jit_compiler + ) + + runner.call([input_literal, initial_value_literal, output_literal]) + + np.testing.assert_array_almost_equal_nulp( + np.asarray(output_literal), + np.add.reduce( + input_np, axis=reduce_dimensions, initial=initial_value_np + ), + nulp=3, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/third_party/xla/xla/backends/cpu/testlib/kernel_runner.cc b/third_party/xla/xla/backends/cpu/testlib/kernel_runner.cc index d90c52120873f2..595f9c274e6c35 100644 --- a/third_party/xla/xla/backends/cpu/testlib/kernel_runner.cc +++ b/third_party/xla/xla/backends/cpu/testlib/kernel_runner.cc @@ -23,14 +23,17 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/types/span.h" +#include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" #include "xla/backends/cpu/codegen/jit_compiler.h" +#include "xla/backends/cpu/codegen/llvm_ir_kernel_spec.h" #include "xla/backends/cpu/runtime/function_library.h" #include "xla/backends/cpu/runtime/kernel.h" #include "xla/backends/cpu/runtime/kernel_c_api.h" -#include "xla/backends/cpu/testlib/llvm_ir_kernel_spec.h" #include "xla/codegen/kernel_spec.h" #include "xla/codegen/llvm_ir_kernel_source.h" +#include "xla/service/cpu/runtime_symbol_generator.h" +#include "xla/tsl/platform/errors.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -42,21 +45,18 @@ absl::StatusOr KernelRunner::Create( // creation of KernelRunner from different kernel spec types. if (auto* llvm_kernel_spec = dynamic_cast(kernel_spec.get())) { - return Create(std::move(*llvm_kernel_spec)); + TF_ASSIGN_OR_RETURN(JitCompiler compiler, CreateJitCompiler()); + return Create(std::move(*llvm_kernel_spec), std::move(compiler)); } return absl::InvalidArgumentError("Unrecognised kernel spec type"); } -absl::StatusOr KernelRunner::Create( - LlvmIrKernelSpec kernel_spec) { +absl::StatusOr KernelRunner::Create(LlvmIrKernelSpec kernel_spec, + JitCompiler compiler) { LlvmIrKernelSource& kernel_source = kernel_spec.kernel_source(); - TF_ASSIGN_OR_RETURN( - JitCompiler compiler, - JitCompiler::Create(llvm::TargetOptions{}, JitCompiler::Options{})); - - // intentional copy as we need to use the kernel name after consuming + // Intentional copy as we need to use the kernel name after consuming // (std::move) the kernel source. std::string kernel_name = kernel_source.kernel_name(); @@ -89,4 +89,19 @@ absl::Status KernelRunner::Call(absl::Span arguments) { return kernel_.Launch(thread_dim_, kernel_args); } +absl::StatusOr KernelRunner::CreateJitCompiler() { + llvm::TargetOptions target_options; + target_options.AllowFPOpFusion = llvm::FPOpFusion::Fast; + + // Needed to resolve symbols such as built in intrinsics (sin, cos etc). + JitCompiler::Options jit_compiler_options; + jit_compiler_options.definition_generator = + [](llvm::TargetMachine* target_machine) { + return std::make_unique( + target_machine->createDataLayout()); + }; + + return JitCompiler::Create(target_options, jit_compiler_options); +} + } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/testlib/kernel_runner.h b/third_party/xla/xla/backends/cpu/testlib/kernel_runner.h index a102c6ad04197a..503ab81d0eadb8 100644 --- a/third_party/xla/xla/backends/cpu/testlib/kernel_runner.h +++ b/third_party/xla/xla/backends/cpu/testlib/kernel_runner.h @@ -21,10 +21,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/cpu/codegen/jit_compiler.h" +#include "xla/backends/cpu/codegen/llvm_ir_kernel_spec.h" #include "xla/backends/cpu/runtime/function_library.h" #include "xla/backends/cpu/runtime/kernel.h" -#include "xla/backends/cpu/runtime/kernel_c_api.h" -#include "xla/backends/cpu/testlib/llvm_ir_kernel_spec.h" #include "xla/codegen/kernel_spec.h" #include "xla/codegen/testlib/kernel_runner.h" @@ -42,13 +42,16 @@ class KernelRunner final : public xla::KernelRunner { // Keep this llvm specific constructor for python bindings: // nanobind will do the downcasting for us and give the python specific // error if there is not a valid Create(...) call. - static absl::StatusOr Create(LlvmIrKernelSpec kernel_spec); + static absl::StatusOr Create(LlvmIrKernelSpec kernel_spec, + JitCompiler compiler); KernelRunner(KernelRunner&&) = default; KernelRunner& operator=(KernelRunner&&) = default; absl::Status Call(absl::Span arguments) final; + static absl::StatusOr CreateJitCompiler(); + private: KernelRunner(std::unique_ptr library, Kernel kernel, Kernel::ThreadDim thread_dim); diff --git a/third_party/xla/xla/backends/cpu/testlib/kernel_runner_extension.cc b/third_party/xla/xla/backends/cpu/testlib/kernel_runner_extension.cc new file mode 100644 index 00000000000000..c9ba0f848d12d5 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/testlib/kernel_runner_extension.cc @@ -0,0 +1,176 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/backends/cpu/codegen/elemental_kernel_emitter.h" +#include "xla/backends/cpu/codegen/jit_compiler.h" +#include "xla/backends/cpu/codegen/llvm_ir_kernel_spec.h" +#include "xla/backends/cpu/codegen/target_machine_features.h" +#include "xla/backends/cpu/testlib/kernel_runner.h" +#include "xla/backends/cpu/testlib/llvm_ir_kernel_emitter.h" +#include "xla/codegen/kernel_emitter.h" +#include "xla/codegen/kernel_spec.h" +#include "xla/codegen/testlib/kernel_runner.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/cpu_compiler.h" +#include "xla/stream_executor/launch_dim.h" + +namespace xla::cpu { + +namespace nb = nanobind; + +void ImportBaseClasses(const nb::module_& kernel_runner_module) { + absl::string_view module_name = + nb::borrow(nb::getattr(kernel_runner_module, "__name__")) + .c_str(); + + // Sequentially strip the module name until we get to the base xla module. + absl::string_view cpu_testlib_module = + module_name.substr(0, module_name.find_last_of('.')); + absl::string_view cpu_module = + cpu_testlib_module.substr(0, cpu_testlib_module.find_last_of('.')); + absl::string_view backends_module = + cpu_module.substr(0, cpu_module.find_last_of('.')); + absl::string_view xla_module = + backends_module.substr(0, backends_module.find_last_of('.')); + + nb::module_::import_(absl::StrCat(xla_module, ".codegen.testlib").c_str()); +} + +NB_MODULE(_extension, kernel_runner_module) { + // We depend on the base classes so must import them before python tries to + // register the derived versions. + ImportBaseClasses(kernel_runner_module); + + nb::class_ kernel_spec(kernel_runner_module, + "LlvmIrKernelSpec"); + + // Use a tuple and cast to ThreadDim to take advantage of built in bindings. + using NbThreadDim = std::tuple; + nb::class_(kernel_runner_module, + "LlvmIrKernelEmitter") + .def("__init__", [](LlvmIrKernelEmitter* self, absl::string_view ir, + absl::string_view kernel_name, + NbThreadDim thread_dim) { + new (self) LlvmIrKernelEmitter( + ir, kernel_name, + se::ThreadDim{std::get<0>(thread_dim), std::get<1>(thread_dim), + std::get<2>(thread_dim)}, + {}); + }); + + nb::class_(kernel_runner_module, "HloCompiler") + .def(nb::init<>()) + .def("create_buffer_assignment", + [](const CpuCompiler& self, const HloModule& hlo_module) { + absl::StatusOr> + buffer_assignment = self.CreateBufferAssignment(hlo_module); + + if (!buffer_assignment.ok()) { + throw std::runtime_error( + std::string(buffer_assignment.status().message())); + } + + return std::move(buffer_assignment).value(); + }) + .def("create_hlo_schedule", [](const CpuCompiler& self, + const HloModule& hlo_module) { + absl::StatusOr schedule = + self.CreateHloSchedule(hlo_module); + + if (!schedule.ok()) { + throw std::runtime_error(std::string(schedule.status().message())); + } + + return std::move(schedule).value(); + }); + + nb::class_(kernel_runner_module, + "TargetMachineFeatures") + .def("__str__", &TargetMachineFeatures::get_target_feature_string); + + nb::class_(kernel_runner_module, + "ElementalKernelEmitter") + .def(nb::init(), nb::keep_alive<1, 2>()) + .def(nb::init(), + nb::keep_alive<1, 2>(), nb::keep_alive<1, 3>(), + nb::keep_alive<1, 4>()); + + nb::class_(kernel_runner_module, "JitCompiler") + .def(nb::new_([]() { + absl::StatusOr compiler = + KernelRunner::CreateJitCompiler(); + + if (!compiler.ok()) { + throw std::runtime_error(std::string(compiler.status().message())); + } + + return std::make_unique( + JitCompiler(std::move(compiler).value())); + })) + .def( + "get_target_machine", + [](JitCompiler* self) { + return std::make_unique( + self->target_machine()); + }, + nb::rv_policy::reference_internal); + + nb::class_(kernel_runner_module, + "KernelRunner") + .def_static( + "create", + [](std::unique_ptr kernel_spec, + std::unique_ptr jit_compiler) { + absl::StatusOr runner = KernelRunner::Create( + std::move(*kernel_spec), std::move(*jit_compiler)); + + if (!runner.ok()) { + throw std::runtime_error(std::string(runner.status().message())); + } + + return std::move(runner).value(); + }) + .def_static("create", [](std::unique_ptr kernel_spec) { + absl::StatusOr runner = + KernelRunner::Create(std::move(kernel_spec)); + + if (!runner.ok()) { + throw std::runtime_error(std::string(runner.status().message())); + } + + return std::move(runner).value(); + }); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/testlib/kernel_runner_extention.cc b/third_party/xla/xla/backends/cpu/testlib/kernel_runner_extention.cc deleted file mode 100644 index a53d625e22edb4..00000000000000 --- a/third_party/xla/xla/backends/cpu/testlib/kernel_runner_extention.cc +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "nanobind/nanobind.h" -#include "nanobind/stl/string_view.h" // IWYU pragma: keep -#include "nanobind/stl/tuple.h" // IWYU pragma: keep -#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep -#include "xla/backends/cpu/testlib/kernel_runner.h" -#include "xla/backends/cpu/testlib/llvm_ir_kernel_emitter.h" -#include "xla/backends/cpu/testlib/llvm_ir_kernel_spec.h" -#include "xla/codegen/kernel_spec.h" -#include "xla/stream_executor/launch_dim.h" - -namespace xla::cpu { - -namespace nb = nanobind; - -void ImportBaseClasses(const nb::module_& kernel_runner_module) { - absl::string_view module_name = - nb::borrow(nb::getattr(kernel_runner_module, "__name__")) - .c_str(); - - // Sequentially strip the module name until we get to the base xla module. - absl::string_view cpu_testlib_module = - module_name.substr(0, module_name.find_last_of('.')); - absl::string_view cpu_module = - cpu_testlib_module.substr(0, cpu_testlib_module.find_last_of('.')); - absl::string_view backends_module = - cpu_module.substr(0, cpu_module.find_last_of('.')); - absl::string_view xla_module = - backends_module.substr(0, backends_module.find_last_of('.')); - - nb::module_::import_( - absl::StrCat(xla_module, ".codegen.testlib.kernel_runner").c_str()); -} - -NB_MODULE(kernel_runner_extention, kernel_runner_module) { - // We depend on the base classes so must import them before python tries to - // register the derived versions. - ImportBaseClasses(kernel_runner_module); - - nb::class_(kernel_runner_module, - "LlvmIrKernelSpec"); - - // Use a tuple and cast to ThreadDim to take advantage of built in bindings. - using NbThreadDim = std::tuple; - nb::class_(kernel_runner_module, - "LlvmIrKernelEmitter") - .def("__init__", [](LlvmIrKernelEmitter* self, std::string_view ir, - std::string_view kernel_name, - std::tuple thread_dim) { - new (self) LlvmIrKernelEmitter( - ir, kernel_name, - se::ThreadDim{std::get<0>(thread_dim), std::get<1>(thread_dim), - std::get<2>(thread_dim)}, - {}); - }); - - nb::class_(kernel_runner_module, - "KernelRunner") - .def_static("create", [](std::unique_ptr kernel_spec) { - absl::StatusOr runner = - KernelRunner::Create(std::move(*kernel_spec)); - - if (!runner.ok()) { - throw std::runtime_error(std::string(runner.status().message())); - } - - return std::move(runner).value(); - }); -} - -} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/testlib/kernel_runner_test.cc b/third_party/xla/xla/backends/cpu/testlib/kernel_runner_test.cc index 8339aa52f04175..b1cd7123305bcf 100644 --- a/third_party/xla/xla/backends/cpu/testlib/kernel_runner_test.cc +++ b/third_party/xla/xla/backends/cpu/testlib/kernel_runner_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -42,7 +41,7 @@ namespace xla::cpu { using ::testing::Eq; TEST(KernelRunnerTest, Add) { - static constexpr std::string_view kLlvmAddI32 = R"( + static constexpr absl::string_view kLlvmAddI32 = R"( %struct.XLA_CPU_KernelCallFrame = type { ptr, ptr, i64, ptr } %struct.XLA_CPU_KernelArg = type { ptr, i64 } ; c = a + b (per thread) diff --git a/third_party/xla/xla/backends/cpu/testlib/kernel_runner_test.py b/third_party/xla/xla/backends/cpu/testlib/kernel_runner_test.py index 9fbee631a5ea18..01fb0e3c24dadf 100644 --- a/third_party/xla/xla/backends/cpu/testlib/kernel_runner_test.py +++ b/third_party/xla/xla/backends/cpu/testlib/kernel_runner_test.py @@ -16,10 +16,10 @@ from absl.testing import absltest import numpy as np -from xla.backends.cpu.testlib import kernel_runner -from xla.codegen.testlib import kernel_runner as kernel_runner_base +from xla.backends.cpu import testlib as cpu_testlib +from xla.codegen.testlib import utilities as testlib_utilities -create_literal = kernel_runner_base.create_literal_from_np +create_literal = testlib_utilities.create_literal_from_np class LLvmKernelRunnerTest(absltest.TestCase): @@ -51,13 +51,11 @@ def test_llvm_ir_kernel_runner(self): ret ptr null } """ - llvm_emitter = kernel_runner.LlvmIrKernelEmitter( - ir, "LlvmAddI32", (4, 1, 1) - ) + llvm_emitter = cpu_testlib.LlvmIrKernelEmitter(ir, "LlvmAddI32", (4, 1, 1)) llvm_spec = llvm_emitter.emit_kernel_spec() - runner = kernel_runner.KernelRunner.create(llvm_spec) + runner = cpu_testlib.KernelRunner.create(llvm_spec) a = create_literal(np.array([1, 2, 3, 4], dtype=np.int32)) b = create_literal(np.array([5, 6, 7, 8], dtype=np.int32)) c = create_literal(np.array([0, 0, 0, 0], dtype=np.int32)) diff --git a/third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_emitter.cc b/third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_emitter.cc index e244f0177c16f7..765da32a5cc086 100644 --- a/third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_emitter.cc +++ b/third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_emitter.cc @@ -16,18 +16,17 @@ limitations under the License. #include "xla/backends/cpu/testlib/llvm_ir_kernel_emitter.h" #include -#include -#include #include #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/LLVMContext.h" #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SourceMgr.h" -#include "xla/backends/cpu/testlib/llvm_ir_kernel_spec.h" +#include "xla/backends/cpu/codegen/llvm_ir_kernel_spec.h" #include "xla/codegen/kernel_spec.h" #include "xla/codegen/llvm_ir_kernel_source.h" #include "xla/runtime/buffer_use.h" @@ -40,8 +39,8 @@ namespace { } // namespace -LlvmIrKernelEmitter::LlvmIrKernelEmitter(std::string_view llvm_ir, - std::string_view kernel_name, +LlvmIrKernelEmitter::LlvmIrKernelEmitter(absl::string_view llvm_ir, + absl::string_view kernel_name, se::ThreadDim thread_dim, absl::Span args) : llvm_ir_(llvm_ir), diff --git a/third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_emitter.h b/third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_emitter.h index 1606efd3cbe2c0..60e737b583278b 100644 --- a/third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_emitter.h +++ b/third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_emitter.h @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/status/statusor.h" @@ -45,7 +44,7 @@ class LlvmIrKernelEmitter : public KernelEmitter { BufferUse::MemoryAccess memory_access; }; - LlvmIrKernelEmitter(std::string_view llvm_ir, std::string_view kernel_name, + LlvmIrKernelEmitter(absl::string_view llvm_ir, absl::string_view kernel_name, se::ThreadDim thread_dim, absl::Span args); diff --git a/third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_emitter_test.cc b/third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_emitter_test.cc index ba3a66a3b7e2fe..91717bfbdbb80f 100644 --- a/third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_emitter_test.cc +++ b/third_party/xla/xla/backends/cpu/testlib/llvm_ir_kernel_emitter_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/backends/cpu/testlib/llvm_ir_kernel_emitter.h" #include -#include #include "xla/codegen/kernel_spec.h" #include "xla/codegen/llvm_ir_kernel_source.h" @@ -29,7 +28,7 @@ limitations under the License. namespace xla::cpu { TEST(LlvmIrKernelEmitterTest, ParseLlvmIr) { - static constexpr std::string_view kLlvmIr = R"( + static constexpr absl::string_view kLlvmIr = R"( define ptr @noop(ptr noundef %0) { ret ptr null } diff --git a/third_party/xla/xla/backends/cpu/xnn_emitter.cc b/third_party/xla/xla/backends/cpu/xnn_emitter.cc new file mode 100644 index 00000000000000..99c68e9ce35d07 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/xnn_emitter.cc @@ -0,0 +1,225 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/xnn_emitter.h" + +#include +#include +#include +#include + +#include "xnnpack.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "xla/backends/cpu/runtime/xnnpack/xnn_interop.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/primitive_util.h" +#include "xla/shape.h" +#include "xla/tsl/platform/logging.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla::cpu { + +// A mapping from HloInstruction to XNNPACK subgraph tensor id. +using TensorIdMap = absl::flat_hash_map; + +//===----------------------------------------------------------------------===// +// XLA <-> XNNPACK type conversion library. +//===----------------------------------------------------------------------===// + +static absl::StatusOr XnnDatatype(const PrimitiveType& type) { + switch (type) { + case F16: + return xnn_datatype_fp16; + case F32: + return xnn_datatype_fp32; + default: + return InvalidArgument("Unsupported XNNPACK data type: %s", + primitive_util::LowercasePrimitiveTypeName(type)); + } +} + +static absl::StatusOr XnnBinaryOperator( + const HloOpcode& opcode) { + switch (opcode) { + case HloOpcode::kAdd: + return xnn_binary_add; + case HloOpcode::kMultiply: + return xnn_binary_multiply; + case HloOpcode::kSubtract: + return xnn_binary_subtract; + default: + return InvalidArgument("Unsupported XNNPACK binary operator: %s", + HloOpcodeString(opcode)); + } +} + +static std::vector XnnDimensions(const Shape& shape) { + std::vector dims; + for (auto& dim : shape.dimensions()) { + dims.push_back(dim); + } + return dims; +} + +//===----------------------------------------------------------------------===// +// XLA <-> XNNPACK emitters. +//===----------------------------------------------------------------------===// + +static absl::StatusOr FindTensorValue(const TensorIdMap& tensor_ids, + const HloInstruction* instr) { + if (auto it = tensor_ids.find(instr); it != tensor_ids.end()) { + return it->second; + } + return Internal("Can't fine XNNPACK tensor value for instruction %s", + instr->ToString()); +} + +static absl::StatusOr DefineTensorValue(xnn_subgraph_t subgraph, + const HloInstruction* instr) { + // We do not support instructions with multiple results (tuples). + if (!instr->shape().IsArray()) { + return Internal("Unsupported XNNPACK instruction shape: %s", + instr->ToString()); + } + + auto dims = XnnDimensions(instr->shape()); + TF_ASSIGN_OR_RETURN(auto type, XnnDatatype(instr->shape().element_type())); + + uint32_t tensor_id = XNN_INVALID_VALUE_ID; + uint32_t tensor_flags = 0; + + // If instruction is a root instruction of the parent computation we assign it + // an external tensor id corresponding to the result index. + const HloComputation* computation = instr->parent(); + if (computation->root_instruction() == instr) { + tensor_id = computation->num_parameters(); + tensor_flags = XNN_VALUE_FLAG_EXTERNAL_OUTPUT; + } + + XNN_RETURN_IF_ERROR(xnn_define_tensor_value( + subgraph, type, dims.size(), dims.data(), nullptr, + /*external_id=*/tensor_id, tensor_flags, &tensor_id)); + + return tensor_id; +} + +static absl::StatusOr DefineParameter(xnn_subgraph_t subgraph, + const HloInstruction* param) { + VLOG(3) << absl::StreamFormat("Define tensor value for parameter: %s", + param->ToString()); + + auto dims = XnnDimensions(param->shape()); + TF_ASSIGN_OR_RETURN(auto type, XnnDatatype(param->shape().element_type())); + + uint32_t tensor_id = param->parameter_number(); + XNN_RETURN_IF_ERROR(xnn_define_tensor_value( + subgraph, type, dims.size(), dims.data(), nullptr, + /*external_id=*/tensor_id, XNN_VALUE_FLAG_EXTERNAL_INPUT, &tensor_id)); + + return tensor_id; +} + +static absl::StatusOr DefineBinaryOp(xnn_subgraph_t subgraph, + TensorIdMap& tensor_ids, + const HloInstruction* instr) { + VLOG(3) << absl::StreamFormat("Define tensor value for binary op: %s", + instr->ToString()); + + TF_ASSIGN_OR_RETURN(auto binary_op, XnnBinaryOperator(instr->opcode())); + + TF_ASSIGN_OR_RETURN(auto lhs, FindTensorValue(tensor_ids, instr->operand(0))); + TF_ASSIGN_OR_RETURN(auto rhs, FindTensorValue(tensor_ids, instr->operand(1))); + TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); + + VLOG(3) << absl::StreamFormat(" tensors: lhs=%d, rhs=%d, out=%d", lhs, rhs, + out); + + xnn_binary_params params = {-std::numeric_limits::infinity(), + std::numeric_limits::infinity()}; + + XNN_RETURN_IF_ERROR(xnn_define_binary(subgraph, binary_op, ¶ms, lhs, rhs, + out, /*flags=*/0)); + + return out; +} + +//===----------------------------------------------------------------------===// +// Emit XNNPACK subgraph for the given HLO computation. +//===----------------------------------------------------------------------===// + +static absl::StatusOr EmitXnnSubgraph( + const HloComputation* computation) { + VLOG(3) << "Emit XNNPACK subgraph for computation: " << computation->name(); + + xnn_subgraph_t subgraph = nullptr; + XNN_RETURN_IF_ERROR(xnn_create_subgraph(/*external_value_ids=*/3, + /*flags=*/0, &subgraph)); + + // Traverse fused computation in post-order and define XNNPACK operations + // corresponding to each HLO instruction. + TensorIdMap tensor_ids; + auto instructions = computation->MakeInstructionPostOrder(); + + for (const HloInstruction* instr : instructions) { + switch (instr->opcode()) { + case HloOpcode::kParameter: { + TF_ASSIGN_OR_RETURN(tensor_ids[instr], + DefineParameter(subgraph, instr)); + } break; + + case HloOpcode::kAdd: + case HloOpcode::kSubtract: + case HloOpcode::kMultiply: { + TF_ASSIGN_OR_RETURN(tensor_ids[instr], + DefineBinaryOp(subgraph, tensor_ids, instr)); + } break; + + default: + return InvalidArgument("Unsupported XNNPACK fusion instruction: %s", + instr->ToString()); + } + } + + return subgraph; +} + +absl::StatusOr()>> +EmitXnnFusionBuilder(const HloComputation* computation) { + // We do not support non-array parameters for XNNPACK operations. + for (auto& param : computation->parameter_instructions()) { + if (!param->shape().IsArray()) { + return InvalidArgument( + "XNNPACK fusion parameters must have array shapes, got %s", + param->shape().ToString()); + } + } + + // Result also must be a single array. + if (!computation->root_instruction()->shape().IsArray()) { + return InvalidArgument("XNNPACK fusion result must be an array, got %s", + computation->root_instruction()->shape().ToString()); + } + + return [computation] { return EmitXnnSubgraph(computation); }; +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/xnn_emitter.h b/third_party/xla/xla/backends/cpu/xnn_emitter.h new file mode 100644 index 00000000000000..fb6b1b9b3ccca5 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/xnn_emitter.h @@ -0,0 +1,31 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_XNN_EMITTER_H_ +#define XLA_BACKENDS_CPU_XNN_EMITTER_H_ + +#include "xnnpack.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_computation.h" + +namespace xla::cpu { + +absl::StatusOr()>> +EmitXnnFusionBuilder(const HloComputation* computation); + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_XNN_EMITTER_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/ir/BUILD b/third_party/xla/xla/backends/gpu/codegen/ir/BUILD similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/ir/BUILD rename to third_party/xla/xla/backends/gpu/codegen/ir/BUILD diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD b/third_party/xla/xla/backends/gpu/codegen/ir/tests/BUILD similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD rename to third_party/xla/xla/backends/gpu/codegen/ir/tests/BUILD diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir b/third_party/xla/xla/backends/gpu/codegen/ir/tests/invalid.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir rename to third_party/xla/xla/backends/gpu/codegen/ir/tests/invalid.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir b/third_party/xla/xla/backends/gpu/codegen/ir/tests/ops.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir rename to third_party/xla/xla/backends/gpu/codegen/ir/tests/ops.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/types.mlir b/third_party/xla/xla/backends/gpu/codegen/ir/tests/types.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/ir/tests/types.mlir rename to third_party/xla/xla/backends/gpu/codegen/ir/tests/types.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc b/third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_attrs.cc similarity index 97% rename from third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc rename to third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_attrs.cc index 16de41e05cf5c6..d71c07ad064444 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc +++ b/third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_attrs.cc @@ -29,10 +29,10 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" #include "xla/codegen/ir/xla_ops.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/analysis/indexing_map_serialization.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td b/third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_attrs.td similarity index 97% rename from third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td rename to third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_attrs.td index 858d4ec82278ec..3708aba936e55b 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td +++ b/third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_attrs.td @@ -18,7 +18,7 @@ limitations under the License. include "mlir/IR/AttrTypeBase.td" include "mlir/IR/EnumAttr.td" -include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td" +include "xla/backends/gpu/codegen/ir/xla_gpu_dialect.td" include "xla/codegen/ir/xla_attrs.td" class XLAGPU_Attr traits = []> : diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc b/third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_dialect.cc similarity index 80% rename from third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc rename to third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_dialect.cc index b7ee5f43d9d68a..185e27a7ec88a9 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc +++ b/third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_dialect.cc @@ -17,14 +17,14 @@ limitations under the License. #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep #include "mlir/IR/OpImplementation.h" // IWYU pragma: keep #include "mlir/Transforms/InliningUtils.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" // The order of these includes is important. -#include "xla/service/gpu/fusions/ir/xla_gpu_enums.cc.inc" +#include "xla/backends/gpu/codegen/ir/xla_gpu_enums.cc.inc" #define GET_ATTRDEF_CLASSES -#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.cc.inc" +#include "xla/backends/gpu/codegen/ir/xla_gpu_attrs.cc.inc" #define GET_TYPEDEF_CLASSES -#include "xla/service/gpu/fusions/ir/xla_gpu_types.cc.inc" +#include "xla/backends/gpu/codegen/ir/xla_gpu_types.cc.inc" namespace xla { namespace gpu { @@ -48,16 +48,16 @@ struct XlaGpuOpAsmDialectInterface : public mlir::OpAsmDialectInterface { void XlaGpuDialect::initialize() { addOperations< #define GET_OP_LIST -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.cc.inc" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.cc.inc" >(); addAttributes< #define GET_ATTRDEF_LIST -#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.cc.inc" +#include "xla/backends/gpu/codegen/ir/xla_gpu_attrs.cc.inc" >(); addInterfaces(); addTypes< #define GET_TYPEDEF_LIST -#include "xla/service/gpu/fusions/ir/xla_gpu_types.cc.inc" +#include "xla/backends/gpu/codegen/ir/xla_gpu_types.cc.inc" >(); } diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.td b/third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_dialect.td similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.td rename to third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_dialect.td diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc b/third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_ops.cc similarity index 98% rename from third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc rename to third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_ops.cc index bdb4a8cc516fb8..846925a925ce12 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc +++ b/third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_ops.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" #include #include @@ -48,9 +48,9 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_dialect.cc.inc" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/analysis/indexing_map_serialization.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_dialect.cc.inc" namespace xla { namespace gpu { @@ -114,7 +114,7 @@ LogicalResult MaterializeOp::verify() { return emitOpError() << "must have thread_id dimension in both indexing maps"; } - if (map_in.GetDimVars(0).bounds != map_out.GetDimVars(0).bounds) { + if (map_in.GetDimVar(0).bounds != map_out.GetDimVar(0).bounds) { return emitOpError() << "thread_id dimension must have the same bounds in " "both indexing maps"; } @@ -376,4 +376,4 @@ void SyncThreadsOp::getAsmResultNames( } // namespace xla #define GET_OP_CLASSES -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.cc.inc" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.cc.inc" diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h b/third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_ops.h similarity index 78% rename from third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h rename to third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_ops.h index bec4116943f732..0d712d90846337 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h +++ b/third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_ops.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_IR_XLA_GPU_OPS_H_ -#define XLA_SERVICE_GPU_FUSIONS_IR_XLA_GPU_OPS_H_ +#ifndef XLA_BACKENDS_GPU_CODEGEN_IR_XLA_GPU_OPS_H_ +#define XLA_BACKENDS_GPU_CODEGEN_IR_XLA_GPU_OPS_H_ #include @@ -30,15 +30,15 @@ limitations under the License. #include "mlir/Interfaces/CallInterfaces.h" // IWYU pragma: keep #include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep #include "mlir/Interfaces/SideEffectInterfaces.h" // IWYU pragma: keep +#include "xla/backends/gpu/codegen/ir/xla_gpu_dialect.h.inc" +#include "xla/backends/gpu/codegen/ir/xla_gpu_enums.h.inc" #include "xla/codegen/ir/xla_ops.h" #include "xla/hlo/analysis/indexing_map.h" // IWYU pragma: keep -#include "xla/service/gpu/fusions/ir/xla_gpu_dialect.h.inc" -#include "xla/service/gpu/fusions/ir/xla_gpu_enums.h.inc" #define GET_ATTRDEF_CLASSES -#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc" +#include "xla/backends/gpu/codegen/ir/xla_gpu_attrs.h.inc" #define GET_TYPEDEF_CLASSES -#include "xla/service/gpu/fusions/ir/xla_gpu_types.h.inc" +#include "xla/backends/gpu/codegen/ir/xla_gpu_types.h.inc" #define GET_OP_CLASSES -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h.inc" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h.inc" -#endif // XLA_SERVICE_GPU_FUSIONS_IR_XLA_GPU_OPS_H_ +#endif // XLA_BACKENDS_GPU_CODEGEN_IR_XLA_GPU_OPS_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td b/third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_ops.td similarity index 97% rename from third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td rename to third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_ops.td index 9c184716ffb913..39e1206fa4d8f3 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td +++ b/third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_ops.td @@ -23,9 +23,9 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" -include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td" -include "xla/service/gpu/fusions/ir/xla_gpu_attrs.td" -include "xla/service/gpu/fusions/ir/xla_gpu_types.td" +include "xla/backends/gpu/codegen/ir/xla_gpu_dialect.td" +include "xla/backends/gpu/codegen/ir/xla_gpu_attrs.td" +include "xla/backends/gpu/codegen/ir/xla_gpu_types.td" class XLAGPU_Op traits = []> : Op { diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc b/third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_types.cc similarity index 97% rename from third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc rename to third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_types.cc index 057a3a5f01a16e..c40da45158347a 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc +++ b/third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_types.cc @@ -21,8 +21,8 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" // IWYU pragma: keep #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" #include "xla/hlo/analysis/indexing_map.h" // IWYU pragma: keep -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td b/third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_types.td similarity index 96% rename from third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td rename to third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_types.td index bcb9a9a66c89df..7df1aeb714973c 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td +++ b/third_party/xla/xla/backends/gpu/codegen/ir/xla_gpu_types.td @@ -19,7 +19,7 @@ limitations under the License. include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinTypes.td" include "mlir/IR/BuiltinTypeInterfaces.td" -include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td" +include "xla/backends/gpu/codegen/ir/xla_gpu_dialect.td" include "xla/codegen/ir/xla_attrs.td" class XLAGPU_Type traits = []> diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/BUILD b/third_party/xla/xla/backends/gpu/codegen/transforms/BUILD similarity index 86% rename from third_party/xla/xla/service/gpu/fusions/transforms/BUILD rename to third_party/xla/xla/backends/gpu/codegen/transforms/BUILD index 9a67d32682236c..f11e5322cc4370 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/BUILD @@ -1,4 +1,9 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") +load( + "//xla/tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -33,6 +38,7 @@ gentbl_cc_library( cc_library( name = "passes", srcs = [ + "atomic_rmw_utils.cc", "convert_float_nvidia.cc", "convert_xla_gpu_pure_call_ops.cc", "erase_dead_functions.cc", @@ -52,19 +58,20 @@ cc_library( "vectorize_loads_stores.cc", ], hdrs = ["passes.h"], + copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured(["-DTENSORFLOW_USE_ROCM=1"]), deps = [ ":passes_inc_gen", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/gpu/codegen/ir:xla_gpu", + "//xla/codegen/emitters:elemental_hlo_to_mlir", "//xla/codegen/ir:xla", "//xla/hlo/analysis:indexing_analysis", "//xla/mlir_hlo", "//xla/mlir_hlo:map_mhlo_to_scalar_op", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu/fusions/ir:xla_gpu", - "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/llvm_gpu_backend", "//xla/stream_executor:device_description", "//xla/stream_executor:semantic_version", @@ -112,5 +119,9 @@ cc_library( "@llvm-project//mlir:VectorToLLVM", "@llvm-project//mlir:VectorTransforms", "@local_tsl//tsl/platform:protobuf", - ], + ] + if_cuda_is_configured([ + "//xla/service/gpu/llvm_gpu_backend:nvptx_backend", + ]) + if_rocm_is_configured([ + "//xla/service/gpu/llvm_gpu_backend:amdgpu_backend", + ]), ) diff --git a/third_party/xla/xla/backends/gpu/codegen/transforms/atomic_rmw_utils.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/atomic_rmw_utils.cc new file mode 100644 index 00000000000000..ad1c769447e012 --- /dev/null +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/atomic_rmw_utils.cc @@ -0,0 +1,120 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/ilist.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/UseDefLists.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "xla/backends/gpu/codegen/transforms/passes.h" +#include "xla/codegen/ir/xla_ops.h" + +namespace xla { +namespace gpu { + +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" + +using mlir::Operation; +using mlir::Type; +using mlir::Value; + +namespace ml = ::mlir::LLVM; +namespace arith = ::mlir::arith; + +bool IsAtomicIntegral(Type element_type) { + if (!element_type.isInteger()) { + return false; + } + unsigned element_bitwidth = element_type.getIntOrFloatBitWidth(); + return element_bitwidth == 32 || element_bitwidth == 64; +} + +std::optional GetAtomicBinOp(Operation* modifier_op, + Type element_type) { + return llvm::TypeSwitch>( + modifier_op) + // Floating-point operations. + .Case([](arith::AddFOp op) { return ml::AtomicBinOp::fadd; }) + .Case([](arith::MaximumFOp op) { return ml::AtomicBinOp::fmax; }) + .Case([](arith::MinimumFOp op) { return ml::AtomicBinOp::fmin; }) + // Integer operations. + .Case([&](arith::AddIOp op) { + return IsAtomicIntegral(element_type) + ? std::make_optional(ml::AtomicBinOp::add) + : std::nullopt; + }) + .Case([&](arith::MaxUIOp op) { + return IsAtomicIntegral(element_type) + ? std::make_optional(ml::AtomicBinOp::umax) + : std::nullopt; + }) + .Case([&](arith::MinUIOp op) { + return IsAtomicIntegral(element_type) + ? std::make_optional(ml::AtomicBinOp::umin) + : std::nullopt; + }) + .Case([&](arith::MaxSIOp op) { + return IsAtomicIntegral(element_type) + ? std::make_optional(ml::AtomicBinOp::max) + : std::nullopt; + }) + .Case([&](arith::MinSIOp op) { + return IsAtomicIntegral(element_type) + ? std::make_optional(ml::AtomicBinOp::min) + : std::nullopt; + }) + .Default([](Operation* op) { return std::nullopt; }); +} + +// Returns atomic op modifier and the atomic bin op kind. +std::optional> GetAtomicModifierParameters( + AtomicRMWOp op) { + Type element_type = op.getInput().getType().getElementType(); + auto& operations = op.getBody()->getOperations(); + auto terminator = op.getBody()->getTerminator(); + if (operations.size() > 2) { + return std::nullopt; + } + // If the body contains only the terminator, then it is an atomic store. + if (operations.size() == 1) { + // TODO(b/336367145): Support complex atomic store. + if (element_type.isF32() || IsAtomicIntegral(element_type)) { + return std::make_pair(terminator->getOperand(0), ml::AtomicBinOp::xchg); + } + return std::nullopt; + } + // Match the kind of the atomic op. + mlir::Operation* modifier_op = &operations.front(); + auto kind = GetAtomicBinOp(modifier_op, element_type); + if (!kind.has_value()) { + return std::nullopt; + } + // Find the modifier arg that does not match the argument of `atomic_rmw` + // body. + Value block_arg = op.getBody()->getArgument(0); + Value modifier_arg = modifier_op->getOperand(0) == block_arg + ? modifier_op->getOperand(1) + : modifier_op->getOperand(0); + return std::make_pair(modifier_arg, *kind); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/convert_float_nvidia.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/convert_float_nvidia.cc similarity index 98% rename from third_party/xla/xla/service/gpu/fusions/transforms/convert_float_nvidia.cc rename to third_party/xla/xla/backends/gpu/codegen/transforms/convert_float_nvidia.cc index 8f899228f0fb94..ba41ce1ae64cc9 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/convert_float_nvidia.cc +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/convert_float_nvidia.cc @@ -30,16 +30,19 @@ limitations under the License. #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/service/gpu/fusions/transforms/passes.h" -#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "xla/backends/gpu/codegen/transforms/passes.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/semantic_version.h" +#ifdef GOOGLE_CUDA +#include "xla/service/gpu/llvm_gpu_backend/nvptx_backend.h" +#endif + namespace xla { namespace gpu { #define GEN_PASS_DEF_CONVERTFLOATNVIDIAPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" namespace { @@ -252,6 +255,7 @@ std::unique_ptr CreateConvertFloatNvidiaPass() { std::optional> MaybeCreateConvertFloatNvidiaPass( const se::DeviceDescription& device_description) { +#ifdef GOOGLE_CUDA se::SemanticVersion ptx_version = nvptx::DetermineHighestSupportedPtxVersionFromCudaVersion( device_description.runtime_version()); @@ -263,6 +267,7 @@ std::optional> MaybeCreateConvertFloatNvidiaPass( (ptx_version >= se::SemanticVersion(7, 8, 0) && cc.IsAtLeast(9, 0))) { return CreateConvertFloatNvidiaPass(); } +#endif return std::nullopt; } diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/convert_xla_gpu_pure_call_ops.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/convert_xla_gpu_pure_call_ops.cc similarity index 94% rename from third_party/xla/xla/service/gpu/fusions/transforms/convert_xla_gpu_pure_call_ops.cc rename to third_party/xla/xla/backends/gpu/codegen/transforms/convert_xla_gpu_pure_call_ops.cc index 0c9053a5570654..14739b9c9adeae 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/convert_xla_gpu_pure_call_ops.cc +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/convert_xla_gpu_pure_call_ops.cc @@ -17,14 +17,14 @@ limitations under the License. #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" namespace xla { namespace gpu { namespace { #define GEN_PASS_DEF_CONVERTPURECALLOPSPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" struct RewriteCall : mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/erase_dead_functions.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/erase_dead_functions.cc similarity index 95% rename from third_party/xla/xla/service/gpu/fusions/transforms/erase_dead_functions.cc rename to third_party/xla/xla/backends/gpu/codegen/transforms/erase_dead_functions.cc index 2c3d53834b14c9..5a2f216e135aeb 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/erase_dead_functions.cc +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/erase_dead_functions.cc @@ -21,13 +21,13 @@ limitations under the License. #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" namespace xla { namespace gpu { #define GEN_PASS_DEF_ERASEDEADFUNCTIONSPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" namespace { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/expand_float_ops.cc similarity index 99% rename from third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc rename to third_party/xla/xla/backends/gpu/codegen/transforms/expand_float_ops.cc index 6fea3a97527f9b..81cb99d66f82d9 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/expand_float_ops.cc @@ -40,9 +40,9 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/backends/gpu/codegen/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" -#include "xla/service/gpu/fusions/transforms/passes.h" #include "xla/xla_data.pb.h" namespace xla { @@ -54,7 +54,7 @@ using ma::SelectOp; using mlir::Value; #define GEN_PASS_DEF_EXPANDFLOATOPSPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" namespace { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/flatten_tensors.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/flatten_tensors.cc similarity index 99% rename from third_party/xla/xla/service/gpu/fusions/transforms/flatten_tensors.cc rename to third_party/xla/xla/backends/gpu/codegen/transforms/flatten_tensors.cc index 384d80752c7d87..5a9de31de91154 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/flatten_tensors.cc +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/flatten_tensors.cc @@ -47,9 +47,9 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" #include "xla/hlo/analysis/indexing_analysis.h" #include "xla/layout_util.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" @@ -58,7 +58,7 @@ namespace gpu { namespace { #define GEN_PASS_DEF_FLATTENTENSORSPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" using mlir::Attribute; using mlir::Location; @@ -495,7 +495,7 @@ struct RewriteFor : public OpRewritePattern { .getResult(0); } rewriter.replaceOp(op, new_results); - return mlir::failure(); + return mlir::success(); } }; diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/fuse_loops.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/fuse_loops.cc similarity index 99% rename from third_party/xla/xla/service/gpu/fusions/transforms/fuse_loops.cc rename to third_party/xla/xla/backends/gpu/codegen/transforms/fuse_loops.cc index 6af46a36e0d6a0..1fe6862689b860 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/fuse_loops.cc +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/fuse_loops.cc @@ -29,8 +29,8 @@ limitations under the License. #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" #include "xla/hlo/analysis/indexing_map.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" namespace xla { namespace gpu { @@ -44,7 +44,7 @@ using mlir::ValueRange; namespace mv = ::mlir::vector; #define GEN_PASS_DEF_FUSELOOPSPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" bool LoopsUseSameDimOps(LoopOp& loop1, LoopOp& loop2) { for (auto [dim1, dim2] : llvm::zip(loop1.getDims(), loop2.getDims())) { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/lower_tensors.cc similarity index 85% rename from third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc rename to third_party/xla/xla/backends/gpu/codegen/transforms/lower_tensors.cc index 9a314781097706..474b4572ebbb41 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/lower_tensors.cc @@ -59,8 +59,8 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/transforms/passes.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" +#include "xla/backends/gpu/codegen/transforms/passes.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -71,7 +71,7 @@ namespace gpu { namespace { #define GEN_PASS_DEF_LOWERTENSORSPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" using mlir::failure; using mlir::Location; @@ -79,10 +79,14 @@ using mlir::LogicalResult; using mlir::MLIRContext; using mlir::OpBuilder; using mlir::Operation; +using mlir::OpResult; +using mlir::OpRewritePattern; +using mlir::SmallVector; using mlir::success; using mlir::Type; using mlir::TypedValue; using mlir::TypeRange; +using mlir::UnrealizedConversionCastOp; using mlir::Value; using mlir::ValueRange; @@ -97,7 +101,7 @@ bool IsAMD(const se::DeviceDescription& device_description) { Value GetDestinationBuffer(Value dest) { while (dest.getDefiningOp()) { - int result_number = mlir::cast(dest).getResultNumber(); + int result_number = mlir::cast(dest).getResultNumber(); if (auto insert = dest.getDefiningOp()) { dest = insert.getDest(); } else if (auto scf_if = dest.getDefiningOp()) { @@ -106,7 +110,7 @@ Value GetDestinationBuffer(Value dest) { result_number); } else if (auto scf_for = dest.getDefiningOp()) { dest = scf_for.getInitArgs()[result_number]; - } else if (dest.getDefiningOp() || + } else if (dest.getDefiningOp() || dest.getDefiningOp()) { break; } else if (auto transfer_write = @@ -127,7 +131,7 @@ bool IsSupportedTransfer(Op op) { op.getPermutationMap().isMinorIdentity(); } -struct RewriteFunctionSignatures : mlir::OpRewritePattern { +struct RewriteFunctionSignatures : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite( @@ -157,11 +161,11 @@ struct RewriteFunctionSignatures : mlir::OpRewritePattern { rewriter.replaceOpWithNewOp(terminator); } - llvm::SmallVector new_operands(op.getFunctionType().getInputs()); + SmallVector new_operands(op.getFunctionType().getInputs()); for (auto&& [index, operand] : llvm::enumerate(new_operands)) { if (is_tensor(operand)) { rewriter.setInsertionPointToStart(&op.getBody().front()); - auto cast = rewriter.create( + auto cast = rewriter.create( op.getLoc(), operand, op.getArgument(index)); op.getArgument(index).replaceAllUsesExcept(cast.getResult(0), cast); operand = mlir::LLVM::LLVMPointerType::get(op.getContext()); @@ -178,6 +182,98 @@ struct RewriteFunctionSignatures : mlir::OpRewritePattern { } }; +Value GetPtr(Value value) { + if (!mlir::isa(value.getType())) { + return nullptr; + } + if (auto cast = value.getDefiningOp()) { + if (cast.getNumOperands() == 1 && cast.getNumResults() == 1 && + mlir::isa(cast.getOperand(0).getType())) { + return cast.getOperand(0); + } + } + return nullptr; +} + +struct RewriteFor : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + scf::ForOp op, mlir::PatternRewriter& rewriter) const override { + llvm::SmallBitVector inits_to_remove(op.getNumRegionIterArgs(), false); + SmallVector new_inits; + new_inits.reserve(op.getNumResults()); + SmallVector ptrs; + ptrs.reserve(op.getNumRegionIterArgs()); + for (auto [index, init] : llvm::enumerate(op.getInitArgs())) { + Value ptr = GetPtr(init); + if (ptr) { + ptrs.push_back(ptr); + inits_to_remove.set(index); + continue; + } + new_inits.push_back(init); + } + if (inits_to_remove.none()) { + return rewriter.notifyMatchFailure(op, "no args to remove"); + } + // Create new ForOp with updated init args. The empty body builder is needed + // to avoid implicit construction of scf.yield in the body block. + Location loc = op.getLoc(); + auto new_for_op = rewriter.create( + loc, op.getLowerBound(), op.getUpperBound(), op.getStep(), new_inits, + [](OpBuilder&, Location, Value, ValueRange) {}); + new_for_op->setAttrs(op->getAttrs()); + + // Collect a mapping for block arguments and results. If the init is + // removed, we can use the init of the original scf.for for replacement, + // since it was provided by the `builtin.unrealized_conversion_cast` cast to + // the correct type. + mlir::Block* new_body = new_for_op.getBody(); + mlir::Block* old_body = op.getBody(); + rewriter.setInsertionPoint(new_body, new_body->begin()); + + SmallVector bb_args_mapping; + bb_args_mapping.reserve(old_body->getNumArguments()); + bb_args_mapping.push_back(new_for_op.getInductionVar()); + SmallVector results_replacement; + results_replacement.reserve(old_body->getNumArguments()); + int num_removed_args = 0; + for (auto [index, arg] : llvm::enumerate(op.getRegionIterArgs())) { + if (!inits_to_remove.test(index)) { + bb_args_mapping.push_back( + new_for_op.getRegionIterArg(index - num_removed_args)); + results_replacement.push_back( + new_for_op.getResult(index - num_removed_args)); + continue; + } + bb_args_mapping.push_back(op.getInitArgs()[index]); + results_replacement.push_back(op.getInitArgs()[index]); + ++num_removed_args; + } + + // Move the body of the old ForOp to the new one. + rewriter.mergeBlocks(old_body, new_body, bb_args_mapping); + + // Update the terminator. + auto new_terminator = mlir::cast(new_body->getTerminator()); + SmallVector new_yielded_values; + new_yielded_values.reserve(new_terminator->getNumOperands()); + rewriter.setInsertionPoint(new_terminator); + for (auto [index, yielded_value] : + llvm::enumerate(new_terminator.getResults())) { + if (inits_to_remove.test(index)) continue; + new_yielded_values.push_back(yielded_value); + } + rewriter.replaceOpWithNewOp(new_terminator, + new_yielded_values); + + // Replace the op. + rewriter.replaceOp(op, results_replacement); + return mlir::success(); + } +}; + Value GetLinearIndex(ValueRange indices, mlir::ImplicitLocOpBuilder& b) { CHECK_LE(indices.size(), 1) << "Only 0D and 1D tensors are supported"; auto index = indices.empty() ? b.create(0) @@ -206,7 +302,7 @@ mlir::LLVM::GEPOp CreateGep(TypedValue tensor, } auto ptr = mlir::LLVM::LLVMPointerType::get(b.getContext()); auto tensor_ptr = - b.create(ptr, tensor).getResult(0); + b.create(ptr, tensor).getResult(0); mlir::LLVMTypeConverter converter(b.getContext()); auto llvm_element_type = converter.convertType(element_type); auto gep = b.create(ptr, llvm_element_type, tensor_ptr, @@ -220,7 +316,7 @@ mlir::LLVM::GEPOp CreateGep(TypedValue tensor, return CreateGep(tensor, GetLinearIndex(indices, b), b); } -struct RewriteTensorExtract : mlir::OpRewritePattern { +struct RewriteTensorExtract : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite( @@ -249,8 +345,8 @@ struct RewriteTensorExtract : mlir::OpRewritePattern { b.create(is_low_nibble, load, high_value)); } - rewriter.replaceOpWithNewOp( - op, op.getType(), load); + rewriter.replaceOpWithNewOp(op, op.getType(), + load); return success(); } }; @@ -271,8 +367,7 @@ Value PermutePairsInVector(Value vector, mlir::ImplicitLocOpBuilder& b) { return result; } -struct RewriteTransferRead - : mlir::OpRewritePattern { +struct RewriteTransferRead : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite( @@ -312,13 +407,13 @@ struct RewriteTransferRead loaded = PermutePairsInVector(loaded, b); } - rewriter.replaceOpWithNewOp( - op, op.getType(), loaded); + rewriter.replaceOpWithNewOp(op, op.getType(), + loaded); return success(); } }; -struct RewriteTensorInsert : mlir::OpRewritePattern { +struct RewriteTensorInsert : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite( @@ -351,7 +446,7 @@ struct RewriteTensorInsert : mlir::OpRewritePattern { Type ty = b.getI8Type(); Type tensor_ty = tensor_dest.getType().clone(ty); auto tensor_dest_i8 = - b.create(tensor_ty, tensor_dest) + b.create(tensor_ty, tensor_dest) .getResult(0); scalar_value = b.create(ty, scalar_value); @@ -377,8 +472,8 @@ struct RewriteTensorInsert : mlir::OpRewritePattern { body_builder.create(4, ty))); Value new_value = body_builder.create( is_low_nibble, low_updated, high_updated); - body_builder.create(new_value); - Value casted_result = b.create( + body_builder.create(new_value); + Value casted_result = b.create( tensor_dest.getType(), atomic_rmw.getResult()) .getResult(0); op.replaceAllUsesWith(casted_result); @@ -387,7 +482,7 @@ struct RewriteTensorInsert : mlir::OpRewritePattern { mlir::LLVMTypeConverter converter(getContext()); auto llvm_type = converter.convertType(scalar_value.getType()); scalar_value = - b.create(llvm_type, scalar_value) + b.create(llvm_type, scalar_value) .getResult(0); b.create(scalar_value, gep); op.replaceAllUsesWith(op.getDest()); @@ -398,8 +493,7 @@ struct RewriteTensorInsert : mlir::OpRewritePattern { } }; -struct RewriteTransferWrite - : mlir::OpRewritePattern { +struct RewriteTransferWrite : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite( @@ -430,9 +524,8 @@ struct RewriteTransferWrite mlir::LLVMTypeConverter converter(getContext()); auto llvm_type = converter.convertType(vector_value.getType()); - vector_value = - b.create(llvm_type, vector_value) - .getResult(0); + vector_value = b.create(llvm_type, vector_value) + .getResult(0); b.create(vector_value, gep); rewriter.replaceOp(op, mlir::ValueRange{op.getSource()}); @@ -440,7 +533,7 @@ struct RewriteTransferWrite } }; -struct RewriteCall : mlir::OpRewritePattern { +struct RewriteCall : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite( @@ -456,7 +549,7 @@ struct RewriteCall : mlir::OpRewritePattern { op.setOperand( index, rewriter - .create( + .create( op.getLoc(), mlir::LLVM::LLVMPointerType::get(op.getContext()), arg) .getResult(0)); @@ -515,7 +608,7 @@ mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value, addr_space); } -struct RewriteAllocateShared : mlir::OpRewritePattern { +struct RewriteAllocateShared : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite( @@ -531,7 +624,7 @@ struct RewriteAllocateShared : mlir::OpRewritePattern { rewriter.setInsertionPoint(op); auto addr = rewriter.create(op.getLoc(), global); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getResult().getType(), rewriter .create( @@ -542,8 +635,7 @@ struct RewriteAllocateShared : mlir::OpRewritePattern { } }; -struct RewriteNonScalarConstants - : mlir::OpRewritePattern { +struct RewriteNonScalarConstants : OpRewritePattern { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite( @@ -568,7 +660,7 @@ struct RewriteNonScalarConstants rewriter.setInsertionPoint(op); auto addr = rewriter.create(op.getLoc(), global); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getResult().getType(), rewriter .create( @@ -579,7 +671,7 @@ struct RewriteNonScalarConstants } }; -struct RewriteSyncThreads : mlir::OpRewritePattern { +struct RewriteSyncThreads : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite( @@ -592,8 +684,7 @@ struct RewriteSyncThreads : mlir::OpRewritePattern { // TODO(jreiffers): Generalize this to support index switches with some used // results and upstream it as a canonicalization pattern. -struct RemoveUnusedIndexSwitchResults - : mlir::OpRewritePattern { +struct RemoveUnusedIndexSwitchResults : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite( @@ -626,7 +717,8 @@ bool IsAtomicIntegral(Type element_type) { return element_bitwidth == 32 || element_bitwidth == 64; } -Value CreateBitcast(mlir::ImplicitLocOpBuilder& b, Value value, Type ty) { +Value CreateBitcast(mlir::ImplicitLocOpBuilder& b, mlir::Operation* op, + Value value, Type ty) { if (value.getType().isIntOrFloat() && ty.isIntOrFloat()) { return b.create(ty, value); } @@ -637,22 +729,26 @@ Value CreateBitcast(mlir::ImplicitLocOpBuilder& b, Value value, Type ty) { Type llvm_input_ty = converter.convertType(value.getType()); Type llvm_result_ty = converter.convertType(ty); Type ptr_ty = mlir::LLVM::LLVMPointerType::get(b.getContext()); + auto func = op->getParentOfType(); + // AMDGPU backend needs allocas to be out of loops. + // Move them to the entry block to be on the safe side. + auto entry_builder = mlir::ImplicitLocOpBuilder::atBlockBegin( + b.getLoc(), &func.getBody().front(), b.getListener()); Value llvm_value = - b.create(llvm_input_ty, value) - .getResult(0); - Value alloca = b.create( + b.create(llvm_input_ty, value).getResult(0); + Value alloca = entry_builder.create( ptr_ty, llvm_input_ty, b.create(b.getI32Type(), 1)); b.create(llvm_value, alloca); auto result = b.create(llvm_result_ty, alloca).getResult(); - return b.create(ty, result).getResult(0); + return b.create(ty, result).getResult(0); }; -class RewriteAtomicRMW : public mlir::OpRewritePattern { +class RewriteAtomicRMW : public OpRewritePattern { public: RewriteAtomicRMW(mlir::MLIRContext* context, const se::DeviceDescription* device_description) - : mlir::OpRewritePattern(context), + : OpRewritePattern(context), device_description_(device_description) {} LogicalResult matchAndRewrite( @@ -665,71 +761,6 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern { } private: - // Returns atomic op modifier and the atomic bin op kind. - std::optional> GetAtomicModifierParameters( - AtomicRMWOp op) const { - Type element_type = op.getInput().getType().getElementType(); - auto& operations = op.getBody()->getOperations(); - auto terminator = op.getBody()->getTerminator(); - if (operations.size() > 2) { - return std::nullopt; - } - // If the body contains only the terminator, then it is an atomic store. - if (operations.size() == 1) { - // TODO(b/336367145): Support complex atomic store. - if (element_type.isF32() || IsAtomicIntegral(element_type)) { - return std::make_pair(terminator->getOperand(0), ml::AtomicBinOp::xchg); - } - return std::nullopt; - } - // Match the kind of the atomic op. - mlir::Operation* modifier_op = &operations.front(); - std::optional kind = - llvm::TypeSwitch>( - modifier_op) - // Floating-point operations. - .Case([](arith::AddFOp op) { return ml::AtomicBinOp::fadd; }) - .Case([](arith::MaximumFOp op) { return ml::AtomicBinOp::fmax; }) - .Case([](arith::MinimumFOp op) { return ml::AtomicBinOp::fmin; }) - // Integer operations. - .Case([&](arith::AddIOp op) { - return IsAtomicIntegral(element_type) - ? std::make_optional(ml::AtomicBinOp::add) - : std::nullopt; - }) - .Case([&](arith::MaxUIOp op) { - return IsAtomicIntegral(element_type) - ? std::make_optional(ml::AtomicBinOp::umax) - : std::nullopt; - }) - .Case([&](arith::MinUIOp op) { - return IsAtomicIntegral(element_type) - ? std::make_optional(ml::AtomicBinOp::umin) - : std::nullopt; - }) - .Case([&](arith::MaxSIOp op) { - return IsAtomicIntegral(element_type) - ? std::make_optional(ml::AtomicBinOp::max) - : std::nullopt; - }) - .Case([&](arith::MinSIOp op) { - return IsAtomicIntegral(element_type) - ? std::make_optional(ml::AtomicBinOp::min) - : std::nullopt; - }) - .Default([](Operation* op) { return std::nullopt; }); - if (!kind.has_value()) { - return std::nullopt; - } - // Find the modifier arg that does not match the argument of `atomic_rmw` - // body. - Value block_arg = op.getBody()->getArgument(0); - Value modifier_arg = modifier_op->getOperand(0) == block_arg - ? modifier_op->getOperand(1) - : modifier_op->getOperand(0); - return std::make_pair(modifier_arg, *kind); - } - // Certain computations, such as floating-point addition and integer // maximization, can be simply implemented using an LLVM atomic instruction. // If "computation" is one of this kind, emits code to do that and returns @@ -1008,7 +1039,7 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern { b.create(old_value, shift)); input_value = b.create(result_ty, short_value); } else { - input_value = CreateBitcast(b, old_value, result_ty); + input_value = CreateBitcast(b, op, old_value, result_ty); } // Perform computation on the loaded input value. @@ -1028,7 +1059,7 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern { b.create(b.create(old_value, mask), b.create(cast_value, shift)); } else { - new_value = CreateBitcast(b, result, atomic_ty); + new_value = CreateBitcast(b, op, result, atomic_ty); } // Try saving the result atomically, retry if failed. @@ -1073,18 +1104,19 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase { .add(mlir_context); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily( - getOperation(), std::move(tensor_patterns)))) { + if (mlir::failed(mlir::applyPatternsGreedily(getOperation(), + std::move(tensor_patterns)))) { signalPassFailure(); return; } mlir::RewritePatternSet function_patterns(mlir_context); function_patterns.add(mlir_context); + RemoveUnusedIndexSwitchResults, RewriteFor>( + mlir_context); scf::ForOp::getCanonicalizationPatterns(function_patterns, mlir_context); scf::IfOp::getCanonicalizationPatterns(function_patterns, mlir_context); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily( + if (mlir::failed(mlir::applyPatternsGreedily( getOperation(), std::move(function_patterns)))) { signalPassFailure(); return; @@ -1095,8 +1127,7 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase { while (auto gep = addr.getDefiningOp()) { addr = gep.getBase(); } - while (auto cast = - addr.getDefiningOp()) { + while (auto cast = addr.getDefiningOp()) { addr = cast.getOperand(0); } if (addr.getDefiningOp() || diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/lower_to_llvm.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/lower_to_llvm.cc similarity index 98% rename from third_party/xla/xla/service/gpu/fusions/transforms/lower_to_llvm.cc rename to third_party/xla/xla/backends/gpu/codegen/transforms/lower_to_llvm.cc index b9a811104c5b4d..89c4b30eacfd8f 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/lower_to_llvm.cc +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/lower_to_llvm.cc @@ -42,7 +42,7 @@ limitations under the License. #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "xla/service/gpu/fusions/transforms/passes.h" +#include "xla/backends/gpu/codegen/transforms/passes.h" #include "xla/stream_executor/device_description.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep @@ -51,7 +51,7 @@ namespace gpu { namespace { #define GEN_PASS_DEF_LOWERTOLLVMPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" class LowerToLLVMPass : public impl::LowerToLLVMPassBase { public: diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/lower_xla_gpu_to_scf.cc similarity index 96% rename from third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc rename to third_party/xla/xla/backends/gpu/codegen/transforms/lower_xla_gpu_to_scf.cc index 82a7be70a5011b..d98e41bbfb2914 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/lower_xla_gpu_to_scf.cc @@ -42,10 +42,10 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" +#include "xla/backends/gpu/codegen/transforms/passes.h" +#include "xla/codegen/emitters/elemental_hlo_to_mlir.h" #include "xla/hlo/analysis/indexing_map.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/transforms/passes.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/util.h" @@ -55,7 +55,7 @@ namespace { #define GEN_PASS_DEF_LOWERXLAGPUTOSCFPASS #define GEN_PASS_DEF_LOWERXLAGPULOOPSTOSCFPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" using mlir::ImplicitLocOpBuilder; using mlir::Location; @@ -213,14 +213,13 @@ struct RewriteXlaGpuLoop : mlir::OpRewritePattern { IndexingMap indexing_map = op.getIndexingMap(); SmallVector lbs, ubs, steps; - mlir_converter::GetLoopBoundsFromIndexingMap(b, indexing_map, &lbs, &ubs, - &steps); + emitters::GetLoopBoundsFromIndexingMap(b, indexing_map, &lbs, &ubs, &steps); mlir::scf::LoopNest loop_nest = mlir::scf::buildLoopNest( b, loc, lbs, ubs, steps, op.getInits(), [&](OpBuilder& nested_builder, Location loc, ValueRange symbol_values, ValueRange iter_args) -> mlir::scf::ValueVector { mlir::ImplicitLocOpBuilder nested_b(loc, nested_builder); - auto is_in_bounds = mlir_converter::CheckConstraints( + auto is_in_bounds = emitters::CheckConstraints( indexing_map, op.getDims(), symbol_values, nested_b); auto if_op = nested_b.create( is_in_bounds, @@ -228,10 +227,9 @@ struct RewriteXlaGpuLoop : mlir::OpRewritePattern { ImplicitLocOpBuilder then_b(then_loc, then_builder); mlir::IRMapping mapping; mapping.map(op.getInductionVars(), symbol_values); - mapping.map( - op.getIndexingMapResults(), - mlir_converter::ApplyIndexing(indexing_map, op.getDims(), - symbol_values, then_b)); + mapping.map(op.getIndexingMapResults(), + emitters::ApplyIndexing(indexing_map, op.getDims(), + symbol_values, then_b)); mapping.map(op.getRegionIterArgs(), iter_args); mlir::Block* old_block = op.getBody(); for (auto& old_op : old_block->without_terminator()) { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/merge_pointers_to_same_slice.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/merge_pointers_to_same_slice.cc similarity index 98% rename from third_party/xla/xla/service/gpu/fusions/transforms/merge_pointers_to_same_slice.cc rename to third_party/xla/xla/backends/gpu/codegen/transforms/merge_pointers_to_same_slice.cc index 50193e3a2a29f4..83dffe970d4794 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/merge_pointers_to_same_slice.cc +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/merge_pointers_to_same_slice.cc @@ -30,7 +30,7 @@ namespace xla { namespace gpu { #define GEN_PASS_DEF_MERGEPOINTERSTOSAMESLICEPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" namespace { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/optimize_loops.cc similarity index 99% rename from third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc rename to third_party/xla/xla/backends/gpu/codegen/transforms/optimize_loops.cc index 029c67dbe0660a..63677821ead8bd 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/optimize_loops.cc @@ -42,15 +42,15 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" #include "xla/hlo/analysis/indexing_map.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/gpu_fusible.h" namespace xla { namespace gpu { #define GEN_PASS_DEF_OPTIMIZELOOPSPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" namespace { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/passes.h b/third_party/xla/xla/backends/gpu/codegen/transforms/passes.h similarity index 75% rename from third_party/xla/xla/service/gpu/fusions/transforms/passes.h rename to third_party/xla/xla/backends/gpu/codegen/transforms/passes.h index c05a1d1ce19a85..de12227f94c0cf 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/passes.h +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/passes.h @@ -12,31 +12,30 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_H_ -#define XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_H_ +#ifndef XLA_BACKENDS_GPU_CODEGEN_TRANSFORMS_PASSES_H_ +#define XLA_BACKENDS_GPU_CODEGEN_TRANSFORMS_PASSES_H_ #include #include #include #include +#include +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" -#include "xla/hlo/analysis/indexing_map.h" +#include "xla/codegen/ir/xla_ops.h" #include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { #define GEN_PASS_DECL -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" -// Returns the range of a given value, if it can be statically determined. -std::optional GetRange(mlir::Value value); - -// Returns the range for the induction variable, if it can be statically -// determined. -std::optional GetIVRange(mlir::Value iv); +// Returns atomic op modifier and the atomic bin op kind. +std::optional> +GetAtomicModifierParameters(AtomicRMWOp op); std::unique_ptr CreateConvertFloatNvidiaPass(); std::optional> MaybeCreateConvertFloatNvidiaPass( @@ -63,12 +62,15 @@ std::unique_ptr CreatePropagateSliceIndicesPass(); std::unique_ptr CreateSimplifyAffinePass(); std::unique_ptr CreateSimplifyArithPass(); std::unique_ptr CreateUnswitchLoopsPass(); -std::unique_ptr CreateVectorizeLoadsAndStoresPass(); +std::unique_ptr CreateVectorizeLoadsAndStoresPass( + const std::string& gpu_device_info = ""); +std::unique_ptr CreateVectorizeLoadsAndStoresPass( + const se::DeviceDescription& device_description); #define GEN_PASS_REGISTRATION -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_H_ +#endif // XLA_BACKENDS_GPU_CODEGEN_TRANSFORMS_PASSES_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/passes.td b/third_party/xla/xla/backends/gpu/codegen/transforms/passes.td similarity index 98% rename from third_party/xla/xla/service/gpu/fusions/transforms/passes.td rename to third_party/xla/xla/backends/gpu/codegen/transforms/passes.td index 1b5ffbdb24636e..53b20387c62aad 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/passes.td +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/passes.td @@ -256,6 +256,11 @@ def VectorizeLoadsAndStoresPass : "mlir::vector::VectorDialect", ]; + let options = [ + Option<"gpu_device_info_", "gpu_device_info", "std::string", /*default=*/"", + "Serialized stream_executor::GPUDeviceInfo proto.">, + ]; + let constructor = "CreateVectorizeLoadsAndStoresPass()"; } diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/peel_loops.cc similarity index 97% rename from third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc rename to third_party/xla/xla/backends/gpu/codegen/transforms/peel_loops.cc index 9f533c87447fea..3446ad5544f93b 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/peel_loops.cc @@ -32,16 +32,16 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/analysis/indexing_map_serialization.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" namespace xla { namespace gpu { namespace { #define GEN_PASS_DEF_PEELLOOPSPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" using mlir::Location; using mlir::OpBuilder; diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/propagate_slice_indices.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/propagate_slice_indices.cc similarity index 95% rename from third_party/xla/xla/service/gpu/fusions/transforms/propagate_slice_indices.cc rename to third_party/xla/xla/backends/gpu/codegen/transforms/propagate_slice_indices.cc index 31a637900c8a7a..a23bf00f70d3ac 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/propagate_slice_indices.cc +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/propagate_slice_indices.cc @@ -19,13 +19,13 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" -#include "xla/service/gpu/fusions/transforms/passes.h" +#include "xla/backends/gpu/codegen/transforms/passes.h" namespace xla { namespace gpu { #define GEN_PASS_DEF_PROPAGATESLICEINDICESPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" namespace { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/simplify_affine.cc similarity index 81% rename from third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc rename to third_party/xla/xla/backends/gpu/codegen/transforms/simplify_affine.cc index bee8dc383a0848..20c00ca28672fd 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/simplify_affine.cc @@ -41,10 +41,10 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" +#include "xla/backends/gpu/codegen/transforms/passes.h" #include "xla/codegen/ir/xla_ops.h" #include "xla/hlo/analysis/indexing_map.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/transforms/passes.h" namespace xla { namespace gpu { @@ -70,7 +70,7 @@ using mlir::affine::AffineApplyOp; namespace arith = mlir::arith; #define GEN_PASS_DEF_SIMPLIFYAFFINEPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" int Distance(ImplicitLocOpBuilder& builder, Value a) { auto* block = builder.getInsertionBlock(); @@ -314,65 +314,6 @@ struct SimplifyAffinePass } // namespace -std::optional GetRange(mlir::Value value) { - auto attr_to_range = [](mlir::Attribute attr) -> std::optional { - if (!attr) { - return std::nullopt; - } - auto values = llvm::to_vector( - mlir::cast(attr).getAsValueRange()); - return {{values[0].getSExtValue(), values[1].getSExtValue()}}; - }; - - if (auto apply = value.getDefiningOp()) { - return apply.getIndexingMap().GetRangeEvaluator().ComputeExpressionRange( - apply.getIndexingMap().GetAffineMap().getResult( - mlir::cast(value).getResultNumber())); - } else if (auto cst = value.getDefiningOp()) { - return {{cst.value(), cst.value()}}; - } else if (value.getDefiningOp()) { - return attr_to_range(value.getDefiningOp()->getAttr("xla.range")); - } - - auto bbarg = mlir::dyn_cast(value); - if (!bbarg) { - return std::nullopt; - } - - auto parent = bbarg.getParentBlock()->getParentOp(); - if (auto func_op = mlir::dyn_cast(parent)) { - return attr_to_range(func_op.getArgAttr(bbarg.getArgNumber(), "xla.range")); - } - return GetIVRange(value); -} - -std::optional GetIVRange(mlir::Value iv) { - auto bbarg = mlir::dyn_cast(iv); - if (!bbarg) { - return std::nullopt; - } - auto parent = bbarg.getParentBlock()->getParentOp(); - if (auto for_op = mlir::dyn_cast(parent)) { - llvm::APInt lb, ub; - if (mlir::matchPattern(for_op.getLowerBound(), mlir::m_ConstantInt(&lb)) && - mlir::matchPattern(for_op.getUpperBound(), mlir::m_ConstantInt(&ub))) { - return {{lb.getSExtValue(), ub.getSExtValue() - 1}}; - } - } - if (auto loop_op = mlir::dyn_cast(parent)) { - const auto& indexing_map = loop_op.getIndexingMap(); - if (bbarg.getArgNumber() >= loop_op.getNumInductionVars() && - bbarg.getArgNumber() < - loop_op.getNumInductionVars() + indexing_map.GetNumResults()) { - RangeEvaluator range_evaluator = indexing_map.GetRangeEvaluator(); - return range_evaluator.ComputeExpressionRange( - indexing_map.GetAffineMap().getResult(bbarg.getArgNumber() - - loop_op.getNumInductionVars())); - } - } - return std::nullopt; -} - std::unique_ptr CreateSimplifyAffinePass() { return std::make_unique(); } diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/simplify_arith.cc similarity index 96% rename from third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc rename to third_party/xla/xla/backends/gpu/codegen/transforms/simplify_arith.cc index 95f9ebc2ff0338..c8c92d0d44df5c 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/simplify_arith.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -31,16 +32,16 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" // IWYU pragma: keep +#include "xla/backends/gpu/codegen/transforms/passes.h" #include "xla/hlo/analysis/indexing_map.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/transforms/passes.h" namespace xla { namespace gpu { namespace { #define GEN_PASS_DEF_SIMPLIFYARITHPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" using mlir::LogicalResult; using mlir::OpRewritePattern; @@ -366,6 +367,14 @@ class SimplifyArithPass mlir::applyPatternsAndFoldGreedily(func, std::move(patterns)))) { signalPassFailure(); } + + mlir::RewritePatternSet scf_patterns(ctx); + mlir::scf::ForOp::getCanonicalizationPatterns(scf_patterns, ctx); + mlir::scf::IfOp::getCanonicalizationPatterns(scf_patterns, ctx); + if (mlir::failed(mlir::applyPatternsAndFoldGreedily( + func, std::move(scf_patterns)))) { + signalPassFailure(); + } } }; diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/BUILD similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/BUILD diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_float_nvidia.mlir b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/convert_float_nvidia.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_float_nvidia.mlir rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/convert_float_nvidia.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_xla_gpu_pure_calls.mlir b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/convert_xla_gpu_pure_calls.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_xla_gpu_pure_calls.mlir rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/convert_xla_gpu_pure_calls.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/flatten_tensors.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/flatten_tensors.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/fuse_loops.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/fuse_loops.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/inlining.mlir b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/inlining.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/inlining.mlir rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/inlining.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir similarity index 95% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir index 455837a698bee0..646c7a00ff756f 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir @@ -732,3 +732,35 @@ func.func @int4_constant(%arg0: tensor<3xi4>, %arg1: index) -> i4 { // CHECK: llvm.mlir.global private constant // CHECK-SAME: dense<[18, 48]> // CHECK-LABEL: @int4_constant + +// ----- + +func.func @for_op(%arg0: tensor<500xf32>) -> f32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %cst = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32> + %for:2 = scf.for %i = %c0 to %c2 step %c1 + iter_args(%cst_ = %cst, %arg_ = %arg0) + -> (vector<4xf32>, tensor<500xf32>) { + %nested_for:2 = scf.for %j = %c0 to %c2 step %c1 + iter_args(%cst__ = %cst_, %arg__ = %arg_) + -> (vector<4xf32>, tensor<500xf32>) { + %index = arith.addi %i, %j : index + %tensor_elem = tensor.extract %arg__[%index] : tensor<500xf32> + %vector_elem = vector.extract %cst__[%index] : f32 from vector<4xf32> + %sum = arith.addf %tensor_elem, %vector_elem : f32 + %v_update = vector.insert %sum, %cst__[%index] : f32 into vector<4xf32> + %t_update = tensor.insert %sum into %arg__[%index] : tensor<500xf32> + scf.yield %v_update, %t_update : vector<4xf32>, tensor<500xf32> + } + scf.yield %nested_for#0, %nested_for#1 : vector<4xf32>, tensor<500xf32> + } + %result = tensor.extract %for#1[%c0] : tensor<500xf32> + func.return %result : f32 +} + +// CHECK-LABEL: @for_op +// CHECK: scf.for {{.*}} -> (vector<4xf32>) { +// CHECK-NEXT: scf.for {{.*}} -> (vector<4xf32>) { \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/lower_xla_gpu_loops_to_scf.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/lower_xla_gpu_loops_to_scf.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/lower_xla_gpu_to_scf.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/lower_xla_gpu_to_scf.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/merge_pointers_to_same_slice.mlir b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/merge_pointers_to_same_slice.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/merge_pointers_to_same_slice.mlir rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/merge_pointers_to_same_slice.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/optimize_loops.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/optimize_loops.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/peel_loops.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/peel_loops.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/propagate_slice_indices.mlir b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/propagate_slice_indices.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/propagate_slice_indices.mlir rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/propagate_slice_indices.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/simplify_affine.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/simplify_affine.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/simplify_arith.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/simplify_arith.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/unswitch_loops.mlir b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/unswitch_loops.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/unswitch_loops.mlir rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/unswitch_loops.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/vectorize_loads_stores.mlir similarity index 98% rename from third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir rename to third_party/xla/xla/backends/gpu/codegen/transforms/tests/vectorize_loads_stores.mlir index a3b7e816bb05fb..d5d3d0a74fe4a2 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/tests/vectorize_loads_stores.mlir @@ -1,5 +1,6 @@ // RUN: emitters_opt -allow-unregistered-dialect %s -split-input-file \ -// RUN: -xla-gpu-vectorize-loads-stores -cse -canonicalize | FileCheck %s +// RUN: -xla-gpu-vectorize-loads-stores="gpu_device_info='cuda_compute_capability {major: 6}'" -cse -canonicalize \ +// RUN: | FileCheck %s #map = #xla.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," "domain: d0 in [0, 63], s0 in [0, 1]"> @@ -251,7 +252,7 @@ func.func @layout(%arg0: tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>>) -> (f3 func.func @simple_write(%arg0: tensor<64xf32>) -> tensor<64xf32> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c4 = arith.constant 2 : index + %c4 = arith.constant 4 : index %cst = arith.constant 0.0 : f32 %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<64xf32> { %inserted = tensor.insert %cst into %iter[%j] : tensor<64xf32> @@ -263,6 +264,7 @@ func.func @simple_write(%arg0: tensor<64xf32>) -> tensor<64xf32> { // CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}) // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[V:.*]] = scf.for +// CHECK-SAME: (vector<4xf32>) // CHECK-NEXT: vector.insert // CHECK-NEXT: scf.yield // CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[V]], %[[ARG0]][%[[C0]]] diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/unswitch_loops.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/unswitch_loops.cc similarity index 98% rename from third_party/xla/xla/service/gpu/fusions/transforms/unswitch_loops.cc rename to third_party/xla/xla/backends/gpu/codegen/transforms/unswitch_loops.cc index d514a678624162..d35911464aaf2b 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/unswitch_loops.cc +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/unswitch_loops.cc @@ -30,7 +30,7 @@ namespace xla { namespace gpu { #define GEN_PASS_DEF_UNSWITCHLOOPSPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" namespace { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc b/third_party/xla/xla/backends/gpu/codegen/transforms/vectorize_loads_stores.cc similarity index 88% rename from third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc rename to third_party/xla/xla/backends/gpu/codegen/transforms/vectorize_loads_stores.cc index 34e90b1ebb3368..19e6b7faf5e36a 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc +++ b/third_party/xla/xla/backends/gpu/codegen/transforms/vectorize_loads_stores.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include "llvm/ADT/APInt.h" @@ -40,14 +41,16 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/backends/gpu/codegen/transforms/passes.h" +#include "xla/codegen/ir/xla_ops.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { namespace { #define GEN_PASS_DEF_VECTORIZELOADSANDSTORESPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" using mlir::Value; @@ -326,21 +329,45 @@ class VectorizeLoadsAndStoresPass : public impl::VectorizeLoadsAndStoresPassBase< VectorizeLoadsAndStoresPass> { public: + explicit VectorizeLoadsAndStoresPass( + const VectorizeLoadsAndStoresPassOptions& options) + : VectorizeLoadsAndStoresPassBase(options) {} + + explicit VectorizeLoadsAndStoresPass( + const se::DeviceDescription& device_description) + : device_description_(device_description) {} + void runOnOperation() override { - mlir::RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (!gpu_device_info_.empty()) { + se::GpuDeviceInfoProto device_info; + CHECK(tsl::protobuf::TextFormat::ParseFromString(gpu_device_info_, + &device_info)); + device_description_ = se::DeviceDescription(device_info); + } + mlir::MLIRContext* mlir_context = &getContext(); + mlir::RewritePatternSet patterns(mlir_context); + patterns.add(mlir_context); + if (mlir::failed( + mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } } + + se::DeviceDescription device_description_; }; } // namespace -std::unique_ptr> -CreateVectorizeLoadsAndStoresPass() { - return std::make_unique(); +std::unique_ptr<::mlir::Pass> CreateVectorizeLoadsAndStoresPass( + const std::string& gpu_device_info) { + VectorizeLoadsAndStoresPassOptions options; + options.gpu_device_info_ = gpu_device_info; + return std::make_unique(options); +} + +std::unique_ptr CreateVectorizeLoadsAndStoresPass( + const se::DeviceDescription& device_description) { + return std::make_unique(device_description); } } // namespace gpu diff --git a/third_party/xla/xla/backends/gpu/collectives/BUILD b/third_party/xla/xla/backends/gpu/collectives/BUILD index 7121377146e064..7dad08220ed4c7 100644 --- a/third_party/xla/xla/backends/gpu/collectives/BUILD +++ b/third_party/xla/xla/backends/gpu/collectives/BUILD @@ -40,6 +40,7 @@ cc_library( "//xla/core/collectives:rank_id", "//xla/service:lockable", "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -59,6 +60,7 @@ cc_library( "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -76,6 +78,7 @@ xla_cc_test( "//xla/service:global_device_id", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -83,9 +86,9 @@ xla_cc_test( ) cc_library( - name = "gpu_clique_locking", - srcs = ["gpu_clique_locking.cc"], - hdrs = ["gpu_clique_locking.h"], + name = "gpu_cliques", + srcs = ["gpu_cliques.cc"], + hdrs = ["gpu_cliques.h"], deps = [ ":gpu_clique", ":gpu_clique_key", @@ -103,11 +106,16 @@ cc_library( "//xla/service:lockable", "//xla/service:rendezvous", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -115,11 +123,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:hash", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -140,6 +144,8 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:casts", @@ -200,16 +206,17 @@ cc_library( "//xla/core/collectives:collectives_registry", "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:casts", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", ] + if_cuda_is_configured([ "@local_config_nccl//:nccl", ]) + if_rocm_is_configured([ @@ -235,17 +242,20 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_stream", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:casts", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", ] + if_cuda_is_configured([ "@local_config_nccl//:nccl", ]) + if_rocm_is_configured([ diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_clique.cc b/third_party/xla/xla/backends/gpu/collectives/gpu_clique.cc index 36bfe1015559f5..affc92419f5cc3 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_clique.cc +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_clique.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/container/btree_map.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_key.cc b/third_party/xla/xla/backends/gpu/collectives/gpu_clique_key.cc index d949fb52da85a1..378ae084038b0d 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_key.cc +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_clique_key.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/hash/hash.h" +#include "absl/log/check.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_key_test.cc b/third_party/xla/xla/backends/gpu/collectives/gpu_clique_key_test.cc index f55b72bdc18c42..c9a584e47ad952 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_key_test.cc +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_clique_key_test.cc @@ -23,6 +23,8 @@ limitations under the License. #include #include +#include +#include #include "absl/container/btree_map.h" #include "xla/core/collectives/clique_id.h" #include "xla/service/global_device_id.h" @@ -150,18 +152,20 @@ TEST(GpuCliqueKeyGetterTest, ToString) { } TEST(GpuCliqueIdGettersTest, Data) { - std::array id; + std::array id; std::fill(id.begin(), id.end(), 0x01); + id[128] = 0; CliqueId clique_id(id.data()); EXPECT_EQ(std::memcmp(clique_id.data().data(), id.data(), 128), 0); } TEST(GpuCliqueIdStringTest, ToString) { - std::array id; + std::array id; std::fill(id.begin(), id.end(), 0x01); + id[128] = 0; CliqueId clique_id(id.data()); for (int i = 0; i < 128; ++i) { - EXPECT_THAT(clique_id.ToString().substr(i, 1), "\x1"); + EXPECT_EQ(clique_id.ToString()[i], id[i]); } } diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_locking.cc b/third_party/xla/xla/backends/gpu/collectives/gpu_cliques.cc similarity index 97% rename from third_party/xla/xla/backends/gpu/collectives/gpu_clique_locking.cc rename to third_party/xla/xla/backends/gpu/collectives/gpu_cliques.cc index be53a701c1192e..77398835588d82 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_locking.cc +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_cliques.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/backends/gpu/collectives/gpu_clique_locking.h" +#include "xla/backends/gpu/collectives/gpu_cliques.h" #include #include @@ -29,6 +29,7 @@ limitations under the License. #include "absl/container/btree_map.h" #include "absl/container/node_hash_map.h" #include "absl/functional/function_ref.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -51,12 +52,12 @@ limitations under the License. #include "xla/service/lockable.h" #include "xla/service/rendezvous.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" #include "tsl/platform/hash.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::gpu { @@ -196,7 +197,6 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, const GpuCollectives::CliqueIdCallback& clique_id_callback, int32_t num_local_participants, RankId rank, const GpuCollectives::Config& config) { - int nranks = clique_key.devices().size(); VLOG(3) << "Initialize GPU clique " << clique_key.ToString() << " rank #" << rank << "; num_local_participants=" << num_local_participants; @@ -239,8 +239,7 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, TF_ASSIGN_OR_RETURN( std::vector> created_comms, - collectives->CreateCommunicators(nranks, clique_key, clique_id, ranks, - config)); + collectives->CreateCommunicators(clique_key, clique_id, ranks, config)); absl::btree_map> comms; for (size_t i = 0; i < ranks.size(); ++i) { @@ -293,7 +292,7 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, // processes are not able to synchronize device activity. RendezvousArg rendezvous_arg = std::make_pair(device_rank, synchronized); - return RendezvousSingle>( + return Rendezvous>( initialization_rendezvous_name, rendezvous_key, rendezvous_arg, num_local_participants, initialize, WarnStuckTimeout(), TerminateTimeout()); @@ -432,7 +431,7 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, rank.value(), clique_key.ToString(), run_id.ToInt(), parent_clique_key.ToString()); - return RendezvousSingle>( + return Rendezvous>( initialization_rendezvous_name, rendezvous_key, rank_pair, num_local_participants, split, WarnStuckTimeout(), TerminateTimeout()); } @@ -467,7 +466,7 @@ absl::StatusOr> AcquireGpuClique( TF_ASSIGN_OR_RETURN( std::shared_ptr clique, - RendezvousSingle>( + Rendezvous>( rendezvous_name, rendezvous_key, num_local_participants, [&] { tsl::profiler::TraceMe trace("LockGpuClique"); diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_locking.h b/third_party/xla/xla/backends/gpu/collectives/gpu_cliques.h similarity index 94% rename from third_party/xla/xla/backends/gpu/collectives/gpu_clique_locking.h rename to third_party/xla/xla/backends/gpu/collectives/gpu_cliques.h index d9e3f6b7b6d340..9825949cf37f2b 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_locking.h +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_cliques.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_LOCKING_H_ -#define XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_LOCKING_H_ +#ifndef XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUES_H_ +#define XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUES_H_ #include #include @@ -70,4 +70,4 @@ absl::StatusOr> AcquireGpuClique( } // namespace xla::gpu -#endif // XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_LOCKING_H_ +#endif // XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUES_H_ diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.cc b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.cc index 38b574dd362a3b..196a37a9a6a9f7 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.cc +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/statusor.h" #include "xla/core/collectives/collectives.h" #include "xla/core/collectives/collectives_registry.h" @@ -25,8 +27,8 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/casts.h" -#include "tsl/platform/logging.h" namespace xla::gpu { diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h index ad64b910c6c97e..590d085450ee1a 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h @@ -50,7 +50,7 @@ class GpuCollectivesStub : public GpuCollectives { } absl::StatusOr>> - CreateCommunicators(int32_t, const CliqueKey&, const std::optional&, + CreateCommunicators(const CliqueKey&, const std::optional&, absl::Span, const Collectives::Config&) final { return UnimplementedError(); diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc index eeb6201aed71e6..59d0117c325c93 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/algorithm/container.h" @@ -29,6 +28,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/backends/gpu/collectives/nccl_communicator.h" @@ -40,11 +40,11 @@ limitations under the License. #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" #include "xla/status_macros.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "tsl/platform/casts.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" #if TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" @@ -69,7 +69,7 @@ absl::StatusOr NcclCollectives::CreateUniqueCliqueId() const { VLOG(3) << "Create NCCL unique clique id"; ncclUniqueId id; XLA_NCCL_RETURN_IF_ERROR(ncclGetUniqueId(&id)); - return CliqueId(std::string_view(id.internal, NCCL_UNIQUE_ID_BYTES)); + return CliqueId(absl::string_view(id.internal, NCCL_UNIQUE_ID_BYTES)); } bool NcclCollectives::IsGlobalConfig() const { @@ -115,8 +115,7 @@ static absl::StatusOr AsNcclUniqueId(const CliqueId& clique_id) { } absl::StatusOr>> -NcclCollectives::CreateCommunicators(int32_t nranks, - const CliqueKey& clique_key, +NcclCollectives::CreateCommunicators(const CliqueKey& clique_key, const std::optional& clique_id, absl::Span ranks, const Collectives::Config& config) { @@ -140,15 +139,15 @@ NcclCollectives::CreateCommunicators(int32_t nranks, TF_RETURN_IF_ERROR(GroupStart()); for (size_t i = 0; i < ranks.size(); ++i) { VLOG(1) << "Initialize NCCL communicator for rank #" << ranks[i].rank - << " of " << nranks + << " of " << clique_key.num_devices() << "; fingerprint(id)=" << clique_id->fingerprint(); TF_ASSIGN_OR_RETURN(auto* device, TryCast(ranks[i].device)); auto activate_context = device->stream_executor()->Activate(); TF_ASSIGN_OR_RETURN(auto nccl_unique_id, AsNcclUniqueId(*clique_id)); - XLA_NCCL_RETURN_IF_ERROR( - ncclCommInitRankConfig(&comm_handles[i], nranks, nccl_unique_id, - ranks[i].rank.value(), &comm_config)); + XLA_NCCL_RETURN_IF_ERROR(ncclCommInitRankConfig( + &comm_handles[i], clique_key.num_devices(), nccl_unique_id, + ranks[i].rank.value(), &comm_config)); } TF_RETURN_IF_ERROR(GroupEnd()); diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h index c8fb34f6276355..721e94d0bc4214 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h @@ -49,7 +49,7 @@ class NcclCollectives : public GpuCollectives { absl::Status GroupEnd() final; absl::StatusOr>> - CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, + CreateCommunicators(const CliqueKey& clique_key, const std::optional& clique_id, absl::Span ranks, const Collectives::Config& config) final; diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.cc b/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.cc index 3cd333a395024b..17f92e9575d544 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.cc +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.cc @@ -18,23 +18,28 @@ limitations under the License. #include #include #include +#include #include #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/backends/gpu/collectives/nccl_errors.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/primitive_util.h" #include "xla/service/collective_ops_utils.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "tsl/platform/casts.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" #if TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" @@ -230,7 +235,7 @@ absl::Status NcclCommunicator::AllReduce( absl::Status NcclCommunicator::Broadcast(se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, - size_t root, + RankId root, const Executor& executor) { TF_ASSIGN_OR_RETURN(se::Stream * stream, ToStream(executor)); @@ -240,13 +245,13 @@ absl::Status NcclCommunicator::Broadcast(se::DeviceMemoryBase send_buffer, "stream=%p", stream->parent()->device_ordinal(), send_buffer.opaque(), recv_buffer.opaque(), primitive_util::LowercasePrimitiveTypeName(dtype), - count, root, comm_, stream); + count, root.value(), comm_, stream); TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); return XLA_NCCL_STATUS(ncclBroadcast( send_buffer.opaque(), recv_buffer.opaque(), ToNcclCount(dtype, count), - nccl_dtype, root, comm_, se::gpu::AsGpuStreamValue(stream))); + nccl_dtype, root.value(), comm_, se::gpu::AsGpuStreamValue(stream))); } absl::Status NcclCommunicator::ReduceScatter(se::DeviceMemoryBase send_buffer, @@ -292,67 +297,142 @@ absl::Status NcclCommunicator::AllGather(se::DeviceMemoryBase send_buffer, nccl_dtype, comm_, se::gpu::AsGpuStreamValue(stream))); } -absl::Status NcclCommunicator::Send(se::DeviceMemoryBase send_buffer, - PrimitiveType dtype, size_t count, - int32_t peer, const Executor& executor) { +absl::Status NcclCommunicator::AllToAll( + absl::Span send_buffers, + absl::Span recv_buffers, PrimitiveType dtype, + size_t count, const Executor& executor) { TF_ASSIGN_OR_RETURN(se::Stream * stream, ToStream(executor)); + auto buffer_formatter = [](std::string* out, se::DeviceMemoryBase buffer) { + absl::StrAppendFormat(out, "%p", buffer.opaque()); + }; + VLOG(3) << absl::StreamFormat( - "Launch NCCL Send operation on device #%d; send_buffer=%p; dtype=%s; " - "count=%d; peer=%d; comm=%p; stream=%p", - stream->parent()->device_ordinal(), send_buffer.opaque(), - primitive_util::LowercasePrimitiveTypeName(dtype), count, peer, comm_, - stream); + "Launch NCCL AllToAll operation on device #%d; send_buffers=[%s]; " + "recv_buffers=[%s]; dtype=%s; count=%d; comm=%p; stream=%p", + stream->parent()->device_ordinal(), + absl::StrJoin(send_buffers, ", ", buffer_formatter), + absl::StrJoin(recv_buffers, ", ", buffer_formatter), + primitive_util::LowercasePrimitiveTypeName(dtype), count, comm_, stream); + + if (send_buffers.size() != recv_buffers.size()) { + return InvalidArgument( + "Number of send buffers must match number of recv buffers: %d != %d", + send_buffers.size(), recv_buffers.size()); + } + + int32_t num_ranks; + XLA_NCCL_RETURN_IF_ERROR(ncclCommCount(comm_, &num_ranks)); + + if (send_buffers.size() != num_ranks) { + return InvalidArgument( + "Number of send buffers must match number of ranks: %d != %d", + send_buffers.size(), num_ranks); + } TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); - return XLA_NCCL_STATUS(ncclSend(send_buffer.opaque(), - ToNcclCount(dtype, count), nccl_dtype, peer, - comm_, se::gpu::AsGpuStreamValue(stream))); + XLA_NCCL_RETURN_IF_ERROR(ncclGroupStart()); + + for (size_t i = 0; i < send_buffers.size(); ++i) { + se::DeviceMemoryBase send_buffer = send_buffers[i]; + se::DeviceMemoryBase recv_buffer = recv_buffers[i]; + + XLA_NCCL_RETURN_IF_ERROR( + ncclSend(send_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, i, + comm_, se::gpu::AsGpuStreamValue(stream))); + + XLA_NCCL_RETURN_IF_ERROR( + ncclRecv(recv_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, i, + comm_, se::gpu::AsGpuStreamValue(stream))); + } + + XLA_NCCL_RETURN_IF_ERROR(ncclGroupEnd()); + + return absl::OkStatus(); } -absl::Status NcclCommunicator::SendPtrToPeer(void* ptr, int32_t peer, - const Executor& executor) { +absl::Status NcclCommunicator::CollectivePermute( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, std::optional source_rank, + absl::Span target_ranks, const Executor& executor) { TF_ASSIGN_OR_RETURN(se::Stream * stream, ToStream(executor)); + auto rank_formatter = [](std::string* out, RankId rank) { + absl::StrAppendFormat(out, "%d", rank.value()); + }; + VLOG(3) << absl::StreamFormat( - "Launch NCCL RecvPtrFromPeer operation on device #%d; " - "peer=%d; comm=%p; stream=%p", - stream->parent()->device_ordinal(), peer, comm_, stream); - return XLA_NCCL_STATUS(ncclSend(ptr, 1, ncclUint64, peer, comm_, - se::gpu::AsGpuStreamValue(stream))); + "Launch NCCL CollectivePermute operation on device #%d; send_buffer=%p; " + "recv_buffer=%p; dtype=%s; source_rank=%s; target_ranks=[%s]; count=%d; " + "comm=%p; stream=%p", + stream->parent()->device_ordinal(), send_buffer.opaque(), + recv_buffer.opaque(), primitive_util::LowercasePrimitiveTypeName(dtype), + source_rank ? absl::StrCat(source_rank->value()) : "", + absl::StrJoin(target_ranks, ", ", rank_formatter), count, comm_, stream); + + TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); + + // Short-circuit if there is no source or target rank. + if (!source_rank && target_ranks.empty()) { + return absl::OkStatus(); + } + + XLA_NCCL_RETURN_IF_ERROR(ncclGroupStart()); + + if (source_rank) { + XLA_NCCL_RETURN_IF_ERROR(ncclRecv( + recv_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, + source_rank->value(), comm_, se::gpu::AsGpuStreamValue(stream))); + } + + for (auto target_rank : target_ranks) { + XLA_NCCL_RETURN_IF_ERROR(ncclSend( + send_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, + target_rank.value(), comm_, se::gpu::AsGpuStreamValue(stream))); + } + + XLA_NCCL_RETURN_IF_ERROR(ncclGroupEnd()); + + return absl::OkStatus(); } -absl::Status NcclCommunicator::Recv(se::DeviceMemoryBase recv_buffer, +absl::Status NcclCommunicator::Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype, size_t count, - int32_t peer, const Executor& executor) { + RankId peer, const Executor& executor) { TF_ASSIGN_OR_RETURN(se::Stream * stream, ToStream(executor)); VLOG(3) << absl::StreamFormat( - "Launch NCCL Recv operation on device #%d; recv_buffer=%p; dtype=%s; " + "Launch NCCL Send operation on device #%d; send_buffer=%p; dtype=%s; " "count=%d; peer=%d; comm=%p; stream=%p", - stream->parent()->device_ordinal(), recv_buffer.opaque(), - primitive_util::LowercasePrimitiveTypeName(dtype), count, peer, comm_, - stream); + stream->parent()->device_ordinal(), send_buffer.opaque(), + primitive_util::LowercasePrimitiveTypeName(dtype), count, peer.value(), + comm_, stream); TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); - return XLA_NCCL_STATUS(ncclRecv(recv_buffer.opaque(), - ToNcclCount(dtype, count), nccl_dtype, peer, - comm_, se::gpu::AsGpuStreamValue(stream))); + return XLA_NCCL_STATUS( + ncclSend(send_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, + peer.value(), comm_, se::gpu::AsGpuStreamValue(stream))); } -absl::Status NcclCommunicator::RecvPtrFromPeer(void* ptr, int32_t peer, - const Executor& executor) { +absl::Status NcclCommunicator::Recv(se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + RankId peer, const Executor& executor) { TF_ASSIGN_OR_RETURN(se::Stream * stream, ToStream(executor)); VLOG(3) << absl::StreamFormat( - "Launch NCCL RecvPtrFromPeer operation on device #%d; " - "peer=%d; comm=%p; stream=%p", - stream->parent()->device_ordinal(), peer, comm_, stream); + "Launch NCCL Recv operation on device #%d; recv_buffer=%p; dtype=%s; " + "count=%d; peer=%d; comm=%p; stream=%p", + stream->parent()->device_ordinal(), recv_buffer.opaque(), + primitive_util::LowercasePrimitiveTypeName(dtype), count, peer.value(), + comm_, stream); + + TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); - return XLA_NCCL_STATUS(ncclRecv(ptr, 1, ncclUint64, peer, comm_, - se::gpu::AsGpuStreamValue(stream))); + return XLA_NCCL_STATUS( + ncclRecv(recv_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, + peer.value(), comm_, se::gpu::AsGpuStreamValue(stream))); } std::string NcclCommunicator::ToString() const { diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.h b/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.h index ca59dd554885e1..07211c0be93992 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.h +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.h @@ -19,11 +19,14 @@ limitations under the License. #include #include #include +#include #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" @@ -61,7 +64,7 @@ class NcclCommunicator : public Communicator { absl::Status Broadcast(se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, size_t root, + size_t count, RankId root, const Executor& executor) final; absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, @@ -74,17 +77,23 @@ class NcclCommunicator : public Communicator { se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, const Executor& executor) final; - absl::Status Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype, - size_t count, int32_t peer, const Executor& executor) final; + absl::Status AllToAll(absl::Span send_buffers, + absl::Span recv_buffers, + PrimitiveType dtype, size_t count, + const Executor& executor) final; - absl::Status SendPtrToPeer(void* ptr, int32_t peer, - const Executor& executor) final; + absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + std::optional source_rank, + absl::Span target_ranks, + const Executor& executor) final; - absl::Status Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, int32_t peer, const Executor& executor) final; + absl::Status Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype, + size_t count, RankId peer, const Executor& executor) final; - absl::Status RecvPtrFromPeer(void* ptr, int32_t peer, - const Executor& executor) final; + absl::Status Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, RankId peer, const Executor& executor) final; std::string ToString() const final; diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_errors.h b/third_party/xla/xla/backends/gpu/collectives/nccl_errors.h index 61feee68cbdc31..473fc9f10a14ac 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_errors.h +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_errors.h @@ -25,7 +25,7 @@ limitations under the License. //===----------------------------------------------------------------------===// #define XLA_NCCL_STATUS(expr) \ - [](ncclResult_t s, std::string_view str) -> absl::Status { \ + [](ncclResult_t s, absl::string_view str) -> absl::Status { \ if (s == ncclSuccess) return absl::OkStatus(); \ return xla::Internal( \ "NCCL operation %s failed: %s. Last NCCL warning(error) log " \ diff --git a/third_party/xla/xla/backends/profiler/cpu/BUILD b/third_party/xla/xla/backends/profiler/cpu/BUILD index dad2b81f1b70ab..a02568bdbbb5dc 100644 --- a/third_party/xla/xla/backends/profiler/cpu/BUILD +++ b/third_party/xla/xla/backends/profiler/cpu/BUILD @@ -76,6 +76,7 @@ cc_library( ]), deps = [ "//xla/python/profiler/internal:python_hooks", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -133,9 +134,9 @@ xla_cc_test( "//xla/tsl/profiler/utils:timespan", "//xla/tsl/profiler/utils:xplane_schema", "//xla/tsl/profiler/utils:xplane_visitor", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:types", diff --git a/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc b/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc index 68fe3fc32c385c..7f7a12ff52b524 100644 --- a/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc +++ b/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc @@ -19,14 +19,15 @@ limitations under the License. #include #include +#include #include +#include "absl/synchronization/blocking_counter.h" #include "absl/types/optional.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "xla/tsl/profiler/utils/timespan.h" #include "xla/tsl/profiler/utils/xplane_schema.h" #include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tsl/platform/blocking_counter.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" #include "tsl/platform/threadpool.h" @@ -51,7 +52,7 @@ using ::tsl::profiler::XPlaneVisitor; using ::tsl::profiler::XStatVisitor; TEST(HostTracerTest, CollectsTraceMeEventsAsXSpace) { - tsl::uint32 thread_id; + int64_t thread_id; std::string thread_name = "MyThreadName"; tensorflow::profiler::XSpace space; @@ -166,7 +167,7 @@ TEST(HostTracerTest, CollectEventsFromThreadPool) { std::make_unique(/*env=*/Env::Default(), /*name=*/"HostTracerTest", /*num_threads=*/1); - tsl::BlockingCounter counter(1); + absl::BlockingCounter counter(1); auto tracer = CreateHostTracer({}); TF_EXPECT_OK(tracer->Start()); thread_pool->Schedule([&counter] { diff --git a/third_party/xla/xla/backends/profiler/cpu/python_tracer.cc b/third_party/xla/xla/backends/profiler/cpu/python_tracer.cc index 30c9982d9b132c..22704dec287566 100644 --- a/third_party/xla/xla/backends/profiler/cpu/python_tracer.cc +++ b/third_party/xla/xla/backends/profiler/cpu/python_tracer.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include "absl/log/log.h" #include "absl/status/status.h" #include "xla/python/profiler/internal/python_hooks.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/backends/profiler/gpu/BUILD b/third_party/xla/xla/backends/profiler/gpu/BUILD index 468d113f0f80cc..a85720e74d0a2c 100644 --- a/third_party/xla/xla/backends/profiler/gpu/BUILD +++ b/third_party/xla/xla/backends/profiler/gpu/BUILD @@ -5,6 +5,7 @@ load( "//xla/tsl:tsl.bzl", "if_google", "if_nvcc", + "if_oss", "internal_visibility", "tsl_copts", "tsl_gpu_library", @@ -178,7 +179,9 @@ tsl_gpu_library( "//xla/tsl/profiler/utils:per_thread", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", @@ -355,6 +358,7 @@ tsl_gpu_library( "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:platform_port", @@ -421,3 +425,34 @@ xla_test( "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) + +cuda_library( + name = "nvtx_with_cuda_kernels", + testonly = 1, + srcs = ["nvtx_with_cuda_kernels.cu.cc"], + hdrs = ["nvtx_with_cuda_kernels.h"], + copts = if_nvcc([ + "-nvcc_options", + "ptxas-options=-v", + ]), + local_defines = if_oss(["NVTX_VERSION_3_1=1"]), + tags = ["cuda-only"], + visibility = ["//visibility:public"], +) + +xla_test( + name = "nvtx_with_cuda_kernels_test", + size = "small", + srcs = ["nvtx_with_cuda_kernels_test.cc"], + backends = ["gpu"], + copts = tf_profiler_copts() + tsl_copts(), + tags = [ + "cuda-only", + "no_mac", + ], + deps = [ + ":nvtx_with_cuda_kernels", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:test_main", + ], +) diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.cc index ba9a352793e062..4f34107808e813 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "third_party/gpus/cuda/extras/CUPTI/include/cupti_activity.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/backends/profiler/gpu/cupti_interface.h" #include "tsl/platform/errors.h" @@ -99,6 +101,14 @@ using CuptiActivityMemsetTy = CUpti_ActivityMemset; using CuptiActivityGraphTraceTy = CUpti_ActivityGraphTrace; #endif // CUDA_VERSION >= 11070 +#if CUDA_VERSION >= 8000 +using CuptiActivityMarkerTy = CUpti_ActivityMarker2; +constexpr int kCuptiActivityMarkerVersion = 2; +#else +using CuptiActivityMarkerTy = CUpti_ActivityMarker; +constexpr int kCuptiActivityMarkerVersion = 1; +#endif // CUDA_VERSION >= 11070 + // Maps an OverheadKind enum to a const string. const char *getActivityOverheadKindString(CUpti_ActivityOverheadKind kind) { switch (kind) { @@ -208,6 +218,55 @@ void AddGraphTraceActivityEvent(CuptiEventCollectorDelegate &collector, }); } +template +const char *GetActivityMarkerDomain(const CuptiActivityMarkerTy *marker_trace) { + if constexpr (CuptiActivityMarkerVersion == 1) { + return ""; + } else { + return marker_trace->domain; + } +} + +void AddMarkerActivityEvent(CuptiEventCollectorDelegate &collector, + CuptiActivityMarkerTy *marker_trace) { + // Currently only support thread marker (i.e., nvtx range push/pop) + if (marker_trace->objectKind != CUPTI_ACTIVITY_OBJECT_THREAD) return; + if (marker_trace->flags == CUPTI_ACTIVITY_FLAG_MARKER_START) { + collector.receive(CuptiTracerEvent{ + /* .type = */ CuptiTracerEventType::ThreadMarkerStart, + /* .source = */ CuptiTracerEventSource::Activity, + /* .name = */ marker_trace->name, + /* .annotation = */ "", + /* .nvtx_range = */ + GetActivityMarkerDomain(marker_trace), + /* .start_time_ns = */ marker_trace->timestamp, + /* .end_time_ns = */ marker_trace->timestamp, + /* .device_id = */ 0, + /* .correlation_id = */ 0, + /* .thread_id = */ marker_trace->objectId.pt.threadId, + /* .context_id = */ 0, + /* .stream_id = */ 0, + /* .graph_id = */ marker_trace->id, + }); + } else if (marker_trace->flags == CUPTI_ACTIVITY_FLAG_MARKER_END) { + collector.receive(CuptiTracerEvent{ + /* .type = */ CuptiTracerEventType::ThreadMarkerEnd, + /* .source = */ CuptiTracerEventSource::Activity, + /* .name = */ "", + /* .annotation = */ "", + /* .nvtx_range = */ "", + /* .start_time_ns = */ marker_trace->timestamp, + /* .end_time_ns = */ marker_trace->timestamp, + /* .device_id = */ 0, + /* .correlation_id = */ 0, + /* .thread_id = */ marker_trace->objectId.pt.threadId, + /* .context_id = */ 0, + /* .stream_id = */ 0, + /* .graph_id = */ marker_trace->id, + }); + } +} + void AddMemcpyActivityEvent(CuptiEventCollectorDelegate &collector, const CuptiActivityMemcpyTy *memcpy) { CuptiTracerEvent event{}; @@ -512,6 +571,10 @@ static absl::Status ConvertActivityBuffer( collector, reinterpret_cast(record)); break; #endif + case CUPTI_ACTIVITY_KIND_MARKER: + AddMarkerActivityEvent( + collector, reinterpret_cast(record)); + break; default: VLOG(3) << "Activity type " << record->kind << " is not supported."; break; diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.h b/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.h index add9875ac27148..fb77d4c080816d 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.h +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.h @@ -174,6 +174,9 @@ enum class CuptiTracerEventType { HostRegister = 13, HostUnregister = 14, CudaGraph = 15, + ThreadMarkerRange = 16, + ThreadMarkerStart = 17, + ThreadMarkerEnd = 18, Generic = 100, }; @@ -187,8 +190,8 @@ enum class CuptiTracerEventSource { }; struct CuptiTracerEvent { - static constexpr uint32_t kInvalidThreadId = - std::numeric_limits::max(); + static constexpr uint64_t kInvalidThreadId = + std::numeric_limits::max(); static constexpr uint32_t kInvalidCorrelationId = std::numeric_limits::max(); static constexpr uint64_t kInvalidContextId = @@ -209,7 +212,7 @@ struct CuptiTracerEvent { uint64_t end_time_ns = 0; uint32_t device_id = 0; uint32_t correlation_id = kInvalidCorrelationId; - uint32_t thread_id = kInvalidThreadId; + uint64_t thread_id = kInvalidThreadId; int64_t context_id = kInvalidContextId; int64_t stream_id = kInvalidStreamId; uint32_t graph_id = 0; @@ -363,11 +366,11 @@ class CallbackAnnotationsAndEvents { size_t NumAnnotations() const { return annotations_.Size(); } - std::string_view DedupAnnotation(std::string_view str) { + absl::string_view DedupAnnotation(absl::string_view str) { return annotations_.Dedup(str); } - std::string_view DedupNvtxRange(std::string_view str) { + absl::string_view DedupNvtxRange(absl::string_view str) { return nvtx_ranges_.Dedup(str); } diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc index 81cc6b6ca3b1ee..fc8a396e5aa07d 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc @@ -62,11 +62,14 @@ namespace profiler { namespace { using tensorflow::profiler::XEventMetadata; +using tensorflow::profiler::XLine; +using tensorflow::profiler::XPlane; using tensorflow::profiler::XSpace; using tensorflow::profiler::XStatMetadata; using tsl::mutex; using tsl::mutex_lock; using tsl::profiler::Annotation; +using tsl::profiler::FindMutablePlaneWithName; using tsl::profiler::FindOrAddMutablePlaneWithName; using tsl::profiler::GpuPlaneName; using tsl::profiler::kCuptiDriverApiPlaneName; @@ -79,13 +82,27 @@ using tsl::profiler::XEventBuilder; using tsl::profiler::XLineBuilder; using tsl::profiler::XPlaneBuilder; +static constexpr int64_t kNvtxLineIdStart = 1LL << 32; +static constexpr int64_t kNvtxLineIdEnd = 2LL << 32; + +bool IsNvtxLine(int64_t line_id) { + return line_id >= kNvtxLineIdStart && line_id < kNvtxLineIdEnd; +} + bool IsHostEvent(const CuptiTracerEvent& event, int64_t* line_id) { // DriverCallback(i.e. kernel launching) events are host events. if (event.source == CuptiTracerEventSource::DriverCallback) { *line_id = event.thread_id; return true; } - // Non-overhead activity events are device events. + // nvtx marker events from activity source are host events. Those markers + // are put into a separate line whose id value greater than kNvtxLineIdStart. + if (event.source == CuptiTracerEventSource::Activity && + event.type == CuptiTracerEventType::ThreadMarkerRange) { + *line_id = kNvtxLineIdStart + event.thread_id; + return true; + } + // Other non-overhead activity events are device events. if (event.type != CuptiTracerEventType::Overhead) { *line_id = event.stream_id; return false; @@ -106,6 +123,37 @@ bool IsHostEvent(const CuptiTracerEvent& event, int64_t* line_id) { } } +int64_t GetNextAvailableLineId(absl::flat_hash_set& occupied_line_ids, + int64_t next_line_id) { + while (occupied_line_ids.contains(next_line_id)) ++next_line_id; + occupied_line_ids.insert(next_line_id); + return next_line_id; +} + +// Change the line id of the lines where line id >= kNvtxLineIdStart to +// any non-occupied line id start from 1, making sure the lower 32 bits value of +// the line ids are unique. This is to avoid the effective line id conflict +// which only count on the lower 32 bits of the line id in further analysis. +void AdjustHostPlaneNvtxLines(XPlane* plane) { + // Get all occupied line ids with value less than kNvtxLineIdStart. + absl::flat_hash_set occupied_line_ids; + for (const XLine& line : plane->lines()) { + if (line.id() < kNvtxLineIdStart) { + occupied_line_ids.insert(line.id()); + } + } + + // Change the line id, whose id value > kNvtxLineIdStart, to a non-occupied + // line id in uint32 range. + int64_t next_line_id = 0; + for (XLine& line : *plane->mutable_lines()) { + if (line.id() >= kNvtxLineIdStart) { + next_line_id = GetNextAvailableLineId(occupied_line_ids, next_line_id); + line.set_id(next_line_id); + } + } +} + struct DeviceOccupancyParams { cudaOccFuncAttributes attributes = {}; int block_size = 0; @@ -165,7 +213,7 @@ class PerDeviceCollector { return stats; } - void CreateXEvent(const CuptiTracerEvent& event, XPlaneBuilder* plane, + void CreateXEvent(CuptiTracerEvent& event, XPlaneBuilder* plane, uint64_t start_gpu_ns, uint64_t end_gpu_ns, XLineBuilder* line) { if (event.start_time_ns < start_gpu_ns || event.end_time_ns > end_gpu_ns || @@ -183,6 +231,12 @@ class PerDeviceCollector { if (event.graph_id != 0 && event.type == CuptiTracerEventType::CudaGraph && event.source == CuptiTracerEventSource::DriverCallback) { absl::StrAppend(&kernel_name, " (CudaGraph:", event.graph_id, ")"); + } else if (event.type == CuptiTracerEventType::ThreadMarkerRange) { + kernel_name = + event.nvtx_range.empty() + ? absl::StrCat("NVTX:", kernel_name) + : absl::StrCat("NVTX:", event.nvtx_range, ":", kernel_name); + event.nvtx_range = ""; } XEventMetadata* event_metadata = plane->GetOrCreateEventMetadata(std::move(kernel_name)); @@ -410,7 +464,15 @@ class PerDeviceCollector { GetDeviceXLineName(line.Id(), events_types_per_line[line.Id()])); }); host_plane->ForEachLine([&](XLineBuilder line) { - line.SetName(absl::StrCat("Host Threads/", line.Id())); + if (IsNvtxLine(line.Id())) { + // Lines will order by name, by appending suffix to the normal cupti + // line name, the nvtx lines will be placed right after their + // corresponding cupti lines. + line.SetName(absl::StrCat("Host Threads/", + static_cast(line.Id()), "/NVTX")); + } else { + line.SetName(absl::StrCat("Host Threads/", line.Id())); + } }); size_t num_events = events_.size(); events_.clear(); @@ -680,6 +742,7 @@ void CuptiTraceCollector::OnTracerCachedActivityBuffers( // CuptiTraceCollectorImpl store the CuptiTracerEvents from CuptiTracer and // eventually convert and filter them to XSpace. +// It also add support to handle cupti activity events for nvtx thread markers. class CuptiTraceCollectorImpl : public CuptiTraceCollector { public: CuptiTraceCollectorImpl(const CuptiTracerCollectorOptions& option, @@ -699,6 +762,13 @@ class CuptiTraceCollectorImpl : public CuptiTraceCollector { } else { num_activity_events_++; } + if (event.type == CuptiTracerEventType::ThreadMarkerStart || + event.type == CuptiTracerEventType::ThreadMarkerEnd) { + // Process the nvtx marker, merge thread range start/end if appropriate. + // If merged, the event will contains the merged content, and be used for + // followed AddEvent() processing. + if (!AddNvtxMarker(event)) return; + } per_device_collector_[event.device_id].AddEvent(std::move(event)); } void OnEventsDropped(const std::string& reason, @@ -745,6 +815,8 @@ class CuptiTraceCollectorImpl : public CuptiTraceCollector { start_gpu_ns_, end_gpu_ns, &device_plane, &host_plane); NormalizeTimeStamps(&device_plane, start_walltime_ns_); } + AdjustHostPlaneNvtxLines( + FindMutablePlaneWithName(space, kCuptiDriverApiPlaneName)); NormalizeTimeStamps(&host_plane, start_walltime_ns_); return num_events > 0; } @@ -775,6 +847,39 @@ class CuptiTraceCollectorImpl : public CuptiTraceCollector { uint64_t start_walltime_ns_; uint64_t start_gpu_ns_; int num_gpus_; + uint32_t num_duplicate_nvtx_marker_start_ = 0; + uint32_t num_unmatched_nvtx_marker_end_ = 0; + + // process the nvtx marker, a)cache range start event, or b)merge range end + // with its corresponding start event. If merged, the event be updated with + // the merged content and return true. If not merged, return false. + bool AddNvtxMarker(CuptiTracerEvent& event) { + const uint32_t marker_id = event.graph_id; + auto it = nvtx_markers_.find(marker_id); + if (event.type == CuptiTracerEventType::ThreadMarkerStart) { + if (it == nvtx_markers_.end()) { + nvtx_markers_[marker_id] = + std::make_unique(std::move(event)); + } else { + LOG_IF(ERROR, ++num_duplicate_nvtx_marker_start_ < 100) + << "Duplicate nvtx thread range start marker id: " << marker_id; + } + } else if (event.type == CuptiTracerEventType::ThreadMarkerEnd) { + if (it != nvtx_markers_.end()) { + it->second->type = CuptiTracerEventType::ThreadMarkerRange; + it->second->end_time_ns = event.end_time_ns; + it->second->graph_id = 0; + event = std::move(*it->second); + nvtx_markers_.erase(it); + return true; // The event is merged for further processing. + } else { + LOG_IF(ERROR, ++num_unmatched_nvtx_marker_end_ < 100) + << "Unmatched nvtx thread range end marker id: " << marker_id; + } + } + // No merged event is generated, return false. + return false; + } // Set the all XLines of specified XPlane to starting walltime. // Events time in both host and device planes are CUTPI timestamps. @@ -788,6 +893,8 @@ class CuptiTraceCollectorImpl : public CuptiTraceCollector { } absl::FixedArray per_device_collector_; + absl::flat_hash_map> + nvtx_markers_; CuptiTraceCollectorImpl(const CuptiTraceCollectorImpl&) = delete; void operator=(const CuptiTraceCollectorImpl&) = delete; diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager.cc index a4aab82e11ed31..94535afc9c249c 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager.cc @@ -279,6 +279,14 @@ CUptiResult CuptiErrorManager::GetGraphExecId(CUgraphExec graph_exec, return error; } +CUptiResult CuptiErrorManager::SetThreadIdType( + CUpti_ActivityThreadIdType type) { + IGNORE_CALL_IF_DISABLED; + CUptiResult error = interface_->SetThreadIdType(type); + LOG_AND_DISABLE_IF_ERROR(error); + return error; +} + void CuptiErrorManager::CleanUp() { if (undo_disabled_) { // prevent deadlock return; diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager.h b/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager.h index 82b547df1c8ded..79b124a5c194f5 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager.h +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager.h @@ -117,6 +117,8 @@ class CuptiErrorManager : public xla::profiler::CuptiInterface { CUptiResult GetGraphExecId(CUgraphExec graph_exec, uint32_t* graph_id) override; + CUptiResult SetThreadIdType(CUpti_ActivityThreadIdType type) override; + // Clears Undo stack. We are maintaining undo stack for each profiling phase. // Once the profiling is done, we need to clear the undo stack. void CleanUp() override; diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager_test.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager_test.cc index 05aa020d84ab9e..7b369fa6fa59c0 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager_test.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager_test.cc @@ -124,6 +124,9 @@ TEST_F(CuptiErrorManagerTest, GpuTraceActivityEnableTest) { .InSequence(s1) .WillRepeatedly( Invoke(cupti_wrapper_.get(), &CuptiWrapper::EnableCallback)); + EXPECT_CALL(*mock_, SetThreadIdType(_)) + .InSequence(s1) + .WillOnce(Invoke(cupti_wrapper_.get(), &CuptiWrapper::SetThreadIdType)); EXPECT_CALL(*mock_, ActivityUsePerThreadBuffer()) .InSequence(s1) .WillOnce(Invoke(cupti_wrapper_.get(), @@ -182,6 +185,9 @@ TEST_F(CuptiErrorManagerTest, GpuTraceAutoEnableTest) { EXPECT_CALL(*mock_, EnableDomain(1, _, _)) .InSequence(s1) .WillOnce(Invoke(cupti_wrapper_.get(), &CuptiWrapper::EnableDomain)); + EXPECT_CALL(*mock_, SetThreadIdType(_)) + .InSequence(s1) + .WillOnce(Invoke(cupti_wrapper_.get(), &CuptiWrapper::SetThreadIdType)); EXPECT_CALL(*mock_, ActivityUsePerThreadBuffer()) .InSequence(s1) .WillOnce(Invoke(cupti_wrapper_.get(), diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_interface.h b/third_party/xla/xla/backends/profiler/gpu/cupti_interface.h index 35b0ae5ab1b997..c577b1e15a7a24 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_interface.h +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_interface.h @@ -99,6 +99,8 @@ class CuptiInterface { virtual CUptiResult GetGraphExecId(CUgraphExec graph_exec, uint32_t* graph_id) = 0; + virtual CUptiResult SetThreadIdType(CUpti_ActivityThreadIdType type) = 0; + // Interface maintenance functions. Not directly related to CUPTI, but // required for implementing an error resilient layer over CUPTI API. diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.cc index 374cc6e7746306..c6ccf2ece89fec 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.cc @@ -15,15 +15,18 @@ limitations under the License. #include "xla/backends/profiler/gpu/cupti_tracer.h" +#include #include -#include #include #include #include #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/types/span.h" +#include "third_party/gpus/cuda/extras/CUPTI/include/cupti_activity.h" +#include "third_party/gpus/cuda/extras/CUPTI/include/cupti_result.h" #include "third_party/gpus/cuda/extras/CUPTI/include/generated_nvtx_meta.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/backends/profiler/gpu/cupti_buffer_events.h" @@ -851,11 +854,6 @@ absl::Status AddDriverApiCallbackEvent( CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata) { absl::string_view annotation = AnnotationStack::Get(); absl::string_view nvtx_range = ""; - if (!annotation.empty() && - cbid != CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernelMultiDevice) { - nvtx_range = NVTXRangeTracker::CurrentRange(); - } - auto &guarded_annotations_and_events = PerThreadCallbackAnnotationsAndEvents::Get(); if (tracer->TooManyCallbackEvents()) { @@ -993,6 +991,12 @@ const char *GetTraceEventTypeName(const CuptiTracerEventType &type) { return "HostUnregister"; case CuptiTracerEventType::CudaGraph: return "CudaGraph"; + case CuptiTracerEventType::ThreadMarkerRange: + return "ThreadMarkerRange"; + case CuptiTracerEventType::ThreadMarkerStart: + return "ThreadMarkerStart"; + case CuptiTracerEventType::ThreadMarkerEnd: + return "ThreadMarkerEnd"; case CuptiTracerEventType::Unsupported: return ""; } @@ -1030,8 +1034,21 @@ void CuptiTracer::Enable(const CuptiTracerOptions &option, option_ = option; collector_ = collector; + // For nvtx tracking, utilize CUPTI activity marker and marker_data. + if (option_->enable_nvtx_tracking) { + std::vector &activities = option_->activities_selected; + if (std::find(activities.begin(), activities.end(), + CUPTI_ACTIVITY_KIND_MARKER) == activities.end()) { + VLOG(1) << "Adding CUPTI_ACTIVITY_KIND_MARKER to activities:" + << (int)CUPTI_ACTIVITY_KIND_MARKER; + activities.push_back(CUPTI_ACTIVITY_KIND_MARKER); + } + // TODO: Add CUPTI_ACTIVITY_KIND_MARKER_DATA to activities after cupti + // more detailed data could be provided by cupti. + } + cupti_driver_api_hook_ = std::make_unique( - option, cupti_interface_, this); + *option_, cupti_interface_, this); absl::Status status = EnableApiTracing(); need_root_access_ |= status.code() == tsl::error::PERMISSION_DENIED; @@ -1144,10 +1161,10 @@ absl::Status CuptiTracer::EnableApiTracing() { 1 /* ENABLE */, subscriber_, CUPTI_CB_DOMAIN_DRIVER_API)); } - if (option_->enable_nvtx_tracking) { - RETURN_IF_CUPTI_ERROR(cupti_interface_->EnableDomain( - 1 /* ENABLE */, subscriber_, CUPTI_CB_DOMAIN_NVTX)); - } + // There is no easy api to get the domain string from CUPTI_CB_DOMAIN_NVTX + // callback. So we use ACTIVIY_MARKERS to get the domain/range_name strings, + // and generate the related nvtx range event. So we do not need to use the + // CUPTI_CB_DOMAIN_NVTX callback here. return absl::OkStatus(); } @@ -1172,11 +1189,6 @@ absl::Status CuptiTracer::DisableApiTracing() { 0 /* DISABLE */, subscriber_, CUPTI_CB_DOMAIN_DRIVER_API)); } - if (option_->enable_nvtx_tracking) { - RETURN_IF_CUPTI_ERROR(cupti_interface_->EnableDomain( - 0 /* DISABLE */, subscriber_, CUPTI_CB_DOMAIN_NVTX)); - } - VLOG(1) << "Disable subscriber"; RETURN_IF_CUPTI_ERROR(cupti_interface_->Unsubscribe(subscriber_)); return absl::OkStatus(); @@ -1186,6 +1198,14 @@ absl::Status CuptiTracer::EnableActivityTracing() { if (activity_tracing_enabled_) return absl::OkStatus(); PrepareActivityStart(); if (!option_->activities_selected.empty()) { + if (cupti_interface_->SetThreadIdType( + CUPTI_ACTIVITY_THREAD_ID_TYPE_SYSTEM) != CUPTI_SUCCESS) { + LOG(WARNING) + << "Failed to set CUPTI activity thread id type to " + "CUPTI_ACTIVITY_THREAD_ID_TYPE_SYSTEM, CUPTI reported thread id " + "may be different from system thread id get with gettid()"; + }; + // Initialize callback functions for Cupti Activity API. VLOG(1) << "Registering CUPTI activity callbacks"; if (auto err = cupti_interface_->ActivityUsePerThreadBuffer(); @@ -1251,25 +1271,6 @@ absl::Status CuptiTracer::Finalize() { return 0; } -absl::Status CuptiTracer::HandleNVTXCallback(CUpti_CallbackId cbid, - const CUpti_CallbackData *cbdata) { - const CUpti_NvtxData *pdata = - reinterpret_cast(cbdata); - if (cbid == CUPTI_CBID_NVTX_nvtxDomainRangePushEx) { - const nvtxDomainRangePushEx_params *params = - reinterpret_cast( - pdata->functionParams); - // TODO(profiler): The messageType is actually NVTX_MESSAGE_TYPE_REGISTERED - // (which is 3), However it seems to me that we can not get the registered - // string from nvtxDomainRegisterStringA_params. If we reinterpret the - // payload as ascii, it happen to work. - NVTXRangeTracker::EnterRange(params->core.eventAttrib->message.ascii); - } else if (cbid == CUPTI_CBID_NVTX_nvtxDomainRangePop) { - NVTXRangeTracker::ExitRange(); - } - return absl::OkStatus(); -} - // Resource callback happens logically inside a driver API call's enter/exit. // Some per-thread data structure to record the graph ids. absl::Status CuptiTracer::HandleResourceCallback( @@ -1334,7 +1335,6 @@ absl::Status CuptiTracer::HandleCallback(CUpti_CallbackDomain domain, if (!api_tracing_enabled_) return absl::OkStatus(); // already unsubscribed. if (!cupti_driver_api_hook_) return absl::OkStatus(); // already unsubscribed. - if (domain == CUPTI_CB_DOMAIN_NVTX) return HandleNVTXCallback(cbid, cbdata); if (domain == CUPTI_CB_DOMAIN_DRIVER_API) return HandleDriverApiCallback(cbid, cbdata); if (domain == CUPTI_CB_DOMAIN_RESOURCE) diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper.cc index 60a4ffc337cae8..e46d03b52c08b9 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper.cc @@ -137,6 +137,10 @@ CUptiResult CuptiWrapper::GetGraphExecId(CUgraphExec graph_exec, return GetGraphId(reinterpret_cast(graph_exec), graph_id); } +CUptiResult CuptiWrapper::SetThreadIdType(CUpti_ActivityThreadIdType type) { + return cuptiSetThreadIdType(type); +} + CUptiResult CuptiWrapper::GetStreamIdEx(CUcontext context, CUstream stream, uint8_t per_thread_stream, uint32_t* stream_id) { diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper.h b/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper.h index a9e081439503bf..9fc26c4c9e598c 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper.h +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper.h @@ -94,6 +94,8 @@ class CuptiWrapper : public xla::profiler::CuptiInterface { CUptiResult GetGraphExecId(CUgraphExec graph_exec, uint32_t* graph_id) override; + CUptiResult SetThreadIdType(CUpti_ActivityThreadIdType type) override; + void CleanUp() override {} bool Disabled() const override { return false; } @@ -173,6 +175,8 @@ class CuptiWrapperStub : public xla::profiler::CuptiInterface { CUptiResult GetGraphExecId(CUgraphExec graph_exec, uint32_t* graph_id) override; + CUptiResult SetThreadIdType(CUpti_ActivityThreadIdType type) override; + void CleanUp() override {} bool Disabled() const override { return false; } diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper_stub.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper_stub.cc index 82fb8dd9bed593..e3c6f2438c036b 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper_stub.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_wrapper_stub.cc @@ -122,5 +122,9 @@ CUptiResult CuptiWrapperStub::GetGraphExecId(CUgraphExec graph_exec, return CUPTI_SUCCESS; } +CUptiResult CuptiWrapperStub::SetThreadIdType(CUpti_ActivityThreadIdType type) { + return CUPTI_SUCCESS; +} + } // namespace profiler } // namespace xla diff --git a/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc b/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc index 578d4ab6d3021d..2d675afba107d4 100644 --- a/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc +++ b/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc @@ -46,8 +46,7 @@ using tsl::ReadBoolFromEnvVar; // GpuTracer for GPU. class GpuTracer : public tsl::profiler::ProfilerInterface { public: - GpuTracer(CuptiTracer* cupti_tracer, CuptiInterface* cupti_interface) - : cupti_tracer_(cupti_tracer) { + explicit GpuTracer(CuptiTracer* cupti_tracer) : cupti_tracer_(cupti_tracer) { VLOG(1) << "GpuTracer created."; } ~GpuTracer() override {} @@ -227,8 +226,7 @@ std::unique_ptr CreateGpuTracer( if (!cupti_tracer->IsAvailable()) { return nullptr; } - profiler::CuptiInterface* cupti_interface = profiler::GetCuptiInterface(); - return std::make_unique(cupti_tracer, cupti_interface); + return std::make_unique(cupti_tracer); } auto register_gpu_tracer_factory = [] { diff --git a/third_party/xla/xla/backends/profiler/gpu/mock_cupti.h b/third_party/xla/xla/backends/profiler/gpu/mock_cupti.h index 1f82ddda8a1ac6..6384a67c3b8625 100644 --- a/third_party/xla/xla/backends/profiler/gpu/mock_cupti.h +++ b/third_party/xla/xla/backends/profiler/gpu/mock_cupti.h @@ -85,6 +85,9 @@ class MockCupti : public xla::profiler::CuptiInterface { MOCK_METHOD(CUptiResult, GetGraphId, (CUgraph graph, uint32_t* graph_id), (override)); + MOCK_METHOD(CUptiResult, SetThreadIdType, (CUpti_ActivityThreadIdType type), + (override)); + MOCK_METHOD(CUptiResult, GetGraphExecId, (CUgraphExec graph_exec, uint32_t* graph_id), (override)); diff --git a/third_party/xla/xla/backends/profiler/gpu/nvtx_utils.h b/third_party/xla/xla/backends/profiler/gpu/nvtx_utils.h index 43f0c91bf917f7..9f253659957cf2 100644 --- a/third_party/xla/xla/backends/profiler/gpu/nvtx_utils.h +++ b/third_party/xla/xla/backends/profiler/gpu/nvtx_utils.h @@ -25,6 +25,8 @@ namespace xla { namespace profiler { /*** + * TODO: After using CUPTI activity marker, remove NVTXRangeTracker related + * code. * We have no intention to use NVTX in tensorflow right now, we use this class * to track NVTX instrumentation inside NVIDIA libraries (such as TensorRT). * This bears a lot of resemblance to ScopedAnnotation for now. In the future, diff --git a/third_party/xla/xla/backends/profiler/gpu/nvtx_with_cuda_kernels.cu.cc b/third_party/xla/xla/backends/profiler/gpu/nvtx_with_cuda_kernels.cu.cc new file mode 100644 index 00000000000000..6f408a80735a41 --- /dev/null +++ b/third_party/xla/xla/backends/profiler/gpu/nvtx_with_cuda_kernels.cu.cc @@ -0,0 +1,148 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/profiler/gpu/nvtx_with_cuda_kernels.h" + +#include + +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/include/nvtx3/nvToolsExt.h" + +namespace xla { +namespace profiler { +namespace test { + +namespace { + +nvtxDomainHandle_t XProfNvtxDomain() { + static nvtxDomainHandle_t domain = nvtxDomainCreateA("xprof"); + return domain; +} + +nvtxStringHandle_t RegisteredMessage(const char* message) { + return nvtxDomainRegisterStringA(XProfNvtxDomain(), message); +} + +class NvtxScopedRange final { + public: + explicit NvtxScopedRange(const char* range_name) { + nvtxEventAttributes_t event_attr{0}; + event_attr.version = NVTX_VERSION; + event_attr.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; + event_attr.messageType = NVTX_MESSAGE_TYPE_REGISTERED; + event_attr.message.registered = RegisteredMessage(range_name); + nvtxDomainRangePushEx(XProfNvtxDomain(), &event_attr); + } + + ~NvtxScopedRange() { nvtxDomainRangePop(XProfNvtxDomain()); } +}; + +__global__ void VecAdd(const int* a, const int* b, int* c, int n) { + int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < n) c[i] = a[i] + b[i]; +} + +__global__ void VecSub(const int* a, const int* b, int* c, int n) { + int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < n) c[i] = a[i] - b[i]; +} + +} // namespace + +#define SCOPEDRANGE(N) NvtxScopedRange range##__LINE__(N) + +std::vector SimpleAddSubWithNvtxTag(int num_elements) { + SCOPEDRANGE(__func__); + + std::vector vec_a; + std::vector vec_b; + std::vector vec_c; + { + SCOPEDRANGE("InitializeHostMemoryVectors"); + // Allocates input/output vectors in host memory. + vec_a.resize(num_elements, 10); + vec_b.resize(num_elements, 20); + vec_c.resize(num_elements, -1); + } + + int* d_a = nullptr; + int* d_b = nullptr; + int* d_c = nullptr; + cudaStream_t stream = nullptr; + const size_t num_bytes = num_elements * sizeof(int); + + { + SCOPEDRANGE("Preparing"); + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + // Allocates vectors in device memory. + cudaMalloc((void**)&d_a, num_bytes); + cudaMalloc((void**)&d_b, num_bytes); + cudaMalloc((void**)&d_c, num_bytes); + } + + { + SCOPEDRANGE("Processing"); + { + SCOPEDRANGE("CopyToDevice"); + // Copies vectors from host to device memory. + cudaMemcpyAsync(d_a, vec_a.data(), num_bytes, cudaMemcpyHostToDevice, + stream); + cudaMemcpyAsync(d_b, vec_b.data(), num_bytes, cudaMemcpyHostToDevice, + stream); + cudaMemcpyAsync(d_c, vec_c.data(), num_bytes, cudaMemcpyHostToDevice, + stream); + } + + { + SCOPEDRANGE("ComputeOnDevice"); + constexpr int kThreadsPerBlock = 256; + const int blocks_per_grid = + (num_elements + kThreadsPerBlock - 1) / kThreadsPerBlock; + + // b1[i] = a[i] + b[i] + VecAdd<<>>(d_a, d_b, d_b, + num_elements); + // c1[i] = a[i] - b1[i] = a[i] - (a[i] + b[i]) = -b[i] + VecSub<<>>(d_a, d_b, d_c, + num_elements); + // c2[i] = c1[i] + b1[i] ==> -b[i] + (a[i] + b[i]) = a[i] + VecAdd<<>>(d_c, d_b, d_c, + num_elements); + // c3[i] = c2[i] - a[i] = a[i] - a[i] = 0 + VecSub<<>>(d_c, d_a, d_c, + num_elements); + } + + { + SCOPEDRANGE("CopyToHost"); + // Copies vectors from device to host memory. + cudaMemcpyAsync(vec_c.data(), d_c, num_bytes, cudaMemcpyDeviceToHost, + stream); + } + } + + { + SCOPEDRANGE("WaitResult"); + cudaStreamSynchronize(stream); + cudaStreamDestroy(stream); + } + + return vec_c; +} + +} // namespace test +} // namespace profiler +} // namespace xla diff --git a/third_party/xla/xla/backends/profiler/gpu/nvtx_with_cuda_kernels.h b/third_party/xla/xla/backends/profiler/gpu/nvtx_with_cuda_kernels.h new file mode 100644 index 00000000000000..7f50e4bc68e95f --- /dev/null +++ b/third_party/xla/xla/backends/profiler/gpu/nvtx_with_cuda_kernels.h @@ -0,0 +1,32 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_PROFILER_GPU_NVTX_WITH_CUDA_KERNELS_H_ +#define XLA_BACKENDS_PROFILER_GPU_NVTX_WITH_CUDA_KERNELS_H_ + +#include + +namespace xla { +namespace profiler { +namespace test { + +// If runs correctly, the returned vector will only contain num_elements of 0. +std::vector SimpleAddSubWithNvtxTag(int num_elements); + +} // namespace test +} // namespace profiler +} // namespace xla + +#endif // XLA_BACKENDS_PROFILER_GPU_NVTX_WITH_CUDA_KERNELS_H_ diff --git a/third_party/xla/xla/backends/profiler/gpu/nvtx_with_cuda_kernels_test.cc b/third_party/xla/xla/backends/profiler/gpu/nvtx_with_cuda_kernels_test.cc new file mode 100644 index 00000000000000..33a24beafa409d --- /dev/null +++ b/third_party/xla/xla/backends/profiler/gpu/nvtx_with_cuda_kernels_test.cc @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/profiler/gpu/nvtx_with_cuda_kernels.h" + +#include + +#include + +namespace xla { +namespace profiler { +namespace test { + +namespace { + +// This test just verify the cuda kernels ares running well and generate correct +// output. +TEST(NvtxCudaKernelSanityTest, SimpleAddSub) { + constexpr int kNumElements = 2048; + std::vector vec = SimpleAddSubWithNvtxTag(kNumElements); + + EXPECT_EQ(vec.size(), kNumElements); + for (int i = 0; i < kNumElements; ++i) { + EXPECT_EQ(vec[i], 0) << "index: " << i; + } +} + +} // namespace + +} // namespace test +} // namespace profiler +} // namespace xla diff --git a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.h b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.h index 220fa2bb13e4a2..46e8e71eee77f0 100644 --- a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.h +++ b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_BACKENDS_PROFILER_GPU_ROCM_COLLECTOR_H_ #define XLA_BACKENDS_PROFILER_GPU_ROCM_COLLECTOR_H_ +#include +#include + #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_set.h" #include "xla/tsl/profiler/utils/xplane_builder.h" @@ -114,7 +117,7 @@ enum class RocmTracerEventDomain { HIP_OPS, }; const char* GetRocmTracerEventDomainName(const RocmTracerEventDomain& domain); -// RocmTracerSyncTypes forward decleration +// RocmTracerSyncTypes forward declaration enum class RocmTracerSyncTypes; struct SynchronizationDetails { @@ -124,8 +127,8 @@ struct SynchronizationDetails { struct RocmTracerEvent { static constexpr uint32_t kInvalidDeviceId = std::numeric_limits::max(); - static constexpr uint32_t kInvalidThreadId = - std::numeric_limits::max(); + static constexpr uint64_t kInvalidThreadId = + std::numeric_limits::max(); static constexpr uint32_t kInvalidCorrelationId = std::numeric_limits::max(); static constexpr uint64_t kInvalidStreamId = @@ -142,7 +145,7 @@ struct RocmTracerEvent { uint64_t end_time_ns = 0; uint32_t device_id = kInvalidDeviceId; uint32_t correlation_id = kInvalidCorrelationId; - uint32_t thread_id = kInvalidThreadId; + uint64_t thread_id = kInvalidThreadId; int64_t stream_id = kInvalidStreamId; union { MemcpyDetails memcpy_info; // If type == Memcpy* diff --git a/third_party/xla/xla/backends/profiler/gpu/rocm_tracer.cc b/third_party/xla/xla/backends/profiler/gpu/rocm_tracer.cc index fad3e39831c49a..2134c7f9d4e28a 100644 --- a/third_party/xla/xla/backends/profiler/gpu/rocm_tracer.cc +++ b/third_party/xla/xla/backends/profiler/gpu/rocm_tracer.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/backends/profiler/gpu/rocm_tracer.h" +#include + #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" #include "rocm/rocm_config.h" @@ -52,8 +54,8 @@ namespace { // GetCachedTID() caches the thread ID in thread-local storage (which is a // userspace construct) to avoid unnecessary system calls. Without this caching, // it can take roughly 98ns, while it takes roughly 1ns with this caching. -int32_t GetCachedTID() { - static thread_local int32_t current_thread_id = +int64_t GetCachedTID() { + static thread_local int64_t current_thread_id = tsl::Env::Default()->GetCurrentThreadId(); return current_thread_id; } diff --git a/third_party/xla/xla/bit_cast_test.cc b/third_party/xla/xla/bit_cast_test.cc index 8445b75aaaa5ad..c8d264662c72bd 100644 --- a/third_party/xla/xla/bit_cast_test.cc +++ b/third_party/xla/xla/bit_cast_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "Eigen/Core" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" #include "tsl/platform/bfloat16.h" namespace xla { diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD index 7af360379dcf38..75a63a1047ac19 100644 --- a/third_party/xla/xla/client/BUILD +++ b/third_party/xla/xla/client/BUILD @@ -66,6 +66,7 @@ cc_library( "//xla/hlo/builder:xla_computation", "//xla/service", "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", @@ -108,6 +109,7 @@ xla_cc_test( ":executable_build_options", "//xla:protobuf_util", "//xla:shape_util", + "//xla/pjrt:compile_options_proto_cc", "//xla/service:computation_placer", "//xla/service:test_compilation_environment_proto_cc", "//xla/tsl/lib/core:status_test_util", @@ -205,6 +207,7 @@ cc_library( "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:logging", diff --git a/third_party/xla/xla/client/client.cc b/third_party/xla/xla/client/client.cc index d6d4e8abb40fbc..8d20613b8542b5 100644 --- a/third_party/xla/xla/client/client.cc +++ b/third_party/xla/xla/client/client.cc @@ -15,13 +15,15 @@ limitations under the License. #include "xla/client/client.h" +#include #include #include -#include #include #include +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/execution_options_util.h" #include "xla/hlo/builder/xla_computation.h" diff --git a/third_party/xla/xla/client/client.h b/third_party/xla/xla/client/client.h index dfefdb615e86a3..9216d752b28abe 100644 --- a/third_party/xla/xla/client/client.h +++ b/third_party/xla/xla/client/client.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_CLIENT_H_ #define XLA_CLIENT_CLIENT_H_ +#include #include #include #include diff --git a/third_party/xla/xla/client/client_library.cc b/third_party/xla/xla/client/client_library.cc index 476208d78b0bfb..cfcc029b9807e1 100644 --- a/third_party/xla/xla/client/client_library.cc +++ b/third_party/xla/xla/client/client_library.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "xla/client/compile_only_client.h" #include "xla/client/local_client.h" diff --git a/third_party/xla/xla/client/compile_only_client.cc b/third_party/xla/xla/client/compile_only_client.cc index 1aa6a4f1a8c54c..0836abe955bbc2 100644 --- a/third_party/xla/xla/client/compile_only_client.cc +++ b/third_party/xla/xla/client/compile_only_client.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/client/compile_only_client.h" +#include #include #include diff --git a/third_party/xla/xla/client/compile_only_client.h b/third_party/xla/xla/client/compile_only_client.h index 8f755691940d49..a786bd1c6131ea 100644 --- a/third_party/xla/xla/client/compile_only_client.h +++ b/third_party/xla/xla/client/compile_only_client.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ #define XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ +#include #include #include diff --git a/third_party/xla/xla/client/executable_build_options.h b/third_party/xla/xla/client/executable_build_options.h index e73d9d763102c6..76d5d415f6babf 100644 --- a/third_party/xla/xla/client/executable_build_options.h +++ b/third_party/xla/xla/client/executable_build_options.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ #define XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ +#include #include +#include #include #include #include diff --git a/third_party/xla/xla/client/executable_build_options_test.cc b/third_party/xla/xla/client/executable_build_options_test.cc index f21c64f8922199..cdba65c6aa82f5 100644 --- a/third_party/xla/xla/client/executable_build_options_test.cc +++ b/third_party/xla/xla/client/executable_build_options_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/pjrt/compile_options.pb.h" #include "xla/protobuf_util.h" #include "xla/service/computation_placer.h" #include "xla/service/test_compilation_environment.pb.h" diff --git a/third_party/xla/xla/client/local_client.cc b/third_party/xla/xla/client/local_client.cc index c60804e557ba2a..df3229809034dc 100644 --- a/third_party/xla/xla/client/local_client.cc +++ b/third_party/xla/xla/client/local_client.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/client/local_client.h" +#include #include #include #include diff --git a/third_party/xla/xla/client/local_client.h b/third_party/xla/xla/client/local_client.h index 6216dcf4ba78b3..6cb2dd22355b95 100644 --- a/third_party/xla/xla/client/local_client.h +++ b/third_party/xla/xla/client/local_client.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LOCAL_CLIENT_H_ #define XLA_CLIENT_LOCAL_CLIENT_H_ +#include #include #include #include diff --git a/third_party/xla/xla/codegen/BUILD b/third_party/xla/xla/codegen/BUILD index f776a30dd5946a..9a9147fb32a908 100644 --- a/third_party/xla/xla/codegen/BUILD +++ b/third_party/xla/xla/codegen/BUILD @@ -35,9 +35,11 @@ cc_library( cc_library( name = "llvm_ir_kernel_source", + srcs = ["llvm_ir_kernel_source.cc"], hdrs = ["llvm_ir_kernel_source.h"], deps = [ ":kernel_spec", + "//xla/service/llvm_ir:llvm_util", "@llvm-project//llvm:Core", "@llvm-project//llvm:JITLink", ], diff --git a/third_party/xla/xla/codegen/emitters/BUILD b/third_party/xla/xla/codegen/emitters/BUILD new file mode 100644 index 00000000000000..3017cd31b1d5d8 --- /dev/null +++ b/third_party/xla/xla/codegen/emitters/BUILD @@ -0,0 +1,171 @@ +load("//xla:xla.bzl", "xla_cc_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "computation_partitioner", + srcs = ["computation_partitioner.cc"], + hdrs = ["computation_partitioner.h"], + deps = [ + ":type_util", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/analysis:indexing_analysis", + "//xla/hlo/ir:hlo", + "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", + ], +) + +xla_cc_test( + name = "computation_partitioner_test", + srcs = ["computation_partitioner_test.cc"], + deps = [ + ":computation_partitioner", + "//xla/hlo/analysis:indexing_analysis", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "elemental_hlo_to_mlir", + srcs = ["elemental_hlo_to_mlir.cc"], + hdrs = ["elemental_hlo_to_mlir.h"], + deps = [ + ":computation_partitioner", + ":type_util", + "//xla:comparison_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/codegen/ir:xla", + "//xla/hlo/analysis:indexing_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", + "//xla/hlo/utils:hlo_traversal", + "//xla/mlir_hlo", + "//xla/mlir_hlo:map_mhlo_to_scalar_op", + "//xla/service:algorithm_util", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineUtils", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:VectorDialect", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "elemental_hlo_to_mlir_test", + srcs = ["elemental_hlo_to_mlir_test.cc"], + deps = [ + ":computation_partitioner", + ":elemental_hlo_to_mlir", + "//xla:status_macros", + "//xla/backends/gpu/codegen/ir:xla_gpu", + "//xla/codegen/ir:xla", + "//xla/hlo/analysis:indexing_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:filecheck", + "//xla/mlir_hlo", + "//xla/service/llvm_ir:llvm_util", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:DLTIDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "type_util", + srcs = ["type_util.cc"], + hdrs = ["type_util.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", + "//xla/mlir/utils:type_util", + "@com_google_absl//absl/log:check", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +xla_cc_test( + name = "type_util_test", + srcs = ["type_util_test.cc"], + deps = [ + ":type_util", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc b/third_party/xla/xla/codegen/emitters/computation_partitioner.cc similarity index 98% rename from third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc rename to third_party/xla/xla/codegen/emitters/computation_partitioner.cc index 60abff497bf91a..53ec9f49bada84 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc +++ b/third_party/xla/xla/codegen/emitters/computation_partitioner.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/codegen/emitters/computation_partitioner.h" #include #include @@ -44,19 +44,18 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Support/LLVM.h" +#include "xla/codegen/emitters/type_util.h" #include "xla/hlo/analysis/indexing_analysis.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/fusions/mlir/type_util.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape.h" #include "xla/shape_util.h" namespace xla { -namespace gpu { -namespace mlir_converter { +namespace emitters { namespace { int Arity(const Shape& shape) { @@ -443,6 +442,5 @@ mlir::func::FuncOp CreateSubgraphMlirFunction( return func_op; } -} // namespace mlir_converter -} // namespace gpu +} // namespace emitters } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.h b/third_party/xla/xla/codegen/emitters/computation_partitioner.h similarity index 96% rename from third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.h rename to third_party/xla/xla/codegen/emitters/computation_partitioner.h index d644ee810743d2..41bd0b1b500f45 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.h +++ b/third_party/xla/xla/codegen/emitters/computation_partitioner.h @@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_COMPUTATION_PARTITIONER_H_ -#define XLA_SERVICE_GPU_FUSIONS_MLIR_COMPUTATION_PARTITIONER_H_ +#ifndef XLA_CODEGEN_EMITTERS_COMPUTATION_PARTITIONER_H_ +#define XLA_CODEGEN_EMITTERS_COMPUTATION_PARTITIONER_H_ +#include #include #include #include @@ -32,8 +33,7 @@ limitations under the License. #include "xla/util.h" namespace xla { -namespace gpu { -namespace mlir_converter { +namespace emitters { struct EpilogueSpecification { // Creates an epilogue with output indices matching the given root's shape. @@ -205,8 +205,7 @@ mlir::func::FuncOp CreateSubgraphMlirFunction( const PartitionedComputation::Subgraph& subgraph, mlir::ImplicitLocOpBuilder& b); -} // namespace mlir_converter -} // namespace gpu +} // namespace emitters } // namespace xla -#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_COMPUTATION_PARTITIONER_H_ +#endif // XLA_CODEGEN_EMITTERS_COMPUTATION_PARTITIONER_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner_test.cc b/third_party/xla/xla/codegen/emitters/computation_partitioner_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner_test.cc rename to third_party/xla/xla/codegen/emitters/computation_partitioner_test.cc index ff60dd53ab95ac..39297d8cf9fc81 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner_test.cc +++ b/third_party/xla/xla/codegen/emitters/computation_partitioner_test.cc @@ -12,10 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/codegen/emitters/computation_partitioner.h" #include -#include #include #include @@ -32,8 +31,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" namespace xla { -namespace gpu { -namespace mlir_converter { +namespace emitters { namespace { using ::testing::ElementsAre; @@ -334,6 +332,5 @@ TEST_F(ComputationPartitionerTest, SubgraphSignatures) { } } // namespace -} // namespace mlir_converter -} // namespace gpu +} // namespace emitters } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.cc similarity index 98% rename from third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc rename to third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.cc index 8ec559156ca0bb..f82eeca401ce82 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" +#include "xla/codegen/emitters/elemental_hlo_to_mlir.h" #include #include @@ -61,6 +61,8 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Support/LLVM.h" +#include "xla/codegen/emitters/computation_partitioner.h" +#include "xla/codegen/emitters/type_util.h" #include "xla/codegen/ir/xla_ops.h" #include "xla/comparison_util.h" #include "xla/hlo/analysis/indexing_analysis.h" @@ -75,8 +77,6 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" -#include "xla/service/gpu/fusions/mlir/type_util.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/xla_data.pb.h" @@ -84,8 +84,7 @@ limitations under the License. #include "tsl/platform/statusor.h" namespace xla { -namespace gpu { -namespace mlir_converter { +namespace emitters { namespace { using llvm::SmallVector; @@ -1481,8 +1480,8 @@ ValueRange EmitLoopNestImpl( ValueRange symbol_values, ValueRange iter_args) -> scf::ValueVector { ImplicitLocOpBuilder nested_b(loc, nested_builder); - auto is_in_bounds = mlir_converter::CheckConstraints( - indexing_map, dim_values, symbol_values, nested_b); + auto is_in_bounds = + CheckConstraints(indexing_map, dim_values, symbol_values, nested_b); auto if_op = nested_b.create( is_in_bounds, [&](OpBuilder& then_builder, Location then_loc) -> void { @@ -1534,14 +1533,14 @@ ValueRange EmitLoopNestImpl( } // namespace -ValueRange EmitXlaLoopOp(ImplicitLocOpBuilder& b, ValueRange dim_values, - ValueRange iter_args_inits, - const IndexingMap& indexing_map, - mlir::function_ref( - ValueRange /*ivs*/, ValueRange /*map_results*/, - ValueRange /*iter_args*/)> - create_body, - bool vectorize) { +ValueRange EmitXlaLoopOp( + ImplicitLocOpBuilder& b, ValueRange dim_values, ValueRange iter_args_inits, + const IndexingMap& indexing_map, + mlir::function_ref( + ImplicitLocOpBuilder& nested_b, ValueRange /*ivs*/, + ValueRange /*map_results*/, ValueRange /*iter_args*/)> + create_body, + bool vectorize) { SmallVector vector_inits; if (vectorize) { CHECK_EQ(indexing_map.GetSymbolBounds().back().lower, 0); @@ -1557,6 +1556,7 @@ ValueRange EmitXlaLoopOp(ImplicitLocOpBuilder& b, ValueRange dim_values, } auto bb = [&](OpBuilder& nested_builder, Location loc, ValueRange ivs, ValueRange map_results, ValueRange iter_args) { + ImplicitLocOpBuilder nested_b(loc, nested_builder); SmallVector results; if (vectorize) { SmallVector vector_args; @@ -1564,11 +1564,10 @@ ValueRange EmitXlaLoopOp(ImplicitLocOpBuilder& b, ValueRange dim_values, // Extract the vector elements. for (auto& init : vector_args) { if (mlir::isa(init.getType())) { - init = nested_builder.create(loc, init, - ivs.back()); + init = nested_b.create(init, ivs.back()); } } - results = create_body(ivs, map_results, vector_args); + results = create_body(nested_b, ivs, map_results, vector_args); // Insert the results. for (auto [index, init] : llvm::enumerate(iter_args)) { if (mlir::isa(init.getType())) { @@ -1577,9 +1576,9 @@ ValueRange EmitXlaLoopOp(ImplicitLocOpBuilder& b, ValueRange dim_values, } } } else { - results = create_body(ivs, map_results, iter_args); + results = create_body(nested_b, ivs, map_results, iter_args); } - nested_builder.create(loc, results); + nested_b.create(results); }; return b.create(indexing_map, dim_values, iter_args_inits, bb) .getResults(); @@ -1701,6 +1700,5 @@ SmallVector InlineBlock(OpBuilder& builder, Block& src_block, return mapped_results; } -} // namespace mlir_converter -} // namespace gpu +} // namespace emitters } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h b/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.h similarity index 93% rename from third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h rename to third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.h index 435820148ebc5c..a1767a27b662b8 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h +++ b/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir.h @@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_ELEMENTAL_HLO_TO_MLIR_H_ -#define XLA_SERVICE_GPU_FUSIONS_MLIR_ELEMENTAL_HLO_TO_MLIR_H_ +#ifndef XLA_CODEGEN_EMITTERS_ELEMENTAL_HLO_TO_MLIR_H_ +#define XLA_CODEGEN_EMITTERS_ELEMENTAL_HLO_TO_MLIR_H_ +#include #include #include "absl/status/status.h" @@ -29,16 +30,15 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" +#include "xla/codegen/emitters/computation_partitioner.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_traversal.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/stream_executor/device_description.h" namespace xla { -namespace gpu { -namespace mlir_converter { +namespace emitters { using OperandProvider = std::function>( @@ -110,8 +110,8 @@ mlir::ValueRange EmitXlaLoopOp( mlir::ImplicitLocOpBuilder& b, mlir::ValueRange dim_values, mlir::ValueRange iter_args_inits, const IndexingMap& indexing_map, mlir::function_ref( - mlir::ValueRange ivs, mlir::ValueRange map_results, - mlir::ValueRange iter_args)> + mlir::ImplicitLocOpBuilder& nested_b, mlir::ValueRange ivs, + mlir::ValueRange map_results, mlir::ValueRange iter_args)> create_body, bool vectorize = false); @@ -143,8 +143,7 @@ void GetLoopBoundsFromIndexingMap(mlir::ImplicitLocOpBuilder& b, llvm::SmallVectorImpl* ubs, llvm::SmallVectorImpl* steps); -} // namespace mlir_converter -} // namespace gpu +} // namespace emitters } // namespace xla -#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_ELEMENTAL_HLO_TO_MLIR_H_ +#endif // XLA_CODEGEN_EMITTERS_ELEMENTAL_HLO_TO_MLIR_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc rename to third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir_test.cc index aba8e66e13e9f6..543a0b230108f1 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc +++ b/third_party/xla/xla/codegen/emitters/elemental_hlo_to_mlir_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" +#include "xla/codegen/emitters/elemental_hlo_to_mlir.h" #include #include @@ -37,14 +37,14 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" +#include "xla/codegen/emitters/computation_partitioner.h" #include "xla/codegen/ir/xla_ops.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/filecheck.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/status_macros.h" #include "xla/tests/hlo_test_base.h" @@ -53,8 +53,7 @@ limitations under the License. #include "tsl/platform/statusor.h" namespace xla { -namespace gpu { -namespace mlir_converter { +namespace emitters { namespace { using ::testing::HasSubstr; @@ -1802,6 +1801,5 @@ TEST_F(ElementalHloToMlirTest, BroadcastSelect) { } } // namespace -} // namespace mlir_converter -} // namespace gpu +} // namespace emitters } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/type_util.cc b/third_party/xla/xla/codegen/emitters/type_util.cc similarity index 95% rename from third_party/xla/xla/service/gpu/fusions/mlir/type_util.cc rename to third_party/xla/xla/codegen/emitters/type_util.cc index 76d4b284ebc331..04d3a6613ec1c5 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/type_util.cc +++ b/third_party/xla/xla/codegen/emitters/type_util.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/mlir/type_util.h" +#include "xla/codegen/emitters/type_util.h" #include "absl/log/check.h" #include "llvm/ADT/SmallVector.h" @@ -28,8 +28,7 @@ limitations under the License. #include "xla/xla_data.pb.h" namespace xla { -namespace gpu { -namespace mlir_converter { +namespace emitters { mlir::Type PrimitiveTypeToMlirType(PrimitiveType type, mlir::OpBuilder& b) { if (primitive_util::IsIntegralType(type)) { @@ -82,6 +81,5 @@ llvm::SmallVector ShapeToMlirTypes(const Shape& shape, return types; } -} // namespace mlir_converter -} // namespace gpu +} // namespace emitters } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/type_util.h b/third_party/xla/xla/codegen/emitters/type_util.h similarity index 87% rename from third_party/xla/xla/service/gpu/fusions/mlir/type_util.h rename to third_party/xla/xla/codegen/emitters/type_util.h index 2e9eeae14efb84..60e8a9390aa27b 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/type_util.h +++ b/third_party/xla/xla/codegen/emitters/type_util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_TYPE_UTIL_H_ -#define XLA_SERVICE_GPU_FUSIONS_MLIR_TYPE_UTIL_H_ +#ifndef XLA_CODEGEN_EMITTERS_TYPE_UTIL_H_ +#define XLA_CODEGEN_EMITTERS_TYPE_UTIL_H_ #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Builders.h" @@ -22,8 +22,7 @@ limitations under the License. #include "xla/xla_data.pb.h" namespace xla { -namespace gpu { -namespace mlir_converter { +namespace emitters { // Converts an XLA tensor to an MLIR ranked tensor. The layout is stored in the // encoding attribute, if it is not the default layout. `shape` must be an @@ -42,8 +41,7 @@ mlir::Type PrimitiveTypeToMlirTypeWithSign(PrimitiveType type, llvm::SmallVector ShapeToMlirTypes(const Shape& shape, mlir::OpBuilder& b); -} // namespace mlir_converter -} // namespace gpu +} // namespace emitters } // namespace xla -#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_TYPE_UTIL_H_ +#endif // XLA_CODEGEN_EMITTERS_TYPE_UTIL_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/type_util_test.cc b/third_party/xla/xla/codegen/emitters/type_util_test.cc similarity index 94% rename from third_party/xla/xla/service/gpu/fusions/mlir/type_util_test.cc rename to third_party/xla/xla/codegen/emitters/type_util_test.cc index 63c0454300fd67..c11c4d5f768568 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/type_util_test.cc +++ b/third_party/xla/xla/codegen/emitters/type_util_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/mlir/type_util.h" +#include "xla/codegen/emitters/type_util.h" #include @@ -28,8 +28,7 @@ limitations under the License. #include "xla/xla_data.pb.h" namespace xla { -namespace gpu { -namespace mlir_converter { +namespace emitters { namespace { using ::testing::ElementsAre; @@ -92,6 +91,5 @@ TEST(ShapeTest, ConvertsTuple) { } } // namespace -} // namespace mlir_converter -} // namespace gpu +} // namespace emitters } // namespace xla diff --git a/third_party/xla/xla/codegen/ir/BUILD b/third_party/xla/xla/codegen/ir/BUILD index 078d617b29a910..b6ed803e1e7a13 100644 --- a/third_party/xla/xla/codegen/ir/BUILD +++ b/third_party/xla/xla/codegen/ir/BUILD @@ -118,6 +118,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", ], diff --git a/third_party/xla/xla/codegen/ir/tests/canonicalize.mlir b/third_party/xla/xla/codegen/ir/tests/canonicalize.mlir index ae0e54c70f9ba4..3f81000f6efd74 100644 --- a/third_party/xla/xla/codegen/ir/tests/canonicalize.mlir +++ b/third_party/xla/xla/codegen/ir/tests/canonicalize.mlir @@ -234,7 +234,7 @@ func.func @apply_indexing_move_syms_to_dims(%dim0: index, %sym0: index) // CHECK-NEXT: xla.apply_indexing #[[$MAP]] // CHECK-SAME: (%[[ARG0:.*]], %[[ARG1:.*]]) -// // ----- +// ----- #map0 = #xla.indexing_map<"(d0) -> (4 * d0), domain: d0 in [0, 3]"> #map1 = #xla.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1)," @@ -278,3 +278,51 @@ func.func @loop_of_apply_indexing_with_syms(%dim0: index, %sym0: index, %input: // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index // CHECK: xla.loop (%[[ARG0]], %[[ARG1]]) // CHECK-SAME: in #[[$MAP]] + +// ----- + +#map = #xla.indexing_map<"(th_x, th_y, th_z, bl_x, bl_y, bl_z, p1, p2, p3)[idx]" +"-> ((th_x floordiv 64) * 100 + bl_x * 200 + idx + th_x + th_y + th_z + bl_x + bl_y + bl_z + p1 + p2 + p3)," +"domain:" +"th_x in [0, 127], th_y in [0, 0], th_z in [0, 10]," +"bl_x in [0, 174], bl_y in [2, 2], bl_z in [3, 3], p1 in [1, 5], p2 in [1, 5], p3 in [0,1000]," +"idx in [0, 99], bl_x + bl_y + bl_z in [0, 200]," +"th_x + th_y + th_z + idx in [-1, 200]," +"th_y + bl_y in [0,4],p1+p2+p3 in [0,10]"> + +func.func private @compute(%in: tensor<350xf32>) -> (tensor<350xf32>) + +func.func @fold_constant_dimensions(%input: tensor<350xf32>, %a1 : index) -> (tensor<350xf32>) { + %c1 = arith.constant 4 : index + %c2 = arith.constant 9 : index // Outside of map bounds. + %thread_id_x = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %thread_id_y = gpu.thread_id y {xla.range = [0 : index, 0 : index]} + %thread_id_z = gpu.thread_id z {xla.range = [1 : index, 1 : index]} + %block_id_x = gpu.block_id x {xla.range = [0 : index, 174 : index]} + %block_id_y = gpu.block_id y {xla.range = [2 : index, 2 : index]} + %block_id_z = gpu.block_id z {xla.range = [3 : index, 3 : index]} + + %result = xla.loop (%thread_id_x, %thread_id_y, %thread_id_z, + %block_id_x, %block_id_y, %block_id_z, %c1, %c2, %a1)[%i] -> (%ra) in #map + iter_args(%iter_ = %input) -> (tensor<350xf32>) { + %0 = func.call @compute(%iter_) : (tensor<350xf32>) -> (tensor<350xf32>) + xla.yield %0 : tensor<350xf32> + } + func.return %result : tensor<350xf32> +} + +// CHECK: #[[$MAP:.*]] = #xla.indexing_map<"(th_x, bl_x, p2, p3)[idx] -> ( +// CHECK-SAME: (th_x floordiv 64) * 100 + bl_x * 200 + idx + th_x + bl_x + p2 + p3 + 10) +// CHECK-SAME: domain: th_x in [0, 127], bl_x in [0, 174], +// CHECK-SAME: p2 in [1, 5], p3 in [0, 1000], idx in [0, 99], +// CHECK-SAME: bl_x + 5 in [0, 200], +// CHECK-SAME: p2 + p3 + 4 in [0, 10], +// CHECK-SAME: th_x + idx + 1 in [-1, 200]"> + +// CHECK-LABEL: func.func @fold_constant_dimensions( +// CHECK-SAME: %[[ARG:.*]]: tensor<350xf32>, %[[SCALAR:.*]]: index) +// CHECK: %[[C9:.*]] = arith.constant 9 +// CHECK: %[[TH_X:.*]] = gpu.thread_id x +// CHECK: %[[BL_X:.*]] = gpu.block_id x +// CHECK: xla.loop (%[[TH_X]], %[[BL_X]], %[[C9]], %[[SCALAR]]) +// CHECK-SAME: in #[[$MAP]] diff --git a/third_party/xla/xla/codegen/ir/xla_attrs.cc b/third_party/xla/xla/codegen/ir/xla_attrs.cc index 5f0e416064ec6c..ce84444d4de797 100644 --- a/third_party/xla/xla/codegen/ir/xla_attrs.cc +++ b/third_party/xla/xla/codegen/ir/xla_attrs.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include #include -#include #include #include "llvm/ADT/StringRef.h" diff --git a/third_party/xla/xla/codegen/ir/xla_ops.cc b/third_party/xla/xla/codegen/ir/xla_ops.cc index b77435146232f6..c2e2941d5f3745 100644 --- a/third_party/xla/xla/codegen/ir/xla_ops.cc +++ b/third_party/xla/xla/codegen/ir/xla_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include @@ -26,16 +27,20 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" // IWYU pragma: keep #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep #include "mlir/IR/MLIRContext.h" // IWYU pragma: keep +#include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" @@ -86,6 +91,69 @@ namespace arith = mlir::arith; } // namespace +std::optional GetRange(mlir::Value value) { + auto attr_to_range = [](mlir::Attribute attr) -> std::optional { + if (!attr) { + return std::nullopt; + } + auto values = llvm::to_vector( + mlir::cast(attr).getAsValueRange()); + return {{values[0].getSExtValue(), values[1].getSExtValue()}}; + }; + + if (auto apply = value.getDefiningOp()) { + return apply.getIndexingMap().GetRangeEvaluator().ComputeExpressionRange( + apply.getIndexingMap().GetAffineMap().getResult( + mlir::cast(value).getResultNumber())); + } else if (auto cst = value.getDefiningOp()) { + return {{cst.value(), cst.value()}}; + } else if (value.getDefiningOp()) { + return attr_to_range(value.getDefiningOp()->getAttr("xla.range")); + } + + auto bbarg = mlir::dyn_cast(value); + if (!bbarg) { + return std::nullopt; + } + + auto parent = bbarg.getParentBlock()->getParentOp(); + if (auto func_op = mlir::dyn_cast(parent)) { + return attr_to_range(func_op.getArgAttr(bbarg.getArgNumber(), "xla.range")); + } + return GetIVRange(value); +} + +std::optional GetIVRange(mlir::Value iv) { + auto bbarg = mlir::dyn_cast(iv); + if (!bbarg) { + return std::nullopt; + } + auto parent = bbarg.getParentBlock()->getParentOp(); + if (auto for_op = mlir::dyn_cast(parent)) { + llvm::APInt lb, ub; + if (mlir::matchPattern(for_op.getLowerBound(), mlir::m_ConstantInt(&lb)) && + mlir::matchPattern(for_op.getUpperBound(), mlir::m_ConstantInt(&ub))) { + return {{lb.getSExtValue(), ub.getSExtValue() - 1}}; + } + } + if (auto loop_op = mlir::dyn_cast(parent)) { + const auto& indexing_map = loop_op.getIndexingMap(); + if (bbarg.getArgNumber() >= loop_op.getNumInductionVars() && + bbarg.getArgNumber() < + loop_op.getNumInductionVars() + indexing_map.GetNumResults()) { + RangeEvaluator range_evaluator = indexing_map.GetRangeEvaluator(); + return range_evaluator.ComputeExpressionRange( + indexing_map.GetAffineMap().getResult(bbarg.getArgNumber() - + loop_op.getNumInductionVars())); + } + } + return std::nullopt; +} + +//===----------------------------------------------------------------------===// +// PureCallOp +//===----------------------------------------------------------------------===// + LogicalResult PureCallOp::verifySymbolUses( mlir::SymbolTableCollection& symbolTable) { auto callee = getCalleeAttr(); @@ -256,7 +324,7 @@ absl::StatusOr GetNewIndexingMapAfterFoldingSequence( replacement_expr = getAffineDimExpr(num_dims + added_dim_args.size(), ctx); added_dim_args.push_back(producer_operand.get()); - new_dim_vars.push_back(producer_map.GetDimVars(dim_num)); + new_dim_vars.push_back(producer_map.GetDimVar(dim_num)); } producer_dim_replacements.push_back(replacement_expr); } @@ -462,7 +530,7 @@ struct FoldApplyIndexingOperands } else { new_operands.push_back(operand.get()); dim_replacements.push_back(getAffineDimExpr(new_num_dims++, ctx)); - new_dim_vars.push_back(indexing_map.GetDimVars(operand_id)); + new_dim_vars.push_back(indexing_map.GetDimVar(operand_id)); } } rewriter.replaceOpWithNewOp( @@ -934,6 +1002,72 @@ struct SimplifyLoopOfApplyIndexing : public mlir::OpRewritePattern { } }; +// Folds dimensions that are constants. +// Only works on dimensions assuming as MoveSymbolsToDims has converted symbols +// and runtime variables already. +struct FoldConstantDimensions : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LoopOp loop_op, + PatternRewriter& rewriter) const override { + auto loop_indexing_map = loop_op.getIndexingMap(); + auto ctx = loop_op.getContext(); + int num_dims = loop_indexing_map.GetDimVarsCount(); + + SmallVector used_operands; + used_operands.reserve(num_dims); + std::vector used_dim_vars; + used_dim_vars.reserve(num_dims); + SmallVector dim_replacements; + dim_replacements.reserve(num_dims); + + for (auto [operand, dim_var] : + llvm::zip(loop_op->getOpOperands().take_front(num_dims), + loop_indexing_map.GetDimVars())) { + auto range = GetRange(operand.get()); + // Note that if range is constant we have to check that it is within the + // bounds of the dimension and can be safely replaced. + if (range && range->IsPoint() && dim_var.bounds.Contains(range->lower)) { + dim_replacements.push_back(getAffineConstantExpr(range->lower, ctx)); + } else { + dim_replacements.push_back(getAffineDimExpr(used_dim_vars.size(), ctx)); + used_operands.push_back(operand.get()); + used_dim_vars.push_back(dim_var); + } + } + + if (used_dim_vars.size() == num_dims) { + return rewriter.notifyMatchFailure(loop_op, + "No constant dimensions found"); + } + + auto new_affine_map = + loop_indexing_map.GetAffineMap().replaceDimsAndSymbols( + dim_replacements, {}, used_dim_vars.size(), + loop_indexing_map.GetSymbolCount()); + + llvm::DenseMap new_constraints; + for (auto [expr, interval] : loop_indexing_map.GetConstraints()) { + new_constraints[expr.replaceDims(dim_replacements)] = interval; + } + + IndexingMap new_indexing_map(new_affine_map, std::move(used_dim_vars), + loop_indexing_map.GetRangeVars(), + loop_indexing_map.GetRTVars(), + new_constraints); + + auto new_loop_op = rewriter.create( + loop_op.getLoc(), new_indexing_map, used_operands, loop_op.getInits()); + + Block* original_block = &loop_op.getRegion().front(); + Block* new_block = &new_loop_op.getRegion().front(); + rewriter.mergeBlocks(original_block, new_block, new_block->getArguments()); + rewriter.replaceOp(loop_op, new_loop_op.getResults()); + + return success(); + } +}; + } // namespace VariableConstraints GetConstraintsForVariables(const IndexingMap& map) { @@ -973,7 +1107,7 @@ std::optional parseChainOfStringsAsIndexingMap( void LoopOp::getCanonicalizationPatterns(mlir::RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } } // namespace xla diff --git a/third_party/xla/xla/codegen/ir/xla_ops.h b/third_party/xla/xla/codegen/ir/xla_ops.h index 0888d0485567b3..a13540c921577c 100644 --- a/third_party/xla/xla/codegen/ir/xla_ops.h +++ b/third_party/xla/xla/codegen/ir/xla_ops.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef XLA_CODEGEN_IR_XLA_OPS_H_ #define XLA_CODEGEN_IR_XLA_OPS_H_ +#include #include #include "llvm/ADT/DenseMap.h" @@ -61,6 +62,13 @@ mlir::ParseResult parseOperands( std::optional parseChainOfStringsAsIndexingMap( mlir::AsmParser& parser); +// Returns the range of a given value, if it can be statically determined. +std::optional GetRange(mlir::Value value); + +// Returns the range for the induction variable, if it can be statically +// determined. +std::optional GetIVRange(mlir::Value iv); + } // namespace xla #endif // XLA_CODEGEN_IR_XLA_OPS_H_ diff --git a/third_party/xla/xla/codegen/ir/xla_ops.td b/third_party/xla/xla/codegen/ir/xla_ops.td index 1e32237111587a..11fc208c33cc83 100644 --- a/third_party/xla/xla/codegen/ir/xla_ops.td +++ b/third_party/xla/xla/codegen/ir/xla_ops.td @@ -153,7 +153,7 @@ def XLA_PredicatedExtractOp : XLA_Op<"predicated_extract", TypesMatchWith<"result type matches element type of src", "src", "result", "::llvm::cast($_self).getElementType()">]> { - let summary = "Inserts a value into a tensor if a condition holds"; + let summary = "Extracts a value from a tensor if a condition holds"; let arguments = (ins I1:$condition, AnyType:$fallback, AnyStaticShapeTensor:$src, Variadic:$indices); let results = (outs AnyType:$result); diff --git a/third_party/xla/xla/codegen/kernel_spec.cc b/third_party/xla/xla/codegen/kernel_spec.cc index 7d2dbd2b000520..dba19a7737fa9a 100644 --- a/third_party/xla/xla/codegen/kernel_spec.cc +++ b/third_party/xla/xla/codegen/kernel_spec.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/codegen/kernel_spec.h" #include -#include #include #include diff --git a/third_party/xla/xla/codegen/kernel_spec.h b/third_party/xla/xla/codegen/kernel_spec.h index b3b5680195e90b..1bfea45797f760 100644 --- a/third_party/xla/xla/codegen/kernel_spec.h +++ b/third_party/xla/xla/codegen/kernel_spec.h @@ -17,8 +17,8 @@ limitations under the License. #define XLA_CODEGEN_KERNEL_SPEC_H_ #include -#include #include +#include #include "absl/container/inlined_vector.h" #include "xla/runtime/buffer_use.h" @@ -32,6 +32,9 @@ namespace xla { class KernelSource { public: virtual ~KernelSource() = default; + + // Get a human readable string representation of the kernel source. + virtual std::string ToString() const = 0; }; // KernelSpec is a specification of an XLA kernel produced by the XLA codegen. diff --git a/third_party/xla/xla/codegen/llvm_ir_kernel_source.cc b/third_party/xla/xla/codegen/llvm_ir_kernel_source.cc new file mode 100644 index 00000000000000..bc9af24f45cfce --- /dev/null +++ b/third_party/xla/xla/codegen/llvm_ir_kernel_source.cc @@ -0,0 +1,28 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/codegen/llvm_ir_kernel_source.h" + +#include + +#include "xla/service/llvm_ir/llvm_util.h" + +namespace xla { + +std::string LlvmIrKernelSource::ToString() const { + return llvm_ir::DumpToString(module_.get()); +} + +} // namespace xla diff --git a/third_party/xla/xla/codegen/llvm_ir_kernel_source.h b/third_party/xla/xla/codegen/llvm_ir_kernel_source.h index 2564e3667546f6..0726380b81aa8f 100644 --- a/third_party/xla/xla/codegen/llvm_ir_kernel_source.h +++ b/third_party/xla/xla/codegen/llvm_ir_kernel_source.h @@ -31,9 +31,9 @@ namespace xla { // implementation we might emit a single LLVM module with multiple kernels or a // separate LLVM module for each kernel. Kernel function signature is defined by // the backend specific ABI. -class LlvmIrKernelSource : public KernelSource { +class LlvmIrKernelSource final : public KernelSource { public: - LlvmIrKernelSource(std::unique_ptr context, + LlvmIrKernelSource(llvm::orc::ThreadSafeContext context, std::unique_ptr module, std::string kernel_name) : context_(std::move(context)), @@ -44,7 +44,7 @@ class LlvmIrKernelSource : public KernelSource { LlvmIrKernelSource& operator=(LlvmIrKernelSource&& other) = default; llvm::orc::ThreadSafeModule thread_safe_module() && { - return llvm::orc::ThreadSafeModule(std::move(module_), std::move(context_)); + return llvm::orc::ThreadSafeModule(std::move(module_), context_); } const std::string& kernel_name() const { return kernel_name_; } @@ -53,8 +53,10 @@ class LlvmIrKernelSource : public KernelSource { return module_->getFunction(kernel_name_); } + std::string ToString() const final; + private: - std::unique_ptr context_; + llvm::orc::ThreadSafeContext context_; std::unique_ptr module_; std::string kernel_name_; }; diff --git a/third_party/xla/xla/codegen/testlib/BUILD b/third_party/xla/xla/codegen/testlib/BUILD index 60db48782ecd50..e0caabe9f8f8d9 100644 --- a/third_party/xla/xla/codegen/testlib/BUILD +++ b/third_party/xla/xla/codegen/testlib/BUILD @@ -33,34 +33,46 @@ cc_library( ) tsl_pybind_extension( - name = "kernel_runner_extention", + name = "_extension", testonly = 1, - srcs = ["kernel_runner_extention.cc"], - visibility = ["//visibility:private"], # the extention should always be linked via kernel_runner_pylib + srcs = ["kernel_runner_extension.cc"], + visibility = ["//visibility:private"], # the extension should always be linked via testlib deps = [ ":kernel_runner", - # placeholder for index annotation deps + # placeholder for index annotation deps # buildcleaner: keep "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@nanobind", "@local_config_python//:python_headers", # buildcleaner: keep + "//xla:comparison_util", "//xla:literal", + "//xla:shape_util", "//xla:util", "//xla/codegen:kernel_emitter", "//xla/codegen:kernel_spec", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/python:nb_absl_inlined_vector", + "//xla/python:nb_absl_span", + "//xla/service:buffer_assignment", ], ) pytype_strict_library( - name = "kernel_runner_pylib", + name = "testlib", testonly = 1, - srcs = ["kernel_runner.py"], + srcs = [ + "__init__.py", + "utilities.py", + ], srcs_version = "PY3", deps = [ - ":kernel_runner_extention", + ":_extension", "//third_party/py/numpy", "//xla/python:xla_extension", + "@ml_dtypes", # buildcleaner: keep (transitively depend on it via xla_extension) ], ) @@ -74,7 +86,8 @@ py_strict_test( "no_oss", ], deps = [ - ":kernel_runner_pylib", + ":_extension", + ":testlib", "//third_party/py/numpy", "@absl_py//absl/testing:absltest", ], diff --git a/third_party/xla/xla/codegen/testlib/__init__.py b/third_party/xla/xla/codegen/testlib/__init__.py index e69de29bb2d1d6..9f33a797b5b384 100644 --- a/third_party/xla/xla/codegen/testlib/__init__.py +++ b/third_party/xla/xla/codegen/testlib/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Public API for codegen testlib.""" + +from xla.codegen.testlib import _extension + +# Classes +# go/keep-sorted start +BufferAssignment = _extension.BufferAssignment +ComparisonDirection = _extension.ComparisonDirection +HloInstruction = _extension.HloInstruction +HloModule = _extension.HloModule +HloOpcode = _extension.HloOpcode +KernelEmmitter = _extension.KernelEmitter +KernelRunner = _extension.KernelRunner +KernelSpec = _extension.KernelSpec +# go/keep-sorted end diff --git a/third_party/xla/xla/codegen/testlib/kernel_runner_extension.cc b/third_party/xla/xla/codegen/testlib/kernel_runner_extension.cc new file mode 100644 index 00000000000000..2f51e7a776013b --- /dev/null +++ b/third_party/xla/xla/codegen/testlib/kernel_runner_extension.cc @@ -0,0 +1,220 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/codegen/kernel_emitter.h" +#include "xla/codegen/kernel_spec.h" +#include "xla/codegen/testlib/kernel_runner.h" +#include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/literal.h" +#include "xla/python/nb_absl_inlined_vector.h" // IWYU pragma: keep +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/service/buffer_assignment.h" +#include "xla/shape.h" +#include "xla/util.h" + +namespace xla { + +namespace { + +// Use `std::vector` instead of `absl::Span` to take +// advantage of the built in bindings. +void KernelRunnerCall(KernelRunner* kernel_runner, + std::vector literals) { + absl::Status status = kernel_runner->Call(absl::MakeSpan(literals)); + if (!status.ok()) { + throw std::runtime_error(std::string(status.message())); + } +} + +// Need this helper as Literal rquires an explicit clone. +std::unique_ptr CreateConstantHloInstruction( + const Literal& literal) { + return HloInstruction::CreateConstant(literal.Clone()); +} + +std::unique_ptr CreateComparisonHloInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + Comparison::Direction direction) { + return HloInstruction::CreateCompare(shape, lhs, rhs, direction); +} + +// A dummy kernel runner that implements a simple elementwise add. +class DummyAddKernelRunner final : public KernelRunner { + public: + absl::Status Call(absl::Span arguments) override { + if (arguments.size() != 3) { + return InvalidArgument("Expected 3 arguments, got %u", arguments.size()); + } + + if (arguments[0].size() != arguments[1].size()) { + return InvalidArgument( + "Expected argument 0 to be the same size as argument " + "1, got %u and %u", + arguments[0].size(), arguments[1].size()); + } + + if (arguments[1].size() != arguments[2].size()) { + return InvalidArgument( + "Expected argument 1 to be the same size as argument " + "2, got %u and %u", + arguments[1].size(), arguments[2].size()); + } + + constexpr size_t element_bytes = sizeof(int32_t); + + if ((arguments[0].size() % element_bytes) != 0) { + return InvalidArgument( + "Expected arguments to be a multiple of %u bytes, got %u", + element_bytes, arguments[0].size()); + } + + size_t num_elements = arguments[0].size() / element_bytes; + + auto* in_arg1 = reinterpret_cast(arguments[0].data()); + auto* in_arg2 = reinterpret_cast(arguments[1].data()); + auto* out_arg = reinterpret_cast(arguments[2].data()); + + for (int i = 0; i < num_elements; ++i) { + out_arg[i] = in_arg1[i] + in_arg2[i]; + } + + return absl::OkStatus(); + } +}; + +} // namespace + +NB_MODULE(_extension, kernel_runner_module) { + namespace nb = nanobind; + + nb::class_(kernel_runner_module, "KernelSource") + .def("__str__", &KernelSource::ToString); + + nb::class_(kernel_runner_module, "KernelSpec") + .def("kernel_source", &KernelSpec::kernel_source, + nb::rv_policy::reference_internal); + + nb::class_(kernel_runner_module, "KernelEmitter") + .def("emit_kernel_spec", [](KernelEmitter* self) { + absl::StatusOr> spec = + self->EmitKernelSpec(); + if (!spec.ok()) { + throw std::runtime_error(std::string(spec.status().message())); + } + return std::move(spec).value(); + }); + + nb::class_(kernel_runner_module, "KernelRunner") + .def("call", &KernelRunnerCall); + + nb::class_(kernel_runner_module, + "DummyAddKernelRunner") + .def(nb::init<>()); + + nb::enum_ hlo_opcode(kernel_runner_module, "HloOpcode"); +#define DECLARE_ENUM(enum_name, opcode_name, ...) \ + hlo_opcode.value(absl::StrReplaceAll(opcode_name, {{"-", "_"}}).c_str(), \ + HloOpcode::enum_name); + HLO_OPCODE_LIST(DECLARE_ENUM) +#undef DECLARE_ENUM + + kernel_runner_module.def("opcode_arity", &HloOpcodeArity); + + nb::enum_(kernel_runner_module, "ComparisonDirection") + .value("kEq", Comparison::Direction::kEq) + .value("kNe", Comparison::Direction::kNe) + .value("kGe", Comparison::Direction::kGe) + .value("kGt", Comparison::Direction::kGt) + .value("kLe", Comparison::Direction::kLe) + .value("kLt", Comparison::Direction::kLt); + + nb::class_ hlo_instruction(kernel_runner_module, + "HloInstruction"); + // Factory methods + hlo_instruction + .def_static("create_parameter", &HloInstruction::CreateParameter) + .def_static("create_constant", &CreateConstantHloInstruction) + .def_static("create_unary", &HloInstruction::CreateUnary) + .def_static("create_binary", &HloInstruction::CreateBinary) + .def_static("create_ternary", &HloInstruction::CreateTernary) + .def_static("create_variadic", &HloInstruction::CreateVariadic) + .def_static("create_compare", &CreateComparisonHloInstruction); + + // Accessors + hlo_instruction.def("opcode", &HloInstruction::opcode); + hlo_instruction.def("shape", &HloInstruction::shape); + hlo_instruction.def("operands", &HloInstruction::operands, + nb::rv_policy::reference_internal); + hlo_instruction.def( + "__str__", [](const HloInstruction& self) { return self.ToString(); }); + + nb::class_(kernel_runner_module, "BufferAssignment") + .def("__str__", &BufferAssignment::ToString); + + nb::class_(kernel_runner_module, "HloSchedule") + .def("__str__", &HloSchedule::ToString); + + nb::class_(kernel_runner_module, "HloModule") + .def_static("parse_from_string", + [](absl::string_view str) { + absl::StatusOr> hlo_module = + ParseAndReturnUnverifiedModule(str); + + if (!hlo_module.ok()) { + throw std::runtime_error( + std::string(hlo_module.status().message())); + } + + return std::move(hlo_module).value(); + }) + .def("set_schedule", + [](HloModule& self, HloSchedule schedule) { + absl::Status status = self.set_schedule(std::move(schedule)); + if (!status.ok()) { + throw std::runtime_error(std::string(status.message())); + } + }) + .def( + "get_root_instruction", + [](HloModule* self) { + return self->entry_computation()->root_instruction(); + }, + nb::rv_policy::reference_internal); +} + +} // namespace xla diff --git a/third_party/xla/xla/codegen/testlib/kernel_runner_extention.cc b/third_party/xla/xla/codegen/testlib/kernel_runner_extention.cc deleted file mode 100644 index 8a4eb07c83f893..00000000000000 --- a/third_party/xla/xla/codegen/testlib/kernel_runner_extention.cc +++ /dev/null @@ -1,118 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "nanobind/nanobind.h" -#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep -#include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "xla/codegen/kernel_emitter.h" -#include "xla/codegen/kernel_spec.h" -#include "xla/codegen/testlib/kernel_runner.h" -#include "xla/literal.h" -#include "xla/util.h" - -namespace xla { - -namespace { - -// Use `std::vector` instead of `absl::Span` to take -// advantage of the built in bindings. -void KernelRunnerCall(KernelRunner* kernel_runner, - std::vector literals) { - absl::Status status = kernel_runner->Call(absl::MakeSpan(literals)); - if (!status.ok()) { - throw std::runtime_error(std::string(status.message())); - } -} - -// A dummy kernel runner that implements a simple elementwise add. -class DummyAddKernelRunner final : public KernelRunner { - public: - absl::Status Call(absl::Span arguments) override { - if (arguments.size() != 3) { - return InvalidArgument("Expected 3 arguments, got %u", arguments.size()); - } - - if (arguments[0].size() != arguments[1].size()) { - return InvalidArgument( - "Expected argument 0 to be the same size as argument " - "1, got %u and %u", - arguments[0].size(), arguments[1].size()); - } - - if (arguments[1].size() != arguments[2].size()) { - return InvalidArgument( - "Expected argument 1 to be the same size as argument " - "2, got %u and %u", - arguments[1].size(), arguments[2].size()); - } - - constexpr size_t element_bytes = sizeof(int32_t); - - if ((arguments[0].size() % element_bytes) != 0) { - return InvalidArgument( - "Expected arguments to be a multiple of %u bytes, got %u", - element_bytes, arguments[0].size()); - } - - size_t num_elements = arguments[0].size() / element_bytes; - - auto* in_arg1 = reinterpret_cast(arguments[0].data()); - auto* in_arg2 = reinterpret_cast(arguments[1].data()); - auto* out_arg = reinterpret_cast(arguments[2].data()); - - for (int i = 0; i < num_elements; ++i) { - out_arg[i] = in_arg1[i] + in_arg2[i]; - } - - return absl::OkStatus(); - } -}; - -} // namespace - -NB_MODULE(kernel_runner_extention, kernel_runner_module) { - namespace nb = nanobind; - - nb::class_(kernel_runner_module, "KernelSpec"); - - nb::class_(kernel_runner_module, "KernelEmitter") - .def("emit_kernel_spec", [](KernelEmitter* self) { - absl::StatusOr> spec = - self->EmitKernelSpec(); - if (!spec.ok()) { - throw std::runtime_error(std::string(spec.status().message())); - } - return std::move(spec).value(); - }); - - nb::class_(kernel_runner_module, "KernelRunner") - .def("call", &KernelRunnerCall); - - nb::class_(kernel_runner_module, - "DummyAddKernelRunner") - .def(nb::init<>()); -} - -} // namespace xla diff --git a/third_party/xla/xla/codegen/testlib/kernel_runner_test.py b/third_party/xla/xla/codegen/testlib/kernel_runner_test.py index 0cdd81ece4d286..dda24e4d34a37b 100644 --- a/third_party/xla/xla/codegen/testlib/kernel_runner_test.py +++ b/third_party/xla/xla/codegen/testlib/kernel_runner_test.py @@ -15,9 +15,11 @@ from absl.testing import absltest import numpy as np -from xla.codegen.testlib import kernel_runner +from xla.codegen.testlib import _extension +from xla.codegen.testlib import utilities as testlib_utilities -create_literal = kernel_runner.create_literal_from_np + +create_literal = testlib_utilities.create_literal_from_np class LiteralFromNpTest(absltest.TestCase): @@ -31,7 +33,7 @@ def test_output_same_as_input(self): class DummyKernelRunnerTest(absltest.TestCase): def test_dummy_kernel(self): - runner = kernel_runner.DummyAddKernelRunner() + runner = _extension.DummyAddKernelRunner() in_arg1 = create_literal(np.array([1, 2, 3, 4], dtype=np.int32)) in_arg2 = create_literal(np.array([5, 6, 7, 8], dtype=np.int32)) out_arg = create_literal(np.array([0, 0, 0, 0], dtype=np.int32)) diff --git a/third_party/xla/xla/codegen/testlib/kernel_runner.py b/third_party/xla/xla/codegen/testlib/utilities.py similarity index 65% rename from third_party/xla/xla/codegen/testlib/kernel_runner.py rename to third_party/xla/xla/codegen/testlib/utilities.py index 11ddd15396ad1e..1ae1b15ae0c958 100644 --- a/third_party/xla/xla/codegen/testlib/kernel_runner.py +++ b/third_party/xla/xla/codegen/testlib/utilities.py @@ -12,22 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Base classes for running kernels.""" +"""Boilerplate utilities for kernel testing.""" import numpy as np -from xla.codegen.testlib import kernel_runner_extention +from xla.codegen.testlib import _extension from xla.python import xla_extension -KernelSpec = kernel_runner_extention.KernelSpec -KernelEmmitter = kernel_runner_extention.KernelEmitter -KernelRunner = kernel_runner_extention.KernelRunner -DummyAddKernelRunner = kernel_runner_extention.DummyAddKernelRunner +def create_scalar_literal(value, dtype: np.dtype) -> xla_extension.Literal: + shape = xla_extension.Shape.scalar_shape(dtype) + literal = xla_extension.Literal(shape) + np.copyto(np.asarray(literal), value) + return literal def create_literal_from_np(array: np.ndarray) -> xla_extension.Literal: + if np.ndim(array) == 0: + return create_scalar_literal(array.item(), array.dtype) + shape = xla_extension.Shape.array_shape(array.dtype, array.shape) literal = xla_extension.Literal(shape) np.copyto(np.asarray(literal), array) return literal + + +# Intentionally rexport-ed to be avalable in the public API. +opcode_arity = _extension.opcode_arity diff --git a/third_party/xla/xla/codegen/tools/BUILD b/third_party/xla/xla/codegen/tools/BUILD index 96e73bff4f1668..bfec97ab9be11b 100644 --- a/third_party/xla/xla/codegen/tools/BUILD +++ b/third_party/xla/xla/codegen/tools/BUILD @@ -13,16 +13,20 @@ xla_cc_binary( # symlinked from the lit_lib directory. linkopts = ["-Wl,-rpath,$$ORIGIN/../lit_lib"], visibility = [ + "//xla/backends/cpu/codegen:__subpackages__", + "//xla/backends/gpu/codegen:__subpackages__", "//xla/codegen/ir/tests:__subpackages__", "//xla/service/gpu/fusions:__subpackages__", ], deps = [ + "//xla/backends/cpu/codegen/ir:xla_cpu", + "//xla/backends/cpu/codegen/transforms:passes", + "//xla/backends/gpu/codegen/ir:xla_gpu", + "//xla/backends/gpu/codegen/transforms:passes", "//xla/codegen/ir:xla", "//xla/mlir_hlo", "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", - "//xla/service/gpu/fusions/transforms:passes", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", diff --git a/third_party/xla/xla/codegen/tools/emitters_opt.cc b/third_party/xla/xla/codegen/tools/emitters_opt.cc index 940a655245faba..6d88aa371d95cf 100644 --- a/third_party/xla/xla/codegen/tools/emitters_opt.cc +++ b/third_party/xla/xla/codegen/tools/emitters_opt.cc @@ -34,29 +34,32 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Transforms/Passes.h" +#include "xla/backends/cpu/codegen/ir/xla_cpu_dialect.h" +#include "xla/backends/cpu/codegen/transforms/passes.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" +#include "xla/backends/gpu/codegen/transforms/passes.h" #include "xla/codegen/ir/xla_ops.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" -#include "xla/service/gpu/fusions/transforms/passes.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" int main(int argc, char** argv) { mlir::DialectRegistry registry; - registry.insert(); + registry.insert< + mlir::DLTIDialect, mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect, + mlir::affine::AffineDialect, mlir::arith::ArithDialect, + mlir::complex::ComplexDialect, mlir::func::FuncDialect, + mlir::gpu::GPUDialect, mlir::math::MathDialect, mlir::mhlo::MhloDialect, + mlir::mhlo::MhloDialect, mlir::scf::SCFDialect, + mlir::tensor::TensorDialect, mlir::vector::VectorDialect, xla::XlaDialect, + xla::cpu::XlaCpuDialect, xla::gpu::XlaGpuDialect>(); mlir::func::registerAllExtensions(registry); mlir::LLVM::registerInlinerInterface(registry); mlir::registerCanonicalizerPass(); mlir::registerCSEPass(); mlir::registerInliner(); xla::gpu::registerGpuFusionTransformsPasses(); + xla::cpu::registerXlaCpuTransformsPasses(); mlir::registerPassPipeline( "xla-gpu-test-optimize", "Test pipeline of passes up to inlining. No vectorization, also does not " diff --git a/third_party/xla/xla/comparison_util_test.cc b/third_party/xla/xla/comparison_util_test.cc index 1581569a5d284c..f41db68363d953 100644 --- a/third_party/xla/xla/comparison_util_test.cc +++ b/third_party/xla/xla/comparison_util_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" #include "xla/types.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/core/collectives/BUILD b/third_party/xla/xla/core/collectives/BUILD index b0f77890ebd584..190a9d17acd6f7 100644 --- a/third_party/xla/xla/core/collectives/BUILD +++ b/third_party/xla/xla/core/collectives/BUILD @@ -53,8 +53,10 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", @@ -65,11 +67,14 @@ cc_library( name = "communicator", hdrs = ["communicator.h"], deps = [ + ":rank_id", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/service:collective_ops_utils", "//xla/stream_executor:device_memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", ], ) @@ -78,9 +83,8 @@ cc_library( srcs = ["clique_id.cc"], hdrs = ["clique_id.h"], deps = [ - "//xla:util", "@com_google_absl//absl/crc:crc32c", - "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/core/collectives/clique.cc b/third_party/xla/xla/core/collectives/clique.cc index 6eb73c1ea91cba..1a0a5d659aecba 100644 --- a/third_party/xla/xla/core/collectives/clique.cc +++ b/third_party/xla/xla/core/collectives/clique.cc @@ -21,8 +21,10 @@ limitations under the License. #include "absl/container/btree_map.h" #include "absl/functional/function_ref.h" +#include "absl/status/status.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" +#include "xla/util.h" namespace xla { @@ -44,4 +46,13 @@ void Clique::ForEachComm( } } +absl::Status Clique::AddComm(RankId rank, + std::unique_ptr communicator) { + auto emplaced = communicators_.emplace(rank, std::move(communicator)); + if (!emplaced.second) { + return InvalidArgument("Rank %d already exists in clique", rank.value()); + } + return absl::OkStatus(); +} + } // namespace xla diff --git a/third_party/xla/xla/core/collectives/clique.h b/third_party/xla/xla/core/collectives/clique.h index 69705ccfa524c5..24f80a3f1682c9 100644 --- a/third_party/xla/xla/core/collectives/clique.h +++ b/third_party/xla/xla/core/collectives/clique.h @@ -49,6 +49,9 @@ class Clique { // Returns a communicator for a given rank if it's in a clique. std::optional comm(RankId rank) const; + // Adds a communicator to the clique. + absl::Status AddComm(RankId rank, std::unique_ptr communicator); + // Calls `fn` for each communicator in the clique. void ForEachComm(absl::FunctionRef fn) const; @@ -61,8 +64,8 @@ class Clique { size_t num_communicators() const { return communicators_.size(); } private: - // We keep communicators in a sorted order by rank to guarantee deterministic - // traversal order in `ForEachComm`. + // We keep communicators in a sorted order by rank to guarantee + // deterministic traversal order in `ForEachComm`. absl::btree_map> communicators_; }; diff --git a/third_party/xla/xla/core/collectives/clique_id.cc b/third_party/xla/xla/core/collectives/clique_id.cc index b58e8ea54191d8..f59b7ce5999692 100644 --- a/third_party/xla/xla/core/collectives/clique_id.cc +++ b/third_party/xla/xla/core/collectives/clique_id.cc @@ -18,14 +18,14 @@ limitations under the License. #include #include #include -#include #include "absl/crc/crc32c.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" namespace xla { -CliqueId::CliqueId(std::string_view data) : data_(data.begin(), data.end()) {} +CliqueId::CliqueId(absl::string_view data) : data_(data.begin(), data.end()) {} absl::Span CliqueId::data() const { return data_; } @@ -34,7 +34,7 @@ std::string CliqueId::ToString() const { } uint32_t CliqueId::fingerprint() const { - std::string_view data_view(data_.data(), data_.size()); + absl::string_view data_view(data_.data(), data_.size()); return static_cast(absl::ComputeCrc32c(data_view)); } diff --git a/third_party/xla/xla/core/collectives/clique_id.h b/third_party/xla/xla/core/collectives/clique_id.h index c9d56a49cacadf..104e1dbde2d9c8 100644 --- a/third_party/xla/xla/core/collectives/clique_id.h +++ b/third_party/xla/xla/core/collectives/clique_id.h @@ -19,9 +19,9 @@ limitations under the License. #include #include #include -#include #include +#include "absl/strings/string_view.h" #include "absl/types/span.h" namespace xla { @@ -40,7 +40,7 @@ class CliqueId { public: CliqueId() = default; - explicit CliqueId(std::string_view data); + explicit CliqueId(absl::string_view data); absl::Span data() const; std::string ToString() const; diff --git a/third_party/xla/xla/core/collectives/clique_key.cc b/third_party/xla/xla/core/collectives/clique_key.cc index 2da8d6651c3548..1ff3c355dbe9c2 100644 --- a/third_party/xla/xla/core/collectives/clique_key.cc +++ b/third_party/xla/xla/core/collectives/clique_key.cc @@ -15,8 +15,8 @@ limitations under the License. #include "xla/core/collectives/clique_key.h" +#include #include -#include #include #include "absl/algorithm/container.h" @@ -26,11 +26,13 @@ limitations under the License. namespace xla { -CliqueKey::CliqueKey(std::vector devices) - : devices_(std::move(devices)) {} +CliqueKey::CliqueKey(absl::Span devices) + : devices_(devices.begin(), devices.end()) {} absl::Span CliqueKey::devices() const { return devices_; } +size_t CliqueKey::num_devices() const { return devices_.size(); } + std::optional CliqueKey::rank(GlobalDeviceId id) const { if (auto it = absl::c_find(devices_, id); it != devices_.end()) { return RankId(it - devices_.begin()); diff --git a/third_party/xla/xla/core/collectives/clique_key.h b/third_party/xla/xla/core/collectives/clique_key.h index 05411773431507..7e5fddbbb2e674 100644 --- a/third_party/xla/xla/core/collectives/clique_key.h +++ b/third_party/xla/xla/core/collectives/clique_key.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_CORE_COLLECTIVES_CLIQUE_KEY_H_ #define XLA_CORE_COLLECTIVES_CLIQUE_KEY_H_ +#include #include #include #include @@ -39,7 +40,7 @@ namespace xla { // these cliques launch operations (device kernels) on different device streams. class CliqueKey { public: - explicit CliqueKey(std::vector devices); + explicit CliqueKey(absl::Span devices); virtual ~CliqueKey() = default; CliqueKey(const CliqueKey& other) = default; @@ -52,6 +53,7 @@ class CliqueKey { std::optional rank(GlobalDeviceId id) const; absl::Span devices() const; + size_t num_devices() const; // Returns true if this clique is a subset of `other`. virtual bool IsSubsetOf(const CliqueKey& other) const = 0; diff --git a/third_party/xla/xla/core/collectives/collectives.h b/third_party/xla/xla/core/collectives/collectives.h index 4b41a0dd440816..68f061252b94c7 100644 --- a/third_party/xla/xla/core/collectives/collectives.h +++ b/third_party/xla/xla/core/collectives/collectives.h @@ -70,7 +70,7 @@ class Collectives { // Creates communicators for given clique key and id. virtual absl::StatusOr>> - CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, + CreateCommunicators(const CliqueKey& clique_key, const std::optional& clique_id, absl::Span ranks, const Config& config) = 0; diff --git a/third_party/xla/xla/core/collectives/collectives_registry.cc b/third_party/xla/xla/core/collectives/collectives_registry.cc index e42da891cdeccb..83f40ec337305a 100644 --- a/third_party/xla/xla/core/collectives/collectives_registry.cc +++ b/third_party/xla/xla/core/collectives/collectives_registry.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -28,6 +27,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "xla/core/collectives/collectives.h" #include "xla/service/platform_util.h" @@ -65,7 +65,7 @@ static Registry& GetCollectivesRegistry() { } absl::Status CollectivesRegistry::Register( - std::string_view platform_name, std::string_view name, int32_t priority, + absl::string_view platform_name, absl::string_view name, int32_t priority, std::unique_ptr collectives) { TF_ASSIGN_OR_RETURN(std::string canonical_platform_name, PlatformUtil::CanonicalPlatformName(platform_name)); @@ -83,7 +83,7 @@ absl::Status CollectivesRegistry::Register( } absl::StatusOr CollectivesRegistry::Default( - std::string_view platform_name) { + absl::string_view platform_name) { TF_ASSIGN_OR_RETURN(std::string canonical_platform_name, PlatformUtil::CanonicalPlatformName(platform_name)); diff --git a/third_party/xla/xla/core/collectives/collectives_registry.h b/third_party/xla/xla/core/collectives/collectives_registry.h index eb9549f6d435a9..558deb647243b5 100644 --- a/third_party/xla/xla/core/collectives/collectives_registry.h +++ b/third_party/xla/xla/core/collectives/collectives_registry.h @@ -18,11 +18,12 @@ limitations under the License. #include #include -#include #include "absl/base/attributes.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/core/collectives/collectives.h" #include "tsl/platform/logging.h" @@ -38,12 +39,12 @@ class CollectivesRegistry { // the given platform. Higher priority wins. // // Returns an error if the implementation is already registered. - static absl::Status Register(std::string_view platform_name, - std::string_view name, int32_t priority, + static absl::Status Register(absl::string_view platform_name, + absl::string_view name, int32_t priority, std::unique_ptr collectives); // Returns the default collectives implementation for the given platform. - static absl::StatusOr Default(std::string_view platform_name); + static absl::StatusOr Default(absl::string_view platform_name); }; } // namespace xla diff --git a/third_party/xla/xla/core/collectives/communicator.h b/third_party/xla/xla/core/collectives/communicator.h index 14b5cdb8c0f432..af95f7063fc803 100644 --- a/third_party/xla/xla/core/collectives/communicator.h +++ b/third_party/xla/xla/core/collectives/communicator.h @@ -17,15 +17,18 @@ limitations under the License. #define XLA_CORE_COLLECTIVES_COMMUNICATOR_H_ #include -#include #include +#include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" #include "xla/stream_executor/device_memory.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla { @@ -51,23 +54,24 @@ class Communicator { virtual absl::Status Unregister() = 0; }; + // Register `buffer` for efficient collective operations (i.e. on NCCL backend + // it registers the buffer for zero-copy collective operations). + virtual absl::StatusOr> + RegisterBuffer(stream_executor::DeviceMemoryBase buffer) { + return Unimplemented("User-managed buffer registration is not supported"); + } + // Abort any uncompleted operations and destroys the underlying communicator // object. It is undefined behavior to use the communicator after calling // this method. - virtual absl::Status Abort() = 0; + virtual absl::Status Abort() { + return Unimplemented("Aborting communicator is not implemented"); + } // Checks the health of the communicator. It might return an error from the // previously launched asynchronous collective operations, and it does not // have to wait for the completion of scheduled operations. - virtual absl::Status HealthCheck() const = 0; - - // Returns the number of ranks in the communicator. - virtual absl::StatusOr NumRanks() const = 0; - - // Register `buffer` for efficient collective operations (i.e. on NCCL backend - // it registers the buffer for zero-copy collective operations). - virtual absl::StatusOr> - RegisterBuffer(stream_executor::DeviceMemoryBase buffer) = 0; + virtual absl::Status HealthCheck() const { return absl::OkStatus(); } // Reduce buffers of length `count` in `send_buff` using `reduction_kind` // reduction and leaves identical copies of the result on each `recv_buff`. @@ -81,7 +85,7 @@ class Communicator { // all other devices. virtual absl::Status Broadcast(se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, size_t root, + PrimitiveType dtype, size_t count, RankId root, const Executor& executor) = 0; // Reduce data in `send_buff` from all devices using the `reduction_kind` @@ -91,7 +95,6 @@ class Communicator { se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, ReductionKind reduction_kind, - const Executor& executor) = 0; // Gather `count` values from all devices into `recv_buffer`, receiving data @@ -101,24 +104,37 @@ class Communicator { PrimitiveType dtype, size_t count, const Executor& executor) = 0; + // Sends data from `send_buffer` to `target_ranks` and receives data from + // `source_rank` into `recv_buffer`. If `source_rank` is not specified, the + // output is filled with zeros. + virtual absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + std::optional source_rank, + absl::Span target_ranks, + const Executor& executor) = 0; + + // Sends `count` values from `send_buffers` to other ranks and receives data + // from other ranks into `recv_buffers`. + virtual absl::Status AllToAll( + absl::Span send_buffers, + absl::Span recv_buffers, PrimitiveType dtype, + size_t count, const Executor& executor) = 0; + // Send data from `send_buff` to rank `peer`. virtual absl::Status Send(se::DeviceMemoryBase send_buffer, - PrimitiveType dtype, size_t count, int32_t peer, + PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) = 0; - // Send a pointer `ptr` to rank `peer`. - virtual absl::Status SendPtrToPeer(void* ptr, int32_t peer, - const Executor& executor) = 0; - // Receive data from rank `peer` into `recv_buff`. virtual absl::Status Recv(se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, int32_t peer, + PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) = 0; - // Receive a pointer from rank `peer` into `ptr`. - virtual absl::Status RecvPtrFromPeer(void* ptr, int32_t peer, - const Executor& executor) = 0; + // Returns the number of ranks in the communicator. + virtual absl::StatusOr NumRanks() const = 0; + // Returns a human-readable description of the communicator. virtual std::string ToString() const = 0; }; diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 69c234bc49e86b..4821b83c25a442 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/ascii.h" #include "absl/strings/match.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -78,6 +79,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_dump_hlo_as_long_text(false); opts.set_xla_dump_large_constants(false); opts.set_xla_dump_enable_mlir_pretty_form(true); + opts.set_xla_gpu_unsupported_annotate_with_emitter_loc(false); opts.set_xla_debug_buffer_assignment_show_max(15); #ifdef ENABLE_MKL opts.set_xla_cpu_use_mkl_dnn(true); @@ -86,6 +88,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_use_acl(true); #endif opts.set_xla_cpu_use_thunk_runtime(true); + opts.set_xla_cpu_use_xnnpack(false); opts.set_xla_cpu_parallel_codegen_split_count(32); opts.set_xla_cpu_copy_insertion_use_region_analysis(false); opts.set_xla_cpu_enable_concurrency_optimized_scheduler(false); @@ -124,6 +127,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_graph_enable_concurrent_region(false); opts.set_xla_cmd_buffer_trace_cache_size(16); + opts.set_xla_gpu_collectives_use_persistent_cliques(false); + // Despite the name, fast min/max on GPUs does not seem to be any faster, and // adds very counter-intuitive "NaN-swallowing" behavior. opts.set_xla_gpu_enable_fast_min_max(false); @@ -166,8 +171,26 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_shape_checks(DebugOptions::RUNTIME); opts.set_xla_dump_latency_hiding_schedule(false); opts.set_xla_gpu_enable_latency_hiding_scheduler(false); - opts.set_xla_gpu_lhs_enable_gpu_async_tracker(true); opts.set_xla_gpu_enable_analytical_latency_estimator(false); + opts.set_xla_gpu_enable_analytical_sol_latency_estimator(false); + auto* sol_estimator_defaults = + opts.mutable_xla_gpu_analytical_latency_estimator_options(); + sol_estimator_defaults->emplace( + "nccl_op_launch_us", + absl::StrCat(static_cast(100.0f * kDefaultNcclCostModelCoeff))); + sol_estimator_defaults->emplace( + "nic_speed_gbps", + absl::StrCat(static_cast(55.56f * kDefaultNcclCostModelCoeff))); + sol_estimator_defaults->emplace( + "chunk_prep_us", + absl::StrCat(static_cast(13.34f * kDefaultNcclCostModelCoeff))); + sol_estimator_defaults->emplace( + "rtt_us", + absl::StrCat(static_cast(68.89f * kDefaultNcclCostModelCoeff))); + sol_estimator_defaults->emplace( + "chunk_size_bytes", absl::StrCat(kDefaultNcclCostModelChunkSizeBytes)); + sol_estimator_defaults->emplace( + "gpus_per_node", absl::StrCat(kDefaultNcclCostModelGPUsPerNode)); opts.set_xla_gpu_pgle_profile_file_or_directory_path(""); opts.set_xla_gpu_memory_limit_slop_factor(95); opts.set_xla_gpu_enable_highest_priority_async_stream(true); @@ -209,10 +232,11 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_exhaustive_tiling_search(false); opts.set_xla_gpu_experimental_enable_triton_heroless_priority_fusion(false); - opts.set_xla_gpu_experimental_enable_triton_softmax_priority_fusion(true); + opts.set_xla_gpu_experimental_enable_triton_i4_rewrites(false); opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_gb(0); opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_ratio(1.1); + opts.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction(false); opts.set_xla_gpu_unsafe_pipelined_loop_annotator(false); opts.set_xla_gpu_copy_insertion_use_region_analysis(false); @@ -250,7 +274,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_bf16_3way_gemm(false); opts.set_xla_gpu_nccl_collective_max_nchannels(0); opts.set_xla_gpu_nccl_p2p_max_nchannels(0); - opts.set_xla_gpu_multi_streamed_windowed_einsum(true); + opts.set_xla_gpu_multi_streamed_windowed_einsum(false); opts.set_xla_gpu_experimental_stream_annotation(false); // Minimum combined size of matrices in matrix multiplication to @@ -298,7 +322,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_experimental_parallel_collective_overlap_limit(1); opts.set_xla_pjrt_allow_auto_layout_in_hlo(false); opts.set_xla_gpu_enable_scatter_determinism_expander(true); - opts.set_xla_gpu_unsupported_enable_ragged_all_to_all_decomposer(true); + opts.set_xla_gpu_unsupported_enable_ragged_all_to_all_decomposer(false); return opts; } @@ -469,6 +493,17 @@ void MakeDebugOptionsFlags(std::vector* flag_list, return true; }; + // Custom "sub-parser" lambda for + // xla_gpu_analytical_latency_estimator_options. + auto setter_for_xla_gpu_analytical_latency_estimator_options = + [debug_options](std::string comma_separated_values) { + google::protobuf::Map* options_map = + debug_options + ->mutable_xla_gpu_analytical_latency_estimator_options(); + parse_xla_backend_extra_options(options_map, comma_separated_values); + return true; + }; + // Custom "sub-parser" lambda for xla_partitioning_algorithm. auto setter_for_xla_partitioning_algorithm = [debug_options](const std::string& value) { @@ -889,6 +924,11 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_cpu_use_thunk_runtime), debug_options->xla_cpu_use_thunk_runtime(), "Use Thunk-based runtime for the CPU backend.")); + flag_list->push_back( + tsl::Flag("xla_cpu_use_xnnpack", + bool_setter_for(&DebugOptions::set_xla_cpu_use_xnnpack), + debug_options->xla_cpu_use_xnnpack(), + "Use XNNPACK for supported operations.")); flag_list->push_back(tsl::Flag( "xla_cpu_parallel_codegen_split_count", int32_setter_for(&DebugOptions::set_xla_cpu_parallel_codegen_split_count), @@ -995,6 +1035,15 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "and \"test_undeclared_outputs_dir\" have a special meaning: They cause " "us to dump into the directory specified by the environment variable " "TEST_UNDECLARED_OUTPUTS_DIR.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_unsupported_annotate_with_emitter_loc", + bool_setter_for( + &DebugOptions::set_xla_gpu_unsupported_annotate_with_emitter_loc), + debug_options->xla_gpu_unsupported_annotate_with_emitter_loc(), + "Forces emitters that use MLIR to annotate all the created MLIR " + "instructions with the emitter's C++ source file and line number. The " + "annotations should appear in the MLIR dumps. The emitters should use " + "EmitterLocOpBuilder for that.")); flag_list->push_back(tsl::Flag( "xla_dump_hlo_as_text", bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_text), @@ -1353,6 +1402,12 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_cublaslt), debug_options->xla_gpu_enable_cublaslt(), "Use cuBLASLt for GEMMs when possible.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_collectives_use_persistent_cliques", + bool_setter_for( + &DebugOptions::set_xla_gpu_collectives_use_persistent_cliques), + debug_options->xla_gpu_collectives_use_persistent_cliques(), + "Use persistent per-process XLA:GPU collectives cliques")); flag_list->push_back(tsl::Flag( "xla_gpu_graph_level", setter_for_xla_gpu_graph_level, 1, "The legacy flag for setting GPU graph level. Use " @@ -1561,22 +1616,55 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_enable_analytical_latency_estimator(), "Enable analytical latency estimator for latency-hiding scheduler for " "XLA:GPU")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_analytical_sol_latency_estimator", + bool_setter_for( + &DebugOptions::set_xla_gpu_enable_analytical_sol_latency_estimator), + debug_options->xla_gpu_enable_analytical_sol_latency_estimator(), + "Enable analytical Speed-of-Light latency estimator for latency-hiding " + "scheduler for XLA:GPU, must be used without " + "xla_gpu_enable_analytical_latency_estimator. It can also benefit from " + "user-passed options in xla_gpu_analytical_latency_estimator_options")); + flag_list->push_back(tsl::Flag( + "xla_gpu_analytical_latency_estimator_options", + setter_for_xla_gpu_analytical_latency_estimator_options, "", + "Extra platform-specific options to improve analytical latency " + "estimator precision; comma-separated list of 'key=val' " + "strings (=val may be omitted); no whitespace around commas." + "Available options: " + "--xla_gpu_analytical_latency_estimator_options='nccl_op_launch_ms=55," + "nic_speed_gbps=40,chunk_prep_ms=1,rtt_ms=2,gpus_per_node=4," + "chunk_size_bytes=1024'")); flag_list->push_back(tsl::Flag( "xla_gpu_pgle_profile_file_or_directory_path", string_setter_for( &DebugOptions::set_xla_gpu_pgle_profile_file_or_directory_path), debug_options->xla_gpu_pgle_profile_file_or_directory_path(), "Directory or file for PGLE profiles in XLA:GPU")); - flag_list->push_back(tsl::Flag( - "xla_gpu_lhs_enable_gpu_async_tracker", - bool_setter_for(&DebugOptions::set_xla_gpu_lhs_enable_gpu_async_tracker), - debug_options->xla_gpu_lhs_enable_gpu_async_tracker(), - "Enable GPU async tracker for latency-hiding scheduler in XLA:GPU")); flag_list->push_back(tsl::Flag( "xla_gpu_memory_limit_slop_factor", int32_setter_for(&DebugOptions::set_xla_gpu_memory_limit_slop_factor), debug_options->xla_gpu_memory_limit_slop_factor(), - "Slop factor for memory limits in XLA:GPU")); + "Slop factor for memory limits in XLA:GPU. This flag serves as a " + "multiplier " + "applied to the total available memory, creating a threshold that guides " + "the " + "Latency Hiding Scheduler (LHS) in balancing memory reduction and " + "latency " + "hiding optimizations. This factor effectively establishes a memory " + "limit " + "for compiler passes, determining when the scheduler should prioritize: " + " 1. Memory reduction: When memory usage approaches or exceeds the " + "calculated " + " threshold. " + " 2. Latency hiding: When memory usage is below the threshold, allowing " + "for " + " more aggressive optimizations that may temporarily increase memory " + "usage " + " but improve overall performance. " + "By adjusting this factor, users can fine-tune the trade-off between " + "memory " + "efficiency and performance optimizations. The default value is 95.")); flag_list->push_back(tsl::Flag( "xla_gpu_enable_highest_priority_async_stream", bool_setter_for( @@ -1676,14 +1764,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Enable heroless Triton fusions in the PriorityFusion pass. The pass " "will try to make Triton fusions first and foremost where it is " "possible.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_experimental_enable_triton_softmax_priority_fusion", - bool_setter_for( - &DebugOptions:: - set_xla_gpu_experimental_enable_triton_softmax_priority_fusion), - debug_options - ->xla_gpu_experimental_enable_triton_softmax_priority_fusion(), - "Enable fusion into Triton Softmax in PriorityFusion pass.")); flag_list->push_back(tsl::Flag( "xla_gpu_dump_autotune_results_to", string_setter_for(&DebugOptions::set_xla_gpu_dump_autotune_results_to), @@ -2037,6 +2117,14 @@ void MakeDebugOptionsFlags(std::vector* flag_list, flag_list->push_back(tsl::Flag("xla_gpu_enable_triton_gemm_int4", noop_flag_setter, true, "[Deprecated, do not use]")); + flag_list->push_back(tsl::Flag( + "xla_gpu_experimental_enable_triton_i4_rewrites", + bool_setter_for( + &DebugOptions::set_xla_gpu_experimental_enable_triton_i4_rewrites), + debug_options->xla_gpu_experimental_enable_triton_i4_rewrites(), + "When enabled, the Triton emitter for dot will use int4 as native type " + "and later the Triton IR will be rewritten by Triton IR rewriting pass " + "to use int4 packed into int8.")); flag_list->push_back( tsl::Flag("xla_gpu_async_dot", bool_setter_for(&DebugOptions::set_xla_gpu_async_dot), @@ -2064,13 +2152,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, int32_setter_for( &DebugOptions::set_xla_gpu_executable_warn_stuck_timeout_seconds), debug_options->xla_gpu_executable_warn_stuck_timeout_seconds(), - "Set timeout for RendezvousSingle stuck warning")); + "Set timeout for Rendezvous stuck warning")); flag_list->push_back(tsl::Flag( "xla_gpu_executable_terminate_timeout", int32_setter_for( &DebugOptions::set_xla_gpu_executable_terminate_timeout_seconds), debug_options->xla_gpu_executable_terminate_timeout_seconds(), - "Set timeout for RendezvousSingle termination")); + "Set timeout for Rendezvous termination")); flag_list->push_back(tsl::Flag( "xla_gpu_experimental_disable_binary_libraries", bool_setter_for( @@ -2126,6 +2214,15 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_unsupported_enable_ragged_all_to_all_decomposer(), "Internal: Enable the RaggedAllToAllDecomposer, an experimental pass " "that rewrites ragged-all-to-all as a dense all-to-all operation.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_experimental_enable_alltoall_windowed_einsum", + bool_setter_for( + &DebugOptions:: + set_xla_gpu_experimental_enable_alltoall_windowed_einsum), + debug_options->xla_gpu_experimental_enable_alltoall_windowed_einsum(), + "Enable windowed einsum rewrite for all-to-all+gemm pattern, " + "This optimization slices the all-to-all into smaller all-to-alls." + "It is an experimental feature.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/third_party/xla/xla/ef57_test.cc b/third_party/xla/xla/ef57_test.cc index 1f5d48cfda0166..4143b58277e567 100644 --- a/third_party/xla/xla/ef57_test.cc +++ b/third_party/xla/xla/ef57_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/log/log_streamer.h" #include "absl/random/random.h" #include "absl/types/span.h" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" namespace xla { namespace { diff --git a/third_party/xla/xla/error_spec.h b/third_party/xla/xla/error_spec.h index 42cc6014e45832..70cf7cf9e896fc 100644 --- a/third_party/xla/xla/error_spec.h +++ b/third_party/xla/xla/error_spec.h @@ -22,7 +22,8 @@ namespace xla { // Structure describing permissible absolute and relative error bounds. struct ErrorSpec { - explicit ErrorSpec(double aabs, double arel = 0, bool relaxed_nans = false) + explicit constexpr ErrorSpec(double aabs, double arel = 0, + bool relaxed_nans = false) : abs(aabs), rel(arel), relaxed_nans(relaxed_nans) {} double abs; // Absolute error bound. diff --git a/third_party/xla/xla/executable_run_options.cc b/third_party/xla/xla/executable_run_options.cc index 0ab7a4bbf77135..706b4143b91e3e 100644 --- a/third_party/xla/xla/executable_run_options.cc +++ b/third_party/xla/xla/executable_run_options.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/executable_run_options.h" #include +#include #include namespace xla { diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD index 8ebc495f27ceaf..ff9a8aa773e203 100644 --- a/third_party/xla/xla/ffi/BUILD +++ b/third_party/xla/xla/ffi/BUILD @@ -48,6 +48,7 @@ xla_cc_test( "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", @@ -78,6 +79,7 @@ xla_cc_test( ":type_id_registry", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -91,6 +93,7 @@ cc_library( deps = [ ":type_id_registry", "//xla:util", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:logging", @@ -105,6 +108,7 @@ xla_cc_test( ":execution_state", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -135,6 +139,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", ], @@ -166,6 +171,8 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -215,7 +222,9 @@ xla_cc_test( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:status_matchers", @@ -247,6 +256,7 @@ xla_cc_test( ":type_id_registry", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h index cf98210af1b717..389d2d2a9a7aec 100644 --- a/third_party/xla/xla/ffi/api/api.h +++ b/third_party/xla/xla/ffi/api/api.h @@ -224,13 +224,13 @@ class Ffi { // Registers FFI handler bundle with an XLA runtime under the given name on a // given platform. - static inline XLA_FFI_Error* RegisterStaticHandler( + static XLA_FFI_Error* RegisterStaticHandler( const XLA_FFI_Api* api, std::string_view name, std::string_view platform, XLA_FFI_Handler_Bundle bundle, XLA_FFI_Handler_Traits traits = 0); // Registers FFI execute handler with an XLA runtime under the given name on a // given platform. - static inline XLA_FFI_Error* RegisterStaticHandler( + static XLA_FFI_Error* RegisterStaticHandler( const XLA_FFI_Api* api, std::string_view name, std::string_view platform, XLA_FFI_Handler* execute, XLA_FFI_Handler_Traits traits = 0) { return RegisterStaticHandler( @@ -238,6 +238,15 @@ class Ffi { XLA_FFI_Handler_Bundle{nullptr, nullptr, nullptr, execute}, traits); } + // Registers a custom type so that it can be used with State and UserData + // arguments to external FFI handlers. The `name` argument must be a unique + // identifier for the type, and duplicate registrations with the same name + // are not allowed. When successful, a unique ID will be returned by updating + // `type_id`. + static XLA_FFI_Error* RegisterTypeId(const XLA_FFI_Api* api, + std::string_view name, + XLA_FFI_TypeId* type_id); + protected: template static std::string StrCat(Args... args); @@ -260,11 +269,9 @@ class Ffi { size_t actual); }; -XLA_FFI_Error* Ffi::RegisterStaticHandler(const XLA_FFI_Api* api, - std::string_view name, - std::string_view platform, - XLA_FFI_Handler_Bundle bundle, - XLA_FFI_Handler_Traits traits) { +inline XLA_FFI_Error* Ffi::RegisterStaticHandler( + const XLA_FFI_Api* api, std::string_view name, std::string_view platform, + XLA_FFI_Handler_Bundle bundle, XLA_FFI_Handler_Traits traits) { XLA_FFI_Handler_Register_Args args; args.struct_size = XLA_FFI_Handler_Register_Args_STRUCT_SIZE; args.extension_start = nullptr; @@ -275,6 +282,17 @@ XLA_FFI_Error* Ffi::RegisterStaticHandler(const XLA_FFI_Api* api, return api->XLA_FFI_Handler_Register(&args); } +inline XLA_FFI_Error* Ffi::RegisterTypeId(const XLA_FFI_Api* api, + std::string_view name, + XLA_FFI_TypeId* type_id) { + XLA_FFI_TypeId_Register_Args args; + args.struct_size = XLA_FFI_TypeId_Register_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.name = XLA_FFI_ByteSpan{name.data(), name.size()}; + args.type_id = type_id; + return api->XLA_FFI_TypeId_Register(&args); +} + template std::string Ffi::StrCat(Args... args) { std::stringstream ss; @@ -568,7 +586,8 @@ inline Binding Ffi::BindInstantiate() { } //===----------------------------------------------------------------------===// -// Template metaprogramming to automatially infer Binding from invocable object. +// Template metaprogramming to automatically infer Binding from invocable +// object. //===----------------------------------------------------------------------===// // A little bit of metaprogramming that automatically infers the binding schema @@ -1663,22 +1682,6 @@ XLA_FFI_REGISTER_SCALAR_ATTR_DECODING(std::complex, #undef XLA_FFI_REGISTER_SCALAR_ATTR_DECODING -template <> -struct AttrDecoding { - using Type = std::string_view; - static std::optional Decode(XLA_FFI_AttrType type, - void* attr, - DiagnosticEngine& diagnostic) { - if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_STRING)) { - return diagnostic.Emit("Wrong attribute type: expected ") - << XLA_FFI_AttrType_STRING << " but got " << type; - } - - auto* span = reinterpret_cast(attr); - return std::string_view(span->ptr, span->len); - } -}; - //===----------------------------------------------------------------------===// // Automatic dictionary attributes to structs decoding. //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/ffi/api/c_api.h b/third_party/xla/xla/ffi/api/c_api.h index 36cf4337564e61..8d6f1095fad24a 100644 --- a/third_party/xla/xla/ffi/api/c_api.h +++ b/third_party/xla/xla/ffi/api/c_api.h @@ -255,30 +255,30 @@ typedef struct XLA_FFI_ExecutionContext XLA_FFI_ExecutionContext; //===----------------------------------------------------------------------===// // TypeId uniquely identifies a user-defined type in a given XLA FFI instance. -struct XLA_FFI_TypeId { +typedef struct XLA_FFI_TypeId { int64_t type_id; -}; +} XLA_FFI_TypeId; // We use byte spans to pass strings to handlers because strings might not be // null terminated, and even if they are, looking for a null terminator can // become very expensive in tight loops. -struct XLA_FFI_ByteSpan { +typedef struct XLA_FFI_ByteSpan { const char* ptr; size_t len; -}; +} XLA_FFI_ByteSpan; // A struct to pass a scalar value to FFI handler. -struct XLA_FFI_Scalar { +typedef struct XLA_FFI_Scalar { XLA_FFI_DataType dtype; void* value; -}; +} XLA_FFI_Scalar; // A struct to pass a dense array to FFI handler. -struct XLA_FFI_Array { +typedef struct XLA_FFI_Array { XLA_FFI_DataType dtype; size_t size; void* data; -}; +} XLA_FFI_Array; //===----------------------------------------------------------------------===// // Future @@ -431,12 +431,12 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_CallFrame, attrs); typedef XLA_FFI_Error* XLA_FFI_Handler(XLA_FFI_CallFrame* call_frame); // XLA FFI handlers for execution stages (see XLA_FFI_ExecutionStage). -struct XLA_FFI_Handler_Bundle { +typedef struct XLA_FFI_Handler_Bundle { XLA_FFI_Handler* instantiate; // optional XLA_FFI_Handler* prepare; // optional XLA_FFI_Handler* initialize; // optional XLA_FFI_Handler* execute; // required -}; +} XLA_FFI_Handler_Bundle; enum XLA_FFI_Handler_TraitsBits { // Calls to FFI handler are safe to trace into the command buffer. It means diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index 1099bcb0bed43f..f264451da34735 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -893,6 +893,22 @@ XLA_FFI_REGISTER_ARRAY_ATTR_DECODING(double, XLA_FFI_DataType_F64); #undef XLA_FFI_REGISTER_ARRAY_ATTR_DECODING +template <> +struct AttrDecoding { + using Type = std::string_view; + static std::optional Decode(XLA_FFI_AttrType type, + void* attr, + DiagnosticEngine& diagnostic) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_STRING)) { + return diagnostic.Emit("Wrong attribute type: expected ") + << XLA_FFI_AttrType_STRING << " but got " << type; + } + + auto* span = reinterpret_cast(attr); + return std::string_view(span->ptr, span->len); + } +}; + // A type tag to mark i64 attributes as pointers to `T`. template struct Pointer {}; @@ -1310,30 +1326,14 @@ inline ThreadPool::ThreadPool(const XLA_FFI_Api* api, // Type Registration //===----------------------------------------------------------------------===// -namespace internal { - -inline XLA_FFI_Error* RegisterType(const XLA_FFI_Api* api, - std::string_view name, - XLA_FFI_TypeId* type_id) { - XLA_FFI_TypeId_Register_Args args; - args.struct_size = XLA_FFI_TypeId_Register_Args_STRUCT_SIZE; - args.extension_start = nullptr; - args.name = XLA_FFI_ByteSpan{name.data(), name.size()}; - args.type_id = type_id; - return api->XLA_FFI_TypeId_Register(&args); -} - -} // namespace internal - #define XLA_FFI_REGISTER_TYPE(API, NAME, TYPE_ID) \ XLA_FFI_REGISTER_TYPE_(API, NAME, TYPE_ID, __COUNTER__) #define XLA_FFI_REGISTER_TYPE_(API, NAME, TYPE_ID, N) \ XLA_FFI_REGISTER_TYPE__(API, NAME, TYPE_ID, N) -#define XLA_FFI_REGISTER_TYPE__(API, NAME, TYPE_ID, N) \ - XLA_FFI_ATTRIBUTE_UNUSED static const XLA_FFI_Error* \ - xla_ffi_type_##N##_registered_ = [] { \ - return ::xla::ffi::internal::RegisterType(API, NAME, TYPE_ID); \ - }() +#define XLA_FFI_REGISTER_TYPE__(API, NAME, TYPE_ID, N) \ + XLA_FFI_ATTRIBUTE_UNUSED static const XLA_FFI_Error* \ + xla_ffi_type_##N##_registered_ = \ + [] { return ::xla::ffi::Ffi::RegisterTypeId(API, NAME, TYPE_ID); }() //===----------------------------------------------------------------------===// // UserData diff --git a/third_party/xla/xla/ffi/call_frame_test.cc b/third_party/xla/xla/ffi/call_frame_test.cc index 89d306455e6a19..c74a51870df3ff 100644 --- a/third_party/xla/xla/ffi/call_frame_test.cc +++ b/third_party/xla/xla/ffi/call_frame_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include #include "absl/strings/str_cat.h" #include "xla/ffi/api/c_api.h" #include "xla/stream_executor/device_memory.h" diff --git a/third_party/xla/xla/ffi/execution_context_test.cc b/third_party/xla/xla/ffi/execution_context_test.cc index 6a5cdfa40b07b6..c8d37ea5c64858 100644 --- a/third_party/xla/xla/ffi/execution_context_test.cc +++ b/third_party/xla/xla/ffi/execution_context_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "absl/status/status.h" #include "xla/ffi/type_id_registry.h" #include "xla/tsl/lib/core/status_test_util.h" diff --git a/third_party/xla/xla/ffi/execution_state.cc b/third_party/xla/xla/ffi/execution_state.cc index e94a3a944fe4ef..5aab4a7a3a575c 100644 --- a/third_party/xla/xla/ffi/execution_state.cc +++ b/third_party/xla/xla/ffi/execution_state.cc @@ -17,7 +17,9 @@ limitations under the License. #include +#include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/ffi/type_id_registry.h" #include "xla/util.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/ffi/execution_state_test.cc b/third_party/xla/xla/ffi/execution_state_test.cc index dd8244f00183ff..d32c80f6d92ff4 100644 --- a/third_party/xla/xla/ffi/execution_state_test.cc +++ b/third_party/xla/xla/ffi/execution_state_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include +#include #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/ffi/ffi.h b/third_party/xla/xla/ffi/ffi.h index 19a61728594687..9335bfa0241357 100644 --- a/third_party/xla/xla/ffi/ffi.h +++ b/third_party/xla/xla/ffi/ffi.h @@ -39,6 +39,7 @@ limitations under the License. #include "absl/base/nullability.h" #include "absl/base/optimization.h" #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/executable_run_options.h" #include "xla/ffi/api/c_api.h" @@ -403,6 +404,22 @@ XLA_FFI_REGISTER_ARRRAY_ATTR_DECODING(double, XLA_FFI_DataType_F64); #undef XLA_FFI_REGISTER_ARRRAY_ATTR_DECODING +template <> +struct AttrDecoding { + using Type = absl::string_view; + static std::optional Decode(XLA_FFI_AttrType type, + void* attr, + DiagnosticEngine& diagnostic) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_STRING)) { + return diagnostic.Emit("Wrong attribute type: expected ") + << XLA_FFI_AttrType_STRING << " but got " << type; + } + + auto* span = reinterpret_cast(attr); + return std::string_view(span->ptr, span->len); + } +}; + // A type tag to mark i64 attributes as pointers to `T`. template struct Pointer {}; diff --git a/third_party/xla/xla/ffi/ffi_api.cc b/third_party/xla/xla/ffi/ffi_api.cc index 507c756a764f24..f52be8b94e6e5d 100644 --- a/third_party/xla/xla/ffi/ffi_api.cc +++ b/third_party/xla/xla/ffi/ffi_api.cc @@ -26,10 +26,12 @@ limitations under the License. #include #include "absl/base/optimization.h" -#include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/numeric/bits.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -370,7 +372,7 @@ static absl::Status RegisterHandler(std::string_view name, api_version.minor_version != XLA_FFI_API_MINOR) { return InvalidArgument( "FFI handler registration for %s on platform %s (canonical %s) failed " - "because the hander's API version (%d.%d) is incompatible with the " + "because the handler's API version (%d.%d) is incompatible with the " "framework's API version (%d.%d)", name, platform, canonical_platform, api_version.major_version, api_version.minor_version, XLA_FFI_API_MAJOR, XLA_FFI_API_MINOR); diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc index 372f02cfe8d67d..1f612ddd747754 100644 --- a/third_party/xla/xla/ffi/ffi_test.cc +++ b/third_party/xla/xla/ffi/ffi_test.cc @@ -26,9 +26,12 @@ limitations under the License. #include #include +#include +#include #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/match.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/call_frame.h" @@ -220,7 +223,7 @@ TEST(FfiTest, BuiltinAttributes) { auto call_frame = builder.Build(); auto fn = [&](bool pred, int8_t i8, int16_t i16, int32_t i32, int64_t i64, - float f32, double f64, std::string_view str) { + float f32, double f64, absl::string_view str) { EXPECT_EQ(pred, true); EXPECT_EQ(i8, 42); EXPECT_EQ(i16, 42); @@ -240,7 +243,7 @@ TEST(FfiTest, BuiltinAttributes) { .Attr("i64") .Attr("f32") .Attr("f64") - .Attr("str") + .Attr("str") .To(fn); auto status = Call(*handler, call_frame); @@ -263,7 +266,7 @@ TEST(FfiTest, BuiltinAttributesAutoBinding) { static constexpr char kStr[] = "str"; auto fn = [&](Attr i32, Attr f32, - Attr str) { + Attr str) { EXPECT_EQ(*i32, 42); EXPECT_EQ(*f32, 42.0f); EXPECT_EQ(*str, "foo"); @@ -357,7 +360,7 @@ TEST(FfiTest, AttrsAsDictionary) { absl::StatusOr i32 = dict.get("i32"); absl::StatusOr f32 = dict.get("f32"); - absl::StatusOr str = dict.get("str"); + absl::StatusOr str = dict.get("str"); EXPECT_TRUE(i32.ok()); EXPECT_TRUE(f32.ok()); @@ -435,7 +438,7 @@ TEST(FfiTest, StructAttr) { builder.AddAttributes(attrs.Build()); auto call_frame = builder.Build(); - auto fn = [&](std::string_view str, PairOfI32AndF32 i32_and_f32) { + auto fn = [&](absl::string_view str, PairOfI32AndF32 i32_and_f32) { EXPECT_EQ(str, "foo"); EXPECT_EQ(i32_and_f32.i32, 42); EXPECT_EQ(i32_and_f32.f32, 42.0f); @@ -443,7 +446,7 @@ TEST(FfiTest, StructAttr) { }; auto handler = Ffi::Bind() - .Attr("str") + .Attr("str") .Attr("i32_and_f32") .To(fn); @@ -484,7 +487,7 @@ TEST(FfiTest, DecodingErrors) { builder.AddAttributes(attrs.Build()); auto call_frame = builder.Build(); - auto fn = [](int32_t, int64_t, float, std::string_view) { + auto fn = [](int32_t, int64_t, float, absl::string_view) { return absl::OkStatus(); }; @@ -492,7 +495,7 @@ TEST(FfiTest, DecodingErrors) { .Attr("not_i32_should_fail") .Attr("not_i64_should_fail") .Attr("f32") - .Attr("not_str_should_fail") + .Attr("not_str_should_fail") .To(fn); auto status = Call(*handler, call_frame); diff --git a/third_party/xla/xla/ffi/type_id_registry.h b/third_party/xla/xla/ffi/type_id_registry.h index 5672ac691e253b..6b7455542c51c4 100644 --- a/third_party/xla/xla/ffi/type_id_registry.h +++ b/third_party/xla/xla/ffi/type_id_registry.h @@ -41,7 +41,7 @@ namespace xla::ffi { // of time and explicitly get a unique type id for them. // // 2. Internal type id. When FFI handler defined in the same binary we rely -// on a global static registry to automatically assing type ids. +// on a global static registry to automatically assign type ids. class TypeIdRegistry { public: TSL_LIB_GTL_DEFINE_INT_TYPE(TypeId, int64_t); diff --git a/third_party/xla/xla/ffi/type_id_registry_test.cc b/third_party/xla/xla/ffi/type_id_registry_test.cc index d34b61a66ac09f..b26e385968c338 100644 --- a/third_party/xla/xla/ffi/type_id_registry_test.cc +++ b/third_party/xla/xla/ffi/type_id_registry_test.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include +#include #include "absl/status/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/fp_util_test.cc b/third_party/xla/xla/fp_util_test.cc index 3eb7c54f919b0a..3eb3561a264d40 100644 --- a/third_party/xla/xla/fp_util_test.cc +++ b/third_party/xla/xla/fp_util_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/numeric/bits.h" #include "xla/bit_cast.h" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" #include "xla/util.h" #include "tsl/platform/ml_dtypes.h" diff --git a/third_party/xla/xla/hlo/analysis/BUILD b/third_party/xla/xla/hlo/analysis/BUILD index e29673b83b68a0..f14588791291ae 100644 --- a/third_party/xla/xla/hlo/analysis/BUILD +++ b/third_party/xla/xla/hlo/analysis/BUILD @@ -39,9 +39,9 @@ xla_cc_test( ":hlo_dfs_reachability", "//xla:literal_util", "//xla:shape_util", - "//xla:test", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:test", "//xla/service:computation_placer_hdr", "//xla/service:hlo_module_config", "@local_tsl//tsl/platform:status", @@ -70,10 +70,10 @@ xla_cc_test( ":hlo_reachability", "//xla:literal_util", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/service:computation_placer", "//xla/service:hlo_module_config", "@com_google_absl//absl/random", @@ -93,8 +93,8 @@ cc_library( "//xla:status_macros", "//xla:types", "//xla:util", + "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_reachability", "//xla/service:call_graph", "//xla/service:hlo_proto_cc", "//xla/service:hlo_value", @@ -165,10 +165,10 @@ xla_cc_test( deps = [ ":while_loop_analysis", "//xla:comparison_util", - "//xla:test", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:test", "//xla/service:constant_value", "//xla/service:value_range", "@com_google_absl//absl/log", @@ -221,13 +221,13 @@ xla_cc_test( "//xla:comparison_util", "//xla:literal_util", "//xla:shape_util", - "//xla:test", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:flatten_call_graph", + "//xla/hlo/testlib:test", + "//xla/hlo/transforms/simplifiers:flatten_call_graph", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/service:hlo_creation_utils", - "//xla/service:hlo_dce", "//xla/service:hlo_value", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", @@ -294,8 +294,10 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -345,10 +347,10 @@ xla_cc_test( "//xla:literal", "//xla:shape_util", "//xla:status_macros", - "//xla:test", - "//xla:test_helpers", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -391,13 +393,13 @@ xla_cc_test( "//xla:literal", "//xla:literal_util", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", + "//xla/hlo/transforms/simplifiers:flatten_call_graph", "//xla/hlo/utils:hlo_matchers", - "//xla/service:flatten_call_graph", "//xla/service:hlo_buffer", "//xla/service:hlo_value", "//xla/tsl/lib/core:status_test_util", @@ -462,11 +464,11 @@ xla_cc_test( ":tuple_points_to_analysis", "//xla:literal_util", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/service:logical_buffer", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/types:span", @@ -537,7 +539,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/transforms:gather_simplifier", + "//xla/hlo/transforms/simplifiers:gather_simplifier", "//xla/hlo/utils:hlo_traversal", "//xla/service/gpu:matmul_indexing_utils", "@com_google_absl//absl/algorithm:container", @@ -565,7 +567,6 @@ xla_cc_test( ":indexing_test_utils", "//xla/hlo/testlib:verified_hlo_module", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/hash:hash_testing", "@com_google_absl//absl/strings:string_view", @@ -603,7 +604,6 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/testlib:verified_hlo_module", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", diff --git a/third_party/xla/xla/hlo/analysis/hlo_alias_analysis_test.cc b/third_party/xla/xla/hlo/analysis/hlo_alias_analysis_test.cc index 00109570e14d18..65b0915bef2fb9 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_alias_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_alias_analysis_test.cc @@ -27,14 +27,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" +#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" #include "xla/literal_util.h" -#include "xla/service/flatten_call_graph.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_value.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis_test.cc b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis_test.cc index 3a5e93f85678ee..07e9853f20d81f 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis_test.cc @@ -33,14 +33,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/literal_util.h" -#include "xla/service/flatten_call_graph.h" #include "xla/service/hlo_creation_utils.h" -#include "xla/service/hlo_dce.h" #include "xla/service/hlo_value.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/hlo/analysis/hlo_dfs_reachability_test.cc b/third_party/xla/xla/hlo/analysis/hlo_dfs_reachability_test.cc index 8687bdff76b8a1..d717759643c103 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_dfs_reachability_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_dfs_reachability_test.cc @@ -18,18 +18,17 @@ limitations under the License. #include #include #include -#include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal_util.h" #include "xla/service/computation_placer.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "tsl/platform/status.h" #include "tsl/platform/test_benchmark.h" @@ -145,7 +144,7 @@ TEST_F(HloDfsReachabilityTest, ChannelReachability) { class HloDfsReachabilityBenchmark { public: - HloDfsReachabilityBenchmark(int size, std::string_view name) : name_(name) { + HloDfsReachabilityBenchmark(int size, absl::string_view name) : name_(name) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); auto builder = HloComputation::Builder(name); diff --git a/third_party/xla/xla/hlo/analysis/hlo_liveness_analysis_test.cc b/third_party/xla/xla/hlo/analysis/hlo_liveness_analysis_test.cc index 436f5dedfef321..0e164504056b5c 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_liveness_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_liveness_analysis_test.cc @@ -18,11 +18,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/hlo/analysis/hlo_ordering.h b/third_party/xla/xla/hlo/analysis/hlo_ordering.h index 644c3881fd2233..ded9fed8ccd1d5 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_ordering.h +++ b/third_party/xla/xla/hlo/analysis/hlo_ordering.h @@ -22,9 +22,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "xla/hlo/analysis/hlo_dataflow_analysis.h" +#include "xla/hlo/analysis/hlo_reachability.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_reachability.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/call_graph.h" #include "xla/service/hlo.pb.h" diff --git a/third_party/xla/xla/hlo/analysis/hlo_reachability_test.cc b/third_party/xla/xla/hlo/analysis/hlo_reachability_test.cc index 98958516f124df..e9aae9531bc51a 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_reachability_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_reachability_test.cc @@ -18,19 +18,18 @@ limitations under the License. #include #include #include -#include #include "absl/random/random.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal_util.h" #include "xla/service/computation_placer.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "tsl/platform/status.h" #include "tsl/platform/test_benchmark.h" @@ -287,7 +286,7 @@ BENCHMARK(BM_HloReachabilityBitSetUnion)->BM_ARGS; class HloReachabilityBenchmark { public: - HloReachabilityBenchmark(int size, std::string_view name) : name_(name) { + HloReachabilityBenchmark(int size, absl::string_view name) : name_(name) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); auto builder = HloComputation::Builder(name); diff --git a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc index 5657f354a42e45..2b1973ee823617 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc @@ -16,8 +16,10 @@ limitations under the License. #include "xla/hlo/analysis/hlo_replication_analysis.h" #include +#include #include #include +#include #include #include #include @@ -25,8 +27,11 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -37,9 +42,58 @@ limitations under the License. #include "xla/xla_data.pb.h" namespace xla { +namespace { +// When cross_partition_spmd is true, returns the partition IDs of all +// replica groups in which a given replica participates. Specfically, the k-th +// element of the outermost vector in the returned data structure holds the +// partition IDs converted from the global IDs in a collective's +// replica_groups field for replica k. +// +// When cross_partition_spmd is false, returns the replica IDs of all +// replica groups in which a given partition participates. Specfically, the k-th +// element of the outermost vector in the returned data structure holds the +// replica IDs converted from the global IDs in a collective's replica_groups +// field for partition k. +std::vector>> GroupsForReplicas( + absl::Span groups, int64_t num_partitions, + int64_t replica_count, bool cross_partition_spmd) { + int64_t num_replicas = cross_partition_spmd ? replica_count : num_partitions; + std::vector>> groups_for_replicas( + num_replicas); + for (const ReplicaGroup& group : groups) { + absl::flat_hash_map> id_to_ids; + for (int64_t id : group.replica_ids()) { + int64_t rid = id / num_partitions; + int64_t pid = id % num_partitions; + if (cross_partition_spmd) { + CHECK_LT(rid, num_replicas) + << "Got replica ID " << rid + << " which is greater or equal to the number of replicas: " + << num_replicas; + id_to_ids[rid].push_back(pid); + } else { + CHECK_LT(pid, num_partitions) + << "Got partition ID " << rid + << " which is greater or equal to the number of partitions: " + << num_partitions; + id_to_ids[pid].push_back(rid); + } + } + for (const auto& [id, ids] : id_to_ids) { + groups_for_replicas[id].push_back(std::move(ids)); + } + } + + return groups_for_replicas; +} + +} // namespace // Determines whether an HLO instruction is replicated at index based on current -// knowledge in hlo_replication. +// knowledge in hlo_replication. When cross_partition_spmd is true, the +// instruction must be replicated across all partitions on each replica. +// Similarly, when cross_partition_spmd is false, the instruction must be +// replicated across all replicas on each partition. HloReplicationAnalysis::HloReplication HloReplicationAnalysis::DetermineHloInstructionIsReplicated( const HloInstruction* hlo, const ShapeIndex& index, @@ -78,11 +132,16 @@ HloReplicationAnalysis::DetermineHloInstructionIsReplicated( return HloReplication::ReplicatedOnAllDevices(); } if (support_partial_replication) { - std::vector> device_sets; + std::vector>> device_sets_per_replica( + 1); for (const ReplicaGroup& replica_group : hlo->replica_groups()) { - device_sets.push_back(replica_group.replica_ids()); + std::vector device_set; + for (auto id : replica_group.replica_ids()) { + device_set.push_back(id); + } + device_sets_per_replica[0].push_back(device_set); } - return HloReplication::PartiallyReplicated(device_sets); + return HloReplication::PartiallyReplicated(device_sets_per_replica); } else { return HloReplication::UniqueOnAllDevices(); } @@ -94,48 +153,29 @@ HloReplicationAnalysis::DetermineHloInstructionIsReplicated( global_id = Cast(hlo)->use_global_device_ids(); } if (global_id) { - // TODO(philipphack): The following is incorrect if partitions are - // replicated differently on replicas, or if replicas are replicated - // differently on partitions. - bool replicated_across_partitions = true; - bool replicated_across_replicas = true; const int64_t num_partitions = hlo->GetModule()->config().num_partitions(); - absl::flat_hash_set visited_partitions; - absl::flat_hash_set visited_replicas; - std::vector device_set; - std::vector> device_sets; - std::vector> device_sets_storage; - for (const auto& group : hlo->replica_groups()) { - device_set.clear(); - visited_partitions.clear(); - visited_replicas.clear(); - visited_replicas.reserve(group.replica_ids().size()); - visited_partitions.reserve(group.replica_ids().size()); - for (int64_t id : group.replica_ids()) { - int64_t rid = id / num_partitions; - int64_t pid = id % num_partitions; - visited_partitions.insert(pid); - visited_replicas.insert(rid); - if (support_partial_replication) { - device_set.push_back(cross_partition_spmd ? pid : rid); - } - } - replicated_across_partitions &= - visited_partitions.size() == num_partitions; - replicated_across_replicas &= - visited_replicas.size() == - hlo->GetModule()->config().replica_count(); - if (support_partial_replication) { - device_sets_storage.push_back(device_set); - device_sets.push_back(device_sets_storage.back()); - } + const int64_t replica_count = + hlo->GetModule()->config().replica_count(); + std::vector>> device_sets_per_replica = + GroupsForReplicas(hlo->replica_groups(), num_partitions, + replica_count, cross_partition_spmd); + + // In the fully replicated case, there is one set of partition or + // replica IDs on each replica or partition. Since the flattened ID + // replica groups must contain every device, the size of the set is the + // number of partitions or replicas. + bool fully_replicated = true; + for (auto device_sets : device_sets_per_replica) { + fully_replicated &= + device_sets.size() == 1 && + (*device_sets.begin()).size() == + (cross_partition_spmd ? num_partitions : replica_count); } - if ((cross_partition_spmd && replicated_across_partitions) || - (!cross_partition_spmd && replicated_across_replicas)) { + if (fully_replicated) { return HloReplication::ReplicatedOnAllDevices(); } else if (support_partial_replication) { - return HloReplication::PartiallyReplicated(device_sets); + return HloReplication::PartiallyReplicated(device_sets_per_replica); } else { return HloReplication::UniqueOnAllDevices(); } @@ -210,12 +250,12 @@ HloReplicationAnalysis::DetermineHloInstructionIsReplicated( ds_buffer->literal().GetIntegralAsS64({device_id}); value_to_device_set[*value].push_back(device_id); } - std::vector> device_sets; + std::vector>> device_sets_per_replica( + 1); for (const auto& value_and_device_set : value_to_device_set) { - device_sets.push_back( - absl::Span(value_and_device_set.second)); + device_sets_per_replica[0].push_back(value_and_device_set.second); } - return HloReplication::PartiallyReplicated(device_sets); + return HloReplication::PartiallyReplicated(device_sets_per_replica); } } } @@ -539,10 +579,12 @@ HloReplicationAnalysis::HloReplication::HloReplication() HloReplicationAnalysis::HloReplication::HloReplication( HloReplicationAnalysis::HloReplication::State state, - absl::Span device_set_root) + absl::Span> device_set_root_per_replica) : state_(state), - device_set_root_(device_set_root.begin(), device_set_root.end()) { - CHECK(state == State::kPartiallyReplicated || device_set_root_.empty()); + device_set_root_per_replica_(device_set_root_per_replica.begin(), + device_set_root_per_replica.end()) { + CHECK(state == State::kPartiallyReplicated || + device_set_root_per_replica_.empty()); } HloReplicationAnalysis::HloReplication @@ -557,22 +599,30 @@ HloReplicationAnalysis::HloReplication::UniqueOnAllDevices() { HloReplicationAnalysis::HloReplication HloReplicationAnalysis::HloReplication::PartiallyReplicated( - absl::Span> device_sets) { - int64_t max_device_id = 0; - for (const absl::Span& device_set : device_sets) { - for (int64_t device_id : device_set) { - max_device_id = std::max(max_device_id, device_id); + absl::Span>> + device_sets_per_replica) { + std::vector> device_set_root_per_replica; + for (int i = 0; i < device_sets_per_replica.size(); ++i) { + const std::vector>& device_sets = + device_sets_per_replica[i]; + int64_t max_device_id = 0; + for (const std::vector& device_set : device_sets) { + for (int64_t device_id : device_set) { + max_device_id = std::max(max_device_id, device_id); + } } - } - std::vector device_set_root; - device_set_root.resize(max_device_id + 1); - for (const absl::Span& device_set : device_sets) { - int64_t min_device_id = *absl::c_min_element(device_set); - for (int64_t device_id : device_set) { - device_set_root[device_id] = min_device_id; + std::vector device_set_root; + device_set_root.resize(max_device_id + 1); + for (const std::vector& device_set : device_sets) { + int64_t min_device_id = *absl::c_min_element(device_set); + for (int64_t device_id : device_set) { + device_set_root[device_id] = min_device_id; + } } + device_set_root_per_replica.push_back(std::move(device_set_root)); } - return HloReplication(State::kPartiallyReplicated, device_set_root); + return HloReplication(State::kPartiallyReplicated, + device_set_root_per_replica); } HloReplicationAnalysis::HloReplication @@ -590,27 +640,36 @@ HloReplicationAnalysis::HloReplication::Merge( case State::kUniqueOnAllDevices: return other; case State::kPartiallyReplicated: { - absl::flat_hash_map> - value_to_device_set; - size_t num_devices = device_set_root_.size(); - for (int64_t device_id = 0; device_id < num_devices; ++device_id) { - int64_t new_value = device_set_root_[device_id] * num_devices + - other.device_set_root_[device_id]; - value_to_device_set[new_value].push_back(device_id); - } - CHECK_LE(value_to_device_set.size(), num_devices); - if (value_to_device_set.size() == 1) { - return ReplicatedOnAllDevices(); - } else if (value_to_device_set.size() < num_devices) { - std::vector> device_sets; + bool unique_on_all_devices = true; + std::vector>> + device_sets_per_replica; + CHECK_EQ(device_set_root_per_replica_.size(), + other.device_set_root_per_replica_.size()); + for (int i = 0; i < device_set_root_per_replica_.size(); ++i) { + const std::vector& my_device_set_root = + device_set_root_per_replica_[i]; + const std::vector& other_device_set_root = + other.device_set_root_per_replica_[i]; + absl::flat_hash_map> + value_to_device_set; + size_t num_devices = my_device_set_root.size(); + for (int64_t device_id = 0; device_id < num_devices; ++device_id) { + int64_t new_value = my_device_set_root[device_id] * num_devices + + other_device_set_root[device_id]; + value_to_device_set[new_value].push_back(device_id); + } + CHECK_LE(value_to_device_set.size(), num_devices); + std::vector> device_sets; for (const auto& value_and_device_set : value_to_device_set) { - device_sets.push_back( - absl::Span(value_and_device_set.second)); + device_sets.push_back(value_and_device_set.second); } - return PartiallyReplicated(device_sets); - } else { - return UniqueOnAllDevices(); + device_sets_per_replica.push_back(std::move(device_sets)); + unique_on_all_devices &= value_to_device_set.size() == num_devices; + } + if (unique_on_all_devices) { + return HloReplication::UniqueOnAllDevices(); } + return HloReplication::PartiallyReplicated(device_sets_per_replica); } } } @@ -622,7 +681,14 @@ bool HloReplicationAnalysis::HloReplication::Equal( if (state_ != other.state_) { return false; } - return absl::c_equal(device_set_root_, other.device_set_root_); + for (int i = 0; i < device_set_root_per_replica_.size(); ++i) { + if (device_set_root_per_replica_[i] != + other.device_set_root_per_replica_[i]) { + return false; + } + } + + return true; } bool HloReplicationAnalysis::HloReplication::IsReplicatedOnAllDevices() const { @@ -636,9 +702,16 @@ bool HloReplicationAnalysis::HloReplication::IsUniqueOnAllDevices() const { bool HloReplicationAnalysis::HloReplication::IsReplicatedWithinSubgroup( absl::Span device_ids) const { if (device_ids.empty()) return true; - return absl::c_all_of(device_ids, [this, &device_ids](int device_id) { - return device_set_root_[device_id] == device_set_root_[device_ids.front()]; - }); + for (std::vector device_set_roots : device_set_root_per_replica_) { + if (!absl::c_all_of(device_ids, + [&device_ids, &device_set_roots](int device_id) { + return device_set_roots[device_id] == + device_set_roots[device_ids.front()]; + })) { + return false; + } + } + return true; } std::string HloReplicationAnalysis::HloReplication::ToString() const { @@ -648,8 +721,17 @@ std::string HloReplicationAnalysis::HloReplication::ToString() const { case State::kUniqueOnAllDevices: return "UniqueOnAllDevices"; case State::kPartiallyReplicated: - return absl::StrCat("PartiallyReplicated{", - absl::StrJoin(device_set_root_, ","), "}"); + std::ostringstream oss; + oss << "PartiallyReplicated{"; + for (int k = 0; k < device_set_root_per_replica_.size(); ++k) { + if (k > 0) { + oss << ", "; + } + oss << absl::StrCat( + "{", absl::StrJoin(device_set_root_per_replica_[k], ","), "}"); + } + oss << "}"; + return oss.str(); } } diff --git a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h index 2818e1ff61196e..aa4f15ab98b3e6 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h +++ b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h @@ -69,7 +69,8 @@ class HloReplicationAnalysis { static HloReplication ReplicatedOnAllDevices(); static HloReplication UniqueOnAllDevices(); static HloReplication PartiallyReplicated( - absl::Span> device_sets); + absl::Span>> + device_sets_per_replica); HloReplication(); HloReplication(const HloReplication& other) = default; HloReplication(HloReplication&& other) = default; @@ -87,14 +88,20 @@ class HloReplicationAnalysis { kUniqueOnAllDevices = 1, kPartiallyReplicated = 2, }; - explicit HloReplication(State state, - absl::Span device_set_root); + explicit HloReplication( + State state, + absl::Span> device_set_root_per_replica); State state_; // Empty if state_ is kReplicatedOnAllDevices or kUniqueOnAllDevices. - // Otherwise, its size equals to the number of devices (either partitions - // or replications). Maps each device ID to the smallest device ID in the - // set. - std::vector device_set_root_; + + // If cross_partition_spmd is true, groups_for_replicas_[k]'s size equals + // the number of partitions, and within replica k, groups_for_replicas_[k] + // maps each partition ID to the smallest partition ID in the set. + // + // If cross_partition_spmd is false, groups_for_replicas_[k]'s size equals + // the number of replicas, and within partition k, groups_for_replicas_[k] + // maps each replica to the smallest replica ID in the set. + std::vector> device_set_root_per_replica_; }; static HloReplication DetermineHloInstructionIsReplicated( diff --git a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis_test.cc b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis_test.cc index 0f2b0061e45c78..eb0a2b1852f5d8 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -29,7 +30,19 @@ limitations under the License. namespace xla { namespace { -class HloReplicationAnalysisTest : public HloHardwareIndependentTestBase {}; +class HloReplicationAnalysisTest : public HloHardwareIndependentTestBase { + public: + std::vector CreateReplicaGroups( + std::vector> replica_ids) { + std::vector replica_groups(replica_ids.size()); + for (int i = 0; i < replica_ids.size(); ++i) { + for (int id : replica_ids[i]) { + replica_groups[i].add_replica_ids(id); + } + } + return replica_groups; + } +}; TEST_F(HloReplicationAnalysisTest, NoControlFlow) { const std::string module_str = R"( @@ -596,7 +609,9 @@ ENTRY entry { use_global_device_ids=true, channel_id=2 ag3 = f32[4] all-gather(param), replica_groups={{0,1,2,3}}, dimensions={0}, use_global_device_ids=true, channel_id=3 - ROOT tuple = (f32[2], f32[2], f32[4]) tuple(ag1, ag2, ag3) + ag4 = f32[2] all-gather(param), replica_groups={{0,3},{1,2}}, dimensions={0}, + use_global_device_ids=true, channel_id=4 + ROOT tuple = (f32[2], f32[2], f32[4], f32[2]) tuple(ag1, ag2, ag3, ag4) } )"; @@ -617,6 +632,8 @@ ENTRY entry { FindInstruction(module.get(), "ag2"), {})); EXPECT_TRUE(replica_analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "ag3"), {})); + EXPECT_FALSE(replica_analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "ag4"), {})); EXPECT_TRUE(partition_analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "ag1"), {})); @@ -624,6 +641,8 @@ ENTRY entry { FindInstruction(module.get(), "ag2"), {})); EXPECT_TRUE(partition_analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "ag3"), {})); + EXPECT_FALSE(partition_analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "ag4"), {})); } TEST_F(HloReplicationAnalysisTest, PartiallyReplicatedDynamicSlice) { @@ -636,41 +655,30 @@ ENTRY entry { ROOT dynamic-slice = s32[1] dynamic-slice(constant, replica-id), dynamic_slice_sizes={1} } )"; + const int replica_count = 8; + const int num_partitions = 1; + const bool cross_partition_spmd = false; + const std::vector replica_groups0 = + CreateReplicaGroups({{0, 4}, {1, 5}, {2, 6}, {3, 7}}); + const std::vector replica_groups1 = + CreateReplicaGroups({{0, 1, 2, 3}, {4, 5, 6, 7}}); TF_ASSERT_OK_AND_ASSIGN( - auto module, ParseAndReturnVerifiedModule(module_str, /*replica_count=*/8, - /*num_partitions=*/1)); + auto module, + ParseAndReturnVerifiedModule(module_str, replica_count, num_partitions)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr replica_analysis, - HloReplicationAnalysis::RunWithPartialReplication( - module.get(), - /*cross_partition_spmd=*/false)); + HloReplicationAnalysis::RunWithPartialReplication(module.get(), + cross_partition_spmd)); EXPECT_FALSE(replica_analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "dynamic-slice"), {})); - std::vector replica_groups(4); - replica_groups[0].add_replica_ids(0); - replica_groups[0].add_replica_ids(4); - replica_groups[1].add_replica_ids(1); - replica_groups[1].add_replica_ids(5); - replica_groups[2].add_replica_ids(2); - replica_groups[2].add_replica_ids(6); - replica_groups[3].add_replica_ids(3); - replica_groups[3].add_replica_ids(7); + EXPECT_TRUE(replica_analysis->HloInstructionIsReplicatedAt( - FindInstruction(module.get(), "dynamic-slice"), {}, replica_groups)); - - std::vector replica_groups_2(2); - replica_groups_2[0].add_replica_ids(0); - replica_groups_2[0].add_replica_ids(1); - replica_groups_2[0].add_replica_ids(2); - replica_groups_2[0].add_replica_ids(3); - replica_groups_2[1].add_replica_ids(4); - replica_groups_2[1].add_replica_ids(5); - replica_groups_2[1].add_replica_ids(6); - replica_groups_2[1].add_replica_ids(7); + FindInstruction(module.get(), "dynamic-slice"), {}, replica_groups0)); + EXPECT_FALSE(replica_analysis->HloInstructionIsReplicatedAt( - FindInstruction(module.get(), "dynamic-slice"), {}, replica_groups_2)); + FindInstruction(module.get(), "dynamic-slice"), {}, replica_groups1)); } TEST_F(HloReplicationAnalysisTest, @@ -685,28 +693,21 @@ ENTRY entry { ROOT tuple = (s32[4], s32[4]) tuple(all-gather0, all-gather1) } )"; + const int replica_count = 4; + const int num_partitions = 2; + const bool cross_partition_spmd = false; + const std::vector replica_groups0 = + CreateReplicaGroups({{0, 1}, {2, 3}}); + const std::vector replica_groups1 = + CreateReplicaGroups({{0, 2}, {1, 3}}); TF_ASSERT_OK_AND_ASSIGN( auto module_replica_analysis, - ParseAndReturnVerifiedModule(module_str, /*replica_count=*/4, - /*num_partitions=*/2)); + ParseAndReturnVerifiedModule(module_str, replica_count, num_partitions)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr replica_analysis, HloReplicationAnalysis::RunWithPartialReplication( - module_replica_analysis.get(), - /*cross_partition_spmd=*/false)); - - std::array replica_groups0; - replica_groups0[0].add_replica_ids(0); - replica_groups0[0].add_replica_ids(1); - replica_groups0[1].add_replica_ids(2); - replica_groups0[1].add_replica_ids(3); - - std::array replica_groups1; - replica_groups1[0].add_replica_ids(0); - replica_groups1[0].add_replica_ids(2); - replica_groups1[1].add_replica_ids(1); - replica_groups1[1].add_replica_ids(3); + module_replica_analysis.get(), cross_partition_spmd)); EXPECT_FALSE(replica_analysis->HloInstructionIsReplicatedAt( FindInstruction(module_replica_analysis.get(), "all-gather0"), {})); @@ -743,28 +744,21 @@ ENTRY entry { ROOT tuple = (s32[4], s32[4]) tuple(all-gather0, all-gather1) } )"; + const int replica_count = 2; + const int num_partitions = 4; + const bool cross_partition_spmd = true; + const std::vector replica_groups0 = + CreateReplicaGroups({{0, 1}, {2, 3}}); + const std::vector replica_groups1 = + CreateReplicaGroups({{0, 2}, {1, 3}}); TF_ASSERT_OK_AND_ASSIGN( auto module_partition_analysis, - ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2, - /*num_partitions=*/4)); + ParseAndReturnVerifiedModule(module_str, replica_count, num_partitions)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr partition_analysis, HloReplicationAnalysis::RunWithPartialReplication( - module_partition_analysis.get(), - /*cross_partition_spmd=*/true)); - - std::array replica_groups0; - replica_groups0[0].add_replica_ids(0); - replica_groups0[0].add_replica_ids(1); - replica_groups0[1].add_replica_ids(2); - replica_groups0[1].add_replica_ids(3); - - std::array replica_groups1; - replica_groups1[0].add_replica_ids(0); - replica_groups1[0].add_replica_ids(2); - replica_groups1[1].add_replica_ids(1); - replica_groups1[1].add_replica_ids(3); + module_partition_analysis.get(), cross_partition_spmd)); EXPECT_FALSE(partition_analysis->HloInstructionIsReplicatedAt( FindInstruction(module_partition_analysis.get(), "all-gather0"), {})); @@ -789,6 +783,174 @@ ENTRY entry { replica_groups0)); } +TEST_F( + HloReplicationAnalysisTest, + PartiallyReplicatedAllGatherFlattenedIDPartitionAnalysisAsymmetricGroups) { + const std::string module_str = R"( +HloModule GlobalIdAllGather + +ENTRY entry { + param = f32[1] parameter(0) + ROOT all_gather = f32[6] all-gather(param), replica_groups={{0,1,2,3,6,7},{4,5,8,9,10,11}}, dimensions={0}, use_global_device_ids=true, channel_id=1 +} +)"; + const int replica_count = 2; + const int num_partitions = 6; + const bool cross_partition_spmd = true; + const std::vector replica_groups0 = + CreateReplicaGroups({{0, 1}, {2, 3}, {4, 5}}); + const std::vector replica_groups1 = + CreateReplicaGroups({{0, 1, 2}, {3, 4, 5}}); + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule(module_str, replica_count, num_partitions)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr partition_analysis, + HloReplicationAnalysis::RunWithPartialReplication(module.get(), + cross_partition_spmd)); + + EXPECT_TRUE(partition_analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all_gather"), {}, replica_groups0)); + EXPECT_FALSE(partition_analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all_gather"), {}, replica_groups1)); +} + +TEST_F(HloReplicationAnalysisTest, + PartiallyReplicatedAllGatherFlattenedIDReplicaAnalysisAsymmetricGroups) { + const std::string module_str = R"( +HloModule GlobalIdAllGather + +ENTRY entry { + param = f32[1] parameter(0) + ROOT all_gather = f32[6] all-gather(param), replica_groups={{0,1,2,3,4,6},{5,7,8,9,10,11}}, dimensions={0}, use_global_device_ids=true, channel_id=1 +} +)"; + const int replica_count = 6; + const int num_partitions = 2; + const bool cross_partition_spmd = false; + const std::vector replica_groups0 = + CreateReplicaGroups({{0, 1}, {2, 3}, {4, 5}}); + const std::vector replica_groups1 = + CreateReplicaGroups({{0, 1, 2}, {3, 4, 5}}); + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule(module_str, replica_count, num_partitions)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr partition_analysis, + HloReplicationAnalysis::RunWithPartialReplication(module.get(), + cross_partition_spmd)); + + EXPECT_TRUE(partition_analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all_gather"), {}, replica_groups0)); + EXPECT_FALSE(partition_analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all_gather"), {}, replica_groups1)); +} + +TEST_F( + HloReplicationAnalysisTest, + PartiallyReplicatedAllGatherFlattenedIDPartitionAnalysisAsymmetricPartial) { + const std::string module_str = R"( +HloModule GlobalIdAllGather + +ENTRY entry { + param = f32[1] parameter(0) + ROOT all_gather = f32[6] all-gather(param), replica_groups={{0,1,2,3,6,7},{4,5,8,9,10,11},{12,13,14,15,16,17}}, dimensions={0}, use_global_device_ids=true, channel_id=1 +} +)"; + const int replica_count = 3; + const int num_partitions = 6; + const bool cross_partition_spmd = true; + const std::vector replica_groups0 = + CreateReplicaGroups({{0, 1}, {2, 3}, {4, 5}}); + const std::vector replica_groups1 = + CreateReplicaGroups({{0, 1, 2}, {3, 4, 5}}); + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule(module_str, replica_count, num_partitions)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr partition_analysis, + HloReplicationAnalysis::RunWithPartialReplication(module.get(), + cross_partition_spmd)); + + EXPECT_TRUE(partition_analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all_gather"), {}, replica_groups0)); + EXPECT_FALSE(partition_analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all_gather"), {}, replica_groups1)); +} + +TEST_F(HloReplicationAnalysisTest, + PartiallyReplicatedAllGatherFlattenedIDPartitionAnalysisAsymmetricAll) { + const std::string module_str = R"( +HloModule GlobalIdAllGather + +ENTRY entry { + param = f32[1] parameter(0) + ROOT all_gather = f32[4] all-gather(param), replica_groups={{0,2,5,7},{1,3,4,6}}, dimensions={0}, use_global_device_ids=true, channel_id=1 +} +)"; + const int replica_count = 2; + const int num_partitions = 4; + const bool cross_partition_spmd = true; + const std::vector replica_groups = + CreateReplicaGroups({{0, 1}, {2, 3}}); + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule(module_str, replica_count, num_partitions)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr partition_analysis, + HloReplicationAnalysis::RunWithPartialReplication(module.get(), + cross_partition_spmd)); + + EXPECT_FALSE(partition_analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all_gather"), {}, replica_groups)); +} + +TEST_F(HloReplicationAnalysisTest, + PartiallyReplicatedAllGatherFlattenedIDPartitionAnalysisMerge) { + const std::string module_str = R"( + HloModule module + + ENTRY entry { + param0 = f32[2] parameter(0) + param1 = f32[4] parameter(1) + all_gather0 = f32[8] all-gather(param0), dimensions={0}, replica_groups={{0,1,2,3},{4,5,6,7},{8,9,10,11},{12,13,14,15}}, use_global_device_ids=true, channel_id=1 + all_gather1 = f32[8] all-gather(param1), dimensions={0}, replica_groups={{0,1},{2,3},{4,5},{6,7},{8,9},{10,11},{12,13},{14,15}}, use_global_device_ids=true, channel_id=2 + all_gather2 = f32[8] all-gather(param0), dimensions={0}, replica_groups={{0,3,4,5},{1,2,6,7},{8,11,12,13},{9,10,14,15}}, use_global_device_ids=true, channel_id=3 + add0 = f32[8] add(all_gather0, all_gather1) + add1 = f32[8] add(all_gather0, all_gather2) + ROOT tuple = (f32[8], f32[8]) tuple(add0, add1) + } + )"; + const int replica_count = 2; + const int num_partitions = 8; + const bool cross_partition_spmd = true; + const std::vector replica_groups0 = + CreateReplicaGroups({{0, 1, 2, 3}, {4, 5, 6, 7}}); + const std::vector replica_groups1 = + CreateReplicaGroups({{0, 1}, {2, 3}, {4, 5}, {6, 7}}); + const std::vector replica_groups2 = + CreateReplicaGroups({{1, 2}, {0, 3}, {4, 5}, {6, 7}}); + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule(module_str, replica_count, num_partitions)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, + HloReplicationAnalysis::RunWithPartialReplication( + module.get(), cross_partition_spmd)); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "add0"), {}, replica_groups0)); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "add0"), {}, replica_groups1)); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "add1"), {}, replica_groups0)); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "add1"), {}, replica_groups2)); +} + TEST_F(HloReplicationAnalysisTest, OptimizationBarrier) { const std::string module_str = R"( HloModule OptimizationBarrier diff --git a/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.cc index 49b7c78fd2b9a1..f2454620fd4665 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.cc @@ -24,7 +24,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -944,7 +943,7 @@ std::string HloValueSemanticsTreeToString( HloValueSemanticsAnalysis::HloValueSemanticsAnalysis( const HloModule& module, - const absl::flat_hash_set& execution_threads) + const absl::flat_hash_set& execution_threads) : module_(module), execution_threads_(execution_threads), next_id_(0) {} const HloValueSemantics* HloValueSemanticsAnalysis::GetSemantics( @@ -969,7 +968,7 @@ int HloValueSemanticsAnalysis::GetHeight(const HloInstruction* instruction, absl::StatusOr> HloValueSemanticsAnalysis::Run( const HloModule& module, - const absl::flat_hash_set& execution_threads) { + const absl::flat_hash_set& execution_threads) { std::unique_ptr value_semantics_analysis = absl::WrapUnique( new HloValueSemanticsAnalysis(module, execution_threads)); diff --git a/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.h b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.h index c6fa0284e7cf97..ec1f6df405206c 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.h +++ b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.h @@ -247,7 +247,7 @@ class HloValueSemanticsAnalysis { public: static absl::StatusOr> Run( const HloModule& module, - const absl::flat_hash_set& execution_threads = {}); + const absl::flat_hash_set& execution_threads = {}); virtual ~HloValueSemanticsAnalysis() = default; bool HasSemanticsFor(const HloInstruction* instruction) const; const HloValueSemantics* GetSemantics(const HloInstruction* instruction, @@ -277,7 +277,7 @@ class HloValueSemanticsAnalysis { friend class HloValueSemanticsPropagation; explicit HloValueSemanticsAnalysis( const HloModule& module, - const absl::flat_hash_set& execution_threads); + const absl::flat_hash_set& execution_threads); virtual absl::Status InitializeEinsumDepth(); virtual absl::Status InitializeEinsumHeight(); // We match send and recv HLOs to propagate semantics from send to recv. diff --git a/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis_test.cc b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis_test.cc index 4c66f9de7207fb..46cc4afa41ccb0 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis_test.cc @@ -722,7 +722,7 @@ TEST_F(EinsumHeightAnalysisTest, MnistTrainingLoop) { TEST_F(HloValueSemanticsAnalysisTest, HandleIncompleteForeignThreadComputation) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule Module ENTRY entry { diff --git a/third_party/xla/xla/hlo/analysis/indexing_analysis_test.cc b/third_party/xla/xla/hlo/analysis/indexing_analysis_test.cc index cf08cd8a1f3e83..ae4bf1bc96f966 100644 --- a/third_party/xla/xla/hlo/analysis/indexing_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/indexing_analysis_test.cc @@ -1782,7 +1782,7 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { d0 in [0, 9] )")); - constexpr std::string_view kInputToOutputIndexing = R"( + constexpr absl::string_view kInputToOutputIndexing = R"( (d0, d1) -> (d1), domain: d0 in [0, 255], @@ -1800,7 +1800,7 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { ElementsAre(ElementsAre(MatchIndexingMap(kInputToOutputIndexing)), ElementsAre(MatchIndexingMap(kInputToOutputIndexing)))); - constexpr std::string_view kInitToOutputIndexing = R"( + constexpr absl::string_view kInitToOutputIndexing = R"( ()[s0] -> (s0), domain: s0 in [0, 9] diff --git a/third_party/xla/xla/hlo/analysis/indexing_map.cc b/third_party/xla/xla/hlo/analysis/indexing_map.cc index 027cf17a010c3a..ec48eb5fd3b1ca 100644 --- a/third_party/xla/xla/hlo/analysis/indexing_map.cc +++ b/third_party/xla/xla/hlo/analysis/indexing_map.cc @@ -24,7 +24,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -785,17 +784,17 @@ SmallVector MapSymbolsToComposedSymbolsList( } // namespace -static constexpr std::string_view kVarKindDefault = "default"; -static constexpr std::string_view kVarKindThreadX = "th_x"; -static constexpr std::string_view kVarKindThreadY = "th_y"; -static constexpr std::string_view kVarKindThreadZ = "th_z"; -static constexpr std::string_view kVarKindBlockX = "bl_x"; -static constexpr std::string_view kVarKindBlockY = "bl_y"; -static constexpr std::string_view kVarKindBlockZ = "bl_z"; -static constexpr std::string_view kVarKindWarp = "warp"; -static constexpr std::string_view kVarKindWarpThread = "th_w"; - -std::string_view ToVariableName(VariableKind var_kind) { +static constexpr absl::string_view kVarKindDefault = "default"; +static constexpr absl::string_view kVarKindThreadX = "th_x"; +static constexpr absl::string_view kVarKindThreadY = "th_y"; +static constexpr absl::string_view kVarKindThreadZ = "th_z"; +static constexpr absl::string_view kVarKindBlockX = "bl_x"; +static constexpr absl::string_view kVarKindBlockY = "bl_y"; +static constexpr absl::string_view kVarKindBlockZ = "bl_z"; +static constexpr absl::string_view kVarKindWarp = "warp"; +static constexpr absl::string_view kVarKindWarpThread = "th_w"; + +absl::string_view ToVariableName(VariableKind var_kind) { switch (var_kind) { case VariableKind::kDefault: return kVarKindDefault; @@ -819,7 +818,7 @@ std::string_view ToVariableName(VariableKind var_kind) { llvm_unreachable("Unknown VariableType"); } -VariableKind ToVariableType(std::string_view var_name) { +VariableKind ToVariableType(absl::string_view var_name) { if (var_name == kVarKindThreadX) return VariableKind::kThreadX; if (var_name == kVarKindThreadY) return VariableKind::kThreadY; if (var_name == kVarKindThreadZ) return VariableKind::kThreadZ; diff --git a/third_party/xla/xla/hlo/analysis/indexing_map.h b/third_party/xla/xla/hlo/analysis/indexing_map.h index 342853f01bd078..77ea7ec24f3be4 100644 --- a/third_party/xla/xla/hlo/analysis/indexing_map.h +++ b/third_party/xla/xla/hlo/analysis/indexing_map.h @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -55,8 +54,8 @@ enum class VariableKind : char { kWarpThread }; -std::string_view ToVariableName(VariableKind var_kind); -VariableKind ToVariableType(std::string_view var_name); +absl::string_view ToVariableName(VariableKind var_kind); +VariableKind ToVariableType(absl::string_view var_name); std::ostream& operator<<(std::ostream& out, VariableKind var_type); // Interval represents a closed interval [lower_bound, upper_bound]. @@ -252,7 +251,7 @@ class IndexingMap { const llvm::DenseMap& constraints); IndexingMap(const IndexingMap&) = default; - IndexingMap(IndexingMap&&) = default; + IndexingMap(IndexingMap&&) noexcept = default; IndexingMap& operator=(const IndexingMap&) = default; IndexingMap& operator=(IndexingMap&&) = default; @@ -287,7 +286,7 @@ class IndexingMap { RangeEvaluator GetRangeEvaluator() const; // Getters for dimension vars. - const Variable& GetDimVars(int64_t id) const { return dim_vars_[id]; } + const Variable& GetDimVar(int64_t id) const { return dim_vars_[id]; } const std::vector& GetDimVars() const { return dim_vars_; } int64_t GetDimVarsCount() const { return dim_vars_.size(); } @@ -408,18 +407,18 @@ class IndexingMap { mlir::AffineMap affine_map_; - // Dimension variable represents a dimension of a tensor or a GPU grid. - // Dimensions correspond to the dimension parameter of `affine_map_`. + // A dimension variable represents a dimension of a tensor or a GPU grid. + // Dimension variables correspond to the dimensions of the `affine_map_`. std::vector dim_vars_; - // RangeSymbol variable represents a range of values, e.g. to compute a single + // A range variable represents a range of values, e.g. to compute a single // element of the reduction's result we need a range of values from the input - // tensor. RangeSymbol variables correspond to the front portion of the + // tensor. Range variables correspond to the front portion of the // symbols in `affine_map_`. std::vector range_vars_; - // RTSymbol variable represents a runtime symbol, e.g. a dynamic offset in - // HLO dynamic-update-slice op. RTSymbol variables correspond to the back + // A runtime variable represents a runtime symbol, e.g. a dynamic offset in of + // a HLO dynamic-update-slice op. Runtime variables correspond to the back // portion of the symbols in `affine_map_`. std::vector rt_vars_; diff --git a/third_party/xla/xla/hlo/analysis/indexing_map_serialization.cc b/third_party/xla/xla/hlo/analysis/indexing_map_serialization.cc index 9b61bbeeb77c88..7ce84492350549 100644 --- a/third_party/xla/xla/hlo/analysis/indexing_map_serialization.cc +++ b/third_party/xla/xla/hlo/analysis/indexing_map_serialization.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -407,24 +406,24 @@ bool ParseAffineExprsWithMLIR(ArrayRef dim_var_names, return true; } -std::string GetVarName(int64_t id, std::string_view name, - std::string_view prefix) { +std::string GetVarName(int64_t id, absl::string_view name, + absl::string_view prefix) { if (!name.empty()) { return std::string(name); } return absl::StrFormat("%s%d", prefix, id); } -std::string GetDimVarName(int64_t dim_id, std::string_view dim_name = "") { +std::string GetDimVarName(int64_t dim_id, absl::string_view dim_name = "") { return GetVarName(dim_id, dim_name, "d"); } std::string GetRangeVarName(int64_t range_id, - std::string_view range_name = "") { + absl::string_view range_name = "") { return GetVarName(range_id, range_name, "s"); } -std::string GetRTVarName(int64_t rt_id, std::string_view rt_name = "") { +std::string GetRTVarName(int64_t rt_id, absl::string_view rt_name = "") { return GetVarName(rt_id, rt_name, "rt"); } diff --git a/third_party/xla/xla/hlo/analysis/indexing_test_utils.cc b/third_party/xla/xla/hlo/analysis/indexing_test_utils.cc index 9fb1d03aaa9d8e..52e62fb0210673 100644 --- a/third_party/xla/xla/hlo/analysis/indexing_test_utils.cc +++ b/third_party/xla/xla/hlo/analysis/indexing_test_utils.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -161,7 +160,7 @@ AffineExpr ParseAffineExpr(absl::string_view serialized_affine_expr, .getResult(0); } -bool ApproximateMatch(std::string_view lhs, std::string_view rhs) { +bool ApproximateMatch(absl::string_view lhs, absl::string_view rhs) { size_t lhs_length = lhs.size(); size_t rhs_length = rhs.size(); size_t l = 0, r = 0; diff --git a/third_party/xla/xla/hlo/analysis/indexing_test_utils.h b/third_party/xla/xla/hlo/analysis/indexing_test_utils.h index aa1566a6015c00..92ccc2de73460f 100644 --- a/third_party/xla/xla/hlo/analysis/indexing_test_utils.h +++ b/third_party/xla/xla/hlo/analysis/indexing_test_utils.h @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -36,12 +35,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" namespace xla { // Matches two strings ignoring whitespaces. -bool ApproximateMatch(std::string_view lhs, std::string_view rhs); +bool ApproximateMatch(absl::string_view lhs, absl::string_view rhs); MATCHER(UndefinedMap, "") { return arg.IsUndefined(); } diff --git a/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis_test.cc b/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis_test.cc index 723f0c4f3d095f..e33d21052b588b 100644 --- a/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis_test.cc @@ -29,12 +29,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal_util.h" #include "xla/service/logical_buffer.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/hlo/analysis/while_loop_analysis.cc b/third_party/xla/xla/hlo/analysis/while_loop_analysis.cc index 9121587f3a2608..6e69f2f277ad96 100644 --- a/third_party/xla/xla/hlo/analysis/while_loop_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/while_loop_analysis.cc @@ -504,7 +504,7 @@ optional MatchTrivialLoopRange(const HloInstruction* while_op) { return nullopt; } - // Check that `i` goes as `i += k` in the while body where k is a natural + // Check that `i` goes as `i += C` in the while body where C is a natural // number. auto* while_body = while_op->while_body(); auto* while_body_indvar_update = @@ -589,6 +589,35 @@ optional MatchTrivialLoopRange(const HloInstruction* while_op) { return nullopt; } + // If the while loop condition does not support equality, then we need to + // deduct one from the bound. + bool while_cond_bound_supports_equality; + if (Match(while_cond_root, + m::Op().WithComparisonDirection(ComparisonDirection::kLt)) || + Match(while_cond_root, + m::Op().WithComparisonDirection(ComparisonDirection::kGt))) { + while_cond_bound_supports_equality = false; + } else if (Match(while_cond_root, + m::Op().WithComparisonDirection(ComparisonDirection::kLe)) || + Match(while_cond_root, + m::Op().WithComparisonDirection(ComparisonDirection::kGe))) { + while_cond_bound_supports_equality = true; + } else { + VLOG(2) << "Pattern-match failed: while condition comparison is not " + "LT, GT, LE, or GE."; + return nullopt; + } + if (!while_cond_bound_supports_equality) { + while_cond_bound_val.value()--; + } + + // We also need to round the bound down so that the difference between bound + // and init_value is a multiple of the step size. + while_cond_bound_val.value() = + (while_cond_bound_val.value() - indvar_init_val.value()) / + trip_count_step * trip_count_step + + indvar_init_val.value(); + const int64_t init_bitwidth = primitive_util::BitWidth(indvar_init.shape().element_type()); const bool init_is_signed = diff --git a/third_party/xla/xla/hlo/analysis/while_loop_analysis.h b/third_party/xla/xla/hlo/analysis/while_loop_analysis.h index edb154749eaa2f..8a99e2b434332c 100644 --- a/third_party/xla/xla/hlo/analysis/while_loop_analysis.h +++ b/third_party/xla/xla/hlo/analysis/while_loop_analysis.h @@ -50,18 +50,18 @@ std::optional GetLoopInductionVarTupleIdx( const HloInstruction *while_op); // Checks the following conditions: -// - `i`, the induction varaiable, is initialized to a scalar constant K +// - `i`, the induction variable, is initialized to a scalar constant K // (namely, `indvar_init`), -// - the while condition does `i < N` or `i <= N` (where N is a know constant) -// - the while body does `i++`. -// If so, it's trivial to compute the loop bound as `N - k` or `N - k + 1`, -// respectively. +// - the while condition does `i < N` or `i <= N` (where N is a known constant) +// - the while body does `i += C` (where C is a positive constant) +// If so, it's trivial to compute the loop bound as `(N - K) div C` or +// `(N - K + 1) div C`, respectively. std::optional MatchTrivialLoopTripCount(const HloInstruction *while_op, int64_t indvar_tuple_idx, const Literal &indvar_init); // Same as above, but returns the loop range, i.e., start (inclusive), end -// (exclusive) and step instead of the trip count. +// (inclusive) and step instead of the trip count. std::optional MatchTrivialLoopRange(const HloInstruction *while_op); } // namespace xla diff --git a/third_party/xla/xla/hlo/analysis/while_loop_analysis_test.cc b/third_party/xla/xla/hlo/analysis/while_loop_analysis_test.cc index 5252bda64ff871..ab69ff36512a69 100644 --- a/third_party/xla/xla/hlo/analysis/while_loop_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/while_loop_analysis_test.cc @@ -34,9 +34,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/test.h" #include "xla/service/constant_value.h" #include "xla/service/value_range.h" -#include "xla/test.h" #include "xla/util.h" #include "tsl/platform/statusor.h" @@ -301,27 +301,27 @@ bool RangeEqualIgnoreBitwidth(const Range& range, int init, int limit, : r.min().GetUnsignedValue(); }; auto range_max = [](const Range& r) { - return r.min().IsSigned() ? r.max().GetSignedValue() - : r.max().GetUnsignedValue(); + return r.max()->IsSigned() ? r.max()->GetSignedValue() + : r.max()->GetUnsignedValue(); }; return range_min(range) == init && range_max(range) == limit && - range.step().GetSignedValue() == step; + range.step()->GetSignedValue() == step; } TEST_F(WhileLoopAnalysisTest, ExactBoundTrivialRange) { // LT cases EXPECT_TRUE(RangeEqualIgnoreBitwidth( MakeWhileLoopAndGetRange(0, 42, 1, ComparisonDirection::kLt).value(), 0, - 42, 1)); + 41, 1)); EXPECT_TRUE(RangeEqualIgnoreBitwidth( MakeWhileLoopAndGetRange(0, 42, 2, ComparisonDirection::kLt).value(), 0, - 42, 2)); + 40, 2)); EXPECT_TRUE(RangeEqualIgnoreBitwidth( MakeWhileLoopAndGetRange(0, 42, 5, ComparisonDirection::kLt).value(), 0, - 42, 5)); + 40, 5)); EXPECT_TRUE(RangeEqualIgnoreBitwidth( MakeWhileLoopAndGetRange(0, 40, 5, ComparisonDirection::kLt).value(), 0, - 40, 5)); + 35, 5)); // LE cases EXPECT_TRUE(RangeEqualIgnoreBitwidth( @@ -332,7 +332,7 @@ TEST_F(WhileLoopAnalysisTest, ExactBoundTrivialRange) { 42, 2)); EXPECT_TRUE(RangeEqualIgnoreBitwidth( MakeWhileLoopAndGetRange(0, 42, 5, ComparisonDirection::kLe).value(), 0, - 42, 5)); + 40, 5)); EXPECT_TRUE(RangeEqualIgnoreBitwidth( MakeWhileLoopAndGetRange(0, 40, 5, ComparisonDirection::kLe).value(), 0, 40, 5)); diff --git a/third_party/xla/xla/hlo/builder/lib/BUILD b/third_party/xla/xla/hlo/builder/lib/BUILD index 489259e694a7b2..fbfc13188c4edd 100644 --- a/third_party/xla/xla/hlo/builder/lib/BUILD +++ b/third_party/xla/xla/hlo/builder/lib/BUILD @@ -89,6 +89,7 @@ xla_test( "//xla/tests:test_macros_header", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:test_main", ], @@ -141,6 +142,7 @@ xla_test( "//xla/hlo/builder:xla_builder", "//xla/tests:client_library_test_base", "//xla/tests:test_macros_header", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:test_main", ], ) @@ -299,6 +301,7 @@ xla_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:test_main", ], ) @@ -374,6 +377,7 @@ xla_test( "//xla/tests:test_macros_header", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:test_main", ], ) @@ -414,6 +418,7 @@ xla_test( "//xla/hlo/builder:xla_builder", "//xla/tests:client_library_test_base", "//xla/tests:test_macros_header", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], @@ -566,6 +571,7 @@ xla_test( "//xla/hlo/builder:xla_builder", "//xla/tests:client_library_test_base", "//xla/tests:test_macros_header", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:bfloat16", "@local_tsl//tsl/platform:test_main", ], @@ -680,6 +686,7 @@ cc_library( "//xla/hlo/builder:xla_builder", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", @@ -705,6 +712,7 @@ xla_test( "//xla/tests:client_library_test_base", "//xla/tests:test_macros_header", "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], diff --git a/third_party/xla/xla/hlo/builder/lib/approx_topk.cc b/third_party/xla/xla/hlo/builder/lib/approx_topk.cc index 16e9c090e9dd3b..f6df6a9ac486ce 100644 --- a/third_party/xla/xla/hlo/builder/lib/approx_topk.cc +++ b/third_party/xla/xla/hlo/builder/lib/approx_topk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "xla/hlo/builder/lib/approx_topk.h" -#include +#include #include #include #include diff --git a/third_party/xla/xla/hlo/builder/lib/approx_topk.h b/third_party/xla/xla/hlo/builder/lib/approx_topk.h index f940d26967cc76..b4f63c1ec9a315 100644 --- a/third_party/xla/xla/hlo/builder/lib/approx_topk.h +++ b/third_party/xla/xla/hlo/builder/lib/approx_topk.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_HLO_BUILDER_LIB_APPROX_TOPK_H_ #define XLA_HLO_BUILDER_LIB_APPROX_TOPK_H_ +#include + #include "absl/types/span.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" diff --git a/third_party/xla/xla/hlo/builder/lib/approx_topk_shape.h b/third_party/xla/xla/hlo/builder/lib/approx_topk_shape.h index 83b2b71d1054e5..f373ee5165edad 100644 --- a/third_party/xla/xla/hlo/builder/lib/approx_topk_shape.h +++ b/third_party/xla/xla/hlo/builder/lib/approx_topk_shape.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_HLO_BUILDER_LIB_APPROX_TOPK_SHAPE_H_ #define XLA_HLO_BUILDER_LIB_APPROX_TOPK_SHAPE_H_ +#include #include #include "absl/status/statusor.h" diff --git a/third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc b/third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc index 3cde6bf0f4e5c3..2e5b546f801e84 100644 --- a/third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/hlo/builder/lib/arithmetic.h" +#include #include #include diff --git a/third_party/xla/xla/hlo/builder/lib/comparators.cc b/third_party/xla/xla/hlo/builder/lib/comparators.cc index fec1874a0373d4..a4965caab0d931 100644 --- a/third_party/xla/xla/hlo/builder/lib/comparators.cc +++ b/third_party/xla/xla/hlo/builder/lib/comparators.cc @@ -74,8 +74,8 @@ XlaComputation CreateScalarComparisonComputation( absl::StrCat("p.", parameter_count, ".lhs")); auto rhs_param = Parameter(b.get(), parameter_count * 2 + 1, scalar_shape, absl::StrCat("p.", parameter_count, ".rhs")); - lhs_params.emplace_back(lhs_param); - rhs_params.emplace_back(rhs_param); + lhs_params.push_back(lhs_param); + rhs_params.push_back(rhs_param); if (generators[parameter_count].has_value()) { last_generator_index = parameter_count; } diff --git a/third_party/xla/xla/hlo/builder/lib/comparators_test.cc b/third_party/xla/xla/hlo/builder/lib/comparators_test.cc index 39bf073171a86b..66352ea0296673 100644 --- a/third_party/xla/xla/hlo/builder/lib/comparators_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/comparators_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include "absl/container/inlined_vector.h" #include "absl/strings/string_view.h" #include "xla/hlo/builder/lib/constants.h" diff --git a/third_party/xla/xla/hlo/builder/lib/constants_test.cc b/third_party/xla/xla/hlo/builder/lib/constants_test.cc index 61aa0ae71dee5b..6e934f09c44fc9 100644 --- a/third_party/xla/xla/hlo/builder/lib/constants_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/constants_test.cc @@ -15,8 +15,10 @@ limitations under the License. #include "xla/hlo/builder/lib/constants.h" +#include #include +#include #include "xla/hlo/builder/xla_builder.h" #include "xla/shape_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.cc b/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.cc index 9bbe184a9d6140..019d4a6e8e673d 100644 --- a/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.cc +++ b/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/hlo/builder/lib/conv_grad_size_util.h" #include +#include #include "absl/log/log.h" #include "absl/status/statusor.h" diff --git a/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.h b/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.h index 91e43d226c180b..862c2da1a219da 100644 --- a/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.h +++ b/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_HLO_BUILDER_LIB_CONV_GRAD_SIZE_UTIL_H_ #define XLA_HLO_BUILDER_LIB_CONV_GRAD_SIZE_UTIL_H_ +#include + #include "absl/status/statusor.h" #include "xla/hlo/builder/padding.h" diff --git a/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.cc b/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.cc index ba82ec343ce55a..8644da4aa80ae5 100644 --- a/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.cc +++ b/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/hlo/builder/lib/dynamic_shaped_ops.h" +#include #include #include diff --git a/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.h b/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.h index 71188b8fb80a22..6073e0325fd6a5 100644 --- a/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.h +++ b/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_HLO_BUILDER_LIB_DYNAMIC_SHAPED_OPS_H_ #define XLA_HLO_BUILDER_LIB_DYNAMIC_SHAPED_OPS_H_ +#include + #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/hlo/builder/lib/constants.h" diff --git a/third_party/xla/xla/hlo/builder/lib/logdet.cc b/third_party/xla/xla/hlo/builder/lib/logdet.cc index cc17d0ec26ffe6..0fa69e6b186383 100644 --- a/third_party/xla/xla/hlo/builder/lib/logdet.cc +++ b/third_party/xla/xla/hlo/builder/lib/logdet.cc @@ -15,9 +15,8 @@ limitations under the License. #include "xla/hlo/builder/lib/logdet.h" +#include #include -#include -#include #include "absl/status/statusor.h" #include "xla/hlo/builder/lib/arithmetic.h" diff --git a/third_party/xla/xla/hlo/builder/lib/loops.cc b/third_party/xla/xla/hlo/builder/lib/loops.cc index e7dbad01163d93..e652fcee1262f2 100644 --- a/third_party/xla/xla/hlo/builder/lib/loops.cc +++ b/third_party/xla/xla/hlo/builder/lib/loops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/hlo/builder/lib/loops.h" +#include #include #include #include diff --git a/third_party/xla/xla/hlo/builder/lib/loops.h b/third_party/xla/xla/hlo/builder/lib/loops.h index 540ab784f34684..cef4d16176d4a9 100644 --- a/third_party/xla/xla/hlo/builder/lib/loops.h +++ b/third_party/xla/xla/hlo/builder/lib/loops.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_HLO_BUILDER_LIB_LOOPS_H_ #define XLA_HLO_BUILDER_LIB_LOOPS_H_ +#include #include #include diff --git a/third_party/xla/xla/hlo/builder/lib/lu_decomposition.cc b/third_party/xla/xla/hlo/builder/lib/lu_decomposition.cc index 78e9c00e07ca1a..9c9b56bdfac3f5 100644 --- a/third_party/xla/xla/hlo/builder/lib/lu_decomposition.cc +++ b/third_party/xla/xla/hlo/builder/lib/lu_decomposition.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/hlo/builder/lib/lu_decomposition.h" #include +#include #include #include "absl/status/statusor.h" diff --git a/third_party/xla/xla/hlo/builder/lib/math.cc b/third_party/xla/xla/hlo/builder/lib/math.cc index f2a77df3d7ddaa..3a72875d2733de 100644 --- a/third_party/xla/xla/hlo/builder/lib/math.cc +++ b/third_party/xla/xla/hlo/builder/lib/math.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" diff --git a/third_party/xla/xla/hlo/builder/lib/math_test.cc b/third_party/xla/xla/hlo/builder/lib/math_test.cc index 9755643b7586a0..cf56e0e39cf2b0 100644 --- a/third_party/xla/xla/hlo/builder/lib/math_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/math_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include diff --git a/third_party/xla/xla/hlo/builder/lib/matrix.h b/third_party/xla/xla/hlo/builder/lib/matrix.h index 8fdf01d438d7a1..6b69b1d0baa95b 100644 --- a/third_party/xla/xla/hlo/builder/lib/matrix.h +++ b/third_party/xla/xla/hlo/builder/lib/matrix.h @@ -17,6 +17,7 @@ limitations under the License. #define XLA_HLO_BUILDER_LIB_MATRIX_H_ #include +#include #include #include #include diff --git a/third_party/xla/xla/hlo/builder/lib/matrix_test.cc b/third_party/xla/xla/hlo/builder/lib/matrix_test.cc index debb6e20ae0108..9afd0cd19e0973 100644 --- a/third_party/xla/xla/hlo/builder/lib/matrix_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/matrix_test.cc @@ -15,11 +15,13 @@ limitations under the License. #include "xla/hlo/builder/lib/matrix.h" +#include #include #include #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" diff --git a/third_party/xla/xla/hlo/builder/lib/pooling.cc b/third_party/xla/xla/hlo/builder/lib/pooling.cc index 81dd1a7c4c0f95..913a399ad4a972 100644 --- a/third_party/xla/xla/hlo/builder/lib/pooling.cc +++ b/third_party/xla/xla/hlo/builder/lib/pooling.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/hlo/builder/lib/pooling.h" +#include #include #include #include diff --git a/third_party/xla/xla/hlo/builder/lib/pooling.h b/third_party/xla/xla/hlo/builder/lib/pooling.h index 15176888939c04..294000817126ee 100644 --- a/third_party/xla/xla/hlo/builder/lib/pooling.h +++ b/third_party/xla/xla/hlo/builder/lib/pooling.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_HLO_BUILDER_LIB_POOLING_H_ #define XLA_HLO_BUILDER_LIB_POOLING_H_ +#include #include #include diff --git a/third_party/xla/xla/hlo/builder/lib/pooling_test.cc b/third_party/xla/xla/hlo/builder/lib/pooling_test.cc index 97b874d81c04ce..83ebbb50337fdb 100644 --- a/third_party/xla/xla/hlo/builder/lib/pooling_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/pooling_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/hlo/builder/lib/pooling.h" +#include #include #include diff --git a/third_party/xla/xla/hlo/builder/lib/prng_test.cc b/third_party/xla/xla/hlo/builder/lib/prng_test.cc index 0e5f9772c35d26..88345e4b61324e 100644 --- a/third_party/xla/xla/hlo/builder/lib/prng_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/prng_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/hlo/builder/lib/constants.h" diff --git a/third_party/xla/xla/hlo/builder/lib/qr_test.cc b/third_party/xla/xla/hlo/builder/lib/qr_test.cc index 9f8e28e53cef66..97d5e3c947ee7d 100644 --- a/third_party/xla/xla/hlo/builder/lib/qr_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/qr_test.cc @@ -15,6 +15,10 @@ limitations under the License. #include "xla/hlo/builder/lib/qr.h" +#include +#include + +#include #include "xla/array.h" #include "xla/array2d.h" #include "xla/array3d.h" diff --git a/third_party/xla/xla/hlo/builder/lib/quantize_test.cc b/third_party/xla/xla/hlo/builder/lib/quantize_test.cc index 6520bb4a07fef1..f887e529b01825 100644 --- a/third_party/xla/xla/hlo/builder/lib/quantize_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/quantize_test.cc @@ -15,9 +15,12 @@ limitations under the License. #include "xla/hlo/builder/lib/quantize.h" +#include #include +#include #include +#include #include "xla/array2d.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/test.h" diff --git a/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.cc b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.cc index a7f3a3c00b6933..0acccb15b7deb8 100644 --- a/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.cc +++ b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.cc @@ -15,7 +15,7 @@ limitations under the License. #include "xla/hlo/builder/lib/self_adjoint_eig.h" -#include +#include #include #include diff --git a/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.h b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.h index f0dffdc41218bf..3a9a7d213ce87e 100644 --- a/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.h +++ b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_HLO_BUILDER_LIB_SELF_ADJOINT_EIG_H_ #define XLA_HLO_BUILDER_LIB_SELF_ADJOINT_EIG_H_ +#include + #include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/hlo/builder/lib/slicing.h b/third_party/xla/xla/hlo/builder/lib/slicing.h index dfb880805d2153..2e40c00e8a8798 100644 --- a/third_party/xla/xla/hlo/builder/lib/slicing.h +++ b/third_party/xla/xla/hlo/builder/lib/slicing.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "absl/types/span.h" diff --git a/third_party/xla/xla/hlo/builder/lib/slicing_test.cc b/third_party/xla/xla/hlo/builder/lib/slicing_test.cc index 72e8e1ca7026d8..c92c160c54745f 100644 --- a/third_party/xla/xla/hlo/builder/lib/slicing_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/slicing_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/hlo/builder/lib/slicing.h" +#include + #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/error_spec.h" diff --git a/third_party/xla/xla/hlo/builder/lib/sorting.cc b/third_party/xla/xla/hlo/builder/lib/sorting.cc index 456accc515e111..8d4eea1e3b6e1d 100644 --- a/third_party/xla/xla/hlo/builder/lib/sorting.cc +++ b/third_party/xla/xla/hlo/builder/lib/sorting.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/hlo/builder/lib/sorting.h" +#include #include #include "absl/status/statusor.h" diff --git a/third_party/xla/xla/hlo/builder/lib/sorting.h b/third_party/xla/xla/hlo/builder/lib/sorting.h index b951f26b97b043..c96d68002dbb71 100644 --- a/third_party/xla/xla/hlo/builder/lib/sorting.h +++ b/third_party/xla/xla/hlo/builder/lib/sorting.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_HLO_BUILDER_LIB_SORTING_H_ #define XLA_HLO_BUILDER_LIB_SORTING_H_ +#include + #include "xla/hlo/builder/xla_builder.h" #include "xla/types.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/hlo/builder/lib/sorting_test.cc b/third_party/xla/xla/hlo/builder/lib/sorting_test.cc index 2230eb73ecc4fb..c2bedc27667b11 100644 --- a/third_party/xla/xla/hlo/builder/lib/sorting_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/sorting_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/hlo/builder/lib/sorting.h" #include +#include #include #include #include diff --git a/third_party/xla/xla/hlo/builder/lib/svd.cc b/third_party/xla/xla/hlo/builder/lib/svd.cc index 537dd4482ea87b..d28a252d3dee6b 100644 --- a/third_party/xla/xla/hlo/builder/lib/svd.cc +++ b/third_party/xla/xla/hlo/builder/lib/svd.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "xla/hlo/builder/lib/svd.h" -#include +#include #include #include #include diff --git a/third_party/xla/xla/hlo/builder/lib/svd.h b/third_party/xla/xla/hlo/builder/lib/svd.h index 42d165f766ab43..0560a8cb4d8a62 100644 --- a/third_party/xla/xla/hlo/builder/lib/svd.h +++ b/third_party/xla/xla/hlo/builder/lib/svd.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_HLO_BUILDER_LIB_SVD_H_ #define XLA_HLO_BUILDER_LIB_SVD_H_ +#include + #include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/hlo/builder/lib/svd_test.cc b/third_party/xla/xla/hlo/builder/lib/svd_test.cc index 7266cde21684fe..cbf9a4bcabc58d 100644 --- a/third_party/xla/xla/hlo/builder/lib/svd_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/svd_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "xla/hlo/builder/lib/svd.h" +#include #include -#include #include #include "absl/status/statusor.h" diff --git a/third_party/xla/xla/hlo/builder/lib/tridiagonal.cc b/third_party/xla/xla/hlo/builder/lib/tridiagonal.cc index 9538a742e4cfce..9282560e879205 100644 --- a/third_party/xla/xla/hlo/builder/lib/tridiagonal.cc +++ b/third_party/xla/xla/hlo/builder/lib/tridiagonal.cc @@ -18,11 +18,12 @@ limitations under the License. #include #include #include -#include +#include #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/builder/lib/constants.h" #include "xla/hlo/builder/lib/loops.h" @@ -124,7 +125,7 @@ struct TridiagonalMatMulShapeParams { }; absl::Status ValidateTridiagonalMatMulDiagonal( - const Shape& diagonal_shape, const std::string_view diagonal_name, + const Shape& diagonal_shape, const absl::string_view diagonal_name, const Shape& rhs_shape) { const int64_t diagonal_rank = diagonal_shape.rank(); const int64_t rhs_rank = rhs_shape.rank(); diff --git a/third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc b/third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc index 5948c8840303e1..87102d7431a9b3 100644 --- a/third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include "absl/status/status.h" #include "xla/array.h" #include "xla/array3d.h" diff --git a/third_party/xla/xla/hlo/builder/padding.cc b/third_party/xla/xla/hlo/builder/padding.cc index b8951735619e92..08fc4c0cb9f5ee 100644 --- a/third_party/xla/xla/hlo/builder/padding.cc +++ b/third_party/xla/xla/hlo/builder/padding.cc @@ -16,6 +16,8 @@ limitations under the License. #include "xla/hlo/builder/padding.h" #include +#include +#include #include #include diff --git a/third_party/xla/xla/hlo/builder/xla_builder.cc b/third_party/xla/xla/hlo/builder/xla_builder.cc index 0a08168dd214ab..08d65ba9359b2c 100644 --- a/third_party/xla/xla/hlo/builder/xla_builder.cc +++ b/third_party/xla/xla/hlo/builder/xla_builder.cc @@ -2033,6 +2033,41 @@ XlaOp XlaBuilder::SparseDot( }); } +XlaOp XlaBuilder::RaggedAllToAll( + XlaOp input, XlaOp input_offsets, XlaOp send_sizes, XlaOp output, + XlaOp output_offsets, XlaOp recv_sizes, + absl::Span replica_groups, + const std::optional& channel_id) { + return ReportErrorOrReturn([&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input)); + TF_ASSIGN_OR_RETURN(const Shape* input_offsets_shape, + GetShapePtr(input_offsets)); + TF_ASSIGN_OR_RETURN(const Shape* send_sizes_shape, GetShapePtr(send_sizes)); + TF_ASSIGN_OR_RETURN(const Shape* output_shape, GetShapePtr(output)); + TF_ASSIGN_OR_RETURN(const Shape* output_offsets_shape, + GetShapePtr(output_offsets)); + TF_ASSIGN_OR_RETURN(const Shape* recv_sizes_shape, GetShapePtr(recv_sizes)); + TF_ASSIGN_OR_RETURN( + Shape shape, + ShapeInference::InferRaggedAllToAllShape( + {input_shape, input_offsets_shape, send_sizes_shape, output_shape, + output_offsets_shape, recv_sizes_shape})); + + std::vector operands{input, input_offsets, send_sizes, + output, output_offsets, recv_sizes}; + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + for (const ReplicaGroup& group : replica_groups) { + *instr.add_replica_groups() = group; + } + if (channel_id.has_value()) { + instr.set_channel_id(channel_id->handle()); + } + return AddInstruction(std::move(instr), HloOpcode::kRaggedAllToAll, + operands); + }); +} + XlaOp XlaBuilder::RaggedDot( XlaOp lhs, XlaOp rhs, XlaOp group_sizes, const RaggedDotDimensionNumbers& dimension_numbers, @@ -3421,7 +3456,7 @@ XlaOp XlaBuilder::ConditionalImpl( std::vector operands(1, branch_index); for (const XlaOp branch_operand : branch_operands) { - operands.emplace_back(branch_operand); + operands.push_back(branch_operand); } return AddInstruction(std::move(instr), HloOpcode::kConditional, absl::MakeSpan(operands)); @@ -5144,6 +5179,16 @@ XlaOp SparseDot(const XlaOp lhs, const XlaOp rhs, preferred_element_type); } +XlaOp RaggedAllToAll(const XlaOp input, const XlaOp input_offsets, + const XlaOp send_sizes, const XlaOp output, + const XlaOp output_offsets, const XlaOp recv_sizes, + absl::Span replica_groups, + const std::optional& channel_id) { + return input.builder()->RaggedAllToAll(input, input_offsets, send_sizes, + output, output_offsets, recv_sizes, + replica_groups, channel_id); +} + XlaOp RaggedDot(const XlaOp lhs, const XlaOp rhs, const XlaOp group_sizes, const RaggedDotDimensionNumbers& dimension_numbers, const PrecisionConfig* precision_config, diff --git a/third_party/xla/xla/hlo/builder/xla_builder.h b/third_party/xla/xla/hlo/builder/xla_builder.h index 69ebf5e5ed0c37..789b22ea65c988 100644 --- a/third_party/xla/xla/hlo/builder/xla_builder.h +++ b/third_party/xla/xla/hlo/builder/xla_builder.h @@ -609,6 +609,12 @@ class XlaBuilder { const PrecisionConfig* precision_config = nullptr, std::optional preferred_element_type = std::nullopt); + XlaOp RaggedAllToAll( + XlaOp input, XlaOp input_offsets, XlaOp send_sizes, XlaOp output, + XlaOp output_offsets, XlaOp recv_sizes, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt); + XlaOp RaggedDot( XlaOp lhs, XlaOp rhs, XlaOp group_sizes, const RaggedDotDimensionNumbers& dimension_numbers, @@ -1314,6 +1320,11 @@ class XlaBuilder { const DotDimensionNumbers& dimension_number, const PrecisionConfig* precision_config, std::optional preferred_element_type); + friend XlaOp RaggedAllToAll(XlaOp input, XlaOp input_offsets, + XlaOp send_sizes, XlaOp output, + XlaOp output_offsets, XlaOp recv_sizes, + absl::Span replica_groups, + const std::optional& channel_id); friend XlaOp RaggedDot(XlaOp lhs, XlaOp rhs, XlaOp group_sizes, const RaggedDotDimensionNumbers& dimension_numbers, const PrecisionConfig* precision_config, @@ -2190,6 +2201,13 @@ XlaOp SparseDot( const PrecisionConfig* precision_config = nullptr, std::optional preferred_element_type = std::nullopt); +// Enqueues a ragged all to all instruction onto the computation. +XlaOp RaggedAllToAll( + XlaOp input, XlaOp input_offsets, XlaOp send_sizes, XlaOp output, + XlaOp output_offsets, XlaOp recv_sizes, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt); + // Enqueues a ragged dot instruction onto the computation. XlaOp RaggedDot( XlaOp lhs, XlaOp rhs, XlaOp group_sizes, diff --git a/third_party/xla/xla/hlo/builder/xla_builder_test.cc b/third_party/xla/xla/hlo/builder/xla_builder_test.cc index baf36f52fadc3d..5f4c0c739b1306 100644 --- a/third_party/xla/xla/hlo/builder/xla_builder_test.cc +++ b/third_party/xla/xla/hlo/builder/xla_builder_test.cc @@ -2164,7 +2164,7 @@ struct BinaryOpTestCase { absl::Span broadcast_dimensions; std::string expected; std::function)> binary_op; - std::optional error_message; + std::optional error_message; }; constexpr absl::string_view kBroadcastDimensionMismatch = diff --git a/third_party/xla/xla/hlo/evaluator/BUILD b/third_party/xla/xla/hlo/evaluator/BUILD index 6b94430b0b2ad3..897e8f8d08be7c 100644 --- a/third_party/xla/xla/hlo/evaluator/BUILD +++ b/third_party/xla/xla/hlo/evaluator/BUILD @@ -30,6 +30,7 @@ cc_library( "hlo_evaluator_typed_visitor_float.cc", "hlo_evaluator_typed_visitor_float8.cc", "hlo_evaluator_typed_visitor_half.cc", + "hlo_evaluator_typed_visitor_int1.cc", "hlo_evaluator_typed_visitor_int16.cc", "hlo_evaluator_typed_visitor_int2.cc", "hlo_evaluator_typed_visitor_int32.cc", @@ -139,7 +140,7 @@ xla_cc_test( "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/transforms:hlo_element_type_converter", + "//xla/hlo/transforms/simplifiers:hlo_element_type_converter", "//xla/service:call_graph", "//xla/service:dynamic_dimension_inference", "//xla/service:hlo_module_config", diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc index cc43e22e4f2449..35fac878f104da 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc @@ -223,7 +223,7 @@ absl::Status MakeEvalErrorDueToParamOrInfeed( return error; } -// Repesents a value that might or might not be determined statically. +// Represents a value that might or might not be determined statically. struct DynamicOrStaticInteger { std::optional static_value; bool is_dynamic() const { return !static_value.has_value(); } diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index 41cd753d987201..74feab55e5e9c8 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -1716,12 +1716,14 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { // instantiating it. We explicitly instantiate this class in the various // hlo_evaluator_typed_visitor*.cc files. extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int1.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int1.cc new file mode 100644 index 00000000000000..0bdbb86bfb1401 --- /dev/null +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int1.cc @@ -0,0 +1,25 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "xla/hlo/evaluator/hlo_evaluator.h" +#include "xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" +#include "xla/types.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index a1ad715cac5401..8a17de2b2da191 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -51,10 +51,10 @@ cc_library( "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:hlo_constant_splitter", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:hlo_memory_scheduler", - "//xla/hlo/transforms:optimize_input_output_buffer_alias", + "//xla/hlo/transforms/simplifiers:hlo_constant_splitter", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", + "//xla/hlo/transforms/simplifiers:optimize_input_output_buffer_alias", "//xla/hlo/utils:hlo_live_range", "//xla/hlo/utils:hlo_sharding_util", "//xla/service:buffer_value", @@ -394,7 +394,7 @@ xla_cc_test( "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/testlib:verified_hlo_module", - "//xla/hlo/transforms:hlo_memory_scheduler", + "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", "//xla/hlo/utils:hlo_live_range", "//xla/hlo/utils:hlo_matchers", "//xla/service:buffer_value", diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 57a3533a4d9509..1d19ce6757cbca 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -3521,6 +3521,7 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( bool module_is_changed = false; bool set_to_memory_lower_bound = (option_.memory_budget_per_device == 0); + bool hard_memory_constraint = (option_.memory_budget_ratio < 0); // Remove CustomCalls with custom_call_target="Sharding" and move their // shardings to their input ops. @@ -3684,7 +3685,7 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( option_.memory_budget_per_device = memory_lower_bound * std::abs(option_.memory_budget_ratio); // TODO(b/341299984): Document this flag syntax, or automate the behavior. - if (option_.memory_budget_ratio < 0) { + if (hard_memory_constraint) { option_.memory_overbudget_coeff = -1.0; // Disables the soft constraint } } else if (option_.memory_budget_per_device > 0) { @@ -3807,7 +3808,12 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( option_, request_name, sharding_propagation_solution)); if (mesh_idx == partial_mesh_shapes.size() - 1) { this->solver_optimal_objective_value_ = output.cost; + } else if (hard_memory_constraint) { + // If the memory budget constraint is *hard*, we're already guaranteed + // that this intermediate solution honors the maximum value. } else { + // If the memory budget constraint is *soft*, we require the intermediate + // solution to be optimal (since otherwise, it's probably degenerate). TF_RET_CHECK(output.is_optimal) << "The solver did not find an optimal solution for a partial mesh " << "shape."; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h index 28f6bba67d0730..9d2f16908f1af8 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h @@ -42,7 +42,7 @@ struct AutoShardingOption { enum class PreserveShardingsType { // AutoSharding constrains the search space using all user shardings. kKeepAllShardings, - // AutoSharding constains the search space using input and output shardings + // AutoSharding constrains the search space using input and output shardings // of HloModule's entry computations and remove shardings of all // intermediate tensors. kKeepInputOutputShardings, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 12526e5ff3ec6b..354a00ba21aa88 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -539,7 +539,7 @@ void AddMemoryTerms( absl::StatusOr FormulateAndSolveMIPFromSolverRequest( const AutoShardingSolverRequest& unscaled_request) { const absl::Time start_time = absl::Now(); - const AutoShardingSolverRequest& request = ScaleRequest(unscaled_request); + const AutoShardingSolverRequest request = ScaleRequest(unscaled_request); const size_t num_edges = request.edges_size(); const int num_workers = 32; // SAT or SCIP @@ -1014,30 +1014,6 @@ std::optional ShardingStrategyHasViolation( return std::nullopt; } -// Computes the objective value of the sharding strategy. If the objective value -// is infinite or the sharding is infeasible (e.g., violates the peak-memory -// constraint), then a negated `AutoShardingViolationCode` value is returned. -double ComputeShardingStrategyCost( - const AutoShardingSolverRequest& request, - const std::vector& node_strategies) { - double cost = 0.0; - for (NodeIdx v = 0; v < request.num_nodes(); ++v) { - NodeStrategyIdx strategy = node_strategies[v]; - cost += request.computation_costs(v).costs(strategy) + - request.communication_costs(v).costs(strategy); - } - for (EdgeIdx e = 0; e < request.edges_size(); ++e) { - EdgeStrategyIdx strategy = GetEdgeStrategy(request, node_strategies, e); - cost += request.resharding_costs(e).costs(strategy); - } - std::optional violation_code = - ShardingStrategyHasViolation(request, node_strategies); - if (violation_code.has_value()) { - cost = -1 * (*violation_code); - } - return cost; -} - // Assigns all nodes to their first sharding configuration. If the assignment is // infeasible, the output cost is negative and encodes the violation code. AutoShardingSolverOutput SolveTrivial( @@ -1149,6 +1125,8 @@ absl::StatusOr RunHeuristicSolver( output = SolveGreedy(request, "node-cost"); } else if (algorithm == "greedy-node-memory") { output = SolveGreedy(request, "node-memory"); + } else if (algorithm == "brkga") { + output = SolveBrkga(request); } else { CHECK(false) << absl::Substitute("Algorithm $0 is not implemented.", algorithm); @@ -1156,6 +1134,8 @@ absl::StatusOr RunHeuristicSolver( auto duration = absl::Now() - start_time; LOG(INFO) << "Solver took " << absl::ToInt64Milliseconds(duration) << " ms"; LOG(INFO) << "Objective value: " << output.cost; + LOG(INFO) << "Total Cost: " + << ComputeShardingStrategyCost(unscaled_request, output.s_val); return output; } @@ -1371,6 +1351,27 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, return evaluation; } +double ComputeShardingStrategyCost( + const AutoShardingSolverRequest& request, + const std::vector& node_strategies) { + double cost = 0.0; + for (NodeIdx v = 0; v < request.num_nodes(); ++v) { + NodeStrategyIdx strategy = node_strategies[v]; + cost += request.computation_costs(v).costs(strategy) + + request.communication_costs(v).costs(strategy); + } + for (EdgeIdx e = 0; e < request.edges_size(); ++e) { + EdgeStrategyIdx strategy = GetEdgeStrategy(request, node_strategies, e); + cost += request.resharding_costs(e).costs(strategy); + } + std::optional violation_code = + ShardingStrategyHasViolation(request, node_strategies); + if (violation_code.has_value()) { + cost = -1 * (*violation_code); + } + return cost; +} + absl::Status ValidateRequest(const AutoShardingSolverRequest& request) { const int num_nodes = request.num_nodes(); const int num_edges = request.edges_size(); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h index 7852e1abfb91f7..d3f79dddf4cc72 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h @@ -51,6 +51,7 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( // - "random" // - "greedy-node-cost" // - "greedy-node-memory" +// - "brkga" absl::StatusOr RunHeuristicSolver( const AutoShardingSolverRequest& request, const std::string& algorithm); @@ -101,6 +102,15 @@ struct AutoShardingEvaluation { AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, const AutoShardingSolverOutput& result); +// Computes the objective value of the sharding strategy. If the objective value +// is infinite or the sharding is infeasible (e.g., violates the peak-memory +// constraint), then a negated `AutoShardingViolationCode` value is returned. +// This function is used instead of `Evaluate` for faster iteration loops in the +// heuristic solver library. +double ComputeShardingStrategyCost( + const AutoShardingSolverRequest& request, + const std::vector& node_strategies); + // Creates and returns a variable for makespan. operations_research::MPVariable* CreateMakespanVar( const AutoShardingSolverRequest& request, @@ -143,6 +153,8 @@ absl::Status ValidateRequest(const AutoShardingSolverRequest& request); void SolverRequestCallback(const AutoShardingSolverRequest& request); +AutoShardingSolverOutput SolveBrkga(const AutoShardingSolverRequest& request); + } // namespace spmd } // namespace xla diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc index 570a21268c50e9..ded0be31f34a63 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc @@ -49,5 +49,11 @@ void SolverRequestCallback(const AutoShardingSolverRequest& request) { // TODO(mofftt): Implement this. } +AutoShardingSolverOutput SolveBrkga(const AutoShardingSolverRequest& request) { + // TODO(fahrbach): Implement this. + AutoShardingSolverOutput output; + return output; +} + } // namespace spmd } // namespace xla diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index fd49246177863f..1f5fd5eff6d0fa 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -264,7 +264,7 @@ BuildStrategyAndCost( // usually "follows" other instruction's sharding. If the instruction it // follows is an intermediate instruction, it may be able to choose // unevenly sharded strategiyes. Usually if we constraint input's sharding - // strategies, outputs would be constrained as welll, but if outputs are + // strategies, outputs would be constrained as well, but if outputs are // still unevely sharded in some cases, we need to fix the implementation // in auto sharding. only_allow_divisible = option.only_allow_divisible_input_output; @@ -286,7 +286,7 @@ BuildStrategyAndCost( // We use this following relationship to ensure that the input tuple // of the while loop, and the parameter of the body of that while // loop. Therefore, this followinf relationship is necessary for - // correctness, and is not merely an optmization. + // correctness, and is not merely an optimization. is_follow_necessary_for_correctness = true; for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { std::unique_ptr child_strategies = diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index c4065bf05066f9..660b344b3bdb71 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -3124,6 +3124,44 @@ ENTRY %entry { op::Sharding("{devices=[8,16]<=[128] last_tile_dim_replicate}")); } +TEST_F(AutoShardingTest, NegativeMemoryBudgetRatioTest) { + constexpr absl::string_view kHloString = R"( +HloModule module + +region { + Arg_0 = s32[] parameter(0) + ROOT Arg_1 = s32[] parameter(1) +} + +ENTRY %Scatter { + call = s32[4,128]{1,0} parameter(0) + clamp = s32[4,2]{1,0} parameter(1) + broadcast = s32[4,8]{1,0} parameter(2) + ROOT scatter = s32[4,128]{1,0} scatter(call, clamp, broadcast), update_window_dims={1}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=region +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + AutoShardingOption option; + option.enable = true; + option.device_mesh_shape = {2, 2}; + option.device_mesh_ids = {0, 1, 2, 3}; + option.device_mesh_alpha = {1.0, 1.0}; + option.device_mesh_beta = {0.01, 1.0}; + // Memory budget a tad higher than what would be required if the largest + // tensors are sharded 4-ways + option.memory_budget_per_device = 0; + option.memory_budget_ratio = -1.1; // Disables the soft memory constraint. + + TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); + VLOG(10) << module->ToString(); + EXPECT_TRUE(changed); + const HloInstruction* scatter = FindInstruction(module.get(), "scatter"); + ASSERT_NE(scatter, nullptr); + EXPECT_EQ(scatter->sharding().NumTiles(), 4); + TF_EXPECT_OK(scatter->sharding().Validate(scatter->shape(), 4)); +} + TEST(NormalizeTest, NormalizeHandlesNegativeCosts) { EdgeReshardingCostMatrix edge_cost(2, 2); edge_cost(0, 0).communication_cost = -100; diff --git a/third_party/xla/xla/hlo/ir/BUILD b/third_party/xla/xla/hlo/ir/BUILD index ef30fbe69d0bff..eb51d248d268b0 100644 --- a/third_party/xla/xla/hlo/ir/BUILD +++ b/third_party/xla/xla/hlo/ir/BUILD @@ -65,6 +65,7 @@ cc_library( "//xla:array", "//xla:comparison_util", "//xla:literal", + "//xla:literal_pool", "//xla:literal_util", "//xla:printer", "//xla:protobuf_util", @@ -77,7 +78,6 @@ cc_library( "//xla:util", "//xla:window_util", "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", "//xla/hlo/parser:hlo_lexer", "//xla/service:compilation_environments", "//xla/service:computation_layout", @@ -161,6 +161,7 @@ xla_cc_test( deps = [ ":hlo", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/parser:hlo_parser", "//xla/service:hlo_module_config", "@com_google_absl//absl/hash", @@ -231,6 +232,19 @@ cc_library( ], ) +xla_cc_test( + name = "hlo_instruction_utils_test", + srcs = ["hlo_instruction_utils_test.cc"], + deps = [ + ":hlo", + ":hlo_instruction_utils", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_query", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "hlo_reachability", hdrs = ["hlo_reachability.h"], diff --git a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.cc b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.cc index d0f6bfb6fe1357..d7afa4b1b92683 100644 --- a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.cc +++ b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.cc @@ -15,8 +15,6 @@ limitations under the License. #include "xla/hlo/ir/dfs_hlo_visitor.h" -#include - #include "absl/status/status.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" diff --git a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.h b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.h index b0c75595651c7d..15c039db8383a5 100644 --- a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.h +++ b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_HLO_IR_DFS_HLO_VISITOR_H_ #define XLA_HLO_IR_DFS_HLO_VISITOR_H_ +#include #include #include diff --git a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h index c9ba49231955ab..56846cac1d9647 100644 --- a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h +++ b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h @@ -380,6 +380,7 @@ class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault { // Mark the computation as having changed. void MarkAsChanged() { changed_ = true; } + void MarkAsMaybeChanged(bool changed) { changed_ |= changed; } private: bool changed_ = false; diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index f1420ecc549bb5..0cb22c0964e572 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -1728,7 +1728,7 @@ std::unique_ptr HloComputation::CloneInContext( for (HloInstruction* operand : cur->operands()) { const HloInstruction* new_operand = replace(operand); if (new_operand) { - dfs_stack.emplace_back(new_operand); + dfs_stack.push_back(new_operand); } } } diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.h b/third_party/xla/xla/hlo/ir/hlo_computation.h index 757505980a079e..4411e3102b5a26 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.h +++ b/third_party/xla/xla/hlo/ir/hlo_computation.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_HLO_IR_HLO_COMPUTATION_H_ #define XLA_HLO_IR_HLO_COMPUTATION_H_ +#include #include #include #include @@ -28,6 +29,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -420,11 +422,23 @@ class HloComputation { // with respect to HloComputation::Equal() method. template friend H AbslHashValue(H h, const HloComputation& computation) { + // Walk the computation in post-order, computing (and caching) the + // Absl::Hash after each instruction to use to as an operand for + // subsequent instructions. auto instructions = computation.MakeInstructionPostOrder(); + absl::flat_hash_map instruction_hash_cache; + instruction_hash_cache.reserve(instructions.size()); for (auto* instruction : instructions) { - h = H::combine(std::move(h), *instruction); + absl::InlinedVector operand_hashes; + for (auto* operand : instruction->operands()) { + operand_hashes.push_back(instruction_hash_cache[operand]); + } + instruction_hash_cache.emplace( + instruction, absl::HashOf(*instruction, operand_hashes)); } - return H::combine(std::move(h), instructions.size()); + return H::combine(std::move(h), + instruction_hash_cache[computation.root_instruction()], + instructions.size()); } using InstructionSequence = tsl::gtl::iterator_range< diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 402c0a97019f32..b051c285743e69 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -1508,7 +1508,8 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, /* static */ std::unique_ptr HloInstruction::CreateVariadic( const Shape& shape, HloOpcode opcode, absl::Span operands) { - CHECK_EQ(HloOpcode::kTuple, opcode); + std::optional arity = HloOpcodeArity(opcode); + CHECK(!arity.has_value() || arity.value() == operands.size()); return CreateNary(shape, opcode, operands); } @@ -2647,7 +2648,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kTan: case HloOpcode::kTanh: CHECK_EQ(new_operands.size(), 1); - clone = CreateUnary(shape, opcode_, new_operands[0]); + clone = CreateUnary(shape, opcode_, new_operands[0], result_accuracy()); break; // Binary ops. case HloOpcode::kAdd: diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 827792b65d8a61..bd1d92132503e6 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -41,6 +41,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -117,7 +118,8 @@ class HloPrintOptions { print_extra_attributes_(true), syntax_sugar_async_ops_(true), print_name_after_closing_brace_(false), - print_full_replica_group_list_(false) {} + print_full_replica_group_list_(false), + print_parameter_number_(true) {} // Static reference to a default construction HloPrintOptions, to avoid // constructing a new one each time default is needed. static const HloPrintOptions& Default() { @@ -399,6 +401,12 @@ class HloPrintOptions { return *this; } + // If true, prints the parameter number of a parameter instruction. + HloPrintOptions& set_print_parameter_number(bool value) { + print_parameter_number_ = value; + return *this; + } + bool print_large_constants() const { return print_large_constants_; } bool print_only_essential_constants() const { return print_only_essential_constants_; @@ -444,6 +452,7 @@ class HloPrintOptions { bool print_full_replica_group_list() const { return print_full_replica_group_list_; } + bool print_parameter_number() const { return print_parameter_number_; } private: // The interval between the /*index=*/ annotated operands. 0 means never print @@ -475,6 +484,7 @@ class HloPrintOptions { bool syntax_sugar_async_ops_; bool print_name_after_closing_brace_; bool print_full_replica_group_list_; + bool print_parameter_number_; }; // For canonical string output, we need to have a canonical way to rename @@ -1073,13 +1083,44 @@ class HloInstruction { // Index 'data' at 'offsets'[2], 'sizes'[2]' // {m,n,o},{p,q,r},{s,t,u},{v,w,x} // + // + // ``output_offsets`` must be sharded in a way that each replica has offsets + // in the target replica output perspective. + // + // For i-th output offset, the current replica will send + // `input[input_offsets[i]:input_offsets[i]+input_sizes[i]]` update to + // `i`-th replica that will be written to + // `output_i[output_offsets[i]:output_offsets[i]+send_sizes[i]]` in `i`-th + // replica ``output``. + // + // For example, if we have 2 replicas: + // + // replica 0: + // input: [1, 2, 2] + // output: [0, 0, 0, 0] + // input_offsets: [0, 1] + // send_sizes: [1, 2] + // output_offsets: [0, 0] + // recv_sizes: [1, 1] + // + // replica 1: + // input: [3, 4, 0] + // output: [0, 0, 0, 0] + // input_offsets: [0, 1] + // send_sizes: [1, 1] + // output_offsets: [1, 2] + // recv_sizes: [2, 1] + // + // replica 0's result will be: [1, 3, 0, 0] + // replica 1's result will be: [2, 2, 4, 0] + // // The ragged all-to-all HLO has the following arguments: - // input: ragged input data tensor. - // output: ragged output data tensor. - // input_offsets: ragged input offsets tensor. - // send_sizes: ragged send sizes tensor. - // output_offsets: ragged output offsets tensor. - // recv_sizes: ragged recv sizes tensor. + // input: ragged input data tensor. + // output: ragged output data tensor. + // input_offsets: ragged input offsets tensor. + // send_sizes: ragged send sizes tensor. + // output_offsets: array of ragged offsets in the target replica output. + // recv_sizes: ragged recv sizes tensor. // // The '*_offsets' and '*_sizes' tensors must have the same shape. // The output buffer is passed in as an input (and aliased in the output), @@ -1705,27 +1746,20 @@ class HloInstruction { /*ignore_commutative_operand_order=*/true); } + // Allow subclasses to contribute additional attributes to the hash. + virtual void HashAdditionalAttributes(absl::HashState h) const {}; + // Generates a hash value of an HLO instruction. Hash considers - // information on opcode, shape, operands, and typically a root instruction. - // This function returns the same hash value for equivalent HLO instructions, - // with respect to HloInstruction::Identical() method. - // TODO(majnemer): Make the comment here more crisp & accurate. + // information on opcode, shape, number of operands, and other relevant + // additional attributes (e.g. literal values, parameters, etc.). template friend H AbslHashValue(H h, const HloInstruction& hlo) { h = H::combine(std::move(h), hlo.opcode(), hlo.shape()); - if (!hlo.IsCrossModuleAllReduce()) { - for (size_t i = 0; i < hlo.operands().size(); ++i) { - h = H::combine(std::move(h), hlo.operand(i)->shape()); - } h = H::combine(std::move(h), hlo.operand_count()); } - - if (hlo.opcode() == HloOpcode::kFusion) { - h = H::combine(std::move(h), *hlo.fused_expression_root(), - hlo.fusion_kind(), hlo.fused_instruction_count(), - hlo.fused_parameters().size()); - } + // Allow subclasses to mix additional data into h before returning + hlo.HashAdditionalAttributes(absl::HashState::Create(&h)); return h; } @@ -2194,8 +2228,6 @@ class HloInstruction { // if no id has been assigned yet). int unique_id() const { return unique_id_; } - bool preserve_layout() const { return metadata_->preserve_layout(); } - bool has_backend_config() const { return !backend_config_.empty(); } void clear_backend_config() { backend_config_ = BackendConfigWrapper(); } @@ -2347,9 +2379,6 @@ class HloInstruction { void set_metadata_deduplicated_name(std::string deduplicated_name) { metadata_->set_deduplicated_name(std::move(deduplicated_name)); } - void set_metadata_preserve_layout(bool preserve_layout) { - metadata_->set_preserve_layout(preserve_layout); - } void set_metadata_scheduling_name(absl::string_view name) { metadata_->set_scheduling_name(std::string(name)); } diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction_utils.cc b/third_party/xla/xla/hlo/ir/hlo_instruction_utils.cc index c500b0ccd079c1..96ec5d59d72d11 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction_utils.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction_utils.cc @@ -16,6 +16,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction_utils.h" #include +#include +#include +#include #include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction_utils.h b/third_party/xla/xla/hlo/ir/hlo_instruction_utils.h index 3721f0e65b3200..35d531122e25aa 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction_utils.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction_utils.h @@ -16,6 +16,11 @@ limitations under the License. #ifndef XLA_HLO_IR_HLO_INSTRUCTION_UTILS_H_ #define XLA_HLO_IR_HLO_INSTRUCTION_UTILS_H_ +#include +#include +#include +#include + #include "xla/hlo/ir/hlo_instruction.h" namespace xla { diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction_utils_test.cc b/third_party/xla/xla/hlo/ir/hlo_instruction_utils_test.cc new file mode 100644 index 00000000000000..fe8c488b154e88 --- /dev/null +++ b/third_party/xla/xla/hlo/ir/hlo_instruction_utils_test.cc @@ -0,0 +1,89 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/ir/hlo_instruction_utils.h" + +#include +#include +#include + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/utils/hlo_query.h" + +namespace xla { + +namespace hlo_instruction_utils { + +namespace { + +class HloInstructionUtilsTest : public HloHardwareIndependentTestBase {}; + +TEST_F(HloInstructionUtilsTest, TestIsUnstridedSlice) { + const char* hlo_text = R"( + HloModule test + ENTRY main { + param = f32[2,8] parameter(0) + strided_slice = f32[2,2] slice(param), slice={[0:2:1], [4:8:2]} + unstrided_slice = f32[2,4] slice(param), slice={[0:2:1], [4:8:1]} + ROOT tuple = (f32[2,2], f32[2,4]) tuple(strided_slice, unstrided_slice) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); + HloInstruction* unstrided_slice = + hlo_query::FindInstruction(m->entry_computation(), "unstrided_slice"); + HloInstruction* strided_slice = + hlo_query::FindInstruction(m->entry_computation(), "strided_slice"); + EXPECT_NE(unstrided_slice, nullptr); + EXPECT_NE(strided_slice, nullptr); + EXPECT_TRUE(IsUnstridedSlice(unstrided_slice)); + EXPECT_FALSE(IsUnstridedSlice(strided_slice)); +} + +TEST_F(HloInstructionUtilsTest, TestAddOrUpdateVectorOfPairsAsAttribute) { + const char* hlo = R"( + HloModule test + ENTRY main { + ROOT param = s32[] parameter(0), frontend_attributes={foo="bar", baz="qux"} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo)); + HloInstruction* param = m->entry_computation()->root_instruction(); + EXPECT_EQ(param->frontend_attributes().map().size(), 2); + EXPECT_EQ(param->frontend_attributes().map().at("foo"), "bar"); + EXPECT_EQ(param->frontend_attributes().map().at("baz"), "qux"); + + std::string new_key = "quux"; + std::vector> value = {{1, 2}, {3, 4}}; + AddOrUpdateVectorOfPairsAsAttribute(param, new_key, value); + EXPECT_EQ(param->frontend_attributes().map().size(), 3); + EXPECT_EQ(param->frontend_attributes().map().at("foo"), "bar"); + EXPECT_EQ(param->frontend_attributes().map().at("baz"), "qux"); + EXPECT_EQ(param->frontend_attributes().map().at("quux"), "{{1,2},{3,4}}"); + + std::vector> new_value = {{5, 6}, {7, 8}}; + AddOrUpdateVectorOfPairsAsAttribute(param, new_key, new_value); + EXPECT_EQ(param->frontend_attributes().map().size(), 3); + EXPECT_EQ(param->frontend_attributes().map().at("foo"), "bar"); + EXPECT_EQ(param->frontend_attributes().map().at("baz"), "qux"); + EXPECT_EQ(param->frontend_attributes().map().at("quux"), "{{5,6},{7,8}}"); +} + +} // namespace + +} // namespace hlo_instruction_utils + +} // namespace xla diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.cc b/third_party/xla/xla/hlo/ir/hlo_instructions.cc index ca071468780d70..feccc9d78ae839 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.cc @@ -2761,7 +2761,9 @@ void HloParameterInstruction::PrintExtraAttributesImpl( void HloParameterInstruction::PrintOperandsWithCanonicalNameMap( Printer* printer, const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { - printer->Append(parameter_number_); + if (options.print_parameter_number()) { + printer->Append(parameter_number_); + } } bool HloParameterInstruction::IdenticalSlowPath( diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.h b/third_party/xla/xla/hlo/ir/hlo_instructions.h index 6830061d85036e..c21dddeee907b5 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.h +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef XLA_HLO_IR_HLO_INSTRUCTIONS_H_ #define XLA_HLO_IR_HLO_INSTRUCTIONS_H_ +#include #include #include #include @@ -28,6 +29,7 @@ limitations under the License. #include "absl/base/attributes.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -40,6 +42,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout.h" #include "xla/literal.h" +#include "xla/literal_pool.h" #include "xla/printer.h" #include "xla/service/hlo.pb.h" #include "xla/shape.h" @@ -1343,6 +1346,26 @@ class HloConstantInstruction : public HloInstruction { return hlo->opcode() == HloOpcode::kConstant; } + // Canonicalize constant literal using the given literal pool. + bool Canonicalize(LiteralPool* literal_pool) { + if (literal_pool && literal_) { + auto canonical = literal_pool->GetCanonicalLiteral(literal_); + if (canonical != literal_) { + literal_ = std::move(canonical); + return true; + } + } + return false; + } + + // Add literal to the hash state. + void HashAdditionalAttributes(absl::HashState h) const override { + if (HasLiteral()) { + absl::HashState::combine(std::move(h), + Literal::AbslHashable(literal())); + } + } + private: bool IsElementwiseImpl( const std::optional& operand_idx) const override; @@ -1582,6 +1605,13 @@ class HloFusionInstruction : public HloCallableInstruction { return hlo->opcode() == HloOpcode::kFusion; } + // Add various fusion parameters to the hash. + void HashAdditionalAttributes(absl::HashState h) const override { + absl::HashState::combine(std::move(h), *fused_expression_root(), + fusion_kind(), fused_instruction_count(), + fused_parameters().size()); + } + protected: std::string default_called_computation_name() const override { return "fused_computation"; @@ -1701,6 +1731,11 @@ class HloParameterInstruction : public HloInstruction { return hlo->opcode() == HloOpcode::kParameter; } + // Add parameter number to the hash. + void HashAdditionalAttributes(absl::HashState h) const override { + absl::HashState::combine(std::move(h), parameter_number()); + } + private: void PrintExtraAttributesImpl(AttributePrinter& printer, const HloPrintOptions& options) const override; diff --git a/third_party/xla/xla/hlo/ir/hlo_module.cc b/third_party/xla/xla/hlo/ir/hlo_module.cc index 7ec1e26749b467..0ed24c019c6b73 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module.cc @@ -58,7 +58,6 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/tsl/lib/gtl/map_util.h" #include "xla/util.h" -#include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/hlo/ir/hlo_module.h b/third_party/xla/xla/hlo/ir/hlo_module.h index c00cd7ee7a7a1a..c9a33280a498db 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.h +++ b/third_party/xla/xla/hlo/ir/hlo_module.h @@ -23,7 +23,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -54,7 +53,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/tsl/lib/gtl/iterator_range.h" -#include "xla/xla.pb.h" #include "tsl/platform/logging.h" namespace xla { @@ -673,8 +671,8 @@ class HloModule { // Describes a stack frame. struct StackFrame { - std::string_view file_name; - std::string_view function_name; + absl::string_view file_name; + absl::string_view function_name; int line = 0; int column = 0; diff --git a/third_party/xla/xla/hlo/ir/hlo_module_metadata.cc b/third_party/xla/xla/hlo/ir/hlo_module_metadata.cc index 238516dcd5633b..eb15af83d4f510 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module_metadata.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module_metadata.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module_metadata.h" #include +#include #include #include "absl/container/flat_hash_set.h" diff --git a/third_party/xla/xla/hlo/ir/hlo_module_metadata.h b/third_party/xla/xla/hlo/ir/hlo_module_metadata.h index 54de64a928e734..ef4c52e395ee15 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module_metadata.h +++ b/third_party/xla/xla/hlo/ir/hlo_module_metadata.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_HLO_IR_HLO_MODULE_METADATA_H_ #define XLA_HLO_IR_HLO_MODULE_METADATA_H_ +#include #include #include #include diff --git a/third_party/xla/xla/hlo/ir/hlo_module_test.cc b/third_party/xla/xla/hlo/ir/hlo_module_test.cc index 32b5119eca2aed..01756318c93ec6 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module_test.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -32,9 +31,7 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/xla_data.pb.h" namespace xla { namespace { @@ -44,7 +41,7 @@ TEST(HloModuleTest, AbslHashValue) { HloModule module2("temp_module3", HloModuleConfig()); EXPECT_EQ(absl::HashOf(module1), absl::HashOf(module2)); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule m1 ENTRY main { a = f32[] parameter(0) @@ -109,7 +106,7 @@ TEST(HloModuleTest, GetModifySetConfig) { EXPECT_EQ(&m1.config(), &m1.mutable_config()); } -void CreateComputation(HloModule& module, std::string_view name, bool is_entry, +void CreateComputation(HloModule& module, absl::string_view name, bool is_entry, HloSchedule& schedule) { HloComputation::Builder builder(name); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); @@ -131,7 +128,7 @@ void CreateComputation(HloModule& module, std::string_view name, bool is_entry, const char* kCloneSuffix = "clone"; -std::string GetCloneName(std::string_view name) { +std::string GetCloneName(absl::string_view name) { return absl::StrCat(name, ".", kCloneSuffix); } @@ -204,5 +201,231 @@ TEST(HloModuleTest, CloneWithNewConfig) { m1.config().device_memory_size()); } +TEST(HloModuleTest, AbslHashInstructionOrdering) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Add.0 and add.1 are swapped. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.1 = f32[32,32] add(b, c) // Swapped with below + add.0 = f32[32,32] add(a, b) // Swapped with above + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + EXPECT_EQ(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashInstructionOpcodes) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Second add changed to sub + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] subtract(b, c) // Changed from add to subtract + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + EXPECT_NE(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashInstructionShapes) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Second add has different shape. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + ENTRY main { + // Shapes changed from [32,32] to [16,16] + a = f32[16,16] parameter(0) + b = f32[16,16] parameter(1) + c = f32[16,16] parameter(2) + add.0 = f32[16,16] add(a, b) + add.1 = f32[16,16] add(b, c) + ROOT result = f32[16,16] add(add.0, add.1) + } + )")); + + EXPECT_NE(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashInstructionNaming) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Add x to all names + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + // All names changed to x + ax = f32[32,32] parameter(0) + bx = f32[32,32] parameter(1) + cx = f32[32,32] parameter(2) + add.0x = f32[32,32] add(ax, bx) + add.1x = f32[32,32] add(bx, cx) + ROOT resultx = f32[32,32] add(add.0x, add.1x) + } + )")); + + EXPECT_EQ(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashGraphChanges) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Changed from (a+b)+(b+c) to ((a+b)+c)+a + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(add.0, c) // Changed from add(b, c) + ROOT result = f32[32,32] add(add.1, a) // Changed from add(add.0, add.1) + } + )")); + + EXPECT_NE(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashParameterChanges) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Change parameter numbers + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(1) // Changed from parameter(0) + b = f32[32,32] parameter(0) // Changed from parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + EXPECT_NE(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashConstantValues) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = s32[32,32] parameter(0) + c = s32[] constant(42) + b = s32[32,32] broadcast(c), dimensions={} + ROOT result = s32[32,32] add(a, b) + } + )")); + + // Changed from 42 to 43 + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = s32[32,32] parameter(0) + c = s32[] constant(43) // Changed from constant(42) + b = s32[32,32] broadcast(c), dimensions={} + ROOT result = s32[32,32] add(a, b) + } + )")); + + EXPECT_NE(absl::HashOf(*module1), absl::HashOf(*module2)); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/hlo/ir/hlo_op_metadata.cc b/third_party/xla/xla/hlo/ir/hlo_op_metadata.cc index 30b1d2c3cfc6a6..462655f4dcab54 100644 --- a/third_party/xla/xla/hlo/ir/hlo_op_metadata.cc +++ b/third_party/xla/xla/hlo/ir/hlo_op_metadata.cc @@ -59,9 +59,6 @@ std::string OpMetadataToString(const OpMetadata& metadata, bool only_op_name) { absl::CEscape(metadata.deduplicated_name()), "\"")); } - if (metadata.preserve_layout()) { - result.push_back(absl::StrCat("preserve_layout=true")); - } if (!metadata.scheduling_name().empty()) { result.push_back( absl::StrCat("scheduling_name=\"", metadata.scheduling_name(), "\"")); diff --git a/third_party/xla/xla/hlo/ir/hlo_original_value.cc b/third_party/xla/xla/hlo/ir/hlo_original_value.cc index c1617888510a4d..e76cd15d989ce0 100644 --- a/third_party/xla/xla/hlo/ir/hlo_original_value.cc +++ b/third_party/xla/xla/hlo/ir/hlo_original_value.cc @@ -53,15 +53,14 @@ std::string OriginalValueToStringHelper(const OriginalValue& original_value, return result; } - // The original_value may refer to an empty array, such as origin {}, so let's - // check whether that's the case before accessing them. Generally speaking the - // index _should_ be good, but let's double check. const auto& leaf = original_value.element(shape_index); if (leaf.has_value()) { absl::StrAppend( &result, "{", "\"", leaf->instruction_name, "\"", (leaf->shape_index.empty() ? "" : " " + leaf->shape_index.ToString()), "}"); + } else { + absl::StrAppend(&result, "{}"); } return result; } diff --git a/third_party/xla/xla/hlo/ir/hlo_schedule.cc b/third_party/xla/xla/hlo/ir/hlo_schedule.cc index ddd2aaf4cffef5..b0898a39ed7777 100644 --- a/third_party/xla/xla/hlo/ir/hlo_schedule.cc +++ b/third_party/xla/xla/hlo/ir/hlo_schedule.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include diff --git a/third_party/xla/xla/hlo/ir/hlo_schedule.h b/third_party/xla/xla/hlo/ir/hlo_schedule.h index 37cbff34856a9a..b0a87284b62b7c 100644 --- a/third_party/xla/xla/hlo/ir/hlo_schedule.h +++ b/third_party/xla/xla/hlo/ir/hlo_schedule.h @@ -17,6 +17,7 @@ limitations under the License. #define XLA_HLO_IR_HLO_SCHEDULE_H_ #include +#include #include #include #include diff --git a/third_party/xla/xla/hlo/ir/hlo_sharding_metadata.cc b/third_party/xla/xla/hlo/ir/hlo_sharding_metadata.cc index b578408ed9dd80..3be3eef29cc847 100644 --- a/third_party/xla/xla/hlo/ir/hlo_sharding_metadata.cc +++ b/third_party/xla/xla/hlo/ir/hlo_sharding_metadata.cc @@ -15,7 +15,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding_metadata.h" -#include +#include #include #include #include diff --git a/third_party/xla/xla/hlo/ir/hlo_sharding_metadata.h b/third_party/xla/xla/hlo/ir/hlo_sharding_metadata.h index 5d963e931d96b2..95069a5d6ac492 100644 --- a/third_party/xla/xla/hlo/ir/hlo_sharding_metadata.h +++ b/third_party/xla/xla/hlo/ir/hlo_sharding_metadata.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_HLO_IR_HLO_SHARDING_METADATA_H_ #define XLA_HLO_IR_HLO_SHARDING_METADATA_H_ +#include #include #include #include diff --git a/third_party/xla/xla/hlo/ir/tile_assignment.h b/third_party/xla/xla/hlo/ir/tile_assignment.h index 7adc5ab50b2d70..31d874328b64cb 100644 --- a/third_party/xla/xla/hlo/ir/tile_assignment.h +++ b/third_party/xla/xla/hlo/ir/tile_assignment.h @@ -16,10 +16,14 @@ limitations under the License. #ifndef XLA_HLO_IR_TILE_ASSIGNMENT_H_ #define XLA_HLO_IR_TILE_ASSIGNMENT_H_ +#include +#include #include +#include #include #include #include +#include #include #include "absl/algorithm/container.h" diff --git a/third_party/xla/xla/hlo/parser/BUILD b/third_party/xla/xla/hlo/parser/BUILD index 9cd8a7d40153eb..cfee6c16a50180 100644 --- a/third_party/xla/xla/hlo/parser/BUILD +++ b/third_party/xla/xla/hlo/parser/BUILD @@ -1,5 +1,5 @@ # Description: -# XLA parser implementation. +# HLO parser implementation. load( "//xla:xla.bzl", diff --git a/third_party/xla/xla/hlo/parser/hlo_parser.cc b/third_party/xla/xla/hlo/parser/hlo_parser.cc index 38d87ee6316c4f..4475c268055df5 100644 --- a/third_party/xla/xla/hlo/parser/hlo_parser.cc +++ b/third_party/xla/xla/hlo/parser/hlo_parser.cc @@ -2005,7 +2005,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT } else { // Since async-{update,done} will inherit the computation from // async-start, we'll only need to make sure it matches what was - // specified explicitily. + // specified explicitly. if (operands[0]->async_wrapped_opcode() != *async_wrapped_opcode) { TokenError( StrFormat("Expect async wrapped opcode to be %s, but got %s", @@ -4586,6 +4586,9 @@ bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) { } elems_seen_per_dim[0] = shape.dimensions(0); lexer_.Lex(); + if (!options_.fill_shortform_constants_with_random_values()) { + break; + } // Fill data with deterministic (garbage) values. Use static to avoid // creating identical constants which could potentially got CSE'ed // away. This is a best-effort approach to make sure replaying a HLO @@ -5136,7 +5139,7 @@ bool HloParserImpl::ParseAttributeHelper( return true; } case AttrTy::kOriginalValue: { - // By the time this attribute is added, the instruciton shape should + // By the time this attribute is added, the instruction shape should // have been inferred. if (!shape) { return TokenError("expects instruction shape"); @@ -6488,18 +6491,25 @@ bool HloParserImpl::ParseOriginalValue( ++leaf_shape_index.back(); } else if (lexer_.GetKind() == TokKind::kLbrace) { lexer_.Lex(); - std::string instruction_name; - ShapeIndex shape_index; - if (!ParseString(&instruction_name)) { - return false; - } if (lexer_.GetKind() != TokKind::kRbrace) { - if (!ParseShapeIndex(&shape_index)) { + std::string instruction_name; + ShapeIndex shape_index; + if (!ParseString(&instruction_name)) { return false; } + if (lexer_.GetKind() != TokKind::kRbrace) { + if (!ParseShapeIndex(&shape_index)) { + return false; + } + } + *(**original_value)->mutable_element(leaf_shape_index) = { + instruction_name, shape_index}; + } else { + // The original_value is not expected to have any leaf without values. + // However we should not fail the execution here. This should + // be done in HloVerifier instead. + LOG(WARNING) << "Found an empty leaf node in an original value"; } - *(**original_value)->mutable_element(leaf_shape_index) = { - instruction_name, shape_index}; if (!ParseToken(TokKind::kRbrace, "Expects '} at end of each OriginalArray'")) { return false; @@ -6522,7 +6532,6 @@ bool HloParserImpl::ParseMetadata(OpMetadata& metadata) { optional source_line; optional> profile_type; optional deduplicated_name; - optional preserve_layout; optional scheduling_name; attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type}; attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name}; @@ -6532,8 +6541,6 @@ bool HloParserImpl::ParseMetadata(OpMetadata& metadata) { &profile_type}; attrs["deduplicated_name"] = {/*required=*/false, AttrTy::kString, &deduplicated_name}; - attrs["preserve_layout"] = {/*required=*/false, AttrTy::kBool, - &preserve_layout}; attrs["scheduling_name"] = {/*required=*/false, AttrTy::kString, &scheduling_name}; if (!ParseSubAttributes(attrs)) { @@ -6562,11 +6569,6 @@ bool HloParserImpl::ParseMetadata(OpMetadata& metadata) { if (deduplicated_name) { metadata.set_deduplicated_name(*deduplicated_name); } - if (preserve_layout) { - metadata.set_preserve_layout(*preserve_layout); - } else { - metadata.set_preserve_layout(false); - } if (scheduling_name) { metadata.set_scheduling_name(*scheduling_name); } @@ -6629,7 +6631,7 @@ bool HloParserImpl::ParseListShardingType( if (!ParseOpShardingType(&type)) { return false; } - types->emplace_back(type); + types->push_back(type); } while (EatIfPresent(TokKind::kComma)); } diff --git a/third_party/xla/xla/hlo/parser/hlo_parser.h b/third_party/xla/xla/hlo/parser/hlo_parser.h index 302bc829f9bd92..3d1d2f25f999f9 100644 --- a/third_party/xla/xla/hlo/parser/hlo_parser.h +++ b/third_party/xla/xla/hlo/parser/hlo_parser.h @@ -40,8 +40,19 @@ class HloParserOptions { bool fill_missing_layouts() const { return fill_missing_layouts_; } + // Fill short form constants (dots) with deterministic random values. + HloParserOptions& set_fill_shortform_constants_with_random_values( + bool value) { + fill_shortform_constants_with_random_values_ = value; + return *this; + } + bool fill_shortform_constants_with_random_values() const { + return fill_shortform_constants_with_random_values_; + } + private: bool fill_missing_layouts_ = true; + bool fill_shortform_constants_with_random_values_ = true; }; // Given a string in the HloModule::ToString() format, parses the string and diff --git a/third_party/xla/xla/hlo/parser/hlo_parser_test.cc b/third_party/xla/xla/hlo/parser/hlo_parser_test.cc index 9c1ecc1836513a..31ec363ee6df28 100644 --- a/third_party/xla/xla/hlo/parser/hlo_parser_test.cc +++ b/third_party/xla/xla/hlo/parser/hlo_parser_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -1531,18 +1530,6 @@ ENTRY %test (p: f32[100]) -> u32[100] { )" }, -{ -"MetadataPreserveLayout", -R"(HloModule test, entry_computation_layout={(f32[100]{0})->u32[100]{0}} - -ENTRY %test (p: f32[100]) -> u32[100] { - %p = f32[100]{0} parameter(0) - ROOT %root = u32[100]{0} bitcast-convert(f32[100]{0} %p), metadata={op_type="a" op_name="b" source_file="c" source_line=1 profile_type={1} deduplicated_name="d" preserve_layout=true} -} - -)" -}, - { "OriginalValue", R"(HloModule test, entry_computation_layout={(f32[], f32[3]{0}, f32[2,3]{1,0})->((f32[], f32[3]{0}), f32[2,3]{1,0})} @@ -4720,7 +4707,7 @@ TEST_F(HloParserTest, ParseDynamicTuple) { } TEST_F(HloParserTest, ParseInvalidDimLevel) { - constexpr std::string_view shape_string = "f32[123]{0:D(D+~)}"; + constexpr absl::string_view shape_string = "f32[123]{0:D(D+~)}"; absl::StatusOr result = ParseShape(shape_string); ASSERT_THAT( result.status(), @@ -5727,6 +5714,20 @@ ENTRY %test { HasSubstr("expects instruction shape"))); } +TEST_F(HloParserTest, EmptyLeafInOriginalValue) { + const std::string hlo_string = R"(HloModule test + +ENTRY %test { + ROOT op = ((f32[], f32[3]{0}), f32[2,3]) parameter(0), origin={(({}, {"v2"}), {"v3"})} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + ExpectHasSubstr(module->ToString(HloPrintOptions::ShortParsable()), + "origin={(({}, {\"v2\"}), {\"v3\"})}"); +} + TEST_F(HloParserTest, TranscendentalAccuracyMode) { constexpr absl::string_view hlo_string = R"( HloModule exponential_hw @@ -5843,21 +5844,5 @@ ENTRY main { "error: unexpected attribute \"result_accuracy\""); } -TEST_F(HloParserTest, EmptyOriginalValueIsPrintedCorrectly) { - const std::string hlo_string = R"(HloModule test - -ENTRY %test { - ROOT op = f32[] parameter(0), origin={} -} - - -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); - - ExpectHasSubstr(module->ToString(HloPrintOptions::Fingerprint()), - "origin={}"); -} - } // namespace } // namespace xla diff --git a/third_party/xla/xla/hlo/pass/BUILD b/third_party/xla/xla/hlo/pass/BUILD index 94eb25100195f5..a6014f5256fb99 100644 --- a/third_party/xla/xla/hlo/pass/BUILD +++ b/third_party/xla/xla/hlo/pass/BUILD @@ -58,7 +58,6 @@ cc_library( "//xla:status_macros", "//xla:types", "//xla:util", - "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:compilation_stats", "//xla/service:dump", diff --git a/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc b/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc index 20a3414a4e9c4e..e5ecfd6c22a123 100644 --- a/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc +++ b/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc @@ -30,7 +30,6 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/types.h" #include "xla/util.h" -#include "xla/xla.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" diff --git a/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.h b/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.h index d12ce3fa356786..e6b6cf4c7d7a52 100644 --- a/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.h +++ b/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.h @@ -31,7 +31,6 @@ limitations under the License. #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/compilation_stats.h" #include "xla/types.h" -#include "xla/xla.pb.h" namespace xla { diff --git a/third_party/xla/xla/hlo/testlib/BUILD b/third_party/xla/xla/hlo/testlib/BUILD index b8c754075bd3ac..9bd094f95111a9 100644 --- a/third_party/xla/xla/hlo/testlib/BUILD +++ b/third_party/xla/xla/hlo/testlib/BUILD @@ -112,3 +112,40 @@ cc_library( "@local_tsl//tsl/platform:resource_loader", ], ) + +cc_library( + name = "pattern_matcher_gmock", + testonly = 1, + hdrs = ["pattern_matcher_gmock.h"], + deps = [ + "test", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "test", + testonly = 1, + hdrs = ["test.h"], + visibility = internal_visibility([":friends"]), + deps = [ + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "test_helpers", + testonly = 1, + hdrs = ["test_helpers.h"], + visibility = internal_visibility([":friends"]), + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:test", + ], +) diff --git a/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.cc b/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.cc index bbe1ecea736a3e..d5af349ef6dece 100644 --- a/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.cc +++ b/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_replace.h" @@ -119,7 +120,7 @@ HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule( allow_mixed_precision_in_hlo_verifier_, ShapeUtil::ByteSizeOfElements, instruction_can_change_layout_func_); TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text)); - return std::move(module); + return module; } /* static */ @@ -258,9 +259,11 @@ HloHardwareIndependentTestBase::RunAndCheckHloRewrite( VLOG(7) << "Input HLO: " << hlo_string; TF_ASSIGN_OR_RETURN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_string)); + VLOG(7) << "Input HLO parsed. Running the pass: + " << hlo_pass.name(); TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(hlo_pass, module.get())); VLOG(7) << "Output HLO: " - << module->ToString(HloPrintOptions::ShortParsable()); + << module->ToString(HloPrintOptions::ShortParsable() + .set_print_control_dependencies(true)); EXPECT_EQ(changed, expect_change); return module; } diff --git a/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.h b/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.h index 2a7f1f488b54e8..e41bcea3e4d828 100644 --- a/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.h +++ b/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.h @@ -55,6 +55,23 @@ class HloHardwareIndependentTestBase : public ::testing::Test { public: static PrecisionConfig DefaultPrecisionConfig(int operands); + // Gets the computation/instruction from the given module with the given name. + // Note that it is encouraged to use these functions directly via the + // hlo_query.h header instead since they are independent from any test-time + // variables or contexts. + + // This is useful for tests which create HLOs from a string and then want to + // inspect a particular computation or instruction. + static HloComputation* FindComputation(HloModule* module, + absl::string_view name); + static HloInstruction* FindInstruction(HloModule* module, + absl::string_view name); + // Gets the instruction from the given module with the given opcode. + static HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode); + // Gets all the instructions from the given module with the given opcode. + static std::vector FindInstructions(HloModule* module, + HloOpcode opcode); + protected: explicit HloHardwareIndependentTestBase( bool verifier_layout_sensitive = false, @@ -199,22 +216,6 @@ class HloHardwareIndependentTestBase : public ::testing::Test { ->Clear(); } - // Gets the computation/instruction from the given module with the given name. - // Note that it is encouraged to use these functions directly via the - // hlo_query.h header instead since they are independent from any test-time - // variables or contexts. - - // This is useful for tests which create HLOs from a string and then want to - // inspect a particular computation or instruction. - static HloComputation* FindComputation(HloModule* module, - absl::string_view name); - static HloInstruction* FindInstruction(HloModule* module, - absl::string_view name); - // Gets the instruction from the given module with the given opcode. - static HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode); - // Gets all the instructions from the given module with the given opcode. - static std::vector FindInstructions(HloModule* module, - HloOpcode opcode); bool verifier_layout_sensitive() const { return verifier_layout_sensitive_; } void set_verifier_layout_sensitive(bool verifier_layout_sensitive) { diff --git a/third_party/xla/xla/hlo/testlib/pattern_matcher_gmock.h b/third_party/xla/xla/hlo/testlib/pattern_matcher_gmock.h new file mode 100644 index 00000000000000..a2558e9510000e --- /dev/null +++ b/third_party/xla/xla/hlo/testlib/pattern_matcher_gmock.h @@ -0,0 +1,108 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TESTLIB_PATTERN_MATCHER_GMOCK_H_ +#define XLA_HLO_TESTLIB_PATTERN_MATCHER_GMOCK_H_ + +#include + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/layout.h" +#include "xla/service/pattern_matcher.h" +#include "xla/shape.h" +#include "tsl/platform/test.h" + +namespace xla { + +namespace pattern_matcher_gmock_detail { +template +class GmockMatcher { + public: + explicit GmockMatcher(Pattern p) : pattern_(std::move(p)) {} + + // In service of better error messages, list out the overloads explicitly + // rather than just using a template. gMock's polymorphism plus + // pattern_matcher yields some pretty gnarly stuff. + bool MatchAndExplain(const Layout& l, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(&l, listener); + } + bool MatchAndExplain(const Layout* l, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(l, listener); + } + bool MatchAndExplain(Layout* l, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(l, listener); + } + + bool MatchAndExplain(const Shape& s, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(&s, listener); + } + bool MatchAndExplain(const Shape* s, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(s, listener); + } + bool MatchAndExplain(Shape* s, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(s, listener); + } + + bool MatchAndExplain(const HloInstruction& instr, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(&instr, listener); + } + bool MatchAndExplain(const HloInstruction* instr, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(instr, listener); + } + bool MatchAndExplain(HloInstruction* instr, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(instr, listener); + } + + void DescribeTo(std::ostream* os) const { pattern_.DescribeTo(os); } + + void DescribeNegationTo(std::ostream* os) const { + *os << "is NOT: "; + DescribeTo(os); + } + + private: + template + bool MatchAndExplainImpl(T* t, + ::testing::MatchResultListener* listener) const { + MatchOption options{/*.capture=*/true, /*.single_user_only=*/false, + /*.explain_os=*/listener->stream()}; + return Match(t, pattern_, options); + } + + Pattern pattern_; +}; +} // namespace pattern_matcher_gmock_detail + +template +::testing::PolymorphicMatcher< + pattern_matcher_gmock_detail::GmockMatcher> +GmockMatch(Pattern&& p) { + return ::testing::MakePolymorphicMatcher( + pattern_matcher_gmock_detail::GmockMatcher( + std::forward(p))); +} + +} // namespace xla + +#endif // XLA_HLO_TESTLIB_PATTERN_MATCHER_GMOCK_H_ diff --git a/third_party/xla/xla/hlo/testlib/test.h b/third_party/xla/xla/hlo/testlib/test.h new file mode 100644 index 00000000000000..adbeaffb90fc87 --- /dev/null +++ b/third_party/xla/xla/hlo/testlib/test.h @@ -0,0 +1,49 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TESTLIB_TEST_H_ +#define XLA_HLO_TESTLIB_TEST_H_ + +// This header includes gmock.h and enables the use of gmock matchers in tests +// in third_party/tensorflow/compiler/xla. +// +// Test including this header can use the macros EXPECT_THAT(...) and +// ASSERT_THAT(...) in combination with gmock matchers. +// Example: +// std::vector vec = Foo(); +// EXPECT_THAT(vec, ::testing::ElementsAre(1,2,3)); +// +// For more details on gmock matchers see: +// https://github.com/google/googletest/blob/master/googlemock/docs/CheatSheet.md#matchers +// +// The advantages of using gmock matchers instead of self defined matchers are +// better error messages, more maintainable tests and more test coverage. +// +// Note that while the use of gmock matchers is allowed in the xla project, the +// use of mocks is disallowed in the whole tensorflow project! + +#include "tsl/platform/platform.h" + +#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) +#include // IWYU pragma: export +#else +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#endif + +#include "tsl/platform/test.h" // IWYU pragma: export + +#endif // XLA_HLO_TESTLIB_TEST_H_ diff --git a/third_party/xla/xla/hlo/testlib/test_helpers.h b/third_party/xla/xla/hlo/testlib/test_helpers.h new file mode 100644 index 00000000000000..6af0436ee7c963 --- /dev/null +++ b/third_party/xla/xla/hlo/testlib/test_helpers.h @@ -0,0 +1,68 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TESTLIB_TEST_HELPERS_H_ +#define XLA_HLO_TESTLIB_TEST_HELPERS_H_ + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tsl/platform/test.h" + +// This module contains a minimal subset of gmock functionality just +// sufficient to execute the currently existing tests. + +namespace xla { +template +class Array2D; +class Literal; + +namespace testing { + +namespace internal_status { +// TODO(b/340953531) Eliminate this function. +inline const absl::Status& GetStatus(const absl::Status& status) { + return status; +} + +template +inline const absl::Status& GetStatus(const absl::StatusOr& status) { + return status.status(); +} +} // namespace internal_status + +} // namespace testing +} // namespace xla + +// The following macros are similar to macros in gmock, but deliberately named +// differently in order to avoid conflicts in files which include both. + +// Macros for testing the results of functions that return absl::Status or +// absl::StatusOr (for any type T). +#define EXPECT_IS_OK(expression) \ + EXPECT_EQ(::absl::OkStatus(), \ + xla::testing::internal_status::GetStatus(expression)) +#define EXPECT_IS_NOT_OK(expression) \ + EXPECT_NE(::absl::OkStatus(), \ + xla::testing::internal_status::GetStatus(expression)) +#undef ASSERT_IS_OK +#define ASSERT_IS_OK(expression) \ + ASSERT_EQ(::absl::OkStatus(), \ + xla::testing::internal_status::GetStatus(expression)) +#undef ASSERT_IS_NOT_OK +#define ASSERT_IS_NOT_OK(expression) \ + ASSERT_NE(::absl::OkStatus(), \ + xla::testing::internal_status::GetStatus(expression)) + +#endif // XLA_HLO_TESTLIB_TEST_HELPERS_H_ diff --git a/third_party/xla/xla/hlo/tools/BUILD b/third_party/xla/xla/hlo/tools/BUILD index e0d0e8c984b953..eb2be4ab665bd1 100644 --- a/third_party/xla/xla/hlo/tools/BUILD +++ b/third_party/xla/xla/hlo/tools/BUILD @@ -187,3 +187,12 @@ xla_cc_binary( "@stablehlo//:register", ], ) + +xla_cc_binary( + name = "hlo-opt", + testonly = True, + linkopts = ["-Wl,-rpath,$$ORIGIN/../lit_lib"], + deps = [ + "//xla/hlo/tools/hlo_opt:opt_main", + ], +) diff --git a/third_party/xla/xla/hlo/tools/hlo_opt/BUILD b/third_party/xla/xla/hlo/tools/hlo_opt/BUILD new file mode 100644 index 00000000000000..563647ec0d1b20 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/hlo_opt/BUILD @@ -0,0 +1,95 @@ +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla:internal"], + licenses = ["notice"], +) + +cc_library( + name = "opt_main", + testonly = True, + srcs = ["opt_main.cc"], + deps = [ + ":opt_lib", + "//xla:debug_options_flags", + "//xla/hlo/ir:hlo", + "//xla/tools:hlo_module_loader", + "//xla/tsl/util:command_line_flags", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:statusor", + ], +) + +# Includes a macro to register a provider. +cc_library( + name = "opt_lib", + srcs = ["opt_lib.cc"], + hdrs = ["opt_lib.h"], + deps = [ + "//xla/hlo/analysis:indexed_array_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:cholesky_expander", + "//xla/hlo/transforms:comparison_expander", + "//xla/hlo/transforms:convert_memory_placement_to_internal_annotations", + "//xla/hlo/transforms:convolution_4d_expander", + "//xla/hlo/transforms:convolution_pred_expander", + "//xla/hlo/transforms:dot_decomposer", + "//xla/hlo/transforms:dynamic_index_splitter", + "//xla/hlo/transforms:eigh_expander", + "//xla/hlo/transforms:logistic_expander", + "//xla/hlo/transforms:operand_upcaster", + "//xla/hlo/transforms:optimization_barrier_expander", + "//xla/hlo/transforms:qr_expander", + "//xla/hlo/transforms:real_imag_expander", + "//xla/hlo/transforms:reduce_decomposer", + "//xla/hlo/transforms:reshape_decomposer", + "//xla/hlo/transforms:rng_expander", + "//xla/hlo/transforms:stable_sort_expander", + "//xla/hlo/transforms:stochastic_convert_decomposer", + "//xla/hlo/transforms:while_loop_trip_count_annotator", + "//xla/hlo/transforms/collectives:all_gather_broadcast_reorder", + "//xla/hlo/transforms/collectives:all_reduce_contiguous", + "//xla/hlo/transforms/collectives:collective_quantizer", + "//xla/hlo/transforms/simplifiers:all_reduce_folder", + "//xla/hlo/transforms/simplifiers:batch_dot_simplification", + "//xla/hlo/transforms/simplifiers:broadcast_canonicalizer", + "//xla/hlo/transforms/simplifiers:conditional_canonicalizer", + "//xla/hlo/transforms/simplifiers:convert_mover", + "//xla/hlo/transforms/simplifiers:convolution_group_converter", + "//xla/hlo/transforms/simplifiers:dynamic_dimension_simplifier", + "//xla/hlo/transforms/simplifiers:flatten_call_graph", + "//xla/hlo/transforms/simplifiers:float_normalization", + "//xla/hlo/transforms/simplifiers:gather_simplifier", + "//xla/hlo/transforms/simplifiers:hlo_constant_folding", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:optimize_input_output_buffer_alias", + "//xla/hlo/transforms/simplifiers:reshape_mover", + "//xla/hlo/transforms/simplifiers:result_caster", + "//xla/hlo/transforms/simplifiers:simplify_fp_conversions", + "//xla/hlo/transforms/simplifiers:slice_sinker", + "//xla/hlo/transforms/simplifiers:sort_simplifier", + "//xla/hlo/transforms/simplifiers:sub_byte_normalization", + "//xla/hlo/transforms/simplifiers:tree_reduction_rewriter", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:zero_sized_hlo_elimination", + "//xla/hlo/transforms/tests:dummy_passes", + "//xla/service:float_support", + "//xla/service:platform_util", + "//xla/stream_executor/platform:initialize", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:statusor", + ], +) diff --git a/third_party/xla/xla/tools/hlo_opt/opt_lib.cc b/third_party/xla/xla/hlo/tools/hlo_opt/opt_lib.cc similarity index 79% rename from third_party/xla/xla/tools/hlo_opt/opt_lib.cc rename to third_party/xla/xla/hlo/tools/hlo_opt/opt_lib.cc index 62b421d058b8b4..78fad847cf289a 100644 --- a/third_party/xla/xla/tools/hlo_opt/opt_lib.cc +++ b/third_party/xla/xla/hlo/tools/hlo_opt/opt_lib.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/tools/hlo_opt/opt_lib.h" +#include "xla/hlo/tools/hlo_opt/opt_lib.h" #include #include @@ -81,38 +81,11 @@ limitations under the License. #include "xla/hlo/transforms/simplifiers/tree_reduction_rewriter.h" #include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.h" +#include "xla/hlo/transforms/tests/dummy_passes.h" #include "xla/hlo/transforms/while_loop_trip_count_annotator.h" -#include "xla/service/all_reduce_simplifier.h" -#include "xla/service/all_to_all_decomposer.h" -#include "xla/service/batched_gather_scatter_normalizer.h" -#include "xla/service/bitcast_dtypes_expander.h" -#include "xla/service/call_inliner.h" -#include "xla/service/conditional_simplifier.h" -#include "xla/service/conditional_to_select.h" -#include "xla/service/copy_insertion.h" #include "xla/service/float_support.h" -#include "xla/service/gather_expander.h" -#include "xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h" -#include "xla/service/gpu/transforms/all_reduce_splitter.h" -#include "xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h" -#include "xla/service/gpu/transforms/scatter_expander.h" -#include "xla/service/gpu/transforms/scatter_slice_simplifier.h" -#include "xla/service/map_inliner.h" #include "xla/service/platform_util.h" -#include "xla/service/reduce_scatter_reassociate.h" -#include "xla/service/scatter_determinism_expander.h" -#include "xla/service/scatter_simplifier.h" -#include "xla/service/select_and_scatter_expander.h" -#include "xla/service/sharding_remover.h" -#include "xla/service/spmd/shardy/shardy_xla_pass.h" -#include "xla/service/topk_rewriter.h" -#include "xla/service/triangular_solve_expander.h" -#include "xla/service/while_loop_all_reduce_code_motion.h" -#include "xla/service/while_loop_constant_sinking.h" -#include "xla/service/while_loop_invariant_code_motion.h" -#include "xla/service/while_loop_simplifier.h" #include "xla/stream_executor/platform/initialize.h" -#include "xla/tools/hlo_opt/transforms_example_passes.h" #include "tsl/platform/statusor.h" namespace xla { @@ -214,75 +187,48 @@ void OptProvider::RegisterAllHardwareIndependentPasses() { // Hardware-independent HLO passes // go/keep-sorted start RegisterPass(); - RegisterPass(); RegisterPass(); RegisterPass(); - RegisterPass(); - RegisterPass(); - RegisterPass(); RegisterPass(); - RegisterPass(); - RegisterPass(); RegisterPass(); - RegisterPass(); RegisterPass(); - RegisterPass(); RegisterPass(); RegisterPass(); RegisterPass(); - RegisterPass(); - RegisterPass(); RegisterPass(); RegisterPass(); RegisterPass(); RegisterPass(); - RegisterPass(); RegisterPass(); RegisterPass(); RegisterPass(); RegisterPass(); RegisterPass(); - RegisterPass(GatherExpander::kEliminateSimpleGathers); RegisterPass(); - RegisterPass(); RegisterPass(); RegisterPass(); RegisterPass(); RegisterPass(); - RegisterPass(); RegisterPass(); RegisterPass(); RegisterPass(true); RegisterPass(); RegisterPass(); RegisterPass(); - RegisterPass(); RegisterPass(); RegisterPass(); RegisterPass(); RegisterPass(); - RegisterPass(); - RegisterPass(); - RegisterPass(); - RegisterPass(); - RegisterPass(); RegisterPass(); RegisterPass(); RegisterPass(); RegisterPass(); RegisterPass(); RegisterPass(SubByteNormalization::SET_ELEMENT_SIZE); - RegisterPass(); RegisterPass(); - RegisterPass(); RegisterPass(); - RegisterPass(); - RegisterPass(); - RegisterPass(); - RegisterPass(); RegisterPass(); RegisterPass(); - RegisterPass(); // go/keep-sorted end FloatSupport bf16_support(BF16); RegisterPass(&bf16_support); diff --git a/third_party/xla/xla/tools/hlo_opt/opt_lib.h b/third_party/xla/xla/hlo/tools/hlo_opt/opt_lib.h similarity index 96% rename from third_party/xla/xla/tools/hlo_opt/opt_lib.h rename to third_party/xla/xla/hlo/tools/hlo_opt/opt_lib.h index 841dcecb0363d2..2b487916631497 100644 --- a/third_party/xla/xla/tools/hlo_opt/opt_lib.h +++ b/third_party/xla/xla/hlo/tools/hlo_opt/opt_lib.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TOOLS_HLO_OPT_OPT_LIB_H_ -#define XLA_TOOLS_HLO_OPT_OPT_LIB_H_ +#ifndef XLA_HLO_TOOLS_HLO_OPT_OPT_LIB_H_ +#define XLA_HLO_TOOLS_HLO_OPT_OPT_LIB_H_ #include #include @@ -96,4 +96,4 @@ class OptProvider { } // namespace xla -#endif // XLA_TOOLS_HLO_OPT_OPT_LIB_H_ +#endif // XLA_HLO_TOOLS_HLO_OPT_OPT_LIB_H_ diff --git a/third_party/xla/xla/tools/hlo_opt/opt_main.cc b/third_party/xla/xla/hlo/tools/hlo_opt/opt_main.cc similarity index 99% rename from third_party/xla/xla/tools/hlo_opt/opt_main.cc rename to third_party/xla/xla/hlo/tools/hlo_opt/opt_main.cc index 31dba72ca48c78..e2d0992611e9fd 100644 --- a/third_party/xla/xla/tools/hlo_opt/opt_main.cc +++ b/third_party/xla/xla/hlo/tools/hlo_opt/opt_main.cc @@ -33,9 +33,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_module_config.h" +#include "xla/hlo/tools/hlo_opt/opt_lib.h" #include "xla/tools/hlo_module_loader.h" -#include "xla/tools/hlo_opt/opt_lib.h" #include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/env.h" #include "tsl/platform/init_main.h" diff --git a/third_party/xla/xla/hlo/transforms/BUILD b/third_party/xla/xla/hlo/transforms/BUILD index bc1c0c2424bdb9..476a8f97de338b 100644 --- a/third_party/xla/xla/hlo/transforms/BUILD +++ b/third_party/xla/xla/hlo/transforms/BUILD @@ -2,12 +2,7 @@ # Implementation of XLA’s HLO transformations. load("//xla:xla.bzl", "xla_cc_test") -load("//xla/tsl:tsl.bzl", "tsl_copts") load("//xla/tsl/platform:rules_cc.bzl", "cc_library") -load( - "//xla/tsl/platform/default:cuda_build_defs.bzl", - "if_cuda_is_configured", -) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -23,1667 +18,600 @@ package_group( ) cc_library( - name = "hlo_constant_splitter", - srcs = ["simplifiers/hlo_constant_splitter.cc"], - hdrs = ["simplifiers/hlo_constant_splitter.h"], + name = "bfloat16_propagation", + srcs = ["bfloat16_propagation.cc"], + hdrs = ["bfloat16_propagation.h"], deps = [ + "//xla:literal", + "//xla:shape_tree", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", + "//xla/service:float_support", + "//xla/service:hlo_value", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "hlo_constant_splitter_test", - srcs = ["simplifiers/hlo_constant_splitter_test.cc"], + name = "bfloat16_propagation_test", + srcs = ["bfloat16_propagation_test.cc"], deps = [ - ":hlo_constant_splitter", - ":hlo_dce", - "//xla:test", - "//xla:util", + ":bfloat16_propagation", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/tsl/lib/core:status_test_util", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", + "//xla/service:float_support", + "//xla/service:hlo_verifier", + "//xla/tests:literal_test_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep ], ) cc_library( - name = "all_reduce_folder", - srcs = ["simplifiers/all_reduce_folder.cc"], - hdrs = ["simplifiers/all_reduce_folder.h"], + name = "op_expander_pass", + srcs = ["expanders/op_expander_pass.cc"], + hdrs = ["expanders/op_expander_pass.h"], deps = [ - "//xla:xla_data_proto_cc", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "//xla/service:all_reduce_key", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "all_reduce_folder_test", - srcs = ["simplifiers/all_reduce_folder_test.cc"], - deps = [ - ":all_reduce_folder", - "//xla:test", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", ], ) cc_library( - name = "broadcast_canonicalizer", - srcs = ["simplifiers/broadcast_canonicalizer.cc"], - hdrs = ["simplifiers/broadcast_canonicalizer.h"], + name = "optimization_barrier_expander", + srcs = ["expanders/optimization_barrier_expander.cc"], + hdrs = ["expanders/optimization_barrier_expander.h"], deps = [ - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service:hlo_creation_utils", + ":op_expander_pass", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "broadcast_canonicalizer_test", - srcs = ["simplifiers/broadcast_canonicalizer_test.cc"], - deps = [ - ":broadcast_canonicalizer", - "//xla:test", - "//xla:test_helpers", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "@local_tsl//tsl/platform:test_main", # fixdeps: keep ], ) cc_library( - name = "bfloat16_conversion_folding", - srcs = ["simplifiers/bfloat16_conversion_folding.cc"], - hdrs = ["simplifiers/bfloat16_conversion_folding.h"], + name = "comparison_expander", + srcs = ["expanders/comparison_expander.cc"], + hdrs = ["expanders/comparison_expander.h"], deps = [ + ":op_expander_pass", + "//xla:comparison_util", + "//xla:literal_util", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service:float_support", - "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", + "@com_google_absl//absl/types:span", ], ) -xla_cc_test( - name = "bfloat16_conversion_folding_test", - srcs = ["simplifiers/bfloat16_conversion_folding_test.cc"], +cc_library( + name = "cholesky_expander", + srcs = ["expanders/cholesky_expander.cc"], + hdrs = ["expanders/cholesky_expander.h"], deps = [ - ":bfloat16_conversion_folding", + ":op_expander_pass", + "//xla:literal", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", + "//xla:status_macros", + "//xla:util", "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:float_support", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:arithmetic", + "//xla/hlo/builder/lib:constants", + "//xla/hlo/builder/lib:loops", + "//xla/hlo/builder/lib:math", + "//xla/hlo/builder/lib:matrix", + "//xla/hlo/builder/lib:slicing", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:test_main", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", ], ) cc_library( - name = "float_normalization", - srcs = ["simplifiers/float_normalization.cc"], - hdrs = ["simplifiers/float_normalization.h"], + name = "qr_expander", + srcs = ["expanders/qr_expander.cc"], + hdrs = ["expanders/qr_expander.h"], deps = [ - ":hlo_dce", - ":tuple_simplifier", + ":op_expander_pass", + "//xla:literal", "//xla:shape_util", + "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service:call_graph", - "//xla/service:float_support", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:arithmetic", + "//xla/hlo/builder/lib:constants", + "//xla/hlo/builder/lib:loops", + "//xla/hlo/builder/lib:math", + "//xla/hlo/builder/lib:matrix", + "//xla/hlo/builder/lib:qr", + "//xla/hlo/builder/lib:slicing", + "//xla/service:hlo_creation_utils", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "real_imag_expander", + srcs = ["expanders/real_imag_expander.cc"], + hdrs = ["expanders/real_imag_expander.h"], + deps = [ + ":op_expander_pass", + "//xla:literal_util", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ], ) xla_cc_test( - name = "float_normalization_test", - srcs = ["simplifiers/float_normalization_test.cc"], + name = "real_imag_expander_test", + size = "small", + srcs = ["expanders/real_imag_expander_test.cc"], deps = [ - ":float_normalization", + ":real_imag_expander", + "//xla:literal", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", - "//xla:xla_data_proto_cc", + "//xla:types", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:float_support", + "//xla/hlo/testlib:pattern_matcher_gmock", + "//xla/hlo/testlib:test", + "//xla/hlo/utils:hlo_matchers", "//xla/service:hlo_creation_utils", - "//xla/service:hlo_verifier", "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], ) cc_library( - name = "bfloat16_propagation", - srcs = ["bfloat16_propagation.cc"], - hdrs = ["bfloat16_propagation.h"], + name = "eigh_expander", + srcs = ["expanders/eigh_expander.cc"], + hdrs = ["expanders/eigh_expander.h"], deps = [ - ":hlo_dce", - ":tuple_simplifier", - "//xla:literal", - "//xla:shape_tree", + ":op_expander_pass", + "//xla:literal_util", "//xla:shape_util", + "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/hlo/analysis:hlo_dataflow_analysis", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service:float_support", - "//xla/service:hlo_value", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:arithmetic", + "//xla/hlo/builder/lib:comparators", + "//xla/hlo/builder/lib:constants", + "//xla/hlo/builder/lib:loops", + "//xla/hlo/builder/lib:math", + "//xla/hlo/builder/lib:matrix", + "//xla/hlo/builder/lib:slicing", + "//xla/service:hlo_creation_utils", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", ], ) -xla_cc_test( - name = "bfloat16_propagation_test", - srcs = ["bfloat16_propagation_test.cc"], +cc_library( + name = "convolution_4d_expander", + srcs = ["expanders/convolution_4d_expander.cc"], + hdrs = ["expanders/convolution_4d_expander.h"], deps = [ - ":bfloat16_propagation", - "//xla:comparison_util", - "//xla:literal_util", + ":op_expander_pass", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:float_support", - "//xla/service:hlo_verifier", - "//xla/tests:literal_test_util", - "@com_google_absl//absl/log", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", # fixdeps: keep + "@com_google_absl//absl/strings", ], ) -cc_library( - name = "flatten_call_graph", - srcs = ["simplifiers/flatten_call_graph.cc"], - hdrs = ["simplifiers/flatten_call_graph.h"], +xla_cc_test( + name = "convolution_4d_expander_test", + srcs = ["expanders/convolution_4d_expander_test.cc"], deps = [ - "//xla:util", + "convolution_4d_expander", "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "//xla/service:call_graph", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:test", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", ], ) cc_library( - name = "hlo_computation_deduplicator", - srcs = ["simplifiers/hlo_computation_deduplicator.cc"], - hdrs = ["simplifiers/hlo_computation_deduplicator.h"], + name = "convolution_pred_expander", + srcs = ["expanders/convolution_pred_expander.cc"], + hdrs = ["expanders/convolution_pred_expander.h"], deps = [ + ":op_expander_pass", "//xla:shape_util", - "//xla:status_macros", - "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", + "//xla/service:hlo_creation_utils", + "//xla/service:pattern_matcher", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:logging", ], ) xla_cc_test( - name = "hlo_computation_deduplicator_test", - size = "small", - srcs = ["simplifiers/hlo_computation_deduplicator_test.cc"], - deps = [ - ":hlo_computation_deduplicator", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "//xla/tsl/lib/core:status_test_util", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ], -) - -xla_cc_test( - name = "flatten_call_graph_test", - srcs = ["simplifiers/flatten_call_graph_test.cc"], - deps = [ - ":flatten_call_graph", - "//xla:comparison_util", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:call_graph", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "hlo_memory_scheduler", - srcs = ["simplifiers/hlo_memory_scheduler.cc"], - hdrs = ["simplifiers/hlo_memory_scheduler.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla/hlo/analysis:hlo_alias_analysis", - "//xla/hlo/analysis:tuple_points_to_analysis", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service:buffer_value", - "//xla/service:logical_buffer", - "//xla/service/heap_simulator", - "//xla/tsl/lib/gtl:map_util", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:numbers", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/lib:scoped_annotation", - ], -) - -xla_cc_test( - name = "hlo_memory_scheduler_test", - srcs = ["simplifiers/hlo_memory_scheduler_test.cc"], - deps = [ - ":hlo_dce", - ":hlo_memory_scheduler", - "//xla:literal_util", - "//xla:shape_util", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/hlo/analysis:hlo_alias_analysis", - "//xla/hlo/analysis:hlo_ordering", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:buffer_value", - "//xla/service:hlo_value", - "//xla/service:logical_buffer", - "//xla/service/heap_simulator", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "op_expander_pass", - srcs = ["expanders/op_expander_pass.cc"], - hdrs = ["expanders/op_expander_pass.h"], - deps = [ - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "optimization_barrier_expander", - srcs = ["expanders/optimization_barrier_expander.cc"], - hdrs = ["expanders/optimization_barrier_expander.h"], - deps = [ - ":op_expander_pass", - ], -) - -cc_library( - name = "comparison_expander", - srcs = ["expanders/comparison_expander.cc"], - hdrs = ["expanders/comparison_expander.h"], - deps = [ - ":op_expander_pass", - "//xla:comparison_util", - "//xla:literal_util", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "cholesky_expander", - srcs = ["expanders/cholesky_expander.cc"], - hdrs = ["expanders/cholesky_expander.h"], - deps = [ - ":op_expander_pass", - "//xla:literal", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/hlo/builder:xla_builder", - "//xla/hlo/builder/lib:arithmetic", - "//xla/hlo/builder/lib:constants", - "//xla/hlo/builder/lib:loops", - "//xla/hlo/builder/lib:math", - "//xla/hlo/builder/lib:matrix", - "//xla/hlo/builder/lib:slicing", - "//xla/service:hlo_creation_utils", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - ], -) - -cc_library( - name = "qr_expander", - srcs = ["expanders/qr_expander.cc"], - hdrs = ["expanders/qr_expander.h"], - deps = [ - ":op_expander_pass", - "//xla:literal", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/hlo/builder:xla_builder", - "//xla/hlo/builder/lib:arithmetic", - "//xla/hlo/builder/lib:constants", - "//xla/hlo/builder/lib:loops", - "//xla/hlo/builder/lib:math", - "//xla/hlo/builder/lib:matrix", - "//xla/hlo/builder/lib:qr", - "//xla/hlo/builder/lib:slicing", - "//xla/service:hlo_creation_utils", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - ], -) - -cc_library( - name = "real_imag_expander", - srcs = ["expanders/real_imag_expander.cc"], - hdrs = ["expanders/real_imag_expander.h"], - deps = [ - ":op_expander_pass", - "//xla:literal_util", - ], -) - -xla_cc_test( - name = "real_imag_expander_test", - size = "small", - srcs = ["expanders/real_imag_expander_test.cc"], - deps = [ - ":real_imag_expander", - "//xla:literal", - "//xla:shape_util", - "//xla:test", - "//xla:types", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "//xla/service:hlo_creation_utils", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "eigh_expander", - srcs = ["expanders/eigh_expander.cc"], - hdrs = ["expanders/eigh_expander.h"], - deps = [ - ":op_expander_pass", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/hlo/builder:xla_builder", - "//xla/hlo/builder/lib:arithmetic", - "//xla/hlo/builder/lib:comparators", - "//xla/hlo/builder/lib:constants", - "//xla/hlo/builder/lib:loops", - "//xla/hlo/builder/lib:math", - "//xla/hlo/builder/lib:matrix", - "//xla/hlo/builder/lib:slicing", - "//xla/service:hlo_creation_utils", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - ], -) - -cc_library( - name = "convolution_4d_expander", - srcs = ["expanders/convolution_4d_expander.cc"], - hdrs = ["expanders/convolution_4d_expander.h"], - deps = [ - ":op_expander_pass", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -xla_cc_test( - name = "convolution_4d_expander_test", - srcs = ["expanders/convolution_4d_expander_test.cc"], - deps = [ - "convolution_4d_expander", - "//xla:test", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "convolution_pred_expander", - srcs = ["expanders/convolution_pred_expander.cc"], - hdrs = ["expanders/convolution_pred_expander.h"], - deps = [ - ":op_expander_pass", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:pattern_matcher", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -xla_cc_test( - name = "convolution_pred_expander_test", - srcs = ["expanders/convolution_pred_expander_test.cc"], - deps = [ - ":convolution_pred_expander", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "algebraic_simplifier", - srcs = ["simplifiers/algebraic_simplifier.cc"], - hdrs = ["simplifiers/algebraic_simplifier.h"], - copts = tsl_copts(), - deps = [ - "//xla:comparison_util", - "//xla:literal", - "//xla:literal_util", - "//xla:permutation_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:window_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/evaluator:hlo_evaluator", - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_instruction_utils", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_sharding_util", - "//xla/service:gather_scatter_utils", - "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_module_config", - "//xla/service:host_memory_offload_annotations_hdr", - "//xla/service:host_offload_utils", - "//xla/service:pattern_matcher", - "//xla/service:shape_inference", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/numeric:bits", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "tree_reduction_rewriter", - srcs = ["simplifiers/tree_reduction_rewriter.cc"], - hdrs = ["simplifiers/tree_reduction_rewriter.h"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/builder:padding", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service:shape_inference", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - ], -) - -xla_cc_test( - name = "algebraic_simplifier_test", - srcs = ["simplifiers/algebraic_simplifier_test.cc"], - deps = [ - ":algebraic_simplifier", - ":hlo_constant_folding", - "//xla:comparison_util", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:util", - "//xla:window_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/parser:hlo_parser", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:hlo_creation_utils", - "//xla/service:host_memory_offload_annotations_hdr", - "//xla/service:layout_assignment", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service:shape_inference", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", # fixdeps: keep - ], -) - -cc_library( - name = "simplify_fp_conversions", - srcs = ["simplifiers/simplify_fp_conversions.cc"], - hdrs = ["simplifiers/simplify_fp_conversions.h"], - deps = [ - "//xla:comparison_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "simplify_fp_conversions_test", - srcs = ["simplifiers/simplify_fp_conversions_test.cc"], - deps = [ - ":simplify_fp_conversions", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", # fixdeps: keep - ], -) - -cc_library( - name = "logistic_expander", - srcs = ["expanders/logistic_expander.cc"], - hdrs = ["expanders/logistic_expander.h"], - deps = [ - ":op_expander_pass", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service:hlo_creation_utils", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:logging", - ], -) - -xla_cc_test( - name = "logistic_expander_test", - srcs = ["expanders/logistic_expander_test.cc"], - deps = [ - ":logistic_expander", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/parser:hlo_parser", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:dynamic_padder", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", # fixdeps: keep - ], -) - -cc_library( - name = "bitcast_dtypes_expander", - srcs = ["expanders/bitcast_dtypes_expander.cc"], - hdrs = ["expanders/bitcast_dtypes_expander.h"], - deps = [ - ":op_expander_pass", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/hlo/builder:xla_builder", - "//xla/hlo/builder:xla_computation", - "//xla/hlo/builder/lib:arithmetic", - "//xla/hlo/builder/lib:broadcast", - "//xla/hlo/builder/lib:constants", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_module_config", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "bitcast_dtypes_expander_test", - srcs = ["expanders/bitcast_dtypes_expander_test.cc"], - deps = [ - ":bitcast_dtypes_expander", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:filecheck", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "batch_dot_simplification", - srcs = ["simplifiers/batch_dot_simplification.cc"], - hdrs = ["simplifiers/batch_dot_simplification.h"], - deps = [ - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service:hlo_creation_utils", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "batch_dot_simplification_test", - srcs = ["simplifiers/batch_dot_simplification_test.cc"], - deps = [ - ":batch_dot_simplification", - "//xla:test", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", # fixdeps: keep - ], -) - -cc_library( - name = "convolution_group_converter", - srcs = ["simplifiers/convolution_group_converter.cc"], - hdrs = ["simplifiers/convolution_group_converter.h"], - deps = [ - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service:hlo_creation_utils", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - ], -) - -xla_cc_test( - name = "convolution_group_converter_test", - size = "small", - srcs = ["simplifiers/convolution_group_converter_test.cc"], - deps = [ - ":convolution_group_converter", - "//xla:test", - "//xla:types", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "@local_tsl//tsl/platform:test_main", # fixdeps: keep - ], -) - -cc_library( - name = "while_loop_trip_count_annotator", - srcs = ["while_loop_trip_count_annotator.cc"], - hdrs = ["while_loop_trip_count_annotator.h"], - deps = [ - "//xla:xla_data_proto_cc", - "//xla/hlo/analysis:while_loop_analysis", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "while_loop_trip_count_annotator_test", - srcs = ["while_loop_trip_count_annotator_test.cc"], - deps = [ - ":while_loop_trip_count_annotator", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", # fixdeps: keep - ], -) - -cc_library( - name = "defuser", - srcs = ["defuser.cc"], - hdrs = ["defuser.h"], - deps = [ - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service:call_graph", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - ], -) - -xla_cc_test( - name = "defuser_test", - srcs = ["defuser_test.cc"], - deps = [ - ":defuser", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:test_main", - ], -) - -xla_cc_test( - name = "despecializer_test", - srcs = ["despecializer_test.cc"], - deps = [ - ":despecializer", - "//xla:literal", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "@com_google_absl//absl/log", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "dot_decomposer", - srcs = ["expanders/dot_decomposer.cc"], - hdrs = ["expanders/dot_decomposer.h"], - deps = [ - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service:shape_inference", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "dot_decomposer_test", - srcs = ["expanders/dot_decomposer_test.cc"], - deps = [ - ":dot_decomposer", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", # fixdeps: keep - ], -) - -cc_library( - name = "dot_dimension_merger", - srcs = ["simplifiers/dot_dimension_merger.cc"], - hdrs = ["simplifiers/dot_dimension_merger.h"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service:hlo_creation_utils", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "dot_dimension_merger_test", - srcs = ["simplifiers/dot_dimension_merger_test.cc"], - deps = [ - ":dot_dimension_merger", - "//xla/hlo/parser:hlo_parser", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "dot_merger", - srcs = ["simplifiers/dot_merger.cc"], - hdrs = ["simplifiers/dot_merger.h"], - deps = [ - "//xla:protobuf_util", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service:shape_inference", - "//xla/service/graphcycles", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "dot_merger_test", - srcs = ["simplifiers/dot_merger_test.cc"], + name = "convolution_pred_expander_test", + srcs = ["expanders/convolution_pred_expander_test.cc"], deps = [ - ":algebraic_simplifier", - ":dot_merger", - "//xla:shape_util", + ":convolution_pred_expander", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", # fixdeps: keep + "@local_tsl//tsl/platform:test_main", ], ) cc_library( - name = "convert_mover", - srcs = ["simplifiers/convert_mover.cc"], - hdrs = ["simplifiers/convert_mover.h"], + name = "logistic_expander", + srcs = ["expanders/logistic_expander.cc"], + hdrs = ["expanders/logistic_expander.h"], deps = [ - "//xla:literal", + ":op_expander_pass", "//xla:shape_util", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_creation_utils", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:logging", ], ) xla_cc_test( - name = "convert_mover_test", - srcs = ["simplifiers/convert_mover_test.cc"], + name = "logistic_expander_test", + srcs = ["expanders/logistic_expander_test.cc"], deps = [ - ":convert_mover", + ":logistic_expander", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:pattern_matcher_gmock", + "//xla/hlo/testlib:test", + "//xla/service:dynamic_padder", "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", # fixdeps: keep ], ) cc_library( - name = "tuple_simplifier", - srcs = ["simplifiers/tuple_simplifier.cc"], - hdrs = ["simplifiers/tuple_simplifier.h"], + name = "bitcast_dtypes_expander", + srcs = ["expanders/bitcast_dtypes_expander.cc"], + hdrs = ["expanders/bitcast_dtypes_expander.h"], deps = [ + ":op_expander_pass", + "//xla:literal_util", "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:arithmetic", + "//xla/hlo/builder/lib:broadcast", + "//xla/hlo/builder/lib:constants", "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_module_config", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "tuple_simplifier_test", - srcs = ["simplifiers/tuple_simplifier_test.cc"], + name = "bitcast_dtypes_expander_test", + srcs = ["expanders/bitcast_dtypes_expander_test.cc"], deps = [ - ":tuple_simplifier", - "//xla:shape_util", - "//xla:test", + ":bitcast_dtypes_expander", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/utils:hlo_matchers", "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "reshape_mover", - srcs = ["simplifiers/reshape_mover.cc"], - hdrs = ["simplifiers/reshape_mover.h"], - deps = [ - "//xla:permutation_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/hlo/pass:hlo_pass", - "//xla/service:hlo_creation_utils", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - ], -) - -cc_library( - name = "reshape_decomposer", - srcs = ["expanders/reshape_decomposer.cc"], - hdrs = ["expanders/reshape_decomposer.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service:hlo_creation_utils", - "@com_google_absl//absl/status", - ], -) - -cc_library( - name = "reduce_decomposer", - srcs = ["expanders/reduce_decomposer.cc"], - hdrs = ["expanders/reduce_decomposer.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service:hlo_creation_utils", - "@com_google_absl//absl/status", - ], -) - -xla_cc_test( - name = "reduce_decomposer_test", - srcs = ["expanders/reduce_decomposer_test.cc"], - deps = [ - ":reduce_decomposer", - "//xla:test", - "//xla:test_helpers", - "//xla/hlo/parser:hlo_parser", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "@local_tsl//tsl/platform:test_main", - ], -) - -xla_cc_test( - name = "reshape_decomposer_test", - srcs = ["expanders/reshape_decomposer_test.cc"], - deps = [ - ":reshape_decomposer", - "//xla:test", - "//xla:test_helpers", - "//xla/hlo/parser:hlo_parser", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], ) cc_library( - name = "dynamic_dimension_simplifier", - srcs = ["simplifiers/dynamic_dimension_simplifier.cc"], - hdrs = ["simplifiers/dynamic_dimension_simplifier.h"], + name = "while_loop_trip_count_annotator", + srcs = ["while_loop_trip_count_annotator.cc"], + hdrs = ["while_loop_trip_count_annotator.h"], deps = [ - "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:while_loop_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", ], ) xla_cc_test( - name = "dynamic_dimension_simplifier_test", - srcs = ["simplifiers/dynamic_dimension_simplifier_test.cc"], + name = "while_loop_trip_count_annotator_test", + srcs = ["while_loop_trip_count_annotator_test.cc"], deps = [ - ":dynamic_dimension_simplifier", - "//xla:literal", - "//xla:shape_util", - "//xla:test", - "//xla:types", - "//xla:window_util", + ":while_loop_trip_count_annotator", "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/parser:hlo_parser", - "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:hlo_creation_utils", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service:shape_inference", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:test_main", # fixdeps: keep - ], -) - -xla_cc_test( - name = "reshape_mover_test", - srcs = ["simplifiers/reshape_mover_test.cc"], - deps = [ - ":algebraic_simplifier", - ":reshape_mover", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:hlo_verifier", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/status", + "//xla/hlo/testlib:test", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "memory_space_propagation", - srcs = ["memory_space_propagation.cc"], - hdrs = ["memory_space_propagation.h"], - deps = [ - "//xla:shape_util", - "//xla/hlo/analysis:hlo_dataflow_analysis", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - ], -) - -xla_cc_test( - name = "memory_space_propagation_test", - srcs = ["memory_space_propagation_test.cc"], - deps = [ - ":memory_space_propagation", - "//xla/hlo/parser:hlo_parser", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "hlo_dce", - srcs = ["simplifiers/hlo_dce.cc"], - hdrs = ["simplifiers/hlo_dce.h"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep ], ) cc_library( - name = "hlo_rematerialization", - srcs = ["simplifiers/hlo_rematerialization.cc"], - hdrs = ["simplifiers/hlo_rematerialization.h"], + name = "defuser", + srcs = ["defuser.cc"], + hdrs = ["defuser.h"], deps = [ - ":hlo_dce", - "//xla:shape_util", "//xla:status_macros", + "//xla:types", "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/analysis:hlo_dataflow_analysis", - "//xla/hlo/analysis:tuple_points_to_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", "//xla/service:call_graph", - "//xla/service:hlo_cost_analysis", - "//xla/service:logical_buffer", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:numbers", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "hlo_rematerialization_test_utils", - testonly = 1, - hdrs = ["simplifiers/hlo_rematerialization_test_utils.h"], - deps = [ - "//xla:literal_util", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "@local_tsl//tsl/platform:test_main", - ], -) - -xla_cc_test( - name = "hlo_rematerialization_test_utils_test", - srcs = ["simplifiers/hlo_rematerialization_test_utils_test.cc"], - deps = [ - ":hlo_rematerialization_test_utils", - "//xla/hlo/ir:hlo", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:test_main", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", ], ) xla_cc_test( - name = "hlo_rematerialization_test", - srcs = ["simplifiers/hlo_rematerialization_test.cc"], + name = "defuser_test", + srcs = ["defuser_test.cc"], deps = [ - ":hlo_memory_scheduler", - ":hlo_rematerialization", - ":hlo_rematerialization_test_utils", + ":defuser", + "//xla:literal", + "//xla:literal_util", "//xla:shape_util", - "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/utils:hlo_matchers", - "//xla/service:hlo_cost_analysis", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:test_main", ], ) xla_cc_test( - name = "hlo_dce_test", - srcs = ["simplifiers/hlo_dce_test.cc"], + name = "despecializer_test", + srcs = ["despecializer_test.cc"], deps = [ - ":hlo_dce", - "//xla:literal_util", + ":despecializer", + "//xla:literal", "//xla:shape_util", - "//xla:types", - "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:literal_test_util", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/types:span", + "//xla/hlo/utils:hlo_matchers", + "@com_google_absl//absl/log", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:test_main", ], ) cc_library( - name = "hlo_constant_folding", - srcs = ["simplifiers/hlo_constant_folding.cc"], - hdrs = ["simplifiers/hlo_constant_folding.h"], + name = "dot_decomposer", + srcs = ["expanders/dot_decomposer.cc"], + hdrs = ["expanders/dot_decomposer.h"], deps = [ - "//xla:literal", "//xla:shape_util", - "//xla/hlo/evaluator:hlo_evaluator", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/service:slow_operation_alarm", + "//xla/service:shape_inference", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/time", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "hlo_constant_folding_test", - srcs = ["simplifiers/hlo_constant_folding_test.cc"], + name = "dot_decomposer_test", + srcs = ["expanders/dot_decomposer_test.cc"], deps = [ - ":hlo_constant_folding", - "//xla:literal", - "//xla:literal_util", - "//xla:permutation_util", - "//xla:shape_util", - "//xla:test", - "//xla:xla_data_proto_cc", + ":dot_decomposer", "//xla/hlo/ir:hlo", - "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/hlo/utils:hlo_matchers", "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep ], ) -# copybara:uncomment_begin(google-only) -# xla_cc_test( -# name = "hlo_constant_folding_peak_heap_test", -# srcs = ["simplifiers/hlo_constant_folding_peak_heap_test.cc"], -# deps = [ -# ":hlo_constant_folding", -# "@com_google_googletest//:gtest", -# "@com_google_absl//absl/strings:str_format", -# "//xla:test", -# "//xla/hlo/testlib:hlo_hardware_independent_test_base", -# "@local_tsl//tsl/platform:statusor", -# "@local_tsl//tsl/platform:test_main", -# ], -# ) -# copybara:uncomment_end - cc_library( - name = "hlo_element_type_converter", - srcs = ["simplifiers/hlo_element_type_converter.cc"], - hdrs = ["simplifiers/hlo_element_type_converter.h"], + name = "reshape_decomposer", + srcs = ["expanders/reshape_decomposer.cc"], + hdrs = ["expanders/reshape_decomposer.h"], deps = [ - "//xla:literal", - "//xla:shape_util", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_creation_utils", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "hlo_element_type_converter_test", - srcs = ["simplifiers/hlo_element_type_converter_test.cc"], - deps = [ - ":hlo_element_type_converter", - "//xla:xla_data_proto_cc", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:test_main", ], ) cc_library( - name = "conditional_canonicalizer", - srcs = ["simplifiers/conditional_canonicalizer.cc"], - hdrs = ["simplifiers/conditional_canonicalizer.h"], + name = "reduce_decomposer", + srcs = ["expanders/reduce_decomposer.cc"], + hdrs = ["expanders/reduce_decomposer.h"], deps = [ - "//xla:status_macros", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", ], ) xla_cc_test( - name = "conditional_canonicalizer_test", - srcs = ["simplifiers/conditional_canonicalizer_test.cc"], + name = "reduce_decomposer_test", + srcs = ["expanders/reduce_decomposer_test.cc"], deps = [ - ":conditional_canonicalizer", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla/hlo/ir:hlo", + ":reduce_decomposer", "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:literal_test_util", - "//xla/tsl/lib/core:status_test_util", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:test_main", ], ) -cc_library( - name = "zero_sized_hlo_elimination", - srcs = ["simplifiers/zero_sized_hlo_elimination.cc"], - hdrs = ["simplifiers/zero_sized_hlo_elimination.h"], - deps = [ - "//xla:literal", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - ], -) - xla_cc_test( - name = "zero_sized_hlo_elimination_test", - srcs = ["simplifiers/zero_sized_hlo_elimination_test.cc"], + name = "reshape_decomposer_test", + srcs = ["expanders/reshape_decomposer_test.cc"], deps = [ - ":zero_sized_hlo_elimination", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", + ":reshape_decomposer", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:test_main", ], ) cc_library( - name = "sort_simplifier", - srcs = ["simplifiers/sort_simplifier.cc"], - hdrs = ["simplifiers/sort_simplifier.h"], + name = "memory_space_propagation", + srcs = ["memory_space_propagation.cc"], + hdrs = ["memory_space_propagation.h"], deps = [ "//xla:shape_util", - "//xla:util", + "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "sort_simplifier_test", - srcs = ["simplifiers/sort_simplifier_test.cc"], + name = "memory_space_propagation_test", + srcs = ["memory_space_propagation_test.cc"], deps = [ - ":sort_simplifier", - "//xla:test", + ":memory_space_propagation", "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:statusor", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:test_main", ], ) @@ -1694,10 +622,14 @@ cc_library( hdrs = ["expanders/stable_sort_expander.h"], deps = [ ":op_expander_pass", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", ], ) @@ -1705,42 +637,15 @@ xla_cc_test( name = "stable_sort_expander_test", srcs = ["expanders/stable_sort_expander_test.cc"], deps = [ - ":algebraic_simplifier", ":stable_sort_expander", - "//xla:test", "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:pattern_matcher_gmock", + "//xla/hlo/testlib:test", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", "//xla/hlo/utils:hlo_matchers", "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "root_instruction_sinker", - srcs = ["simplifiers/root_instruction_sinker.cc"], - hdrs = ["simplifiers/root_instruction_sinker.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service:tuple_util", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - ], -) - -xla_cc_test( - name = "root_instruction_sinker_test", - srcs = ["simplifiers/root_instruction_sinker_test.cc"], - deps = [ - ":root_instruction_sinker", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:test_main", ], @@ -1757,11 +662,14 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:host_memory_offload_annotations_hdr", + "//xla/tsl/platform:errors", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) @@ -1775,50 +683,10 @@ xla_cc_test( "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/testlib:verified_hlo_module", "//xla/service:host_memory_offload_annotations_hdr", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "host_memory_transfer_asyncifier", - srcs = ["simplifiers/host_memory_transfer_asyncifier.cc"], - hdrs = ["simplifiers/host_memory_transfer_asyncifier.h"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "host_memory_transfer_asyncifier_test", - srcs = ["simplifiers/host_memory_transfer_asyncifier_test.cc"], - deps = [ - ":host_memory_transfer_asyncifier", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/status", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], ) @@ -1843,6 +711,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:errors", @@ -1860,9 +729,9 @@ xla_cc_test( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/service:host_memory_offload_annotations_hdr", "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", @@ -1884,6 +753,7 @@ cc_library( "//xla:side_effect_util", "//xla:status_macros", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", @@ -1920,56 +790,18 @@ xla_cc_test( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/hlo/testlib:verified_hlo_module", "//xla/service:hlo_verifier", "//xla/service:host_memory_offload_annotations_hdr", "//xla/service:host_offload_utils", "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "host_offloading_prepare", - srcs = ["host_offloading_prepare.cc"], - hdrs = ["host_offloading_prepare.h"], - deps = [ - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service:call_graph", - "//xla/service:host_memory_offload_annotations_hdr", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "host_offloading_prepare_test", - srcs = ["host_offloading_prepare_test.cc"], - deps = [ - ":host_offloading_prepare", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:host_memory_offload_annotations_hdr", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", @@ -1977,34 +809,38 @@ xla_cc_test( ) cc_library( - name = "fusion_constant_sinking", - srcs = ["simplifiers/fusion_constant_sinking.cc"], - hdrs = ["simplifiers/fusion_constant_sinking.h"], + name = "host_offloading_prepare", + srcs = ["host_offloading_prepare.cc"], + hdrs = ["host_offloading_prepare.h"], deps = [ - ":hlo_dce", "//xla:shape_util", - "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/service:call_graph", + "//xla/service:host_memory_offload_annotations_hdr", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "fusion_constant_sinking_test", - srcs = ["simplifiers/fusion_constant_sinking_test.cc"], + name = "host_offloading_prepare_test", + srcs = ["host_offloading_prepare_test.cc"], deps = [ - ":fusion_constant_sinking", - "//xla:test", + ":host_offloading_prepare", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", + "//xla/service:host_memory_offload_annotations_hdr", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], @@ -2016,13 +852,13 @@ cc_library( hdrs = ["despecializer.h"], deps = [ ":defuser", - ":float_normalization", - ":hlo_memory_scheduler", - ":sub_byte_normalization", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms/simplifiers:float_normalization", + "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", + "//xla/hlo/transforms/simplifiers:sub_byte_normalization", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -2032,72 +868,41 @@ cc_library( ) cc_library( - name = "optimize_input_output_buffer_alias", - srcs = ["simplifiers/optimize_input_output_buffer_alias.cc"], - hdrs = ["simplifiers/optimize_input_output_buffer_alias.h"], + name = "literal_canonicalizer", + srcs = ["literal_canonicalizer.cc"], + hdrs = ["literal_canonicalizer.h"], deps = [ - "//xla:shape_util", - "//xla:status_macros", + "//xla:literal_pool", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", + "//xla/hlo/pass:hlo_pass_pipeline", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", ], ) xla_cc_test( - name = "optimize_input_output_buffer_alias_test", - srcs = ["simplifiers/optimize_input_output_buffer_alias_test.cc"], + name = "literal_canonicalizer_test", + srcs = ["literal_canonicalizer_test.cc"], deps = [ - ":optimize_input_output_buffer_alias", - "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", - "//xla:xla_data_proto_cc", + ":literal_canonicalizer", + "//xla:literal_pool", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", ], ) -cc_library( - name = "ar_crs_combiner", - srcs = ["simplifiers/ar_crs_combiner.cc"], - hdrs = ["simplifiers/ar_crs_combiner.h"], - deps = [ - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:xla_data_proto_cc", - "//xla/hlo/analysis:hlo_replication_analysis", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "//xla/service:call_graph", - "//xla/service:pattern_matcher", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "dynamic_index_splitter", srcs = ["expanders/dynamic_index_splitter.cc"], @@ -2111,6 +916,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -2119,70 +925,13 @@ xla_cc_test( srcs = ["expanders/dynamic_index_splitter_test.cc"], deps = [ ":dynamic_index_splitter", - "//xla:test", - "//xla:test_helpers", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "@local_tsl//tsl/platform:test_main", - ], -) - -xla_cc_test( - name = "ar_crs_combiner_test", - srcs = ["simplifiers/ar_crs_combiner_test.cc"], - deps = [ - ":ar_crs_combiner", - "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/hlo/utils:hlo_matchers", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "slice_sinker", - srcs = ["simplifiers/slice_sinker.cc"], - hdrs = ["simplifiers/slice_sinker.h"], - deps = [ - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "slice_sinker_test", - srcs = ["simplifiers/slice_sinker_test.cc"], - deps = [ - ":hlo_dce", - ":slice_sinker", - "//xla:literal_util", - "//xla:shape_util", - "//xla:types", - "//xla/hlo/ir:hlo", - "//xla/hlo/parser:hlo_parser", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tsl/lib/core:status_test_util", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:test_main", ], ) @@ -2195,9 +944,17 @@ cc_library( ":op_expander_pass", "//xla:literal_util", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder/lib:prng", "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", ], ) @@ -2215,7 +972,9 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:hlo_creation_utils", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ], ) @@ -2245,36 +1004,6 @@ xla_cc_test( deps = [ ":operand_upcaster", "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "result_caster", - srcs = ["simplifiers/result_caster.cc"], - hdrs = ["simplifiers/result_caster.h"], - deps = [ - ":op_expander_pass", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:shape_inference", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - ], -) - -xla_cc_test( - name = "result_caster_test", - srcs = ["simplifiers/result_caster_test.cc"], - deps = [ - ":result_caster", - "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", @@ -2286,110 +1015,6 @@ xla_cc_test( ], ) -cc_library( - name = "convert_operand_folding", - srcs = ["simplifiers/convert_operand_folder.cc"], - hdrs = ["simplifiers/convert_operand_folder.h"], - deps = [ - ":op_expander_pass", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "convert_operand_folding_test", - srcs = ["simplifiers/convert_operand_folder_test.cc"], - deps = [ - ":convert_operand_folding", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_matchers", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "instruction_hoister", - srcs = ["simplifiers/instruction_hoister.cc"], - hdrs = ["simplifiers/instruction_hoister.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:status", - ], -) - -cc_library( - name = "gather_simplifier", - srcs = ["simplifiers/gather_simplifier.cc"], - hdrs = ["simplifiers/gather_simplifier.h"], - deps = [ - ":op_expander_pass", - "//xla:literal_util", - "//xla:permutation_util", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:gather_scatter_utils", - "//xla/service:hlo_creation_utils", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "reduce_window_rewriter", - srcs = ["simplifiers/reduce_window_rewriter.cc"], - hdrs = ["simplifiers/reduce_window_rewriter.h"], - deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:window_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "reduce_window_rewriter_test", - srcs = ["simplifiers/reduce_window_rewriter_test.cc"], - deps = [ - ":reduce_window_rewriter", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:test_main", - ], -) - cc_library( name = "stochastic_convert_decomposer", srcs = ["expanders/stochastic_convert_decomposer.cc"], @@ -2402,8 +1027,12 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_creation_utils", "//xla/service:shape_inference", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], @@ -2418,26 +1047,9 @@ xla_cc_test( "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/utils:hlo_matchers", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "sub_byte_normalization", - srcs = ["simplifiers/sub_byte_normalization.cc"], - hdrs = ["simplifiers/sub_byte_normalization.h"], - deps = [ - "//xla:shape_layout", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:test_main", ], ) @@ -2458,18 +1070,6 @@ cc_library( ], ) -xla_cc_test( - name = "gather_simplifier_test", - srcs = ["simplifiers/gather_simplifier_test.cc"], - deps = [ - ":gather_simplifier", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:test_main", - ], -) - cc_library( name = "add_original_value", srcs = ["add_original_value.cc"], @@ -2494,9 +1094,9 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/hlo/testlib:verified_hlo_module", "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", diff --git a/third_party/xla/xla/hlo/transforms/bfloat16_propagation_test.cc b/third_party/xla/xla/hlo/transforms/bfloat16_propagation_test.cc index ff99c9215cbd1f..cf14c05d6a7365 100644 --- a/third_party/xla/xla/hlo/transforms/bfloat16_propagation_test.cc +++ b/third_party/xla/xla/hlo/transforms/bfloat16_propagation_test.cc @@ -28,13 +28,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal_util.h" #include "xla/service/float_support.h" #include "xla/service/hlo_verifier.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/literal_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -1154,7 +1154,7 @@ ENTRY main { // This test demonstrates the need for invoking the ResolveAliasingBuffer // multiple times via a fixed-point algorithm. The key was the aliasing of the // two output buffers of the conditional, at subshape 0 (first element). This -// aliasing is not resolved until after the gte0 variale is already processed, +// aliasing is not resolved until after the gte0 variable is already processed, // triggering incorrect type for gte0 if not repeating the aliasing analysis. TEST_F(BFloat16PropagationTest, ConditionalGTEWithFusion) { const std::string module_str = R"( diff --git a/third_party/xla/xla/hlo/transforms/collectives/BUILD b/third_party/xla/xla/hlo/transforms/collectives/BUILD index bf9014083eb7f1..5e991185bfba7e 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/BUILD +++ b/third_party/xla/xla/hlo/transforms/collectives/BUILD @@ -40,7 +40,7 @@ xla_cc_test( ":all_gather_cse", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/utils:hlo_matchers", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings:string_view", @@ -79,6 +79,7 @@ xla_cc_test( srcs = ["async_collective_creator_test.cc"], deps = [ ":async_collective_creator", + "//xla:side_effect_util", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", @@ -143,7 +144,6 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:collective_ops_utils", - "//xla/service:hlo_replication_analysis", "//xla/service:pattern_matcher", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", @@ -364,7 +364,7 @@ cc_library( "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", @@ -401,16 +401,20 @@ cc_library( deps = [ "//xla:shape_util", "//xla:util", + "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/service:call_graph", + "//xla/service:tuple_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", @@ -425,6 +429,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/utils:hlo_matchers", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", ], @@ -455,6 +460,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/utils:hlo_matchers", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", # fixdeps: keep diff --git a/third_party/xla/xla/hlo/transforms/collectives/all_gather_cse_test.cc b/third_party/xla/xla/hlo/transforms/collectives/all_gather_cse_test.cc index 4e726e934df32e..e5d23ca53cf6df 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/all_gather_cse_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/all_gather_cse_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/hlo/transforms/collectives/all_gather_cse.h" #include -#include #include #include diff --git a/third_party/xla/xla/hlo/transforms/collectives/async_collective_creator_test.cc b/third_party/xla/xla/hlo/transforms/collectives/async_collective_creator_test.cc index 159ab382a364ce..bf794478dc3711 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/async_collective_creator_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/async_collective_creator_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/side_effect_util.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "tsl/platform/statusor.h" @@ -363,11 +364,13 @@ TEST_F(AsyncCollectiveCreatorTest, PreserveFrontendAttributesAllGather) { HloInstruction* done = hlo_module->entry_computation()->root_instruction(); HloInstruction* start = done->mutable_operand(0); EXPECT_TRUE( - done->frontend_attributes().map().contains("_scheduling_group_id")); - EXPECT_EQ(done->frontend_attributes().map().at("_scheduling_group_id"), "0"); + done->frontend_attributes().map().contains(kXlaSchedulingGroupIdAttr)); + EXPECT_EQ(done->frontend_attributes().map().at(kXlaSchedulingGroupIdAttr), + "0"); EXPECT_TRUE( - start->frontend_attributes().map().contains("_scheduling_group_id")); - EXPECT_EQ(start->frontend_attributes().map().at("_scheduling_group_id"), "0"); + start->frontend_attributes().map().contains(kXlaSchedulingGroupIdAttr)); + EXPECT_EQ(start->frontend_attributes().map().at(kXlaSchedulingGroupIdAttr), + "0"); } TEST_F(AsyncCollectiveCreatorTest, PreserveFrontendAttributesAllReduce) { @@ -394,11 +397,13 @@ TEST_F(AsyncCollectiveCreatorTest, PreserveFrontendAttributesAllReduce) { HloInstruction* done = hlo_module->entry_computation()->root_instruction(); HloInstruction* start = done->mutable_operand(0); EXPECT_TRUE( - done->frontend_attributes().map().contains("_scheduling_group_id")); - EXPECT_EQ(done->frontend_attributes().map().at("_scheduling_group_id"), "0"); + done->frontend_attributes().map().contains(kXlaSchedulingGroupIdAttr)); + EXPECT_EQ(done->frontend_attributes().map().at(kXlaSchedulingGroupIdAttr), + "0"); EXPECT_TRUE( - start->frontend_attributes().map().contains("_scheduling_group_id")); - EXPECT_EQ(start->frontend_attributes().map().at("_scheduling_group_id"), "0"); + start->frontend_attributes().map().contains(kXlaSchedulingGroupIdAttr)); + EXPECT_EQ(start->frontend_attributes().map().at(kXlaSchedulingGroupIdAttr), + "0"); } TEST_F(AsyncCollectiveCreatorTest, @@ -421,11 +426,13 @@ TEST_F(AsyncCollectiveCreatorTest, HloInstruction* done = hlo_module->entry_computation()->root_instruction(); HloInstruction* start = done->mutable_operand(0); EXPECT_TRUE( - done->frontend_attributes().map().contains("_scheduling_group_id")); - EXPECT_EQ(done->frontend_attributes().map().at("_scheduling_group_id"), "0"); + done->frontend_attributes().map().contains(kXlaSchedulingGroupIdAttr)); + EXPECT_EQ(done->frontend_attributes().map().at(kXlaSchedulingGroupIdAttr), + "0"); EXPECT_TRUE( - start->frontend_attributes().map().contains("_scheduling_group_id")); - EXPECT_EQ(start->frontend_attributes().map().at("_scheduling_group_id"), "0"); + start->frontend_attributes().map().contains(kXlaSchedulingGroupIdAttr)); + EXPECT_EQ(start->frontend_attributes().map().at(kXlaSchedulingGroupIdAttr), + "0"); } TEST_F(AsyncCollectiveCreatorTest, PreserveFrontendAttributesAllToAll) { @@ -447,11 +454,13 @@ TEST_F(AsyncCollectiveCreatorTest, PreserveFrontendAttributesAllToAll) { HloInstruction* done = hlo_module->entry_computation()->root_instruction(); HloInstruction* start = done->mutable_operand(0); EXPECT_TRUE( - done->frontend_attributes().map().contains("_scheduling_group_id")); - EXPECT_EQ(done->frontend_attributes().map().at("_scheduling_group_id"), "0"); + done->frontend_attributes().map().contains(kXlaSchedulingGroupIdAttr)); + EXPECT_EQ(done->frontend_attributes().map().at(kXlaSchedulingGroupIdAttr), + "0"); EXPECT_TRUE( - start->frontend_attributes().map().contains("_scheduling_group_id")); - EXPECT_EQ(start->frontend_attributes().map().at("_scheduling_group_id"), "0"); + start->frontend_attributes().map().contains(kXlaSchedulingGroupIdAttr)); + EXPECT_EQ(start->frontend_attributes().map().at(kXlaSchedulingGroupIdAttr), + "0"); } } // namespace } // namespace xla diff --git a/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.cc b/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.cc index 9e4ad0e5cb2ba7..8806aa01ee0cc1 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.cc @@ -15,8 +15,8 @@ limitations under the License. #include "xla/hlo/transforms/collectives/collective_quantizer.h" +#include "xla/hlo/analysis/hlo_replication_analysis.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/hlo_replication_analysis.h" #include "xla/service/pattern_matcher.h" #include "xla/shape_util.h" @@ -148,7 +148,7 @@ std::vector FindDequantizationSubgraphRecursive( return {}; } - subgraph.emplace_back(instr); + subgraph.push_back(instr); if (Match(instr, ConvertToWiderType())) { return subgraph; } @@ -193,7 +193,7 @@ std::optional IsSupportedDequantization( ScalarBroadcast(&subgraph.scale_bcast))))) { subgraph.unaries = {candidate_subgraph.begin() + 2, candidate_subgraph.end()}; - } else if (candidate_subgraph.size() > 0 && + } else if (!candidate_subgraph.empty() && Match(candidate_subgraph[0], m::Convert(&subgraph.convert))) { subgraph.unaries = {candidate_subgraph.begin() + 1, candidate_subgraph.end()}; @@ -231,7 +231,7 @@ std::optional IsSupportedQuantization( BitcastPreservesElementType(), m::Copy(), m::Reshape(), m::Slice(), m::Multiply(), m::Divide(), m::Clamp()))) { if (instr->user_count() > 0) { - ops.emplace_back(instr); + ops.push_back(instr); instr = instr->users()[0]; continue; } @@ -239,7 +239,7 @@ std::optional IsSupportedQuantization( } if (Match(instr, ConvertToNarrowerType())) { - ops.emplace_back(instr); + ops.push_back(instr); break; } VLOG(5) << "Unsupported instruction."; @@ -265,8 +265,7 @@ std::optional IsSupportedQuantization( ScalarBroadcast(&subgraph.scale_bcast)), ScalarBroadcast(m::Constant())))))) { subgraph.unaries = {ops.begin(), ops.end() - 3}; - } else if (ops.size() > 0 && - Match(ops.back(), m::Convert(&subgraph.convert))) { + } else if (!ops.empty() && Match(ops.back(), m::Convert(&subgraph.convert))) { subgraph.unaries = {ops.begin(), ops.end() - 1}; } else { VLOG(5) << "Did not find type conversion or quantization pattern."; diff --git a/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer_test.cc b/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer_test.cc index 1a47be62d9fff1..8e23a37f4f08be 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer_test.cc @@ -44,10 +44,10 @@ class CollectiveQuantizerTest : public HloHardwareIndependentTestBase { TEST_F(CollectiveQuantizerTest, AllGatherConvert) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 ENTRY entry { param = bf16[8,4,8,128] parameter(0) - all-gather = bf16[8,32,8,128] all-gather(param), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1 + all-gather = bf16[8,32,8,128] all-gather(param), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1, use_global_device_ids=true ROOT convert = f8e4m3fn[8,32,8,128] convert(all-gather) } )"; @@ -63,10 +63,10 @@ TEST_F(CollectiveQuantizerTest, AllGatherConvert) { TEST_F(CollectiveQuantizerTest, AllGatherConvertUnary) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 ENTRY entry { param = bf16[8,4,8,128] parameter(0) - all-gather = bf16[8,32,8,128] all-gather(param), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1 + all-gather = bf16[8,32,8,128] all-gather(param), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1, use_global_device_ids=true reshape = bf16[8,32,1024] reshape(all-gather) slice = bf16[8,32,512] slice(reshape), slice={[0:8], [0:32], [256:768]} ROOT convert = f8e4m3fn[8,32,512] convert(slice) @@ -85,7 +85,7 @@ TEST_F(CollectiveQuantizerTest, AllGatherConvertUnary) { TEST_F(CollectiveQuantizerTest, AllGatherQuantize) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 ENTRY entry { param = bf16[8,4,8,128] parameter(0) all-gather = bf16[8,32,8,128] all-gather(param), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1, use_global_device_ids=true @@ -114,7 +114,7 @@ TEST_F(CollectiveQuantizerTest, AllGatherQuantize) { TEST_F(CollectiveQuantizerTest, AllToAllQuantize) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 ENTRY entry { param = bf16[8,32,8,128] parameter(0) all-to-all = bf16[8,32,8,128] all-to-all(param), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1 @@ -143,7 +143,7 @@ TEST_F(CollectiveQuantizerTest, AllToAllQuantize) { TEST_F(CollectiveQuantizerTest, CollectiveBroadcastQuantize) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 ENTRY entry { param = bf16[8,32,8,128] parameter(0) collective-broadcast = bf16[8,32,8,128] collective-broadcast(param), replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1 @@ -173,7 +173,7 @@ TEST_F(CollectiveQuantizerTest, CollectiveBroadcastQuantize) { TEST_F(CollectiveQuantizerTest, CollectivePermuteQuantize) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 ENTRY entry { param = bf16[8,32,8,128] parameter(0) collective-permute = bf16[8,32,8,128] collective-permute(param), source_target_pairs={{0,1},{2,3},{4,5},{6,7}}, channel_id=1 @@ -203,7 +203,7 @@ TEST_F(CollectiveQuantizerTest, CollectivePermuteQuantize) { TEST_F(CollectiveQuantizerTest, AllGatherQuantizeUnary) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 ENTRY entry { param = bf16[8,4,8,128] parameter(0) all-gather = bf16[8,32,8,128] all-gather(param), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1, use_global_device_ids=true @@ -234,10 +234,10 @@ TEST_F(CollectiveQuantizerTest, AllGatherQuantizeUnary) { TEST_F(CollectiveQuantizerTest, AllGatherQuantizeMultiUser) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 ENTRY entry { param = bf16[8,4,8,128] parameter(0) - all-gather = bf16[8,32,8,128] all-gather(param), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1 + all-gather = bf16[8,32,8,128] all-gather(param), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1, use_global_device_ids=true scale = bf16[] parameter(1), sharding={replicated} scale_bcast = bf16[8,32,8,128] broadcast(scale), dimensions={} divide = bf16[8,32,8,128] divide(all-gather, scale_bcast) @@ -258,10 +258,10 @@ TEST_F(CollectiveQuantizerTest, AllGatherQuantizeMultiUser) { TEST_F(CollectiveQuantizerTest, AllGatherQuantizeNonReplicatedScale) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 ENTRY entry { param = bf16[8,4,8,128] parameter(0) - all-gather = bf16[8,32,8,128] all-gather(param), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1 + all-gather = bf16[8,32,8,128] all-gather(param), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1, use_global_device_ids=true scale = bf16[] parameter(1) scale_bcast = bf16[8,32,8,128] broadcast(scale), dimensions={} divide = bf16[8,32,8,128] divide(all-gather, scale_bcast) @@ -281,7 +281,7 @@ TEST_F(CollectiveQuantizerTest, AllGatherQuantizeNonReplicatedScale) { TEST_F(CollectiveQuantizerTest, AllGatherQuantizePartialReplication) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 max { a = f32[] parameter(0) b = f32[] parameter(1) @@ -321,7 +321,7 @@ TEST_F(CollectiveQuantizerTest, AllGatherQuantizePartialReplication) { TEST_F(CollectiveQuantizerTest, AllToAllQuantizePartialReplication) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 max { a = f32[] parameter(0) b = f32[] parameter(1) @@ -362,7 +362,7 @@ TEST_F(CollectiveQuantizerTest, AllToAllQuantizePartialReplication) { TEST_F(CollectiveQuantizerTest, AllToAllQuantizePartialReplicationSeparateComputation) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 max { a = f32[] parameter(0) b = f32[] parameter(1) @@ -410,7 +410,7 @@ TEST_F(CollectiveQuantizerTest, TEST_F(CollectiveQuantizerTest, AllGatherQuantizePartialReplicationGroupMismatch) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 max { a = f32[] parameter(0) b = f32[] parameter(1) @@ -447,7 +447,7 @@ TEST_F(CollectiveQuantizerTest, TEST_F(CollectiveQuantizerTest, AllToAllQuantizePartialReplicationGroupMismatchSeparateComputation) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 max { a = f32[] parameter(0) b = f32[] parameter(1) @@ -486,11 +486,11 @@ TEST_F(CollectiveQuantizerTest, TEST_F(CollectiveQuantizerTest, ConvertAllGather) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 ENTRY entry { param = f8e4m3fn[8,4,8,128] parameter(0) convert = bf16[8,4,8,128] convert(param) - ROOT all-gather = bf16[8,32,8,128] all-gather(convert), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1 + ROOT all-gather = bf16[8,32,8,128] all-gather(convert), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1, use_global_device_ids=true } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -506,13 +506,13 @@ TEST_F(CollectiveQuantizerTest, ConvertAllGather) { TEST_F(CollectiveQuantizerTest, ConvertAllGatherUnary) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 ENTRY entry { param = f8e4m3fn[8,4,8,128] parameter(0) convert = bf16[8,4,8,128] convert(param) reshape = bf16[8,4,1024] reshape(convert) slice = bf16[8,4,512] slice(reshape), slice={[0:8], [0:4], [256:768]} - ROOT all-gather = bf16[8,32,512] all-gather(slice), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1 + ROOT all-gather = bf16[8,32,512] all-gather(slice), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1, use_global_device_ids=true } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -529,7 +529,7 @@ TEST_F(CollectiveQuantizerTest, ConvertAllGatherUnary) { TEST_F(CollectiveQuantizerTest, DequantizeAllGather) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 ENTRY entry { param = f8e4m3fn[8,4,8,128] parameter(0) convert = bf16[8,4,8,128] convert(param) @@ -553,7 +553,7 @@ TEST_F(CollectiveQuantizerTest, DequantizeAllGather) { TEST_F(CollectiveQuantizerTest, DequantizeAllToAll) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 ENTRY entry { param = f8e4m3fn[8,32,8,128] parameter(0) convert = bf16[8,32,8,128] convert(param) @@ -577,7 +577,7 @@ TEST_F(CollectiveQuantizerTest, DequantizeAllToAll) { TEST_F(CollectiveQuantizerTest, DequantizeCollectiveBroadcast) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 ENTRY entry { param = f8e4m3fn[8,32,8,128] parameter(0) convert = bf16[8,32,8,128] convert(param) @@ -602,7 +602,7 @@ TEST_F(CollectiveQuantizerTest, DequantizeCollectiveBroadcast) { TEST_F(CollectiveQuantizerTest, DequantizeCollectivePermute) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 ENTRY entry { param = f8e4m3fn[8,32,8,128] parameter(0) convert = bf16[8,32,8,128] convert(param) @@ -626,7 +626,7 @@ TEST_F(CollectiveQuantizerTest, DequantizeCollectivePermute) { TEST_F(CollectiveQuantizerTest, DequantizeAllGatherUnary) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 ENTRY entry { param = f8e4m3fn[8,4,8,128] parameter(0) convert = bf16[8,4,8,128] convert(param) @@ -656,7 +656,7 @@ TEST_F(CollectiveQuantizerTest, DequantizeAllGatherUnary) { TEST_F(CollectiveQuantizerTest, DequantizeAllGatherPartialReplication) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 max { a = f32[] parameter(0) b = f32[] parameter(1) @@ -691,7 +691,7 @@ TEST_F(CollectiveQuantizerTest, DequantizeAllGatherPartialReplication) { TEST_F(CollectiveQuantizerTest, DequantizeAllToAllPartialReplication) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 max { a = f32[] parameter(0) b = f32[] parameter(1) @@ -727,7 +727,7 @@ TEST_F(CollectiveQuantizerTest, DequantizeAllToAllPartialReplication) { TEST_F(CollectiveQuantizerTest, DequantizeAllToAllPartialReplicationSeparateComputation) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 max { a = f32[] parameter(0) b = f32[] parameter(1) @@ -770,7 +770,7 @@ TEST_F(CollectiveQuantizerTest, TEST_F(CollectiveQuantizerTest, DequantizeAllGatherPartialReplicationGroupMismatch) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 max { a = f32[] parameter(0) b = f32[] parameter(1) @@ -802,7 +802,7 @@ TEST_F(CollectiveQuantizerTest, TEST_F(CollectiveQuantizerTest, DequantizeAllToAllPartialReplicationGroupMismatchSeparateComputation) { absl::string_view hlo_string = R"( - HloModule module + HloModule module, num_partitions=8 max { a = f32[] parameter(0) b = f32[] parameter(1) diff --git a/third_party/xla/xla/hlo/transforms/collectives/convert_async_collectives_to_sync_test.cc b/third_party/xla/xla/hlo/transforms/collectives/convert_async_collectives_to_sync_test.cc index 4d21c33f0d44e0..f429941314f2e2 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/convert_async_collectives_to_sync_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/convert_async_collectives_to_sync_test.cc @@ -38,7 +38,7 @@ namespace { namespace m = xla::testing::opcode_matchers; // Note: The pass only processes modules that are already scheduled. If the test -// does not work as epxected, make sure to check if "is_scheduled=true" is added +// does not work as expected, make sure to check if "is_scheduled=true" is added // to the HLO module string. class ConvertAsyncCollectivesToSyncTest : public HloHardwareIndependentTestBase { diff --git a/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation.cc b/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation.cc index 3de31a8315ba50..8c1db7e4cc10fa 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/hlo/transforms/collectives/infeed_token_propagation.h" #include -#include #include #include "absl/container/flat_hash_set.h" @@ -24,7 +23,10 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -32,6 +34,7 @@ limitations under the License. #include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/service/call_graph.h" +#include "xla/service/tuple_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -40,6 +43,83 @@ limitations under the License. namespace xla { namespace { +HloInstruction* InfeedToken(HloInstruction* infeed) { + CHECK_EQ(infeed->opcode(), HloOpcode::kInfeed); + for (HloInstruction* user : infeed->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() == 1) { + return user; + } + } + return nullptr; +} + +HloInstruction* InfeedChainBegin(HloInstruction* infeed) { + CHECK_EQ(infeed->opcode(), HloOpcode::kInfeed); + HloInstruction* begin = infeed; + while (begin->operand(0)->opcode() == HloOpcode::kGetTupleElement && + begin->operand(0)->operand(0)->opcode() == HloOpcode::kInfeed) { + begin = begin->mutable_operand(0)->mutable_operand(0); + } + return begin; +} + +HloInstruction* InfeedChainEnd(HloInstruction* infeed) { + CHECK_EQ(infeed->opcode(), HloOpcode::kInfeed); + HloInstruction* end = infeed; + HloInstruction* token = InfeedToken(end); + while (token != nullptr && token->user_count() == 1) { + if (token->users()[0]->opcode() == HloOpcode::kInfeed) { + end = token->users()[0]; + token = InfeedToken(end); + } else { + break; + } + } + return end; +} + +HloInstruction* OutfeedChainBegin(HloInstruction* outfeed) { + CHECK_EQ(outfeed->opcode(), HloOpcode::kOutfeed); + HloInstruction* begin = outfeed; + while (begin->operand(1)->opcode() == HloOpcode::kOutfeed) { + begin = begin->mutable_operand(1); + } + return begin; +} + +HloInstruction* OutfeedChainEnd(HloInstruction* outfeed) { + CHECK_EQ(outfeed->opcode(), HloOpcode::kOutfeed); + HloInstruction* end = outfeed; + while (end->user_count() == 1 && + end->users()[0]->opcode() == HloOpcode::kOutfeed) { + end = end->users()[0]; + } + return end; +} + +HloInstruction* ChainBegin(HloInstruction* instruction) { + if (instruction->opcode() == HloOpcode::kInfeed) { + return InfeedChainBegin(instruction); + } else if (instruction->opcode() == HloOpcode::kOutfeed) { + return OutfeedChainBegin(instruction); + } else { + LOG(FATAL) << "Unexpected opcode"; + } + return nullptr; +} + +HloInstruction* ChainEnd(HloInstruction* instruction) { + if (instruction->opcode() == HloOpcode::kInfeed) { + return InfeedChainEnd(instruction); + } else if (instruction->opcode() == HloOpcode::kOutfeed) { + return OutfeedChainEnd(instruction); + } else { + LOG(FATAL) << "Unexpected opcode"; + } + return nullptr; +} + bool IsDanglingInfeed(HloInstruction* infeed) { CHECK(infeed->opcode() == HloOpcode::kInfeed); if (infeed->has_sharding()) { @@ -48,14 +128,14 @@ bool IsDanglingInfeed(HloInstruction* infeed) { } // Check for dangling input token. - if (const HloInstruction* after_all = infeed->operand(0); + if (const HloInstruction* after_all = ChainBegin(infeed)->operand(0); after_all->opcode() != HloOpcode::kAfterAll || after_all->operand_count() != 0) { return false; } // Check for dangling output token. - for (const HloInstruction* user : infeed->users()) { + for (const HloInstruction* user : ChainEnd(infeed)->users()) { if (user->opcode() == HloOpcode::kGetTupleElement && user->tuple_index() == 1) { return false; @@ -73,34 +153,20 @@ bool IsDanglingOutfeed(HloInstruction* outfeed) { } // Check for dangling input token. - if (const HloInstruction* after_all = outfeed->operand(1); + if (const HloInstruction* after_all = OutfeedChainBegin(outfeed)->operand(1); after_all->opcode() != HloOpcode::kAfterAll || after_all->operand_count() != 0) { return false; } // Check for dangling output token. - if (outfeed->user_count() != 0) { + if (OutfeedChainEnd(outfeed)->user_count() != 0) { return false; } return true; } -HloInstruction* ReconstructTuple(HloInstruction* tuple) { - CHECK(tuple->shape().IsTuple()); - HloComputation* computation = tuple->parent(); - - std::vector gtes; - gtes.resize(tuple->shape().tuple_shapes_size()); - for (int64_t idx = 0; idx < gtes.size(); ++idx) { - gtes[idx] = computation->AddInstruction( - HloInstruction::CreateGetTupleElement(tuple, idx)); - } - - return computation->AddInstruction(HloInstruction::CreateTuple(gtes)); -} - absl::StatusOr InsertTokenIntoTuple(HloInstruction* tuple, bool add_token_operand) { CHECK(tuple->shape().IsTuple()); @@ -109,7 +175,7 @@ absl::StatusOr InsertTokenIntoTuple(HloInstruction* tuple, // Recreate the original tuple, we'll need to pass this to all the users. // Trying to use tuple->ReplaceAllUsesWith(original_tuple) cause a cycle. std::vector original_users = tuple->users(); - HloInstruction* original_tuple = ReconstructTuple(tuple); + HloInstruction* original_tuple = TupleUtil::Duplicate(tuple); for (HloInstruction* original_user : original_users) { for (int64_t idx : original_user->operand_indices(tuple)) { TF_RETURN_IF_ERROR( @@ -159,7 +225,7 @@ absl::Status CanonicalizeConditionalInstruction(HloInstruction* conditional) { // Explicitly disjoin computation parameters from branch inputs, so we can // insert tokens into the input tuple. if (branch_tuple->opcode() == HloOpcode::kParameter) { - branch_tuple = ReconstructTuple(branch_tuple); + branch_tuple = TupleUtil::Duplicate(branch_tuple); TF_RETURN_IF_ERROR( conditional->ReplaceOperandWith(branch_operand_idx, branch_tuple)); } @@ -167,7 +233,7 @@ absl::Status CanonicalizeConditionalInstruction(HloInstruction* conditional) { // Explicitly make the root of the branch a tuple. HloInstruction* root = branch->root_instruction(); if (root->opcode() != HloOpcode::kTuple) { - root = ReconstructTuple(root); + root = TupleUtil::Duplicate(root); branch->set_root_instruction(root); } } @@ -179,7 +245,7 @@ absl::Status CanonicalizeConditionalInstruction(HloInstruction* conditional) { // Explicitly disjoin the conditional from being a computation root, so that // we can insert tokens into, while preserving the original computation shape. if (conditional->IsRoot()) { - HloInstruction* new_root = ReconstructTuple(conditional); + HloInstruction* new_root = TupleUtil::Duplicate(conditional); conditional->parent()->set_root_instruction(new_root); } @@ -239,20 +305,20 @@ absl::Status CanonicalizeWhileInstruction(HloInstruction* loop) { // Explicitly disjoin computation parameters from loop inputs, so we can // insert tokens into the input tuple. if (loop_tuple->opcode() == HloOpcode::kParameter) { - loop_tuple = ReconstructTuple(loop_tuple); + loop_tuple = TupleUtil::Duplicate(loop_tuple); TF_RETURN_IF_ERROR(loop->ReplaceOperandWith(0, loop_tuple)); } // Explicitly make the root of the body a tuple. if (root->opcode() != HloOpcode::kTuple) { - root = ReconstructTuple(root); + root = TupleUtil::Duplicate(root); body->set_root_instruction(root); } // Explicitly disjoin the loop from being a computation root, so that // we can insert tokens into, while preserving the original computation shape. if (loop->IsRoot()) { - HloInstruction* new_root = ReconstructTuple(loop); + HloInstruction* new_root = TupleUtil::Duplicate(loop); loop->parent()->set_root_instruction(new_root); } @@ -338,6 +404,9 @@ absl::Status InfeedTokenPropagation::PropagateTokenThroughWhileBody() { TF_ASSIGN_OR_RETURN( input_token_, InsertTokenIntoTuple(while_tuple, /*add_token_operand=*/true)); + // Retrieve the actual token added to the tuple. + input_token_ = input_token_->mutable_operand(0)->mutable_operand( + input_token_->tuple_index()); TF_RETURN_IF_ERROR( dangling_instruction_->ReplaceOperandWithDifferentShape(0, while_tuple)); @@ -349,8 +418,42 @@ absl::Status InfeedTokenPropagation::PropagateTokenThroughWhileBody() { return absl::OkStatus(); } -absl::Status InfeedTokenPropagation::PropagateToken() { +absl::Status InfeedTokenPropagation::PropagateToken( + const HloOrdering& ordering) { HloComputation* comp = dangling_instruction_->parent(); + if (dangling_instruction_->opcode() != HloOpcode::kInfeed && + dangling_instruction_->opcode() != HloOpcode::kOutfeed) { + for (HloInstruction* instruction : comp->instructions()) { + if (instruction->opcode() == original_opcode_) { + HloInstruction* begin = ChainBegin(instruction); + HloInstruction* end = ChainEnd(instruction); + if (ordering.ExecutesBefore(end, dangling_instruction_)) { + // Parent infeed happens before child infeed. Stitch via parent result + // token. + CHECK_EQ(begin->opcode(), HloOpcode::kInfeed); + HloInstruction* parent_output_token = comp->AddInstruction( + HloInstruction::CreateGetTupleElement(end, 1)); + TF_RETURN_IF_ERROR( + input_token_->ReplaceAllUsesWith(parent_output_token)); + input_token_ = begin->mutable_operand(0); + } else if (ordering.ExecutesBefore(dangling_instruction_, begin)) { + // Parent outfeed happens after child infeed. Stitch via parent input + // token. + CHECK_EQ(begin->opcode(), HloOpcode::kOutfeed); + TF_RETURN_IF_ERROR(begin->ReplaceOperandWith(1, output_token_)); + output_token_ = end; + } else { + LOG(WARNING) << absl::StrFormat( + "Execution order of %s, %s and %s is undefined. This may lead to " + "incorrect results", + begin->name(), end->name(), dangling_instruction_->name()); + } + // We assume that a well defined HLO graph only contains a single + // infeed chain per computation. + break; + } + } + } if (comp->IsEntryComputation()) { return absl::OkStatus(); } @@ -378,12 +481,12 @@ absl::Status InfeedTokenPropagation::PropagateToken() { return absl::OkStatus(); } - return PropagateToken(); + return PropagateToken(ordering); } absl::StatusOr InfeedTokenPropagation::Run( HloModule* module, - const absl::flat_hash_set& execution_threads) { + const absl::flat_hash_set& execution_threads) { VLOG(5) << "Before InfeedTokenPropagation:"; XLA_VLOG_LINES(5, module->ToString()); @@ -397,10 +500,15 @@ absl::StatusOr InfeedTokenPropagation::Run( IsDanglingInfeed(instruction)) { VLOG(1) << "Found dangling infeed: " << instruction->ToString(); dangling_infeeds.push_back(instruction); - } else if (instruction->opcode() == HloOpcode::kOutfeed && - IsDanglingOutfeed(instruction)) { + break; + } + } + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kOutfeed && + IsDanglingOutfeed(instruction)) { VLOG(1) << "Found dangling outfeed: " << instruction->ToString(); dangling_outfeeds.push_back(instruction); + break; } } } @@ -408,28 +516,43 @@ absl::StatusOr InfeedTokenPropagation::Run( bool changed = !dangling_infeeds.empty() || !dangling_outfeeds.empty(); if (changed) { - call_graph_ = CallGraph::Build(module); + call_graph_ = CallGraph::Build(module, execution_threads); if (!call_graph_->IsFlattened()) { return FailedPrecondition( "Call graph must be flattened before infeed token propagation."); } - } - - for (HloInstruction* dangling_infeed : dangling_infeeds) { - dangling_instruction_ = dangling_infeed; - input_token_ = dangling_infeed->mutable_operand(0); - output_token_ = dangling_infeed->AddInstruction( - HloInstruction::CreateGetTupleElement(dangling_infeed, 1)); - TF_RETURN_IF_ERROR(PropagateToken()); - } - for (HloInstruction* dangling_outfeed : dangling_outfeeds) { - dangling_instruction_ = dangling_outfeed; - input_token_ = dangling_outfeed->mutable_operand(1); - output_token_ = dangling_outfeed; - TF_RETURN_IF_ERROR(PropagateToken()); - } + DependencyHloOrdering ordering = DependencyHloOrdering(module); + + for (HloInstruction* dangling_infeed : dangling_infeeds) { + // In the process of token propagation, we might have stitched two + // previously dangling infeeds token, causing both to no longer be + // dangling. + if (!IsDanglingInfeed(dangling_infeed)) { + continue; + } + dangling_instruction_ = dangling_infeed; + original_opcode_ = HloOpcode::kInfeed; + input_token_ = ChainBegin(dangling_infeed)->mutable_operand(0); + output_token_ = + ChainEnd(dangling_infeed) + ->AddInstruction( + HloInstruction::CreateGetTupleElement(dangling_infeed, 1)); + TF_RETURN_IF_ERROR(PropagateToken(ordering)); + } + for (HloInstruction* dangling_outfeed : dangling_outfeeds) { + // In the process of token propagation, we might have stitched two + // previously dangling outfeeds token, causing both to no longer be + // dangling. + if (!IsDanglingOutfeed(dangling_outfeed)) { + continue; + } + dangling_instruction_ = dangling_outfeed; + original_opcode_ = HloOpcode::kOutfeed; + input_token_ = ChainBegin(dangling_outfeed)->mutable_operand(1); + output_token_ = ChainEnd(dangling_outfeed); + TF_RETURN_IF_ERROR(PropagateToken(ordering)); + } - if (changed) { TF_RETURN_IF_ERROR( TupleSimplifier().Run(module, execution_threads).status()); TF_RETURN_IF_ERROR(HloDCE().Run(module, execution_threads).status()); diff --git a/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation.h b/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation.h index f1e3080b7a07e7..d95f218fbc867d 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation.h +++ b/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation.h @@ -17,14 +17,14 @@ limitations under the License. #define XLA_HLO_TRANSFORMS_COLLECTIVES_INFEED_TOKEN_PROPAGATION_H_ #include -#include -#include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/call_graph.h" @@ -38,19 +38,21 @@ namespace xla { // This pass assumes the HLO graph is flattened. class InfeedTokenPropagation : public HloModulePass { public: - std::string_view name() const override { return "infeed-token-propagation"; } + absl::string_view name() const override { return "infeed-token-propagation"; } using HloPassInterface::Run; absl::StatusOr Run( HloModule* module, - const absl::flat_hash_set& execution_threads) override; + const absl::flat_hash_set& execution_threads) override; private: - absl::Status PropagateToken(); + absl::Status PropagateToken(const HloOrdering& ordering); absl::Status PropagateTokenThroughWhileBody(); absl::Status PropagateTokenThroughConditionalBranch(); std::unique_ptr call_graph_; + HloInstruction* dangling_instruction_ = nullptr; + HloOpcode original_opcode_; HloInstruction* input_token_ = nullptr; HloInstruction* output_token_ = nullptr; }; diff --git a/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation_test.cc b/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation_test.cc index 2be79575afe8b2..f702afea5b8425 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation_test.cc @@ -15,11 +15,11 @@ limitations under the License. #include "xla/hlo/transforms/collectives/infeed_token_propagation.h" -#include #include #include #include +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" @@ -36,7 +36,7 @@ class InfeedTokenPropagationTest : public HloHardwareIndependentTestBase { }; TEST_F(InfeedTokenPropagationTest, EntryComputationInfeed) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main ENTRY main { @@ -52,7 +52,7 @@ ENTRY main { } TEST_F(InfeedTokenPropagationTest, EntryComputationOutfeed) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main ENTRY main { @@ -70,7 +70,7 @@ ENTRY main { } TEST_F(InfeedTokenPropagationTest, ConditionalInfeed) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main true_comp { @@ -124,7 +124,7 @@ ENTRY main { } TEST_F(InfeedTokenPropagationTest, ConditionalOutfeed) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main true_comp { @@ -178,7 +178,7 @@ ENTRY main { } TEST_F(InfeedTokenPropagationTest, ConditionalDuplicateOperand) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main true_comp { @@ -231,7 +231,7 @@ ENTRY main { } TEST_F(InfeedTokenPropagationTest, NonTupleConditional) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main true_comp { @@ -286,7 +286,7 @@ ENTRY main { } TEST_F(InfeedTokenPropagationTest, DisjointConditionalOutfeed) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main true_comp { @@ -340,7 +340,7 @@ ENTRY main { } TEST_F(InfeedTokenPropagationTest, WhileInfeed) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main comp { @@ -394,7 +394,7 @@ ENTRY main { } TEST_F(InfeedTokenPropagationTest, WhileOutfeed) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main comp { @@ -452,7 +452,7 @@ ENTRY main { } TEST_F(InfeedTokenPropagationTest, DisjointWhileOutfeed) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main comp { @@ -508,7 +508,7 @@ ENTRY main { } TEST_F(InfeedTokenPropagationTest, NonTupleWhile) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main comp { @@ -563,7 +563,7 @@ ENTRY main { } TEST_F(InfeedTokenPropagationTest, NestedInfeedOutfeed) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main true_comp { @@ -649,5 +649,180 @@ ENTRY main { HloComputation* false_comp = FindComputation(module.get(), "false_comp"); EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); } + +TEST_F(InfeedTokenPropagationTest, WhileNestedAfterInfeed) { + constexpr absl::string_view hlo = R"( +HloModule main + +body { + ROOT arg.0 = s32[] parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) +} + +cond { + arg.0 = s32[] parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + gte.0 = get-tuple-element(infeed.0), index=0 + gte.1 = get-tuple-element(infeed.0), index=1 + infeed.1 = (s32[], token[]) infeed(gte.1) + gte.2 = get-tuple-element(infeed.1), index=0 + add.0 = add(gte.0, gte.2) + ROOT while.0 = s32[] while(add.0), body=body, condition=cond +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The second infeed should send its token into the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_THAT(loop, op::While(op::Tuple( + op::Add(), + op::GetTupleElement(op::Infeed(op::GetTupleElement( + op::Infeed(op::AfterAll()), 1)), + 1)))); +} + +TEST_F(InfeedTokenPropagationTest, WhileNestedBeforeOutfeed) { + constexpr absl::string_view hlo = R"( +HloModule main + +body { + ROOT arg.0 = s32[] parameter(0) + token.0 = after-all() + outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=s32[] +} + +cond { + arg.0 = s32[] parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + arg.0 = s32[] parameter(0) + ROOT while.0 = s32[] while(arg.0), body=body, condition=cond + token.0 = after-all() + outfeed.1 = token[] outfeed(while.0, token.0), outfeed_shape=s32[] + outfeed.2 = token[] outfeed(while.0, outfeed.1), outfeed_shape=s32[] +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The first outfeed should get its token from the loop. + // The second outfeed should get its token from the first outfeed. + HloInstruction* outfeed_2 = FindInstruction(module.get(), "outfeed.2"); + EXPECT_THAT(outfeed_2, + op::Outfeed(op::GetTupleElement(), + op::Outfeed(op::GetTupleElement(), + op::GetTupleElement(op::While(), 1)))); +} + +TEST_F(InfeedTokenPropagationTest, ConditionalNestedAfterInfeed) { + constexpr absl::string_view hlo = R"( +HloModule main + +true_comp { + ROOT arg.0 = (s32[]) parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) +} + +false_comp { + ROOT arg.0 = (s32[]) parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) +} + +ENTRY main { + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + gte.0 = get-tuple-element(infeed.0), index=0 + gte.1 = get-tuple-element(infeed.0), index=1 + infeed.1 = (s32[], token[]) infeed(gte.1) + gte.2 = get-tuple-element(infeed.1), index=0 + add.0 = add(gte.0, gte.2) + tuple.0 = tuple(add.0) + pred.0 = pred[] constant(true) + ROOT cond.0 = (s32[]) conditional(pred.0, tuple.0, tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The conditional should get both its tokens from the second infeed. + // The second infeed should get its token from the first infeed. + HloInstruction* conditional = FindInstruction(module.get(), "cond.0"); + EXPECT_THAT(conditional, + op::Conditional( + op::Constant(), + op::Tuple(op::Add(), op::GetTupleElement( + op::Infeed(op::GetTupleElement( + op::Infeed(op::AfterAll()), 1)), + 1)), + op::Tuple(op::Add(), op::GetTupleElement( + op::Infeed(op::GetTupleElement( + op::Infeed(op::AfterAll()), 1)), + 1)))); +} + +TEST_F(InfeedTokenPropagationTest, ConditionalNestedBeforeOutfeed) { + constexpr absl::string_view hlo = R"( +HloModule main + +true_comp { + ROOT arg.0 = (s32[]) parameter(0) + token.0 = after-all() + gte.0 = get-tuple-element(arg.0), index=0 + outfeed.0 = token[] outfeed(gte.0, token.0), outfeed_shape=s32[] +} + +false_comp { + ROOT arg.0 = (s32[]) parameter(0) + token.0 = after-all() + gte.0 = get-tuple-element(arg.0), index=0 + outfeed.1 = token[] outfeed(gte.0, token.0), outfeed_shape=s32[] +} + +ENTRY main { + arg.0 = s32[] parameter(0) + tuple.0 = tuple(arg.0) + pred.0 = pred[] constant(true) + ROOT cond.0 = (s32[]) conditional(pred.0, tuple.0, tuple.0), true_computation=true_comp, false_computation=false_comp + gte.0 = get-tuple-element(cond.0), index=0 + token.0 = after-all() + outfeed.2 = token[] outfeed(gte.0, token.0), outfeed_shape=s32[] + outfeed.3 = token[] outfeed(gte.0, outfeed.2), outfeed_shape=s32[] +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The second outfeed should get its token from the first outfeed. + // The first outfeed should get its token from the conditional. + // Note, there is a quirk - each branch of the of the conditional will produce + // its own token, but the first outfeed can only consume one of those. + // I'm not certain if we deterministically will consume last token in the + // conditional result. + HloInstruction* outfeed_3 = FindInstruction(module.get(), "outfeed.3"); + EXPECT_THAT( + outfeed_3, + op::Outfeed(op::GetTupleElement(), + op::Outfeed(op::GetTupleElement(), + op::GetTupleElement(op::Conditional(), 2)))); +} } // namespace } // namespace xla diff --git a/third_party/xla/xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup_test.cc b/third_party/xla/xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup_test.cc index 2f9717ae57628c..b268c99dc489f9 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup_test.cc @@ -15,10 +15,9 @@ limitations under the License. #include "xla/hlo/transforms/collectives/while_loop_all_reduce_code_motion_setup.h" -#include - #include #include +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" @@ -35,7 +34,7 @@ class ReorderReduceTransposeTest : public HloHardwareIndependentTestBase { }; TEST_F(ReorderReduceTransposeTest, SimpleReduceScatterTransposeInWhileBody) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main %reduction { @@ -82,7 +81,7 @@ ENTRY main { TEST_F(ReorderReduceTransposeTest, ReduceScatterConvertTransposeNotInWhileBody) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main %reduction { @@ -105,7 +104,7 @@ ENTRY main { } TEST_F(ReorderReduceTransposeTest, ReduceScatterConvertTransposeInWhileBody) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main %reduction { @@ -153,7 +152,7 @@ ENTRY main { TEST_F(ReorderReduceTransposeTest, ReduceScatterTransposeReshapeDynamicUpdateSliceInWhileBody) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main %reduction { @@ -208,7 +207,7 @@ class ReorderConvertReduceAddTest : public HloHardwareIndependentTestBase { }; TEST_F(ReorderConvertReduceAddTest, SimpleConvertReduceScatterAddInWhileBody) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main %reduction { @@ -255,7 +254,7 @@ ENTRY main { } TEST_F(ReorderConvertReduceAddTest, ConvertAllReduceAddNotInWhileBody) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main %reduction { @@ -277,7 +276,7 @@ ENTRY main { } TEST_F(ReorderConvertReduceAddTest, ConvertReduceScatterAddInWhileBody) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main %reduction { @@ -324,7 +323,7 @@ ENTRY main { } TEST_F(ReorderConvertReduceAddTest, DisableReduceScatter) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main %reduction { @@ -361,7 +360,7 @@ ENTRY main { } TEST_F(ReorderConvertReduceAddTest, ConvertAllReduceAddInWhileBody) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main %reduction { diff --git a/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.cc b/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.cc index 6846a186c7e691..570afa9e3d501b 100644 --- a/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.cc +++ b/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.cc @@ -17,17 +17,106 @@ #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/host_memory_offload_annotations.h" #include "xla/side_effect_util.h" +#include "xla/tsl/platform/errors.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { +namespace { +absl::StatusOr GetCustomCallTarget( + absl::string_view external_annotation) { + if (external_annotation == + host_memory_offload_annotations::kMemoryTargetPinnedHost || + external_annotation == + host_memory_offload_annotations::kMemoryTargetUnpinnedHost) { + return host_memory_offload_annotations::kMoveToHostCustomCallTarget; + } + if (external_annotation == + host_memory_offload_annotations::kMemoryTargetDevice) { + return host_memory_offload_annotations::kMoveToDeviceCustomCallTarget; + } + if (external_annotation == + host_memory_offload_annotations::kMemoryTargetDeviceSram) { + return host_memory_offload_annotations::kPinToDeviceSramCustomCallTarget; + } + if (external_annotation == + host_memory_offload_annotations::kMemoryTargetPinnedDevice) { + return host_memory_offload_annotations::kPinToDeviceCustomCallTarget; + } + return absl::InvalidArgumentError( + absl::StrCat("Invalid external annotation: ", external_annotation)); +} + +absl::StatusOr +ConvertCustomCallWithExternalAnnotationToInternalAnnotation( + HloComputation* c, HloInstruction* instruction) { + const auto& frontend_attributes = instruction->frontend_attributes(); + const auto it = frontend_attributes.map().find(kXlaBufferPlacementAttr); + if (it == frontend_attributes.map().end()) { + return false; + } + // XLA currently does not differentiate between pinned and unpinned host + // memory. + const bool is_to_host_case = + (it->second == host_memory_offload_annotations::kMemoryTargetPinnedHost || + it->second == + host_memory_offload_annotations::kMemoryTargetUnpinnedHost); + const bool is_to_device_case = + (it->second == host_memory_offload_annotations::kMemoryTargetDevice || + it->second == host_memory_offload_annotations::kMemoryTargetDeviceSram || + it->second == + host_memory_offload_annotations::kMemoryTargetPinnedDevice); + if (!is_to_host_case && !is_to_device_case) { + return false; + } + const absl::StatusOr custom_call_target = + GetCustomCallTarget(it->second); + TF_RETURN_IF_ERROR(custom_call_target.status()); + if (is_to_host_case) { + VLOG(1) << "Process forward case: " << instruction->ToString(); + if (instruction->operand_count() != 1) { + return Internal( + "Custom calls with target %s must have exactly one operand. %s " + "has %d.", + host_memory_offload_annotations::kDevicePlacement, + instruction->name(), instruction->operand_count()); + } + HloInstruction* input = instruction->mutable_operand(0); + HloInstruction* move_to_host_custom_call = + c->AddInstruction(HloInstruction::CreateCustomCall( + input->shape(), {input}, *custom_call_target)); + if (instruction->has_sharding()) { + move_to_host_custom_call->set_sharding(instruction->sharding()); + } + TF_RETURN_IF_ERROR( + instruction->ReplaceAllUsesWith(move_to_host_custom_call)); + TF_RETURN_IF_ERROR(c->RemoveInstructionAndUnusedOperands(instruction)); + return true; + } else if (is_to_device_case) { + VLOG(1) << "Process backward case: " << instruction->ToString(); + HloInstruction* custom_call_operand = instruction->mutable_operand(0); + HloInstruction* new_result = + c->AddInstruction(HloInstruction::CreateCustomCall( + custom_call_operand->shape(), {custom_call_operand}, + *custom_call_target)); + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_result)); + TF_RETURN_IF_ERROR(c->RemoveInstructionAndUnusedOperands(instruction)); + return true; + } + return false; +} + +} // namespace + absl::StatusOr ConvertMemoryPlacementToInternalAnnotations::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { @@ -36,60 +125,11 @@ absl::StatusOr ConvertMemoryPlacementToInternalAnnotations::Run( for (HloInstruction* instruction : c->MakeInstructionPostOrder()) { if (instruction->IsCustomCall( host_memory_offload_annotations::kDevicePlacement)) { - const auto& frontend_attributes = instruction->frontend_attributes(); - const auto it = frontend_attributes.map().find(kXlaBufferPlacementAttr); - if (it == frontend_attributes.map().end()) { - continue; - } - // XLA currently does not differentiate between pinned and unpinned host - // memory. - const bool is_to_host_case = - (it->second == - host_memory_offload_annotations::kMemoryTargetPinnedHost || - it->second == - host_memory_offload_annotations::kMemoryTargetUnpinnedHost); - const bool is_to_device_case = - (it->second == - host_memory_offload_annotations::kMemoryTargetDevice); - if (!is_to_host_case && !is_to_device_case) { - continue; - } - if (is_to_host_case) { - VLOG(1) << "Process forward case: " << instruction->ToString(); - if (instruction->operand_count() != 1) { - return Internal( - "Custom calls with target %s must have exactly one operand. %s " - "has %d.", - host_memory_offload_annotations::kDevicePlacement, - instruction->name(), instruction->operand_count()); - } - HloInstruction* input = instruction->mutable_operand(0); - HloInstruction* move_to_host_custom_call = - c->AddInstruction(HloInstruction::CreateCustomCall( - input->shape(), {input}, - host_memory_offload_annotations:: - kMoveToHostCustomCallTarget)); - if (instruction->has_sharding()) { - move_to_host_custom_call->set_sharding(instruction->sharding()); - } - TF_RETURN_IF_ERROR( - instruction->ReplaceAllUsesWith(move_to_host_custom_call)); - TF_RETURN_IF_ERROR( - c->RemoveInstructionAndUnusedOperands(instruction)); - changed = true; - } else if (is_to_device_case) { - VLOG(1) << "Process backward case: " << instruction->ToString(); - HloInstruction* custom_call_operand = instruction->mutable_operand(0); - HloInstruction* new_result = - c->AddInstruction(HloInstruction::CreateCustomCall( - custom_call_operand->shape(), {custom_call_operand}, - host_memory_offload_annotations:: - kMoveToDeviceCustomCallTarget)); - TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_result)); - TF_RETURN_IF_ERROR( - c->RemoveInstructionAndUnusedOperands(instruction)); - changed = true; - } + TF_ASSIGN_OR_RETURN( + auto result, + ConvertCustomCallWithExternalAnnotationToInternalAnnotation( + c, instruction)); + changed |= result; } } } diff --git a/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations_test.cc b/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations_test.cc index 88fa3644fbd26b..dab4d055d8f252 100644 --- a/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations_test.cc +++ b/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations_test.cc @@ -17,16 +17,16 @@ #include #include -#include #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/host_memory_offload_annotations.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -486,7 +486,7 @@ ENTRY main.183 { TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest, ConvertOutputPinnedHostTest) { - constexpr std::string_view hlo_string = R"( + constexpr absl::string_view hlo_string = R"( HloModule m, entry_computation_layout={(f32[2,2]{1,0:T(2,128)},f32[2,2]{1,0:T(2,128)})->f32[2,2]{1,0:T(2,128)S(5)}} ENTRY m { x = f32[2,2] parameter(0) @@ -510,5 +510,65 @@ TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest, EXPECT_EQ(move_to_host_count, 1); } +TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest, + ConvertPinToDeviceSramTest) { + constexpr absl::string_view hlo_string = R"( + HloModule jit_f, entry_computation_layout={(s32[8,2]{0,1:T(2,128)S(1)})->s32[8,2]{0,1:T(2,128)}}, allow_spmd_sharding_propagation_to_output={true} + + ENTRY main.8 { + Arg_0.1 = s32[8,2]{1,0} parameter(0), sharding={devices=[2,1]<=[2]}, metadata={op_name="x"} + constant.2 = s32[] constant(2) + broadcast.3 = s32[8,2]{1,0} broadcast(constant.2), dimensions={} + multiply.4 = s32[8,2]{1,0} multiply(Arg_0.1, broadcast.3), metadata={op_name="jit(f)/jit(main)/mul" source_file="third_party/py/jax/tests/memories_test.py" source_line=707} + custom-call.5 = s32[8,2]{1,0} custom-call(multiply.4), custom_call_target="Sharding", sharding={devices=[2,1]<=[2]}, metadata={op_name="jit(f)/jit(main)/device_put" source_file="third_party/py/jax/tests/memories_test.py" source_line=708} + custom-call.6 = s32[8,2]{1,0} custom-call(custom-call.5), custom_call_target="annotate_device_placement", custom_call_has_side_effect=true, frontend_attributes={_xla_buffer_placement="device_sram"}, metadata={op_name="jit(f)/jit(main)/device_put" source_file="third_party/py/jax/tests/memories_test.py" source_line=708} + ROOT multiply.7 = s32[8,2]{1,0} multiply(custom-call.6, broadcast.3), metadata={op_name="jit(f)/jit(main)/mul" source_file="third_party/py/jax/tests/memories_test.py" source_line=709} + } // main.8 )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + bool changed = + ConvertMemoryPlacementToInternalAnnotations().Run(module.get()).value(); + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + int64_t pin_todevice_sramcount = 0; + for (auto* c : module->computations()) { + for (auto* instr : c->instructions()) { + pin_todevice_sramcount += instr->IsCustomCall( + host_memory_offload_annotations::kPinToDeviceSramCustomCallTarget); + } + } + EXPECT_EQ(pin_todevice_sramcount, 1); +} + +TEST_F(ConvertMemoryPlacementToInternalAnnotationsTest, + ConvertPinToDeviceTest) { + constexpr absl::string_view hlo_string = R"( + HloModule jit_f, entry_computation_layout={(s32[8,2]{0,1:T(2,128)S(1)})->s32[8,2]{0,1:T(2,128)}}, allow_spmd_sharding_propagation_to_output={true} + + ENTRY main.8 { + Arg_0.1 = s32[8,2]{1,0} parameter(0), sharding={devices=[2,1]<=[2]}, metadata={op_name="x"} + constant.2 = s32[] constant(2) + broadcast.3 = s32[8,2]{1,0} broadcast(constant.2), dimensions={} + multiply.4 = s32[8,2]{1,0} multiply(Arg_0.1, broadcast.3), metadata={op_name="jit(f)/jit(main)/mul" source_file="third_party/py/jax/tests/memories_test.py" source_line=707} + custom-call.5 = s32[8,2]{1,0} custom-call(multiply.4), custom_call_target="Sharding", sharding={devices=[2,1]<=[2]}, metadata={op_name="jit(f)/jit(main)/device_put" source_file="third_party/py/jax/tests/memories_test.py" source_line=708} + custom-call.6 = s32[8,2]{1,0} custom-call(custom-call.5), custom_call_target="annotate_device_placement", custom_call_has_side_effect=true, frontend_attributes={_xla_buffer_placement="pinned_device"}, metadata={op_name="jit(f)/jit(main)/device_put" source_file="third_party/py/jax/tests/memories_test.py" source_line=708} + ROOT multiply.7 = s32[8,2]{1,0} multiply(custom-call.6, broadcast.3), metadata={op_name="jit(f)/jit(main)/mul" source_file="third_party/py/jax/tests/memories_test.py" source_line=709} + } // main.8 )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + bool changed = + ConvertMemoryPlacementToInternalAnnotations().Run(module.get()).value(); + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + int64_t pin_todevice_count = 0; + for (auto* c : module->computations()) { + for (auto* instr : c->instructions()) { + pin_todevice_count += instr->IsCustomCall( + host_memory_offload_annotations::kPinToDeviceCustomCallTarget); + } + } + EXPECT_EQ(pin_todevice_count, 1); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/hlo/transforms/defuser.cc b/third_party/xla/xla/hlo/transforms/defuser.cc index 04d93ef8237743..16f8152a9d15dc 100644 --- a/third_party/xla/xla/hlo/transforms/defuser.cc +++ b/third_party/xla/xla/hlo/transforms/defuser.cc @@ -18,7 +18,9 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc index 9918e34c352386..3cccad769aff47 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc @@ -15,6 +15,11 @@ limitations under the License. #include "xla/hlo/transforms/expanders/bitcast_dtypes_expander.h" +#include +#include +#include + +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/lib/broadcast.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander_test.cc b/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander_test.cc index 2b5efab5c6897b..033bd4d5d84cfb 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/hlo/transforms/expanders/bitcast_dtypes_expander.h" +#include + #include #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.cc index 2bdb4c18036da9..56794a3985ad38 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.cc @@ -15,10 +15,18 @@ limitations under the License. #include "xla/hlo/transforms/expanders/cholesky_expander.h" -#include +#include +#include +#include +#include +#include +#include #include +#include "absl/algorithm/container.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" #include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/lib/constants.h" #include "xla/hlo/builder/lib/loops.h" @@ -32,6 +40,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.h b/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.h index 3ee4a26ad2ee2f..868bde43018b9c 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.h +++ b/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.h @@ -16,9 +16,16 @@ limitations under the License. #ifndef XLA_HLO_TRANSFORMS_EXPANDERS_CHOLESKY_EXPANDER_H_ #define XLA_HLO_TRANSFORMS_EXPANDERS_CHOLESKY_EXPANDER_H_ +#include +#include +#include + #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/transforms/expanders/op_expander_pass.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/expanders/comparison_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/comparison_expander.cc index 0f09ecced1ebaf..61a4305b09d5b9 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/comparison_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/comparison_expander.cc @@ -19,6 +19,9 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/convolution_4d_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/convolution_4d_expander.cc index a6c25114a4ce19..efa18b8266a000 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/convolution_4d_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/convolution_4d_expander.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/hlo/transforms/expanders/convolution_4d_expander.h" #include +#include #include #include diff --git a/third_party/xla/xla/hlo/transforms/expanders/convolution_4d_expander_test.cc b/third_party/xla/xla/hlo/transforms/expanders/convolution_4d_expander_test.cc index 39d7e3ebb9a9c1..3221a01c528689 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/convolution_4d_expander_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/convolution_4d_expander_test.cc @@ -18,11 +18,12 @@ limitations under the License. #include #include +#include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/expanders/convolution_pred_expander_test.cc b/third_party/xla/xla/hlo/transforms/expanders/convolution_pred_expander_test.cc index e7aab8622b75f1..f97744f4b71eb6 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/convolution_pred_expander_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/convolution_pred_expander_test.cc @@ -15,14 +15,15 @@ limitations under the License. #include "xla/hlo/transforms/expanders/convolution_pred_expander.h" +#include #include #include #include #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc b/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc index 1df1743532438b..339165f485110e 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc @@ -23,7 +23,9 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -34,6 +36,7 @@ limitations under the License. #include "xla/service/shape_inference.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer_test.cc b/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer_test.cc index ad8e6d874fd80d..3a5c5e6112a0e6 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer_test.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include +#include #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -24,9 +26,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter.cc b/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter.cc index bf4ecc61bf6361..8472b031859bca 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter.cc @@ -15,11 +15,14 @@ limitations under the License. #include "xla/hlo/transforms/expanders/dynamic_index_splitter.h" -#include +#include +#include +#include -#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter.h b/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter.h index 26f68155ac71e6..910b149d136755 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter.h +++ b/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_HLO_TRANSFORMS_EXPANDERS_DYNAMIC_INDEX_SPLITTER_H_ #define XLA_HLO_TRANSFORMS_EXPANDERS_DYNAMIC_INDEX_SPLITTER_H_ +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter_test.cc b/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter_test.cc index b0699e5a07b6fc..a7727224a6ecd8 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter_test.cc @@ -15,13 +15,15 @@ limitations under the License. #include "xla/hlo/transforms/expanders/dynamic_index_splitter.h" +#include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/test.h" -#include "xla/test_helpers.h" +#include "xla/xla.pb.h" namespace xla { namespace { diff --git a/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.cc index d7900a19fdbce0..b934245d6f336f 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.cc @@ -16,14 +16,20 @@ limitations under the License. #include "xla/hlo/transforms/expanders/eigh_expander.h" #include +#include #include -#include #include #include #include #include +#include "absl/algorithm/container.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" +#include "absl/types/span.h" #include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/lib/comparators.h" #include "xla/hlo/builder/lib/constants.h" @@ -38,6 +44,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" // Parallel two-sided Jacobi symmetric eigendecomposition. diff --git a/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.h b/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.h index 54cbee776d9c99..3f47d792183de1 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.h +++ b/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.h @@ -16,7 +16,13 @@ limitations under the License. #ifndef XLA_HLO_TRANSFORMS_EXPANDERS_EIGH_EXPANDER_H_ #define XLA_HLO_TRANSFORMS_EXPANDERS_EIGH_EXPANDER_H_ +#include +#include + #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/transforms/expanders/op_expander_pass.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/logistic_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/logistic_expander.cc index 416d29ed6ef8fc..0eab6cdd3e5d2d 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/logistic_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/logistic_expander.cc @@ -15,8 +15,7 @@ limitations under the License. #include "xla/hlo/transforms/expanders/logistic_expander.h" -#include - +#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_creation_utils.h" @@ -35,7 +34,7 @@ absl::StatusOr LogisticExpander::ExpandInstruction( HloInstruction* instruction) { HloInstruction* operand = instruction->mutable_operand(0); const Shape operand_shape = operand->shape(); - // Computing 1.0 / (1.0 - exp(-x)) + // Computing 1.0 / (1.0 + exp(-x)) HloInstruction* one_constant = MakeScalarLike(operand, 1.0f); HloInstruction* exp_instr = MakeUnaryHlo(HloOpcode::kExp, diff --git a/third_party/xla/xla/hlo/transforms/expanders/logistic_expander_test.cc b/third_party/xla/xla/hlo/transforms/expanders/logistic_expander_test.cc index fb5598524006f6..a2314a50df2825 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/logistic_expander_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/logistic_expander_test.cc @@ -16,17 +16,18 @@ limitations under the License. #include "xla/hlo/transforms/expanders/logistic_expander.h" #include -#include +#include +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" +#include "xla/hlo/testlib/test.h" #include "xla/service/dynamic_padder.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" -#include "xla/test.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -62,7 +63,7 @@ TEST_F(LogisticExpanderTest, ExpandWith) { } TEST_F(LogisticExpanderTest, DynamicDimensions) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule DynamicDimensions ENTRY main { diff --git a/third_party/xla/xla/hlo/transforms/expanders/op_expander_pass.h b/third_party/xla/xla/hlo/transforms/expanders/op_expander_pass.h index 798c6a4ed46c06..c30120ee2370f5 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/op_expander_pass.h +++ b/third_party/xla/xla/hlo/transforms/expanders/op_expander_pass.h @@ -16,6 +16,11 @@ limitations under the License. #ifndef XLA_HLO_TRANSFORMS_EXPANDERS_OP_EXPANDER_PASS_H_ #define XLA_HLO_TRANSFORMS_EXPANDERS_OP_EXPANDER_PASS_H_ +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/pass/hlo_pass_interface.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/expanders/optimization_barrier_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/optimization_barrier_expander.cc index 12908f26c8fbd8..10dcc7a2eef96c 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/optimization_barrier_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/optimization_barrier_expander.cc @@ -15,6 +15,14 @@ limitations under the License. #include "xla/hlo/transforms/expanders/optimization_barrier_expander.h" +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" + namespace xla { absl::StatusOr OptimizationBarrierExpander::Run( diff --git a/third_party/xla/xla/hlo/transforms/expanders/optimization_barrier_expander.h b/third_party/xla/xla/hlo/transforms/expanders/optimization_barrier_expander.h index f6904ec0ff1b7e..a18b8e9a310239 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/optimization_barrier_expander.h +++ b/third_party/xla/xla/hlo/transforms/expanders/optimization_barrier_expander.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_HLO_TRANSFORMS_EXPANDERS_OPTIMIZATION_BARRIER_EXPANDER_H_ #define XLA_HLO_TRANSFORMS_EXPANDERS_OPTIMIZATION_BARRIER_EXPANDER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/transforms/expanders/op_expander_pass.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc index 1627a6be5e683b..c23bc8279da2d0 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc @@ -15,10 +15,19 @@ limitations under the License. #include "xla/hlo/transforms/expanders/qr_expander.h" -#include +#include +#include +#include +#include +#include +#include #include +#include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" #include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/lib/constants.h" #include "xla/hlo/builder/lib/loops.h" @@ -33,6 +42,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/expanders/qr_expander.h b/third_party/xla/xla/hlo/transforms/expanders/qr_expander.h index 8d7c4a8e90786b..7ff56e28d485d9 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/qr_expander.h +++ b/third_party/xla/xla/hlo/transforms/expanders/qr_expander.h @@ -16,10 +16,17 @@ limitations under the License. #ifndef XLA_HLO_TRANSFORMS_EXPANDERS_QR_EXPANDER_H_ #define XLA_HLO_TRANSFORMS_EXPANDERS_QR_EXPANDER_H_ +#include +#include + #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/builder/lib/qr.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/transforms/expanders/op_expander_pass.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander.cc index 33735a16f25e8b..33ca25fc4dc320 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/hlo/transforms/expanders/real_imag_expander.h" +#include "absl/status/statusor.h" #include "xla/literal_util.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander.h b/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander.h index 52b50455744b27..e9ae9ce611c331 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander.h +++ b/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_HLO_TRANSFORMS_EXPANDERS_REAL_IMAG_EXPANDER_H_ #define XLA_HLO_TRANSFORMS_EXPANDERS_REAL_IMAG_EXPANDER_H_ +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/transforms/expanders/op_expander_pass.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander_test.cc b/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander_test.cc index 7f0042a5169db1..ab5c06f556cbc0 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander_test.cc @@ -16,18 +16,18 @@ limitations under the License. #include "xla/hlo/transforms/expanders/real_imag_expander.h" #include -#include +#include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.cc b/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.cc index de795a8f74989a..3b7746cfdb6137 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.cc @@ -20,7 +20,10 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.h b/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.h index 46c2e7ddf6e429..22bcabf831ca6f 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.h +++ b/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.h @@ -18,6 +18,9 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/pass/hlo_pass_interface.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer_test.cc b/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer_test.cc index 997ea50e51b565..e597519e306a02 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer_test.cc @@ -14,14 +14,13 @@ limitations under the License. ==============================================================================*/ #include "xla/hlo/transforms/expanders/reduce_decomposer.h" -#include -#include #include +#include #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" -#include "xla/test.h" -#include "xla/test_helpers.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" namespace xla { namespace { diff --git a/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer.cc b/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer.cc index ac0b058426a67e..50924428832c5d 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer.cc @@ -15,7 +15,12 @@ limitations under the License. #include "xla/hlo/transforms/expanders/reshape_decomposer.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/service/hlo_creation_utils.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer.h b/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer.h index 1efa0cbf2c7ef2..f169cdc666a803 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer.h +++ b/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_HLO_TRANSFORMS_EXPANDERS_RESHAPE_DECOMPOSER_H_ #define XLA_HLO_TRANSFORMS_EXPANDERS_RESHAPE_DECOMPOSER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/pass/hlo_pass_interface.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer_test.cc b/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer_test.cc index 87cf748818069e..ae937ee77ce135 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer_test.cc @@ -14,13 +14,15 @@ limitations under the License. ==============================================================================*/ #include "xla/hlo/transforms/expanders/reshape_decomposer.h" -#include #include +#include +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" -#include "xla/test.h" -#include "xla/test_helpers.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" namespace xla { namespace { diff --git a/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.h b/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.h index 15df45060052b5..40057f9fcbbc87 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.h +++ b/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.h @@ -17,7 +17,9 @@ limitations under the License. #define XLA_HLO_TRANSFORMS_EXPANDERS_RNG_BIT_GENERATOR_EXPANDER_H_ #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/transforms/expanders/op_expander_pass.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/rng_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/rng_expander.cc index 2667440674887a..dfcc95c0324f2b 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/rng_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/rng_expander.cc @@ -15,13 +15,23 @@ limitations under the License. #include "xla/hlo/transforms/expanders/rng_expander.h" +#include +#include #include - +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "xla/hlo/builder/lib/prng.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/service/hlo_creation_utils.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/expanders/rng_expander.h b/third_party/xla/xla/hlo/transforms/expanders/rng_expander.h index e6c52cf1143a44..d8f41ec83071e6 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/rng_expander.h +++ b/third_party/xla/xla/hlo/transforms/expanders/rng_expander.h @@ -16,7 +16,13 @@ limitations under the License. #ifndef XLA_HLO_TRANSFORMS_EXPANDERS_RNG_EXPANDER_H_ #define XLA_HLO_TRANSFORMS_EXPANDERS_RNG_EXPANDER_H_ +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/transforms/expanders/op_expander_pass.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander.cc index 3df7d03a2b0024..775fe3ef1cb72d 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander.cc @@ -18,16 +18,20 @@ limitations under the License. #include #include #include +#include #include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander.h b/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander.h index 210eaeb1a17b74..f6d84fae29a994 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander.h +++ b/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander.h @@ -18,6 +18,8 @@ limitations under the License. #include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/transforms/expanders/op_expander_pass.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander_test.cc b/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander_test.cc index a3b40831a24f5e..e577e8c557ba79 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander_test.cc @@ -15,13 +15,16 @@ limitations under the License. #include "xla/hlo/transforms/expanders/stable_sort_expander.h" +#include + +#include #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" -#include "xla/test.h" #include "xla/tsl/lib/core/status_test_util.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer.cc b/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer.cc index 1fb054159d7848..7c5ab5fa62a752 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer.cc @@ -16,10 +16,13 @@ limitations under the License. #include "xla/hlo/transforms/expanders/stochastic_convert_decomposer.h" #include -#include +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer.h b/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer.h index 835a55be249c7c..e0574e4fa5e85f 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer.h +++ b/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_HLO_TRANSFORMS_EXPANDERS_STOCHASTIC_CONVERT_DECOMPOSER_H_ #define XLA_HLO_TRANSFORMS_EXPANDERS_STOCHASTIC_CONVERT_DECOMPOSER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer_test.cc b/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer_test.cc index 8ebc1b448e09a2..27d19d24bb31b1 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer_test.cc @@ -15,8 +15,12 @@ limitations under the License. #include "xla/hlo/transforms/expanders/stochastic_convert_decomposer.h" +#include #include +#include +#include +#include "absl/status/status.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" diff --git a/third_party/xla/xla/hlo/transforms/host_offload_legalize.cc b/third_party/xla/xla/hlo/transforms/host_offload_legalize.cc index 639e37874ceb4b..5e70dbb26c7d21 100644 --- a/third_party/xla/xla/hlo/transforms/host_offload_legalize.cc +++ b/third_party/xla/xla/hlo/transforms/host_offload_legalize.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/algorithm/container.h" diff --git a/third_party/xla/xla/hlo/transforms/host_offload_legalize.h b/third_party/xla/xla/hlo/transforms/host_offload_legalize.h index a5d85fa40a8a5c..e08c842ee0bc68 100644 --- a/third_party/xla/xla/hlo/transforms/host_offload_legalize.h +++ b/third_party/xla/xla/hlo/transforms/host_offload_legalize.h @@ -17,8 +17,10 @@ #include #include +#include #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/third_party/xla/xla/hlo/transforms/host_offload_legalize_test.cc b/third_party/xla/xla/hlo/transforms/host_offload_legalize_test.cc index 4aedc40b8ca2be..12e3c6935cdab2 100644 --- a/third_party/xla/xla/hlo/transforms/host_offload_legalize_test.cc +++ b/third_party/xla/xla/hlo/transforms/host_offload_legalize_test.cc @@ -16,12 +16,9 @@ limitations under the License. #include "xla/hlo/transforms/host_offload_legalize.h" #include -#include #include -#include #include -#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_computation.h" @@ -29,9 +26,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/service/host_memory_offload_annotations.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/lib/core/status_test_util.h" diff --git a/third_party/xla/xla/hlo/transforms/host_offloader.cc b/third_party/xla/xla/hlo/transforms/host_offloader.cc index 29073f6bf26eeb..c05659eb5bb21c 100644 --- a/third_party/xla/xla/hlo/transforms/host_offloader.cc +++ b/third_party/xla/xla/hlo/transforms/host_offloader.cc @@ -15,15 +15,10 @@ limitations under the License. #include "xla/hlo/transforms/host_offloader.h" -#include -#include #include #include #include -#include #include -#include -#include #include #include "absl/algorithm/container.h" @@ -35,7 +30,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -56,6 +51,7 @@ limitations under the License. #include "xla/side_effect_util.h" #include "xla/status_macros.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -259,7 +255,7 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( instruction_and_shape_index.shape_index); CHECK(output_shape.has_layout()) << "Expecting output shape of entry computation to have a layout."; - if (output_shape.layout().memory_space() == kHostMemorySpaceColor) { + if (output_shape.layout().memory_space() == Layout::kHostMemorySpace) { VLOG(2) << absl::StreamFormat( "Memory offloaded starting from %s is output streamed", starting_instruction_and_index.ToString()); @@ -284,7 +280,7 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( // Finished walking all host memory paths. Now we'll make all the necessary // changes. const bool set_buffers_changed = SetBuffersToMemorySpaceColor( - buffers_to_set_to_host_memory, kHostMemorySpaceColor); + buffers_to_set_to_host_memory, Layout::kHostMemorySpace); changed = changed || set_buffers_changed; for (HloInstruction* dus : dynamic_update_slices) { @@ -353,7 +349,7 @@ absl::StatusOr HostOffloader::HandleInputStreaming( entry_computation_layout.parameter_shape(i), [&](const Shape& subshape, const ShapeIndex& index) { if (subshape.has_layout() && - subshape.layout().memory_space() == kHostMemorySpaceColor) { + subshape.layout().memory_space() == Layout::kHostMemorySpace) { HloInstruction* parameter_instruction = entry_computation->parameter_instruction(i); VLOG(1) << "Host parameter streamed into program with shape: " @@ -399,7 +395,7 @@ absl::StatusOr HostOffloader::HandleMoveToHostCustomCall( HloInstruction* copy_to_host = data_to_copy->parent()->AddInstruction(HloInstruction::CreateUnary( data_to_copy->shape(), HloOpcode::kCopy, data_to_copy)); - SetMemorySpace(copy_to_host->mutable_shape(), kHostMemorySpaceColor); + SetMemorySpace(copy_to_host->mutable_shape(), Layout::kHostMemorySpace); TF_RETURN_IF_ERROR( custom_call_instruction->ReplaceAllUsesWith(copy_to_host)); VLOG(2) << absl::StreamFormat( @@ -491,7 +487,7 @@ absl::StatusOr HostOffloader::InsertCopyBetween( copy_to_host = data_to_copy->parent()->AddInstruction(HloInstruction::CreateUnary( data_to_copy->shape(), HloOpcode::kCopy, data_to_copy)); - SetMemorySpace(copy_to_host->mutable_shape(), kHostMemorySpaceColor); + SetMemorySpace(copy_to_host->mutable_shape(), Layout::kHostMemorySpace); copies_created_after_[data_to_copy] = copy_to_host; } else { // We already have a copy which feeds into this instruction. @@ -623,7 +619,7 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice( SetMemorySpace(ShapeUtil::GetMutableSubshape( instruction_and_shape.instruction->mutable_shape(), instruction_and_shape.shape_index), - kHostMemorySpaceColor); + Layout::kHostMemorySpace); HloInstruction* instruction = instruction_and_shape.instruction; if (instruction->opcode() == HloOpcode::kParameter) { // If this is a parameter of a while_body, we also need to find the @@ -649,7 +645,7 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice( SetMemorySpace(ShapeUtil::GetMutableSubshape( while_condition_parameter->mutable_shape(), instruction_and_shape.shape_index), - kHostMemorySpaceColor); + Layout::kHostMemorySpace); // Walk further down the graph and set the memory spaces of all uses // too. This includes verifying that no compute is done on the buffer. // Another, better way, to do this, is to walk down the graph starting @@ -673,7 +669,7 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice( ShapeUtil::GetMutableSubshape( nested_instruction_and_shape.instruction->mutable_shape(), nested_instruction_and_shape.shape_index), - kHostMemorySpaceColor); + Layout::kHostMemorySpace); TF_ASSIGN_OR_RETURN( const std::vector successors, host_offload_utils::GetSuccessors( @@ -715,7 +711,8 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice( VLOG(1) << absl::StreamFormat( "Created new AllocateBuffer instruction \"%s\"", allocate_buffer->ToString()); - SetMemorySpace(allocate_buffer->mutable_shape(), kHostMemorySpaceColor); + SetMemorySpace(allocate_buffer->mutable_shape(), + Layout::kHostMemorySpace); for (int64_t index : operand_indices) { TF_RETURN_IF_ERROR( broadcast_user->ReplaceOperandWith(index, allocate_buffer)); @@ -797,7 +794,7 @@ absl::StatusOr HostOffloader::ApplySchedulingFix( continue; } if (instruction->shape().layout().memory_space() != - kHostMemorySpaceColor) { + Layout::kHostMemorySpace) { continue; } // Replace DynamicUpdateSlice's 1st operand with a copy in case it @@ -859,7 +856,7 @@ absl::StatusOr UpdateMemorySpaceForHostOffloadedOutputs( // If instruction is MoveToHost, we will replace usage. if (instr_and_shape.instruction->IsCustomCall( host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { - to_replace.emplace_back(instr_and_shape); + to_replace.push_back(instr_and_shape); continue; } @@ -1018,7 +1015,7 @@ absl::StatusOr HostOffloader::HandleRedundantCopiesBackToHost( queue.push(successor); host_instrs_tree.mutable_element(output_shape_index) - ->emplace_back(successor); + ->push_back(successor); } } @@ -1064,7 +1061,7 @@ absl::StatusOr HostOffloader::Run( const absl::flat_hash_set& execution_threads) { bool changed = false; - // First remove redundant copies to and from host (conservatively) starting + // Remove redundant copies to and from host (conservatively) starting // from the outputs of the host offloaded computations. Iterate over all // instructions and look for XLA host offload annotations. bool changed_in_loop; diff --git a/third_party/xla/xla/hlo/transforms/host_offloader.h b/third_party/xla/xla/hlo/transforms/host_offloader.h index 765b3c2709856e..5055aa15f10a87 100644 --- a/third_party/xla/xla/hlo/transforms/host_offloader.h +++ b/third_party/xla/xla/hlo/transforms/host_offloader.h @@ -18,8 +18,11 @@ #include #include #include +#include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" @@ -56,8 +59,7 @@ class HloCostAnalysis; // pass. class HostOffloader : public HloModulePass { public: - explicit HostOffloader(int64_t host_memory_space_color) - : kHostMemorySpaceColor(host_memory_space_color) {} + HostOffloader() = default; ~HostOffloader() override = default; absl::string_view name() const override { return "host-offloader"; } @@ -74,7 +76,6 @@ class HostOffloader : public HloModulePass { // instruction chain) are ignored. absl::StatusOr ProcessNextMoveToHostInstr(HloComputation* computation); - const int64_t kHostMemorySpaceColor; absl::flat_hash_set already_visited_move_to_host_custom_calls_; absl::flat_hash_set dynamic_update_slices_already_allocated_; diff --git a/third_party/xla/xla/hlo/transforms/host_offloader_test.cc b/third_party/xla/xla/hlo/transforms/host_offloader_test.cc index 1452815127f1a7..84e748747b68e1 100644 --- a/third_party/xla/xla/hlo/transforms/host_offloader_test.cc +++ b/third_party/xla/xla/hlo/transforms/host_offloader_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include @@ -32,6 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/pattern_matcher_gmock.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/transforms/host_offload_legalize.h" #include "xla/layout.h" @@ -39,7 +39,6 @@ limitations under the License. #include "xla/service/host_memory_offload_annotations.h" #include "xla/service/host_offload_utils.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/lib/core/status_test_util.h" @@ -64,7 +63,7 @@ class HostOffloaderTest : public HloHardwareIndependentTestBase { after_layout); TF_ASSIGN_OR_RETURN(bool legal_changed, host_offload_legalize.Run(module)); changed |= legal_changed; - HostOffloader host_offloader(Layout::kHostMemorySpace); + HostOffloader host_offloader; TF_ASSIGN_OR_RETURN(bool offload_changed, host_offloader.Run(module)); changed |= offload_changed; return changed; diff --git a/third_party/xla/xla/hlo/transforms/literal_canonicalizer.cc b/third_party/xla/xla/hlo/transforms/literal_canonicalizer.cc new file mode 100644 index 00000000000000..6b7418447b6ad0 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/literal_canonicalizer.cc @@ -0,0 +1,76 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/transforms/literal_canonicalizer.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/dfs_hlo_visitor.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/literal_pool.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" + +namespace xla { +namespace { + +class LiteralCanonicalizerVisitor : public DfsHloRewriteVisitor { + public: + LiteralCanonicalizerVisitor(LiteralPool* literal_pool, size_t min_size_bytes) + : literal_pool_(literal_pool), min_size_bytes_(min_size_bytes) {} + + absl::Status HandleConstant(HloInstruction* hlo) final { + auto* constant = Cast(hlo); + if (constant->HasLiteral() && + constant->literal().size_bytes() >= min_size_bytes_) { + MarkAsMaybeChanged(constant->Canonicalize(literal_pool_)); + } + return absl::OkStatus(); + } + + private: + LiteralPool* literal_pool_; + size_t min_size_bytes_; +}; + +} // namespace + +LiteralCanonicalizer::LiteralCanonicalizer(LiteralPool* literal_pool, + size_t min_size_bytes) + : literal_pool_(literal_pool), min_size_bytes_(min_size_bytes) {} + +absl::StatusOr LiteralCanonicalizer::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + // Every time we canonicalize literals in a module, we garbage collect expired + // literals from the pool. + size_t num_erased = literal_pool_->GarbageCollect(); + VLOG(3) << "Garbage collected " << num_erased << " expired literals"; + + LiteralCanonicalizerVisitor visitor(literal_pool_, min_size_bytes_); + TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&visitor)); + return visitor.changed(); +} + +} // namespace xla diff --git a/third_party/xla/xla/hlo/transforms/literal_canonicalizer.h b/third_party/xla/xla/hlo/transforms/literal_canonicalizer.h new file mode 100644 index 00000000000000..26d1768f374a79 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/literal_canonicalizer.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_LITERAL_CANONICALIZER_H_ +#define XLA_HLO_TRANSFORMS_LITERAL_CANONICALIZER_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/literal_pool.h" + +namespace xla { + +// Canonicalizes literals larger than 'min_size_bytes' in the HLO module using +// the given literal pool. +class LiteralCanonicalizer : public HloModulePass { + public: + LiteralCanonicalizer(LiteralPool* literal_pool, size_t min_size_bytes); + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + absl::string_view name() const override { return "literal-canonicalizer"; } + + protected: + LiteralPool* literal_pool_; + size_t min_size_bytes_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_LITERAL_CANONICALIZER_H_ diff --git a/third_party/xla/xla/hlo/transforms/literal_canonicalizer_test.cc b/third_party/xla/xla/hlo/transforms/literal_canonicalizer_test.cc new file mode 100644 index 00000000000000..6a59ee3eb39b37 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/literal_canonicalizer_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/transforms/literal_canonicalizer.h" + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/literal_pool.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +class LiteralCanonicalizerTest : public HloHardwareIndependentTestBase {}; + +TEST_F(LiteralCanonicalizerTest, CanonicalizeConstants) { + absl::string_view hlo_string = R"( + HloModule m + + ENTRY %entry { + ROOT %c0 = f32[4] constant({1.0, 2.0, 3.0, 4.0}) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module0, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module1, + ParseAndReturnVerifiedModule(hlo_string)); + + LiteralPool literal_pool; + LiteralCanonicalizer literal_canonicalizer(&literal_pool, 0); + + EXPECT_FALSE(literal_canonicalizer.Run(module0.get()).value()); + EXPECT_TRUE(literal_canonicalizer.Run(module1.get()).value()); + + auto* c0 = Cast( + module0->entry_computation()->root_instruction()); + auto* c1 = Cast( + module1->entry_computation()->root_instruction()); + + EXPECT_EQ(c0->literal(), c1->literal()); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/hlo/transforms/memory_space_propagation.cc b/third_party/xla/xla/hlo/transforms/memory_space_propagation.cc index d0704df0e88af9..3dc14572dc408b 100644 --- a/third_party/xla/xla/hlo/transforms/memory_space_propagation.cc +++ b/third_party/xla/xla/hlo/transforms/memory_space_propagation.cc @@ -16,7 +16,11 @@ limitations under the License. #include "xla/hlo/transforms/memory_space_propagation.h" #include +#include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/hlo/transforms/memory_space_propagation.h b/third_party/xla/xla/hlo/transforms/memory_space_propagation.h index bb0da70bf1a7fc..b3998f542d39f5 100644 --- a/third_party/xla/xla/hlo/transforms/memory_space_propagation.h +++ b/third_party/xla/xla/hlo/transforms/memory_space_propagation.h @@ -16,6 +16,12 @@ limitations under the License. #ifndef XLA_HLO_TRANSFORMS_MEMORY_SPACE_PROPAGATION_H_ #define XLA_HLO_TRANSFORMS_MEMORY_SPACE_PROPAGATION_H_ +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/third_party/xla/xla/hlo/transforms/memory_space_propagation_test.cc b/third_party/xla/xla/hlo/transforms/memory_space_propagation_test.cc index 15cd6c4cd4cbff..a1252d596ee281 100644 --- a/third_party/xla/xla/hlo/transforms/memory_space_propagation_test.cc +++ b/third_party/xla/xla/hlo/transforms/memory_space_propagation_test.cc @@ -15,6 +15,10 @@ limitations under the License. #include "xla/hlo/transforms/memory_space_propagation.h" +#include +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" diff --git a/third_party/xla/xla/hlo/transforms/operand_upcaster_test.cc b/third_party/xla/xla/hlo/transforms/operand_upcaster_test.cc index 8a143b365af618..ed61bb63d2dad6 100644 --- a/third_party/xla/xla/hlo/transforms/operand_upcaster_test.cc +++ b/third_party/xla/xla/hlo/transforms/operand_upcaster_test.cc @@ -18,12 +18,15 @@ limitations under the License. #include #include +#include +#include #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/primitive_util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/BUILD b/third_party/xla/xla/hlo/transforms/simplifiers/BUILD new file mode 100644 index 00000000000000..9c9a1be6f5718c --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/BUILD @@ -0,0 +1,1535 @@ +# Description: +# Implementation of XLA’s HLO simplifier transformations. + +load("//xla:xla.bzl", "xla_cc_test") +load("//xla/tsl:tsl.bzl", "tsl_copts") +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") +load( + "//xla/tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "hlo_constant_splitter", + srcs = ["hlo_constant_splitter.cc"], + hdrs = ["hlo_constant_splitter.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "hlo_constant_splitter_test", + srcs = ["hlo_constant_splitter_test.cc"], + deps = [ + ":hlo_constant_splitter", + ":hlo_dce", + "//xla:test", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "all_reduce_folder", + srcs = ["all_reduce_folder.cc"], + hdrs = ["all_reduce_folder.h"], + deps = [ + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/service:all_reduce_key", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "all_reduce_folder_test", + srcs = ["all_reduce_folder_test.cc"], + deps = [ + ":all_reduce_folder", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "broadcast_canonicalizer", + srcs = ["broadcast_canonicalizer.cc"], + hdrs = ["broadcast_canonicalizer.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "broadcast_canonicalizer_test", + srcs = ["broadcast_canonicalizer_test.cc"], + deps = [ + ":broadcast_canonicalizer", + "//xla:test", + "//xla:test_helpers", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "bfloat16_conversion_folding", + srcs = ["bfloat16_conversion_folding.cc"], + hdrs = ["bfloat16_conversion_folding.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_dataflow_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:float_support", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + ], +) + +xla_cc_test( + name = "bfloat16_conversion_folding_test", + srcs = ["bfloat16_conversion_folding_test.cc"], + deps = [ + ":bfloat16_conversion_folding", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:float_support", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "float_normalization", + srcs = ["float_normalization.cc"], + hdrs = ["float_normalization.h"], + deps = [ + ":hlo_dce", + ":tuple_simplifier", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:call_graph", + "//xla/service:float_support", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "float_normalization_test", + srcs = ["float_normalization_test.cc"], + deps = [ + ":float_normalization", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:float_support", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_verifier", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_computation_deduplicator", + srcs = ["hlo_computation_deduplicator.cc"], + hdrs = ["hlo_computation_deduplicator.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "hlo_computation_deduplicator_test", + size = "small", + srcs = ["hlo_computation_deduplicator_test.cc"], + deps = [ + ":hlo_computation_deduplicator", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "flatten_call_graph", + srcs = ["flatten_call_graph.cc"], + hdrs = ["flatten_call_graph.h"], + deps = [ + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/service:call_graph", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "flatten_call_graph_test", + srcs = ["flatten_call_graph_test.cc"], + deps = [ + ":flatten_call_graph", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:call_graph", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_memory_scheduler", + srcs = ["hlo_memory_scheduler.cc"], + hdrs = ["hlo_memory_scheduler.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:tuple_points_to_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:buffer_value", + "//xla/service:logical_buffer", + "//xla/service/heap_simulator", + "//xla/tsl/lib/gtl:map_util", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:numbers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:scoped_annotation", + ], +) + +xla_cc_test( + name = "hlo_memory_scheduler_test", + srcs = ["hlo_memory_scheduler_test.cc"], + deps = [ + ":hlo_dce", + ":hlo_memory_scheduler", + "//xla:literal_util", + "//xla:shape_util", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:hlo_ordering", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:buffer_value", + "//xla/service:hlo_value", + "//xla/service:logical_buffer", + "//xla/service/heap_simulator", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "algebraic_simplifier", + srcs = ["algebraic_simplifier.cc"], + hdrs = ["algebraic_simplifier.h"], + copts = tsl_copts(), + deps = [ + "//xla:comparison_util", + "//xla:literal", + "//xla:literal_util", + "//xla:permutation_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/evaluator:hlo_evaluator", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_instruction_utils", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_sharding_util", + "//xla/service:gather_scatter_utils", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_module_config", + "//xla/service:host_memory_offload_annotations_hdr", + "//xla/service:host_offload_utils", + "//xla/service:pattern_matcher", + "//xla/service:shape_inference", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "algebraic_simplifier_test", + srcs = ["algebraic_simplifier_test.cc"], + deps = [ + ":algebraic_simplifier", + ":hlo_constant_folding", + "//xla:comparison_util", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla:util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:hlo_creation_utils", + "//xla/service:host_memory_offload_annotations_hdr", + "//xla/service:layout_assignment", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service:shape_inference", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "simplify_fp_conversions", + srcs = ["simplify_fp_conversions.cc"], + hdrs = ["simplify_fp_conversions.h"], + deps = [ + "//xla:comparison_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "simplify_fp_conversions_test", + srcs = ["simplify_fp_conversions_test.cc"], + deps = [ + ":simplify_fp_conversions", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "batch_dot_simplification", + srcs = ["batch_dot_simplification.cc"], + hdrs = ["batch_dot_simplification.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "batch_dot_simplification_test", + srcs = ["batch_dot_simplification_test.cc"], + deps = [ + ":batch_dot_simplification", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "convolution_group_converter", + srcs = ["convolution_group_converter.cc"], + hdrs = ["convolution_group_converter.h"], + deps = [ + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + ], +) + +xla_cc_test( + name = "convolution_group_converter_test", + size = "small", + srcs = ["convolution_group_converter_test.cc"], + deps = [ + ":convolution_group_converter", + "//xla:test", + "//xla:types", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "dot_dimension_merger", + srcs = ["dot_dimension_merger.cc"], + hdrs = ["dot_dimension_merger.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "dot_dimension_merger_test", + srcs = ["dot_dimension_merger_test.cc"], + deps = [ + ":dot_dimension_merger", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "dot_merger", + srcs = ["dot_merger.cc"], + hdrs = ["dot_merger.h"], + deps = [ + "//xla:protobuf_util", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:shape_inference", + "//xla/service/graphcycles", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "dot_merger_test", + srcs = ["dot_merger_test.cc"], + deps = [ + ":algebraic_simplifier", + ":dot_merger", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "convert_mover", + srcs = ["convert_mover.cc"], + hdrs = ["convert_mover.h"], + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "convert_mover_test", + srcs = ["convert_mover_test.cc"], + deps = [ + ":convert_mover", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "tuple_simplifier", + srcs = ["tuple_simplifier.cc"], + hdrs = ["tuple_simplifier.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "tuple_simplifier_test", + srcs = ["tuple_simplifier_test.cc"], + deps = [ + ":tuple_simplifier", + "//xla:shape_util", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "dynamic_dimension_simplifier", + srcs = ["dynamic_dimension_simplifier.cc"], + hdrs = ["dynamic_dimension_simplifier.h"], + deps = [ + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +xla_cc_test( + name = "dynamic_dimension_simplifier_test", + srcs = ["dynamic_dimension_simplifier_test.cc"], + deps = [ + ":dynamic_dimension_simplifier", + "//xla:literal", + "//xla:shape_util", + "//xla:test", + "//xla:types", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:hlo_creation_utils", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service:shape_inference", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "reshape_mover", + srcs = ["reshape_mover.cc"], + hdrs = ["reshape_mover.h"], + deps = [ + "//xla:permutation_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "reshape_mover_test", + srcs = ["reshape_mover_test.cc"], + deps = [ + ":algebraic_simplifier", + ":reshape_mover", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:hlo_verifier", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_dce", + srcs = ["hlo_dce.cc"], + hdrs = ["hlo_dce.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "hlo_dce_test", + srcs = ["hlo_dce_test.cc"], + deps = [ + ":hlo_dce", + "//xla:literal_util", + "//xla:shape_util", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:literal_test_util", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_rematerialization_test_utils", + testonly = 1, + hdrs = ["hlo_rematerialization_test_utils.h"], + deps = [ + "//xla:literal_util", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@local_tsl//tsl/platform:test_main", + ], +) + +xla_cc_test( + name = "hlo_rematerialization_test_utils_test", + srcs = ["hlo_rematerialization_test_utils_test.cc"], + deps = [ + ":hlo_rematerialization_test_utils", + "//xla/hlo/ir:hlo", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_rematerialization", + srcs = ["hlo_rematerialization.cc"], + hdrs = ["hlo_rematerialization.h"], + deps = [ + ":hlo_dce", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_dataflow_analysis", + "//xla/hlo/analysis:tuple_points_to_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/service:call_graph", + "//xla/service:hlo_cost_analysis", + "//xla/service:logical_buffer", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:numbers", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "hlo_rematerialization_test", + srcs = ["hlo_rematerialization_test.cc"], + deps = [ + ":hlo_memory_scheduler", + ":hlo_rematerialization", + ":hlo_rematerialization_test_utils", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:hlo_cost_analysis", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_constant_folding", + srcs = ["hlo_constant_folding.cc"], + hdrs = ["hlo_constant_folding.h"], + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla/hlo/evaluator:hlo_evaluator", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:slow_operation_alarm", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "hlo_constant_folding_test", + srcs = ["hlo_constant_folding_test.cc"], + deps = [ + ":hlo_constant_folding", + "//xla:literal", + "//xla:literal_util", + "//xla:permutation_util", + "//xla:shape_util", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +# copybara:uncomment_begin(google-only) +# xla_cc_test( +# name = "hlo_constant_folding_peak_heap_test", +# srcs = ["hlo_constant_folding_peak_heap_test.cc"], +# deps = [ +# ":hlo_constant_folding", +# "@com_google_googletest//:gtest", +# "@com_google_absl//absl/strings:str_format", +# "//xla:test", +# "//xla/hlo/testlib:hlo_hardware_independent_test_base", +# "@local_tsl//tsl/platform:statusor", +# "@local_tsl//tsl/platform:test_main", +# ], +# ) +# copybara:uncomment_end + +cc_library( + name = "hlo_element_type_converter", + srcs = ["hlo_element_type_converter.cc"], + hdrs = ["hlo_element_type_converter.h"], + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/evaluator:hlo_evaluator", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "hlo_element_type_converter_test", + srcs = ["hlo_element_type_converter_test.cc"], + deps = [ + ":hlo_element_type_converter", + "//xla:xla_data_proto_cc", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "conditional_canonicalizer", + srcs = ["conditional_canonicalizer.cc"], + hdrs = ["conditional_canonicalizer.h"], + deps = [ + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +xla_cc_test( + name = "conditional_canonicalizer_test", + srcs = ["conditional_canonicalizer_test.cc"], + deps = [ + ":conditional_canonicalizer", + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/tests:literal_test_util", + "//xla/tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "zero_sized_hlo_elimination", + srcs = ["zero_sized_hlo_elimination.cc"], + hdrs = ["zero_sized_hlo_elimination.h"], + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "zero_sized_hlo_elimination_test", + srcs = ["zero_sized_hlo_elimination_test.cc"], + deps = [ + ":zero_sized_hlo_elimination", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "sort_simplifier", + srcs = ["sort_simplifier.cc"], + hdrs = ["sort_simplifier.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "sort_simplifier_test", + srcs = ["sort_simplifier_test.cc"], + deps = [ + ":sort_simplifier", + "//xla:test", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "root_instruction_sinker", + srcs = ["root_instruction_sinker.cc"], + hdrs = ["root_instruction_sinker.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:tuple_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +xla_cc_test( + name = "root_instruction_sinker_test", + srcs = ["root_instruction_sinker_test.cc"], + deps = [ + ":root_instruction_sinker", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "host_memory_transfer_asyncifier", + srcs = ["host_memory_transfer_asyncifier.cc"], + hdrs = ["host_memory_transfer_asyncifier.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "host_memory_transfer_asyncifier_test", + srcs = ["host_memory_transfer_asyncifier_test.cc"], + deps = [ + ":host_memory_transfer_asyncifier", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "fusion_constant_sinking", + srcs = ["fusion_constant_sinking.cc"], + hdrs = ["fusion_constant_sinking.h"], + deps = [ + ":hlo_dce", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "fusion_constant_sinking_test", + srcs = ["fusion_constant_sinking_test.cc"], + deps = [ + ":fusion_constant_sinking", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "optimize_input_output_buffer_alias", + srcs = ["optimize_input_output_buffer_alias.cc"], + hdrs = ["optimize_input_output_buffer_alias.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "optimize_input_output_buffer_alias_test", + srcs = ["optimize_input_output_buffer_alias_test.cc"], + deps = [ + ":optimize_input_output_buffer_alias", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "ar_crs_combiner", + srcs = ["ar_crs_combiner.cc"], + hdrs = ["ar_crs_combiner.h"], + deps = [ + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_replication_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/service:call_graph", + "//xla/service:pattern_matcher", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "ar_crs_combiner_test", + srcs = ["ar_crs_combiner_test.cc"], + deps = [ + ":ar_crs_combiner", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "slice_sinker", + srcs = ["slice_sinker.cc"], + hdrs = ["slice_sinker.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "slice_sinker_test", + srcs = ["slice_sinker_test.cc"], + deps = [ + ":hlo_dce", + ":slice_sinker", + "//xla:literal_util", + "//xla:shape_util", + "//xla:types", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "result_caster", + srcs = ["result_caster.cc"], + hdrs = ["result_caster.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:op_expander_pass", + "//xla/service:shape_inference", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +xla_cc_test( + name = "result_caster_test", + srcs = ["result_caster_test.cc"], + deps = [ + ":result_caster", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "convert_operand_folding", + srcs = ["convert_operand_folder.cc"], + hdrs = ["convert_operand_folder.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:op_expander_pass", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "convert_operand_folding_test", + srcs = ["convert_operand_folder_test.cc"], + deps = [ + ":convert_operand_folding", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "reduce_window_rewriter", + srcs = ["reduce_window_rewriter.cc"], + hdrs = ["reduce_window_rewriter.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "reduce_window_rewriter_test", + srcs = ["reduce_window_rewriter_test.cc"], + deps = [ + ":reduce_window_rewriter", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "gather_simplifier", + srcs = ["gather_simplifier.cc"], + hdrs = ["gather_simplifier.h"], + deps = [ + "//xla:literal_util", + "//xla:permutation_util", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:op_expander_pass", + "//xla/service:gather_scatter_utils", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "gather_simplifier_test", + srcs = ["gather_simplifier_test.cc"], + deps = [ + ":gather_simplifier", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "instruction_hoister", + srcs = ["instruction_hoister.cc"], + hdrs = ["instruction_hoister.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:status", + ], +) + +cc_library( + name = "sub_byte_normalization", + srcs = ["sub_byte_normalization.cc"], + hdrs = ["sub_byte_normalization.h"], + deps = [ + "//xla:shape_layout", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + ], +) + +cc_library( + name = "tree_reduction_rewriter", + srcs = ["tree_reduction_rewriter.cc"], + hdrs = ["tree_reduction_rewriter.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:padding", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:shape_inference", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + ], +) diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc index 269284b021d5de..4b96bf2a81d502 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc @@ -5939,8 +5939,9 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( new_operands.push_back(operand); } } - VLOG(4) << "Sinking broadcast after user:" << "\n old broadcast: " - << broadcast->ToString() << "\n old user: " << user->ToString(); + VLOG(4) << "Sinking broadcast after user:" + << "\n old broadcast: " << broadcast->ToString() + << "\n old user: " << user->ToString(); changed_shape = ShapeUtil::ChangeElementType(operand->shape(), user->shape().element_type()); simplifier_->UpdateLayout(&changed_shape); @@ -8233,6 +8234,24 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { return absl::OkStatus(); } +absl::Status AlgebraicSimplifierVisitor::HandleReducePrecision( + HloInstruction* hlo) { + HloReducePrecisionInstruction* reduce_precision = + Cast(hlo); + PrimitiveType element_type = + reduce_precision->operand(0)->shape().element_type(); + if (options_.enable_remove_no_op_reduce_precision() && + reduce_precision->exponent_bits() == + primitive_util::ExponentWidth(element_type) && + reduce_precision->mantissa_bits() + 1 == + primitive_util::SignificandWidth(element_type)) { + return ReplaceInstruction( + /*old_instruction=*/hlo, + /*new_instruction=*/reduce_precision->mutable_operand(0)); + } + return absl::OkStatus(); +} + absl::Status AlgebraicSimplifierVisitor::HandleReduceWindow( HloInstruction* hlo) { auto* reduce_window = Cast(hlo); diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.h b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.h index 96c50ba251a949..f3ded542605dbf 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.h +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.h @@ -322,6 +322,16 @@ class AlgebraicSimplifierOptions { return enable_broadcast_degenerate_dimension_; } + void set_enable_remove_no_op_reduce_precision( + bool enable_remove_no_op_reduce_precision) { + enable_remove_no_op_reduce_precision_ = + enable_remove_no_op_reduce_precision; + } + + bool enable_remove_no_op_reduce_precision() const { + return enable_remove_no_op_reduce_precision_; + } + private: // Metadata struct can be used to store any metadata information encapsulated // with the AlgebraicSimplifierOptions that can be later used in an @@ -364,6 +374,7 @@ class AlgebraicSimplifierOptions { bool disable_dynamic_slice_to_slice_conversion_{false}; bool enable_fast_math_{false}; bool enable_broadcast_degenerate_dimension_{true}; + bool enable_remove_no_op_reduce_precision_{false}; Metadata metadata_; }; @@ -484,6 +495,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { absl::Status HandleReduce(HloInstruction* hlo) override; + absl::Status HandleReducePrecision(HloInstruction* hlo) override; + absl::Status HandleReduceWindow(HloInstruction* hlo) override; absl::Status HandleReverse(HloInstruction* reverse) override; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc index 5b0519107ad653..e30822e37f578d 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc @@ -12688,5 +12688,36 @@ TEST_F(AlgebraicSimplifierTest, TestNew123) { EXPECT_FALSE(simplifier.Run(module.get()).value()); } +TEST_F(AlgebraicSimplifierTest, + ReducePrecisionWithSamePrecisionAsOperandIsRemovedIfRemoveNoOpIsSet) { + const char* hlo = R"( + HloModule test + ENTRY main { + p0 = bf16[64]{0} parameter(0) + ROOT reduce-precision = bf16[64] reduce-precision(p0), exponent_bits=8, mantissa_bits=7 + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo)); + default_options_.set_enable_remove_no_op_reduce_precision(true); + EXPECT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter())); +} + +TEST_F(AlgebraicSimplifierTest, + ReducePrecisionWithDifferentPrecisionFromOperandIsNotModifiedByDefault) { + const char* hlo = R"( + HloModule test + ENTRY main { + p0 = bf16[64]{0} parameter(0) + ROOT reduce-precision = bf16[64] reduce-precision(p0), exponent_bits=7, mantissa_bits=8 + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo)); + + default_options_.set_enable_remove_no_op_reduce_precision(true); + EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/all_reduce_folder.cc b/third_party/xla/xla/hlo/transforms/simplifiers/all_reduce_folder.cc index 49ba41a4cedcdd..078767e8a2112b 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/all_reduce_folder.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/all_reduce_folder.cc @@ -137,7 +137,7 @@ std::optional> FoldReplicaGroups( } // Sort the replica groups by the first id for stable behavior. Otherwise, - // groups are formed according to the order in the contributer_set_id map, + // groups are formed according to the order in the contributor_set_id map, // which is not stable. absl::c_sort(new_replica_groups, [](const ReplicaGroup &a, const ReplicaGroup &b) { diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.cc b/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.cc index b6d8a532054502..88dbd2781ca60f 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.cc @@ -259,7 +259,7 @@ absl::Status FloatNormalizationVisitor::ChangeOutputTypeThenInsertConvertBack( if (allow_excess_precision && user->opcode() == HloOpcode::kConvert && user->shape().element_type() == to && to == HighPrecisionType() && from == LowPrecisionType()) { - conversions_to_simplify.emplace_back(user); + conversions_to_simplify.push_back(user); } else { TF_RETURN_IF_ERROR(hlo->ReplaceUseWithDifferentShape(user, new_hlo)); } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_computation_deduplicator_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_computation_deduplicator_test.cc index 85b8e1c9619589..07565f5f26eff9 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_computation_deduplicator_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_computation_deduplicator_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include #include -#include #include #include +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -38,7 +38,7 @@ namespace { class HloComputationDeduplicatorTest : public HloHardwareIndependentTestBase { protected: - std::vector RunDeduplicatePass(const std::string_view text, + std::vector RunDeduplicatePass(const absl::string_view text, bool expect_true) { std::unique_ptr module = ParseAndReturnVerifiedModule(text).value(); @@ -54,7 +54,7 @@ class HloComputationDeduplicatorTest : public HloHardwareIndependentTestBase { }; TEST_F(HloComputationDeduplicatorTest, RemoveRegionBandC) { - const std::string_view text = R"( + const absl::string_view text = R"( HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0}, s32[20]{0})->s32[]} region_A { Arg_0.6 = s32[] parameter(0) @@ -97,7 +97,7 @@ TEST_F(HloComputationDeduplicatorTest, RemoveRegionBandC) { } TEST_F(HloComputationDeduplicatorTest, RemoveRegionBExactCopy) { - const std::string_view text = R"( + const absl::string_view text = R"( HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]} region_A { Arg_0.5 = s32[] parameter(0) @@ -129,7 +129,7 @@ TEST_F(HloComputationDeduplicatorTest, RemoveRegionBExactCopy) { } TEST_F(HloComputationDeduplicatorTest, RemoveRegionsWithSameSubcomp) { - const std::string_view text = R"( + const absl::string_view text = R"( HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]} region_X { Ag_0 = s32[] parameter(0) @@ -193,7 +193,7 @@ TEST_F(HloComputationDeduplicatorTest, RemoveRegionsWithSameSubcomp) { } TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionsWithDifferentSubcomp) { - const std::string_view text = R"( + const absl::string_view text = R"( HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]} region_X { Ag_0 = s32[] parameter(0) @@ -272,7 +272,7 @@ TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionsWithDifferentSubcomp) { } TEST_F(HloComputationDeduplicatorTest, RemoveRegionBVarDifferences) { - const std::string_view text = R"( + const absl::string_view text = R"( HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]} region_A { Arg_0.5 = s32[] parameter(0) @@ -306,7 +306,7 @@ TEST_F(HloComputationDeduplicatorTest, RemoveRegionBVarDifferences) { } TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionBCommutative) { - const std::string_view text = R"( + const absl::string_view text = R"( HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]} region_A { Arg_0 = s32[] parameter(0) @@ -342,7 +342,7 @@ TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionBCommutative) { TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionBDifferentExecutionThread) { - const std::string_view text = R"( + const absl::string_view text = R"( HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]} region_A { @@ -389,7 +389,7 @@ TEST_F(HloComputationDeduplicatorTest, } TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionLargeConstant) { - const std::string_view text = R"( + const absl::string_view text = R"( HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]} region_A { Arg_00 = s32[] parameter(0) @@ -481,7 +481,7 @@ TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionLargeConstant) { } TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionBDifferentcomp) { - const std::string_view text = R"( + const absl::string_view text = R"( HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]} region_A { Arg_0.5 = s32[] parameter(0) @@ -516,7 +516,7 @@ TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionBDifferentcomp) { } TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionBDifferentType) { - const std::string_view text = R"( + const absl::string_view text = R"( HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s16[15]{0})->s16[]} region_A { Arg_0.5 = s32[] parameter(0) @@ -552,7 +552,7 @@ TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionBDifferentType) { TEST_F(HloComputationDeduplicatorTest, DontRemoveRegionBEntryComp) { // Note: this test is hypothetical and just to check dedup. - const std::string_view text = R"( + const absl::string_view text = R"( HloModule DeDupTest, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]} region_A1 { Arg_0.5 = s32[] parameter(0) @@ -637,7 +637,7 @@ TEST_F(HloComputationDeduplicatorTest, LargeSubComputationTest) { TEST_F(HloComputationDeduplicatorTest, DontDeduplicateReduceAllReduce) { // Note: this test is hypothetical and just to check dedup. - const std::string_view text = R"( + const absl::string_view text = R"( HloModule TestModule add.1 { diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_splitter_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_splitter_test.cc index 6a9dc33350c5fd..b1e76bf7a90fbf 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_splitter_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_splitter_test.cc @@ -236,7 +236,7 @@ TEST_F(HloConstantSplitterTest, NoSplittingSideEffectExpressions) { // The HloConstantSplitter pass duplicates several constant expressions. Then // the DCE pass removes the dead instructions. Although the flag changed is - // true, we do not alter the module in essense. + // true, we do not alter the module in essence. EXPECT_TRUE(changed); EXPECT_EQ(count_before, count_after_dce); int64_t rng_count = 0; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler_test.cc index 74fcfb4d08106a..8ffc8e3b19c5cb 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/algorithm/container.h" @@ -429,7 +428,7 @@ TEST_F(HloSchedulingTest, BFSScheduler) { instructions_by_name[instruction->name()] = instruction; } - auto index = [&](std::string_view name) -> size_t { + auto index = [&](absl::string_view name) -> size_t { const HloInstruction* instruction = instructions_by_name.at(name); return std::distance(sequence.begin(), absl::c_find(sequence, instruction)); }; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc index e40bb1d872ced4..04997baca3642c 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc @@ -25,7 +25,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -2898,7 +2897,7 @@ absl::StatusOr HloRematerialization::Run( // at the same time, as that will cause the asynchronous callee usage to be // added to the main thread callers usage. The callee's memory is // preallocated, so the caller doesn't pay for it. - absl::flat_hash_set async_threads; + absl::flat_hash_set async_threads; for (const auto& [computation, _] : options_.async_computation_parallelism) { async_threads.insert(computation->execution_thread()); diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization_test.cc index e9742e2e28e874..6f0a72ce3edfbf 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization_test.cc @@ -83,7 +83,7 @@ class AsyncRematerializationTest : public RematerializationTestBase { }; TEST_F(AsyncRematerializationTest, AsyncComputation) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule async, is_scheduled=true %offload_computation { diff --git a/third_party/xla/xla/hlo/transforms/tests/BUILD b/third_party/xla/xla/hlo/transforms/tests/BUILD new file mode 100644 index 00000000000000..9b1d8595f3062f --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/tests/BUILD @@ -0,0 +1,39 @@ +load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla:internal"], + licenses = ["notice"], +) + +cc_library( + name = "dummy_passes", + hdrs = ["dummy_passes.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +lit_test_suite( + name = "hlo_opt_tests", + srcs = enforce_glob( + [ + "run_single_pass.hlo", + "run_multiple_passes.hlo", + "algebraic_simplifier.hlo", + ], + include = [ + "*.hlo", + ], + ), + cfg = "//xla:lit.cfg.py", + tools = [ + "//xla/hlo/tools:hlo-opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/third_party/xla/xla/hlo/transforms/tests/algebraic_simplifier.hlo b/third_party/xla/xla/hlo/transforms/tests/algebraic_simplifier.hlo new file mode 100644 index 00000000000000..899da94152467d --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/tests/algebraic_simplifier.hlo @@ -0,0 +1,17 @@ +// RUN: hlo-opt %s --passes=algebraic_simplifier | FileCheck %s + +HloModule m +ENTRY test { + // CHECK: %[[p0:.*]] = s32[8]{0} parameter(0) + // CHECK-NEXT: %[[p2:.*]] = s32[8]{0} parameter(2) + // CHECK-NEXT: %[[x:.*]] = s32[8]{0} multiply(s32[8]{0} %[[p0]], s32[8]{0} %[[p2]]) + // CHECK-NEXT: %[[p1:.*]] = s32[8]{0} parameter(1) + // CHECK-NEXT: %[[y:.*]] = s32[8]{0} multiply(s32[8]{0} %[[p1]], s32[8]{0} %[[p2]]) + // CHECK-NEXT: ROOT %[[sum:.*]] = s32[8]{0} add(s32[8]{0} %[[x]], s32[8]{0} %[[y]]) + p0 = s32[8] parameter(0) + p1 = s32[8] parameter(1) + p2 = s32[8] parameter(2) + x = s32[8] multiply(p0, p2) + y = s32[8] multiply(p1, p2) + ROOT sum = s32[8] add(x, y) +} diff --git a/third_party/xla/xla/tools/hlo_opt/transforms_example_passes.h b/third_party/xla/xla/hlo/transforms/tests/dummy_passes.h similarity index 93% rename from third_party/xla/xla/tools/hlo_opt/transforms_example_passes.h rename to third_party/xla/xla/hlo/transforms/tests/dummy_passes.h index 1f2d954d78d637..fb1644dc88ec96 100644 --- a/third_party/xla/xla/tools/hlo_opt/transforms_example_passes.h +++ b/third_party/xla/xla/hlo/transforms/tests/dummy_passes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TOOLS_HLO_OPT_TRANSFORMS_EXAMPLE_PASSES_H_ -#define XLA_TOOLS_HLO_OPT_TRANSFORMS_EXAMPLE_PASSES_H_ +#ifndef XLA_HLO_TRANSFORMS_TESTS_DUMMY_PASSES_H_ +#define XLA_HLO_TRANSFORMS_TESTS_DUMMY_PASSES_H_ #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" @@ -73,4 +73,4 @@ class BarToHelloModulePass : public HloModulePass { } // namespace xla -#endif // XLA_TOOLS_HLO_OPT_TRANSFORMS_EXAMPLE_PASSES_H_ +#endif // XLA_HLO_TRANSFORMS_TESTS_DUMMY_PASSES_H_ diff --git a/third_party/xla/xla/tools/hlo_opt/tests/run_multiple_passes.hlo b/third_party/xla/xla/hlo/transforms/tests/run_multiple_passes.hlo similarity index 100% rename from third_party/xla/xla/tools/hlo_opt/tests/run_multiple_passes.hlo rename to third_party/xla/xla/hlo/transforms/tests/run_multiple_passes.hlo diff --git a/third_party/xla/xla/tools/hlo_opt/tests/run_single_pass.hlo b/third_party/xla/xla/hlo/transforms/tests/run_single_pass.hlo similarity index 100% rename from third_party/xla/xla/tools/hlo_opt/tests/run_single_pass.hlo rename to third_party/xla/xla/hlo/transforms/tests/run_single_pass.hlo diff --git a/third_party/xla/xla/hlo/transforms/while_loop_trip_count_annotator_test.cc b/third_party/xla/xla/hlo/transforms/while_loop_trip_count_annotator_test.cc index 942408086452d4..2391db7f81f5a0 100644 --- a/third_party/xla/xla/hlo/transforms/while_loop_trip_count_annotator_test.cc +++ b/third_party/xla/xla/hlo/transforms/while_loop_trip_count_annotator_test.cc @@ -15,8 +15,9 @@ limitations under the License. #include "xla/hlo/transforms/while_loop_trip_count_annotator.h" +#include #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD index 57c4026256176f..95feac917157d1 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD @@ -36,10 +36,12 @@ cc_library( deps = [ ":attribute_importer", ":hlo_utils", + "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/mlir_hlo", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", @@ -198,6 +200,7 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/mlir_hlo", "//xla/tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc index c71cdabc7b2acb..86098e3a538aa7 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc @@ -17,10 +17,13 @@ limitations under the License. #include #include +#include #include +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -34,6 +37,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -193,25 +198,58 @@ absl::StatusOr ImportSend( attributes.push_back(ConvertChannelHandle(channel_handle, builder)); } - // Return async_start/done for pipelined send. - // - // old-style send returns a bundle of (arg, sync flag, token) to be passed - // along to send-done. - // However, the new-style async ops have a shared bundle - // format of (args, results, scratchpad), so to rewrite the `send` and - // `send-done` ops to use the new-style async API, we need to reorder the - // arguments to be in (args, token, sync flag) order. - auto result_types = result_type.cast().getTypes(); - if (result_types.size() != 3) - return InvalidArgument("send should return a 3-tuple"); - auto async_arg_type = mlir::TupleType::get( - builder->getContext(), {result_types[0], result_types[2]}); - auto async_bundled_tuple = - mlir::TupleType::get(builder->getContext(), - {async_arg_type, result_types[2], result_types[1]}); - return ImportOldStyleAsyncStart( - symbol_table, attributes, operands, loc, async_bundled_tuple, builder, - "send_", [](auto) { return absl::OkStatus(); }); + bool isPipelined = + instruction->users().front()->opcode() != HloOpcode::kSendDone; + if (isPipelined) { + // Consider removing this path and erroring, unclear if support is needed. + + // Return async_start/done for pipelined send. + // + // old-style send returns a bundle of (arg, sync flag, token) to be passed + // along to send-done. + // However, the new-style async ops have a shared bundle + // format of (args, results, scratchpad), so to rewrite the `send` and + // `send-done` ops to use the new-style async API, we need to reorder the + // arguments to be in (args, token, sync flag) order. + auto result_types = result_type.cast().getTypes(); + if (result_types.size() != 3) + return InvalidArgument("send should return a 3-tuple"); + auto async_arg_type = mlir::TupleType::get( + builder->getContext(), {result_types[0], result_types[2]}); + auto async_bundled_tuple = mlir::TupleType::get( + builder->getContext(), + {async_arg_type, result_types[2], result_types[1]}); + return ImportOldStyleAsyncStart( + symbol_table, attributes, operands, loc, async_bundled_tuple, builder, + "send_", [](auto) { return absl::OkStatus(); }); + } + + // Otherwise return send op for non-pipelined send. + // Skip empty data in MLIR send(tuple<>, token) --> mhlo.send(token) + auto token = operands[1]; + llvm::ArrayRef args = operands; + if (args.size() == 2 && IsEmptyTuple(args[0].getType())) { + args = args.drop_front(1); + } + auto send = + builder + ->create(loc, token.getType(), args, attributes) + .getOperation(); + if (instruction->has_sharding()) { + const HloSharding& sharding = instruction->sharding(); + if (sharding.IsTuple() && sharding.tuple_elements().size() == 3) { + // Here we are returning a 1-tuple, but HLO send returns a 3-tuple. Need + // to grab a slice of the sharding. All shardings are maximal, so we + // just need 1 of them. + send->setAttr( + kShardingAttr, + mlir::StringAttr::get( + builder->getContext(), + HloSharding::FromProto(sharding.ToProto().tuple_shardings()[0]) + ->ToString())); + } + } + return send; } absl::StatusOr ImportRecv( @@ -223,6 +261,7 @@ absl::StatusOr ImportRecv( auto recv_op = Cast(instruction); attributes.push_back(builder->getNamedAttr( "is_host_transfer", builder->getBoolAttr(recv_op->is_host_transfer()))); + if (recv_op->channel_id().has_value()) { ChannelHandle channel_handle; channel_handle.set_handle(recv_op->channel_id().value()); @@ -232,27 +271,68 @@ absl::StatusOr ImportRecv( attributes.push_back(ConvertChannelHandle(channel_handle, builder)); } - // Old-style `recv` returns a bundle of (result, sync flag, token) to be - // passed along to recv-done. - // However, the new-style async ops have a shared - // bundle format of (args, results, scratchpad), so to rewrite the `recv` - // and `recv-done` ops to use the new-style async API, we need to reorder - // the arguments to be in (token, (result, token), sync flag) order. - // OR (token, token, sync flag) if no result is received. - auto result_types = result_type.cast().getTypes(); + // Currently only consolidates async recv with result, 0-result recv uses old + // style, unclear if this support is needed. + auto result_types = llvm::cast(result_type).getTypes(); if (result_types.size() != 3) return InvalidArgument("recv should return a 3-tuple"); - // Allow recv of no values, only token. - // b/TODO: Allow recv of no values, only token. - auto async_result_type = mlir::TupleType::get( - builder->getContext(), {result_types[0], result_types[2]}); - auto async_bundled_tuple = mlir::TupleType::get( - builder->getContext(), - {result_types[2], async_result_type, result_types[1]}); - return ImportOldStyleAsyncStart( - symbol_table, attributes, operands, loc, async_bundled_tuple, builder, - "recv_", [](auto) { return absl::OkStatus(); }); + bool isPipelined = + instruction->users().front()->opcode() != HloOpcode::kRecvDone; + if (isPipelined) { + // Consider removing this path and erroring, unclear if support is needed. + + // Old-style `recv` returns a bundle of (result, sync flag, token) to be + // passed along to recv-done. + // However, the new-style async ops have a shared + // bundle format of (args, results, scratchpad), so to rewrite the `recv` + // and `recv-done` ops to use the new-style async API, we need to reorder + // the arguments to be in (token, (result, token), sync flag) order. + // OR (token, token, sync flag) if no result is received. + llvm::SmallVector async_result_types = {result_types[0], + result_types[2]}; + auto async_result_type_tuple = builder->getTupleType(async_result_types); + auto async_bundled_tuple = builder->getTupleType( + {result_types[2], async_result_type_tuple, result_types[1]}); + return ImportOldStyleAsyncStart( + symbol_table, attributes, operands, loc, async_bundled_tuple, builder, + "recv_", [](auto) { return absl::OkStatus(); }); + } + + // Return recv op for non-pipelined send, skip empty tuple result type + if (!IsEmptyTuple(result_types[0])) { + auto recv = builder->create( + loc, llvm::SmallVector{result_types[0], result_types[2]}, + operands, attributes); + if (instruction->has_sharding()) { + const HloSharding& sharding = instruction->sharding(); + if (sharding.IsTuple() && sharding.tuple_elements().size() == 3) { + // Here we are returning a 2-tuple, but HLO recv returns a 3-tuple. Need + // to grab a slice of the sharding. All shardings are maximal, so we + // just need to 2 of them. + OpSharding sharding_proto = sharding.ToProto(); + auto* tuple_shardings = sharding_proto.mutable_tuple_shardings(); + tuple_shardings->DeleteSubrange(1, 1); + recv->setAttr(kShardingAttr, + mlir::StringAttr::get( + builder->getContext(), + HloSharding::FromProto(sharding_proto)->ToString())); + } + } + return WrapVariadicResultsInTuple(builder, loc, recv); + } + + // Recv with no result, only token. + // To keep parity, if op only returns token, wrap in tuple, token> + auto recv = builder->create( + loc, llvm::SmallVector{result_types[2]}, operands, + attributes); + auto empty_tuple = + builder->create(loc, llvm::ArrayRef{}); + + return builder->create( + loc, + llvm::ArrayRef{empty_tuple.getResult(), recv.getResult(0)}); } // Async Collectives @@ -376,7 +456,14 @@ absl::StatusOr ImportAsyncOpDone( const HloInstruction* instruction, mlir::Location loc, const llvm::SmallVectorImpl& operands, llvm::SmallVectorImpl& attributes, - mlir::Type result_type, mlir::OpBuilder* builder) { + mlir::Type result_type, mlir::OpBuilder* builder, + std::optional consolidate_if_parent) { + // Consolidate if the defining op matches `consolidate_if_parent`, ensuring + // the async communication op is not pipelined. + if (consolidate_if_parent.has_value() && + instruction->operand(0)->opcode() == consolidate_if_parent.value()) { + return operands[0].getDefiningOp(); + } return ImportOldStyleAsyncDone(attributes, operands, loc, result_type, builder); } diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.h index 906f9235f28498..116d17f86c7bc0 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.h @@ -17,6 +17,7 @@ limitations under the License. #define XLA_HLO_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_ #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -28,6 +29,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace xla { @@ -81,7 +83,8 @@ absl::StatusOr ImportAsyncOpDone( const HloInstruction* instruction, mlir::Location loc, const llvm::SmallVectorImpl& operands, llvm::SmallVectorImpl& attributes, - mlir::Type result_type, mlir::OpBuilder* builder); + mlir::Type result_type, mlir::OpBuilder* builder, + std::optional consolidate_if_parent = std::nullopt); } // namespace xla diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc index 79211ab918d546..a40e1d571a3395 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -1275,7 +1275,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } case HloOpcode::kSendDone: { return ImportAsyncOpDone(instruction, loc, operands, attributes, - result_type, func_builder); + result_type, func_builder, HloOpcode::kSend); } case HloOpcode::kRecv: { return ImportRecv(instruction, loc, operands, attributes, result_type, @@ -1283,7 +1283,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } case HloOpcode::kRecvDone: { return ImportAsyncOpDone(instruction, loc, operands, attributes, - result_type, func_builder); + result_type, func_builder, HloOpcode::kRecv); } case HloOpcode::kConditional: { llvm::SmallVector rets; diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc index 2398576e09ed0f..2b8f9e3669c34c 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc @@ -15,8 +15,6 @@ limitations under the License. #include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" -#include - #include "absl/status/status.h" #include "absl/status/statusor.h" #include "mlir/IR/Location.h" diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc index 564440ac00edcb..f70769ea91abec 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc @@ -188,6 +188,26 @@ mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder, return tupleOp; } +mlir::Operation* WrapVariadicResultsInTuple(mlir::OpBuilder* builder, + mlir::Location loc, + mlir::Operation* op) { + auto result_types = op->getResultTypes(); + // Consider skipping wrapping result type of size 1. + assert(result_types.size() != 1 || + !llvm::isa(result_types[0]) && + "Cannot wrap single tuple arg in tuple"); + + auto tuple_type = builder->getTupleType(result_types); + return CreateTupleFromOpResults(builder, loc, op, tuple_type); +} + +bool IsEmptyTuple(const mlir::Type& type) { + if (auto tuple_type = llvm::dyn_cast(type)) { + return tuple_type.getTypes().empty(); + } + return false; +} + mlir::TypeRange Untuple(const mlir::Type& type) { if (llvm::isa(type)) { return llvm::dyn_cast(type).getTypes(); diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h index 5c116fd08c9705..34c2a0ecd3f445 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h @@ -195,6 +195,14 @@ mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder, mlir::Location loc, mlir::Operation* op, mlir::Type type); +// Create a TupleOp using the results of 'op'. +mlir::Operation* WrapVariadicResultsInTuple(mlir::OpBuilder* builder, + mlir::Location loc, + mlir::Operation* op); + +// Returns true if the type is a tuple with no elements. +bool IsEmptyTuple(const mlir::Type& type); + mlir::TypeRange Untuple(const mlir::Type& type); static std::pair GetLayoutAttribute( diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils_test.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils_test.cc index e51b0b9b325c2d..6c7ca18cfcfe72 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils_test.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD index ff3acf603add8d..19b7a2314245b8 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD @@ -22,7 +22,6 @@ lit_test_suite( "if_conditional.hlo", "import.hlo", "import_async.hlo", - "import_async2.hlo", "import_entry_computation_layout.hlo", "layouts_and_names.hlo", "location.hlo", diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_async.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_async.hlo index 5aa09777f30022..7689434eb8568d 100644 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_async.hlo +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_async.hlo @@ -1,18 +1,13 @@ // RUN: xla-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations -split-input-file %s -o - | FileCheck %s -// CHECK-LABEL: func.func private @recv_ -// CHECK: %0:2 = "mhlo.recv"(%arg0) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> : (!mhlo.token) -> (tensor, !mhlo.token) +// These tests are created from MHLO->HLO of export_async.mlir. -// CHECK-LABEL: func.func private @send_ -// CHECK: %0 = "mhlo.send"(%arg0, %arg1) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> : (tensor, !mhlo.token) -> !mhlo.token - -// CHECK-LABEL: func.func @main -// CHECK-LITERAL: %0 = "mhlo.async_start"(%arg0, %arg1) <{called_computation = @send_, execution_thread = "main"}> {mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_dtoh_0"}, mhlo.sharding = "{{maximal device=0}, {maximal device=0}, {maximal device=0}}", xla_shape = "(s32[], u32[], token[])"} : (tensor, !mhlo.token) -> !mhlo.async_bundle, !mhlo.token>, !mhlo.token, tensor> -// CHECK-NEXT-LITERAL: %1 = "mhlo.async_done"(%0) {called_computation = @send_, execution_thread = "main", mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_dtoh_0"}, mhlo.sharding = "{maximal device=0}", xla_shape = "token[]"} : (!mhlo.async_bundle, !mhlo.token>, !mhlo.token, tensor>) -> !mhlo.token -// CHECK-NEXT-LITERAL: %2 = "mhlo.async_start"(%1) <{called_computation = @recv_, execution_thread = "main"}> {mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_htod_0"}, mhlo.sharding = "{{maximal device=0}, {maximal device=0}, {maximal device=0}}", xla_shape = "(s32[], u32[], token[])"} : (!mhlo.token) -> !mhlo.async_bundle, !mhlo.token>, tensor> -// CHECK-NEXT-LITERAL: %3:2 = "mhlo.async_done"(%2) {called_computation = @recv_, execution_thread = "main", mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_htod_0"}, mhlo.sharding = "{{maximal device=0}, {maximal device=0}}"} : (!mhlo.async_bundle, !mhlo.token>, tensor>) -> (tensor, !mhlo.token) HloModule foobar +// CHECK-LABEL: func.func @main(%arg0: tensor, %arg1: !mhlo.token) +// CHECK-NEXT: "mhlo.send"(%arg0, %arg1) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> +// CHECK-NEXT: "mhlo.recv"(%0) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> + ENTRY %async_send_recv_test (arg_0: s32[], arg_1: token[]) -> (s32[], token[]) { %arg_0 = s32[] parameter(0) %arg_1 = token[] parameter(1) @@ -41,8 +36,8 @@ HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,128]{1,0}} ENTRY %async_all_gather_test (Arg_0.1: f32[128,32]) -> f32[128,128] { %Arg_0.1 = f32[128,32] parameter(0) - %all-gather-start.2 = f32[128,128] all-gather-start(f32[128,32] %Arg_0.1), channel_id=1, replica_groups={{0,2,4,6},{1,3,5,7}}, constrain_layout=true, dimensions={1}, use_global_device_ids=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=16} - ROOT %all-gather-done.3 = f32[128,128] all-gather-done(f32[128,128] %all-gather-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=17} + %all-gather-start.2 = f32[128,128] all-gather-start(f32[128,32] %Arg_0.1), channel_id=1, replica_groups={{0,2,4,6},{1,3,5,7}}, constrain_layout=true, dimensions={1}, use_global_device_ids=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=16} + ROOT %all-gather-done.3 = f32[128,128] all-gather-done(f32[128,128] %all-gather-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=17} } // ----- @@ -52,7 +47,7 @@ HloModule main, entry_computation_layout={(f32[10]{0})->f32[10]{0}} %region_1.2 (Arg_0.3: f32[], Arg_1.4: f32[]) -> f32[] { %Arg_0.3 = f32[] parameter(0) %Arg_1.4 = f32[] parameter(1) - ROOT %maximum.5 = f32[] maximum(f32[] %Arg_0.3, f32[] %Arg_1.4), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=7} + ROOT %maximum.5 = f32[] maximum(f32[] %Arg_0.3, f32[] %Arg_1.4), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=7} } // CHECK-LABEL: func.func private @all_reduce_ @@ -63,8 +58,8 @@ HloModule main, entry_computation_layout={(f32[10]{0})->f32[10]{0}} // CHECK: mhlo.async_done ENTRY %async_all_reduce_test (Arg_0.1: f32[10]) -> f32[10] { %Arg_0.1 = f32[10] parameter(0) - %all-reduce-start.6 = f32[10] all-reduce-start(f32[10] %Arg_0.1), channel_id=5, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=%region_1.2, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=22} - ROOT %all-reduce-done.7 = f32[10] all-reduce-done(f32[10] %all-reduce-start.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=23} + %all-reduce-start.6 = f32[10] all-reduce-start(f32[10] %Arg_0.1), channel_id=5, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=%region_1.2, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=22} + ROOT %all-reduce-done.7 = f32[10] all-reduce-done(f32[10] %all-reduce-start.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=23} } // ----- @@ -79,30 +74,38 @@ HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,32]{1,0}} // CHECK: mhlo.async_done ENTRY %async_collective_permute_test (Arg_0.1: f32[128,32]) -> f32[128,32] { %Arg_0.1 = f32[128,32] parameter(0) - %collective-permute-start.2 = f32[128,32] collective-permute-start(f32[128,32] %Arg_0.1), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=13} - ROOT %collective-permute-done.3 = f32[128,32] collective-permute-done(f32[128,32] %collective-permute-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=14} + %collective-permute-start.2 = f32[128,32] collective-permute-start(f32[128,32] %Arg_0.1), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=13} + ROOT %collective-permute-done.3 = f32[128,32] collective-permute-done(f32[128,32] %collective-permute-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=14} } // ----- HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,32]{1,0}} +// CHECK-LABEL: func.func private @copy_(%arg0: tensor<128x32xf32>) +// CHECK-NEXT: mhlo.copy %arg0 {cross_program_prefetch_index = 0 : i32} + +// CHECK-LABEL: func.func @main(%arg0: tensor<128x32xf32>) +// CHECK-NEXT: "mhlo.async_start"(%arg0) <{called_computation = @copy_, execution_thread = "main"}> +// CHECK-NEXT: mhlo.async_done ENTRY %async_copy_test (Arg_0.1: f32[128,32]) -> f32[128,32] { %Arg_0.1 = f32[128,32] parameter(0) - %copy-start.2 = (f32[128,32], f32[128,32], u32[]) copy-start(f32[128,32] %Arg_0.1), cross_program_prefetch_index=0, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=10} - ROOT %copy-done.3 = f32[128,32] copy-done((f32[128,32], f32[128,32], u32[]) %copy-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=11} + %copy-start.2 = (f32[128,32], f32[128,32], u32[]) copy-start(f32[128,32] %Arg_0.1), cross_program_prefetch_index=0, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=10} + ROOT %copy-done.3 = f32[128,32] copy-done((f32[128,32], f32[128,32], u32[]) %copy-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=11} } // ----- HloModule main, entry_computation_layout={(token[])->(s32[3,4]{1,0}, token[])} +// CHECK-LABEL: func.func @main(%arg0: !mhlo.token) +// CHECK-NEXT: "mhlo.recv"(%arg0) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> ENTRY %async_recv_test_tuple (Arg_0.1: token[]) -> (s32[3,4], token[]) { %Arg_0.1 = token[] parameter(0) - %recv.2 = (s32[3,4], u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=16} - %recv-done.3 = (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) %recv.2), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} - %get-tuple-element.4 = s32[3,4] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=0, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} - %get-tuple-element.5 = token[] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=1, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} + %recv.2 = (s32[3,4], u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=16} + %recv-done.3 = (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) %recv.2), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} + %get-tuple-element.4 = s32[3,4] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=0, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} + %get-tuple-element.5 = token[] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=1, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} ROOT %tuple.6 = (s32[3,4], token[]) tuple(s32[3,4] %get-tuple-element.4, token[] %get-tuple-element.5) } @@ -110,53 +113,197 @@ ENTRY %async_recv_test_tuple (Arg_0.1: token[]) -> (s32[3,4], token[]) { HloModule main, entry_computation_layout={(s32[3,4]{1,0}, token[])->token[]} +// CHECK-LABEL: func.func @main(%arg0: tensor<3x4xi32>, %arg1: !mhlo.token) +// CHECK: "mhlo.send"(%arg0, %arg1) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> ENTRY %async_send_test (Arg_0.1: s32[3,4], Arg_1.2: token[]) -> token[] { %Arg_0.1 = s32[3,4] parameter(0) %Arg_1.2 = token[] parameter(1) - %send.3 = (s32[3,4], u32[], token[]) send(s32[3,4] %Arg_0.1, token[] %Arg_1.2), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=16} - ROOT %send-done.4 = token[] send-done((s32[3,4], u32[], token[]) %send.3), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=17} + %send.3 = (s32[3,4], u32[], token[]) send(s32[3,4] %Arg_0.1, token[] %Arg_1.2), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=16} + ROOT %send-done.4 = token[] send-done((s32[3,4], u32[], token[]) %send.3), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=17} } +// ----- -// BROKEN: b/TODO: Async custom calls? +HloModule main, entry_computation_layout={(token[])->token[]} -// HloModule main, entry_computation_layout={(f32[10]{0})->(f32[20]{0})} +// CHECK-LABEL: func.func @main(%arg0: !mhlo.token) +// CHECK-NEXT: "mhlo.send"(%arg0) <{channel_handle = #mhlo.channel_handle, is_host_transfer = false}> -// ENTRY %async_custom_call_test2 (Arg_0.1: f32[10]) -> (f32[20]) { -// %Arg_0.1 = f32[10] parameter(0) -// %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="bar", metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=21} -// %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=22} -// ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=23} -// } +ENTRY %async_send_test_empty (Arg_0.1: token[]) -> token[] { + %tuple.2 = () tuple(), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15} + %Arg_0.1 = token[] parameter(0) + %send.3 = ((), u32[], token[]) send(() %tuple.2, token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15} + ROOT %send-done.4 = token[] send-done(((), u32[], token[]) %send.3), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=16} +} -// HloModule main, entry_computation_layout={(f32[10]{0})->(f32[20]{0})} +// ----- -// ENTRY %async_custom_call_test (Arg_0.1: f32[10]) -> (f32[20]) { -// %Arg_0.1 = f32[10] parameter(0) -// %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="foo", metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=16} -// %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=18} -// ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=20} -// } +HloModule main, entry_computation_layout={(token[])->((), token[])} +// CHECK-LABEL: func.func @main(%arg0: !mhlo.token) +// CHECK-NEXT: "mhlo.recv"(%arg0) <{channel_handle = #mhlo.channel_handle, is_host_transfer = false}> -/////////// +ENTRY %async_recv_test_empty (Arg_0.1: token[]) -> ((), token[]) { + %Arg_0.1 = token[] parameter(0) + %recv.2 = ((), u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=17} + ROOT %recv-done.3 = ((), token[]) recv-done(((), u32[], token[]) %recv.2), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=18} +} -// BROKEN: b/TODO: Empty arg send/recv don't roundtrip +// ----- -// HloModule main, entry_computation_layout={(token[])->token[]} +/// Legacy tests -- These tests are not directly from export_async.mlir. -// ENTRY %async_send_test_empty (Arg_0.1: token[]) -> token[] { -// %tuple.2 = () tuple(), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15} -// %Arg_0.1 = token[] parameter(0) -// %send.3 = ((), u32[], token[]) send(() %tuple.2, token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15} -// ROOT %send-done.4 = token[] send-done(((), u32[], token[]) %send.3), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=16} -// } +HloModule foobar + +// CHECK-LABEL: func.func private @all_gather_(%arg0: tensor<128x32xf32>) +// CHECK-NEXT: "mhlo.all_gather" +// CHECK-SAME: all_gather_dim = 1 : i64 +// CHECK-SAME: channel_handle = #mhlo.channel_handle +// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> +// CHECK-SAME: use_global_device_ids + +// CHECK-LABEL: func.func @main +// CHECK-NEXT: %0 = "mhlo.async_start"(%arg0) <{called_computation = @all_gather_, execution_thread = "main"}> +// CHECK-NEXT: "mhlo.async_done" +ENTRY %test_all_gather_start { + input = f32[128,32] parameter(0) + ag-start = (f32[128,32], f32[128,128]) all-gather-start(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, dimensions={1}, use_global_device_ids=true + ROOT ag-done = f32[128,128] all-gather-done(ag-start) +} + +// ----- -// HloModule main, entry_computation_layout={(token[])->((), token[])} +HloModule foobar + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +// CHECK-LABEL: func.func private @all_reduce_(%arg0: tensor<128x32xf32>) +// CHECK-NEXT: "mhlo.all_reduce" +// CHECK-SAME: channel_handle = #mhlo.channel_handle +// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> +// CHECK-SAME: use_global_device_ids + +// CHECK-LABEL: func.func @main +// CHECK-NEXT: [[AR_START:%.*]] = "mhlo.async_start"(%arg0) <{called_computation = @all_reduce_, execution_thread = "main"}> +// CHECK-NEXT: "mhlo.async_done"([[AR_START]]) +%test_all_reduce_start { + input = f32[128,32] parameter(0) + ar-start = (f32[128,32], f32[128,32]) all-reduce-start(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, to_apply=add, use_global_device_ids=true + ROOT ar-done = f32[128,32] all-reduce-done(ar-start) +} -// ENTRY %async_recv_test (Arg_0.1: token[]) -> ((), token[]) { -// %Arg_0.1 = token[] parameter(0) -// %recv.2 = ((), u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=17} -// ROOT %recv-done.3 = ((), token[]) recv-done(((), u32[], token[]) %recv.2), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=18} +// ----- + +HloModule foobar + +// CHECK-LABEL: func.func private @collective_permute_(%arg0: tensor<128x32xf32>) +// CHECK-NEXT: "mhlo.collective_permute" +// CHECK-SAME{LITERAL}: <{source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>}> : (tensor<128x32xf32>) -> tensor<128x32xf32> + +// CHECK-LABEL: func @main +// CHECK-NEXT: "mhlo.async_start"(%arg0) <{called_computation = @collective_permute_, execution_thread = "main"}> +// CHECK-NEXT: "mhlo.async_done" +%test_collective_permute (input: f32[128,32]) -> f32[128,32] { + %input = f32[128,32]{1,0} parameter(0) + %cp-start = (f32[128,32]{1,0}, f32[128,32]) collective-permute-start(%input), source_target_pairs={{0,1},{1,2},{2,3}} + ROOT %cp-done = f32[128,32]{1,0} collective-permute-done(%cp-start) +} + +// ----- + +HloModule foobar + +// CHECK-LABEL: func.func private @copy_(%arg0: tensor<128x32xf32>) +// CHECK-NEXT: mhlo.copy +// CHECK-SAME: cross_program_prefetch_index + +// CHECK-LABEL: func @main +// CHECK-NEXT: "mhlo.async_start"(%arg0) <{called_computation = @copy_, execution_thread = "main"}> +// CHECK-NEXT: "mhlo.async_done" +%test_copy_start { + input = f32[128,32] parameter(0) + copy-start = (f32[128,32], f32[128,32], u32[]) copy-start(input), cross_program_prefetch_index=0 + ROOT copy-done = f32[128,32] copy-done(copy-start) +} + +// ----- + +HloModule foobar + +// CHECK-LABEL: func.func @main +// CHECK-NEXT: "mhlo.send"(%arg0, %arg1) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> +%test_send_start { + input = f32[128,32] parameter(0) + tok = token[] parameter(1) + send-start = (f32[128,32], u32[], token[]) send(input, tok), channel_id=5, is_host_transfer=true + ROOT send-done = token[] send-done(send-start), channel_id=5, is_host_transfer=true +} + +// ----- + +HloModule foobar + +// CHECK-LABEL: func.func @main +// CHECK-NEXT:"mhlo.recv"(%arg1) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> +%test_recv_start { + input = f32[128,32] parameter(0) + tok = token[] parameter(1) + recv-start = (f32[128,32], u32[], token[]) recv(tok), channel_id=5, is_host_transfer=true + recv-done = (f32[128,21], token[]) recv-done(recv-start), channel_id=5, is_host_transfer=true + ROOT gte = get-tuple-element(recv-done), index=0 +} + +// ----- + +HloModule foobar + +// CHECK-LABEL: func.func @main +// CHECK-NEXT: "mhlo.recv"(%arg1) <{channel_handle = #mhlo.channel_handle, is_host_transfer = false}> +%test_recv_dtd_start { + input = f32[128,32] parameter(0) + tok = token[] parameter(1) + recv-start = (f32[128,32], u32[], token[]) recv(tok), channel_id=5 + recv-done = (f32[128,32], token[]) recv-done(recv-start), channel_id=5 + ROOT gte = get-tuple-element(recv-done), index=0 +} + +// ----- + +HloModule foobar + +// CHECK-LABEL: func.func @main +// CHECK-NEXT: "mhlo.recv"(%arg1) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> +// CHECK-SAME{LITERAL}: {mhlo.sharding = "{{maximal device=0}, {maximal device=0}}"} +// CHECK-SAME: (!mhlo.token) -> (tensor, !mhlo.token) +%test_recv_3_tuple_sharding_to_2_tuple { + input = s32[] parameter(0) + tok = token[] parameter(1) + recv = (s32[], u32[], token[]) recv(token[] tok), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}} + recv-done = (s32[], token[]) recv-done((s32[], u32[], token[]) recv), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}} + ROOT tok2 = s32[] get-tuple-element((s32[], token[]) recv-done), index=0, sharding={maximal device=0} +} + + +// BROKEN: b/TODO: support roundtrip of async custom calls? + +// HloModule main, entry_computation_layout={(f32[10]{0})->(f32[20]{0})} + +// ENTRY %async_custom_call_test2 (Arg_0.1: f32[10]) -> (f32[20]) { +// %Arg_0.1 = f32[10] parameter(0) +// %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="bar", metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=21} +// %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=22} +// ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=23} // } +// HloModule main, entry_computation_layout={(f32[10]{0})->(f32[20]{0})} + +// ENTRY %async_custom_call_test (Arg_0.1: f32[10]) -> (f32[20]) { +// %Arg_0.1 = f32[10] parameter(0) +// %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="foo", metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=16} +// %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=18} +// ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=20} +// } diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_async2.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_async2.hlo deleted file mode 100644 index 7493c958776950..00000000000000 --- a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_async2.hlo +++ /dev/null @@ -1,146 +0,0 @@ -// RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s -// RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo %s -o - | FileCheck %s -check-prefix=NO_DEAD_FUNCTION - -// It would be great to consolidate this test with `import_async.hlo`, but -// this test is very fragile and doesn't run properly in a `-split-input-file` -// mode. - -// NO_DEAD_FUNCTION-NOT: @test - -// CHECK: module @foobar -HloModule foobar - -// Compiler-generated functions - -// CHECK: func private [[RECV_DTD_GENSYM:@.*recv.*]]([[TOK:%.*]]: !mhlo.token) -> (tensor<128x32xf32>, !mhlo.token) attributes {execution_thread = "main"} { - // CHECK-NEXT: "mhlo.recv"([[TOK]] - // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle, is_host_transfer = false} - -// CHECK: func private [[RECV_GENSYM:@.*recv.*]]([[TOK:%.*]]: !mhlo.token) -> (tensor<128x32xf32>, !mhlo.token) attributes {execution_thread = "main"} { - // CHECK-NEXT: "mhlo.recv"([[TOK]] - // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle, is_host_transfer = true} - -// CHECK: func private [[SEND_GENSYM:@.*send.*]]([[INPUT:%.*]]: tensor<128x32xf32>, %arg1: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} { - // CHECK-NEXT: "mhlo.send"([[INPUT]] - // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle, is_host_transfer = true} - -// CHECK: func private [[COPY_GENSYM:@.*copy.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} { - // CHECK-NEXT: mhlo.copy [[INPUT]] - // CHECK-SAME: cross_program_prefetch_index - -// CHECK: func private [[CP_GENSYM:@.*collective_permute_.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} { - // CHECK-NEXT: "mhlo.collective_permute"([[INPUT]]) - // CHECK-SAME{LITERAL}: <{source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>}> : (tensor<128x32xf32>) -> tensor<128x32xf32> - -// CHECK: func private [[AR_GENSYM:@.*all_reduce.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} { - // CHECK-NEXT: "mhlo.all_reduce"([[INPUT]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle - // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - // CHECK-SAME: use_global_device_ids - // CHECK: [[BLOCK:^.*]]([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): - // CHECK: mhlo.add [[LHS]], [[RHS]] - -// CHECK: func private [[AG_GENSYM:@.*all_gather.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x128xf32> attributes {execution_thread = "main"} { - // CHECK-NEXT: "mhlo.all_gather"([[INPUT]]) - // CHECK-SAME: all_gather_dim = 1 : i64 - // CHECK-SAME: channel_handle = #mhlo.channel_handle - // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - // CHECK-SAME: use_global_device_ids - -// CHECK: func @main(%arg0: tensor) -> tensor { -ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { - ROOT %Arg_0.1 = f32[] parameter(0) -} - -// Tests - -// CHECK: func private @test_all_gather_start -// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>) -%test_all_gather_start { - input = f32[128,32] parameter(0) - // CHECK-NEXT: [[AG_START:%.*]] = "mhlo.async_start"([[INPUT]]) - // CHECK-SAME: called_computation = [[AG_GENSYM]], execution_thread = "main" - ag-start = (f32[128,32], f32[128,128]) all-gather-start(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, dimensions={1}, use_global_device_ids=true - // CHECK-NEXT: "mhlo.async_done"([[AG_START]]) - ROOT ag-done = f32[128,128] all-gather-done(ag-start) -} - -add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) -} - -// CHECK: func private @test_all_reduce_start -// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>) -%test_all_reduce_start { - input = f32[128,32] parameter(0) - // CHECK-NEXT: [[AR_START:%.*]] = "mhlo.async_start"([[INPUT]]) - // CHECK-SAME: called_computation = [[AR_GENSYM]], execution_thread = "main" - ar-start = (f32[128,32], f32[128,32]) all-reduce-start(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, to_apply=add, use_global_device_ids=true - // CHECK-NEXT: "mhlo.async_done"([[AR_START]]) - ROOT ar-done = f32[128,32] all-reduce-done(ar-start) -} - -// CHECK: func private @test_collective_permute -// CHECK-SAME: ([[ARG:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> -%test_collective_permute (input: f32[128,32]) -> f32[128,32] { - %input = f32[128,32]{1,0} parameter(0) - // CHECK-NEXT: [[CP_START:%.*]] = "mhlo.async_start"([[ARG]]) - // CHECK-SAME: called_computation = [[CP_GENSYM]], execution_thread = "main" - %cp-start = (f32[128,32]{1,0}, f32[128,32]) collective-permute-start(%input), source_target_pairs={{0,1},{1,2},{2,3}} - // CHECK-NEXT: "mhlo.async_done"([[CP_START]]) - ROOT %cp-done = f32[128,32]{1,0} collective-permute-done(%cp-start) -} - -// CHECK: func private @test_copy_start -// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>) -%test_copy_start { - input = f32[128,32] parameter(0) - // CHECK-NEXT: [[COPY_START:%.*]] = "mhlo.async_start"([[INPUT]]) - // CHECK-SAME: called_computation = [[COPY_GENSYM]], execution_thread = "main" - copy-start = (f32[128,32], f32[128,32], u32[]) copy-start(input), cross_program_prefetch_index=0 - // CHECK-NEXT: "mhlo.async_done"([[COPY_START]]) - ROOT copy-done = f32[128,32] copy-done(copy-start) -} - -// CHECK: func private @test_send -// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token) -%test_send_start { - input = f32[128,32] parameter(0) - tok = token[] parameter(1) - // CHECK-NEXT: [[SEND_START:%.*]] = "mhlo.async_start"([[INPUT]], [[TOK]]) - // CHECK-SAME: called_computation = [[SEND_GENSYM]], execution_thread = "main" - // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle, !mhlo.token>, !mhlo.token, tensor> - send-start = (f32[128,32], u32[], token[]) send(input, tok), channel_id=5, is_host_transfer=true - // CHECK-NEXT: "mhlo.async_done"([[SEND_START]]) - ROOT send-done = token[] send-done(send-start), channel_id=5, is_host_transfer=true -} - -// CHECK: func private @test_recv -// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token) -%test_recv_start { - input = f32[128,32] parameter(0) - tok = token[] parameter(1) - // CHECK-NEXT: [[RECV_START:%.*]] = "mhlo.async_start"([[TOK]]) - // CHECK-SAME: called_computation = [[RECV_GENSYM]], execution_thread = "main" - // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle, !mhlo.token>, tensor> - recv-start = (f32[128,32], u32[], token[]) recv(tok), channel_id=5, is_host_transfer=true - // CHECK-NEXT: "mhlo.async_done"([[RECV_START]]) - recv-done = (f32[128,21], token[]) recv-done(recv-start), channel_id=5, is_host_transfer=true - ROOT gte = get-tuple-element(recv-done), index=0 -} - -// CHECK: func private @test_recv_dtd -// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token) -%test_recv_dtd_start { - input = f32[128,32] parameter(0) - tok = token[] parameter(1) - // CHECK-NEXT: [[RECV_START:%.*]] = "mhlo.async_start"([[TOK]]) - // CHECK-SAME: called_computation = [[RECV_DTD_GENSYM]], execution_thread = "main" - // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle, !mhlo.token>, tensor> - recv-start = (f32[128,32], u32[], token[]) recv(tok), channel_id=5 - // CHECK-NEXT: "mhlo.async_done"([[RECV_START]]) - recv-done = (f32[128,21], token[]) recv-done(recv-start), channel_id=5 - ROOT gte = get-tuple-element(recv-done), index=0 -} diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD index f4ed22e790935c..f4f1ac8d62f633 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD @@ -108,6 +108,7 @@ cc_library( hdrs = ["stack_frame_index_builder.h"], deps = [ "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/strings:string_view", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", ], @@ -138,18 +139,12 @@ cc_library( "//xla:status_macros", "//xla:types", "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:approx_topk", - "//xla/client/lib:approx_topk_shape", - "//xla/client/lib:matrix", - "//xla/client/lib:quantize", - "//xla/client/lib:slicing", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", "//xla/hlo/builder/lib:approx_topk", "//xla/hlo/builder/lib:approx_topk_shape", "//xla/hlo/builder/lib:matrix", + "//xla/hlo/builder/lib:quantize", "//xla/hlo/builder/lib:slicing", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 259a396f036c80..e837d47418a141 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -127,11 +127,14 @@ constexpr char kApproxTopK[] = "ApproxTopK"; constexpr char kBackendConfig[] = "backend_config"; constexpr char kCallTargetName[] = "call_target_name"; constexpr char kCalledComputations[] = "called_computations"; +constexpr char kChannelId[] = "channel_id"; constexpr char kHasSideEffect[] = "has_side_effect"; constexpr char kIsFallback[] = "is_fallback"; +constexpr char kRaggedAllToAll[] = "ragged_all_to_all"; constexpr char kRecallTarget[] = "recall_target"; constexpr char kReductionDim[] = "reduction_dim"; constexpr char kReductionInputSizeOverride[] = "reduction_input_size_override"; +constexpr char kReplicaGroups[] = "replica_groups"; constexpr char kTopK[] = "top_k"; // MHLO attributes. Module level attributes require namespacing. @@ -2265,6 +2268,34 @@ LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) { } BuildGetTupleElementsForTupleResults(op, cc_op, ctx); return success(); + } else if (op.getCallTargetName() == kRaggedAllToAll) { + auto backend_config = + mlir::dyn_cast_or_null(op.getBackendConfigAttr()); + auto isSupportedAttrName = [](NamedAttribute attr) { + auto name = attr.getName(); + return name == kCallTargetName || name == kBackendConfig || + name == kApiVersion || name == kCalledComputations || + name == kHasSideEffect; + }; + for (const auto& attr : op->getAttrs()) { + if (!isSupportedAttrName(attr)) + return op.emitOpError() + << attr.getName().getValue() + << " is not a supported attribute for RaggedAllToAll"; + } + DenseIntElementsAttr replica_groups = + backend_config.getAs(kReplicaGroups); + mlir::mhlo::ChannelHandleAttr channel_handle_attr = + backend_config.getAs(kChannelId); + xla::ChannelHandle channel_handle; + if (channel_handle_attr) { + channel_handle = Convert_channel_handle(channel_handle_attr); + } + xla::XlaOp ragged_all_to_all_op = + RaggedAllToAll(args[0], args[1], args[2], args[3], args[4], args[5], + Convert_replica_groups(replica_groups), channel_handle); + value_map[op.getResult(0)] = ragged_all_to_all_op; + return success(); } if (op.getCalledComputations().size() > 1) @@ -2515,14 +2546,54 @@ LogicalResult ExportXlaOp(RecvOp op, OpLoweringContext ctx) { else data_shape = xla::ShapeUtil::MakeTupleShape(subshapes); - token = xla::internal::XlaBuilderFriend::BuildRecv( - ctx.builder, token, data_shape, - Convert_channel_handle(op.getChannelHandle()), op.getIsHostTransfer()); - xla::XlaOp xla_result = xla::internal::XlaBuilderFriend::BuildRecvDone( - ctx.builder, token, data_shape, - Convert_channel_handle(op.getChannelHandle()), op.getIsHostTransfer()); + auto get_sharding = [](const xla::OpSharding& sharding) { + xla::OpSharding ret; + if (sharding.type() != xla::OpSharding::TUPLE) { + ret = sharding; + } else { + ret = sharding.tuple_shardings(0); + } + return ret; + }; + if (ctx.builder->sharding().has_value()) { + // HLO Recv needs a 3-tuple sharding. Get the sharding from the builder and + // make it a 3-tuple sharding. + std::optional sharding = *ctx.builder->sharding(); + xla::OpSharding single_sharding = get_sharding(*sharding); + auto* tuple_shardings = sharding->mutable_tuple_shardings(); + tuple_shardings->Clear(); + for (int i = 0; i < 3; ++i) { + tuple_shardings->Add(xla::OpSharding(single_sharding)); + } + xla::XlaScopedShardingAssignment sharding_scope(ctx.builder, sharding); + token = xla::internal::XlaBuilderFriend::BuildRecv( + ctx.builder, token, data_shape, + Convert_channel_handle(op.getChannelHandle()), op.getIsHostTransfer()); + } else { + token = xla::internal::XlaBuilderFriend::BuildRecv( + ctx.builder, token, data_shape, + Convert_channel_handle(op.getChannelHandle()), op.getIsHostTransfer()); + } + + xla::XlaOp xla_result; + { + xla::XlaScopedShardingAssignment sharding_scope(ctx.builder, + ctx.builder->sharding()); + xla_result = xla::internal::XlaBuilderFriend::BuildRecvDone( + ctx.builder, token, data_shape, + Convert_channel_handle(op.getChannelHandle()), op.getIsHostTransfer()); + } + + xla::XlaOp data_tuple_element; + if (ctx.builder->sharding().has_value()) { + // HLO GetTupleElement needs a single sharding, + xla::XlaScopedShardingAssignment sharding_scope( + ctx.builder, get_sharding(*ctx.builder->sharding())); + data_tuple_element = xla::GetTupleElement(xla_result, 0); + } else { + data_tuple_element = xla::GetTupleElement(xla_result, 0); + } - auto data_tuple_element = xla::GetTupleElement(xla_result, 0); if (subshapes.size() == 1) { value_map[op.getResult(0)] = data_tuple_element; } else { @@ -2757,9 +2828,25 @@ LogicalResult ExportXlaOp(SendOp op, OpLoweringContext ctx) { xla::XlaOp token; if (failed(GetXlaOp(op.getToken(), value_map, &token, op))) return failure(); - token = xla::internal::XlaBuilderFriend::BuildSend( - ctx.builder, operand, token, - Convert_channel_handle(op.getChannelHandle()), op.getIsHostTransfer()); + // SendOp has 1 result, but HLO Send has 3 results. Convert the sharding to a + // tuple sharding with 3 entries. + if (ctx.builder->sharding().has_value()) { + xla::OpSharding sharding = *ctx.builder->sharding(); + const xla::OpSharding single_sharding = *ctx.builder->sharding(); + sharding.set_type(xla::OpSharding::TUPLE); + auto* tuple_shardings = sharding.mutable_tuple_shardings(); + tuple_shardings->Add(xla::OpSharding(single_sharding)); + tuple_shardings->Add(xla::OpSharding(single_sharding)); + tuple_shardings->Add(xla::OpSharding(single_sharding)); + xla::XlaScopedShardingAssignment sharding_scope(ctx.builder, sharding); + token = xla::internal::XlaBuilderFriend::BuildSend( + ctx.builder, operand, token, + Convert_channel_handle(op.getChannelHandle()), op.getIsHostTransfer()); + } else { + token = xla::internal::XlaBuilderFriend::BuildSend( + ctx.builder, operand, token, + Convert_channel_handle(op.getChannelHandle()), op.getIsHostTransfer()); + } value_map[op] = xla::internal::XlaBuilderFriend::BuildSendDone( ctx.builder, token, Convert_channel_handle(op.getChannelHandle()), op.getIsHostTransfer()); @@ -3474,7 +3561,6 @@ LogicalResult ConvertToHloModule::LowerReturn( /*fast_mem=*/false); if (!reshape.ok()) return inst->emitError() << reshape.status().message(); - returns[index] = reshape.value(); } } diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.cc index dc96c4192938c3..ab011f018bda99 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include #include -#include #include #include +#include "absl/strings/string_view.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Location.h" #include "mlir/Support/LLVM.h" @@ -29,7 +29,7 @@ limitations under the License. namespace mlir { -int FindId(std::string_view key, std::map &index) { +int FindId(absl::string_view key, std::map &index) { auto entry_iterator = index.find(key); if (entry_iterator == index.end()) { return 0; diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h index b8bed27e2ab091..9e1c34085452db 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include #include +#include "absl/strings/string_view.h" #include "mlir/IR/Location.h" #include "xla/service/hlo.pb.h" @@ -46,8 +46,8 @@ class StackFrameIndexBuilder { xla::StackFrameIndexProto indexes_; - std::map function_name_to_id_; - std::map file_name_to_id_; + std::map function_name_to_id_; + std::map file_name_to_id_; std::map, int> file_location_to_id_; std::map, int> frame_to_id_; }; diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir index 17b686cc2f5ebe..a22ec331d93b20 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir @@ -814,6 +814,26 @@ func.func @main(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK-SAME: f32[2,3] custom-call(f32[2,3] [[VAL_1]]) // CHECK-SAME: custom_call_target="SetBound" // CHECK-SAME: literal=s32[] 1 + +// ----- + +// CHECK: HloModule +func.func @main(%arg0: tensor<6xf32>, %arg1: tensor<6xf32>, %arg2: tensor<3xi32>, %arg3: tensor<3xi32>, %arg4: tensor<3xi32>, %arg5: tensor<3xi32>) -> (tensor<6xf32>) { + %0 = mhlo.custom_call @ragged_all_to_all(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {api_version = 4 : i32, backend_config = {replica_groups = dense<[[0, 1, 2]]> : tensor<1x3xi64>}} : (tensor<6xf32>, tensor<6xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<6xf32> + return %0 : tensor<6xf32> +} + +// CHECK: ENTRY +// CHECK: [[ARG_0:%.*]] = f32[6] parameter(0) +// CHECK: [[ARG_1:%.*]] = f32[6] parameter(1) +// CHECK: [[ARG_2:%.*]] = s32[3] parameter(2) +// CHECK: [[ARG_3:%.*]] = s32[3] parameter(3) +// CHECK: [[ARG_4:%.*]] = s32[3] parameter(4) +// CHECK: [[ARG_5:%.*]] = s32[3] parameter(5) +// CHECK: ROOT +// CHECK-SAME: f32[6] ragged-all-to-all(f32[6] [[ARG_0]], f32[6] [[ARG_1]], s32[3] [[ARG_2]], s32[3] [[ARG_3]], s32[3] [[ARG_4]], /*index=5*/s32[3] [[ARG_5]]) +// CHECK-SAME{LITERAL}: replica_groups={{0,1,2}} + // ----- // CHECK: HloModule diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir index 70bf10c8d045c8..add453c9a276df 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir @@ -310,3 +310,53 @@ func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> { %2 = "mhlo.async_done"(%1) : (!mhlo.async_bundle>, tensor<20xf32>, tensor>) -> tensor<20xf32> return %2 : tensor<20xf32> } + +// ----- + +// Breaking test case where tf2xla lowers to a send with a single manual +// sharding annotation on recv. + +// CHECK: HloModule + +// CHECK: ENTRY +func.func @main() -> tensor<1x2xf32> attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "", outputs = "_retval0"}} { + // CHECK: %[[AFTER_ALL:.*]] = token[] after-all() + // CHECK-NEXT: %[[RECV:.*]] = (f32[1,2], u32[], token[]) recv(token[] %[[AFTER_ALL]]), channel_id=2, is_host_transfer=true, + // CHECK-SAME{LITERAL}: sharding={{manual}, {manual}, {manual}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_1_retvals_htod_0"} + // CHECK-NEXT: %[[RECV_DONE:.*]] = (f32[1,2], token[]) recv-done((f32[1,2], u32[], token[]) %[[RECV]]), channel_id=2, is_host_transfer=true, + // CHECK-SAME{LITERAL}: sharding={{manual}, {manual}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_1_retvals_htod_0"} + // CHECK-NEXT: ROOT %[[GET_TUPLE_0:.*]] = f32[1,2] get-tuple-element((f32[1,2], token[]) %[[RECV_DONE]]), index=0, sharding={manual}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_1_retvals_htod_0"} + // CHECK-NEXT: %[[GET_TUPLE_1:.*]] = token[] get-tuple-element((f32[1,2], token[]) %[[RECV_DONE]]), index=1, sharding={manual}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_1_retvals_htod_0"} + %0 = mhlo.create_token : !mhlo.token + %1:2 = "mhlo.recv"(%0) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> {mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "host_compute_channel_1_retvals_htod_0"}, mhlo.sharding = "\08\04"} : (!mhlo.token) -> (tensor<1x2xf32>, !mhlo.token) + return %1#0 : tensor<1x2xf32> +} + +// ----- + +// Check: +// - send has a 3 tuple sharding +// - send-done has a single sharding +// - recv has a 3 tuple sharding +// - recv-done has a 2 tuple sharding + +// CHECK: HloModule + +// CHECK: ENTRY +func.func @main(%arg0: tensor<1x2xi64>) -> tensor<1x2xi64> attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "_arg0", outputs = "_retval0"}} { + // CHECK: %[[ARG0:.*]] = s64[1,2] parameter(0) + // CHECK-NEXT: %[[AFTER_ALL:.*]] = token[] after-all() + // CHECK-NEXT: %[[SEND:.*]] = (s64[1,2], u32[], token[]) send(s64[1,2] %[[ARG0]], token[] %[[AFTER_ALL]]), channel_id=3, is_host_transfer=true, + // CHECK-SAME{LITERAL}: sharding={{manual}, {manual}, {manual}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_0_args_dtoh_0"} + // CHECK-NEXT: %[[SEND_DONE:.*]] = token[] send-done((s64[1,2], u32[], token[]) %[[SEND]]), channel_id=3, is_host_transfer=true, sharding={manual}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_0_args_dtoh_0"} + // CHECK-NEXT: %[[RECV:.*]] = (s64[1,2], u32[], token[]) recv(token[] %[[SEND_DONE]]), channel_id=4, is_host_transfer=true, + // CHECK-SAME{LITERAL}: sharding={{manual}, {manual}, {manual}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_0_retvals_htod_0"} + // CHECK-NEXT: %[[RECV_DONE:.*]] = (s64[1,2], token[]) recv-done((s64[1,2], u32[], token[]) %[[RECV]]), channel_id=4, is_host_transfer=true, + // CHECK-SAME{LITERAL}: sharding={{manual}, {manual}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_0_retvals_htod_0"} + // CHECK-NEXT: ROOT %[[GET_TUPLE_0:.*]] = s64[1,2] get-tuple-element((s64[1,2], token[]) %[[RECV_DONE]]), index=0, sharding={manual}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_0_retvals_htod_0"} + // CHECK-NEXT: %[[GET_TUPLE_1:.*]] = token[] get-tuple-element((s64[1,2], token[]) %[[RECV_DONE]]), index=1, sharding={manual}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_0_retvals_htod_0"} + %0 = mhlo.create_token : !mhlo.token + %1 = "mhlo.send"(%arg0, %0) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> {mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "host_compute_channel_0_args_dtoh_0"}, mhlo.sharding = "\08\04"} : (tensor<1x2xi64>, !mhlo.token) -> !mhlo.token + %2:2 = "mhlo.recv"(%1) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> {mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "host_compute_channel_0_retvals_htod_0"}, mhlo.sharding = "\08\04"} : (!mhlo.token) -> (tensor<1x2xi64>, !mhlo.token) + return %2#0 : tensor<1x2xi64> +} diff --git a/third_party/xla/xla/hlo/translate/stablehlo.cc b/third_party/xla/xla/hlo/translate/stablehlo.cc index 28471095577319..eff9875b253999 100644 --- a/third_party/xla/xla/hlo/translate/stablehlo.cc +++ b/third_party/xla/xla/hlo/translate/stablehlo.cc @@ -57,6 +57,34 @@ absl::Status MhloToStablehlo(mlir::ModuleOp module) { } return absl::OkStatus(); } + +// TODO(b/385393967) Separate createCanonicalizerPass from StableHLO -> HLO +// Translation +absl::Status StablehloToMhlo(mlir::ModuleOp module, bool run_canonicalizer) { + mlir::MLIRContext* context = module->getContext(); + mlir::PassManager pm(context); + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + pm.addNestedPass( + mlir::mhlo::createChloLegalizeToHloPass()); + if (run_canonicalizer) { + pm.addNestedPass(mlir::createCanonicalizerPass()); + } + // In order to export to XLA, we must sink constants to control flow + // regions, since XLA uses functional control flow. + pm.addNestedPass( + mlir::mhlo::createSinkConstantsToControlFlowPass()); + mlir::BaseScopedDiagnosticHandler diagnostic_handler(context); + if (failed(pm.run(module))) { + VLOG(1) << "MHLO->HLO lowering passes failed. Module:\n" << module; + return diagnostic_handler.ConsumeStatus(); + } + + VLOG(5) << "MHLO module after lowering, before HLO import, Module:\n" + << module; + + return absl::OkStatus(); +} + } // namespace void RegisterMlirToHloDependentDialects(mlir::DialectRegistry& registry) { @@ -113,29 +141,7 @@ absl::Status ConvertStablehloToHloProto(mlir::ModuleOp module, xla::HloProto* hlo_proto) { if (!module) return absl::InvalidArgumentError("Module is null"); - mlir::MLIRContext* context = module->getContext(); - mlir::BaseScopedDiagnosticHandler diagnostic_handler(context); - { - mlir::PassManager pm(context); - pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); - pm.addNestedPass( - mlir::mhlo::createChloLegalizeToHloPass()); - pm.addNestedPass(mlir::createCanonicalizerPass()); - // In order to export to XLA, we must sink constants to control flow - // regions, since XLA uses functional control flow. - pm.addNestedPass( - mlir::mhlo::createSinkConstantsToControlFlowPass()); - if (failed(pm.run(module))) { - VLOG(1) << "MHLO->HLO lowering passes failed."; - module->dump(); - return diagnostic_handler.ConsumeStatus(); - } - - VLOG(5) << "MHLO module after lowering, before HLO import "; - if (VLOG_IS_ON(5)) { - module->dump(); - } - } + TF_RETURN_IF_ERROR(StablehloToMhlo(module, /*run_canonicalizer=*/true)); mlir::MlirToHloConversionOptions options; options.return_tuple = false; @@ -144,4 +150,22 @@ absl::Status ConvertStablehloToHloProto(mlir::ModuleOp module, return absl::OkStatus(); } +absl::Status ConvertStablehloWithManyArgsToHloProto(mlir::ModuleOp module, + xla::HloProto* hlo_proto, + bool use_tuple_args) { + if (!module) return absl::InvalidArgumentError("Module is null"); + + TF_RETURN_IF_ERROR(StablehloToMhlo(module, /*run_canonicalizer=*/false)); + + mlir::MlirToHloConversionOptions options; + options.return_tuple = false; + options.use_tuple_args = use_tuple_args; + // Remove attributes introduced by `import_all_computation=true` at + // ConvertHloToStablehlo. + module->removeAttr("mhlo.xla_entry_computation_parameter_layouts"); + module->removeAttr("mhlo.xla_entry_computation_parameter_tiles"); + TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module, hlo_proto, options)); + return absl::OkStatus(); +} + } // namespace xla diff --git a/third_party/xla/xla/hlo/translate/stablehlo.h b/third_party/xla/xla/hlo/translate/stablehlo.h index 933d0c895dd539..1c649344973940 100644 --- a/third_party/xla/xla/hlo/translate/stablehlo.h +++ b/third_party/xla/xla/hlo/translate/stablehlo.h @@ -48,6 +48,15 @@ absl::StatusOr> ConvertStablehloToHlo( absl::Status ConvertStablehloToHloProto(mlir::ModuleOp module, xla::HloProto* hlo_proto); +// Convert StableHLO module to HloModuleProto. +// Some platforms run out of memory when the argument list is too long. +// This API wraps the arguments in a tuple (if use_tuple_args = true) +// as a workaround. The long-term solution is to add an HLO pass to do this. +// In general, prefer the other ConvertStablehloToHloProto method. +absl::Status ConvertStablehloWithManyArgsToHloProto( + mlir::ModuleOp module, xla::HloProto* hlo_proto, + bool use_tuple_args = false); + } // namespace xla #endif // XLA_HLO_TRANSLATE_STABLEHLO_H_ diff --git a/third_party/xla/xla/hlo/translate/xla_translate_main.cc b/third_party/xla/xla/hlo/translate/xla_translate_main.cc index e2e0cbc8399ec1..86d44d08b46e08 100644 --- a/third_party/xla/xla/hlo/translate/xla_translate_main.cc +++ b/third_party/xla/xla/hlo/translate/xla_translate_main.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include "llvm/Support/CommandLine.h" diff --git a/third_party/xla/xla/hlo/utils/BUILD b/third_party/xla/xla/hlo/utils/BUILD index fb49da2a16b230..9af623898d05be 100644 --- a/third_party/xla/xla/hlo/utils/BUILD +++ b/third_party/xla/xla/hlo/utils/BUILD @@ -111,7 +111,6 @@ cc_library( deps = [ ":hlo_container_util", "//xla:array", - "//xla:literal", "//xla:literal_util", "//xla:protobuf_util", "//xla:shape_util", @@ -122,6 +121,7 @@ cc_library( "//xla/service:call_graph", "//xla/service:dot_as_convolution_util", "//xla/service:gather_scatter_utils", + "//xla/tsl/platform:errors", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", @@ -133,7 +133,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/hlo/utils/hlo_matchers.h b/third_party/xla/xla/hlo/utils/hlo_matchers.h index 2c00ddb7b3edfb..1235dcbdd6a0c6 100644 --- a/third_party/xla/xla/hlo/utils/hlo_matchers.h +++ b/third_party/xla/xla/hlo/utils/hlo_matchers.h @@ -284,6 +284,7 @@ HLO_MATCHER(BitcastConvert); HLO_MATCHER(Broadcast); HLO_MATCHER(Call); HLO_MATCHER(Ceil); +HLO_MATCHER(Cholesky); HLO_MATCHER(Clamp); HLO_MATCHER(CollectiveBroadcast); HLO_MATCHER(CollectivePermute); @@ -353,6 +354,7 @@ HLO_MATCHER(Subtract); HLO_MATCHER(Tan); HLO_MATCHER(Tanh); HLO_MATCHER(Transpose); +HLO_MATCHER(TriangularSolve); HLO_MATCHER(Tuple); HLO_MATCHER(While); HLO_MATCHER(Xor); diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index 0b1b88de4ae9d9..d4d5179061b5f7 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -57,9 +58,9 @@ limitations under the License. #include "xla/service/gather_scatter_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/errors.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace xla { @@ -1513,6 +1514,7 @@ GatherScatterDims GetGatherScatterOperandPassthroughDims( absl::Span offset_or_window_dims, absl::Span slice_size) { GatherScatterDims result; + CHECK(absl::c_is_sorted(offset_or_window_dims)); int64_t collapsed_or_batching = 0; for (int64_t i = 0; i < operand_shape.rank(); ++i) { @@ -1524,12 +1526,6 @@ GatherScatterDims GetGatherScatterOperandPassthroughDims( if (slice_size[i] != operand_shape.dimensions(i)) { continue; } - if (i - collapsed_or_batching > 0 && - offset_or_window_dims[i - collapsed_or_batching] < - offset_or_window_dims[i - collapsed_or_batching - 1]) { - // Output offsets are transposed, we do not support this case. - continue; - } result.operand_dims.push_back(i); result.output_dims.push_back( offset_or_window_dims[i - collapsed_or_batching]); @@ -2942,23 +2938,27 @@ std::shared_ptr CreateTupleSharding( HloSharding::Tuple(shape, sub_shardings)); } -std::optional GetFirstMergeableDimForSortOperand( - const Shape& operand_shape, const HloSharding& operand_sharding, - int64_t sort_dim) { - if (operand_shape.rank() < 2 || operand_shape.dimensions(sort_dim) == 1) { +std::optional GetFirstTargetDimToMoveShardingTiles( + const Shape& shape, const HloSharding& sharding, int64_t source_dim, + std::function can_be_target_dim) { + if (shape.rank() < 2 || shape.dimensions(source_dim) == 1) { return std::nullopt; } - if (!operand_sharding.IsTiled() || - operand_sharding.tile_assignment().dim(sort_dim) == 1) { + if (!sharding.IsTiled() || sharding.tile_assignment().dim(source_dim) == 1) { return std::nullopt; } - for (int64_t dim = 0; dim < operand_shape.rank(); ++dim) { + for (int64_t dim = 0; dim < shape.rank(); ++dim) { + if (dim == source_dim) { + continue; + } + if (!can_be_target_dim(dim)) { + continue; + } const int64_t merged_tile_dims = - operand_sharding.tile_assignment().dim(sort_dim) * - operand_sharding.tile_assignment().dim(dim); - if (dim != sort_dim && - operand_shape.dimensions(dim) % merged_tile_dims == 0) { + sharding.tile_assignment().dim(source_dim) * + sharding.tile_assignment().dim(dim); + if (shape.dimensions(dim) % merged_tile_dims == 0) { return dim; } } diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h index f9fefdd3352dde..049ddca5daea09 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h @@ -17,6 +17,7 @@ limitations under the License. #define XLA_HLO_UTILS_HLO_SHARDING_UTIL_H_ #include +#include #include #include #include @@ -24,7 +25,6 @@ limitations under the License. #include #include -#include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" @@ -34,7 +34,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" -#include "xla/literal.h" #include "xla/service/call_graph.h" #include "xla/service/dot_as_convolution_util.h" #include "xla/shape.h" @@ -503,19 +502,21 @@ HloSharding MergeShardingDimension(const HloSharding& sharding, std::shared_ptr CreateTupleSharding( const Shape& shape, absl::Span elements); -// Returns the first mergeable dimension for the sort operand. A mergeable -// dimension satisfies: -// 1. The sort dimension is sharded. The size of the sort dimension is larger -// than 1. -// 2. The mergeable dimension is not a sort dimension. -// 3. The size of the mergeable dimension is divisible by the merged tile size, -// which is the product of the tile sizes of the sort dim and the picked -// mergeable dim. +// We intend to move the sharding tiles from the source dimension to a target +// dimension. Returns the first target dimension, which satisfies: +// 1. The source dimension is sharded. The size of the source dimension is +// larger than 1. +// 2. The target dimension and source dimension are different. +// 3. The target dimension satisfies the can_be_target_dim predicate. +// 4. The size of the target dimension is divisible by the merged tile size, +// which is the product of the tile sizes of the source dim and the target dim. // // If there is no such dimension, returns std::nullopt. -std::optional GetFirstMergeableDimForSortOperand( - const Shape& operand_shape, const HloSharding& operand_sharding, - int64_t sort_dim); +std::optional GetFirstTargetDimToMoveShardingTiles( + const Shape& shape, const HloSharding& sharding, int64_t source_dim, + std::function can_be_target_dim = [](int64_t) { + return true; + }); // Returns the sharding of an output of an instruction. Some instructions have // special handling like Outfeed and this function takes care of those. diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc index ce916f03ae0508..6c4847cf3d3c8d 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc @@ -1076,50 +1076,50 @@ TEST(HloShardingUtilTest, IsSubTilingOrEqualShardingShortcut7) { } } -TEST(HloShardingUtilTest, GetFirstMergeableDimForSortOperand1) { +TEST(HloShardingUtilTest, GetFirstTargetDimToMoveShardingTiles1) { Shape shape = ShapeUtil::MakeShape(F32, {1, 8, 128, 128}); HloSharding sharding = HloSharding::IotaTile({8, 1, 2, 16}); EXPECT_FALSE( - GetFirstMergeableDimForSortOperand(shape, sharding, 0).has_value()); + GetFirstTargetDimToMoveShardingTiles(shape, sharding, 0).has_value()); EXPECT_FALSE( - GetFirstMergeableDimForSortOperand(shape, sharding, 1).has_value()); - EXPECT_EQ(GetFirstMergeableDimForSortOperand(shape, sharding, 2), 1); - EXPECT_EQ(GetFirstMergeableDimForSortOperand(shape, sharding, 3), 2); + GetFirstTargetDimToMoveShardingTiles(shape, sharding, 1).has_value()); + EXPECT_EQ(GetFirstTargetDimToMoveShardingTiles(shape, sharding, 2), 1); + EXPECT_EQ(GetFirstTargetDimToMoveShardingTiles(shape, sharding, 3), 2); } -TEST(HloShardingUtilTest, GetFirstMergeableDimForSortOperand2) { +TEST(HloShardingUtilTest, GetFirstTargetDimToMoveShardingTiles2) { Shape shape = ShapeUtil::MakeShape(F32, {4, 8, 128, 128}); HloSharding sharding = HloSharding::IotaTile({2, 2, 4, 16}); - EXPECT_EQ(GetFirstMergeableDimForSortOperand(shape, sharding, 0), 1); - EXPECT_EQ(GetFirstMergeableDimForSortOperand(shape, sharding, 1), 0); - EXPECT_EQ(GetFirstMergeableDimForSortOperand(shape, sharding, 2), 1); - EXPECT_EQ(GetFirstMergeableDimForSortOperand(shape, sharding, 3), 2); + EXPECT_EQ(GetFirstTargetDimToMoveShardingTiles(shape, sharding, 0), 1); + EXPECT_EQ(GetFirstTargetDimToMoveShardingTiles(shape, sharding, 1), 0); + EXPECT_EQ(GetFirstTargetDimToMoveShardingTiles(shape, sharding, 2), 1); + EXPECT_EQ(GetFirstTargetDimToMoveShardingTiles(shape, sharding, 3), 2); } -TEST(HloShardingUtilTest, GetFirstMergeableDimForSortOperand3) { +TEST(HloShardingUtilTest, GetFirstTargetDimToMoveShardingTiles3) { Shape shape = ShapeUtil::MakeShape(F32, {1, 128}); HloSharding sharding = HloSharding::IotaTile({1, 2}); EXPECT_FALSE( - GetFirstMergeableDimForSortOperand(shape, sharding, 0).has_value()); + GetFirstTargetDimToMoveShardingTiles(shape, sharding, 0).has_value()); EXPECT_FALSE( - GetFirstMergeableDimForSortOperand(shape, sharding, 1).has_value()); + GetFirstTargetDimToMoveShardingTiles(shape, sharding, 1).has_value()); } -TEST(HloShardingUtilTest, GetFirstMergeableDimForSortOperandRankOne) { +TEST(HloShardingUtilTest, GetFirstTargetDimToMoveShardingTilesRankOne) { Shape shape = ShapeUtil::MakeShape(F32, {1024}); HloSharding sharding = HloSharding::Tile(TileAssignment(std::initializer_list{2})); EXPECT_FALSE( - GetFirstMergeableDimForSortOperand(shape, sharding, 0).has_value()); + GetFirstTargetDimToMoveShardingTiles(shape, sharding, 0).has_value()); } -TEST(HloShardingUtilTest, GetFirstMergeableDimForSortOperandReplicated) { +TEST(HloShardingUtilTest, GetFirstTargetDimToMoveShardingTilesReplicated) { Shape shape = ShapeUtil::MakeShape(F32, {8, 128}); HloSharding sharding = HloSharding::Replicate(); EXPECT_FALSE( - GetFirstMergeableDimForSortOperand(shape, sharding, 0).has_value()); + GetFirstTargetDimToMoveShardingTiles(shape, sharding, 0).has_value()); EXPECT_FALSE( - GetFirstMergeableDimForSortOperand(shape, sharding, 1).has_value()); + GetFirstTargetDimToMoveShardingTiles(shape, sharding, 1).has_value()); } TEST(HloShardingUtilTest, TileShape) { diff --git a/third_party/xla/xla/index_util_test.cc b/third_party/xla/xla/index_util_test.cc index 333f772f0b4cfb..a312293d32b586 100644 --- a/third_party/xla/xla/index_util_test.cc +++ b/third_party/xla/xla/index_util_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include "absl/types/span.h" +#include "xla/hlo/testlib/test.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/iterator_util_test.cc b/third_party/xla/xla/iterator_util_test.cc index ac093c3d1bd68d..3a9e9b05553026 100644 --- a/third_party/xla/xla/iterator_util_test.cc +++ b/third_party/xla/xla/iterator_util_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" namespace xla { namespace { diff --git a/third_party/xla/xla/layout_test.cc b/third_party/xla/xla/layout_test.cc index 46a13cf421b0e2..e26b020ea463a2 100644 --- a/third_party/xla/xla/layout_test.cc +++ b/third_party/xla/xla/layout_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include +#include "xla/hlo/testlib/test.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/layout_util_test.cc b/third_party/xla/xla/layout_util_test.cc index ed2f6ff479d7e2..56f821ce0a0908 100644 --- a/third_party/xla/xla/layout_util_test.cc +++ b/third_party/xla/xla/layout_util_test.cc @@ -19,11 +19,11 @@ limitations under the License. #include #include "absl/types/span.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/layout.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep diff --git a/third_party/xla/xla/literal.cc b/third_party/xla/xla/literal.cc index 997f44a4dd0f62..6b5db7f893ec4c 100644 --- a/third_party/xla/xla/literal.cc +++ b/third_party/xla/xla/literal.cc @@ -87,9 +87,10 @@ void ConvertEndianShort(char* bytes, int64_t size) { } bool LiteralProtoHasValues(const LiteralProto& proto) { - return !proto.s2s().empty() || !proto.s4s().empty() || !proto.s8s().empty() || - !proto.s16s().empty() || proto.s32s_size() || proto.s64s_size() || - !proto.u2s().empty() || !proto.u4s().empty() || !proto.u8s().empty() || + return !proto.s1s().empty() || !proto.s2s().empty() || !proto.s4s().empty() || + !proto.s8s().empty() || !proto.s16s().empty() || proto.s32s_size() || + proto.s64s_size() || !proto.u1s().empty() || !proto.u2s().empty() || + !proto.u4s().empty() || !proto.u8s().empty() || !proto.u16s().empty() || proto.u32s_size() || proto.u64s_size() || !proto.f8e5m2s().empty() || !proto.f8e4m3s().empty() || !proto.f8e4m3fns().empty() || !proto.f8e4m3b11fnuzs().empty() || @@ -2207,6 +2208,10 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case PRED: CopyToRepeatedField(proto->mutable_preds(), data()); break; + case U1: + *proto->mutable_u1s() = std::string( + reinterpret_cast(data().data()), size_bytes_dense()); + break; case U2: *proto->mutable_u2s() = std::string( reinterpret_cast(data().data()), size_bytes_dense()); @@ -2233,6 +2238,10 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case U64: CopyToRepeatedField(proto->mutable_u64s(), data()); break; + case S1: + *proto->mutable_s1s() = std::string( + reinterpret_cast(data().data()), size_bytes_dense()); + break; case S2: *proto->mutable_s2s() = std::string( reinterpret_cast(data().data()), size_bytes_dense()); diff --git a/third_party/xla/xla/literal.h b/third_party/xla/xla/literal.h index 3233126a5efb05..1b76f2effe6a94 100644 --- a/third_party/xla/xla/literal.h +++ b/third_party/xla/xla/literal.h @@ -367,9 +367,9 @@ class LiteralBase { static_assert(sizeof(H) == 0, "Do not use Literal directly as a hash key, because it has " "multiple definitions of equality - layout sensitive or " - "insensitive. Instead, provide an external hash function " - "that uses Literal::Hash which allows you to specify layout " - "sensitivity."); + "insensitive. Instead, use AbslHashable<...>() to create a " + "wrapper with layout sensitivity specified suitable for " + "passing to Absl::Hash"); } // Always use this together with the Equal method and not operator== in order @@ -419,6 +419,17 @@ class LiteralBase { return std::move(state); } + // Templated wrapper struct to control layout sensitivity during Absl::Hash. + template + struct AbslHashable { + const LiteralBase& literal; + explicit AbslHashable(const LiteralBase& l) : literal(l) {} + template + friend H AbslHashValue(H h, const AbslHashable& w) { + return LiteralBase::Hash(std::move(h), w.literal); + } + }; + // Converts this literal to the given shape. Returns an error is the // conversion is not possible. absl::StatusOr ConvertToShape(const Shape& dest_shape) const; @@ -1404,7 +1415,7 @@ class Literal : public MutableLiteralBase { static absl::StatusOr Deserialize(InputIterator begin, InputIterator end); - static absl::StatusOr DeserializeFromString(std::string_view data) { + static absl::StatusOr DeserializeFromString(absl::string_view data) { return Deserialize(data.data(), data.data() + data.size()); } diff --git a/third_party/xla/xla/literal_comparison_test.cc b/third_party/xla/xla/literal_comparison_test.cc index 7713aceaaa3bc5..4dcdad85fd5d43 100644 --- a/third_party/xla/xla/literal_comparison_test.cc +++ b/third_party/xla/xla/literal_comparison_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include "xla/error_spec.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal_util.h" -#include "xla/test_helpers.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/ml_dtypes.h" diff --git a/third_party/xla/xla/literal_pool.cc b/third_party/xla/xla/literal_pool.cc new file mode 100644 index 00000000000000..e3ce7269621f6b --- /dev/null +++ b/third_party/xla/xla/literal_pool.cc @@ -0,0 +1,114 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/literal_pool.h" + +#include +#include +#include +#include +#include + +#include "absl/synchronization/mutex.h" +#include "xla/literal.h" +#include "xla/shape.h" +#include "tsl/platform/logging.h" + +namespace xla { + +LiteralPool* LiteralPool::Default() { + static auto* pool = new LiteralPool(); + return pool; +} + +// Erases expired weak pointers from the vector and returns the number of +// elements that were erased. +static size_t EraseExpiredLiterals( + std::vector>& literals) { + auto it = std::remove_if(literals.begin(), literals.end(), + [](auto& ptr) { return ptr.expired(); }); + size_t num_erased = std::distance(it, literals.end()); + + literals.erase(it, literals.end()); + return num_erased; +} + +size_t LiteralPool::GarbageCollect() { + absl::MutexLock lock(&mu_); + size_t num_erased = 0; + + for (auto& [shape, literals] : literals_) { + num_erased += EraseExpiredLiterals(literals); + } + + VLOG(3) << "Garbage collected " << num_erased << " literals"; + return num_erased; +} + +size_t LiteralPool::GarbageCollect(Shape shape) { + absl::MutexLock lock(&mu_); + size_t num_erased = 0; + + if (auto it = literals_.find(shape); it != literals_.end()) { + num_erased = EraseExpiredLiterals(it->second); + } + + VLOG(3) << "Garbage collected " << num_erased << " literals for shape " + << shape.ToString(); + return num_erased; +} + +// Tried to find a canonical literal in the pool. Return nullptr if not found. +static std::shared_ptr FindCanonicalLiteral( + std::vector>& literals, const Literal& literal) { + for (std::weak_ptr& ptr : literals) { + if (auto locked_ptr = ptr.lock()) { + if (locked_ptr->Equal(literal, /*layout_sensitive=*/true)) { + return locked_ptr; + } + } + } + + return nullptr; +} + +std::shared_ptr LiteralPool::GetCanonicalLiteral( + const Literal& literal) { + absl::MutexLock lock(&mu_); + + auto& literals = literals_[literal.shape()]; + if (auto ptr = FindCanonicalLiteral(literals, literal)) { + return ptr; + } + + std::shared_ptr new_literal = literal.CloneToUnique(); + literals.push_back(new_literal); + return new_literal; +} + +std::shared_ptr LiteralPool::GetCanonicalLiteral( + std::shared_ptr literal) { + absl::MutexLock lock(&mu_); + + auto& literals = literals_[literal->shape()]; + if (auto ptr = FindCanonicalLiteral(literals, *literal)) { + return ptr; + } + + literals.push_back(literal); + return literal; +} + +} // namespace xla diff --git a/third_party/xla/xla/literal_pool.h b/third_party/xla/xla/literal_pool.h new file mode 100644 index 00000000000000..4e53181b05e9a6 --- /dev/null +++ b/third_party/xla/xla/literal_pool.h @@ -0,0 +1,67 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_LITERAL_POOL_H_ +#define XLA_LITERAL_POOL_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "xla/literal.h" +#include "xla/shape.h" + +namespace xla { + +// Literal pool provides a mechanism to deduplicate identical literals and +// share them across multiple HLO modules. +class LiteralPool { + public: + // Returns a default literal pool that can be used across multiple HLO modules + // in a process. + static LiteralPool* Default(); + + // Returns a canonical literal from the pool. If the literal is not in the + // pool, it is added to the pool and returned back. + std::shared_ptr GetCanonicalLiteral(const Literal& literal); + + // Returns a canonical literal from the pool. If the literal is not in the + // pool, it is added to the pool and returned back. + std::shared_ptr GetCanonicalLiteral( + std::shared_ptr literal); + + // Runs garbage collection on all the literals in the pool. Returns the number + // of literals that were garbage collected. + size_t GarbageCollect(); + + // Runs garbage collection on literals with the given shape. Returns the + // number of literals that were garbage collected. + size_t GarbageCollect(Shape shape); + + private: + // We keep weak pointers to the literals in the pool to allow for garbage + // collection when owning HLO modules are destroyed. We run periodic garbage + // collection to clean up the literals that are no longer referenced. + absl::Mutex mu_; + absl::flat_hash_map>> literals_ + ABSL_GUARDED_BY(mu_); +}; + +} // namespace xla + +#endif // XLA_LITERAL_POOL_H_ diff --git a/third_party/xla/xla/literal_pool_test.cc b/third_party/xla/xla/literal_pool_test.cc new file mode 100644 index 00000000000000..b655c8c4661f77 --- /dev/null +++ b/third_party/xla/xla/literal_pool_test.cc @@ -0,0 +1,45 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/literal_pool.h" + +#include "xla/literal_util.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +TEST(LiteralPoolTest, GetCanonicalLiteral) { + LiteralPool pool; + + auto l0 = LiteralUtil::CreateR2({{1., 2.}, {3., 4.}}); + auto l1 = LiteralUtil::CreateR2({{2., 1.}, {4., 3.}}); + + { // Use nested scope to allow garbage collection below. + auto cl0_0 = pool.GetCanonicalLiteral(l0); + auto cl0_1 = pool.GetCanonicalLiteral(l0); + ASSERT_EQ(cl0_0, cl0_1); + + auto cl1_0 = pool.GetCanonicalLiteral(l1); + auto cl1_1 = pool.GetCanonicalLiteral(l1); + ASSERT_NE(cl0_0, cl1_0); + ASSERT_EQ(cl1_0, cl1_1); + } + + ASSERT_EQ(pool.GarbageCollect(), 2); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/literal_test.cc b/third_party/xla/xla/literal_test.cc index 44e4acd6a5cef7..5bbddd572c8a64 100644 --- a/third_party/xla/xla/literal_test.cc +++ b/third_party/xla/xla/literal_test.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" +#include "xla/hlo/testlib/test.h" #include "xla/index_util.h" #include "xla/layout.h" #include "xla/layout_util.h" @@ -46,7 +47,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/util.h" @@ -139,12 +139,24 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto false_lit = LiteralUtil::CreateR0(false); EXPECT_EQ("pred[] false", false_lit.ToString()); + auto u1_lit = LiteralUtil::CreateR0(u1(1)); + EXPECT_EQ("u1[] 1", u1_lit.ToString()); + + auto u2_lit = LiteralUtil::CreateR0(u2(0)); + EXPECT_EQ("u2[] 0", u2_lit.ToString()); + auto u4_lit = LiteralUtil::CreateR0(u4(5)); EXPECT_EQ("u4[] 5", u4_lit.ToString()); auto u32_lit = LiteralUtil::CreateR0(42); EXPECT_EQ("u32[] 42", u32_lit.ToString()); + auto s1_lit = LiteralUtil::CreateR0(s1(-1)); + EXPECT_EQ("s1[] -1", s1_lit.ToString()); + + auto s2_lit = LiteralUtil::CreateR0(s2(1)); + EXPECT_EQ("s2[] 1", s2_lit.ToString()); + auto s4_lit = LiteralUtil::CreateR0(s4(-3)); EXPECT_EQ("s4[] -3", s4_lit.ToString()); diff --git a/third_party/xla/xla/literal_util.cc b/third_party/xla/xla/literal_util.cc index 9b5507327789a8..c689d7eb74ad23 100644 --- a/third_party/xla/xla/literal_util.cc +++ b/third_party/xla/xla/literal_util.cc @@ -472,6 +472,11 @@ void PopulateWithRandomIntegralDataWithBounds(Literal* literal, return ConvertType(s32_literal); } +/* static */ Literal LiteralUtil::ConvertS32ToS1( + const LiteralSlice& s32_literal) { + return ConvertType(s32_literal); +} + /* static */ Literal LiteralUtil::CreateToken() { return Literal(ShapeUtil::MakeTokenShape()); } diff --git a/third_party/xla/xla/literal_util.h b/third_party/xla/xla/literal_util.h index 01af0cea5499b8..db8e958f2340b3 100644 --- a/third_party/xla/xla/literal_util.h +++ b/third_party/xla/xla/literal_util.h @@ -252,6 +252,7 @@ class LiteralUtil { static Literal ConvertF64ToBF16(const LiteralSlice& f64_literal); static Literal ConvertF64ToF32(const LiteralSlice& f64_literal); static Literal ConvertS32ToF32(const LiteralSlice& s32_literal); + static Literal ConvertS32ToS1(const LiteralSlice& s32_literal); // Creates a scalar literal whose value is the maximum value of a given // literal slice. @@ -282,6 +283,12 @@ class LiteralUtil { static absl::StatusOr CreateRandomLiteral(const Shape& shape, E* engine, T mean, T stddev); + // Same as the above, but takes mean and stddev as doubles. + template > + static absl::StatusOr CreateRandomLiteral(const Shape& shape, + E* engine, double mean, + double stddev); // Creates a literal with the supplied shape, and initializes the literal // values using a normal distribution with given mean and stddev standard @@ -595,6 +602,13 @@ template template /* static */ absl::StatusOr LiteralUtil::CreateRandomLiteral( const Shape& shape, E* engine, T mean, T stddev) { + return CreateRandomLiteral(shape, engine, static_cast(mean), + static_cast(stddev)); +} + +template +/* static */ absl::StatusOr LiteralUtil::CreateRandomLiteral( + const Shape& shape, E* engine, double mean, double stddev) { using NativeT = primitive_util::NativeTypeOf; std::normal_distribution generator(mean, stddev); return CreateLiteralWithGenerator( diff --git a/third_party/xla/xla/mlir/framework/transforms/outline_with_xla_framework.cc b/third_party/xla/xla/mlir/framework/transforms/outline_with_xla_framework.cc index 7cafdfa3bcb23e..7d9b8fc700767a 100644 --- a/third_party/xla/xla/mlir/framework/transforms/outline_with_xla_framework.cc +++ b/third_party/xla/xla/mlir/framework/transforms/outline_with_xla_framework.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include #include #include "llvm/ADT/STLExtras.h" @@ -165,7 +164,7 @@ class OutlineWithXLAFrameworkPass patterns.add(ctx); // Set target. - if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) { + if (failed(applyPatternsGreedily(m, std::move(patterns)))) { signalPassFailure(); } m->walk([](func::FuncOp f) { diff --git a/third_party/xla/xla/mlir/framework/transforms/xla_framework_to_llvm_pass.cc b/third_party/xla/xla/mlir/framework/transforms/xla_framework_to_llvm_pass.cc index 703b6c9af785d5..c40a7ad1b9aa46 100644 --- a/third_party/xla/xla/mlir/framework/transforms/xla_framework_to_llvm_pass.cc +++ b/third_party/xla/xla/mlir/framework/transforms/xla_framework_to_llvm_pass.cc @@ -14,8 +14,9 @@ limitations under the License. ==============================================================================*/ #include +#include +#include #include -#include #include #include "llvm/ADT/ArrayRef.h" diff --git a/third_party/xla/xla/mlir/tools/mlir_interpreter/framework/interpreter_value.cc b/third_party/xla/xla/mlir/tools/mlir_interpreter/framework/interpreter_value.cc index 4962754e41c3ed..21d0bddf7e93cc 100644 --- a/third_party/xla/xla/mlir/tools/mlir_interpreter/framework/interpreter_value.cc +++ b/third_party/xla/xla/mlir/tools/mlir_interpreter/framework/interpreter_value.cc @@ -22,10 +22,10 @@ limitations under the License. #include #include #include -#include #include #include +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" @@ -40,19 +40,19 @@ namespace interpreter { namespace { struct TypeStr { - static std::string_view Get(bool) { return "i1"; } - static std::string_view Get(int64_t) { return "i64"; } - static std::string_view Get(int32_t) { return "i32"; } - static std::string_view Get(int16_t) { return "i16"; } - static std::string_view Get(int8_t) { return "i8"; } - static std::string_view Get(uint64_t) { return "ui64"; } - static std::string_view Get(uint32_t) { return "ui32"; } - static std::string_view Get(uint16_t) { return "ui16"; } - static std::string_view Get(uint8_t) { return "ui8"; } - static std::string_view Get(float) { return "f32"; } - static std::string_view Get(double) { return "f64"; } - static std::string_view Get(std::complex) { return "complex"; } - static std::string_view Get(std::complex) { return "complex"; } + static absl::string_view Get(bool) { return "i1"; } + static absl::string_view Get(int64_t) { return "i64"; } + static absl::string_view Get(int32_t) { return "i32"; } + static absl::string_view Get(int16_t) { return "i16"; } + static absl::string_view Get(int8_t) { return "i8"; } + static absl::string_view Get(uint64_t) { return "ui64"; } + static absl::string_view Get(uint32_t) { return "ui32"; } + static absl::string_view Get(uint16_t) { return "ui16"; } + static absl::string_view Get(uint8_t) { return "ui8"; } + static absl::string_view Get(float) { return "f32"; } + static absl::string_view Get(double) { return "f64"; } + static absl::string_view Get(std::complex) { return "complex"; } + static absl::string_view Get(std::complex) { return "complex"; } }; struct InterpreterValuePrinter { diff --git a/third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD b/third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD index 6c0359e089e12a..487f0311a3caff 100644 --- a/third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD +++ b/third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD @@ -15,6 +15,7 @@ cc_library( ":compiler_trace_proto_cc", ":compiler_trace_proto_cc_impl", "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/mlir/tools/mlir_replay/public/compiler_trace_instrumentation.cc b/third_party/xla/xla/mlir/tools/mlir_replay/public/compiler_trace_instrumentation.cc index 5a32a6fcbe5e92..c789ea6dc05fe0 100644 --- a/third_party/xla/xla/mlir/tools/mlir_replay/public/compiler_trace_instrumentation.cc +++ b/third_party/xla/xla/mlir/tools/mlir_replay/public/compiler_trace_instrumentation.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/log/log.h" #include "llvm/Support/Casting.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" diff --git a/third_party/xla/xla/mlir/utils/BUILD b/third_party/xla/xla/mlir/utils/BUILD index 4026decfd952d5..f618b26763f427 100644 --- a/third_party/xla/xla/mlir/utils/BUILD +++ b/third_party/xla/xla/mlir/utils/BUILD @@ -18,10 +18,14 @@ cc_library( hdrs = ["error_util.h"], compatible_with = get_compatible_with_portable(), deps = [ + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:errors", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:logging", ], ) @@ -33,6 +37,7 @@ cc_test( "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:status", diff --git a/third_party/xla/xla/mlir/utils/error_util.cc b/third_party/xla/xla/mlir/utils/error_util.cc index 9fe4801c95f99b..3c45f3fd9ebcde 100644 --- a/third_party/xla/xla/mlir/utils/error_util.cc +++ b/third_party/xla/xla/mlir/utils/error_util.cc @@ -15,12 +15,15 @@ limitations under the License. #include "xla/mlir/utils/error_util.h" +#include #include -#include -#include "tsl/platform/errors.h" -#include "mlir/IR/BuiltinAttributes.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/Support/LLVM.h" +#include "tsl/platform/logging.h" namespace mlir { BaseScopedDiagnosticHandler::BaseScopedDiagnosticHandler(MLIRContext* context, diff --git a/third_party/xla/xla/mlir/utils/error_util_test.cc b/third_party/xla/xla/mlir/utils/error_util_test.cc index 23f214f9658b26..942809105d24ad 100644 --- a/third_party/xla/xla/mlir/utils/error_util_test.cc +++ b/third_party/xla/xla/mlir/utils/error_util_test.cc @@ -15,8 +15,7 @@ limitations under the License. #include "xla/mlir/utils/error_util.h" -#include - +#include #include "absl/status/status.h" #include "absl/strings/match.h" #include "llvm/ADT/Twine.h" diff --git a/third_party/xla/xla/mlir_hlo/bindings/python/MlirHloModule.cc b/third_party/xla/xla/mlir_hlo/bindings/python/MlirHloModule.cc index 386e6b1c6acc9a..bfe3e87894138f 100644 --- a/third_party/xla/xla/mlir_hlo/bindings/python/MlirHloModule.cc +++ b/third_party/xla/xla/mlir_hlo/bindings/python/MlirHloModule.cc @@ -18,9 +18,11 @@ limitations under the License. #include "bindings/c/Passes.h" #include "bindings/c/Types.h" #include "mlir-c/IR.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/vector.h" // IWYU pragma: keep -namespace py = pybind11; +namespace nb = nanobind; namespace { // Returns a vector containing integers extracted from an attribute using the @@ -38,12 +40,12 @@ std::vector attributePropertyVector( } auto toPyString(MlirStringRef mlirStringRef) { - return py::str(mlirStringRef.data, mlirStringRef.length); + return nb::str(mlirStringRef.data, mlirStringRef.length); } } // namespace -PYBIND11_MODULE(_mlirHlo, m) { +NB_MODULE(_mlirHlo, m) { m.doc() = "mlir-hlo main python extension"; // @@ -59,7 +61,7 @@ PYBIND11_MODULE(_mlirHlo, m) { mlirDialectHandleLoadDialect(mhloDialect, context); } }, - py::arg("context"), py::arg("load") = true); + nb::arg("context"), nb::arg("load") = true); // // Passes. @@ -71,14 +73,14 @@ PYBIND11_MODULE(_mlirHlo, m) { // Types. // - mlir::python::adaptors::mlir_type_subclass(m, "TokenType", - mlirMhloTypeIsAToken) + mlir::python::nanobind_adaptors::mlir_type_subclass(m, "TokenType", + mlirMhloTypeIsAToken) .def_classmethod( "get", - [](py::object cls, MlirContext ctx) { + [](nb::object cls, MlirContext ctx) { return cls(mlirMhloTokenTypeGet(ctx)); }, - py::arg("cls"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("context").none() = nb::none(), "Creates a Token type."); // @@ -91,11 +93,11 @@ PYBIND11_MODULE(_mlirHlo, m) { mlirMhloScatterDimensionNumbersGetScatteredDimsToOperandDimsElem); }; - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "ScatterDimensionNumbers", mlirMhloAttributeIsAScatterDimensionNumbers) .def_classmethod( "get", - [](py::object cls, const std::vector &updateWindowDims, + [](nb::object cls, const std::vector &updateWindowDims, const std::vector &insertedWindowDims, const std::vector &inputBatchingDims, const std::vector &scatterIndicesBatchingDims, @@ -110,11 +112,11 @@ PYBIND11_MODULE(_mlirHlo, m) { scatteredDimsToOperandDims.size(), scatteredDimsToOperandDims.data(), indexVectorDim)); }, - py::arg("cls"), py::arg("update_window_dims"), - py::arg("inserted_window_dims"), py::arg("input_batching_dims"), - py::arg("scatter_indices_batching_dims"), - py::arg("scattered_dims_to_operand_dims"), - py::arg("index_vector_dim"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("update_window_dims"), + nb::arg("inserted_window_dims"), nb::arg("input_batching_dims"), + nb::arg("scatter_indices_batching_dims"), + nb::arg("scattered_dims_to_operand_dims"), + nb::arg("index_vector_dim"), nb::arg("context").none() = nb::none(), "Creates a ScatterDimensionNumbers with the given dimension " "configuration.") .def_property_readonly( @@ -153,11 +155,11 @@ PYBIND11_MODULE(_mlirHlo, m) { return mlirMhloDimensionNumbersGetIndexVectorDim(self); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "GatherDimensionNumbers", mlirMhloAttributeIsAGatherDimensionNumbers) .def_classmethod( "get", - [](py::object cls, const std::vector &offsetDims, + [](nb::object cls, const std::vector &offsetDims, const std::vector &collapsedSliceDims, const std::vector &operandBatchingDims, const std::vector &startIndicesBatchingDims, @@ -171,10 +173,10 @@ PYBIND11_MODULE(_mlirHlo, m) { startIndicesBatchingDims.data(), startIndexMap.size(), startIndexMap.data(), indexVectorDim)); }, - py::arg("cls"), py::arg("offset_dims"), - py::arg("collapsed_slice_dims"), py::arg("operand_batching_dims"), - py::arg("start_indices_batching_dims"), py::arg("start_index_map"), - py::arg("index_vector_dim"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("offset_dims"), + nb::arg("collapsed_slice_dims"), nb::arg("operand_batching_dims"), + nb::arg("start_indices_batching_dims"), nb::arg("start_index_map"), + nb::arg("index_vector_dim"), nb::arg("context").none() = nb::none(), "Creates a GatherDimensionNumbers attribute with the given dimension " "configuration.") .def_property_readonly( @@ -217,11 +219,11 @@ PYBIND11_MODULE(_mlirHlo, m) { return mlirMhloGatherDimensionNumbersGetIndexVectorDim(self); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "DotDimensionNumbers", mlirMhloAttributeIsADotDimensionNumbers) .def_classmethod( "get", - [](py::object cls, const std::vector &lhsBatchingDims, + [](nb::object cls, const std::vector &lhsBatchingDims, const std::vector &rhsBatchingDims, const std::vector &lhsContractingDims, const std::vector &rhsContractingDims, MlirContext ctx) { @@ -231,11 +233,11 @@ PYBIND11_MODULE(_mlirHlo, m) { lhsContractingDims.size(), lhsContractingDims.data(), rhsContractingDims.size(), rhsContractingDims.data())); }, - py::arg("cls"), py::arg("lhs_batching_dimensions"), - py::arg("rhs_batching_dimensions"), - py::arg("lhs_contracting_dimensions"), - py::arg("rhs_contracting_dimensions"), - py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("lhs_batching_dimensions"), + nb::arg("rhs_batching_dimensions"), + nb::arg("lhs_contracting_dimensions"), + nb::arg("rhs_contracting_dimensions"), + nb::arg("context").none() = nb::none(), "Creates a DotDimensionNumbers attribute with the given dimension " "configuration.") .def_property_readonly( @@ -268,11 +270,11 @@ PYBIND11_MODULE(_mlirHlo, m) { mlirMhloDotDimensionNumbersGetRhsContractingDimensionsElem); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "ConvDimensionNumbers", mlirMhloAttributeIsAConvDimensionNumbers) .def_classmethod( "get", - [](py::object cls, int64_t inputBatchDimension, + [](nb::object cls, int64_t inputBatchDimension, int64_t inputFeatureDimension, const std::vector inputSpatialDimensions, int64_t kernelInputFeatureDimension, @@ -290,15 +292,16 @@ PYBIND11_MODULE(_mlirHlo, m) { outputSpatialDimensions.size(), outputSpatialDimensions.data())); }, - py::arg("cls"), py::arg("input_batch_dimension"), - py::arg("input_feature_dimension"), - py::arg("input_spatial_dimensions"), - py::arg("kernel_input_feature_dimension"), - py::arg("kernel_output_feature_dimension"), - py::arg("kernel_spatial_dimensions"), - py::arg("output_batch_dimension"), - py::arg("output_feature_dimension"), - py::arg("output_spatial_dimensions"), py::arg("ctx") = py::none(), + nb::arg("cls"), nb::arg("input_batch_dimension"), + nb::arg("input_feature_dimension"), + nb::arg("input_spatial_dimensions"), + nb::arg("kernel_input_feature_dimension"), + nb::arg("kernel_output_feature_dimension"), + nb::arg("kernel_spatial_dimensions"), + nb::arg("output_batch_dimension"), + nb::arg("output_feature_dimension"), + nb::arg("output_spatial_dimensions"), + nb::arg("ctx").none() = nb::none(), "Creates a ConvDimensionNumbers attribute with the given dimension " "configuration.") .def_property_readonly( @@ -356,11 +359,11 @@ PYBIND11_MODULE(_mlirHlo, m) { mlirMhloConvDimensionNumbersGetOutputSpatialDimensionsElem); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "OutputOperandAlias", mlirMhloAttributeIsAOutputOperandAlias) .def_classmethod( "get", - [](py::object cls, const std::vector outputTupleIndices, + [](nb::object cls, const std::vector outputTupleIndices, int64_t operandIndex, const std::vector operandTupleIndices, MlirContext ctx) { return cls(mlirMhloOutputOperandAliasGet( @@ -368,9 +371,9 @@ PYBIND11_MODULE(_mlirHlo, m) { operandIndex, operandTupleIndices.size(), operandTupleIndices.data())); }, - py::arg("cls"), py::arg("output_tuple_indices"), - py::arg("operand_index"), py::arg("operand_tuple_indices"), - py::arg("ctx") = py::none(), + nb::arg("cls"), nb::arg("output_tuple_indices"), + nb::arg("operand_index"), nb::arg("operand_tuple_indices"), + nb::arg("ctx").none() = nb::none(), "Creates a OutputOperandAlias attribute with the given tuple index.") .def_property_readonly( "output_tuple_indices", @@ -390,143 +393,153 @@ PYBIND11_MODULE(_mlirHlo, m) { mlirMhloOutputOperandAliasGetOperandTupleIndicesElem); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "ComparisonDirectionAttr", mlirMhloAttributeIsAComparisonDirectionAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloComparisonDirectionAttrGet( ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a ComparisonDirection attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { return toPyString(mlirMhloComparisonDirectionAttrGetValue(self)); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "ComparisonTypeAttr", mlirMhloAttributeIsAComparisonTypeAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloComparisonTypeAttrGet( ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a ComparisonType attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { return toPyString(mlirMhloComparisonTypeAttrGetValue(self)); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "PrecisionAttr", mlirMhloAttributeIsAPrecisionAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloPrecisionAttrGet( ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a Precision attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { return toPyString(mlirMhloPrecisionAttrGetValue(self)); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "FftTypeAttr", mlirMhloAttributeIsAFftTypeAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloFftTypeAttrGet( ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a FftType attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { return toPyString(mlirMhloFftTypeAttrGetValue(self)); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "DequantizeModeAttr", mlirMhloAttributeIsADequantizeModeAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloDequantizeModeAttrGet( ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a DequantizeMode attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { return toPyString(mlirMhloDequantizeModeAttrGetValue(self)); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "TransposeAttr", mlirMhloAttributeIsATransposeAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloTransposeAttrGet( ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a Transpose attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { return toPyString(mlirMhloTransposeAttrGetValue(self)); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "FusionKindAttr", mlirMhloAttributeIsAFusionKindAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloFusionKindAttrGet( ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a FusionKind attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { return toPyString(mlirMhloFusionKindAttrGetValue(self)); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "RngDistributionAttr", mlirMhloAttributeIsARngDistributionAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloRngDistributionAttrGet( ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a RngDistribution attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { auto value = mlirMhloRngDistributionAttrGetValue(self); - return py::str(value.data, value.length); + return nb::str(value.data, value.length); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "RngAlgorithmAttr", mlirMhloAttributeIsARngAlgorithmAttr) .def_classmethod( "get", - [](py::object cls, const std::string &value, MlirContext ctx) { + [](nb::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloRngAlgorithmAttrGet( ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), "Creates a RngAlgorithm attribute with the given value.") .def_property_readonly("value", [](MlirAttribute self) { auto value = mlirMhloRngAlgorithmAttrGetValue(self); - return py::str(value.data, value.length); + return nb::str(value.data, value.length); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "ChannelHandle", mlirMhloAttributeIsChannelHandle) .def_classmethod( "get", - [](py::object cls, int64_t handle, int64_t type, MlirContext ctx) { + [](nb::object cls, int64_t handle, int64_t type, MlirContext ctx) { return cls(mlirMhloChannelHandleGet(ctx, handle, type)); }, - py::arg("cls"), py::arg("handle"), py::arg("type"), - py::arg("context") = py::none(), "Creates a ChannelHandle attribute.") + nb::arg("cls"), nb::arg("handle"), nb::arg("type"), + nb::arg("context").none() = nb::none(), + "Creates a ChannelHandle attribute.") .def_property_readonly("handle", [](MlirAttribute self) { return mlirMhloChannelHandleGetHandle(self); @@ -535,16 +548,17 @@ PYBIND11_MODULE(_mlirHlo, m) { return mlirMhloChannelHandleGetType(self); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "TypeExtensions", mlirMhloAttributeIsTypeExtensions) .def_classmethod( "get", - [](py::object cls, const std::vector &bounds, + [](nb::object cls, const std::vector &bounds, MlirContext ctx) { return cls( mlirMhloTypeExtensionsGet(ctx, bounds.size(), bounds.data())); }, - py::arg("cls"), py::arg("bounds"), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("bounds"), + nb::arg("context").none() = nb::none(), "Creates a TypeExtensions with the given bounds.") .def_property_readonly("bounds", [](MlirAttribute self) { return attributePropertyVector(self, @@ -552,16 +566,16 @@ PYBIND11_MODULE(_mlirHlo, m) { mlirMhloTypeExtensionsGetBoundsElem); }); - mlir::python::adaptors::mlir_attribute_subclass( + mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "SparsityDescriptor", mlirMhloAttributeIsASparsityDescriptor) .def_classmethod( "get", - [](py::object cls, const int64_t dimension, const int64_t n, + [](nb::object cls, const int64_t dimension, const int64_t n, const int64_t m, MlirContext ctx) { return cls(mlirMhloSparsityDescriptorGet(ctx, dimension, n, m)); }, - py::arg("cls"), py::arg("dimension"), py::arg("n"), py::arg("m"), - py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("dimension"), nb::arg("n"), nb::arg("m"), + nb::arg("context").none() = nb::none(), "Creates a SparseDescriptor attribute with the given sparsity " "configurations.") .def_property_readonly( diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc b/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc index 064978aec3982b..3a09b6e3b33814 100644 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc +++ b/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc @@ -536,7 +536,7 @@ struct BufferReusePass : public impl::BufferReusePassBase { eliminateCopies(block, /*root=*/block); do { // Eliminate dead code. - (void)applyPatternsAndFoldGreedily(getOperation(), {}); + (void)applyPatternsGreedily(getOperation(), {}); // Only coalesce dealloc/alloc pairs that are immediate neighbors, to // make sure we don't accidentally extend the live range of a buffer. result = reuseBuffers(block, BufferReuseMode::CONSERVATIVE); @@ -547,7 +547,7 @@ struct BufferReusePass : public impl::BufferReusePassBase { // Now we can also coalesce distant dealloc/alloc pairs. reuseBuffers(block, BufferReuseMode::AGGRESSIVE); promoteBuffers(block); - (void)applyPatternsAndFoldGreedily(getOperation(), {}); + (void)applyPatternsGreedily(getOperation(), {}); } }; diff --git a/third_party/xla/xla/mlir_hlo/deallocation/utils/util.cc b/third_party/xla/xla/mlir_hlo/deallocation/utils/util.cc index 5c383b357e7f6f..a59c231ebaf0c3 100644 --- a/third_party/xla/xla/mlir_hlo/deallocation/utils/util.cc +++ b/third_party/xla/xla/mlir_hlo/deallocation/utils/util.cc @@ -15,8 +15,6 @@ limitations under the License. #include "deallocation/utils/util.h" -#include - #include "mlir/Dialect/SCF/IR/SCF.h" namespace mlir { diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td index f33701336bbb64..4eb95ef326a659 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -2877,7 +2877,7 @@ def MHLO_ReshapeOp: MHLO_Op<"reshape", let arguments = (ins MHLO_AnyTensor:$operand); - let results = (outs MHLO_StaticShapeTensor); + let results = (outs MHLO_AnyTensor); let hasFolder = 1; let hasCanonicalizer = 1; let hasVerifier = 1; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc index da27173913f81e..c8268e4335dca2 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc @@ -439,8 +439,8 @@ struct BroadcastPropagationPass GreedyRewriteConfig config; config.useTopDownTraversal = false; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc index 6a5300a484f2e1..de4beac80cc2aa 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include #include "mhlo/IR/hlo_ops.h" @@ -23,9 +23,12 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/ChloOps.h" @@ -57,7 +60,8 @@ struct ChloLegalizeToHighLevelMhloPass // Consider the mhlo dialect legal for tests. Also add helper dialects // that are needed by the patterns. conversionTarget.addLegalDialect(); - conversionTarget.addIllegalOp(); + conversionTarget + .addIllegalOp(); if (failed(applyPartialConversion(getOperation(), conversionTarget, std::move(conversionPatterns)))) { @@ -94,6 +98,64 @@ struct ChloLegalizeToHloPass } }; +struct RaggedDotChloToMhlo : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(chlo::RaggedDotOp raggedDotOp, + PatternRewriter &rewriter) const override { + auto moduleOp = raggedDotOp->getParentOfType(); + + OpBuilder builder(moduleOp.getBodyRegion()); + builder.setInsertionPointToStart(&moduleOp.getBodyRegion().front()); + + auto chloRaggedDotDimNums = raggedDotOp.getRaggedDotDimensionNumbers(); + auto dotDimNums = mhlo::DotDimensionNumbersAttr::get( + builder.getContext(), chloRaggedDotDimNums.getLhsBatchingDimensions(), + chloRaggedDotDimNums.getRhsBatchingDimensions(), + chloRaggedDotDimNums.getLhsContractingDimensions(), + chloRaggedDotDimNums.getRhsContractingDimensions()); + auto raggedDotDimNums = mhlo::RaggedDotDimensionNumbersAttr::get( + builder.getContext(), dotDimNums, + chloRaggedDotDimNums.getLhsRaggedDimensions(), + chloRaggedDotDimNums.getRhsGroupDimensions()); + + auto mhloPrecision = + [](chlo::Precision precision) -> std::optional { + switch (precision) { + case chlo::Precision::DEFAULT: + return mhlo::Precision::DEFAULT; + case chlo::Precision::HIGH: + return mhlo::Precision::HIGH; + case chlo::Precision::HIGHEST: + return mhlo::Precision::HIGHEST; + } + }; + ArrayAttr precisionConfig = rewriter.getArrayAttr({}); + if (raggedDotOp.getPrecisionConfig().has_value()) { + SmallVector vector; + for (auto configValue : raggedDotOp.getPrecisionConfig() + .value() + .getAsRange()) { + vector.push_back( + PrecisionAttr::get(raggedDotOp.getContext(), + mhloPrecision(configValue.getValue()).value())); + } + precisionConfig = rewriter.getArrayAttr(vector); + } + + rewriter.replaceOp( + raggedDotOp, + rewriter + .create( + raggedDotOp.getLoc(), raggedDotOp.getResult().getType(), + raggedDotOp.getLhs(), raggedDotOp.getRhs(), + raggedDotOp.getGroupSizes(), raggedDotDimNums, precisionConfig) + .getOperation()); + + return success(); + } +}; + } // namespace } // namespace mhlo @@ -106,6 +168,7 @@ namespace { void populateChloToHighLevelMhloOpPatterns(MLIRContext *, RewritePatternSet *patterns) { + patterns->add(patterns->getContext()); populateWithGenerated(*patterns); } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc index 60fcd198853911..cbe532ba959f76 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc @@ -92,8 +92,7 @@ struct CollapseElementwiseMapPass MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc index e986bdc5ad694c..79e55a4c9f3d53 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc @@ -68,8 +68,7 @@ struct LegalizeDotToDotGeneralPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateDotToDotGeneralPatterns(&getContext(), &patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc index c35ce560146dcb..e861dec331848c 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc @@ -179,8 +179,7 @@ struct LegalizeEinsumToDotGeneralPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateEinsumToDotGeneralPatterns(&getContext(), &patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc index 8cc65ea23f04c2..865c07fc316d89 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc @@ -139,8 +139,7 @@ struct LegalizeTorchIndexSelectToGatherPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateTorchIndexSelectToGatherPatterns(&getContext(), &patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc index 2e7018e2fd17c3..ccf2ed1151ccc7 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc @@ -172,8 +172,7 @@ struct LegalizeTrigonometricToApproximationPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateTrigonometricToApproximationPatterns(&getContext(), &patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc index 185b2c9d7caa18..d6c4b4767297d6 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc @@ -434,8 +434,8 @@ struct MergeAssumingOpsPass mhlo::populateMergeAssumingOpsPatterns(ctx, &patterns); GreedyRewriteConfig config; config.maxIterations = GreedyRewriteConfig::kNoLimit; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc index deccadf230d5a3..b86038624c4c24 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc @@ -132,8 +132,7 @@ class FlattenTuplePass : public impl::FlattenTuplePassBase { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); patterns.add(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo.cc index d5932234c5f003..bd45785c2c5bec 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo.cc @@ -165,8 +165,8 @@ class GatherIsSlice : public OpRewritePattern { } // end anonymous namespace -void populateOptimizeMhloPatterns(MLIRContext* context, - RewritePatternSet* patterns) { +static void populateOptimizeMhloPatterns(MLIRContext* context, + RewritePatternSet* patterns) { patterns->add(context); } } // end namespace mhlo diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc index dfd370298bd862..a4fa95071d1283 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include "llvm/ADT/STLExtras.h" #include "mhlo/IR/hlo_ops.h" diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc index b96370f71cf23c..1747bd93b492ef 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc @@ -242,7 +242,7 @@ struct ShapeSimplification ExtractFromBroadcastedTensorCanonicalizationPattern>(context); auto func = getOperation(); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + if (failed(applyPatternsGreedily(func, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc index 20808e4d12d9e7..961e512d239686 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc @@ -793,8 +793,8 @@ class SymbolicShapeOptimizationPass final shape::AssumingOp::getCanonicalizationPatterns(patterns, ctx); shape::ShapeOfOp::getCanonicalizationPatterns(patterns, ctx); - if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed( + mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc index 8bd3bbc1409610..d585ea0b9d1592 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc @@ -95,8 +95,7 @@ struct TestInferShapedTypeMethodsPass RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc index 7409def78d770f..285f056008da72 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc @@ -43,8 +43,7 @@ struct TestUnfuseBatchNormPass RewritePatternSet patterns(&getContext()); populateUnfuseBatchNormInferencePattern(&getContext(), &patterns); populateUnfuseBatchNormTrainingPattern(&getContext(), &patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/utils/type_conversion.cc b/third_party/xla/xla/mlir_hlo/mhlo/utils/type_conversion.cc index 27aa4efc2ea6f0..0c53644a2f031e 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/utils/type_conversion.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/utils/type_conversion.cc @@ -110,6 +110,8 @@ RemoveSignTypeConverter::RemoveSignTypeConverter() { LinalgTypeConverter::LinalgTypeConverter() : RemoveSignTypeConverter() { addArgumentMaterialization(scalarToTensor); + addSourceMaterialization(scalarToTensor); + addTargetMaterialization(scalarToTensor); } } // namespace mhlo diff --git a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp index 0ad3029f96ccf6..9cd3e90e6f5dfb 100644 --- a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp +++ b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp @@ -200,8 +200,7 @@ struct StablehloCanonicalizeDynamismPass patterns.add(&getContext()); auto funcOp = getOperation(); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(funcOp, std::move(patterns), config))) { funcOp.emitError("Failed to converge StablehloCanonicalizeDynamism in ") << config.maxIterations << " iterations"; return signalPassFailure(); diff --git a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp index 7f630f0e11eea0..37effdeadd65af 100644 --- a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp +++ b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp @@ -13,9 +13,11 @@ limitations under the License. ==============================================================================*/ #include +#include #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Support/LogicalResult.h" @@ -138,32 +140,20 @@ struct StablehloRefineShapesPass auto func = stablehlo::getStablehloRefineShapesTarget(getOperation()); if (!func) return signalPassFailure(); - // The algorithm behind this pass consists of a single traversal of the - // function. This is sufficient because we only support one function per - // program at the moment. - // TODO(#1048): Find out why .maxIterations = 1 no longer works. - // There have been recent refactors to applyPatternsAndFoldGreedily - // upstream, and that might be the reason. - GreedyRewriteConfig config; - config.useTopDownTraversal = true; - config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; - config.maxIterations = 3; - config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; - config.strictMode = GreedyRewriteStrictness::AnyOp; - - RewritePatternSet patterns(&getContext()); - stablehlo::populateStablehloRefineShapesPatterns(&patterns, &getContext()); - stablehlo::populateStablehloShapeFolderPatterns(&patterns, &getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - if (failed( - applyPatternsAndFoldGreedily(func, std::move(patterns), config))) { - func.emitError() - << "Greedy rewriter in StablehloRefineShapes does not converge after " - << config.maxIterations << " iterations."; + // Start with empty state, and no dim args / token args. + MLIRContext* context = func.getContext(); + + // Populate additional patterns for StableHLO extensions. + std::function additionalPatternsFn = + [&](RewritePatternSet* patterns) { + patterns->add(context); + patterns->add(context); + patterns->add(context); + }; + + if (failed(stablehlo::refineEntryFunction(*context, func, + additionalPatternsFn))) return signalPassFailure(); - } } }; diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir index 3e67fb3c3ed8bb..9f588b0bb18c91 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir @@ -3642,3 +3642,57 @@ func.func @erf_inv_wide(%arg0 : tensor<16x16xf64>) { %0 = chlo.erf_inv %arg0 : tensor<16x16xf64> -> tensor<16x16xf64> return } + +// ----- + +func.func @ragged_dot_non_contracting(%lhs : tensor<2x11x5xf32>, %rhs : tensor<3x2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<2x11x7xf32> { + // CHECK-HIGH-LEVEL: mhlo.ragged_dot + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [1], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [2], + lhs_ragged_dimensions = [1], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32> + func.return %0 : tensor<2x11x7xf32> +} + +// ----- + +func.func @ragged_dot_contracting(%lhs : tensor<2x11x5xf32>, %rhs : tensor<2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<3x2x11x7xf32> { + // CHECK-HIGH-LEVEL: mhlo.ragged_dot + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [2], + rhs_group_dimensions = [] + >, + precision_config = [#chlo, #chlo] + } : (tensor<2x11x5xf32>, tensor<2x5x7xf32>, tensor<3xi64>) -> tensor<3x2x11x7xf32> + func.return %0 : tensor<3x2x11x7xf32> +} + +// ----- + +func.func @ragged_dot_batch(%lhs : tensor<3x11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<3x11x7xf32> { + // CHECK-HIGH-LEVEL: mhlo.ragged_dot + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [0], + rhs_group_dimensions = [] + >, + precision_config = [#chlo, #chlo] + } : (tensor<3x11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<3x11x7xf32> + func.return %0 : tensor<3x11x7xf32> +} diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 70a27eabd67856..92b59bda4c1c05 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -2143,6 +2143,15 @@ func.func @op_fusion(%arg0: tensor) -> tensor { // ----- +func.func @reshape_with_dynamic_size_convert(%arg0: tensor>) -> tensor> { + // expected-error@+1 {{'stablehlo.reshape' op result #0 must be statically shaped tensor}} + %0 = "mhlo.reshape"(%arg0) : (tensor>) + -> tensor> + return %0 : tensor> +} + +// ----- + func.func @op_stochastic_convert(%arg0: tensor, %arg1: tensor) -> tensor { // expected-error@+1 {{failed to legalize operation 'mhlo.stochastic_convert' that was explicitly marked illegal}} %0 = "mhlo.stochastic_convert"(%arg0, %arg1) : (tensor, tensor) -> tensor diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index 85eeb2c22a44f1..12b16bc1fad215 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -2920,6 +2920,18 @@ func.func @reshape_invalid_shapes(%operand: tensor<2x4xf32>) -> tensor<3x3xf32> // ----- +// CHECK-LABEL: func @reshape_can_have_dynamic_dimensions +func.func @reshape_can_have_dynamic_dimensions() -> tensor> { + %0 = "mhlo.constant"() {value = dense<[[1],[2],[3],[4],[5],[6],[7]]> : tensor<7x1xi64>} : () -> tensor<7x1xi64> + %size = builtin.unrealized_conversion_cast to tensor + %1 = "mhlo.set_dimension_size"(%0, %size) <{dimension = 0 : i64}> : (tensor<7x1xi64>, tensor) -> tensor> + %2 = "mhlo.reshape"(%1) : (tensor>) + -> tensor> + return %2 : tensor> +} + +// ----- + // CHECK-LABEL: func @reverse func.func @reverse(%operand: tensor<3x2xi32>) -> tensor<3x2xi32> { %0 = "mhlo.reverse"(%operand) { diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index c8687cfe3ff0da..fdf12a56cefb08 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -1355,6 +1355,13 @@ func.func @op_reshape(%arg0: tensor<16xf32>) -> tensor<4x4xf32> { func.return %0 : tensor<4x4xf32> } +// CHECK-LABEL: "op_reshape_dynamic" +func.func @op_reshape_dynamic(%arg0: tensor>) -> tensor<7xi64> { + // CHECK: "mhlo.reshape"({{.*}}) : (tensor>) -> tensor<7xi64> + %0 = "stablehlo.reshape"(%arg0) : (tensor>) -> tensor<7xi64> + return %0 : tensor<7xi64> +} + // CHECK-LABEL: "op_return" func.func @op_return(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "mhlo.case"([[ARG0:%arg[0-9]+]]) ({ diff --git a/third_party/xla/xla/mlir_hlo/tests/stablehlo_ext/stablehlo_refine_shapes.mlir b/third_party/xla/xla/mlir_hlo/tests/stablehlo_ext/stablehlo_refine_shapes.mlir index 85d3c97dcaf581..63560cf04a3e36 100644 --- a/third_party/xla/xla/mlir_hlo/tests/stablehlo_ext/stablehlo_refine_shapes.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/stablehlo_ext/stablehlo_refine_shapes.mlir @@ -40,3 +40,23 @@ func.func @refine_dynamic_top_k(%arg0: tensor<16xf32>) -> (tensor, tensor %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor) -> (tensor, tensor) return %1#0, %1#1 : tensor, tensor } + +// ----- + +// CHECK-LABEL: module @refine_call +module @refine_call { + // CHECK: func.func @main{{.*}}-> (tensor<4xf32>, tensor<4xi32>) + func.func @main(%arg1: tensor<16xf32>) -> (tensor, tensor) { + %0 = stablehlo.bitcast_convert %arg1 : (tensor<16xf32>) -> tensor + // CHECK: refine_call_callee{{.*}}-> (tensor<4xf32>, tensor<4xi32>) + %2:2 = call @refine_call_callee(%0) : (tensor) -> (tensor, tensor) + return %2#0, %2#1 : tensor, tensor + } + // CHECK: refine_call_callee(%arg0: tensor<16xf32>) -> (tensor<4xf32>, tensor<4xi32>) + func.func @refine_call_callee(%arg0: tensor) -> (tensor, tensor) { + // CHECK: stablehlo.dynamic_top_k{{.*}} -> (tensor<4xf32>, tensor<4xi32>) + %k = stablehlo.constant dense<4> : tensor + %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor, tensor) -> (tensor, tensor) + return %1#0, %1#1 : tensor, tensor + } +} diff --git a/third_party/xla/xla/mlir_hlo/transforms/detensorize_scf_ops.cc b/third_party/xla/xla/mlir_hlo/transforms/detensorize_scf_ops.cc index 2a8be4e6b09ae0..12d8b3814646e7 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/detensorize_scf_ops.cc +++ b/third_party/xla/xla/mlir_hlo/transforms/detensorize_scf_ops.cc @@ -120,7 +120,7 @@ struct DetensorizeScfOpsPass patterns.add, RegionOpPattern, RegionOpPattern>(&getContext()); - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { + if (failed(applyPatternsGreedily(f, std::move(patterns)))) { signalPassFailure(); } } diff --git a/third_party/xla/xla/mlir_hlo/transforms/generic_host_to_llvm.cc b/third_party/xla/xla/mlir_hlo/transforms/generic_host_to_llvm.cc index 9df69afbaf55aa..8cd4bf99f5133d 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/generic_host_to_llvm.cc +++ b/third_party/xla/xla/mlir_hlo/transforms/generic_host_to_llvm.cc @@ -86,7 +86,7 @@ class GenericHostToLLVMPass // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. vector::populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } LLVMConversionTarget target(*ctx); diff --git a/third_party/xla/xla/mlir_hlo/transforms/gpu_kernel_lowering_passes.cc b/third_party/xla/xla/mlir_hlo/transforms/gpu_kernel_lowering_passes.cc index 3e22aa55888327..d490588de4508b 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/gpu_kernel_lowering_passes.cc +++ b/third_party/xla/xla/mlir_hlo/transforms/gpu_kernel_lowering_passes.cc @@ -96,7 +96,7 @@ void GpuKernelToNVVMPass::runOnOperation() { { RewritePatternSet patterns(&getContext()); populateAllCommonVectorProgressiveLoweringPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } RewritePatternSet patterns(&getContext()); diff --git a/third_party/xla/xla/mlir_hlo/transforms/lower_index_cast_pass.cc b/third_party/xla/xla/mlir_hlo/transforms/lower_index_cast_pass.cc index 489d8fb4cb811e..b773792e67b5c4 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/lower_index_cast_pass.cc +++ b/third_party/xla/xla/mlir_hlo/transforms/lower_index_cast_pass.cc @@ -64,8 +64,7 @@ struct LowerIndexCastPass patterns.add, IndexCastConverter>( patterns.getContext()); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/xla/xla/mlir_hlo/transforms/naive_copy_removal.cc b/third_party/xla/xla/mlir_hlo/transforms/naive_copy_removal.cc index 55ab2fbb2e0ee5..a13f0396a85e63 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/naive_copy_removal.cc +++ b/third_party/xla/xla/mlir_hlo/transforms/naive_copy_removal.cc @@ -80,7 +80,7 @@ struct NaiveCopyRemovalPass RewritePatternSet patterns(ctx); patterns.add(removeCopy); memref::AllocOp::getCanonicalizationPatterns(patterns, ctx); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + if (failed(applyPatternsGreedily(func, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/xla/xla/mlir_hlo/transforms/tile_loops_pass.cc b/third_party/xla/xla/mlir_hlo/transforms/tile_loops_pass.cc index ee3b935cff2771..d6efd72d2437c0 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/tile_loops_pass.cc +++ b/third_party/xla/xla/mlir_hlo/transforms/tile_loops_pass.cc @@ -127,7 +127,7 @@ void TileLoopsPass::runOnOperation() { getContext() .getOrLoadDialect() ->getCanonicalizationPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/third_party/xla/xla/mlir_hlo/transforms/vectorize_copy.cc b/third_party/xla/xla/mlir_hlo/transforms/vectorize_copy.cc index 1b68cd8b28b74e..5650e83be0c2d4 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/vectorize_copy.cc +++ b/third_party/xla/xla/mlir_hlo/transforms/vectorize_copy.cc @@ -215,7 +215,7 @@ struct VectorizeCopyPass RewritePatternSet patterns(ctx); patterns.add( ctx, /*numElementsThreshold = */ 8); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { return signalPassFailure(); } } diff --git a/third_party/xla/xla/mlir_hlo/utils/cycle_detector_test.cc b/third_party/xla/xla/mlir_hlo/utils/cycle_detector_test.cc index dd0fdacfb3f9df..18bdefb50b5eab 100644 --- a/third_party/xla/xla/mlir_hlo/utils/cycle_detector_test.cc +++ b/third_party/xla/xla/mlir_hlo/utils/cycle_detector_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "utils/cycle_detector.h" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" class GraphCyclesTest : public ::testing::Test { public: diff --git a/third_party/xla/xla/permutation_util_test.cc b/third_party/xla/xla/permutation_util_test.cc index 9597da742f09da..99266509404763 100644 --- a/third_party/xla/xla/permutation_util_test.cc +++ b/third_party/xla/xla/permutation_util_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "xla/permutation_util.h" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" namespace xla { namespace { diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 9ee6a77af1097c..b01d79cec6febb 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -73,7 +73,7 @@ xla_cc_test( srcs = ["semaphore_test.cc"], deps = [ ":semaphore", - "//xla:test", + "//xla/hlo/testlib:test", "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:env", @@ -123,10 +123,10 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", - "//xla:test", "//xla:util", "//xla/client:client_library", "//xla/client:local_client", + "//xla/hlo/testlib:test", "//xla/service:cpu_plugin", "//xla/stream_executor:device_memory_allocator", "@com_google_absl//absl/log", @@ -247,11 +247,11 @@ cc_library( ":pjrt_compiler", "//xla:cpu_function_runtime", "//xla:shape_util", - "//xla:test", "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test", "//xla/tests:literal_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", @@ -413,11 +413,9 @@ cc_library( deps = [ "//xla:shape_util", "//xla/hlo/parser:hlo_parser", - "@com_google_absl//absl/hash", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:statusor", ], ) @@ -505,6 +503,7 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", "//xla/client:executable_build_options", "//xla/client:local_client", "//xla/hlo/builder:xla_computation", @@ -513,9 +512,11 @@ cc_library( "//xla/service:compiler", "//xla/service:computation_layout", "//xla/service:computation_placer", + "//xla/service:dump", "//xla/service:executable", "//xla/service:generic_transfer_manager", "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_proto_cc", "//xla/service:maybe_owning_device_memory", "//xla/service:shaped_buffer", "//xla/service:transfer_manager", @@ -565,10 +566,10 @@ xla_cc_test( "//xla:literal_comparison", "//xla:literal_util", "//xla:shape_util", - "//xla:test", "//xla:xla_data_proto_cc", "//xla/client:client_library", "//xla/hlo/builder:xla_builder", + "//xla/hlo/testlib:test", "//xla/service:cpu_plugin", "//xla/service:platform_util", "//xla/tsl/concurrency:async_value", @@ -644,7 +645,7 @@ xla_cc_test( srcs = ["mlir_to_hlo_test.cc"], deps = [ ":mlir_to_hlo", - "//xla:test", + "//xla/hlo/testlib:test", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", @@ -722,7 +723,7 @@ xla_cc_test( srcs = ["lru_cache_test.cc"], deps = [ ":lru_cache", - "//xla:test", + "//xla/hlo/testlib:test", "@local_tsl//tsl/platform:test_main", ], ) @@ -766,8 +767,8 @@ xla_cc_test( "//xla:array", "//xla:permutation_util", "//xla:shape_util", - "//xla:test", "//xla:util", + "//xla/hlo/testlib:test", "//xla/tsl/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/numeric:int128", @@ -806,6 +807,7 @@ cc_library( "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/pjrt/c:pjrt_c_api_helpers", "//xla/pjrt/c:pjrt_c_api_layouts_extension_hdrs", + "//xla/pjrt/c:pjrt_c_api_memory_descriptions_extension_hdrs", "//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs", "//xla/pjrt/c:pjrt_c_api_stream_extension_hdrs", "//xla/pjrt/distributed:key_value_store_interface", @@ -813,6 +815,7 @@ cc_library( "//xla/service:hlo_cost_analysis", "//xla/service:hlo_proto_cc", "//xla/tsl/framework:allocator", + "//xla/tsl/platform:status", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -849,6 +852,7 @@ xla_cc_test( ":pjrt_c_api_client", ":pjrt_client", ":pjrt_compiler", + ":pjrt_device_description", ":pjrt_executable", "//xla:cpu_function_runtime", "//xla:literal_util", @@ -860,6 +864,7 @@ xla_cc_test( "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/pjrt/c/BUILD b/third_party/xla/xla/pjrt/c/BUILD index ca3f5be88989c7..40ffd72e0c3222 100644 --- a/third_party/xla/xla/pjrt/c/BUILD +++ b/third_party/xla/xla/pjrt/c/BUILD @@ -119,6 +119,15 @@ cc_library( ], ) +cc_library( + name = "pjrt_c_api_memory_descriptions_extension_hdrs", + hdrs = ["pjrt_c_api_memory_descriptions_extension.h"], + visibility = ["//visibility:public"], + deps = [ + ":pjrt_c_api_hdrs", + ], +) + cc_library( name = "pjrt_c_api_wrapper_impl", srcs = ["pjrt_c_api_wrapper_impl.cc"], @@ -128,6 +137,7 @@ cc_library( ":pjrt_c_api_hdrs", ":pjrt_c_api_helpers", ":pjrt_c_api_layouts_extension_hdrs", + ":pjrt_c_api_memory_descriptions_extension_hdrs", "//xla:literal", "//xla:shape_util", "//xla:util", @@ -151,6 +161,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -177,12 +188,14 @@ cc_library( deps = [ ":pjrt_c_api_hdrs", ":pjrt_c_api_layouts_extension_hdrs", + ":pjrt_c_api_memory_descriptions_extension_hdrs", ":pjrt_c_api_profiler_extension_hdrs", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", + "//xla/pjrt:pjrt_device_description", "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", "//xla/pjrt/distributed:key_value_store_interface", @@ -399,6 +412,7 @@ xla_test( "//xla/tests:literal_test_util", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -455,6 +469,7 @@ cc_library( deps = [ ":pjrt_c_api_hdrs", ":pjrt_c_api_helpers", + ":pjrt_c_api_memory_descriptions_extension_hdrs", ":pjrt_c_api_test_base", "//xla:literal", "//xla:literal_util", @@ -466,6 +481,7 @@ cc_library( "//xla/hlo/parser:hlo_parser", "//xla/pjrt:compile_options_proto_cc", "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_device_description", "//xla/pjrt:pjrt_future", "//xla/service:computation_placer_hdr", "//xla/service:hlo_proto_cc", diff --git a/third_party/xla/xla/pjrt/c/CHANGELOG.md b/third_party/xla/xla/pjrt/c/CHANGELOG.md index 6034d631634e02..d56741eb3500b0 100644 --- a/third_party/xla/xla/pjrt/c/CHANGELOG.md +++ b/third_party/xla/xla/pjrt/c/CHANGELOG.md @@ -1,5 +1,16 @@ # PJRT C API changelog +## 0.61 +* Added ``PJRT_KeyValueTryGet`` to the KV store interface, + which is non-blocking and immediately returns an error if the + key is not found. + +## 0.60 +* Added ``PJRT_Client_CreateBuffersForAsyncHostToDevice`` and ``PJRT_AsyncHostToDeviceTransferManager_TransferRawDataToSubBuffer``. + +## 0.59 +* Added ``PJRT_MemoryDescriptions_Extension``. + ## 0.57 * Rearranged fields in the PJRT_Api * Update outdated struct sizes from previous changes to diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api.h b/third_party/xla/xla/pjrt/c/pjrt_c_api.h index 85c1903e648117..f2fc3b1c507a3c 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api.h @@ -45,6 +45,7 @@ typedef enum { PJRT_Extension_Type_Stream, PJRT_Extension_Type_Layouts, PJRT_Extension_Type_FFI, + PJRT_Extension_Type_MemoryDescriptions, } PJRT_Extension_Type; // PJRT_Extension_Base contains a type and a pointer to next @@ -79,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 57 +#define PJRT_API_MINOR 61 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -307,11 +308,14 @@ typedef PJRT_Error* PJRT_Event_OnReady(PJRT_Event_OnReady_Args* args); typedef struct PJRT_Client PJRT_Client; typedef struct PJRT_Device PJRT_Device; typedef struct PJRT_Memory PJRT_Memory; +typedef struct PJRT_ShapeSpec PJRT_ShapeSpec; typedef struct PJRT_DeviceDescription PJRT_DeviceDescription; typedef struct PJRT_TopologyDescription PJRT_TopologyDescription; typedef struct PJRT_Executable PJRT_Executable; typedef struct PJRT_LoadedExecutable PJRT_LoadedExecutable; typedef struct PJRT_Buffer PJRT_Buffer; +typedef struct PJRT_AsyncHostToDeviceTransferManager + PJRT_AsyncHostToDeviceTransferManager; // The caller of PJRT_Client_Create can optionally provide a key-value store // accessible across nodes and/or processes. KV store access may be necessary to @@ -347,6 +351,35 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueGetCallback_Args, typedef PJRT_Error* (*PJRT_KeyValueGetCallback)( PJRT_KeyValueGetCallback_Args* args); +// Same as KeyValueGet, but returns `NotFoundError` immediately if the key is +// not found. +typedef void (*PJRT_KeyValueTryGetCallback_ValueDeleter)(char* value); + +struct PJRT_KeyValueTryGetCallback_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + const char* key; + size_t key_size; + PJRT_CallbackError* callback_error; + void* user_arg; + char* value; // out + size_t value_size; // out + // The caller needs to set a PJRT_KeyValueTryGetCallback_ValueDeleter to + // delete the value returned by PJRT_KeyValueTryGetCallback. The + // implementation is responsible for copying `value` and then calling + // value_deleter_callback. + PJRT_KeyValueTryGetCallback_ValueDeleter value_deleter_callback; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueTryGetCallback_Args, + value_deleter_callback); + +// Requirements for PJRT_KeyValueTryGetCallback implementation: (1) Thread-safe. +// (2) The caller that provides the two callbacks is responsible for avoiding +// key collisions between different users of key-value store (i.e. between +// different plugins, but not between different nodes in one plugin). +typedef PJRT_Error* (*PJRT_KeyValueTryGetCallback)( + PJRT_KeyValueTryGetCallback_Args* args); + struct PJRT_KeyValuePutCallback_Args { size_t struct_size; PJRT_Extension_Base* extension_start; @@ -385,8 +418,15 @@ struct PJRT_Client_Create_Args { void* kv_put_user_arg; PJRT_Client* client; // out + + // Key-value try-get callback provided by the caller of PJRT_Client_Create. + // Same as key-value get callback, but returns `NotFoundError` immediately if + // the key is not found. + PJRT_KeyValueTryGetCallback kv_try_get_callback; + // Will be passed to `kv_try_get_callback` as `user_arg` argument. + void* kv_try_get_user_arg; }; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, client); +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, kv_try_get_user_arg); // Creates and initializes a new PJRT_Client and returns in `client`. typedef PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args); @@ -592,6 +632,35 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_DefaultDeviceAssignment_Args, typedef PJRT_Error* PJRT_Client_DefaultDeviceAssignment( PJRT_Client_DefaultDeviceAssignment_Args* args); +struct PJRT_AsyncHostToDeviceTransferManager_Destroy_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + PJRT_AsyncHostToDeviceTransferManager* transfer_manager; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_AsyncHostToDeviceTransferManager_Destroy_Args, + transfer_manager); + +// Frees `transfer_manager`. `transfer_manager` can be nullptr. +typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_Destroy( + PJRT_AsyncHostToDeviceTransferManager_Destroy_Args* args); + +struct PJRT_AsyncHostToDeviceTransferManager_TransferData_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + PJRT_AsyncHostToDeviceTransferManager* transfer_manager; + int buffer_index; + const void* data; + int64_t offset; + int64_t transfer_size; + bool is_last_transfer; + PJRT_Event* done_with_h2d_transfer; // out +}; +PJRT_DEFINE_STRUCT_TRAITS( + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args, + done_with_h2d_transfer); +typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_TransferData( + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args* args); + typedef enum { // Invalid primitive type to serve as default. PJRT_Buffer_Type_INVALID, @@ -819,6 +888,31 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_CreateViewOfDeviceBuffer_Args, buffer); typedef PJRT_Error* PJRT_Client_CreateViewOfDeviceBuffer( PJRT_Client_CreateViewOfDeviceBuffer_Args* args); +struct PJRT_ShapeSpec { + size_t struct_size; + PJRT_Extension_Base* extension_start; + const int64_t* dims; + size_t num_dims; + PJRT_Buffer_Type element_type; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_ShapeSpec, element_type); + +struct PJRT_Client_CreateBuffersForAsyncHostToDevice_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + PJRT_Client* client; + PJRT_ShapeSpec* shape_specs; + size_t num_shape_specs; + PJRT_Buffer_MemoryLayout** device_layouts; // optional + size_t num_device_layouts; + PJRT_Memory* memory; + PJRT_AsyncHostToDeviceTransferManager* transfer_manager; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_CreateBuffersForAsyncHostToDevice_Args, + transfer_manager); +typedef PJRT_Error* PJRT_Client_CreateBuffersForAsyncHostToDevice( + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args* args); + // -------------------------- Device Descriptions ------------------------------ // Device descriptions may be associated with an actual device @@ -2265,10 +2359,14 @@ typedef struct PJRT_Api { _PJRT_API_STRUCT_FIELD(PJRT_ExecuteContext_Create); _PJRT_API_STRUCT_FIELD(PJRT_ExecuteContext_Destroy); _PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyRawToHost); + _PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_Destroy); + _PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_TransferData); + _PJRT_API_STRUCT_FIELD(PJRT_Client_CreateBuffersForAsyncHostToDevice); } PJRT_Api; enum { - PJRT_Api_STRUCT_SIZE = PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Buffer_CopyRawToHost) + PJRT_Api_STRUCT_SIZE = + PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Client_CreateBuffersForAsyncHostToDevice) }; #undef _PJRT_API_STRUCT_FIELD diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc index 0f0b2d4071a89d..2e5cbe3d412027 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc @@ -60,12 +60,16 @@ const PJRT_Api* GetCpuPjrtApi() { static PJRT_Layouts_Extension layouts_extension = pjrt::CreateLayoutsExtension(nullptr); + static PJRT_MemoryDescriptions_Extension memory_descriptions_extension = + pjrt::CreateMemoryDescriptionsExtension( + reinterpret_cast(&layouts_extension)); + static const PJRT_Api pjrt_api = pjrt::CreatePjrtApi( pjrt::cpu_plugin::PJRT_Client_Create, pjrt::cpu_plugin::PJRT_ExecuteContext_Create, pjrt::cpu_plugin::PJRT_CpuDeviceTopology_Create, pjrt::PJRT_Plugin_Initialize_NoOp, - reinterpret_cast(&layouts_extension), + reinterpret_cast(&memory_descriptions_extension), pjrt::PJRT_Plugin_Attributes_Xla); return &pjrt_api; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_internal.cc index 4ee3722f94f47f..0375b39d0b9a0d 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_ffi_internal.cc @@ -15,8 +15,6 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api_ffi_internal.h" -#include - #include "absl/status/status.h" #include "xla/ffi/execution_context.h" #include "xla/ffi/type_id_registry.h" @@ -36,7 +34,7 @@ static PJRT_Error* PJRT_FFI_TypeID_Register( PJRT_ASSIGN_OR_RETURN( auto type_id, xla::ffi::TypeIdRegistry::RegisterExternalTypeId( - std::string_view(args->type_name, args->type_name_size))); + absl::string_view(args->type_name, args->type_name_size))); args->type_id = type_id.value(); return nullptr; } diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index 17995b811ce695..68d36fdb7f5c86 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -154,9 +154,9 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { options.num_nodes = num_nodes; options.allowed_devices = visible_devices; options.platform_name = platform_name; - options.kv_store = - pjrt::ToCppKeyValueStore(args->kv_get_callback, args->kv_get_user_arg, - args->kv_put_callback, args->kv_put_user_arg); + options.kv_store = pjrt::ToCppKeyValueStore( + args->kv_get_callback, args->kv_get_user_arg, args->kv_try_get_callback, + args->kv_try_get_user_arg, args->kv_put_callback, args->kv_put_user_arg); options.enable_mock_nccl = enable_mock_nccl; options.mock_gpu_topology = mock_gpu_topology; PJRT_ASSIGN_OR_RETURN(std::unique_ptr client, @@ -399,12 +399,16 @@ const PJRT_Api* GetGpuPjrtApi() { static PJRT_FFI_Extension ffi_extension = pjrt::CreateFfiExtension( reinterpret_cast(&layouts_extension)); + static PJRT_MemoryDescriptions_Extension memory_descriptions_extension = + pjrt::CreateMemoryDescriptionsExtension( + reinterpret_cast(&ffi_extension)); + static const PJRT_Api pjrt_api = pjrt::CreatePjrtApi( pjrt::gpu_plugin::PJRT_Client_Create, pjrt::gpu_plugin::PJRT_ExecuteContext_Create, pjrt::gpu_plugin::PJRT_GpuDeviceTopology_Create, pjrt::PJRT_Plugin_Initialize_NoOp, - reinterpret_cast(&ffi_extension), + reinterpret_cast(&memory_descriptions_extension), pjrt::PJRT_Plugin_Attributes_Xla); return &pjrt_api; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc index 43bbaf5056aa4b..ae12a1684c23ee 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -30,6 +31,7 @@ limitations under the License. #include #include #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" @@ -272,6 +274,71 @@ TEST_F(PjrtCApiGpuTest, CreateAndDestroyExecuteContext) { api_->PJRT_ExecuteContext_Destroy(&destroy_args); } +TEST_F(PjrtCApiGpuTest, CreateBuffersWithMemorytForH2DAndTransfer) { + xla::Shape host_shape = xla::ShapeUtil::MakeShapeWithDenseLayout( + xla::F32, /*dimensions=*/{2, 2, 2}, /*minor_to_major=*/{1, 0, 2}); + std::vector float_data = {1, 2, 3, 4, 5, 6, 7, 8}; + + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args args; + args.struct_size = + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.client = client_; + PJRT_ShapeSpec c_shape_spec; + c_shape_spec.element_type = + pjrt::ConvertToPjRtBufferType(xla::PrimitiveType::F32); + c_shape_spec.dims = host_shape.dimensions().data(); + c_shape_spec.num_dims = host_shape.dimensions().size(); + args.shape_specs = &c_shape_spec; + args.num_shape_specs = 1; + TF_ASSERT_OK_AND_ASSIGN(pjrt::BufferMemoryLayoutData c_layout_data, + ConvertToBufferMemoryLayoutData(host_shape.layout())); + std::vector device_layout_list(1); + device_layout_list[0] = &(c_layout_data.c_layout); + args.device_layouts = device_layout_list.data(); + args.num_device_layouts = device_layout_list.size(); + PJRT_Client_AddressableMemories_Args memory_args; + memory_args.struct_size = PJRT_Client_AddressableMemories_Args_STRUCT_SIZE; + memory_args.extension_start = nullptr; + memory_args.client = client_; + + PJRT_Error* memory_error = + api_->PJRT_Client_AddressableMemories(&memory_args); + ASSERT_EQ(memory_error, nullptr); + ASSERT_NE(memory_args.addressable_memories, nullptr); + ASSERT_GT(memory_args.num_addressable_memories, 0); + args.memory = memory_args.addressable_memories[0]; + PJRT_Error* error = + api_->PJRT_Client_CreateBuffersForAsyncHostToDevice(&args); + ASSERT_EQ(error, nullptr); + + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args transfer_args; + transfer_args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args_STRUCT_SIZE; + transfer_args.extension_start = nullptr; + transfer_args.transfer_manager = args.transfer_manager; + transfer_args.buffer_index = 0; + transfer_args.data = float_data.data(); + transfer_args.offset = 0; + transfer_args.transfer_size = float_data.size(); + transfer_args.is_last_transfer = true; + + PJRT_Error* transfer_error = + PJRT_AsyncHostToDeviceTransferManager_TransferData(&transfer_args); + ASSERT_EQ(transfer_error, nullptr); + std::unique_ptr done_with_h2d_transfer_event( + transfer_args.done_with_h2d_transfer, MakeEventDeleter(api_)); + + // Destroy the transfer manager. + PJRT_AsyncHostToDeviceTransferManager_Destroy_Args destroy_args; + destroy_args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_Destroy_Args_STRUCT_SIZE; + destroy_args.extension_start = nullptr; + destroy_args.transfer_manager = args.transfer_manager; + LogFatalIfPjrtError( + api_->PJRT_AsyncHostToDeviceTransferManager_Destroy(&destroy_args), api_); +} + absl::StatusOr BuildCreateArg( ::pjrt::PJRT_KeyValueCallbackData* kv_callback_data, std::vector& c_options) { @@ -284,6 +351,8 @@ absl::StatusOr BuildCreateArg( args.kv_get_user_arg = &kv_callback_data->kv_get_c_func; args.kv_put_callback = kv_callback_data->c_kv_put; args.kv_put_user_arg = &kv_callback_data->kv_put_c_func; + args.kv_try_get_user_arg = &kv_callback_data->kv_try_get_c_func; + args.kv_try_get_callback = kv_callback_data->c_kv_try_get; args.client = nullptr; return args; } diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc index ca09cc4ec8856f..c5113d1766ef66 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -38,10 +37,12 @@ limitations under the License. #include "xla/layout.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_layouts_extension.h" +#include "xla/pjrt/c/pjrt_c_api_memory_descriptions_extension.h" #include "xla/pjrt/c/pjrt_c_api_profiler_extension.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/primitive_util.h" @@ -75,6 +76,20 @@ PJRT_ClientDeleter MakeClientDeleter(const PJRT_Api* api) { }; } +PJRT_AsyncHostToDeviceTransferManagerDeleter +MakeAsyncHostToDeviceTransferManagerDeleter(const PJRT_Api* api) { + return [api]( + PJRT_AsyncHostToDeviceTransferManager* transfer_manager) -> void { + PJRT_AsyncHostToDeviceTransferManager_Destroy_Args destroy_args; + destroy_args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_Destroy_Args_STRUCT_SIZE; + destroy_args.extension_start = nullptr; + destroy_args.transfer_manager = transfer_manager; + pjrt::LogFatalIfPjrtError( + api->PJRT_AsyncHostToDeviceTransferManager_Destroy(&destroy_args), api); + }; +} + PJRT_ErrorDeleter MakeErrorDeleter(const PJRT_Api* api) { return [api](PJRT_Error* error) -> void { PJRT_Error_Destroy_Args destroy_args; @@ -766,7 +781,7 @@ static PJRT_KeyValueGetCFunc ToKVGetCFunc( xla::KeyValueStoreInterface* kv_store) { return [kv_store](PJRT_KeyValueGetCallback_Args* args) -> PJRT_Error* { absl::StatusOr output = - kv_store->Get(std::string_view(args->key, args->key_size), + kv_store->Get(absl::string_view(args->key, args->key_size), absl::Milliseconds(args->timeout_in_ms)); if (!output.ok()) { absl::string_view message = output.status().message(); @@ -782,12 +797,31 @@ static PJRT_KeyValueGetCFunc ToKVGetCFunc( }; } +static PJRT_KeyValueTryGetCFunc ToKVTryGetCFunc( + xla::KeyValueStoreInterface* kv_store) { + return [kv_store](PJRT_KeyValueTryGetCallback_Args* args) -> PJRT_Error* { + absl::StatusOr output = + kv_store->TryGet(absl::string_view(args->key, args->key_size)); + if (!output.ok()) { + absl::string_view message = output.status().message(); + return (*args->callback_error)( + StatusCodeToPjrtErrorCode(output.status().code()), message.data(), + message.size()); + } + args->value = new char[output->size()]; + std::copy(output->begin(), output->end(), args->value); + args->value_size = output->size(); + args->value_deleter_callback = &PjRtValueDeleterCallback; + return nullptr; + }; +} + static PJRT_KeyValuePutCFunc ToKVPutCFunc( xla::KeyValueStoreInterface* kv_store) { return [kv_store](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* { absl::Status status = - kv_store->Set(std::string_view(args->key, args->key_size), - std::string_view(args->value, args->value_size)); + kv_store->Set(absl::string_view(args->key, args->key_size), + absl::string_view(args->value, args->value_size)); if (!status.ok()) { absl::string_view message = status.message(); return (*args->callback_error)(StatusCodeToPjrtErrorCode(status.code()), @@ -813,6 +847,22 @@ static PJRT_KeyValueGetCallback ToCKVGetCallback( }; } +static PJRT_KeyValueTryGetCallback ToCKVTryGetCallback( + PJRT_KeyValueTryGetCFunc* kv_try_get_c_func) { + return [](PJRT_KeyValueTryGetCallback_Args* args) -> PJRT_Error* { + PJRT_KeyValueTryGetCFunc* kv_try_get_c_func = + reinterpret_cast(args->user_arg); + if (kv_try_get_c_func == nullptr) { + absl::Status status = xla::InvalidArgument( + "got nullptr for PJRT_KeyValueTryGet_Args.user_arg"); + return (*args->callback_error)(StatusCodeToPjrtErrorCode(status.code()), + status.message().data(), + status.message().size()); + } + return (*kv_try_get_c_func)(args); + }; +} + static PJRT_KeyValuePutCallback ToCKVPutCallback( PJRT_KeyValuePutCFunc* kv_put_c_func) { return [](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* { @@ -833,9 +883,12 @@ std::unique_ptr ConvertToCKeyValueCallbacks( std::shared_ptr kv_store) { auto kv_callback_data = std::make_unique(); kv_callback_data->kv_get_c_func = ToKVGetCFunc(kv_store.get()); + kv_callback_data->kv_try_get_c_func = ToKVTryGetCFunc(kv_store.get()); kv_callback_data->kv_put_c_func = ToKVPutCFunc(kv_store.get()); kv_callback_data->c_kv_get = ToCKVGetCallback(&kv_callback_data->kv_get_c_func); + kv_callback_data->c_kv_try_get = + ToCKVTryGetCallback(&kv_callback_data->kv_try_get_c_func); kv_callback_data->c_kv_put = ToCKVPutCallback(&kv_callback_data->kv_put_c_func); kv_callback_data->kv_store = std::move(kv_store); @@ -1065,4 +1118,65 @@ PJRT_Profiler_Extension CreatePjrtProfilerExtension( return profiler_extension; } +PJRT_ShapeSpec ConvertToPjRtShapeSpec( + const xla::PjRtClient::ShapeSpec& shape_spec) { + PJRT_ShapeSpec c_shape_spec; + c_shape_spec.struct_size = PJRT_ShapeSpec_STRUCT_SIZE; + c_shape_spec.extension_start = nullptr; + c_shape_spec.element_type = + pjrt::ConvertToPjRtBufferType(shape_spec.element_type); + c_shape_spec.dims = shape_spec.dims.data(); + c_shape_spec.num_dims = shape_spec.dims.size(); + return c_shape_spec; +} + +xla::PjRtClient::ShapeSpec ConvertFromPjrtShapeSpec( + PJRT_ShapeSpec c_shape_spec) { + xla::PjRtClient::ShapeSpec shape_spec; + shape_spec.element_type = + pjrt::ConvertFromPjRtBufferType(c_shape_spec.element_type); + + shape_spec.dims = xla::DimensionVector( + c_shape_spec.dims, c_shape_spec.dims + c_shape_spec.num_dims); + return shape_spec; +} + +std::vector GetMemorySpaceDescriptions( + PJRT_DeviceDescription* device_description, const PJRT_Api* c_api, + absl::StatusOr* default_memory) { + const PJRT_MemoryDescriptions_Extension* extension = + pjrt::FindExtension( + c_api, PJRT_Extension_Type::PJRT_Extension_Type_MemoryDescriptions); + if (!extension) return {}; + + PJRT_DeviceDescription_MemoryDescriptions_Args mem_desc_args; + mem_desc_args.struct_size = + PJRT_DeviceDescription_MemoryDescriptions_Args_STRUCT_SIZE; + mem_desc_args.extension_start = nullptr; + mem_desc_args.device_description = device_description; + pjrt::LogFatalIfPjrtError( + extension->PJRT_DeviceDescription_MemoryDescriptions(&mem_desc_args), + c_api); + + std::vector memory_space_descriptions; + for (int i = 0; i < mem_desc_args.num_memory_descriptions; i++) { + PJRT_MemoryDescription_Kind_Args kind_args; + kind_args.struct_size = PJRT_MemoryDescription_Kind_Args_STRUCT_SIZE; + kind_args.extension_start = nullptr; + kind_args.memory_description = mem_desc_args.memory_descriptions[i]; + pjrt::LogFatalIfPjrtError( + extension->PJRT_MemoryDescription_Kind(&kind_args), c_api); + xla::PjRtMemorySpaceDescription description( + std::string(kind_args.kind, kind_args.kind_size), kind_args.kind_id); + memory_space_descriptions.push_back(description); + } + *default_memory = {}; + for (int i = 0; i < mem_desc_args.num_memory_descriptions; i++) { + if (mem_desc_args.default_memory_index == i && default_memory) { + *default_memory = &memory_space_descriptions[i]; + } + } + return memory_space_descriptions; +} + } // namespace pjrt diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h index 759569123456ee..44b56cc1b7f4fb 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h @@ -66,6 +66,14 @@ using PJRT_ClientDeleter = std::function; // The lifetime of the Api pointed to must be longer than the client. PJRT_ClientDeleter MakeClientDeleter(const PJRT_Api* api); +using PJRT_AsyncHostToDeviceTransferManagerDeleter = + std::function; + +// Pass in an API pointer; receive a custom deleter for smart pointers. +// The lifetime of the Api pointed to must be longer than the transfer manager. +PJRT_AsyncHostToDeviceTransferManagerDeleter +MakeAsyncHostToDeviceTransferManagerDeleter(const PJRT_Api* api); + using PJRT_ErrorDeleter = std::function; // Pass in an API pointer; receive a custom deleter for smart pointers. @@ -210,6 +218,9 @@ int GetId(const PJRT_Api* api, PJRT_DeviceDescription* device_desc); using PJRT_KeyValueGetCFunc = std::function; +using PJRT_KeyValueTryGetCFunc = + std::function; + using PJRT_KeyValuePutCFunc = std::function; @@ -220,17 +231,21 @@ struct PJRT_KeyValueCallbackData { std::shared_ptr kv_store; - // kv_get_c_func and kv_put_c_func are holding pointers to kv_store. + // kv_get_c_func, kv_try_get_c_func and kv_put_c_func are holding pointers to + // kv_store. pjrt::PJRT_KeyValueGetCFunc kv_get_c_func; pjrt::PJRT_KeyValuePutCFunc kv_put_c_func; - // c_kv_get and c_kv_put are holding pointers to kv_get_c_func and - // kv_put_c_func. + // c_kv_get, c_kv_try_get and c_kv_put are holding pointers to kv_get_c_func, + // kv_try_get_c_func and kv_put_c_func. PJRT_KeyValueGetCallback c_kv_get; PJRT_KeyValuePutCallback c_kv_put; + pjrt::PJRT_KeyValueTryGetCFunc kv_try_get_c_func; + PJRT_KeyValueTryGetCallback c_kv_try_get; }; -// The returned &kv_get_c_func and &kv_put_c_func must be set as -// PJRT_Client_Create_Args.kv_get_user_arg and +// The returned &kv_get_c_func, &kv_try_get_c_func and &kv_put_c_func must be +// set as PJRT_Client_Create_Args.kv_get_user_arg, +// PJRT_Client_Create_Args.kv_try_get_user_arg and // PJRT_Client_Create_Args.kv_put_user_arg, respectively. The entire // PJRT_KeyValueCallbackData must be kept alive as long as c_kv_get and c_kv_put // may be called. @@ -296,6 +311,12 @@ absl::Span DeviceDescriptions( absl::StatusOr GetCompiledMemoryStats( const PJRT_Api* api, PJRT_Executable* executable); +PJRT_ShapeSpec ConvertToPjRtShapeSpec( + const xla::PjRtClient::ShapeSpec& shape_spec); + +xla::PjRtClient::ShapeSpec ConvertFromPjrtShapeSpec( + PJRT_ShapeSpec c_shape_spec); + // Creates a PJRT_Profiler_Extension and adds a producer trace with // the given name. The created PJRT_Profiler_Extension will be used in argument // structs to pass the producer traceme context id to add a corresponding @@ -336,6 +357,10 @@ int64_t GetTracemeContextId(InputType* args) { return traceme_context_id; } +std::vector GetMemorySpaceDescriptions( + PJRT_DeviceDescription* device_description, const PJRT_Api* c_api, + absl::StatusOr* default_memory); + } // namespace pjrt #endif // XLA_PJRT_C_PJRT_C_API_HELPERS_H_ diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc index 4b8a59287589ed..6dfce81a1e4514 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc @@ -108,14 +108,22 @@ TEST(PjRtCApiHelperTest, Callback) { auto kv_callback_data = ConvertToCKeyValueCallbacks(kv_store); auto converted_kv_store = ToCppKeyValueStore( kv_callback_data->c_kv_get, &kv_callback_data->kv_get_c_func, + kv_callback_data->c_kv_try_get, &kv_callback_data->kv_try_get_c_func, kv_callback_data->c_kv_put, &kv_callback_data->kv_put_c_func); + auto v_not_found = converted_kv_store->Get("key", absl::Seconds(1)); + EXPECT_TRUE(absl::IsNotFound(v_not_found.status())) << v_not_found.status(); + auto s = converted_kv_store->Set("key", "value"); TF_EXPECT_OK(s); auto v = converted_kv_store->Get("key", absl::Seconds(1)); TF_EXPECT_OK(v.status()); EXPECT_EQ(*v, "value"); + + auto v_2 = converted_kv_store->TryGet("key"); + TF_EXPECT_OK(v.status()); + EXPECT_EQ(*v, "value"); } TEST(PjRtCApiHelperTest, ConvertToCLayoutFromStrides) { diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_memory_descriptions_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_memory_descriptions_extension.h new file mode 100644 index 00000000000000..91f61961dd1630 --- /dev/null +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_memory_descriptions_extension.h @@ -0,0 +1,86 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_C_PJRT_C_API_MEMORY_DESCRIPTIONS_EXTENSION_H_ +#define XLA_PJRT_C_PJRT_C_API_MEMORY_DESCRIPTIONS_EXTENSION_H_ + +#include + +#include "xla/pjrt/c/pjrt_c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// Optional and experimental extension. +// This extension allows to retrieve all supported types of memory +// supported by a given device description. This is useful for specifying +// non-default memories in AOT computations (as opposed to the +// physically-present memories associated with a PJRT_Client). + +#define PJRT_API_MEMORY_DESCRIPTIONS_EXTENSION_VERSION 1 + +typedef struct PJRT_MemoryDescription PJRT_MemoryDescription; + +struct PJRT_DeviceDescription_MemoryDescriptions_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + PJRT_DeviceDescription* device_description; + const PJRT_MemoryDescription* const* memory_descriptions; // out + size_t num_memory_descriptions; // out + // Index into memory_descriptions. -1 if there's no default: + size_t default_memory_index; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_DeviceDescription_MemoryDescriptions_Args, + default_memory_index); + +// Returns all memory descriptions attached to this device. +// The memories are in no particular order. +typedef PJRT_Error* PJRT_DeviceDescription_MemoryDescriptions( + PJRT_DeviceDescription_MemoryDescriptions_Args* args); + +struct PJRT_MemoryDescription_Kind_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + const PJRT_MemoryDescription* memory_description; + // `kind` has same lifetime as `memory_description`. + const char* kind; // out + size_t kind_size; // out + int kind_id; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_MemoryDescription_Kind_Args, kind_id); + +// Returns the kind of a given memory space description. This is a +// platform-dependent string and numeric ID that uniquely identifies the kind of +// memory space among those possible on this platform. +typedef PJRT_Error* PJRT_MemoryDescription_Kind( + PJRT_MemoryDescription_Kind_Args* args); + +typedef struct PJRT_MemoryDescriptions_Extension { + size_t struct_size; + PJRT_Extension_Type type; + PJRT_Extension_Base* next; + PJRT_DeviceDescription_MemoryDescriptions* + PJRT_DeviceDescription_MemoryDescriptions; + PJRT_MemoryDescription_Kind* PJRT_MemoryDescription_Kind; +} PJRT_MemoryDescriptions_Extension; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_MemoryDescriptions_Extension, + PJRT_MemoryDescription_Kind); + +#ifdef __cplusplus +} +#endif + +#endif // XLA_PJRT_C_PJRT_C_API_MEMORY_DESCRIPTIONS_EXTENSION_H_ diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_test.cc index fa6c1b7cb46cec..0d9030380f35b9 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_test.cc @@ -44,9 +44,11 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/c/pjrt_c_api_memory_descriptions_extension.h" #include "xla/pjrt/c/pjrt_c_api_test_base.h" #include "xla/pjrt/compile_options.pb.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_future.h" #include "xla/service/computation_placer.h" #include "xla/service/hlo.pb.h" @@ -551,6 +553,28 @@ TEST_F(PjrtCApiTest, DeviceLocalHardwareId) { CHECK_EQ(args.local_hardware_id, 0); } +TEST_F(PjrtCApiTest, DeviceDescriptionAndMemoryDescriptionss) { + PJRT_Device_GetDescription_Args get_description = + PJRT_Device_GetDescription_Args{ + .struct_size = PJRT_Device_GetDescription_Args_STRUCT_SIZE, + .extension_start = nullptr, + .device = GetClientDevices()[0], + }; + PJRT_Error* error = api_->PJRT_Device_GetDescription(&get_description); + EXPECT_EQ(error, nullptr); + + absl::StatusOr default_memory; + std::vector memory_descriptions = + GetMemorySpaceDescriptions(get_description.device_description, api_, + &default_memory); + + EXPECT_TRUE(default_memory.ok()); + for (int i = 0; i < memory_descriptions.size(); i++) { + EXPECT_NE(memory_descriptions[i].kind_id(), 0); + EXPECT_NE(memory_descriptions[i].kind().size(), 0); + } +} + // ---------------------------------- Buffers ---------------------------------- class PjrtCApiBufferTest : public PjrtCApiTest { @@ -891,6 +915,12 @@ FieldOffsetsAndSizesForVersion(int major_version, int minor_version) { if (minor_version >= 57) { add_field("PJRT_Buffer_CopyRawToHost", kFnPtrSize); } + if (minor_version >= 58) { + add_field("PJRT_AsyncHostToDeviceTransferManager_Destroy", kFnPtrSize); + add_field("PJRT_AsyncHostToDeviceTransferManager_TransferData", + kFnPtrSize); + add_field("PJRT_Client_CreateBuffersForAsyncHostToDevice", kFnPtrSize); + } return version_offsets_and_sizes; } LOG(FATAL) << "Unsupported API version: " << major_version << "." @@ -1219,6 +1249,17 @@ TEST_F(PjrtCAbiTestBase, FieldOffsetsAndSizes) { {"PJRT_Buffer_CopyRawToHost", {offsetof(PJRT_Api, PJRT_Buffer_CopyRawToHost), sizeof(PJRT_Api::PJRT_Buffer_CopyRawToHost)}}, + {"PJRT_AsyncHostToDeviceTransferManager_Destroy", + {offsetof(PJRT_Api, PJRT_AsyncHostToDeviceTransferManager_Destroy), + sizeof(PJRT_Api::PJRT_AsyncHostToDeviceTransferManager_Destroy)}}, + {"PJRT_AsyncHostToDeviceTransferManager_TransferData", + {offsetof(PJRT_Api, + PJRT_AsyncHostToDeviceTransferManager_TransferData), + sizeof( + PJRT_Api::PJRT_AsyncHostToDeviceTransferManager_TransferData)}}, + {"PJRT_Client_CreateBuffersForAsyncHostToDevice", + {offsetof(PJRT_Api, PJRT_Client_CreateBuffersForAsyncHostToDevice), + sizeof(PJRT_Api::PJRT_Client_CreateBuffersForAsyncHostToDevice)}}, }; ASSERT_EQ(api_->pjrt_api_version.major_version, PJRT_API_MAJOR); ASSERT_EQ(api_->pjrt_api_version.minor_version, PJRT_API_MINOR); diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.cc index 9602813c573c52..f867846ebcbd54 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.cc @@ -47,9 +47,11 @@ PJRT_Client* CreateClient(const PJRT_Api* api) { create_args.create_options = nullptr; create_args.num_options = 0; create_args.kv_get_callback = nullptr; + create_args.kv_get_user_arg = nullptr; create_args.kv_put_callback = nullptr; create_args.kv_put_user_arg = nullptr; - create_args.kv_get_user_arg = nullptr; + create_args.kv_try_get_callback = nullptr; + create_args.kv_try_get_user_arg = nullptr; PJRT_Error* error = api->PJRT_Client_Create(&create_args); CHECK_EQ(error, nullptr); CHECK_NE(create_args.client, nullptr); diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 1f8fc2de7498a2..906223b3159319 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -22,13 +22,13 @@ limitations under the License. #include #include #include -#include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -235,13 +235,17 @@ static absl::Status PopulateExecutableOutputMemoryKinds( class CApiKeyValueStore : public xla::KeyValueStoreInterface { public: CApiKeyValueStore(PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, + PJRT_KeyValueTryGetCallback c_try_get_callback, + void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg) : c_get_callback_(c_get_callback), get_user_arg_(get_user_arg), + c_try_get_callback_(c_try_get_callback), + try_get_user_arg_(try_get_user_arg), c_put_callback_(c_put_callback), put_user_arg_(put_user_arg) {} - absl::StatusOr Get(std::string_view key, + absl::StatusOr Get(absl::string_view key, absl::Duration timeout) override { PJRT_CallbackError callback_error = [](PJRT_Error_Code code, const char* message, @@ -264,7 +268,28 @@ class CApiKeyValueStore : public xla::KeyValueStoreInterface { return result; } - absl::Status Set(std::string_view key, std::string_view value) override { + absl::StatusOr TryGet(absl::string_view key) override { + PJRT_CallbackError callback_error = [](PJRT_Error_Code code, + const char* message, + size_t message_size) { + return new PJRT_Error{absl::Status(static_cast(code), + std::string(message, message_size))}; + }; + PJRT_KeyValueTryGetCallback_Args args; + args.key = key.data(); + args.key_size = key.size(); + args.callback_error = &callback_error; + args.user_arg = try_get_user_arg_; + std::unique_ptr error(c_try_get_callback_(&args)); + if (error != nullptr) { + return error->status; + } + auto result = std::string(args.value, args.value_size); + args.value_deleter_callback(args.value); + return result; + } + + absl::Status Set(absl::string_view key, absl::string_view value) override { PJRT_CallbackError callback_error = [](PJRT_Error_Code code, const char* message, size_t message_size) { @@ -288,18 +313,23 @@ class CApiKeyValueStore : public xla::KeyValueStoreInterface { private: PJRT_KeyValueGetCallback c_get_callback_; void* get_user_arg_; + PJRT_KeyValueTryGetCallback c_try_get_callback_; + void* try_get_user_arg_; PJRT_KeyValuePutCallback c_put_callback_; void* put_user_arg_; }; std::shared_ptr ToCppKeyValueStore( PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, + PJRT_KeyValueTryGetCallback c_try_get_callback, void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg) { - if (c_get_callback == nullptr || c_put_callback == nullptr) { + if (c_get_callback == nullptr || c_try_get_callback == nullptr || + c_put_callback == nullptr) { return nullptr; } - return std::make_shared(c_get_callback, get_user_arg, - c_put_callback, put_user_arg); + return std::make_shared( + c_get_callback, get_user_arg, c_try_get_callback, try_get_user_arg, + c_put_callback, put_user_arg); } // ---------------------------------- Errors ----------------------------------- @@ -479,6 +509,67 @@ PJRT_Error* PJRT_Client_AddressableMemories( return nullptr; } +PJRT_Error* PJRT_Client_CreateBuffersForAsyncHostToDevice( + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_Client_CreateBuffersForAsyncHostToDevice_Args", + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args_STRUCT_SIZE, + args->struct_size)); + std::vector> device_layouts; + absl::InlinedVector shape_specs; + shape_specs.reserve(args->num_shape_specs); + for (int i = 0; i < args->num_shape_specs; ++i) { + shape_specs.push_back(pjrt::ConvertFromPjrtShapeSpec(args->shape_specs[i])); + } + std::optional>> + arg_device_layouts; + if (args->num_device_layouts == 0) { + arg_device_layouts = std::nullopt; + } else { + device_layouts.reserve(args->num_device_layouts); + for (int i = 0; i < args->num_device_layouts; ++i) { + std::optional optional_layout; + if (args->device_layouts[i] != nullptr) { + xla::Layout cpp_layout; + PJRT_Buffer_MemoryLayout* layout = args->device_layouts[i]; + switch (layout->type) { + case PJRT_Buffer_MemoryLayout_Type:: + PJRT_Buffer_MemoryLayout_Type_Tiled: { + PJRT_ASSIGN_OR_RETURN(cpp_layout, ConvertToLayout(layout->tiled)); + break; + } + case PJRT_Buffer_MemoryLayout_Type:: + PJRT_Buffer_MemoryLayout_Type_Strides: { + PJRT_RETURN_IF_ERROR(absl::InvalidArgumentError( + "PJRT_Buffer_MemoryLayout_Type_Strides is not supported to be " + "converted to a xla::Layout.")); + break; + } + default: { + PJRT_RETURN_IF_ERROR(absl::InvalidArgumentError( + absl::StrCat("Unexpected PJRT_Buffer_MemoryLayout_Type type: ", + layout->type))); + } + } + device_layouts.push_back(cpp_layout); + } else { + device_layouts.push_back(std::nullopt); + } + } + arg_device_layouts = absl::MakeSpan(device_layouts); + } + + PJRT_ASSIGN_OR_RETURN( + std::unique_ptr + transfer_manager, + args->client->client->CreateBuffersForAsyncHostToDevice( + absl::MakeSpan(shape_specs), arg_device_layouts, + args->memory->memory_space)); + args->transfer_manager = new PJRT_AsyncHostToDeviceTransferManager{ + std::move(transfer_manager), args->client}; + return nullptr; +} + // Searches `device_list` for a PJRT_Device* that wraps a provided // `xla::PjRtDevice *` (`cpp_device`). If a match is found, that PJRT_Device* // is returned. Otherwise, returns nullptr. @@ -531,6 +622,36 @@ static void PopulatePjrtExecutableAddressableDevices( } } +//-------------------- AsyncHostToDeviceTransferManager --------------------- + +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_Destroy( + PJRT_AsyncHostToDeviceTransferManager_Destroy_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_AsyncHostToDeviceTransferManager_Destroy_Args", + PJRT_AsyncHostToDeviceTransferManager_Destroy_Args_STRUCT_SIZE, + args->struct_size)); + delete args->transfer_manager; + return nullptr; +} + +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_TransferData( + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_AsyncHostToDeviceTransferManager_TransferData_Args", + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args_STRUCT_SIZE, + args->struct_size)); + xla::PjRtFuture<>::Promise promise = xla::PjRtFuture<>::CreatePromise(); + absl::AnyInvocable on_done_with_d2h_transfer = + [promise]() mutable { promise.Set(); }; + PJRT_RETURN_IF_ERROR( + args->transfer_manager->transfer_manager->TransferRawDataToSubBuffer( + args->buffer_index, args->data, args->offset, args->transfer_size, + args->is_last_transfer, std::move(on_done_with_d2h_transfer))); + args->done_with_h2d_transfer = + new PJRT_Event{xla::PjRtFuture<>(std::move(promise))}; + return nullptr; +} + namespace { absl::StatusOr ParseCompileOptions( @@ -833,6 +954,35 @@ PJRT_Error* PJRT_DeviceDescription_DebugString( return nullptr; } +PJRT_Error* PJRT_DeviceDescription_MemoryDescriptions( + PJRT_DeviceDescription_MemoryDescriptions_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_DeviceDescription_MemoryDescriptions_Args", + PJRT_DeviceDescription_MemoryDescriptions_Args_STRUCT_SIZE, + args->struct_size)); + + absl::Span memory_spaces = + args->device_description->device_description->memory_spaces(); + + // We pass each xla::PjRtMemorySpaceDescriptions to the caller through an + // opaque pointer. + args->memory_descriptions = + reinterpret_cast( + memory_spaces.data()); + + absl::StatusOr default_memory = + args->device_description->device_description->default_memory_space(); + args->default_memory_index = -1; + for (int i = 0; i < memory_spaces.size(); i++) { + if (default_memory.ok() && *default_memory == memory_spaces[i]) { + args->default_memory_index = i; + } + } + + args->num_memory_descriptions = memory_spaces.size(); + return nullptr; +} + PJRT_Error* PJRT_DeviceDescription_ToString( PJRT_DeviceDescription_ToString_Args* args) { PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( @@ -845,6 +995,19 @@ PJRT_Error* PJRT_DeviceDescription_ToString( return nullptr; } +PJRT_Error* PJRT_MemoryDescription_Kind( + PJRT_MemoryDescription_Kind_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_MemoryDescription_Kind_Args", + PJRT_MemoryDescription_Kind_Args_STRUCT_SIZE, args->struct_size)); + absl::string_view kind = + args->memory_description->memory_space_description.kind(); + args->kind = kind.data(); + args->kind_size = kind.size(); + args->kind_id = args->memory_description->memory_space_description.kind_id(); + return nullptr; +} + PJRT_Error* PJRT_Device_GetDescription(PJRT_Device_GetDescription_Args* args) { PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( "PJRT_Device_GetDescription_Args", @@ -1673,12 +1836,9 @@ PJRT_Error* PJRT_Buffer_GetMemoryLayout( absl::MutexLock lock(&args->buffer->mu); if (!layout_data.has_value()) { // TODO(skyewm): change PJRT C API to also use opaque layout type - std::unique_ptr pjrt_layout = + std::shared_ptr pjrt_layout = args->buffer->buffer->layout(); - xla::PjRtXlaLayout* pjrt_xla_layout = - tensorflow::down_cast(pjrt_layout.get()); - CHECK(pjrt_xla_layout != nullptr) << "Got unexpected layout type"; - const xla::Layout& xla_layout = pjrt_xla_layout->xla_layout(); + const xla::Layout& xla_layout = pjrt_layout->xla_layout(); PJRT_ASSIGN_OR_RETURN(BufferMemoryLayoutData data, ConvertToBufferMemoryLayoutData(xla_layout)); @@ -2159,7 +2319,7 @@ PJRT_Error* PJRT_Layouts_PJRT_Client_GetDefaultLayout( args->client->client->GetDefaultLayout( pjrt::ConvertFromPjRtBufferType(args->type), {args->dims, args->num_dims})); - auto pjrt_xla_layout = std::make_unique(xla_layout); + auto pjrt_xla_layout = std::make_shared(xla_layout); args->layout = new PJRT_Layouts_MemoryLayout{std::move(pjrt_xla_layout)}; return nullptr; } @@ -2530,6 +2690,12 @@ PJRT_Api CreatePjrtApi(PJRT_Client_Create* create_fn, /*PJRT_ExecuteContext_Create=*/execute_context_create_fn, /*PJRT_ExecuteContext_Destroy=*/pjrt::PJRT_ExecuteContext_Destroy, /*PJRT_Buffer_CopyRawToHost=*/pjrt::PJRT_Buffer_CopyRawToHost, + /*PJRT_AsyncHostToDeviceTransferManager_Destroy=*/ + pjrt::PJRT_AsyncHostToDeviceTransferManager_Destroy, + /*PJRT_AsyncHostToDeviceTransferManager_TransferData=*/ + pjrt::PJRT_AsyncHostToDeviceTransferManager_TransferData, + /*PJRT_Client_CreateBuffersForAsyncHostToDevice=*/ + pjrt::PJRT_Client_CreateBuffersForAsyncHostToDevice, }; } @@ -2549,4 +2715,17 @@ PJRT_Layouts_Extension CreateLayoutsExtension(PJRT_Extension_Base* next) { }; } +PJRT_MemoryDescriptions_Extension CreateMemoryDescriptionsExtension( + PJRT_Extension_Base* next) { + return PJRT_MemoryDescriptions_Extension{ + /*struct_size=*/PJRT_MemoryDescriptions_Extension_STRUCT_SIZE, + /*type=*/PJRT_Extension_Type_MemoryDescriptions, + /*next=*/next, + /*PJRT_DeviceDescription_MemorySpaces=*/ + pjrt::PJRT_DeviceDescription_MemoryDescriptions, + /*PJRT_MemoryDescription_Kind=*/ + pjrt::PJRT_MemoryDescription_Kind, + }; +} + } // namespace pjrt diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h index 00a0d16d6b4f47..27b1cac051dbd0 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h @@ -32,6 +32,7 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/c/pjrt_c_api_layouts_extension.h" +#include "xla/pjrt/c/pjrt_c_api_memory_descriptions_extension.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" @@ -86,6 +87,18 @@ struct PJRT_Client { explicit PJRT_Client(std::unique_ptr cpp_client); }; +struct PJRT_MemoryDescription { + xla::PjRtMemorySpaceDescription memory_space_description; +}; + +// PJRT_AsyncHostToDeviceTransferManager is owned by its corresponding +// PJRT_Client. +struct PJRT_AsyncHostToDeviceTransferManager { + std::unique_ptr + transfer_manager; + PJRT_Client* client; +}; + // PJRT_DeviceDescriptions are owned by their corresponding PJRT_Device. struct PJRT_DeviceDescription { // The xla::PjRtDeviceDescription* is owned transitively by the @@ -205,7 +218,7 @@ struct PJRT_CopyToDeviceStream { }; struct PJRT_Layouts_MemoryLayout { - std::unique_ptr layout; + std::shared_ptr layout; }; struct PJRT_Layouts_SerializedLayout { @@ -249,7 +262,12 @@ PJRT_Error* PJRT_Client_BufferFromHostBuffer( PJRT_Client_BufferFromHostBuffer_Args* args); PJRT_Error* PJRT_Client_CreateViewOfDeviceBuffer( PJRT_Client_CreateViewOfDeviceBuffer_Args* args); - +PJRT_Error* PJRT_Client_CreateBuffersForAsyncHostToDevice( + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args* args); +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_Destroy( + PJRT_AsyncHostToDeviceTransferManager_Destroy_Args* args); +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_TransferData( + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args* args); PJRT_Error* PJRT_DeviceDescription_Id(PJRT_DeviceDescription_Id_Args* args); PJRT_Error* PJRT_DeviceDescription_ProcessIndex( PJRT_DeviceDescription_ProcessIndex_Args* args); @@ -446,6 +464,7 @@ PJRT_Client* CreateWrapperClient(std::unique_ptr cpp_client); // Helper functions for converting C key-value store callbacks to C++ callbacks. std::shared_ptr ToCppKeyValueStore( PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, + PJRT_KeyValueTryGetCallback c_try_get_callback, void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg); // A method that does not nothing other than returning a nullptr. Can be used as @@ -456,6 +475,9 @@ PJRT_Error* PJRT_Plugin_Initialize_NoOp(PJRT_Plugin_Initialize_Args* args); PJRT_Layouts_Extension CreateLayoutsExtension( PJRT_Extension_Base* next = nullptr); +PJRT_MemoryDescriptions_Extension CreateMemoryDescriptionsExtension( + PJRT_Extension_Base* next = nullptr); + // Creates a PJRT_Api with create_fn from the input and other functions in // pjrt_c_api_wrapper_impl. PJRT_Api CreatePjrtApi(PJRT_Client_Create* create_fn, diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD index 9ff2f1b02e3deb..806a746ed413ce 100644 --- a/third_party/xla/xla/pjrt/cpu/BUILD +++ b/third_party/xla/xla/pjrt/cpu/BUILD @@ -1,7 +1,6 @@ load("//xla:xla.bzl", "xla_cc_test") load("//xla/pjrt/cpu:package_groups.bzl", "xla_cpu_internal_packages") -load("//xla/tsl:tsl.bzl", "if_oss", "internal_visibility") -load("//xla/tsl/platform:build_config.bzl", "tf_proto_library") +load("//xla/tsl:tsl.bzl", "internal_visibility") load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( @@ -105,36 +104,33 @@ cc_library( ], ) -tf_proto_library( - name = "cpu_topology_proto", - srcs = ["cpu_topology.proto"], - visibility = ["//visibility:public"], -) - cc_library( - name = "cpu_topology", - srcs = ["cpu_topology.cc"], - hdrs = ["cpu_topology.h"], + name = "cpu_device", + srcs = ["cpu_device.cc"], + hdrs = ["cpu_device.h"], visibility = internal_visibility(["//xla/pjrt/cpu:legacy_cpu_topology_users"]), deps = [ - ":cpu_topology_proto_cc", + "//xla:literal", + "//xla/pjrt:host_memory_spaces", + "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", + "//xla/pjrt:pjrt_future", + "//xla/pjrt:semaphore", + "//xla/pjrt/plugin/xla_cpu:cpu_device_description", + "//xla/service/cpu:cpu_xfeed", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], ) -xla_cc_test( - name = "cpu_topology_test", - srcs = ["cpu_topology_test.cc"], - deps = [ - ":cpu_topology", - ":cpu_topology_proto_cc", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - cc_library( name = "cpu_client", srcs = ["cpu_client.cc"], @@ -142,7 +138,7 @@ cc_library( visibility = internal_visibility(["//xla/pjrt/cpu:legacy_cpu_client_users"]), deps = [ ":abstract_tfrt_cpu_buffer", - ":cpu_topology", + ":cpu_device", ":tracked_tfrt_cpu_device_buffer", "//xla:array", "//xla:cpu_function_runtime", @@ -155,6 +151,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/backends/cpu/codegen:cpu_features", + "//xla/backends/cpu/collectives:cpu_collectives", "//xla/backends/cpu/runtime:buffer_allocations", "//xla/backends/cpu/runtime:thread_pool_task_runner", "//xla/backends/cpu/runtime:thunk", @@ -168,13 +165,15 @@ cc_library( "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", "//xla/pjrt:pjrt_compiler", - "//xla/pjrt:pjrt_device_description", "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", "//xla/pjrt:semaphore", "//xla/pjrt:transpose", "//xla/pjrt:utils", "//xla/pjrt/plugin/xla_cpu:cpu_client_options", + "//xla/pjrt/plugin/xla_cpu:cpu_device_description", + "//xla/pjrt/plugin/xla_cpu:cpu_topology", + "//xla/pjrt/plugin/xla_cpu:cpu_topology_description", "//xla/service:buffer_assignment", "//xla/service:compiler", "//xla/service:computation_placer_hdr", @@ -188,18 +187,15 @@ cc_library( "//xla/service:hlo_proto_cc", "//xla/service:hlo_value", "//xla/service:maybe_owning_device_memory", - "//xla/service/cpu:collectives_interface", "//xla/service/cpu:cpu_compiler", "//xla/service/cpu:cpu_event", "//xla/service/cpu:cpu_executable", "//xla/service/cpu:cpu_executable_run_options", "//xla/service/cpu:cpu_runtime", - "//xla/service/cpu:cpu_xfeed", + "//xla/service/llvm_ir:llvm_command_line_options", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "//xla/tsl/concurrency:ref_count", - "//xla/tsl/lib/strings:proto_serialization", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", @@ -265,125 +261,3 @@ xla_cc_test( "@local_tsl//tsl/platform:test_main", ], ) - -cc_library( - name = "gloo_kv_store", - srcs = ["gloo_kv_store.cc"], - hdrs = ["gloo_kv_store.h"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - visibility = [ - "//xla/pjrt/cpu:legacy_cpu_internal_users", - ], - deps = [ - "//xla/pjrt:status_casters", - "//xla/pjrt/distributed:key_value_store_interface", - "@com_google_absl//absl/time", - "@gloo", - ], -) - -cc_library( - name = "gloo_collectives", - srcs = ["gloo_collectives.cc"], - hdrs = ["gloo_collectives.h"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - visibility = [ - "//xla/pjrt/cpu:legacy_cpu_internal_users", - ], - deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/service:collective_ops_utils", - "//xla/service:global_device_id", - "//xla/service/cpu:collectives_interface", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - "@gloo", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - ], -) - -xla_cc_test( - name = "gloo_collectives_test", - srcs = ["gloo_collectives_test.cc"], - linkstatic = True, - deps = [ - ":gloo_collectives", - ":gloo_kv_store", - "//xla:executable_run_options", - "//xla:xla_data_proto_cc", - "//xla/pjrt/distributed:in_memory_key_value_store", - "//xla/pjrt/distributed:key_value_store_interface", - "//xla/service:collective_ops_utils", - "//xla/service:global_device_id", - "//xla/service/cpu:collectives_interface", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/platform:test_main", - ] + select({ - # Gloo's transport_tcp is not available on MacOS - "//xla/tsl:macos": [ - "@gloo//:transport_uv", - ], - "//conditions:default": [ - "@gloo//:transport_tcp", - ], - }), -) - -cc_library( - name = "mpi_collectives", - srcs = if_oss(["mpi_collectives.cc"]), - hdrs = if_oss(["mpi_collectives.h"]), - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - visibility = [ - "//xla/pjrt/cpu:legacy_cpu_internal_users", - ], - deps = if_oss([ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/service:collective_ops_utils", - "//xla/service:global_device_id", - "//xla/service/cpu:collectives_interface", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@mpitrampoline", - ]), -) diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc index fec7d7d1e9ff3e..67180480233659 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -27,10 +27,10 @@ limitations under the License. #include #include #include +#include #include #include -#include "absl/algorithm/container.h" #include "absl/base/casts.h" #include "absl/base/dynamic_annotations.h" #include "absl/container/flat_hash_map.h" @@ -49,6 +49,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "xla/array.h" #include "xla/backends/cpu/codegen/cpu_features.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/thread_pool_task_runner.h" #include "xla/backends/cpu/runtime/thunk.h" @@ -67,29 +68,29 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/pjrt/compile_options.pb.h" #include "xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h" -#include "xla/pjrt/cpu/cpu_topology.h" +#include "xla/pjrt/cpu/cpu_device.h" #include "xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h" #include "xla/pjrt/host_memory_spaces.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_compiler.h" -#include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_topology.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_topology_description.h" #include "xla/pjrt/semaphore.h" #include "xla/pjrt/transpose.h" #include "xla/pjrt/utils.h" #include "xla/service/buffer_assignment.h" #include "xla/service/compiler.h" #include "xla/service/computation_placer.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/cpu_compiler.h" #include "xla/service/cpu/cpu_event.h" #include "xla/service/cpu/cpu_executable.h" #include "xla/service/cpu/cpu_executable_run_options.h" #include "xla/service/cpu/cpu_runtime.h" -#include "xla/service/cpu/cpu_xfeed.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" #include "xla/service/dump.h" @@ -99,6 +100,7 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_module_util.h" #include "xla/service/hlo_value.h" +#include "xla/service/llvm_ir/llvm_command_line_options.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -106,7 +108,6 @@ limitations under the License. #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/ref_count.h" -#include "xla/tsl/lib/strings/proto_serialization.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" @@ -114,6 +115,7 @@ limitations under the License. #include "tsl/platform/denormal.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" +#include "tsl/platform/fingerprint.h" #include "tsl/platform/setround.h" #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" @@ -150,8 +152,6 @@ absl::StatusOr> AllocateDestinationBufferAndAvs( tensorflow::down_cast(device), client); } -const char kCpuPlatformName[] = "cpu"; - void EnqueueWork(tsl::thread::ThreadPool* pool, absl::AnyInvocable callee) { // TSL TheadPool expects std::function that must be copyable, so we are @@ -253,127 +253,19 @@ class TfrtCpuAsyncHostToDeviceTransferManager TfrtCpuDevice* device_; }; -} // namespace - -TfrtCpuDeviceDescription::TfrtCpuDeviceDescription(int process_id, - int local_device_id) - : id_(PackCpuDeviceId(process_id, local_device_id)), - process_index_(process_id), - local_hardware_id_(local_device_id) { - debug_string_ = absl::StrCat("TFRT_CPU_", id_.value()); - to_string_ = absl::StrCat("CpuDevice(id=", id_.value(), ")"); -} - -absl::string_view TfrtCpuDeviceDescription::device_kind() const { - return kCpuPlatformName; -} - -absl::string_view TfrtCpuDeviceDescription::DebugString() const { - return debug_string_; -} - -absl::string_view TfrtCpuDeviceDescription::ToString() const { - return to_string_; -} - -/*static*/ TfrtCpuTopologyDescription TfrtCpuTopologyDescription::Create( - PjRtPlatformId platform_id, absl::string_view platform_name, - absl::string_view platform_version, - absl::Span> devices, - absl::Span machine_attributes) { - std::vector cpu_devices; - cpu_devices.reserve(devices.size()); - for (auto& device : devices) { - cpu_devices.push_back(CpuTopology::CpuDevice{ - device->process_index(), device->local_hardware_id().value()}); - } - return TfrtCpuTopologyDescription(platform_id, platform_name, - platform_version, cpu_devices, - machine_attributes); -} - -absl::StatusOr TfrtCpuTopologyDescription::GetDefaultLayout( - PrimitiveType element_type, absl::Span dims) const { - Shape shape = ShapeUtil::MakeShape(element_type, dims); - return LayoutUtil::GetWithDefaultLayout(shape).layout(); -} - -absl::StatusOr TfrtCpuTopologyDescription::Serialize() const { - std::string result; - if (!tsl::SerializeToStringDeterministic(cpu_topology_.ToProto(), &result)) { - return absl::InternalError("Failed to serialize cpu_topology"); - } - return result; -} - -std::vector> -TfrtCpuTopologyDescription::DeviceDescriptions() const { - std::vector> devices; - devices.reserve(cpu_topology_.number_of_devices()); - for (const CpuTopology::CpuDevice& device : cpu_topology_.devices()) { - devices.push_back(std::make_unique( - device.process_id, device.local_device_id)); - } - return devices; -} - -TfrtCpuDevice::TfrtCpuDevice(int process_id, int local_device_id, - int max_inflight_computations) - : description_(process_id, local_device_id), - max_inflight_computations_semaphore_( - /*capacity=*/max_inflight_computations) {} - -absl::Status TfrtCpuDevice::TransferToInfeed(const LiteralSlice& literal) { - return TransferLiteralToInfeedOnCpu(local_hardware_id().value(), literal); -} - -absl::Status TfrtCpuDevice::TransferFromOutfeed( - MutableBorrowingLiteral literal) { - return TransferLiteralFromOutfeedOnCpu(local_hardware_id().value(), literal); -} - -void TfrtCpuDevice::AttachMemorySpace(PjRtMemorySpace* memory_space) { - CHECK(memory_space != nullptr); - CHECK(client_ == memory_space->client()) << absl::StrFormat( - "Could not attach a TfrtCpuDevice to a PjRtMemorySpace owned by a " - "different client, the device's client: %s, the memory space's client: " - "%s.", - client_->platform_name(), memory_space->client()->platform_name()); - - memory_spaces_.push_back(memory_space); - memory_spaces_by_id_.emplace(memory_space->kind_id(), memory_space); -} - -absl::Span TfrtCpuDevice::memory_spaces() const { - return memory_spaces_; -} - -absl::StatusOr TfrtCpuDevice::default_memory_space() const { - return memory_space_by_kind_id(UnpinnedHostMemorySpace::kKindId); +// Converts a const span of unique_ptr to a const span of +// unique_ptr. This is a safe operation because the resulting span +// only permits access to elements via pointer dereference, and unique_ptr +// values remain immutable. +absl::Span> GetPjRtDeviceSpan( + absl::Span> devices) { + static_assert(std::is_base_of_v); + return absl::Span>( + reinterpret_cast*>(devices.data()), + devices.size()); } -absl::StatusOr TfrtCpuDevice::memory_space_by_kind( - absl::string_view memory_space_kind) const { - auto it = - absl::c_find_if(memory_spaces_, [memory_space_kind](PjRtMemorySpace* ms) { - return ms->kind() == memory_space_kind; - }); - if (it != memory_spaces_.end()) { - return *it; - } - return absl::InternalError( - absl::StrCat("No memory space found (kind: ", memory_space_kind, ")")); -} - -absl::StatusOr TfrtCpuDevice::memory_space_by_kind_id( - int id) const { - auto it = memory_spaces_by_id_.find(id); - if (it == memory_spaces_by_id_.end()) { - return absl::InternalError( - absl::StrCat("No memory space found (kind_id: ", id, ")")); - } - return it->second; -} +} // namespace static int CpuDeviceCount() { // By default we fix the number of devices to one. However we do let the user @@ -419,7 +311,7 @@ static tsl::ThreadOptions GetThreadOptions() { TfrtCpuClient::TfrtCpuClient( int process_index, std::vector> devices, - std::shared_ptr collectives, size_t num_threads, + std::shared_ptr collectives, size_t num_threads, bool asynchronous, std::function customize_hlo_module_config) : process_index_(process_index), @@ -440,9 +332,9 @@ TfrtCpuClient::TfrtCpuClient( tsl::MakeAvailableAsyncValueRef()), transpose_cache_(1024), collectives_(std::move(collectives)), - topology_(TfrtCpuTopologyDescription::Create( - platform_id(), platform_name(), platform_version(), owned_devices_, - cpu::DetectMachineAttributes())), + topology_(CpuTopologyDescription::Create( + platform_id(), platform_name(), platform_version(), + GetPjRtDeviceSpan(owned_devices_), cpu::DetectMachineAttributes())), asynchronous_(asynchronous), customize_hlo_module_config_(std::move(customize_hlo_module_config)) { for (const std::unique_ptr& device : owned_devices_) { @@ -743,6 +635,11 @@ static absl::StatusOr> JitCompile( static constexpr char kBeforeOptimizationsDumpName[] = "before_optimizations"; DumpHloModuleIfEnabled(*hlo_module, kBeforeOptimizationsDumpName); + // RunHloPasses and RunBackend both look at the LLVM command line options. + auto llvm_options = llvm_ir::ExtractXlaBackendExtraOptions( + hlo_module->config().debug_options().xla_backend_extra_options()); + llvm_ir::LLVMCommandLineOptionsLock llvm_lock(llvm_options); + // Run Hlo Passes cpu::CpuCompiler compiler; TF_ASSIGN_OR_RETURN(hlo_module, compiler.RunHloPasses(std::move(hlo_module), @@ -1183,17 +1080,19 @@ TfrtCpuExecutable::TfrtCpuExecutable( computation_layout.parameter_shape(0).tuple_shapes(i))); } } + + // Compute fingerprint of the executable from the HloModule. + tsl::Fprint128 fingerprint = tsl::Fingerprint128(fingerprint_); + fingerprint = tsl::FingerprintCat128( + tsl::Fingerprint128(fingerprint_), + tsl::Fingerprint128(cpu_executable_->module().ToString())); + fingerprint_ = absl::StrCat(fingerprint.low64, fingerprint.high64); } void TfrtCpuExecutable::Delete() {} bool TfrtCpuExecutable::IsDeleted() { return false; } -absl::StatusOr> TfrtCpuExecutable::Fingerprint() - const { - return std::optional(); -} - absl::Status TfrtCpuExecutable::SetUpDonation(bool tuple_inputs) { TF_ASSIGN_OR_RETURN(parameters_that_must_be_donated_, ComputeParametersThatMustBeDonated( diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.h b/third_party/xla/xla/pjrt/cpu/cpu_client.h index b94591a447f70e..f4074534e9ff5a 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.h +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.h @@ -38,26 +38,26 @@ limitations under the License. #include "absl/types/span.h" #include "unsupported/Eigen/CXX11/Tensor" #include "mlir/IR/BuiltinOps.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/executable_run_options.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/layout.h" #include "xla/literal.h" #include "xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h" -#include "xla/pjrt/cpu/cpu_topology.h" +#include "xla/pjrt/cpu/cpu_device.h" #include "xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_compiler.h" -#include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" -#include "xla/pjrt/semaphore.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_device_description.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_topology_description.h" #include "xla/pjrt/transpose.h" #include "xla/service/buffer_assignment.h" #include "xla/service/computation_placer.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/cpu_event.h" #include "xla/service/executable.h" #include "xla/service/hlo.pb.h" @@ -73,195 +73,12 @@ limitations under the License. namespace xla { -class TfrtCpuDevice; // forward declare - -class TfrtCpuDeviceDescription final : public PjRtDeviceDescription { - public: - explicit TfrtCpuDeviceDescription(int process_id, int local_device_id); - - int id() const override { return id_.value(); } - - int process_index() const override { return process_index_; } - - int local_hardware_id() const { return local_hardware_id_; } - - absl::string_view device_kind() const override; - - absl::string_view DebugString() const override; - - absl::string_view ToString() const override; - - const absl::flat_hash_map& Attributes() - const override { - return attributes_; - } - - private: - PjRtGlobalDeviceId id_; - int process_index_; - int local_hardware_id_; - std::string debug_string_; - std::string to_string_; - absl::flat_hash_map attributes_ = {}; -}; - -class TfrtCpuTopologyDescription : public PjRtTopologyDescription { - public: - static TfrtCpuTopologyDescription Create( - PjRtPlatformId platform_id, absl::string_view platform_name, - absl::string_view platform_version, - absl::Span> devices, - absl::Span machine_attributes); - - // `cpu_device_ids` is the list of logical device ids for the CPU devices and - // will be used to initialize the CPU topology. - TfrtCpuTopologyDescription( - const PjRtPlatformId platform_id, const absl::string_view platform_name, - const absl::string_view platform_version, - const std::vector cpu_devices, - absl::Span machine_attributes) - : platform_id_(platform_id), - platform_name_(platform_name), - platform_version_(platform_version), - cpu_topology_(std::move(cpu_devices), - std::vector(machine_attributes.begin(), - machine_attributes.end())) {} - - bool operator==(const TfrtCpuTopologyDescription& other) const { - return this->platform_id() == other.platform_id() && - this->platform_name() == other.platform_name() && - this->platform_version() == other.platform_version() && - this->cpu_topology().devices() == other.cpu_topology().devices(); - } - - PjRtPlatformId platform_id() const override { return platform_id_; } - - absl::string_view platform_name() const override { return platform_name_; } - - absl::string_view platform_version() const override { - return platform_version_; - } - - std::vector> DeviceDescriptions() - const override; - - const CpuTopology& cpu_topology() const { return cpu_topology_; } - const CpuTopology* cpu_topology_ptr() const { return &cpu_topology_; } - - // No subslice is supported. - bool is_subslice_topology() const override { return false; } - - // TODO(b/319478189): We support multi-host CPU computations and should - // correctly report process count. - absl::StatusOr ProcessCount() const override { return 1; } - - absl::StatusOr CoreCountOfDefaultType() const override { - return cpu_topology_.number_of_devices(); - } - - absl::StatusOr LogicalDeviceCountOfDefaultType() const override { - return cpu_topology_.number_of_devices(); - } - - absl::StatusOr CoreCountOfDefaultTypePerProcess() const override { - return cpu_topology_.number_of_devices(); - } - - absl::StatusOr CoreCountOfDefaultTypePerChip() const override { - return 1; - } - - absl::StatusOr Serialize() const override; - - // Returns vendor specific attributes about the topology. - const absl::flat_hash_map& Attributes() - const override { - return attributes_; - } - - absl::StatusOr GetDefaultLayout( - PrimitiveType element_type, - absl::Span dims) const override; - - private: - const PjRtPlatformId platform_id_; - const std::string platform_name_; - const std::string platform_version_; - const CpuTopology cpu_topology_; - absl::flat_hash_map attributes_; -}; - -class TfrtCpuDevice final : public PjRtDevice { - public: - explicit TfrtCpuDevice(int process_id, int local_device_id, - int max_inflight_computations = 32); - - const TfrtCpuDeviceDescription& description() const override { - return description_; - } - - void SetClient(PjRtClient* client) { - CHECK(client_ == nullptr); - client_ = client; - } - - PjRtClient* client() const override { return client_; } - - bool IsAddressable() const override { - return process_index() == client()->process_index(); - } - - PjRtLocalDeviceId local_device_id() const override { - return PjRtLocalDeviceId(local_hardware_id().value()); - } - - PjRtLocalHardwareId local_hardware_id() const override { - return PjRtLocalHardwareId(description_.local_hardware_id()); - } - - absl::Status TransferToInfeed(const LiteralSlice& literal) override; - - absl::Status TransferFromOutfeed(MutableBorrowingLiteral literal) override; - - void AttachMemorySpace(PjRtMemorySpace* memory_space); - - absl::Span memory_spaces() const override; - - absl::StatusOr default_memory_space() const override; - - absl::StatusOr memory_space_by_kind( - absl::string_view memory_space_kind) const override; - - absl::StatusOr memory_space_by_kind_id(int id) const; - - // Returns a semaphore for admission control on inflight computations. - Semaphore& max_inflight_computations_semaphore() { - return max_inflight_computations_semaphore_; - } - - std::unique_ptr CreateAsyncTrackingEvent( - absl::string_view description) const override { - return nullptr; - } - - private: - PjRtClient* client_ = nullptr; - TfrtCpuDeviceDescription description_; - absl::InlinedVector memory_spaces_; - absl::flat_hash_map memory_spaces_by_id_; - - // TODO(zhangqiaorjc): Optimize semaphore related overhead. - // Semaphore used to limit how many programs can be enqueued by the host - // ahead of the device. - Semaphore max_inflight_computations_semaphore_; -}; - class TfrtCpuClient final : public PjRtClient { public: TfrtCpuClient( int process_index, std::vector> devices, - std::shared_ptr collectives, - size_t num_threads, bool asynchronous, + std::shared_ptr collectives, size_t num_threads, + bool asynchronous, std::function customize_hlo_module_config); ~TfrtCpuClient() override; @@ -385,13 +202,6 @@ class TfrtCpuClient final : public PjRtClient { std::function on_delete_callback, std::optional stream) override; - absl::StatusOr CreateChannelHandle() override { - return Unimplemented("CreateChannelHandle not implemented."); - } - absl::StatusOr CreateDeviceToHostChannelHandle() override { - return Unimplemented("CreateDeviceToHostChannelHandle not implemented."); - } - absl::Status Defragment() override { return Unimplemented("Defragment not implemented."); } @@ -478,9 +288,9 @@ class TfrtCpuClient final : public PjRtClient { absl::Mutex transpose_mu_; TransposePlanCache transpose_cache_ ABSL_GUARDED_BY(transpose_mu_); - std::shared_ptr collectives_; + std::shared_ptr collectives_; - xla::TfrtCpuTopologyDescription topology_; + xla::CpuTopologyDescription topology_; // Used to control whether asynchronous computation dispatch is available for // this client. Only applies to non-parallel computations. @@ -626,12 +436,14 @@ class TfrtCpuExecutable final : public PjRtLoadedExecutable { bool IsReturnedFutureSupported() const override { return true; } - absl::StatusOr> Fingerprint() const; - std::shared_ptr cpu_executable() const { return cpu_executable_; } + absl::StatusOr> Fingerprint() const { + return fingerprint_; + } + absl::StatusOr FingerprintExecutable() const override { - return Unimplemented("Fingerprinting executable is not supported."); + return fingerprint_; } absl::StatusOr GetCompileOptions() const override { @@ -697,6 +509,8 @@ class TfrtCpuExecutable final : public PjRtLoadedExecutable { // Cached result of comparing HloCostAnalysis FLOP estimate for execute // critical path. bool cheap_computation_; + + std::string fingerprint_; }; absl::StatusOr> ABSL_DEPRECATED( diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc b/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc index 7c0f6eff91b1be..b01ee3a279bb40 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc @@ -122,6 +122,10 @@ ENTRY DonationWithExecutionError() -> f32[2, 2] { TF_ASSERT_OK_AND_ASSIGN(auto pjrt_executable, client->Compile(xla_computation, {})); + TF_ASSERT_OK_AND_ASSIGN(auto fingerprint, + pjrt_executable->FingerprintExecutable()); + ASSERT_TRUE(!fingerprint.empty()); + std::vector data(4, 0); Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); TF_ASSERT_OK_AND_ASSIGN( diff --git a/third_party/xla/xla/pjrt/cpu/cpu_device.cc b/third_party/xla/xla/pjrt/cpu/cpu_device.cc new file mode 100644 index 00000000000000..4e7bf57efd9fdd --- /dev/null +++ b/third_party/xla/xla/pjrt/cpu/cpu_device.cc @@ -0,0 +1,91 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/cpu/cpu_device.h" + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/literal.h" +#include "xla/pjrt/host_memory_spaces.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/service/cpu/cpu_xfeed.h" + +namespace xla { + +TfrtCpuDevice::TfrtCpuDevice(int process_id, int local_device_id, + int max_inflight_computations) + : description_(process_id, local_device_id), + max_inflight_computations_semaphore_( + /*capacity=*/max_inflight_computations) {} + +absl::Status TfrtCpuDevice::TransferToInfeed(const LiteralSlice& literal) { + return TransferLiteralToInfeedOnCpu(local_hardware_id().value(), literal); +} + +absl::Status TfrtCpuDevice::TransferFromOutfeed( + MutableBorrowingLiteral literal) { + return TransferLiteralFromOutfeedOnCpu(local_hardware_id().value(), literal); +} + +void TfrtCpuDevice::AttachMemorySpace(PjRtMemorySpace* memory_space) { + CHECK(memory_space != nullptr); + CHECK(client_ == memory_space->client()) << absl::StrFormat( + "Could not attach a TfrtCpuDevice to a PjRtMemorySpace owned by a " + "different client, the device's client: %s, the memory space's client: " + "%s.", + client_->platform_name(), memory_space->client()->platform_name()); + + memory_spaces_.push_back(memory_space); + memory_spaces_by_id_.emplace(memory_space->kind_id(), memory_space); +} + +absl::Span TfrtCpuDevice::memory_spaces() const { + return memory_spaces_; +} + +absl::StatusOr TfrtCpuDevice::default_memory_space() const { + return memory_space_by_kind_id(UnpinnedHostMemorySpace::kKindId); +} + +absl::StatusOr TfrtCpuDevice::memory_space_by_kind( + absl::string_view memory_space_kind) const { + auto it = + absl::c_find_if(memory_spaces_, [memory_space_kind](PjRtMemorySpace* ms) { + return ms->kind() == memory_space_kind; + }); + if (it != memory_spaces_.end()) { + return *it; + } + return absl::InternalError( + absl::StrCat("No memory space found (kind: ", memory_space_kind, ")")); +} + +absl::StatusOr TfrtCpuDevice::memory_space_by_kind_id( + int id) const { + auto it = memory_spaces_by_id_.find(id); + if (it == memory_spaces_by_id_.end()) { + return absl::InternalError( + absl::StrCat("No memory space found (kind_id: ", id, ")")); + } + return it->second; +} + +} // namespace xla diff --git a/third_party/xla/xla/pjrt/cpu/cpu_device.h b/third_party/xla/xla/pjrt/cpu/cpu_device.h new file mode 100644 index 00000000000000..c6b5c8f7f3f3ce --- /dev/null +++ b/third_party/xla/xla/pjrt/cpu/cpu_device.h @@ -0,0 +1,104 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_CPU_CPU_DEVICE_H_ +#define XLA_PJRT_CPU_CPU_DEVICE_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/literal.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_device_description.h" +#include "xla/pjrt/semaphore.h" + +namespace xla { + +class TfrtCpuDevice final : public PjRtDevice { + public: + explicit TfrtCpuDevice(int process_id, int local_device_id, + int max_inflight_computations = 32); + + const CpuDeviceDescription& description() const override { + return description_; + } + + void SetClient(PjRtClient* client) { + CHECK(client_ == nullptr); + client_ = client; + } + + PjRtClient* client() const override { return client_; } + + bool IsAddressable() const override { + return process_index() == client()->process_index(); + } + + PjRtLocalDeviceId local_device_id() const override { + return PjRtLocalDeviceId(local_hardware_id().value()); + } + + PjRtLocalHardwareId local_hardware_id() const override { + return PjRtLocalHardwareId(description_.local_hardware_id()); + } + + absl::Status TransferToInfeed(const LiteralSlice& literal) override; + + absl::Status TransferFromOutfeed(MutableBorrowingLiteral literal) override; + + void AttachMemorySpace(PjRtMemorySpace* memory_space); + + absl::Span memory_spaces() const override; + + absl::StatusOr default_memory_space() const override; + + absl::StatusOr memory_space_by_kind( + absl::string_view memory_space_kind) const override; + + absl::StatusOr memory_space_by_kind_id(int id) const; + + // Returns a semaphore for admission control on inflight computations. + Semaphore& max_inflight_computations_semaphore() { + return max_inflight_computations_semaphore_; + } + + std::unique_ptr CreateAsyncTrackingEvent( + absl::string_view description) const override { + return nullptr; + } + + private: + PjRtClient* client_ = nullptr; + CpuDeviceDescription description_; + absl::InlinedVector memory_spaces_; + absl::flat_hash_map memory_spaces_by_id_; + + // TODO(zhangqiaorjc): Optimize semaphore related overhead. + // Semaphore used to limit how many programs can be enqueued by the host + // ahead of the device. + Semaphore max_inflight_computations_semaphore_; +}; + +} // namespace xla + +#endif // XLA_PJRT_CPU_CPU_DEVICE_H_ diff --git a/third_party/xla/xla/pjrt/cpu/cpu_topology.proto b/third_party/xla/xla/pjrt/cpu/cpu_topology.proto deleted file mode 100644 index 667d0159fdc4f7..00000000000000 --- a/third_party/xla/xla/pjrt/cpu/cpu_topology.proto +++ /dev/null @@ -1,13 +0,0 @@ -syntax = "proto3"; - -package xla; - -// A proto used to serialize CpuTopology instances. -message CpuTopologyProto { - message CpuDevice { - int32 process_index = 2; - int32 local_hardware_id = 3; - } - repeated CpuDevice cpu_devices = 1; - repeated string machine_attributes = 4; -} diff --git a/third_party/xla/xla/pjrt/cpu/gloo_collectives.h b/third_party/xla/xla/pjrt/cpu/gloo_collectives.h deleted file mode 100644 index 432a86c4d0acac..00000000000000 --- a/third_party/xla/xla/pjrt/cpu/gloo_collectives.h +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PJRT_CPU_GLOO_COLLECTIVES_H_ -#define XLA_PJRT_CPU_GLOO_COLLECTIVES_H_ - -#include -#include -#include -#include -#include - -#include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/synchronization/mutex.h" -#include "absl/time/time.h" -#include "absl/types/span.h" -#include "gloo/context.h" -#include "gloo/rendezvous/store.h" -#include "gloo/transport/device.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" -#include "xla/service/global_device_id.h" -#include "xla/xla_data.pb.h" - -namespace xla::cpu { - -class GlooCollectivesCommunicator : public CollectivesCommunicator { - public: - explicit GlooCollectivesCommunicator(std::shared_ptr context); - ~GlooCollectivesCommunicator() override; - - absl::Status AllReduce(const RendezvousKey& key, ReductionKind reduction_kind, - PrimitiveType element_type, size_t num_elements, - const void* input_buffer, void* output_buffer, - absl::Duration timeout) override; - absl::Status CollectivePermute(const RendezvousKey& key, size_t num_bytes, - std::optional source_rank, - absl::Span target_ranks, - const void* input_buffer, void* output_buffer, - absl::Duration timeout) override; - absl::Status AllToAll(const RendezvousKey& key, size_t chunk_bytes, - absl::Span input_buffers, - absl::Span output_buffers, - absl::Duration timeout) override; - absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, - const void* input_buffer, void* output_buffer, - absl::Duration timeout) override; - absl::Status ReduceScatter(const RendezvousKey& key, - ReductionKind reduction_kind, - PrimitiveType element_type, size_t chunk_elems, - const void* input_buffer, void* output_buffer, - absl::Duration timeout) override; - - private: - std::shared_ptr context_; -}; - -class GlooCollectives : public CollectivesInterface { - public: - GlooCollectives(std::unique_ptr store, - std::shared_ptr device); - ~GlooCollectives() override; - - // Thread-safe. - absl::StatusOr> GetCommunicator( - absl::Span devices, int rank) override; - - private: - std::unique_ptr store_; - std::shared_ptr device_; - absl::Mutex mu_; - struct Context { - absl::Mutex mu; - std::shared_ptr communicator; - }; - absl::flat_hash_map, int>, - std::unique_ptr> - contexts_ ABSL_GUARDED_BY(mu_); -}; - -} // namespace xla::cpu - -#endif // XLA_PJRT_CPU_GLOO_COLLECTIVES_H_ diff --git a/third_party/xla/xla/pjrt/cpu/mpi_collectives.h b/third_party/xla/xla/pjrt/cpu/mpi_collectives.h deleted file mode 100644 index fdf6ec81b6dc6b..00000000000000 --- a/third_party/xla/xla/pjrt/cpu/mpi_collectives.h +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PJRT_CPU_MPI_COLLECTIVES_H_ -#define XLA_PJRT_CPU_MPI_COLLECTIVES_H_ - -#include -#include -#include -#include -#include - -#include "mpi.h" // NOLINT -#include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/time/time.h" -#include "absl/types/span.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" -#include "xla/service/global_device_id.h" -#include "xla/xla_data.pb.h" - -namespace xla::cpu { - -class MpiCollectivesCommunicator : public CollectivesCommunicator { - public: - explicit MpiCollectivesCommunicator(int color, int key); - ~MpiCollectivesCommunicator() override; - - absl::Status AllReduce(const RendezvousKey& key, ReductionKind reduction_kind, - PrimitiveType element_type, size_t num_elements, - const void* input_buffer, void* output_buffer, - absl::Duration timeout) override; - absl::Status CollectivePermute(const RendezvousKey& key, size_t num_bytes, - std::optional source_rank, - absl::Span target_ranks, - const void* input_buffer, void* output_buffer, - absl::Duration timeout) override; - absl::Status AllToAll(const RendezvousKey& key, size_t chunk_bytes, - absl::Span input_buffers, - absl::Span output_buffers, - absl::Duration timeout) override; - absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, - const void* input_buffer, void* output_buffer, - absl::Duration timeout) override; - absl::Status ReduceScatter(const RendezvousKey& key, - ReductionKind reduction_kind, - PrimitiveType element_type, size_t chunk_elems, - const void* input_buffer, void* output_buffer, - absl::Duration timeout) override; - - private: - MPI_Comm comm_; - int mpi_rank_; - int mpi_size_; -}; - -class MpiCollectives : public CollectivesInterface { - public: - /* - The user has to explicitly call Init() and Finalize() before and - after use. - For example, using the Python client, this can be achieved with: - - collectives = xla_client._xla.make_mpi_collectives() - collectives.Init() - atexit.register(collectives.Finalize) - */ - void Init(); - void Finalize(); - - absl::StatusOr> GetCommunicator( - absl::Span global_devices, int rank) override; - - private: - absl::Status ExchangeGlobalDeviceIds( - absl::Span global_devices, int rank); - - int mpi_world_rank_; - int mpi_world_size_; - absl::flat_hash_map, int>, - std::shared_ptr> - contexts_; -}; - -} // namespace xla::cpu - -#endif // XLA_PJRT_CPU_MPI_COLLECTIVES_H_ diff --git a/third_party/xla/xla/pjrt/distributed/BUILD b/third_party/xla/xla/pjrt/distributed/BUILD index 85a8bc4ac3de86..09c0b1b0ecadcd 100644 --- a/third_party/xla/xla/pjrt/distributed/BUILD +++ b/third_party/xla/xla/pjrt/distributed/BUILD @@ -49,8 +49,8 @@ xla_cc_test( ":in_memory_key_value_store", ":protocol_proto_cc", ":topology_util", - "//xla:test_helpers", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", @@ -115,7 +115,6 @@ cc_library( ":key_value_store_interface", ":protocol_proto_cc", "//xla:util", - "//xla/pjrt:pjrt_client", "//xla/pjrt:utils", "//xla/pjrt/gpu:gpu_topology_proto_cc", "@com_google_absl//absl/container:flat_hash_map", @@ -166,6 +165,7 @@ cc_library( deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", ], ) diff --git a/third_party/xla/xla/pjrt/distributed/client.cc b/third_party/xla/xla/pjrt/distributed/client.cc index 69ebb5f99775f1..305afe7ae4c6d4 100644 --- a/third_party/xla/xla/pjrt/distributed/client.cc +++ b/third_party/xla/xla/pjrt/distributed/client.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -27,6 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "grpcpp/channel.h" @@ -53,14 +53,15 @@ class DistributedRuntimeCoordinationServiceClient absl::Status Connect() override; absl::Status Shutdown() override; absl::StatusOr BlockingKeyValueGet( - std::string_view key, absl::Duration timeout) override; + absl::string_view key, absl::Duration timeout) override; + absl::StatusOr KeyValueTryGet(absl::string_view key) override; absl::StatusOr>> - KeyValueDirGet(std::string_view key) override; - absl::Status KeyValueSet(std::string_view key, - std::string_view value) override; - absl::Status KeyValueSet(std::string_view key, std::string_view value, + KeyValueDirGet(absl::string_view key) override; + absl::Status KeyValueSet(absl::string_view key, + absl::string_view value) override; + absl::Status KeyValueSet(absl::string_view key, absl::string_view value, bool allow_overwrite) override; - absl::Status KeyValueDelete(std::string_view key) override; + absl::Status KeyValueDelete(absl::string_view key) override; absl::Status WaitAtBarrier( std::string barrier_id, absl::Duration timeout, std::optional> process_ids) override; @@ -141,13 +142,19 @@ absl::Status DistributedRuntimeCoordinationServiceClient::Shutdown() { absl::StatusOr DistributedRuntimeCoordinationServiceClient::BlockingKeyValueGet( - std::string_view key, absl::Duration timeout) { + absl::string_view key, absl::Duration timeout) { return coord_agent_->GetKeyValue(key, timeout); } +absl::StatusOr +DistributedRuntimeCoordinationServiceClient::KeyValueTryGet( + absl::string_view key) { + return coord_agent_->TryGetKeyValue(key); +} + absl::StatusOr>> DistributedRuntimeCoordinationServiceClient::KeyValueDirGet( - std::string_view key) { + absl::string_view key) { TF_ASSIGN_OR_RETURN(const auto results, coord_agent_->GetKeyValueDir(key)); std::vector> kvs; @@ -162,17 +169,17 @@ DistributedRuntimeCoordinationServiceClient::KeyValueDirGet( } absl::Status DistributedRuntimeCoordinationServiceClient::KeyValueDelete( - std::string_view key) { + absl::string_view key) { return coord_agent_->DeleteKeyValue(key); } absl::Status DistributedRuntimeCoordinationServiceClient::KeyValueSet( - std::string_view key, std::string_view value) { + absl::string_view key, absl::string_view value) { return KeyValueSet(key, value, /*allow_overwrite=*/false); } absl::Status DistributedRuntimeCoordinationServiceClient::KeyValueSet( - std::string_view key, std::string_view value, bool allow_overwrite) { + absl::string_view key, absl::string_view value, bool allow_overwrite) { return coord_agent_->InsertKeyValue(key, value, allow_overwrite); } @@ -212,12 +219,16 @@ class DistributedKeyValueStore : public KeyValueStoreInterface { std::string prefix) : client_(std::move(client)), prefix_(std::move(prefix)) {} - absl::StatusOr Get(std::string_view key, + absl::StatusOr Get(absl::string_view key, absl::Duration timeout) override { return client_->BlockingKeyValueGet(absl::StrCat(prefix_, key), timeout); } - absl::Status Set(std::string_view key, std::string_view value) override { + absl::StatusOr TryGet(absl::string_view key) override { + return client_->KeyValueTryGet(absl::StrCat(prefix_, key)); + } + + absl::Status Set(absl::string_view key, absl::string_view value) override { return client_->KeyValueSet(absl::StrCat(prefix_, key), value); } diff --git a/third_party/xla/xla/pjrt/distributed/client.h b/third_party/xla/xla/pjrt/distributed/client.h index 0654522bb78818..58f4fe367681d2 100644 --- a/third_party/xla/xla/pjrt/distributed/client.h +++ b/third_party/xla/xla/pjrt/distributed/client.h @@ -21,13 +21,13 @@ limitations under the License. #include #include #include -#include #include #include #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "grpcpp/channel.h" @@ -115,7 +115,10 @@ class DistributedRuntimeClient { // There are no concurrency guarantees. To avoid a race / impose an ordering // on potentially concurrent ops (e.g. set, delete), use WaitAtBarrier(). virtual absl::StatusOr BlockingKeyValueGet( - std::string_view key, absl::Duration timeout) = 0; + absl::string_view key, absl::Duration timeout) = 0; + + // Returns `NotFoundError` immediately if the key is not found. + virtual absl::StatusOr KeyValueTryGet(absl::string_view key) = 0; // Get all key-value pairs under a directory (key). // A value is considered to be in the directory if its key is prefixed with @@ -123,16 +126,17 @@ class DistributedRuntimeClient { // This is not a blocking call. If no keys are found, an empty vector is // returned immediately. virtual absl::StatusOr>> - KeyValueDirGet(std::string_view key) = 0; + KeyValueDirGet(absl::string_view key) = 0; - virtual absl::Status KeyValueSet(std::string_view key, - std::string_view value) = 0; - virtual absl::Status KeyValueSet(std::string_view key, std::string_view value, + virtual absl::Status KeyValueSet(absl::string_view key, + absl::string_view value) = 0; + virtual absl::Status KeyValueSet(absl::string_view key, + absl::string_view value, bool allow_overwrite) = 0; // Delete the key-value. If the key is a directory, recursively clean // up all key-values under the directory. - virtual absl::Status KeyValueDelete(std::string_view key) = 0; + virtual absl::Status KeyValueDelete(absl::string_view key) = 0; // Blocks until all nodes (or the ones specified in `nodes`) are at the // barrier or the barrier times out. `barrier_id` should be unique across diff --git a/third_party/xla/xla/pjrt/distributed/client_server_test.cc b/third_party/xla/xla/pjrt/distributed/client_server_test.cc index da164607f8c667..baec103eced933 100644 --- a/third_party/xla/xla/pjrt/distributed/client_server_test.cc +++ b/third_party/xla/xla/pjrt/distributed/client_server_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/log/log.h" @@ -1030,6 +1029,20 @@ TEST_F(ClientServerTest, KeyValueSet_Duplicate_Overwrites) { EXPECT_EQ(result.value(), "overwritten_value"); } +TEST_F(ClientServerTest, KeyValueTryGet) { + StartService(/*num_nodes=*/1); + auto client = GetClient(/*node_id=*/0); + TF_ASSERT_OK(client->Connect()); + + ASSERT_THAT(client->KeyValueTryGet("test_key").status(), + StatusIs(absl::StatusCode::kNotFound)); + + TF_ASSERT_OK(client->KeyValueSet("test_key", "value")); + auto result = client->KeyValueTryGet("test_key"); + TF_ASSERT_OK(result.status()); + EXPECT_EQ(result.value(), "value"); +} + TEST_F(ClientServerTest, KeyValueDelete) { StartService(/*num_nodes=*/1); auto client = GetClient(/*node_id=*/0); diff --git a/third_party/xla/xla/pjrt/distributed/in_memory_key_value_store.cc b/third_party/xla/xla/pjrt/distributed/in_memory_key_value_store.cc index 8140bb9bd80eac..49fc73ec87f163 100644 --- a/third_party/xla/xla/pjrt/distributed/in_memory_key_value_store.cc +++ b/third_party/xla/xla/pjrt/distributed/in_memory_key_value_store.cc @@ -16,17 +16,17 @@ limitations under the License. #include "xla/pjrt/distributed/in_memory_key_value_store.h" #include -#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" namespace xla { -absl::StatusOr InMemoryKeyValueStore::Get(std::string_view key, +absl::StatusOr InMemoryKeyValueStore::Get(absl::string_view key, absl::Duration timeout) { absl::MutexLock lock(&mu_); auto cond = [&]() { @@ -41,8 +41,19 @@ absl::StatusOr InMemoryKeyValueStore::Get(std::string_view key, return kv_store_.find(key)->second; } -absl::Status InMemoryKeyValueStore::Set(std::string_view key, - std::string_view value) { +absl::StatusOr InMemoryKeyValueStore::TryGet( + absl::string_view key) { + absl::MutexLock lock(&mu_); + auto it = kv_store_.find(key); + if (it == kv_store_.end()) { + return absl::NotFoundError( + absl::StrCat(key, " is not found in the kv store.")); + } + return it->second; +} + +absl::Status InMemoryKeyValueStore::Set(absl::string_view key, + absl::string_view value) { absl::MutexLock lock(&mu_); kv_store_[key] = value; return absl::OkStatus(); diff --git a/third_party/xla/xla/pjrt/distributed/in_memory_key_value_store.h b/third_party/xla/xla/pjrt/distributed/in_memory_key_value_store.h index 680abc5b4c9c0b..13f50c722bd125 100644 --- a/third_party/xla/xla/pjrt/distributed/in_memory_key_value_store.h +++ b/third_party/xla/xla/pjrt/distributed/in_memory_key_value_store.h @@ -17,22 +17,25 @@ limitations under the License. #define XLA_PJRT_DISTRIBUTED_IN_MEMORY_KEY_VALUE_STORE_H_ #include -#include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/time/time.h" #include "xla/pjrt/distributed/key_value_store_interface.h" namespace xla { class InMemoryKeyValueStore : public KeyValueStoreInterface { public: - absl::StatusOr Get(std::string_view key, + absl::StatusOr Get(absl::string_view key, absl::Duration timeout) override; - absl::Status Set(std::string_view key, std::string_view value) override; + absl::StatusOr TryGet(absl::string_view key) override; + + absl::Status Set(absl::string_view key, absl::string_view value) override; private: absl::Mutex mu_; diff --git a/third_party/xla/xla/pjrt/distributed/key_value_store_interface.h b/third_party/xla/xla/pjrt/distributed/key_value_store_interface.h index a5b68fa1aa8a7c..312ebb8abb6463 100644 --- a/third_party/xla/xla/pjrt/distributed/key_value_store_interface.h +++ b/third_party/xla/xla/pjrt/distributed/key_value_store_interface.h @@ -18,10 +18,10 @@ limitations under the License. #include #include -#include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/time/time.h" namespace xla { @@ -38,12 +38,19 @@ class KeyValueStoreInterface { virtual ~KeyValueStoreInterface() = default; // Blocking Get(). + // Useful for listening for a key-value pair that may be set later on. // There are no concurrency guarantees. To avoid a race / impose an ordering // on potentially concurrent ops (e.g. set, delete), use WaitAtBarrier(). - virtual absl::StatusOr Get(std::string_view key, + virtual absl::StatusOr Get(absl::string_view key, absl::Duration timeout) = 0; - virtual absl::Status Set(std::string_view key, std::string_view value) = 0; + // Returns `NotFoundError` immediately if the key is not found. + // Useful for checking key existence. + // There are no concurrency guarantees. To avoid a race / impose an ordering + // on potentially concurrent ops (e.g. set, delete), use WaitAtBarrier(). + virtual absl::StatusOr TryGet(absl::string_view key) = 0; + + virtual absl::Status Set(absl::string_view key, absl::string_view value) = 0; }; struct MultiProcessKeyValueStore { diff --git a/third_party/xla/xla/pjrt/distributed/topology_util.cc b/third_party/xla/xla/pjrt/distributed/topology_util.cc index e3926dcb39cd5a..ca08bbb530f2c8 100644 --- a/third_party/xla/xla/pjrt/distributed/topology_util.cc +++ b/third_party/xla/xla/pjrt/distributed/topology_util.cc @@ -16,11 +16,12 @@ limitations under the License. #include "xla/pjrt/distributed/topology_util.h" #include +#include +#include #include #include #include #include -#include #include #include "absl/container/flat_hash_map.h" @@ -29,13 +30,13 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/substitute.h" #include "absl/synchronization/blocking_counter.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/distributed/protocol.pb.h" -#include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/utils.h" #include "xla/util.h" #include "tsl/platform/env.h" @@ -46,6 +47,34 @@ limitations under the License. namespace xla { +namespace { +bool SameDevice(const DeviceProto& a, const DeviceProto& b) { + return (a.name() == b.name() && a.vendor() == b.vendor() && + a.local_device_ordinal() == b.local_device_ordinal() && + a.core_count() == b.core_count() && + a.device_kind() == b.device_kind() && + a.slice_index() == b.slice_index() && + // Global device ID Might not be set for LocalTopologyProto, still + // check it for default value. + a.global_device_id() == b.global_device_id() && + a.compute_capability() == b.compute_capability()); +} + +bool SameLocalTopology(const LocalTopologyProto& a, + const LocalTopologyProto& b) { + if (a.node_id() != b.node_id() || a.devices_size() != b.devices_size()) { + return false; + } + for (int i = 0; i < a.devices_size(); ++i) { + if (!SameDevice(a.devices(i), b.devices(i))) { + return false; + } + } + return true; +} + +} // namespace + // Exists on Linux systems. Unique per OS kernel restart. static constexpr char kBootIdPath[] = "/proc/sys/kernel/random/boot_id"; @@ -68,16 +97,17 @@ absl::StatusOr GetBootIdString() { return boot_id_str; } -static std::string GetLocalTopologyKey(std::string_view platform, int node_id) { +static std::string GetLocalTopologyKey(absl::string_view platform, + int node_id) { return absl::StrCat("local_topology/", platform, "/", node_id); } -static std::string GetGlobalTopologyKey(std::string_view platform) { +static std::string GetGlobalTopologyKey(absl::string_view platform) { return absl::StrCat("global_topology/", platform); } static absl::StatusOr> GetAllLocalTopologies( - std::string_view platform, int num_nodes, KeyValueStoreInterface* kv_store, + absl::string_view platform, int num_nodes, KeyValueStoreInterface* kv_store, absl::Duration timeout) { std::vector> local_topology_strs(num_nodes); @@ -136,7 +166,7 @@ GlobalTopologyProto BuildGlobalTopology( absl::flat_hash_map boot_id_to_slice_index; for (LocalTopologyProto& local : local_topologies) { // Every new boot_id seen is treated as a new host/slice. - std::string_view boot_id = local.boot_id(); + absl::string_view boot_id = local.boot_id(); auto [it, inserted] = boot_id_to_slice_index.try_emplace(boot_id, next_slice_index); if (inserted) { @@ -160,7 +190,7 @@ GlobalTopologyProto BuildGlobalTopology( return global_topology; } -absl::Status ExchangeTopologies(std::string_view platform, int node_id, +absl::Status ExchangeTopologies(absl::string_view platform, int node_id, int num_nodes, absl::Duration get_local_topology_timeout, absl::Duration get_global_topology_timeout, @@ -179,8 +209,34 @@ absl::Status ExchangeTopologies(std::string_view platform, int node_id, return absl::OkStatus(); } CHECK(kv_store != nullptr); - TF_RETURN_IF_ERROR(kv_store->Set(GetLocalTopologyKey(platform, node_id), - local_topology.SerializeAsString())); + const std::string local_topology_key = GetLocalTopologyKey(platform, node_id); + const std::string serialized_local_topology = + local_topology.SerializeAsString(); + + absl::StatusOr existing_local_topology = + kv_store->TryGet(local_topology_key); + printf("existing_local_topology status: %s\n", + existing_local_topology.status().ToString().c_str()); + + if (existing_local_topology.ok()) { + printf("existing topology found"); + // Local topology has been set previously from the same node before + // restart. + LocalTopologyProto existing_local_topology_proto; + existing_local_topology_proto.ParseFromString(*existing_local_topology); + if (!SameLocalTopology(existing_local_topology_proto, local_topology)) { + return absl::InternalError(absl::Substitute( + "Different local topology for node $0 has been set previously, " + "possibly before a restart.\nBefore: $1\nAfter: $2", + node_id, existing_local_topology_proto.DebugString(), + local_topology.DebugString())); + } + } else if (absl::IsNotFound(existing_local_topology.status())) { + TF_RETURN_IF_ERROR(kv_store->Set(GetLocalTopologyKey(platform, node_id), + serialized_local_topology)); + } else { + return existing_local_topology.status(); + } // The lead node gets all local topologies, builds the global topology and // puts it to the key-value store. diff --git a/third_party/xla/xla/pjrt/distributed/topology_util.h b/third_party/xla/xla/pjrt/distributed/topology_util.h index ec902d72efd63a..2e492d9c907398 100644 --- a/third_party/xla/xla/pjrt/distributed/topology_util.h +++ b/third_party/xla/xla/pjrt/distributed/topology_util.h @@ -17,7 +17,6 @@ limitations under the License. #define XLA_PJRT_DISTRIBUTED_TOPOLOGY_UTIL_H_ #include -#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -40,7 +39,7 @@ absl::StatusOr GetBootIdString(); // topology in the order they appear in the input. Otherwise leaves the global // IDs as they were in the local topologies.. // TODO(phawkins): deprecate and remove assign_global_device_ids. -absl::Status ExchangeTopologies(std::string_view platform, int node_id, +absl::Status ExchangeTopologies(absl::string_view platform, int node_id, int num_nodes, absl::Duration get_local_topology_timeout, absl::Duration get_global_topology_timeout, diff --git a/third_party/xla/xla/pjrt/distributed/topology_util_test.cc b/third_party/xla/xla/pjrt/distributed/topology_util_test.cc index ad63a3071d3b41..06464dc9b1b1b3 100644 --- a/third_party/xla/xla/pjrt/distributed/topology_util_test.cc +++ b/third_party/xla/xla/pjrt/distributed/topology_util_test.cc @@ -16,14 +16,13 @@ limitations under the License. #include "xla/pjrt/distributed/topology_util.h" #include -#include #include +#include "absl/status/status.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/distributed/protocol.pb.h" -#include "xla/test_helpers.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" @@ -32,6 +31,7 @@ limitations under the License. namespace xla { namespace { +using tsl::testing::StatusIs; TEST(TopologyTest, BuildGlobalTopology) { std::vector locals(2); @@ -87,6 +87,94 @@ TEST(TopologyTest, ExchangeTopology) { } } +TEST(TopologyTest, ExchangeTopology_Twice_Succeeds) { + int num_nodes = 2; + std::vector locals(num_nodes); + DeviceProto* d0 = locals[0].add_devices(); + d0->set_local_device_ordinal(0); + DeviceProto* d1 = locals[0].add_devices(); + d1->set_local_device_ordinal(0); + DeviceProto* d2 = locals[1].add_devices(); + d2->set_local_device_ordinal(0); + DeviceProto* d3 = locals[1].add_devices(); + d3->set_local_device_ordinal(1); + + InMemoryKeyValueStore kv_store; + std::vector globals(num_nodes); + { + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "TestPool", + num_nodes); + for (int i = 0; i < num_nodes; i++) { + thread_pool.Schedule([&, i] { + TF_ASSERT_OK(ExchangeTopologies( + /*platform=*/"cuda", /*node_id=*/i, num_nodes, + /*get_local_topology_timeout=*/ + absl::Seconds(10), /*get_global_topology_timeout=*/ + absl::Seconds(10), &kv_store, locals[i], &globals[i], + /*assign_global_device_ids=*/true)); + // Simulate node 1 restarting and exchanging topologies again. + if (i == 1) { + TF_ASSERT_OK(ExchangeTopologies( + /*platform=*/"cuda", /*node_id=*/i, num_nodes, + /*get_local_topology_timeout=*/ + absl::Seconds(10), /*get_global_topology_timeout=*/ + absl::Seconds(10), &kv_store, locals[i], &globals[i], + /*assign_global_device_ids=*/true)); + } + }); + } + } + for (const GlobalTopologyProto& global : globals) { + EXPECT_EQ(global.nodes_size(), 2); + EXPECT_EQ(global.nodes()[0].devices_size(), 2); + EXPECT_EQ(global.nodes()[1].devices_size(), 2); + } +} + +TEST(TopologyTest, ExchangeTopology_TwiceWithDifferentLocalTopology_Fails) { + int num_nodes = 2; + std::vector locals(num_nodes); + DeviceProto* d0 = locals[0].add_devices(); + d0->set_local_device_ordinal(0); + DeviceProto* d1 = locals[0].add_devices(); + d1->set_local_device_ordinal(0); + DeviceProto* d2 = locals[1].add_devices(); + d2->set_local_device_ordinal(0); + DeviceProto* d3 = locals[1].add_devices(); + d3->set_local_device_ordinal(1); + + InMemoryKeyValueStore kv_store; + std::vector globals(num_nodes); + { + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "TestPool", + num_nodes); + for (int i = 0; i < num_nodes; i++) { + thread_pool.Schedule([&, i] { + TF_ASSERT_OK(ExchangeTopologies( + /*platform=*/"cuda", /*node_id=*/i, num_nodes, + /*get_local_topology_timeout=*/ + absl::Seconds(10), /*get_global_topology_timeout=*/ + absl::Seconds(10), &kv_store, locals[i], &globals[i], + /*assign_global_device_ids=*/true)); + // Simulate node 1 restarting with different devices. + if (i == 1) { + DeviceProto* d4 = locals[1].add_devices(); + d4->set_local_device_ordinal(2); + // This should fail because the local topology is unexpectedly + // different. + EXPECT_THAT(ExchangeTopologies( + /*platform=*/"cuda", /*node_id=*/i, num_nodes, + /*get_local_topology_timeout=*/ + absl::Seconds(10), /*get_global_topology_timeout=*/ + absl::Seconds(10), &kv_store, locals[i], &globals[i], + /*assign_global_device_ids=*/true), + StatusIs(absl::StatusCode::kInternal)); + } + }); + } + } +} + TEST(TopologyTest, BuildGpuTopology) { std::string slice_0_boot_id = "foo"; std::string slice_1_boot_id = "bar"; diff --git a/third_party/xla/xla/pjrt/event_pool.h b/third_party/xla/xla/pjrt/event_pool.h index 65a55bb1ac2a8e..a0b33e55b5e014 100644 --- a/third_party/xla/xla/pjrt/event_pool.h +++ b/third_party/xla/xla/pjrt/event_pool.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_PJRT_EVENT_POOL_H_ #define XLA_PJRT_EVENT_POOL_H_ +#include #include #include diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 8ec969d26d554b..f8f0d0aee8d99d 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -79,10 +79,12 @@ cc_library( "//xla/pjrt:stream_executor_executable_proto_cc", "//xla/pjrt:tracked_device_buffer", "//xla/pjrt:utils", + "//xla/pjrt:worker_thread", "//xla/pjrt/distributed:client", "//xla/pjrt/distributed:in_memory_key_value_store", "//xla/pjrt/distributed:key_value_store_interface", "//xla/pjrt/distributed:topology_util", + "//xla/pjrt/plugin/xla_gpu:xla_gpu_allocator_config", "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options", "//xla/service:compiler", "//xla/service:computation_placer_hdr", @@ -420,7 +422,6 @@ xla_test( ":se_gpu_pjrt_compiler_impl", "//xla:literal", "//xla:literal_util", - "//xla/client:xla_computation", "//xla/hlo/builder:xla_computation", "//xla/hlo/parser:hlo_parser", "//xla/mlir_hlo", diff --git a/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc b/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc index c89f333d209e28..e604f771ebacf0 100644 --- a/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc +++ b/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -103,9 +102,7 @@ absl::StatusOr> CreateBFCAllocator( executor, tsl::PlatformDeviceId(device_ordinal), /*memory_type=*/ enable_unified_memory ? stream_executor::MemoryType::kUnified - : stream_executor::MemoryType::kDevice, - /*alloc_visitors=*/std::vector(), - /*free_visitors=*/std::vector()); + : stream_executor::MemoryType::kDevice); int64_t free_memory; int64_t total_memory; @@ -147,9 +144,7 @@ absl::StatusOr> CreateCollectiveBFCAllocator( int device_ordinal = executor->device_ordinal(); auto sub_allocator = std::make_unique( executor, tsl::PlatformDeviceId(device_ordinal), - /*memory_type=*/stream_executor::MemoryType::kCollective, - /*alloc_visitors=*/std::vector(), - /*free_visitors=*/std::vector()); + /*memory_type=*/stream_executor::MemoryType::kCollective); int64_t free_memory; int64_t total_memory; @@ -236,7 +231,7 @@ int TopologySizes::GetDeviceCount() { // static absl::StatusOr TopologySizes::FromString( - std::string_view topology_string) { + absl::string_view topology_string) { TopologySizes sizes; std::vector topology_components = absl::StrSplit(topology_string, 'x'); diff --git a/third_party/xla/xla/pjrt/gpu/gpu_helpers.h b/third_party/xla/xla/pjrt/gpu/gpu_helpers.h index 9807967654593d..3f6472d628a383 100644 --- a/third_party/xla/xla/pjrt/gpu/gpu_helpers.h +++ b/third_party/xla/xla/pjrt/gpu/gpu_helpers.h @@ -65,7 +65,7 @@ struct TopologySizes { // " x x " // and returns the parsed components on success. static absl::StatusOr FromString( - std::string_view topology_string); + absl::string_view topology_string); }; } // namespace xla diff --git a/third_party/xla/xla/pjrt/gpu/gpu_topology.h b/third_party/xla/xla/pjrt/gpu/gpu_topology.h index 9636432c17d2a9..609c7fbab610b1 100644 --- a/third_party/xla/xla/pjrt/gpu/gpu_topology.h +++ b/third_party/xla/xla/pjrt/gpu/gpu_topology.h @@ -57,7 +57,7 @@ class GpuTopology { const GpuTopologyProto& proto); GpuTopologyProto ToProto() const; - std::string_view platform_version() const { return platform_version_; } + absl::string_view platform_version() const { return platform_version_; } int32_t num_slices() const { return num_slices_; } int32_t num_hosts_per_slice() const { return num_hosts_per_slice_; } int32_t num_devices_per_host() const { return num_devices_per_host_; } diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 7de31ac3a0090e..cd54e96ce4a91e 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -18,12 +18,12 @@ limitations under the License. #include #include #include +#include #include #include #include #include #include -#include #include #include #include @@ -39,6 +39,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" @@ -65,8 +66,11 @@ limitations under the License. #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_stream_executor_client.h" +#include "xla/pjrt/plugin/xla_gpu/xla_gpu_allocator_config.h" +#include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h" #include "xla/pjrt/stream_executor_executable.h" #include "xla/pjrt/tracked_device_buffer.h" +#include "xla/pjrt/worker_thread.h" #include "xla/service/compiler.h" #include "xla/service/computation_placer.h" #include "xla/service/global_device_id.h" @@ -114,7 +118,6 @@ limitations under the License. #endif #include "xla/service/gpu/gpu_executable_run_options.h" -#include "xla/stream_executor/integrations/device_mem_allocator.h" #include "xla/stream_executor/integrations/tf_allocator_adapter.h" #include "xla/util.h" @@ -1073,12 +1076,12 @@ void NameDeviceAndLauncherThread(const LocalTopologyProto& node, } // namespace absl::StatusOr BuildDistributedDevices( - std::string_view platform_name, + absl::string_view platform_name, std::map> local_device_states, int node_id, int num_nodes, gpu::GpuExecutableRunOptions* gpu_executable_run_options, std::shared_ptr kv_store, bool enable_mock_nccl, - std::optional mock_gpu_topology, + std::optional mock_gpu_topology, absl::Duration get_local_topology_timeout, absl::Duration get_global_topology_timeout) { std::vector> devices; diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h index 6109620a0c2257..a60a65c4bf3dde 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -271,12 +271,12 @@ std::vector> BuildLocalDevices( std::string MakeComputeCapabilityString(const se::DeviceDescription* desc); absl::StatusOr BuildDistributedDevices( - std::string_view platform_name, + absl::string_view platform_name, std::map> local_device_states, int node_id, int num_nodes, gpu::GpuExecutableRunOptions* gpu_executable_run_options, std::shared_ptr kv_store, bool enable_mock_nccl, - std::optional mock_gpu_topology = std::nullopt, + std::optional mock_gpu_topology = std::nullopt, absl::Duration get_local_topology_timeout = absl::Minutes(2), absl::Duration get_global_topology_timeout = absl::Minutes(5)); diff --git a/third_party/xla/xla/pjrt/interpreter/BUILD b/third_party/xla/xla/pjrt/interpreter/BUILD new file mode 100644 index 00000000000000..750e580497b81b --- /dev/null +++ b/third_party/xla/xla/pjrt/interpreter/BUILD @@ -0,0 +1,67 @@ +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:private"], + licenses = ["notice"], +) + +cc_library( + name = "interpreter_client", + srcs = ["interpreter_client.cc"], + hdrs = ["interpreter_client.h"], + visibility = internal_visibility(["//xla:friends"]), + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla:util", + "//xla/backends/interpreter:compiler", + "//xla/client:executable_build_options", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/evaluator:hlo_evaluator", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:cholesky_expander", + "//xla/hlo/transforms:dynamic_index_splitter", + "//xla/hlo/transforms:eigh_expander", + "//xla/hlo/transforms:qr_expander", + "//xla/pjrt:layout_mode", + "//xla/pjrt:mlir_to_hlo", + "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_common", + "//xla/pjrt:pjrt_compiler", + "//xla/pjrt:pjrt_device_description", + "//xla/pjrt:pjrt_executable", + "//xla/pjrt:pjrt_future", + "//xla/pjrt:utils", + "//xla/service:batchnorm_expander", + "//xla/service:computation_placer_hdr", + "//xla/service:custom_call_target_registry", + "//xla/service:dynamic_dimension_inference", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_module_config", + "//xla/service:hlo_module_util", + "//xla/service:layout_assignment", + "//xla/service:topk_rewriter", + "//xla/service:triangular_solve_expander", + "//xla/tsl/platform:errors", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:fingerprint", + "@local_tsl//tsl/platform:statusor", + ], +) diff --git a/third_party/xla/xla/pjrt/interpreter/interpreter_client.cc b/third_party/xla/xla/pjrt/interpreter/interpreter_client.cc new file mode 100644 index 00000000000000..fea857e6d89a1a --- /dev/null +++ b/third_party/xla/xla/pjrt/interpreter/interpreter_client.cc @@ -0,0 +1,539 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/interpreter/interpreter_client.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" +#include "xla/client/executable_build_options.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/evaluator/hlo_evaluator.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/expanders/cholesky_expander.h" +#include "xla/hlo/transforms/expanders/dynamic_index_splitter.h" +#include "xla/hlo/transforms/expanders/eigh_expander.h" +#include "xla/hlo/transforms/expanders/qr_expander.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/pjrt/layout_mode.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/utils.h" +#include "xla/service/batchnorm_expander.h" +#include "xla/service/computation_placer.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/dynamic_dimension_inference.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/hlo_module_util.h" +#include "xla/service/layout_assignment.h" +#include "xla/service/topk_rewriter.h" +#include "xla/service/triangular_solve_expander.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +bool ShapesMatch(const Shape& expected_shape, const Shape& actual_shape) { + if (expected_shape.is_dynamic()) { + return ShapeUtil::DynamicArrayShapeIsCompatible(actual_shape, + expected_shape); + } + return Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape, + actual_shape); +} + +absl::StatusOr ChooseCompactLayoutForShape(const Shape& shape) { + return LayoutUtil::GetWithDefaultLayout(shape); +} + +// Handles custom_call ops during evaluation by routing them through the global +// CPU registry used by other CPU-based backends. +absl::StatusOr HandleEvaluatorCustomCall( + const HloInstruction* custom_call, absl::Span operands) { + // Find the target C function in the global registry. + CustomCallTargetRegistry* const registry = CustomCallTargetRegistry::Global(); + void* const target_fn = + registry->Lookup(custom_call->custom_call_target(), "Host"); + if (target_fn == nullptr) { + return NotFound("Custom call target '%s' was not registered", + custom_call->custom_call_target()); + } + + // Populate pointers to operand and output literal data. + std::vector operand_data; + operand_data.reserve(operands.size()); + for (const Literal* const literal : operands) { + operand_data.push_back(literal->untyped_data()); + } + Literal output = Literal::CreateFromShape(custom_call->shape()); + void* const output_data = output.untyped_data(); + + // Call the target function matching the C ABI used by the CPU backends. + auto* typed_fn = reinterpret_cast(target_fn); + (*typed_fn)(output_data, operand_data.data()); + + return std::move(output); +} + +// Extract the input literals from the provided buffers. +// +// If there is a tupled argument and the arguments are not tupled, the extracted +// literals will be reconstituted into a tuple. The second element of the +// returned tuple is storage for the tupled literal, if required. Otherwise it +// is nullptr. +absl::StatusOr, std::unique_ptr>> +ExtractInterpreterInputLiteralsFromBuffers( + const absl::Span buffers, + const HloComputation& entry_computation, + const bool parameter_is_tupled_arguments, const bool arguments_are_tupled) { + std::vector literals; + for (PjRtBuffer* const buffer : buffers) { + InterpreterLiteralWrapperBuffer* interpreter_buffer = + dynamic_cast(buffer); + if (interpreter_buffer == nullptr) { + return absl::InvalidArgumentError( + "Interpreter only supports InterpreterLiteralWrapperBuffers"); + } + literals.push_back(&interpreter_buffer->mutable_literal()); + } + + // Return early if arguments don't need to be re-tupled. + if (!parameter_is_tupled_arguments || arguments_are_tupled) { + return std::make_tuple(std::move(literals), nullptr); + } + + if (entry_computation.num_parameters() != 1) { + return absl::InvalidArgumentError(absl::StrFormat( + "Interpreter expected a single tupled entry parameter, but got %d.", + entry_computation.num_parameters())); + } + + // Re-tuple input arguments. PjRt is commonly used in a mode where the input + // tuple (if present) is flattened and passed as a vector of argument + // buffers. The HloEvaluator expects the input to be tupled in these cases. + // + // This process invalidates the input literals and thus the input buffers + // themselves. + std::vector shapes; + shapes.reserve(literals.size()); + for (const Literal* literal : literals) { + shapes.push_back(literal->shape()); + } + auto tupled_arg_literal = + std::make_unique(ShapeUtil::MakeTupleShape(shapes), + /*allocate_arrays=*/false); + for (int i = 0; i < literals.size(); ++i) { + TF_RETURN_IF_ERROR(tupled_arg_literal->MoveFrom(std::move(*literals[i]), + /*dest_shape_index=*/{i})); + } + + // Replace arg literals with the tupled literal. + literals.clear(); + literals.push_back(tupled_arg_literal.get()); + return std::make_tuple(std::move(literals), std::move(tupled_arg_literal)); +} + +// The interpreter is a 1 replica, 1 partition = 1 device system. +inline DeviceAssignment MakeInterpreterDeviceAssignment() { + DeviceAssignment assignment(1, 1); + assignment(0, 0) = 0; + return assignment; +} +} // namespace + +const InterpreterDescription& InterpreterDescription::Singleton() { + static const InterpreterDescription* singleton = new InterpreterDescription; + return *singleton; +} + +absl::StatusOr>>> +InterpreterLoadedExecutable::Execute( + absl::Span> argument_handles, + const ExecuteOptions& options, + std::optional>>& returned_futures) { + if (device_assignment_ == nullptr) { + return absl::InvalidArgumentError( + "Execute expects a non-null device_assignment"); + } + if (argument_handles.size() != addressable_devices_.size()) { + return absl::InvalidArgumentError(absl::StrFormat( + "Attempted to execute with %d argument lists when device count is %d " + "(total replica count: %d, partition count: %d)", + argument_handles.size(), addressable_devices_.size(), num_replicas(), + num_partitions())); + } + if (addressable_devices_.size() != 1) { + return absl::InvalidArgumentError( + absl::StrFormat("Attempted to execute with %d devices, but interpreter " + "only supports single device execution.", + addressable_devices_.size())); + } + + std::optional> returned_future; + TF_ASSIGN_OR_RETURN( + std::vector> replica_result, + ExecuteSharded(argument_handles[0], addressable_devices_[0], options, + returned_future, returned_futures.has_value())); + std::vector>> result; + result.push_back(std::move(replica_result)); + if (returned_futures.has_value()) { + CHECK(returned_future.has_value()) + << "returned_future must be set because ExecuteSharded was called with " + "fill_future=true."; + returned_futures = std::vector>({*std::move(returned_future)}); + } + return result; +} + +absl::StatusOr>> +InterpreterLoadedExecutable::ExecuteSharded( + absl::Span argument_handles, PjRtDevice* device, + const ExecuteOptions& options, std::optional>& returned_future, + bool fill_future) { + if (device_assignment_ == nullptr) { + return absl::InvalidArgumentError( + "ExecuteSharded expects a non-null device_assignment"); + } + // Since there is only one device, the device should always be the same. Check + // anyways just to be sure. + if (!absl::c_any_of( + addressable_devices_, + [needle = device](PjRtDevice* const d) { return d == needle; })) { + return absl::InvalidArgumentError(absl::StrFormat( + "ExecuteShard attempted to execute on device id %d, which is not " + "addressable by this client.", + device->global_device_id().value())); + } + + // Extract the literals from the arguments. + const HloComputation& computation = *hlo_module_->entry_computation(); + TF_ASSIGN_OR_RETURN(const auto literals_and_storage, + ExtractInterpreterInputLiteralsFromBuffers( + argument_handles, computation, + compile_options_.parameter_is_tupled_arguments, + options.arguments_are_tupled)); + const absl::Span literals = + std::get<0>(literals_and_storage); + if (computation.num_parameters() != literals.size()) { + return absl::InternalError(absl::StrFormat( + "Mismatch between argument count (%d) and graph parameter count (%d).", + literals.size(), computation.num_parameters())); + } + + // Check that the args have the right shape. + for (int64_t i = 0; i < computation.num_parameters(); ++i) { + const Shape& expected_shape = computation.parameter_instruction(i)->shape(); + const Shape& actual_shape = literals[i]->shape(); + if (!ShapesMatch(expected_shape, actual_shape)) { + return absl::InvalidArgumentError(absl::StrFormat( + "Shape mismatch on parameter %d. Expected %s but was %s.", i, + ShapeUtil::HumanStringWithLayout(expected_shape), + ShapeUtil::HumanStringWithLayout(actual_shape))); + } + } + + TF_ASSIGN_OR_RETURN(Literal result_literal, Evaluate(computation, literals)); + // Shrink the generated dynamic shape into static shape. + result_literal = result_literal.ToStatic(); + if (fill_future) { + returned_future = PjRtFuture<>(absl::OkStatus()); + } + + // Transform the result literal back into a one or more + // InterpreterLiteralWrapperBuffer. + std::vector> result; + // Untuple result if requested. + if (options.untuple_result && result_literal.shape().IsTuple()) { + const int tuple_count = result_literal.shape().tuple_shapes_size(); + result.reserve(tuple_count); + // DecomposeTuple invalidates result_literal. move(...) to make it obvious. + std::vector tuple_elements = + std::move(result_literal).DecomposeTuple(); + CHECK(tuple_count == tuple_elements.size()) + << "DecomposedTuple returned the wrong number of elements."; + for (int i = 0; i < tuple_count; ++i) { + result.push_back(std::make_unique( + client_, device, std::move(tuple_elements[i]))); + } + } else { + result.push_back(std::make_unique( + client_, device, std::move(result_literal))); + } + return result; +} + +absl::StatusOr>> +InterpreterLoadedExecutable::ExecutePortable( + absl::Span argument_handles, PjRtDevice* device, + const ExecuteOptions& options, std::optional>& returned_future, + bool fill_future) { + return absl::UnimplementedError("ExecutePortable is not implemented"); +} + +absl::StatusOr InterpreterLoadedExecutable::Evaluate( + const HloComputation& computation, + absl::Span arg_literals) { + absl::MutexLock lock(&hlo_evaluator_lock_); + return hlo_evaluator_->Evaluate(computation, arg_literals); +} + +absl::StatusOr InterpreterClient::GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const { + if (num_replicas != 1 || num_partitions != 1) { + return absl::UnimplementedError( + "Interpreter only supports num_replicas=1 and num_partitions=1."); + } + return MakeInterpreterDeviceAssignment(); +} + +absl::StatusOr InterpreterClient::GetDefaultLayout( + PrimitiveType element_type, absl::Span dims) { + // This is all the GenericTransferManager::ChooseCompactLayoutForShape does. + Shape shape = ShapeUtil::MakeShape(element_type, dims); + LayoutUtil::SetToDefaultLayout(&shape); + return shape.layout(); +} + +absl::StatusOr> +InterpreterClient::Compile(const XlaComputation& computation, + CompileOptions options) { + std::vector argument_layout_pointers; + const ExecutableBuildOptions& build_options = + options.executable_build_options; + const bool allow_auto_layout = + build_options.has_debug_options() && + build_options.debug_options().xla_pjrt_allow_auto_layout_in_hlo(); + TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions( + computation, + [allow_auto_layout](Shape shape) -> absl::StatusOr { + if (allow_auto_layout && !shape.has_layout()) { + return shape; + } + return ChooseCompactLayoutForShape(shape); + }, + options.argument_layouts, &options.executable_build_options, + &argument_layout_pointers)); + return CompileInternal(computation, argument_layout_pointers, + /*layout_canonicalization_callback=*/nullptr, options); +} + +absl::StatusOr> +InterpreterClient::Compile(mlir::ModuleOp module, CompileOptions options) { + XlaComputation xla_computation; + const ExecutableBuildOptions& exec_build_options = + options.executable_build_options; + TF_RETURN_IF_ERROR(MlirToXlaComputation( + module, xla_computation, + /*use_tuple_args=*/options.parameter_is_tupled_arguments, + /*return_tuple=*/false, exec_build_options.use_shardy_partitioner())); + + // If the compile options specify argument layout, then let's + // fall back to using the options to determine layouts. + if (options.argument_layouts) { + return Compile(xla_computation, options); + } + + TF_ASSIGN_OR_RETURN(std::vector arg_layout_modes, + GetArgLayoutModes(module)); + TF_ASSIGN_OR_RETURN(std::vector out_layout_modes, + GetOutputLayoutModes(module)); + TF_ASSIGN_OR_RETURN(std::vector arg_memory_spaces, + GetArgMemoryKinds(module)); + TF_ASSIGN_OR_RETURN(std::vector out_memory_spaces, + GetOutputMemoryKinds(module)); + + // If auto-sharding modifies shapes of arguments and/or result, + // we get a callback to restore the layouts. Let us restore the layouts + // according to the attributes we parsed from MLIR. + auto layout_callback = [&arg_layout_modes, &out_layout_modes, + &arg_memory_spaces, + &out_memory_spaces](const HloModule& module) + -> absl::StatusOr, Shape>> { + XlaComputation xla_computation(XlaComputation(module.ToProto())); + return LayoutModesToXlaShapes( + xla_computation, arg_layout_modes, out_layout_modes, arg_memory_spaces, + out_memory_spaces, ChooseCompactLayoutForShape); + }; + + // This call will update result_layout in options.executable_build_options. + TF_ASSIGN_OR_RETURN( + auto arg_layouts_and_pointers, + LayoutModesToXla(xla_computation, arg_layout_modes, out_layout_modes, + arg_memory_spaces, out_memory_spaces, + ChooseCompactLayoutForShape, + options.executable_build_options)); + return CompileInternal(xla_computation, arg_layouts_and_pointers.second, + layout_callback, options); +} + +absl::StatusOr> +InterpreterClient::BufferFromHostLiteral(const LiteralSlice& literal, + PjRtDevice* device) { + return std::make_unique(device->client(), + device, literal); +} + +absl::StatusOr> +InterpreterClient::BufferFromHostLiteral(const LiteralSlice& literal, + PjRtDevice* device, + const Layout* device_layout) { + if (device_layout == nullptr) { + return BufferFromHostLiteral(literal, device); + } + Literal device_literal = literal.Relayout(*device_layout); + return std::make_unique( + device->client(), device, std::move(device_literal)); +} + +absl::StatusOr> +InterpreterClient::CompileInternal( + const XlaComputation& computation, + const std::vector& argument_shapes, + LayoutCanonicalizationCallback layout_canonicalization_callback, + CompileOptions options) { + CompileOptions input_options = options; + TF_RETURN_IF_ERROR(options.ApplyAllOptionOverrides()); + if (layout_canonicalization_callback != nullptr) { + options.executable_build_options.set_layout_canonicalization_callback( + layout_canonicalization_callback); + } + + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + computation.GetProgramShape()); + + const ExecutableBuildOptions& build_options = + options.executable_build_options; + ExecutionOptions execution_options = + CreateExecutionOptions(build_options, &program_shape); + + // Unoptimized HloModuleConfig. + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module_config, + CreateModuleConfig(program_shape, argument_shapes, &execution_options, + execution_options.num_replicas(), + /*num_threads=*/std::nullopt, + /*aot_options=*/nullptr)); + // Unoptimized HloModule. + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + HloModule::CreateFromProto(computation.proto(), *hlo_module_config)); + + if (build_options.num_partitions() != 1) { + return absl::UnimplementedError( + "For the time being, only num_partitions=1 is supported."); + } + + if (!build_options.run_backend_only()) { + TF_ASSIGN_OR_RETURN(hlo_module, RunHloPasses(std::move(hlo_module))); + } + + return RunBackend(std::move(hlo_module), options); +} + +absl::StatusOr> InterpreterClient::RunHloPasses( + std::unique_ptr hlo_module) { + HloPassPipeline pipeline("Interpreter"); + + // The TopkDecomposer generates a compare op with type=TOTALORDER and must + // run before the ComparisonExpander which rewrites such comparisons. + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass( + /*rewrite_training_op=*/true, + /*rewrite_inference_op=*/true, + /*rewrite_grad_op=*/true); + pipeline.AddPass( + hlo_module->mutable_entry_computation_layout()); + + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module.get()).status()); + return hlo_module; +} + +absl::StatusOr> +InterpreterClient::RunBackend(std::unique_ptr hlo_module, + CompileOptions& options) { + TF_ASSIGN_OR_RETURN( + DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run( + hlo_module.get(), + /*op_supports_dynamism_handler=*/[&](HloInstruction* hlo) { + return OpDynamismSupport::kOptional; + })); + auto evaluator = std::make_unique(); + evaluator->set_use_fast_path( + hlo_module->config().debug_options().xla_hlo_evaluator_use_fast_path()); + evaluator->set_custom_call_handler(HandleEvaluatorCustomCall); + + std::shared_ptr device_assignment = nullptr; + std::vector + addressable_device_logical_ids; + std::vector addressable_devices; + int num_replicas = 0, num_partitions = 0; + TF_RETURN_IF_ERROR(ParseDeviceAssignmentCompileOptions( + options.compile_portable_executable, &options.executable_build_options, + [this](int num_replicas, int num_partitions) { + return GetDefaultDeviceAssignment(num_replicas, num_partitions); + }, + &num_replicas, &num_partitions, &device_assignment)); + if (device_assignment == nullptr) { + return absl::InternalError("device_assignment is nullptr"); + } + if (num_replicas != 1 || num_partitions != 1) { + return absl::InvalidArgumentError( + absl::StrFormat("num_replicas and num_partitions must be 1. " + "num_replicas: %d, num_partitions: %d", + num_replicas, num_partitions)); + } + PjRtLoadedExecutable::LogicalDeviceIds logical_device_ids; + logical_device_ids.replica = 0; + logical_device_ids.partition = 0; + addressable_device_logical_ids.push_back(std::move(logical_device_ids)); + addressable_devices.push_back(&interpreter_device_); + + return std::make_unique( + this, std::move(hlo_module), std::move(evaluator), + dynamic_dimension_inference, std::move(device_assignment), options, + std::move(addressable_device_logical_ids), + std::move(addressable_devices)); +} + +} // namespace xla diff --git a/third_party/xla/xla/pjrt/interpreter/interpreter_client.h b/third_party/xla/xla/pjrt/interpreter/interpreter_client.h new file mode 100644 index 00000000000000..aab0506500a647 --- /dev/null +++ b/third_party/xla/xla/pjrt/interpreter/interpreter_client.h @@ -0,0 +1,454 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_INTERPRETER_INTERPRETER_CLIENT_H_ +#define XLA_PJRT_INTERPRETER_INTERPRETER_CLIENT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/check.h" +#include "absl/log/die_if_null.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/evaluator/hlo_evaluator.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" +#include "xla/literal.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_device_description.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/service/computation_placer.h" +#include "xla/service/dynamic_dimension_inference.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/fingerprint.h" + +namespace xla { + +class InterpreterDescription final : public PjRtDeviceDescription { + public: + static const InterpreterDescription& Singleton(); + + int id() const override { return 0; } + + int process_index() const override { return 0; } + + absl::string_view device_kind() const override { return "interpreter"; } + + absl::string_view DebugString() const override { return "interpreter:0"; } + + absl::string_view ToString() const override { + return "InterpreterDevice(id=0)"; + } + + const absl::flat_hash_map& Attributes() + const override { + return attributes_; + } + + private: + InterpreterDescription() = default; + absl::flat_hash_map attributes_; +}; + +class InterpreterDevice final : public PjRtDevice { + public: + explicit InterpreterDevice(absl::Nonnull client) + : client_(ABSL_DIE_IF_NULL(client)) {} + + // Return the client that owns this device. + PjRtClient* client() const override { return client_; } + + bool IsAddressable() const override { return true; }; + + const InterpreterDescription& description() const override { + return InterpreterDescription::Singleton(); + } + + PjRtLocalDeviceId local_device_id() const override { + return PjRtLocalDeviceId(0); + } + + PjRtLocalHardwareId local_hardware_id() const override { + return PjRtLocalHardwareId(0); + } + + std::unique_ptr CreateAsyncTrackingEvent( + absl::string_view description) const override { + return nullptr; + } + + absl::Status TransferToInfeed(const LiteralSlice& literal) override { + return Unimplemented("Interpreter does not suppot transfer to infeed."); + } + + absl::Status TransferFromOutfeed(MutableBorrowingLiteral literal) override { + return Unimplemented("Interpreter does not support transfer from outfeed."); + } + + absl::Span memory_spaces() const override { + return {}; + } + + absl::StatusOr default_memory_space() const override { + return Unimplemented("default_memory_space not implemented"); + } + + private: + PjRtClient* client_ = nullptr; +}; + +// A buffer that wraps a Literal. +class InterpreterLiteralWrapperBuffer final : public PjRtBuffer { + public: + InterpreterLiteralWrapperBuffer(absl::Nonnull client, + absl::Nonnull device, + const LiteralSlice& literal) + : client_(client), device_(device), literal_(literal.Clone()) {} + InterpreterLiteralWrapperBuffer(absl::Nonnull client, + absl::Nonnull device, + Literal literal) + : client_(client), device_(device), literal_(std::move(literal)) {} + + const Shape& on_device_shape() const override { return literal_.shape(); } + + PjRtMemorySpace* memory_space() const override { return nullptr; } + + PjRtDevice* device() const override { return device_; } + + PjRtClient* client() const override { return client_; } + + absl::StatusOr> AcquireExternalReference() + override { + return absl::UnimplementedError( + "AcquireExternalReference not supported by " + "InterpreterLiteralWrapperBuffer."); + } + + PjRtFuture<> ToLiteral(MutableLiteralBase* literal) override { + return PjRtFuture<>(ShapeUtil::ForEachSubshapeWithStatus( + literal_.shape(), + [&](const Shape& subshape, const ShapeIndex& index) -> absl::Status { + if (!subshape.IsArray()) { + return absl::OkStatus(); + } + const int64_t src_size = literal_.size_bytes(index); + const int64_t dst_size = literal->size_bytes(index); + if (src_size < dst_size) { + return absl::FailedPreconditionError(absl::StrFormat( + "Cannot copy more data than available: Tried to copy %d bytes, " + "but only %d bytes are available (%d < %d).", + dst_size, src_size, src_size, dst_size)); + } + std::memcpy(/*dst=*/literal->untyped_data(index), + /*src=*/literal_.untyped_data(index), dst_size); + return absl::OkStatus(); + })); + } + + PjRtFuture<> LazyToLiteral( + absl::AnyInvocable() &&> generator) + override { + // Underlying buffer is always ready, so we can immediately call the + // generator. + absl::StatusOr literal = std::move(generator)(); + if (!literal.ok()) { + return PjRtFuture<>(literal.status()); + } + return ToLiteral(*literal); + } + + absl::StatusOr GetOnDeviceSizeInBytes() const override { + return literal_.size_bytes(); + } + + PjRtFuture<> CopyRawToHost(void* dst, int64_t offset, + int64_t transfer_size) override { + return PjRtFuture<>(absl::UnimplementedError( + "CopyRawToHost not supported by InterpreterLiteralWrapperBuffer.")); + } + + void Delete() override { + // Delete does not need to do anything for this type of buffer. + // + // This buffer does not support ownership transfers of the underlying + // buffer. The buffer memory is owned by the Literal field, deleted when + // this buffer's object is deleted. + is_deleted_ = true; + } + + absl::StatusOr> + ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete) override { + return absl::UnimplementedError( + "ReleaseDeviceMemoryOwnership not supported by " + "InterpreterLiteralWrapperBuffer."); + } + + bool IsDeleted() override { return is_deleted_; } + + absl::StatusOr> CopyToDevice( + PjRtDevice* dst_device) override { + return absl::UnimplementedError( + "CopyToDevice not supported by InterpreterLiteralWrapperBuffer."); + } + + absl::StatusOr> CopyToMemorySpace( + PjRtMemorySpace* dst_memory_space) override { + return absl::UnimplementedError( + "CopyToMemorySpace not supported by InterpreterLiteralWrapperBuffer."); + } + + void CopyToRemoteDevice(PjRtFuture serialized_descriptor, + RemoteSendCallback on_done) override { + LOG(ERROR) << "InterpreterLiteralWrapperBuffer::CopyToRemoteDevice was " + "called but is not implemented."; + } + + void CopyToRemoteDeviceScattered( + PjRtFuture> serialized_descriptors, + std::vector callbacks, + const ScatterDetails& scatter_details) override { + LOG(ERROR) + << "InterpreterLiteralWrapperBuffer::CopyToRemoteDeviceScattered " + "was called but is not implemented."; + } + + PjRtFuture<> GetReadyFuture() override { + return PjRtFuture<>(absl::OkStatus()); + } + + bool IsOnCpu() const override { return true; } + + const Literal& literal() const { return literal_; } + Literal& mutable_literal() { return literal_; } + + private: + PjRtClient* client_ = nullptr; + PjRtDevice* device_ = nullptr; + Literal literal_; + bool is_deleted_ = false; +}; + +class InterpreterLoadedExecutable final : public PjRtLoadedExecutable { + public: + explicit InterpreterLoadedExecutable( + absl::Nonnull client, std::unique_ptr hlo_module, + std::unique_ptr hlo_evaluator, + std::optional dynamic_dimension_inference, + std::shared_ptr device_assignment, + CompileOptions compile_options, + std::vector addressable_device_logical_ids, + std::vector addressable_devices) + : client_(ABSL_DIE_IF_NULL(client)), + hlo_module_(std::move(hlo_module)), + hlo_evaluator_(std::move(hlo_evaluator)), + dynamic_dimension_inference_(std::move(dynamic_dimension_inference)), + device_assignment_(std::move(device_assignment)), + compile_options_(std::move(compile_options)), + addressable_device_logical_ids_( + std::move(addressable_device_logical_ids)), + addressable_devices_(std::move(addressable_devices)) { + if (dynamic_dimension_inference_.has_value()) { + hlo_evaluator_->set_dynamic_dimension_inference( + &dynamic_dimension_inference_.value()); + } + } + + int num_replicas() const override { + return hlo_module_->config().replica_count(); + } + + int num_partitions() const override { + return hlo_module_->config().num_partitions(); + } + + int64_t SizeOfGeneratedCodeInBytes() const override { return -1; } + + absl::string_view name() const override { return hlo_module_->name(); } + + absl::StatusOr>> GetHloModules() + const override { + std::vector> hlo_modules; + hlo_modules.push_back(hlo_module_); + return hlo_modules; + } + + absl::StatusOr>> + GetOutputMemoryKinds() const override { + return absl::UnimplementedError("GetOutputMemoryKinds is not supported."); + } + + PjRtClient* client() const override { return client_; } + + const DeviceAssignment& device_assignment() const override { + return *device_assignment_; + } + + absl::Span addressable_device_logical_ids() + const override { + return addressable_device_logical_ids_; + } + + absl::Span addressable_devices() const override { + return addressable_devices_; + } + + absl::StatusOr>>> Execute( + absl::Span> argument_handles, + const ExecuteOptions& options, + std::optional>>& returned_futures) override; + + absl::StatusOr>> ExecuteSharded( + absl::Span argument_handles, PjRtDevice* device, + const ExecuteOptions& options, + std::optional>& returned_future, bool fill_future) override; + + absl::StatusOr>> ExecutePortable( + absl::Span argument_handles, PjRtDevice* device, + const ExecuteOptions& options, + std::optional>& returned_future, bool fill_future) override; + + void Delete() override { hlo_module_ = nullptr; } + + bool IsDeleted() override { return hlo_module_ == nullptr; } + + private: + absl::StatusOr Evaluate( + const HloComputation& computation, + absl::Span arg_literals) + ABSL_LOCKS_EXCLUDED(hlo_evaluator_lock_); + + PjRtClient* client_ = nullptr; + std::shared_ptr hlo_module_; + mutable absl::Mutex hlo_evaluator_lock_; + std::unique_ptr hlo_evaluator_ + ABSL_PT_GUARDED_BY(hlo_evaluator_lock_); + std::optional dynamic_dimension_inference_; + std::shared_ptr device_assignment_; + CompileOptions compile_options_; + std::vector addressable_device_logical_ids_; + std::vector addressable_devices_; +}; + +class InterpreterClient final : public PjRtClient { + public: + InterpreterClient() + : interpreter_device_{this}, devices_({&interpreter_device_}) {} + // Not copyable or movable + InterpreterClient(const InterpreterClient&) = delete; + InterpreterClient& operator=(const InterpreterClient&) = delete; + InterpreterClient(InterpreterClient&&) = delete; + InterpreterClient& operator=(InterpreterClient&&) = delete; + + static Shape DeviceShapeRepresentation(const Shape& shape) { return shape; } + + static int64_t ShapeSizeBytes(const Shape& shape) { + if (shape.IsOpaque()) { + return sizeof(void*); + } + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + } + + int process_index() const override { return 0; } + + int device_count() const override { return devices().size(); } + + int addressable_device_count() const override { + return addressable_devices().size(); + } + + absl::Span devices() const override { return devices_; } + + absl::Span addressable_devices() const override { + return devices_; + } + + absl::Span memory_spaces() const override { + return interpreter_device_.memory_spaces(); + } + + PjRtPlatformId platform_id() const override { + static const PjRtPlatformId kPlatformId = tsl::Fingerprint64("interpreter"); + return kPlatformId; + } + + absl::string_view platform_name() const override { return "interpreter"; } + + absl::string_view platform_version() const override { return ""; } + + absl::StatusOr GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const override; + + absl::StatusOr GetDefaultLayout( + PrimitiveType element_type, absl::Span dims) override; + + absl::StatusOr> GetHloCostAnalysis() + const override { + return std::make_unique(ShapeSizeBytes); + } + + absl::StatusOr> Compile( + const XlaComputation& computation, CompileOptions options) override; + + absl::StatusOr> Compile( + mlir::ModuleOp module, CompileOptions options) override; + + absl::StatusOr> BufferFromHostLiteral( + const LiteralSlice& literal, PjRtDevice* device) override; + + absl::StatusOr> BufferFromHostLiteral( + const LiteralSlice& literal, PjRtDevice* device, + const Layout* device_layout) override; + + private: + absl::StatusOr> CompileInternal( + const XlaComputation& computation, + const std::vector& argument_shapes, + LayoutCanonicalizationCallback layout_canonicalization_callback, + CompileOptions options); + absl::StatusOr> RunHloPasses( + std::unique_ptr hlo_module); + absl::StatusOr> RunBackend( + std::unique_ptr hlo_module, CompileOptions& options); + + InterpreterDevice interpreter_device_; + // Pointer array of devices (just one) so that we can create a span of it. + std::array devices_; +}; +} // namespace xla + +#endif // XLA_PJRT_INTERPRETER_INTERPRETER_CLIENT_H_ diff --git a/third_party/xla/xla/pjrt/local_device_state.cc b/third_party/xla/xla/pjrt/local_device_state.cc index 51b4257bfff965..152c87844fbb30 100644 --- a/third_party/xla/xla/pjrt/local_device_state.cc +++ b/third_party/xla/xla/pjrt/local_device_state.cc @@ -128,6 +128,16 @@ LocalDeviceState::~LocalDeviceState() { if (!status.ok()) { LOG(ERROR) << "Error when closing device: " << status; } + + // Explicitly delete all the streams to ensure that their callbacks are + // executed before the destruction of the LocalDeviceState and its callback + // threads. + external_ready_event_streams_.clear(); + fixed_size_pool_usage_streams_.clear(); + device_to_device_streams_.clear(); + device_to_host_streams_.clear(); + host_to_device_stream_.reset(); + compute_stream_.reset(); } absl::Status LocalDeviceState::SynchronizeAllActivity() { diff --git a/third_party/xla/xla/pjrt/lru_cache_test.cc b/third_party/xla/xla/pjrt/lru_cache_test.cc index 1c091bb1188a3f..c731d4a5e1627f 100644 --- a/third_party/xla/xla/pjrt/lru_cache_test.cc +++ b/third_party/xla/xla/pjrt/lru_cache_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" namespace xla { namespace { diff --git a/third_party/xla/xla/pjrt/mlir_to_hlo.cc b/third_party/xla/xla/pjrt/mlir_to_hlo.cc index 830e10f4502093..88084e290ab30f 100644 --- a/third_party/xla/xla/pjrt/mlir_to_hlo.cc +++ b/third_party/xla/xla/pjrt/mlir_to_hlo.cc @@ -82,6 +82,9 @@ absl::Status MlirToXlaComputation(mlir::ModuleOp module, mlir::BaseScopedDiagnosticHandler diagnostic_handler(context); { mlir::PassManager pm(context); + // Expand stablehlo complex math functions such as log_plus_one, etc. + pm.addNestedPass( + mlir::stablehlo::createStablehloComplexMathExpanderPass()); pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); pm.addNestedPass( mlir::mhlo::createChloLegalizeToHloPass()); @@ -107,7 +110,7 @@ absl::Status MlirToXlaComputation(mlir::ModuleOp module, if (use_tuple_args && use_shardy) { // Shardy can't handle tuple args when round-tripping. So delay using // tuples until after Shardy is run. - sdy::addFrontendAttribute(module, sdy::kUseTupleArgs, + sdy::setFrontendAttribute(module, sdy::kUseTupleArgs, mlir::StringAttr::get(context, "t")); use_tuple_args = false; } @@ -223,6 +226,10 @@ absl::StatusOr SerializeUsingVersionedStablehlo( // Legalize CHLO -> [StableHLO+Shape] -> StableHLO // Preserve higher-level ops with XLA support. To be replaced by composites. mlir::PassManager pm(context); + // Expand stablehlo complex math functions such as log_plus_one, etc. + pm.addNestedPass( + mlir::stablehlo::createStablehloComplexMathExpanderPass()); + xla::sdy::addSdyRoundTripExportPipeline(pm); pm.addNestedPass( mlir::mhlo::createChloLegalizeToHighLevelMhloPass()); diff --git a/third_party/xla/xla/pjrt/mlir_to_hlo_test.cc b/third_party/xla/xla/pjrt/mlir_to_hlo_test.cc index 4e7b2610f4bcbe..21c98138ff82f4 100644 --- a/third_party/xla/xla/pjrt/mlir_to_hlo_test.cc +++ b/third_party/xla/xla/pjrt/mlir_to_hlo_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" #include "stablehlo/api/PortableApi.h" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc index ca1066d46db6e0..1ef07811cf53df 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc @@ -52,6 +52,7 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/c/pjrt_c_api_layouts_extension.h" +#include "xla/pjrt/c/pjrt_c_api_memory_descriptions_extension.h" #include "xla/pjrt/c/pjrt_c_api_profiler_extension.h" #include "xla/pjrt/c/pjrt_c_api_stream_extension.h" #include "xla/pjrt/compile_options.pb.h" @@ -70,6 +71,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/framework/allocator.h" +#include "xla/tsl/platform/status.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" @@ -167,7 +169,6 @@ void PjRtCApiClient::InitDevicesAndMemorySpaces() { } // Initialize addressable memory spaces. - // TODO(yueshengys): Initialize global memory spaces when supported. PJRT_Client_AddressableMemories_Args memory_args; memory_args.struct_size = PJRT_Client_AddressableMemories_Args_STRUCT_SIZE; memory_args.extension_start = nullptr; @@ -688,15 +689,219 @@ absl::StatusOr PjRtCApiClient::GetDefaultLayout( std::string serialized_layout(serialize_args.serialized_bytes, serialize_args.serialized_bytes_size); - TF_ASSIGN_OR_RETURN(PjRtXlaLayout pjrt_xla_layout, - PjRtXlaLayout::Deserialize(serialized_layout)); + TF_ASSIGN_OR_RETURN(std::shared_ptr pjrt_layout, + PjRtLayout::Deserialize(serialized_layout)); - return pjrt_xla_layout.xla_layout(); + return pjrt_layout->xla_layout(); +} + +class PjRtCApiAsyncHostToDeviceTransferManager + : public PjRtClient::AsyncHostToDeviceTransferManager { + public: + PjRtCApiAsyncHostToDeviceTransferManager( + PjRtCApiClient* client, + PJRT_AsyncHostToDeviceTransferManager* c_transfer_manager) + : c_client_(client), c_transfer_manager_(std::move(c_transfer_manager)) {} + + size_t buffer_count() const override { + LOG(FATAL) << "PJRT C API does not support buffer_count. Please " + "report an issue at https://github.com/google/jax/issues if " + "you need " + "this feature."; + } + + PjRtDevice* device() const override { + LOG(FATAL) << "PJRT C API does not support device. Please " + "report an issue at https://github.com/google/jax/issues if " + "you need " + "this feature."; + } + + std::unique_ptr RetrieveBuffer(int buffer_index) override { + LOG(FATAL) << "PJRT C API does not support RetrieveBuffer. Please " + "report an issue at https://github.com/google/jax/issues if " + "you need " + "this feature."; + } + + absl::Status TransferLiteralToBuffer( + int buffer_index, const LiteralSlice& literal, + absl::AnyInvocable on_done) override { + return Unimplemented( + "PJRT C API does not support TransferLiteralToBuffer. Please report an " + "issue at https://github.com/google/jax/issues if you need this " + "feature."); + } + + size_t buffer_size(int buffer_index) const override { + LOG(FATAL) + << "PJRT C API does not support buffer_size. Please report an " + "issue at https://github.com/google/jax/issues if you need this " + "feature."; + } + + absl::Status TransferRawDataToBuffer( + int buffer_index, absl::string_view data, + absl::AnyInvocable on_done) override { + return TransferRawDataToSubBuffer(buffer_index, data.data(), 0, data.size(), + /*is_last_transfer=*/true, + std::move(on_done)); + } + + absl::Status TransferRawDataToSubBuffer( + int buffer_index, const void* data, int64_t offset, int64_t transfer_size, + bool is_last_transfer, absl::AnyInvocable on_done) override { + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args args; + args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_TransferData_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.transfer_manager = c_transfer_manager_.get(); + args.buffer_index = buffer_index; + args.data = data; + args.offset = offset; + args.transfer_size = transfer_size; + args.is_last_transfer = is_last_transfer; + const PJRT_Api* api = c_client_->pjrt_c_api(); + RETURN_STATUS_IF_PJRT_ERROR( + api->PJRT_AsyncHostToDeviceTransferManager_TransferData(&args), api); + std::unique_ptr event( + args.done_with_h2d_transfer, ::pjrt::MakeEventDeleter(api)); + if (on_done) { + PJRT_Event_OnReady_Args event_args; + event_args.struct_size = PJRT_Event_OnReady_Args_STRUCT_SIZE; + event_args.extension_start = nullptr; + event_args.event = event.get(); + event_args.user_arg = new absl::AnyInvocable( + [on_done = std::move(on_done), + c_api = api](PJRT_Error* error) mutable { + if (error) { + ::pjrt::MakeErrorDeleter(c_api)(error); + } + std::move(on_done)(); + }); + event_args.callback = [](PJRT_Error* error, void* args) { + auto* on_done_with_d2h_transfer = + reinterpret_cast*>(args); + (*on_done_with_d2h_transfer)(error); + delete on_done_with_d2h_transfer; + }; + + RETURN_STATUS_IF_PJRT_ERROR(api->PJRT_Event_OnReady(&event_args), api); + } + return absl::OkStatus(); + } + + void SetBufferError(int buffer_index, absl::Status error) override { + LOG(FATAL) << "PJRT C API does not support SetBufferError. Please " + "report an issue at https://github.com/google/jax/issues if " + "you need " + "this feature."; + } + + using TransferMetadata = absl::flat_hash_map; + void AddTransferMetadata(const TransferMetadata& metadata) override { + LOG(FATAL) << "PJRT C API does not support AddTransferMetadata. Please " + "report an issue at https://github.com/google/jax/issues if " + "you need " + "this feature."; + } + + private: + PjRtCApiClient* c_client_; + std::unique_ptr + c_transfer_manager_; +}; + +absl::StatusOr> +PjRtCApiClient::CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional>> device_layouts, + PjRtMemorySpace* memory_space) { + const PJRT_Api* c_api = pjrt_c_api(); + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args args; + args.struct_size = + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.client = c_client_.get(); + args.num_shape_specs = shape_specs.size(); + args.shape_specs = new PJRT_ShapeSpec[shape_specs.size()]; + absl::Cleanup cleanup = + absl::MakeCleanup([&args] { delete[] args.shape_specs; }); + const ShapeSpec* iterator = shape_specs.begin(); + for (int i = 0; i < shape_specs.size(); ++i) { + args.shape_specs[i] = pjrt::ConvertToPjRtShapeSpec(*(iterator++)); + } + if (device_layouts.has_value()) { + args.num_device_layouts = device_layouts->size(); + auto device_layout_list = + std::make_unique>( + device_layouts->size()); + for (int i = 0; i < device_layouts->size(); ++i) { + if (device_layouts.has_value() && (*device_layouts)[i].has_value()) { + const Layout& layout = (*device_layouts)[i].value(); + TF_ASSIGN_OR_RETURN(pjrt::BufferMemoryLayoutData c_layout_data, + pjrt::ConvertToBufferMemoryLayoutData(layout)); + device_layout_list->emplace_back(&(c_layout_data.c_layout)); + } else { + device_layout_list->emplace_back(nullptr); + } + } + args.device_layouts = device_layout_list->data(); + } else { + args.num_device_layouts = 0; + args.device_layouts = nullptr; + } + args.memory = + tensorflow::down_cast(memory_space)->c_memory(); + + RETURN_STATUS_IF_PJRT_ERROR( + c_api->PJRT_Client_CreateBuffersForAsyncHostToDevice(&args), c_api); + return std::make_unique( + this, args.transfer_manager); +} + +absl::StatusOr> +PjRtCApiClient::CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional>> device_layouts, + PjRtDevice* device) { + TF_ASSIGN_OR_RETURN(auto memory_space, device->default_memory_space()); + return CreateBuffersForAsyncHostToDevice(shape_specs, device_layouts, + memory_space); +} + +absl::StatusOr> +PjRtCApiClient::CreateBuffersForAsyncHostToDevice( + absl::Span shapes, PjRtDevice* device) { + absl::InlinedVector shape_specs; + shape_specs.reserve(shapes.size()); + for (const auto& shape : shapes) { + shape_specs.emplace_back(PjRtClient::ShapeSpec{ + shape.element_type(), + DimensionVector(shape.dimensions().begin(), shape.dimensions().end())}); + } + return CreateBuffersForAsyncHostToDevice( + shape_specs, /*device_layouts=*/std::nullopt, device); +} + +absl::StatusOr> +PjRtCApiClient::CreateBuffersForAsyncHostToDevice( + absl::Span shapes, PjRtMemorySpace* memory_space) { + absl::InlinedVector shape_specs; + shape_specs.reserve(shapes.size()); + for (const auto& shape : shapes) { + shape_specs.emplace_back(PjRtClient::ShapeSpec{ + shape.element_type(), + DimensionVector(shape.dimensions().begin(), shape.dimensions().end())}); + } + return CreateBuffersForAsyncHostToDevice( + shape_specs, /*device_layouts=*/std::nullopt, memory_space); } const PJRT_Api* PjRtCApiClient::pjrt_c_api() const { return c_api_; } -// --------------------------------- Devices ----------------------------------- +// --------------------------------- Device Descriptions ----------------------- PjRtCApiDeviceDescription::PjRtCApiDeviceDescription( const PJRT_Api* c_api, PJRT_DeviceDescription* device_description) @@ -809,6 +1014,40 @@ absl::string_view PjRtCApiDeviceDescription::ToString() const { return to_string; } +void PjRtCApiDeviceDescription::InitMemoryDescriptions() const { + const PJRT_MemoryDescriptions_Extension* extension = + pjrt::FindExtension( + c_api_, PJRT_Extension_Type::PJRT_Extension_Type_MemoryDescriptions); + if (!extension) return; + + if (memory_space_description_pointers_.empty()) { + memory_space_descriptions_ = pjrt::GetMemorySpaceDescriptions( + device_description_, c_api_, &default_memory_space_description_); + for (int i = 0; i < memory_space_descriptions_.size(); i++) { + memory_space_description_pointers_.push_back( + &memory_space_descriptions_[i]); + } + } +} + +absl::Span +PjRtCApiDeviceDescription::memory_spaces() const { + if (memory_space_description_pointers_.empty()) { + InitMemoryDescriptions(); + } + return memory_space_description_pointers_; +} + +absl::StatusOr +PjRtCApiDeviceDescription::default_memory_space() const { + if (memory_space_description_pointers_.empty()) { + InitMemoryDescriptions(); + } + return default_memory_space_description_; +} + +// ------------------------------- Devices ------------------------------------- + PjRtCApiDevice::PjRtCApiDevice(PJRT_Device* device, PjRtCApiClient* client) : client_(client), device_(device), @@ -1797,16 +2036,17 @@ absl::Span PjRtCApiBuffer::dimensions() const { return absl::Span(args.dims, args.num_dims); } -std::unique_ptr PjRtCApiBuffer::layout() const { +std::shared_ptr PjRtCApiBuffer::layout() const { { absl::MutexLock lock(&mu_); - if (!layout_.has_value()) { + if (layout_ == nullptr) { const PJRT_Api* c_api = pjrt_c_api(); PJRT_Layouts_Extension* extension = pjrt::FindExtension( c_api, PJRT_Extension_Type::PJRT_Extension_Type_Layouts); if (extension == nullptr) { - layout_.emplace(LayoutUtil::MakeDescendingLayout(dimensions().size())); + layout_ = std::make_shared( + LayoutUtil::MakeDescendingLayout(dimensions().size())); } else { std::unique_ptr @@ -1831,14 +2071,14 @@ std::unique_ptr PjRtCApiBuffer::layout() const { std::string serialized_layout(serialize_args.serialized_bytes, serialize_args.serialized_bytes_size); - absl::StatusOr pjrt_xla_layout = - PjRtXlaLayout::Deserialize(serialized_layout); - TF_CHECK_OK(pjrt_xla_layout.status()); - layout_.emplace(*pjrt_xla_layout); + absl::StatusOr> pjrt_layout = + PjRtLayout::Deserialize(serialized_layout); + TF_CHECK_OK(pjrt_layout.status()); + layout_ = *std::move(pjrt_layout); } } } - return std::make_unique(*layout_); + return layout_; } bool PjRtCApiBuffer::has_dynamic_dimensions() const { @@ -2374,6 +2614,8 @@ absl::StatusOr> WrapClientAroundCApi( kv_callback_data = pjrt::ConvertToCKeyValueCallbacks(kv_store); init_args.kv_get_callback = kv_callback_data->c_kv_get; init_args.kv_get_user_arg = &kv_callback_data->kv_get_c_func; + init_args.kv_try_get_callback = kv_callback_data->c_kv_try_get; + init_args.kv_try_get_user_arg = &kv_callback_data->kv_try_get_c_func; init_args.kv_put_callback = kv_callback_data->c_kv_put; init_args.kv_put_user_arg = &kv_callback_data->kv_put_c_func; } diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.h b/third_party/xla/xla/pjrt/pjrt_c_api_client.h index 3897b023427169..0c8500c5818b10 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.h +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.h @@ -80,15 +80,28 @@ class PjRtCApiDeviceDescription : public PjRtDeviceDescription { const absl::flat_hash_map& Attributes() const override; + absl::Span memory_spaces() + const override; + + absl::StatusOr default_memory_space() + const override; + private: const PJRT_Api* c_api_; // `device_description_` is owned by the `PJRT_Client` wrapped by `client_` PJRT_DeviceDescription* device_description_; // Device specific attributes with corresponding values. absl::flat_hash_map attributes_; + mutable std::vector memory_space_descriptions_; + mutable std::vector + memory_space_description_pointers_; + mutable absl::StatusOr + default_memory_space_description_; // Initializes device specific attributes. void InitAttributes(); + // Initialize device specific memory descriptions. + void InitMemoryDescriptions() const; }; class PjRtCApiMemorySpace : public PjRtMemorySpace { @@ -318,23 +331,25 @@ class PjRtCApiClient : public PjRtClient { absl::StatusOr GetTopologyDescription() const override; + absl::StatusOr> + CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional>> device_layouts, + PjRtDevice* device) override; + absl::StatusOr> + CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional>> device_layouts, + PjRtMemorySpace* memory_space) override; + + absl::StatusOr> CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtDevice* device) override { - return Unimplemented( - "PJRT C API does not support CreateBuffersForAsyncHostToDevice. Please " - "report an issue at https://github.com/google/jax/issues if you need " - "this feature."); - } + PjRtDevice* device) override; absl::StatusOr> CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtMemorySpace* memory_space) override { - return Unimplemented( - "PJRT C API does not support CreateBuffersForAsyncHostToDevice. Please " - "report an issue at https://github.com/google/jax/issues if you need " - "this feature."); - } + PjRtMemorySpace* memory_space) override; absl::StatusOr> BufferFromHostBuffer( const void* data, PrimitiveType type, absl::Span dims, @@ -393,28 +408,12 @@ class PjRtCApiClient : public PjRtClient { "this feature."); } - absl::StatusOr CreateChannelHandle() override { - return Unimplemented( - "PJRT C API does not support CreateChannelHandle. Please report an " - "issue at https://github.com/google/jax/issues if you need this " - "feature."); - } - - absl::StatusOr CreateDeviceToHostChannelHandle() override { - return Unimplemented( - "PJRT C API does not support CreateDeviceToHostChannelHandle. Please " - "report an issue at https://github.com/google/jax/issues if you need " - "this feature."); - } - absl::Status Defragment() override { return Unimplemented( "PJRT C API does not support Defragment. Please report an issue at " "https://github.com/google/jax/issues if you need this feature."); } - bool SupportsSendRecvCallbacks() const override { return true; } - const PJRT_Api* pjrt_c_api() const; PJRT_Client* pjrt_c_client() { return c_client_.get(); } @@ -456,8 +455,6 @@ class PjRtCApiClient : public PjRtClient { std::vector addressable_devices_; absl::flat_hash_map c_to_cpp_device_map_; std::vector> owned_memory_spaces_; - // TODO(yueshengys): Add a `memory_spaces_` member when global memories are - // supported. std::vector addressable_memory_spaces_; absl::flat_hash_map c_to_cpp_memory_map_; // There may be an error fetching the topology desc via the C API @@ -479,7 +476,7 @@ class PjRtCApiBuffer : public PjRtBuffer { absl::Span dimensions() const override; - std::unique_ptr layout() const override; + std::shared_ptr layout() const override; // PJRT C API doesn't support tuple buffers. bool IsTuple() const override { return false; } @@ -577,7 +574,7 @@ class PjRtCApiBuffer : public PjRtBuffer { // we set on `readiness_event` modifies `readiness_promise_`. std::shared_ptr::Promise> readiness_promise_; // Set and cached the first time layout() is called. - mutable std::optional layout_; + mutable std::shared_ptr layout_; // Set and cached the first time is_dynamic_dimension() is called. mutable std::optional> is_dynamic_dimension_; diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc index 58b8eead2be920..8dfcc5b07e5499 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/strings/str_format.h" +#include "absl/types/span.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" @@ -39,6 +40,7 @@ limitations under the License. #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -114,7 +116,7 @@ TEST(PjRtCApiClientTest, PlatformId) { EXPECT_EQ(client->platform_id(), xla::CpuId()); } -TEST(PjRtCApiClientTest, EmptyExecutableFingerprint) { +TEST(PjRtCApiClientTest, NonEmptyExecutableFingerprint) { SetUpCpuPjRtApi(); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, GetCApiClient("cpu")); @@ -130,8 +132,9 @@ TEST(PjRtCApiClientTest, EmptyExecutableFingerprint) { PjRtCApiClient* c_client = dynamic_cast(client.get()); ASSERT_NE(c_client, nullptr); - if (c_client->pjrt_c_api()->pjrt_api_version.minor_version >= 35) { - // Empty executable should return an error status. + if (c_client->pjrt_c_api()->pjrt_api_version.minor_version >= 58) { + EXPECT_TRUE(executable->FingerprintExecutable().ok()); + } else if (c_client->pjrt_c_api()->pjrt_api_version.minor_version >= 35) { EXPECT_FALSE(executable->FingerprintExecutable().ok()); } else { // TODO(yeounoh): To be removed after 01/20/2024. @@ -140,6 +143,19 @@ TEST(PjRtCApiClientTest, EmptyExecutableFingerprint) { } } +TEST(PjRtCApiClientTest, CreateBuffersForAsyncHostToDeviceWithShape) { + SetUpCpuPjRtApi(); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetCApiClient("cpu")); + xla::Shape host_shape = xla::ShapeUtil::MakeShapeWithDenseLayout( + xla::PrimitiveType::F32, /*dimensions=*/{2, 2, 2}, + /*minor_to_major=*/{1, 0, 2}); + std::vector host_shapes = {host_shape}; + auto status_or_transfer_manager = client->CreateBuffersForAsyncHostToDevice( + absl::MakeSpan(host_shapes), client->addressable_devices()[0]); + EXPECT_FALSE(status_or_transfer_manager.ok()); +} + TEST(PjRtClientTest, CreateViewAndCopyToDeviceAsyncExternalCpuOnly) { SetUpCpuPjRtApi(); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, @@ -196,6 +212,25 @@ TEST(PjRtClientTest, CompileUsesStableHloVersion) { const_cast(c_api)->PJRT_Client_Compile = PJRT_Client_Compile_Orig; } +TEST(PjRtClientTest, CanQueryMemoryDescriptions) { + SetUpCpuPjRtApi(); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetCApiClient("cpu")); + TF_ASSERT_OK_AND_ASSIGN(const PjRtTopologyDescription* topology, + client->GetTopologyDescription()); + std::vector> devices = + topology->DeviceDescriptions(); + for (std::unique_ptr& device : devices) { + for (const PjRtMemorySpaceDescription* memory : device->memory_spaces()) { + // TODO: CPU doesn't currently have memory descriptions, so the + // code below doesn't get triggered yet. + EXPECT_NE(memory, nullptr); + EXPECT_GT(memory->kind().size(), 0); + EXPECT_GE(memory->kind_id(), 0); + } + } +} + TEST(PjRtCApiClientTest, WrapClientAroundCApi) { const PJRT_Api* c_api = ::pjrt::cpu_plugin::GetCpuPjrtApi(); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, diff --git a/third_party/xla/xla/pjrt/pjrt_client.h b/third_party/xla/xla/pjrt/pjrt_client.h index 1162e172a39d5b..7beb4c1a9da921 100644 --- a/third_party/xla/xla/pjrt/pjrt_client.h +++ b/third_party/xla/xla/pjrt/pjrt_client.h @@ -1070,25 +1070,12 @@ class PjRtClient { "MakeCrossHostReceiveBuffersForGather is not implemented."); } - // Create ChannelHandles for XLA send/recv. - virtual absl::StatusOr CreateChannelHandle() { - return Unimplemented("CreateChannelHandle is not implemented."); - } - virtual absl::StatusOr CreateDeviceToHostChannelHandle() { - return Unimplemented("CreateDeviceToHostChannelHandle is not implemented."); - } - // TODO(zhangqiaorjc): Experimental API to be removed. // Defragment device memory. virtual absl::Status Defragment() { return Unimplemented("Defragment is not implemented."); } - // If false, this client does not support send/recv host callbacks, and - // callers should not set the `send_callbacks` and `recv_callbacks` arguments - // in ExecuteOptions. - virtual bool SupportsSendRecvCallbacks() const { return false; } - // Return the PjRtHostMemoryForDeviceManager for this client. It can be // nullptr if the implementation does not provide one. virtual PjRtHostMemoryForDeviceManager* GetPjRtHostMemoryForDeviceManager() @@ -1121,12 +1108,12 @@ class PjRtBuffer { return on_device_shape().dimensions(); } - // The on-device memory layout of this buffer. Returned via unique_ptr to make + // The on-device memory layout of this buffer. Returned via shared_ptr to make // memory management easier -- PjRtLayout is an abstract base class, so cannot // be easily copied. - virtual std::unique_ptr layout() const { + virtual std::shared_ptr layout() const { CHECK(on_device_shape().has_layout()); - return std::make_unique(on_device_shape().layout()); + return std::make_shared(on_device_shape().layout()); } // PjRtBuffers can either represent a single array buffer or a tuple of array @@ -1244,9 +1231,13 @@ class PjRtBuffer { } else { literal_dims = dimensions(); } - device_shape = ShapeUtil::MakeShape(element_type(), literal_dims); - // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout - *device_shape.mutable_layout() = GetXlaLayoutUnsafe(layout()); + if (element_type() == TOKEN) { + device_shape = ShapeUtil::MakeTokenShape(); + } else { + device_shape = ShapeUtil::MakeShape(element_type(), literal_dims); + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + *device_shape.mutable_layout() = layout()->xla_layout(); + } } else { // TODO(skyewm): does anything need to create tuple literals? The PJRT C // API doesn't support tuples or {logical_}on_device_shape(), so we prefer diff --git a/third_party/xla/xla/pjrt/pjrt_client_test.cc b/third_party/xla/xla/pjrt/pjrt_client_test.cc index 64e3552ded666a..c9e4369f6fdeaa 100644 --- a/third_party/xla/xla/pjrt/pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_client_test.cc @@ -30,11 +30,11 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/test.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tests/literal_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -350,9 +350,11 @@ TEST_P(PjRtClientTest, ExecuteWithConcurrentUsageAndDonation) { auto& results = *results_or; CHECK_EQ(results.size(), 1); CHECK_EQ(results[0].size(), 1); - auto literal = results[0][0]->ToLiteralSync().value(); - CHECK(LiteralTestUtil::Equal(LiteralUtil::CreateR1(expected), - *literal)); + auto literal_or = results[0][0]->ToLiteralSync(); + if (literal_or.ok()) { + CHECK(LiteralTestUtil::Equal(LiteralUtil::CreateR1(expected), + *literal_or.value())); + } } blocking_counter.DecrementCount(); }); diff --git a/third_party/xla/xla/pjrt/pjrt_device_description.h b/third_party/xla/xla/pjrt/pjrt_device_description.h index 77107fdc495c71..95e2367a757268 100644 --- a/third_party/xla/xla/pjrt/pjrt_device_description.h +++ b/third_party/xla/xla/pjrt/pjrt_device_description.h @@ -17,7 +17,6 @@ limitations under the License. #define XLA_PJRT_PJRT_DEVICE_DESCRIPTION_H_ #include -#include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" @@ -37,7 +36,7 @@ class PjRtMemorySpaceDescription { // A platform-dependent string that uniquely identifies the kind of the // memory space. - absl::string_view kind() const { return kind_; } + absl::string_view kind() const { return absl::string_view(kind_); } // An ID uniquely identifies the kind of the memory space among those attached // to the same `PjRtClient`. The IDs assigned to a kind is implementation @@ -45,7 +44,7 @@ class PjRtMemorySpaceDescription { int kind_id() const { return kind_id_; } private: - absl::string_view kind_; + std::string kind_; int kind_id_; }; @@ -68,15 +67,15 @@ class PjRtDeviceDescription { // A vendor-dependent string that uniquely identifies the kind of device, // e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are // compatible compilation. - virtual std::string_view device_kind() const = 0; + virtual absl::string_view device_kind() const = 0; // Debug string suitable for logging when errors occur. Should be verbose // enough to describe the current device unambiguously. - virtual std::string_view DebugString() const = 0; + virtual absl::string_view DebugString() const = 0; // Debug string suitable for reading by end users, should be reasonably terse, // for example: "CpuDevice(id=0)". - virtual std::string_view ToString() const = 0; + virtual absl::string_view ToString() const = 0; // Returns vendor specific attributes about the device. For example the model // number of a GPU, or the mesh coordinates of a TPU device. The returned diff --git a/third_party/xla/xla/pjrt/pjrt_executable.cc b/third_party/xla/xla/pjrt/pjrt_executable.cc index 79fea677871222..bec43a0487ac62 100644 --- a/third_party/xla/xla/pjrt/pjrt_executable.cc +++ b/third_party/xla/xla/pjrt/pjrt_executable.cc @@ -208,7 +208,7 @@ absl::StatusOr ExecuteOptions::FromProto( return options; } -CompiledMemoryStatsProto CompiledMemoryStats::ToProto() { +CompiledMemoryStatsProto CompiledMemoryStats::ToProto() const { CompiledMemoryStatsProto proto; proto.set_generated_code_size_in_bytes(generated_code_size_in_bytes); proto.set_argument_size_in_bytes(argument_size_in_bytes); @@ -422,7 +422,7 @@ PjRtExecutable::GetOutputDimensions() const { return output_dimensions; } -absl::StatusOr>> +absl::StatusOr>> PjRtExecutable::GetParameterLayouts() const { TF_ASSIGN_OR_RETURN(std::vector> hlo_modules, GetHloModules()); @@ -439,15 +439,15 @@ PjRtExecutable::GetParameterLayouts() const { ComputationLayout comp_layout = hlo_modules[0]->entry_computation_layout(); TF_ASSIGN_OR_RETURN(std::vector layouts, comp_layout.FlattenedParameterLayouts()); - std::vector> result; + std::vector> result; result.reserve(layouts.size()); for (const Layout& layout : layouts) { - result.push_back(std::make_unique(layout)); + result.push_back(std::make_shared(layout)); } return result; } -absl::StatusOr>> +absl::StatusOr>> PjRtExecutable::GetOutputLayouts() const { TF_ASSIGN_OR_RETURN(std::vector> hlo_modules, GetHloModules()); @@ -464,10 +464,10 @@ PjRtExecutable::GetOutputLayouts() const { ComputationLayout comp_layout = hlo_modules[0]->entry_computation_layout(); TF_ASSIGN_OR_RETURN(std::vector layouts, comp_layout.FlattenedResultLayouts()); - std::vector> result; + std::vector> result; result.reserve(layouts.size()); for (const Layout& layout : layouts) { - result.push_back(std::make_unique(layout)); + result.push_back(std::make_shared(layout)); } return result; } @@ -667,6 +667,10 @@ absl::Status CompileOptions::ApplyOptionFromString( } return absl::OkStatus(); } else { + if (value.empty() && field->is_repeated()) { + reflection->ClearField(&debug_options, field); + return absl::OkStatus(); + } auto enum_desc = field->enum_type()->FindValueByName(value); if (enum_desc != nullptr) { if (field->is_repeated()) { diff --git a/third_party/xla/xla/pjrt/pjrt_executable.h b/third_party/xla/xla/pjrt/pjrt_executable.h index f5f4aa89dcece3..1244039ede0cd1 100644 --- a/third_party/xla/xla/pjrt/pjrt_executable.h +++ b/third_party/xla/xla/pjrt/pjrt_executable.h @@ -101,7 +101,9 @@ struct CompileOptions { // Key-value string pairs, parsed in order to set miscellaneous options, // overriding if appropriate. using OptionOverride = std::variant; - std::vector> env_option_overrides; + using EnvironmentOptionOverrides = + std::vector>; + EnvironmentOptionOverrides env_option_overrides; std::optional target_config; @@ -295,7 +297,7 @@ struct CompiledMemoryStats { std::string serialized_hlo_proto = ""; std::string DebugString() const; - CompiledMemoryStatsProto ToProto(); + CompiledMemoryStatsProto ToProto() const; static CompiledMemoryStats FromProto(const CompiledMemoryStatsProto& proto); @@ -335,11 +337,11 @@ class PjRtExecutable { GetOutputDimensions() const; // Returns the layout of each input parameter. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetParameterLayouts() const; // Returns the layout of each output. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetOutputLayouts() const; // Returns a list of lists of memory kind strings for output. The returned diff --git a/third_party/xla/xla/pjrt/pjrt_layout.h b/third_party/xla/xla/pjrt/pjrt_layout.h index eea9b861690860..e4318102bf7c1c 100644 --- a/third_party/xla/xla/pjrt/pjrt_layout.h +++ b/third_party/xla/xla/pjrt/pjrt_layout.h @@ -20,93 +20,54 @@ limitations under the License. #include #include -#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/layout.h" -#include "tsl/platform/casts.h" #include "tsl/platform/statusor.h" namespace xla { -// Abstract class representing the memory layout of a PjRtBuffer. +// Represents the memory layout of a PjRtBuffer. class PjRtLayout { public: - virtual ~PjRtLayout() = default; - - // Returns the serialized layout as a string. - // TODO(b/328671718): add generic deserialize method to PjRtClient and/or - // PjRtCompiler. - virtual std::string Serialize() const = 0; - - // Human-readable string for error messages, user introspection, etc. - virtual std::string ToString() const = 0; - - virtual bool operator==(const PjRtLayout& other) const = 0; - - template - friend H AbslHashValue(H state, const PjRtLayout& layout) { - layout.Hash(absl::HashState::Create(&state)); - return std::move(state); - } - - protected: - virtual void Hash(absl::HashState state) const = 0; -}; - -// PjRtLayout backed by an xla::Layout. This is a convenience class for PJRT -// implementations that use XLA. PJRT users should use the PjRtLayout interface -// to be compatible with all implementations, e.g. PjRtCApiClient which doesn't -// have access to full xla::Layouts. -class PjRtXlaLayout : public PjRtLayout { - public: - explicit PjRtXlaLayout(Layout layout) : xla_layout_(std::move(layout)) { + explicit PjRtLayout(Layout layout) : xla_layout_(std::move(layout)) { // Strip memory space and set it to the default. PJRT tracks memory space // separately from layout. xla_layout_.set_memory_space(xla::Layout::kDefaultMemorySpace); } - std::string Serialize() const override { return xla_layout_.ToString(); } + PjRtLayout(PjRtLayout& other) = delete; + PjRtLayout& operator=(const PjRtLayout& other) = delete; - static absl::StatusOr Deserialize( + static absl::StatusOr> Deserialize( absl::string_view serialized) { TF_ASSIGN_OR_RETURN(Layout xla_layout, ParseLayout(serialized)); - return PjRtXlaLayout(std::move(xla_layout)); + return std::make_shared(std::move(xla_layout)); } - std::string ToString() const override { return xla_layout_.ToString(); } + const Layout& xla_layout() const { return xla_layout_; } - bool operator==(const PjRtLayout& other) const override { - auto xla_other = dynamic_cast(&other); - if (xla_other == nullptr) { - return false; - } - return xla_layout_ == xla_other->xla_layout_; - }; + // Returns the serialized layout as a string. + std::string Serialize() const { return xla_layout_.ToString(); } - const Layout& xla_layout() const { return xla_layout_; } + // Human-readable string for error messages, user introspection, etc. + std::string ToString() const { return xla_layout_.ToString(); } - protected: - void Hash(absl::HashState state) const override { - absl::HashState::combine(std::move(state), xla_layout_); + bool operator==(const PjRtLayout& other) const { + return xla_layout_ == other.xla_layout_; + } + + template + friend H AbslHashValue(H state, const PjRtLayout& layout) { + return H::combine(std::move(state), layout.xla_layout_); } private: Layout xla_layout_; }; -// TODO(b/327524065): make callers use PjRtLayout directly instead of assuming -// an xla::Layout and get rid of this function. -inline Layout GetXlaLayoutUnsafe( - const std::unique_ptr& pjrt_layout) { - PjRtXlaLayout* xla_layout = - tensorflow::down_cast(pjrt_layout.get()); - CHECK(xla_layout != nullptr) << "Got unexpected layout type"; - return xla_layout->xla_layout(); -} - } // namespace xla #endif // XLA_PJRT_PJRT_LAYOUT_H_ diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index aa9dec51117d15..a5b0790b107b76 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -120,8 +120,10 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/service/compiler.h" #include "xla/service/computation_layout.h" +#include "xla/service/dump.h" #include "xla/service/executable.h" #include "xla/service/generic_transfer_manager.h" +#include "xla/service/hlo.pb.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/service/shaped_buffer.h" @@ -345,32 +347,11 @@ void StallStreamOnError(LocalDeviceState* local_device, se::Stream* stream) { // after the usage of device_buffer was enqueued. // usage_stream: the stream the operation using device_buffer // was enqueued on. -// prefer_to_retain_reference: relevant only for the compute synchronous -// allocation model. If true, retain a reference -// to device_buffer until after the operation -// completes. If false then the compute stream -// will have to be synchronized past event before -// device_buffer can be freed. -// -// prefer_to_retain_reference encodes a heuristic set by the caller for the -// compute synchronous model: -// -// Generally when a buffer is the destination of a copy to a device, it will -// subsequently be used on the device's compute stream before being freed. In -// that case, there is no need to retain a reference to the buffer. If the -// buffer is freed before being used on the compute stream, the free will be -// delayed until the host knows that event has completed, but this is expected -// to be uncommon. -// -// When a buffer is the source of a copy from a device, we need to either retain -// a reference to the buffer until the copy completes or serialize the compute -// stream behind the copy. It is often better to retain a reference since while -// that keeps memory alive longer, it avoids stalling the compute stream. void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer, LocalDeviceState* buffer_local_device, LocalDeviceState* stream_local_device, std::shared_ptr event, - se::Stream* usage_stream, bool prefer_to_retain_reference, + se::Stream* usage_stream, std::vector>* buffers_to_release = nullptr) { tsl::profiler::TraceMe traceme("RecordUsage"); @@ -380,11 +361,7 @@ void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer, (stream_local_device != buffer_local_device) || // In the synchronous allocation model, always retain a reference. (stream_local_device->allocation_model() == - LocalDeviceState::kSynchronous) || - // In the compute synchronous model, use the caller's heuristic. - (stream_local_device->allocation_model() == - LocalDeviceState::kComputeSynchronized && - prefer_to_retain_reference); + LocalDeviceState::kSynchronous); if (retain_buffer_until_completion) { if (buffers_to_release) { buffers_to_release->push_back(device_buffer.buffer()); @@ -413,15 +390,8 @@ absl::Status AddDestinationBufferSynchronization( } definition_event->SetSequencingEvent(std::move(event_or).value(), copy_stream); - // prefer_to_retain_reference=false means don't retain a memory reference - // until the transfer is complete when using the ComputeSynchronized - // allocation model. This is a heuristic because in the common case - // destination buffers will be used on the compute stream and therefore don't - // require any synchronization before being freed. If the buffer is allocated - // and never used, the free will take longer and this is assumed to be ok. RecordUsage(std::move(device_buffer), local_device, local_device, - definition_event, copy_stream, - /*prefer_to_retain_reference=*/false); + definition_event, copy_stream); return absl::OkStatus(); } @@ -514,7 +484,7 @@ AllocateDestinationBuffer( // put it as the first definition event so that we can guarantee only the // first one might not have event recorded. if (definition_event) { - definition_events.emplace_back(definition_event); + definition_events.push_back(definition_event); } if (local_device->allocation_model() == LocalDeviceState::kComputeSynchronized) { @@ -532,7 +502,7 @@ AllocateDestinationBuffer( // We have at least one definition event, for the copy completing to // the device buffers. if (definition_event) { - definition_events.emplace_back(definition_event); + definition_events.push_back(definition_event); } else { definition_events.emplace_back( std::make_shared(client->thread_pool())); @@ -581,16 +551,9 @@ AllocateDestinationBuffer( if (on_device_shape.IsTuple()) { // Add a usage hold for the tuple table write and immediately convert it to - // the appropriate form of synchronization. prefer_to_retain_reference=false - // means don't retain a memory reference until the transfer is complete when - // using the ComputeSynchronized allocation model. This is a heuristic - // because in the common case destination buffers will be used on the - // compute stream and therefore don't require any synchronization before - // being freed. If the buffer is allocated and never used, the free will - // take longer and this is assumed to be ok. + // the appropriate form of synchronization. RecordUsage(py_buffer->GetBufferWithUsageHold(), local_device, local_device, - definition_events.back(), tuple_table_stream, - /*prefer_to_retain_reference=*/false); + definition_events.back(), tuple_table_stream); } return py_buffer; @@ -1952,8 +1915,7 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper( std::move(async_copy_to_device)); RecordUsage(std::move(dst_device_buffer), transfer_local_device, - transfer_local_device, copy_event, transfer_stream, - /*prefer_to_retain_reference=*/false); + transfer_local_device, copy_event, transfer_stream); return std::pair, std::shared_ptr>( @@ -2037,12 +1999,6 @@ PjRtStreamExecutorBuffer::CopyToDeviceMemorySpace( std::unique_ptr& buffer = buffer_and_event.first; std::shared_ptr& event = buffer_and_event.second; - // prefer_to_retain_reference=*/true means that, when using the - // ComputeSynchronized allocation model, retain a reference to the - // src_device_buffer until the copy completes. This is a heuristic; the - // alternative is to ensure, before freeing the buffer, that the compute - // stream is synchronized past the transfer, but it seems better to hold onto - // the buffer too long than to stall the compute stream. src_device_buffer.ConvertUsageHold(transfer_stream, event, /*reference_held=*/true); @@ -2338,7 +2294,7 @@ absl::StatusOr> OutputBufferHelper( memory_space); RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device, definition_event, local_device->compute_stream(), - /*prefer_to_retain_reference=*/false, &buffers_to_release); + &buffers_to_release); return std::unique_ptr(std::move(pjrt_buffer)); } @@ -3014,6 +2970,9 @@ static absl::Status GetFirstInputError( auto* buffer = tensorflow::down_cast(handle); PjRtStreamExecutorBuffer::ScopedHold hold = buffer->GetBufferWithUsageHold(); + if (!hold.ok()) { + return hold.status(); + } for (const auto& event : hold->definition_events()) { if (event->IsPredeterminedError()) { return event->GetDefinedStatus(); @@ -3116,14 +3075,9 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper( buffers_to_release)); for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) { - // prefer_to_retain_reference=false because when using the - // ComputeSynchronized allocation model we don't need to retain a reference - // to the device_buffer during execution because by definition the compute - // stream is synchronized past the execution. if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kUsage) { RecordUsage(std::move(b), device_state, device_state, definition_event, - stream, - /*prefer_to_retain_reference=*/false, &buffers_to_release); + stream, &buffers_to_release); } else { CHECK(b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation); b.ConfirmDonation(); @@ -3175,6 +3129,21 @@ PjRtStreamExecutorLoadedExecutable::Execute( if (device_assignment_ == nullptr) { return InvalidArgument("Execute expects a non-null device_assignment"); } + if (input_hlo_snapshot_bits_.has_value()) { + HloUnoptimizedSnapshot hlo_snapshot; + *hlo_snapshot.mutable_hlo_module() = input_hlo_snapshot_bits_->hlo_module; + for (const auto& argument_handle : argument_handles) { + HloInputs hlo_inputs; + for (const auto& buffer : argument_handle) { + TF_ASSIGN_OR_RETURN(auto literal, buffer->ToLiteralSync()); + *hlo_inputs.add_arguments() = literal->ToProto(); + } + *hlo_snapshot.add_partitions() = std::move(hlo_inputs); + + DumpHloUnoptimizedSnapshotIfEnabled( + hlo_snapshot, input_hlo_snapshot_bits_->debug_options); + } + } RunId run_id; tsl::profiler::TraceMeProducer activity( @@ -3566,6 +3535,12 @@ PjRtStreamExecutorClient::CompileInternal( TF_RETURN_IF_ERROR( executable->SetUpDonation(options.parameter_is_tupled_arguments)); + const auto& ex_options = options.executable_build_options; + if (ex_options.has_debug_options() && + ex_options.debug_options().xla_gpu_dump_hlo_unoptimized_snapshots()) { + executable->SetInputHloSnapshotBits( + computation.proto(), options.executable_build_options.debug_options()); + } return std::unique_ptr(std::move(executable)); } diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index c06417928bb6a7..f753df6d6fcc29 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -61,6 +61,7 @@ limitations under the License. #include "xla/service/computation_placer.h" #include "xla/service/executable.h" #include "xla/service/gpu/gpu_executable_run_options.h" +#include "xla/service/hlo.pb.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/service/shaped_buffer.h" @@ -69,6 +70,7 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/tsl/framework/allocator.h" #include "xla/util.h" +#include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/casts.h" #include "tsl/platform/threadpool.h" @@ -392,13 +394,6 @@ class PjRtStreamExecutorClient : public PjRtClient { std::function on_delete_callback, std::optional stream) override; - absl::StatusOr CreateChannelHandle() override { - return client()->CreateChannelHandle(); - } - absl::StatusOr CreateDeviceToHostChannelHandle() override { - return client()->CreateDeviceToHostChannelHandle(); - } - // TODO(zhangqiaorjc): Experimental. Will be removed. absl::Status Defragment() override { return Unimplemented("Defragment not implemented"); @@ -1012,6 +1007,13 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable { return fingerprint_; }; + void SetInputHloSnapshotBits(HloModuleProto hlo_module, + DebugOptions debug_options) { + input_hlo_snapshot_bits_ = + std::make_optional(InputHloSnapshotBits{ + HloModuleProto(std::move(hlo_module)), std::move(debug_options)}); + } + protected: bool parameter_is_tupled_arguments() const { return parameter_is_tupled_arguments_; @@ -1093,6 +1095,14 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable { // unique_ptrs to play well with the Python bindings (see xla.cc). std::vector addressable_devices_; std::string fingerprint_; + + struct InputHloSnapshotBits { + HloModuleProto hlo_module; + DebugOptions debug_options; + }; + + // The unoptimized (unsharded) HloModule. Primarily used for debugging. + std::optional input_hlo_snapshot_bits_; }; } // namespace xla diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc index a25125ceb9a2c6..0c742a62aa86cf 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/client/client_library.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal.h" #include "xla/literal_comparison.h" #include "xla/literal_util.h" @@ -32,7 +33,6 @@ limitations under the License. #include "xla/pjrt/pjrt_future.h" #include "xla/service/platform_util.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/pjrt/plugin/example_plugin/BUILD b/third_party/xla/xla/pjrt/plugin/example_plugin/BUILD index 35365d04821249..cd18a91dacbb41 100644 --- a/third_party/xla/xla/pjrt/plugin/example_plugin/BUILD +++ b/third_party/xla/xla/pjrt/plugin/example_plugin/BUILD @@ -27,6 +27,7 @@ xla_cc_test( deps = [ ":myplugin_cpp_pjrt", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test", ], @@ -56,6 +57,7 @@ xla_cc_test( ":myplugin_c_pjrt", "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test", ], diff --git a/third_party/xla/xla/pjrt/plugin/example_plugin/myplugin_c_pjrt_test.cc b/third_party/xla/xla/pjrt/plugin/example_plugin/myplugin_c_pjrt_test.cc index 65d927b9060a82..2aaa7a3a6ffee6 100644 --- a/third_party/xla/xla/pjrt/plugin/example_plugin/myplugin_c_pjrt_test.cc +++ b/third_party/xla/xla/pjrt/plugin/example_plugin/myplugin_c_pjrt_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/pjrt/plugin/example_plugin/myplugin_c_pjrt.h" +#include +#include #include "xla/pjrt/c/pjrt_c_api.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/pjrt/plugin/example_plugin/myplugin_cpp_pjrt_test.cc b/third_party/xla/xla/pjrt/plugin/example_plugin/myplugin_cpp_pjrt_test.cc index 750e19df161da5..4fde7ae2161b22 100644 --- a/third_party/xla/xla/pjrt/plugin/example_plugin/myplugin_cpp_pjrt_test.cc +++ b/third_party/xla/xla/pjrt/plugin/example_plugin/myplugin_cpp_pjrt_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/pjrt/plugin/example_plugin/myplugin_cpp_pjrt.h" +#include +#include #include "tsl/platform/status_matchers.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/pjrt/plugin/xla_cpu/BUILD b/third_party/xla/xla/pjrt/plugin/xla_cpu/BUILD index 3972441ecc90c7..7cacc483b635a5 100644 --- a/third_party/xla/xla/pjrt/plugin/xla_cpu/BUILD +++ b/third_party/xla/xla/pjrt/plugin/xla_cpu/BUILD @@ -1,4 +1,5 @@ load("//xla:xla.bzl", "xla_cc_test") +load("//xla/tsl/platform:build_config.bzl", "tf_proto_library") load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( @@ -39,7 +40,69 @@ cc_library( srcs = [], hdrs = ["cpu_client_options.h"], deps = [ + "//xla/backends/cpu/collectives:cpu_collectives", "//xla/service:hlo_module_config", - "//xla/service/cpu:collectives_interface", + ], +) + +cc_library( + name = "cpu_device_description", + srcs = ["cpu_device_description.cc"], + hdrs = ["cpu_device_description.h"], + deps = [ + ":cpu_topology", + "//xla/pjrt:pjrt_common", + "//xla/pjrt:pjrt_device_description", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "cpu_topology_description", + srcs = ["cpu_topology_description.cc"], + hdrs = ["cpu_topology_description.h"], + deps = [ + ":cpu_device_description", + ":cpu_topology", + "//xla:shape_util", + "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_compiler", + "//xla/pjrt:pjrt_device_description", + "//xla/tsl/lib/strings:proto_serialization", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +tf_proto_library( + name = "cpu_topology_proto", + srcs = ["cpu_topology.proto"], +) + +cc_library( + name = "cpu_topology", + srcs = ["cpu_topology.cc"], + hdrs = ["cpu_topology.h"], + deps = [ + ":cpu_topology_proto_cc", + "//xla/pjrt:pjrt_common", + "@com_google_absl//absl/types:span", + ], +) + +xla_cc_test( + name = "cpu_topology_test", + srcs = ["cpu_topology_test.cc"], + deps = [ + ":cpu_topology", + ":cpu_topology_proto_cc", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_client_options.h b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_client_options.h index bed88b8ae68e5e..aec801763e1404 100644 --- a/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_client_options.h +++ b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_client_options.h @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "xla/service/cpu/collectives_interface.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/service/hlo_module_config.h" namespace xla { @@ -45,7 +45,7 @@ struct CpuClientOptions { // Distributed collectives implementation. Optional. If not provided, an // in-process collectives implementation will be used. - std::shared_ptr collectives; + std::shared_ptr collectives; // If defined this function will be called on the HloModuleConfig before // compilation, and allows users to set custom flags. diff --git a/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_device_description.cc b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_device_description.cc new file mode 100644 index 00000000000000..d907259ed28e9e --- /dev/null +++ b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_device_description.cc @@ -0,0 +1,48 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/plugin/xla_cpu/cpu_device_description.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_topology.h" + +namespace xla { + +namespace { + +constexpr char kCpuPlatformName[] = "cpu"; + +} + +CpuDeviceDescription::CpuDeviceDescription(int process_id, int local_device_id) + : id_(PackCpuDeviceId(process_id, local_device_id)), + process_index_(process_id), + local_hardware_id_(local_device_id) { + debug_string_ = absl::StrCat("TFRT_CPU_", id_.value()); + to_string_ = absl::StrCat("CpuDevice(id=", id_.value(), ")"); +} + +absl::string_view CpuDeviceDescription::device_kind() const { + return kCpuPlatformName; +} + +absl::string_view CpuDeviceDescription::DebugString() const { + return debug_string_; +} + +absl::string_view CpuDeviceDescription::ToString() const { return to_string_; } + +} // namespace xla diff --git a/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_device_description.h b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_device_description.h new file mode 100644 index 00000000000000..0ea1861e7b936d --- /dev/null +++ b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_device_description.h @@ -0,0 +1,60 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_PLUGIN_XLA_CPU_CPU_DEVICE_DESCRIPTION_H_ +#define XLA_PJRT_PLUGIN_XLA_CPU_CPU_DEVICE_DESCRIPTION_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_device_description.h" + +namespace xla { + +class CpuDeviceDescription final : public PjRtDeviceDescription { + public: + explicit CpuDeviceDescription(int process_id, int local_device_id); + + int id() const override { return id_.value(); } + + int process_index() const override { return process_index_; } + + int local_hardware_id() const { return local_hardware_id_; } + + absl::string_view device_kind() const override; + + absl::string_view DebugString() const override; + + absl::string_view ToString() const override; + + const absl::flat_hash_map& Attributes() + const override { + return attributes_; + } + + private: + PjRtGlobalDeviceId id_; + int process_index_; + int local_hardware_id_; + std::string debug_string_; + std::string to_string_; + absl::flat_hash_map attributes_ = {}; +}; + +} // namespace xla + +#endif // XLA_PJRT_PLUGIN_XLA_CPU_CPU_DEVICE_DESCRIPTION_H_ diff --git a/third_party/xla/xla/pjrt/cpu/cpu_topology.cc b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology.cc similarity index 95% rename from third_party/xla/xla/pjrt/cpu/cpu_topology.cc rename to third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology.cc index f9729ff093bac9..5eca7c0d07760c 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_topology.cc +++ b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/pjrt/cpu/cpu_topology.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_topology.h" #include #include @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "xla/pjrt/cpu/cpu_topology.pb.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_topology.pb.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/cpu/cpu_topology.h b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology.h similarity index 91% rename from third_party/xla/xla/pjrt/cpu/cpu_topology.h rename to third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology.h index eb337325758788..24c5e1c93e637a 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_topology.h +++ b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_PJRT_CPU_CPU_TOPOLOGY_H_ -#define XLA_PJRT_CPU_CPU_TOPOLOGY_H_ +#ifndef XLA_PJRT_PLUGIN_XLA_CPU_CPU_TOPOLOGY_H_ +#define XLA_PJRT_PLUGIN_XLA_CPU_CPU_TOPOLOGY_H_ #include #include @@ -22,8 +22,8 @@ limitations under the License. #include #include "absl/types/span.h" -#include "xla/pjrt/cpu/cpu_topology.pb.h" #include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_topology.pb.h" namespace xla { class CpuTopology { @@ -71,4 +71,4 @@ inline int UnpackCpuProcessIndex(PjRtGlobalDeviceId global_device_id) { } // namespace xla -#endif // XLA_PJRT_CPU_CPU_TOPOLOGY_H_ +#endif // XLA_PJRT_PLUGIN_XLA_CPU_CPU_TOPOLOGY_H_ diff --git a/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology.proto b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology.proto new file mode 100644 index 00000000000000..bd258a822bfcc7 --- /dev/null +++ b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology.proto @@ -0,0 +1,28 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +// A proto used to serialize CpuTopology instances. +message CpuTopologyProto { + message CpuDevice { + int32 process_index = 2; + int32 local_hardware_id = 3; + } + repeated CpuDevice cpu_devices = 1; + repeated string machine_attributes = 4; +} diff --git a/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology_description.cc b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology_description.cc new file mode 100644 index 00000000000000..60a9054588d6c8 --- /dev/null +++ b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology_description.cc @@ -0,0 +1,80 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/plugin/xla_cpu/cpu_topology_description.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_device_description.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_device_description.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_topology.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/lib/strings/proto_serialization.h" + +namespace xla { + +/*static*/ CpuTopologyDescription CpuTopologyDescription::Create( + PjRtPlatformId platform_id, absl::string_view platform_name, + absl::string_view platform_version, + absl::Span> devices, + absl::Span machine_attributes) { + std::vector cpu_devices; + cpu_devices.reserve(devices.size()); + for (const auto& device : devices) { + cpu_devices.push_back(CpuTopology::CpuDevice{ + device->process_index(), device->local_hardware_id().value()}); + } + return CpuTopologyDescription(platform_id, platform_name, platform_version, + cpu_devices, machine_attributes); +} + +absl::StatusOr CpuTopologyDescription::GetDefaultLayout( + PrimitiveType element_type, absl::Span dims) const { + Shape shape = ShapeUtil::MakeShape(element_type, dims); + return LayoutUtil::GetWithDefaultLayout(shape).layout(); +} + +absl::StatusOr CpuTopologyDescription::Serialize() const { + std::string result; + if (!tsl::SerializeToStringDeterministic(cpu_topology_.ToProto(), &result)) { + return absl::InternalError("Failed to serialize cpu_topology"); + } + return result; +} + +std::vector> +CpuTopologyDescription::DeviceDescriptions() const { + std::vector> devices; + devices.reserve(cpu_topology_.number_of_devices()); + for (const CpuTopology::CpuDevice& device : cpu_topology_.devices()) { + devices.push_back(std::make_unique( + device.process_id, device.local_device_id)); + } + return devices; +} + +} // namespace xla diff --git a/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology_description.h b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology_description.h new file mode 100644 index 00000000000000..545644c0c7eaec --- /dev/null +++ b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology_description.h @@ -0,0 +1,125 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_PLUGIN_XLA_CPU_CPU_TOPOLOGY_DESCRIPTION_H_ +#define XLA_PJRT_PLUGIN_XLA_CPU_CPU_TOPOLOGY_DESCRIPTION_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/layout.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_device_description.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_topology.h" + +namespace xla { + +class CpuTopologyDescription : public PjRtTopologyDescription { + public: + static CpuTopologyDescription Create( + PjRtPlatformId platform_id, absl::string_view platform_name, + absl::string_view platform_version, + absl::Span> devices, + absl::Span machine_attributes); + + // `cpu_device_ids` is the list of logical device ids for the CPU devices and + // will be used to initialize the CPU topology. + CpuTopologyDescription(const PjRtPlatformId platform_id, + const absl::string_view platform_name, + const absl::string_view platform_version, + const std::vector cpu_devices, + absl::Span machine_attributes) + : platform_id_(platform_id), + platform_name_(platform_name), + platform_version_(platform_version), + cpu_topology_(std::move(cpu_devices), + std::vector(machine_attributes.begin(), + machine_attributes.end())) {} + + bool operator==(const CpuTopologyDescription& other) const { + return this->platform_id() == other.platform_id() && + this->platform_name() == other.platform_name() && + this->platform_version() == other.platform_version() && + this->cpu_topology().devices() == other.cpu_topology().devices(); + } + + PjRtPlatformId platform_id() const override { return platform_id_; } + + absl::string_view platform_name() const override { return platform_name_; } + + absl::string_view platform_version() const override { + return platform_version_; + } + + std::vector> DeviceDescriptions() + const override; + + const CpuTopology& cpu_topology() const { return cpu_topology_; } + const CpuTopology* cpu_topology_ptr() const { return &cpu_topology_; } + + // No subslice is supported. + bool is_subslice_topology() const override { return false; } + + // TODO(b/319478189): We support multi-host CPU computations and should + // correctly report process count. + absl::StatusOr ProcessCount() const override { return 1; } + + absl::StatusOr CoreCountOfDefaultType() const override { + return cpu_topology_.number_of_devices(); + } + + absl::StatusOr LogicalDeviceCountOfDefaultType() const override { + return cpu_topology_.number_of_devices(); + } + + absl::StatusOr CoreCountOfDefaultTypePerProcess() const override { + return cpu_topology_.number_of_devices(); + } + + absl::StatusOr CoreCountOfDefaultTypePerChip() const override { + return 1; + } + + absl::StatusOr Serialize() const override; + + // Returns vendor specific attributes about the topology. + const absl::flat_hash_map& Attributes() + const override { + return attributes_; + } + + absl::StatusOr GetDefaultLayout( + PrimitiveType element_type, + absl::Span dims) const override; + + private: + const PjRtPlatformId platform_id_; + const std::string platform_name_; + const std::string platform_version_; + const CpuTopology cpu_topology_; + absl::flat_hash_map attributes_; +}; + +} // namespace xla + +#endif // XLA_PJRT_PLUGIN_XLA_CPU_CPU_TOPOLOGY_DESCRIPTION_H_ diff --git a/third_party/xla/xla/pjrt/cpu/cpu_topology_test.cc b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology_test.cc similarity index 94% rename from third_party/xla/xla/pjrt/cpu/cpu_topology_test.cc rename to third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology_test.cc index 46574d47a867e7..3ac9b18fe52a66 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_topology_test.cc +++ b/third_party/xla/xla/pjrt/plugin/xla_cpu/cpu_topology_test.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/pjrt/cpu/cpu_topology.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_topology.h" #include -#include "xla/pjrt/cpu/cpu_topology.pb.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_topology.pb.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/pjrt/plugin/xla_gpu/BUILD b/third_party/xla/xla/pjrt/plugin/xla_gpu/BUILD index 3a8110fea36876..1dcbfd5e150456 100644 --- a/third_party/xla/xla/pjrt/plugin/xla_gpu/BUILD +++ b/third_party/xla/xla/pjrt/plugin/xla_gpu/BUILD @@ -51,6 +51,7 @@ xla_test( ":xla_gpu_pjrt_client", "//xla/pjrt/gpu:se_gpu_pjrt_client", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client_test.cc index d0e9661264c548..13ea2f0a799822 100644 --- a/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client.h" +#include +#include #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/pjrt/plugin/xla_tpu/BUILD b/third_party/xla/xla/pjrt/plugin/xla_tpu/BUILD index d2b16893f209af..6a924d400539a6 100644 --- a/third_party/xla/xla/pjrt/plugin/xla_tpu/BUILD +++ b/third_party/xla/xla/pjrt/plugin/xla_tpu/BUILD @@ -17,6 +17,9 @@ cc_library( deps = [ "//xla/pjrt:pjrt_c_api_client", "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_common", + "//xla/pjrt/distributed:key_value_store_interface", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", ], ) @@ -27,7 +30,9 @@ cc_test( tags = ["no_oss"], deps = [ ":xla_tpu_pjrt_client", + "//xla/pjrt:pjrt_common", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_map", "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client.cc b/third_party/xla/xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client.cc index 4858780276986d..81f2f4b509400f 100644 --- a/third_party/xla/xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client.cc @@ -16,16 +16,23 @@ limitations under the License. #include "xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client.h" #include +#include +#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/pjrt_c_api_client.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" const char kTpuPjrtName[] = "tpu"; namespace xla { -absl::StatusOr> GetXlaPjrtTpuClient() { - return GetCApiClient(kTpuPjrtName); +absl::StatusOr> GetXlaPjrtTpuClient( + const absl::flat_hash_map& create_options, + std::shared_ptr kv_store) { + return GetCApiClient(kTpuPjrtName, create_options, kv_store); } } // namespace xla diff --git a/third_party/xla/xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client.h b/third_party/xla/xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client.h index f5fa9637522d90..39050706e325a4 100644 --- a/third_party/xla/xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client.h @@ -17,14 +17,20 @@ limitations under the License. #define XLA_PJRT_PLUGIN_XLA_TPU_XLA_TPU_PJRT_CLIENT_H_ #include +#include +#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" namespace xla { -// Public entry point to get an XLA:TPU PjRtClient -absl::StatusOr> GetXlaPjrtTpuClient(); +// Public entry point to get an XLA:TPU PjRtClient with default options +absl::StatusOr> GetXlaPjrtTpuClient( + const absl::flat_hash_map& create_options = {}, + std::shared_ptr kv_store = nullptr); } // namespace xla diff --git a/third_party/xla/xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client_test.cc index 5fb666670c975e..7c7b8c587d45bb 100644 --- a/third_party/xla/xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client_test.cc @@ -15,11 +15,23 @@ limitations under the License. #include "xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client.h" +#include + +#include "absl/container/flat_hash_map.h" +#include "xla/pjrt/pjrt_common.h" #include "tsl/platform/test.h" namespace xla { -TEST(XlaCpuPjrtClientTest, GetXlaPjrtTpuClient) { +TEST(XlaCpuPjrtClientTest, GetXlaPjrtTpuClientWithDefaultOptions) { + ASSERT_OK_AND_ASSIGN(auto client, GetXlaPjrtTpuClient()); + EXPECT_EQ(client->platform_name(), "tpu"); +} + +TEST(XlaCpuPjrtClientTest, GetXlaPjrtTpuClientWithInvalidOptions) { + absl::flat_hash_map create_options; + create_options.insert({"invalid_option", true}); + ASSERT_OK_AND_ASSIGN(auto client, GetXlaPjrtTpuClient()); EXPECT_EQ(client->platform_name(), "tpu"); } diff --git a/third_party/xla/xla/pjrt/semaphore_test.cc b/third_party/xla/xla/pjrt/semaphore_test.cc index 624265b773e99f..51413f132b8694 100644 --- a/third_party/xla/xla/pjrt/semaphore_test.cc +++ b/third_party/xla/xla/pjrt/semaphore_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "absl/synchronization/notification.h" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" #include "tsl/platform/env.h" #include "tsl/platform/threadpool.h" diff --git a/third_party/xla/xla/pjrt/tf_pjrt_client.h b/third_party/xla/xla/pjrt/tf_pjrt_client.h index 8933a2482c8683..49b8d5db5e92ec 100644 --- a/third_party/xla/xla/pjrt/tf_pjrt_client.h +++ b/third_party/xla/xla/pjrt/tf_pjrt_client.h @@ -340,12 +340,6 @@ class TfPjRtClient : public PjRtClient { return wrapped_->MakeCrossHostReceiveBuffersForGather( shapes, std::move(gather_details), device, std::move(notifier)); } - absl::StatusOr CreateChannelHandle() override { - return wrapped_->CreateChannelHandle(); - } - absl::StatusOr CreateDeviceToHostChannelHandle() override { - return wrapped_->CreateDeviceToHostChannelHandle(); - } absl::StatusOr GetTopologyDescription() const override { return wrapped_->GetTopologyDescription(); diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc b/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc index 5c710344eeb655..b8d4b61a75dd4f 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc +++ b/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/client/client_library.h" #include "xla/client/local_client.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/pjrt/pjrt_client.h" @@ -34,7 +35,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_memory_allocator.h" -#include "xla/test.h" #include "xla/util.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/pjrt/transpose_test.cc b/third_party/xla/xla/pjrt/transpose_test.cc index 7d7ed774c0ce9f..0530c9eb4547c8 100644 --- a/third_party/xla/xla/pjrt/transpose_test.cc +++ b/third_party/xla/xla/pjrt/transpose_test.cc @@ -30,9 +30,9 @@ limitations under the License. #include "absl/numeric/int128.h" #include "unsupported/Eigen/CXX11/Tensor" #include "xla/array.h" +#include "xla/hlo/testlib/test.h" #include "xla/permutation_util.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tsl/protobuf/error_codes.pb.h" #include "xla/util.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/pjrt/utils.cc b/third_party/xla/xla/pjrt/utils.cc index a7462b60700fa5..fcec108940b134 100644 --- a/third_party/xla/xla/pjrt/utils.cc +++ b/third_party/xla/xla/pjrt/utils.cc @@ -251,7 +251,7 @@ static absl::StatusOr> MlirAttrsToMemoryKinds( if (attr != nullptr) { TF_ASSIGN_OR_RETURN(MemorySpaceColor memory_space, GetMemorySpaceColor(attr.getValue().str())); - result.emplace_back(memory_space); + result.push_back(memory_space); } else { result.emplace_back(xla::Layout::kDefaultMemorySpace); } @@ -420,7 +420,7 @@ GetMemoryKindsFromFrontendAttr(absl::string_view attr) { for (const std::string& str_mem_space : str_memory_spaces) { MemorySpaceColor memory_space; CHECK(absl::SimpleAtoi(str_mem_space, &memory_space)); - result.emplace_back(memory_space); + result.push_back(memory_space); } return result; } @@ -480,9 +480,9 @@ absl::StatusOr> GetOutputMemoryKinds( return GetMemoryKinds(computation, "out_memory_spaces", num_outputs); } -static absl::StatusOr LayoutModeToXlaShape( +absl::StatusOr LayoutModeToXlaShape( const LayoutMode& layout_mode, const Shape& unsharded_shape, - const Shape& sharded_shape, + const Shape& sharded_shape, MemorySpaceColor memory_space, std::function(Shape)> choose_compact_layout_for_shape_function) { if (unsharded_shape.IsToken() || unsharded_shape.IsOpaque()) { @@ -516,6 +516,10 @@ static absl::StatusOr LayoutModeToXlaShape( break; } } + // When layout is AUTO, memory space can't be set since it will be partial. + if (result.has_layout()) { + result.mutable_layout()->set_memory_space(memory_space); + } return result; } @@ -587,12 +591,8 @@ absl::StatusOr, Shape>> LayoutModesToXlaShapes( TF_ASSIGN_OR_RETURN( Shape layout, LayoutModeToXlaShape(arg_layout_modes[i], unsharded_arg_shapes[i], - sharded_arg_shapes[i], + sharded_arg_shapes[i], arg_memory_spaces[i], choose_compact_layout_for_shape_function)); - // When layout is AUTO, memory space can't be set since it will be partial. - if (layout.has_layout()) { - layout.mutable_layout()->set_memory_space(arg_memory_spaces[i]); - } flat_arg_layouts.emplace_back(std::move(layout)); } @@ -606,12 +606,8 @@ absl::StatusOr, Shape>> LayoutModesToXlaShapes( TF_ASSIGN_OR_RETURN( Shape layout, LayoutModeToXlaShape(out_layout_modes[i], unsharded_out_shapes[i], - sharded_out_shapes[i], + sharded_out_shapes[i], out_memory_spaces[i], choose_compact_layout_for_shape_function)); - // When layout is AUTO, memory space can't be set since it will be partial. - if (layout.has_layout()) { - layout.mutable_layout()->set_memory_space(out_memory_spaces[i]); - } flat_out_layouts.emplace_back(std::move(layout)); } diff --git a/third_party/xla/xla/pjrt/utils.h b/third_party/xla/xla/pjrt/utils.h index 3470bd164d72a7..d726ecd2745669 100644 --- a/third_party/xla/xla/pjrt/utils.h +++ b/third_party/xla/xla/pjrt/utils.h @@ -90,6 +90,13 @@ absl::StatusOr> GetArgMemoryKinds( absl::StatusOr> GetOutputMemoryKinds( const XlaComputation& computation); +// Returns xla shape with layout set to reflect the given layout mode. +absl::StatusOr LayoutModeToXlaShape( + const LayoutMode& layout_mode, const Shape& unsharded_shape, + const Shape& sharded_shape, MemorySpaceColor memory_space, + std::function(Shape)> + choose_compact_layout_for_shape_function); + // Returns (arg shapes, output shape) with properly-set Layouts that can // be passed to XLA to reflect arg_layout_modes and out_layout_modes. absl::StatusOr, Shape>> LayoutModesToXlaShapes( diff --git a/third_party/xla/xla/primitive_util.cc b/third_party/xla/xla/primitive_util.cc index b70ba275a1f47f..f09b9b7a1edb50 100644 --- a/third_party/xla/xla/primitive_util.cc +++ b/third_party/xla/xla/primitive_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/primitive_util.h" +#include #include #include #include @@ -132,22 +133,25 @@ xla::PrimitiveType SignedIntegralTypeForBitWidth(int64_t src_bitwidth) { class PrimitiveTypeNameGenerator { public: PrimitiveTypeNameGenerator() { - for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { - if (i == static_cast(OPAQUE_TYPE)) { - lowercase_name_[i] = "opaque"; - } else if (PrimitiveType_IsValid(i)) { - lowercase_name_[i] = absl::AsciiStrToLower( - PrimitiveType_Name(static_cast(i))); + for (size_t idx = 0; idx < std::size(lowercase_name_); ++idx) { + PrimitiveType t = static_cast(idx + PrimitiveType_MIN); + if (t == OPAQUE_TYPE) { + lowercase_name_[idx] = "opaque"; + } else if (PrimitiveType_IsValid(t)) { + lowercase_name_[idx] = absl::AsciiStrToLower(PrimitiveType_Name(t)); } } } const std::string& LowercaseName(PrimitiveType t) { - CHECK_LT(t, PrimitiveType_ARRAYSIZE); - return lowercase_name_[static_cast(t)]; + CHECK_GE(t, PrimitiveType_MIN); + CHECK_LE(t, PrimitiveType_MAX); + CHECK(PrimitiveType_IsValid(t)) + << "Invalid PrimitiveType: " << static_cast(t); + return lowercase_name_[t - PrimitiveType_MIN]; } private: - std::string lowercase_name_[PrimitiveType_ARRAYSIZE]; + std::string lowercase_name_[PrimitiveType_MAX - PrimitiveType_MIN + 1]; }; const std::string& LowercasePrimitiveTypeName(PrimitiveType s) { diff --git a/third_party/xla/xla/primitive_util.h b/third_party/xla/xla/primitive_util.h index de5ee4fde11d7b..b9c1c978bc620e 100644 --- a/third_party/xla/xla/primitive_util.h +++ b/third_party/xla/xla/primitive_util.h @@ -93,6 +93,11 @@ constexpr PrimitiveType NativeToPrimitiveType() { } // Unsigned integer +template <> +constexpr PrimitiveType NativeToPrimitiveType() { + return U1; +} + template <> constexpr PrimitiveType NativeToPrimitiveType() { return U2; @@ -124,6 +129,11 @@ constexpr PrimitiveType NativeToPrimitiveType() { } // Signed integer +template <> +constexpr PrimitiveType NativeToPrimitiveType() { + return S1; +} + template <> constexpr PrimitiveType NativeToPrimitiveType() { return S2; @@ -234,6 +244,11 @@ struct PrimitiveTypeToNative { }; // Unsigned integer +template <> +struct PrimitiveTypeToNative { + using type = u1; +}; + template <> struct PrimitiveTypeToNative { using type = u2; @@ -265,6 +280,11 @@ struct PrimitiveTypeToNative { }; // Signed integer +template <> +struct PrimitiveTypeToNative { + using type = s1; +}; + template <> struct PrimitiveTypeToNative { using type = s2; @@ -397,13 +417,13 @@ constexpr bool IsComplexType(PrimitiveType type) { } constexpr bool IsSignedIntegralType(PrimitiveType type) { - return type == S2 || type == S4 || type == S8 || type == S16 || type == S32 || - type == S64; + return type == S1 || type == S2 || type == S4 || type == S8 || type == S16 || + type == S32 || type == S64; } constexpr bool IsUnsignedIntegralType(PrimitiveType type) { - return type == U2 || type == U4 || type == U8 || type == U16 || type == U32 || - type == U64; + return type == U1 || type == U2 || type == U4 || type == U8 || type == U16 || + type == U32 || type == U64; } constexpr bool IsIntegralType(PrimitiveType type) { @@ -414,6 +434,8 @@ template constexpr R IntegralTypeSwitch(F&& f, PrimitiveType type) { if (ABSL_PREDICT_TRUE(IsIntegralType(type))) { switch (type) { + case S1: + return std::forward(f)(PrimitiveTypeConstant()); case S2: return std::forward(f)(PrimitiveTypeConstant()); case S4: @@ -426,6 +448,8 @@ constexpr R IntegralTypeSwitch(F&& f, PrimitiveType type) { return std::forward(f)(PrimitiveTypeConstant()); case S64: return std::forward(f)(PrimitiveTypeConstant()); + case U1: + return std::forward(f)(PrimitiveTypeConstant()); case U2: return std::forward(f)(PrimitiveTypeConstant()); case U4: @@ -602,6 +626,8 @@ inline constexpr int ByteWidth(PrimitiveType type) { constexpr PrimitiveType UnsignedIntegralTypeForBitWidth(int64_t src_bitwidth) { switch (src_bitwidth) { + case 1: + return xla::U1; case 2: return xla::U2; case 4: diff --git a/third_party/xla/xla/primitive_util_test.cc b/third_party/xla/xla/primitive_util_test.cc index 850203f17379a4..e4abeb4ff7ac9b 100644 --- a/third_party/xla/xla/primitive_util_test.cc +++ b/third_party/xla/xla/primitive_util_test.cc @@ -15,11 +15,10 @@ limitations under the License. #include "xla/primitive_util.h" -#include #include -#include "xla/test.h" -#include "xla/test_helpers.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -56,39 +55,56 @@ TEST(PrimitiveUtilTest, FloatTypes) { TEST(PrimitiveUtilTest, CastPreservesValues) { bool expecteds[PrimitiveType_ARRAYSIZE][PrimitiveType_ARRAYSIZE]; - expecteds[PRED][PRED] = true; - expecteds[PRED][S2] = true; - expecteds[PRED][S4] = true; - expecteds[PRED][S8] = true; - expecteds[PRED][S16] = true; - expecteds[PRED][S32] = true; - expecteds[PRED][S64] = true; - expecteds[PRED][U2] = true; - expecteds[PRED][U4] = true; - expecteds[PRED][U8] = true; - expecteds[PRED][U16] = true; - expecteds[PRED][U32] = true; - expecteds[PRED][U64] = true; - expecteds[PRED][F16] = true; - expecteds[PRED][F32] = true; - expecteds[PRED][F64] = true; - expecteds[PRED][C64] = true; - expecteds[PRED][BF16] = true; - expecteds[PRED][C128] = true; - expecteds[PRED][F8E5M2] = true; - expecteds[PRED][F8E4M3] = true; - expecteds[PRED][F8E4M3FN] = true; - expecteds[PRED][F8E4M3B11FNUZ] = true; - expecteds[PRED][F8E5M2FNUZ] = true; - expecteds[PRED][F8E4M3FNUZ] = true; - expecteds[PRED][F8E3M4] = true; + expecteds[PRED][PRED] = expecteds[PRED][S1] = true; + expecteds[PRED][S2] = expecteds[PRED][S4] = true; + expecteds[PRED][S8] = expecteds[PRED][S16] = true; + expecteds[PRED][S32] = expecteds[PRED][S64] = true; + expecteds[PRED][U1] = expecteds[PRED][U2] = true; + expecteds[PRED][U4] = expecteds[PRED][U8] = true; + expecteds[PRED][U16] = expecteds[PRED][U32] = true; + expecteds[PRED][U64] = expecteds[PRED][F16] = true; + expecteds[PRED][F32] = expecteds[PRED][F64] = true; + expecteds[PRED][C64] = expecteds[PRED][BF16] = true; + expecteds[PRED][C128] = expecteds[PRED][F8E5M2] = true; + expecteds[PRED][F8E4M3] = expecteds[PRED][F8E4M3FN] = true; + expecteds[PRED][F8E4M3B11FNUZ] = expecteds[PRED][F8E5M2FNUZ] = true; + expecteds[PRED][F8E4M3FNUZ] = expecteds[PRED][F8E3M4] = true; + expecteds[S1][PRED] = false; expecteds[S2][PRED] = false; - expecteds[S2][S2] = true; - expecteds[S2][S4] = true; + expecteds[S1][S1] = true; + expecteds[S1][S2] = true; + expecteds[S1][S4] = true; + expecteds[S1][S8] = true; + expecteds[S1][S16] = true; + expecteds[S1][S32] = true; + expecteds[S1][S64] = true; + expecteds[S1][U1] = false; + expecteds[S1][U2] = false; + expecteds[S1][U4] = false; + expecteds[S1][U8] = false; + expecteds[S1][U16] = false; + expecteds[S1][U32] = false; + expecteds[S1][U64] = false; + expecteds[S1][F16] = true; + expecteds[S1][F32] = true; + expecteds[S1][F64] = true; + expecteds[S1][C64] = true; + expecteds[S1][BF16] = true; + expecteds[S1][C128] = true; + expecteds[S1][F8E5M2] = true; + expecteds[S1][F8E4M3] = true; + expecteds[S1][F8E4M3FN] = true; + expecteds[S1][F8E4M3B11FNUZ] = true; + expecteds[S1][F8E5M2FNUZ] = true; + expecteds[S1][F8E4M3FNUZ] = true; + expecteds[S1][F8E3M4] = true; + expecteds[S2][S1] = false; + expecteds[S2][S2] = expecteds[S2][S4] = true; expecteds[S2][S8] = true; expecteds[S2][S16] = true; expecteds[S2][S32] = true; expecteds[S2][S64] = true; + expecteds[S2][U1] = false; expecteds[S2][U2] = false; expecteds[S2][U4] = false; expecteds[S2][U8] = false; @@ -109,12 +125,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S2][F8E4M3FNUZ] = true; expecteds[S2][F8E3M4] = true; expecteds[S4][PRED] = false; + expecteds[S4][S1] = false; expecteds[S4][S2] = false; expecteds[S4][S4] = true; expecteds[S4][S8] = true; expecteds[S4][S16] = true; expecteds[S4][S32] = true; expecteds[S4][S64] = true; + expecteds[S4][U1] = false; expecteds[S4][U2] = false; expecteds[S4][U4] = false; expecteds[S4][U8] = false; @@ -135,12 +153,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S4][F8E4M3FNUZ] = true; expecteds[S4][F8E3M4] = true; expecteds[S8][PRED] = false; + expecteds[S8][S1] = false; expecteds[S8][S2] = false; expecteds[S8][S4] = false; expecteds[S8][S8] = true; expecteds[S8][S16] = true; expecteds[S8][S32] = true; expecteds[S8][S64] = true; + expecteds[S8][U1] = false; expecteds[S8][U2] = false; expecteds[S8][U4] = false; expecteds[S8][U8] = false; @@ -161,12 +181,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S8][F8E4M3FNUZ] = false; expecteds[S8][F8E3M4] = false; expecteds[S16][PRED] = false; + expecteds[S16][S1] = false; expecteds[S16][S2] = false; expecteds[S16][S4] = false; expecteds[S16][S8] = false; expecteds[S16][S16] = true; expecteds[S16][S32] = true; expecteds[S16][S64] = true; + expecteds[S16][U1] = false; expecteds[S16][U2] = false; expecteds[S16][U4] = false; expecteds[S16][U8] = false; @@ -187,12 +209,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S16][F8E4M3FNUZ] = false; expecteds[S16][F8E3M4] = false; expecteds[S32][PRED] = false; + expecteds[S32][S1] = false; expecteds[S32][S2] = false; expecteds[S32][S4] = false; expecteds[S32][S8] = false; expecteds[S32][S16] = false; expecteds[S32][S32] = true; expecteds[S32][S64] = true; + expecteds[S32][U1] = false; expecteds[S32][U2] = false; expecteds[S32][U4] = false; expecteds[S32][U8] = false; @@ -213,12 +237,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S32][F8E4M3FNUZ] = false; expecteds[S32][F8E3M4] = false; expecteds[S64][PRED] = false; + expecteds[S64][S1] = false; expecteds[S64][S2] = false; expecteds[S64][S4] = false; expecteds[S64][S8] = false; expecteds[S64][S16] = false; expecteds[S64][S32] = false; expecteds[S64][S64] = true; + expecteds[S64][U1] = false; expecteds[S64][U2] = false; expecteds[S64][U4] = false; expecteds[S64][U8] = false; @@ -238,7 +264,38 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S64][F8E5M2FNUZ] = false; expecteds[S64][F8E4M3FNUZ] = false; expecteds[S64][F8E3M4] = false; + expecteds[U1][PRED] = false; + expecteds[U1][S1] = false; + expecteds[U1][S2] = true; + expecteds[U1][S4] = true; + expecteds[U1][S8] = true; + expecteds[U1][S16] = true; + expecteds[U1][S32] = true; + expecteds[U1][S64] = true; + expecteds[U1][U1] = true; + expecteds[U1][U2] = true; + expecteds[U1][U4] = true; + expecteds[U1][U8] = true; + expecteds[U1][U16] = true; + expecteds[U1][U32] = true; + expecteds[U1][U64] = true; + expecteds[U1][F16] = true; + expecteds[U1][F32] = true; + expecteds[U1][F64] = true; + expecteds[U1][C64] = true; + expecteds[U1][BF16] = true; + expecteds[U1][C128] = true; + expecteds[U1][BF16] = true; + expecteds[U1][C128] = true; + expecteds[U1][F8E5M2] = true; + expecteds[U1][F8E4M3] = true; + expecteds[U1][F8E4M3FN] = true; + expecteds[U1][F8E4M3B11FNUZ] = true; + expecteds[U1][F8E5M2FNUZ] = true; + expecteds[U1][F8E4M3FNUZ] = true; + expecteds[U1][F8E3M4] = true; expecteds[U2][PRED] = false; + expecteds[U2][U1] = expecteds[U2][S1] = false; expecteds[U2][S2] = false; expecteds[U2][S4] = true; expecteds[U2][S8] = true; @@ -267,12 +324,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U2][F8E4M3FNUZ] = true; expecteds[U2][F8E3M4] = true; expecteds[U4][PRED] = false; + expecteds[U4][S1] = false; expecteds[U4][S2] = false; expecteds[U4][S4] = false; expecteds[U4][S8] = true; expecteds[U4][S16] = true; expecteds[U4][S32] = true; expecteds[U4][S64] = true; + expecteds[U4][U1] = false; expecteds[U4][U2] = false; expecteds[U4][U4] = true; expecteds[U4][U8] = true; @@ -295,12 +354,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U4][F8E4M3FNUZ] = true; expecteds[U4][F8E3M4] = true; expecteds[U8][PRED] = false; + expecteds[U8][S1] = false; expecteds[U8][S2] = false; expecteds[U8][S4] = false; expecteds[U8][S8] = false; expecteds[U8][S16] = true; expecteds[U8][S32] = true; expecteds[U8][S64] = true; + expecteds[U8][U1] = false; expecteds[U8][U2] = false; expecteds[U8][U4] = false; expecteds[U8][U8] = true; @@ -323,12 +384,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U8][F8E4M3FNUZ] = false; expecteds[U8][F8E3M4] = false; expecteds[U16][PRED] = false; + expecteds[U16][S1] = false; expecteds[U16][S2] = false; expecteds[U16][S4] = false; expecteds[U16][S8] = false; expecteds[U16][S16] = false; expecteds[U16][S32] = true; expecteds[U16][S64] = true; + expecteds[U16][U1] = false; expecteds[U16][U2] = false; expecteds[U16][U4] = false; expecteds[U16][U8] = false; @@ -349,12 +412,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U16][F8E4M3FNUZ] = false; expecteds[U16][F8E3M4] = false; expecteds[U32][PRED] = false; + expecteds[U32][S1] = false; expecteds[U32][S2] = false; expecteds[U32][S4] = false; expecteds[U32][S8] = false; expecteds[U32][S16] = false; expecteds[U32][S32] = false; expecteds[U32][S64] = true; + expecteds[U32][U1] = false; expecteds[U32][U2] = false; expecteds[U32][U4] = false; expecteds[U32][U8] = false; @@ -375,12 +440,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U32][F8E4M3FNUZ] = false; expecteds[U32][F8E3M4] = false; expecteds[U64][PRED] = false; + expecteds[U64][S1] = false; expecteds[U64][S2] = false; expecteds[U64][S4] = false; expecteds[U64][S8] = false; expecteds[U64][S16] = false; expecteds[U64][S32] = false; expecteds[U64][S64] = false; + expecteds[U64][U1] = false; expecteds[U64][U2] = false; expecteds[U64][U4] = false; expecteds[U64][U8] = false; @@ -401,12 +468,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U64][F8E4M3FNUZ] = false; expecteds[U64][F8E3M4] = false; expecteds[F16][PRED] = false; + expecteds[F16][S1] = false; expecteds[F16][S2] = false; expecteds[F16][S4] = false; expecteds[F16][S8] = false; expecteds[F16][S16] = false; expecteds[F16][S32] = false; expecteds[F16][S64] = false; + expecteds[F16][U1] = false; expecteds[F16][U2] = false; expecteds[F16][U4] = false; expecteds[F16][U8] = false; @@ -427,12 +496,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F16][F8E4M3FNUZ] = false; expecteds[F16][F8E3M4] = false; expecteds[F32][PRED] = false; + expecteds[F32][S1] = false; expecteds[F32][S2] = false; expecteds[F32][S4] = false; expecteds[F32][S8] = false; expecteds[F32][S16] = false; expecteds[F32][S32] = false; expecteds[F32][S64] = false; + expecteds[F32][U1] = false; expecteds[F32][U2] = false; expecteds[F32][U4] = false; expecteds[F32][U8] = false; @@ -453,12 +524,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F32][F8E4M3FNUZ] = false; expecteds[F32][F8E3M4] = false; expecteds[F64][PRED] = false; + expecteds[F64][S1] = false; expecteds[F64][S2] = false; expecteds[F64][S4] = false; expecteds[F64][S8] = false; expecteds[F64][S16] = false; expecteds[F64][S32] = false; expecteds[F64][S64] = false; + expecteds[F64][U1] = false; expecteds[F64][U2] = false; expecteds[F64][U4] = false; expecteds[F64][U8] = false; @@ -479,12 +552,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F64][F8E4M3FNUZ] = false; expecteds[F64][F8E3M4] = false; expecteds[C64][PRED] = false; + expecteds[C64][S1] = false; expecteds[C64][S2] = false; expecteds[C64][S4] = false; expecteds[C64][S8] = false; expecteds[C64][S16] = false; expecteds[C64][S32] = false; expecteds[C64][S64] = false; + expecteds[C64][U1] = false; expecteds[C64][U2] = false; expecteds[C64][U4] = false; expecteds[C64][U8] = false; @@ -505,12 +580,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C64][F8E4M3FNUZ] = false; expecteds[C64][F8E3M4] = false; expecteds[BF16][PRED] = false; + expecteds[BF16][S1] = false; expecteds[BF16][S2] = false; expecteds[BF16][S4] = false; expecteds[BF16][S8] = false; expecteds[BF16][S16] = false; expecteds[BF16][S32] = false; expecteds[BF16][S64] = false; + expecteds[BF16][U1] = false; expecteds[BF16][U2] = false; expecteds[BF16][U4] = false; expecteds[BF16][U8] = false; @@ -531,12 +608,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[BF16][F8E4M3FNUZ] = false; expecteds[BF16][F8E3M4] = false; expecteds[C128][PRED] = false; + expecteds[C128][S1] = false; expecteds[C128][S2] = false; expecteds[C128][S4] = false; expecteds[C128][S8] = false; expecteds[C128][S16] = false; expecteds[C128][S32] = false; expecteds[C128][S64] = false; + expecteds[C128][U1] = false; expecteds[C128][U2] = false; expecteds[C128][U4] = false; expecteds[C128][U8] = false; @@ -557,12 +636,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C128][F8E4M3FNUZ] = false; expecteds[C128][F8E3M4] = false; expecteds[F8E5M2][PRED] = false; + expecteds[F8E5M2][S1] = false; expecteds[F8E5M2][S2] = false; expecteds[F8E5M2][S4] = false; expecteds[F8E5M2][S8] = false; expecteds[F8E5M2][S16] = false; expecteds[F8E5M2][S32] = false; expecteds[F8E5M2][S64] = false; + expecteds[F8E5M2][U1] = false; expecteds[F8E5M2][U2] = false; expecteds[F8E5M2][U4] = false; expecteds[F8E5M2][U8] = false; @@ -583,12 +664,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2][F8E4M3FNUZ] = false; expecteds[F8E5M2][F8E3M4] = false; expecteds[F8E4M3][PRED] = false; + expecteds[F8E4M3][S1] = false; expecteds[F8E4M3][S2] = false; expecteds[F8E4M3][S4] = false; expecteds[F8E4M3][S8] = false; expecteds[F8E4M3][S16] = false; expecteds[F8E4M3][S32] = false; expecteds[F8E4M3][S64] = false; + expecteds[F8E4M3][U1] = false; expecteds[F8E4M3][U2] = false; expecteds[F8E4M3][U4] = false; expecteds[F8E4M3][U8] = false; @@ -609,12 +692,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3][F8E4M3B11FNUZ] = false; expecteds[F8E4M3][F8E3M4] = false; expecteds[F8E4M3FN][PRED] = false; + expecteds[F8E4M3FN][S1] = false; expecteds[F8E4M3FN][S2] = false; expecteds[F8E4M3FN][S4] = false; expecteds[F8E4M3FN][S8] = false; expecteds[F8E4M3FN][S16] = false; expecteds[F8E4M3FN][S32] = false; expecteds[F8E4M3FN][S64] = false; + expecteds[F8E4M3FN][U1] = false; expecteds[F8E4M3FN][U2] = false; expecteds[F8E4M3FN][U4] = false; expecteds[F8E4M3FN][U8] = false; @@ -635,12 +720,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FN][F8E4M3B11FNUZ] = false; expecteds[F8E4M3FN][F8E3M4] = false; expecteds[F8E4M3B11FNUZ][PRED] = false; + expecteds[F8E4M3B11FNUZ][S1] = false; expecteds[F8E4M3B11FNUZ][S2] = false; expecteds[F8E4M3B11FNUZ][S4] = false; expecteds[F8E4M3B11FNUZ][S8] = false; expecteds[F8E4M3B11FNUZ][S16] = false; expecteds[F8E4M3B11FNUZ][S32] = false; expecteds[F8E4M3B11FNUZ][S64] = false; + expecteds[F8E4M3B11FNUZ][U1] = false; expecteds[F8E4M3B11FNUZ][U2] = false; expecteds[F8E4M3B11FNUZ][U4] = false; expecteds[F8E4M3B11FNUZ][U8] = false; @@ -661,12 +748,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3B11FNUZ][F8E5M2FNUZ] = false; expecteds[F8E4M3B11FNUZ][F8E3M4] = false; expecteds[F8E5M2FNUZ][PRED] = false; + expecteds[F8E5M2FNUZ][S1] = false; expecteds[F8E5M2FNUZ][S2] = false; expecteds[F8E5M2FNUZ][S4] = false; expecteds[F8E5M2FNUZ][S8] = false; expecteds[F8E5M2FNUZ][S16] = false; expecteds[F8E5M2FNUZ][S32] = false; expecteds[F8E5M2FNUZ][S64] = false; + expecteds[F8E5M2FNUZ][U1] = false; expecteds[F8E5M2FNUZ][U2] = false; expecteds[F8E5M2FNUZ][U4] = false; expecteds[F8E5M2FNUZ][U8] = false; @@ -687,12 +776,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2FNUZ][F8E4M3FNUZ] = false; expecteds[F8E5M2FNUZ][F8E3M4] = false; expecteds[F8E4M3FNUZ][PRED] = false; + expecteds[F8E4M3FNUZ][S1] = false; expecteds[F8E4M3FNUZ][S2] = false; expecteds[F8E4M3FNUZ][S4] = false; expecteds[F8E4M3FNUZ][S8] = false; expecteds[F8E4M3FNUZ][S16] = false; expecteds[F8E4M3FNUZ][S32] = false; expecteds[F8E4M3FNUZ][S64] = false; + expecteds[F8E4M3FNUZ][U1] = false; expecteds[F8E4M3FNUZ][U2] = false; expecteds[F8E4M3FNUZ][U4] = false; expecteds[F8E4M3FNUZ][U8] = false; @@ -713,12 +804,14 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FNUZ][F8E4M3FNUZ] = true; expecteds[F8E4M3FNUZ][F8E3M4] = false; expecteds[F8E3M4][PRED] = false; + expecteds[F8E3M4][S1] = false; expecteds[F8E3M4][S2] = false; expecteds[F8E3M4][S4] = false; expecteds[F8E3M4][S8] = false; expecteds[F8E3M4][S16] = false; expecteds[F8E3M4][S32] = false; expecteds[F8E3M4][S64] = false; + expecteds[F8E3M4][U1] = false; expecteds[F8E3M4][U2] = false; expecteds[F8E3M4][U4] = false; expecteds[F8E3M4][U8] = false; diff --git a/third_party/xla/xla/protobuf_util.cc b/third_party/xla/xla/protobuf_util.cc index 4c6815d9396491..c6744b19507df0 100644 --- a/third_party/xla/xla/protobuf_util.cc +++ b/third_party/xla/xla/protobuf_util.cc @@ -15,10 +15,10 @@ limitations under the License. #include "xla/protobuf_util.h" +#include #include #include "absl/hash/hash.h" -#include "absl/status/status.h" #include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/protobuf_util.h b/third_party/xla/xla/protobuf_util.h index b763d7ddaeff1c..4ba58f2f91388b 100644 --- a/third_party/xla/xla/protobuf_util.h +++ b/third_party/xla/xla/protobuf_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_PROTOBUF_UTIL_H_ #define XLA_PROTOBUF_UTIL_H_ +#include #include #include diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index ea3d40c543d048..0a15d5d52da54f 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -12,6 +12,7 @@ load( "//xla/tsl:tsl.bzl", "if_cuda_or_rocm", "if_google", + "if_oss", "internal_visibility", ) load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable", "tsl_pybind_extension") @@ -263,6 +264,7 @@ cc_library( "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@nanobind", @@ -309,6 +311,7 @@ cc_library( "py_program.cc", "py_values.cc", "sharding.cc", + "to_ifrt_sharding.cc", ], hdrs = [ "py_array.h", @@ -323,6 +326,7 @@ cc_library( "py_values.h", "sharded_device_array.h", "sharding.h", + "to_ifrt_sharding.h", ], compatible_with = [], copts = [ @@ -386,6 +390,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/pjrt:exceptions", "//xla/pjrt:host_callback", + "//xla/pjrt:host_memory_spaces", "//xla/pjrt:lru_cache", "//xla/pjrt:mlir_to_hlo", "//xla/pjrt:pjrt_client", @@ -420,6 +425,7 @@ cc_library( "//xla/tsl/concurrency:ref_count", "//xla/tsl/framework:allocator", "//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", + "//xla/tsl/platform:statusor", "//xla/tsl/python/lib/core:numpy", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:errors", @@ -502,12 +508,15 @@ cc_library( "//xla:comparison_util", "//xla/pjrt:exceptions", "//xla/pjrt:host_callback", + "//xla/pjrt:transpose", "//xla/service:custom_call_status", "//xla/service:custom_call_target_registry", "//xla/service:platform_util", "@com_google_absl//absl/base", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@nanobind", ] + if_rocm( @@ -589,6 +598,7 @@ cc_library( "@nanobind", "@local_config_python//:python_headers", # build_cleaner: keep "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_layout", "//xla/pjrt:status_casters", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/profiler/lib:traceme", @@ -600,6 +610,7 @@ cc_library( srcs = ["inspect_sharding.cc"], hdrs = ["inspect_sharding.h"], deps = [ + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:custom_call_sharding_helper", "//xla/service/spmd:spmd_partitioner", @@ -631,6 +642,9 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", ], ) @@ -695,6 +709,7 @@ cc_library( "//xla/hlo/builder/lib:sorting", "//xla/hlo/builder/lib:svd", "//xla/pjrt:status_casters", + "//xla/service:hlo_proto_cc", ], ) @@ -724,7 +739,6 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", "@com_google_absl//absl/status", @@ -1013,6 +1027,9 @@ cc_library( "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/lib:profiler_session", "@local_tsl//tsl/profiler/lib:traceme", + "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", + "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -1078,6 +1095,7 @@ cc_library( "@nanobind", "@local_config_python//:python_headers", "//xla/pjrt:lru_cache", + "//xla/tsl/platform:logging", ], ) @@ -1125,9 +1143,9 @@ cc_library( "//xla/hlo/ir:hlo_module_group", "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:flatten_call_graph", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:flatten_call_graph", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/pjrt:compile_options_proto_cc", "//xla/pjrt:exceptions", "//xla/pjrt:pjrt_executable", @@ -1292,6 +1310,7 @@ tsl_pybind_extension( "//xla:shape_util", "//xla:types", "//xla:util", + "//xla/backends/cpu/collectives:cpu_collectives", "//xla/ffi:ffi_api", "//xla/pjrt:exceptions", "//xla/pjrt:mlir_to_hlo", @@ -1318,7 +1337,6 @@ tsl_pybind_extension( "//xla/python/pjrt_ifrt", "//xla/python/pjrt_ifrt:pjrt_attribute_map_util", "//xla/python/pjrt_ifrt:xla_ifrt", - "//xla/service/cpu:collectives_interface", "//xla/tsl/concurrency:ref_count", "//xla/tsl/distributed_runtime/preemption:preemption_sync_manager", "//xla/tsl/platform/cloud:gcs_file_system", @@ -1330,22 +1348,21 @@ tsl_pybind_extension( ] + select({ # gloo tcp transport only builds on linux "//xla/tsl:macos": [ - "//xla/pjrt/cpu:gloo_collectives", - "//xla/pjrt/cpu:gloo_kv_store", + "//xla/backends/cpu/collectives:gloo_collectives", + "//xla/backends/cpu/collectives:gloo_kv_store", "@gloo//:transport_uv", ], "//xla/tsl:windows": [], "//conditions:default": [ - "//xla/pjrt/cpu:gloo_collectives", - "//xla/pjrt/cpu:gloo_kv_store", + "//xla/backends/cpu/collectives:gloo_collectives", + "//xla/backends/cpu/collectives:gloo_kv_store", "@gloo//:transport_tcp", ], }) + select({ # mpitrampoline does not build on windows "//xla/tsl:windows": [], - "//conditions:default": [ - "//xla/pjrt/cpu:mpi_collectives", - ], + # we support MPI collectives only in OSS builds + "//conditions:default": if_oss(["//xla/backends/cpu/collectives:mpi_collectives"]), }), ) @@ -1379,8 +1396,8 @@ xla_cc_test( srcs = ["xplane_to_profile_instructions_test.cc"], deps = [ ":xplane_to_profile_instructions", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:hlo_proto_cc", - "//xla/tests:verified_hlo_module", "//xla/tsl/profiler/convert:xla_op_utils", "//xla/tsl/profiler/rpc/client:save_profile", "//xla/tsl/profiler/utils:file_system_utils", @@ -1422,10 +1439,12 @@ cc_library( copts = ["-fexceptions"], features = ["-use_header_modules"], deps = [ - "//xla/tsl/python/lib/core:numpy", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_config_python//:python_headers", "@nanobind", + # copybara:uncomment "//third_party/py/numpy:multiarray", + "@local_config_python//:python_headers", + "//xla/tsl/python/lib/core:numpy", ], ) diff --git a/third_party/xla/xla/python/callback.cc b/third_party/xla/xla/python/callback.cc index 9d0f707b71d2e7..5f4675df6ccb2c 100644 --- a/third_party/xla/xla/python/callback.cc +++ b/third_party/xla/xla/python/callback.cc @@ -23,7 +23,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -32,10 +31,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string_view.h" // IWYU pragma: keep -#include "xla/pjrt/host_callback.h" #include "xla/pjrt/transpose.h" #include "xla/primitive_util.h" #include "xla/python/nb_numpy.h" @@ -127,7 +126,7 @@ absl::StatusOr CpuCallback::Call(nb::tuple args) { if (!PyTuple_Check(result_object.ptr())) { return absl::InternalError( absl::StrFormat("CPU callback expected a tuple result, got %s", - nb::cast(nb::repr(result_object)))); + nb::cast(nb::repr(result_object)))); } if (PyTuple_Size(result_object.ptr()) != results_.size()) { return absl::InternalError( @@ -142,7 +141,7 @@ absl::StatusOr CpuCallback::Call(nb::tuple args) { if (!output.is_none()) { return absl::InternalError(absl::StrFormat( "Token output from Python callback should be None, got %s", - nb::cast(nb::repr(output)))); + nb::cast(nb::repr(output)))); } continue; } diff --git a/third_party/xla/xla/python/custom_call_sharding.cc b/third_party/xla/xla/python/custom_call_sharding.cc index e25fdf835955e0..0bc424c9c13bee 100644 --- a/third_party/xla/xla/python/custom_call_sharding.cc +++ b/third_party/xla/xla/python/custom_call_sharding.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -93,7 +92,7 @@ class PyCustomCallPartitionerCallbacks { xla::Shape result_shape = std::move(std::get<2>(args_tuple)); std::optional result_sharding = std::move(std::get<3>(args_tuple)); - std::string_view backend_config = std::move(std::get<4>(args_tuple)); + absl::string_view backend_config = std::move(std::get<4>(args_tuple)); { nb::gil_scoped_acquire gil; @@ -118,7 +117,7 @@ class PyCustomCallPartitionerCallbacks { return xla::Internal( "Shardings returned from partitioning: expected " "Tuple[bytes, List[HloSharding], HloSharding] got: %s", - nb::cast(nb::repr(py_result))); + nb::cast(nb::repr(py_result))); } } catch (const nb::python_error& e) { return xla::Internal("custom_partitioner: %s", e.what()); @@ -136,7 +135,7 @@ class PyCustomCallPartitionerCallbacks { std::vector> arg_shardings = std::move(std::get<1>(args_tuple)); xla::Shape result_shape = std::move(std::get<2>(args_tuple)); - std::string_view backend_config = std::move(std::get<3>(args_tuple)); + absl::string_view backend_config = std::move(std::get<3>(args_tuple)); std::optional result; nb::gil_scoped_acquire gil; @@ -161,7 +160,7 @@ class PyCustomCallPartitionerCallbacks { TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); xla::HloSharding result_sharding = std::move(std::get<0>(args_tuple)); xla::Shape result_shape = std::move(std::get<1>(args_tuple)); - std::string_view backend_config = std::move(std::get<2>(args_tuple)); + absl::string_view backend_config = std::move(std::get<2>(args_tuple)); nb::gil_scoped_acquire gil; try { @@ -229,7 +228,7 @@ void BuildCustomCallShardingPybindAPI(nb::module_& m) { return; } - if (std::string_view(c_api->name()) != "pjrt_c_api") { + if (absl::string_view(c_api->name()) != "pjrt_c_api") { throw absl::InvalidArgumentError( "Argument to register_custom_call_partitioner was not a " "pjrt_c_api capsule."); diff --git a/third_party/xla/xla/python/custom_calls_testlib.cc b/third_party/xla/xla/python/custom_calls_testlib.cc index c8563e00f62795..2c57fbd7e52fde 100644 --- a/third_party/xla/xla/python/custom_calls_testlib.cc +++ b/third_party/xla/xla/python/custom_calls_testlib.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "nanobind/nanobind.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" @@ -64,11 +67,40 @@ XLA_FFI_DEFINE_HANDLER(kSubtractCst, SubtractCst, .Ret>() .Attr("cst")); +// XLA FFI calls can also be stateful. +struct TestFfiState { + static TypeId id; + explicit TestFfiState(int32_t value) : value(value) {} + int32_t value; +}; +TypeId TestFfiState::id = {}; + +static ErrorOr> StateInstantiate() { + return std::make_unique(42); +} + +static Error StateExecute(TestFfiState* state, + Result> out) { + *out->typed_data() = state->value; + return Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER(kStateInstantiate, StateInstantiate, + Ffi::BindInstantiate()); +XLA_FFI_DEFINE_HANDLER( + kStateExecute, StateExecute, + Ffi::Bind().Ctx>().Ret>()); + template static auto BindFunction(T* fn) { return nb::capsule(reinterpret_cast(fn)); } +template +static auto BindTypeId(T* typeId) { + return nb::capsule(reinterpret_cast(typeId)); +} + // Custom calls registration library that exports function pointers to XLA FFI // handlers to the python users. NB_MODULE(custom_calls_testlib, m) { @@ -78,8 +110,19 @@ NB_MODULE(custom_calls_testlib, m) { dict["always_succeed"] = BindFunction(kAlwaysSucceed); dict["subtract_f32"] = BindFunction(kSubtract); dict["subtract_f32_cst"] = BindFunction(kSubtractCst); + + nb::dict bundle; + bundle["instantiate"] = BindFunction(kStateInstantiate); + bundle["execute"] = BindFunction(kStateExecute); + dict["stateful"] = bundle; + return dict; }); + m.def("type_ids", []() { + nb::dict type_ids; + type_ids["test_ffi_state"] = BindTypeId(&TestFfiState::id); + return type_ids; + }); } } // namespace xla::ffi diff --git a/third_party/xla/xla/python/custom_partition_callback.cc b/third_party/xla/xla/python/custom_partition_callback.cc index df49dfc1e37bc4..3349385ffa43e2 100644 --- a/third_party/xla/xla/python/custom_partition_callback.cc +++ b/third_party/xla/xla/python/custom_partition_callback.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -31,6 +30,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "xla/debug_options_flags.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -46,8 +46,11 @@ limitations under the License. #include "xla/pjrt/mlir_to_hlo.h" #include "xla/service/call_inliner.h" #include "xla/service/custom_call_sharding_helper.h" -#include "xla/service/spmd/spmd_partitioner_util.h" +#include "xla/service/spmd/spmd_partitioner.h" #include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -202,8 +205,8 @@ void SetCAPIString(JAX_CustomCallPartitioner_string& out, std::string result, out.size = scratch.back().size(); } -std::string_view ToStringView(JAX_CustomCallPartitioner_string data) { - return std::string_view(data.data, data.size); +absl::string_view ToStringView(JAX_CustomCallPartitioner_string data) { + return absl::string_view(data.data, data.size); } void SetCAPIAval(JAX_CustomCallPartitioner_aval& result, @@ -343,7 +346,7 @@ PartitionScratch PopulateArgs(JAX_CustomCallPartitioner_Partition_Args* args, absl::StatusOr, std::vector>, - xla::Shape, std::optional, std::string_view>> + xla::Shape, std::optional, absl::string_view>> ReadArgs(JAX_CustomCallPartitioner_Partition_Args* args) { std::vector shapes; std::vector> shardings; @@ -369,14 +372,14 @@ ReadArgs(JAX_CustomCallPartitioner_Partition_Args* args) { } return std::tuple, std::vector>, xla::Shape, - std::optional, std::string_view>( + std::optional, absl::string_view>( std::move(shapes), std::move(shardings), std::move(result_shape), std::move(result_sharding), ToStringView(args->backend_config)); } absl::StatusOr, std::vector>, - xla::Shape, std::string_view>> + xla::Shape, absl::string_view>> ReadArgs(JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) { std::vector shapes; std::vector> shardings; @@ -397,9 +400,9 @@ ReadArgs(JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) { TF_ASSIGN_OR_RETURN(auto result_shape, ReadHloShape(args->result_shape)); return std::tuple, std::vector>, xla::Shape, - std::string_view>(std::move(shapes), std::move(shardings), - std::move(result_shape), - ToStringView(args->backend_config)); + absl::string_view>(std::move(shapes), std::move(shardings), + std::move(result_shape), + ToStringView(args->backend_config)); } PartitionScratch PopulateArgs( @@ -455,11 +458,11 @@ absl::StatusOr> ConsumeResults( return ReadHloSharding(args->result_sharding); } -absl::StatusOr> +absl::StatusOr> ReadArgs(JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) { TF_ASSIGN_OR_RETURN(auto shape, ReadHloShape(args->result_shape)); TF_ASSIGN_OR_RETURN(auto sharding, ReadHloSharding(args->result_sharding)); - return std::tuple( + return std::tuple( std::move(sharding), std::move(shape), ToStringView(args->backend_config)); } diff --git a/third_party/xla/xla/python/custom_partition_callback.h b/third_party/xla/xla/python/custom_partition_callback.h index 33cc31e75fc9bf..6ba1789a038daa 100644 --- a/third_party/xla/xla/python/custom_partition_callback.h +++ b/third_party/xla/xla/python/custom_partition_callback.h @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include "xla/hlo/ir/hlo_instruction.h" @@ -37,7 +36,7 @@ PartitionScratch PopulateArgs(JAX_CustomCallPartitioner_Partition_Args* args, const xla::HloInstruction* instruction); absl::StatusOr, std::vector>, - xla::Shape, std::optional, std::string_view>> + xla::Shape, std::optional, absl::string_view>> ReadArgs(JAX_CustomCallPartitioner_Partition_Args* args); void PopulateResults( absl::StatusOr, @@ -50,7 +49,7 @@ ConsumeResults(JAX_CustomCallPartitioner_Partition_Args* args); absl::StatusOr, std::vector>, - xla::Shape, std::string_view>> + xla::Shape, absl::string_view>> ReadArgs(JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args); PartitionScratch PopulateArgs( JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args, @@ -61,7 +60,7 @@ void PopulateResults( absl::StatusOr> ConsumeResults( JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args); -absl::StatusOr> +absl::StatusOr> ReadArgs(JAX_CustomCallPartitioner_PropagateUserSharding_Args* args); PartitionScratch PopulateArgs( JAX_CustomCallPartitioner_PropagateUserSharding_Args* args, diff --git a/third_party/xla/xla/python/dlpack.cc b/third_party/xla/xla/python/dlpack.cc index 2848fc20827b18..dfe30f0dda6cd3 100644 --- a/third_party/xla/xla/python/dlpack.cc +++ b/third_party/xla/xla/python/dlpack.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -419,7 +418,7 @@ absl::StatusOr BufferToDLPackManagedTensor( pjrt_buffer->dimensions().end()); // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout - Layout xla_layout = GetXlaLayoutUnsafe(pjrt_buffer->layout()); + Layout xla_layout = pjrt_buffer->layout()->xla_layout(); pack->strides = StridesForShape(pjrt_buffer->element_type(), pjrt_buffer->dimensions(), xla_layout); @@ -458,11 +457,11 @@ absl::StatusOr DLPackManagedTensorToBuffer( auto* cpu_pjrt_client = cpu_client ? (*cpu_client)->pjrt_client() : nullptr; auto* gpu_pjrt_client = gpu_client ? (*gpu_client)->pjrt_client() : nullptr; - if (std::string_view(tensor.name()) != kDlTensorCapsuleName) { + if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { return InvalidArgument( "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " "Note that a DLPack tensor may be consumed at most once.", - std::string_view(tensor.name())); + absl::string_view(tensor.name())); } DLManagedTensor* dlmt = static_cast(tensor.data()); if (dlmt->dl_tensor.ndim < 0) { @@ -552,11 +551,11 @@ absl::StatusOr DLPackManagedTensorToBuffer( "DLPack is only supported for devices addressable by the current " "process."); } - if (std::string_view(tensor.name()) != kDlTensorCapsuleName) { + if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { return InvalidArgument( "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " "Note that a DLPack tensor may be consumed at most once.", - std::string_view(tensor.name())); + absl::string_view(tensor.name())); } DLManagedTensor* dlmt = static_cast(tensor.data()); if (dlmt->dl_tensor.ndim < 0) { diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index c67816cb6534f5..aa2754ae424e9c 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -410,8 +410,8 @@ cc_library( deps = [ ":attribute_map", ":ifrt", - "//xla:test", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_layout", "//xla/tsl/concurrency:ref_count", @@ -594,6 +594,8 @@ xla_cc_test( ":device_test_util", ":ifrt", ":sharding_serdes", + "//xla:shape_util", + "//xla/pjrt:pjrt_layout", "//xla/tsl/concurrency:ref_count", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/functional:bind_front", @@ -620,6 +622,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], diff --git a/third_party/xla/xla/python/ifrt/array.h b/third_party/xla/xla/python/ifrt/array.h index 2a4ff23b1fdb1d..e31a2600352324 100644 --- a/third_party/xla/xla/python/ifrt/array.h +++ b/third_party/xla/xla/python/ifrt/array.h @@ -76,7 +76,7 @@ class Array : public llvm::RTTIExtends { // The device memory layout for each shard of the Array. All shards are // assumed to have the same layout. Cannot be nullptr; implementations should // return UNIMPLEMENTED instead. - virtual absl::StatusOr> layout() const = 0; + virtual absl::StatusOr> layout() const = 0; // Breaks an array up into per-device arrays. This is the elimination // counterpart of `Client::AssembleArrayFromSingleDeviceArrays()`. diff --git a/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc b/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc index b8ef7caed58dec..622a80a35d8366 100644 --- a/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc +++ b/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc @@ -311,9 +311,8 @@ TEST(ArrayImplTest, MakeArrayFromHostBufferReplicated) { std::iota(data->begin(), data->end(), 0); absl::Span devices = client->addressable_devices(); std::shared_ptr sharding = ConcreteEvenSharding::Create( - BasicDeviceList::Create( - BasicDeviceList::Devices(devices.begin(), devices.end())), - MemoryKind(), shape, /*shard_shape=*/shape, /*is_fully_replicated=*/true); + BasicDeviceList::Create(devices), MemoryKind(), shape, + /*shard_shape=*/shape, /*is_fully_replicated=*/true); TF_ASSERT_OK_AND_ASSIGN( auto array, @@ -376,9 +375,9 @@ TEST(ArrayImplTest, AssembleArray) { std::vector> arrays({array0, array1}); Shape assembled_shape({4, 3}); std::shared_ptr assembled_sharding = OpaqueSharding::Create( - BasicDeviceList::Create(BasicDeviceList::Devices( + BasicDeviceList::Create( {array0->sharding().devices()->devices().front(), - array1->sharding().devices()->devices().front()})), + array1->sharding().devices()->devices().front()}), MemoryKind()); TF_ASSERT_OK_AND_ASSIGN( auto assembled_array, @@ -424,9 +423,9 @@ TEST(ArrayImplTest, AssembleAndDisassembleArray) { Shape assembled_shape({4, 3}); ShardingParam sharding_param( /*dim_shards=*/{2, 1}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 1}}); - auto ifrt_device_list = BasicDeviceList::Create(BasicDeviceList::Devices( + auto ifrt_device_list = BasicDeviceList::Create( {array0->sharding().devices()->devices().front(), - array1->sharding().devices()->devices().front()})); + array1->sharding().devices()->devices().front()}); TF_ASSERT_OK_AND_ASSIGN( std::shared_ptr sharding_param_sharding, ShardingParamSharding::Create(std::move(sharding_param), ifrt_device_list, @@ -537,9 +536,8 @@ TEST(ArrayImplTest, CopyToSameDevices) { TEST(ArrayImplTest, CopyToDifferentDevice) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); - tsl::RCReference devices = BasicDeviceList::Create( - BasicDeviceList::Devices(client->addressable_devices().begin(), - client->addressable_devices().end())); + tsl::RCReference devices = + BasicDeviceList::Create(client->addressable_devices()); DType dtype(DType::kF32); Shape shape({2, 3}); @@ -639,8 +637,7 @@ TEST(ArrayImplTest, CopyMixedSourceDevices) { Device* new_device = client->addressable_devices().at(1); EXPECT_THAT(client ->CopyArrays(absl::MakeSpan(arrays), - BasicDeviceList::Create( - BasicDeviceList::Devices({new_device})), + BasicDeviceList::Create({new_device}), MemoryKind(), ArrayCopySemantics::kAlwaysCopy) .status(), StatusIs(absl::StatusCode::kInvalidArgument)); @@ -674,8 +671,7 @@ TEST(ArrayImplTest, CopyMixedSourceMemoryKind) { Device* new_device = client->addressable_devices().at(1); EXPECT_THAT(client ->CopyArrays(absl::MakeSpan(arrays), - BasicDeviceList::Create( - BasicDeviceList::Devices({new_device})), + BasicDeviceList::Create({new_device}), MemoryKind(), ArrayCopySemantics::kAlwaysCopy) .status(), StatusIs(absl::StatusCode::kInvalidArgument)); diff --git a/third_party/xla/xla/python/ifrt/array_spec.cc b/third_party/xla/xla/python/ifrt/array_spec.cc index b8b8d5b1f872dd..46023a3d87e5d3 100644 --- a/third_party/xla/xla/python/ifrt/array_spec.cc +++ b/third_party/xla/xla/python/ifrt/array_spec.cc @@ -15,11 +15,13 @@ limitations under the License. #include "xla/python/ifrt/array_spec.h" +#include #include #include #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array_spec.pb.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/dtype.h" @@ -36,8 +38,16 @@ absl::StatusOr ArraySpec::FromProto( TF_ASSIGN_OR_RETURN(auto shape, Shape::FromProto(proto.shape())); TF_ASSIGN_OR_RETURN(auto sharding, Sharding::FromProto(lookup_device, proto.sharding())); - return ArraySpec{/*dtype=*/dtype, /*shape=*/std::move(shape), - /*sharding=*/std::move(sharding)}; + std::shared_ptr layout; + if (proto.has_layout()) { + TF_ASSIGN_OR_RETURN(layout, xla::PjRtLayout::Deserialize(proto.layout())); + } + return ArraySpec{ + /*dtype=*/dtype, + /*shape=*/std::move(shape), + /*sharding=*/std::move(sharding), + /*layout=*/std::move(layout), + }; } absl::StatusOr ArraySpec::ToProto() const { @@ -45,13 +55,17 @@ absl::StatusOr ArraySpec::ToProto() const { *proto.mutable_dtype() = dtype.ToProto(); *proto.mutable_shape() = shape.ToProto(); TF_ASSIGN_OR_RETURN(*proto.mutable_sharding(), sharding->ToProto()); + if (layout != nullptr) { + proto.set_layout(layout->Serialize()); + } return proto; } std::string ArraySpec::DebugString() const { - return absl::StrCat("ArraySpec(dtype=", dtype.DebugString(), - ",shape=", shape.DebugString(), - ",sharding=", sharding->DebugString(), ")"); + return absl::StrCat( + "ArraySpec(dtype=", dtype.DebugString(), ",shape=", shape.DebugString(), + ",sharding=", sharding->DebugString(), + ",layout=", (layout != nullptr ? layout->ToString() : ""), ")"); } } // namespace ifrt diff --git a/third_party/xla/xla/python/ifrt/array_spec.h b/third_party/xla/xla/python/ifrt/array_spec.h index 9261c187483f79..329ef0ab17685d 100644 --- a/third_party/xla/xla/python/ifrt/array_spec.h +++ b/third_party/xla/xla/python/ifrt/array_spec.h @@ -22,8 +22,8 @@ limitations under the License. #include "absl/base/nullability.h" #include "absl/log/log.h" #include "absl/status/statusor.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array_spec.pb.h" -#include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/shape.h" @@ -39,8 +39,7 @@ struct ArraySpec { DType dtype; Shape shape; absl::Nonnull> sharding; - // TODO(hyeontaek): Add `layout` once expressing the default layout can be - // done in a symbolic manner. + absl::Nullable> layout; // Constructs `ArraySpec` from `ArraySpecProto`. static absl::StatusOr FromProto( diff --git a/third_party/xla/xla/python/ifrt/array_spec.proto b/third_party/xla/xla/python/ifrt/array_spec.proto index 6d61b71a004039..411cd9ac3bc0b7 100644 --- a/third_party/xla/xla/python/ifrt/array_spec.proto +++ b/third_party/xla/xla/python/ifrt/array_spec.proto @@ -26,4 +26,5 @@ message ArraySpecProto { DTypeProto dtype = 1; ShapeProto shape = 2; ShardingProto sharding = 3; + optional bytes layout = 4; } diff --git a/third_party/xla/xla/python/ifrt/client.h b/third_party/xla/xla/python/ifrt/client.h index 441aa66781a462..13d797ecc74bc0 100644 --- a/third_party/xla/xla/python/ifrt/client.h +++ b/third_party/xla/xla/python/ifrt/client.h @@ -237,13 +237,13 @@ class Client : public llvm::RTTIExtends { virtual absl::StatusOr> GetTopologyForDevices( const tsl::RCReference& devices) const = 0; - // Returns the default layout on `device` for a buffer with `dtype` and - // single-shard dimensions `dims`. + // Returns the default layout on `device` with `memory_kind` for a buffer with + // `dtype` and single-shard dimensions `dims`. // TODO(hyeontaek): Change the API to take `Shape` and `Sharding` instead of // single-shard dimensions and device. - virtual absl::StatusOr> - GetDefaultLayoutForDevice(DType dtype, absl::Span dims, - Device* device) const = 0; + virtual absl::StatusOr> GetDefaultLayout( + DType dtype, absl::Span dims, Device* device, + xla::ifrt::MemoryKind memory_kind) const = 0; static char ID; // NOLINT }; diff --git a/third_party/xla/xla/python/ifrt/device_list.cc b/third_party/xla/xla/python/ifrt/device_list.cc index 35b37b5ec1a1dd..76e7de9e8e8551 100644 --- a/third_party/xla/xla/python/ifrt/device_list.cc +++ b/third_party/xla/xla/python/ifrt/device_list.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -65,6 +66,16 @@ tsl::RCReference BasicDeviceList::Create(Devices devices) { return tsl::MakeRef(std::move(devices)); } +tsl::RCReference BasicDeviceList::Create( + absl::Span devices) { + return Create(Devices(devices.begin(), devices.end())); +} + +tsl::RCReference BasicDeviceList::Create( + std::initializer_list devices) { + return Create(Devices(devices.begin(), devices.end())); +} + BasicDeviceList::BasicDeviceList(Devices devices) : devices_(std::move(devices)), hash_(kUnsetHash) {} @@ -109,8 +120,7 @@ std::string BasicDeviceList::ToString() const { return absl::StrCat("BasicDeviceList([", absl::StrJoin(devices_, ",", [](std::string* out, Device* device) { - absl::StrAppend(out, - device->DebugString()); + absl::StrAppend(out, device->ToString()); }), "])"); } diff --git a/third_party/xla/xla/python/ifrt/device_list.h b/third_party/xla/xla/python/ifrt/device_list.h index b10dad716e76eb..27479428aa3aff 100644 --- a/third_party/xla/xla/python/ifrt/device_list.h +++ b/third_party/xla/xla/python/ifrt/device_list.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include @@ -130,6 +131,9 @@ class BasicDeviceList : public llvm::RTTIExtends { // Constructor with a pre-populated `devices`. static tsl::RCReference Create(Devices devices); + static tsl::RCReference Create(absl::Span devices); + static tsl::RCReference Create( + std::initializer_list devices); ~BasicDeviceList() override = default; diff --git a/third_party/xla/xla/python/ifrt/dtype.cc b/third_party/xla/xla/python/ifrt/dtype.cc index a79240f51a7e23..ed68a1d11403c2 100644 --- a/third_party/xla/xla/python/ifrt/dtype.cc +++ b/third_party/xla/xla/python/ifrt/dtype.cc @@ -214,6 +214,10 @@ std::string DType::DebugString() const { return "INVALID"; case kPred: return "PRED"; + case kS2: + return "S2"; + case kS4: + return "S4"; case kS8: return "S8"; case kS16: @@ -222,6 +226,10 @@ std::string DType::DebugString() const { return "S32"; case kS64: return "S64"; + case kU2: + return "U2"; + case kU4: + return "U4"; case kU8: return "U8"; case kU16: @@ -246,6 +254,20 @@ std::string DType::DebugString() const { return "TOKEN"; case kOpaque: return "OPAQUE"; + case kF8E3M4: + return "F8E3M4"; + case kF8E4M3: + return "F8E4M3"; + case kF8E4M3FN: + return "F8E4M3FN"; + case kF8E4M3B11FNUZ: + return "F8E4M3B11FNUZ"; + case kF8E4M3FNUZ: + return "F8E4M3FNUZ"; + case kF8E5M2: + return "F8E5M2"; + case kF8E5M2FNUZ: + return "F8E5M2FNUZ"; case kString: return "STRING"; default: diff --git a/third_party/xla/xla/python/ifrt/executable.h b/third_party/xla/xla/python/ifrt/executable.h index 5332768c885b9c..7e8ecc3b0daba6 100644 --- a/third_party/xla/xla/python/ifrt/executable.h +++ b/third_party/xla/xla/python/ifrt/executable.h @@ -78,10 +78,10 @@ class Executable : public llvm::RTTIExtends { // Returns a list of output `OpSharding`. virtual std::optional> GetOutputShardings() const = 0; // Returns a list of parameter layouts. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetParameterLayouts() const = 0; // Returns a list of output/result layouts. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetOutputLayouts() const = 0; // Returns an `HloModule` (optimized) per partition. virtual absl::StatusOr>> @@ -99,11 +99,6 @@ class Executable : public llvm::RTTIExtends { // differ for different implementations and platforms. virtual absl::StatusOr GetCostAnalysis() const = 0; - // Returns the compile options used to compile this executable. - // TODO(phawkins): consider removing this API and having the client remember - // the compile options used to create the executable. - virtual const CompileOptions* GetCompileOptions() const = 0; - static char ID; // NOLINT }; @@ -187,10 +182,10 @@ class LoadedExecutable // Returns a list of output OpSharding. virtual std::optional> GetOutputShardings() const = 0; // Returns a list of parameter layouts. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetParameterLayouts() const = 0; // Returns a list of output/result layouts. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetOutputLayouts() const = 0; // Return an HloModule (optimized) per partition. virtual absl::StatusOr>> diff --git a/third_party/xla/xla/python/ifrt/ir/BUILD b/third_party/xla/xla/python/ifrt/ir/BUILD index d0affacfeb439c..8176a32edcce96 100644 --- a/third_party/xla/xla/python/ifrt/ir/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/BUILD @@ -485,6 +485,7 @@ tsl_pybind_extension( "//xla/python/ifrt/ir/transforms:utils", "//xla/python/ifrt/support:module_parsing", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@llvm-project//mlir:CAPIIRHeaders", "@llvm-project//mlir:IR", "@llvm-project//mlir:MLIRBindingsPythonHeaders", diff --git a/third_party/xla/xla/python/ifrt/ir/constants.h b/third_party/xla/xla/python/ifrt/ir/constants.h index 52b22e7b9c5dd2..512b22259fdc03 100644 --- a/third_party/xla/xla/python/ifrt/ir/constants.h +++ b/third_party/xla/xla/python/ifrt/ir/constants.h @@ -57,6 +57,11 @@ inline constexpr llvm::StringLiteral kIfrtMemoryKindAttrName = inline constexpr llvm::StringLiteral kIfrtEntryFunctionAttrName = "ifrt.entry_function"; +// Name of UnitAttr on CallOp used to indicate that an atom program was +// partitioned by the Sdy partitioner. +inline constexpr llvm::StringLiteral kIsSdyPartitioned = + "ifrt.is_sdy_partitioned"; + inline constexpr llvm::StringLiteral kCalleeMainFuncName = "main"; // Name of StringAttr used to store the HloSharding. diff --git a/third_party/xla/xla/python/ifrt/ir/ir_py.cc b/third_party/xla/xla/python/ifrt/ir/ir_py.cc index 73f889eaa7f3f4..806cccd73bbf05 100644 --- a/third_party/xla/xla/python/ifrt/ir/ir_py.cc +++ b/third_party/xla/xla/python/ifrt/ir/ir_py.cc @@ -15,9 +15,9 @@ limitations under the License. #include #include -#include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/PybindAdaptors.h" // IWYU pragma: keep; Needed to allow MlirModule -> ModuleOp. #include "mlir/CAPI/IR.h" @@ -41,8 +41,8 @@ namespace ifrt { namespace { absl::StatusOr SerializedVersionedProgram( - MlirModule module, std::string_view ifrt_ir_version, - std::string_view atom_program_version, bool version_in_place) { + MlirModule module, absl::string_view ifrt_ir_version, + absl::string_view atom_program_version, bool version_in_place) { auto program = std::make_unique(unwrap(module)); TF_ASSIGN_OR_RETURN( auto serialized, @@ -55,8 +55,8 @@ absl::StatusOr SerializedVersionedProgram( } absl::StatusOr SerializedVersionedProgram( - std::string_view module_str, std::string_view ifrt_ir_version, - std::string_view atom_program_version, bool version_in_place) { + absl::string_view module_str, absl::string_view ifrt_ir_version, + absl::string_view atom_program_version, bool version_in_place) { mlir::MLIRContext context; TF_ASSIGN_OR_RETURN(auto module, support::ParseMlirModuleString(module_str, context)); @@ -72,7 +72,7 @@ absl::StatusOr SerializedVersionedProgram( } absl::StatusOr DeserializeVersionedProgram( - mlir::MLIRContext* context, std::string_view serialized_program) { + mlir::MLIRContext* context, absl::string_view serialized_program) { xla::ifrt::Serialized serialized; serialized.set_type_name(std::string(IfrtIRProgram::type_name())); serialized.set_data(std::string(serialized_program)); @@ -85,7 +85,7 @@ absl::StatusOr DeserializeVersionedProgram( } absl::StatusOr DeserializeVersionedProgram( - std::string_view serialized_program) { + absl::string_view serialized_program) { mlir::MLIRContext context; support::RegisterMlirDialects(context); TF_ASSIGN_OR_RETURN( @@ -121,8 +121,8 @@ PYBIND11_MODULE(ir_py, m) { // modules. m.def( "serialize_versioned_program", - [](MlirModule module, std::string_view ifrt_ir_version, - std::string_view atom_program_version, + [](MlirModule module, absl::string_view ifrt_ir_version, + absl::string_view atom_program_version, bool version_in_place) -> py::bytes { return xla::ValueOrThrow(SerializedVersionedProgram( module, ifrt_ir_version, atom_program_version, version_in_place)); @@ -131,8 +131,8 @@ PYBIND11_MODULE(ir_py, m) { py::arg("atom_program_version"), py::arg("version_in_place")); m.def( "serialize_versioned_program_str", - [](std::string_view module_str, std::string_view ifrt_ir_version, - std::string_view atom_program_version, + [](absl::string_view module_str, absl::string_view ifrt_ir_version, + absl::string_view atom_program_version, bool version_in_place) -> py::bytes { return xla::ValueOrThrow( SerializedVersionedProgram(module_str, ifrt_ir_version, @@ -145,14 +145,14 @@ PYBIND11_MODULE(ir_py, m) { m.def( "deserialize_versioned_program", [](MlirContext context, - std::string_view serialized_program) -> MlirModule { + absl::string_view serialized_program) -> MlirModule { return wrap(xla::ValueOrThrow( DeserializeVersionedProgram(unwrap(context), serialized_program))); }, py::arg("context"), py::arg("serialized_program")); m.def( "deserialize_versioned_program_str", - [](std::string_view serialized_program) -> py::bytes { + [](absl::string_view serialized_program) -> py::bytes { return xla::ValueOrThrow( DeserializeVersionedProgram(serialized_program)); }, diff --git a/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_base.cc b/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_base.cc index ff64a5b4dd6219..341bcfc92ca6a9 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_base.cc +++ b/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_base.cc @@ -163,9 +163,7 @@ IfrtIrExecutableImplTestBase::PickDevices(int count) { absl::Span devices = client_->devices(); TF_RET_CHECK(count <= devices.size()) << "Requested " << count << " devices. Only have " << devices.size(); - auto picked = devices.first(count); - return BasicDeviceList::Create( - BasicDeviceList::Devices(picked.begin(), picked.end())); + return BasicDeviceList::Create(devices.first(count)); } } // namespace test_util diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir index b99e0f9a43b79e..22257730e01d5e 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir @@ -25,3 +25,34 @@ module @call_hlo { } } } + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @call_hlo_sdy_lowered +module @call_hlo_sdy_lowered attributes { + mhlo.frontend_attributes = { + xla.sdy.meshes ="{mesh = #sdy.mesh<[\\\22x\\\22=2]>}"}} { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + // CHECK: ifrt.CallLoadedExecutable @fake_component__fake_method_1(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0,1] + {ifrt.module_type = "xla", ifrt.is_sdy_partitioned} : (!array) -> !array + return %0 : !array + } + + // module @add_one attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\\\22x\\\22=2]>}"}, sym_visibility = "private"} + // CHECK: ifrt.LoadedExecutable @fake_component__fake_method + // CHECK-SAME: on devices [0, 1] + // CHECK: (!ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>) + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + module @add_one attributes {sym_visibility = "private"} { + func.func private @main( + %arg0: tensor<2x2xi32> {mhlo.sharding = "{devices=[2,1]<=[2]}"}) + -> (tensor<2x2xi32> {mhlo.sharding = "{devices=[2,1]<=[2]}"}) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } +} diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD index 7508bb9d935cd9..958d067bb04f94 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD @@ -85,6 +85,8 @@ cc_library( "//xla/service:compilation_environments", "//xla/service:computation_placer_hdr", "//xla/service:hlo_proto_cc", + "//xla/service/spmd/shardy:constants", + "//xla/service/spmd/shardy:utils", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", @@ -95,6 +97,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:AllPassesAndDialects", @@ -109,6 +112,7 @@ cc_library( "@local_tsl//tsl/platform:fingerprint", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", + "@shardy//shardy/dialect/sdy/ir:dialect", "@stablehlo//:register", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_serialization", diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc index 04f005ff73cb43..216fb974c024b0 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc @@ -26,6 +26,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" @@ -42,6 +43,7 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/TypeID.h" +#include "shardy/dialect/sdy/ir/dialect.h" #include "stablehlo/dialect/StablehloOps.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/python/ifrt/compiler.h" @@ -52,6 +54,8 @@ limitations under the License. #include "xla/python/ifrt/ir/transforms/passes.h" #include "xla/python/ifrt/ir/transforms/utils.h" #include "xla/service/hlo.pb.h" +#include "xla/service/spmd/shardy/constants.h" +#include "xla/service/spmd/shardy/utils.h" namespace xla { namespace ifrt { @@ -83,6 +87,7 @@ class IfrtCompileAtomProgramPass void getDependentDialects(::mlir::DialectRegistry& registry) const override { registry.insert(); registry.insert(); + registry.insert(); } void runOnOperation() override; @@ -108,6 +113,14 @@ void IfrtCompileAtomProgramPass::runOnOperation() { // Map from the hash of the CallOp to the compile future. llvm::DenseMap call_to_compile_futures; mlir::ModuleOp module_op = getOperation(); + + mlir::Attribute meshes_round_trip_attr; + // TODO: icgog - This attribute will be deleted in the IFRT -> VIFRT + // legalization. Fix in order to be able to use Sdy with VIFRT. + if (auto front_end_attr = xla::sdy::getFrontendAttrs(module_op)) { + meshes_round_trip_attr = front_end_attr.get(xla::sdy::kMeshesRoundTripAttr); + } + // Walk and dispatch the compilations in parallel. auto compile_result = module_op.walk([&](CallOp call_op) -> mlir::WalkResult { @@ -125,6 +138,21 @@ void IfrtCompileAtomProgramPass::runOnOperation() { << callee.getSymName() << ". Actual callee parent: " << callee->getParentOp()->getName(); } + + if (call_op->hasAttr(kIsSdyPartitioned)) { + // Add the meshes roundtrip attribute to the callee module if the + // atom program was partitioned with sdy. + if (!meshes_round_trip_attr) { + return call_op.emitOpError() + << "requires meshes roundtrip attribute to be set on the " + "program module if the atom program was partitioned " + "with sdy."; + } + xla::sdy::setFrontendAttribute( + callee_module, xla::sdy::kMeshesRoundTripAttr, + meshes_round_trip_attr, /*escapeAttr=*/false); + } + absl::StatusOr compile_future = atom_program_compiler_.CompileModule(call_op, callee_module); if (!compile_future.ok()) { diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_reshard_to_copy_arrays_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_reshard_to_copy_arrays_pass.cc index 82ef8580ff31d1..44423bf6e341e9 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_reshard_to_copy_arrays_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_reshard_to_copy_arrays_pass.cc @@ -170,8 +170,8 @@ class IfrtReshardToCopyArraysPass mlir::RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); mlir::ModuleOp module_op = getOperation(); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily(module_op, - std::move(patterns)))) { + if (mlir::failed( + mlir::applyPatternsGreedily(module_op, std::move(patterns)))) { signalPassFailure(); } } diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_bound_external_loaded_executable_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_bound_external_loaded_executable_pass.cc index 2e4f3d03fac2d1..dc00af01840122 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_bound_external_loaded_executable_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_bound_external_loaded_executable_pass.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include #include #include @@ -22,6 +21,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -77,7 +77,7 @@ class IfrtVerifyBoundExternalLoadedExecutablePass absl::Status VerifyShardingsEqual( llvm::ArrayRef types, const std::vector& shardings, - std::string_view sharding_type); + absl::string_view sharding_type); // Map from symbol name of LoadedExecutableOp to externally bound // LoadedExecutable. @@ -87,7 +87,7 @@ class IfrtVerifyBoundExternalLoadedExecutablePass absl::Status IfrtVerifyBoundExternalLoadedExecutablePass::VerifyShardingsEqual( llvm::ArrayRef types, const std::vector& shardings, - std::string_view sharding_type) { + absl::string_view sharding_type) { for (const auto& it : llvm::enumerate(llvm::zip(types, shardings))) { const auto& [param_type, sharding] = it.value(); TF_ASSIGN_OR_RETURN(auto hlo_sharding, diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/passes.cc b/third_party/xla/xla/python/ifrt/ir/transforms/passes.cc index 0bf65c503a8145..79a895a71e0cda 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/passes.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/passes.cc @@ -28,6 +28,7 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/ir/atom_program_compiler.h" +#include "xla/python/ifrt/ir/ifrt_ir_program.pb.h" #include "xla/python/ifrt/ir/version.h" namespace xla { diff --git a/third_party/xla/xla/python/ifrt/mock.cc b/third_party/xla/xla/python/ifrt/mock.cc index d62646bf5b78ad..0071575b62d32b 100644 --- a/third_party/xla/xla/python/ifrt/mock.cc +++ b/third_party/xla/xla/python/ifrt/mock.cc @@ -78,9 +78,10 @@ MockArray::MockArray(tsl::RCReference delegated) return delegated_->shared_ptr_sharding(); }); ON_CALL(*this, layout) - .WillByDefault([this]() -> absl::StatusOr> { - return delegated_->layout(); - }); + .WillByDefault( + [this]() -> absl::StatusOr> { + return delegated_->layout(); + }); ON_CALL(*this, DisassembleIntoSingleDeviceArrays(_)) .WillByDefault([this](ArrayCopySemantics semantics) { return delegated_->DisassembleIntoSingleDeviceArrays(semantics); @@ -217,11 +218,12 @@ MockClient::MockClient(std::unique_ptr delegated) [this](const tsl::RCReference& devices) { return delegated_->GetTopologyForDevices(devices); }); - ON_CALL(*this, GetDefaultLayoutForDevice) + ON_CALL(*this, GetDefaultLayout) .WillByDefault([this](xla::ifrt::DType dtype, absl::Span dims, - xla::ifrt::Device* device) { - return delegated_->GetDefaultLayoutForDevice(dtype, dims, device); + xla::ifrt::Device* device, + xla::ifrt::MemoryKind memory_kind) { + return delegated_->GetDefaultLayout(dtype, dims, device, memory_kind); }); } // LINT.ThenChange() diff --git a/third_party/xla/xla/python/ifrt/mock.h b/third_party/xla/xla/python/ifrt/mock.h index 11ba98cc96326a..9fd960156c1e1b 100644 --- a/third_party/xla/xla/python/ifrt/mock.h +++ b/third_party/xla/xla/python/ifrt/mock.h @@ -30,6 +30,7 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/test.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" @@ -52,7 +53,6 @@ limitations under the License. #include "xla/python/ifrt/topology.h" #include "xla/python/ifrt/tuple.h" #include "xla/python/ifrt/value.h" -#include "xla/test.h" #include "xla/tsl/concurrency/ref_count.h" namespace xla { @@ -76,7 +76,7 @@ class MockArray : public llvm::RTTIExtends { MOCK_METHOD(const Sharding&, sharding, (), (const, final)); MOCK_METHOD(absl::Nonnull>, shared_ptr_sharding, (), (const, final)); - MOCK_METHOD(absl::StatusOr>, layout, (), + MOCK_METHOD(absl::StatusOr>, layout, (), (const, final)); MOCK_METHOD(absl::StatusOr>>, DisassembleIntoSingleDeviceArrays, (ArrayCopySemantics semantics), @@ -173,10 +173,10 @@ class MockClient : public llvm::RTTIExtends { MOCK_METHOD(absl::StatusOr>, GetTopologyForDevices, (const tsl::RCReference& devices), (const, final)); - MOCK_METHOD(absl::StatusOr>, - GetDefaultLayoutForDevice, + MOCK_METHOD(absl::StatusOr>, + GetDefaultLayout, (xla::ifrt::DType dtype, absl::Span dims, - xla::ifrt::Device* device), + xla::ifrt::Device* device, xla::ifrt::MemoryKind memory_kind), (const, final)); // LINT.ThenChange(mock.cc:MockClientDelegation) @@ -264,9 +264,9 @@ class MockExecutable : public llvm::RTTIExtends { (const, final)); MOCK_METHOD(std::optional>, GetOutputShardings, (), (const, final)); - MOCK_METHOD(absl::StatusOr>>, + MOCK_METHOD(absl::StatusOr>>, GetParameterLayouts, (), (const, final)); - MOCK_METHOD(absl::StatusOr>>, + MOCK_METHOD(absl::StatusOr>>, GetOutputLayouts, (), (const, final)); MOCK_METHOD(absl::StatusOr>>, GetHloModules, (), (const, final)); @@ -293,9 +293,9 @@ class MockLoadedExecutable (const, final)); MOCK_METHOD(std::optional>, GetOutputShardings, (), (const, final)); - MOCK_METHOD(absl::StatusOr>>, + MOCK_METHOD(absl::StatusOr>>, GetParameterLayouts, (), (const, final)); - MOCK_METHOD(absl::StatusOr>>, + MOCK_METHOD(absl::StatusOr>>, GetOutputLayouts, (), (const, final)); MOCK_METHOD(absl::StatusOr>>, GetOutputMemoryKinds, (), (const, final)); diff --git a/third_party/xla/xla/python/ifrt/remap_impl_test_lib.cc b/third_party/xla/xla/python/ifrt/remap_impl_test_lib.cc index c8531d41791338..7625cc1bbc6f83 100644 --- a/third_party/xla/xla/python/ifrt/remap_impl_test_lib.cc +++ b/third_party/xla/xla/python/ifrt/remap_impl_test_lib.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -90,12 +91,12 @@ struct CppTypeToDType; template <> struct CppTypeToDType { - static constexpr DType::Kind dtype = DType::kS32; + static constexpr DType::Kind kDType = DType::kS32; }; template <> struct CppTypeToDType { - static constexpr DType::Kind dtype = DType::kF32; + static constexpr DType::Kind kDType = DType::kF32; }; template @@ -104,7 +105,7 @@ absl::StatusOr> CreateArray( absl::Span device_indices, Shape shard_shape = Shape({2, 3})) { TF_RET_CHECK(base_values.size() == device_indices.size()); - DType dtype(CppTypeToDType::dtype); + DType dtype(CppTypeToDType::kDType); TF_ASSIGN_OR_RETURN(Shape shape, GetShape(base_values.size(), shard_shape)); std::vector> shards; @@ -147,7 +148,7 @@ void AssertArrayContent(Client* client, Array* array, absl::Span base_values, absl::Span device_indices, Shape expected_shard_shape = Shape({2, 3})) { - DType expected_dtype(CppTypeToDType::dtype); + DType expected_dtype(CppTypeToDType::kDType); TF_ASSERT_OK_AND_ASSIGN(Shape expected_shape, GetShape(base_values.size(), expected_shard_shape)); EXPECT_EQ(array->dtype(), expected_dtype); @@ -531,6 +532,79 @@ TEST(RemapImplTest, BatchMappingDeinterleave) { } } +TEST(RemapImplTest, DetectBadInput) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + + // Trivial remap plan for a single device array on device 0. + RemapPlan plan; + plan.input_specs.push_back( + CreateArraySpec(client.get(), /*device_indices=*/{0}).value()); + plan.output_specs.push_back( + CreateArraySpec(client.get(), /*device_indices=*/{0}).value()); + plan.mappings = std::make_shared>(); + plan.mappings->push_back( + RemapPlan::Mapping{/*in_array=*/0, /*out_array=*/0, + /*from=*/{RemapPlan::Interval{0, 1, 1}}, + /*to=*/{RemapPlan::Interval{0, 1, 1}}}); + TF_ASSERT_OK(plan.Validate()); + + { + std::vector> arrays; + TF_ASSERT_OK_AND_ASSIGN( + arrays.emplace_back(), + CreateArray(client.get(), /*base_values=*/{0}, + /*device_indices=*/{0})); + TF_ASSERT_OK_AND_ASSIGN( + arrays.emplace_back(), + CreateArray(client.get(), /*base_values=*/{0}, + /*device_indices=*/{0})); + EXPECT_THAT( + client->RemapArrays(plan, absl::MakeSpan(arrays), + ArrayCopySemantics::kReuseInput), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("RemapArrays expects 1 input arrays, but got 2"))); + } + + { + std::vector> arrays; + TF_ASSERT_OK_AND_ASSIGN( + arrays.emplace_back(), + CreateArray(client.get(), /*base_values=*/{0}, + /*device_indices=*/{0})); + EXPECT_THAT( + client->RemapArrays(plan, absl::MakeSpan(arrays), + ArrayCopySemantics::kReuseInput), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("RemapArrays expects input #0 to have dtype"))); + } + + { + std::vector> arrays; + TF_ASSERT_OK_AND_ASSIGN( + arrays.emplace_back(), + CreateArray(client.get(), /*base_values=*/{0}, + /*device_indices=*/{0}, + /*shard_shape=*/Shape({20, 30}))); + EXPECT_THAT( + client->RemapArrays(plan, absl::MakeSpan(arrays), + ArrayCopySemantics::kReuseInput), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("RemapArrays expects input #0 to have shape"))); + } + + { + std::vector> arrays; + TF_ASSERT_OK_AND_ASSIGN( + arrays.emplace_back(), + CreateArray(client.get(), /*base_values=*/{0}, + /*device_indices=*/{1})); + EXPECT_THAT(client->RemapArrays(plan, absl::MakeSpan(arrays), + ArrayCopySemantics::kReuseInput), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("RemapArrays expects input #0 to be on"))); + } +} + } // namespace } // namespace ifrt } // namespace xla diff --git a/third_party/xla/xla/python/ifrt/remap_plan.cc b/third_party/xla/xla/python/ifrt/remap_plan.cc index 8925cbc47bd9cb..01df47accf7aaf 100644 --- a/third_party/xla/xla/python/ifrt/remap_plan.cc +++ b/third_party/xla/xla/python/ifrt/remap_plan.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/device.h" @@ -216,6 +217,20 @@ absl::Status RemapPlan::Validate() const { output_specs[mapping.out_array].dtype, mapping.out_array); } + const std::shared_ptr& in_layout = + input_specs[mapping.in_array].layout; + const std::shared_ptr& out_layout = + output_specs[mapping.out_array].layout; + if (in_layout != out_layout) { + return InvalidArgument( + "Input and output must have the same layout: %s (input %d) vs. %s " + "(output %d)", + in_layout != nullptr ? in_layout->ToString() : "", + mapping.in_array, + out_layout != nullptr ? out_layout->ToString() : "", + mapping.out_array); + } + std::vector& in_used_buffers = in_used_buffers_list[mapping.in_array]; absl::Span in_devices = input_specs[mapping.in_array].sharding->devices()->devices(); diff --git a/third_party/xla/xla/python/ifrt/remap_plan_test.cc b/third_party/xla/xla/python/ifrt/remap_plan_test.cc index b888b1012deb1b..eeb928f7f56071 100644 --- a/third_party/xla/xla/python/ifrt/remap_plan_test.cc +++ b/third_party/xla/xla/python/ifrt/remap_plan_test.cc @@ -23,6 +23,8 @@ limitations under the License. #include "absl/functional/bind_front.h" #include "absl/status/status.h" #include "llvm/Support/Casting.h" +#include "xla/layout_util.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/device.h" @@ -248,6 +250,42 @@ TEST_P(RemapPlanTest, InvalidOutputDtypeFromMixedInputDtypes) { HasSubstr("Input and output must have the same dtype"))); } +TEST_P(RemapPlanTest, InvalidLayout) { + RemapPlan plan; + plan.input_specs.push_back(ArraySpec{ + /*dtype=*/DType(DType::kS32), + /*shape=*/Shape({2, 3}), + /*sharding=*/ + ConcreteEvenSharding::Create(GetDevices({0}), MemoryKind(), + /*shape=*/Shape({2, 3}), + /*shard_shape=*/Shape({2, 3})), + /*layout=*/ + std::make_shared( + xla::LayoutUtil::MakeDescendingLayout(2)), + }); + plan.output_specs.push_back(ArraySpec{ + /*dtype=*/DType(DType::kS32), + /*shape=*/Shape({2, 3}), + /*sharding=*/ + ConcreteEvenSharding::Create(GetDevices({0}), MemoryKind(), + /*shape=*/Shape({2, 3}), + /*shard_shape=*/Shape({2, 3})), + /*layout=*/ + std::make_shared( + xla::LayoutUtil::MakeAscendingLayout(2)), // layout differs + }); + plan.mappings = std::make_shared>(); + plan.mappings->push_back( + RemapPlan::Mapping{/*in_array=*/0, + /*out_array=*/0, + /*from=*/{RemapPlan::Interval{0, 1, 1}}, + /*to=*/{RemapPlan::Interval{0, 1, 1}}}); + EXPECT_THAT( + plan.Validate(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Input and output must have the same layout"))); +} + TEST_P(RemapPlanTest, InvalidInputArrayIndex) { RemapPlan plan; plan.input_specs.push_back( diff --git a/third_party/xla/xla/python/ifrt/sharding.cc b/third_party/xla/xla/python/ifrt/sharding.cc index 7985c9a07674d3..5d32d711e5cbcb 100644 --- a/third_party/xla/xla/python/ifrt/sharding.cc +++ b/third_party/xla/xla/python/ifrt/sharding.cc @@ -422,7 +422,24 @@ ConcreteSharding::ConcreteSharding(tsl::RCReference devices, : llvm::RTTIExtends( std::move(devices), memory_kind, /*is_fully_replicated=*/false), shape_(std::move(shape)), - shard_shapes_(std::move(shard_shapes)) {} + shard_shapes_(std::move(shard_shapes)) { + // If all per-shard shapes are the same, cache this shape for + // `GetShardShape()`. Ideally, users should have used `ConcreteEvenSharding` + // for such a case, but there are existing use cases that instantiate + // `ConcreteSharding` from a list of per-shard shapes without checking for + // identical per-shard shapes. + const auto& static_shard_shapes = std::get>(shard_shapes_); + bool identical = true; + for (int i = 1; i < static_shard_shapes.size(); ++i) { + if (static_shard_shapes[i] != static_shard_shapes[0]) { + identical = false; + break; + } + } + if (identical) { + shard_shape_ = static_shard_shapes[0]; + } +} ConcreteSharding::ConcreteSharding( tsl::RCReference devices, MemoryKind memory_kind, @@ -434,6 +451,9 @@ ConcreteSharding::ConcreteSharding( absl::StatusOr ConcreteSharding::GetShardShape( const Shape& shape) const { + if (shard_shape_.has_value()) { + return *shard_shape_; + } return InvalidArgument("ConcreteSharding does not have a fixed shard shape"); } diff --git a/third_party/xla/xla/python/ifrt/sharding.h b/third_party/xla/xla/python/ifrt/sharding.h index b2b20da873c28f..4fc4085296cd8d 100644 --- a/third_party/xla/xla/python/ifrt/sharding.h +++ b/third_party/xla/xla/python/ifrt/sharding.h @@ -421,6 +421,7 @@ class ConcreteSharding : public llvm::RTTIExtends { std::variant shape_; std::variant, std::vector> shard_shapes_; + std::optional shard_shape_; }; // Opaque sharding that does not define a fixed semantics for conversion between diff --git a/third_party/xla/xla/python/ifrt/sharding_test.cc b/third_party/xla/xla/python/ifrt/sharding_test.cc index 23c4e015672b1e..b12a1a2ae417b9 100644 --- a/third_party/xla/xla/python/ifrt/sharding_test.cc +++ b/third_party/xla/xla/python/ifrt/sharding_test.cc @@ -325,7 +325,16 @@ TEST_P(ConcreteShardingTest, IsFullyReplicated) { EXPECT_FALSE(sharding->IsFullyReplicated()); } -TEST_P(ConcreteShardingTest, GetShardShape) { +TEST_P(ConcreteShardingTest, GetShardShapeSuccess) { + auto device_list = GetDevices({0, 1}); + Shape shard_shape({30}); + std::vector shard_shapes(2, shard_shape); + std::shared_ptr sharding = ConcreteSharding::Create( + device_list, MemoryKind(), Shape({30}), shard_shapes); + EXPECT_THAT(sharding->GetShardShape(Shape({30})), IsOkAndHolds(shard_shape)); +} + +TEST_P(ConcreteShardingTest, GetShardShapeFailure) { auto device_list = GetDevices({0, 1}); std::vector shard_shapes; shard_shapes.reserve(2); diff --git a/third_party/xla/xla/python/ifrt/support/module_parsing.cc b/third_party/xla/xla/python/ifrt/support/module_parsing.cc index b1740cd5cf0ca9..8d6efaf1a4a560 100644 --- a/third_party/xla/xla/python/ifrt/support/module_parsing.cc +++ b/third_party/xla/xla/python/ifrt/support/module_parsing.cc @@ -52,6 +52,7 @@ void RegisterMlirDialects(mlir::MLIRContext& context) { mlir::DialectRegistry registry; InitializeMlirDialectRegistry(registry); context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); } absl::StatusOr> ParseMlirModuleString( diff --git a/third_party/xla/xla/python/ifrt/topology.h b/third_party/xla/xla/python/ifrt/topology.h index 8d1104aca01f33..f7713239d5f9c9 100644 --- a/third_party/xla/xla/python/ifrt/topology.h +++ b/third_party/xla/xla/python/ifrt/topology.h @@ -42,10 +42,17 @@ class Topology : public llvm::RTTIExtends { // (e.g. the CUDA version on GPU or libtpu version on Cloud TPU). virtual absl::string_view platform_version() const = 0; + // Returns an ID that identifies the platform (CPU/GPU/TPU). virtual PjRtPlatformId platform_id() const = 0; + // Returns the topology description. + // TODO(hyeontaek): Consider introducing an IFRT-specific API here instead of + // delegating to PJRT. + virtual const std::shared_ptr& + description() const = 0; + // Returns an unordered list of descriptions for all devices in this topology. - // TODO(phawkins): consider introducing an IFRT-specific API here instead of + // TODO(hyeontaek): Consider introducing an IFRT-specific API here instead of // delegating to PJRT. virtual std::vector> DeviceDescriptions() const = 0; diff --git a/third_party/xla/xla/python/ifrt_proxy/client/BUILD b/third_party/xla/xla/python/ifrt_proxy/client/BUILD index 2b382e414d415a..2f354d5f771482 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/client/BUILD @@ -251,6 +251,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", diff --git a/third_party/xla/xla/python/ifrt_proxy/client/array.cc b/third_party/xla/xla/python/ifrt_proxy/client/array.cc index eabbbbd66e7987..578799b7db1287 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/array.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/array.cc @@ -31,6 +31,7 @@ #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" @@ -363,11 +364,22 @@ Array::RemapArrays(xla::ifrt::Client* client, return tsl::profiler::TraceMeEncode("IfrtProxyEntrypointRemapArrays", {{"n_arrays", n_arrays}}); }); + + TF_RETURN_IF_ERROR(plan.CheckArrayCopySemantics(semantics)); + const int num_inputs = plan.input_specs.size(); + const int num_actual_inputs = arrays.size(); + if (num_inputs != num_actual_inputs) { + return absl::InvalidArgumentError( + absl::StrFormat("RemapArrays expects %d input arrays, but got %d", + num_inputs, num_actual_inputs)); + } + auto req = std::make_unique(); TF_RET_CHECK(!arrays.empty()); TF_ASSIGN_OR_RETURN(*req->mutable_plan(), plan.ToProto()); req->set_copy_semantics(ToArrayCopySemanticsProto(semantics)); - for (const tsl::RCReference& rcref : arrays) { + for (int i = 0; i < num_inputs; ++i) { + const tsl::RCReference& rcref = arrays[i]; Array* array = llvm::dyn_cast(rcref.get()); if (array == nullptr) { return absl::InvalidArgumentError( @@ -375,6 +387,34 @@ Array::RemapArrays(xla::ifrt::Client* client, "not a xla::ifrt::proxy::Array.", rcref.get())); } + + if (plan.input_specs[i].dtype != arrays[i]->dtype()) { + return absl::InvalidArgumentError(absl::StrFormat( + "RemapArrays expects input #%d to have dtype %v, but got %v", i, + plan.input_specs[i].dtype, arrays[i]->dtype())); + } + if (plan.input_specs[i].shape != arrays[i]->shape()) { + return absl::InvalidArgumentError(absl::StrFormat( + "RemapArrays expects input #%d to have shape %v, but got %v", i, + plan.input_specs[i].shape, arrays[i]->shape().DebugString())); + } + // Skip xla::ifrt::Sharding::HasSamePartitioning() check because RemapArrays + // is currently called with input arrays with implicit sharding + // reinterpretation. Such patterns should be fixed before enabling stricter + // checking to avoid false positives. + if (*plan.input_specs[i].sharding->devices() != + *arrays[i]->sharding().devices() || + plan.input_specs[i].sharding->memory_kind() != + arrays[i]->sharding().memory_kind()) { + return absl::InvalidArgumentError( + absl::StrFormat("RemapArrays expects input #%d to be on %v with " + "%v, but is on %v with %v", + i, *plan.input_specs[i].sharding->devices(), + plan.input_specs[i].sharding->memory_kind(), + *arrays[i]->sharding().devices(), + arrays[i]->sharding().memory_kind())); + } + req->add_array_handles(array->handle_.handle); } diff --git a/third_party/xla/xla/python/ifrt_proxy/client/array.h b/third_party/xla/xla/python/ifrt_proxy/client/array.h index 2a9ccdf17bea32..5c4b42475f36c7 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/array.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/array.h @@ -112,7 +112,7 @@ class Array final : public llvm::RTTIExtends { std::shared_ptr shared_ptr_sharding() const override { return sharding_; } - absl::StatusOr> layout() const override { + absl::StatusOr> layout() const override { return absl::UnimplementedError( "Array::layout() not implemented for IFRT proxy"); }; diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client.h b/third_party/xla/xla/python/ifrt_proxy/client/client.h index 3732b5ddd832d7..29edb78c1af009 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/client.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/client.h @@ -140,9 +140,10 @@ class Client final : public llvm::RTTIExtends { return absl::UnimplementedError( "GetTopologyForDevices is not supported for the IFRT proxy client."); } - absl::StatusOr> GetDefaultLayoutForDevice( + absl::StatusOr> GetDefaultLayout( xla::ifrt::DType dtype, absl::Span dims, - xla::ifrt::Device* device) const override { + xla::ifrt::Device* device, + xla::ifrt::MemoryKind memory_kind) const override { return absl::UnimplementedError( "GetDefaultLayout is not supported for the IFRT proxy client."); } diff --git a/third_party/xla/xla/python/ifrt_proxy/client/executable.cc b/third_party/xla/xla/python/ifrt_proxy/client/executable.cc index 81ef43ec5c0f3b..a4926dfe84bd6b 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/executable.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/executable.cc @@ -310,10 +310,11 @@ LoadedExecutable::LoadedExecutable( auto parse_layouts = [](const LoadedExecutableMetadataResponse::LayoutList& list) { - std::vector layouts; + std::vector> layouts; layouts.reserve(list.layouts_size()); for (const auto& layout : list.layouts()) { - layouts.push_back(xla::Layout::CreateFromProto(layout)); + layouts.push_back(std::make_shared( + xla::Layout::CreateFromProto(layout))); } return layouts; }; @@ -433,34 +434,20 @@ std::optional> LoadedExecutable::GetOutputShardings() return (*info)->output_shardings; } -absl::StatusOr>> +absl::StatusOr>> LoadedExecutable::GetParameterLayouts() const { tsl::profiler::TraceMe traceme_ifrt_entrypoint( "IfrtProxyEntrypointLoadedExecutableGetParameterLayouts"); TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await()); - TF_RETURN_IF_ERROR(info->parameter_layouts.status()); - - std::vector> result; - result.reserve(info->parameter_layouts->size()); - for (const xla::Layout& layout : *info->parameter_layouts) { - result.push_back(std::make_unique(layout)); - } - return result; + return info->parameter_layouts; } -absl::StatusOr>> +absl::StatusOr>> LoadedExecutable::GetOutputLayouts() const { tsl::profiler::TraceMe traceme_ifrt_entrypoint( "IfrtProxyEntrypointLoadedExecutableGetOutputLayouts"); TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await()); - TF_RETURN_IF_ERROR(info->output_layouts.status()); - - std::vector> result; - result.reserve(info->output_layouts->size()); - for (const xla::Layout& layout : *info->output_layouts) { - result.push_back(std::make_unique(layout)); - } - return result; + return info->output_layouts; } absl::StatusOr>> diff --git a/third_party/xla/xla/python/ifrt_proxy/client/executable.h b/third_party/xla/xla/python/ifrt_proxy/client/executable.h index 5ce5292d5a76b8..0af4a14a3e80b6 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/executable.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/executable.h @@ -35,6 +35,7 @@ #include "xla/hlo/ir/hlo_module.h" #include "xla/layout.h" #include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" @@ -77,9 +78,9 @@ class LoadedExecutable final std::optional> GetParameterShardings() const override; std::optional> GetOutputShardings() const override; - absl::StatusOr>> + absl::StatusOr>> GetParameterLayouts() const override; - absl::StatusOr>> + absl::StatusOr>> GetOutputLayouts() const override; absl::StatusOr>> GetOutputMemoryKinds() const override; @@ -105,8 +106,10 @@ class LoadedExecutable final std::optional> parameter_shardings; std::optional> output_shardings; - absl::StatusOr> parameter_layouts; - absl::StatusOr> output_layouts; + absl::StatusOr>> + parameter_layouts; + absl::StatusOr>> + output_layouts; // Elements in `output_memory_kinds` point to elements in `memory_kinds`. // Required since `GetOutputMemoryKinds()` returns `absl::string_view`. diff --git a/third_party/xla/xla/python/ifrt_proxy/client/executable_test.cc b/third_party/xla/xla/python/ifrt_proxy/client/executable_test.cc index 70bb1791d3d8f6..9d050f297ac506 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/executable_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/executable_test.cc @@ -157,19 +157,14 @@ TEST_F(LoadedExecutableTest, Metadata) { Optional(ElementsAre(EquivToProto(R"pb(type: REPLICATED)pb")))); ASSERT_OK_AND_ASSIGN(auto parameter_layouts, executable.GetParameterLayouts()); - EXPECT_EQ(parameter_layouts.size(), 2); - EXPECT_EQ( - tensorflow::down_cast(parameter_layouts[0].get()) - ->xla_layout(), - xla::LayoutUtil::MakeDescendingLayout(/*rank=*/1)); - EXPECT_EQ( - tensorflow::down_cast(parameter_layouts[1].get()) - ->xla_layout(), - xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2)); + ASSERT_EQ(parameter_layouts.size(), 2); + EXPECT_EQ(parameter_layouts[0]->xla_layout(), + xla::LayoutUtil::MakeDescendingLayout(/*rank=*/1)); + EXPECT_EQ(parameter_layouts[1]->xla_layout(), + xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2)); ASSERT_OK_AND_ASSIGN(auto output_layouts, executable.GetOutputLayouts()); - EXPECT_EQ(output_layouts.size(), 1); - EXPECT_EQ(tensorflow::down_cast(output_layouts[0].get()) - ->xla_layout(), + ASSERT_EQ(output_layouts.size(), 1); + EXPECT_EQ(output_layouts[0]->xla_layout(), xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2)); EXPECT_THAT(executable.GetOutputMemoryKinds(), IsOkAndHolds(ElementsAre(ElementsAre("foo")))); diff --git a/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.cc b/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.cc index ab36e6c0f17f6f..2c8d52e7e7cff2 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.cc @@ -105,8 +105,10 @@ Future<> GrpcClientHostBufferStore::Store(uint64_t handle, } if (!writer->WritesDone()) { + writer->Finish().IgnoreError(); promise.Set( absl::InternalError("Failed to write all host buffer chunks")); + return; } } @@ -150,6 +152,7 @@ Future<> GrpcClientHostBufferStore::Store(uint64_t handle, } } if (!writer->WritesDone()) { + writer->Finish().IgnoreError(); return Future<>( absl::InternalError("Failed to write all host buffer chunks")); } diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc index 2355ba0a0bc5c7..5cfd3c52e57eb3 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc @@ -330,7 +330,6 @@ RPC(CopyToHostBuffer, copy_to_host_buffer); RPC(IsArrayDeleted, is_array_deleted); RPC(DestructArray, destruct_array) RPC(CopyArrays, copy_arrays); -RPC(Reshard, reshard); RPC(FullyReplicatedShard, fully_replicated_shard); RPC(DeleteArray, delete_array); RPC(Compile, compile); diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h index 38b61d83cbaa67..ec225b98af94d4 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h @@ -112,7 +112,6 @@ class RpcHelper { std::unique_ptr req); ResponseFuture CopyArrays( std::unique_ptr req); - ResponseFuture Reshard(std::unique_ptr req); ResponseFuture FullyReplicatedShard( std::unique_ptr req); ResponseFuture IsArrayDeleted( diff --git a/third_party/xla/xla/python/ifrt_proxy/common/BUILD b/third_party/xla/xla/python/ifrt_proxy/common/BUILD index 04cd73a1959d7f..0dcb0ea6005d6b 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/common/BUILD @@ -173,6 +173,7 @@ cc_library( "//xla/python/ifrt:plugin_program_serdes", "//xla/python/ifrt/hlo:hlo_program_serdes", "//xla/python/ifrt/ir:ifrt_ir_program_serdes", + "//xla/python/pjrt_ifrt:xla_sharding_serdes", ], alwayslink = True, ) diff --git a/third_party/xla/xla/python/ifrt_proxy/common/array_util.cc b/third_party/xla/xla/python/ifrt_proxy/common/array_util.cc index 6b3bb83863d492..1c3f2ac8643ac0 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/array_util.cc +++ b/third_party/xla/xla/python/ifrt_proxy/common/array_util.cc @@ -186,7 +186,7 @@ absl::Status DeserializeFromCordIntoPreallocatedStringHostBuffer( proto::StringArrayContents string_array_proto; #if defined(PLATFORM_GOOGLE) - if (!string_array_proto.ParseFromCord(serialized_string_buffer)) { + if (!string_array_proto.ParseFromString(serialized_string_buffer)) { #else if (!string_array_proto.ParseFromString( // No absl::Cord support in OSS. std::string(serialized_string_buffer))) { diff --git a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto index 748f8994217bf2..942e0f648a3bb9 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto +++ b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto @@ -56,7 +56,6 @@ message IfrtRequest { disassemble_into_single_device_arrays_request = 7; DeleteArrayRequest delete_array_request = 9; CopyArraysRequest copy_arrays_request = 24; - ReshardRequest reshard_request = 10 [deprecated = true]; FullyReplicatedShardRequest fully_replicated_shard_request = 20; IsArrayDeletedRequest is_array_deleted_request = 11; DestructArrayRequest destruct_array_request = 12; @@ -79,6 +78,8 @@ message IfrtRequest { GetDefaultDeviceAssignmentRequest get_default_device_assignment_request = 19; } + + reserved 10; } message IfrtResponse { @@ -103,7 +104,6 @@ message IfrtResponse { disassemble_into_single_device_arrays_response = 7; DeleteArrayResponse delete_array_response = 9; CopyArraysResponse copy_arrays_response = 24; - ReshardResponse reshard_response = 10 [deprecated = true]; FullyReplicatedShardResponse fully_replicated_shard_response = 20; IsArrayDeletedResponse is_array_deleted_response = 11; DestructArrayResponse destruct_array_response = 12; @@ -127,6 +127,8 @@ message IfrtResponse { GetDefaultDeviceAssignmentResponse get_default_device_assignment_response = 19; } + + reserved 10; } // Metadata of an IFRT Request. @@ -323,15 +325,6 @@ message CopyArraysResponse { repeated fixed64 array_handles = 1; } -message ReshardRequest { - fixed64 array_handle = 1; - ShardingProto sharding = 2; - proto.ArrayCopySemantics copy_semantics = 3; -} -message ReshardResponse { - fixed64 array_handle = 1; -} - message FullyReplicatedShardRequest { fixed64 array_handle = 1; proto.ArrayCopySemantics copy_semantics = 2; diff --git a/third_party/xla/xla/python/ifrt_proxy/common/versions.h b/third_party/xla/xla/python/ifrt_proxy/common/versions.h index fca38276e75f78..0a95337040bb93 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/versions.h +++ b/third_party/xla/xla/python/ifrt_proxy/common/versions.h @@ -26,7 +26,7 @@ namespace protocol_version { inline constexpr int kClientMin = 3; // The minimum protocol_version that the current server code understands. -inline constexpr int kServerMin = 1; +inline constexpr int kServerMin = 3; enum { // Versions kAncient are named and are only referred to by their numbers. See diff --git a/third_party/xla/xla/python/ifrt_proxy/server/BUILD b/third_party/xla/xla/python/ifrt_proxy/server/BUILD index 7bbeebeeedf089..a2beb527bf3eac 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/server/BUILD @@ -172,8 +172,8 @@ ifrt_proxy_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", - "//xla:test", "//xla:xla_data_proto_cc", + "//xla/hlo/testlib:test", "//xla/pjrt:host_callback", "//xla/pjrt:pjrt_layout", "//xla/python/ifrt", diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc index f9167a8c23c026..4bcb18893601cc 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc @@ -342,8 +342,6 @@ Future IfrtBackend::ProcessInternal( return Future(HandleCheckValueReadyRequest(std::move(request))); case IfrtRequest::RequestCase::kCopyArraysRequest: return Future(HandleCopyArraysRequest(std::move(request))); - case IfrtRequest::RequestCase::kReshardRequest: - return Future(HandleReshardRequest(std::move(request))); case IfrtRequest::RequestCase::kFullyReplicatedShardRequest: return Future( HandleFullyReplicatedShardRequest(std::move(request))); @@ -1029,44 +1027,6 @@ absl::StatusOr IfrtBackend::HandleCopyArraysRequest( return ifrt_resp; } -absl::StatusOr IfrtBackend::HandleReshardRequest( - std::unique_ptr request) { - const auto& reshard_request = request->reshard_request(); - TF_ASSIGN_OR_RETURN(auto array, GetArray(reshard_request.array_handle())); - TF_ASSIGN_OR_RETURN( - std::shared_ptr sharding, - Sharding::FromProto( - absl::bind_front(&Client::LookupDevice, client_.get()), - reshard_request.sharding())); - TF_ASSIGN_OR_RETURN(auto semantics, FromArrayCopySemanticsProto( - reshard_request.copy_semantics())); - - // Emulate the old `Array::Reshard` behavior using `Client::CopyArrays`. No - // existing IFRT implementations before `Array::Reshard` was deleted actually - // supported resharding, so this should be safe. - if (!array->sharding().HasSamePartitioning(*sharding)) { - return absl::InvalidArgumentError(absl::StrCat( - "IFRT Proxy does not support resharding, but got ", - array->sharding().DebugString(), " as the original sharding and ", - sharding->DebugString(), " as the target sharding")); - } - TF_ASSIGN_OR_RETURN( - auto copied_arrays, - client_->CopyArrays(absl::MakeSpan(&array, 1), sharding->devices(), - sharding->memory_kind(), semantics)); - - uint64_t resharded_array_handle = handle_generator_.GenerateAtServer(); - { - absl::MutexLock lock(&arrays_mutex_); - arrays_.insert({resharded_array_handle, std::move(copied_arrays[0])}); - } - - auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); - ifrt_resp->mutable_reshard_response()->set_array_handle( - resharded_array_handle); - return ifrt_resp; -} - absl::StatusOr IfrtBackend::HandleFullyReplicatedShardRequest( std::unique_ptr request) { @@ -1327,15 +1287,10 @@ IfrtBackend::HandleLoadedExecutableMetadataRequest( parameter_layouts.ok()) { auto* const layouts = metadata_resp->mutable_parameter_layouts_list()->mutable_layouts(); - for (const std::unique_ptr& parameter_layout : + for (const std::shared_ptr& parameter_layout : *parameter_layouts) { // TODO(b/329165105): use PjRtLayout::Serialize instead - const xla::PjRtXlaLayout* layout = - dynamic_cast(parameter_layout.get()); - TF_RET_CHECK(layout != nullptr) - << "IFRT proxy only supports PjRtXlaLayout, got a different " - "subclass"; - layouts->Add(layout->xla_layout().ToProto()); + layouts->Add(parameter_layout->xla_layout().ToProto()); } } else { *metadata_resp->mutable_parameter_layouts_error() = @@ -1345,15 +1300,10 @@ IfrtBackend::HandleLoadedExecutableMetadataRequest( output_layouts.ok()) { auto* const layouts = metadata_resp->mutable_output_layouts_list()->mutable_layouts(); - for (const std::unique_ptr& output_layout : + for (const std::shared_ptr& output_layout : *output_layouts) { // TODO(b/329165105): use PjRtLayout::Serialize instead - const xla::PjRtXlaLayout* layout = - dynamic_cast(output_layout.get()); - TF_RET_CHECK(layout != nullptr) - << "IFRT proxy only supports PjRtXlaLayout, got a different " - "subclass"; - layouts->Add(layout->xla_layout().ToProto()); + layouts->Add(output_layout->xla_layout().ToProto()); } } else { *metadata_resp->mutable_output_layouts_error() = diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index 602c01cf5e4382..d19c99c022f9a6 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -38,6 +38,7 @@ #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ExtensibleRTTI.h" +#include "xla/hlo/testlib/test.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/literal_util.h" @@ -68,7 +69,6 @@ #include "xla/service/computation_placer.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/test.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/protobuf/error_codes.pb.h" @@ -970,74 +970,6 @@ TEST_P(IfrtBackendHandlerTest, CopyArrays) { SizeIs(copied_arrays.size())); } -TEST_P(IfrtBackendHandlerTest, ReshardSuccess) { - auto src_mock_array = tsl::MakeRef(); - TF_ASSERT_OK_AND_ASSIGN(auto* device, - mock_client_->LookupDevice(DeviceId(0))); - auto src_sharding = SingleDeviceSharding::Create(device, MemoryKind()); - ON_CALL(*src_mock_array, sharding()).WillByDefault(ReturnRef(*src_sharding)); - TF_ASSERT_OK_AND_ASSIGN(auto src_array_handle, - MakeTestArray(std::move(src_mock_array))); - - auto copied_mock_array = tsl::MakeRef(); - EXPECT_CALL(*mock_client_, CopyArrays(_, _, _, _)) - .WillOnce(Return(std::vector>( - {copied_mock_array}))); - - auto ifrt_request = NewIfrtRequest(NewOpId()); - auto* reshard_request = ifrt_request->mutable_reshard_request(); - reshard_request->set_array_handle(src_array_handle); - reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); - TF_ASSERT_OK_AND_ASSIGN(auto* new_device, - mock_client_->LookupDevice(DeviceId(1))); - TF_ASSERT_OK_AND_ASSIGN( - *ifrt_request->mutable_reshard_request()->mutable_sharding(), - SingleDeviceSharding::Create(new_device, MemoryKind())->ToProto()); - - TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(ifrt_request))); - - EXPECT_THAT(tsl::StatusFromProto(response->response_metadata().status()), - IsOk()); - EXPECT_NE(response->reshard_response().array_handle(), 0); -} - -TEST_P(IfrtBackendHandlerTest, ReshardFailsWhenTheBackendFails) { - auto mock_array = tsl::MakeRef(); - TF_ASSERT_OK_AND_ASSIGN(auto* device, - mock_client_->LookupDevice(DeviceId(1))); - auto sharding = SingleDeviceSharding::Create(device, MemoryKind()); - ON_CALL(*mock_array, sharding()).WillByDefault(ReturnRef(*sharding)); - TF_ASSERT_OK_AND_ASSIGN(auto array_handle, - MakeTestArray(std::move(mock_array))); - - EXPECT_CALL(*mock_client_, CopyArrays(_, _, _, _)) - .WillOnce(Return(absl::UnknownError("injected error"))); - - auto ifrt_request = NewIfrtRequest(NewOpId()); - auto* reshard_request = ifrt_request->mutable_reshard_request(); - reshard_request->set_array_handle(array_handle); - reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); - TF_ASSERT_OK_AND_ASSIGN(auto* new_device, - mock_client_->LookupDevice(DeviceId(1))); - TF_ASSERT_OK_AND_ASSIGN( - *ifrt_request->mutable_reshard_request()->mutable_sharding(), - SingleDeviceSharding::Create(new_device, MemoryKind())->ToProto()); - - EXPECT_THAT(CallBackend(std::move(ifrt_request)), - StatusIs(absl::StatusCode::kUnknown, StrEq("injected error"))); -} - -TEST_P(IfrtBackendHandlerTest, ReshardFailsWithNonExistentArrayHandle) { - auto ifrt_request = NewIfrtRequest(NewOpId()); - auto* reshard_request = ifrt_request->mutable_reshard_request(); - reshard_request->set_array_handle(0); - reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY); - reshard_request->mutable_sharding(); - - EXPECT_THAT(CallBackend(std::move(ifrt_request)), - StatusIs(absl::StatusCode::kNotFound)); -} - TEST_P(IfrtBackendHandlerTest, FullyReplicatedShardSuccess) { auto fully_replicated_mock_array = tsl::MakeRef(); auto resultant_array = tsl::MakeRef(); @@ -1311,16 +1243,16 @@ TEST_P(IfrtBackendHandlerTest, LoadedExecutableMetadata) { EXPECT_CALL(*executable, GetOutputShardings()) .WillOnce(Return(std::vector{op_sharding1})); - std::vector> parameter_layouts; - parameter_layouts.push_back(std::make_unique( + std::vector> parameter_layouts; + parameter_layouts.push_back(std::make_shared( xla::LayoutUtil::MakeDescendingLayout(/*rank=*/1))); - parameter_layouts.push_back(std::make_unique( + parameter_layouts.push_back(std::make_shared( xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2))); EXPECT_CALL(*executable, GetParameterLayouts()) .WillOnce(Return(std::move(parameter_layouts))); - std::vector> output_layouts; - output_layouts.push_back(std::make_unique( + std::vector> output_layouts; + output_layouts.push_back(std::make_shared( xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2))); EXPECT_CALL(*executable, GetOutputLayouts()) .WillOnce(Return(std::move(output_layouts))); diff --git a/third_party/xla/xla/python/inspect_sharding.cc b/third_party/xla/xla/python/inspect_sharding.cc index dfa03f37f01e01..598ccd925ff52e 100644 --- a/third_party/xla/xla/python/inspect_sharding.cc +++ b/third_party/xla/xla/python/inspect_sharding.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/python/inspect_sharding.h" +#include #include #include #include @@ -24,6 +25,7 @@ limitations under the License. #include "absl/status/status.h" #include "xla/service/custom_call_sharding_helper.h" #include "xla/service/spmd/spmd_partitioner_util.h" +#include "xla/xla_data.pb.h" namespace jax { diff --git a/third_party/xla/xla/python/inspect_sharding.h b/third_party/xla/xla/python/inspect_sharding.h index 4afc3a63875a0d..c6ee425071da25 100644 --- a/third_party/xla/xla/python/inspect_sharding.h +++ b/third_party/xla/xla/python/inspect_sharding.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_PYTHON_INSPECT_SHARDING_H_ #define XLA_PYTHON_INSPECT_SHARDING_H_ +#include #include #include diff --git a/third_party/xla/xla/python/jax_jit.cc b/third_party/xla/xla/python/jax_jit.cc index d5961c9fddbc3d..e6d7ee51ab5f1f 100644 --- a/third_party/xla/xla/python/jax_jit.cc +++ b/third_party/xla/xla/python/jax_jit.cc @@ -29,10 +29,10 @@ limitations under the License. #include #include +#include #include #include #include -#include #include #include @@ -44,6 +44,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep @@ -52,6 +53,7 @@ limitations under the License. #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" #include "xla/python/nb_absl_inlined_vector.h" // IWYU pragma: keep #include "xla/python/nb_absl_span.h" // IWYU pragma: keep @@ -136,17 +138,9 @@ static std::string OptionalDebugString( } } -bool FetchMemoriesFlag() { - auto& global_state = GlobalJitState(); - auto& thread_local_state = ThreadLocalJitState(); - CHECK(global_state.enable_memories.has_value()); - return thread_local_state.enable_memories.value_or( - *global_state.enable_memories); -} - std::string ArgumentSignature::DebugString() const { auto py_object_formatter = [](std::string* out, const nb::object& o) { - out->append(nb::cast(nb::str(o))); + out->append(nb::cast(nb::str(o))); }; auto treedef_formatter = [](std::string* out, const xla::PyTreeDef& d) { out->append(d.ToString()); @@ -187,8 +181,8 @@ bool ArgumentSignature::operator==(const ArgumentSignature& other) const { "static arguments should be comparable using __eq__." "The following error was raised when comparing two objects of " "types ", - nb::cast(nb::str(a.type())), " and ", - nb::cast(nb::str(b.type())), + nb::cast(nb::str(a.type())), " and ", + nb::cast(nb::str(b.type())), ". The error was:\n", e.what())); } }); @@ -196,12 +190,20 @@ bool ArgumentSignature::operator==(const ArgumentSignature& other) const { std::string CallSignature::DebugString() const { auto py_object_formatter = [](std::string* out, const nb::object& o) { - out->append(nb::cast(nb::str(o))); + out->append(nb::cast(nb::str(o))); }; auto signature_formatter = [](std::string* out, const xla::PyArgSignature& s) { out->append(s.DebugString()); }; + auto layout_formatter = [](std::string* out, + const std::shared_ptr& l) { + if (l != nullptr) { + out->append(l->ToString()); + } else { + out->append("None"); + } + }; auto bool_formatter = [](std::string* out, bool o) { out->append(o ? "true" : "false"); }; @@ -209,20 +211,21 @@ std::string CallSignature::DebugString() const { "arg signature: %s\n" "dynamic arg signatures (positional + keyword): %s\n" "dynamic arg shardings: %s\n" + "dynamic arg layouts: %s\n" "committed args: %s\n" "device: %s\n" "default_device: %s\n" "jax_enable_x64: %d\n" - "jax_enable_memories: %d\n" "global_extra_jit_context: %s\n" "thread_local_extra_jit_context: %s\n" "configs: %s\n", arg_signature.DebugString(), absl::StrJoin(dynamic_arg_signatures, ", ", signature_formatter), absl::StrJoin(dynamic_arg_shardings, ", ", py_object_formatter), + absl::StrJoin(dynamic_arg_layouts, ", ", layout_formatter), absl::StrJoin(committed_args, ",", bool_formatter), device != nullptr ? device->DebugString() : "nullptr", - OptionalDebugString(default_device), jax_enable_x64, jax_enable_memories, + OptionalDebugString(default_device), jax_enable_x64, OptionalDebugString(global_extra_jit_context), OptionalDebugString(thread_local_extra_jit_context), absl::StrJoin(configs, ", ", py_object_formatter)); @@ -241,9 +244,6 @@ bool CallSignature::operator==(const CallSignature& other) const { if (jax_enable_x64 != other.jax_enable_x64) { return false; } - if (jax_enable_memories != other.jax_enable_memories) { - return false; - } if (committed_args != other.committed_args) { return false; } @@ -251,6 +251,11 @@ bool CallSignature::operator==(const CallSignature& other) const { // `==` on py:objects is the Python `is`. We need equal. absl::c_equal(dynamic_arg_shardings, other.dynamic_arg_shardings, ShardingEqual) && + absl::c_equal(dynamic_arg_layouts, other.dynamic_arg_layouts, + [](const std::shared_ptr& a, + const std::shared_ptr& b) { + return (a && b) ? *a == *b : a == b; + }) && (global_extra_jit_context.has_value() == other.global_extra_jit_context.has_value()) && (!global_extra_jit_context.has_value() || @@ -370,16 +375,12 @@ void BuildJaxjitSubmodule(nb::module_& m) { nb::class_ jit_state_(jitlib, "JitState"); jit_state_.def_rw("disable_jit", &JitState::disable_jit, nb::arg().none()); jit_state_.def_rw("enable_x64", &JitState::enable_x64, nb::arg().none()); - jit_state_.def_rw("enable_memories", &JitState::enable_memories, - nb::arg().none()); jit_state_.def_rw("default_device", &JitState::default_device, nb::arg().none()); jit_state_.def_rw("extra_jit_context", &JitState::extra_jit_context, nb::arg().none()); jit_state_.def_rw("post_hook", &JitState::post_hook, nb::arg().none()); - GetEnableMemories = +[] { return FetchMemoriesFlag(); }; - jitlib.def( "global_state", [&]() { return &GlobalJitState(); }, nb::rv_policy::reference); diff --git a/third_party/xla/xla/python/jax_jit.h b/third_party/xla/xla/python/jax_jit.h index df90f26cde750c..59d35abf0daa18 100644 --- a/third_party/xla/xla/python/jax_jit.h +++ b/third_party/xla/xla/python/jax_jit.h @@ -19,10 +19,10 @@ limitations under the License. #include #include +#include #include #include #include -#include #include #include @@ -60,7 +60,6 @@ struct JitState { std::optional disable_jit; std::optional enable_x64; - std::optional enable_memories; // Used to manually set the default device jax should use. May be unset even // in global state, indicating there is no manual override. @@ -140,8 +139,8 @@ H AbslHashValue(H h, const ArgumentSignature& s) { throw std::invalid_argument(absl::StrCat( "Non-hashable static arguments are not supported. An error occurred " "while trying to hash an object of type ", - nanobind::cast(nanobind::str(static_arg.type())), - ", ", nanobind::cast(nanobind::str(static_arg)), + nanobind::cast(nanobind::str(static_arg.type())), + ", ", nanobind::cast(nanobind::str(static_arg)), ". The error was:\n", e.what(), "\n")); } h = H::combine(std::move(h), hash); @@ -185,7 +184,7 @@ absl::Status ParseArguments( // (a) equality (delegated to Python) of the static arguments. struct CallSignature { // Not part of the signature, but we need it for error messages. - std::string_view function_name; + absl::string_view function_name; ArgumentSignature arg_signature; @@ -193,10 +192,12 @@ struct CallSignature { // arguments (sorted by keyword name). absl::InlinedVector dynamic_arg_signatures; - // The sharding of the jax.Array arguments. This is only used by pjit with - // jax.Array enabled. + // The sharding of the jax.Array arguments. std::vector dynamic_arg_shardings; + // The layout of the jax.Array arguments. + std::vector> dynamic_arg_layouts; + absl::InlinedVector committed_args; // For JIT, we need this in the key because computation follows the data, so @@ -204,7 +205,6 @@ struct CallSignature { // This is not the case for PMAP, and is set to `nullptr`. xla::PjRtDevice* device = nullptr; bool jax_enable_x64; - bool jax_enable_memories = false; // For JIT on PJIT, we need to fallback to python whenever default_device // changes. @@ -231,12 +231,20 @@ H AbslHashValue(H h, const CallSignature& s) { DCHECK(s.dynamic_arg_shardings.empty() || s.dynamic_arg_shardings.size() == s.dynamic_arg_signatures.size()); + DCHECK(s.dynamic_arg_layouts.empty() || + s.dynamic_arg_layouts.size() == s.dynamic_arg_signatures.size()); + // TODO(chky): For now, we are only hashing the pointer of shardings to avoid // slow python hashing function. Consider implementing hashing function and // equality checks in C++ in jax::Sharding and use those here. for (const auto& sharding : s.dynamic_arg_shardings) { - // TODO(phawkins): remove .ptr() after nanobind transition is complete. - h = H::combine(std::move(h), ShardingHash(sharding.ptr())); + h = H::combine(std::move(h), ShardingHash(sharding)); + } + + for (const auto& layout : s.dynamic_arg_layouts) { + if (layout != nullptr) { + h = H::combine(std::move(h), *layout); + } } h = H::combine(std::move(h), s.committed_args, s.device, s.jax_enable_x64); diff --git a/third_party/xla/xla/python/mlir.cc b/third_party/xla/xla/python/mlir.cc index 36e19d2e7f94a8..2083367b87d429 100644 --- a/third_party/xla/xla/python/mlir.cc +++ b/third_party/xla/xla/python/mlir.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include #include "mhlo/transforms/passes.h" #include "absl/status/status.h" @@ -36,10 +35,8 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "stablehlo/dialect/Serialization.h" -#include "stablehlo/dialect/StablehloOps.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" -#include "xla/mlir/utils/error_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/pjrt/mlir_to_hlo.h" @@ -110,7 +107,7 @@ absl::StatusOr PyXlaComputationToMlirModule( } absl::StatusOr PyMlirModuleToXlaComputation( - std::string_view mlir_module, bool use_tuple_args, bool return_tuple) { + absl::string_view mlir_module, bool use_tuple_args, bool return_tuple) { mlir::MLIRContext context; TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseMlirModuleString(mlir_module, context)); @@ -123,7 +120,7 @@ absl::StatusOr PyMlirModuleToXlaComputation( return computation; } -absl::StatusOr PyMhloToStablehlo(std::string_view mlir_module) { +absl::StatusOr PyMhloToStablehlo(absl::string_view mlir_module) { mlir::MLIRContext context; if (VLOG_IS_ON(3)) context.disableMultithreading(); // JAX can be customized in a way that involves operations from custom @@ -156,7 +153,7 @@ absl::StatusOr PyStablehloToMhlo(const nb::bytes& mlir_module) { TF_ASSIGN_OR_RETURN( mlir::OwningOpRef module, ParseMlirModuleString( - std::string_view(mlir_module.c_str(), mlir_module.size()), context)); + absl::string_view(mlir_module.c_str(), mlir_module.size()), context)); mlir::PassManager pm(&context); if (VLOG_IS_ON(3)) EnablePrintBeforeAndAfter(pm); pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); @@ -171,7 +168,7 @@ absl::StatusOr PyStablehloToMhlo(const nb::bytes& mlir_module) { } absl::StatusOr PySerializePortableArtifact( - std::string_view mlir_module, std::string_view target) { + absl::string_view mlir_module, absl::string_view target) { mlir::MLIRContext context; if (VLOG_IS_ON(3)) context.disableMultithreading(); TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, @@ -189,7 +186,7 @@ absl::StatusOr PyDeserializePortableArtifact( mlir::MLIRContext context; mlir::OwningOpRef module = mlir::stablehlo::deserializePortableArtifact( - std::string_view(bytecode_str.c_str(), bytecode_str.size()), + absl::string_view(bytecode_str.c_str(), bytecode_str.size()), &context); if (!module) return tsl::errors::InvalidArgument("Failed to deserialize StableHLO"); @@ -208,8 +205,8 @@ void BuildMlirSubmodule(nb::module_& m) { "mlir_module_to_xla_computation", [](const nb::bytes& bytecode, bool use_tuple_args, bool return_tuple) { return xla::ValueOrThrow(PyMlirModuleToXlaComputation( - std::string_view(bytecode.c_str(), bytecode.size()), use_tuple_args, - return_tuple)); + absl::string_view(bytecode.c_str(), bytecode.size()), + use_tuple_args, return_tuple)); }, nb::arg("mlir_module"), nb::arg("use_tuple_args") = false, nb::arg("return_tuple") = false); @@ -221,7 +218,7 @@ void BuildMlirSubmodule(nb::module_& m) { "mhlo_to_stablehlo", [](const nb::bytes& bytecode) { return xla::ValueOrThrow(PyMhloToStablehlo( - std::string_view(bytecode.c_str(), bytecode.size()))); + absl::string_view(bytecode.c_str(), bytecode.size()))); }, nb::arg("mlir_module")); mlir_module.def("mhlo_to_stablehlo", @@ -232,9 +229,9 @@ void BuildMlirSubmodule(nb::module_& m) { nb::arg("mlir_module")); mlir_module.def( "serialize_portable_artifact", - [](const nb::bytes& bytecode, std::string_view target) { + [](const nb::bytes& bytecode, absl::string_view target) { return xla::ValueOrThrow(PySerializePortableArtifact( - std::string_view(bytecode.c_str(), bytecode.size()), target)); + absl::string_view(bytecode.c_str(), bytecode.size()), target)); }, nb::arg("mlir_module"), nb::arg("target")); mlir_module.def("serialize_portable_artifact", @@ -250,7 +247,7 @@ void BuildMlirSubmodule(nb::module_& m) { std::string buffer; llvm::raw_string_ostream os(buffer); xla::ThrowIfError(RefinePolymorphicShapes( - std::string_view(bytecode.c_str(), bytecode.size()), os, + absl::string_view(bytecode.c_str(), bytecode.size()), os, enable_shape_assertions, validate_static_shapes)); return nb::bytes(buffer.data(), buffer.size()); }, diff --git a/third_party/xla/xla/python/nb_numpy.h b/third_party/xla/xla/python/nb_numpy.h index b4ed1c9cc92c03..94820d464b3022 100644 --- a/third_party/xla/xla/python/nb_numpy.h +++ b/third_party/xla/xla/python/nb_numpy.h @@ -26,8 +26,8 @@ limitations under the License. #include #include -#include +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "xla/tsl/python/lib/core/numpy.h" @@ -46,7 +46,7 @@ class nb_dtype : public nanobind::object { explicit nb_dtype(const nanobind::str& format) : nb_dtype(from_args(format)) {} - explicit nb_dtype(std::string_view format) + explicit nb_dtype(absl::string_view format) : nb_dtype(from_args(nanobind::str(format.data(), format.size()))) {} static nb_dtype from_args(const nanobind::object& args); diff --git a/third_party/xla/xla/python/ops.cc b/third_party/xla/xla/python/ops.cc index 904bc7f4015bd6..67ca0c5d273768 100644 --- a/third_party/xla/xla/python/ops.cc +++ b/third_party/xla/xla/python/ops.cc @@ -46,6 +46,7 @@ limitations under the License. #include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/nb_helpers.h" #include "xla/python/types.h" +#include "xla/service/hlo.pb.h" #include "xla/xla_data.pb.h" namespace nb = nanobind; @@ -67,7 +68,7 @@ struct type_caster { const_name("xla::ConvolutionDimensionNumbers")); // PyObject -> C++ conversion. - bool from_python(handle handle, uint8_t, cleanup_list*) { + bool from_python(handle handle, uint8_t, cleanup_list*) noexcept { try { value.set_input_batch_dimension( cast(getattr(handle, "input_batch_dimension"))); @@ -147,7 +148,7 @@ struct type_caster { const_name("xla::GatherDimensionNumbers")); // PyObject -> C++ conversion. - bool from_python(handle handle, uint8_t, cleanup_list*) { + bool from_python(handle handle, uint8_t, cleanup_list*) noexcept { try { std::vector dims; dims = cast>(getattr(handle, "offset_dims")); @@ -179,7 +180,7 @@ struct type_caster { const_name("xla::ScatterDimensionNumbers")); // PyObject -> C++ conversion. - bool from_python(handle handle, uint8_t, cleanup_list*) { + bool from_python(handle handle, uint8_t, cleanup_list*) noexcept { try { std::vector dims; dims = cast>(getattr(handle, "update_window_dims")); @@ -212,7 +213,7 @@ struct type_caster { const_name("xla::ReplicaGroup")); // PyObject -> C++ conversion. - bool from_python(handle handle, uint8_t, cleanup_list*) { + bool from_python(handle handle, uint8_t, cleanup_list*) noexcept { try { auto dims = cast>(getattr(handle, "replica_ids")); std::copy(dims.begin(), dims.end(), @@ -232,7 +233,7 @@ struct type_caster { const_name("xla::PaddingConfig")); // PyObject -> C++ conversion. - bool from_python(handle handle, uint8_t, cleanup_list*) { + bool from_python(handle handle, uint8_t, cleanup_list*) noexcept { try { sequence dimensions = borrow(getattr(handle, "dimensions")); @@ -260,7 +261,7 @@ struct type_caster { const_name("xla::PrecisionConfig")); // PyObject -> C++ conversion. - bool from_python(handle handle, uint8_t, cleanup_list*) { + bool from_python(handle handle, uint8_t, cleanup_list*) noexcept { try { if (handle.is_none()) { return true; @@ -286,7 +287,7 @@ struct type_caster { NB_TYPE_CASTER_FROM_PYTHON_ONLY(xla::ResultAccuracy, const_name("xla::ResultAccuracy")); // PyObject -> C++ conversion. - bool from_python(handle handle, uint8_t, cleanup_list*) { + bool from_python(handle handle, uint8_t, cleanup_list*) noexcept { try { if (handle.is_none()) { return true; diff --git a/third_party/xla/xla/python/pjit.cc b/third_party/xla/xla/python/pjit.cc index 5e012002586489..2a009377142333 100644 --- a/third_party/xla/xla/python/pjit.cc +++ b/third_party/xla/xla/python/pjit.cc @@ -26,7 +26,6 @@ limitations under the License. #include #include #include -#include #include // NOLINT #include #include @@ -35,7 +34,6 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/hash/hash.h" #include "absl/status/status.h" @@ -131,10 +129,12 @@ class PjitFunctionCache { // We include as part of the cache key `global_cache_key` (and any other // fields that aren't subsumed by the CallSignature we compute for each call). - std::shared_ptr Lookup(nb::handle function, - nb::object global_cache_key); + static std::shared_ptr Lookup( + xla::nb_class_ptr self, nb::handle function, + nb::object global_cache_key); std::shared_ptr DefaultCache(); + // These methods require the GIL or the object's lock in no-GIL mode. int Size() const { return lru_list_.Size(); } int Capacity() const { return lru_list_.Capacity(); } void Clear() { @@ -192,10 +192,14 @@ class PjitFunctionCache { std::optional weakref; }; + // lru_list_ and functions_ are protected by the GIL in GIL mode, and by the + // self object lock in freethreading mode. Cache::LRUList lru_list_; - absl::Mutex mu_; // Non-trivial hashes need to be mutex locked. - // ABSL containers are not exception safe: + // We use std::unordered_map because ABSL containers are not exception safe: std::unordered_map, absl::Hash> functions_; + // mu_ prevents concurrent insertions into functions_ if the gil or critical + // section lock is released during insertion. + absl::Mutex mu_; }; PjitFunctionCache::PjitFunctionCache(int capacity) : lru_list_(capacity) {} @@ -204,31 +208,38 @@ std::shared_ptr PjitFunctionCache::DefaultCache() { return std::make_shared(&lru_list_); } -std::shared_ptr PjitFunctionCache::Lookup( - nb::handle function, +/*static*/ std::shared_ptr PjitFunctionCache::Lookup( + xla::nb_class_ptr self, nb::handle function, nb::object global_cache_key) ABSL_NO_THREAD_SAFETY_ANALYSIS { + // In no-GIL mode, a critical section on self plays the same role that + // the GIL plays in GIL mode. + nb::ft_object_guard lock(self); { - // Because the gil can be released during cache insertion, this forces - // the lock order to be mu_ then gil so we must release the gil first. + // Because the gil (or the critical section lock) can be released during + // cache insertion, this forces the lock order to be mu_ then gil so we + // must release the gil first. nb::gil_scoped_release release; // Acquire a mutex to avoid problems where the gil is released during // cache insertion and then a second thread invalidates the cache order. - mu_.Lock(); + self->mu_.Lock(); } - absl::Cleanup unlock = [this]() ABSL_UNLOCK_FUNCTION(mu_) { mu_.Unlock(); }; + absl::Cleanup unlock = [&self]() ABSL_UNLOCK_FUNCTION(self->mu_) { + self->mu_.Unlock(); + }; Key key; key.function = function; key.global_cache_key = global_cache_key; - auto insert = functions_.emplace(key, nullptr); + auto insert = self->functions_.emplace(key, nullptr); if (!insert.second) { return insert.first->second->cache; } - std::shared_ptr cache = std::make_shared(&lru_list_); + std::shared_ptr cache = std::make_shared(&self->lru_list_); auto callback = - nb::cpp_function([this, key{std::move(key)}](nb::handle weakref) { - auto it = functions_.find(key); - if (it != functions_.end()) { - functions_.erase(it); + nb::cpp_function([self, key{std::move(key)}](nb::handle weakref) { + nb::ft_object_guard lock(self); + auto it = self->functions_.find(key); + if (it != self->functions_.end()) { + self->functions_.erase(it); } }); PyObject* weakref = PyWeakref_NewRef(function.ptr(), callback.ptr()); @@ -241,7 +252,7 @@ std::shared_ptr PjitFunctionCache::Lookup( // `function` is not weak-referenceable. Don't bother adding it to the // shared cache in that case; the `jit` object will hold the only shared // reference to the cache entry. - functions_.erase(insert.first); + self->functions_.erase(insert.first); } return cache; } @@ -254,7 +265,7 @@ class PjitFunction { nb::object global_cache_key, xla::nb_class_ptr pytree_registry, nb::callable shard_arg_fallback, - std::shared_ptr cache); + xla::nb_class_ptr cache); ~PjitFunction(); PjitFunction(const PjitFunction&) = delete; @@ -301,11 +312,22 @@ class PjitFunction { return static_argnames_; } const nb::object& global_cache_key() const { return global_cache_key_; } - const std::shared_ptr& cache() const { return cache_; } + const xla::nb_class_ptr& cache() const { return cache_; } - int cache_capacity() const { return executables_->Size(); } + int cache_capacity() const { + nb::ft_object_guard lock(cache_); + return executables_->Size(); + } - void ClearCache() { executables_->Clear(); } + void ClearCache() { + nb::ft_object_guard lock(cache_); + executables_->Clear(); + } + + std::shared_ptr executables() { + nb::ft_object_guard lock(cache_); + return executables_; + } nb::object PythonSignature() { if (!fun_.has_value()) { @@ -337,40 +359,19 @@ class PjitFunction { xla::nb_class_ptr pytree_registry_; nb::callable shard_arg_fallback_; - std::shared_ptr cache_; - std::shared_ptr executables_; -}; - -// thread-compatible. -class PjitFunctionStore { - public: - void Insert(PjitFunction* function) { compiled_functions_.insert(function); } - - void Erase(PjitFunction* function) { compiled_functions_.erase(function); } - - void ClearFunctionCache() { - for (auto* function : compiled_functions_) { - function->ClearCache(); - } - compiled_functions_.clear(); - } + xla::nb_class_ptr cache_; - private: - absl::flat_hash_set compiled_functions_; + // In no-GIL mode executables_ is protected by the object lock on cache_, + // because it shared an LRU list with cache_. + std::shared_ptr executables_; }; -// Protected by GIL. -PjitFunctionStore& GetGlobalPjitFunctionStore() { - static auto* const store = new PjitFunctionStore(); - return *store; -} - PjitFunction::PjitFunction( std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, std::vector static_argnames, nb::object global_cache_key, xla::nb_class_ptr pytree_registry, - nb::callable shard_arg_fallback, std::shared_ptr cache) + nb::callable shard_arg_fallback, xla::nb_class_ptr cache) : function_name_(std::move(function_name)), fun_(std::move(fun)), cache_miss_(std::move(cache_miss)), @@ -386,19 +387,22 @@ PjitFunction::PjitFunction( PyUnicode_InternInPlace(&s); static_argnames_.push_back(nb::steal(s)); } - - GetGlobalPjitFunctionStore().Insert(this); } void PjitFunction::InitExecutables() { + // Construction of the object hasn't completed yet, so we don't need to hold + // the cache lock to mutate executables_. if (!fun_.has_value()) { executables_ = cache_->DefaultCache(); } else { - executables_ = cache_->Lookup(fun_.value(), global_cache_key_); + executables_ = cache_->Lookup(cache_, fun_.value(), global_cache_key_); } } -PjitFunction::~PjitFunction() { GetGlobalPjitFunctionStore().Erase(this); } +PjitFunction::~PjitFunction() { + nb::ft_object_guard lock(cache_); + executables_ = nullptr; +} void CallShardArgFallback( nb::handle arg, nb::handle sharding, nb::handle layout, @@ -481,7 +485,7 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, } continue; } else { - CallShardArgFallback(arg.ptr(), in_shardings[dce_index], + CallShardArgFallback(arg, in_shardings[dce_index], in_device_local_layout, shard_arg_fallback, num_args_arrays, keep_alive_objects); continue; @@ -502,8 +506,8 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, TF_ASSIGN_OR_RETURN(auto arr_layout, py_array.ifrt_array()->layout()); xla::Layout in_xc_layout = nb::cast( in_device_local_layout.attr("_to_xla_layout")(py_array.dtype())); - if (in_xc_layout != GetXlaLayoutUnsafe(arr_layout)) { - CallShardArgFallback(arg.ptr(), in_shardings[dce_index], + if (in_xc_layout != arr_layout->xla_layout()) { + CallShardArgFallback(arg, in_shardings[dce_index], in_device_local_layout, shard_arg_fallback, num_args_arrays, keep_alive_objects); continue; @@ -511,16 +515,16 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, } if (sharding.type().ptr() == jax::PmapSharding::type().ptr()) { - CallShardArgFallback(arg.ptr(), in_shardings[dce_index], - in_device_local_layout, shard_arg_fallback, - num_args_arrays, keep_alive_objects); + CallShardArgFallback(arg, in_shardings[dce_index], in_device_local_layout, + shard_arg_fallback, num_args_arrays, + keep_alive_objects); continue; } if (sharding_num_devices != num_global_devices) { - CallShardArgFallback(arg.ptr(), in_shardings[dce_index], - in_device_local_layout, shard_arg_fallback, - num_args_arrays, keep_alive_objects); + CallShardArgFallback(arg, in_shardings[dce_index], in_device_local_layout, + shard_arg_fallback, num_args_arrays, + keep_alive_objects); continue; } @@ -660,12 +664,15 @@ absl::StatusOr PjitFunction::Call(nb::handle callable, VLOG(2) << "CallSignature:\n" << call_signature.DebugString(); bool inserted = false; - std::shared_ptr cache_entry = - executables_->GetOrCreateIfAbsent( - call_signature, [this, &inserted](const CallSignature& unused) { - inserted = true; - return std::make_shared(pytree_registry_.get()); - }); + std::shared_ptr cache_entry; + { + nb::ft_object_guard lock(cache_); + cache_entry = executables_->GetOrCreateIfAbsent( + call_signature, [this, &inserted](const CallSignature& unused) { + inserted = true; + return std::make_shared(pytree_registry_.get()); + }); + } if (!cache_entry->compilation_complete.HasBeenNotified()) { // In case of several threads attempting to compile the executable, only @@ -698,6 +705,7 @@ absl::StatusOr PjitFunction::Call(nb::handle callable, cache_entry->compilation_complete.Notify(); if (remove_cache) { + nb::ft_object_guard lock(cache_); executables_->Remove(call_signature); } @@ -805,12 +813,13 @@ absl::Status PjitFunction::ComputeCallSignature( signature.default_device = GetDefaultDevice(); signature.jax_enable_x64 = jax_enable_x64; - signature.jax_enable_memories = GetEnableMemories(); auto& dynamic_arg_signatures = signature.dynamic_arg_signatures; dynamic_arg_signatures.reserve(flat_dynamic_args.size()); auto& dynamic_arg_shardings = signature.dynamic_arg_shardings; dynamic_arg_shardings.reserve(flat_dynamic_args.size()); + auto& dynamic_arg_layouts = signature.dynamic_arg_layouts; + dynamic_arg_layouts.reserve(flat_dynamic_args.size()); for (nb::handle arg : flat_dynamic_args) { TF_ASSIGN_OR_RETURN(auto arg_signature, @@ -822,9 +831,16 @@ absl::Status PjitFunction::ComputeCallSignature( if (arg.type().ptr() == xla::PyArray::type().ptr()) { auto py_array = nb::borrow(arg); signature.dynamic_arg_shardings.push_back(py_array.sharding()); + auto layout = py_array.layout(); + if (absl::IsUnimplemented(layout.status())) { + signature.dynamic_arg_layouts.push_back(nullptr); + } else { + signature.dynamic_arg_layouts.push_back(*std::move(layout)); + } signature.committed_args.push_back(py_array.committed()); } else { signature.dynamic_arg_shardings.push_back(nb::none()); + signature.dynamic_arg_layouts.push_back(nullptr); signature.committed_args.push_back(false); } } @@ -923,8 +939,64 @@ struct PjitFunctionObject { #endif // PY_VERSION_HEX < 0x030C0000 vectorcallfunc vectorcall; PjitFunction fun; + + // Doubly-linked list of PjitFunctionObjects, protected by + // PjitFunctionStore::mu_ or the GIL in GIL mode. + PjitFunctionObject* next; + PjitFunctionObject* prev; }; +// Contains a list of all PjitFunctionObjects. +// Thread-safe. +class PjitFunctionStore { + public: + void Insert(PjitFunctionObject* o) { + nb::ft_lock_guard lock(mu_); + o->next = compiled_functions_; + o->prev = nullptr; + if (o->next) { + o->next->prev = o; + } + compiled_functions_ = o; + } + + void Remove(PjitFunctionObject* o) { + nb::ft_lock_guard lock(mu_); + if (o->next) { + o->next->prev = o->prev; + } + if (o->prev) { + o->prev->next = o->next; + } else { + compiled_functions_ = o->next; + } + } + + void ClearCaches() { + std::vector< + std::pair>> + caches; + { + nb::ft_lock_guard lock(mu_); + for (PjitFunctionObject* fn = compiled_functions_; fn != nullptr; + fn = fn->next) { + caches.emplace_back(fn->fun.cache(), fn->fun.executables()); + } + } + for (auto& [cache, executables] : caches) { + nb::ft_object_guard lock(cache); + executables->Clear(); + } + }; + + private: + // Protected by the GIL in GIL mode, and by mu_ in freethreading mode. + nb::ft_mutex mu_; + PjitFunctionObject* compiled_functions_; +}; + +PjitFunctionStore pjit_function_store; + PyObject* PjitFunction_Type = nullptr; bool PjitFunction::IsPjitFunction(nb::handle handle) { @@ -990,6 +1062,7 @@ void PjitFunction_tp_dealloc(PyObject* self) { PyObject_GC_UnTrack(self); PyTypeObject* tp = Py_TYPE(self); PjitFunctionObject* o = reinterpret_cast(self); + pjit_function_store.Remove(o); PyObject_ClearWeakRefs(self); #if PY_VERSION_HEX < 0x030C0000 Py_CLEAR(o->dict); @@ -1060,8 +1133,8 @@ static PyGetSetDef PjitFunction_tp_getset[] = { PyObject* PjitFunction_tp_repr(PyObject* self) { try { const std::string& repr = absl::StrFormat( - "", - nb::cast(nb::repr(nb::getattr(self, "__wrapped__")))); + "", nb::cast(nb::repr( + nb::getattr(self, "__wrapped__")))); return PyUnicode_FromString(repr.c_str()); } catch (...) { // Ignore all errors when accessing a repr. @@ -1077,7 +1150,9 @@ void InitializePjitFunction( std::vector static_argnums, std::vector static_argnames, nb::object global_cache_key, xla::nb_class_ptr pytree_registry, - nb::callable shard_arg_fallback, std::shared_ptr cache) { + nb::callable shard_arg_fallback, + xla::nb_class_ptr cache) { + fn_obj->next = fn_obj->prev = nullptr; if (nb::isinstance(global_cache_key)) { global_cache_key = nb::tuple(global_cache_key); } @@ -1089,6 +1164,10 @@ void InitializePjitFunction( // Handled separately because it is not exception safe to call this // in the constructor because it leaves the object improperly constructed. fn_obj->fun.InitExecutables(); + + // Only add the executable to the store after executables_ has been + // initialized. We want only fully constructed executables in the store. + pjit_function_store.Insert(fn_obj); } nb::object MakePjitFunction( @@ -1097,12 +1176,12 @@ nb::object MakePjitFunction( std::vector static_argnames, nb::object global_cache_key, xla::nb_class_ptr pytree_registry, nb::callable shard_arg_fallback, - std::optional> cache) { + std::optional> cache) { nb::object obj = nb::steal(PjitFunction_tp_new( reinterpret_cast(PjitFunction_Type), nullptr, nullptr)); PjitFunctionObject* fn_obj = reinterpret_cast(obj.ptr()); if (!cache) { - cache = std::make_shared( + cache = xla::make_nb_class( PjitFunctionCache::kDefaultCapacity); } InitializePjitFunction( @@ -1151,19 +1230,20 @@ void BuildPjitSubmodule(nb::module_& m) { nb::class_ cache(m, "PjitFunctionCache"); cache.def(nb::init(), nb::arg("capacity") = PjitFunctionCache::kDefaultCapacity); - cache.def("size", &PjitFunctionCache::Size); - cache.def("capacity", &PjitFunctionCache::Capacity); - cache.def("clear", &PjitFunctionCache::Clear); - cache.def_static("clear_all", - []() { GetGlobalPjitFunctionStore().ClearFunctionCache(); }); - cache.def("__getstate__", - // Pickles as an empty cache; the client can repopulate as needed. - [](const PjitFunctionCache& cache) { - nb::dict pickle; - pickle["version"] = kPjitFunctionPickleVersion; - pickle["capacity"] = cache.Capacity(); - return pickle; - }); + cache.def("size", &PjitFunctionCache::Size, nb::lock_self()); + cache.def("capacity", &PjitFunctionCache::Capacity, nb::lock_self()); + cache.def("clear", &PjitFunctionCache::Clear, nb::lock_self()); + cache.def_static("clear_all", []() { pjit_function_store.ClearCaches(); }); + cache.def( + "__getstate__", + // Pickles as an empty cache; the client can repopulate as needed. + [](const PjitFunctionCache& cache) { + nb::dict pickle; + pickle["version"] = kPjitFunctionPickleVersion; + pickle["capacity"] = cache.Capacity(); + return pickle; + }, + nb::lock_self()); cache.def("__setstate__", [](PjitFunctionCache* cache, const nb::dict& pickle) { int version = nb::cast(pickle["version"]); @@ -1255,8 +1335,8 @@ void BuildPjitSubmodule(nb::module_& m) { nb::handle(pickle["pytree_registry"].ptr())); nb::callable shard_arg_fallback = nb::cast(pickle["shard_arg_fallback"]); - std::shared_ptr cache = - nb::cast>(pickle["cache"]); + xla::nb_class_ptr cache = + nb::cast>(pickle["cache"]); InitializePjitFunction( reinterpret_cast(self.ptr()), std::move(function_name), std::move(fun), std::move(cache_miss), @@ -1289,7 +1369,7 @@ void BuildPjitSubmodule(nb::module_& m) { nb::callable cache_miss, std::vector static_argnums, std::vector static_argnames, nb::object global_cache_key, nb::object pytree_registry, nb::callable shard_arg_fallback, - std::optional> cache) { + std::optional> cache) { xla::nb_class_ptr registry = nb::cast>( nb::handle(pytree_registry.ptr())); diff --git a/third_party/xla/xla/python/pjrt_ifrt/BUILD b/third_party/xla/xla/python/pjrt_ifrt/BUILD index 92b3d9e264d36a..c8ba2027b5b495 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/BUILD +++ b/third_party/xla/xla/python/pjrt_ifrt/BUILD @@ -377,9 +377,7 @@ xla_cc_test( deps = [ ":basic_string_array", ":pjrt_cpu_client_multi_process_test_lib", - "//xla:shape_util", "//xla/pjrt:pjrt_future", - "//xla/pjrt:pjrt_layout", "//xla/python/ifrt", "//xla/python/ifrt:test_util", "//xla/tsl/concurrency:ref_count", diff --git a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc index d3b9fd1be984f5..7006caaae2f549 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc @@ -53,35 +53,6 @@ limitations under the License. namespace xla { namespace ifrt { -///////////////////////////////////////////////////////////////////////////// -// -// BasicStringArrayLayout -// - -std::string BasicStringArrayLayout::Serialize() const { - // We currently do not have any state that need to be serialized. Return an - // empty string. - return std::string(); -} - -std::string BasicStringArrayLayout::ToString() const { - return "BasicStringArrayLayout: Dense, major-to-minor."; -} - -bool BasicStringArrayLayout::operator==(const PjRtLayout& other) const { - auto* other_basic_string_array_layout = - dynamic_cast(&other); - if (other_basic_string_array_layout == nullptr) { - return false; - } - // All BasicStringArrayLayout objects are the same - they are all dense, - // major-to-minor. So, all of them are equal. - return true; -} - -void BasicStringArrayLayout::Hash(absl::HashState state) const { -} // Nothing to add to the hash state. Just return. - ///////////////////////////////////////////////////////////////////////////// // // BasicStringArray @@ -446,12 +417,9 @@ absl::StatusOr> BasicStringArray::FullyReplicatedShard( std::move(buffers_future), std::move(on_done_with_buffer)); } -absl::StatusOr> BasicStringArray::layout() const { - absl::MutexLock lock(&mu_); - if (is_deleted_) { - return absl::FailedPreconditionError("Array has already been deleted"); - } - return std::make_unique(); +absl::StatusOr> BasicStringArray::layout() + const { + return absl::UnimplementedError("String arrays do not support PjRtLayout"); } std::string BasicStringArray::DebugString() const { diff --git a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.h b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.h index a430cfa73fdd26..c7ce68d85c9e52 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.h +++ b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.h @@ -46,22 +46,6 @@ limitations under the License. namespace xla { namespace ifrt { -// Describes the layout of a `BasicStringArray`. -class BasicStringArrayLayout : public PjRtLayout { - public: - BasicStringArrayLayout() = default; - BasicStringArrayLayout(const BasicStringArrayLayout& other) = delete; - - ~BasicStringArrayLayout() override = default; - - std::string Serialize() const override; - std::string ToString() const override; - bool operator==(const PjRtLayout& other) const override; - - protected: - void Hash(absl::HashState state) const override; -}; - // `BasicStringArray` implements an `ifrt::Array` by wrapping a local (aka host) // string buffer. This object is expected to live exclusively in the IFRT layer, // and thus is not specific to any particular backend. However, it is currently @@ -121,7 +105,7 @@ class BasicStringArray final return sharding_; } - absl::StatusOr> layout() const override; + absl::StatusOr> layout() const override; absl::StatusOr>> DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) override; diff --git a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc index c402f0a38ecdb2..644abe66d25a3a 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc @@ -33,9 +33,7 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" -#include "xla/layout.h" #include "xla/pjrt/pjrt_future.h" -#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" @@ -124,46 +122,6 @@ CreateNonReadyTestArray( return std::make_pair(std::move(array), std::move(buffers_promise)); } -///////////////////////////////////////////////////////////////////////////// -// -// Tests related to BasicStringArrayLayout. -// - -TEST(BasicStringArrayLayoutTest, Serialize) { - BasicStringArrayLayout layout; - // Seerialize currently has no state to serialize, and so the returned value - // should be an empty string. - EXPECT_TRUE(layout.Serialize().empty()); -} - -TEST(BasicStringArrayLayoutTest, ToString) { - BasicStringArrayLayout layout; - auto output_str = layout.ToString(); - EXPECT_THAT(output_str, HasSubstr("major-to-minor")); -} - -TEST(BasicStringArrayLayoutTest, Equality) { - BasicStringArrayLayout layout_1; - - // In the equality comparisons below, use the PjRtLayout interface for the - // second object so we can avoid the error: `ambiguity is between a regular - // call to this operator and a call with the argument order reversed`. - - // Any two BasicStringArrayLayouts are equal. - BasicStringArrayLayout layout_2; - const PjRtLayout& layout_3 = layout_2; - EXPECT_EQ(layout_1, layout_3); - - // In the next test, EXPECT_NE is not used because the version of EXCEPT_NE - // available in the open sourced libraries requires the operator `!=` to be - // overloaded. - - // Non-BasicStringArrayLayouts are not equal to BasicStringArrayLayouts. - xla::PjRtXlaLayout layout_6((xla::Layout())); - const PjRtLayout& layout_7 = layout_6; - EXPECT_FALSE(layout_7 == layout_1); -} - ///////////////////////////////////////////////////////////////////////////// // // Tests related to BasicStringArray. @@ -948,13 +906,6 @@ TEST(LayoutTest, Success) { CreateTestArray(client.get(), Future(std::move(buffers)), std::move(on_done_with_buffer))); - - // The number of dimensions for the testArray should be 1. Typical usage of - // BasicStringArrayLayout does not require an accessor to retrieve the number - // of dimensions. Instead of adding a test only method, we could just check - // the serialized layout. - TF_ASSERT_OK_AND_ASSIGN(auto layout, array->layout()); - EXPECT_TRUE(layout->Serialize().empty()); } TEST(LayoutTest, FailsAfterDeletion) { @@ -969,8 +920,6 @@ TEST(LayoutTest, FailsAfterDeletion) { std::move(on_done_with_buffer))); array->Delete(); - - EXPECT_THAT(array->layout(), StatusIs(absl::StatusCode::kFailedPrecondition)); } ///////////////////////////////////////////////////////////////////////////// diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc index 0c04f21a533464..db429bea24f83a 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc @@ -72,17 +72,17 @@ absl::Status ValidateArrayCreationInput( if (pjrt_buffers.empty()) { return InvalidArgument("pjrt_buffers must be non-empty"); } - if (sharding->devices()->size() != pjrt_buffers.size()) { + absl::Span sharding_devices = + sharding->devices()->AddressableDeviceList()->devices(); + if (sharding_devices.size() != pjrt_buffers.size()) { return InvalidArgument("device and buffer counts mismatch: %d vs. %d", - sharding->devices()->size(), pjrt_buffers.size()); + sharding_devices.size(), pjrt_buffers.size()); } // Canonicalize memory kind in case it hasn't been done before. - MemoryKind canonicalized_sharding_memory_kind = CanonicalizeMemoryKind( - sharding->memory_kind(), sharding->devices()->devices().front()); - const absl::Span sharding_devices = - sharding->devices()->devices(); - for (int i = 0; i < sharding->devices()->size(); ++i) { + MemoryKind canonicalized_sharding_memory_kind = + CanonicalizeMemoryKind(sharding->memory_kind(), sharding_devices.front()); + for (int i = 0; i < sharding_devices.size(); ++i) { PjRtCompatibleDevice* device = llvm::dyn_cast(sharding_devices[i]); if (!device) { @@ -553,7 +553,7 @@ bool PjRtArray::IsDeleted() const { std::string PjRtArray::DebugString() const { DCHECK(this); - absl::StatusOr> layout_ptr = layout(); + absl::StatusOr> layout_ptr = layout(); std::string layout_str = layout_ptr.ok() ? (*layout_ptr)->ToString() : ""; @@ -566,12 +566,12 @@ std::string PjRtArray::DebugString() const { // TODO(b/330198879): populate layout at construction instead of accessing PJRT // buffer directly for consistency with Pathways. -absl::StatusOr> PjRtArray::layout() const { +absl::StatusOr> PjRtArray::layout() const { CHECK(!pjrt_buffers_.empty()); - std::unique_ptr layout = pjrt_buffers_[0]->layout(); + std::shared_ptr layout = pjrt_buffers_[0]->layout(); #ifndef NDEBUG for (int i = 1; i < pjrt_buffers_.size(); ++i) { - std::unique_ptr layout_i = pjrt_buffers_[i]->layout(); + std::shared_ptr layout_i = pjrt_buffers_[i]->layout(); DCHECK(*layout == *layout_i) << "PjRtArray has mismatched layouts across shards! " << "shard 0: " << layout->ToString() << ", shard " << i << ": " diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h index d14747fea550ea..7a88f708248393 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h @@ -151,7 +151,7 @@ class PjRtArray final return sharding_; } - absl::StatusOr> layout() const override; + absl::StatusOr> layout() const override; absl::StatusOr>> DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) override; diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc index fa3dbfe2ee65ca..ccf1a3889a75b9 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -45,9 +44,11 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "xla/layout.h" +#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/pjrt/distributed/protocol.pb.h" #include "xla/pjrt/distributed/topology_util.h" +#include "xla/pjrt/host_memory_spaces.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_compiler.h" @@ -466,7 +467,7 @@ MakePjRtDevicesFromGlobalTopology(PjRtClient* client, int64_t slice_index = -1; if (!node.boot_id().empty()) { // Every new boot_id seen is treated as a new host/slice. - std::string_view boot_id = node.boot_id(); + absl::string_view boot_id = node.boot_id(); auto [it, inserted] = boot_id_to_slice_index.try_emplace(boot_id, next_slice_index); slice_index = it->second; @@ -1117,14 +1118,18 @@ absl::StatusOr> PjRtClient::GetTopologyForDevices( topology)); } -absl::StatusOr> -PjRtClient::GetDefaultLayoutForDevice(DType dtype, - absl::Span dims, - Device* device) const { +absl::StatusOr> PjRtClient::GetDefaultLayout( + DType dtype, absl::Span dims, Device* device, + MemoryKind memory_kind) const { + static MemoryKind kUnpinnedHostMemoryKind(UnpinnedHostMemorySpace::kKind); + if (memory_kind == kUnpinnedHostMemoryKind) { + return std::make_shared( + LayoutUtil::MakeDescendingLayout(dims.size())); + } TF_ASSIGN_OR_RETURN(PrimitiveType element_type, ToPrimitiveType(dtype)); TF_ASSIGN_OR_RETURN(xla::Layout layout, pjrt_client_->GetDefaultLayout(element_type, dims)); - return std::make_unique(std::move(layout)); + return std::make_shared(std::move(layout)); } absl::Status PjRtClient::TransferToInfeed(PjRtDevice* device, diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h index 4849f5329e9e07..634f74d398a1ff 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h @@ -259,9 +259,9 @@ class PjRtClient final absl::StatusOr> GetTopologyForDevices( const tsl::RCReference& devices) const override; - absl::StatusOr> GetDefaultLayoutForDevice( - DType dtype, absl::Span dims, - Device* device) const override; + absl::StatusOr> GetDefaultLayout( + DType dtype, absl::Span dims, Device* device, + MemoryKind memory_kind) const override; absl::StatusOr LookupPjRtDevice( xla::PjRtDevice* pjrt_device) const override; diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_compiler.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_compiler.cc index 2d476abc54e633..407ca4bd5da199 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_compiler.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_compiler.cc @@ -120,8 +120,7 @@ absl::StatusOr> PjRtCompiler::Compile( auto executable, PjRtCompile(xla_compile_options->compile_options, xla_program->mlir_module, *pjrt_topology->description())); - return PjRtExecutable::Create(std::move(executable), - std::move(xla_compile_options)); + return PjRtExecutable::Create(std::move(executable)); } absl::StatusOr> diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc index 60f0e6bba78b0c..44706a41b54f59 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc @@ -184,10 +184,9 @@ char PjRtExecutable::ID = 0; char PjRtLoadedExecutable::ID = 0; absl::StatusOr> PjRtExecutable::Create( - std::shared_ptr pjrt_executable, - std::unique_ptr compile_options) { - return std::unique_ptr(new PjRtExecutable( - std::move(pjrt_executable), std::move(compile_options))); + std::shared_ptr pjrt_executable) { + return std::unique_ptr( + new PjRtExecutable(std::move(pjrt_executable))); } absl::StatusOr> PjRtExecutable::Fingerprint() const { diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.h index ce83ee0da24de1..b6c8c359133bfe 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.h +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.h @@ -86,8 +86,7 @@ class PjRtExecutable final public: // Creates PjRtExecutable from xla::PjRtExecutable. static absl::StatusOr> Create( - std::shared_ptr pjrt_executable, - std::unique_ptr compile_options); + std::shared_ptr pjrt_executable); // PjRtCompatibleExecutable implementation. @@ -116,13 +115,13 @@ class PjRtExecutable final return pjrt_executable_->GetOutputShardings(); } - absl::StatusOr>> + absl::StatusOr>> GetParameterLayouts() const override { DCHECK(this); return pjrt_executable_->GetParameterLayouts(); } - absl::StatusOr>> + absl::StatusOr>> GetOutputLayouts() const override { DCHECK(this); return pjrt_executable_->GetOutputLayouts(); @@ -162,20 +161,13 @@ class PjRtExecutable final return pjrt_executable_->GetOutputMemoryKinds(); } - const XlaCompileOptions* GetCompileOptions() const override { - return compile_options_.get(); - } - static char ID; // NOLINT protected: - explicit PjRtExecutable(std::shared_ptr pjrt_executable, - std::unique_ptr compile_options) - : pjrt_executable_(std::move(pjrt_executable)), - compile_options_(std::move(compile_options)) {} + explicit PjRtExecutable(std::shared_ptr pjrt_executable) + : pjrt_executable_(std::move(pjrt_executable)) {} std::shared_ptr pjrt_executable_; - std::unique_ptr compile_options_; }; // `LoadedExecutable` implementation that wraps a `xla::PjRtLoadedExecutable`. @@ -242,13 +234,13 @@ class PjRtLoadedExecutable final return pjrt_loaded_executable_->GetOutputShardings(); } - absl::StatusOr>> + absl::StatusOr>> GetParameterLayouts() const override { DCHECK(this); return pjrt_loaded_executable_->GetParameterLayouts(); } - absl::StatusOr>> + absl::StatusOr>> GetOutputLayouts() const override { DCHECK(this); return pjrt_loaded_executable_->GetOutputLayouts(); diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_memory.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_memory.cc index 8edb3bfa29fe2c..ebe1d86f915dfa 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_memory.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_memory.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_device_description.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/memory.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_memory.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_memory.h index 3964ac56b184d5..3e69a151555b53 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_memory.h +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_memory.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_device_description.h" #include "xla/python/ifrt/memory.h" namespace xla { diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_remap.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_remap.cc index 6544cc32fa5d5a..c77b0e5e608fe0 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_remap.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_remap.cc @@ -26,7 +26,6 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "xla/pjrt/pjrt_client.h" #include "xla/python/ifrt/array.h" -#include "xla/python/ifrt/device.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/remap_plan.h" #include "xla/python/ifrt/shape.h" @@ -45,17 +44,49 @@ PjRtCompatibleClientRemapArrays( PjRtCompatibleClient* client, const RemapPlan& plan, absl::Span> arrays, ArrayCopySemantics semantics) { - const int num_inputs = arrays.size(); + TF_RETURN_IF_ERROR(plan.CheckArrayCopySemantics(semantics)); + const int num_inputs = plan.input_specs.size(); + const int num_actual_inputs = arrays.size(); + const int num_outputs = plan.output_specs.size(); + if (num_inputs != num_actual_inputs) { + return InvalidArgument("RemapArrays expects %d input arrays, but got %d", + num_inputs, num_actual_inputs); + } for (int i = 0; i < num_inputs; ++i) { if (!llvm::isa(arrays[i].get())) { return InvalidArgument( - "Only PjRtCompatibleArray is supported: arrays[%d]=%s", i, + "Only PjRtCompatibleArray is supported, but input#%d is %s", i, arrays[i]->DebugString()); } + + if (plan.input_specs[i].dtype != arrays[i]->dtype()) { + return InvalidArgument( + "RemapArrays expects input #%d to have dtype %v, but got %v", i, + plan.input_specs[i].dtype, arrays[i]->dtype()); + } + if (plan.input_specs[i].shape != arrays[i]->shape()) { + return InvalidArgument( + "RemapArrays expects input #%d to have shape %v, but got %v", i, + plan.input_specs[i].shape, arrays[i]->shape().DebugString()); + } + // Skip xla::ifrt::Sharding::HasSamePartitioning() check because RemapArrays + // is currently called with input arrays with implicit sharding + // reinterpretation. Such patterns should be fixed before enabling stricter + // checking to avoid false positives. + if (*plan.input_specs[i].sharding->devices() != + *arrays[i]->sharding().devices() || + plan.input_specs[i].sharding->memory_kind() != + arrays[i]->sharding().memory_kind()) { + return InvalidArgument( + "RemapArrays expects input #%d to be on %v with " + "%v, but is on %v with %v", + i, *plan.input_specs[i].sharding->devices(), + plan.input_specs[i].sharding->memory_kind(), + *arrays[i]->sharding().devices(), + arrays[i]->sharding().memory_kind()); + } } - TF_RETURN_IF_ERROR(plan.CheckArrayCopySemantics(semantics)); - const int num_outputs = plan.output_specs.size(); std::vector out_buffers_list(num_outputs); for (int i = 0; i < num_outputs; ++i) { out_buffers_list[i].resize( diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_topology.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_topology.h index 81adf1bda215df..2543fe750757ef 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_topology.h +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_topology.h @@ -39,15 +39,15 @@ class PjRtTopology final : public llvm::RTTIExtends { explicit PjRtTopology( std::shared_ptr description); - const std::shared_ptr& description() - const { - return description_; - } - absl::string_view platform_name() const override; absl::string_view platform_version() const override; PjRtPlatformId platform_id() const override; + const std::shared_ptr& description() + const override { + return description_; + } + std::vector> DeviceDescriptions() const override; diff --git a/third_party/xla/xla/python/pmap_lib.cc b/third_party/xla/xla/python/pmap_lib.cc index 397a4786328b0d..f1d9a25144c72e 100644 --- a/third_party/xla/xla/python/pmap_lib.cc +++ b/third_party/xla/xla/python/pmap_lib.cc @@ -139,15 +139,14 @@ absl::StatusOr ShardArg( auto py_array = nb::borrow(arg); if (py_array.sharding().type().ptr() == input_spec.array_sharding.type().ptr()) { - auto* pmap_sharding = - nb::cast(nb::handle(py_array.sharding().ptr())); - auto* cached_pmap_sharding = nb::cast( - nb::handle(input_spec.array_sharding.ptr())); + auto* pmap_sharding = nb::cast(py_array.sharding()); + auto* cached_pmap_sharding = + nb::cast(input_spec.array_sharding); if (pmap_sharding->sharding_spec() == cached_pmap_sharding->sharding_spec()) { ShardArgResult result; - result.owning_sda = nb::borrow(arg.ptr()); + result.owning_sda = nb::borrow(arg); result.ifrt_array = tsl::FormRef(py_array.ifrt_array()); if (result.ifrt_array == nullptr) { return xla::InvalidArgument("Array has been deleted."); @@ -258,7 +257,7 @@ absl::StatusOr ShardArg( auto py_array = nb::cast(py_array_or_bufs); ShardArgResult result; - result.owning_sda = nb::borrow(py_array_or_bufs.ptr()); + result.owning_sda = nb::borrow(py_array_or_bufs); result.ifrt_array = tsl::FormRef(py_array.ifrt_array()); return result; } @@ -331,8 +330,14 @@ class PmapFunction { return inspect->attr("signature")(fun_); } - int cache_size() const { return executables_.size(); } - void cache_clear() { return executables_.clear(); } + int cache_size() { + nb::ft_lock_guard lock(mu_); + return executables_.size(); + } + void cache_clear() { + nb::ft_lock_guard lock(mu_); + return executables_.clear(); + } const nb::callable& fun() const { return fun_; } const nb::callable& cache_miss() const { return cache_miss_; } const std::string& function_name() const { return function_name_; } @@ -407,7 +412,8 @@ class PmapFunction { // cache and recompiles), the list of the string representations of the keys. // // The format can change at any time. - std::string DebugCacheKeys() const { + std::string DebugCacheKeys() { + nb::ft_lock_guard lock(mu_); std::vector key_strings = { absl::StrCat("The cache contains ", executables_.size(), " elements:")}; // We will be able to use auto& [key, _] when TF uses C++ 17. @@ -433,13 +439,18 @@ class PmapFunction { // passed to the underlying PyLoadedExecutable. In sorted order. std::vector static_argnums_; xla::nb_class_ptr pytree_registry_; - // We need a `unique_ptr` here to ensure value pointer stability. - absl::flat_hash_map> + // We need a `shared_ptr` here to ensure value pointer stability, and to + // ensure that the cache entry remains alive in the presence of concurrent + // removals. + absl::flat_hash_map> executables_; // The fallback function to use with `ShardArgs`. // TODO(jblespiau): Add support for more types from C++. nb::callable python_shard_arg_fallback_; + + // Protect methods in FT: + nb::ft_mutex mu_; }; void PmapFunction::PopulateCacheEntry(PmapCacheEntry& cache_entry, @@ -496,8 +507,7 @@ void PmapFunction::PopulateCacheEntry(PmapCacheEntry& cache_entry, } // Outputs specs. - auto out_tree = nb::cast( - nb::handle(pmap_data.attr("out_pytree_def").ptr())); + auto out_tree = nb::cast(pmap_data.attr("out_pytree_def")); cache_entry.out_pytree_def = std::move(out_tree); nb::list out_avals = pmap_data.attr("out_avals"); @@ -583,15 +593,17 @@ absl::StatusOr PmapFunction::Call(nb::handle callable, } // Retrieve/Maybe add the executable to the cache. - absl::flat_hash_map>::iterator - it; - bool inserted; - std::tie(it, inserted) = executables_.try_emplace( - call_signature, std::unique_ptr()); - if (inserted) { - it->second = std::make_unique(pytree_registry_.get()); + bool inserted = false; + std::shared_ptr cache_entry_ptr; + { + nb::ft_lock_guard lock(mu_); + cache_entry_ptr = executables_[call_signature]; + } + if (cache_entry_ptr == nullptr) { + inserted = true; + cache_entry_ptr = std::make_shared(pytree_registry_.get()); } - PmapCacheEntry& cache_entry = *(it->second); + PmapCacheEntry& cache_entry = *cache_entry_ptr; if (!cache_entry.compilation_complete.HasBeenNotified()) { // In case of several threads attempting to compile the executable, only @@ -642,7 +654,7 @@ absl::StatusOr PmapFunction::Call(nb::handle callable, for (int i = 0; i < num_args; ++i) { TF_ASSIGN_OR_RETURN( ShardArgResult sharded_arg, - ShardArg(flat_dynamic_args[i].ptr(), input_devices, input_specs[i], + ShardArg(flat_dynamic_args[i], input_devices, input_specs[i], cache_entry.py_devices, python_shard_arg_fallback_)); num_args_arrays[i] = std::move(sharded_arg.ifrt_array); @@ -711,8 +723,7 @@ absl::StatusOr PmapFunction::Call(nb::handle callable, } } - (*post_hook)(nb::handle(callable.ptr()), args_tuple, kwargs, - nb::handle(out.ptr())); + (*post_hook)(callable, args_tuple, kwargs, out); } return out; @@ -882,9 +893,8 @@ const int kPmapFunctionPickleVersion = 1; void BuildPmapSubmodule(nb::module_& m) { nb::module_ pmap_lib = m.def_submodule("pmap_lib", "Jax C++ pmap library"); - nb::module_ pmap_lib_nb = nb::cast(nb::borrow(pmap_lib.ptr())); - nb::class_ no_sharding(pmap_lib_nb, "NoSharding"); + nb::class_ no_sharding(pmap_lib, "NoSharding"); no_sharding.def(nb::init<>()) .def("__getstate__", [](const NoSharding& self) { return nb::make_tuple(); }) @@ -901,7 +911,7 @@ void BuildPmapSubmodule(nb::module_& m) { return nb::int_(hash); }); - nb::class_ chunked(pmap_lib_nb, "Chunked"); + nb::class_ chunked(pmap_lib, "Chunked"); chunked.def(nb::init>()) .def("__getstate__", [](const Chunked& self) { return nb::make_tuple(self.chunks); }) @@ -922,7 +932,7 @@ void BuildPmapSubmodule(nb::module_& m) { return self == nb::cast(other); }); - nb::class_ unstacked(pmap_lib_nb, "Unstacked"); + nb::class_ unstacked(pmap_lib, "Unstacked"); unstacked.def(nb::init()) .def("__getstate__", [](const Unstacked& self) { return nb::make_tuple(self.size); }) @@ -942,7 +952,7 @@ void BuildPmapSubmodule(nb::module_& m) { return self == nb::cast(other); }); - nb::class_ sharded_axis(pmap_lib_nb, "ShardedAxis"); + nb::class_ sharded_axis(pmap_lib, "ShardedAxis"); sharded_axis.def(nb::init()) .def("__getstate__", [](const ShardedAxis& self) { return nb::make_tuple(self.axis); }) @@ -959,7 +969,7 @@ void BuildPmapSubmodule(nb::module_& m) { return self == other; }); - nb::class_ replicated(pmap_lib_nb, "Replicated"); + nb::class_ replicated(pmap_lib, "Replicated"); replicated.def(nb::init()) .def("__getstate__", [](const Replicated& self) { return nb::make_tuple(self.replicas); }) @@ -976,7 +986,7 @@ void BuildPmapSubmodule(nb::module_& m) { return self == other; }); - nb::class_ sharding_spec(pmap_lib_nb, "ShardingSpec"); + nb::class_ sharding_spec(pmap_lib, "ShardingSpec"); sharding_spec .def(nb::init(), nb::arg("sharding"), nb::arg("mesh_mapping")) @@ -1091,7 +1101,7 @@ void BuildPmapSubmodule(nb::module_& m) { nb::cast(pickle["python_shard_arg_fallback"]); xla::nb_class_ptr pytree_registry = nb::cast>( - nb::handle(pickle["pytree_registry"].ptr())); + pickle["pytree_registry"]); new (&(reinterpret_cast(self.ptr())->fun)) PmapFunction(std::move(fun), std::move(cache_miss), std::move(static_argnums), @@ -1127,8 +1137,7 @@ void BuildPmapSubmodule(nb::module_& m) { std::vector static_argnums, nb::callable shard_arg_fallback, nb::object pytree_registry) -> nb::object { xla::nb_class_ptr registry = - nb::cast>( - nb::handle(pytree_registry.ptr())); + nb::cast>(pytree_registry); return MakePmapFunction( std::move(fun), std::move(cache_miss), std::move(static_argnums), std::move(shard_arg_fallback), std::move(registry)); diff --git a/third_party/xla/xla/python/pprof_profile_builder.cc b/third_party/xla/xla/python/pprof_profile_builder.cc index e3bf8104eab9aa..483624d417817c 100644 --- a/third_party/xla/xla/python/pprof_profile_builder.cc +++ b/third_party/xla/xla/python/pprof_profile_builder.cc @@ -18,7 +18,6 @@ limitations under the License. #include // IWYU pragma: keep #include -#include #include #include "absl/status/statusor.h" @@ -27,6 +26,7 @@ limitations under the License. #include "xla/util.h" #include "tsl/platform/logging.h" #include "tsl/platform/protobuf.h" +#include "tsl/profiler/protobuf/profile.pb.h" namespace xla { @@ -34,7 +34,7 @@ namespace nb = nanobind; PprofProfileBuilder::PprofProfileBuilder() { CHECK_EQ(0, StringId("")); } -int PprofProfileBuilder::StringId(std::string_view s) { +int PprofProfileBuilder::StringId(absl::string_view s) { auto ret = strings_.emplace(s, profile_.string_table_size()); if (ret.second) { profile_.add_string_table(s.data(), s.size()); @@ -48,11 +48,11 @@ int PprofProfileBuilder::FunctionId(PyCodeObject* code) { if (ret.second) { auto* function = profile_.add_function(); function->set_id(ret.first->second); - int name = StringId(nb::cast(nb::str(code->co_name))); + int name = StringId(nb::cast(nb::str(code->co_name))); function->set_name(name); function->set_system_name(name); function->set_filename( - StringId(nb::cast(nb::str(code->co_filename)))); + StringId(nb::cast(nb::str(code->co_filename)))); function->set_start_line(code->co_firstlineno); } return ret.first->second; diff --git a/third_party/xla/xla/python/pprof_profile_builder.h b/third_party/xla/xla/python/pprof_profile_builder.h index ca0e6f04e57f9e..8c1ee9afb784a9 100644 --- a/third_party/xla/xla/python/pprof_profile_builder.h +++ b/third_party/xla/xla/python/pprof_profile_builder.h @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include #include #include "absl/container/flat_hash_map.h" @@ -36,7 +35,7 @@ class PprofProfileBuilder { tensorflow::tfprof::pprof::Profile& profile() { return profile_; } // Adds or returns the ID of `s` in the table. - int StringId(std::string_view s); + int StringId(absl::string_view s); // Adds or returns the ID of a function. int FunctionId(PyCodeObject* code); diff --git a/third_party/xla/xla/python/profiler.cc b/third_party/xla/xla/python/profiler.cc index 9afe7d695ff7cc..cee7ae5cecbdcc 100644 --- a/third_party/xla/xla/python/profiler.cc +++ b/third_party/xla/xla/python/profiler.cc @@ -15,14 +15,13 @@ limitations under the License. #include "xla/python/profiler.h" -#include #include #include -#include #include #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/pair.h" // IWYU pragma: keep @@ -30,10 +29,7 @@ limitations under the License. #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "xla/backends/profiler/plugin/plugin_tracer.h" -#include "xla/backends/profiler/plugin/profiler_c_api.h" #include "xla/pjrt/c/pjrt_c_api.h" -#include "xla/pjrt/c/pjrt_c_api_profiler_extension.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/status_casters.h" #include "xla/python/aggregate_profile.h" @@ -44,10 +40,11 @@ limitations under the License. #include "xla/tsl/profiler/rpc/profiler_server.h" #include "tsl/platform/macros.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep -#include "tsl/profiler/lib/profiler_factory.h" -#include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/lib/profiler_session.h" #include "tsl/profiler/lib/traceme.h" +#include "tsl/profiler/protobuf/profiled_instructions.pb.h" +#include "tsl/profiler/protobuf/profiler_options.pb.h" +#include "tsl/profiler/protobuf/xplane.pb.h" namespace xla { @@ -93,7 +90,7 @@ class TraceMeWrapper { static void AppendMetadata(std::string* name, const nb::kwargs& kwargs) { name->push_back('#'); for (const auto& kv : kwargs) { - absl::StrAppend(name, nb::cast(kv.first), "=", + absl::StrAppend(name, nb::cast(kv.first), "=", EncodePyObject(kv.second), ","); } name->back() = '#'; @@ -131,7 +128,7 @@ struct ProfilerSessionWrapper { static std::string GetFdoProfile(const std::string& xspace, bool as_textproto = false) { tensorflow::profiler::XSpace xspace_proto; - // TODO(phawkins): change to std::string_view when protobuf is + // TODO(phawkins): change to absl::string_view when protobuf is // updated in XLA. xspace_proto.ParseFromString(std::string(xspace.c_str(), xspace.size())); tensorflow::profiler::ProfiledInstructionsProto fdo_profile; @@ -161,7 +158,7 @@ void BuildProfilerSubmodule(nb::module_& m) { }, nb::arg("port")); profiler.def("register_plugin_profiler", [](nb::capsule c_api) -> void { - if (std::string_view(c_api.name()) != "pjrt_c_api") { + if (absl::string_view(c_api.name()) != "pjrt_c_api") { throw xla::XlaRuntimeError( "Argument to register_plugin_profiler was not a pjrt_c_api capsule."); } @@ -211,7 +208,7 @@ void BuildProfilerSubmodule(nb::module_& m) { [](ProfilerSessionWrapper* sess, nb::bytes xspace, const std::string& tensorboard_dir) -> void { tensorflow::profiler::XSpace xspace_proto; - // TODO(phawkins): change to std::string_view when protobuf is + // TODO(phawkins): change to absl::string_view when protobuf is // updated in XLA. xspace_proto.ParseFromString( std::string(xspace.c_str(), xspace.size())); diff --git a/third_party/xla/xla/python/profiler/internal/python_hooks.cc b/third_party/xla/xla/python/profiler/internal/python_hooks.cc index 0da1fe5e0124b5..504f918bccc499 100644 --- a/third_party/xla/xla/python/profiler/internal/python_hooks.cc +++ b/third_party/xla/xla/python/profiler/internal/python_hooks.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/python/profiler/internal/python_hooks.h" #include +#include #include #include "absl/log/log.h" @@ -60,7 +61,7 @@ std::string GetEventName(PyObject* co_filename, PyObject* co_name, " ", function); } -std::string GetEventName(std::string_view method_name, PyObject* module) { +std::string GetEventName(absl::string_view method_name, PyObject* module) { // Python stack does not have a filename/line_no for native calls. // Use module name and function/method name instead. std::string filename; @@ -90,6 +91,7 @@ void AddEventToXLine(const PythonTraceEntry& event, xevent.SetEndTimestampNs(event.end_time_ns); } +#if PY_VERSION_HEX < 0x030C0000 template void ForEachThread(PyThreadState* curr_thread, ForEachThreadFunc&& callback) { // Note: PyThreadState's interp is not accessible in open source due to @@ -117,6 +119,8 @@ void ForEachThread(PyThreadState* curr_thread, ForEachThreadFunc&& callback) { #endif } +#endif // PY_VERSION_HEX + } // namespace /*static*/ PythonHookContext* PythonHooks::e2e_context_ = nullptr; @@ -202,7 +206,7 @@ void PythonHookContext::CollectData(tensorflow::profiler::XPlane* raw_plane) { } tsl::profiler::XPlaneBuilder plane(raw_plane); for (auto& it : entries_) { - uint32_t thread_id = it.first; + int64_t thread_id = it.first; auto& thread_events = it.second; VLOG(1) << "Collecting " << thread_events.completed.size() << ":" << thread_events.active.size() << " events on thread " << thread_id; @@ -283,7 +287,7 @@ void PythonHooks::ProfileSlow(const py::object& frame, const std::string& event, void PythonHookContext::ProfileFast(PyFrameObject* frame, int what, PyObject* arg) { - const uint32_t thread_id = tsl::Env::Default()->GetCurrentThreadId(); + const int64_t thread_id = tsl::Env::Default()->GetCurrentThreadId(); uint64_t now = tsl::profiler::GetCurrentTimeNanos(); auto& thread_traces = entries_[thread_id]; @@ -370,21 +374,29 @@ void PythonHookContext::ProfileFast(PyFrameObject* frame, int what, // NOTE: This must be after `threading.setprofile` otherwise we // end up recording that in our trace. +#if PY_VERSION_HEX < 0x030C0000 PyThreadState* curr_thread = PyThreadState_Get(); ForEachThread(curr_thread, [](PyThreadState* thread) { VLOG(1) << "Setting profiler in " << thread->thread_id; PyEval_SetProfile(&PythonHooks::ProfileFunction, nullptr); }); PyThreadState_Swap(curr_thread); +#else // PY_VERSION_HEX >= 0x030C0000 + PyEval_SetProfileAllThreads(&PythonHooks::ProfileFunction, nullptr); +#endif // PY_VERSION_HEX >= 0x030C0000 } /*static*/ void PythonHookContext::ClearProfilerInAllThreads() { +#if PY_VERSION_HEX < 0x030C0000 PyThreadState* curr_thread = PyThreadState_Get(); ForEachThread(curr_thread, [](PyThreadState* thread) { VLOG(1) << "Clearing profiler in " << thread->thread_id; PyEval_SetProfile(nullptr, nullptr); }); PyThreadState_Swap(curr_thread); +#else // PY_VERSION_HEX >= 0x030C0000 + PyEval_SetProfileAllThreads(nullptr, nullptr); +#endif // PY_VERSION_HEX >= 0x030C0000 // And notify the threading library that we're done. ThreadingSetProfile(py::none()); diff --git a/third_party/xla/xla/python/profiler/internal/python_hooks.h b/third_party/xla/xla/python/profiler/internal/python_hooks.h index af97b6a286679e..8ddc6ea985da60 100644 --- a/third_party/xla/xla/python/profiler/internal/python_hooks.h +++ b/third_party/xla/xla/python/profiler/internal/python_hooks.h @@ -138,8 +138,8 @@ class PythonHookContext { void operator=(PythonHookContext&&) = delete; // The thread id to entries map, Note: by convention the thread id is - // uint32_t to be consistent with cpu tracer when serialize to Xspace. - absl::flat_hash_map entries_; + // int64_t to be consistent with cpu tracer when serialize to Xspace. + absl::flat_hash_map entries_; uint64_t start_timestamp_ns_; PythonHooksOptions options_; // In end to end mode, Python get uninitialized before Stop()/Finalize(), we diff --git a/third_party/xla/xla/python/py_array.cc b/third_party/xla/xla/python/py_array.cc index ef1655b1ad97b8..7e4051c1a1adbe 100644 --- a/third_party/xla/xla/python/py_array.cc +++ b/third_party/xla/xla/python/py_array.cc @@ -22,12 +22,13 @@ limitations under the License. #include #include #include +#include #include #include #include #include #include -#include +#include // NOLINT #include #include #include @@ -46,6 +47,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep @@ -129,7 +131,7 @@ absl::StatusOr XlaDynamicShape(ifrt::Array* ifrt_array, } Shape shape = ShapeUtil::MakeShape(pjrt_buffer->element_type(), dims); // TODO(b/327524065): fix this - *shape.mutable_layout() = GetXlaLayoutUnsafe(pjrt_buffer->layout()); + *shape.mutable_layout() = pjrt_buffer->layout()->xla_layout(); scratch = std::move(shape); } return &scratch.value(); @@ -377,6 +379,7 @@ struct ShapedArrayCacheKey { nb::object MakeShapedArrayCached(const ShapedArrayCacheKey& key) { using CacheT = LRUCache>>; + static nb::ft_mutex mu; static auto* lru_list = new CacheT::LRUList(4096); static auto* cache = new CacheT(lru_list); @@ -393,6 +396,7 @@ nb::object MakeShapedArrayCached(const ShapedArrayCacheKey& key) { return nb::none(); } + nb::ft_lock_guard lock(mu); auto value = cache->GetOrCreateIfAbsent(key, [](const ShapedArrayCacheKey& key) { return std::make_shared>(); @@ -455,8 +459,15 @@ PyArray_Storage::PyArray_Storage( traceback(std::move(traceback)), ifrt_array(std::move(ifrt_array)), result_status(std::move(result_status)) { - next = this->py_client->arrays_; - this->py_client->arrays_ = this; + static_assert(PyClient::kNumArraysShards < + std::numeric_limits::max()); + thread_id_bucket = std::hash()(std::this_thread::get_id()) % + PyClient::kNumArraysShards; + + PyClient::ArraysShard& shard = this->py_client->arrays_[thread_id_bucket]; + nanobind::ft_lock_guard lock(shard.mutex); + next = shard.arrays; + shard.arrays = this; if (next) { next->prev = this; } @@ -501,7 +512,7 @@ PyArray PyArray::MakeFromSingleDeviceArray( auto dtype = IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); const ifrt::MemoryKind memory_kind = ifrt_array->sharding().memory_kind(); nb::object py_memory_kind = - (jax::GetEnableMemories() && memory_kind.memory_kind().has_value()) + (memory_kind.memory_kind().has_value()) ? nb::object(nb::str(memory_kind.memory_kind()->data(), memory_kind.memory_kind()->size())) : nb::none(); @@ -645,7 +656,7 @@ absl::Status PyArray::set_arrays(nb::object obj) { if (!nb::isinstance(obj)) { return InvalidArgument("Unsupported arg when setting Array._arrays: %s", - nb::cast(nb::str(obj.type()))); + nb::cast(nb::str(obj.type()))); } nb::list list(obj); @@ -676,7 +687,7 @@ absl::Status PyArray::set_arrays(nb::object obj) { shapes.push_back(ifrt_arrays.back()->shape()); } else { return InvalidArgument("Unsupported arg when setting Array._arrays: %s", - nb::cast(nb::str(obj.type()))); + nb::cast(nb::str(obj.type()))); } } const ifrt::MemoryKind first_memory_kind = @@ -786,7 +797,7 @@ absl::Status PyArray::CopySingleDeviceArrayToHostAsync() { arr.GetStorage().dynamic_shape, arr.ifrt_array()); } -absl::StatusOr PyArray::AssertUnsharded(std::string_view api) { +absl::StatusOr PyArray::AssertUnsharded(absl::string_view api) { if (ifrt_array() == nullptr) { return InvalidArgument("%s( called on deleted or donated buffer", api); } @@ -858,7 +869,7 @@ nb::dict PyArray::CudaArrayInterface() { ValueOrThrow(TypeDescriptorForPrimitiveType(pjrt_buffer->element_type())); // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout - Layout xla_layout = GetXlaLayoutUnsafe(pjrt_buffer->layout()); + Layout xla_layout = pjrt_buffer->layout()->xla_layout(); if (!LayoutUtil::IsMonotonicWithDim0Major(xla_layout)) { throw nb::attribute_error( "__cuda_array_interface__ is only currently supported for " @@ -1055,14 +1066,18 @@ nb::handle PyArray::Storage::AsHandle() { PyArray::Storage::~PyArray_Storage() { CHECK(PyGILState_Check()); - if (py_client && py_client->arrays_ == this) { - py_client->arrays_ = next; - } - if (prev) { - prev->next = next; - } - if (next) { - next->prev = prev; + if (py_client) { + PyClient::ArraysShard& shard = py_client->arrays_[thread_id_bucket]; + nanobind::ft_lock_guard lock(shard.mutex); + if (shard.arrays == this) { + shard.arrays = next; + } + if (prev) { + prev->next = next; + } + if (next) { + next->prev = prev; + } } // Release GIL and then explicitly destroy `ifrt_array` to prevent deadlock on // CPU backend caused by interactions between argument donations and host @@ -1119,11 +1134,11 @@ absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( auto transfer_guard_formatter = [&py_array, &dst_sharding] { return absl::StrCat( - "aval=", nb::cast(nb::repr(py_array.aval())), + "aval=", nb::cast(nb::repr(py_array.aval())), ", sharding=", - nb::cast(nb::repr(py_array.sharding())), + nb::cast(nb::repr(py_array.sharding())), ", dst_sharding=", - nb::cast(nb::repr(dst_sharding))); + nb::cast(nb::repr(dst_sharding))); }; TF_RETURN_IF_ERROR( jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); @@ -1187,8 +1202,8 @@ absl::StatusOr PyArray::BatchedDevicePut( } auto transfer_guard_formatter = [&aval, &sharding] { return absl::StrCat( - "aval=", nb::cast(nb::repr(aval)), - ", dst_sharding=", nb::cast(nb::repr(sharding))); + "aval=", nb::cast(nb::repr(aval)), + ", dst_sharding=", nb::cast(nb::repr(sharding))); }; GlobalPyRefManager()->CollectGarbage(); @@ -1297,13 +1312,16 @@ absl::Status PyArray::BatchedBlockUntilReady(std::vector objs) { return AwaitBuffersReady(absl::MakeConstSpan(ifrt_arrays)); } -std::vector PyClient::LiveArrays() const { - std::vector result; - for (PyArray::Storage* array = arrays_; array; array = array->next) { - bool all_deleted = - (array->ifrt_array == nullptr || array->ifrt_array->IsDeleted()); - if (!all_deleted) { - result.push_back(nb::borrow(array->AsHandle())); +std::vector PyClient::LiveArrays() const { + std::vector result; + for (auto& shard : arrays_) { + nb::ft_lock_guard lock(shard.mutex); + for (PyArray::Storage* array = shard.arrays; array; array = array->next) { + bool all_deleted = + (array->ifrt_array == nullptr || array->ifrt_array->IsDeleted()); + if (!all_deleted) { + result.push_back(nb::borrow(array->AsHandle())); + } } } return result; @@ -1400,7 +1418,7 @@ int PyArray_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) { } // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout - Layout xla_layout = GetXlaLayoutUnsafe(buffer.layout()); + Layout xla_layout = buffer.layout()->xla_layout(); if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS || (flags & PyBUF_STRIDES) == PyBUF_ND) && @@ -1495,8 +1513,8 @@ bool IsZeroCopyableCpuBuffer(const PjRtBuffer* buf) { // to unpack the array. This could happen for the host buffer // pre-mapped to the TPU device, a.k.a., pinned host buffers for the // device. - bool has_default_layout = buf->layout() == nullptr || - HasDefaultLayout(GetXlaLayoutUnsafe(buf->layout())); + bool has_default_layout = + buf->layout() == nullptr || HasDefaultLayout(buf->layout()->xla_layout()); // On CPU for values >= 8 bits, we can return the value in a zero-copy way. // For sub-byte values, we must copy in order to unpack the array. return buf->IsOnCpu() && @@ -1702,7 +1720,7 @@ absl::Status PyArray::RegisterTypes(nb::module_& m) { throw nb::type_error( absl::StrCat( "Unsupported type for elements in `arrays`: ", - nb::cast(nb::str(arrays[0].type()))) + nb::cast(nb::str(arrays[0].type()))) .c_str()); } }, diff --git a/third_party/xla/xla/python/py_array.h b/third_party/xla/xla/python/py_array.h index 39731a9b6200e1..d3bf0ca3337966 100644 --- a/third_party/xla/xla/python/py_array.h +++ b/third_party/xla/xla/python/py_array.h @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -114,6 +113,8 @@ struct PyArray_Storage { // duplicate PjRtBuffers in this list. PyArray_Storage* next; PyArray_Storage* prev; + + uint8_t thread_id_bucket; }; // The C++ implementation of jax.Array. A few key methods and data members are @@ -170,7 +171,7 @@ class PyArray : public nanobind::object { const nanobind::object& sharding() const { return GetStorage().sharding; } - absl::StatusOr> layout() { + absl::StatusOr> layout() { return ifrt_array()->layout(); } @@ -295,7 +296,7 @@ class PyArray : public nanobind::object { std::vector objs); private: - absl::StatusOr AssertUnsharded(std::string_view api); + absl::StatusOr AssertUnsharded(absl::string_view api); void CheckAndRearrange(); diff --git a/third_party/xla/xla/python/py_client.cc b/third_party/xla/xla/python/py_client.cc index c157994149c6fa..da4eca022ed73f 100644 --- a/third_party/xla/xla/python/py_client.cc +++ b/third_party/xla/xla/python/py_client.cc @@ -23,7 +23,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -36,7 +35,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" @@ -91,7 +89,6 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" -#include "tsl/platform/casts.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" @@ -192,6 +189,7 @@ absl::StatusOr> PyClient::DeviceFromLocalHardwareId( nb::list PyClient::LiveExecutables() { CHECK(PyGILState_Check()); + nb::ft_lock_guard lock(executables_mutex_); nb::list executables; for (PyLoadedExecutable* exec = executables_; exec; exec = exec->next_) { if (!exec->is_deleted()) { @@ -226,15 +224,16 @@ absl::Status PyClient::Defragment() { // Synchronously copy all buffers to host absl::flat_hash_map pjrt_buf_to_tmp_buffer; - for (PyArray_Storage* array = arrays_; array; array = array->next) { + std::vector arrays = LiveArrays(); + for (const PyArray& array : arrays) { // TODO(hyeontaek): Support non-PjRt Arrays. // TODO(hyeontaek): Re-construct ifrt::Array with new PjRtBuffer so that // std::shared_ptr does not need to be updated in-place. - if (array->ifrt_array == nullptr) { + if (array.ifrt_array() == nullptr) { continue; } - auto* arr = llvm::dyn_cast_or_null( - array->ifrt_array.get()); + auto* arr = + llvm::dyn_cast_or_null(array.ifrt_array()); if (arr == nullptr) { throw XlaRuntimeError( "This operation is implemented for a PjRt-compatible backend " @@ -340,10 +339,9 @@ absl::Status PyClient::Defragment() { options.allow_zero_copy = (!force_copy && (host_buffer_semantics == ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); - // TODO(phawkins): remove .ptr() after nanobind transition is complete. - TF_ASSIGN_OR_RETURN( - auto put_fn, DevicePut(argument.ptr(), client->ifrt_client_.get(), device, - options, ifrt::MemoryKind())); + TF_ASSIGN_OR_RETURN(auto put_fn, + DevicePut(argument, client->ifrt_client_.get(), device, + options, ifrt::MemoryKind())); TF_ASSIGN_OR_RETURN(auto put, [&]() { // Must release the GIL before calling IFRT because backends may // decide to block/sleep for device buffer allocation. @@ -490,7 +488,7 @@ PyClient::DeserializeExecutable(nb_class_ptr client, TF_ASSIGN_OR_RETURN( ifrt_loaded_executable, client->ifrt_client_->GetDefaultCompiler()->DeserializeLoadedExecutable( - std::string_view(serialized.c_str(), serialized.size()), + absl::string_view(serialized.c_str(), serialized.size()), std::move(ifrt_deserialize_options))); } TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); @@ -550,12 +548,13 @@ absl::StatusOr PyClient::HeapProfile() { return absl::OkStatus(); }; - for (PyArray_Storage* array = arrays_; array; array = array->next) { - if (array->ifrt_array == nullptr) { + std::vector arrays = LiveArrays(); + for (const PyArray& array : arrays) { + if (array.ifrt_array() == nullptr) { continue; } - auto* arr = llvm::dyn_cast_or_null( - array->ifrt_array.get()); + auto* arr = + llvm::dyn_cast_or_null(array.ifrt_array()); // TODO(hyeontaek): Support non-PjRt Arrays. if (arr == nullptr) { throw XlaRuntimeError( @@ -564,7 +563,8 @@ absl::StatusOr PyClient::HeapProfile() { } for (const auto& buffer : arr->pjrt_buffers()) { TF_RETURN_IF_ERROR(add_buffer_to_profile( - buffer.get(), array->traceback ? array->traceback->get() : nullptr)); + buffer.get(), + array.traceback() ? array.traceback()->get() : nullptr)); } } @@ -634,14 +634,13 @@ absl::StatusOr PyClient::MakePythonCallbackUsingHostSendAndRecv( } absl::StatusOr> -PyClient::GetEmitPythonCallbackDescriptor(nb::callable callable, - nb::object operand_shapes, - nb::object result_shapes) { - TF_ASSIGN_OR_RETURN(auto loaded_host_callback, - PyCpuLoadedHostCallback::Create( - ifrt_client(), std::move(callable), - nb::cast>(operand_shapes), - nb::cast>(result_shapes))); +PyClient::GetEmitPythonCallbackDescriptor( + nb::callable callable, absl::Span operand_shapes, + absl::Span result_shapes) { + TF_ASSIGN_OR_RETURN( + auto loaded_host_callback, + PyCpuLoadedHostCallback::Create(ifrt_client(), std::move(callable), + operand_shapes, result_shapes)); const uint64_t descriptor = loaded_host_callback->descriptor(); nb::capsule callback_capsule( @@ -778,16 +777,16 @@ PyType_Slot PyClient::slots_[] = { .def( "get_default_layout", [](PyClient& self, nb_dtype dtype, nb::sequence shard_shape, - nb_class_ptr device) -> std::unique_ptr { + nb_class_ptr device) + -> std::shared_ptr { ifrt::DType ifrt_type = xla::ValueOrThrow(DtypeToIfRtDType(dtype)); std::vector dims = SequenceToVector(shard_shape); - return xla::ValueOrThrow( - self.ifrt_client()->GetDefaultLayoutForDevice( - ifrt_type, dims, device->device())); + return xla::ValueOrThrow(self.ifrt_client()->GetDefaultLayout( + ifrt_type, dims, device->device(), xla::ifrt::MemoryKind())); }, nb::arg("dtype"), nb::arg("shard_shape"), nb::arg("device")) .def("__getattr__", - [](PyClient& client, std::string_view name) -> nb::object { + [](PyClient& client, absl::string_view name) -> nb::object { const auto& attrs = client.Attributes().map(); auto it = attrs.find(name); if (it != attrs.end()) { diff --git a/third_party/xla/xla/python/py_client.h b/third_party/xla/xla/python/py_client.h index 0a0b2275b6afbb..a8893a0b41441f 100644 --- a/third_party/xla/xla/python/py_client.h +++ b/third_party/xla/xla/python/py_client.h @@ -18,21 +18,22 @@ limitations under the License. #include +#include +#include #include #include #include #include -#include #include #include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" -#include "xla/hlo/builder/xla_builder.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" @@ -95,7 +96,7 @@ class PyClient { return shared_ptr_pjrt_client(); } - std::string_view platform_name() const { + absl::string_view platform_name() const { // TODO(phawkins): this is a temporary backwards compatibility shim. We // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but // we haven't yet updated JAX clients that expect "gpu". Migrate users and @@ -107,14 +108,16 @@ class PyClient { return ifrt_client_->platform_name(); } } - std::string_view raw_platform_name() const { + absl::string_view raw_platform_name() const { // TODO(parkers): Once platform_name() is the same, remove this. return ifrt_client_->platform_name(); } - std::string_view platform_version() const { + absl::string_view platform_version() const { return ifrt_client_->platform_version(); } - std::string_view runtime_type() const { return ifrt_client_->runtime_type(); } + absl::string_view runtime_type() const { + return ifrt_client_->runtime_type(); + } // Returns implementation-specific attributes about this client, e.g. the PJRT // C API version if applicable. @@ -189,23 +192,10 @@ class PyClient { // The callable receives as arguments NumPy arrays for arguments with array // types, and None for Token argument. The callable must return a tuple of // either arrays or None values. - // TODO(phawkins): pass operand_shapes and result_shapes as - // absl::Span when nanobind transition is complete. absl::StatusOr> GetEmitPythonCallbackDescriptor(nanobind::callable callable, - nanobind::object operand_shapes, - nanobind::object result_shapes); - // Deprecated; please switch to emitting a `CustomCallOp` directly. - absl::StatusOr EmitPythonCallbackFromDescriptor( - XlaBuilder& builder, uint64_t descriptor, - absl::Span operands, absl::Span result_shapes, - std::optional> operand_layouts, bool has_side_effect); - // Deprecated; please switch to using `GetEmitPythonCallbackDescriptor` - // and then emitting a `CustomCall` op instead. - absl::StatusOr> EmitPythonCallback( - nanobind::callable callable, XlaBuilder& builder, - absl::Span operands, absl::Span result_shapes, - std::optional> operand_layouts, bool has_side_effect); + absl::Span operand_shapes, + absl::Span result_shapes); // `MakePythonCallbackUsingHostSendAndRecv` takes in an input Python callable // that takes in arguments of shapes `operand_shapes` and returns results of @@ -229,7 +219,7 @@ class PyClient { absl::Span recv_channel_ids, nanobind::callable serializer); - std::vector LiveArrays() const; + std::vector LiveArrays() const; static void RegisterPythonTypes(nanobind::module_& m); @@ -251,8 +241,20 @@ class PyClient { // to iterate over all known objects when heap profiling. The list structure // is protected by the GIL. + nanobind::ft_mutex executables_mutex_; + // List guarded by executables_mutex_. PyLoadedExecutable* executables_ = nullptr; - PyArray_Storage* arrays_ = nullptr; + +#ifdef NB_FREE_THREADING + static constexpr size_t kNumArraysShards = 16; +#else + static constexpr size_t kNumArraysShards = 1; +#endif + struct ArraysShard { + mutable nanobind::ft_mutex mutex; + PyArray_Storage* arrays; + }; + std::array arrays_; absl::flat_hash_map> devices_; absl::flat_hash_map> diff --git a/third_party/xla/xla/python/py_client_gpu.cc b/third_party/xla/xla/python/py_client_gpu.cc index d1c01a62d16a7a..73d2e8edafaa9e 100644 --- a/third_party/xla/xla/python/py_client_gpu.cc +++ b/third_party/xla/xla/python/py_client_gpu.cc @@ -13,30 +13,35 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include +#include +#include #include #include "absl/base/casts.h" +#include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" +#include "absl/types/span.h" #include "xla/service/custom_call_status.h" -#include "tsl/platform/errors.h" #if TENSORFLOW_USE_ROCM #include "rocm/include/hip/hip_runtime.h" #else #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/include/driver_types.h" #endif #include "nanobind/nanobind.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" #include "xla/primitive_util.h" #include "xla/python/callback.h" #include "xla/python/nb_numpy.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/platform_util.h" - #if TENSORFLOW_USE_ROCM #define gpuSuccess hipSuccess #define gpuStreamHandle hipStream_t @@ -109,7 +114,7 @@ void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers, callback->Call(host_input_arrays); LeaveHostCallback(); if (!maybe_result_tuple.ok()) { - std::string_view msg = maybe_result_tuple.status().message(); + absl::string_view msg = maybe_result_tuple.status().message(); XlaCustomCallStatusSetFailure(status, msg.data(), msg.length()); return; } diff --git a/third_party/xla/xla/python/py_compile_only_client.cc b/third_party/xla/xla/python/py_compile_only_client.cc index 9dde801ff5a7fd..9d9ea0cfe59bf4 100644 --- a/third_party/xla/xla/python/py_compile_only_client.cc +++ b/third_party/xla/xla/python/py_compile_only_client.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -39,6 +38,8 @@ limitations under the License. #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/pjrt/host_memory_spaces.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_device_description.h" @@ -63,7 +64,6 @@ limitations under the License. #include "xla/python/ifrt/tuple.h" #include "xla/python/ifrt/value.h" #include "xla/python/nb_class_ptr.h" -#include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" @@ -338,13 +338,17 @@ class CompileOnlyIfRtClient final return topology_; } - absl::StatusOr> GetDefaultLayoutForDevice( - ifrt::DType dtype, absl::Span dims, - ifrt::Device* device) const override { + absl::StatusOr> GetDefaultLayout( + ifrt::DType dtype, absl::Span dims, ifrt::Device* device, + ifrt::MemoryKind memory_kind) const override { + if (memory_kind == ifrt::MemoryKind(UnpinnedHostMemorySpace::kKind)) { + return std::make_shared( + LayoutUtil::MakeDescendingLayout(dims.size())); + } TF_ASSIGN_OR_RETURN(PrimitiveType element_type, ToPrimitiveType(dtype)); TF_ASSIGN_OR_RETURN(xla::Layout layout, topology_->GetDefaultLayout(element_type, dims)); - return std::make_unique(std::move(layout)); + return std::make_shared(std::move(layout)); } private: @@ -372,7 +376,7 @@ class CompileOnlyPyClient : public PyClient { } absl::StatusOr> CompileUnloaded( - std::string_view mlir_module, CompileOptions options, + absl::string_view mlir_module, CompileOptions options, std::vector host_callbacks) { if (!host_callbacks.empty()) { return Unimplemented( @@ -397,8 +401,7 @@ class CompileOnlyPyClient : public PyClient { PjRtCompile(std::move(options), module.get(), *ifrt_client->topology().description())); TF_ASSIGN_OR_RETURN(auto ifrt_executable, - ifrt::PjRtExecutable::Create(std::move(executable), - std::move(xla_options))); + ifrt::PjRtExecutable::Create(std::move(executable))); return std::shared_ptr(std::move(ifrt_executable)); } @@ -422,7 +425,7 @@ void RegisterCompileOnlyClient(nb::module_& m) { [](CompileOnlyPyClient& self, nb::bytes mlir_module, CompileOptions options, std::vector host_callbacks) { return ValueOrThrow(self.CompileUnloaded( - std::string_view(mlir_module.c_str(), mlir_module.size()), + absl::string_view(mlir_module.c_str(), mlir_module.size()), std::move(options), std::move(host_callbacks))); }, nb::arg("computation"), nb::arg("compile_options") = CompileOptions(), diff --git a/third_party/xla/xla/python/py_device.cc b/third_party/xla/xla/python/py_device.cc index 9139454bc36cd4..6a9f4ef781b845 100644 --- a/third_party/xla/xla/python/py_device.cc +++ b/third_party/xla/xla/python/py_device.cc @@ -22,13 +22,13 @@ limitations under the License. #include #include #include -#include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep @@ -66,7 +66,7 @@ int PyDevice::id() const { return device_->Id().value(); } int PyDevice::process_index() const { return device_->ProcessIndex(); } -std::string_view PyDevice::platform() const { +absl::string_view PyDevice::platform() const { // TODO(phawkins): this is a temporary backwards // compatibility shim. We changed the name PJRT // reports for GPU platforms to "cuda" or "rocm", @@ -75,13 +75,13 @@ std::string_view PyDevice::platform() const { // code. if (client_->platform_name() == "cuda" || client_->platform_name() == "rocm") { - return std::string_view("gpu"); + return absl::string_view("gpu"); } else { return client_->platform_name(); } } -std::string_view PyDevice::device_kind() const { return device_->Kind(); } +absl::string_view PyDevice::device_kind() const { return device_->Kind(); } std::optional PyDevice::local_hardware_id() const { // TODO(phawkins): consider supporting this for non-PJRT devices. @@ -96,9 +96,9 @@ std::optional PyDevice::local_hardware_id() const { return local_hardware_id; } -std::string_view PyDevice::Str() const { return device_->DebugString(); } +absl::string_view PyDevice::Str() const { return device_->DebugString(); } -std::string_view PyDevice::Repr() const { return device_->ToString(); } +absl::string_view PyDevice::Repr() const { return device_->ToString(); } absl::Status PyDevice::TransferToInfeed(LiteralSlice literal) { GlobalPyRefManager()->CollectGarbage(); @@ -136,7 +136,7 @@ absl::StatusOr PyDevice::TransferFromOutfeed(Shape shape) { } absl::StatusOr> PyDevice::Memory( - std::string_view kind) const { + absl::string_view kind) const { ifrt::Memory* result_memory_space = nullptr; for (auto* memory_space : device_->Memories()) { if (memory_space->Kind().memory_kind() == kind) { @@ -321,7 +321,7 @@ PyType_Slot PyDevice::slots_[] = { } try { auto device = nb::cast(nb::handle(self)); - auto name = nb::cast(nb::handle(key)); + auto name = nb::cast(nb::handle(key)); const auto& attrs = device->device_->Attributes().map(); auto it = attrs.find(name); if (it != attrs.end()) { diff --git a/third_party/xla/xla/python/py_device.h b/third_party/xla/xla/python/py_device.h index 7151fccb114a62..6acd35b1da9906 100644 --- a/third_party/xla/xla/python/py_device.h +++ b/third_party/xla/xla/python/py_device.h @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -49,18 +48,18 @@ class PyDevice { int id() const; int process_index() const; - std::string_view platform() const; - std::string_view device_kind() const; + absl::string_view platform() const; + absl::string_view device_kind() const; std::optional local_hardware_id() const; - std::string_view Str() const; - std::string_view Repr() const; + absl::string_view Str() const; + absl::string_view Repr() const; absl::Status TransferToInfeed(LiteralSlice literal); absl::StatusOr TransferFromOutfeed(Shape shape); absl::StatusOr> Memory( - std::string_view kind) const; + absl::string_view kind) const; absl::StatusOr> DefaultMemory() const; nanobind::list AddressableMemories() const; absl::StatusOr> MemoryStats() const; diff --git a/third_party/xla/xla/python/py_device_list.cc b/third_party/xla/xla/python/py_device_list.cc index a0ea40ce1efb81..0ecc2dc5ba32e8 100644 --- a/third_party/xla/xla/python/py_device_list.cc +++ b/third_party/xla/xla/python/py_device_list.cc @@ -113,27 +113,40 @@ int64_t PyDeviceList::Hash() { return *hash_; } -bool PyDeviceList::operator==(nb::handle other) { +/*static*/ bool PyDeviceList::Equal(xla::nb_class_ptr self, + nb::handle other) { if (!nb::isinstance(other)) { return false; } auto o = nb::cast(other); // Fast-path using a pointer equality check. - if (this == o) { + if (self.get() == o) { return true; } - if (Hash() != o->Hash()) { + int64_t h1, h2; + { + nb::ft_object_guard lock(self); + h1 = self->Hash(); + } + { + nb::ft_object_guard lock(other); + h2 = o->Hash(); + } + if (h1 != h2) { return false; } - if (device_list_.index() == 0 && o->device_list_.index() == 0) { + if (self->device_list_.index() == 0 && o->device_list_.index() == 0) { nb::gil_scoped_release gil_release; - return *std::get<0>(device_list_) == *std::get<0>(o->device_list_); + return *std::get<0>(self->device_list_) == *std::get<0>(o->device_list_); } else { - return AsTuple().equal(o->AsTuple()); + return self->AsTuple().equal(o->AsTuple()); } } -bool PyDeviceList::operator!=(nb::handle other) { return !(*this == other); } +/*static*/ bool PyDeviceList::NotEqual(xla::nb_class_ptr self, + nb::handle other) { + return !Equal(std::move(self), other); +} int PyDeviceList::Len() const { switch (device_list_.index()) { @@ -281,6 +294,7 @@ bool PyDeviceList::IsFullyAddressable() { /*static*/ xla::nb_class_ptr PyDeviceList::AddressableDeviceList( xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); if (self->IsFullyAddressable()) { // Do not cache this result in `addressable_device_list_`. Otherwise, it // will create a cycle that prevents deletion of this object. @@ -395,38 +409,36 @@ void PyDeviceList::PopulateMemoryKindInfoForDuckTypedDevices() { } } -absl::StatusOr PyDeviceList::MemoryKinds() { - if (!GetEnableMemories()) { - return nb::tuple(); - } - if (!memory_kind_info_.has_value()) { - PopulateMemoryKindInfo(); +/*static*/ absl::StatusOr PyDeviceList::MemoryKinds( + xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (!self->memory_kind_info_.has_value()) { + self->PopulateMemoryKindInfo(); } - if (!memory_kind_info_->ok()) { - return memory_kind_info_->status(); + if (!self->memory_kind_info_->ok()) { + return self->memory_kind_info_->status(); } - return (*memory_kind_info_)->memory_kinds; + return (*self->memory_kind_info_)->memory_kinds; } -absl::StatusOr PyDeviceList::DefaultMemoryKind() { - if (!GetEnableMemories()) { - return nb::none(); - } - if (!memory_kind_info_.has_value()) { - PopulateMemoryKindInfo(); +/*static*/ absl::StatusOr PyDeviceList::DefaultMemoryKind( + xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (!self->memory_kind_info_.has_value()) { + self->PopulateMemoryKindInfo(); } - if (!memory_kind_info_->ok()) { - return memory_kind_info_->status(); + if (!self->memory_kind_info_->ok()) { + return self->memory_kind_info_->status(); } - return (*memory_kind_info_)->default_memory_kind; + return (*self->memory_kind_info_)->default_memory_kind; } -void RegisterDeviceList(nb::module_& m) { +/*static*/ void PyDeviceList::Register(nb::module_& m) { nb::class_(m, "DeviceList") .def(nb::init()) - .def("__hash__", &PyDeviceList::Hash) - .def("__eq__", &PyDeviceList::operator==) - .def("__ne__", &PyDeviceList::operator!=) + .def("__hash__", &PyDeviceList::Hash, nb::lock_self()) + .def("__eq__", &PyDeviceList::Equal) + .def("__ne__", &PyDeviceList::NotEqual) .def("__len__", &PyDeviceList::Len) .def("__getitem__", &PyDeviceList::GetItem) .def("__getitem__", &PyDeviceList::GetSlice) @@ -438,21 +450,22 @@ void RegisterDeviceList(nb::module_& m) { [](PyDeviceList& self, nb::tuple t) { new (&self) PyDeviceList(std::move(t)); }) - .def_prop_ro("is_fully_addressable", &PyDeviceList::IsFullyAddressable) + .def_prop_ro("is_fully_addressable", &PyDeviceList::IsFullyAddressable, + nb::lock_self()) .def_prop_ro("addressable_device_list", &PyDeviceList::AddressableDeviceList) // `xla::ValueOrThrowWrapper` does not work with // `def_prop_ro()`. Manually convert an error into an exception. .def_prop_ro("default_memory_kind", - [](PyDeviceList* l) { - auto kind = l->DefaultMemoryKind(); + [](xla::nb_class_ptr l) { + auto kind = DefaultMemoryKind(l); if (!kind.ok()) { throw nb::value_error(kind.status().ToString().c_str()); } return *kind; }) - .def_prop_ro("memory_kinds", [](PyDeviceList* l) { - auto kinds = l->MemoryKinds(); + .def_prop_ro("memory_kinds", [](xla::nb_class_ptr l) { + auto kinds = MemoryKinds(l); if (!kinds.ok()) { throw nb::value_error(kinds.status().ToString().c_str()); } diff --git a/third_party/xla/xla/python/py_device_list.h b/third_party/xla/xla/python/py_device_list.h index 8113ead6aee373..d44065f59d43a0 100644 --- a/third_party/xla/xla/python/py_device_list.h +++ b/third_party/xla/xla/python/py_device_list.h @@ -53,13 +53,33 @@ class PyDeviceList { absl::StatusOr> ifrt_device_list() const; - // Methods below require GIL. - int64_t Hash(); - bool operator==(nanobind::handle other); - bool operator!=(nanobind::handle other); + int Len() const; // Requires the GIL in GIL mode. + nanobind::object GetItem(int index); // Requires the GIL in GIL mode. + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static xla::nb_class_ptr AddressableDeviceList( + xla::nb_class_ptr self); + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static absl::StatusOr DefaultMemoryKind( + xla::nb_class_ptr self); - int Len() const; - nanobind::object GetItem(int index); + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static absl::StatusOr MemoryKinds( + xla::nb_class_ptr self); + + // go/pywald-pybind-annotation BEGIN + // refs { + // module_path: "third_party/tensorflow/compiler/xla/python/xla.cc" + // module_arg {} + // } + // go/pywald-pybind-annotation END + static void Register(nanobind::module_& m); + + private: + nanobind::tuple AsTuple() const; + + // Methods below require GIL. nanobind::object GetSlice(nanobind::slice slice); nanobind::iterator Iter(); @@ -67,21 +87,24 @@ class PyDeviceList { nanobind::tuple Dump() const; - bool IsFullyAddressable(); - static xla::nb_class_ptr AddressableDeviceList( - xla::nb_class_ptr self); - absl::StatusOr DefaultMemoryKind(); - absl::StatusOr MemoryKinds(); + int64_t Hash(); // Mutates hash_, needs self lock. - private: - nanobind::tuple AsTuple() const; + static bool Equal(xla::nb_class_ptr self, + nanobind::handle other); + static bool NotEqual(xla::nb_class_ptr self, + nanobind::handle other); - // Finds the memory kind info from an addressable device. + // Finds the memory kind info from an addressable device. Requires the GIL + // or self lock. void PopulateMemoryKindInfo(); // Same as `PopulateMemoryKindInfo()`, but uses `py_device_assignment_` // instead of `ifrt_device_list_` to support duck-typed device objects. + // Requires the GIL or self lock. void PopulateMemoryKindInfoForDuckTypedDevices(); + // Requires the self lock or GIL is held. + bool IsFullyAddressable(); + // Valid only if `device_list_` contains `xla::ifrt::DeviceList` and // non-empty. xla::nb_class_ptr py_client_; @@ -90,32 +113,27 @@ class PyDeviceList { // TODO(hyeontaek): Remove support for Python duck-type devices once all // JAX backends and tests are migrated to use an `xla::ifrt::Device` type // for JAX devices. + // Immutable after constructor; no locking needed. std::variant, nanobind::tuple> device_list_; - std::optional hash_; // Populated on demand. + // Populated on demand. Guarded by the object's self lock. + std::optional hash_; // TODO(hyeontaek): Make the following property cached within // `xla::ifrt::DeviceList`. - std::optional is_fully_addressable_; // Populated on demand. - std::optional> - addressable_device_list_; // Populated on demand. + // Populated on demand. Guarded by the object's self lock. + std::optional is_fully_addressable_; + // Populated on demand. Guarded by the object's self lock. + std::optional> addressable_device_list_; struct MemoryKindInfo { nanobind::object default_memory_kind; nanobind::tuple memory_kinds; }; - std::optional> - memory_kind_info_; // Populated on demand. + // Populated on demand. Guarded by the object's self lock. + std::optional> memory_kind_info_; }; -// go/pywald-pybind-annotation BEGIN -// refs { -// module_path: "third_party/tensorflow/compiler/xla/python/xla.cc" -// module_arg {} -// } -// go/pywald-pybind-annotation END -void RegisterDeviceList(nanobind::module_& m); - } // namespace jax #endif // XLA_PYTHON_PY_DEVICE_LIST_H_ diff --git a/third_party/xla/xla/python/py_executable.cc b/third_party/xla/xla/python/py_executable.cc index 0bdff1204ac2f8..bd582d3035cf58 100644 --- a/third_party/xla/xla/python/py_executable.cc +++ b/third_party/xla/xla/python/py_executable.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -30,10 +29,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" @@ -87,21 +86,23 @@ PyLoadedExecutable::PyLoadedExecutable( traceback_(std::move(traceback)), fingerprint_(std::move(fingerprint)) { CHECK(PyGILState_Check()); + if (fingerprint_) { + options_.launch_id = tsl::Fingerprint32(*fingerprint_); + VLOG(1) << "Fingerprint for executable " << ifrt_loaded_executable_->name() + << ": " << *fingerprint_; + } + nb::ft_lock_guard lock(client_->executables_mutex_); next_ = client_->executables_; client_->executables_ = this; prev_ = nullptr; if (next_) { next_->prev_ = this; } - if (fingerprint_) { - options_.launch_id = tsl::Fingerprint32(*fingerprint_); - VLOG(1) << "Fingerprint for executable " << ifrt_loaded_executable_->name() - << ": " << *fingerprint_; - } } PyLoadedExecutable::~PyLoadedExecutable() { CHECK(PyGILState_Check()); + nb::ft_lock_guard lock(client_->executables_mutex_); if (client_->executables_ == this) { client_->executables_ = next_; } @@ -408,19 +409,19 @@ PyLoadedExecutable::HloModules() const { return ifrt_loaded_executable_->GetHloModules(); } -absl::StatusOr>> +absl::StatusOr>> PyLoadedExecutable::GetOutputMemoryKinds() const { nb::gil_scoped_release gil_release; return ifrt_loaded_executable_->GetOutputMemoryKinds(); } -absl::StatusOr>> +absl::StatusOr>> PyLoadedExecutable::GetParameterLayouts() const { nb::gil_scoped_release gil_release; return ifrt_loaded_executable_->GetParameterLayouts(); } -absl::StatusOr>> +absl::StatusOr>> PyLoadedExecutable::GetOutputLayouts() const { nb::gil_scoped_release gil_release; return ifrt_loaded_executable_->GetOutputLayouts(); diff --git a/third_party/xla/xla/python/py_executable.h b/third_party/xla/xla/python/py_executable.h index e032ee7b4acdda..f4c22b52c431c7 100644 --- a/third_party/xla/xla/python/py_executable.h +++ b/third_party/xla/xla/python/py_executable.h @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -187,14 +186,14 @@ class PyLoadedExecutable { absl::StatusOr>> HloModules() const; - absl::StatusOr>> + absl::StatusOr>> GetOutputMemoryKinds() const; - absl::StatusOr>> GetParameterLayouts() - const; + absl::StatusOr>> + GetParameterLayouts() const; - absl::StatusOr>> GetOutputLayouts() - const; + absl::StatusOr>> + GetOutputLayouts() const; std::optional> GetParameterShardings() const; @@ -208,15 +207,6 @@ class PyLoadedExecutable { // Short-term escape hatch to get PjRtLoadedExecutable from PyExecutable. // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. - PjRtLoadedExecutable* pjrt_executable() const { - auto* exec = llvm::dyn_cast_or_null( - ifrt_loaded_executable_.get()); - if (exec == nullptr) { - throw XlaRuntimeError( - "This operation is implemented for a PjRt-compatible backend only."); - } - return exec->pjrt_loaded_executable(); - } std::shared_ptr shared_ptr_pjrt_executable() { auto* exec = llvm::dyn_cast_or_null( ifrt_loaded_executable_.get()); diff --git a/third_party/xla/xla/python/py_memory_space.cc b/third_party/xla/xla/python/py_memory_space.cc index c55f0d04383960..990b1ba6ec5f84 100644 --- a/third_party/xla/xla/python/py_memory_space.cc +++ b/third_party/xla/xla/python/py_memory_space.cc @@ -17,12 +17,11 @@ limitations under the License. #include -#include #include +#include "absl/strings/string_view.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string_view.h" // IWYU pragma: keep -#include "xla/pjrt/pjrt_client.h" #include "xla/python/ifrt/device.h" #include "xla/python/nb_class_ptr.h" #include "xla/python/py_client.h" @@ -37,7 +36,7 @@ PyMemorySpace::PyMemorySpace(nb_class_ptr client, int PyMemorySpace::process_index() const { return client_->process_index(); } -std::string_view PyMemorySpace::platform() const { +absl::string_view PyMemorySpace::platform() const { // TODO(phawkins): this is a temporary backwards // compatibility shim. We changed the name PJRT // reports for GPU platforms to "cuda" or "rocm", @@ -46,19 +45,19 @@ std::string_view PyMemorySpace::platform() const { // code. if (client_->platform_name() == "cuda" || client_->platform_name() == "rocm") { - return std::string_view("gpu"); + return absl::string_view("gpu"); } else { return client_->platform_name(); } } -std::string_view PyMemorySpace::kind() const { +absl::string_view PyMemorySpace::kind() const { return *memory_->Kind().memory_kind(); } -std::string_view PyMemorySpace::Str() const { return memory_->DebugString(); } +absl::string_view PyMemorySpace::Str() const { return memory_->DebugString(); } -std::string_view PyMemorySpace::Repr() const { return memory_->ToString(); } +absl::string_view PyMemorySpace::Repr() const { return memory_->ToString(); } nb::list PyMemorySpace::AddressableByDevices() const { nb::list devices; diff --git a/third_party/xla/xla/python/py_memory_space.h b/third_party/xla/xla/python/py_memory_space.h index 9b5507b55422ef..bc0773ed436672 100644 --- a/third_party/xla/xla/python/py_memory_space.h +++ b/third_party/xla/xla/python/py_memory_space.h @@ -18,8 +18,6 @@ limitations under the License. #include -#include - #include "nanobind/nanobind.h" #include "xla/python/ifrt/memory.h" #include "xla/python/nb_class_ptr.h" @@ -42,11 +40,11 @@ class PyMemorySpace { ifrt::Memory* memory_space() const { return memory_; } int process_index() const; - std::string_view platform() const; - std::string_view kind() const; + absl::string_view platform() const; + absl::string_view kind() const; - std::string_view Str() const; - std::string_view Repr() const; + absl::string_view Str() const; + absl::string_view Repr() const; nanobind::list AddressableByDevices() const; diff --git a/third_party/xla/xla/python/py_values.cc b/third_party/xla/xla/python/py_values.cc index ed7388aa9ba53e..631b0bcb9b9562 100644 --- a/third_party/xla/xla/python/py_values.cc +++ b/third_party/xla/xla/python/py_values.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -32,6 +31,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/complex.h" // IWYU pragma: keep @@ -44,7 +44,6 @@ limitations under the License. #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/py_array.h" @@ -83,7 +82,7 @@ absl::StatusOr HandlePythonScalar( "Unable to convert Python scalar to %s. This most likely means the " "value (%s) overflows the range of the type.", PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), - nb::cast(nb::repr(obj))); + nb::cast(nb::repr(obj))); } std::variant data; @@ -130,7 +129,7 @@ absl::StatusOr HandlePythonInt( "Unable to convert Python scalar to %s. This most likely means the " "value (%s) overflows the range of the type.", PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), - nb::cast(nb::repr(obj))); + nb::cast(nb::repr(obj))); } type = S32; } else { @@ -141,7 +140,7 @@ absl::StatusOr HandlePythonInt( "Unable to convert Python scalar to %s. This most likely means the " "value (%s) overflows the range of the type.", PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), - nb::cast(nb::repr(obj))); + nb::cast(nb::repr(obj))); } type = S64; } @@ -378,8 +377,7 @@ absl::StatusOr DevicePut(nb::handle arg, (*p)[reinterpret_cast(&PyComplex_Type)] = HandlePythonScalar; - const auto numpy = nb::module_::import_("numpy"); - (*p)[numpy.attr("ndarray").ptr()] = HandleNumpyArray; + (*p)[reinterpret_cast(&PyArray_Type)] = HandleNumpyArray; // Numpy scalar types. For some of them, we share the handler with // Python types (np_int64, np_float64, np_complex128). @@ -452,7 +450,7 @@ absl::StatusOr DevicePut(nb::handle arg, "Not supported: The C++ jax jit execution path, only accepts " "DeviceArray, Numpy arrays scalars of supported types " "(see implementation), or Python scalars. Got type ", - nb::cast(nb::str(arg.type())))); + nb::cast(nb::str(arg.type())))); } return res->second(arg, client, to_device, options, to_memory_kind); } @@ -553,8 +551,7 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, numpy_array.ndim()), /*weak_type=*/false); }; - const auto numpy = nb::module_::import_("numpy"); - (*p)[numpy.attr("ndarray").ptr()] = numpy_handler; + (*p)[reinterpret_cast(&PyArray_Type)] = numpy_handler; ToPyArgSignatureHandler np_uint64_handler = [](nb::handle h, @@ -643,7 +640,7 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, "Buffer/DeviceArray, Numpy " "arrays scalars of supported types " "(see implementation), or Python scalars. Got type ", - nb::cast(nb::str(arg.type())))); + nb::cast(nb::str(arg.type())))); } return res->second(arg, jax_enable_x64); } diff --git a/third_party/xla/xla/python/pytree.cc b/third_party/xla/xla/python/pytree.cc index 5a165cde069201..d5799b8695cb72 100644 --- a/third_party/xla/xla/python/pytree.cc +++ b/third_party/xla/xla/python/pytree.cc @@ -29,7 +29,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -40,6 +39,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep @@ -96,11 +96,12 @@ void PyTreeRegistry::Register( registration->to_iterable = std::move(to_iterable); registration->from_iterable = std::move(from_iterable); registration->to_iterable_with_keys = std::move(to_iterable_with_keys); + nb::ft_lock_guard lock(mu_); auto it = registrations_.emplace(type, std::move(registration)); if (!it.second) { throw std::invalid_argument( absl::StrFormat("Duplicate custom PyTreeDef type registration for %s.", - nb::cast(nb::repr(type)))); + nb::cast(nb::repr(type)))); } } @@ -112,11 +113,12 @@ void PyTreeRegistry::RegisterDataclass(nb::object type, registration->type = type; registration->data_fields = std::move(data_fields); registration->meta_fields = std::move(meta_fields); + nb::ft_lock_guard lock(mu_); auto it = registrations_.emplace(type, std::move(registration)); if (!it.second) { throw std::invalid_argument(absl::StrFormat( "Duplicate custom dataclass PyTreeDef type registration for %s.", - nb::cast(nb::repr(std::move(type))))); + nb::cast(nb::repr(std::move(type))))); } } @@ -129,7 +131,7 @@ PyTreeRegistry::Registration::ToIterable(nanobind::handle o) const { throw std::invalid_argument(absl::StrCat( "The to_iterable function for a custom PyTree node should return " "a (children, aux_data) tuple, got ", - nb::cast(nb::repr(out)))); + nb::cast(nb::repr(out)))); } nb::iterable leaves; if (!nb::try_cast(leaves_and_aux_data[0], leaves)) { @@ -137,7 +139,7 @@ PyTreeRegistry::Registration::ToIterable(nanobind::handle o) const { "The to_iterable function for a custom PyTree node should return " "a (children, aux_data) tuple where 'children' is iterable, " "got ", - nb::cast(nb::repr(out)))); + nb::cast(nb::repr(out)))); } return std::make_pair(std::move(leaves), nb::object(leaves_and_aux_data[1])); } @@ -161,7 +163,7 @@ PyTreeRegistry::Registration::ToIterableWithKeys(nb::handle o) const { throw std::invalid_argument(absl::StrCat( "The to_iterable_with_keys function for a custom PyTree " "node should return a (key_leaf_pairs, aux_data) tuple, got ", - nb::cast(nb::repr(out)))); + nb::cast(nb::repr(out)))); } nb::iterable key_leaf_pairs; if (!nb::try_cast(leaves_and_aux_data[0], key_leaf_pairs)) { @@ -169,7 +171,7 @@ PyTreeRegistry::Registration::ToIterableWithKeys(nb::handle o) const { "The to_iterable_with_keys function for a custom PyTree node should " "return a (key_leaf_pairs, aux_data) tuple where 'key_leaf_pairs' is " "iterable, got ", - nb::cast(nb::repr(leaves_and_aux_data)))); + nb::cast(nb::repr(leaves_and_aux_data)))); } for (nb::handle key_leaf_pair : key_leaf_pairs) { nb::tuple key_leaf_pair_tuple; @@ -178,7 +180,7 @@ PyTreeRegistry::Registration::ToIterableWithKeys(nb::handle o) const { throw std::invalid_argument(absl::StrCat( "The to_iterable_with_keys function for a custom PyTree node should " "return a (key_leaf_pairs, aux_data) tuple where 'child", - nb::cast(nb::repr(key_leaf_pair)))); + nb::cast(nb::repr(key_leaf_pair)))); } result.push_back(std::make_pair(nb::borrow(key_leaf_pair_tuple[0]), nb::borrow(key_leaf_pair_tuple[1]))); @@ -222,6 +224,7 @@ PyTreeKind PyTreeRegistry::KindOfObject( /*static*/ const PyTreeRegistry::Registration* PyTreeRegistry::Lookup( nb::handle type) const { + nb::ft_lock_guard lock(mu_); auto it = registrations_.find(type); return it == registrations_.end() ? nullptr : it->second.get(); } @@ -291,22 +294,62 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const { } nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/false); +} + +nb::object PyTreeRegistry::FlattenOneLevelWithKeys(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/true); +} + +nb::object PyTreeRegistry::FlattenOneLevelImpl(nb::handle x, + bool with_keys) const { PyTreeRegistry::Registration const* custom; PyTreeKind kind = KindOfObject(x, &custom); switch (kind) { case PyTreeKind::kNone: return nb::make_tuple(nb::make_tuple(), nb::none()); - case PyTreeKind::kTuple: - case PyTreeKind::kList: + case PyTreeKind::kTuple: { + if (with_keys) { + auto size = PyTuple_GET_SIZE(x.ptr()); + nb::object key_leaves = nb::steal(PyTuple_New(size)); + for (int i = 0; i < size; ++i) { + nb::object key = make_nb_class(i); + nb::object value = + nb::borrow(PyTuple_GET_ITEM(x.ptr(), i)); + PyTuple_SET_ITEM(key_leaves.ptr(), i, + nb::make_tuple(key, value).release().ptr()); + } + return nb::make_tuple(std::move(key_leaves), nb::none()); + } + return nb::make_tuple(nb::borrow(x), nb::none()); + } + case PyTreeKind::kList: { + if (with_keys) { + auto size = PyList_GET_SIZE(x.ptr()); + nb::object key_leaves = nb::steal(PyTuple_New(size)); + for (int i = 0; i < size; ++i) { + nb::object key = make_nb_class(i); + nb::object value = + nb::borrow(PyList_GET_ITEM(x.ptr(), i)); + PyTuple_SET_ITEM(key_leaves.ptr(), i, + nb::make_tuple(key, value).release().ptr()); + } + return nb::make_tuple(std::move(key_leaves), nb::none()); + } return nb::make_tuple(nb::borrow(x), nb::none()); + } case PyTreeKind::kDict: { nb::dict dict = nb::borrow(x); std::vector sorted_keys = GetSortedPyDictKeys(dict.ptr()); nb::tuple keys = nb::steal(PyTuple_New(sorted_keys.size())); nb::tuple values = nb::steal(PyTuple_New(sorted_keys.size())); for (size_t i = 0; i < sorted_keys.size(); ++i) { - PyTuple_SET_ITEM(values.ptr(), i, - nb::object(dict[sorted_keys[i]]).release().ptr()); + nb::object& key = sorted_keys[i]; + nb::object value = nb::object(dict[key]); + if (with_keys) { + value = nb::make_tuple(make_nb_class(key), value); + } + PyTuple_SET_ITEM(values.ptr(), i, value.release().ptr()); PyTuple_SET_ITEM(keys.ptr(), i, sorted_keys[i].release().ptr()); } return nb::make_tuple(std::move(values), std::move(keys)); @@ -314,12 +357,32 @@ nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { case PyTreeKind::kNamedTuple: { nb::tuple in = nb::borrow(x); nb::list out; + if (with_keys) { + // Get key names from NamedTuple fields. + nb::tuple fields; + if (!nb::try_cast(nb::getattr(in, "_fields"), fields) || + in.size() != fields.size()) { + throw std::invalid_argument( + "A namedtuple's _fields attribute should have the same size as " + "the tuple."); + } + auto field_iter = fields.begin(); + for (nb::handle entry : in) { + out.append(nb::make_tuple( + make_nb_class(nb::str(*field_iter)), entry)); + } + return nb::make_tuple(std::move(out), x.type()); + } for (size_t i = 0; i < in.size(); ++i) { out.append(in[i]); } return nb::make_tuple(std::move(out), x.type()); } case PyTreeKind::kCustom: { + if (with_keys) { + auto [leaves, aux_data] = custom->ToIterableWithKeys(x); + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } auto [leaves, aux_data] = custom->ToIterable(x); return nb::make_tuple(std::move(leaves), std::move(aux_data)); } @@ -327,9 +390,12 @@ nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { auto data_size = custom->data_fields.size(); nb::list leaves = nb::steal(PyList_New(data_size)); for (int leaf = 0; leaf < data_size; ++leaf) { - PyList_SET_ITEM( - leaves.ptr(), leaf, - nb::getattr(x, custom->data_fields[leaf]).release().ptr()); + nb::object value = nb::getattr(x, custom->data_fields[leaf]); + if (with_keys) { + value = nb::make_tuple( + make_nb_class(custom->data_fields[leaf]), value); + } + PyList_SET_ITEM(leaves.ptr(), leaf, value.release().ptr()); } auto meta_size = custom->meta_fields.size(); nb::object aux_data = nb::steal(PyTuple_New(meta_size)); @@ -356,6 +422,7 @@ nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { void* arg) { PyTreeRegistry* registry = nb::inst_ptr(self); Py_VISIT(Py_TYPE(self)); + nb::ft_lock_guard lock(registry->mu_); for (const auto& [key, value] : registry->registrations_) { Py_VISIT(key.ptr()); int rval = value->tp_traverse(visit, arg); @@ -368,6 +435,7 @@ nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { /* static */ int PyTreeRegistry::tp_clear(PyObject* self) { PyTreeRegistry* registry = nb::inst_ptr(self); + nb::ft_lock_guard lock(registry->mu_); registry->registrations_.clear(); return 0; } @@ -401,21 +469,21 @@ std::string SequenceKey::ToReprString() const { } std::string DictKey::ToString() const { - return absl::StrFormat("[%s]", nb::cast(nb::repr(key_))); + return absl::StrFormat("[%s]", nb::cast(nb::repr(key_))); } std::string DictKey::ToReprString() const { return absl::StrFormat("DictKey(key=%s)", - nb::cast(nb::repr(key_))); + nb::cast(nb::repr(key_))); } std::string GetAttrKey::ToString() const { - return absl::StrFormat(".%s", nb::cast(name_)); + return absl::StrFormat(".%s", nb::cast(name_)); } std::string GetAttrKey::ToReprString() const { return absl::StrFormat("GetAttrKey(name='%s')", - nb::cast(name_)); + nb::cast(name_)); } std::string FlattenedIndexKey::ToString() const { @@ -483,7 +551,7 @@ void PyTreeDef::FlattenImpl(nb::handle handle, T& leaves, } else if (!nb::try_cast(o, is_known_leaf)) { throw std::invalid_argument(absl::StrCat( "is_leaf predicate returned a non-boolean value ", - nb::cast(nb::repr(o)), "; expected a boolean")); + nb::cast(nb::repr(o)), "; expected a boolean")); } } if (is_known_leaf) { @@ -836,7 +904,7 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { if (it == traversal_.rend()) { throw std::invalid_argument(absl::StrFormat( "Tree structures did not match: %s vs %s", - nb::cast(nb::repr(xs)), ToString())); + nb::cast(nb::repr(xs)), ToString())); } const Node& node = *it; nb::object object = agenda.back(); @@ -861,7 +929,7 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { "the previous behavior, you can usually write:\n" " jax.tree.map(lambda x, y: None if x is None else f(x, y), a, " "b, is_leaf=lambda x: x is None)", - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(object)))); } break; @@ -869,13 +937,13 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { if (!PyTuple_CheckExact(object.ptr())) { throw std::invalid_argument( absl::StrFormat("Expected tuple, got %s.", - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(object)))); } nb::tuple tuple = nb::borrow(object); if (tuple.size() != node.arity) { throw std::invalid_argument(absl::StrFormat( "Tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), - node.arity, nb::cast(nb::repr(object)))); + node.arity, nb::cast(nb::repr(object)))); } for (nb::handle entry : tuple) { agenda.push_back(nb::borrow(entry)); @@ -887,13 +955,13 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { if (!PyList_CheckExact(object.ptr())) { throw std::invalid_argument( absl::StrFormat("Expected list, got %s.", - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(object)))); } nb::list list = nb::borrow(object); if (list.size() != node.arity) { throw std::invalid_argument(absl::StrFormat( "List arity mismatch: %d != %d; list: %s.", list.size(), - node.arity, nb::cast(nb::repr(object)))); + node.arity, nb::cast(nb::repr(object)))); } for (nb::handle entry : list) { agenda.push_back(nb::borrow(entry)); @@ -905,7 +973,7 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { if (!PyDict_CheckExact(object.ptr())) { throw std::invalid_argument( absl::StrFormat("Expected dict, got %s.", - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(object)))); } nb::dict dict = nb::borrow(object); std::vector keys = GetSortedPyDictKeys(dict.ptr()); @@ -914,9 +982,9 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { // vector. This is error path so it is fine to pay conversion cost. throw std::invalid_argument( absl::StrFormat("Dict key mismatch; expected keys: %s; dict: %s.", - nb::cast( + nb::cast( nb::repr(nb::cast(node.sorted_dict_keys))), - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(object)))); } for (nb::handle key : keys) { agenda.push_back(dict[key]); @@ -929,19 +997,19 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { !nb::hasattr(object, "_fields")) { throw std::invalid_argument( absl::StrFormat("Expected named tuple, got %s.", - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(object)))); } nb::tuple tuple = nb::borrow(object); if (tuple.size() != node.arity) { throw std::invalid_argument(absl::StrFormat( "Named tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), - node.arity, nb::cast(nb::repr(object)))); + node.arity, nb::cast(nb::repr(object)))); } if (tuple.type().not_equal(node.node_data)) { throw std::invalid_argument(absl::StrFormat( "Named tuple type mismatch: expected type: %s, tuple: %s.", - nb::cast(nb::repr(node.node_data)), - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(object)))); } for (nb::handle entry : tuple) { agenda.push_back(nb::borrow(entry)); @@ -954,16 +1022,16 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { if (registration != node.custom) { throw std::invalid_argument(absl::StrFormat( "Custom node type mismatch: expected type: %s, value: %s.", - nb::cast(nb::repr(node.custom->type)), - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(node.custom->type)), + nb::cast(nb::repr(object)))); } auto [leaves, aux_data] = node.custom->ToIterable(object); if (node.node_data.not_equal(aux_data)) { throw std::invalid_argument(absl::StrFormat( "Mismatch custom node data: %s != %s; value: %s.", - nb::cast(nb::repr(node.node_data)), - nb::cast(nb::repr(aux_data)), - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(aux_data)), + nb::cast(nb::repr(object)))); } int arity = 0; for (nb::handle entry : leaves) { @@ -973,7 +1041,7 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { if (arity != node.arity) { throw std::invalid_argument(absl::StrFormat( "Custom type arity mismatch: %d != %d; value: %s.", arity, - node.arity, nb::cast(nb::repr(object)))); + node.arity, nb::cast(nb::repr(object)))); } break; } @@ -984,8 +1052,8 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { throw std::invalid_argument(absl::StrFormat( "Custom dataclasss node type mismatch: expected type: %s, value: " "%s.", - nb::cast(nb::repr(node.custom->type)), - nb::cast(nb::repr(std::move(object))))); + nb::cast(nb::repr(node.custom->type)), + nb::cast(nb::repr(std::move(object))))); } auto meta_size = node.custom->meta_fields.size(); nb::object aux_data = nb::steal(PyTuple_New(meta_size)); @@ -999,15 +1067,15 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { if (node.node_data.not_equal(aux_data)) { throw std::invalid_argument(absl::StrFormat( "Mismatch custom dataclass node data: %s != %s; value: %s.", - nb::cast(nb::repr(node.node_data)), - nb::cast(nb::repr(aux_data)), - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(aux_data)), + nb::cast(nb::repr(object)))); } auto data_size = node.custom->data_fields.size(); if (data_size != node.arity) { throw std::invalid_argument(absl::StrFormat( "Custom type arity mismatch: %d != %d; value: %s.", data_size, - node.arity, nb::cast(nb::repr(object)))); + node.arity, nb::cast(nb::repr(object)))); } for (int leaf = 0; leaf < data_size; ++leaf) { agenda.push_back(nb::borrow( @@ -1020,7 +1088,7 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { if (it != traversal_.rend() || leaf != -1) { throw std::invalid_argument( absl::StrFormat("Tree structures did not match: %s vs %s", - nb::cast(nb::repr(xs)), ToString())); + nb::cast(nb::repr(xs)), ToString())); } return leaves; } @@ -1213,7 +1281,7 @@ std::string PyTreeDef::ToString() const { auto child_iter = agenda.end() - node.arity; for (const nb::handle& key : node.sorted_dict_keys) { absl::StrAppendFormat(&representation, "%s%s: %s", separator, - nb::cast(nb::repr(key)), + nb::cast(nb::repr(key)), *child_iter); child_iter++; separator = ", "; @@ -1232,7 +1300,7 @@ std::string PyTreeDef::ToString() const { if (node.node_data) { // Node data for named tuples is the type. data = absl::StrFormat( - "[%s]", nb::cast( + "[%s]", nb::cast( nb::str(nb::getattr(node.node_data, "__name__")))); } } else { @@ -1240,7 +1308,7 @@ std::string PyTreeDef::ToString() const { nb::str(nb::getattr(node.custom->type, "__name__"))); if (node.node_data) { data = absl::StrFormat( - "[%s]", nb::cast(nb::str(node.node_data))); + "[%s]", nb::cast(nb::str(node.node_data))); } } @@ -1309,7 +1377,7 @@ void PyTreeDef::FromPickle(nb::object pickle) { if (node.custom == nullptr) { throw xla::XlaRuntimeError( absl::StrCat("Unknown custom type in pickled PyTreeDef: ", - nb::cast(nb::repr(t[3])))); + nb::cast(nb::repr(t[3])))); } } else { if (!t[3].is_none()) { @@ -1503,7 +1571,7 @@ nb_class_ptr PyTreeDef::MakeFromNodeDataAndChildren( if (registration == nullptr) { throw std::logic_error(absl::StrFormat( "Could not find type: %s.", - nb::cast(nb::repr(node_data->first)))); + nb::cast(nb::repr(node_data->first)))); } node.kind = registration->kind; if (node.kind == PyTreeKind::kCustom || node.kind == PyTreeKind::kDataclass) { @@ -1577,6 +1645,9 @@ void BuildPytreeSubmodule(nb::module_& m) { nb::arg("tree").none(), nb::arg("leaf_predicate").none() = std::nullopt); registry.def("flatten_one_level", &PyTreeRegistry::FlattenOneLevel, nb::arg("tree").none()); + registry.def("flatten_one_level_with_keys", + &PyTreeRegistry::FlattenOneLevelWithKeys, + nb::arg("tree").none()); registry.def( "flatten_with_path", [](nb_class_ptr registry, nb::object x, @@ -1637,7 +1708,7 @@ void BuildPytreeSubmodule(nb::module_& m) { "deserialize_using_proto", [](nb_class_ptr registry, nb::bytes data) { jax::PyTreeDefProto input; - std::string_view serialized(data.c_str(), data.size()); + absl::string_view serialized(data.c_str(), data.size()); if (serialized.size() > std::numeric_limits::max()) { throw xla::XlaRuntimeError( "Pytree serialization too large to deserialize."); diff --git a/third_party/xla/xla/python/pytree.h b/third_party/xla/xla/python/pytree.h index 1dc8c6effc24e8..f526893d8dc818 100644 --- a/third_party/xla/xla/python/pytree.h +++ b/third_party/xla/xla/python/pytree.h @@ -25,7 +25,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -115,6 +114,11 @@ class PyTreeRegistry { // Flattens a pytree one level, returning either a tuple of the leaves and // the node data, or None, if the entry is a leaf. nanobind::object FlattenOneLevel(nanobind::handle x) const; + // Similar to above but returns a key-leaf pair for each leaf. + nanobind::object FlattenOneLevelWithKeys(nanobind::handle x) const; + // Underlying implementation of FlattenOneLevel and FlattenOneLevelWithKeys. + nanobind::object FlattenOneLevelImpl(nanobind::handle x, + bool with_keys) const; static PyType_Slot slots_[]; @@ -139,9 +143,10 @@ class PyTreeRegistry { return a.ptr() == b.ptr(); } }; + mutable nanobind::ft_mutex mu_; absl::flat_hash_map, TypeHash, TypeEq> - registrations_; + registrations_; // Guarded by mu_ bool enable_namedtuple_; static int tp_traverse(PyObject* self, visitproc visit, void* arg); diff --git a/third_party/xla/xla/python/sharding.cc b/third_party/xla/xla/python/sharding.cc index c1bae6a50a58a1..d9d509cd95a5bc 100644 --- a/third_party/xla/xla/python/sharding.cc +++ b/third_party/xla/xla/python/sharding.cc @@ -20,22 +20,20 @@ limitations under the License. #include #include #include -#include #include #include "absl/hash/hash.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "xla/hlo/ir/hlo_sharding.h" #include "xla/pjrt/status_casters.h" -#include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/nb_class_ptr.h" -#include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" #include "xla/python/py_client.h" #include "xla/python/py_device_list.h" @@ -48,24 +46,13 @@ namespace jax { namespace nb = nanobind; -bool (*GetEnableMemories)() = +[] { - static bool fetch_memory_kind_on_executable = [] { - char* v = getenv("JAX_ENABLE_MEMORIES"); - if (v == nullptr || *v == '\0') { - return false; - } - return true; - }(); - return fetch_memory_kind_on_executable; -}; - nb::object CheckAndCanonicalizeMemoryKind( nb::object memory_kind, const xla::nb_class_ptr& device_list) { if (!memory_kind.is_none()) { // If memory kind is not None, check if it's supported by the devices // mentioned in the Sharding. - auto supported_memory_kinds = device_list->MemoryKinds(); + auto supported_memory_kinds = PyDeviceList::MemoryKinds(device_list); if (!supported_memory_kinds.ok()) { supported_memory_kinds = nb::tuple(); } @@ -83,9 +70,10 @@ nb::object CheckAndCanonicalizeMemoryKind( } nb::object device_kind = addressable_device_list->GetItem(0).attr("device_kind"); - std::string_view device_kind_str = nb::cast(device_kind); + absl::string_view device_kind_str = + nb::cast(device_kind); auto py_str_formatter = [](std::string* out, nb::handle h) { - *out += nb::cast(nb::str(h)); + *out += nb::cast(nb::str(h)); }; throw nb::value_error( absl::StrCat( @@ -93,12 +81,12 @@ nb::object CheckAndCanonicalizeMemoryKind( ". Device ", device_kind_str, " can address the following memory kinds: ", absl::StrJoin(*supported_memory_kinds, ", ", py_str_formatter), - ". Got memory kind: ", nb::cast(memory_kind)) + ". Got memory kind: ", nb::cast(memory_kind)) .c_str()); } // If memory kind is None, canonicalize to default memory. absl::StatusOr default_memory_kind = - device_list->DefaultMemoryKind(); + PyDeviceList::DefaultMemoryKind(device_list); if (!default_memory_kind.ok()) { return nb::none(); } diff --git a/third_party/xla/xla/python/sharding.h b/third_party/xla/xla/python/sharding.h index 5b41ae04110689..3d484e3c217f6f 100644 --- a/third_party/xla/xla/python/sharding.h +++ b/third_party/xla/xla/python/sharding.h @@ -52,8 +52,6 @@ class Sharding { std::optional num_devices_; }; -extern bool (*GetEnableMemories)(); - // Checks if the memory kind is valid, and canonicalizes the // memory kind to default memory on backends that support memories. nanobind::object CheckAndCanonicalizeMemoryKind( diff --git a/third_party/xla/xla/python/to_ifrt_sharding.cc b/third_party/xla/xla/python/to_ifrt_sharding.cc new file mode 100644 index 00000000000000..f7f27a5793fc30 --- /dev/null +++ b/third_party/xla/xla/python/to_ifrt_sharding.cc @@ -0,0 +1,115 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/python/to_ifrt_sharding.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_class_ptr.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "xla/python/py_device_list.h" +#include "xla/python/sharding.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { + +namespace nb = ::nanobind; + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nb::handle sharding, + int64_t num_dimensions) { + if (sharding.type().is(nb::handle(jax::GSPMDSharding::type().ptr()))) { + return nb::cast(nb::handle(sharding.ptr())) + ->hlo_sharding(); + } else { + return nb::cast( + sharding.attr("_to_xla_hlo_sharding")(num_dimensions)); + } +} + +// Gets `xla::ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr> GetIfrtDeviceList( + nb::handle sharding_py) { + nb::handle sharding(sharding_py.ptr()); + if (sharding.type().is(jax::NamedSharding::type())) { + TF_ASSIGN_OR_RETURN( + auto ns_device_list, + nb::cast(sharding)->internal_device_list()); + return ns_device_list->ifrt_device_list(); + } else if (sharding.type().is(jax::SingleDeviceSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else if (sharding.type().is(jax::PmapSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else if (sharding.type().is(jax::GSPMDSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else { + return nb::cast( + sharding.attr("_internal_device_list")) + ->ifrt_device_list(); + } +} + +// Converts a JAX Sharding into `xla::ifrt::HloSharding`. +absl::StatusOr> GetIfrtHloSharding( + nb::handle sharding, const xla::ifrt::Shape& shape) { + TF_ASSIGN_OR_RETURN(tsl::RCReference device_list, + GetIfrtDeviceList(sharding)); + xla::HloSharding hlo_sharding = + GetXlaHloSharding(sharding, shape.dims().size()); + return xla::ifrt::HloSharding::Create( + std::move(device_list), xla::ifrt::MemoryKind(), std::move(hlo_sharding)); +} + +// Converts a JAX Sharding into `xla::ifrt::ConcreteEvenSharding`. +absl::StatusOr> +GetIfrtConcreteEvenSharding(nb::handle sharding, xla::ifrt::DType dtype, + const xla::ifrt::Shape& shape) { + TF_ASSIGN_OR_RETURN(tsl::RCReference device_list, + GetIfrtDeviceList(sharding)); + TF_ASSIGN_OR_RETURN(xla::PrimitiveType xla_primitive_type, + xla::ifrt::ToPrimitiveType(dtype)); + // The XLA shape's layout is irrelevant because we only need to know the + // tile shape, which is independent from the layout. + xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout( + xla_primitive_type, shape.dims()); + xla::HloSharding hlo_sharding = + GetXlaHloSharding(sharding, shape.dims().size()); + xla::Shape tile_shape = hlo_sharding.TileShape(xla_shape); + xla::ifrt::Shape shard_shape(xla::ifrt::Shape::Dimensions( + tile_shape.dimensions().begin(), tile_shape.dimensions().end())); + return xla::ifrt::ConcreteEvenSharding::Create( + std::move(device_list), xla::ifrt::MemoryKind(), shape, + /*shard_shape=*/std::move(shard_shape)); +} + +} // namespace xla diff --git a/third_party/xla/xla/python/to_ifrt_sharding.h b/third_party/xla/xla/python/to_ifrt_sharding.h new file mode 100644 index 00000000000000..dad74f5dc4a818 --- /dev/null +++ b/third_party/xla/xla/python/to_ifrt_sharding.h @@ -0,0 +1,47 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_TO_IFRT_SHARDING_H_ +#define XLA_PYTHON_TO_IFRT_SHARDING_H_ + +#include "nanobind/nanobind.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/sharding.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nanobind::handle sharding, + int64_t num_dimensions); + +// Gets `xla::ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr> GetIfrtDeviceList( + nanobind::handle sharding_py); + +// Converts a JAX Sharding into `xla::ifrt::HloSharding`. +absl::StatusOr> GetIfrtHloSharding( + nanobind::handle sharding, const xla::ifrt::Shape& shape); + +// Converts a JAX Sharding into `xla::ifrt::ConcreteEvenSharding`. +absl::StatusOr> +GetIfrtConcreteEvenSharding(nanobind::handle sharding, xla::ifrt::DType dtype, + const xla::ifrt::Shape& shape); + +} // namespace xla + +#endif // XLA_PYTHON_TO_IFRT_SHARDING_H_ diff --git a/third_party/xla/xla/python/traceback.cc b/third_party/xla/xla/python/traceback.cc index 19e4f94d4f8d9b..a9d35e4d04d745 100644 --- a/third_party/xla/xla/python/traceback.cc +++ b/third_party/xla/xla/python/traceback.cc @@ -20,14 +20,15 @@ limitations under the License. #include #include #include -#include #include #include #include "absl/base/casts.h" #include "absl/hash/hash.h" +#include "absl/log/check.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep #include "nanobind/stl/string.h" // IWYU pragma: keep @@ -108,8 +109,8 @@ Traceback::Traceback(Traceback&& other) noexcept } std::string Traceback::Frame::ToString() const { - return absl::StrFormat("%s:%d (%s)", nb::cast(file_name), - line_num, nb::cast(function_name)); + return absl::StrFormat("%s:%d (%s)", nb::cast(file_name), + line_num, nb::cast(function_name)); } std::string Traceback::ToString() const { @@ -230,8 +231,8 @@ void BuildTracebackSubmodule(nb::module_& m) { .def_ro("line_num", &Traceback::Frame::line_num) .def("__repr__", [](const Traceback::Frame& frame) { return absl::StrFormat( - "%s;%s:%d", nb::cast(frame.function_name), - nb::cast(frame.file_name), frame.line_num); + "%s;%s:%d", nb::cast(frame.function_name), + nb::cast(frame.file_name), frame.line_num); }); nb::class_ traceback(m, "Traceback", diff --git a/third_party/xla/xla/python/types.cc b/third_party/xla/xla/python/types.cc index 125f96a75fdf25..50366be350bc08 100644 --- a/third_party/xla/xla/python/types.cc +++ b/third_party/xla/xla/python/types.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -29,6 +28,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/ndarray.h" // IWYU pragma: keep @@ -39,7 +39,6 @@ limitations under the License. #include "xla/literal.h" #include "xla/pjrt/exceptions.h" #include "xla/python/ifrt/dtype.h" -#include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/shape.h" @@ -175,7 +174,7 @@ absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type) { return custom_it->second; } return InvalidArgument("Unknown NumPy dtype %s char %c kind %c itemsize %d", - nb::cast(nb::repr(np_type)), + nb::cast(nb::repr(np_type)), np_type.char_(), np_type.kind(), np_type.itemsize()); } diff --git a/third_party/xla/xla/python/types.h b/third_party/xla/xla/python/types.h index 59c27d99184e5c..aacfea1a17997f 100644 --- a/third_party/xla/xla/python/types.h +++ b/third_party/xla/xla/python/types.h @@ -186,7 +186,7 @@ struct type_caster { // Pybind appears to keep type_casters alive until the callee has run. absl::InlinedVector arrays; - bool from_python(handle input, uint8_t, cleanup_list*) { + bool from_python(handle input, uint8_t, cleanup_list*) noexcept { // TODO(b/79707221): support nested tuples if/when XLA adds support for // nested BorrowingLiterals. if (nanobind::isinstance(input)) { @@ -227,7 +227,8 @@ struct type_caster { // Pybind appears to keep type_casters alive until the callee has run. type_caster literal_caster; - bool from_python(handle handle, uint8_t flags, cleanup_list* cleanup) { + bool from_python(handle handle, uint8_t flags, + cleanup_list* cleanup) noexcept { if (!literal_caster.from_python(handle, flags, cleanup)) { return false; } diff --git a/third_party/xla/xla/python/weakref_lru_cache.cc b/third_party/xla/xla/python/weakref_lru_cache.cc index 3cba509adb8a2c..34209751067054 100644 --- a/third_party/xla/xla/python/weakref_lru_cache.cc +++ b/third_party/xla/xla/python/weakref_lru_cache.cc @@ -39,6 +39,7 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "xla/pjrt/lru_cache.h" +#include "xla/tsl/platform/logging.h" namespace nb = nanobind; @@ -215,6 +216,12 @@ class WeakrefLRUCache : public std::enable_shared_from_this { if (cache == nullptr) { return; } + // Set up PyCriticalSection for cache python associated object; + auto py_cache = nb::find(cache); + // This should never happen as python cache should always be found + CHECK(py_cache.ptr() != nullptr); + nb::ft_object_guard lock(py_cache); + // The object the reference referred to is now in the process of being // destroyed, so we cannot refer to its contents. Python weakref // objects compare based on identity if the object they refer to is @@ -367,10 +374,10 @@ void BuildWeakrefLRUCacheAPI(nb::module_& m) { nb::class_(m, "WeakrefLRUCache", nb::is_weak_referenceable(), nb::type_slots(WeakrefLRUCache::slots_)) - .def("__call__", &WeakrefLRUCache::Call) - .def("cache_keys", &WeakrefLRUCache::GetKeys) - .def("cache_info", &WeakrefLRUCache::GetCacheInfo) - .def("cache_clear", &WeakrefLRUCache::Clear); + .def("__call__", &WeakrefLRUCache::Call, nb::lock_self()) + .def("cache_keys", &WeakrefLRUCache::GetKeys, nb::lock_self()) + .def("cache_info", &WeakrefLRUCache::GetCacheInfo, nb::lock_self()) + .def("cache_clear", &WeakrefLRUCache::Clear, nb::lock_self()); nb::class_(weakref_lru_cache, "WeakrefLRUCacheInfo") .def_ro("hits", &WeakrefLRUCache::CacheInfo::hits) diff --git a/third_party/xla/xla/python/weakref_lru_cache_test.py b/third_party/xla/xla/python/weakref_lru_cache_test.py index 55d33fb895c8f2..018b70c0351adc 100644 --- a/third_party/xla/xla/python/weakref_lru_cache_test.py +++ b/third_party/xla/xla/python/weakref_lru_cache_test.py @@ -76,6 +76,36 @@ def Body(): cache(wrkey, GilReleasingCacheKey()) t.join() + def testAnotherMultiThreaded(self): + num_workers = 5 + barrier = threading.Barrier(num_workers) + cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048) + + class WRKey: + pass + + def WorkerAddToCache(): + barrier.wait() + wrkey = WRKey() + for i in range(10): + cache(wrkey, i) + + def WorkerCleanCache(): + barrier.wait() + for _ in range(10): + cache.cache_clear() + + workers = [ + threading.Thread(target=WorkerAddToCache) + for _ in range(num_workers - 1) + ] + [threading.Thread(target=WorkerCleanCache)] + + for t in workers: + t.start() + + for t in workers: + t.join() + def testKwargsDictOrder(self): miss_id = 0 diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index 0fe3da546b9526..0085e3224efe20 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -46,6 +46,7 @@ limitations under the License. #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep #include "nanobind/stl/variant.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/distributed/client.h" #include "xla/pjrt/distributed/distributed.h" @@ -63,23 +64,22 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" #include "xla/python/py_client.h" #include "xla/python/py_program.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/python/lib/core/numpy.h" // NOLINT #if defined(__linux__) #include "gloo/transport/tcp/attr.h" #include "gloo/transport/tcp/device.h" -#include "xla/pjrt/cpu/gloo_collectives.h" -#include "xla/pjrt/cpu/gloo_kv_store.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" +#include "xla/backends/cpu/collectives/gloo_kv_store.h" #elif defined(__APPLE__) #include "gloo/transport/uv/device.h" -#include "xla/pjrt/cpu/gloo_collectives.h" // NOLINT -#include "xla/pjrt/cpu/gloo_kv_store.h" // NOLINT +#include "xla/backends/cpu/collectives/gloo_collectives.h" // NOLINT +#include "xla/backends/cpu/collectives/gloo_kv_store.h" // NOLINT #endif // defined(__linux__) #if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) -#include "xla/pjrt/cpu/mpi_collectives.h" +#include "xla/backends/cpu/collectives/mpi_collectives.h" #endif // !_WIN32 && !PLATFORM_GOOGLE #include "xla/pjrt/distributed/key_value_store_interface.h" @@ -169,7 +169,7 @@ bool IsSanitized() { return IsAsan() || IsMsan() || IsTsan(); } } // namespace -NB_MODULE(xla_extension, m_nb) { +NB_MODULE(xla_extension, m) { // Initialize ABSL logging because code within XLA uses it. #ifndef PLATFORM_GOOGLE InitializeAbslLogging(); @@ -182,7 +182,7 @@ NB_MODULE(xla_extension, m_nb) { tsl::ImportNumpy(); // Exceptions - nb::exception xla_runtime_error(m_nb, "XlaRuntimeError", + nb::exception xla_runtime_error(m, "XlaRuntimeError", PyExc_RuntimeError); xla_runtime_error.attr("__doc__") = nb::str( "Runtime errors thrown by the JAX runtime. While the JAX runtime may " @@ -190,7 +190,7 @@ NB_MODULE(xla_extension, m_nb) { "are instances of this class."); // Types - nb::enum_(m_nb, "PrimitiveType", nb::is_arithmetic()) + nb::enum_(m, "PrimitiveType", nb::is_arithmetic()) .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID) .value("PRED", PRED) .value("S4", S4) @@ -222,58 +222,52 @@ NB_MODULE(xla_extension, m_nb) { .value("TOKEN", TOKEN); // Must be before PyClient.compile. - BuildXlaCompilerSubmodule(m_nb); + BuildXlaCompilerSubmodule(m); - PyDevice::RegisterPythonType(m_nb); - PyMemorySpace::RegisterPythonType(m_nb); - PyClient::RegisterPythonTypes(m_nb); + PyDevice::RegisterPythonType(m); + PyMemorySpace::RegisterPythonType(m); + PyClient::RegisterPythonTypes(m); - nb::enum_(m_nb, "ArrayCopySemantics", + nb::enum_(m, "ArrayCopySemantics", nb::is_arithmetic()) .value("ALWAYS_COPY", ifrt::ArrayCopySemantics::kAlwaysCopy) .value("REUSE_INPUT", ifrt::ArrayCopySemantics::kReuseInput) .value("DONATE_INPUT", ifrt::ArrayCopySemantics::kDonateInput); - nb::class_(m_nb, "PjRtLayout") + nb::class_(m, "PjRtLayout") .def("__str__", &PjRtLayout::ToString) .def("__eq__", [](const PjRtLayout& layout, const PjRtLayout& other) { return layout == other; }) .def("__hash__", - [](const PjRtLayout& layout) { return absl::HashOf(layout); }); - - nb::class_(m_nb, "PjRtXlaLayout") - .def("_xla_layout", &PjRtXlaLayout::xla_layout) + [](const PjRtLayout& layout) { return absl::HashOf(layout); }) + .def("_xla_layout", &PjRtLayout::xla_layout) .def("__getstate__", - [](const PjRtXlaLayout& layout) -> nb::tuple { + [](const PjRtLayout& layout) -> nb::tuple { absl::StatusOr serialized = layout.Serialize(); ThrowIfError(serialized.status()); return nb::make_tuple( nb::bytes(serialized->data(), serialized->size())); }) - .def("__setstate__", [](PjRtXlaLayout* self, nb::tuple t) { - // TODO(b/328671718): don't assume PjRtXlaLayout. We probably want a - // generic method on PjRtCompiler instead, although we'll have - // somehow have to attach a compiler to this PjRtLayout (something - // like ClientAndPtr). + .def("__setstate__", [](PjRtLayout* self, nb::tuple t) { nb::bytes serialized = nb::cast(t[0]); - absl::StatusOr layout = PjRtXlaLayout::Deserialize( - std::string_view(serialized.c_str(), serialized.size())); + absl::StatusOr> layout = + PjRtLayout::Deserialize( + absl::string_view(serialized.c_str(), serialized.size())); ThrowIfError(layout.status()); - new (self) PjRtXlaLayout(std::move(*layout)); + new (self) PjRtLayout((*layout)->xla_layout()); }); - jax::BuildWeakrefLRUCacheAPI(m_nb); + jax::BuildWeakrefLRUCacheAPI(m); - nb::class_ cpu_collectives(m_nb, - "CpuCollectives"); + nb::class_ cpu_collectives(m, "CpuCollectives"); - m_nb.def( + m.def( "make_gloo_tcp_collectives", [](std::shared_ptr distributed_client, std::optional hostname, std::optional interface) - -> std::shared_ptr { + -> std::shared_ptr { #if defined(__linux__) std::shared_ptr kv_store = nullptr; if (distributed_client != nullptr) { @@ -317,29 +311,28 @@ NB_MODULE(xla_extension, m_nb) { nb::arg("interface").none() = std::nullopt); #if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) - nb::class_ mpi_collectives(m_nb, "MpiCollectives", + nb::class_ mpi_collectives(m, "MpiCollectives", cpu_collectives); mpi_collectives.def("Init", &cpu::MpiCollectives::Init); mpi_collectives.def("Finalize", &cpu::MpiCollectives::Finalize); - m_nb.def("make_mpi_collectives", - []() -> std::shared_ptr { - return std::make_shared(); - }); + m.def("make_mpi_collectives", []() -> std::shared_ptr { + return std::make_shared(); + }); #else // !_WIN32 && !PLATFORM_GOOGLE - m_nb.def("make_mpi_collectives", - []() -> std::shared_ptr { - throw xla::XlaRuntimeError( - "make_mpi_collectives is not implemented for Windows"); - }); + m.def("make_mpi_collectives", + []() -> std::shared_ptr { + throw xla::XlaRuntimeError( + "make_mpi_collectives is not implemented for Windows"); + }); #endif // !_WIN32 && !PLATFORM_GOOGLE - m_nb.def( + m.def( "get_tfrt_cpu_client", [](bool asynchronous, std::shared_ptr distributed_client, int node_id, int num_nodes, - std::shared_ptr collectives) - -> nb_class_ptr { + std::shared_ptr collectives, + std::optional num_devices) -> nb_class_ptr { std::unique_ptr ifrt_client; { nb::gil_scoped_release gil_release; @@ -348,6 +341,7 @@ NB_MODULE(xla_extension, m_nb) { options.asynchronous = asynchronous; options.collectives = std::move(collectives); options.process_id = node_id; + options.cpu_device_count = num_devices; std::unique_ptr client = xla::ValueOrThrow(xla::GetXlaPjrtCpuClient(std::move(options))); ifrt::PjRtClient::CreateOptions ifrt_options; @@ -368,12 +362,13 @@ NB_MODULE(xla_extension, m_nb) { nb::arg("asynchronous") = true, nb::arg("distributed_client") = nullptr, nb::arg("node_id") = 0, nb::arg("num_nodes") = 1, nb::arg("collectives").none() = - std::shared_ptr()); - m_nb.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool { + std::shared_ptr(), + nb::arg("num_devices").none() = std::nullopt); + m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool { absl::StatusOr pjrt_api = pjrt::PjrtApi(platform_name); return pjrt_api.ok(); }); - m_nb.def( + m.def( "load_pjrt_plugin", [](std::string platform_name, std::optional library_path, std::optional c_api) -> nb::capsule { @@ -393,14 +388,14 @@ NB_MODULE(xla_extension, m_nb) { }, nb::arg("platform_name"), nb::arg("library_path").none() = std::nullopt, nb::arg("c_api").none() = std::nullopt); - m_nb.def("pjrt_plugin_initialized", [](std::string platform_name) -> bool { + m.def("pjrt_plugin_initialized", [](std::string platform_name) -> bool { return xla::ValueOrThrow(pjrt::IsPjrtPluginInitialized(platform_name)); }); - m_nb.def("initialize_pjrt_plugin", [](std::string platform_name) { + m.def("initialize_pjrt_plugin", [](std::string platform_name) { return xla::ThrowIfError(pjrt::InitializePjrtPlugin(platform_name)); }); - m_nb.def( + m.def( "get_c_api_client", [](std::string platform_name, const absl::flat_hash_map& options, @@ -426,54 +421,53 @@ NB_MODULE(xla_extension, m_nb) { nb::arg("distributed_client").none() = nullptr); // TODO(b/322357665): Delete this method after TPU plugin changes to use the // standard registration. - m_nb.def("get_default_c_api_topology", - [](std::string platform_name, std::string topology_name, - const absl::flat_hash_map& options) - -> std::shared_ptr { - return std::make_shared(xla::ValueOrThrow( - GetCApiTopology(platform_name, topology_name, options))); - }); - m_nb.def( - "get_c_api_topology", - [](nb::capsule c_api, std::string topology_name, - const absl::flat_hash_map& options) - -> std::shared_ptr { - if (absl::string_view(c_api.name()) != "pjrt_c_api") { - throw nb::value_error( - "Argument to get_c_api_topology was not a pjrt_c_api capsule."); - } - return std::make_shared(xla::ValueOrThrow( - GetCApiTopology(static_cast(c_api.data()), - topology_name, options))); - }); - m_nb.def("get_topology_for_devices", - [](const std::vector>& py_devices) { - if (py_devices.empty()) { - throw nb::value_error( - "get_topology_for_devices requires >= 1 devices."); - } - auto client = py_devices[0]->client(); - ifrt::BasicDeviceList::Devices ifrt_devices; - ifrt_devices.reserve(py_devices.size()); - for (const auto& py_device : py_devices) { - if (py_device->client().get() != client.get()) { - throw nb::value_error( - "devices passed to get_topology_for_devices come from " - "different clients."); - } - ifrt_devices.push_back(py_device->device()); - } - tsl::RCReference device_list = - ifrt::BasicDeviceList::Create(std::move(ifrt_devices)); - return xla::ValueOrThrow( - client->ifrt_client()->GetTopologyForDevices(device_list)); - }); + m.def("get_default_c_api_topology", + [](std::string platform_name, std::string topology_name, + const absl::flat_hash_map& options) + -> std::shared_ptr { + return std::make_shared(xla::ValueOrThrow( + GetCApiTopology(platform_name, topology_name, options))); + }); + m.def("get_c_api_topology", + [](nb::capsule c_api, std::string topology_name, + const absl::flat_hash_map& options) + -> std::shared_ptr { + if (absl::string_view(c_api.name()) != "pjrt_c_api") { + throw nb::value_error( + "Argument to get_c_api_topology was not a pjrt_c_api capsule."); + } + return std::make_shared(xla::ValueOrThrow( + GetCApiTopology(static_cast(c_api.data()), + topology_name, options))); + }); + m.def("get_topology_for_devices", + [](const std::vector>& py_devices) { + if (py_devices.empty()) { + throw nb::value_error( + "get_topology_for_devices requires >= 1 devices."); + } + auto client = py_devices[0]->client(); + ifrt::BasicDeviceList::Devices ifrt_devices; + ifrt_devices.reserve(py_devices.size()); + for (const auto& py_device : py_devices) { + if (py_device->client().get() != client.get()) { + throw nb::value_error( + "devices passed to get_topology_for_devices come from " + "different clients."); + } + ifrt_devices.push_back(py_device->device()); + } + tsl::RCReference device_list = + ifrt::BasicDeviceList::Create(std::move(ifrt_devices)); + return xla::ValueOrThrow( + client->ifrt_client()->GetTopologyForDevices(device_list)); + }); - TF_CHECK_OK(PyArray::RegisterTypes(m_nb)); - jax::RegisterDeviceList(m_nb); - jax::RegisterSharding(m_nb); + TF_CHECK_OK(PyArray::RegisterTypes(m)); + jax::PyDeviceList::Register(m); + jax::RegisterSharding(m); - nb::class_(m_nb, "CompiledMemoryStats") + nb::class_(m, "CompiledMemoryStats") .def_rw("generated_code_size_in_bytes", &CompiledMemoryStats::generated_code_size_in_bytes) .def_rw("argument_size_in_bytes", @@ -499,7 +493,7 @@ NB_MODULE(xla_extension, m_nb) { }) .def("__str__", &CompiledMemoryStats::DebugString); - nb::class_(m_nb, "ExecuteResults") + nb::class_(m, "ExecuteResults") .def("__len__", [](PyExecuteResults& results) { return results.Size(); }) .def("disassemble_into_single_device_arrays", &PyExecuteResults::DisassembleIntoSingleDeviceArrays) @@ -508,7 +502,7 @@ NB_MODULE(xla_extension, m_nb) { .def("consume_with_handlers", &PyExecuteResults::ConsumeWithHandlers) .def("consume_token", &PyExecuteResults::ConsumeToken); - nb::class_(m_nb, "LoadedExecutable") + nb::class_(m, "LoadedExecutable") .def_prop_ro("client", &PyLoadedExecutable::client) .def("local_devices", &PyLoadedExecutable::AddressableDevices) .def("size_of_generated_code_in_bytes", @@ -540,11 +534,6 @@ NB_MODULE(xla_extension, m_nb) { .def("get_parameter_shardings", &PyLoadedExecutable::GetParameterShardings) .def("keep_alive", &PyLoadedExecutable::KeepAlive) - .def("compile_options", - [](const PyLoadedExecutable& self) { - return xla::ValueOrThrow( - self.pjrt_executable()->GetCompileOptions()); - }) .def("cost_analysis", [](const PyLoadedExecutable& self) { auto map = ValueOrThrow(self.GetCostAnalysis()); @@ -559,20 +548,20 @@ NB_MODULE(xla_extension, m_nb) { return nb::none(); } }); - nb::class_ token(m_nb, "Token"); + nb::class_ token(m, "Token"); token.def("block_until_ready", [](PyToken& self) { xla::ThrowIfError(self.Await()); }); - nb::class_ sharded_token(m_nb, "ShardedToken"); + nb::class_ sharded_token(m, "ShardedToken"); sharded_token.def("block_until_ready", [](PyShardedToken& self) { xla::ThrowIfError(self.Await()); }); sharded_token.def("get_token", &PyShardedToken::GetPyToken); - m_nb.def("buffer_to_dlpack_managed_tensor", - xla::ValueOrThrowWrapper(BufferToDLPackManagedTensor), - nb::arg("buffer"), nb::arg("stream").none() = nb::none()); - m_nb.def( + m.def("buffer_to_dlpack_managed_tensor", + xla::ValueOrThrowWrapper(BufferToDLPackManagedTensor), + nb::arg("buffer"), nb::arg("stream").none() = nb::none()); + m.def( "dlpack_managed_tensor_to_buffer", [](const nb::capsule& tensor, nb_class_ptr device, std::optional stream) { @@ -581,7 +570,7 @@ NB_MODULE(xla_extension, m_nb) { }, nb::arg("dlpack"), nb::arg("device"), nb::arg("stream").none()); // Legacy overload - m_nb.def( + m.def( "dlpack_managed_tensor_to_buffer", [](const nb::capsule& tensor, std::optional> cpu_client, @@ -591,30 +580,30 @@ NB_MODULE(xla_extension, m_nb) { }, nb::arg("dlpack"), nb::arg("cpu_backend").none() = nb::none(), nb::arg("gpu_backend").none() = nb::none()); - m_nb.def("cuda_array_interface_to_buffer", - xla::ValueOrThrowWrapper(CudaArrayInterfaceToBuffer), nb::arg("cai"), - nb::arg("gpu_backend").none() = nb::none(), - nb::arg("device_id").none() = nb::none()); - - jax::BuildConfigSubmodule(m_nb); - BuildIfrtProgramsSubmodule(m_nb); - BuildProfilerSubmodule(m_nb); - BuildOpsSubmodule(m_nb); - BuildPytreeSubmodule(m_nb); - jax::BuildGuardSubmodule(m_nb); - jax::BuildJaxjitSubmodule(m_nb); - jax::BuildPmapSubmodule(m_nb); - jax::BuildPjitSubmodule(m_nb); - BuildTracebackSubmodule(m_nb); - BuildMlirSubmodule(m_nb); - BuildCustomCallShardingPybindAPI(m_nb); + m.def("cuda_array_interface_to_buffer", + xla::ValueOrThrowWrapper(CudaArrayInterfaceToBuffer), nb::arg("cai"), + nb::arg("gpu_backend").none() = nb::none(), + nb::arg("device_id").none() = nb::none()); + + jax::BuildConfigSubmodule(m); + BuildIfrtProgramsSubmodule(m); + BuildProfilerSubmodule(m); + BuildOpsSubmodule(m); + BuildPytreeSubmodule(m); + jax::BuildGuardSubmodule(m); + jax::BuildJaxjitSubmodule(m); + jax::BuildPmapSubmodule(m); + jax::BuildPjitSubmodule(m); + BuildTracebackSubmodule(m); + BuildMlirSubmodule(m); + BuildCustomCallShardingPybindAPI(m); // The following uses python bindings for PyClient defined above using // pybind11, and hence needs pybind11::module_ (not just nanobind::module_). - xla::ifrt::proxy::BuildIfrtProxySubmodule(m_nb); + xla::ifrt::proxy::BuildIfrtProxySubmodule(m); nb::class_ preemption_sync_manager( - m_nb, "PreemptionSyncManager"); + m, "PreemptionSyncManager"); preemption_sync_manager .def( "initialize", @@ -629,16 +618,16 @@ NB_MODULE(xla_extension, m_nb) { [](tsl::PreemptionSyncManager& manager, int step_counter) { return manager.ReachedSyncPoint(step_counter); }); - m_nb.def("create_preemption_sync_manager", - []() { return tsl::CreatePreemptionSyncManager(); }); + m.def("create_preemption_sync_manager", + []() { return tsl::CreatePreemptionSyncManager(); }); nb::class_ distributed_runtime_service( - m_nb, "DistributedRuntimeService"); + m, "DistributedRuntimeService"); distributed_runtime_service.def("shutdown", &DistributedRuntimeService::Shutdown, nb::call_guard()); nb::class_ distributed_runtime_client( - m_nb, "DistributedRuntimeClient"); + m, "DistributedRuntimeClient"); distributed_runtime_client .def("connect", [](DistributedRuntimeClient& self) { @@ -674,6 +663,21 @@ NB_MODULE(xla_extension, m_nb) { return nb::bytes(result.data(), result.size()); }, nb::arg("key"), nb::arg("timeout_in_ms")) + .def( + "key_value_try_get", + [](DistributedRuntimeClient& client, std::string key) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.KeyValueTryGet(key)); + }, + nb::arg("key")) + .def( + "key_value_try_get_bytes", + [](DistributedRuntimeClient& client, std::string key) -> nb::bytes { + nb::gil_scoped_release gil_release; + std::string result = xla::ValueOrThrow(client.KeyValueTryGet(key)); + return nb::bytes(result.data(), result.size()); + }, + nb::arg("key")) .def( "wait_at_barrier", [](DistributedRuntimeClient& client, std::string barrier_id, @@ -693,8 +697,8 @@ NB_MODULE(xla_extension, m_nb) { // `blocking_key_value_get_bytes()`. .def( "key_value_set", - [](DistributedRuntimeClient& client, std::string_view key, - std::string_view value, bool allow_overwrite) { + [](DistributedRuntimeClient& client, absl::string_view key, + absl::string_view value, bool allow_overwrite) { nb::gil_scoped_release gil_release; xla::ThrowIfError(client.KeyValueSet(key, value, allow_overwrite)); }, @@ -704,18 +708,18 @@ NB_MODULE(xla_extension, m_nb) { // Use `key_value_set_bytes()` and `blocking_key_value_get_bytes()`. .def( "key_value_set_bytes", - [](DistributedRuntimeClient& client, std::string_view key, + [](DistributedRuntimeClient& client, absl::string_view key, nb::bytes value, bool allow_overwrite) { nb::gil_scoped_release gil_release; xla::ThrowIfError(client.KeyValueSet( - key, std::string_view(value.c_str(), value.size()), + key, absl::string_view(value.c_str(), value.size()), allow_overwrite)); }, nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false) // Assumes that all values in the directory are Python strings. .def( "key_value_dir_get", - [](DistributedRuntimeClient& client, std::string_view key) { + [](DistributedRuntimeClient& client, absl::string_view key) { nb::gil_scoped_release gil_release; return xla::ValueOrThrow(client.KeyValueDirGet(key)); }, @@ -725,7 +729,7 @@ NB_MODULE(xla_extension, m_nb) { // explicitly. .def( "key_value_dir_get_bytes", - [](DistributedRuntimeClient& client, std::string_view key) + [](DistributedRuntimeClient& client, absl::string_view key) -> std::vector> { nb::gil_scoped_release gil_release; std::vector> result = @@ -742,13 +746,13 @@ NB_MODULE(xla_extension, m_nb) { nb::arg("key")) .def( "key_value_delete", - [](DistributedRuntimeClient& client, std::string_view key) { + [](DistributedRuntimeClient& client, absl::string_view key) { nb::gil_scoped_release gil_release; return xla::ThrowIfError(client.KeyValueDelete(key)); }, nb::arg("key")); - m_nb.def( + m.def( "get_distributed_runtime_service", [](std::string address, int num_nodes, std::optional heartbeat_interval, @@ -781,7 +785,7 @@ NB_MODULE(xla_extension, m_nb) { nb::arg("cluster_register_timeout").none() = std::nullopt, nb::arg("shutdown_timeout").none() = std::nullopt); - m_nb.def( + m.def( "get_distributed_runtime_client", [](std::string address, int node_id, std::optional rpc_timeout, std::optional init_timeout, std::optional shutdown_timeout, @@ -829,21 +833,19 @@ NB_MODULE(xla_extension, m_nb) { nb::arg("shutdown_on_destruction").none() = std::nullopt, nb::arg("use_compression").none() = std::nullopt); - m_nb.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); }); + m.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); }); - m_nb.def("is_optimized_build", &IsOptimizedBuild); + m.def("is_optimized_build", &IsOptimizedBuild); - m_nb.def("json_to_pprof_profile", - xla::ValueOrThrowWrapper(JsonToPprofProfile), - "Encodes the JSON representation of a pprof Profile into its binary " - "protocol buffer encoding."); - m_nb.def("pprof_profile_to_json", - xla::ValueOrThrowWrapper(PprofProfileToJson), - "Decodes an uncompressed pprof Profile protocol buffer into a JSON " - "representation"); + m.def("json_to_pprof_profile", xla::ValueOrThrowWrapper(JsonToPprofProfile), + "Encodes the JSON representation of a pprof Profile into its binary " + "protocol buffer encoding."); + m.def("pprof_profile_to_json", xla::ValueOrThrowWrapper(PprofProfileToJson), + "Decodes an uncompressed pprof Profile protocol buffer into a JSON " + "representation"); - RegisterCompileOnlyClient(m_nb); - nb::class_(m_nb, "DeviceTopology") + RegisterCompileOnlyClient(m); + nb::class_(m, "DeviceTopology") .def("_make_compile_only_devices", [](std::shared_ptr topology) { if (!llvm::isa(*topology)) { @@ -865,7 +867,7 @@ NB_MODULE(xla_extension, m_nb) { return nb::bytes(serialized.data(), serialized.size()); }) .def("__getattr__", - [](ifrt::Topology& topology, std::string_view name) -> nb::object { + [](ifrt::Topology& topology, absl::string_view name) -> nb::object { const auto& attrs = topology.Attributes().map(); auto it = attrs.find(name); if (it != attrs.end()) { @@ -876,7 +878,7 @@ NB_MODULE(xla_extension, m_nb) { absl::StrCat("Unknown attribute ", name).c_str()); }); - nb::class_(m_nb, "Executable") + nb::class_(m, "Executable") .def("hlo_modules", ValueOrThrowWrapper(&ifrt::Executable::GetHloModules)) .def("get_output_memory_kinds", xla::ValueOrThrowWrapper(&ifrt::Executable::GetOutputMemoryKinds)) @@ -888,7 +890,6 @@ NB_MODULE(xla_extension, m_nb) { .def("get_parameter_shardings", &ifrt::Executable::GetParameterShardings) .def("get_compiled_memory_stats", xla::ValueOrThrowWrapper(&ifrt::Executable::GetCompiledMemoryStats)) - .def("compile_options", &ifrt::Executable::GetCompileOptions) .def("serialize", [](const ifrt::Executable& exec) -> nb::bytes { std::string serialized = ValueOrThrow(exec.Serialize()); @@ -899,34 +900,33 @@ NB_MODULE(xla_extension, m_nb) { return ifrt::ToPjRtAttributeMap(std::move(attrs)); }); - m_nb.def("is_asan", IsAsan); - m_nb.def("is_msan", IsMsan); - m_nb.def("is_tsan", IsTsan); - m_nb.def("is_sanitized", IsSanitized); + m.def("is_asan", IsAsan); + m.def("is_msan", IsMsan); + m.def("is_tsan", IsTsan); + m.def("is_sanitized", IsSanitized); - m_nb.def( + m.def( "batched_device_put", [](nb::object aval, nb::object sharding, std::vector xs, std::vector dst_devices, bool committed, bool force_copy, PjRtClient::HostBufferSemantics host_buffer_semantics) -> nb::object { return ValueOrThrow(PyArray::BatchedDevicePut( - nb::borrow(aval.ptr()), nb::borrow(sharding.ptr()), std::move(xs), - std::move(dst_devices), committed, force_copy, - host_buffer_semantics, jax::GetEnableX64())); + aval, sharding, std::move(xs), std::move(dst_devices), committed, + force_copy, host_buffer_semantics, jax::GetEnableX64())); }, nb::arg("aval"), nb::arg("sharding"), nb::arg("xs"), nb::arg("devices"), nb::arg("committed") = true, nb::arg("force_copy") = false, nb::arg("host_buffer_semantics") = PjRtClient::HostBufferSemantics::kImmutableZeroCopy); - m_nb.def("batched_block_until_ready", [](std::vector xs) { + m.def("batched_block_until_ready", [](std::vector xs) { ThrowIfError(PyArray::BatchedBlockUntilReady(std::move(xs))); }); - m_nb.def("check_and_canonicalize_memory_kind", - &jax::CheckAndCanonicalizeMemoryKind, nb::arg("memory_kind").none(), - nb::arg("device_list")); + m.def("check_and_canonicalize_memory_kind", + &jax::CheckAndCanonicalizeMemoryKind, nb::arg("memory_kind").none(), + nb::arg("device_list")); } // NOLINT(readability/fn_size) } // namespace xla diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index aadc1c2f6c71ca..46dd4a72edd1e7 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 300 +_version = 303 # Version number for MLIR:Python components. mlir_api_version = 57 @@ -70,15 +70,18 @@ def make_cpu_client( distributed_client=None, node_id=0, num_nodes=1, - collectives=None + collectives=None, + num_devices=None, ) -> ...: register_custom_call_handler('cpu', _xla.register_custom_call_target) + register_custom_type_id_handler('cpu', _xla.register_custom_type_id) return _xla.get_tfrt_cpu_client( asynchronous=asynchronous, distributed_client=distributed_client, node_id=node_id, num_nodes=num_nodes, collectives=collectives, + num_devices=num_devices, ) @@ -111,6 +114,8 @@ def make_gpu_client( config.collective_memory_size = options['collective_memory_size'] register_custom_call_handler('CUDA', _xla.register_custom_call_target) register_custom_call_handler('ROCM', _xla.register_custom_call_target) + register_custom_type_id_handler('CUDA', _xla.register_custom_type_id) + register_custom_type_id_handler('ROCM', _xla.register_custom_type_id) return _xla.get_gpu_client( asynchronous=True, @@ -625,6 +630,7 @@ def register_custom_call_handler( If a custom call handler for the platform already exist, calling this method is a no-op and it will not register a new handler. + Args: platform: the target platform. handler: the function to register a custom call. @@ -645,6 +651,67 @@ def register_custom_call_handler( del _custom_callback[xla_platform_name] +class CustomTypeIdHandler(Protocol): + + def __call__(self, name: str, capsule: Any) -> None: + ... + + +_custom_type_id_handler: dict[str, CustomTypeIdHandler] = {} +_custom_type_id: dict[str, Any] = {} +_custom_type_id_lock = threading.Lock() + + +def register_custom_type_id( + type_name: str, + type_id: Any, + platform: str = 'cpu', +) -> None: + """Register a custom type id for use with the FFI. + + Args: + type_name: a unique name for the type. + type_id: a PyCapsule object containing a pointer to the ``ffi::TypeId``. + platform: the target platform. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_type_id_lock: + if xla_platform_name in _custom_type_id_handler: + _custom_type_id_handler[xla_platform_name](type_name, type_id) + else: + _custom_type_id.setdefault(xla_platform_name, []).append( + (type_name, type_id) + ) + + +def register_custom_type_id_handler( + platform: str, handler: CustomTypeIdHandler +) -> None: + """Register a custom type id handler and use it to register existing type ids. + + If a custom type id handler for the platform already exist, calling this + method is a no-op and it will not register a new handler. + + Args: + platform: the target platform. + handler: the function to register a custom type id. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_type_id_handler: + logger.debug( + 'Custom type id handler for %s is already register. Will not ' + 'register a new one', + xla_platform_name, + ) + return + _custom_type_id_handler[xla_platform_name] = handler + if xla_platform_name in _custom_type_id: + for name, capsule in _custom_type_id[xla_platform_name]: + handler(name, capsule) + del _custom_type_id[xla_platform_name] + + register_custom_call_partitioner = _xla.register_custom_call_partitioner encode_inspect_sharding_callback = _xla.encode_inspect_sharding_callback hlo_sharding_util = _xla.hlo_sharding_util diff --git a/third_party/xla/xla/python/xla_client.pyi b/third_party/xla/xla/python/xla_client.pyi index 07149713148a74..efc3d2573b2224 100644 --- a/third_party/xla/xla/python/xla_client.pyi +++ b/third_party/xla/xla/python/xla_client.pyi @@ -89,6 +89,7 @@ def make_cpu_client( node_id: int = ..., num_nodes: int = ..., collectives: _xla.CpuCollectives | None = ..., + num_devices: int | None = ..., ) -> Client: ... @@ -297,6 +298,14 @@ def register_custom_call_handler( def custom_call_targets(platform: str) -> dict[str, Any]: ... +def register_custom_type_id( + type_name: str, + type_id: Any, + platform: str = ..., +) -> None: ... + +def register_custom_type_id_handler(platform: str, handler: Any) -> None: ... + def encode_inspect_sharding_callback(handler: Any) -> bytes: ... register_custom_call_partitioner = _xla.register_custom_call_partitioner diff --git a/third_party/xla/xla/python/xla_client_test.py b/third_party/xla/xla/python/xla_client_test.py index 6aef3213764cce..f0cecc9903295e 100644 --- a/third_party/xla/xla/python/xla_client_test.py +++ b/third_party/xla/xla/python/xla_client_test.py @@ -52,7 +52,6 @@ xla_client._xla.jax_jit.set_thread_local_state_initialization_callback( lambda: None ) -xla_client._xla.jax_jit.global_state().enable_memories = False bfloat16 = xla_client.bfloat16 # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. @@ -204,8 +203,10 @@ def setUp(self): if self.backend.platform == "cpu" and not _CUSTOM_CALLS_REGISTERED: for name, fn in custom_calls_testlib.registrations().items(): xla_client.register_custom_call_target( - name, {"execute": fn}, platform="cpu", api_version=1 + name, fn, platform="cpu", api_version=1 ) + for name, val in custom_calls_testlib.type_ids().items(): + xla_client.register_custom_type_id(name, val, platform="cpu") _CUSTOM_CALLS_REGISTERED = True def _NewComputation(self, name=None): @@ -323,7 +324,9 @@ def testFingerprint(self): xla_computation_to_mlir_module(computation)) fingerprint = executable.fingerprint if ( - self.backend.platform == "tpu" or self.backend.platform == "gpu" + self.backend.platform == "tpu" + or self.backend.platform == "gpu" + or self.backend.platform == "cpu" ) and not (cloud_tpu or pathways or pathways_ifrt): logging.info("fingerprint: %s", fingerprint) self.assertNotEmpty(fingerprint) @@ -617,6 +620,21 @@ def testCustomCallTypedFfiSubtract(self): ) self._ExecuteAndCompareClose(c, expected=[-1.75]) + def testStatefulCustomCall(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + ops.CustomCallWithLayout( + c, + b"stateful", + operands=[], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.int32), (), ()), + operand_shapes_with_layout=[], + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_TYPED_FFI) + self._ExecuteAndCompareClose(c, expected=[42]) + def testCustomCallLookup(self): if self.backend.platform != "cpu": self.skipTest("Test requires cpu platform") @@ -2739,6 +2757,8 @@ def testDevices(self): def testLocalDevices(self): self.assertNotEmpty(self.backend.local_devices()) + if self.backend.platform == "cpu": + self.assertLen(self.backend.local_devices(), 2) def testGetAllDevices(self): # TODO(hyeontaek): Remove this method once we have a unified API for @@ -3674,7 +3694,7 @@ def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw): backends = { - "cpu": xla_client.make_cpu_client, + "cpu": functools.partial(xla_client.make_cpu_client, num_devices=2), "gpu": xla_client.make_gpu_client, } diff --git a/third_party/xla/xla/python/xla_compiler.cc b/third_party/xla/xla/python/xla_compiler.cc index 02610edd83d1cb..91bc9690a06e86 100644 --- a/third_party/xla/xla/python/xla_compiler.cc +++ b/third_party/xla/xla/python/xla_compiler.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -101,7 +100,7 @@ struct type_caster { NB_TYPE_CASTER_FROM_PYTHON_ONLY(xla::OpMetadata, const_name("xla::OpMetadata")); - bool from_python(handle h, uint8_t, cleanup_list*) { + bool from_python(handle h, uint8_t, cleanup_list*) noexcept { handle op_type = getattr(h, "op_type"); if (!op_type.is_none()) { value.set_op_type(cast(op_type)); @@ -376,6 +375,20 @@ absl::Status PyRegisterCustomCallTarget(const std::string& fn_name, api_version)); } +absl::Status PyRegisterCustomTypeId(absl::string_view type_name, + nb::object type_id) { + nb::capsule capsule; + if (!nb::try_cast(type_id, capsule)) { + return absl::InvalidArgumentError( + "The type_id argument to register_custom_call_type_id must be a " + "PyCapsule object holding a pointer to a XLA_FFI_TypeId."); + } + XLA_FFI_TypeId* type_id_ptr = + reinterpret_cast(static_cast(capsule.data())); + return ffi::TakeStatus(ffi::Ffi::RegisterTypeId(xla::ffi::GetXlaFfiApi(), + type_name, type_id_ptr)); +} + template void DefRepeatedProperty(nb::class_& cls, const char* name, Container* (T::*getter)()) { @@ -729,7 +742,8 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { nb::cast(obj)); }, nb::arg("dtype").none() = nb::none(), - nb::arg("copy").none() = nb::none()); + nb::arg("copy").none() = nb::none()) + .def("shape", &Literal::shape); nb::class_(m, "XlaComputation") .def("__init__", @@ -1141,7 +1155,8 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { for (const auto& [name, registration] : *ffi_handlers) { nb::dict bundle; - auto export_handler = [&](std::string_view name, XLA_FFI_Handler* h) { + auto export_handler = [&](absl::string_view name, + XLA_FFI_Handler* h) { if (h != nullptr) { bundle[nb::str(name.data(), name.size())] = nb::capsule(reinterpret_cast(h)); @@ -1161,6 +1176,13 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { .value("UPDATE", DebugOptions::AUTOTUNE_CACHE_MODE_UPDATE) .value("READ", DebugOptions::AUTOTUNE_CACHE_MODE_READ); + m.def( + "register_custom_type_id", + [](absl::string_view type_name, nb::object type_id) { + xla::ThrowIfError(PyRegisterCustomTypeId(type_name, type_id)); + }, + nb::arg("type_name"), nb::arg("type_id")); + nb::class_(m, "DebugOptions") .def("__repr__", &DebugOptions::DebugString) .def_prop_rw("xla_backend_optimization_level", @@ -1503,31 +1525,45 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { .def("is_tiled", &xla::HloSharding::IsTiled) .def("tile", [](const xla::HloSharding& self, xla::Shape shape) { return self.TileShape(shape); }) - .def("tuple_elements", - [](const xla::HloSharding& self) { return self.tuple_elements(); }) - .def("num_devices", - [](const xla::HloSharding& self) { - return self.tile_assignment().num_elements(); - }) - .def("num_dimensions", - [](const xla::HloSharding& self) { - return self.tile_assignment().num_dimensions(); - }) - .def("tile_assignment_dimensions", - [](const xla::HloSharding& self) { - absl::Span span = - self.tile_assignment().dimensions(); - CHECK(span.data()); - return span; - }) - .def("tile_assignment_devices", - [](const xla::HloSharding& self) { - auto span = - absl::MakeConstSpan(self.tile_assignment().array().data(), - self.tile_assignment().num_elements()); - CHECK(span.data()); - return span; - }) + // tile_assignment.array() is computed using an internal cache, + // which is why nb::lock_self() is required. It may be preferable to move + // this locking into the TileAssignment class if we find it to race with + // non-Python users of that class. + .def( + "tuple_elements", + [](const xla::HloSharding& self) { return self.tuple_elements(); }, + nb::lock_self()) + .def( + "num_devices", + [](const xla::HloSharding& self) { + return self.tile_assignment().num_elements(); + }, + nb::lock_self()) + .def( + "num_dimensions", + [](const xla::HloSharding& self) { + return self.tile_assignment().num_dimensions(); + }, + nb::lock_self()) + .def( + "tile_assignment_dimensions", + [](const xla::HloSharding& self) { + absl::Span span = + self.tile_assignment().dimensions(); + CHECK(span.data()); + return span; + }, + nb::lock_self()) + .def( + "tile_assignment_devices", + [](const xla::HloSharding& self) { + auto span = + absl::MakeConstSpan(self.tile_assignment().array().data(), + self.tile_assignment().num_elements()); + CHECK(span.data()); + return span; + }, + nb::lock_self()) .def("replicate_on_last_tile_dim", &xla::HloSharding::ReplicateOnLastTileDim) .def("subgroup_types", &xla::HloSharding::subgroup_types) diff --git a/third_party/xla/xla/python/xla_extension/__init__.pyi b/third_party/xla/xla/python/xla_extension/__init__.pyi index 003482ac200840..ec3ff508a21cb9 100644 --- a/third_party/xla/xla/python/xla_extension/__init__.pyi +++ b/third_party/xla/xla/python/xla_extension/__init__.pyi @@ -169,6 +169,7 @@ class Literal: def __array__( self, dtype: Optional[np.dtype] = None, copy: Optional[bool] = None ) -> np.ndarray: ... + def shape(self) -> Shape: ... class XlaComputation: def __init__(self, serialized_hlo_module_proto: bytes) -> None: ... @@ -500,7 +501,7 @@ class PjRtLayout: def __eq__(self, other: PjRtLayout) -> bool: ... def __hash__(self) -> int: ... def __getstate__(self) -> Any: ... - def __setstate__(self, Any): ... + def __setstate__(self, _: Any): ... def _xla_layout(self) -> Layout: ... class GpuAllocatorConfig: @@ -606,6 +607,7 @@ def get_tfrt_cpu_client( node_id: int = ..., num_nodes: int = ..., collectives: Optional[CpuCollectives] = ..., + num_devices: int | None = ..., ) -> Client: ... def get_gpu_client( asynchronous: bool = ..., @@ -735,7 +737,6 @@ class LoadedExecutable: def get_parameter_layouts(self) -> List[Layout]: ... def get_output_layouts(self) -> List[Layout]: ... def keep_alive(self) -> None: ... - def compile_options(self) -> CompileOptions: ... def cost_analysis(self) -> Dict[str, Any]: ... traceback: Traceback fingerprint: Optional[bytes] @@ -749,7 +750,6 @@ class Executable: def get_output_layouts(self) -> List[Layout]: ... def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... def serialize(self) -> str: ... - def compile_options(self) -> CompileOptions: ... def cost_analysis(self) -> Dict[str, Any]: ... class DeviceTopology: @@ -829,6 +829,8 @@ class DistributedRuntimeClient: def blocking_key_value_get_bytes( self, key: str, timeout_in_ms: int ) -> _Status: ... + def key_value_try_get(self, key: str) -> _Status: ... + def key_value_try_get_bytes(self, key: str) -> _Status: ... def key_value_dir_get(self, key: str) -> _Status: ... def key_value_dir_get_bytes(self, key: str) -> _Status: ... def key_value_set(self, key: str, value: str, diff --git a/third_party/xla/xla/python/xla_extension/jax_jit.pyi b/third_party/xla/xla/python/xla_extension/jax_jit.pyi index 931ee12dfb8779..aa731b5bfaa98b 100644 --- a/third_party/xla/xla/python/xla_extension/jax_jit.pyi +++ b/third_party/xla/xla/python/xla_extension/jax_jit.pyi @@ -27,7 +27,6 @@ Device = xla_extension.Device class JitState: disable_jit: Optional[bool] enable_x64: Optional[bool] - enable_memories: Optional[bool] default_device: Optional[Any] extra_jit_context: Optional[Any] post_hook: Optional[Callable[..., Any]] diff --git a/third_party/xla/xla/python/xla_extension/pytree.pyi b/third_party/xla/xla/python/xla_extension/pytree.pyi index a777e364e65036..a90bb59ad876fd 100644 --- a/third_party/xla/xla/python/xla_extension/pytree.pyi +++ b/third_party/xla/xla/python/xla_extension/pytree.pyi @@ -48,6 +48,9 @@ class PyTreeRegistry: def flatten_one_level( self, tree: Any ) -> Optional[Tuple[Iterable[Any], Any]]: ... + def flatten_one_level_with_keys( + self, tree: Any + ) -> Optional[Tuple[Iterable[_KeyLeafPair], Any]]: ... def flatten_with_path( self, tree: Any, diff --git a/third_party/xla/xla/python/xplane_to_profile_instructions.cc b/third_party/xla/xla/python/xplane_to_profile_instructions.cc index c95bd724f3d5fb..8d5dbf23223e08 100644 --- a/third_party/xla/xla/python/xplane_to_profile_instructions.cc +++ b/third_party/xla/xla/python/xplane_to_profile_instructions.cc @@ -117,7 +117,7 @@ void GetXPlaneLatencyInfo( if (fingerprint.has_value()) { key = absl::StrCat(fingerprint.value(), kCostNameSep, hlo_name.value()); } - (*hlo_latency_info)[key].durations.emplace_back(latency); + (*hlo_latency_info)[key].durations.push_back(latency); }); }); } @@ -194,7 +194,7 @@ absl::Status ConvertXplaneUnderLogdirToProfiledInstructionsProto( tensorflow::profiler::XSpace xspace; TF_RETURN_IF_ERROR( ReadBinaryProto(tsl::Env::Default(), xspace_path, &xspace)); - xspaces.emplace_back(xspace); + xspaces.push_back(xspace); } } diff --git a/third_party/xla/xla/python/xplane_to_profile_instructions_test.cc b/third_party/xla/xla/python/xplane_to_profile_instructions_test.cc index ee77891fb6b61c..75f6d8aee2eedf 100644 --- a/third_party/xla/xla/python/xplane_to_profile_instructions_test.cc +++ b/third_party/xla/xla/python/xplane_to_profile_instructions_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/hlo.pb.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/profiler/convert/xla_op_utils.h" #include "xla/tsl/profiler/rpc/client/save_profile.h" #include "xla/tsl/profiler/utils/file_system_utils.h" diff --git a/third_party/xla/xla/refcounting_hash_map.h b/third_party/xla/xla/refcounting_hash_map.h deleted file mode 100644 index 68520a636cfcbb..00000000000000 --- a/third_party/xla/xla/refcounting_hash_map.h +++ /dev/null @@ -1,104 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_REFCOUNTING_HASH_MAP_H_ -#define XLA_REFCOUNTING_HASH_MAP_H_ - -#include -#include - -#include "absl/base/thread_annotations.h" -#include "absl/container/node_hash_map.h" -#include "absl/functional/function_ref.h" -#include "absl/status/statusor.h" -#include "absl/synchronization/mutex.h" - -namespace xla { - -// RefcountingHashMap is an "eager, thread-safe cache". -// -// Given a key k you can retrieve a shared_ptr to a value v. If k is not -// already in the map, we construct a new V; if it is already in the map, we'll -// return the existing v. Once all shared_ptrs are destroyed, the entry is -// removed from the map. -// -// This class is thread-safe. -// -// Word to the wise: You might want an erase() function here that removes a -// value from the map but leaves existing shared_ptrs intact. My experience is, -// this is extremely complicated to implement correctly. -template -class RefcountingHashMap { - public: - // Default-constructs new values. - RefcountingHashMap() = default; - - // Not copyable or movable because this contains internal pointers (namely, - // instances of Deleter contain pointers to `this` and into `map_`). - RefcountingHashMap(const RefcountingHashMap&) = delete; - RefcountingHashMap(RefcountingHashMap&&) = delete; - RefcountingHashMap& operator=(const RefcountingHashMap&) = delete; - RefcountingHashMap& operator=(RefcountingHashMap&&) = delete; - - // Gets the value for the given key. - // - // If the map doesn't contain a live value for the key, constructs one - // using `value_factory`. - std::shared_ptr GetOrCreateIfAbsent( - const K& key, - absl::FunctionRef(const K&)> value_factory) { - absl::MutexLock lock(&mu_); - auto it = map_.find(key); - if (it != map_.end()) { - // We ensure that the entry has not expired in case deleter was running - // when we have entered this block. - if (std::shared_ptr value = it->second.lock()) { - return value; - } - } - - // Create entry in the map and then set its value, so the value can - // contain a pointer back into the map. - it = map_.emplace(key, std::weak_ptr()).first; - std::shared_ptr value(value_factory(key).release(), - Deleter{it->first, *this}); - it->second = value; // Set the weak ptr to the shared ptr. - return value; - } - - private: - struct Deleter { - const K& key; // Points into parent->map_. - RefcountingHashMap& parent; - - void operator()(V* v) { - delete v; - absl::MutexLock lock(&parent.mu_); - // We must check if that the entry is still expired in case the value was - // replaced while the deleter was running. - auto it = parent.map_.find(key); - if (it != parent.map_.end() && it->second.expired()) { - parent.map_.erase(it); - } - } - }; - - absl::Mutex mu_; - absl::node_hash_map> map_ ABSL_GUARDED_BY(mu_); -}; - -} // namespace xla - -#endif // XLA_REFCOUNTING_HASH_MAP_H_ diff --git a/third_party/xla/xla/refcounting_hash_map_test.cc b/third_party/xla/xla/refcounting_hash_map_test.cc deleted file mode 100644 index 71211cc36c02e0..00000000000000 --- a/third_party/xla/xla/refcounting_hash_map_test.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/refcounting_hash_map.h" - -#include -#include -#include - -#include "xla/test.h" - -namespace xla { -namespace { - -struct DeleteNotifier { - DeleteNotifier() = default; - DeleteNotifier(const DeleteNotifier&) = delete; - DeleteNotifier& operator=(const DeleteNotifier&) = delete; - DeleteNotifier(DeleteNotifier&& o) noexcept : fn(std::move(o.fn)) { - o.fn = nullptr; - } - DeleteNotifier& operator=(DeleteNotifier&& o) noexcept { - fn = o.fn; - o.fn = nullptr; - return *this; - } - - ~DeleteNotifier() { - if (fn) { - fn(); - } - } - - std::function fn; -}; - -TEST(RefcountingHashMapTest, PointerIdentity) { - RefcountingHashMap m; - auto factory = [](const int) { return std::make_unique(); }; - std::shared_ptr a = m.GetOrCreateIfAbsent(0, factory); - std::shared_ptr b = m.GetOrCreateIfAbsent(0, factory); - std::shared_ptr c = m.GetOrCreateIfAbsent(1, factory); - EXPECT_EQ(a.get(), b.get()); - EXPECT_NE(a.get(), c.get()); -} - -TEST(RefcountingHashMapTest, DefaultInitialized) { - RefcountingHashMap m; - auto factory = [](const int) { return std::make_unique(); }; - EXPECT_EQ(*m.GetOrCreateIfAbsent(42, factory), 0); -} - -TEST(RefcountingHashMapTest, DeletesEagerly) { - RefcountingHashMap m; - bool deleted = false; - auto factory = [](const int) { return std::make_unique(); }; - auto handle = m.GetOrCreateIfAbsent(0, factory); - handle->fn = [&] { deleted = true; }; - EXPECT_FALSE(deleted); - handle = nullptr; - EXPECT_TRUE(deleted); -} - -TEST(RefcountingHashMapTest, CustomFactory) { - RefcountingHashMap m; - auto factory = [](const int x) { return std::make_unique(x + 1); }; - EXPECT_EQ(*m.GetOrCreateIfAbsent(0, factory), 1); - EXPECT_EQ(*m.GetOrCreateIfAbsent(100, factory), 101); -} - -} // anonymous namespace -} // namespace xla diff --git a/third_party/xla/xla/reference_util.cc b/third_party/xla/xla/reference_util.cc index 09419db81191dd..25ec47105e7b7d 100644 --- a/third_party/xla/xla/reference_util.cc +++ b/third_party/xla/xla/reference_util.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/functional/function_ref.h" +#include "absl/log/check.h" #include "absl/types/span.h" #include "xla/array2d.h" #include "xla/array3d.h" diff --git a/third_party/xla/xla/reference_util_test.cc b/third_party/xla/xla/reference_util_test.cc index f53b584aa14b66..4ad7c660f8c902 100644 --- a/third_party/xla/xla/reference_util_test.cc +++ b/third_party/xla/xla/reference_util_test.cc @@ -19,14 +19,15 @@ limitations under the License. #include #include +#include #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" #include "xla/error_spec.h" #include "xla/hlo/builder/padding.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/test.h" #include "xla/tests/literal_test_util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/runtime/BUILD b/third_party/xla/xla/runtime/BUILD index 15f2e48a4ce8e6..d9ba81074bc04c 100644 --- a/third_party/xla/xla/runtime/BUILD +++ b/third_party/xla/xla/runtime/BUILD @@ -32,6 +32,7 @@ xla_cc_test( deps = [ ":buffer_use", "//xla/service:buffer_assignment", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", ], diff --git a/third_party/xla/xla/runtime/buffer_use_test.cc b/third_party/xla/xla/runtime/buffer_use_test.cc index 31050af3125214..fa0de3a2cc74b3 100644 --- a/third_party/xla/xla/runtime/buffer_use_test.cc +++ b/third_party/xla/xla/runtime/buffer_use_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/runtime/buffer_use.h" +#include #include "xla/service/buffer_assignment.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 31b2057f353801..458f8e9e0ac7a3 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -214,8 +214,8 @@ xla_cc_test( cc_library( name = "all_reduce_folder", hdrs = ["all_reduce_folder.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:all_reduce_folder instead.", - deps = ["//xla/hlo/transforms:all_reduce_folder"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:all_reduce_folder instead.", + deps = ["//xla/hlo/transforms/simplifiers:all_reduce_folder"], ) cc_library( @@ -231,22 +231,22 @@ cc_library( cc_library( name = "broadcast_canonicalizer", hdrs = ["broadcast_canonicalizer.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:broadcast_canonicalizer instead.", - deps = ["//xla/hlo/transforms:broadcast_canonicalizer"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:broadcast_canonicalizer instead.", + deps = ["//xla/hlo/transforms/simplifiers:broadcast_canonicalizer"], ) cc_library( name = "bfloat16_conversion_folding", hdrs = ["bfloat16_conversion_folding.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:bfloat16_conversion_folding instead.", - deps = ["//xla/hlo/transforms:bfloat16_conversion_folding"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:bfloat16_conversion_folding instead.", + deps = ["//xla/hlo/transforms/simplifiers:bfloat16_conversion_folding"], ) cc_library( name = "float_normalization", hdrs = ["float_normalization.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:float_normalization instead.", - deps = ["//xla/hlo/transforms:float_normalization"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:float_normalization instead.", + deps = ["//xla/hlo/transforms/simplifiers:float_normalization"], ) cc_library( @@ -256,22 +256,47 @@ cc_library( deps = ["//xla/hlo/transforms:bfloat16_propagation"], ) +cc_library( + name = "collective_permute_utils", + srcs = ["collective_permute_utils.cc"], + hdrs = ["collective_permute_utils.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service/graphcycles", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "collective_permute_utils_test", + srcs = ["collective_permute_utils_test.cc"], + deps = [ + ":collective_permute_utils", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "collective_permute_decomposer", srcs = ["collective_permute_decomposer.cc"], hdrs = ["collective_permute_decomposer.h"], deps = [ ":collective_ops_utils", + ":collective_permute_utils", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service/gpu:backend_configs_cc", - "//xla/service/graphcycles", + "//xla/tsl/platform:errors", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) @@ -283,15 +308,16 @@ xla_cc_test( ":collective_ops_utils", ":collective_permute_decomposer", "//xla/hlo/ir:hlo", - "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/utils:hlo_matchers", "//xla/hlo/utils:hlo_query", - "//xla/service/gpu:backend_configs_cc", - "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", + "@com_google_googletest//:gtest_main", ], ) @@ -380,7 +406,7 @@ cc_library( "//xla/hlo/ir:hlo_instruction_utils", "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -413,8 +439,8 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/testlib:filecheck", "//xla/hlo/utils:hlo_matchers", - "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", @@ -609,8 +635,8 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", - "//xla/hlo/transforms:hlo_constant_splitter", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_constant_splitter", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -634,6 +660,7 @@ cc_library( deps = [ "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/service/spmd:shard_barrier_partitioner", "//xla/service/spmd/shardy:constants", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", @@ -772,13 +799,8 @@ cc_library( name = "pattern_matcher_gmock", testonly = 1, hdrs = ["pattern_matcher_gmock.h"], - deps = [ - ":pattern_matcher", - "//xla:shape_util", - "//xla:test", - "//xla/hlo/ir:hlo", - "@local_tsl//tsl/platform:test", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/testlib:pattern_matcher_gmock instead.", + deps = ["//xla/hlo/testlib:pattern_matcher_gmock"], ) xla_cc_test( @@ -914,8 +936,8 @@ xla_cc_test( cc_library( name = "flatten_call_graph", hdrs = ["flatten_call_graph.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:flatten_call_graph instead.", - deps = ["//xla/hlo/transforms:flatten_call_graph"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:flatten_call_graph instead.", + deps = ["//xla/hlo/transforms/simplifiers:flatten_call_graph"], ) cc_library( @@ -931,7 +953,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/service/spmd/shardy:constants", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -958,8 +980,8 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:filecheck", "//xla/hlo/utils:hlo_matchers", - "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -972,8 +994,8 @@ xla_cc_test( cc_library( name = "hlo_computation_deduplicator", hdrs = ["hlo_computation_deduplicator.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:hlo_computation_deduplicator instead.", - deps = ["//xla/hlo/transforms:hlo_computation_deduplicator"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:hlo_computation_deduplicator instead.", + deps = ["//xla/hlo/transforms/simplifiers:hlo_computation_deduplicator"], ) cc_library( @@ -1157,12 +1179,14 @@ cc_library( ":hlo_value", "//xla:debug_options_flags", "//xla:shape_util", + "//xla:side_effect_util", "//xla:status_macros", "//xla:util", "//xla:xla_proto_cc", "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:ptrvec", "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -1175,8 +1199,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", ], ) @@ -1720,9 +1742,9 @@ xla_cc_test( "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", - "//xla/hlo/transforms:flatten_call_graph", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:hlo_memory_scheduler", + "//xla/hlo/transforms/simplifiers:flatten_call_graph", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", "//xla/service/memory_space_assignment", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -1834,21 +1856,21 @@ xla_cc_test( name = "hlo_schedule_test", srcs = ["hlo_schedule_test.cc"], deps = [ + ":buffer_value", + "//xla:literal_util", "//xla:shape_util", "//xla:test_helpers", - "//xla:types", "//xla:xla_data_proto_cc", - "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:hlo_memory_scheduler", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", ], ) @@ -1873,9 +1895,9 @@ xla_cc_test( cc_library( name = "hlo_memory_scheduler", hdrs = ["hlo_memory_scheduler.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:hlo_memory_scheduler instead.", + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:hlo_memory_scheduler instead.", local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = ["//xla/hlo/transforms:hlo_memory_scheduler"], + deps = ["//xla/hlo/transforms/simplifiers:hlo_memory_scheduler"], ) cc_library( @@ -1947,7 +1969,7 @@ cc_library( "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", @@ -2021,14 +2043,22 @@ xla_cc_test( ":hlo_creation_utils", ":pattern_matcher", ":pattern_matcher_gmock", + "//xla:array2d", + "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:test", "//xla:xla_data_proto_cc", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:test", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", ], ) @@ -2227,13 +2257,16 @@ xla_cc_test( shard_count = 12, deps = [ ":triangular_solve_expander", + "//xla:array2d", + "//xla:error_spec", "//xla:literal", + "//xla:literal_util", "//xla:reference_util", - "//xla:test", - "//xla:types", "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", ], ) @@ -2307,15 +2340,15 @@ cc_library( name = "algebraic_simplifier", hdrs = ["algebraic_simplifier.h"], copts = tsl_copts(), - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:algebraic_simplifier instead.", - deps = ["//xla/hlo/transforms:algebraic_simplifier"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:algebraic_simplifier instead.", + deps = ["//xla/hlo/transforms/simplifiers:algebraic_simplifier"], ) cc_library( name = "tree_reduction_rewriter", hdrs = ["tree_reduction_rewriter.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:tree_reduction_rewriter instead.", - deps = ["//xla/hlo/transforms:tree_reduction_rewriter"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:tree_reduction_rewriter instead.", + deps = ["//xla/hlo/transforms/simplifiers:tree_reduction_rewriter"], ) xla_test( @@ -2332,8 +2365,8 @@ xla_test( cc_library( name = "simplify_fp_conversions", hdrs = ["simplify_fp_conversions.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:simplify_fp_conversions instead.", - deps = ["//xla/hlo/transforms:simplify_fp_conversions"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:simplify_fp_conversions instead.", + deps = ["//xla/hlo/transforms/simplifiers:simplify_fp_conversions"], ) cc_library( @@ -2567,8 +2600,8 @@ xla_cc_test( cc_library( name = "batch_dot_simplification", hdrs = ["batch_dot_simplification.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:batch_dot_simplification instead.", - deps = ["//xla/hlo/transforms:batch_dot_simplification"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:batch_dot_simplification instead.", + deps = ["//xla/hlo/transforms/simplifiers:batch_dot_simplification"], ) xla_cc_test( @@ -2645,8 +2678,8 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -2682,8 +2715,8 @@ xla_cc_test( cc_library( name = "convolution_group_converter", hdrs = ["convolution_group_converter.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:convolution_group_converter instead.", - deps = ["//xla/hlo/transforms:convolution_group_converter"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:convolution_group_converter instead.", + deps = ["//xla/hlo/transforms/simplifiers:convolution_group_converter"], ) cc_library( @@ -2754,7 +2787,7 @@ cc_library( "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", @@ -2772,10 +2805,10 @@ xla_cc_test( ":scan_loop_accumulator_input_unification", "//xla:literal", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", "//xla/tests:test_utils", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/log", "@com_google_googletest//:gtest", @@ -2797,7 +2830,7 @@ cc_library( "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -2832,11 +2865,13 @@ cc_library( deps = [ ":call_inliner", ":collective_ops_utils", + ":constant_value", ":hlo_buffer", ":hlo_creation_utils", ":hlo_cse", ":hlo_value", ":pattern_matcher", + ":value_range", ":while_loop_constant_sinking", "//xla:comparison_util", "//xla:literal", @@ -2850,11 +2885,12 @@ cc_library( "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:flatten_call_graph", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:flatten_call_graph", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/algorithm", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -2873,9 +2909,9 @@ xla_cc_test( ":while_loop_unroller", "//xla:literal", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -2911,7 +2947,7 @@ cc_library( "//xla/hlo/analysis:while_loop_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -2938,8 +2974,8 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -2974,22 +3010,22 @@ cc_library( cc_library( name = "dot_dimension_merger", hdrs = ["dot_dimension_merger.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:dot_dimension_merger instead.", - deps = ["//xla/hlo/transforms:dot_dimension_merger"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:dot_dimension_merger instead.", + deps = ["//xla/hlo/transforms/simplifiers:dot_dimension_merger"], ) cc_library( name = "dot_merger", hdrs = ["dot_merger.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:dot_merger instead.", - deps = ["//xla/hlo/transforms:dot_merger"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:dot_merger instead.", + deps = ["//xla/hlo/transforms/simplifiers:dot_merger"], ) cc_library( name = "convert_mover", hdrs = ["convert_mover.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:convert_mover instead.", - deps = ["//xla/hlo/transforms:convert_mover"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:convert_mover instead.", + deps = ["//xla/hlo/transforms/simplifiers:convert_mover"], ) cc_library( @@ -3047,15 +3083,15 @@ xla_cc_test( cc_library( name = "tuple_simplifier", hdrs = ["tuple_simplifier.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:tuple_simplifier instead.", - deps = ["//xla/hlo/transforms:tuple_simplifier"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:tuple_simplifier instead.", + deps = ["//xla/hlo/transforms/simplifiers:tuple_simplifier"], ) cc_library( name = "reshape_mover", hdrs = ["reshape_mover.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:reshape_mover instead.", - deps = ["//xla/hlo/transforms:reshape_mover"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:reshape_mover instead.", + deps = ["//xla/hlo/transforms/simplifiers:reshape_mover"], ) cc_library( @@ -3130,8 +3166,8 @@ cc_library( cc_library( name = "dynamic_dimension_simplifier", hdrs = ["dynamic_dimension_simplifier.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:dynamic_dimension_simplifier instead.", - deps = ["//xla/hlo/transforms:dynamic_dimension_simplifier"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:dynamic_dimension_simplifier instead.", + deps = ["//xla/hlo/transforms/simplifiers:dynamic_dimension_simplifier"], ) cc_library( @@ -3156,7 +3192,7 @@ cc_library( "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/tsl/lib/monitoring:gauge", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -3176,7 +3212,7 @@ cc_library( xla_test( name = "dynamic_padder_test", srcs = ["dynamic_padder_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":dynamic_dimension_inference", ":dynamic_padder", @@ -3194,10 +3230,10 @@ xla_test( "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", - "//xla/hlo/transforms:algebraic_simplifier", - "//xla/hlo/transforms:dynamic_dimension_simplifier", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:dynamic_dimension_simplifier", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/hlo/utils:hlo_matchers", "//xla/tests:client_library_test_base", "//xla/tests:hlo_test_base", @@ -3233,8 +3269,8 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", "//xla/hlo/utils:hlo_matchers", - "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -3490,25 +3526,35 @@ xla_cc_test( name = "hlo_module_test", srcs = ["hlo_module_test.cc"], deps = [ + ":buffer_value", ":computation_placer_hdr", + ":hlo_module_config", ":test_compilation_environment_proto_cc", - "//xla:literal", + "//xla:comparison_util", + "//xla:debug_options_flags", + "//xla:literal_util", "//xla:shape_util", "//xla:test", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/transforms:hlo_memory_scheduler", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/lib/strings:proto_serialization", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:casts", + "@local_tsl//tsl/platform:protobuf", ], ) @@ -3720,8 +3766,8 @@ cc_library( "//xla/hlo/analysis:tuple_points_to_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -3765,8 +3811,8 @@ cc_library( "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -3868,8 +3914,8 @@ cc_library( cc_library( name = "hlo_dce", hdrs = ["hlo_dce.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:hlo_dce instead.", - deps = ["//xla/hlo/transforms:hlo_dce"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:hlo_dce instead.", + deps = ["//xla/hlo/transforms/simplifiers:hlo_dce"], ) cc_library( @@ -3884,8 +3930,8 @@ cc_library( "//xla/hlo/analysis:hlo_liveness_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:errors", @@ -3989,16 +4035,16 @@ xla_cc_test( cc_library( name = "hlo_rematerialization", hdrs = ["hlo_rematerialization.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:hlo_rematerialization instead.", - deps = ["//xla/hlo/transforms:hlo_rematerialization"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:hlo_rematerialization instead.", + deps = ["//xla/hlo/transforms/simplifiers:hlo_rematerialization"], ) cc_library( name = "hlo_rematerialization_test_utils", testonly = 1, hdrs = ["hlo_rematerialization_test_utils.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:hlo_rematerialization_test_utils instead.", - deps = ["//xla/hlo/transforms:hlo_rematerialization_test_utils"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:hlo_rematerialization_test_utils instead.", + deps = ["//xla/hlo/transforms/simplifiers:hlo_rematerialization_test_utils"], ) xla_cc_test( @@ -4036,7 +4082,7 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", - "//xla/hlo/transforms:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -4075,9 +4121,9 @@ cc_library( deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/pass:hlo_pass_pipeline instead.", deps = [ ":compilation_stats", - ":hlo_pass", "//xla:types", "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -4127,7 +4173,7 @@ xla_cc_test( cc_library( name = "hlo_constant_folding", hdrs = ["hlo_constant_folding.h"], - deps = ["//xla/hlo/transforms:hlo_constant_folding"], + deps = ["//xla/hlo/transforms/simplifiers:hlo_constant_folding"], ) cc_library( @@ -4213,15 +4259,15 @@ xla_cc_test( cc_library( name = "hlo_element_type_converter", hdrs = ["hlo_element_type_converter.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:hlo_element_type_converter instead.", - deps = ["//xla/hlo/transforms:hlo_element_type_converter"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:hlo_element_type_converter instead.", + deps = ["//xla/hlo/transforms/simplifiers:hlo_element_type_converter"], ) cc_library( name = "conditional_canonicalizer", hdrs = ["conditional_canonicalizer.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:conditional_canonicalizer instead.", - deps = ["//xla/hlo/transforms:conditional_canonicalizer"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:conditional_canonicalizer instead.", + deps = ["//xla/hlo/transforms/simplifiers:conditional_canonicalizer"], ) cc_library( @@ -4284,6 +4330,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -4496,8 +4543,8 @@ xla_cc_test( cc_library( name = "zero_sized_hlo_elimination", hdrs = ["zero_sized_hlo_elimination.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:zero_sized_hlo_elimination instead.", - deps = ["//xla/hlo/transforms:zero_sized_hlo_elimination"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:zero_sized_hlo_elimination instead.", + deps = ["//xla/hlo/transforms/simplifiers:zero_sized_hlo_elimination"], ) cc_library( @@ -4560,14 +4607,21 @@ cc_library( deps = [ ":computation_placer", ":executable", - "//xla:status_macros", - "//xla:types", + ":hlo_module_config", + "//xla:literal", + "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", ], ) @@ -4610,9 +4664,11 @@ cc_library( hdrs = ["hlo_runner_pjrt.h"], deps = [ ":computation_layout", + ":computation_placer_hdr", ":executable", ":hlo_module_util", ":hlo_runner_interface", + "//xla:literal", "//xla:shape_layout", "//xla:shape_util", "//xla:status_macros", @@ -4622,15 +4678,22 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/pjrt:host_memory_spaces", "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_common", "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:casts", ], ) @@ -4650,8 +4713,8 @@ cc_library( cc_library( name = "sort_simplifier", hdrs = ["sort_simplifier.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:sort_simplifier instead.", - deps = ["//xla/hlo/transforms:sort_simplifier"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:sort_simplifier instead.", + deps = ["//xla/hlo/transforms/simplifiers:sort_simplifier"], ) cc_library( @@ -4693,7 +4756,6 @@ xla_cc_test( "//xla/hlo/parser:hlo_parser", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", @@ -4703,8 +4765,8 @@ xla_cc_test( cc_library( name = "root_instruction_sinker", hdrs = ["root_instruction_sinker.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:root_instruction_sinker instead.", - deps = ["//xla/hlo/transforms:root_instruction_sinker"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:root_instruction_sinker instead.", + deps = ["//xla/hlo/transforms/simplifiers:root_instruction_sinker"], ) cc_library( @@ -4725,8 +4787,8 @@ cc_library( cc_library( name = "host_memory_transfer_asyncifier", hdrs = ["host_memory_transfer_asyncifier.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:host_memory_transfer_asyncifier instead.", - deps = ["//xla/hlo/transforms:host_memory_transfer_asyncifier"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:host_memory_transfer_asyncifier instead.", + deps = ["//xla/hlo/transforms/simplifiers:host_memory_transfer_asyncifier"], ) cc_library( @@ -4908,8 +4970,8 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -4952,7 +5014,7 @@ cc_library( "//xla/hlo/analysis:while_loop_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -5025,8 +5087,8 @@ xla_cc_test( cc_library( name = "fusion_constant_sinking", hdrs = ["fusion_constant_sinking.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:fusion_constant_sinking instead.", - deps = ["//xla/hlo/transforms:fusion_constant_sinking"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:fusion_constant_sinking instead.", + deps = ["//xla/hlo/transforms/simplifiers:fusion_constant_sinking"], ) cc_library( @@ -5102,7 +5164,7 @@ xla_cc_test( deps = [ ":while_loop_fusible_sinking", "//xla/hlo/ir:hlo", - "//xla/hlo/transforms:flatten_call_graph", + "//xla/hlo/transforms/simplifiers:flatten_call_graph", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "@com_google_absl//absl/log:check", @@ -5165,6 +5227,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -5178,15 +5241,15 @@ cc_library( cc_library( name = "optimize_input_output_buffer_alias", hdrs = ["optimize_input_output_buffer_alias.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:optimize_input_output_buffer_alias instead.", - deps = ["//xla/hlo/transforms:optimize_input_output_buffer_alias"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:optimize_input_output_buffer_alias instead.", + deps = ["//xla/hlo/transforms/simplifiers:optimize_input_output_buffer_alias"], ) cc_library( name = "ar_crs_combiner", hdrs = ["ar_crs_combiner.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:ar_crs_combiner instead.", - deps = ["//xla/hlo/transforms:ar_crs_combiner"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:ar_crs_combiner instead.", + deps = ["//xla/hlo/transforms/simplifiers:ar_crs_combiner"], ) cc_library( @@ -5224,6 +5287,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_googletest//:gtest", ], ) @@ -5273,7 +5337,7 @@ cc_library( name = "slice_sinker", hdrs = ["slice_sinker.h"], deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:slice_sinker instead.", - deps = ["//xla/hlo/transforms:slice_sinker"], + deps = ["//xla/hlo/transforms/simplifiers:slice_sinker"], ) cc_library( @@ -5429,20 +5493,22 @@ cc_library( "//xla:executable_run_options", "//xla:literal", "//xla:literal_util", + "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", "//xla/stream_executor:device_memory", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:blocking_counter", - "@local_tsl//tsl/platform:statusor", ], ) @@ -5505,8 +5571,8 @@ xla_cc_test( ":pattern_matcher_gmock", ":topk_rewriter", "//xla/hlo/ir:hlo", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", @@ -5527,8 +5593,8 @@ cc_library( cc_library( name = "result_caster", hdrs = ["result_caster.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:result_caster instead.", - deps = ["//xla/hlo/transforms:result_caster"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:result_caster instead.", + deps = ["//xla/hlo/transforms/simplifiers:result_caster"], ) cc_library( @@ -5546,8 +5612,8 @@ cc_library( cc_library( name = "convert_operand_folding", hdrs = ["convert_operand_folding.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:convert_operand_folding instead.", - deps = ["//xla/hlo/transforms:convert_operand_folding"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:convert_operand_folding instead.", + deps = ["//xla/hlo/transforms/simplifiers:convert_operand_folding"], ) cc_library( @@ -5638,6 +5704,7 @@ xla_cc_test( "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/log", + "@com_google_googletest//:gtest", ], ) @@ -5659,6 +5726,7 @@ xla_cc_test( deps = [ ":lockable", "@com_google_absl//absl/synchronization", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -5670,16 +5738,18 @@ cc_library( srcs = ["rendezvous.cc"], hdrs = ["rendezvous.h"], deps = [ + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -5788,8 +5858,8 @@ cc_library( cc_library( name = "instruction_hoister", hdrs = ["instruction_hoister.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:instruction_hoister instead.", - deps = ["//xla/hlo/transforms:instruction_hoister"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:instruction_hoister instead.", + deps = ["//xla/hlo/transforms/simplifiers:instruction_hoister"], ) cc_library( @@ -5903,7 +5973,7 @@ cc_library( name = "gather_simplifier", hdrs = ["gather_simplifier.h"], deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:gather_simplifier instead.", - deps = ["//xla/hlo/transforms:gather_simplifier"], + deps = ["//xla/hlo/transforms/simplifiers:gather_simplifier"], ) cc_library( @@ -5929,8 +5999,8 @@ cc_library( cc_library( name = "reduce_window_rewriter", hdrs = ["reduce_window_rewriter.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:reduce_window_rewriter instead.", - deps = ["//xla/hlo/transforms:reduce_window_rewriter"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:reduce_window_rewriter instead.", + deps = ["//xla/hlo/transforms/simplifiers:reduce_window_rewriter"], ) cc_library( @@ -5954,8 +6024,8 @@ cc_library( cc_library( name = "sub_byte_normalization", hdrs = ["sub_byte_normalization.h"], - deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:sub_byte_normalization instead.", - deps = ["//xla/hlo/transforms:sub_byte_normalization"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms/simplifiers:sub_byte_normalization instead.", + deps = ["//xla/hlo/transforms/simplifiers:sub_byte_normalization"], ) cc_library( @@ -6310,8 +6380,8 @@ cc_library( ":while_util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:flatten_call_graph", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:flatten_call_graph", + "//xla/hlo/transforms/simplifiers:hlo_dce", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", @@ -6333,6 +6403,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -6349,8 +6420,11 @@ cc_library( srcs = ["legalize_scheduling_annotations.cc"], hdrs = ["legalize_scheduling_annotations.h"], deps = [ + "//xla:side_effect_util", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:ptrvec", "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -6369,14 +6443,15 @@ xla_cc_test( srcs = ["legalize_scheduling_annotations_test.cc"], deps = [ ":legalize_scheduling_annotations", + "//xla:side_effect_util", "//xla:test_helpers", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "@com_google_absl//absl/status", + "//xla/hlo/testlib:test_helpers", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/buffer_assignment.cc b/third_party/xla/xla/service/buffer_assignment.cc index 6562eee2048bf9..12540a782a8dba 100644 --- a/third_party/xla/xla/service/buffer_assignment.cc +++ b/third_party/xla/xla/service/buffer_assignment.cc @@ -656,7 +656,7 @@ void BufferAssignment::CombineTempAllocations( // size constraint. VLOG(1) << "Combined temp allocation for color " << color << " is: " << temp_allocation; - combined_allocations.emplace_back(temp_allocation); + combined_allocations.push_back(temp_allocation); combined_allocation_map.emplace(color, &combined_allocations.back()); continue; } @@ -666,7 +666,7 @@ void BufferAssignment::CombineTempAllocations( // combined_it. VLOG(1) << "Due to size constraint, reset temp allocation for color " << color << " to: " << temp_allocation; - combined_allocations.emplace_back(temp_allocation); + combined_allocations.push_back(temp_allocation); combined_allocation_map.emplace(color, &combined_allocations.back()); continue; } diff --git a/third_party/xla/xla/service/buffer_assignment_test.cc b/third_party/xla/xla/service/buffer_assignment_test.cc index 7435d0d65c4fd3..d91b13a7bff1d6 100644 --- a/third_party/xla/xla/service/buffer_assignment_test.cc +++ b/third_party/xla/xla/service/buffer_assignment_test.cc @@ -2836,7 +2836,7 @@ ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] { LOG(INFO) << buffers->ToString(); - auto get_slice = [&](std::string_view hlo_name, const ShapeIndex& index) { + auto get_slice = [&](absl::string_view hlo_name, const ShapeIndex& index) { return buffers->GetUniqueSlice(FindInstruction(m.get(), hlo_name), index) .value(); }; @@ -2929,7 +2929,7 @@ ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] { LOG(INFO) << buffers->ToString(); - auto get_slice = [&](std::string_view hlo_name, const ShapeIndex& index) { + auto get_slice = [&](absl::string_view hlo_name, const ShapeIndex& index) { return buffers->GetUniqueSlice(FindInstruction(m.get(), hlo_name), index) .value(); }; @@ -3040,7 +3040,7 @@ ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] { LOG(INFO) << buffers->ToString(); - auto get_slice = [&](std::string_view hlo_name, const ShapeIndex& index) { + auto get_slice = [&](absl::string_view hlo_name, const ShapeIndex& index) { return buffers->GetUniqueSlice(FindInstruction(m.get(), hlo_name), index) .value(); }; @@ -3104,7 +3104,7 @@ TEST_F(BufferAssignmentTest, AsyncCallImplicitSharding) { LOG(INFO) << buffers->ToString(); - auto get_slice = [&](std::string_view hlo_name, const ShapeIndex& index) { + auto get_slice = [&](absl::string_view hlo_name, const ShapeIndex& index) { return buffers ->GetUniqueSlice(FindInstruction(module.get(), hlo_name), index) .value(); diff --git a/third_party/xla/xla/service/call_inliner_test.cc b/third_party/xla/xla/service/call_inliner_test.cc index 31130894231607..dd6d5e2b301902 100644 --- a/third_party/xla/xla/service/call_inliner_test.cc +++ b/third_party/xla/xla/service/call_inliner_test.cc @@ -24,12 +24,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/collective_ops_utils.cc b/third_party/xla/xla/service/collective_ops_utils.cc index 4436bd6f9ba67b..c95e0381278665 100644 --- a/third_party/xla/xla/service/collective_ops_utils.cc +++ b/third_party/xla/xla/service/collective_ops_utils.cc @@ -26,6 +26,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -37,9 +39,10 @@ limitations under the License. #include "xla/service/global_device_id.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/pattern_matcher.h" +#include "xla/status_macros.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" namespace xla { @@ -165,6 +168,35 @@ absl::StatusOr GetCollectiveOpGroupMode( return Internal("Unexpected instruction type."); } +absl::StatusOr GetCollectiveUseGlobalDeviceIds( + const HloInstruction* hlo) { + const bool is_all_reduce = (hlo->opcode() == HloOpcode::kAllReduce || + hlo->opcode() == HloOpcode::kAllReduceStart || + hlo->opcode() == HloOpcode::kReduceScatter); + const bool is_all_gather = (hlo->opcode() == HloOpcode::kAllGather || + hlo->opcode() == HloOpcode::kAllGatherStart); + if (!is_all_reduce && !is_all_gather) { + return absl::InvalidArgumentError( + "GetReplicaGroupCountAndSize only supports AllReduce and AllGather."); + } + return is_all_reduce + ? Cast(hlo)->use_global_device_ids() + : Cast(hlo)->use_global_device_ids(); +} + +std::optional GetCollectiveChannelId(const HloInstruction* hlo) { + return Cast(hlo)->channel_id(); +} + +const CollectiveDeviceList& GetCollectiveDeviceList(const HloInstruction* hlo) { + return Cast(hlo)->device_list(); +} + +const std::vector& GetCollectiveReplicaGroups( + const HloInstruction* hlo) { + return Cast(hlo)->replica_groups(); +} + // Returns the group formation mode implied by (a) whether the operation has // channel_id and (b) if it has use_global_device_ids and if yes, its value. absl::StatusOr GetCollectiveOpGroupMode( @@ -310,6 +342,21 @@ GetParticipatingDevicesGroups(const DeviceAssignment& device_assignment, } } +absl::StatusOr>> +GetParticipatingDevicesGroups(const HloInstruction* collective) { + CHECK(collective->GetModule()->config().has_static_device_assignment()); + const DeviceAssignment& device_assignment = + collective->GetModule()->config().static_device_assignment(); + TF_ASSIGN_OR_RETURN(bool use_global_device_ids, + GetCollectiveUseGlobalDeviceIds(collective)); + TF_ASSIGN_OR_RETURN( + CollectiveOpGroupMode mode, + GetCollectiveOpGroupMode(GetCollectiveChannelId(collective).has_value(), + use_global_device_ids)); + return GetParticipatingDevicesGroups( + device_assignment, GetCollectiveReplicaGroups(collective), mode); +} + absl::StatusOr> GetParticipatingFlattenedIdGroups( const DeviceAssignment& device_assignment, absl::Span replica_groups, @@ -410,59 +457,31 @@ absl::StatusOr> GetParticipatingFlattenedIdGroups( absl::StatusOr> GetParticipatingFlattenedIdGroups( const HloInstruction* hlo, const DeviceAssignment& device_assignment) { - if (hlo->opcode() != HloOpcode::kAllGather && - hlo->opcode() != HloOpcode::kAllGatherStart && - hlo->opcode() != HloOpcode::kAllReduce && - hlo->opcode() != HloOpcode::kAllReduceStart && - hlo->opcode() != HloOpcode::kReduceScatter) { - return absl::InvalidArgumentError( - "GetParticipatingFlattenedIdGroups only supports AllGather and " - "AllReduce."); - } - bool use_global_device_ids = - (hlo->opcode() == HloOpcode::kAllGather || - hlo->opcode() == HloOpcode::kAllGatherStart) - ? Cast(hlo)->use_global_device_ids() - : Cast(hlo)->use_global_device_ids(); - const HloCollectiveInstruction* hlo_collective = - Cast(hlo); + TF_ASSIGN_OR_RETURN(bool use_global_device_ids, + GetCollectiveUseGlobalDeviceIds(hlo)); TF_ASSIGN_OR_RETURN( CollectiveOpGroupMode mode, - GetCollectiveOpGroupMode(hlo_collective->channel_id().has_value(), + GetCollectiveOpGroupMode(GetCollectiveChannelId(hlo).has_value(), use_global_device_ids)); TF_ASSIGN_OR_RETURN( std::vector replica_groups, - GetParticipatingFlattenedIdGroups( - device_assignment, hlo_collective->replica_groups(), mode)); + GetParticipatingFlattenedIdGroups(device_assignment, + GetCollectiveReplicaGroups(hlo), mode)); return replica_groups; } // Same as above, used for cases where static_device_assignment is not present. absl::StatusOr> GetParticipatingFlattenedIdGroups( const HloInstruction* hlo, int replica_count, int partition_count) { - if (hlo->opcode() != HloOpcode::kAllGather && - hlo->opcode() != HloOpcode::kAllGatherStart && - hlo->opcode() != HloOpcode::kAllReduce && - hlo->opcode() != HloOpcode::kAllReduceStart && - hlo->opcode() != HloOpcode::kReduceScatter) { - return absl::InvalidArgumentError( - "GetParticipatingFlattenedIdGroups only supports AllGather and " - "AllReduce."); - } - bool use_global_device_ids = - (hlo->opcode() == HloOpcode::kAllGather || - hlo->opcode() == HloOpcode::kAllGatherStart) - ? Cast(hlo)->use_global_device_ids() - : Cast(hlo)->use_global_device_ids(); - const HloCollectiveInstruction* hlo_collective = - Cast(hlo); + TF_ASSIGN_OR_RETURN(bool use_global_device_ids, + GetCollectiveUseGlobalDeviceIds(hlo)); TF_ASSIGN_OR_RETURN( CollectiveOpGroupMode mode, - GetCollectiveOpGroupMode(hlo_collective->channel_id().has_value(), + GetCollectiveOpGroupMode(GetCollectiveChannelId(hlo).has_value(), use_global_device_ids)); TF_ASSIGN_OR_RETURN( std::vector replica_groups, - GetParticipatingFlattenedIdGroups(hlo_collective->replica_groups(), mode, + GetParticipatingFlattenedIdGroups(GetCollectiveReplicaGroups(hlo), mode, replica_count, partition_count)); return replica_groups; } @@ -637,22 +656,7 @@ absl::StatusOr> GetPariticipantCountsForReplicaGroups( absl::StatusOr>> GetReplicaGroupCountAndSize(const HloInstruction* hlo) { - const bool is_all_reduce = (hlo->opcode() == HloOpcode::kAllReduce || - hlo->opcode() == HloOpcode::kAllReduceStart || - hlo->opcode() == HloOpcode::kReduceScatter); - const bool is_all_gather = (hlo->opcode() == HloOpcode::kAllGather || - hlo->opcode() == HloOpcode::kAllGatherStart); - if (!is_all_reduce && !is_all_gather) { - return absl::InvalidArgumentError( - "GetReplicaGroupCountAndSize only supports AllReduce and AllGather."); - } - const CollectiveDeviceList& device_list = - Cast(hlo)->device_list(); - const std::optional channel_id = hlo->channel_id(); - const bool use_global_ids = - is_all_reduce - ? Cast(hlo)->use_global_device_ids() - : Cast(hlo)->use_global_device_ids(); + const CollectiveDeviceList& device_list = GetCollectiveDeviceList(hlo); auto config = hlo->GetModule()->config(); if (device_list.iota_replica_group_list().has_value()) { @@ -660,9 +664,12 @@ GetReplicaGroupCountAndSize(const HloInstruction* hlo) { device_list.iota_replica_group_list()->num_replica_groups(), device_list.iota_replica_group_list()->num_devices_per_group()); } + TF_ASSIGN_OR_RETURN(bool use_global_device_ids, + GetCollectiveUseGlobalDeviceIds(hlo)); TF_ASSIGN_OR_RETURN( CollectiveOpGroupMode group_mode, - GetCollectiveOpGroupMode(channel_id.has_value(), use_global_ids)); + GetCollectiveOpGroupMode(GetCollectiveChannelId(hlo).has_value(), + use_global_device_ids)); TF_ASSIGN_OR_RETURN(std::vector participant_counts, GetPariticipantCountsForReplicaGroups( config.replica_count(), config.num_partitions(), diff --git a/third_party/xla/xla/service/collective_ops_utils.h b/third_party/xla/xla/service/collective_ops_utils.h index 242975a5d6bf6a..9c7776e5bdb8ed 100644 --- a/third_party/xla/xla/service/collective_ops_utils.h +++ b/third_party/xla/xla/service/collective_ops_utils.h @@ -17,32 +17,35 @@ limitations under the License. #define XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_ #include -#include #include #include -#include #include #include #include "absl/functional/function_ref.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/executable_run_options.h" +#include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/literal.h" #include "xla/service/computation_placer.h" #include "xla/service/global_device_id.h" #include "xla/service/pattern_matcher.h" #include "xla/stream_executor/device_memory.h" -#include "tsl/platform/blocking_counter.h" namespace xla { enum class ReductionKind { SUM, PRODUCT, MIN, MAX }; -constexpr std::string_view ReductionKindToString(ReductionKind reduction_kind) { +constexpr absl::string_view ReductionKindToString( + ReductionKind reduction_kind) { switch (reduction_kind) { case ReductionKind::SUM: return "sum"; @@ -120,6 +123,15 @@ absl::StatusOr> GetParticipatingIDs( absl::string_view CollectiveOpGroupModeToString( CollectiveOpGroupMode group_mode); +absl::StatusOr GetCollectiveUseGlobalDeviceIds(const HloInstruction* hlo); + +std::optional GetCollectiveChannelId(const HloInstruction* hlo); + +const CollectiveDeviceList& GetCollectiveDeviceList(const HloInstruction* hlo); + +const std::vector& GetCollectiveReplicaGroups( + const HloInstruction* hlo); + // Returns the group formation mode of instr, assuming that instr is, or is // dervied from, an HloAllGatherInstruction, HloAllReduceInstructionBase, // HloAllToAllInstruction, HloCollectiveBroadcastInstruction or @@ -159,6 +171,10 @@ GetParticipatingDevicesGroups(const DeviceAssignment& device_assignment, absl::Span replica_groups, CollectiveOpGroupMode group_mode); +// Same as above, except taking an HloInstruction instead. +absl::StatusOr>> +GetParticipatingDevicesGroups(const HloInstruction* collective); + // Same as above, except that it returns the flattened id in the replica groups // instead of device id. absl::StatusOr> GetParticipatingFlattenedIdGroups( @@ -318,132 +334,6 @@ struct RendezvousKey { int64_t op_id; }; -template -void WaitAndLogIfStuck(tsl::BlockingCounter* counter, const DescFn& desc_fn) { - VLOG(3) << "Begin: " << desc_fn(); - const std::chrono::milliseconds timeout(5000); - bool ok = counter->WaitFor(timeout); - if (ok) { - VLOG(3) << "Finished: " << desc_fn(); - return; - } - LOG(ERROR) << "This thread has been waiting for " << timeout.count() - << "ms for and may be stuck: " << desc_fn(); - counter->Wait(); - LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. " - "Perhaps the timeout is too short: " - << desc_fn(); -} - -// Participant data for each rendezvous. -struct ParticipantData { - ParticipantData(const RendezvousKey& rendezvous_key, int local_rank) - : rendezvous_key(rendezvous_key), local_rank(local_rank) {} - - virtual ~ParticipantData() {} - - RendezvousKey rendezvous_key; - int local_rank; // Which of the local participants is this? - - virtual std::string ToString() const = 0; -}; - -// The set of threads that want to do a collective op together all pick the same -// Rendezvous object out of the global cache and call SubmitParticipant. -// -// The Rendezvous instance handles waiting for all threads to join, ensuring -// that a clique exists for the desired set of GPUs, etc. -// -// Rendezvous objects can only be used once. -// -// I: Participant data. -// O: Participant output. -template ::value>> -class Rendezvous { - public: - virtual ~Rendezvous() {} - explicit Rendezvous(const RendezvousKey& k) - : participants_(k.num_local_participants), key_(k) {} - - // Submit a participant to the rendezvous. We get the rendezvous from - // `rendezvous_getter`, which we can then use to drop the existing reference. - static absl::StatusOr SubmitParticipant( - absl::FunctionRef>()> rendezvous_getter, - I participant) { - std::shared_ptr> rendezvous = rendezvous_getter(); - TF_ASSIGN_OR_RETURN(auto p, rendezvous->SubmitParticipant(participant)); - - // Drop our reference to the Rendezvous and wait for all other threads to do - // the same. If we didn't do this, one of the threads could run past this - // point, reenter ExecuteOnStream for another all-reduce, and attempt to - // reuse the Rendezvous! - // - // An alternative way of accomplishing this goal would be to implement - // RefcountingHashMap::erase() and call it during SubmitParticipant. But - // erase() is deceptively complex to implement correctly. - std::shared_ptr blocking_counter = p.second; - rendezvous.reset(); - blocking_counter->DecrementCount(); - xla::WaitAndLogIfStuck(blocking_counter.get(), [&] { - return absl::StrFormat( - "participant waiting for all threads to drop their reference to the " - "rendezvous: %p", - rendezvous.get()); - }); - return std::move(p.first); - } - - protected: - // Returns domain-specific output O and whether this replica is primary. - virtual absl::StatusOr RunCollectiveOp(const I& participant) = 0; - - // Adding participants_ requires holding mu_. - // Not annotated with ABSL_GUARDED_BY(mu_) because we do not require the lock - // to be held during CollectiveOp(), since at that point all the data is known - // to be present due to the global barrier. - std::vector> participants_; - - private: - absl::Mutex mu_; - - // Runs the all-reduce on the given thread. If successful, returns - // - a handle to the clique that was used, so that the caller may keep the - // clique alive if it chooses. - // - a BlockingCounter initialized to the number of participants, so that - // the caller can coordinate with the participants one last time if it - // chooses. This is useful for coordinating destruction of the Rendezvous. - absl::StatusOr>> - SubmitParticipant(const I& participant) { - { - absl::MutexLock lock(&mu_); - CHECK(!participants_[participant.local_rank].has_value()); - participants_[participant.local_rank] = participant; - } - - // Wait for all participants to arrive. - all_participants_present_.DecrementCount(); - WaitAndLogIfStuck(&all_participants_present_, [&] { - return absl::StrFormat( - "participant %s waiting for all participants to arrive at rendezvous " - "%s", - participant.ToString(), key_.ToString()); - }); - - TF_ASSIGN_OR_RETURN(O output, RunCollectiveOp(participant)); - return std::make_pair(std::move(output), returned_blocking_counter_); - } - - const RendezvousKey key_; - - tsl::BlockingCounter all_participants_present_{key_.num_local_participants}; - - // tsl::BlockingCounter returned by SubmitParticipant. - std::shared_ptr returned_blocking_counter_{ - std::make_shared(key_.num_local_participants)}; -}; - // We only pipeline Send-Recv chains with channel_id > 0, where each chain // has a unique channel_id, and allows multiple Send-Recv chains using // channel_id 0. diff --git a/third_party/xla/xla/service/collective_permute_decomposer.cc b/third_party/xla/xla/service/collective_permute_decomposer.cc index d1ea5974fd82cc..9f051576e5fc00 100644 --- a/third_party/xla/xla/service/collective_permute_decomposer.cc +++ b/third_party/xla/xla/service/collective_permute_decomposer.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/service/collective_permute_decomposer.h" -#include #include #include #include @@ -23,10 +22,11 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -34,51 +34,20 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/collective_ops_utils.h" +#include "xla/service/collective_permute_utils.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/graphcycles/graphcycles.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/errors.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace xla { - namespace { using SourceTargetPair = std::pair; using SourceTargetPairs = std::vector; -// Returns true if the (source, target) relationship has a cycle. -bool HasCycles(const SourceTargetPairs& pairs) { - // Build a direct graph to check for cycles in (source, target) relationship. - GraphCycles graph; - - // Map replica numbers to graph node ids. - absl::flat_hash_map replica_to_node_id; - auto get_node_id = [&](int64_t replica) { - auto it_and_inserted = replica_to_node_id.emplace(replica, -1); - auto it = it_and_inserted.first; - auto inserted = it_and_inserted.second; - if (inserted) { - // First time to see the replica, create a node for it. - it->second = graph.NewNode(); - } - return it->second; - }; - - for (auto pair : pairs) { - int source = get_node_id(pair.first); - int target = get_node_id(pair.second); - VLOG(3) << "See source " << source << " -> target " << target; - if (!graph.InsertEdge(source, target)) { - VLOG(3) << "Detected cycles"; - return true; - } - } - return false; -} - // Returns true if the CollectivePermute instruction should be transformed // to Send/Recv. We currently limit the transformation to CollectivePermute // operations without any cycle in their (source, target) relationship, @@ -96,7 +65,7 @@ bool ShouldDecompose(const HloCollectivePermuteInstruction& collective_permute, if (ShapeUtil::ByteSizeOf(result_shape) < threshold_in_bytes) { return false; } - return !HasCycles(collective_permute.source_target_pairs()); + return !cp_utils::HasCycles(collective_permute.source_target_pairs()); } // Returns true for a pipelineable collective-permute. As a simple heuristic, @@ -110,82 +79,77 @@ bool MayPipeline(const HloCollectivePermuteInstruction& collective_permute) { // Contains source-target pairs from the permute operation and send and recv // instructions it was decomposed to. -struct CpWithDecomposedOps { - HloInstruction* inserted_send; - HloInstruction* inserted_recv; +struct DecomposedCp { + HloInstruction* send; + HloInstruction* recv; SourceTargetPairs source_target_pairs; }; -// Decomposes a collective-permute and adds frontend attributes to record -// pipeline decision. The present of the frontend attribute means that the -// collective-permute will be pipelined and the value of the attribute -// represents the runtime stream to execute the instruction. Without the -// frontend attribute, the collective-permute will not be pipelined. -absl::StatusOr DecomposeCollectivePermute( - HloCollectivePermuteInstruction* collective_permute, - HloComputation* computation, const std::string& pipeline_decision) { - // We currently only decompose collective-permute with a channel_id. - std::optional channel_id = collective_permute->channel_id(); - - HloInstruction* data = collective_permute->mutable_operand(0); - const Shape& data_shape = data->shape(); - const OpMetadata& metadata = collective_permute->metadata(); - - const xla::FrontendAttributes& old_attributes = - collective_permute->frontend_attributes(); +xla::FrontendAttributes ExtractFrontendAttributes( + const HloCollectivePermuteInstruction& cp) { + const xla::FrontendAttributes& old_attributes = cp.frontend_attributes(); xla::FrontendAttributes attributes; - std::string source_target_pairs_string = - "{" + - absl::StrJoin(collective_permute->source_target_pairs(), ",", - absl::PairFormatter( - [](std::string* out, int64_t value) { - absl::StrAppend(out, "{", value); - }, - ",", - [](std::string* out, int64_t value) { - absl::StrAppend(out, value, "}"); - })) + - "}"; attributes.mutable_map()->insert(old_attributes.map().begin(), old_attributes.map().end()); (*attributes.mutable_map())[kSendRecvSourceTargetPairsAttr] = - source_target_pairs_string; + cp_utils::SourceTargetPairsString(cp); + return attributes; +} - HloInstruction* after_all = - computation->AddInstruction(HloInstruction::CreateToken()); - HloInstruction* recv = computation->AddInstruction(HloInstruction::CreateRecv( - data_shape, after_all, channel_id, /*is_host_transfer=*/false)); - recv->add_frontend_attributes(attributes); +// Decomposes a collective-permute into send, send-done, recv, recv-done. +// Adds frontend attributes to record pipeline decision. The present of the +// frontend attribute means that the collective-permute will be pipelined and +// the value of the attribute represents the runtime stream to execute the +// instruction. Without the frontend attribute, the collective-permute will not +// be pipelined. +absl::StatusOr DecomposeCollectivePermute( + HloCollectivePermuteInstruction* cp, HloComputation* computation, + const std::string& pipeline_decision) { + absl::string_view cp_name = cp->name(); + std::optional channel_id = cp->channel_id(); + HloInstruction* data = cp->mutable_operand(0); + const Shape& shape = data->shape(); + const OpMetadata& metadata = cp->metadata(); + const xla::FrontendAttributes attributes = ExtractFrontendAttributes(*cp); + + HloInstruction* after_all = computation->AddInstruction( + HloInstruction::CreateToken(), absl::StrCat(cp_name, "-after-all")); + HloInstruction* recv = computation->AddInstruction( + HloInstruction::CreateRecv(shape, after_all, channel_id, + /*is_host_transfer=*/false), + absl::StrCat(cp_name, "-recv")); + recv->set_frontend_attributes(attributes); recv->set_metadata(metadata); - HloInstruction* send = computation->AddInstruction(HloInstruction::CreateSend( - data, after_all, channel_id, /*is_host_transfer=*/false)); - send->add_frontend_attributes(attributes); + HloInstruction* send = computation->AddInstruction( + HloInstruction::CreateSend(data, after_all, channel_id, + /*is_host_transfer=*/false), + absl::StrCat(cp_name, "-send")); + send->set_frontend_attributes(attributes); send->set_metadata(metadata); - HloInstruction* recv_done = - computation->AddInstruction(HloInstruction::CreateRecvDone( - recv, channel_id, /*is_host_transfer=*/false)); - HloInstruction* send_done = - computation->AddInstruction(HloInstruction::CreateSendDone( - send, channel_id, /*is_host_transfer=*/false)); - - // We will add control dependence to represent how we want to order Send/Recv - // and other collective operations. Here we only add the necessary control - // dependence to avoid optimization that can cause problems, in particular, - // to prevent fusion from fusing the computation of Send-data with the - // computation that requires the Recv-result. - TF_RETURN_IF_ERROR(send->AddControlDependencyTo(recv_done)); + HloInstruction* recv_done = computation->AddInstruction( + HloInstruction::CreateRecvDone(recv, channel_id, + /*is_host_transfer=*/false), + absl::StrCat(cp_name, "-recv-done")); + HloInstruction* send_done = computation->AddInstruction( + HloInstruction::CreateSendDone(send, channel_id, + /*is_host_transfer=*/false), + absl::StrCat(cp_name, "-send-done")); HloInstruction* recv_data = computation->AddInstruction( - HloInstruction::CreateGetTupleElement(recv_done, 0)); - TF_RETURN_IF_ERROR(collective_permute->ReplaceAllUsesWith(recv_data)); + HloInstruction::CreateGetTupleElement(recv_done, 0), + absl::StrCat(cp_name, "-recv-data")); - CpWithDecomposedOps decomposed_cp = { - send, recv, collective_permute->source_target_pairs()}; + TF_RETURN_IF_ERROR(cp->ReplaceAllUsesWith(recv_data)); + TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(cp)); - TF_RETURN_IF_ERROR( - computation->RemoveInstructionAndUnusedOperands(collective_permute)); + // Control dependencies are require to assure order of the instructions. + // To avoid deadlocks as the program runs on multiple devices, we need to + // assure that we initiate receival before initiating sending and that receive + // done is executed after send is initiated. + TF_RETURN_IF_ERROR(recv->AddControlDependencyTo(send)); + TF_RETURN_IF_ERROR(send->AddControlDependencyTo(recv_done)); if (!pipeline_decision.empty()) { xla::FrontendAttributes attributes; @@ -195,48 +159,7 @@ absl::StatusOr DecomposeCollectivePermute( recv->add_frontend_attributes(attributes); recv_done->add_frontend_attributes(attributes); } - - return decomposed_cp; -} - -// Returns true if the (source, target) pairs form a forward cycle with all -// participants in the cycle, such as {{0,1},{1,2},{2,3},{3,0}}. We assume that -// the (source, target) pairs are ordered via increasing source IDs, as they are -// currently generated by SPMD partitioning. -// -bool IsForwardCycle(const SourceTargetPair& backedge, - const SourceTargetPairs& others) { - int64_t num_pairs = others.size() + 1; - if (backedge.first != num_pairs - 1 || backedge.second != 0) { - return false; - } - for (int64_t i = 0; i < num_pairs - 1; ++i) { - const SourceTargetPair& pair = others[i]; - if (pair.first != i || pair.second != i + 1) { - return false; - } - } - return true; -} - -// Returns true if the (source, target) pairs form a backward cycle with all -// participants in the cycle, such as {{0,3},{1,0},{2,1},{3,2}}. We assume that -// the (source, target) pairs are ordered via increasing source IDs, as they are -// currently generated by SPMD partitioning. -// -bool IsBackwardCycle(const SourceTargetPair& backedge, - const SourceTargetPairs& others) { - int64_t num_pairs = others.size() + 1; - if (backedge.first != 0 || backedge.second != num_pairs - 1) { - return false; - } - for (int64_t i = 0; i < num_pairs - 1; ++i) { - const SourceTargetPair& pair = others[i]; - if (pair.first != i + 1 || pair.second != i) { - return false; - } - } - return true; + return DecomposedCp{send, recv, cp->source_target_pairs()}; } // Checks whether the two collective-permutes for a forward cycle or a backward @@ -250,15 +173,15 @@ CheckCyclePatterns(HloCollectivePermuteInstruction* cp0, const SourceTargetPairs& cp0_pairs = cp0->source_target_pairs(); const SourceTargetPairs& cp1_pairs = cp1->source_target_pairs(); if (cp0_pairs.size() == 1) { - if (IsForwardCycle(cp0_pairs.front(), cp1_pairs) || - IsBackwardCycle(cp0_pairs.front(), cp1_pairs)) { + if (cp_utils::IsForwardCycle(cp0_pairs.front(), cp1_pairs) || + cp_utils::IsBackwardCycle(cp0_pairs.front(), cp1_pairs)) { // cp0 represents the backedge for the cycle. return std::make_pair(cp0, cp1); } } if (cp1_pairs.size() == 1) { - if (IsForwardCycle(cp1_pairs.front(), cp0_pairs) || - IsBackwardCycle(cp1_pairs.front(), cp0_pairs)) { + if (cp_utils::IsForwardCycle(cp1_pairs.front(), cp0_pairs) || + cp_utils::IsBackwardCycle(cp1_pairs.front(), cp0_pairs)) { // cp1 represents the forward edge for the cycle. return std::make_pair(cp1, cp0); } @@ -270,43 +193,16 @@ CheckCyclePatterns(HloCollectivePermuteInstruction* cp0, // The order protects from a potential deadlock when every device tries to // execute recv with no devices executing send - if there are no constraints, // the scheduler is free to schedule all recv ops first. -// -// The input argument is a vector of decomposed collective permutes in the order -// they were added into instructions. +// deco_post_order is expected to be post order within a computation. +// TODO b/388072780 add second hueristic to enforce back edge before the forward +// edge for max performance. absl::Status EnforceOrderOfSendRecvChains( - std::vector& decomposed_cps) { - // Order the decomposed permutes in order of the intended scheduling: - // 1. Permutes with fewer target pairs go first. This is a heuristic to - // prioritize backwards edges, which would normally have fewer pairs. - // 2. The permute appearing earlier in the instructions should be scheduled - // earlier. - // The incoming vector is already in the order of instructions, so we use - // stable sort to preserve the existing ordering. - // - // This scheduling order is a performance optimization heuristic. It is not - // necessary to prevent deadlocks - all we need to do is to prevent recv being - // executed on every device at once, so any sorting criteria should work. - // However, we know that back edges should generally be scheduled earlier for - // better overlap with compute. - std::stable_sort( - decomposed_cps.begin(), decomposed_cps.end(), - [](const CpWithDecomposedOps& lhs, const CpWithDecomposedOps& rhs) { - return lhs.source_target_pairs.size() < rhs.source_target_pairs.size(); - }); - - for (size_t i = 0; i < decomposed_cps.size(); ++i) { - // Link within the current send and recv pair. - CpWithDecomposedOps& cur_decomposed_cp = decomposed_cps[i]; - TF_RETURN_IF_ERROR(cur_decomposed_cp.inserted_recv->AddControlDependencyTo( - cur_decomposed_cp.inserted_send)); - - // Link between the previous and current send/recv pair. - if (i < 1) continue; - CpWithDecomposedOps& prev_decomposed_cp = decomposed_cps[i - 1]; - TF_RETURN_IF_ERROR(prev_decomposed_cp.inserted_send->AddControlDependencyTo( - cur_decomposed_cp.inserted_recv)); + std::vector& deco_post_order) { + for (size_t i = 1; i < deco_post_order.size(); ++i) { + DecomposedCp& cur = deco_post_order[i]; + DecomposedCp& prev = deco_post_order[i - 1]; + TF_RETURN_IF_ERROR(prev.send->AddControlDependencyTo(cur.recv)); } - return absl::OkStatus(); } @@ -341,18 +237,18 @@ absl::StatusOr CollectivePermuteDecomposer::Run( std::vector cps_to_decompose; HloCollectivePermuteInstruction* cp0_to_pipeline = nullptr; HloCollectivePermuteInstruction* cp1_to_pipeline = nullptr; - for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) { - if (hlo->opcode() == HloOpcode::kWhile) { + for (HloInstruction* instr : computation->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kWhile) { // Collect while-body computations. - while_bodies.insert(hlo->while_body()); + while_bodies.insert(instr->while_body()); continue; } - if (hlo->opcode() != HloOpcode::kCollectivePermute) { + if (instr->opcode() != HloOpcode::kCollectivePermute) { continue; } HloCollectivePermuteInstruction* cp = - Cast(hlo); + Cast(instr); if (!ShouldDecompose(*cp, threshold_in_bytes_)) { continue; } @@ -363,7 +259,7 @@ absl::StatusOr CollectivePermuteDecomposer::Run( continue; } if (cp0_to_pipeline != nullptr && cp1_to_pipeline != nullptr) { - // Already find a pair of collective-permute that forms a cycle to + // Already found a pair of collective-permute that forms a cycle to // pipeline. continue; } @@ -385,10 +281,12 @@ absl::StatusOr CollectivePermuteDecomposer::Run( // Collective-permute for the forward edges. cp1_to_pipeline = optional_pair.value().second; } - } + } // for MakeInstructionPostOrder - std::vector decomposed_cps; - decomposed_cps.reserve(cps_to_decompose.size()); + // cps to decompose were collected post order, similarly we will collect + // the decomposed send/recv pairs. + std::vector deco_post_order; + deco_post_order.reserve(cps_to_decompose.size()); // Decompose the collective-permute, may add frontend attribute to record // pipeline decision. for (HloCollectivePermuteInstruction* cp : cps_to_decompose) { @@ -399,19 +297,15 @@ absl::StatusOr CollectivePermuteDecomposer::Run( pipeline_decision = "1"; } TF_ASSIGN_OR_RETURN( - auto decomposed_ops, + DecomposedCp decomposed_ops, DecomposeCollectivePermute(cp, computation, pipeline_decision)); - decomposed_cps.push_back(decomposed_ops); + deco_post_order.push_back(decomposed_ops); } - - TF_RETURN_IF_ERROR(EnforceOrderOfSendRecvChains(decomposed_cps)); - + TF_RETURN_IF_ERROR(EnforceOrderOfSendRecvChains(deco_post_order)); if (!cps_to_decompose.empty()) { changed = true; } - } - + } // for reverse MakeComputationPostOrder return changed; } - } // namespace xla diff --git a/third_party/xla/xla/service/collective_permute_decomposer.h b/third_party/xla/xla/service/collective_permute_decomposer.h index 11e96e5005e11b..33716f24f7eda6 100644 --- a/third_party/xla/xla/service/collective_permute_decomposer.h +++ b/third_party/xla/xla/service/collective_permute_decomposer.h @@ -16,6 +16,11 @@ limitations under the License. #ifndef XLA_SERVICE_COLLECTIVE_PERMUTE_DECOMPOSER_H_ #define XLA_SERVICE_COLLECTIVE_PERMUTE_DECOMPOSER_H_ +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" @@ -60,7 +65,6 @@ class CollectivePermuteDecomposer : public HloModulePass { return "collective-permute-decomposer"; } - using HloPassInterface::Run; // Runs CollectivePermuteDecomposer pass on computations in 'module'. // Returns whether the 'module' was changed. absl::StatusOr Run( diff --git a/third_party/xla/xla/service/collective_permute_decomposer_test.cc b/third_party/xla/xla/service/collective_permute_decomposer_test.cc index cc0634472ecf1f..ec386b43e3834b 100644 --- a/third_party/xla/xla/service/collective_permute_decomposer_test.cc +++ b/third_party/xla/xla/service/collective_permute_decomposer_test.cc @@ -15,183 +15,230 @@ limitations under the License. #include "xla/service/collective_permute_decomposer.h" +#include #include +#include #include #include +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/hlo/utils/hlo_query.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/statusor.h" namespace xla { namespace { using ::testing::ElementsAre; using ::testing::HasSubstr; + namespace op = xla::testing::opcode_matchers; -using CollectivePermuteDecomposerTest = HloTestBase; - -TEST_F(CollectivePermuteDecomposerTest, WithCycleNotTransformed) { - const absl::string_view kModuleStr = R"( - HloModule test - ENTRY test_computation { - p = u32[] replica-id() - ROOT cp = u32[] collective-permute(p), channel_id=1, - source_target_pairs={{0,1}, {1,0}} - } - )"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_FALSE(changed); +using Pass = CollectivePermuteDecomposer; + +struct Decomposed { + std::string cp_name; + HloInstruction* after_all; + HloInstruction* send; + HloInstruction* recv; + HloInstruction* send_done; + HloInstruction* recv_done; +}; + +class DecomposerTest : public HloHardwareIndependentTestBase { + protected: + void AssertNoTranform(absl::string_view hlo, int64_t threshold = 0) { + TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Pass(threshold), false)); + }; + auto Transform(absl::string_view hlo, int64_t threshold = 0) { + return RunAndCheckHloRewrite(hlo, Pass(threshold), true); + }; + void AssertTransform(absl::string_view hlo, int64_t threshold = 0) { + TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Pass(threshold), true)); + }; + Decomposed FindComponents(HloModule* module, absl::string_view cp_name) { + Decomposed result; + result.cp_name = cp_name; + result.after_all = + FindInstruction(module, absl::StrCat(cp_name, "-after-all")); + result.send = FindInstruction(module, absl::StrCat(cp_name, "-send")); + result.recv = FindInstruction(module, absl::StrCat(cp_name, "-recv")); + result.send_done = + FindInstruction(module, absl::StrCat(cp_name, "-send-done")); + result.recv_done = + FindInstruction(module, absl::StrCat(cp_name, "-recv-done")); + CHECK(result.after_all != nullptr) << cp_name; + CHECK(result.send != nullptr) << cp_name; + CHECK(result.recv != nullptr) << cp_name; + CHECK(result.send_done != nullptr) << cp_name; + CHECK(result.recv_done != nullptr) << cp_name; + return result; + } +}; + +TEST_F(DecomposerTest, WithCycleNotTransformed) { + AssertNoTranform(R"(HloModule test + ENTRY test_computation { + data = u32[] parameter(0) + ROOT cp = u32[] collective-permute(data), channel_id=1, source_target_pairs={{0,1}, {1,0}} + })"); } -TEST_F(CollectivePermuteDecomposerTest, WithContextDataNotTransformed) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - p = u32[] replica-id() - ROOT cp = (u32[], u32[], u32[], u32[]) collective-permute(p), channel_id=1, - source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}} - } - )"; +TEST_F(DecomposerTest, ThresholdNotTransformed) { + AssertNoTranform(R"(HloModule test + ENTRY test_computation { + p = u32[] replica-id() + ROOT cp = u32[] collective-permute(p), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}} + })", + 8); +} - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_FALSE(changed); +TEST_F(DecomposerTest, Basic) { + AssertTransform(R"(HloModule test + ENTRY test_computation { + data = u32[] parameter(0) + ROOT cp = u32[] collective-permute(data), channel_id=1, source_target_pairs={{0,1}, {1,2}} + })"); } -TEST_F(CollectivePermuteDecomposerTest, TransformedExplicitChannelId) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - p = u32[] replica-id() - ROOT cp = u32[] collective-permute(p), channel_id=1, - source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}, - metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} - } - )"; +TEST_F(DecomposerTest, NoChannelId) { + AssertTransform(R"(HloModule test + ENTRY test_computation { + data = u32[] parameter(0) + ROOT cp = u32[] collective-permute(data), source_target_pairs={{0,1}, {1,2}} + })"); +} - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); +TEST_F(DecomposerTest, ControlDependency_IndependentCPs) { + absl::string_view hlo = R"(HloModule test + ENTRY test_computation { + data1 = u32[] parameter(0) + data2 = u32[] parameter(1) + cp3 = u32[] collective-permute(data2), source_target_pairs={{6,7}} + cp1 = u32[] collective-permute(data1), source_target_pairs={{3,0}} + cp2 = u32[] collective-permute(data2), source_target_pairs={{0,1},{1,2},{2,3}} + ROOT out = (u32[],u32[],u32[]) tuple(cp2, cp3, cp1) + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, Transform(hlo)); + Decomposed cp1 = FindComponents(module.get(), "cp1"); + Decomposed cp2 = FindComponents(module.get(), "cp2"); + Decomposed cp3 = FindComponents(module.get(), "cp3"); + // Sequence in tuple determines the port order and therefore control + // dependency of consecutive CPs. + EXPECT_THAT(cp3.recv->control_predecessors(), ElementsAre(cp2.send)); + EXPECT_THAT(cp1.recv->control_predecessors(), ElementsAre(cp3.send)); +} - auto check_metadata = [](const HloInstruction* inst) { - EXPECT_EQ(inst->metadata().op_name(), "op1/op2/add"); - EXPECT_EQ(inst->metadata().source_file(), "foo/bar/mysource.py"); - EXPECT_EQ(inst->metadata().source_line(), 35); - }; +// Negative test to assure that the decomposer does not create cyclic +// instructions when there is dependency from one cp to another. +TEST_F(DecomposerTest, ControlDependency_BasicDependency) { + absl::string_view hlo = R"(HloModule test + ENTRY test_computation { + p0 = f32[] parameter(0) + cp-a = f32[] collective-permute(p0), source_target_pairs={{0,1}, {1,2}, {2,3}} + ROOT cp-b = f32[] collective-permute(cp-a), source_target_pairs={{3,0}} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, Transform(hlo)); + Decomposed cp_a = FindComponents(module.get(), "cp-a"); + Decomposed cp_b = FindComponents(module.get(), "cp-b"); + EXPECT_THAT(cp_b.recv->control_predecessors(), ElementsAre(cp_a.send)) + << "Recv-start from cp1 should depend on send start from cp2"; +} - auto check_not_pipelined = [](const HloInstruction* instr) { - const FrontendAttributes& attributes = instr->frontend_attributes(); - EXPECT_EQ(attributes.map().end(), - attributes.map().find(kSendRecvPipelineAttr)); - }; +TEST_F(DecomposerTest, ControlDependency_MoreDependencies) { + absl::string_view hlo = R"(HloModule test + ENTRY test_computation { + data1 = u32[] parameter(0) + data2 = u32[] parameter(1) + // misordered names to assure that dependencies are honored + cp1 = u32[] collective-permute(data1), source_target_pairs={{3,0}} + cp2 = u32[] collective-permute(cp1), source_target_pairs={{0,1},{1,2},{2,3}} + cp3 = u32[] collective-permute(cp2), source_target_pairs={{6,7}} + ROOT out = u32[8] broadcast(cp3), dimensions={} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, Transform(hlo)); + Decomposed cp1 = FindComponents(module.get(), "cp1"); + Decomposed cp2 = FindComponents(module.get(), "cp2"); + Decomposed cp3 = FindComponents(module.get(), "cp3"); + EXPECT_THAT(cp2.recv->control_predecessors(), ElementsAre(cp1.send)); + EXPECT_THAT(cp3.recv->control_predecessors(), ElementsAre(cp2.send)); +} - HloInstruction* after_all = FindInstruction(module.get(), "after-all"); - HloInstruction* recv = FindInstruction(module.get(), "recv"); - EXPECT_EQ(recv->operand(0), after_all); - EXPECT_EQ(recv->channel_id().value(), 1); - EXPECT_THAT( - recv->ToString(), - HasSubstr( - "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}")); - check_metadata(recv); - check_not_pipelined(recv); - HloInstruction* recv_done = FindInstruction(module.get(), "recv-done"); - EXPECT_EQ(recv_done->operand(0), recv); - - HloInstruction* send = FindInstruction(module.get(), "send"); - EXPECT_EQ(send->operand(1), after_all); - EXPECT_EQ(send->channel_id().value(), 1); +void EnsurePreservedInfo(const HloInstruction* instr) { + SCOPED_TRACE("AssurePreservedInfo for: " + instr->ToString()); + EXPECT_EQ(instr->channel_id().value(), 1); + EXPECT_EQ(instr->metadata().op_name(), "op1/op2/add"); + EXPECT_EQ(instr->metadata().source_file(), "foo/bar/mysource.py"); + EXPECT_EQ(instr->metadata().source_line(), 35); EXPECT_THAT( - send->ToString(), + instr->ToString(), HasSubstr( "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}")); - check_metadata(send); - check_not_pipelined(send); - HloInstruction* send_done = FindInstruction(module.get(), "send-done"); - EXPECT_EQ(send_done->operand(0), send); - - HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::GetTupleElement(recv_done, 0)); } -TEST_F(CollectivePermuteDecomposerTest, NotTransformedDefaultChannelId) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - p = u32[] replica-id() - ROOT cp = u32[] collective-permute(p), - source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}} +std::string PipelineAttr(const HloInstruction* instr) { + const FrontendAttributes& attr = instr->frontend_attributes(); + if (auto it = attr.map().find(kSendRecvPipelineAttr); + it != attr.map().end()) { + return it->second; } - )"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); - - HloInstruction* after_all = FindInstruction(module.get(), "after-all"); - HloInstruction* recv = FindInstruction(module.get(), "recv"); - EXPECT_EQ(recv->operand(0), after_all); - EXPECT_FALSE(recv->channel_id().has_value()); - EXPECT_THAT( - recv->ToString(), - HasSubstr( - "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}")); - HloInstruction* recv_done = FindInstruction(module.get(), "recv-done"); - EXPECT_EQ(recv_done->operand(0), recv); + return ""; +} +std::string OtherAttr(const HloInstruction* instr) { + const FrontendAttributes& attributes = instr->frontend_attributes(); + return attributes.map().find("_xla_other_attribute")->second; +} - HloInstruction* send = FindInstruction(module.get(), "send"); - EXPECT_EQ(send->operand(1), after_all); - EXPECT_FALSE(send->channel_id().has_value()); - EXPECT_THAT( - send->ToString(), - HasSubstr( - "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}")); - HloInstruction* send_done = FindInstruction(module.get(), "send-done"); - EXPECT_EQ(send_done->operand(0), send); +void EnsurePipelineAttr(Decomposed cp, std::string val) { + SCOPED_TRACE("ExpectePipelineAttr for " + cp.cp_name); + EXPECT_EQ(PipelineAttr(cp.recv), val); + EXPECT_EQ(PipelineAttr(cp.send), val); + EXPECT_EQ(PipelineAttr(cp.recv_done), val); + EXPECT_EQ(PipelineAttr(cp.send_done), val); +} - HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::GetTupleElement(recv_done, 0)); +void EnsureControlDependency(Decomposed cp) { + SCOPED_TRACE("ExpectOpControlDependency for " + cp.cp_name); + EXPECT_EQ(cp.recv->operand(0), cp.after_all); + EXPECT_EQ(cp.send->operand(1), cp.after_all); + EXPECT_EQ(cp.recv_done->operand(0), cp.recv); + EXPECT_EQ(cp.send_done->operand(0), cp.send); + + EXPECT_THAT(cp.send->control_predecessors(), ElementsAre(cp.recv)) + << "Send should depend on recv when decoposed"; + EXPECT_THAT(cp.recv_done->control_predecessors(), ElementsAre(cp.send)) + << "Recv-done should depend on send when decoposed"; } -TEST_F(CollectivePermuteDecomposerTest, ThresholdNotTransformed) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - p = u32[] replica-id() - ROOT cp = u32[] collective-permute(p), channel_id=1, - source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}, - metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} - } +TEST_F(DecomposerTest, StructureAndMetadata) { + absl::string_view hlo = R"( + HloModule test + ENTRY test_computation { + p = u32[] replica-id() + ROOT cp = u32[] collective-permute(p), channel_id=1, + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}, + metadata={op_name="op1/op2/add" + source_file="foo/bar/mysource.py" source_line=35} + } )"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/8); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_FALSE(changed); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, Transform(hlo)); + Decomposed cp = FindComponents(module.get(), "cp"); + EnsurePreservedInfo(cp.send); + EnsurePreservedInfo(cp.recv); + EnsurePipelineAttr(cp, ""); + EnsureControlDependency(cp); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(cp.recv_done, 0)); } -TEST_F(CollectivePermuteDecomposerTest, Pipeline1) { - const char* const kModuleStr = R"( +TEST_F(DecomposerTest, Pipeline1) { + absl::string_view hlo = R"( HloModule module cond { param = (u32[], u32[2]) parameter(0) @@ -205,7 +252,7 @@ TEST_F(CollectivePermuteDecomposerTest, Pipeline1) { count = get-tuple-element(param), index=0 send-data = get-tuple-element(param), index=1 - recv-data = u32[2] collective-permute(send-data), channel_id=1, + cp = u32[2] collective-permute(send-data), channel_id=1, source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}, frontend_attributes={_xla_other_attribute="xyz"} @@ -213,7 +260,7 @@ TEST_F(CollectivePermuteDecomposerTest, Pipeline1) { new_count = u32[] add(count, c1) r = u32[2] broadcast(c1), dimensions={} - s = u32[2] add(r, recv-data) + s = u32[2] add(r, cp) ROOT result = (u32[], u32[2]) tuple(new_count, s) } @@ -229,41 +276,16 @@ TEST_F(CollectivePermuteDecomposerTest, Pipeline1) { ROOT result = u32[2] get-tuple-element(while_result), index=1 })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); - HloInstruction* recv = FindInstruction(module.get(), "recv"); - EXPECT_EQ(recv->channel_id().value(), 1); - EXPECT_THAT( - recv->ToString(), - HasSubstr( - "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}")); - EXPECT_THAT(recv->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); - EXPECT_THAT(recv->ToString(), HasSubstr("_xla_other_attribute=\"xyz\"")); - HloInstruction* recv_done = FindInstruction(module.get(), "recv-done"); - EXPECT_THAT(recv_done->ToString(), - HasSubstr("_xla_send_recv_pipeline=\"0\"")); - - HloInstruction* send = FindInstruction(module.get(), "send"); - EXPECT_EQ(send->channel_id().value(), 1); - EXPECT_THAT( - send->ToString(), - HasSubstr( - "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}")); - EXPECT_THAT(send->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); - EXPECT_THAT(send->ToString(), HasSubstr("_xla_other_attribute=\"xyz\"")); - HloInstruction* send_done = FindInstruction(module.get(), "send-done"); - EXPECT_THAT(send_done->ToString(), - HasSubstr("_xla_send_recv_pipeline=\"0\"")); - - EXPECT_THAT(send->control_predecessors(), ElementsAre(recv)); - EXPECT_THAT(recv_done->control_predecessors(), ElementsAre(send)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, Transform(hlo)); + Decomposed cp = FindComponents(module.get(), "cp"); + EnsurePipelineAttr(cp, "0"); + EXPECT_EQ(OtherAttr(cp.recv), "xyz") << "Preseving other attributes"; + EXPECT_EQ(OtherAttr(cp.send), "xyz") << "Preseving other attributes"; + EnsureControlDependency(cp); } -TEST_F(CollectivePermuteDecomposerTest, ForwardPipeline2) { - const char* const kModuleStr = R"( +TEST_F(DecomposerTest, ForwardPipeline2) { + absl::string_view hlo = R"( HloModule module cond { param = (u32[], u32[2]) parameter(0) @@ -277,17 +299,17 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipeline2) { count = get-tuple-element(param), index=0 send-data = get-tuple-element(param), index=1 - recv-data.0 = u32[2] collective-permute(send-data), channel_id=1, - source_target_pairs={{3,0}} - - recv-data.1 = u32[2] collective-permute(send-data), channel_id=2, + cp_fwd = u32[2] collective-permute(send-data), channel_id=2, source_target_pairs={{0,1}, {1,2}, {2,3}} + cp_back = u32[2] collective-permute(send-data), channel_id=1, + source_target_pairs={{3,0}} + replica = u32[] replica-id() constant0 = u32[] constant(0) compare0 = pred[] compare(replica, constant0), direction=EQ compare = pred[2] broadcast(compare0), dimensions={} - recv-data = u32[2] select(compare, recv-data.0, recv-data.1) + recv-data = u32[2] select(compare, cp_back, cp_fwd) c1 = u32[] constant(1) new_count = u32[] add(count, c1) @@ -309,49 +331,25 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipeline2) { ROOT result = u32[2] get-tuple-element(while_result), index=1 })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); - HloInstruction* recv = FindInstruction(module.get(), "recv"); - EXPECT_EQ(recv->channel_id().value(), 1); - EXPECT_THAT(recv->ToString(), - HasSubstr("_xla_send_recv_source_target_pairs={{3,0}}")); - EXPECT_THAT(recv->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); - HloInstruction* send = FindInstruction(module.get(), "send"); - EXPECT_THAT(send->ToString(), - HasSubstr("_xla_send_recv_source_target_pairs={{3,0}}")); - EXPECT_THAT(send->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); - - HloInstruction* recv1 = FindInstruction(module.get(), "recv.1"); - EXPECT_EQ(recv1->channel_id().value(), 2); - EXPECT_THAT( - recv1->ToString(), - HasSubstr("_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}")); - EXPECT_THAT(recv1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\"")); - HloInstruction* recv_done1 = FindInstruction(module.get(), "recv-done.1"); - EXPECT_THAT(recv_done1->ToString(), - HasSubstr("_xla_send_recv_pipeline=\"1\"")); - HloInstruction* send1 = FindInstruction(module.get(), "send.1"); - EXPECT_THAT( - send1->ToString(), - HasSubstr("_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}")); - EXPECT_THAT(send1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\"")); - HloInstruction* send_done1 = FindInstruction(module.get(), "send-done.1"); - EXPECT_THAT(send_done1->ToString(), - HasSubstr("_xla_send_recv_pipeline=\"1\"")); - - EXPECT_THAT(send->control_predecessors(), ElementsAre(recv)); - EXPECT_THAT(recv1->control_predecessors(), ElementsAre(send)); - EXPECT_THAT(send1->control_predecessors(), ElementsAre(recv1)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, Transform(hlo)); + Decomposed cp_back = FindComponents(module.get(), "cp_back"); + Decomposed cp_fwd = FindComponents(module.get(), "cp_fwd"); + + EXPECT_EQ(cp_back.recv->channel_id().value(), 1); + EXPECT_EQ(cp_fwd.recv->channel_id().value(), 2); + EnsurePipelineAttr(cp_back, "0"); + EnsurePipelineAttr(cp_fwd, "1"); + EnsureControlDependency(cp_back); + EnsureControlDependency(cp_fwd); + EXPECT_THAT(cp_fwd.recv->control_predecessors(), ElementsAre(cp_back.send)) + << "Per sequence of select operands, cp_back should come before cp_fwd"; } -TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { +TEST_F(DecomposerTest, ForwardPipelineWithMatmul) { // The HLO module below is generated by passing the HLO in // CollectiveOpsTest.CollectivePermute_CircularPipelinePreOptimization through // the collective_permute_cycle_decomposer.transformation. - const char* const kModuleStr = R"( + absl::string_view hlo = R"( HloModule test while_body { @@ -370,11 +368,11 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { cp_back = f32[2,2] collective-permute(data), channel_id=1, source_target_pairs={{3,0}}, frontend_attributes={_xla_send_recv_validation="{{3,10}}"} - cp_forward = f32[2,2] collective-permute(data), channel_id=2, + cp_fwd = f32[2,2] collective-permute(data), channel_id=2, source_target_pairs={{0,1},{1,2},{2,3}}, frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9}}"} - select = f32[2,2] select(broadcast, cp_back, cp_forward) + select = f32[2,2] select(broadcast, cp_back, cp_fwd) matmul = f32[2,2] dot(weights, select), lhs_contracting_dims={1}, rhs_contracting_dims={0} @@ -400,66 +398,20 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { ROOT data_out = f32[2,2] get-tuple-element(while_result), index=1 } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); - HloModule* transformed_module = module.get(); - // Check the annotations and ordering of the decomposed send-recv pairs. - // We expect the recv to come before the send in the while body, both for the - // forward edge ({0,1},{1,2},{2,3}}) and the backward edge ({3,0}). This is - // an XLA invariant that shouldn't be broken (see - // https://openxla.org/xla/operation_semantics#send for details of the - // semantics). - HloComputation* while_body = - FindComputation(transformed_module, "while_body"); - HloInstruction* recv_bwd = hlo_query::FindInstruction(while_body, "recv"); - EXPECT_EQ(recv_bwd->channel_id().value(), 1); - auto recv_bwd_frontend_attributes = recv_bwd->frontend_attributes().map(); - EXPECT_EQ(recv_bwd_frontend_attributes.size(), 3); - EXPECT_EQ(recv_bwd_frontend_attributes.at(kSendRecvValidationAttr), - "{{3,10}}"); - EXPECT_EQ(recv_bwd_frontend_attributes.at(kSendRecvPipelineAttr), "0"); - EXPECT_EQ(recv_bwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), - "{{3,0}}"); - - HloInstruction* send_bwd = hlo_query::FindInstruction(while_body, "send"); - auto send_bwd_frontend_attributes = send_bwd->frontend_attributes().map(); - EXPECT_THAT(send_bwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), - "{{3,0}}"); - - HloInstruction* recv_fwd = hlo_query::FindInstruction(while_body, "recv.1"); - EXPECT_EQ(recv_fwd->channel_id().value(), 2); - auto recv_fwd_frontend_attributes = recv_fwd->frontend_attributes().map(); - EXPECT_EQ(recv_fwd_frontend_attributes.size(), 3); - EXPECT_EQ(recv_fwd_frontend_attributes.at(kSendRecvPipelineAttr), "1"); - EXPECT_EQ(recv_fwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), - "{{0,1},{1,2},{2,3}}"); - - HloInstruction* send_fwd = hlo_query::FindInstruction(while_body, "send.1"); - auto send_fwd_frontend_attributes = send_fwd->frontend_attributes().map(); - EXPECT_EQ(send_fwd_frontend_attributes.size(), 3); - EXPECT_EQ(send_fwd_frontend_attributes.at(kSendRecvPipelineAttr), "1"); - EXPECT_EQ(send_fwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), - "{{0,1},{1,2},{2,3}}"); - - EXPECT_NE(while_body, nullptr); - HloInstruction* recv_done_fwd = - hlo_query::FindInstruction(while_body, "recv-done"); - HloInstruction* recv_done_bwd = - hlo_query::FindInstruction(while_body, "recv-done.1"); - - EXPECT_THAT(send_bwd->control_predecessors(), ElementsAre(recv_bwd)); - EXPECT_THAT(recv_fwd->control_predecessors(), ElementsAre(send_bwd)); - EXPECT_THAT(send_fwd->control_predecessors(), ElementsAre(recv_fwd)); - - EXPECT_THAT(recv_done_fwd->control_predecessors(), ElementsAre(send_bwd)); - EXPECT_THAT(recv_done_bwd->control_predecessors(), ElementsAre(send_fwd)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, Transform(hlo)); + Decomposed cp_back = FindComponents(module.get(), "cp_back"); + Decomposed cp_fwd = FindComponents(module.get(), "cp_fwd"); + EXPECT_EQ(cp_back.recv->channel_id().value(), 1); + EXPECT_EQ(cp_fwd.recv->channel_id().value(), 2); + EnsurePipelineAttr(cp_back, "0"); + EnsurePipelineAttr(cp_fwd, "1"); + EnsureControlDependency(cp_back); + EnsureControlDependency(cp_fwd); + EXPECT_THAT(cp_fwd.recv->control_predecessors(), ElementsAre(cp_back.send)); } -TEST_F(CollectivePermuteDecomposerTest, BackwardPipeline2) { - const char* const kModuleStr = R"( +TEST_F(DecomposerTest, BackwardPipeline2) { + absl::string_view hlo = R"( HloModule module cond { param = (u32[], u32[2]) parameter(0) @@ -473,17 +425,17 @@ TEST_F(CollectivePermuteDecomposerTest, BackwardPipeline2) { count = get-tuple-element(param), index=0 send-data = get-tuple-element(param), index=1 - recv-data.0 = u32[2] collective-permute(send-data), channel_id=1, + cp_fwd = u32[2] collective-permute(send-data), channel_id=1, source_target_pairs={{1,0},{2,1},{3,2}} - recv-data.1 = u32[2] collective-permute(send-data), channel_id=2, + cp_back = u32[2] collective-permute(send-data), channel_id=2, source_target_pairs={{0,3}} replica = u32[] replica-id() constant0 = u32[] constant(0) compare0 = pred[] compare(replica, constant0), direction=NE compare = pred[2] broadcast(compare0), dimensions={} - recv-data = u32[2] select(compare, recv-data.0, recv-data.1) + recv-data = u32[2] select(compare, cp_fwd, cp_back) c1 = u32[] constant(1) new_count = u32[] add(count, c1) @@ -505,83 +457,18 @@ TEST_F(CollectivePermuteDecomposerTest, BackwardPipeline2) { ROOT result = u32[2] get-tuple-element(while_result), index=1 })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); - HloInstruction* recv = FindInstruction(module.get(), "recv"); - EXPECT_EQ(recv->channel_id().value(), 1); - EXPECT_THAT( - recv->ToString(), - HasSubstr("_xla_send_recv_source_target_pairs={{1,0},{2,1},{3,2}}")); - EXPECT_THAT(recv->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\"")); - HloInstruction* send = FindInstruction(module.get(), "send"); - EXPECT_THAT( - send->ToString(), - HasSubstr("_xla_send_recv_source_target_pairs={{1,0},{2,1},{3,2}}")); - EXPECT_THAT(send->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\"")); - - HloInstruction* recv1 = FindInstruction(module.get(), "recv.1"); - EXPECT_EQ(recv1->channel_id().value(), 2); - EXPECT_THAT(recv1->ToString(), - HasSubstr("_xla_send_recv_source_target_pairs={{0,3}}")); - EXPECT_THAT(recv1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); - HloInstruction* send1 = FindInstruction(module.get(), "send.1"); - EXPECT_THAT(send1->ToString(), - HasSubstr("_xla_send_recv_source_target_pairs={{0,3}}")); - EXPECT_THAT(send1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); - - EXPECT_THAT(send1->control_predecessors(), ElementsAre(recv1)); - EXPECT_THAT(recv->control_predecessors(), ElementsAre(send1)); - EXPECT_THAT(send->control_predecessors(), ElementsAre(recv)); -} - -TEST_F(CollectivePermuteDecomposerTest, - DecomposeCrossReplicaCollectivePermute) { - const char* const kModuleStr = R"( - HloModule module - ENTRY body { - data = f32[16] parameter(0) - ROOT data_ = f32[16] collective-permute(data), - source_target_pairs={{0,1}, {1,2}, {2,3}} - } - )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnUnverifiedModule((kModuleStr))); - - CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); - - HloComputation* comp = module->entry_computation(); - HloInstruction* root = comp->root_instruction(); - HloInstruction* send = hlo_query::FindInstruction(comp, "send"); - HloInstruction* send_done = hlo_query::FindInstruction(comp, "send-done"); - HloInstruction* recv = hlo_query::FindInstruction(comp, "recv"); - HloInstruction* recv_done = hlo_query::FindInstruction(comp, "recv-done"); - - EXPECT_THAT(send, op::Send(op::Parameter(0), op::AfterAll())); - EXPECT_EQ( - send->frontend_attributes().map().at(kSendRecvSourceTargetPairsAttr), - "{{0,1},{1,2},{2,3}}"); - EXPECT_FALSE(send->channel_id().has_value()); - - EXPECT_THAT(send_done, op::SendDone(send)); - EXPECT_FALSE(send_done->channel_id().has_value()); - - EXPECT_THAT(recv, op::Recv(op::AfterAll())); - EXPECT_EQ( - recv->frontend_attributes().map().at(kSendRecvSourceTargetPairsAttr), - "{{0,1},{1,2},{2,3}}"); - EXPECT_FALSE(recv->channel_id().has_value()); - - EXPECT_THAT(recv_done, op::RecvDone(recv)); - EXPECT_FALSE(recv_done->channel_id().has_value()); - - EXPECT_THAT(root, op::GetTupleElement(recv_done, 0)); - - EXPECT_THAT(send->control_predecessors(), ElementsAre(recv)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, Transform(hlo)); + Decomposed cp_back = FindComponents(module.get(), "cp_back"); + Decomposed cp_fwd = FindComponents(module.get(), "cp_fwd"); + EXPECT_EQ(cp_back.recv->channel_id().value(), 2); + EXPECT_EQ(cp_fwd.recv->channel_id().value(), 1); + + EnsurePipelineAttr(cp_back, "0"); + EnsurePipelineAttr(cp_fwd, "1"); + EnsureControlDependency(cp_back); + EnsureControlDependency(cp_fwd); + EXPECT_THAT(cp_back.recv->control_predecessors(), ElementsAre(cp_fwd.send)) + << "Per sequence of select operands, cp_fwd should come before cp_back"; } } // namespace diff --git a/third_party/xla/xla/service/collective_permute_utils.cc b/third_party/xla/xla/service/collective_permute_utils.cc new file mode 100644 index 00000000000000..3ee67e3d86096f --- /dev/null +++ b/third_party/xla/xla/service/collective_permute_utils.cc @@ -0,0 +1,99 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/collective_permute_utils.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/graphcycles/graphcycles.h" + +namespace xla { +namespace cp_utils { + +using ::xla::HloCollectivePermuteInstruction; + +std::string SourceTargetPairsString(const HloCollectivePermuteInstruction& cp) { + auto formatter = absl::PairFormatter( + [](std::string* out, int64_t value) { absl::StrAppend(out, "{", value); }, + ",", + [](std::string* out, int64_t value) { + absl::StrAppend(out, value, "}"); + }); + const std::string pairs_str = + absl::StrJoin(cp.source_target_pairs(), ",", formatter); + return absl::StrCat("{", pairs_str, "}"); +} + +namespace { +int32_t GetNodeId(int64_t replica, GraphCycles& graph, + absl::flat_hash_map& map) { + if (!map.contains(replica)) { + map.emplace(replica, graph.NewNode()); + } + return map.at(replica); +} +} // namespace + +bool HasCycles(const SourceTargetPairs& pairs) { + GraphCycles graph; + absl::flat_hash_map replica_to_node_id; + for (const SourceTargetPair& pair : pairs) { + const int source = GetNodeId(pair.first, graph, replica_to_node_id); + const int target = GetNodeId(pair.second, graph, replica_to_node_id); + if (!graph.InsertEdge(source, target)) { + return true; + } + } + return false; +} + +// TODO: b/388623407 - remove assumptions that pairs are ordered and 0 based. +bool IsForwardCycle(const SourceTargetPair& backedge, + const SourceTargetPairs& others) { + const int64_t num_pairs = others.size() + 1; + if (backedge.first != num_pairs - 1 || backedge.second != 0) { + return false; + } + for (int64_t i = 0; i < num_pairs - 1; ++i) { + const SourceTargetPair& pair = others[i]; + if (pair.first != i || pair.second != i + 1) { + return false; + } + } + return true; +} + +bool IsBackwardCycle(const SourceTargetPair& backedge, + const SourceTargetPairs& others) { + const int64_t num_pairs = others.size() + 1; + if (backedge.first != 0 || backedge.second != num_pairs - 1) { + return false; + } + for (int64_t i = 0; i < num_pairs - 1; ++i) { + const SourceTargetPair& pair = others[i]; + if (pair.first != i + 1 || pair.second != i) { + return false; + } + } + return true; +} + +} // namespace cp_utils +} // namespace xla diff --git a/third_party/xla/xla/service/collective_permute_utils.h b/third_party/xla/xla/service/collective_permute_utils.h new file mode 100644 index 00000000000000..46c62ea25bb381 --- /dev/null +++ b/third_party/xla/xla/service/collective_permute_utils.h @@ -0,0 +1,54 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_COLLECTIVE_PERMUTE_UTILS_H_ +#define XLA_SERVICE_COLLECTIVE_PERMUTE_UTILS_H_ + +#include +#include +#include +#include + +#include "xla/hlo/ir/hlo_instructions.h" + +namespace xla { +namespace cp_utils { + +using SourceTargetPair = std::pair; +using SourceTargetPairs = std::vector; + +// Source Targe Pairs to a cannoical string such as {{0,1},{1,2},{2,3},{3,0}}. +std::string SourceTargetPairsString(const HloCollectivePermuteInstruction& cp); + +// Returns true if the (source, target) relationship has a cycle. +bool HasCycles(const SourceTargetPairs& pairs); + +// Returns true if the (source, target) pairs form a forward cycle with all +// participants in the cycle, such as {{0,1},{1,2},{2,3},{3,0}}. We assume that +// the (source, target) pairs are ordered via increasing source IDs, as they are +// currently generated by SPMD partitioning. +bool IsForwardCycle(const SourceTargetPair& backedge, + const SourceTargetPairs& others); + +// Returns true if the (source, target) pairs form a backward cycle with all +// participants in the cycle, such as {{0,3},{1,0},{2,1},{3,2}}. We assume that +// the (source, target) pairs are ordered via increasing source IDs, as they are +// currently generated by SPMD partitioning. +bool IsBackwardCycle(const SourceTargetPair& backedge, + const SourceTargetPairs& others); + +} // namespace cp_utils +} // namespace xla +#endif // XLA_SERVICE_COLLECTIVE_PERMUTE_UTILS_H_ diff --git a/third_party/xla/xla/service/collective_permute_utils_test.cc b/third_party/xla/xla/service/collective_permute_utils_test.cc new file mode 100644 index 00000000000000..54a2a66eb349ba --- /dev/null +++ b/third_party/xla/xla/service/collective_permute_utils_test.cc @@ -0,0 +1,107 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/collective_permute_utils.h" + +#include + +#include +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape_util.h" + +namespace xla { +namespace cp_utils { + +struct Cannonical { + SourceTargetPairs cycle; + SourceTargetPairs fwd_edge; + SourceTargetPairs bwd_edge; +}; + +class CollectivePermuteUtilsTest : public ::testing::Test { + protected: + Cannonical fwd2_ = { + .cycle = {{0, 1}, {1, 0}}, .fwd_edge = {{0, 1}}, .bwd_edge = {{1, 0}}}; + Cannonical bwd2_ = { + .cycle = {{1, 0}, {0, 1}}, .fwd_edge = {{1, 0}}, .bwd_edge = {{0, 1}}}; + Cannonical fwd4_ = {.cycle = {{0, 1}, {1, 2}, {2, 3}, {3, 0}}, + .fwd_edge = {{0, 1}, {1, 2}, {2, 3}}, + .bwd_edge = {{3, 0}}}; + Cannonical bwd4_ = {.cycle = {{0, 3}, {1, 0}, {2, 1}, {3, 2}}, + .fwd_edge = {{1, 0}, {2, 1}, {3, 2}}, + .bwd_edge = {{0, 3}}}; + std::unique_ptr simple_input_ = HloInstruction::CreateToken(); + + HloCollectivePermuteInstruction CreateCollectivePermute( + const SourceTargetPairs& pairs) { + return HloCollectivePermuteInstruction(HloOpcode::kCollectivePermute, + ShapeUtil::MakeShape(U32, {8, 8}), + simple_input_.get(), pairs, 1); + } +}; + +TEST_F(CollectivePermuteUtilsTest, HasCycles) { + EXPECT_TRUE(HasCycles(fwd2_.cycle)); + EXPECT_TRUE(HasCycles(bwd2_.cycle)); + EXPECT_TRUE(HasCycles(fwd4_.cycle)); + EXPECT_TRUE(HasCycles(bwd4_.cycle)); + + EXPECT_TRUE(HasCycles({{0, 1}, {1, 2}, {2, 3}, {3, 2}})) << "Lasso 3->2"; + EXPECT_TRUE(HasCycles({{0, 1}, {1, 2}, {2, 3}, {3, 1}})) << "Lasso 3->1"; + + EXPECT_FALSE(HasCycles({{1, 2}, {2, 3}, {3, 0}})) << "Forward only"; + EXPECT_FALSE(HasCycles({{1, 2}})) << "Single edge"; +} + +bool IsForwardCycle(Cannonical& canonical) { + return IsForwardCycle(canonical.bwd_edge[0], canonical.fwd_edge); +} +bool IsBackwardCycle(Cannonical& canonical) { + return IsBackwardCycle(canonical.bwd_edge[0], canonical.fwd_edge); +} + +TEST_F(CollectivePermuteUtilsTest, IsForwardCycle) { + EXPECT_TRUE(IsForwardCycle(fwd2_)); + EXPECT_TRUE(IsForwardCycle(fwd4_)); + + EXPECT_FALSE(IsForwardCycle(bwd2_)); + EXPECT_FALSE(IsForwardCycle(bwd4_)); + + EXPECT_FALSE(IsForwardCycle({3, 0}, {{0, 2}, {2, 3}, {3, 0}})) << "Skip 1"; +} + +TEST_F(CollectivePermuteUtilsTest, IsBackwardCycle) { + EXPECT_TRUE(IsBackwardCycle(bwd2_)); + EXPECT_TRUE(IsBackwardCycle(bwd4_)); + + EXPECT_FALSE(IsBackwardCycle(fwd2_)); + EXPECT_FALSE(IsBackwardCycle(fwd4_)); +} + +TEST_F(CollectivePermuteUtilsTest, SourceTargetPairsString) { + EXPECT_EQ(SourceTargetPairsString(CreateCollectivePermute(fwd2_.cycle)), + "{{0,1},{1,0}}"); + EXPECT_EQ(SourceTargetPairsString(CreateCollectivePermute(bwd2_.cycle)), + "{{1,0},{0,1}}"); + EXPECT_EQ(SourceTargetPairsString(CreateCollectivePermute(fwd4_.cycle)), + "{{0,1},{1,2},{2,3},{3,0}}"); + EXPECT_EQ(SourceTargetPairsString(CreateCollectivePermute(bwd4_.cycle)), + "{{0,3},{1,0},{2,1},{3,2}}"); +} + +} // namespace cp_utils +} // namespace xla diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index d02424990edea9..69a4af5295c5f6 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -148,7 +148,7 @@ std::optional GetSlicedDimension( bool CheckIndexIsMonotonic( const HloInstruction* index, - const absl::flat_hash_map& induction_map) { + absl::flat_hash_map& induction_map) { // Because the only math operations supported by RecursivelyIdentifyRange() // are only sub/add then checking that we can compute the range here is enough // to guarantee that the index is monotonic if the base index is monotonic. If @@ -156,7 +156,7 @@ bool CheckIndexIsMonotonic( // sophisticated check for monotonicity. Range range = RecursivelyIdentifyRange(index, induction_map); VLOG(6) << "Range for: " << index->ToString() << " " << range.ToString(); - return !range.IsEmpty() && range.IsLinear(); + return !range.IsEmpty() && range.IsBounded() && range.IsLinear(); } // Check that the parameter is only used in a pattern param -> gte -> @@ -789,8 +789,7 @@ class WhileLoopAnalysis { CollectivePipeliner::PipeliningDirection direction, int64_t level_to_operate_on, const absl::flat_hash_map& parameter_gtes_count, - const absl::flat_hash_map& index_ranges) - const; + absl::flat_hash_map& index_ranges) const; // Merges the new collective (instr) with the existing one stored in // move_infos_[indices_to_merge[0]]. indices_to_merge.size() should be 1. @@ -981,8 +980,7 @@ WhileLoopAnalysis::IsSupportedDynamicUpdateSlice( CollectivePipeliner::PipeliningDirection direction, int64_t level_to_operate_on, const absl::flat_hash_map& parameter_gtes_count, - const absl::flat_hash_map& index_ranges) - const { + absl::flat_hash_map& index_ranges) const { HloComputation* while_body = while_->while_body(); const HloInstruction* loop_parameter = while_body->parameter_instructions()[0]; diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc index 0d82839f27c552..baba47db849303 100644 --- a/third_party/xla/xla/service/collective_pipeliner_test.cc +++ b/third_party/xla/xla/service/collective_pipeliner_test.cc @@ -38,13 +38,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal_util.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" #include "xla/service/host_memory_offload_annotations.h" #include "xla/test_helpers.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/collective_utils.h b/third_party/xla/xla/service/collective_utils.h index 916e007dc9b2eb..dc69009445686d 100644 --- a/third_party/xla/xla/service/collective_utils.h +++ b/third_party/xla/xla/service/collective_utils.h @@ -32,6 +32,11 @@ constexpr int64_t kDefaultAllGatherCombineThreshold = 30 * 1024 * 1024 + 7; // pass will combine collectives. constexpr int64_t kDefaultReduceScatterCombineThreshold = 30 * 1024 * 1024 + 7; +// Defines the default coefficient for the SoL NCCL collective cost model. +// Note: XLA flags allow a user to override the default values of the model. +constexpr float kDefaultNcclCostModelCoeff = 0.45f; +constexpr int64_t kDefaultNcclCostModelChunkSizeBytes = 4194304; // 4MB +constexpr int64_t kDefaultNcclCostModelGPUsPerNode = 8; } // namespace xla #endif // XLA_SERVICE_COLLECTIVE_UTILS_H_ diff --git a/third_party/xla/xla/service/compilation_environments.cc b/third_party/xla/xla/service/compilation_environments.cc index f4e6f5b404d917..f2e0dff2b5c5b3 100644 --- a/third_party/xla/xla/service/compilation_environments.cc +++ b/third_party/xla/xla/service/compilation_environments.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -58,7 +57,7 @@ class GlobalCompEnvStats { return *singleton; } - void DefaultEnvCreatedByCompilationEnvironments(std::string_view env_type) + void DefaultEnvCreatedByCompilationEnvironments(absl::string_view env_type) ABSL_LOCKS_EXCLUDED(mu_) { { absl::MutexLock l(&mu_); @@ -68,7 +67,7 @@ class GlobalCompEnvStats { VLOG(1) << "New GlobalCompEnvStats value: " << ToString(); } - void EnvAdded(std::string_view env_type) ABSL_LOCKS_EXCLUDED(mu_) { + void EnvAdded(absl::string_view env_type) ABSL_LOCKS_EXCLUDED(mu_) { { absl::MutexLock l(&mu_); ++stats_[std::string(env_type)].env_added; @@ -230,12 +229,12 @@ CompilationEnvironments::GetProcessNewEnvFn( } void CompilationEnvironments::DefaultEnvCreatedByCompilationEnvironments( - std::string_view env_type) { + absl::string_view env_type) { GlobalCompEnvStats::GetSingleton().DefaultEnvCreatedByCompilationEnvironments( env_type); } -void CompilationEnvironments::EnvAdded(std::string_view env_type) { +void CompilationEnvironments::EnvAdded(absl::string_view env_type) { GlobalCompEnvStats::GetSingleton().EnvAdded(env_type); } diff --git a/third_party/xla/xla/service/compilation_environments.h b/third_party/xla/xla/service/compilation_environments.h index 08a79df01e09cd..fe845c23a2e711 100644 --- a/third_party/xla/xla/service/compilation_environments.h +++ b/third_party/xla/xla/service/compilation_environments.h @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -118,11 +117,11 @@ class CompilationEnvironments { // track stats about how many such environments are created by // CompilationEnvironments. static void DefaultEnvCreatedByCompilationEnvironments( - std::string_view env_type); + absl::string_view env_type); // Called by AddEnv(), to globally track stats about how many environments // are added to CompilationEnvironments. - static void EnvAdded(std::string_view env_type); + static void EnvAdded(absl::string_view env_type); absl::Status AddEnvImpl(const tsl::protobuf::Descriptor& descriptor, std::unique_ptr env); diff --git a/third_party/xla/xla/service/compiler.h b/third_party/xla/xla/service/compiler.h index 45dc7298c4e8d4..dd923a4ce45043 100644 --- a/third_party/xla/xla/service/compiler.h +++ b/third_party/xla/xla/service/compiler.h @@ -300,7 +300,7 @@ class Compiler { // Returns a function that computes the size in bytes of a given // logical buffer. - std::function BufferSizeBytesFunction() { + std::function BufferSizeBytesFunction() const { HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction(); return [shape_size](const BufferValue& buffer) { return shape_size(buffer.shape()); diff --git a/third_party/xla/xla/service/cost_modelling/op_cost.cc b/third_party/xla/xla/service/cost_modelling/op_cost.cc index c7becbef9ff35b..53e8bdc73cde8a 100644 --- a/third_party/xla/xla/service/cost_modelling/op_cost.cc +++ b/third_party/xla/xla/service/cost_modelling/op_cost.cc @@ -46,7 +46,7 @@ namespace xla { namespace { // Used in LOG(INFO) statements for analysis logging. -constexpr std::string_view kLoggingAnalysisId = "COST_LOGGING"; +constexpr absl::string_view kLoggingAnalysisId = "COST_LOGGING"; } // namespace @@ -291,7 +291,7 @@ class CalculationLeaf : public OpCostManager::CalculationNode { return cost_value.value(); } - std::string_view Name() const override { return name_; } + absl::string_view Name() const override { return name_; } std::vector LeafCalculatorNames() const override { return {name_}; @@ -373,7 +373,7 @@ class DelegationCalculationNode : public OpCostManager::CalculationNode { return final_result; } - std::string_view Name() const override { return name_; } + absl::string_view Name() const override { return name_; } std::vector LeafCalculatorNames() const override { std::vector result; diff --git a/third_party/xla/xla/service/cost_modelling/op_cost.h b/third_party/xla/xla/service/cost_modelling/op_cost.h index dce6b3d50b305e..356599707fc706 100644 --- a/third_party/xla/xla/service/cost_modelling/op_cost.h +++ b/third_party/xla/xla/service/cost_modelling/op_cost.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -161,6 +162,9 @@ class CostValue { // Suitable for logging analysis for debugging. std::string ToString() const; + friend std::ostream& operator<<(std::ostream& os, const CostValue& value) { + return os << value.ToString(); + } private: enum class Type : std::uint8_t { kNotFound, kError, kOk }; @@ -265,7 +269,7 @@ class OpCostManager { const CostMetricId& metric_id, LeafCalculatorValueMap* calculator_value_map) = 0; - virtual std::string_view Name() const = 0; + virtual absl::string_view Name() const = 0; // Returns the names of leaf calculators at or below the node (in the tree). // Leaf calculator names are used to uniquely identify the costs associated diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 91166c2235fff4..f0ea8b779baa63 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -235,6 +235,7 @@ cc_library( "//xla:cpu_function_runtime", "//xla:debug_options_flags", "//xla:literal", + "//xla:literal_pool", "//xla:protobuf_util", "//xla:shape_util", "//xla:status_macros", @@ -254,44 +255,45 @@ cc_library( "//xla/hlo/ir:hlo_module_group", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:algebraic_simplifier", - "//xla/hlo/transforms:batch_dot_simplification", "//xla/hlo/transforms:bitcast_dtypes_expander", - "//xla/hlo/transforms:broadcast_canonicalizer", "//xla/hlo/transforms:cholesky_expander", "//xla/hlo/transforms:comparison_expander", - "//xla/hlo/transforms:conditional_canonicalizer", - "//xla/hlo/transforms:convolution_group_converter", "//xla/hlo/transforms:dot_decomposer", - "//xla/hlo/transforms:dynamic_dimension_simplifier", "//xla/hlo/transforms:dynamic_index_splitter", "//xla/hlo/transforms:eigh_expander", - "//xla/hlo/transforms:flatten_call_graph", - "//xla/hlo/transforms:float_normalization", - "//xla/hlo/transforms:hlo_constant_folding", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:hlo_memory_scheduler", + "//xla/hlo/transforms:literal_canonicalizer", "//xla/hlo/transforms:logistic_expander", "//xla/hlo/transforms:operand_upcaster", "//xla/hlo/transforms:optimization_barrier_expander", - "//xla/hlo/transforms:optimize_input_output_buffer_alias", "//xla/hlo/transforms:qr_expander", "//xla/hlo/transforms:reduce_decomposer", - "//xla/hlo/transforms:reduce_window_rewriter", "//xla/hlo/transforms:reshape_decomposer", - "//xla/hlo/transforms:reshape_mover", - "//xla/hlo/transforms:result_caster", "//xla/hlo/transforms:rng_bit_generator_expander", "//xla/hlo/transforms:rng_expander", - "//xla/hlo/transforms:simplify_fp_conversions", - "//xla/hlo/transforms:slice_sinker", - "//xla/hlo/transforms:sort_simplifier", "//xla/hlo/transforms:stochastic_convert_decomposer", - "//xla/hlo/transforms:sub_byte_normalization", - "//xla/hlo/transforms:tree_reduction_rewriter", - "//xla/hlo/transforms:tuple_simplifier", "//xla/hlo/transforms:while_loop_trip_count_annotator", - "//xla/hlo/transforms:zero_sized_hlo_elimination", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:batch_dot_simplification", + "//xla/hlo/transforms/simplifiers:broadcast_canonicalizer", + "//xla/hlo/transforms/simplifiers:conditional_canonicalizer", + "//xla/hlo/transforms/simplifiers:convolution_group_converter", + "//xla/hlo/transforms/simplifiers:dynamic_dimension_simplifier", + "//xla/hlo/transforms/simplifiers:flatten_call_graph", + "//xla/hlo/transforms/simplifiers:float_normalization", + "//xla/hlo/transforms/simplifiers:hlo_constant_folding", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", + "//xla/hlo/transforms/simplifiers:optimize_input_output_buffer_alias", + "//xla/hlo/transforms/simplifiers:reduce_window_rewriter", + "//xla/hlo/transforms/simplifiers:reshape_mover", + "//xla/hlo/transforms/simplifiers:result_caster", + "//xla/hlo/transforms/simplifiers:simplify_fp_conversions", + "//xla/hlo/transforms/simplifiers:slice_sinker", + "//xla/hlo/transforms/simplifiers:sort_simplifier", + "//xla/hlo/transforms/simplifiers:sub_byte_normalization", + "//xla/hlo/transforms/simplifiers:tree_reduction_rewriter", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:zero_sized_hlo_elimination", "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", "//xla/mlir_hlo", @@ -424,7 +426,6 @@ cc_library( deps = [ "cpu_compiler_pure", ":executable_proto_cc", - ":xla_framework", "//xla:cpu_function_runtime", "//xla:util", "//xla/backends/cpu/codegen:target_machine_features", @@ -437,6 +438,7 @@ cc_library( "//xla/service:hlo_profile_printer_data_cc", "//xla/service:hlo_proto_cc", "//xla/service:llvm_compiler", + "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/host:host_platform_id", "@com_google_absl//absl/status", @@ -453,20 +455,17 @@ xla_test( "cpu", ], tags = [ - "test_hlo_pjrt_runner", + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", ], deps = [ - "//xla:shape_util", - "//xla/pjrt:pjrt_client", - "//xla/service:hlo_runner", - "//xla/service:hlo_runner_interface", - "//xla/service:hlo_runner_pjrt", - "//xla/service:platform_util", - "//xla/tests:hlo_runner_agnostic_test_base", - "//xla/tests:pjrt_client_registry", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/tests:hlo_pjrt_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/monitoring:collected_metrics", "//xla/tsl/lib/monitoring:collection_registry", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], @@ -630,17 +629,12 @@ cc_library( srcs = ["elemental_math_emitter.cc"], hdrs = ["elemental_math_emitter.h"], deps = [ - "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/service/llvm_ir:ir_array", - "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:math_ops", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", - "@local_tsl//tsl/platform:logging", ], ) @@ -651,29 +645,31 @@ cc_library( deps = [ ":backend_config_proto_cc", ":dot_op_emitter", - ":elemental_math_emitter", + ":elemental_ir_emitter", ":ir_emitter", - ":ir_function", ":parallel_loop_emitter", ":shape_partition", - "//xla:cpu_function_runtime", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", + "//xla/backends/cpu/codegen:kernel_api_ir_builder", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:elemental_ir_emitter", + "//xla/service:hlo_module_config", "//xla/service/llvm_ir:dynamic_update_slice_util", "//xla/service/llvm_ir:fused_ir_emitter", "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", "//xla/stream_executor:launch_dim", + "//xla/tsl/platform:errors", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -682,8 +678,6 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ], ) @@ -692,49 +686,39 @@ xla_cc_test( name = "ir_emitter_test", srcs = ["ir_emitter_test.cc"], deps = [ + ":cpu_compiler", + ":cpu_executable", + ":cpu_options", ":ir_emitter", ":ir_function", - ":target_machine_features_stub", - "//xla/hlo/analysis:hlo_ordering", - "//xla/hlo/ir:hlo", - "//xla/hlo/parser:hlo_parser", - "//xla/service:buffer_assignment", - "//xla/service:hlo_module_config", - "//xla/service:logical_buffer", - "//xla/tests:hlo_test_base", - "@llvm-project//llvm:Core", - "@llvm-project//llvm:Support", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - -xla_cc_test( - name = "ir_emitter2_test", - srcs = ["ir_emitter2_test.cc"], - deps = [ - ":ir_emitter", - ":ir_emitter2", + ":runtime_symbol_generator", ":target_machine_features_stub", "//xla:cpu_function_runtime", - "//xla:shape_util", - "//xla:xla_data_proto_cc", + "//xla/backends/cpu/codegen:cpu_features", + "//xla/backends/cpu/codegen:ir_compiler", + "//xla/backends/cpu/codegen:jit_compiler", + "//xla/backends/cpu/codegen:target_machine_features", "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", "//xla/service:buffer_assignment", + "//xla/service:buffer_value", "//xla/service:hlo_module_config", "//xla/service:logical_buffer", - "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:llvm_util", - "//xla/tests:filecheck", "//xla/tests:hlo_test_base", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -748,9 +732,11 @@ cc_library( copts = tsl_copts(), deps = [ ":backend_config_proto_cc", + ":cpu_instruction_fusion", ":cpu_options", ":cpu_runtime", ":dot_op_emitter", + ":elemental_ir_emitter", ":elemental_math_emitter", ":ir_emission_utils", ":ir_function", @@ -858,6 +844,7 @@ cc_library( srcs = ["thunk_emitter.cc"], hdrs = ["thunk_emitter.h"], deps = [ + ":backend_config_proto_cc", ":dot_op_emitter", ":ir_emission_utils", ":ir_emitter2", @@ -867,6 +854,9 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu:xnn_emitter", + "//xla/backends/cpu/codegen:elemental_kernel_emitter", + "//xla/backends/cpu/codegen:llvm_ir_kernel_spec", "//xla/backends/cpu/codegen:target_machine_features", "//xla/backends/cpu/runtime:all_gather_thunk", "//xla/backends/cpu/runtime:all_reduce_thunk", @@ -891,22 +881,28 @@ cc_library( "//xla/backends/cpu/runtime:thunk", "//xla/backends/cpu/runtime:topk_thunk", "//xla/backends/cpu/runtime:while_thunk", + "//xla/backends/cpu/runtime/xnnpack:xnn_dot_thunk", + "//xla/backends/cpu/runtime/xnnpack:xnn_fusion_thunk", + "//xla/codegen:kernel_spec", + "//xla/codegen:llvm_ir_kernel_source", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", "//xla/service:pattern_matcher", - "//xla/service/cpu:backend_config_proto_cc", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", + "@llvm-project//llvm:JITLink", + "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:statusor", ], ) @@ -1005,33 +1001,39 @@ cc_library( ], copts = runtime_copts(), deps = [ - ":collectives_interface", ":cpu_executable_run_options", - ":in_process_collectives", "//xla:executable_run_options", "//xla:shape_util", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:cpu_clique", + "//xla/backends/cpu/collectives:cpu_clique_key", + "//xla/backends/cpu/collectives:cpu_cliques", + "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/backends/cpu/collectives:in_process_collectives", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/hlo/parser:hlo_parser", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", "//xla/service:global_device_id", "//xla/stream_executor:device_memory", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:status", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -1320,11 +1322,10 @@ cc_library( "//xla:executable_run_options", "//xla/service:custom_call_status_internal", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:logging", ], ) @@ -1388,6 +1389,7 @@ xla_cc_test( tags = ["not_run:arm"], deps = [ ":cpu_instruction_fusion", + "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -1459,7 +1461,7 @@ xla_cc_test( deps = [ ":ir_emission_utils", ":target_machine_features_stub", - "//xla:test", + "//xla/hlo/testlib:test", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", ], @@ -1497,11 +1499,11 @@ xla_cc_test( "//xla:literal", "//xla:shape_layout", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/hlo/utils:hlo_matchers", "//xla/service:computation_layout", "//xla/tests:hlo_test_base", @@ -1541,11 +1543,12 @@ xla_cc_test( deps = [ ":conv_canonicalization", ":target_machine_features_stub", - "//xla:test", - "//xla:test_helpers", + "//xla:literal_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", ], @@ -1567,9 +1570,9 @@ xla_cc_test( srcs = ["shape_partition_test.cc"], deps = [ ":shape_partition", - "//xla:test_helpers", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/testlib:test_helpers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", @@ -1611,9 +1614,9 @@ xla_cc_test( ":cpu_executable", ":parallel_task_assignment", ":target_machine_features_stub", - "//xla:test", "//xla/backends/cpu/codegen:target_machine_features", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "//xla/service:hlo_cost_analysis", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -1653,7 +1656,7 @@ xla_cc_test( deps = [ ":ir_emission_utils", ":target_machine_features_stub", - "//xla:test", + "//xla/hlo/testlib:test", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", ], @@ -1669,11 +1672,11 @@ xla_cc_test( ":cpu_compiler", ":cpu_transfer_manager", ":test_header_helper", - "//xla:test", "//xla:util", "//xla/backends/cpu/codegen:target_machine_features", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", + "//xla/hlo/testlib:test", "//xla/service:compiler", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -1731,8 +1734,9 @@ cc_library( ":onednn_config_proto_cc", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/tsl/platform:env", + "@com_google_absl//absl/synchronization", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:platform_port", ] + mkl_deps(), @@ -1785,8 +1789,8 @@ cc_library( "//xla/hlo/ir:hlo", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", + "@com_google_absl//absl/synchronization", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", @@ -1809,8 +1813,8 @@ cc_library( "//xla:shape_util", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", + "@com_google_absl//absl/synchronization", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", @@ -1832,10 +1836,11 @@ cc_library( ":onednn_memory_util", ":runtime_lightweight_check", "//xla:executable_run_options", + "//xla/tsl/platform:env", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", + "@com_google_absl//absl/synchronization", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:platform_port", ] + mkl_deps(), @@ -1856,10 +1861,11 @@ cc_library( ":onednn_memory_util", ":runtime_lightweight_check", "//xla:executable_run_options", + "//xla/tsl/platform:env", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", + "@com_google_absl//absl/synchronization", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:platform_port", ] + mkl_deps(), @@ -1902,10 +1908,11 @@ cc_library( "//xla/service:hlo_cost_analysis", "//xla/service:hlo_creation_utils", "//xla/service:pattern_matcher", + "//xla/tsl/platform:env", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", @@ -1956,48 +1963,10 @@ cc_library( ], ) -cc_library( - name = "collectives_interface", - hdrs = ["collectives_interface.h"], - deps = [ - "//xla:xla_data_proto_cc", - "//xla/service:collective_ops_utils", - "//xla/service:global_device_id", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "in_process_collectives", - srcs = ["in_process_collectives.cc"], - hdrs = ["in_process_collectives.h"], - deps = [ - ":collectives_interface", - "//xla:refcounting_hash_map", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service:collective_ops_utils", - "//xla/service:global_device_id", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - ], -) - cc_library( name = "cpu_executable_run_options", hdrs = ["cpu_executable_run_options.h"], - deps = [":collectives_interface"], + deps = ["//xla/backends/cpu/collectives:cpu_collectives"], ) cc_library( @@ -2009,6 +1978,24 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:stacktrace", + "@local_tsl//tsl/profiler/lib:traceme", + ], +) + +cc_library( + name = "elemental_ir_emitter", + srcs = ["elemental_ir_emitter.cc"], + hdrs = ["elemental_ir_emitter.h"], + deps = [ + ":elemental_math_emitter", + "//xla/hlo/ir:hlo", + "//xla/service:elemental_ir_emitter", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:ir_headers", ], ) diff --git a/third_party/xla/xla/service/cpu/backend_config.proto b/third_party/xla/xla/service/cpu/backend_config.proto index 426f7e83229d74..3779fd755963d2 100644 --- a/third_party/xla/xla/service/cpu/backend_config.proto +++ b/third_party/xla/xla/service/cpu/backend_config.proto @@ -15,6 +15,10 @@ message CustomCallBackendConfig { } } +message FusionBackendConfig { + string kind = 1; +} + // Backend config for XLA:CPU. message BackendConfig { // Number of partitions per outer dimension (in order, starting with @@ -32,5 +36,7 @@ message BackendConfig { OneDnnConvolutionConfig onednn_conv_config = 5; // Configuration to be used by general custom call, e.g., FFI. CustomCallBackendConfig custom_call_config = 6; + // Configuration for custom fusions. + FusionBackendConfig fusion_config = 7; } } diff --git a/third_party/xla/xla/service/cpu/benchmarks/BUILD b/third_party/xla/xla/service/cpu/benchmarks/BUILD index 6ef77577c0c4d9..49b2292f3708d9 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/BUILD +++ b/third_party/xla/xla/service/cpu/benchmarks/BUILD @@ -30,8 +30,10 @@ cc_library( "//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", "//xla/service:hlo_module_config", "//xla/tests:test_utils", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", @@ -48,10 +50,11 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], ) @@ -65,10 +68,11 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], ) @@ -82,11 +86,11 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], ) @@ -100,10 +104,11 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], ) @@ -117,11 +122,12 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], ) @@ -135,10 +141,11 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], ) @@ -152,10 +159,11 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], ) @@ -169,9 +177,9 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], ) @@ -187,13 +195,12 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/ffi", "//xla/ffi:ffi_api", - "//xla/tests:hlo_test_base", - "//xla/tests:test_macros_header", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], ) @@ -208,10 +215,11 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test_benchmark", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:test_main", ], ) @@ -225,10 +233,11 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], ) @@ -242,10 +251,11 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], ) @@ -258,9 +268,10 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:test_benchmark", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:test_main", ], ) @@ -274,10 +285,11 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], ) @@ -291,10 +303,11 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], ) @@ -308,10 +321,10 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/service/cpu/benchmarks/concatenate_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/concatenate_benchmark_test.cc index 9daa20e011df35..caaf72e9c08493 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/concatenate_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/concatenate_benchmark_test.cc @@ -15,20 +15,20 @@ limitations under the License. #include #include -#include #include #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test_benchmark.h" namespace xla::cpu { @@ -38,7 +38,7 @@ static void BM_ConcatenateTwoR3F32(benchmark::State& state) { Shape shape = ShapeUtil::MakeShape(F32, dims); int64_t axis = state.range(4); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule concatenate_r3f32_$shape_repr ENTRY test { diff --git a/third_party/xla/xla/service/cpu/benchmarks/convolution_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/convolution_benchmark_test.cc index c59b9af562ebf8..57fe8f3bd735a9 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/convolution_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/convolution_benchmark_test.cc @@ -22,9 +22,9 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test_benchmark.h" namespace xla::cpu { namespace { @@ -81,15 +81,73 @@ static void BM_Conv2D(benchmark::State& state) { padding_w, "_", padding_w)}})); } +static void BM_GroupedConv2D(benchmark::State& state) { + int batch = state.range(0); + int height = state.range(1); + int width = state.range(2); + int input_channels = state.range(3); + int kernel_h = state.range(4); + int kernel_w = state.range(5); + int output_channels = state.range(6); + int feature_group_count = state.range(7); + + // Derive filter channels from input channels and feature group count. + int filter_channels = input_channels / feature_group_count; + + // Padding values for 'SAME' padding. Only odd kernel sizes are supported. + CHECK(IsOdd(kernel_h) && IsOdd(kernel_w)); + int padding_h = (kernel_h - 1) / 2; + int padding_w = (kernel_w - 1) / 2; + + std::string hlo_module = R"( + HloModule TestModule + + ENTRY TestComputation { + %p0 = $input_shape parameter(0) + %p1 = $kernel_shape parameter(1) + ROOT conv = convolution(p0, p1), window={size=$window_size pad=$padding}, + dim_labels=b01f_01io->b01f, feature_group_count=$feature_group_count + } + )"; + + std::minstd_rand0 engine; + + // Input format is NHWC. + auto input_shape = + ShapeUtil::MakeShape(F32, {batch, height, width, input_channels}); + // Filter format is HWIO. + auto kernel_shape = ShapeUtil::MakeShape( + F32, {kernel_h, kernel_w, filter_channels, output_channels}); + + auto input = + *LiteralUtil::CreateRandomLiteral(input_shape, &engine, 1.0f, 0.1f); + auto kernel = + *LiteralUtil::CreateRandomLiteral(kernel_shape, &engine, 1.0f, 0.1f); + + std::vector args = {&input, &kernel}; + + CHECK_OK(RunHloBenchmark( + state, hlo_module, args, + {{"$input_shape", input_shape.ToString()}, + {"$kernel_shape", kernel_shape.ToString()}, + {"$window_size", absl::StrCat(kernel_h, "x", kernel_w)}, + {"$padding", absl::StrCat(padding_h, "_", padding_h, "x", padding_w, "_", + padding_w)}, + {"$feature_group_count", absl::StrCat(feature_group_count)}})); +} + // Regular strided 1D convolution. Shapes come from an actual use case. static void BM_Conv1DStrided(benchmark::State& state) { + int input_channels = state.range(0); + int output_channels = state.range(1); + std::string hlo_module = R"( HloModule jit_jconvf ENTRY main.6 { - Arg_0.1 = f32[16,1,25600]{2,1,0} parameter(0) - Arg_1.2 = f32[1,129,256]{2,1,0} parameter(1) - ROOT conv.3 = f32[16,129,400]{2,1,0} convolution(Arg_0.1, Arg_1.2), + Arg_0.1 = $input_shape parameter(0) + Arg_1.2 = $kernel_shape parameter(1) + ROOT conv.3 = $output_shape convolution(Arg_0.1, Arg_1.2), window={size=256 stride=64 pad=96_96}, dim_labels=bf0_io0->bf0 } )"; @@ -97,9 +155,11 @@ static void BM_Conv1DStrided(benchmark::State& state) { std::minstd_rand0 engine; // NCW layout - auto input_shape = ShapeUtil::MakeShape(F32, {16, 1, 25600}); + auto input_shape = ShapeUtil::MakeShape(F32, {16, input_channels, 25600}); + auto output_shape = ShapeUtil::MakeShape(F32, {16, output_channels, 400}); // IOW layout - auto kernel_shape = ShapeUtil::MakeShape(F32, {1, 129, 256}); + auto kernel_shape = + ShapeUtil::MakeShape(F32, {input_channels, output_channels, 256}); auto input = *LiteralUtil::CreateRandomLiteral(input_shape, &engine, 1.0f, 0.1f); @@ -107,7 +167,10 @@ static void BM_Conv1DStrided(benchmark::State& state) { *LiteralUtil::CreateRandomLiteral(kernel_shape, &engine, 1.0f, 0.1f); std::vector args = {&input, &kernel}; - CHECK_OK(RunHloBenchmark(state, hlo_module, args)); + CHECK_OK(RunHloBenchmark(state, hlo_module, args, + {{"$input_shape", input_shape.ToString()}, + {"$kernel_shape", kernel_shape.ToString()}, + {"$output_shape", output_shape.ToString()}})); } // Transposed version (i.e. gradient) of BM_Conv1DStrided. In terms of shapes, @@ -117,13 +180,16 @@ static void BM_Conv1DStrided(benchmark::State& state) { // Currently, the performance is few times worse than regular conv when they // should be similar. static void BM_Conv1DTransposedStrided(benchmark::State& state) { + int input_channels = state.range(0); + int output_channels = state.range(1); + std::string hlo_module = R"( HloModule jit_jconvt ENTRY main.6 { - Arg_0.1 = f32[16,129,400]{2,1,0} parameter(0) - Arg_1.2 = f32[129,1,256]{2,1,0} parameter(1) - ROOT conv.3 = f32[16,1,25600]{2,1,0} convolution(Arg_0.1, Arg_1.2), + Arg_0.1 = $input_shape parameter(0) + Arg_1.2 = $kernel_shape parameter(1) + ROOT conv.3 = $output_shape convolution(Arg_0.1, Arg_1.2), window={size=256 pad=159_159 lhs_dilate=64}, dim_labels=bf0_io0->bf0 } )"; @@ -131,9 +197,11 @@ static void BM_Conv1DTransposedStrided(benchmark::State& state) { std::minstd_rand0 engine; // NCW layout - auto input_shape = ShapeUtil::MakeShape(F32, {16, 129, 400}); + auto input_shape = ShapeUtil::MakeShape(F32, {16, input_channels, 400}); + auto output_shape = ShapeUtil::MakeShape(F32, {16, output_channels, 25600}); // IOW layout - auto kernel_shape = ShapeUtil::MakeShape(F32, {129, 1, 256}); + auto kernel_shape = + ShapeUtil::MakeShape(F32, {input_channels, output_channels, 256}); auto input = *LiteralUtil::CreateRandomLiteral(input_shape, &engine, 1.0f, 0.1f); @@ -141,19 +209,24 @@ static void BM_Conv1DTransposedStrided(benchmark::State& state) { *LiteralUtil::CreateRandomLiteral(kernel_shape, &engine, 1.0f, 0.1f); std::vector args = {&input, &kernel}; - CHECK_OK(RunHloBenchmark(state, hlo_module, args)); + CHECK_OK(RunHloBenchmark(state, hlo_module, args, + {{"$input_shape", input_shape.ToString()}, + {"$kernel_shape", kernel_shape.ToString()}, + {"$output_shape", output_shape.ToString()}})); } // The same shapes as BM_Conv1DTransposedStrided, but with a different layout. static void BM_Conv1DTransposedStridedNonDefaultLayout( benchmark::State& state) { + int input_channels = state.range(0); + int output_channels = state.range(1); std::string hlo_module = R"( HloModule jit_jconvt ENTRY main.6 { - Arg_0.1 = f32[16,400,129]{2,1,0} parameter(0) - Arg_1.2 = f32[256,1,129]{2,1,0} parameter(1) - ROOT conv.3 = f32[16,25600,1]{2,1,0} convolution(Arg_0.1, Arg_1.2), + Arg_0.1 = $input_shape parameter(0) + Arg_1.2 = $kernel_shape parameter(1) + ROOT conv.3 = $output_shape convolution(Arg_0.1, Arg_1.2), window={size=256 pad=159_159 lhs_dilate=64}, dim_labels=b0f_0oi->b0f } )"; @@ -161,9 +234,11 @@ static void BM_Conv1DTransposedStridedNonDefaultLayout( std::minstd_rand0 engine; // NWC layout - auto input_shape = ShapeUtil::MakeShape(F32, {16, 400, 129}); + auto input_shape = ShapeUtil::MakeShape(F32, {16, 400, input_channels}); + auto output_shape = ShapeUtil::MakeShape(F32, {16, 25600, output_channels}); // WOI layout - auto kernel_shape = ShapeUtil::MakeShape(F32, {256, 1, 129}); + auto kernel_shape = + ShapeUtil::MakeShape(F32, {256, output_channels, input_channels}); auto input = *LiteralUtil::CreateRandomLiteral(input_shape, &engine, 1.0f, 0.1f); @@ -171,7 +246,10 @@ static void BM_Conv1DTransposedStridedNonDefaultLayout( *LiteralUtil::CreateRandomLiteral(kernel_shape, &engine, 1.0f, 0.1f); std::vector args = {&input, &kernel}; - CHECK_OK(RunHloBenchmark(state, hlo_module, args)); + CHECK_OK(RunHloBenchmark(state, hlo_module, args, + {{"$input_shape", input_shape.ToString()}, + {"$kernel_shape", kernel_shape.ToString()}, + {"$output_shape", output_shape.ToString()}})); } // Regular strided 2D convolution. Buffer sizes and convolution parameters are @@ -239,59 +317,91 @@ static void BM_Conv2DTransposedStrided(benchmark::State& state) { CHECK_OK(RunHloBenchmark(state, hlo_module, args)); } -static void BM_GroupedConv2D(benchmark::State& state) { - int batch = state.range(0); - int height = state.range(1); - int width = state.range(2); - int input_channels = state.range(3); - int kernel_h = state.range(4); - int kernel_w = state.range(5); - int output_channels = state.range(6); - int feature_group_count = state.range(7); +// Regular (i.e. non-transposed) grouped and strided 2D convolution. +static void BM_GroupedConv2DStrided(benchmark::State& state) { + int input_channels = state.range(0); + int output_channels = state.range(1); + int feature_group_count = state.range(2); // Derive filter channels from input channels and feature group count. int filter_channels = input_channels / feature_group_count; - // Padding values for 'SAME' padding. Only odd kernel sizes are supported. - CHECK(IsOdd(kernel_h) && IsOdd(kernel_w)); - int padding_h = (kernel_h - 1) / 2; - int padding_w = (kernel_w - 1) / 2; - std::string hlo_module = R"( - HloModule TestModule + HloModule jit_jconvf - ENTRY TestComputation { - %p0 = $input_shape parameter(0) - %p1 = $kernel_shape parameter(1) - ROOT conv = convolution(p0, p1), window={size=$window_size pad=$padding}, - dim_labels=b01f_01io->b01f, feature_group_count=$feature_group_count + ENTRY main.6 { + Arg_0.1 = $input_shape parameter(0) + Arg_1.2 = $kernel_shape parameter(1) + ROOT conv.3 = convolution(Arg_0.1, Arg_1.2), + window={size=16x16 stride=8x8 pad=4_4x4_4}, dim_labels=bf01_io01->bf01, + feature_group_count=$feature_group_count } )"; std::minstd_rand0 engine; - // Input format is NHWC. - auto input_shape = - ShapeUtil::MakeShape(F32, {batch, height, width, input_channels}); - // Filter format is HWIO. - auto kernel_shape = ShapeUtil::MakeShape( - F32, {kernel_h, kernel_w, filter_channels, output_channels}); + // NCHW layout + auto input_shape = ShapeUtil::MakeShape(F32, {2, input_channels, 80, 80}); + // IOHW layout + auto kernel_shape = + ShapeUtil::MakeShape(F32, {filter_channels, output_channels, 16, 16}); auto input = *LiteralUtil::CreateRandomLiteral(input_shape, &engine, 1.0f, 0.1f); auto kernel = *LiteralUtil::CreateRandomLiteral(kernel_shape, &engine, 1.0f, 0.1f); + std::vector args = {&input, &kernel}; + + CHECK_OK(RunHloBenchmark( + state, hlo_module, args, + {{"$input_shape", input_shape.ToString()}, + {"$kernel_shape", kernel_shape.ToString()}, + {"$feature_group_count", std::to_string(feature_group_count)}})); +} + +// Transposed version (i.e. gradient) of BM_GroupedConv2DStrided. In terms of +// shapes, this operation can be thought of as reverse of regular strided +// convolution, that's why input and output shapes are swapped (so we can +// directly compare performance of this function with BM_GroupedConv2DStrided). +static void BM_GroupedConv2DTransposedStrided(benchmark::State& state) { + int input_channels = state.range(0); + int output_channels = state.range(1); + int feature_group_count = state.range(2); + + // Derive filter channels from input channels and feature group count. + int filter_channels = input_channels / feature_group_count; + + std::string hlo_module = R"( + HloModule jit_jconvt + + ENTRY main.6 { + Arg_0.1 = $input_shape parameter(0) + Arg_1.2 = $kernel_shape parameter(1) + ROOT conv.3 = convolution(Arg_0.1, Arg_1.2), + window={size=16x16 pad=11_11x11_11 lhs_dilate=8x8}, + dim_labels=bf01_io01->bf01, feature_group_count=$feature_group_count + } + )"; + + std::minstd_rand0 engine; + + // NCHW layout + auto input_shape = ShapeUtil::MakeShape(F32, {2, input_channels, 10, 10}); + // IOHW layout + auto kernel_shape = + ShapeUtil::MakeShape(F32, {filter_channels, output_channels, 16, 16}); + auto input = + *LiteralUtil::CreateRandomLiteral(input_shape, &engine, 1.0f, 0.1f); + auto kernel = + *LiteralUtil::CreateRandomLiteral(kernel_shape, &engine, 1.0f, 0.1f); std::vector args = {&input, &kernel}; CHECK_OK(RunHloBenchmark( state, hlo_module, args, {{"$input_shape", input_shape.ToString()}, {"$kernel_shape", kernel_shape.ToString()}, - {"$window_size", absl::StrCat(kernel_h, "x", kernel_w)}, - {"$padding", absl::StrCat(padding_h, "_", padding_h, "x", padding_w, "_", - padding_w)}, - {"$feature_group_count", absl::StrCat(feature_group_count)}})); + {"$feature_group_count", std::to_string(feature_group_count)}})); } // -------------------------------------------------------------------------- // @@ -346,24 +456,51 @@ BENCHMARK(BM_Conv2D) ->Args({32, 64, 64, 4, 3, 3, 16}) ->Args({32, 32, 32, 96, 3, 3, 96}); +// -------------------------------------------------------------------------- // +// Grouped convolution +// -------------------------------------------------------------------------- // + +BENCHMARK(BM_GroupedConv2D) + ->MeasureProcessCPUTime() + ->Args({1, 45, 45, 1024, 5, 5, 1024, 1024}); + // -------------------------------------------------------------------------- // // 1D and 2D strided convolutions // -------------------------------------------------------------------------- // -BENCHMARK(BM_Conv1DStrided)->MeasureProcessCPUTime(); -BENCHMARK(BM_Conv1DTransposedStrided)->MeasureProcessCPUTime(); -BENCHMARK(BM_Conv1DTransposedStridedNonDefaultLayout)->MeasureProcessCPUTime(); +BENCHMARK(BM_Conv1DStrided) + ->MeasureProcessCPUTime() + ->Args({1, 129}) + ->Args({3, 129}); +BENCHMARK(BM_Conv1DTransposedStrided) + ->MeasureProcessCPUTime() + ->MeasureProcessCPUTime() + ->Args({129, 1}) + ->Args({129, 3}); +BENCHMARK(BM_Conv1DTransposedStridedNonDefaultLayout) + ->MeasureProcessCPUTime() + ->Args({129, 1}) + ->Args({129, 3}); BENCHMARK(BM_Conv2DStrided)->MeasureProcessCPUTime(); BENCHMARK(BM_Conv2DTransposedStrided)->MeasureProcessCPUTime(); // -------------------------------------------------------------------------- // -// Grouped convolution +// Grouped strided convolutions // -------------------------------------------------------------------------- // -BENCHMARK(BM_GroupedConv2D) +BENCHMARK(BM_GroupedConv2DStrided) ->MeasureProcessCPUTime() - ->Args({1, 45, 45, 1024, 5, 5, 1024, 1024}); + ->Args({128, 128, 128}); +BENCHMARK(BM_GroupedConv2DTransposedStrided) + ->MeasureProcessCPUTime() + ->Args({128, 128, 128}); +BENCHMARK(BM_GroupedConv2DStrided) + ->MeasureProcessCPUTime() + ->Args({128, 128, 16}); +BENCHMARK(BM_GroupedConv2DTransposedStrided) + ->MeasureProcessCPUTime() + ->Args({128, 128, 16}); } // namespace } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/benchmarks/custom_call_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/custom_call_benchmark_test.cc index fb9d35311108bc..f8e95caa312d6a 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/custom_call_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/custom_call_benchmark_test.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include #include -#include #include #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" @@ -29,9 +29,9 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test_benchmark.h" namespace xla::cpu { namespace { @@ -95,7 +95,7 @@ XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_bm$$many_int_attributes", "Host", kManyIntAttributes); static void BM_CustomCall_16IntAttributes(benchmark::State& state) { - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule module ENTRY custom_call { @@ -154,7 +154,7 @@ XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_bm$$many_float_buffers", static void BM_CustomCall_16FloatBuffers(benchmark::State& state) { int64_t d = 128; - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule module ENTRY custom_call { diff --git a/third_party/xla/xla/service/cpu/benchmarks/dag_execution_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/dag_execution_benchmark_test.cc index 6b28f468439e30..86de7a23691fcd 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/dag_execution_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/dag_execution_benchmark_test.cc @@ -15,18 +15,18 @@ limitations under the License. #include #include -#include #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test_benchmark.h" namespace xla::cpu { @@ -36,7 +36,7 @@ static void BM_DagExecution(benchmark::State& state) { // We use this benchmark to test how well XLA does the scheduling of the HLO // module to extract available parallelism, and how well ThunkExecutor // exploits that parallelism at run time. - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule fusion_f32_$d0 add { diff --git a/third_party/xla/xla/service/cpu/benchmarks/dot_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/dot_benchmark_test.cc index 58c9be6ab2b900..69aab1ecd9a2a7 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/dot_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/dot_benchmark_test.cc @@ -15,74 +15,91 @@ limitations under the License. #include #include -#include #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/primitive_util.h" #include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test_benchmark.h" namespace xla::cpu { -static void BM_BatchedDotF32(benchmark::State& state) { - int64_t d0 = state.range(0); - int64_t d1 = state.range(1); +static void BM_BatchedDot(benchmark::State& state) { + PrimitiveType dtype = static_cast(state.range(0)); + int64_t d0 = state.range(1); + int64_t d1 = state.range(2); - std::string_view hlo = R"( - HloModule dot_f32_b$d0_d$d1 + absl::string_view hlo = R"( + HloModule dot_$dtype_b$d0_d$d1 ENTRY e { - p0 = f32[$d0,$d1,$d1] parameter(0) - p1 = f32[$d0,$d1,$d1] parameter(1) - ROOT dot = f32[$d0,$d1,$d1] dot(p0, p1), + p0 = $dtype[$d0,$d1,$d1] parameter(0) + p1 = $dtype[$d0,$d1,$d1] parameter(1) + ROOT dot = $dtype[$d0,$d1,$d1] dot(p0, p1), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} } )"; + Literal p0, p1; + double mean = 1.0f; + double stddev = 0.1f; std::minstd_rand0 engine; - - auto shape = ShapeUtil::MakeShape(F32, {d0, d1, d1}); - auto p0 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); - auto p1 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); + auto shape = ShapeUtil::MakeShape(dtype, {d0, d1, d1}); + if (dtype == F32) { + p0 = *LiteralUtil::CreateRandomLiteral(shape, &engine, mean, stddev); + p1 = *LiteralUtil::CreateRandomLiteral(shape, &engine, mean, stddev); + } else if (dtype == BF16) { + p0 = *LiteralUtil::CreateRandomLiteral(shape, &engine, mean, stddev); + p1 = *LiteralUtil::CreateRandomLiteral(shape, &engine, mean, stddev); + } else { + LOG(FATAL) << "Add dtype to the if-else block before use: " << dtype; + } std::vector args = {&p0, &p1}; - CHECK_OK( - RunHloBenchmark(state, hlo, args, - {{"$d0", absl::StrCat(d0)}, {"$d1", absl::StrCat(d1)}})); + CHECK_OK(RunHloBenchmark( + state, hlo, args, + {{"$dtype", primitive_util::LowercasePrimitiveTypeName(dtype)}, + {"$d0", absl::StrCat(d0)}, + {"$d1", absl::StrCat(d1)}})); } -BENCHMARK(BM_BatchedDotF32) - ->MeasureProcessCPUTime() - ->ArgPair(1, 2) - ->ArgPair(1, 32) - ->ArgPair(1, 64) - ->ArgPair(1, 128) - ->ArgPair(1, 256) - ->ArgPair(1, 512) - ->ArgPair(2, 2) - ->ArgPair(2, 32) - ->ArgPair(2, 64) - ->ArgPair(2, 128) - ->ArgPair(2, 256) - ->ArgPair(2, 512) - ->ArgPair(4, 2) - ->ArgPair(4, 32) - ->ArgPair(4, 64) - ->ArgPair(4, 128) - ->ArgPair(4, 256) - ->ArgPair(4, 512) - ->ArgPair(8, 2) - ->ArgPair(8, 32) - ->ArgPair(8, 64) - ->ArgPair(8, 128) - ->ArgPair(8, 256) - ->ArgPair(8, 512); +#define BENCHMARK_BATCHED_DOT(dtype) \ + BENCHMARK(BM_BatchedDot) \ + ->MeasureProcessCPUTime() \ + ->Args({dtype, 1, 2}) \ + ->Args({dtype, 1, 32}) \ + ->Args({dtype, 1, 64}) \ + ->Args({dtype, 1, 128}) \ + ->Args({dtype, 1, 256}) \ + ->Args({dtype, 1, 512}) \ + ->Args({dtype, 2, 2}) \ + ->Args({dtype, 2, 32}) \ + ->Args({dtype, 2, 64}) \ + ->Args({dtype, 2, 128}) \ + ->Args({dtype, 2, 256}) \ + ->Args({dtype, 2, 512}) \ + ->Args({dtype, 4, 2}) \ + ->Args({dtype, 4, 32}) \ + ->Args({dtype, 4, 64}) \ + ->Args({dtype, 4, 128}) \ + ->Args({dtype, 4, 256}) \ + ->Args({dtype, 4, 512}) \ + ->Args({dtype, 8, 2}) \ + ->Args({dtype, 8, 32}) \ + ->Args({dtype, 8, 64}) \ + ->Args({dtype, 8, 128}) \ + ->Args({dtype, 8, 256}) \ + ->Args({dtype, 8, 512}) + +BENCHMARK_BATCHED_DOT(F32); // Shown as "11" in the benchmark name. +BENCHMARK_BATCHED_DOT(BF16); // Shown as "16" in the benchmark name. } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/benchmarks/dynamic_update_slice_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/dynamic_update_slice_benchmark_test.cc index 195a98523d2f29..2031189cf24848 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/dynamic_update_slice_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/dynamic_update_slice_benchmark_test.cc @@ -15,25 +15,25 @@ limitations under the License. #include #include -#include #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test_benchmark.h" namespace xla::cpu { static void BM_DynamicUpdateSliceF32(benchmark::State& state) { int64_t d0 = state.range(0); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule dynamic_update_slice_f32_$d0 ENTRY e { diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/README.md b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/README.md new file mode 100644 index 00000000000000..a334d0d6b66133 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/README.md @@ -0,0 +1,16 @@ +# Gemma2 Flax 2b-it Benchmark + +This repository provides scripts for benchmarking the Gemma2 Flax 2b-it model. + +## Scripts Instructions + +* **setup.sh:** This script sets up the necessary environment for the benchmark. + * Usage: `bash setup.sh` +* **run.sh:** This script executes the benchmark and displays the results. + * Usage: `bash run.sh` + +## Model Page on Kaggle + +The Gemma Flax model can be accessed and used on Kaggle: + +https://www.kaggle.com/models/google/gemma-2/flax \ No newline at end of file diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/benchmark.py b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/benchmark.py new file mode 100644 index 00000000000000..23d482fde3a91c --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/benchmark.py @@ -0,0 +1,100 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmark gemma2-2b-it Flax performance.""" + +import datetime +import os +import statistics + +from gemma import params as params_lib +from gemma import sampler as sampler_lib +from gemma import transformer as transformer_lib +import sentencepiece as spm + + +GEMMA_VARIANT = 'gemma2-2b-it' + +# Assign Gemma path +GEMMA_PATH = os.environ.get('MODEL_DIR') + +# Ensure that the tokenizer is present +TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model') +assert os.path.isfile(TOKENIZER_PATH), 'Tokenizer not found!' + +# Ensure that the checkpoint is present +CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT) +assert os.path.exists(CKPT_PATH), 'Flax checkpoint not found!' + +# Set up model sampler +params = params_lib.load_and_format_params(CKPT_PATH) +vocab = spm.SentencePieceProcessor() +vocab.Load(TOKENIZER_PATH) +transformer_config = transformer_lib.TransformerConfig.from_params( + params=params, cache_size=1024 +) +transformer = transformer_lib.Transformer(transformer_config) +sampler = sampler_lib.Sampler( + transformer=transformer, + vocab=vocab, + params=params['transformer'], +) + +OUTPUT_TOKEN_LEN = 128 +prompt = ['What is JAX in 3 bullet points?'] + + +def benchmark_generation_time(output_token_len): + """Benchmark generation time given output token length.""" + timestamp_start = datetime.datetime.now() + reply = sampler(input_strings=prompt, total_generation_steps=output_token_len) + timestamp_end = datetime.datetime.now() + timer_delta = timestamp_end - timestamp_start + # Prints generated tokens when benchmarking the full length. + if output_token_len == OUTPUT_TOKEN_LEN: + print(reply.text) + return timer_delta.total_seconds() * 1000 + + +def display_tpot(): + """Calculate the time per output token.""" + e2e_latency_mean = statistics.mean(latency_list) + ttft_mean = statistics.mean(ttft_ms_list) + generation_time_mean = e2e_latency_mean - ttft_mean + tpot = generation_time_mean / (OUTPUT_TOKEN_LEN - 1) + print(f'TPOT: {round(tpot, 2)} ms') + + +def display_benchmark_results(timer_list, metric_name): + """Display mean and stdev for a given metric.""" + mean_time = statistics.mean(timer_list) + stdev_time = statistics.stdev(timer_list) + stdev_time_percentage = (stdev_time / mean_time) * 100 + + print( + '%s: %.2f ms ± %.2f%%' % (metric_name, mean_time, stdev_time_percentage) + ) + + +if __name__ == '__main__': + # Measure time to first token. + ttft_ms_list = [benchmark_generation_time(1) for _ in range(5)] + # Measure time for full tokens. + latency_list = [benchmark_generation_time(OUTPUT_TOKEN_LEN) for _ in range(5)] + + # Display benchmark results + display_benchmark_results(ttft_ms_list, 'TTFT') + display_benchmark_results(latency_list, 'E2E Latency') + display_tpot() + del sampler diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/config.sh b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/config.sh new file mode 100644 index 00000000000000..6028e10109fee5 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/config.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -x + +# Temporary directory for the virtual environment +export TMP_DIR="${HOME}/tmp" + +# Cache directory for the Gemma2 Flax model +export CACHE_DIR="${HOME}/.cache" + +# Path to virtual environment +export VENV_DIR="${TMP_DIR}/gemma-2-flax" + +# Path to the Gemma2 Flax model files +export MODEL_DIR="${CACHE_DIR}/gemma-2-flax-2b-it" diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/requirements.txt b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/requirements.txt new file mode 100644 index 00000000000000..59c242835f35f2 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/requirements.txt @@ -0,0 +1,30 @@ +absl-py==2.1.0 +chex==0.1.87 +etils==1.11.0 +flax==0.10.2 +fsspec==2024.10.0 +gemma @ git+https://github.com/google-deepmind/gemma.git@af38d6eb413cb98446b78a906c77cf5ba28be149 +humanize==4.11.0 +importlib_resources==6.4.5 +jax==0.4.37 +jaxlib==0.4.36 +markdown-it-py==3.0.0 +mdurl==0.1.2 +ml_dtypes==0.5.0 +msgpack==1.1.0 +nest-asyncio==1.6.0 +numpy==2.2.0 +opt_einsum==3.4.0 +optax==0.2.4 +orbax-checkpoint==0.10.2 +protobuf==5.29.1 +Pygments==2.18.0 +PyYAML==6.0.2 +rich==13.9.4 +scipy==1.14.1 +sentencepiece==0.2.0 +simplejson==3.19.3 +tensorstore==0.1.69 +toolz==1.0.0 +typing_extensions==4.12.2 +zipp==3.21.0 \ No newline at end of file diff --git a/third_party/xla/xla/backends/cpu/testlib/kernel_runner.py b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/run.sh similarity index 69% rename from third_party/xla/xla/backends/cpu/testlib/kernel_runner.py rename to third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/run.sh index 8f6a4ce6ea4c65..3c71bcc45ddc13 100644 --- a/third_party/xla/xla/backends/cpu/testlib/kernel_runner.py +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/run.sh @@ -1,3 +1,4 @@ +#!/bin/bash # Copyright 2024 The OpenXLA Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,10 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""CPU specific kernel runner implementations.""" -from xla.backends.cpu.testlib import kernel_runner_extention +set -x -LlvmIrKernelSpec = kernel_runner_extention.LlvmIrKernelSpec -LlvmIrKernelEmitter = kernel_runner_extention.LlvmIrKernelEmitter -KernelRunner = kernel_runner_extention.KernelRunner +source ./config.sh + +if [[ ! -d "$VENV_DIR" ]]; then + echo "Virtual environment not found. Please run setup.sh first." +else + # Activate the virtual environment + source "$VENV_DIR"/bin/activate + + # Run the benchmark + python benchmark.py +fi diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/setup.sh b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/setup.sh new file mode 100644 index 00000000000000..dc11d1b4ec4381 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/flax_2b/setup.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -x + +source ./config.sh + +# Create tmp and cache directories if they don't exist +mkdir -p "$TMP_DIR" +mkdir -p "$CACHE_DIR" + +if [[ ! -d "$VENV_DIR" ]]; then + # Create a virtual environment + python3 -m venv "$VENV_DIR" + # Activate the virtual environment + source "$VENV_DIR"/bin/activate + # Install Gemma2 Flax dependencies + pip install -r requirements.txt +else + # Activate the virtual environment + source "$VENV_DIR"/bin/activate +fi + + +TAR_FILE="${CACHE_DIR}/gemma-2-flax-2b-it.tar" +# Download and extract Gemma2 Flax model files +if [[ ! -d "$MODEL_DIR" ]]; then + # Copy the tar file to the tmp directory + wget -P "$CACHE_DIR" https://storage.googleapis.com/xla-benchmarking-temp/gemma-2-flax-2b-it.tar + # Change to cache directory and extract the tar file + cd "$CACHE_DIR" + tar -xf "$TAR_FILE" +fi diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/README.md b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/README.md new file mode 100644 index 00000000000000..35337b27d053d6 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/README.md @@ -0,0 +1,32 @@ +# Gemma2 2B Keras model + +Scripts to run Gemma2 2B Keras model on CPU. + +Model link: https://www.kaggle.com/models/google/gemma-2/keras + +Instructions: + +* Set up your Kaggle API key by following + [these instructions](https://www.kaggle.com/docs/api#authentication). +* `$ bash setup.sh` + * This only needs to be run once. It will create a virtual environment at + a location read from `config.sh` and install the necessary dependencies. + * Change the `VENV_BASE` variable in `config.sh` before running `setup.sh` + if you want to use a different location. +* `$ KERAS_BACKEND=jax bash run.sh` + * This script activates the right virtual environment and runs the + benchmark in `benchmark.py`. + * Set `KERAS_BACKEND=tensorflow` or `torch` to run with TensorFlow or + PyTorch backend. +* (Optional) Delete the virtual environment: `$ bash cleanup.sh` + +To try other model variations with different numbers of parameters, modify the +following line in `benchmark.py`: + +``` +gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en") +``` + +Replace "gemma2_2b_en" with other preset names, e.g., +"gemma2_instruct_2b_en","gemma2_9b_en", etc. See the full preset list +[here](https://github.com/keras-team/keras-hub/blob/86607dc921999e33f5b8a0bcf81ec987b60c9dee/keras_hub/src/models/gemma/gemma_presets.py#L5-L200). diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/benchmark.py b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/benchmark.py new file mode 100644 index 00000000000000..46d5e4355c1136 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/benchmark.py @@ -0,0 +1,107 @@ +# Copyright 2025 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmark Gemma2-2B Keras performance.""" + +import time +import keras_nlp +import numpy as np + +_NUM_OUTPUT_TOKENS = 30 +_QUERY = "What is JAX in 3 bullet points?" +_VERBOSE = True + + +def compute_stats(array): + """Reports mean and ± range for the given array. + + The range computation follows benchstat's. + + Args: + array: The array to compute stats for. + + Returns: + mean and ± %diff range. + """ + q1 = np.percentile(array, 25) + q3 = np.percentile(array, 75) + low = q1 - 1.5 * (q3 - q1) + high = q3 + 1.5 * (q3 - q1) + + # Remove outliers. + filtered_array = list(filter(lambda x: low <= x and x <= high, array)) + + mean = np.mean(filtered_array) + min_val = np.min(filtered_array) + max_val = np.max(filtered_array) + max_diff = max(max_val - mean, mean - min_val) + diff = max_diff / mean * 100.0 + + return (mean, diff) + + +def run(gemma_lm, max_len): + """Benchmarks inferences with at most `max_len` output tokens. + + Args: + gemma_lm: The Gemma2 Keras model. + max_len: The maximum number of output tokens per one inference. + + Returns: + mean ± %diff and the actual number of output tokens generated per inference. + """ + # Warm up. + start = time.time() + output = gemma_lm.generate(_QUERY, max_length=max_len + 1) + num_actual_output_tokens = len(output.split(" ")) + warmup_time = (time.time() - start) * 1000 + + if _VERBOSE: + print("=== Max len: %d ===" % max_len) + print("Warmup: %lf ms" % warmup_time) + print("Output:\n%s\n" % output) + + times = [] + for i in range(1, 6): + start = time.time() + output = gemma_lm.generate(_QUERY, max_length=max_len + 1) + assert num_actual_output_tokens == len(output.split(" ")) + elapsed_time = (time.time() - start) * 1000 + times.append(elapsed_time) + if _VERBOSE: + print("%d: %lf ms" % (i, elapsed_time)) + + mean, diff = compute_stats(times) + if _VERBOSE: + print("Mean: %lf ± %d%% ms\n" % (mean, diff)) + + return (mean, diff, num_actual_output_tokens) + + +def main(): + if _VERBOSE: + print("Query: %s" % _QUERY) + + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en") + mean_1, diff_1, _ = run(gemma_lm, 1) + mean_n, diff_n, num_output_tokens = run(gemma_lm, _NUM_OUTPUT_TOKENS) + + print("Generated %d tokens", num_output_tokens) + tpot = (mean_n - mean_1) / (num_output_tokens - 1) + print("TTFT: %lf ± %d%% ms" % (mean_1, diff_1)) + print("TPOT: %lf ± %d%% ms" % (tpot, diff_n)) + + +if __name__ == "__main__": + main() diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/cleanup.sh b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/cleanup.sh new file mode 100644 index 00000000000000..8cb893f5b1d38b --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/cleanup.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# Copyright 2025 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -x +set -e + +source config.sh + +rm -rf ${GEMMA2_VENV} diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/config.sh b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/config.sh new file mode 100644 index 00000000000000..55f1139b818f28 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/config.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copyright 2025 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -x +set -e + +export VENV_BASE=~/venv +export GEMMA2_VENV=${VENV_BASE}/gemma2-keras diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/requirements.txt b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/requirements.txt new file mode 100644 index 00000000000000..d9866bf65bad57 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/requirements.txt @@ -0,0 +1,5 @@ +keras==3.8.0 +keras_nlp==0.18.1 +tensorflow==2.18.0 +jax==0.4.38 +torch==2.5.1 diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/run.sh b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/run.sh new file mode 100644 index 00000000000000..876625a65658e1 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/run.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Copyright 2025 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -x +set -e + +source config.sh +source ${GEMMA2_VENV}/bin/activate + +python benchmark.py diff --git a/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/setup.sh b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/setup.sh new file mode 100644 index 00000000000000..2258692d608be1 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/e2e/gemma2/keras/setup.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Copyright 2025 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -x +set -e + +source config.sh + +mkdir -p ${VENV_BASE} +python3 -m venv ${GEMMA2_VENV} +source ${GEMMA2_VENV}/bin/activate +pip install -r requirements.txt diff --git a/third_party/xla/xla/service/cpu/benchmarks/elementwise_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/elementwise_benchmark_test.cc index 9b9d205097d695..c33bd97e6f25fa 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/elementwise_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/elementwise_benchmark_test.cc @@ -15,25 +15,25 @@ limitations under the License. #include #include -#include #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test_benchmark.h" namespace xla::cpu { static void BM_AddF32(benchmark::State& state) { int64_t d0 = state.range(0); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule add_f32_$d0 ENTRY e { @@ -56,7 +56,7 @@ static void BM_AddF32(benchmark::State& state) { static void BM_AddBF16(benchmark::State& state) { int64_t d0 = state.range(0); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule add_bf16_$d0 ENTRY e { @@ -79,7 +79,7 @@ static void BM_AddBF16(benchmark::State& state) { static void BM_ConvertF32ToBF16(benchmark::State& state) { int64_t d0 = state.range(0); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule convert_f32_to_bf16_$d0 ENTRY e { diff --git a/third_party/xla/xla/service/cpu/benchmarks/fusion_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/fusion_benchmark_test.cc index 97412cb2301977..38e43af34b6988 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/fusion_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/fusion_benchmark_test.cc @@ -16,26 +16,26 @@ limitations under the License. #include #include #include -#include #include #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test_benchmark.h" namespace xla::cpu { static void BM_FusionF32(benchmark::State& state) { int64_t d0 = state.range(0); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule fusion_f32_$d0 ENTRY e { @@ -68,7 +68,7 @@ static void BM_FusionF32(benchmark::State& state) { static void BM_FusionF32_2(benchmark::State& state) { int64_t d0 = state.range(0); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule fusion_f32_2_$d0 ENTRY e { @@ -144,7 +144,7 @@ static void BM_FusionF32_2(benchmark::State& state) { static void BM_BcastFusionF32(benchmark::State& state) { int64_t d0 = state.range(0); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule fusion_f32_$d0 ENTRY e { @@ -169,7 +169,7 @@ static void BM_BcastFusionF32(benchmark::State& state) { static void BM_DynamicUpdateSliceFusionF32(benchmark::State& state) { int64_t d0 = state.range(0); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule dynamic_update_slice_fusion_f32_$d0 ENTRY e { @@ -198,7 +198,7 @@ static void BM_ChainOfAddF32(benchmark::State& state) { // In this benchmark we create a chain of additions starting from `p2` and // ending with `p$size`. The chain is fused into a single fusion node. - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule chain_of_add_f32_$size ENTRY e { diff --git a/third_party/xla/xla/service/cpu/benchmarks/gather_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/gather_benchmark_test.cc index 128711c0e740e3..597bc7c0e8c792 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/gather_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/gather_benchmark_test.cc @@ -15,19 +15,19 @@ limitations under the License. #include #include -#include #include #include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "xla/array2d.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test_benchmark.h" namespace xla::cpu { @@ -36,7 +36,7 @@ static void BM_GatherS32(benchmark::State& state) { int64_t d1 = state.range(1); int64_t slice_size = state.range(2); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule gather_s32_d$d0_d$d1_s$slice_size ENTRY e { diff --git a/third_party/xla/xla/service/cpu/benchmarks/hlo_benchmark_runner.cc b/third_party/xla/xla/service/cpu/benchmarks/hlo_benchmark_runner.cc index 4431eff2758aec..600d1c001fad21 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/hlo_benchmark_runner.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/hlo_benchmark_runner.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include -#include #include #include "absl/status/status.h" #include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" @@ -40,7 +40,7 @@ limitations under the License. namespace xla::cpu { absl::Status RunHloBenchmark(benchmark::State& state, - std::string_view hlo_module, + absl::string_view hlo_module, absl::Span args, StrToStrMapping replacements, bool disable_parallel_task_assigner) { @@ -123,7 +123,7 @@ absl::Status RunHloBenchmark(benchmark::State& state, } absl::Status CompileHloBenchmark(benchmark::State& state, - std::string_view hlo_module, + absl::string_view hlo_module, StrToStrMapping replacements, bool disable_parallel_task_assigner) { xla::CpuClientOptions options; diff --git a/third_party/xla/xla/service/cpu/benchmarks/hlo_benchmark_runner.h b/third_party/xla/xla/service/cpu/benchmarks/hlo_benchmark_runner.h index 23fca54359e93d..5891f6488c87b7 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/hlo_benchmark_runner.h +++ b/third_party/xla/xla/service/cpu/benchmarks/hlo_benchmark_runner.h @@ -16,13 +16,11 @@ limitations under the License. #ifndef XLA_SERVICE_CPU_BENCHMARKS_HLO_BENCHMARK_RUNNER_H_ #define XLA_SERVICE_CPU_BENCHMARKS_HLO_BENCHMARK_RUNNER_H_ -#include - #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/literal.h" -#include "tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/test_benchmark.h" namespace xla::cpu { @@ -40,7 +38,7 @@ using StrToStrMapping = // not be run on the HLO module before running the benchmark. Therefore, // parallel backend will not be executed. absl::Status RunHloBenchmark(benchmark::State& state, - std::string_view hlo_module, + absl::string_view hlo_module, absl::Span args, StrToStrMapping replacements = {}, bool disable_parallel_task_assigner = false); @@ -50,7 +48,7 @@ absl::Status RunHloBenchmark(benchmark::State& state, // Takes the same options as RunHloBenchmark, except no arguments since the // HLO is only compiled, not run. absl::Status CompileHloBenchmark(benchmark::State& state, - std::string_view hlo_module, + absl::string_view hlo_module, StrToStrMapping replacements = {}, bool disable_parallel_task_assigner = false); diff --git a/third_party/xla/xla/service/cpu/benchmarks/optimizer_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/optimizer_benchmark_test.cc index 3d553885e47349..b7aa400c578a37 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/optimizer_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/optimizer_benchmark_test.cc @@ -15,25 +15,25 @@ limitations under the License. #include #include -#include #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test_benchmark.h" namespace xla::cpu { static void BM_Optimizer0(benchmark::State& state) { int64_t d0 = state.range(0); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule jit_update_fn_$d0 add { diff --git a/third_party/xla/xla/service/cpu/benchmarks/pad_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/pad_benchmark_test.cc index a2857ef274b521..1bef38a5c2fce7 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/pad_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/pad_benchmark_test.cc @@ -15,25 +15,25 @@ limitations under the License. #include #include -#include #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test_benchmark.h" namespace xla::cpu { static void BM_PadF32(benchmark::State& state) { int64_t d0 = state.range(0); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule pad_f32_$d0 ENTRY e { diff --git a/third_party/xla/xla/service/cpu/benchmarks/reduction_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/reduction_benchmark_test.cc index af51cdcf6c395b..9d90e42548f99b 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/reduction_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/reduction_benchmark_test.cc @@ -15,25 +15,25 @@ limitations under the License. #include #include -#include #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test_benchmark.h" namespace xla::cpu { static void BM_ReduceAddF32(benchmark::State& state) { int64_t d0 = state.range(0); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule reduce_add_f32_$d0 add { @@ -61,7 +61,7 @@ static void BM_ReduceAddF32(benchmark::State& state) { static void BM_ReduceAddBF16(benchmark::State& state) { int64_t d0 = state.range(0); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule reduce_add_bf16_$d0 add { diff --git a/third_party/xla/xla/service/cpu/benchmarks/scatter_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/scatter_benchmark_test.cc index d9bf151c5ec045..962f15eafb6432 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/scatter_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/scatter_benchmark_test.cc @@ -26,9 +26,9 @@ limitations under the License. #include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test_benchmark.h" namespace xla::cpu { diff --git a/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc index 600c2ea319df3d..b92557b6d99501 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc @@ -15,18 +15,18 @@ limitations under the License. #include #include -#include #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test_benchmark.h" namespace xla::cpu { @@ -34,7 +34,7 @@ static void BM_SelectAndScatterF32(benchmark::State& state) { int64_t d0 = state.range(0); int64_t d1 = (d0 - 1) / 2; - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule select_and_scatter_f32_$d0 ge { diff --git a/third_party/xla/xla/service/cpu/benchmarks/tanh_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/tanh_benchmark_test.cc index 4f8aa0670b7e07..b210d75a1176ec 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/tanh_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/tanh_benchmark_test.cc @@ -15,25 +15,25 @@ limitations under the License. #include #include -#include #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test_benchmark.h" namespace xla::cpu { static void BM_TanhF32(benchmark::State& state) { int64_t d0 = state.range(0); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule tanh_f32_$d0 ENTRY e { @@ -54,7 +54,7 @@ static void BM_TanhF32(benchmark::State& state) { static void BM_TanhF64(benchmark::State& state) { int64_t d0 = state.range(0); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule tanh_f64_$d0 ENTRY e { diff --git a/third_party/xla/xla/service/cpu/benchmarks/topk_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/topk_benchmark_test.cc index f062213a725117..99f48a2caa225c 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/topk_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/topk_benchmark_test.cc @@ -15,16 +15,16 @@ limitations under the License. #include #include -#include #include #include "absl/log/check.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "xla/literal_util.h" #include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test_benchmark.h" namespace xla::cpu { @@ -34,7 +34,7 @@ static void BM_TopKCustomCall_F32(benchmark::State& state) { int64_t length = state.range(2); CHECK_LE(k, length); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule topk_custom_call ENTRY test { @@ -62,7 +62,7 @@ static void BM_TopK_BF16(benchmark::State& state) { int64_t length = state.range(2); CHECK_LE(k, length); - std::string_view hlo = R"( + absl::string_view hlo = R"( HloModule topk ENTRY test { diff --git a/third_party/xla/xla/service/cpu/collectives_interface.h b/third_party/xla/xla/service/cpu/collectives_interface.h deleted file mode 100644 index 54b6a280f59910..00000000000000 --- a/third_party/xla/xla/service/cpu/collectives_interface.h +++ /dev/null @@ -1,92 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_CPU_COLLECTIVES_INTERFACE_H_ -#define XLA_SERVICE_CPU_COLLECTIVES_INTERFACE_H_ - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/time/time.h" -#include "absl/types/span.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/global_device_id.h" -#include "xla/xla_data.pb.h" - -namespace xla::cpu { - -class CollectivesCommunicator { - public: - virtual ~CollectivesCommunicator() = default; - - // Performs an all-reduce. - virtual absl::Status AllReduce(const RendezvousKey& key, - ReductionKind reduction_kind, - PrimitiveType element_type, - size_t num_elements, const void* input_buffer, - void* output_buffer, - absl::Duration timeout) = 0; - - // Performs a collective permute. - // Arguments: - // source_rank: the rank from which this rank should receive its data. - // Optional; if absent, then the output is filled with zeros. - // target_rank: the ranks to which this rank should send its data. - virtual absl::Status CollectivePermute(const RendezvousKey& key, - size_t num_bytes, - std::optional source_rank, - absl::Span target_ranks, - const void* input_buffer, - void* output_buffer, - absl::Duration timeout) = 0; - - // Performs an all-to-all. - // The all-to-all chunks are passed separately and do not have to be - // contiguous in memory. - virtual absl::Status AllToAll(const RendezvousKey& key, size_t chunk_bytes, - absl::Span input_buffers, - absl::Span output_buffers, - absl::Duration timeout) = 0; - - // Performs an all-gather. - virtual absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, - const void* input_buffer, void* output_buffer, - absl::Duration timeout) = 0; - - // Performs a reduce-scatter - virtual absl::Status ReduceScatter( - const RendezvousKey& key, ReductionKind reduction_kind, - PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, - void* output_buffer, absl::Duration timeout) = 0; -}; - -class CollectivesInterface { - public: - virtual ~CollectivesInterface() = default; - - // Builds a context for a collective group. - // Args: - // devices: the devices participating in this collective. - // rank: the rank of this process. - virtual absl::StatusOr> - GetCommunicator(absl::Span devices, int rank) = 0; -}; - -} // namespace xla::cpu - -#endif // XLA_SERVICE_CPU_COLLECTIVES_INTERFACE_H_ diff --git a/third_party/xla/xla/service/cpu/conv_canonicalization_test.cc b/third_party/xla/xla/service/cpu/conv_canonicalization_test.cc index 00c9ee256452c9..80d8313b7c752c 100644 --- a/third_party/xla/xla/service/cpu/conv_canonicalization_test.cc +++ b/third_party/xla/xla/service/cpu/conv_canonicalization_test.cc @@ -20,9 +20,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" +#include "xla/literal_util.h" #include "xla/service/cpu/target_machine_features_stub.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 3e97c271371667..5c28de6021def4 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -24,7 +24,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -42,6 +41,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/SmallVector.h" @@ -111,6 +111,7 @@ limitations under the License. #include "xla/hlo/transforms/expanders/rng_bit_generator_expander.h" #include "xla/hlo/transforms/expanders/rng_expander.h" #include "xla/hlo/transforms/expanders/stochastic_convert_decomposer.h" +#include "xla/hlo/transforms/literal_canonicalizer.h" #include "xla/hlo/transforms/operand_upcaster.h" #include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/hlo/transforms/simplifiers/batch_dot_simplification.h" @@ -135,6 +136,7 @@ limitations under the License. #include "xla/hlo/transforms/while_loop_trip_count_annotator.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/literal.h" +#include "xla/literal_pool.h" #include "xla/map_util.h" #include "xla/mlir_hlo/transforms/passes.h" #include "xla/primitive_util.h" @@ -203,9 +205,6 @@ limitations under the License. #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/tsl/concurrency/async_value.h" -#include "xla/tsl/concurrency/async_value_ref.h" -#include "xla/tsl/concurrency/chain.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" @@ -217,7 +216,6 @@ limitations under the License. #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" -#include "tsl/platform/threadpool_async_executor.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" @@ -235,14 +233,11 @@ limitations under the License. namespace xla { namespace { -using tsl::AsyncValue; -using tsl::AsyncValueRef; -using tsl::Chain; using tsl::profiler::TraceMe; using tsl::profiler::TraceMeEncode; // A module identifier (prefix) for emitted LLVM modules. -static constexpr std::string_view kXlaModuleIdentifier = "__compute_module"; +static constexpr absl::string_view kXlaModuleIdentifier = "__compute_module"; // Returns a global (per-process) thread pool for XLA CPU compilation tasks. static tsl::thread::ThreadPool* GetCompilationThreadPool() { @@ -758,6 +753,12 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass( SubByteNormalization::SET_ELEMENT_SIZE); } + + // Finally canonicalize all literals larger than 1024 bytes in the module to + // reuse the same literal across multiple HLO modules. + pipeline.AddPass(LiteralPool::Default(), + /*min_size_bytes=*/1024); + return pipeline.Run(module).status(); } @@ -943,7 +944,7 @@ std::pair GetIRModuleHooks( // Include LLVM module identifier suffix in case `llvm_module` is just a // part of the original LLVM module constructed by the XLA. - std::string_view id = llvm_module.getModuleIdentifier(); + absl::string_view id = llvm_module.getModuleIdentifier(); size_t pos = std::min(id.size(), 1 + kXlaModuleIdentifier.size()); llvm_ir::DumpIrIfEnabled(*hlo_module_ptr, llvm_module, optimized, /*filename_suffix=*/id.substr(pos)); @@ -1031,16 +1032,21 @@ namespace { // Post-compilation callback functor for use by SimpleOrcJIT. // // Dumps machine code if dumping is enabled for the module. -static std::function -CreateOrcJITPostCompilationHook(const HloModule* module, +static std::function +CreateOrcJITPostCompilationHook(const HloModule* hlo_module, std::vector* obj_files) { - return [=](const llvm::object::ObjectFile& obj_file) { + return [=](const llvm::Module& llvm_module, + const llvm::object::ObjectFile& obj_file) { if (obj_files) obj_files->push_back(obj_file.getData().str()); - if (DumpingEnabledForHloModule(*module)) { - DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o", - absl::string_view(obj_file.getData().data(), - obj_file.getData().size())); + if (DumpingEnabledForHloModule(*hlo_module)) { + std::string_view id = llvm_module.getModuleIdentifier(); + size_t pos = std::min(id.size(), 1 + kXlaModuleIdentifier.size()); + DumpToFileInDir( + *hlo_module, /*file_prefix=*/"", + /*file_suffix=*/absl::StrCat("obj-file.", id.substr(pos), ".o"), + absl::string_view(obj_file.getData().data(), + obj_file.getData().size())); } }; } @@ -1324,6 +1330,23 @@ inline void VlogMaxIsa(absl::string_view max_cpu_isa) { } } +// We keep HloProto in the CpuExecutable, but we don't need to keep literals +// payload in it as we use it only for debugging and memory analysis. +static void StripPayloadFromLiteralProto(HloProto& proto) { + auto* module = proto.mutable_hlo_module(); + for (auto& computation : *module->mutable_computations()) { + for (auto& instruction : *computation.mutable_instructions()) { + // We only keep literal shape to correctly estimate memory usage of the + // HLO module, but we don't need the actual literal data. + if (instruction.has_literal()) { + LiteralProto literal; + *literal.mutable_shape() = instruction.literal().shape(); + *instruction.mutable_literal() = std::move(literal); + } + } + } +} + absl::StatusOr> CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { TraceMe trace([&] { @@ -1414,28 +1437,11 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { const bool embed_ir_in_executable = debug_options.xla_embed_ir_in_executable(); - // Select a memory scheduler optimized for concurrency vs minimal memory. - auto scheduler = - debug_options.xla_cpu_enable_concurrency_optimized_scheduler() - ? BFSMemoryScheduler - : DFSMemoryScheduler; - - // Select an order for emitting the HLO instructions for each - // computation. Using this sequence enables tighter buffer liveness analysis - // and reduced memory usage (as compared to using `DependencyHloOrdering`). - TF_ASSIGN_OR_RETURN( - HloSchedule schedule, - ScheduleModule(module.get(), BufferSizeBytesFunction(), - ComputationSchedulerToModuleScheduler(scheduler))); + TF_ASSIGN_OR_RETURN(HloSchedule schedule, CreateHloSchedule(*module)); TF_RETURN_IF_ERROR(module->set_schedule(schedule)); - // Run buffer allocation on the HLO graph. - TF_ASSIGN_OR_RETURN( - std::unique_ptr assignment, - BufferAssigner::Run(module.get(), - std::make_unique(schedule), - BufferSizeBytesFunction(), memory_alignment, - /*allocate_buffers_for_constants=*/true)); + TF_ASSIGN_OR_RETURN(std::unique_ptr assignment, + CreateBufferAssignment(*module)); DumpHloModuleIfEnabled(*module, *assignment, absl::StrCat("cpu_", kAfterOptimizationsDumpName)); @@ -1446,6 +1452,7 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { *hlo_proto->mutable_hlo_module() = cpu_executable->module().ToProto(); *hlo_proto->mutable_buffer_assignment() = cpu_executable->buffer_assignment().ToProto(); + StripPayloadFromLiteralProto(*hlo_proto); cpu_executable->set_hlo_proto(std::move(hlo_proto)); return cpu_executable; }; @@ -1474,17 +1481,15 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { #endif ); - // Emit global variables for constants. - // - // TODO(ezhulenev): Figure out how to emit constants that are only needed for - // thread local computations as with Thunks runtime we keep constants outside - // of the LLVM module. Currently we end up doubling memory for constants. - TF_RETURN_IF_ERROR(nested_ir_emitter.EmitConstantGlobals()); // If we use Thunk runtime then instead of emitting LLVM function for the // entry computation we emit a sequence of thunks that implement the // computation as a sequence of interpreted commands. if (module->config().debug_options().xla_cpu_use_thunk_runtime()) { + // The thunk runtime manages large constants, therefore we only emit + // small ones. + TF_RETURN_IF_ERROR(nested_ir_emitter.EmitSmallConstantGlobals()); + // IR emitter is responsible for building LLVM module with host kernels for // corresponding HLO instructions (fusions, elemental instructions, etc.). IrEmitter2 ir_emitter2(*module, llvm_module.get(), &nested_ir_emitter); @@ -1499,16 +1504,30 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { std::string ir_module_string; if (embed_ir_in_executable) { - ir_module_string = llvm_ir::DumpToString(llvm_module.get()); + std::string emitter2_ir = llvm_ir::DumpToString(llvm_module.get()); + + auto thunk_kernel_fmt = [](std::string* out, + const ThunkEmitter::EmittedKernel& kernel) { + absl::StrAppend( + out, llvm_ir::DumpToString(kernel.module.getModuleUnlocked())); + }; + std::string thunks_ir = + absl::StrJoin(thunk_emitter.kernels(), "\n", thunk_kernel_fmt); + + ir_module_string = absl::StrCat(emitter2_ir, "\n", thunks_ir); } TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); + for (const auto& [name, module] : thunk_emitter.kernels()) { + TF_RETURN_IF_ERROR(VerifyLlvmModule(*module.getModuleUnlocked())); + } // We define the number of module parts based on the total number of // compiled functions (kernels and comparators) that are called from thunks, // and the maximum number of parts that we want to split the module into. - size_t num_compiled_functions = - ir_emitter2.kernels().size() + ir_emitter2.comparators().size(); + size_t num_compiled_functions = ir_emitter2.kernels().size() + + ir_emitter2.comparators().size() + + thunk_emitter.kernels().size(); size_t num_parts = std::min(num_compiled_functions, parallel_codegen_split_count); @@ -1572,6 +1591,18 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { // Collect compiled symbols from all LLVM module parts. std::vector compiled_symbols; + VLOG(3) << "Adding " << thunk_emitter.kernels().size() + << " kernels to the JIT compiler"; + int kernel_dylib_index = 0; + for (auto& [name, module] : thunk_emitter.kernels()) { + compiled_symbols.push_back( + FunctionLibrary::Sym(name)); + TF_CHECK_OK( + jit_compiler.AddModule(std::move(module), kernel_dylib_index)); + // Simply roundrobin the kernel dylibs + kernel_dylib_index = (kernel_dylib_index + 1) % num_parts; + } + for (const CompiledSymbolsPart& part : compiled_parts) { for (const IrEmitter2::KernelInfo& kernel : part.kernels) { compiled_symbols.push_back( @@ -1618,6 +1649,8 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { return with_hlo_proto(std::move(cpu_executable)); } + TF_RETURN_IF_ERROR(nested_ir_emitter.EmitAllConstantGlobals()); + // Each computation is a single function. Emit all embedded computations // before the entry computation. The order of computations returned from // SubcomputationEmissionOrder guarantees that a called computation occurs @@ -1687,6 +1720,10 @@ absl::StatusOr> CpuCompiler::RunBackend( std::unique_ptr module, [[maybe_unused]] se::StreamExecutor* stream_exec, const CompileOptions& options) { + TraceMe trace([&] { + return TraceMeEncode("CpuCompiler::RunBackend", {{"name", module->name()}}); + }); + VLOG(1) << "Compiling: " << module->name(); RecordCpuCompilerStacktrace(); XLA_SCOPED_LOGGING_TIMER( @@ -1871,7 +1908,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, // TODO(b/66051036): Run full msan for AOT. /*emit_code_for_msan=*/false); - TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + TF_RETURN_IF_ERROR(ir_emitter.EmitAllConstantGlobals()); for (ComputationToEmit subcomputation : SubcomputationEmissionOrder(computation)) { @@ -1914,13 +1951,18 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, TF_RETURN_IF_ERROR(verify_status); } - auto post_codegen_hook = [&](const llvm::object::ObjectFile& obj_file) { + auto post_codegen_hook = [&](const llvm::Module& llvm_module, + const llvm::object::ObjectFile& obj_file) { if (!DumpingEnabledForHloModule(*module)) { return; } - DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o", - absl::string_view(obj_file.getData().data(), - obj_file.getData().size())); + std::string_view id = llvm_module.getModuleIdentifier(); + size_t pos = std::min(id.size(), 1 + kXlaModuleIdentifier.size()); + DumpToFileInDir( + *module, /*file_prefix=*/"", + /*file_suffix=*/absl::StrCat("obj-file.", id.substr(pos), ".o"), + absl::string_view(obj_file.getData().data(), + obj_file.getData().size())); }; IrCompiler::Options ir_compiler_options = { @@ -1980,7 +2022,7 @@ class CpuExecutableAotCompilationResult : public AotCompilationResult { public: CpuExecutableAotCompilationResult( const HloModule* hlo_module, const BufferAssignment* buffer_assignment, - std::string_view function_name, std::vector obj_files, + absl::string_view function_name, std::vector obj_files, CompilationResultProto::ObjFileKind obj_file_kind) { *proto_.mutable_hlo_module()->mutable_hlo_module() = hlo_module->ToProto(); *proto_.mutable_hlo_module()->mutable_config() = @@ -2137,6 +2179,10 @@ CpuExecutableAotCompilationResult::LoadExecutable( // Collect compiled symbols from IrEmitter2. std::vector compiled_symbols; + for (auto& [name, module] : thunk_emitter.kernels()) { + compiled_symbols.push_back( + FunctionLibrary::Sym(name)); + } for (const auto& kernel : ir_emitter2.kernels()) { compiled_symbols.push_back( FunctionLibrary::Sym(kernel.name)); @@ -2217,5 +2263,30 @@ CpuCompiler::LoadAotCompilationResult( return CpuExecutableAotCompilationResult::FromString(serialized_aot_result); } +absl::StatusOr CpuCompiler::CreateHloSchedule( + const HloModule& hlo_module) const { + // Select a memory scheduler optimized for concurrency vs minimal memory. + auto scheduler = hlo_module.config() + .debug_options() + .xla_cpu_enable_concurrency_optimized_scheduler() + ? BFSMemoryScheduler + : DFSMemoryScheduler; + + // Select an order for emitting the HLO instructions for each + // computation. Using this sequence enables tighter buffer liveness analysis + // and reduced memory usage (as compared to using `DependencyHloOrdering`). + return ScheduleModule(&hlo_module, BufferSizeBytesFunction(), + ComputationSchedulerToModuleScheduler(scheduler)); +} + +absl::StatusOr> +CpuCompiler::CreateBufferAssignment(const HloModule& module) const { + // Run buffer allocation on the HLO graph. + return BufferAssigner::Run( + &module, std::make_unique(module.schedule()), + BufferSizeBytesFunction(), memory_alignment, + /*allocate_buffers_for_constants=*/true); +} + } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.h b/third_party/xla/xla/service/cpu/cpu_compiler.h index e9afc008a93e68..b38409f7a455df 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.h +++ b/third_party/xla/xla/service/cpu/cpu_compiler.h @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/status/status.h" @@ -29,15 +28,16 @@ limitations under the License. #include "xla/cpu_function_runtime.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/buffer_assignment.h" #include "xla/service/compiler.h" #include "xla/service/cpu/executable.pb.h" -#include "xla/service/cpu/xla_framework.h" #include "xla/service/executable.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_profile_printer_data.pb.h" #include "xla/service/llvm_compiler.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" @@ -189,6 +189,12 @@ class CpuCompiler : public LLVMCompiler { std::unique_ptr module, mlir::DialectRegistry* registry = nullptr); + absl::StatusOr CreateHloSchedule( + const HloModule& hlo_module) const; + + absl::StatusOr> CreateBufferAssignment( + const HloModule& module) const; + private: // Initialize the LLVM target. static void InitializeLLVMTarget(); diff --git a/third_party/xla/xla/service/cpu/cpu_compiler_test.cc b/third_party/xla/xla/service/cpu/cpu_compiler_test.cc index b60d1161c96910..a2afebce2e8285 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler_test.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler_test.cc @@ -10,18 +10,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #include #include #include -#include "xla/pjrt/pjrt_client.h" -#include "xla/service/hlo_runner.h" -#include "xla/service/hlo_runner_interface.h" -#include "xla/service/hlo_runner_pjrt.h" -#include "xla/service/platform_util.h" -#include "xla/shape.h" -#include "xla/tests/hlo_runner_agnostic_test_base.h" -#include "xla/tests/pjrt_client_registry.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tsl/lib/monitoring/collected_metrics.h" #include "xla/tsl/lib/monitoring/collection_registry.h" #include "tsl/platform/statusor.h" @@ -31,57 +27,52 @@ namespace xla { namespace cpu { namespace { -std::unique_ptr CreateHloRunner() { - if (!ShouldUsePjRt()) { - return std::make_unique( - PlatformUtil::GetDefaultPlatform().value()); - } - - PjRtClientTestFactoryRegistry& pjrt_registry = - GetGlobalPjRtClientTestFactory(); - std::unique_ptr client = pjrt_registry.Get()().value(); - PjRtClientTestFactoryRegistry::DeviceShapeRepresentationFn - device_shape_representation_fn = - pjrt_registry.GetDeviceShapeRepresentationFn(client.get()); - PjRtClientTestFactoryRegistry::DeviceShapeSizeFn device_shape_size_fn = - pjrt_registry.GetDeviceShapeSizeFn(client.get()); - return std::make_unique( - std::move(client), [](const Shape& host_shape) { return host_shape; }, - device_shape_size_fn); -} +using CpuCompilerTest = HloPjRtTestBase; -class CpuCompilerTest : public HloRunnerAgnosticTestBase { - public: - CpuCompilerTest() - : HloRunnerAgnosticTestBase(CreateHloRunner(), CreateHloRunner()) {} -}; +constexpr absl::string_view kCpuCompilerStacktraceMetricName = + "/xla/service/cpu/compiler_stacktrace_count"; TEST_F(CpuCompilerTest, RecordsStreamzStackTrace) { - const char* hlo_text = R"( + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( HloModule test ENTRY main { p = f32[10]{0} parameter(0) ROOT neg = f32[10]{0} negate(p) } - )"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text)); + )")); EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/true)); - const std::string kCpuCompilerStacktraceMetricName = - "/xla/service/cpu/compiler_stacktrace_count"; - tsl::monitoring::CollectionRegistry::CollectMetricsOptions options; std::unique_ptr metrics = tsl::monitoring::CollectionRegistry::Default()->CollectMetrics(options); - EXPECT_TRUE(metrics->point_set_map.find(kCpuCompilerStacktraceMetricName) != - metrics->point_set_map.end()); + + const auto it = metrics->point_set_map.find( + std::string(kCpuCompilerStacktraceMetricName)); + ASSERT_TRUE(it != metrics->point_set_map.end()); // Since Streamz is recorded every call, we expect at least one point. // All other callers may increment the counter as well. - EXPECT_GT( - metrics->point_set_map[kCpuCompilerStacktraceMetricName]->points.size(), - 0); + EXPECT_GT(it->second->points.size(), 0); +} + +TEST_F(CpuCompilerTest, CompilationWithLargeConstants) { + absl::string_view module_string = R"( +HloModule module + +ENTRY main { + a = f32[1000,1000]{1,0} parameter(0) + b = f32[1000,1000]{1,0} constant({...}) + a_plus_b = f32[1000,1000]{1,0} add(a, b) + c = f32[1000,1000]{1,0} constant({...}) + ROOT result = f32[1000,1000]{1,0} add(a_plus_b, c) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + + EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/true)); } } // namespace diff --git a/third_party/xla/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc b/third_party/xla/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc index 1193443a806a36..b898eab1c5d2f7 100644 --- a/third_party/xla/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc +++ b/third_party/xla/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include +#include "xla/hlo/testlib/test.h" #include "xla/service/cpu/ir_emission_utils.h" #include "xla/service/cpu/target_machine_features_stub.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" namespace xla { diff --git a/third_party/xla/xla/service/cpu/cpu_executable.h b/third_party/xla/xla/service/cpu/cpu_executable.h index e80221146db6bf..0e75a05ab6ca30 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.h +++ b/third_party/xla/xla/service/cpu/cpu_executable.h @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include diff --git a/third_party/xla/xla/service/cpu/cpu_executable_run_options.h b/third_party/xla/xla/service/cpu/cpu_executable_run_options.h index ee1a47e1382283..6d78723c8c30a5 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable_run_options.h +++ b/third_party/xla/xla/service/cpu/cpu_executable_run_options.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_CPU_CPU_EXECUTABLE_RUN_OPTIONS_H_ #define XLA_SERVICE_CPU_CPU_EXECUTABLE_RUN_OPTIONS_H_ -#include "xla/service/cpu/collectives_interface.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" namespace xla::cpu { @@ -25,16 +25,16 @@ namespace xla::cpu { // dependencies to ExecutableRunOptions. class CpuExecutableRunOptions { public: - CpuExecutableRunOptions& set_collectives(CollectivesInterface* collectives) { + CpuExecutableRunOptions& set_collectives(CpuCollectives* collectives) { collectives_ = collectives; return *this; } - CollectivesInterface* collectives() const { return collectives_; } + CpuCollectives* collectives() const { return collectives_; } private: // For cross-process collectives, use this collective implementation to // communicate. - CollectivesInterface* collectives_; + CpuCollectives* collectives_; }; } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc index 3a4aafa88a5b17..5435f0441b9134 100644 --- a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc +++ b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc @@ -19,6 +19,9 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/log/log.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/fusion_node_indexing_evaluation.h" #include "xla/service/instruction_fusion.h" @@ -81,6 +84,10 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, constexpr int kFusionThresholdBytes = 16 * 1024; + if (IsLargeConstant(producer)) { + return FusionDecision::Forbid("Don't fuse large constants."); + } + if (CanBeOutputFused(producer, consumer)) { VLOG(2) << "Fusion OK: Can create output fusion."; return FusionDecision::Allow(); @@ -219,5 +226,12 @@ HloInstruction* CpuInstructionFusion::FuseInstruction( evaluation->second.UpdateEvaluationCache(new_producer, indexing_users); return new_producer; } + +bool CpuInstructionFusion::IsLargeConstant( + const HloInstruction* constant) const { + return constant->IsConstant() && + Cast(constant)->literal().size_bytes() > + GetLargeConstantThresholdBytes(); +} } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.h b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.h index 87eec792924f64..e5c4c54b0005ed 100644 --- a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.h +++ b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.h @@ -43,6 +43,12 @@ class CpuInstructionFusion : public InstructionFusion { return InstructionFusion::Run(module, execution_threads); } + // Returns the threshold for a constant to be considered a large constant. + static constexpr int64_t GetLargeConstantThresholdBytes() { + constexpr int64_t kLargeConstantThresholdBytes = 10000; + return kLargeConstantThresholdBytes; + } + protected: FusionDecision ShouldFuse(HloInstruction* consumer, int64_t operand_index) override; @@ -53,6 +59,9 @@ class CpuInstructionFusion : public InstructionFusion { HloInstruction* FuseInstruction(HloInstruction* fusion_instruction, HloInstruction* producer) override; + // Returns if a constant is large enough to be considered a large constant. + bool IsLargeConstant(const HloInstruction* constant) const; + // Keep track of the number of times each instruction inside a fusion node is // indexed with different index vectors. absl::flat_hash_map diff --git a/third_party/xla/xla/service/cpu/cpu_instruction_fusion_test.cc b/third_party/xla/xla/service/cpu/cpu_instruction_fusion_test.cc index 933d5133e759ba..787c4d138b3448 100644 --- a/third_party/xla/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/third_party/xla/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/literal_util.h" #include "xla/service/transpose_folding.h" #include "xla/shape.h" #include "xla/tests/hlo_test_base.h" @@ -935,5 +936,45 @@ ENTRY main { EXPECT_THAT(module->entry_computation()->root_instruction(), op::Fusion()); } +TEST_F(OpcodeFusionTest, BigConstantNotInFusion) { + absl::string_view module_string = R"( +HloModule module + +ENTRY main { + a = f32[1000,1000]{1,0} parameter(0) + b = f32[1000,1000]{1,0} constant({...}) + a_plus_b = f32[1000,1000]{1,0} add(a, b) + c = f32[1000,1000]{1,0} constant({...}) + ROOT result = f32[1000,1000]{1,0} add(a_plus_b, c) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(module_string)); + RunFusionAndCheckOpcodesWereFused( + module.get(), {HloOpcode::kParameter, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kAdd, HloOpcode::kAdd}); +} + +TEST_F(OpcodeFusionTest, SmallConstantInFusion) { + absl::string_view module_string = R"( +HloModule module + +ENTRY main { + a = f32[10,10]{1,0} parameter(0) + b = f32[10,10]{1,0} constant({...}) + a_plus_b = f32[10,10]{1,0} add(a, b) + c = f32[10,10]{1,0} constant({...}) + ROOT result = f32[10,10]{1,0} add(a_plus_b, c) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(module_string)); + RunFusionAndCheckOpcodesWereFused( + module.get(), {HloOpcode::kParameter, HloOpcode::kConstant, + HloOpcode::kConstant, HloOpcode::kAdd, HloOpcode::kAdd}); +} + } // namespace } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/cpu_layout_assignment_test.cc b/third_party/xla/xla/service/cpu/cpu_layout_assignment_test.cc index 252049664af8f5..66c3a4f509a4fa 100644 --- a/third_party/xla/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/third_party/xla/xla/service/cpu/cpu_layout_assignment_test.cc @@ -27,6 +27,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/layout_util.h" #include "xla/literal.h" @@ -34,8 +36,6 @@ limitations under the License. #include "xla/service/cpu/target_machine_features_stub.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.cc b/third_party/xla/xla/service/cpu/cpu_runtime.cc index c1c44bd8600b79..f6efca7936780b 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.cc +++ b/third_party/xla/xla/service/cpu/cpu_runtime.cc @@ -20,10 +20,8 @@ limitations under the License. #include #include #include -#include #include #include -#include #include #include @@ -31,32 +29,39 @@ limitations under the License. #include "absl/base/attributes.h" #include "absl/base/dynamic_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/backends/cpu/collectives/cpu_cliques.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/backends/cpu/collectives/in_process_collectives.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/executable_run_options.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/layout_util.h" +#include "xla/primitive_util.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/cpu_executable_run_options.h" -#include "xla/service/cpu/in_process_collectives.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/service/global_device_id.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/status.h" #include "tsl/profiler/lib/traceme.h" namespace xla { @@ -334,13 +339,12 @@ RendezvousKey GetRendezvousKey(const ExecutableRunOptions* run_options, num_local_participants, op_kind, op_id}; } -CollectivesInterface* GetInProcessCollectivesImpl() { +CpuCollectives* GetInProcessCollectivesImpl() { static InProcessCollectives* c = new InProcessCollectives(); return c; } -CollectivesInterface* GetCollectivesImpl( - const ExecutableRunOptions* run_options) { +CpuCollectives* GetCollectivesImpl(const ExecutableRunOptions* run_options) { if (run_options->cpu_executable_run_options() && run_options->cpu_executable_run_options()->collectives()) { return run_options->cpu_executable_run_options()->collectives(); @@ -371,7 +375,7 @@ void AllToAllImpl(const ExecutableRunOptions* run_options, int64_t buffer_size, void** source_buffers, void** destination_buffers) { GlobalDeviceId device(GetDeviceOrdinal(run_options)); - std::string_view replica_groups_serialized( + absl::string_view replica_groups_serialized( static_cast(replica_groups_str), replica_groups_str_size); std::vector group = ParseReplicaGroupsOnly(replica_groups_serialized).value(); @@ -381,19 +385,31 @@ void AllToAllImpl(const ExecutableRunOptions* run_options, int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); - CollectivesInterface* collectives = GetCollectivesImpl(run_options); + CpuCollectives* collectives = GetCollectivesImpl(run_options); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(source_buffers, sizeof(void*) * num_buffers); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(destination_buffers, sizeof(void*) * num_buffers); - auto communicator = - collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); - TF_CHECK_OK(communicator->AllToAll( - rendezvous_key, buffer_size, - absl::Span(source_buffers, num_buffers), - absl::Span(destination_buffers, num_buffers), - DefaultCollectiveTimeout())); + + CpuCliqueKey clique_key(rendezvous_key.global_devices); + Communicator* communicator = + AcquireCommunicator(collectives, clique_key, RankId(rank)).value(); + + CpuCollectives::Executor executor(rendezvous_key, DefaultCollectiveTimeout()); + + std::vector source_buffers_data; + std::vector destination_buffers_data; + for (int i = 0; i < num_buffers; i++) { + source_buffers_data.push_back( + se::DeviceMemoryBase(source_buffers[i], buffer_size)); + destination_buffers_data.push_back( + se::DeviceMemoryBase(destination_buffers[i], buffer_size)); + } + + TF_CHECK_OK(communicator->AllToAll(source_buffers_data, + destination_buffers_data, U8, buffer_size, + executor)); } ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY @@ -403,7 +419,7 @@ void AllGatherImpl(const ExecutableRunOptions* run_options, int32_t replica_groups_str_size, int64_t buffer_size, void* source_buffer, void* destination_buffer) { GlobalDeviceId device(GetDeviceOrdinal(run_options)); - std::string_view replica_groups_serialized( + absl::string_view replica_groups_serialized( static_cast(replica_groups_str), replica_groups_str_size); std::vector group = ParseReplicaGroupsOnly(replica_groups_serialized).value(); @@ -413,13 +429,18 @@ void AllGatherImpl(const ExecutableRunOptions* run_options, int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); - CollectivesInterface* collectives = GetCollectivesImpl(run_options); + CpuCollectives* collectives = GetCollectivesImpl(run_options); + + CpuCliqueKey clique_key(rendezvous_key.global_devices); + Communicator* communicator = + AcquireCommunicator(collectives, clique_key, RankId(rank)).value(); + + se::DeviceMemoryBase input_buffer_data(source_buffer, buffer_size); + se::DeviceMemoryBase output_buffer_data(destination_buffer, buffer_size); - auto communicator = - collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); - TF_CHECK_OK(communicator->AllGather(rendezvous_key, buffer_size, - source_buffer, destination_buffer, - DefaultCollectiveTimeout())); + CpuCollectives::Executor executor(rendezvous_key, DefaultCollectiveTimeout()); + TF_CHECK_OK(communicator->AllGather(input_buffer_data, output_buffer_data, U8, + buffer_size, executor)); } ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY @@ -432,7 +453,7 @@ void ReduceScatterImpl(const ExecutableRunOptions* run_options, int64_t chunk_elems, void* input_buffer, void* output_buffer) { GlobalDeviceId device(GetDeviceOrdinal(run_options)); - std::string_view replica_groups_serialized( + absl::string_view replica_groups_serialized( static_cast(replica_groups_str), replica_groups_str_size); std::vector group = ParseReplicaGroupsOnly(replica_groups_serialized).value(); @@ -442,14 +463,23 @@ void ReduceScatterImpl(const ExecutableRunOptions* run_options, int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); - CollectivesInterface* collectives = GetCollectivesImpl(run_options); + CpuCollectives* collectives = GetCollectivesImpl(run_options); + + CpuCliqueKey clique_key(rendezvous_key.global_devices); + Communicator* communicator = + AcquireCommunicator(collectives, clique_key, RankId(rank)).value(); + + auto dtype = static_cast(element_type); - auto communicator = - collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); + se::DeviceMemoryBase input_buffer_data(input_buffer, + primitive_util::ByteWidth(dtype)); + se::DeviceMemoryBase output_buffer_data(output_buffer, + primitive_util::ByteWidth(dtype)); + + CpuCollectives::Executor executor(rendezvous_key, DefaultCollectiveTimeout()); TF_CHECK_OK(communicator->ReduceScatter( - rendezvous_key, static_cast(reduction_kind), - static_cast(element_type), chunk_elems, input_buffer, - output_buffer, DefaultCollectiveTimeout())); + input_buffer_data, output_buffer_data, dtype, chunk_elems, + static_cast(reduction_kind), executor)); } ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY @@ -461,7 +491,7 @@ void AllReduceImpl(const ExecutableRunOptions* run_options, int32_t shape_length, int32_t num_buffers, void** input_buffers, void** output_buffers) { GlobalDeviceId device(GetDeviceOrdinal(run_options)); - std::string_view replica_groups_serialized( + absl::string_view replica_groups_serialized( static_cast(replica_groups_str), replica_groups_str_size); std::vector group = ParseReplicaGroupsOnly(replica_groups_serialized).value(); @@ -479,16 +509,31 @@ void AllReduceImpl(const ExecutableRunOptions* run_options, int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); - CollectivesInterface* collectives = GetCollectivesImpl(run_options); + CpuCollectives* collectives = GetCollectivesImpl(run_options); + + CpuCliqueKey clique_key(rendezvous_key.global_devices); + Communicator* communicator = + AcquireCommunicator(collectives, clique_key, RankId(rank)).value(); + + // Convert input/output buffers to DeviceMemoryBase. + std::vector input_buffers_data; + std::vector output_buffers_data; + for (int i = 0; i < num_buffers; i++) { + Shape subshape = num_buffers == 1 ? shape : shape.tuple_shapes(i); + input_buffers_data.push_back(se::DeviceMemoryBase( + input_buffers[i], ShapeUtil::ByteSizeOf(subshape))); + output_buffers_data.push_back(se::DeviceMemoryBase( + output_buffers[i], ShapeUtil::ByteSizeOf(subshape))); + } + + CpuCollectives::Executor executor(rendezvous_key, DefaultCollectiveTimeout()); - auto communicator = - collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); for (int i = 0; i < num_buffers; i++) { Shape subshape = num_buffers == 1 ? shape : shape.tuple_shapes(i); TF_CHECK_OK(communicator->AllReduce( - rendezvous_key, static_cast(reduction_kind), - subshape.element_type(), ShapeUtil::ElementsIn(subshape), - input_buffers[i], output_buffers[i], DefaultCollectiveTimeout())); + input_buffers_data[i], output_buffers_data[i], subshape.element_type(), + ShapeUtil::ElementsIn(subshape), + static_cast(reduction_kind), executor)); } } @@ -499,7 +544,7 @@ void CollectivePermuteImpl(const ExecutableRunOptions* run_options, void* output_buffer, const void* source_target_pairs, int32_t source_target_pairs_size) { GlobalDeviceId device(GetDeviceOrdinal(run_options)); - std::string_view source_target_pairs_serialized( + absl::string_view source_target_pairs_serialized( static_cast(source_target_pairs), source_target_pairs_size); auto pairs = absl::StrSplit(source_target_pairs_serialized, ','); const DeviceAssignment::LogicalID logical_id = @@ -507,19 +552,19 @@ void CollectivePermuteImpl(const ExecutableRunOptions* run_options, int32_t logical_device_id = channel_id_present ? logical_id.computation_id : logical_id.replica_id; - std::optional source_replica_id; - std::vector copy_to; + std::optional source_replica_id; + std::vector copy_to; for (auto& p : pairs) { std::vector mapping = absl::StrSplit(p, '='); CHECK_EQ(mapping.size(), 2); int from = std::stoi(mapping[0]); int to = std::stoi(mapping[1]); if (from == logical_device_id) { - copy_to.push_back(to); + copy_to.push_back(RankId(to)); } if (to == logical_device_id) { CHECK(!source_replica_id.has_value()); - source_replica_id = from; + source_replica_id = RankId(from); } } RendezvousKey rendezvous_key = @@ -528,13 +573,20 @@ void CollectivePermuteImpl(const ExecutableRunOptions* run_options, int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value(); - CollectivesInterface* collectives = GetCollectivesImpl(run_options); + CpuCollectives* collectives = GetCollectivesImpl(run_options); + + CpuCliqueKey clique_key(rendezvous_key.global_devices); + Communicator* communicator = + AcquireCommunicator(collectives, clique_key, RankId(rank)).value(); + + CpuCollectives::Executor executor(rendezvous_key, DefaultCollectiveTimeout()); + + se::DeviceMemoryBase input_buffer_data(input_buffer, byte_size); + se::DeviceMemoryBase output_buffer_data(output_buffer, byte_size); - auto communicator = - collectives->GetCommunicator(rendezvous_key.global_devices, rank).value(); TF_CHECK_OK(communicator->CollectivePermute( - rendezvous_key, byte_size, source_replica_id, copy_to, input_buffer, - output_buffer, DefaultCollectiveTimeout())); + input_buffer_data, output_buffer_data, U8, byte_size, source_replica_id, + copy_to, executor)); } } // namespace } // namespace runtime diff --git a/third_party/xla/xla/service/cpu/dot_op_emitter.cc b/third_party/xla/xla/service/cpu/dot_op_emitter.cc index 91eee483beff52..4911cbcf235a05 100644 --- a/third_party/xla/xla/service/cpu/dot_op_emitter.cc +++ b/third_party/xla/xla/service/cpu/dot_op_emitter.cc @@ -36,6 +36,7 @@ limitations under the License. #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "llvm/Support/Alignment.h" @@ -770,6 +771,7 @@ absl::Status DotOpEmitter::EmitCallToRuntime() { bool use_acl = hlo_module_config_.debug_options().xla_cpu_use_acl(); PrimitiveType type = target_array_.GetShape().element_type(); llvm::Function* function = b_->GetInsertBlock()->getParent(); + llvm::LLVMContext& context = b_->getContext(); llvm::Module* module = function->getParent(); llvm::Type* float_type; const char* fn_name; @@ -797,13 +799,13 @@ absl::Status DotOpEmitter::EmitCallToRuntime() { fn_name = multi_threaded ? runtime::kEigenMatMulC64SymbolName : runtime::kEigenSingleThreadedMatMulC64SymbolName; - float_type = llvm_ir::PrimitiveTypeToIrType(C64, module); + float_type = llvm_ir::PrimitiveTypeToIrType(C64, context); break; case C128: fn_name = multi_threaded ? runtime::kEigenMatMulC128SymbolName : runtime::kEigenSingleThreadedMatMulC128SymbolName; - float_type = llvm_ir::PrimitiveTypeToIrType(C128, module); + float_type = llvm_ir::PrimitiveTypeToIrType(C128, context); break; case S32: fn_name = multi_threaded @@ -1108,13 +1110,12 @@ Shape CollapseFirstNDims(const Shape& shape, int64_t n) { llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilderBase* b, const llvm_ir::IrArray& array, int64_t n) { - llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); const Shape& shape = array.GetShape(); CHECK(shape.has_layout() && LayoutUtil::IsMonotonicWithDim0Major(shape.layout())); CHECK_GE(shape.dimensions_size(), n); Shape new_shape = CollapseFirstNDims(shape, n); - llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module); + llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, b->getContext()); return llvm_ir::IrArray(array.GetBasePointer(), new_ir_type, std::move(new_shape)); } @@ -1138,8 +1139,6 @@ absl::Status ValidateDotDimensionNumbers( llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array, llvm::Value* batch_index, llvm::IRBuilderBase* b) { - llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); - Shape inner_shape = DropFirstDim(outer_array.GetShape()); std::vector multidim_index(inner_shape.rank() + 1, b->getInt64(0)); @@ -1147,7 +1146,8 @@ llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array, llvm_ir::IrArray::Index slice_index(multidim_index, outer_array.GetShape(), batch_index->getType()); llvm::Value* slice_ptr = outer_array.EmitArrayElementAddress(slice_index, b); - llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(inner_shape, module); + llvm::Type* new_ir_type = + llvm_ir::ShapeToIrType(inner_shape, b->getContext()); return llvm_ir::IrArray(slice_ptr, new_ir_type, std::move(inner_shape)); } diff --git a/third_party/xla/xla/service/cpu/elemental_ir_emitter.cc b/third_party/xla/xla/service/cpu/elemental_ir_emitter.cc new file mode 100644 index 00000000000000..41a6f0524befaf --- /dev/null +++ b/third_party/xla/xla/service/cpu/elemental_ir_emitter.cc @@ -0,0 +1,57 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/elemental_ir_emitter.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/IR/Value.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/service/cpu/elemental_math_emitter.h" + +namespace xla::cpu { + +absl::StatusOr CpuElementalIrEmitter::EmitAtan2( + PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs, + absl::string_view) { + return xla::cpu::EmitAtan2(module(), *b(), prim_type, lhs, rhs); +} + +absl::StatusOr CpuElementalIrEmitter::EmitTanh( + PrimitiveType prim_type, llvm::Value* value) { + return xla::cpu::EmitTanh(module(), *b(), prim_type, value); +} + +absl::StatusOr CpuElementalIrEmitter::EmitErf( + PrimitiveType prim_type, llvm::Value* value) { + return xla::cpu::EmitErf(module(), *b(), prim_type, value); +} + +absl::StatusOr> +CpuElementalIrEmitter::EmitThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view name, bool is_reducer) { + if (thread_local_call_fn_ == nullptr) { + return absl::InternalError("Thread local call function is not set."); + } + + return thread_local_call_fn_(callee, parameters, name, is_reducer); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/elemental_ir_emitter.h b/third_party/xla/xla/service/cpu/elemental_ir_emitter.h new file mode 100644 index 00000000000000..921df54d7c8c7d --- /dev/null +++ b/third_party/xla/xla/service/cpu/elemental_ir_emitter.h @@ -0,0 +1,74 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_ELEMENTAL_IR_EMITTER_H_ +#define XLA_SERVICE_CPU_ELEMENTAL_IR_EMITTER_H_ + +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Value.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/service/elemental_ir_emitter.h" + +namespace xla::cpu { + +class CpuElementalIrEmitter final : public ElementalIrEmitter { + public: + using ThreadLocalCallPrototype = absl::StatusOr>( + const HloComputation& callee, absl::Span parameters, + absl::string_view name, bool is_reducer); + using ThreadLocalCallCallback = absl::AnyInvocable; + + CpuElementalIrEmitter(llvm::Module* llvm_module, llvm::IRBuilderBase* builder, + ThreadLocalCallCallback thread_local_call_fn, + bool use_truncate_f32_to_bf16_conversion, + bool fast_min_max) + : ElementalIrEmitter(llvm_module, builder, + Options{use_truncate_f32_to_bf16_conversion}), + thread_local_call_fn_(std::move(thread_local_call_fn)), + fast_min_max_(fast_min_max) {} + + private: + absl::StatusOr EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, llvm::Value* rhs, + absl::string_view) override; + + absl::StatusOr EmitTanh(PrimitiveType prim_type, + llvm::Value* value) override; + + absl::StatusOr EmitErf(PrimitiveType prim_type, + llvm::Value* value) override; + + absl::StatusOr> EmitThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view name, bool is_reducer) override; + + bool fast_min_max() override { return fast_min_max_; } + + private: + ThreadLocalCallCallback thread_local_call_fn_; + bool fast_min_max_; +}; + +} // namespace xla::cpu + +#endif // XLA_SERVICE_CPU_ELEMENTAL_IR_EMITTER_H_ diff --git a/third_party/xla/xla/service/cpu/in_process_collectives.cc b/third_party/xla/xla/service/cpu/in_process_collectives.cc deleted file mode 100644 index d63862b6b3bbbe..00000000000000 --- a/third_party/xla/xla/service/cpu/in_process_collectives.cc +++ /dev/null @@ -1,656 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/cpu/in_process_collectives.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/time/time.h" -#include "absl/types/span.h" -#include "xla/primitive_util.h" -#include "xla/refcounting_hash_map.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" -#include "xla/service/global_device_id.h" -#include "xla/status_macros.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" - -namespace xla { -namespace cpu { -namespace runtime { -namespace { - -void FormatGlobalId(std::string* out, const GlobalDeviceId& device) { - absl::StrAppend(out, device.value()); -} - -struct AllReduceParticipantData : ParticipantData { - explicit AllReduceParticipantData(const RendezvousKey& rendezvous_key_p, - int rank) - : ParticipantData(rendezvous_key_p, rank) {} - - int64_t element_count; - const void* source_data; - void* destination_data; - PrimitiveType primitive_type; - - ReductionKind reduction_kind; - - std::string ToString() const override { - return absl::StrFormat( - "AllReduceParticipantData{rank=%d, element_count=%d, type=%s, " - "rendezvous_key=%s}", - local_rank, element_count, PrimitiveType_Name(primitive_type), - rendezvous_key.ToString()); - } -}; - -template -T GetInitialValue(ReductionKind reduction_kind) { - switch (reduction_kind) { - case ReductionKind::SUM: - return static_cast(0); - case ReductionKind::PRODUCT: - return static_cast(1); - case ReductionKind::MIN: - return std::numeric_limits::has_infinity - ? std::numeric_limits::infinity() - : std::numeric_limits::max(); - case ReductionKind::MAX: - return std::numeric_limits::has_infinity - ? -std::numeric_limits::infinity() - : std::numeric_limits::lowest(); - } -} - -// We cannot use static_assert(false), because the C++ standard (prior to -// CWG2518) does not allow the statement discarded by a constexpr if to -// be ill-formed for every possible specialization. -// See https://en.cppreference.com/w/cpp/language/if#Constexpr_if -template -constexpr bool always_false_v = false; - -template -void ReduceHelper(absl::Span acc, absl::Span inputs) { - // TODO(penporn): make sure this gets vectorized. - if constexpr (reduction_kind == ReductionKind::SUM) { - for (size_t j = 0; j < inputs.size(); ++j) { - for (size_t i = 0; i < acc.size(); ++i) { - acc[i] += inputs[j][i]; - } - } - } else if constexpr (reduction_kind == ReductionKind::PRODUCT) { - for (size_t j = 0; j < inputs.size(); ++j) { - for (size_t i = 0; i < acc.size(); ++i) { - acc[i] *= inputs[j][i]; - } - } - } else if constexpr (reduction_kind == ReductionKind::MIN) { - for (size_t j = 0; j < inputs.size(); ++j) { - for (size_t i = 0; i < acc.size(); ++i) { - acc[i] = std::min(acc[i], inputs[j][i]); - } - } - } else if constexpr (reduction_kind == ReductionKind::MAX) { - for (size_t j = 0; j < inputs.size(); ++j) { - for (size_t i = 0; i < acc.size(); ++i) { - acc[i] = std::max(acc[i], inputs[j][i]); - } - } - } else { - static_assert(always_false_v, "Unsupported reduction kind"); - } -} - -template -absl::Status ReduceScatter(ReductionKind reduction_kind, - absl::Span inputs, void* output, - int64_t num_elems) { - using T = typename primitive_util::PrimitiveTypeToNative::type; - T initial_value = GetInitialValue(reduction_kind); - - absl::Span out_chunk = - absl::MakeSpan(reinterpret_cast(output), num_elems); - for (int64_t i = 0; i < num_elems; ++i) { - out_chunk[i] = initial_value; - } - - absl::Span input_chunks( - reinterpret_cast(inputs.data()), inputs.size()); - switch (reduction_kind) { - case ReductionKind::SUM: - ReduceHelper(out_chunk, input_chunks); - break; - case ReductionKind::PRODUCT: - ReduceHelper(out_chunk, input_chunks); - break; - case ReductionKind::MIN: - if constexpr (!is_complex_v) { - ReduceHelper(out_chunk, input_chunks); - } else { - return absl::InvalidArgumentError( - "Min reductions not supported for complex types"); - } - break; - case ReductionKind::MAX: - if constexpr (!is_complex_v) { - ReduceHelper(out_chunk, input_chunks); - } else { - return absl::InvalidArgumentError( - "Max reductions not supported for complex types"); - } - break; - } - - return absl::OkStatus(); -} - -class CpuAllReduceRendezvous - : public Rendezvous { - public: - explicit CpuAllReduceRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - absl::StatusOr RunCollectiveOp( - const AllReduceParticipantData& me) override { - VLOG(3) << me.ToString(); - int64_t world_size = participants_.size(); - // Divide the buffer up into equal(ish) chunks. Rank r computes the r-th - // chunk of the output. - int64_t chunk_elems = CeilOfRatio(me.element_count, world_size); - - int64_t start_elem = me.local_rank * chunk_elems; - int64_t end_elem = std::min(start_elem + chunk_elems, me.element_count); - chunk_elems = std::max(int64_t{0}, end_elem - start_elem); - if (chunk_elems == 0) { - return nullptr; - } - - auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type); - int64_t chunk_offset = start_elem * bytes_per_elem; - int64_t chunk_bytes = chunk_elems * bytes_per_elem; - void* reduce_output = - reinterpret_cast(me.destination_data) + chunk_offset; - - std::vector inputs; - inputs.reserve(world_size); - for (const auto& p : participants_) { - inputs.push_back(reinterpret_cast(p->source_data) + - chunk_offset); - } - - switch (me.primitive_type) { - case S8: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); - break; - case PRED: - case U8: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); - break; - case S16: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); - break; - case U16: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); - break; - case S32: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); - break; - case U32: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); - break; - case S64: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); - break; - case U64: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); - break; - case F16: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); - break; - case F32: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); - break; - case F64: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); - break; - case C64: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); - break; - case C128: - TF_RETURN_IF_ERROR(ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems)); - break; - default: - return absl::UnimplementedError("Unexpected datatype"); - } - - // All-gather the reduced chunks. - for (const auto& p : participants_) { - if (p->local_rank != me.local_rank) { - std::memcpy(reinterpret_cast(p->destination_data) + chunk_offset, - reduce_output, chunk_bytes); - } - } - return nullptr; - } -}; - -struct CollectivePermuteParticipantData : ParticipantData { - CollectivePermuteParticipantData(const RendezvousKey& rendezvous_key_p, - int rank) - : ParticipantData(rendezvous_key_p, rank) {} - const void* source_buffer; - void* destination_buffer; - size_t num_bytes; - - // From which rank is this participant receiving its data? Optional; if - // absent fill with zeros. - std::optional source_rank; - - std::string ToString() const override { - return absl::StrFormat( - "CollectivePermuteParticipantData{rank=%d, " - "source_buffer=%p, destination_buffer=%p, num_bytes=%d, " - "source_replica_id=%d, " - "devices=[%s]}", - local_rank, source_buffer, destination_buffer, num_bytes, - source_rank.value_or(-1), - absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId)); - } -}; - -class CpuCollectivePermuteRendezvous - : public Rendezvous { - public: - explicit CpuCollectivePermuteRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - CollectivesInterface* collectives_; - - absl::StatusOr RunCollectiveOp( - const CollectivePermuteParticipantData& p) override { - VLOG(3) << p.ToString(); - if (p.source_rank) { - std::memcpy(p.destination_buffer, - participants_[*p.source_rank]->source_buffer, p.num_bytes); - } else { - std::memset(p.destination_buffer, 0, p.num_bytes); - } - return nullptr; - } -}; - -struct AllToAllParticipantData : ParticipantData { - AllToAllParticipantData(const RendezvousKey& rendezvous_key_p, int rank) - : ParticipantData(rendezvous_key_p, rank) {} - - std::vector source_buffers; - std::vector destination_buffers; - size_t chunk_size; - - std::string ToString() const override { - auto addr_formatter = [](std::string* out, const void* mem) { - absl::StrAppend(out, absl::StrFormat("%p", mem)); - }; - return absl::StrFormat( - "AllToAllParticipantData{rank=%d, " - "devices=[%s], source_buffers=[%s], " - "destination_buffers=[%s], chunk_size=%d}", - local_rank, - absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), - absl::StrJoin(source_buffers, ", ", addr_formatter), - absl::StrJoin(destination_buffers, ", ", addr_formatter), chunk_size); - } -}; - -class CpuAllToAllRendezvous - : public Rendezvous { - public: - explicit CpuAllToAllRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - CollectivesInterface* collectives_; - absl::StatusOr RunCollectiveOp( - const AllToAllParticipantData& p) override { - int world_size = p.rendezvous_key.global_devices.size(); - for (int i = 0; i < world_size; ++i) { - std::memcpy(participants_[i]->destination_buffers[p.local_rank], - p.source_buffers[i], p.chunk_size); - } - return nullptr; - } -}; - -struct AllGatherParticipantData : ParticipantData { - AllGatherParticipantData(const RendezvousKey& rendezvous_key_p, int rank) - : ParticipantData(rendezvous_key_p, rank) {} - - const void* source_buffer; - void* destination_buffer; - size_t chunk_size; - - std::string ToString() const override { - return absl::StrFormat( - "AllGatherParticipantData{rank=%d, " - "devices=[%s], source_buffer=%p, " - "destination_buffer=%p, chunk_size=%d}", - local_rank, - absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), - source_buffer, destination_buffer, chunk_size); - } -}; - -class CpuAllGatherRendezvous - : public Rendezvous { - public: - explicit CpuAllGatherRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - CollectivesInterface* collectives_; - absl::StatusOr RunCollectiveOp( - const AllGatherParticipantData& p) override { - int world_size = p.rendezvous_key.global_devices.size(); - char* out = static_cast(p.destination_buffer); - for (int i = 0; i < world_size; ++i, out += p.chunk_size) { - std::memcpy(out, participants_[i]->source_buffer, p.chunk_size); - } - return nullptr; - } -}; - -struct ReduceScatterParticipantData : ParticipantData { - ReduceScatterParticipantData(const RendezvousKey& rendezvous_key_p, int rank) - : ParticipantData(rendezvous_key_p, rank) {} - - ReductionKind reduction_kind; - PrimitiveType element_type; - const void* source_buffer; - void* destination_buffer; - size_t chunk_elems; - - std::string ToString() const override { - return absl::StrFormat( - "ReduceScatterParticipantData{rank=%d, " - "devices=[%s], source_buffer=%p, " - "destination_buffer=%p, chunk_elems=%d}", - local_rank, - absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), - source_buffer, destination_buffer, chunk_elems); - } -}; - -class CpuReduceScatterRendezvous - : public Rendezvous { - public: - explicit CpuReduceScatterRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - CollectivesInterface* collectives_; - absl::StatusOr RunCollectiveOp( - const ReduceScatterParticipantData& me) override { - auto bytes_per_elem = primitive_util::ByteWidth(me.element_type); - int64_t chunk_offset = me.local_rank * me.chunk_elems * bytes_per_elem; - - std::vector inputs; - inputs.reserve(participants_.size()); - for (const auto& p : participants_) { - inputs.push_back(reinterpret_cast(p->source_buffer) + - chunk_offset); - } - - switch (me.element_type) { - case S8: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case PRED: - case U8: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case S16: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case U16: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case S32: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case U32: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case S64: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case U64: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case F16: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case F32: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case F64: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case C64: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - case C128: - TF_RETURN_IF_ERROR(ReduceScatter( - me.reduction_kind, inputs, me.destination_buffer, me.chunk_elems)); - break; - default: - return absl::UnimplementedError("Unexpected datatype"); - } - - return nullptr; - } -}; - -} // namespace - -struct InProcessCollectivesState { - RefcountingHashMap - all_reduce_rendezvous_map; - RefcountingHashMap - collective_permute_rendezvous_map; - RefcountingHashMap - all_to_all_rendezvous_map; - RefcountingHashMap - all_gather_rendezvous_map; - RefcountingHashMap - reduce_scatter_rendezvous_map; -}; - -InProcessCollectivesCommunicator::InProcessCollectivesCommunicator( - InProcessCollectivesState* state, int rank, int size) - : state_(state), rank_(rank) {} -InProcessCollectivesCommunicator::~InProcessCollectivesCommunicator() = default; - -absl::Status InProcessCollectivesCommunicator::AllReduce( - const RendezvousKey& key, ReductionKind reduction_kind, - PrimitiveType element_type, size_t num_elements, - const void* const input_buffer, void* const output_buffer, - absl::Duration timeout) { - AllReduceParticipantData participant(key, rank_); - participant.element_count = num_elements; - participant.primitive_type = element_type; - participant.source_data = input_buffer; - participant.destination_data = output_buffer; - participant.reduction_kind = reduction_kind; - - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - - return CpuAllReduceRendezvous::SubmitParticipant( - [&] { - return state_->all_reduce_rendezvous_map.GetOrCreateIfAbsent( - key, make_cpu_rendezvous); - }, - participant) - .status(); -} - -absl::Status InProcessCollectivesCommunicator::CollectivePermute( - const RendezvousKey& key, size_t num_bytes, std::optional source_rank, - absl::Span target_ranks, const void* input_buffer, - void* output_buffer, absl::Duration timeout) { - CollectivePermuteParticipantData participant(key, rank_); - participant.source_buffer = input_buffer; - participant.destination_buffer = output_buffer; - participant.num_bytes = num_bytes; - participant.source_rank = source_rank; - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - return CpuCollectivePermuteRendezvous::SubmitParticipant( - [&] { - return state_->collective_permute_rendezvous_map - .GetOrCreateIfAbsent(key, make_cpu_rendezvous); - }, - participant) - .status(); -} - -absl::Status InProcessCollectivesCommunicator::AllToAll( - const RendezvousKey& key, size_t chunk_bytes, - absl::Span input_buffers, - absl::Span output_buffers, absl::Duration timeout) { - AllToAllParticipantData participant(key, rank_); - TF_RET_CHECK(input_buffers.size() == output_buffers.size()); - participant.chunk_size = chunk_bytes; - participant.source_buffers.reserve(input_buffers.size()); - participant.destination_buffers.reserve(output_buffers.size()); - for (const void* input_buffer : input_buffers) { - participant.source_buffers.push_back(input_buffer); - } - for (void* output_buffer : output_buffers) { - participant.destination_buffers.push_back(output_buffer); - } - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - return CpuAllToAllRendezvous::SubmitParticipant( - [&] { - return state_->all_to_all_rendezvous_map.GetOrCreateIfAbsent( - key, make_cpu_rendezvous); - }, - participant) - .status(); -} - -absl::Status InProcessCollectivesCommunicator::AllGather( - const RendezvousKey& key, size_t chunk_bytes, const void* input_buffer, - void* output_buffer, absl::Duration timeout) { - AllGatherParticipantData participant(key, rank_); - participant.chunk_size = chunk_bytes; - participant.source_buffer = input_buffer; - participant.destination_buffer = output_buffer; - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - return CpuAllGatherRendezvous::SubmitParticipant( - [&] { - return state_->all_gather_rendezvous_map.GetOrCreateIfAbsent( - key, make_cpu_rendezvous); - }, - participant) - .status(); -} - -absl::Status InProcessCollectivesCommunicator::ReduceScatter( - const RendezvousKey& key, ReductionKind reduction_kind, - PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, - void* output_buffer, absl::Duration timeout) { - ReduceScatterParticipantData participant(key, rank_); - participant.element_type = element_type; - participant.reduction_kind = reduction_kind; - participant.chunk_elems = chunk_elems; - participant.source_buffer = input_buffer; - participant.destination_buffer = output_buffer; - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - return CpuReduceScatterRendezvous::SubmitParticipant( - [&] { - return state_->reduce_scatter_rendezvous_map.GetOrCreateIfAbsent( - key, make_cpu_rendezvous); - }, - participant) - .status(); -} -InProcessCollectives::InProcessCollectives() - : state_(std::make_unique()) {} -InProcessCollectives::~InProcessCollectives() = default; - -absl::StatusOr> -InProcessCollectives::GetCommunicator(absl::Span devices, - int rank) { - // We don't care about devices here: we share rendezvous state globally. - return std::make_shared(state_.get(), rank, - devices.size()); -} - -} // namespace runtime -} // namespace cpu -} // namespace xla diff --git a/third_party/xla/xla/service/cpu/in_process_collectives.h b/third_party/xla/xla/service/cpu/in_process_collectives.h deleted file mode 100644 index 4551644585a6f7..00000000000000 --- a/third_party/xla/xla/service/cpu/in_process_collectives.h +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_CPU_IN_PROCESS_COLLECTIVES_H_ -#define XLA_SERVICE_CPU_IN_PROCESS_COLLECTIVES_H_ - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/time/time.h" -#include "absl/types/span.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" -#include "xla/service/global_device_id.h" -#include "xla/xla_data.pb.h" - -namespace xla::cpu::runtime { - -struct InProcessCollectivesState; - -class InProcessCollectivesCommunicator : public CollectivesCommunicator { - public: - InProcessCollectivesCommunicator(InProcessCollectivesState* state, int rank, - int size); - ~InProcessCollectivesCommunicator() override; - - absl::Status AllReduce(const RendezvousKey& key, ReductionKind reduction_kind, - PrimitiveType element_type, size_t num_elements, - const void* input_buffer, void* output_buffer, - absl::Duration timeout) override; - - absl::Status CollectivePermute(const RendezvousKey& key, size_t num_bytes, - std::optional source_rank, - absl::Span target_ranks, - const void* input_buffer, void* output_buffer, - absl::Duration timeout) override; - - absl::Status AllToAll(const RendezvousKey& key, size_t chunk_bytes, - absl::Span input_buffers, - absl::Span output_buffers, - absl::Duration timeout) override; - - absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, - const void* input_buffer, void* output_buffer, - absl::Duration timeout) override; - - absl::Status ReduceScatter(const RendezvousKey& key, - ReductionKind reduction_kind, - PrimitiveType element_type, size_t chunk_elems, - const void* input_buffer, void* output_buffer, - absl::Duration timeout) override; - - private: - InProcessCollectivesState* state_; - int rank_; -}; - -class InProcessCollectives : public CollectivesInterface { - public: - InProcessCollectives(); - ~InProcessCollectives() override; - - // Thread-safe. - absl::StatusOr> GetCommunicator( - absl::Span devices, int rank) override; - - private: - std::unique_ptr state_; -}; - -} // namespace xla::cpu::runtime - -#endif // XLA_SERVICE_CPU_IN_PROCESS_COLLECTIVES_H_ diff --git a/third_party/xla/xla/service/cpu/ir_emission_utils_test.cc b/third_party/xla/xla/service/cpu/ir_emission_utils_test.cc index 6babf519fde9b8..b957dde61e3786 100644 --- a/third_party/xla/xla/service/cpu/ir_emission_utils_test.cc +++ b/third_party/xla/xla/service/cpu/ir_emission_utils_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "xla/hlo/testlib/test.h" #include "xla/service/cpu/target_machine_features_stub.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" namespace xla { diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index 37b16f3fb1e5ee..e9006e196b2268 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -26,7 +26,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -67,9 +66,11 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/cpu_instruction_fusion.h" #include "xla/service/cpu/cpu_options.h" #include "xla/service/cpu/cpu_runtime.h" #include "xla/service/cpu/dot_op_emitter.h" +#include "xla/service/cpu/elemental_ir_emitter.h" #include "xla/service/cpu/elemental_math_emitter.h" #include "xla/service/cpu/ir_emission_utils.h" #include "xla/service/cpu/ir_function.h" @@ -114,51 +115,6 @@ bool IsNativeConvertSupportedOnTargetCPU(std::string feature_string) { absl::StrContains(feature_string, "+amx-bf16")); } -class IrEmitter::CpuElementalIrEmitter : public ElementalIrEmitter { - public: - CpuElementalIrEmitter(const HloModuleConfig& module_config, - IrEmitter* ir_emitter, llvm::Module* module) - : ElementalIrEmitter( - module, ir_emitter->b(), - Options{/*xla_cpu_use_truncate_f32_to_bf16_conversion=*/ - !IsNativeConvertSupportedOnTargetCPU( - ir_emitter->target_machine_features_ - .get_target_feature_string())}), - hlo_module_config_(module_config), - ir_emitter_(ir_emitter) {} - - protected: - absl::StatusOr EmitAtan2(PrimitiveType prim_type, - llvm::Value* lhs, llvm::Value* rhs, - absl::string_view) override { - return xla::cpu::EmitAtan2(module(), *b(), prim_type, lhs, rhs); - } - - absl::StatusOr EmitTanh(PrimitiveType prim_type, - llvm::Value* value) override { - return xla::cpu::EmitTanh(module(), *b(), prim_type, value); - } - - absl::StatusOr EmitErf(PrimitiveType prim_type, - llvm::Value* value) override { - return xla::cpu::EmitErf(module(), *b(), prim_type, value); - } - - absl::StatusOr> EmitThreadLocalCall( - const HloComputation& callee, absl::Span parameters, - absl::string_view name, bool is_reducer) override { - return ir_emitter_->EmitThreadLocalCall(callee, parameters, name, - is_reducer); - } - - bool fast_min_max() override { - return hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max(); - } - - const HloModuleConfig& hlo_module_config_; - IrEmitter* ir_emitter_; -}; - IrEmitter::IrEmitter(mlir::MLIRContext* mlir_context, const HloModule& hlo_module, const BufferAssignment& assignment, @@ -182,6 +138,7 @@ IrEmitter::IrEmitter(mlir::MLIRContext* mlir_context, computation_transitively_contains_custom_call_( std::move(computation_transitively_contains_custom_call)), alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), + hlo_module_(hlo_module), hlo_module_config_(hlo_module.config()), is_top_level_computation_(false), target_machine_features_(*target_machine_features), @@ -214,7 +171,8 @@ void IrEmitter::EmitThreadLocalFunctionEpilogue(HloComputation* computation) { } else { CHECK(return_shape.IsTuple()); - llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_); + llvm::Type* tuple_type = + llvm_ir::ShapeToIrType(return_shape, module_->getContext()); for (int i = 0; i < return_shape.tuple_shapes_size(); i++) { const Shape& element_shape = return_shape.tuple_shapes(i); @@ -244,7 +202,13 @@ absl::StatusOr IrEmitter::EmitComputation( std::string function_name = name_uniquer_.GetUniqueName(function_name_prefix); VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]"; is_top_level_computation_ = is_top_level_computation; + + auto cleanup = absl::MakeCleanup( + [saved_allow_reassociation = allow_reassociation_, this]() { + allow_reassociation_ = saved_allow_reassociation; + }); allow_reassociation_ = allow_reassociation; + num_dynamic_loop_bounds_ = 0; auto backend_config_or = computation->root_instruction()->backend_config(); @@ -350,9 +314,24 @@ llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { return result_global; } -absl::Status IrEmitter::EmitConstantGlobals() { +absl::Status IrEmitter::EmitSmallConstantGlobals() { + return EmitConstantGlobals(/*max_size_bytes=*/CpuInstructionFusion:: + GetLargeConstantThresholdBytes()); +} + +absl::Status IrEmitter::EmitAllConstantGlobals() { + return EmitConstantGlobals(/*max_size_bytes=*/std::nullopt); +} + +absl::Status IrEmitter::EmitConstantGlobals( + std::optional max_size_bytes) { for (const BufferAllocation& allocation : assignment_.Allocations()) { - if (!allocation.is_constant()) { + // Large constants don't get fused with other instructions, so we don't + // need to emit them as globals. + if (!allocation.is_constant() || + (max_size_bytes && + llvm_ir::LiteralForConstantAllocation(allocation).size_bytes() > + *max_size_bytes)) { continue; } @@ -1599,7 +1578,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( ShardedVectorType sharded_vector_type; llvm::Type* element_ir_type = - llvm_ir::PrimitiveTypeToIrType(element_type, module_); + llvm_ir::PrimitiveTypeToIrType(element_type, module_->getContext()); for (int i = 0, e = 1 + Log2Ceiling(element_count); i < e; i++) { // For every power of two present in element_count, we generate one or more @@ -2211,7 +2190,7 @@ absl::Status IrEmitter::HandleFusion(HloInstruction* fusion) { auto* root = fusion->fused_expression_root(); if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) { VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace"; - CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); + CpuElementalIrEmitter elemental_emitter = ElementalIrEmmiterFactory(); FusedIrEmitter fused_emitter(elemental_emitter); BindFusionArguments(fusion, &fused_emitter); @@ -2221,7 +2200,7 @@ absl::Status IrEmitter::HandleFusion(HloInstruction* fusion) { fusion, GetIrArrayFor(fusion), &fused_emitter, b()); } else if (fusion->IsLoopFusion()) { VLOG(3) << "HandleFusion kLoop"; - CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); + CpuElementalIrEmitter elemental_emitter = ElementalIrEmmiterFactory(); FusedIrEmitter fused_emitter(elemental_emitter); BindFusionArguments(fusion, &fused_emitter); TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator( @@ -3008,7 +2987,8 @@ absl::Status IrEmitter::HandleWhile(HloInstruction* xla_while) { Load(IrShapeType( xla_while->while_condition()->root_instruction()->shape()), GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), - llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); + llvm::ConstantInt::get( + llvm_ir::PrimitiveTypeToIrType(PRED, module_->getContext()), 0)); // Branches to the body or to the while exit depending on the condition. llvm::BasicBlock* body_bb = @@ -3188,7 +3168,7 @@ struct EncodedInfo { }; template -static EncodedInfo StoreEncodedTypes(std::string_view alloca_name, +static EncodedInfo StoreEncodedTypes(absl::string_view alloca_name, const Args& args, llvm::IRBuilderBase& ir) { // Store the types of `args` into the allocated memory. These types are stored @@ -3217,7 +3197,7 @@ static EncodedInfo StoreEncodedTypes(std::string_view alloca_name, }; template -static EncodedInfo StoreEncodedShapes(std::string_view alloca_name, +static EncodedInfo StoreEncodedShapes(absl::string_view alloca_name, const Args& args, llvm::IRBuilderBase& ir) { // Prepare metadata for all buffers. A tuple shape is flattened to only encode @@ -3343,7 +3323,7 @@ void EmitTransferElements(llvm::Value* target, llvm::Value* source, primitive_type_size, ::xla::cpu::MinimumAlignmentForPrimitiveType(primitive_type))); llvm::Type* primitive_llvm_type = - llvm_ir::PrimitiveTypeToIrType(primitive_type, module); + llvm_ir::PrimitiveTypeToIrType(primitive_type, module->getContext()); if (element_count == 1) { auto* load_instruction = @@ -3439,11 +3419,11 @@ absl::Status IrEmitter::HandleConditional(HloInstruction* conditional) { llvm::LoadInst* pred_value = Load( GetIrArrayFor(branch_index).GetBasePointeeType(), GetIrArrayFor(branch_index).GetBasePointer(), "load_predicate_value"); - llvm::Value* pred_cond = - ICmpNE(pred_value, - llvm::ConstantInt::get( - llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), - "boolean_predicate"); + llvm::Value* pred_cond = ICmpNE( + pred_value, + llvm::ConstantInt::get( + llvm_ir::PrimitiveTypeToIrType(PRED, module_->getContext()), 0), + "boolean_predicate"); llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(pred_cond, "conditional", b()); @@ -3814,7 +3794,7 @@ llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) { } llvm::Type* IrEmitter::IrShapeType(const Shape& shape) { - return llvm_ir::ShapeToIrType(shape, module_); + return llvm_ir::ShapeToIrType(shape, module_->getContext()); } llvm::Value* IrEmitter::GetProfileCountersArgument() { @@ -4031,13 +4011,13 @@ absl::Status IrEmitter::ElementTypesSameAndSupported( } absl::Status IrEmitter::DefaultAction(HloInstruction* hlo) { - ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; + CpuElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : hlo->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { return GetIrArrayFor(operand).EmitReadArrayElement(index, b()); }; } - CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); + CpuElementalIrEmitter elemental_emitter = ElementalIrEmmiterFactory(); return EmitTargetElementLoop( hlo, "elemental_loop", elemental_emitter.MakeElementGenerator(hlo, operand_to_generator), @@ -4076,7 +4056,7 @@ std::vector IrEmitter::EmitThreadLocalCall( } llvm::Type* return_value_buffer_type = - llvm_ir::ShapeToIrType(return_shape, module_); + llvm_ir::ShapeToIrType(return_shape, module_->getContext()); std::string retval_alloca_name = absl::StrCat(name, "_return_value_addr"); int retval_alignment = is_scalar_return @@ -4175,5 +4155,59 @@ void IrEmitter::BindFusionArguments(const HloInstruction* fusion, } } +CpuElementalIrEmitter IrEmitter::ElementalIrEmmiterFactory() { + auto thread_local_call_fn = [this](const HloComputation& callee, + absl::Span parameters, + absl::string_view name, bool is_reducer) { + return EmitThreadLocalCall(callee, parameters, name, is_reducer); + }; + + bool use_truncate_f32_to_bf16_conversion = + !IsNativeConvertSupportedOnTargetCPU( + target_machine_features_.get_target_feature_string()); + + return CpuElementalIrEmitter( + module_, b(), std::move(thread_local_call_fn), + use_truncate_f32_to_bf16_conversion, + hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max()); +} + +absl::Status IrEmitter::EmitNestedComputation(const HloComputation& callee, + absl::string_view name, + bool is_reducer) { + // Module must be scheduled to emit thread local computation. + if (!hlo_module_.has_schedule()) { + return absl::InternalError( + "HLO module must be scheduled to emit thread local computation."); + } + + if (is_computation_emitted(callee, is_reducer)) { + return absl::OkStatus(); + } + + for (HloInstruction* instr : callee.instructions()) { + bool nested_is_reducer = instr->opcode() == HloOpcode::kReduce || + instr->opcode() == HloOpcode::kReduceWindow; + for (HloComputation* called_computation : instr->called_computations()) { + // reassociation is transitive so we "or" the caller and the callee. + TF_RETURN_IF_ERROR( + EmitNestedComputation(*called_computation, llvm_ir::IrName(instr), + is_reducer || nested_is_reducer)); + } + } + + if (callee.IsFusionComputation()) { + return absl::OkStatus(); + } + + VLOG(2) << "Emit nested computation: " << callee.name(); + return EmitComputation( + const_cast(&callee), name, false, + hlo_module_.schedule().sequence(&callee).instructions(), + /*allow_reassociation=*/is_reducer, + /*function_attributes=*/{llvm::Attribute::AlwaysInline}) + .status(); +} + } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index 6d6108475e5295..f96d43fa40f678 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -49,6 +49,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/literal.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/elemental_ir_emitter.h" #include "xla/service/cpu/ir_function.h" #include "xla/service/hlo_module_config.h" #include "xla/service/llvm_ir/alias_analysis.h" @@ -79,7 +80,7 @@ bool IsNativeConvertSupportedOnTargetCPU(std::string feature_string); // classes are part of the new runtime and will eventually replace IrEmitter. class IrEmitter : public DfsHloVisitorWithDefault, public IrBuilderMixin { - class CpuElementalIrEmitter; + class ElementalIrEmitter; public: using GeneratorForOperandIrArrays = @@ -177,8 +178,11 @@ class IrEmitter : public DfsHloVisitorWithDefault, compute_function_.pop(); } - // Emit an LLVM global variable for every constant buffer allocation. - absl::Status EmitConstantGlobals(); + // Emit LLVM global variable for a small constant buffer allocation. + absl::Status EmitSmallConstantGlobals(); + + // Emit LLVM global variables for all constant buffer allocations. + absl::Status EmitAllConstantGlobals(); // Emits a call to a thread local function (e.g. to the computation nested // within a reduce or a map). Thread local callees (by definition) only write @@ -214,6 +218,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // This is convenient for reusing the same logic with a different builder. class IRBuilderGuard { public: + IRBuilderGuard() = default; explicit IRBuilderGuard(IrEmitter* ir_emitter, llvm::IRBuilderBase* builder) : ir_emitter_(ir_emitter), original_builder_(ir_emitter->current_builder_) { @@ -223,11 +228,15 @@ class IrEmitter : public DfsHloVisitorWithDefault, IRBuilderGuard(IRBuilderGuard&& other) = delete; IRBuilderGuard& operator=(IRBuilderGuard&& other) = delete; - ~IRBuilderGuard() { ir_emitter_->current_builder_ = original_builder_; } + ~IRBuilderGuard() { + if (ir_emitter_ != nullptr) { + ir_emitter_->current_builder_ = original_builder_; + } + } private: - IrEmitter* ir_emitter_; - llvm::IRBuilderBase* original_builder_; + IrEmitter* ir_emitter_ = nullptr; + llvm::IRBuilderBase* original_builder_ = nullptr; }; // WithBuilder is a convenience function that creates and returns a @@ -236,9 +245,15 @@ class IrEmitter : public DfsHloVisitorWithDefault, return IRBuilderGuard(this, &builder); } + absl::Status EmitNestedComputation(const HloComputation& callee, + absl::string_view name, bool is_reducer); + protected: friend class IrEmitter2; + // Emit an LLVM global variable for every constant buffer allocation. + absl::Status EmitConstantGlobals(std::optional max_size_bytes); + // // The following methods implement the DfsHloVisitor interface. // @@ -631,7 +646,9 @@ class IrEmitter : public DfsHloVisitorWithDefault, llvm::IRBuilderBase* current_builder_; std::stack compute_function_; mlir::MLIRContext* mlir_context_; - bool allow_reassociation_; + // The state of allow_reassociation_ is required so that that it is + // transitive to all nested computations. + bool allow_reassociation_ = false; // The buffer allocation slice for the root of the computation being compiled. // Only relevant for thread local computations. @@ -786,6 +803,9 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Returns a ConstExpr bitcast. llvm::Constant* EmitGlobalForLiteral(const Literal& literal); + CpuElementalIrEmitter ElementalIrEmmiterFactory(); + + const HloModule& hlo_module_; const HloModuleConfig& hlo_module_config_; bool is_top_level_computation_; diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.cc b/third_party/xla/xla/service/cpu/ir_emitter2.cc index 4dad61b6fe6aa7..1890d5377bfb49 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter2.cc @@ -19,15 +19,14 @@ limitations under the License. #include #include #include -#include -#include #include #include #include "absl/algorithm/container.h" -#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -37,7 +36,6 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" #include "llvm/IR/Attributes.h" -#include "llvm/IR/CallingConv.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalValue.h" @@ -49,7 +47,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "llvm/Support/CodeGen.h" -#include "xla/cpu_function_runtime.h" +#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -60,11 +58,12 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/backend_config.pb.h" #include "xla/service/cpu/dot_op_emitter.h" -#include "xla/service/cpu/elemental_math_emitter.h" +#include "xla/service/cpu/elemental_ir_emitter.h" #include "xla/service/cpu/ir_emitter.h" #include "xla/service/cpu/parallel_loop_emitter.h" #include "xla/service/cpu/shape_partition.h" #include "xla/service/elemental_ir_emitter.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/llvm_ir/dynamic_update_slice_util.h" #include "xla/service/llvm_ir/fused_ir_emitter.h" #include "xla/service/llvm_ir/ir_array.h" @@ -73,141 +72,25 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/launch_dim.h" +#include "xla/tsl/platform/errors.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" namespace xla::cpu { -namespace { - -// Following struct types correspond to HostKernel C API. -// See: xla/stream_executor/host/host_kernel_c_api.h - -static llvm::StructType* Dim3StructTy(llvm::LLVMContext& ctx, - std::string_view name) { - auto* i64 = llvm::IntegerType::getInt64Ty(ctx); - return llvm::StructType::create(name, i64, i64, i64); -} - -static llvm::StructType* KernelThreadDimTy(llvm::LLVMContext& ctx) { - return Dim3StructTy(ctx, "SE_HOST_KernelThreadDim"); -} - -static llvm::StructType* KernelThreadTy(llvm::LLVMContext& ctx) { - return Dim3StructTy(ctx, "SE_HOST_KernelThread"); -} - -static llvm::StructType* KernelArgTy(llvm::LLVMContext& ctx) { - auto* ptr = llvm::PointerType::getUnqual(ctx); - auto* i64 = llvm::IntegerType::getInt64Ty(ctx); - return llvm::StructType::create("SE_HOST_KernelArg", ptr, i64); -} -static llvm::StructType* KernelCallFrameTy(llvm::LLVMContext& ctx) { - auto* ptr = llvm::PointerType::getUnqual(ctx); - auto* i64 = llvm::IntegerType::getInt64Ty(ctx); - return llvm::StructType::create("SE_HOST_KernelCallFrame", ptr, ptr, i64, - ptr); -} +namespace { -static llvm::FunctionType* KernelFunctionTy(llvm::LLVMContext& ctx) { - return llvm::FunctionType::get(llvm::PointerType::getUnqual(ctx), - llvm::PointerType::getUnqual(ctx), - /*isVarArg=*/false); +KernelApiIrBuilder::Options KernelApiIrBuilderOptionsFromHloModuleConfig( + const HloModuleConfig& config) { + return KernelApiIrBuilder::Options{ + config.debug_options().xla_llvm_enable_invariant_load_metadata(), + config.debug_options().xla_cpu_prefer_vector_width()}; } } // namespace -//===----------------------------------------------------------------------===// -// ElementalIrEmitter -//===----------------------------------------------------------------------===// - -class IrEmitter2::ElementalIrEmitter : public xla::ElementalIrEmitter { - public: - ElementalIrEmitter(llvm::Module* module, llvm::IRBuilderBase* b, - const HloModule* hlo_module, IrEmitter* nested_ir_emitter, - bool fast_min_max) - : xla::ElementalIrEmitter( - module, b, - Options{/*xla_cpu_use_truncate_f32_to_bf16_conversion=*/true}), - hlo_module_(hlo_module), - nested_ir_emitter_(nested_ir_emitter), - fast_min_max_(fast_min_max) {} - - protected: - absl::StatusOr EmitAtan2(PrimitiveType prim_type, - llvm::Value* lhs, llvm::Value* rhs, - absl::string_view) override { - return xla::cpu::EmitAtan2(module(), *b(), prim_type, lhs, rhs); - } - - absl::StatusOr EmitTanh(PrimitiveType prim_type, - llvm::Value* value) override { - return xla::cpu::EmitTanh(module(), *b(), prim_type, value); - } - - absl::StatusOr EmitErf(PrimitiveType prim_type, - llvm::Value* value) override { - return xla::cpu::EmitErf(module(), *b(), prim_type, value); - } - - absl::StatusOr> EmitThreadLocalCall( - const HloComputation& callee, absl::Span parameters, - absl::string_view name, bool is_reducer) override { - // Module must be scheduled to emit thread local computation. - if (!hlo_module_ || !hlo_module_->has_schedule()) { - return absl::InternalError( - "HLO module must be scheduled to emit thread local computation."); - } - - // Create a nested function for thread local computation(s) if it is not - // already created. Nested functions are created with internal linkage. - auto emit_computation = [&](const HloComputation* computation) { - if (!nested_ir_emitter_->is_computation_emitted(*computation, - is_reducer)) { - VLOG(2) << "Emit nested computation: " << computation->name(); - TF_RETURN_IF_ERROR( - nested_ir_emitter_ - ->EmitComputation( - const_cast(computation), name, false, - hlo_module_->schedule() - .sequence(computation) - .instructions(), - /*allow_reassociation=*/is_reducer, - /*function_attributes=*/{llvm::Attribute::AlwaysInline}) - .status()); - } - return absl::OkStatus(); - }; - - // We emit all embedded computations reachable through the `callee` to - // support nested thread local call, i.e., nested map computations. - for (HloComputation* embedded : callee.MakeEmbeddedComputationsList()) { - if (embedded->IsFusionComputation()) continue; - TF_RETURN_IF_ERROR(emit_computation(embedded)); - } - TF_RETURN_IF_ERROR(emit_computation(&callee)); - - // Add a thread local call to the nested computation. - VLOG(2) << "Emit thread local call to: " << callee.name(); - nested_ir_emitter_->b()->SetInsertPoint(b()->GetInsertPoint()); - auto values = nested_ir_emitter_->EmitThreadLocalCall( - callee, parameters, name, is_reducer, /*in_compute_function=*/false); - - return values; - } - - bool fast_min_max() override { return fast_min_max_; } - - private: - const HloModule* hlo_module_; - IrEmitter* nested_ir_emitter_; - bool fast_min_max_; -}; - //===----------------------------------------------------------------------===// // IrEmitter2 //===----------------------------------------------------------------------===// @@ -217,10 +100,9 @@ IrEmitter2::IrEmitter2(const HloModule& hlo_module, llvm::Module* module, : hlo_module_(hlo_module), module_(module), nested_ir_emitter_(nested_ir_emitter), - call_frame_ty_(KernelCallFrameTy(module_->getContext())), - thread_dims_ty_(KernelThreadDimTy(module_->getContext())), - thread_ty_(KernelThreadTy(module_->getContext())), - arg_ty_(KernelArgTy(module_->getContext())) {} + kernel_api_ir_builder_( + module_->getContext(), + KernelApiIrBuilderOptionsFromHloModuleConfig(hlo_module_.config())) {} bool IrEmitter2::fast_min_max() const { return hlo_module_.config().debug_options().xla_cpu_enable_fast_min_max(); @@ -233,37 +115,6 @@ IrEmitter2::KernelInfo::KernelInfo(KernelPrototype prototype, thread_dims(thread_dims), invariant_arguments(std::move(prototype.invariant_arguments)) {} -absl::StatusOr IrEmitter2::EmitElementalHostKernel( - const HloInstruction* instr) { - VLOG(2) << "Emit elemental host kernel: " << instr->name(); - - TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, - EmitKernelPrototype(instr)); - - llvm::IRBuilder<> b(module_->getContext()); - b.SetInsertPoint(kernel_prototype.function->getEntryBlock().getTerminator()); - - ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; - for (int64_t i = 0; i < instr->operand_count(); ++i) { - const HloInstruction* operand = instr->operand(i); - operand_to_generator[operand] = [&, i](const llvm_ir::IrArray::Index& idx) { - return kernel_prototype.arguments[i].EmitReadArrayElement(idx, &b); - }; - } - - ElementalIrEmitter elemental_emitter(module_, &b, &hlo_module_, - nested_ir_emitter_, fast_min_max()); - llvm_ir::ElementGenerator element_generator = - elemental_emitter.MakeElementGenerator(instr, operand_to_generator); - - TF_ASSIGN_OR_RETURN( - se::ThreadDim thread_dims, - EmitElementalLoops(b, instr, kernel_prototype, element_generator)); - - return kernels_.emplace_back( - KernelInfo(std::move(kernel_prototype), se::BlockDim(), thread_dims)); -} - absl::StatusOr IrEmitter2::EmitPadHostKernel( const HloInstruction* pad) { VLOG(2) << "Emit Pad host kernel."; @@ -304,7 +155,7 @@ absl::StatusOr IrEmitter2::EmitFusionHostKernel( } if (fusion->fusion_kind() != HloInstruction::FusionKind::kLoop) { - return Internal("Unsupported loop fusion kind for instruction: %s", + return Internal("Unsupported fusion kind for instruction: %s", fusion->ToString()); } @@ -314,8 +165,13 @@ absl::StatusOr IrEmitter2::EmitFusionHostKernel( llvm::IRBuilder<> b(module_->getContext()); b.SetInsertPoint(kernel_prototype.function->getEntryBlock().getTerminator()); - ElementalIrEmitter elemental_emitter(module_, &b, &hlo_module_, - nested_ir_emitter_, fast_min_max()); + IrEmitter::IRBuilderGuard builder_guard = nested_ir_emitter_->WithBuilder(b); + + HloComputation* nested_computation = fusion->fused_instructions_computation(); + TF_RETURN_IF_ERROR(nested_ir_emitter_->EmitNestedComputation( + *nested_computation, llvm_ir::IrName(fusion), false)); + + CpuElementalIrEmitter elemental_emitter = ElementalIrEmmiterFactory(&b); FusedIrEmitter fused_emitter(elemental_emitter); for (int i = 0; i < fusion->operand_count(); i++) { @@ -352,14 +208,6 @@ absl::StatusOr IrEmitter2::EmitFusionHostKernel( KernelInfo(std::move(kernel_prototype), se::BlockDim(), thread_dims)); } -absl::StatusOr IrEmitter2::EmitReductionHostKernel( - const HloInstruction* instr) { - VLOG(2) << "Emit reduction host kernel: " << instr->name(); - - // TODO(ezhulenev): Port vectorized reduction emitter from IrEmitter. - return EmitElementalHostKernel(instr); -} - // Dot (fusion) host kernel only supports strategies that emit LLVM IR. static bool IsDotCodegenStrategy(DotImplementationStrategy strategy) { static std::array kDotCodegenStrategies = { @@ -408,25 +256,20 @@ absl::StatusOr IrEmitter2::EmitConcatenateHostKernel( const HloInstruction* instr) { VLOG(2) << "Emit concatenate host kernel: " << instr->name(); - auto fast_impl_reason = CanDoFastConcatenate(instr); - if (fast_impl_reason.ok()) { - VLOG(1) << "Emitting fast concatenate for " << instr->ToString() << ": " - << fast_impl_reason.message(); - TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, - EmitKernelPrototype(instr)); - llvm::IRBuilder<> ir_builder(module_->getContext()); - ir_builder.SetInsertPoint( - kernel_prototype.function->getEntryBlock().getTerminator()); - - llvm_ir::IrArray output_array = kernel_prototype.results[0]; - TF_RETURN_IF_ERROR(::xla::cpu::EmitFastConcatenate( - instr, kernel_prototype.arguments, output_array, module_, ir_builder)); - return kernels_.emplace_back(KernelInfo(std::move(kernel_prototype), - se::BlockDim(), se::ThreadDim())); - } - VLOG(1) << "Could not emit fast concatenate for " << instr->ToString() << ": " - << fast_impl_reason.message(); - return EmitElementalHostKernel(instr); + DCHECK_OK(CanDoFastConcatenate(instr)); + + VLOG(1) << "Emitting fast concatenate for " << instr->ToString(); + TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, + EmitKernelPrototype(instr)); + llvm::IRBuilder<> ir_builder(module_->getContext()); + ir_builder.SetInsertPoint( + kernel_prototype.function->getEntryBlock().getTerminator()); + + llvm_ir::IrArray output_array = kernel_prototype.results[0]; + TF_RETURN_IF_ERROR(::xla::cpu::EmitFastConcatenate( + instr, kernel_prototype.arguments, output_array, module_, ir_builder)); + return kernels_.emplace_back( + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } absl::StatusOr IrEmitter2::EmitDotFusionHostKernel( @@ -506,26 +349,22 @@ absl::StatusOr IrEmitter2::EmitSliceToDynamicHostKernel( absl::StatusOr IrEmitter2::EmitDynamicUpdateSliceHostKernel(const HloInstruction* instr) { - if (llvm_ir::CanUpdateDynamicSliceInPlace(const_cast(instr), - nested_ir_emitter_->assignment())) { - VLOG(2) << "Emit in-place dynamic-update-slice kernel: " << instr->name(); + DCHECK(CanUpdateDynamicSliceInPlace(instr)); - TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, - EmitKernelPrototype(instr)); + VLOG(2) << "Emit in-place dynamic-update-slice kernel: " << instr->name(); - llvm::IRBuilder<> b(module_->getContext()); - b.SetInsertPoint( - kernel_prototype.function->getEntryBlock().getTerminator()); + TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, + EmitKernelPrototype(instr)); - TF_RETURN_IF_ERROR(llvm_ir::EmitDynamicUpdateSliceInPlace( - kernel_prototype.arguments, kernel_prototype.results.front(), - llvm_ir::IrName(instr, "in_place"), &b)); + llvm::IRBuilder<> b(module_->getContext()); + b.SetInsertPoint(kernel_prototype.function->getEntryBlock().getTerminator()); - return kernels_.emplace_back(KernelInfo(std::move(kernel_prototype), - se::BlockDim(), se::ThreadDim())); - } + TF_RETURN_IF_ERROR(llvm_ir::EmitDynamicUpdateSliceInPlace( + kernel_prototype.arguments, kernel_prototype.results.front(), + llvm_ir::IrName(instr, "in_place"), &b)); - return EmitElementalHostKernel(instr); + return kernels_.emplace_back( + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } absl::StatusOr IrEmitter2::EmitSortComparator( @@ -560,316 +399,10 @@ absl::StatusOr IrEmitter2::EmitSortComparator( // Building HostKernel prototypes. //===----------------------------------------------------------------------===// -absl::StatusOr IrEmitter2::GetAllocationSlice( - const HloInstruction* instruction, const ShapeIndex& index) { - return nested_ir_emitter_->assignment().GetUniqueSlice(instruction, index); -} - -absl::StatusOr> -IrEmitter2::GetKernelArgumentsParameters(const HloInstruction* instruction) { - std::vector arguments; - - for (HloInstruction* operand : instruction->operands()) { - for (auto& indexed : ShapeUtil::GetLeafShapes(operand->shape())) { - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, - GetAllocationSlice(operand, indexed.index)); - arguments.push_back(KernelParameter{indexed.shape, slice}); - } - } - return arguments; -} - -absl::StatusOr> -IrEmitter2::GetKernelResultsParameters(const HloInstruction* instruction) { - std::vector results; - for (auto& indexed : ShapeUtil::GetLeafShapes(instruction->shape())) { - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, - GetAllocationSlice(instruction, indexed.index)); - results.push_back(KernelParameter{indexed.shape, slice}); - } - return results; -} - -absl::Status IrEmitter2::VerifyKernelParameters( - absl::Span arguments, - absl::Span results) { - // IMPORTANT: Buffer slice non-overlapping property checked below does not - // necessarily mean that the buffers do not alias. Parameter allocations - // might have different index but at run time might be backed by the same - // memory (or aliased memory). We conservatively do not emit noalias metadata - // for buffers coming from parameter allocations. - - // Check that all kernel arguments are coming from non-overlapping slices. It - // is fine to pass same slice as different arguments. This property is not - // used anywhere during the codegen, it acts mostly as a sanity check for - // the buffer assignment. In the future we might emit better aliasing metadata - // based on this property. - for (size_t i = 0; i < arguments.size(); ++i) { - for (size_t j = i + 1; j < arguments.size(); ++j) { - const KernelParameter& a = arguments[i]; - const KernelParameter& b = arguments[j]; - - if (a.slice != b.slice && a.slice.OverlapsWith(b.slice)) { - return Internal( - "Kernel arguments must not overlap: result #%d (%s) overlaps " - "with result #%d (%s)", - i, a.slice.ToString(), j, b.slice.ToString()); - } - } - } - - // Check that all kernel results are unique and coming from non-overlapping - // slices. We rely on this property to create LLVM `!alias.scope` for each - // kernel result buffer and to construct `!noalias` metadata for arguments. - for (size_t i = 0; i < results.size(); ++i) { - for (size_t j = i + 1; j < results.size(); ++j) { - const KernelParameter& a = results[i]; - const KernelParameter& b = results[j]; - - if (a.slice.OverlapsWith(b.slice)) { - return Internal( - "Kernel results must not overlap: result #%d (%s) overlaps " - "with result #%d (%s)", - i, a.slice.ToString(), j, b.slice.ToString()); - } - } - } - - // Check that results do not overlap with arguments, or if they do, they must - // be the same as one of the arguments, which can happen for inplace kernels. - for (size_t i = 0; i < results.size(); ++i) { - for (size_t j = 0; j < arguments.size(); ++j) { - const KernelParameter& result = results[i]; - const KernelParameter& argument = arguments[j]; - - if (result.slice.OverlapsWith(argument.slice) && - result.slice != argument.slice) { - return Internal( - "Kernel results must not partially overlap with arguments: result " - "#%d (%s) overlaps with argument #%d (%s)", - i, result.slice.ToString(), j, argument.slice.ToString()); - break; - } - } - } - - return absl::OkStatus(); -} - -IrEmitter2::KernelThreadDims IrEmitter2::EmitKernelThreadDims( - llvm::IRBuilderBase& b, llvm::Value* call_frame) { - auto* td_gep = b.CreateStructGEP(call_frame_ty_, call_frame, 0, "tdims_gep"); - auto* tdims = b.CreateLoad(b.getPtrTy(), td_gep, "tdims"); - auto* x_gep = b.CreateStructGEP(thread_dims_ty_, tdims, 0, "tdim_x_gep"); - auto* y_gep = b.CreateStructGEP(thread_dims_ty_, tdims, 1, "tdim_y_gep"); - auto* z_gep = b.CreateStructGEP(thread_dims_ty_, tdims, 2, "tdim_z_gep"); - - return {b.CreateLoad(b.getInt64Ty(), x_gep, "tdim_x"), - b.CreateLoad(b.getInt64Ty(), y_gep, "tdim_y"), - b.CreateLoad(b.getInt64Ty(), z_gep, "tdim_z")}; -} - -IrEmitter2::KernelThread IrEmitter2::EmitKernelThread(llvm::IRBuilderBase& b, - llvm::Value* call_frame) { - auto* t_gep = b.CreateStructGEP(call_frame_ty_, call_frame, 1, "tid_gep"); - auto* tids = b.CreateLoad(b.getPtrTy(), t_gep, "tids"); - auto* x_gep = b.CreateStructGEP(thread_ty_, tids, 0, "tid_x_gep"); - auto* y_gep = b.CreateStructGEP(thread_ty_, tids, 1, "tid_y_gep"); - auto* z_gep = b.CreateStructGEP(thread_ty_, tids, 2, "tid_z_gep"); - - return {b.CreateLoad(b.getInt64Ty(), x_gep, "tid_x"), - b.CreateLoad(b.getInt64Ty(), y_gep, "tid_y"), - b.CreateLoad(b.getInt64Ty(), z_gep, "tid_z")}; -} - -llvm_ir::IrArray IrEmitter2::EmitKernelArgument(llvm::IRBuilderBase& b, - llvm::Value* call_frame, - int64_t index, - const Shape& shape) { - llvm::Type* ptr = llvm::PointerType::get(b.getContext(), 0); - std::string name = absl::StrCat("arg", index); - - auto* args_gep = b.CreateStructGEP(call_frame_ty_, call_frame, 3, "args_gep"); - auto* args = b.CreateLoad(ptr, args_gep, "args"); - auto* data_gep = b.CreateConstGEP2_32(arg_ty_, args, index, 0, name + "_gep"); - auto* data = b.CreateLoad(ptr, data_gep, name); - - // All buffers passed to host kernels are expected to be properly aligned, - // emit metadata to allow LLVM to use that information for optimization. - llvm_ir::SetAlignmentMetadataForLoad(data, cpu_function_runtime::MinAlign()); - - // All buffers pointers passed to host kernels are expected to be - // dereferenceable. - IrEmitter::AttachDereferenceableMetadataForLoad(data, ByteSizeOf(shape)); - - // All buffers pointers passed to host kernels are expected to be invariant - // over the whole program. Note the metadata is attached only to loading - // buffer pointers, not to loading actual buffers. - AttachInvariantLoadMetadataForLoad(data); - - return llvm_ir::IrArray(data, llvm_ir::ShapeToIrType(shape, module_), shape); -} - -absl::StatusOr IrEmitter2::EmitKernelPrototype( - std::string_view name, absl::Span arguments, - absl::Span results) { - VLOG(3) << "Emit kernel prototype: " << name - << ", #arguments=" << arguments.size() - << ", #results=" << results.size(); - for (const KernelParameter& argument : arguments) { - VLOG(3) << " argument: " << argument.shape.ToString(true) << " in " - << argument.slice.ToString(); - } - for (const KernelParameter& result : results) { - VLOG(3) << " result: " << result.shape.ToString(true) << " in " - << result.slice.ToString(); - } - - TF_RETURN_IF_ERROR(VerifyKernelParameters(arguments, results)); - - llvm::LLVMContext& ctx = module_->getContext(); - llvm::MDBuilder mb(ctx); - llvm::IRBuilder<> b(ctx); - - // Create an alias domain for the host kernel function. - llvm::MDNode* domain = mb.createAliasScopeDomain( - absl::StrFormat("XLA host kernel %s AA domain", name)); - - // Emit alias scopes for all kernel result buffers. We do not emit alias - // scopes for kernel arguments, because it's usually not profitable, and we - // mostly care about avoiding reloading data from read-only buffers. We use - // sorted container to make sure that emitted metadata is deterministic. - absl::btree_map alias_scopes; - for (const KernelParameter& result : results) { - // Skip result buffers that are aliased with entry parameters as we don't - // know if they can alias with any other buffers. - if (result.slice.allocation()->is_parameter_aliased_with_output()) { - continue; - } - alias_scopes[result.slice] = mb.createAliasScope( - absl::StrFormat("result slice: %s", result.slice.ToString()), domain); - } - - // Returns alias scope for the given buffer slice. - auto get_alias_scope = [&](BufferAllocation::Slice slice) -> llvm::MDNode* { - auto it = alias_scopes.find(slice); - return it == alias_scopes.end() ? nullptr - : llvm::MDNode::get(ctx, it->second); - }; - - // Construct !noalias metadata for buffer slice. - auto get_noalias = [&](BufferAllocation::Slice slice) -> llvm::MDNode* { - llvm::SmallVector scopes; - for (const auto& [alias_slice, alias_scope] : alias_scopes) { - if (!slice.OverlapsWith(alias_slice)) { - scopes.push_back(alias_scope); - } - } - return scopes.empty() ? nullptr : llvm::MDNode::get(ctx, scopes); - }; - - // Collect all buffer slices that the kernel writes to. - absl::flat_hash_set result_slices; - result_slices.reserve(results.size()); - for (const KernelParameter& result : results) { - result_slices.insert(result.slice); - } - - // Create a kernel function with HostKernel API. We use external linkage - // because we'll be resolving this function from the XLA runtime. - llvm::Function* function = llvm::Function::Create( - KernelFunctionTy(ctx), llvm::GlobalValue::ExternalLinkage, name, module_); - function->setCallingConv(llvm::CallingConv::C); - - // Generate unwind information so that GDB can crawl through the stack frames - // created by the JIT compiled code. - function->setUWTableKind(llvm::UWTableKind::Default); - - // Set prefer-vector-width attribute to allow LLVM to use wider vector - // registers (by default LLVM uses at most 256-bit registers). - const DebugOptions& debug_options = hlo_module_.config().debug_options(); - function->addFnAttr( - "prefer-vector-width", - absl::StrCat(debug_options.xla_cpu_prefer_vector_width())); - - // Always keep a frame pointer for the host kernel so we can see them in all - // performance profiling tools. - function->addFnAttr("frame-pointer", "all"); - - // Create an entry basic block and set insert point to the end of it. - b.SetInsertPoint(llvm::BasicBlock::Create(ctx, "", function)); - - llvm::Value* call_frame = function->getArg(0); - // Build thread coordinates from the call frame. - KernelThreadDims kernel_thread_dims = EmitKernelThreadDims(b, call_frame); - KernelThread kernel_thread = EmitKernelThread(b, call_frame); - - int64_t idx = 0; - - // A set of invariant (read-only) buffer indices, feeded in the loop array in - // the next section. - absl::flat_hash_set invariant_arguments; - - // IrArrays for the parameters. - std::vector ir_arguments; - for (int64_t i = 0; i < arguments.size(); ++i) { - const KernelParameter& argument = arguments[i]; - auto ir_argument = EmitKernelArgument(b, call_frame, idx++, argument.shape); - if (auto* noalias = get_noalias(argument.slice)) { - ir_argument.AddNoaliasMetadata(noalias); - } - - // If a buffer slice is not a part of result set, then it must be invariant - // (read-only). - if (!result_slices.contains(argument.slice)) { - ir_argument.MarkInvariantOverWholeProgram(&ctx); - invariant_arguments.insert(i); - } - - ir_arguments.push_back(std::move(ir_argument)); - } - - // IrArrays for the results. - std::vector ir_results; - for (const KernelParameter& result : results) { - auto ir_result = EmitKernelArgument(b, call_frame, idx++, result.shape); - if (auto* noalias = get_noalias(result.slice)) { - ir_result.AddNoaliasMetadata(noalias); - } - if (auto* alias_scope = get_alias_scope(result.slice)) { - ir_result.AddAliasScopeMetadata(alias_scope); - } - ir_results.push_back(std::move(ir_result)); - } - - // Return null pointer to signal success as we do not support error handling - // in the compiled host kernel. - llvm::BasicBlock* return_block = - llvm::BasicBlock::Create(ctx, "return", function); - - b.CreateBr(return_block); - - b.SetInsertPoint(return_block); - b.CreateRet( - llvm::ConstantPointerNull::get(llvm::PointerType::getUnqual(ctx))); - - return KernelPrototype{function, - return_block, - kernel_thread_dims, - kernel_thread, - std::move(ir_arguments), - std::move(ir_results), - std::move(invariant_arguments)}; -} - absl::StatusOr IrEmitter2::EmitKernelPrototype( const HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(std::vector arguments, - GetKernelArgumentsParameters(instr)); - TF_ASSIGN_OR_RETURN(std::vector results, - GetKernelResultsParameters(instr)); - return EmitKernelPrototype(instr->name(), std::move(arguments), - std::move(results)); + return kernel_api_ir_builder_.EmitKernelPrototype( + *module_, instr, &nested_ir_emitter_->assignment()); } std::optional IrEmitter2::GetParallelConfig( @@ -910,10 +443,16 @@ absl::Status IrEmitter2::CanDoFastConcatenate( return absl::OkStatus(); }; +bool IrEmitter2::CanUpdateDynamicSliceInPlace( + const HloInstruction* update) const { + return llvm_ir::CanUpdateDynamicSliceInPlace( + const_cast(update), nested_ir_emitter_->assignment()); +} + IrEmitter2::ParallelPartitionBounds IrEmitter2::EmitParallelPartitionBounds( llvm::IRBuilderBase& b, const KernelPrototype& kernel_prototype, const ParallelConfig& parallel_config, const Shape& shape, - std::string_view name) { + absl::string_view name) { ShapePartitionIterator it(shape, parallel_config.outer_dimension_partitions); size_t num_parallel_dimensions = @@ -957,7 +496,7 @@ IrEmitter2::ParallelPartitionBounds IrEmitter2::EmitParallelPartitionBounds( // Construct IR to load bounds for all parallel dimensions. ParallelPartitionBounds bounds; for (size_t i = 0; i < num_parallel_dimensions; ++i) { - llvm::Value* partition = kernel_prototype.thread.x; + llvm::Value* partition = kernel_prototype.thread_id.x; llvm::Value* parallel_dim = b.getInt32(i); llvm::Value* lower_gep = b.CreateInBoundsGEP( @@ -1039,4 +578,18 @@ void IrEmitter2::AttachInvariantLoadMetadataForLoad( hlo_module_.config()); } +CpuElementalIrEmitter IrEmitter2::ElementalIrEmmiterFactory( + llvm::IRBuilderBase* b) const { + auto thread_local_call_fn = [this](const HloComputation& callee, + absl::Span parameters, + absl::string_view name, bool is_reducer) { + return nested_ir_emitter_->EmitThreadLocalCall( + callee, parameters, name, is_reducer, + /*in_compute_function=*/false); + }; + + return CpuElementalIrEmitter(module_, b, thread_local_call_fn, true, + fast_min_max()); +} + } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.h b/third_party/xla/xla/service/cpu/ir_emitter2.h index 3c7f874c041f5c..77ea6647d4ec97 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2.h +++ b/third_party/xla/xla/service/cpu/ir_emitter2.h @@ -19,25 +19,26 @@ limitations under the License. #include #include #include -#include #include #include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" +#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/elemental_ir_emitter.h" #include "xla/service/cpu/ir_emitter.h" -#include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/loop_emitter.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -68,34 +69,13 @@ class IrEmitter2 { friend class IrEmitter2Test; private: - struct KernelPrototype; + using KernelParameter = KernelApiIrBuilder::KernelParameter; + using KernelPrototype = KernelApiIrBuilder::KernelPrototype; public: IrEmitter2(const HloModule& hlo_module, llvm::Module* module, IrEmitter* nested_ir_emitter); - // Kernel parameter (argument or result buffer) passed to a kernel function. - // We rely on buffer allocation slice information to infer buffer aliasing - // scopes for LLVM codegen. - struct KernelParameter { - Shape shape; - BufferAllocation::Slice slice; - }; - - // Thread dimensions of the kernel invocation. - struct KernelThreadDims { - llvm::Value* x; - llvm::Value* y; - llvm::Value* z; - }; - - // Thread coordinates of the kernel invocation. - struct KernelThread { - llvm::Value* x; - llvm::Value* y; - llvm::Value* z; - }; - // Emitted kernel information that defines how to launch it at run time. struct KernelInfo { explicit KernelInfo(KernelPrototype prototype, @@ -118,10 +98,6 @@ class IrEmitter2 { absl::Span comparators() const { return comparators_; } - // Emits an elemental host kernel for the given HLO instruction. - absl::StatusOr EmitElementalHostKernel( - const HloInstruction* instr); - // Emits a host kernel for the pad instruction. absl::StatusOr EmitPadHostKernel(const HloInstruction* pad); @@ -129,10 +105,6 @@ class IrEmitter2 { absl::StatusOr EmitFusionHostKernel( const HloFusionInstruction* fusion); - // Emits a host kernel for the given reduction instruction. - absl::StatusOr EmitReductionHostKernel( - const HloInstruction* instr); - // Emits a host kernel for the given dot instruction. Small dot operations // are emitted as LLVM IR directly, while larger ones are emitted as a dot // thunk that calls into libraries. @@ -157,36 +129,12 @@ class IrEmitter2 { // Emits a comparator function for the given sort instruction. absl::StatusOr EmitSortComparator(HloComputation* comparator); + absl::Status CanDoFastConcatenate(const HloInstruction* concatenate) const; + bool CanUpdateDynamicSliceInPlace(const HloInstruction* update) const; + private: class ElementalIrEmitter; - // A kernel function prototype with all the LLVM values that might be needed - // to emit the actual kernel body. - struct KernelPrototype { - llvm::Function* function; - llvm::BasicBlock* return_block; - - // LLVM values identifying kernel invocation thread coordinates. - KernelThreadDims thread_dims; - KernelThread thread; - - // LLVM values corresponding to the kernel arguments and results arrays. All - // tuples are flattened as we do not have any tuples at run time and only - // read and write data from/to leaf arrays. - std::vector arguments; - std::vector results; - - // Set containing all invariant (read-only) buffers indices. A buffer is - // read-only if it is not aliased with any result. - absl::flat_hash_set invariant_arguments; - }; - - // Emits a host kernel prototype and prepares function for emitting kernel - // body into it. - absl::StatusOr EmitKernelPrototype( - std::string_view name, absl::Span arguments, - absl::Span results); - // Emits a host kernel prototype for the given HLO instruction. absl::StatusOr EmitKernelPrototype( const HloInstruction* instr); @@ -203,46 +151,16 @@ class IrEmitter2 { std::vector outer_dimension_partitions; }; - // Returns the buffer allocation slice assigned to the given instruction at - // the given shape index. Instruction must have a unique slice assigned to it! - absl::StatusOr GetAllocationSlice( - const HloInstruction* instruction, const ShapeIndex& index = {}); - - // We do not materialize buffers for tuples at run time, and work only with - // leaf arrays. These are the helper functions to flatten HLO instruction - // parameters and results into a list of leaf shapes. - absl::StatusOr> GetKernelArgumentsParameters( - const HloInstruction* instruction); - absl::StatusOr> GetKernelResultsParameters( - const HloInstruction* instruction); - - // Verifies kernel parameters preconditions that are required for codegen. - absl::Status VerifyKernelParameters( - absl::Span arguments, - absl::Span results); - - KernelThreadDims EmitKernelThreadDims(llvm::IRBuilderBase& b, - llvm::Value* call_frame); - - KernelThread EmitKernelThread(llvm::IRBuilderBase& b, - llvm::Value* call_frame); - - llvm_ir::IrArray EmitKernelArgument(llvm::IRBuilderBase& b, - llvm::Value* call_frame, int64_t index, - const Shape& shape); - // Returns parallel config for the given instruction or std::nullopt if // the instruction has to be compiled to a single threaded loop. std::optional GetParallelConfig(const HloInstruction* instr); - absl::Status CanDoFastConcatenate(const HloInstruction* concatenate) const; - // Emits LLVM IR that computes parallel partition bounds from the call frame's // block and thread dimensions and parallel execution config. ParallelPartitionBounds EmitParallelPartitionBounds( llvm::IRBuilderBase& b, const KernelPrototype& kernel_prototype, const ParallelConfig& parallel_config, const Shape& shape, - std::string_view name); + absl::string_view name); // Emits LLVM IR using elemental loop emitter and the given element generator. // If the instruction is parallelized, it will emit a parallel loop partition @@ -261,6 +179,8 @@ class IrEmitter2 { // load metadata. void AttachInvariantLoadMetadataForLoad(llvm::LoadInst* instr) const; + CpuElementalIrEmitter ElementalIrEmmiterFactory(llvm::IRBuilderBase* b) const; + const HloModule& hlo_module_; llvm::Module* module_; @@ -268,11 +188,7 @@ class IrEmitter2 { // to reductions inside fusions). IrEmitter* nested_ir_emitter_; - // LLVM types defining HostKernel API (see host_kernel_c_api.h). - llvm::StructType* call_frame_ty_; - llvm::StructType* thread_dims_ty_; - llvm::StructType* thread_ty_; - llvm::StructType* arg_ty_; + KernelApiIrBuilder kernel_api_ir_builder_; // Keeps track of all the functions emitted so far. std::vector kernels_; diff --git a/third_party/xla/xla/service/cpu/ir_emitter2_test.cc b/third_party/xla/xla/service/cpu/ir_emitter2_test.cc deleted file mode 100644 index ee2464c7b9cad9..00000000000000 --- a/third_party/xla/xla/service/cpu/ir_emitter2_test.cc +++ /dev/null @@ -1,368 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/cpu/ir_emitter2.h" - -#include -#include -#include -#include - -#include "absl/memory/memory.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Type.h" -#include "xla/cpu_function_runtime.h" -#include "xla/hlo/analysis/hlo_ordering.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/parser/hlo_parser.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/ir_emitter.h" -#include "xla/service/cpu/target_machine_features_stub.h" -#include "xla/service/hlo_module_config.h" -#include "xla/service/llvm_ir/ir_array.h" -#include "xla/service/llvm_ir/llvm_util.h" -#include "xla/service/logical_buffer.h" -#include "xla/shape_util.h" -#include "xla/tests/filecheck.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" - -namespace xla::cpu { - -class IrEmitter2Test : public HloTestBase { - public: - // This is a proxy function that allows us call private method - // IrEmitter2::EmitKernelPrototype. - static auto EmitKernelPrototype( - IrEmitter2& ir_emitter, - const std::vector& arguments, - const std::vector& results) { - return ir_emitter.EmitKernelPrototype("test", arguments, results); - } - - absl::StatusOr MakeIrEmitter2(llvm::Module& module, - const HloModule& hlo) { - TF_ASSIGN_OR_RETURN( - buffer_assignment_, - BufferAssigner::Run( - &hlo, std::make_unique(&hlo), - backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return /*alignment=*/1; })); - - target_machine_ = std::make_unique( - [](int64_t size) { return 1; }); - - nested_ir_emitter_ = absl::WrapUnique( - new IrEmitter(nullptr, hlo, *buffer_assignment_, &module, {}, {}, {}, - target_machine_.get(), false)); - - return IrEmitter2(hlo, &module, nested_ir_emitter_.get()); - } - - // TODO(abanas): This function could be static. It requires making the - // underlying FindInstruction function static first. - absl::StatusOr EmitElementalHostKernel( - IrEmitter2& ir_emitter, HloModule& hlo, - std::string_view instruction_name) { - HloInstruction* instruction = FindInstruction(&hlo, instruction_name); - - if (instruction == nullptr) { - return absl::InternalError("Instruction not found"); - } - TF_ASSIGN_OR_RETURN(IrEmitter2::KernelInfo kernel, - ir_emitter.EmitElementalHostKernel(instruction)); - return kernel; - } - - private: - // Dependencies of IrEmitter2. These are created in MakeIrEmitter2 and kept - // alive for the duration of the test, because IrEmitter2 does not take - // ownership of them. - std::unique_ptr buffer_assignment_; - std::unique_ptr target_machine_; - std::unique_ptr nested_ir_emitter_; -}; - -namespace { - -TEST_F(IrEmitter2Test, BuildKernelPrototype) { - auto hlo = std::make_unique("test", HloModuleConfig()); - - llvm::LLVMContext context; - auto module = std::make_unique("test", context); - - auto shape = ShapeUtil::MakeShape(PrimitiveType::F32, {4, 2}); - - BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); - BufferAllocation::Slice arg0(&alloc, /*offset=*/0, /*size=*/256); - BufferAllocation::Slice arg1(&alloc, /*offset=*/256, /*size=*/256); - BufferAllocation::Slice res0(&alloc, /*offset=*/512, /*size=*/256); - BufferAllocation::Slice res1(&alloc, /*offset=*/768, /*size=*/256); - - std::vector arguments = {{shape, arg0}, - {shape, arg1}}; - std::vector results = {{shape, res0}, - {shape, res1}}; - - IrEmitter2 ir_emitter(*hlo, module.get(), /*nested_ir_emitter=*/nullptr); - TF_ASSERT_OK_AND_ASSIGN(auto prototype, - EmitKernelPrototype(ir_emitter, arguments, results)); - - llvm::IRBuilder<> b(context); - b.SetInsertPoint(prototype.function->getEntryBlock().getTerminator()); - - auto* zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0); - llvm_ir::IrArray::Index index(zero, shape, &b); - - // Emit loads from arguments and results buffers to test alias scope metadata. - EXPECT_NE(prototype.arguments[0].EmitReadArrayElement(index, &b), nullptr); - EXPECT_NE(prototype.arguments[1].EmitReadArrayElement(index, &b), nullptr); - EXPECT_NE(prototype.results[0].EmitReadArrayElement(index, &b), nullptr); - EXPECT_NE(prototype.results[1].EmitReadArrayElement(index, &b), nullptr); - - // clang-format off - ASSERT_TRUE(*RunFileCheck(llvm_ir::DumpToString(module.get()), - absl::StrCat(R"( - CHECK: define ptr @test(ptr %0) #0 { - - CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 0 - CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 0 - CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 1 - CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 2 - CHECK: load i64 - CHECK: load i64 - CHECK: load i64 - - CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 1 - CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 0 - CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 1 - CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 2 - CHECK: load i64 - CHECK: load i64 - CHECK: load i64 - - CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3 - CHECK: load ptr - CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 0, i32 0 - CHECK: %[[ARG0:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0:.+]], !dereferenceable ![[DEREF_BYTES:.+]], !align ![[ALIGNMENT:.+]] - - CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3 - CHECK: load ptr - CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 1, i32 0 - CHECK: %[[ARG1:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]] - - CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3 - CHECK: load ptr - CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 2, i32 0 - CHECK: %[[ARG2:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]] - - CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3 - CHECK: load ptr - CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 3, i32 0 - CHECK: %[[ARG3:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]] - - CHECK-NEXT: %[[PTR0:.+]] = getelementptr inbounds float, ptr %[[ARG0]] - CHECK: load float, ptr %[[PTR0]], align 4, - CHECK-SAME: !invariant.load ![[SCOPE0]], - CHECK-SAME: !noalias ![[SCOPE1:.+]] - - CHECK-NEXT: %[[PTR1:.+]] = getelementptr inbounds float, ptr %[[ARG1]] - CHECK: load float, ptr %[[PTR1]], align 4, - CHECK-SAME: !invariant.load ![[SCOPE0]], - CHECK-SAME: !noalias ![[SCOPE1]] - - CHECK-NEXT: %[[PTR2:.+]] = getelementptr inbounds float, ptr %[[ARG2]] - CHECK: load float, ptr %[[PTR2]], align 4, !alias.scope ![[SCOPE2:.+]], - CHECK: !noalias ![[SCOPE3:.+]] - - CHECK-NEXT: %[[PTR3:.+]] = getelementptr inbounds float, ptr %[[ARG3]] - CHECK: load float, ptr %[[PTR3]], align 4, !alias.scope ![[SCOPE3]], - CHECK: !noalias ![[SCOPE2]] - - CHECK: ret ptr null - CHECK: } - - #0 = { uwtable "frame-pointer"="all" "prefer-vector-width"="256" } - CHECK-DAG: ![[ALIGNMENT]] = !{i64 )", cpu_function_runtime::MinAlign(), R"(} - CHECK-DAG: ![[SCOPE0]] = !{} - CHECK-DAG: ![[SCOPE1]] = !{![[RES0:.+]], ![[RES1:.+]]} - CHECK-DAG: ![[SCOPE2]] = !{![[RES0]]} - CHECK-DAG: ![[SCOPE3]] = !{![[RES1]]} - CHECK-DAG: ![[RES0]] = !{!"{{.*}}, offset:512, {{.*}}", ![[DOMAIN:.+]]} - CHECK-DAG: ![[RES1]] = !{!"{{.*}}, offset:768, {{.*}}", ![[DOMAIN]]} - CHECK-DAG: ![[DOMAIN]] = !{!"XLA host kernel test AA domain"} - )"))); - // clang-format on - - // Match for dereferenceable metadata in separate check, because depending on - // the alignment value, it may be the same scope as align, and may be a - // separate one. It's impossible to match both these cases in one FileCheck. - ASSERT_TRUE(*RunFileCheck(llvm_ir::DumpToString(module.get()), R"( - CHECK: {{.+}} = load ptr, {{.*}}, !dereferenceable ![[DEREF_BYTES:.+]], - CHECK: ![[DEREF_BYTES]] = !{i64 32} - )")); -} - -TEST_F(IrEmitter2Test, EmitElementalKernel) { - llvm::LLVMContext context; - auto module = std::make_unique("test", context); - - const char* hlo_text = R"( - HloModule m - ENTRY main { - p0 = f32[2,2] parameter(0) - ROOT convert = s32[2,2] convert(p0) - })"; - - TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); - TF_ASSERT_OK_AND_ASSIGN(IrEmitter2 ir_emitter, MakeIrEmitter2(*module, *hlo)); - TF_ASSERT_OK_AND_ASSIGN(IrEmitter2::KernelInfo kernel, - EmitElementalHostKernel(ir_emitter, *hlo, "convert")); - - ASSERT_TRUE(*RunFileCheck(llvm_ir::DumpToString(module.get()), R"( - CHECK: define ptr @convert(ptr %0) #0 { - CHECK: fptosi float {{.*}} to i32 - CHECK: } - )")); -} - -TEST_F(IrEmitter2Test, EmitParallelKernel) { - llvm::LLVMContext context; - auto module = std::make_unique("test", context); - - const char* hlo_text = R"( - HloModule m - ENTRY main { - p0 = f32[1,2,1,16384,256] parameter(0) - ROOT convert = s32[1,2,1,16384,256] convert(p0), - backend_config={"outer_dimension_partitions":["1","2","1","4"]} - })"; - - TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); - TF_ASSERT_OK_AND_ASSIGN(IrEmitter2 ir_emitter, MakeIrEmitter2(*module, *hlo)); - TF_ASSERT_OK_AND_ASSIGN(IrEmitter2::KernelInfo kernel, - EmitElementalHostKernel(ir_emitter, *hlo, "convert")); - - ASSERT_TRUE(*RunFileCheck(llvm_ir::DumpToString(module.get()), R"( - CHECK: @convert_parallel_bounds = private constant [8 x [4 x [2 x i64]]] - - CHECK: define ptr @convert(ptr %0) #0 { - CHECK: %lo_dim_0_gep = getelementptr{{.*}} i32 0, i64 %tid_x, i32 0, i32 0 - CHECK: %up_dim_0_gep = getelementptr{{.*}} i32 0, i64 %tid_x, i32 0, i32 1 - CHECK: %lo_dim_1_gep = getelementptr{{.*}} i32 0, i64 %tid_x, i32 1, i32 0 - CHECK: %up_dim_1_gep = getelementptr{{.*}} i32 0, i64 %tid_x, i32 1, i32 1 - CHECK: %lo_dim_2_gep = getelementptr{{.*}} i32 0, i64 %tid_x, i32 2, i32 0 - CHECK: %up_dim_2_gep = getelementptr{{.*}} i32 0, i64 %tid_x, i32 2, i32 1 - CHECK: %lo_dim_3_gep = getelementptr{{.*}} i32 0, i64 %tid_x, i32 3, i32 0 - CHECK: %up_dim_3_gep = getelementptr{{.*}} i32 0, i64 %tid_x, i32 3, i32 1 - CHECK: fptosi float {{.*}} to i32 - CHECK: } - )")); -} - -using IrEmitter2InvariantBuffersTest = IrEmitter2Test; - -TEST_F(IrEmitter2InvariantBuffersTest, AllInvariantBuffers) { - llvm::LLVMContext context; - auto module = std::make_unique("test", context); - - const char* hlo_text = R"( - HloModule m - ENTRY main { - p0 = f32[2,2] parameter(0) - p1 = f32[2,2] parameter(1) - ROOT add.0 = f32[2,2] add(p0, p1) - })"; - - TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); - TF_ASSERT_OK_AND_ASSIGN(IrEmitter2 ir_emitter, MakeIrEmitter2(*module, *hlo)); - TF_ASSERT_OK_AND_ASSIGN(IrEmitter2::KernelInfo kernel, - EmitElementalHostKernel(ir_emitter, *hlo, "add.0")); - - ASSERT_EQ(kernel.invariant_arguments.size(), 2); -} - -TEST_F(IrEmitter2InvariantBuffersTest, InvariantBufferPassedTwice) { - llvm::LLVMContext context; - auto module = std::make_unique("test", context); - - const char* hlo_text = R"( - HloModule m - ENTRY main { - p0 = f32[2,2] parameter(0) - ROOT add.0 = f32[2,2] add(p0, p0) - })"; - - TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); - TF_ASSERT_OK_AND_ASSIGN(IrEmitter2 ir_emitter, MakeIrEmitter2(*module, *hlo)); - TF_ASSERT_OK_AND_ASSIGN(IrEmitter2::KernelInfo kernel, - EmitElementalHostKernel(ir_emitter, *hlo, "add.0")); - - // Invariant buffers contains indices of both arguments, even though it is the - // same buffer slice. - ASSERT_EQ(kernel.invariant_arguments.size(), 2); -} - -TEST_F(IrEmitter2InvariantBuffersTest, NoInvariantBuffers) { - llvm::LLVMContext context; - auto module = std::make_unique("test", context); - - const char* hlo_text = R"( - HloModule m, input_output_alias={ {}: (0, {}, must-alias) } - ENTRY main { - p0 = f32[2,2] parameter(0) - ROOT add.0 = f32[2,2] add(p0, p0) - })"; - - TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); - TF_ASSERT_OK_AND_ASSIGN(IrEmitter2 ir_emitter, MakeIrEmitter2(*module, *hlo)); - TF_ASSERT_OK_AND_ASSIGN(IrEmitter2::KernelInfo kernel, - EmitElementalHostKernel(ir_emitter, *hlo, "add.0")); - - ASSERT_EQ(kernel.invariant_arguments.size(), 0); -} - -TEST_F(IrEmitter2InvariantBuffersTest, MixedBuffers) { - llvm::LLVMContext context; - auto module = std::make_unique("test", context); - - const char* hlo_text = R"( - HloModule m, input_output_alias={ {}: (1, {}, must-alias) } - ENTRY main { - p0 = f32[2,2] parameter(0) - p1 = f32[2,2] parameter(1) - ROOT add.0 = f32[2,2] add(p0, p1) - })"; - - TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); - TF_ASSERT_OK_AND_ASSIGN(IrEmitter2 ir_emitter, MakeIrEmitter2(*module, *hlo)); - TF_ASSERT_OK_AND_ASSIGN(IrEmitter2::KernelInfo kernel, - EmitElementalHostKernel(ir_emitter, *hlo, "add.0")); - - // The first argument is invariant, the second is not because it's aliased to - // the output. - EXPECT_EQ(kernel.invariant_arguments.size(), 1); - EXPECT_TRUE(kernel.invariant_arguments.contains(0)); -} - -} // namespace -} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/ir_emitter_test.cc b/third_party/xla/xla/service/cpu/ir_emitter_test.cc index 9b98e1f966d3db..d41cad880a38bf 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter_test.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter_test.cc @@ -15,11 +15,17 @@ limitations under the License. #include "xla/service/cpu/ir_emitter.h" +#include #include +#include #include #include #include +#include +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" @@ -29,17 +35,39 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/backends/cpu/codegen/cpu_features.h" +#include "xla/backends/cpu/codegen/ir_compiler.h" +#include "xla/backends/cpu/codegen/jit_compiler.h" +#include "xla/backends/cpu/codegen/target_machine_features.h" +#include "xla/cpu_function_runtime.h" #include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/buffer_value.h" +#include "xla/service/cpu/cpu_compiler.h" +#include "xla/service/cpu/cpu_executable.h" +#include "xla/service/cpu/cpu_options.h" #include "xla/service/cpu/ir_function.h" +#include "xla/service/cpu/runtime_symbol_generator.h" #include "xla/service/cpu/target_machine_features_stub.h" #include "xla/service/hlo_module_config.h" +#include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/logical_buffer.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" +#include "tsl/platform/threadpool.h" namespace xla::cpu { namespace { @@ -179,5 +207,136 @@ TEST_F(IrEmitterTest, CheckNativeConvertSupportOnTargetCPU) { ASSERT_TRUE(IsNativeConvertSupportedOnTargetCPU(srf_feature_string)); } +// Used to keep all dependencies of IrEmitter alive. +struct IrEmitterWrapper { + std::unique_ptr ir_emitter; + std::unique_ptr buffer_assignment; + std::unique_ptr target_machine_features; + std::unique_ptr mlir_context; +}; + +static absl::StatusOr> +CreateIrEmitterForConstantEmissionTests(HloModule& module, + llvm::Module& llvm_module) { + const DebugOptions& debug_options = module.config().debug_options(); + + const HloModuleConfig& config = module.config(); + + // Options for compiling LLVM IR to machine code. + IrCompiler::Options ir_compiler_options{ + /*optimization_level=*/llvm::CodeGenOptLevel::Default, + /*optimize_for_size=*/options::OptimizeForSizeRequested(config), + /*fast_math_flags=*/llvm_ir::GetCpuFastMathFlags(config), + /*disable_expensive_passes=*/ + debug_options.xla_llvm_disable_expensive_passes(), + /*slp_vectorizer_disabled=*/options::SlpVectorizerDisabled(config), + }; + + // Definition generator to link with XLA:CPU host runtime symbols. + JitCompiler::DefinitionGenerator definition_generator = + [](llvm::TargetMachine* target_machine) { + return std::make_unique( + target_machine->createDataLayout()); + }; + + // Options for orchestrating the JIT compilation process. + JitCompiler::Options jit_compiler_options{ + std::move(ir_compiler_options), + {}, + /*num_dylibs=*/1, + /*definition_generator=*/std::move(definition_generator), + /*max_cpu_isa=*/CpuFeatureFromString(debug_options.xla_cpu_max_isa()), + }; + + llvm::TargetOptions target_options; + target_options.AllowFPOpFusion = llvm::FPOpFusion::Fast; + + // Returns a global (per-process) thread pool for XLA CPU compilation tasks. + auto compilation_task_runner = [](cpu::JitCompiler::Task task) { + static auto* thread_pool = + new tsl::thread::ThreadPool(tsl::Env::Default(), "ir-emitter-test", 1); + + thread_pool->Schedule(std::move(task)); + }; + + TF_ASSIGN_OR_RETURN( + JitCompiler jit_compiler, + JitCompiler::Create(target_options, std::move(jit_compiler_options), + compilation_task_runner)); + + auto scheduler = + debug_options.xla_cpu_enable_concurrency_optimized_scheduler() + ? BFSMemoryScheduler + : DFSMemoryScheduler; + + auto buffer_size_bytes_function = [](const BufferValue& buffer) { + return CpuExecutable::ShapeSizeBytes(buffer.shape()); + }; + TF_ASSIGN_OR_RETURN( + HloSchedule schedule, + ScheduleModule(&module, buffer_size_bytes_function, + ComputationSchedulerToModuleScheduler(scheduler))); + TF_RETURN_IF_ERROR(module.set_schedule(schedule)); + + auto memory_alignment = [](LogicalBuffer::Color) { + return cpu_function_runtime::MinAlign(); + }; + // Run buffer allocation on the HLO graph. + TF_ASSIGN_OR_RETURN( + std::unique_ptr assignment, + BufferAssigner::Run(&module, + std::make_unique(schedule), + buffer_size_bytes_function, memory_alignment, + /*allocate_buffers_for_constants=*/true)); + + auto target_machine_features = + std::make_unique(jit_compiler.target_machine()); + + std::unique_ptr mlir_context; + auto ir_emitter = std::make_unique( + mlir_context.get(), module, *assignment, &llvm_module, + absl::flat_hash_map{}, + absl::flat_hash_map{}, + absl::flat_hash_map{}, + target_machine_features.get(), + /*emit_code_for_msan=*/false); + + return std::make_unique(IrEmitterWrapper{ + std::move(ir_emitter), std::move(assignment), + std::move(target_machine_features), std::move(mlir_context)}); +} + +TEST_F(IrEmitterTest, SmallConstantsAreEmittedAsGlobalsLargeAreNot) { + constexpr size_t kNumberOfSmallConstants = 1; + absl::string_view module_string = R"( +HloModule module + +ENTRY main { + a = f32[1000,1000]{1,0} parameter(0) + b = f32[1000,1000]{1,0} constant({...}) + a_plus_b = f32[1000,1000]{1,0} add(a, b) + c = f32[1,1]{1,0} constant({...}) + broadcast = f32[1000,1000]{1,0} broadcast(c), dimensions={} + ROOT result = f32[1000,1000]{1,0} add(a_plus_b, broadcast) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(module_string)); + + auto llvm_context = std::make_unique(); + auto llvm_module = std::make_unique("test", *llvm_context); + + TF_ASSERT_OK_AND_ASSIGN( + auto wrapped_ir_emitter, + CreateIrEmitterForConstantEmissionTests(*module, *llvm_module)); + + TF_ASSERT_OK(wrapped_ir_emitter->ir_emitter->EmitSmallConstantGlobals()); + + EXPECT_EQ( + std::distance(llvm_module->global_begin(), llvm_module->global_end()), + kNumberOfSmallConstants); +} + } // namespace } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/metrics.cc b/third_party/xla/xla/service/cpu/metrics.cc index ab0289fba24092..4dd25432330460 100644 --- a/third_party/xla/xla/service/cpu/metrics.cc +++ b/third_party/xla/xla/service/cpu/metrics.cc @@ -24,15 +24,25 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/tsl/lib/monitoring/counter.h" #include "tsl/platform/stacktrace.h" +#include "tsl/profiler/lib/traceme.h" namespace xla { namespace cpu { +namespace { + +using ::tsl::profiler::TraceMe; +using ::tsl::profiler::TraceMeEncode; + +} // namespace + auto* cpu_compiler_stacktrace_count = tsl::monitoring::Counter<1>::New( "/xla/service/cpu/compiler_stacktrace_count", "The number of times a compiler stacktrace was called.", "stacktrace"); void RecordCpuCompilerStacktrace() { + TraceMe trace( + [&] { return TraceMeEncode("RecordCpuCompilerStacktrace", {}); }); std::string tsl_stacktrace = tsl::CurrentStackTrace(); // tsl::CurrentStackTrace() adds a prefix and postfix lines, so remove them. diff --git a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc index d09f0e2a6e5025..6a359c9a4d91e3 100644 --- a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc @@ -770,18 +770,18 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { } absl::Status HandleMaximum(HloInstruction* instr) override { - HloInstruction* matmul_call; + HloInstruction* contraction; HloInstruction* intermediate_instr = nullptr; HloInstruction* optional_bitcast = nullptr; - // Attempt to elide maximum and fuse ReLU activation into GEMM, including - // when slicing or bitcasting is applied to the result. + // Attempt to elide maximum and fuse ReLU activation into GEMM / Conv, + // including when slicing or bitcasting is applied to the result. if (Match(instr, m::MaximumAnyOrder(ElementwiseSafeIntermediates( &intermediate_instr, &optional_bitcast, - OneDnnMatmulInstr(&matmul_call)) + OneDnnFusibleInstr(&contraction)) .WithOneUser(), BcastConstScalar(0)))) { - return FuseActivation(OneDnnFusionConfig::RELU, instr, matmul_call, + return FuseActivation(OneDnnFusionConfig::RELU, instr, contraction, intermediate_instr, optional_bitcast); } return absl::OkStatus(); @@ -801,59 +801,59 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { } absl::Status HandleSelect(HloInstruction* instr) override { - HloInstruction* matmul_call; + HloInstruction* contraction; HloInstruction* intermediate_instr = nullptr; HloInstruction* optional_bitcast = nullptr; HloInstruction* src; - // Attempt to elide ELU subgraph and fuse ELU activation into GEMM, + // Attempt to elide ELU subgraph and fuse ELU activation into GEMM / Conv, // including when slicing or bitcasting is applied to the result. if (ELUActivation(instr, &src)) { if (Match(src, ElementwiseSafeIntermediates( &intermediate_instr, &optional_bitcast, - OneDnnMatmulInstr(&matmul_call)))) { - return FuseActivation(OneDnnFusionConfig::ELU, instr, matmul_call, - intermediate_instr); + OneDnnFusibleInstr(&contraction)))) { + return FuseActivation(OneDnnFusionConfig::ELU, instr, contraction, + intermediate_instr, optional_bitcast); } } return absl::OkStatus(); } absl::Status HandleTanh(HloInstruction* instr) override { - HloInstruction* matmul_call; + HloInstruction* contraction; HloInstruction* intermediate_instr = nullptr; HloInstruction* optional_bitcast = nullptr; - // Attempt to elide Tanh and fuse Tanh activation into GEMM, including - // when slicing or bitcasting is applied to the result. + // Attempt to elide Tanh and fuse Tanh activation into GEMM / Conv, + // including when slicing or bitcasting is applied to the result. if (Match(instr, m::Tanh(ElementwiseSafeIntermediates( &intermediate_instr, &optional_bitcast, - OneDnnMatmulInstr(&matmul_call)) + OneDnnFusibleInstr(&contraction)) .WithOneUser()))) { - return FuseActivation(OneDnnFusionConfig::TANH, instr, matmul_call, - intermediate_instr); + return FuseActivation(OneDnnFusionConfig::TANH, instr, contraction, + intermediate_instr, optional_bitcast); } return absl::OkStatus(); } absl::Status HandleClamp(HloInstruction* instr) override { - HloInstruction* matmul_call; + HloInstruction* contraction; HloInstruction* intermediate_instr = nullptr; HloInstruction* optional_bitcast = nullptr; - // Attempt to elide RELU6 and fuse RELU6 activation into GEMM, including - // when slicing or bitcasting is applied to the result. + // Attempt to elide RELU6 and fuse RELU6 activation into GEMM / Conv, + // including when slicing or bitcasting is applied to the result. if (Match(instr, m::Clamp(BcastConstScalar(0), ElementwiseSafeIntermediates( &intermediate_instr, &optional_bitcast, - OneDnnMatmulInstr(&matmul_call)) + OneDnnFusibleInstr(&contraction)) .WithOneUser(), BcastConstScalar(6)))) { - return FuseActivation(OneDnnFusionConfig::RELU6, instr, matmul_call, - intermediate_instr); + return FuseActivation(OneDnnFusionConfig::RELU6, instr, contraction, + intermediate_instr, optional_bitcast); } return absl::OkStatus(); } absl::Status HandleMultiply(HloInstruction* instr) override { - HloInstruction* matmul_call; + HloInstruction* contraction; HloInstruction* intermediate_instr = nullptr; HloInstruction* src; auto activation = GELUActivation(instr, &src); @@ -861,24 +861,25 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { HloInstruction* optional_bitcast = nullptr; if (Match(src, ElementwiseSafeIntermediates( &intermediate_instr, &optional_bitcast, - OneDnnMatmulInstr(&matmul_call)))) { - return FuseActivation(activation, instr, matmul_call, + OneDnnFusibleInstr(&contraction)))) { + return FuseActivation(activation, instr, contraction, intermediate_instr, optional_bitcast); } } - HloInstruction *dot, *constant; + HloInstruction* constant; HloInstruction* optional_convert = nullptr; - auto pattern = m::Op(&instr) - .WithOpcode(HloOpcode::kMultiply) - .WithBinaryOperandsAnyOrder( - m::AnyOf( - pu::SupportedConvert(&optional_convert, - OneDnnMatmulInstr(&dot)) - .WithElementType(PrimitiveType::F32), - OneDnnMatmulInstr(&dot)) - .WithOneUser(), - m::Broadcast(m::Constant(&constant))); + auto pattern = + m::Op(&instr) + .WithOpcode(HloOpcode::kMultiply) + .WithBinaryOperandsAnyOrder( + m::AnyOf( + pu::SupportedConvert(&optional_convert, + OneDnnFusibleInstr(&contraction)) + .WithElementType(PrimitiveType::F32), + OneDnnFusibleInstr(&contraction)) + .WithOneUser(), + m::Broadcast(m::Constant(&constant))); if (Match(instr, pattern)) { std::vector new_operands; @@ -887,31 +888,28 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { return absl::OkStatus(); } - for (auto operand : dot->operands()) { + for (auto operand : contraction->operands()) { new_operands.push_back(operand); } - auto matmul_call = Cast(instr->AddInstruction( - dot->CloneWithNewOperands(instr->shape(), new_operands))); - auto backend_config = matmul_call->backend_config(); - backend_config->mutable_onednn_matmul_config() - ->mutable_fusions() - ->add_ops(OneDnnFusionConfig::LINEAR); + auto custom_call = Cast(instr->AddInstruction( + contraction->CloneWithNewOperands(instr->shape(), new_operands))); + auto backend_config = custom_call->backend_config(); + auto fusions_config = GetFusionsConfig(&backend_config); + fusions_config->add_ops(OneDnnFusionConfig::LINEAR); // Casting to int32 because of issues in proto config for decimal types // handling. - backend_config->mutable_onednn_matmul_config() - ->mutable_fusions() - ->set_alpha_typecast( - *(reinterpret_cast(&constant_value.value()))); - TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config)); + fusions_config->set_alpha_typecast( + *(reinterpret_cast(&constant_value.value()))); + TF_RETURN_IF_ERROR(custom_call->set_backend_config(*backend_config)); HloInstruction* new_instr; if (optional_convert != nullptr && optional_convert->opcode() == HloOpcode::kConvert) { - new_instr = matmul_call->AddInstruction(HloInstruction::CreateConvert( + new_instr = custom_call->AddInstruction(HloInstruction::CreateConvert( ShapeUtil::ChangeElementType( - matmul_call->shape(), optional_convert->shape().element_type()), - matmul_call)); + custom_call->shape(), optional_convert->shape().element_type()), + custom_call)); } else { - new_instr = matmul_call; + new_instr = custom_call; } TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr)); @@ -927,16 +925,16 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { } absl::Status HandleDivide(HloInstruction* instr) override { - HloInstruction* matmul_call; + HloInstruction* contraction; HloInstruction* intermediate_instr = nullptr; HloInstruction* optional_bitcast = nullptr; HloInstruction* src; if (SigmoidActivation(instr, &src)) { if (Match(src, ElementwiseSafeIntermediates( &intermediate_instr, &optional_bitcast, - OneDnnMatmulInstr(&matmul_call)) + OneDnnFusibleInstr(&contraction)) .WithOneUser())) { - return FuseActivation(OneDnnFusionConfig::SIGMOID, instr, matmul_call, + return FuseActivation(OneDnnFusionConfig::SIGMOID, instr, contraction, intermediate_instr, optional_bitcast); } } @@ -945,25 +943,25 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { absl::Status FuseActivation(OneDnnFusionConfig_FusionKind kind, HloInstruction* activation, - HloInstruction* matmul, + HloInstruction* contraction, HloInstruction* intermediate_instr = nullptr, HloInstruction* optional_bitcast = nullptr) { - TF_ASSIGN_OR_RETURN(auto backend_config, - matmul->backend_config()); - auto* matmul_config = backend_config.mutable_onednn_matmul_config(); - matmul_config->mutable_fusions()->add_ops(kind); - TF_RETURN_IF_ERROR(matmul->set_backend_config(backend_config)); - std::unique_ptr output = matmul->Clone(); + auto backend_config = contraction->backend_config(); + auto fusions_config = GetFusionsConfig(&backend_config); + fusions_config->add_ops(kind); + TF_RETURN_IF_ERROR(contraction->set_backend_config(*backend_config)); + std::unique_ptr output = contraction->Clone(); if (optional_bitcast != nullptr && optional_bitcast->opcode() == HloOpcode::kBitcast) { HloInstruction* new_instr = nullptr; if (intermediate_instr != nullptr && intermediate_instr->opcode() == HloOpcode::kConvert) { auto bitcast_call = - matmul->AddInstruction(HloInstruction::CreateBitcast( - ShapeUtil::ChangeElementType(optional_bitcast->shape(), - matmul->shape().element_type()), - matmul)); + contraction->AddInstruction(HloInstruction::CreateBitcast( + ShapeUtil::ChangeElementType( + optional_bitcast->shape(), + contraction->shape().element_type()), + contraction)); new_instr = bitcast_call->AddInstruction(HloInstruction::CreateConvert( ShapeUtil::ChangeElementType( bitcast_call->shape(), @@ -974,7 +972,7 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { } else if (intermediate_instr) { output = intermediate_instr->CloneWithNewOperands( intermediate_instr->shape(), - {matmul->parent()->AddInstruction(std::move(output))}); + {contraction->parent()->AddInstruction(std::move(output))}); } return ReplaceWithNewInstruction(activation, std::move(output)); @@ -1278,6 +1276,8 @@ EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SetUserScratch, absl::StatusOr OneDnnContractionRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { + XLA_VLOG_LINES( + 3, "OneDnnContractionRewriter::Run(), before:\n" + module->ToString()); OneDnnContractionRewriteVisitor visitor; TF_ASSIGN_OR_RETURN(auto result, visitor.RunOnModule(module, execution_threads)); @@ -1286,7 +1286,8 @@ absl::StatusOr OneDnnContractionRewriter::Run( compile_threadpool_); TF_ASSIGN_OR_RETURN(auto result2, reorder_visitor.RunOnModule(module, execution_threads)); - + XLA_VLOG_LINES( + 3, "OneDnnContractionRewriter::Run(), after:\n" + module->ToString()); return {result || result2}; } diff --git a/third_party/xla/xla/service/cpu/onednn_convolution.cc b/third_party/xla/xla/service/cpu/onednn_convolution.cc index 30e91fb4aae3e7..46b4f17a570f18 100644 --- a/third_party/xla/xla/service/cpu/onednn_convolution.cc +++ b/third_party/xla/xla/service/cpu/onednn_convolution.cc @@ -185,44 +185,22 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution( std::vector fused_bufs; for (int64_t i = 0; i < num_fused_operands; ++i) { MemrefInfo operand_minfo(args[arg_indx++]); - fused_mds.push_back(operand_minfo.GetOneDnnMemDesc()); + auto mem_desc = operand_minfo.GetOneDnnMemDesc(); + if (mem_desc.get_ndims() == new_res_md.get_ndims()) { + mem_desc = mem_desc.permute_axes(out_axes); + } + fused_mds.push_back(mem_desc); fused_bufs.push_back(operand_minfo.Data()); } std::vector> postop_args; + FusedOperandsRef fused_operands_ref{fused_bufs, postop_args}; auto bias_md = memory::desc(); - dnnl::post_ops post_ops; - int fused_operand_idx = 0; - for (auto& fused_op : conv_config.fusions().ops()) { - switch (fused_op) { - case OneDnnFusionConfig::BIAS: { - bias_md = fused_mds.at(fused_operand_idx); - postop_args.emplace_back( - DNNL_ARG_BIAS, - dnnl::memory(bias_md, cpu_engine, fused_bufs[fused_operand_idx])); - fused_operand_idx++; - } break; - case OneDnnFusionConfig::BINARY_ADD: { - auto binary_md = fused_mds.at(fused_operand_idx); - binary_md = binary_md.permute_axes(out_axes); - auto arg_idx = - DNNL_ARG_ATTR_MULTIPLE_POST_OP(post_ops.len()) | DNNL_ARG_SRC_1; - postop_args.emplace_back( - arg_idx, - dnnl::memory(binary_md, cpu_engine, fused_bufs[fused_operand_idx])); - post_ops.append_binary(dnnl::algorithm::binary_add, binary_md); - fused_operand_idx++; - } break; - default: - LOG(FATAL) - << __FILE__ << ":" << __LINE__ - << " Attempt to call OneDNN Convolution runtime library with " - "unsupported post op." - << std::endl; - } - } + dnnl::post_ops post_ops = + PopulateOneDnnPostOps(cpu_engine, fused_mds, &conv_config.fusions(), + &fused_operands_ref, &bias_md); auto any_ker_md = memory::desc(new_ker_md.get_dims(), new_ker_md.get_data_type(), diff --git a/third_party/xla/xla/service/cpu/onednn_matmul.cc b/third_party/xla/xla/service/cpu/onednn_matmul.cc index cf954686d4fadd..ebb394768e0d5f 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul.cc +++ b/third_party/xla/xla/service/cpu/onednn_matmul.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/util/onednn_threadpool.h" +#include "tsl/platform/cpu_info.h" #include "tsl/platform/logging.h" #define EIGEN_USE_THREADS @@ -222,6 +223,12 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( TRANSPOSE_LAST_TWO_DIMS_IF( matmul_config.transpose_b() && weights_md.get_ndims() > 1, weights_md); auto output_md = output_minfo.GetOneDnnMemDesc(); + + Literal* reordered_weights_literal = nullptr; + void* rhs_data = weights_minfo.Data(); + + auto weight_format = tsl::port::IsAarch64CPU() ? memory::format_tag::any + : memory::format_tag::ab; if (matmul_config.optimization_config().weights_prepacked()) { // Weight pre-packing is supported for 2D weights only. // Since prepacked weights array is flattened, try to infer the dims from @@ -230,8 +237,48 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( // array. weights_md = memory::desc({input_md.get_dims().back(), output_md.get_dims().back()}, - weights_md.get_data_type(), memory::format_tag::ab); + weights_md.get_data_type(), weight_format); + } else if (tsl::port::IsAarch64CPU()) { + // Weights are not pre-packed, and this scenario requires + // weights reordering on ARM64 platform + auto weights_mem = + dnnl::memory{weights_md, cpu_engine, weights_minfo.Data()}; + + auto bias_md = dnnl::memory::desc{}; + + if (absl::c_count(matmul_config.fusions().ops(), OneDnnFusionConfig::BIAS) > + 0) { + MemrefInfo bias_minfo(args[arg_indx]); + bias_md = bias_minfo.GetOneDnnMemDesc(); + } + + // extend bias rank to match result rank + if (!bias_md.is_zero()) { + auto missed_rank = output_md.get_ndims() - bias_md.get_ndims(); + XLA_LIGHTWEIGHT_CHECK(missed_rank >= 0); + if (missed_rank > 0) { + auto bias_dims = bias_md.get_dims(); + bias_dims.insert(bias_dims.begin(), missed_rank, 1); + bias_md = bias_md.reshape(bias_dims); + } + } + auto reordered_weights_md = OneDnnMatMulOptWeightsDesc( + cpu_engine, input_md, weights_md, bias_md, output_md); + + auto reordered_weights_shape = + MemDescToXlaShapeFlattened(reordered_weights_md); + reordered_weights_literal = new Literal(reordered_weights_shape); + + rhs_data = reordered_weights_literal->untyped_data(); + auto reordered_weights_mem = + dnnl::memory{reordered_weights_md, cpu_engine, rhs_data}; + + dnnl::reorder rdr{weights_mem, reordered_weights_mem}; + rdr.execute(onednn_stream, weights_mem, reordered_weights_mem); + onednn_stream.wait(); + weights_md = reordered_weights_md; } + const int64_t num_fused_operands = num_args - arg_indx; std::vector fused_mds; std::vector fused_bufs; @@ -250,8 +297,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( XLA_LIGHTWEIGHT_CHECK(num_args == arg_indx); auto lhs_mem = memory(input_md, cpu_engine, input_minfo.Data()); - auto rhs_mem = - memory(matmul_pd->weights_desc(), cpu_engine, weights_minfo.Data()); + auto rhs_mem = memory(matmul_pd->weights_desc(), cpu_engine, rhs_data); auto result_mem = memory(output_md, cpu_engine, output_minfo.Data()); if (std::strstr(matmul_pd->impl_info_str(), "ref") != nullptr) { @@ -275,6 +321,11 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMul( matmul_args.insert(postop_args.begin(), postop_args.end()); matmul_prim.execute(onednn_stream, matmul_args); + + if (reordered_weights_literal != nullptr) { + delete reordered_weights_literal; + reordered_weights_literal = nullptr; + } } ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnMatMulReorder( diff --git a/third_party/xla/xla/service/cpu/onednn_memory_util.cc b/third_party/xla/xla/service/cpu/onednn_memory_util.cc index 587c61963193fa..bfb879dd69cffe 100644 --- a/third_party/xla/xla/service/cpu/onednn_memory_util.cc +++ b/third_party/xla/xla/service/cpu/onednn_memory_util.cc @@ -73,18 +73,34 @@ MemrefInfoHandler CreateMemrefInfoFromLiteral(const Literal* literal) { return CreateMemrefFromShape(shape, buf); } +std::pair, std::vector> GetDimsStrides( + const Shape& shape) { + // oneDNN handles scalar as a vector of size 1. + const bool is_scalar = shape.rank() == 0; + int64_t rank = is_scalar ? 1 : shape.rank(); + std::vector strides(rank); + std::vector scalar_shape(1, 1); + absl::Span dimensions = + is_scalar ? scalar_shape : shape.dimensions(); + std::vector dims(dimensions.begin(), dimensions.end()); + if (is_scalar) { + strides[0] = 1; + } else { + int64_t stride = 1; + for (int i : shape.layout().minor_to_major()) { + strides.at(i) = stride; + stride *= dims.at(i); + } + } + return std::make_pair(dims, strides); +} + StackAlloca GetAllocaAndEmitMemrefInfo(llvm::IRBuilderBase& builder, const llvm_ir::IrArray& ir_array) { const Shape& shape = ir_array.GetShape(); - int64_t rank = shape.rank(); - absl::Span dims = shape.dimensions(); - - std::vector strides(rank); - int64_t stride = 1; - for (int i : shape.layout().minor_to_major()) { - strides.at(i) = stride; - stride *= dims.at(i); - } + // oneDNN handles scalar as a vector of size 1. + int64_t rank = shape.rank() == 0 ? 1 : shape.rank(); + auto [dims, strides] = GetDimsStrides(shape); // Type of struct llvm::Type* i64_type = builder.getInt64Ty(); @@ -184,17 +200,10 @@ absl::StatusOr TransposeLastTwoDims( } dnnl::memory::desc ShapeToMemDesc(const Shape& shape) { - auto dimensions = shape.dimensions(); - if (dimensions.empty()) { + auto [dims, strides] = GetDimsStrides(shape); + if (dims.empty()) { return dnnl::memory::desc{}; } - auto dims = dnnl::memory::dims(dimensions.begin(), dimensions.end()); - dnnl::memory::dims strides(dims.size()); - dnnl::memory::dim stride = 1; - for (auto i : shape.layout().minor_to_major()) { - strides.at(i) = stride; - stride *= dims.at(i); - } auto dt = ToOneDnnDataType(static_cast(shape.element_type())); return dnnl::memory::desc(dims, dt, strides); } diff --git a/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc index ec94eb695d2397..5251835aa2d044 100644 --- a/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_ops_rewriter.cc @@ -576,8 +576,12 @@ class OneDnnOpsRewriterVisitor : public DfsHloRewriteVisitor { absl::StatusOr OneDnnOpsRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { + XLA_VLOG_LINES(3, "OneDnnOpsRewriter::Run(), before:\n" + module->ToString()); OneDnnOpsRewriterVisitor visitor; - return visitor.RunOnModule(module, execution_threads); + TF_ASSIGN_OR_RETURN(auto result, + visitor.RunOnModule(module, execution_threads)); + XLA_VLOG_LINES(3, "OneDnnOpsRewriter::Run(), after:\n" + module->ToString()); + return result; } } // namespace cpu diff --git a/third_party/xla/xla/service/cpu/parallel_task_assignment_test.cc b/third_party/xla/xla/service/cpu/parallel_task_assignment_test.cc index 7c76e5f271ca91..2dd12755c25dc2 100644 --- a/third_party/xla/xla/service/cpu/parallel_task_assignment_test.cc +++ b/third_party/xla/xla/service/cpu/parallel_task_assignment_test.cc @@ -23,11 +23,11 @@ limitations under the License. #include "xla/backends/cpu/codegen/target_machine_features.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/test.h" #include "xla/service/cpu/backend_config.pb.h" #include "xla/service/cpu/cpu_executable.h" #include "xla/service/cpu/target_machine_features_stub.h" #include "xla/service/hlo_cost_analysis.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/cpu/runtime_fork_join.cc b/third_party/xla/xla/service/cpu/runtime_fork_join.cc index 50f7814e09b769..bf30ddfebd15f0 100644 --- a/third_party/xla/xla/service/cpu/runtime_fork_join.cc +++ b/third_party/xla/xla/service/cpu/runtime_fork_join.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/base/attributes.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/blocking_counter.h" #define EIGEN_USE_THREADS @@ -32,7 +33,6 @@ limitations under the License. #include "unsupported/Eigen/CXX11/Tensor" #include "xla/executable_run_options.h" #include "xla/service/custom_call_status_internal.h" -#include "tsl/platform/blocking_counter.h" #include "tsl/platform/logging.h" using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, @@ -91,7 +91,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( std::vector statuses(num_partitions); // Dispatch 'num_partitions - 1' compute functions to run in parallel. - tsl::BlockingCounter bc(num_partitions - 1); + absl::BlockingCounter bc(num_partitions - 1); for (int32_t i = 1; i < num_partitions; ++i) { const int64_t offset = i * stride; run_options->intra_op_thread_pool()->enqueueNoNotification( diff --git a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc index 874d9b3fe1b508..65a000c524a472 100644 --- a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc +++ b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include "absl/algorithm/container.h" @@ -95,8 +94,8 @@ void BuildRetBuffers(absl::Span types, int64_t* encoded_dims, } static absl::Status BuildAndCallFfi( - const xla::ExecutableRunOptions* run_options, std::string_view target_name, - std::string_view backend_config, absl::Span outputs, + const xla::ExecutableRunOptions* run_options, absl::string_view target_name, + absl::string_view backend_config, absl::Span outputs, absl::Span inputs, absl::Span result_types, int64_t* result_dims, absl::Span operand_types, int64_t* operand_dims) { diff --git a/third_party/xla/xla/service/cpu/shape_partition_test.cc b/third_party/xla/xla/service/cpu/shape_partition_test.cc index 5a8d152bc37ca4..e5684a69fa7d5c 100644 --- a/third_party/xla/xla/service/cpu/shape_partition_test.cc +++ b/third_party/xla/xla/service/cpu/shape_partition_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include #include #include "absl/algorithm/container.h" -#include "xla/test_helpers.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/cpu/tests/BUILD b/third_party/xla/xla/service/cpu/tests/BUILD index 5c2bf8289dec2e..7b30e8719615ac 100644 --- a/third_party/xla/xla/service/cpu/tests/BUILD +++ b/third_party/xla/xla/service/cpu/tests/BUILD @@ -41,7 +41,7 @@ cc_library( xla_cc_test( name = "cpu_aot_export_test", srcs = ["cpu_aot_export_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", @@ -137,12 +137,12 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", "//xla/service:buffer_assignment", "//xla/service:logical_buffer", "//xla/service/llvm_ir:alias_analysis", "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:llvm_util", - "//xla/tests:filecheck", "@com_google_absl//absl/status", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", @@ -374,11 +374,11 @@ xla_cc_test( "//xla:shape_util", "//xla:test", "//xla:test_helpers", + "//xla/hlo/testlib:filecheck", "//xla/hlo/utils:hlo_matchers", "//xla/service:cpu_plugin", "//xla/service/cpu:onednn_contraction_rewriter", "//xla/service/cpu:onednn_util", - "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", @@ -395,11 +395,11 @@ xla_cc_test( "//xla:shape_util", "//xla:test", "//xla:test_helpers", + "//xla/hlo/testlib:filecheck", "//xla/hlo/utils:hlo_matchers", "//xla/service:cpu_plugin", "//xla/service/cpu:onednn_contraction_rewriter", "//xla/service/cpu:onednn_util", - "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", diff --git a/third_party/xla/xla/service/cpu/tests/cpu_noalias_test.cc b/third_party/xla/xla/service/cpu/tests/cpu_noalias_test.cc index 42a561b12cb35f..837451a3128245 100644 --- a/third_party/xla/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/third_party/xla/xla/service/cpu/tests/cpu_noalias_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/buffer_assignment.h" @@ -39,7 +40,6 @@ limitations under the License. #include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/logical_buffer.h" #include "xla/shape_util.h" -#include "xla/tests/filecheck.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc index a710898ca8350f..7fa7e1e8a82d4e 100644 --- a/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc +++ b/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/strings/str_replace.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/service/cpu/onednn_contraction_rewriter.h" @@ -25,7 +26,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "tsl/platform/cpu_info.h" @@ -170,6 +170,23 @@ TEST_P(ConvolutionTest, Simple2DTest1) { RunCompareAndMatchOptimizedHlo(outline, {}); } +TEST_P(ConvolutionTest, SimpleScalarTest) { + const absl::string_view outline = R"( + HloModule convolution.test + + ENTRY convolution.test { + arg.0 = $dtype[1,22,22,1] parameter(0) + arg.1 = $dtype[1] parameter(1) + reshape.1 = $dtype[1,1,1,1] reshape(arg.1) + convolution.0 = $dtype[1,14,14,1] convolution(arg.0, reshape.1), + window={size=1x1 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + tuple.0 = ($dtype[1,14,14,1]) tuple(convolution.0) + ROOT gte.0 = $dtype[1,14,14,1] get-tuple-element(tuple.0), index=0 + })"; + + RunCompareAndMatchOptimizedHlo(outline, {}); +} + TEST_P(ConvolutionTest, Simple3DTest1) { const absl::string_view outline = R"( HloModule convolution.test @@ -201,6 +218,45 @@ TEST_P(ConvolutionTest, Conv3DWithBiasTest) { RunCompareAndMatchOptimizedHlo(outline, {"BIAS"}); } +TEST_P(ConvolutionTest, Conv3DReluTest) { + const absl::string_view outline = R"( + HloModule convolution.test.with.relu + + ENTRY convolution.test.with.relu { + arg.0 = $dtype[15,4,5,5,28] parameter(0) + arg.1 = $dtype[3,3,3,28,64] parameter(1) + conv = $dtype[15,4,5,5,64] convolution(arg.0, arg.1), + window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f + const.1 = $pdtype[] constant(0) + convert.0 = $dtype[] convert(const.1) + bcast.2 = $dtype[15,4,5,5,64] broadcast(convert.0), dimensions={} + ROOT maximum.1 = $dtype[15,4,5,5,64] maximum(conv, bcast.2) +})"; + + RunCompareAndMatchOptimizedHlo(outline, {"RELU"}); +} + +TEST_P(ConvolutionTest, Conv2DWithBiasAndReluTest) { + const absl::string_view outline = R"( + HloModule convolution.bias.relu.test + + ENTRY convolution.bias.relu.test { + arg0.1 = $dtype[1,22,22,1] parameter(0) + arg0.2 = $dtype[8,8,1,10] parameter(1) + convolution.0 = $dtype[1,11,11,10] convolution(arg0.1, arg0.2), + window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + const.0 = $dtype[10] constant(15) + bcast.1 = $dtype[1,11,11,10] broadcast(const.0), dimensions={3} + add.0 = $dtype[1,11,11,10] add(convolution.0, bcast.1) + const.1 = $pdtype[] constant(0) + convert.0 = $dtype[] convert(const.1) + bcast.2 = $dtype[1,11,11,10] broadcast(convert.0), dimensions={} + ROOT maximum.1 = $dtype[1,11,11,10] maximum(add.0, bcast.2) + })"; + + RunCompareAndMatchOptimizedHlo(outline, {"BIAS", "RELU"}); +} + TEST_P(ConvolutionTest, Conv2DWithBinaryAddTest) { const absl::string_view outline = R"( HloModule convolution.test.with.binaryadd @@ -241,6 +297,319 @@ TEST_P(ConvolutionTest, Conv2DWithBiasAndBinaryAddTest) { RunCompareAndMatchOptimizedHlo(outline, {"BIAS"}); } +TEST_P(ConvolutionTest, ToeplitzConstrcutionTest) { + if (dtype_ == BF16 || dtype_ == F16) { + GTEST_SKIP() << "Skipping test for " << dtypeString_ + << ". HLO Binary Complex instruction expects F32 inputs and " + "Unary Real and Imag instructions output F32 shapes only."; + } + + const absl::string_view outline = R"( + HloModule toeplitz.construction.test + + ENTRY toeplitz.construction.test { + Arg_0.1 = c64[1,23,1] parameter(0) + real.3 = $dtype[1,23,1] real(Arg_0.1) + imag.4 = $dtype[1,23,1] imag(Arg_0.1) + add.7 = $dtype[1,23,1] add(real.3, imag.4) + Arg_1.2 = c64[1,3,3] parameter(1) + real.5 = $dtype[1,3,3] real(Arg_1.2) + convolution.8 = $dtype[1,21,3] convolution(add.7, real.5), + window={size=3}, dim_labels=b0f_io0->b0f + imag.6 = $dtype[1,3,3] imag(Arg_1.2) + add.11 = $dtype[1,3,3] add(real.5, imag.6) + convolution.12 = $dtype[1,21,3] convolution(imag.4, add.11), + window={size=3}, dim_labels=b0f_io0->b0f + subtract.13 = $dtype[1,21,3] subtract(convolution.8, convolution.12) + subtract.9 = $dtype[1,3,3] subtract(imag.6, real.5) + convolution.10 = $dtype[1,21,3] convolution(real.3, subtract.9), + window={size=3}, dim_labels=b0f_io0->b0f + add.14 = $dtype[1,21,3] add(convolution.8, convolution.10) + ROOT complex.15 = c64[1,21,3] complex(subtract.13, add.14) + })"; + + RunCompareAndMatchOptimizedHlo(outline, {"BINARY_ADD"}); +} + +TEST_P(ConvolutionTest, Conv2DWithBiasAndTanhTest) { + const absl::string_view outline = R"( + HloModule convolution.bias.tanh.test + + ENTRY convolution.bias.tanh.test { + arg0.1 = $dtype[1,22,22,1] parameter(0) + arg0.2 = $dtype[8,8,1,10] parameter(1) + convolution.0 = $dtype[1,11,11,10] convolution(arg0.1, arg0.2), + window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + const.0 = $dtype[10] constant(15) + bcast.1 = $dtype[1,11,11,10] broadcast(const.0), dimensions={3} + add.0 = $dtype[1,11,11,10] add(convolution.0, bcast.1) + tanh.0 = $dtype[1,11,11,10] tanh(add.0) + tuple.0 = ($dtype[1,11,11,10]) tuple(tanh.0) + ROOT gte.0 = $dtype[1,11,11,10] get-tuple-element(tuple.0), index=0 + })"; + + RunCompareAndMatchOptimizedHlo(outline, {"BIAS", "TANH"}); +} + +TEST_P(ConvolutionTest, Conv2DWithLinearAndBinaryAddTest) { + const absl::string_view outline = R"( + HloModule convolution.test.linear.binaryadd + + ENTRY convolution.test.linear.binaryadd { + arg0.1 = $dtype[1,22,22,1] parameter(0) + constant.3 = $dtype[] constant(1) + broadcast.4 = $dtype[8,8,1,1] broadcast(constant.3), dimensions={} + convolution.0 = $dtype[1,11,11,1] convolution(arg0.1, broadcast.4), + window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + constant.4 = $pdtype[] constant(0.044715) + convert.0 = $dtype[] convert(constant.4) + broadcast.5 = $dtype[1,11,11,1] broadcast(convert.0), dimensions={} + multiply.0 = $dtype[1,11,11,1] multiply(convolution.0,broadcast.5) + constant.5 = $dtype[] constant(15) + broadcast.6 = $dtype[1] broadcast(constant.5), dimensions={} + broadcast.9 = $dtype[1,11,11,1] broadcast(broadcast.6), dimensions={3} + ROOT add.10 = $dtype[1,11,11,1] add(multiply.0, broadcast.9) + })"; + + RunCompareAndMatchOptimizedHlo(outline, {"LINEAR", "BINARY_ADD"}); +} + +TEST_P(ConvolutionTest, Conv3DWithBiasAndRelu6Test) { + const absl::string_view outline = R"( + HloModule convolution.test.bias.relu6 + + ENTRY convolution.test.bias.relu6 { + arg.0 = $dtype[15,4,5,5,28] parameter(0) + arg.1 = $dtype[3,3,3,28,64] parameter(1) + conv = $dtype[15,4,5,5,64] convolution(arg.0, arg.1), + window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f + bias = $dtype[64] parameter(2) + broadcasted_bias = $dtype[15,4,5,5,64] broadcast(bias), dimensions={4} + add = $dtype[15,4,5,5,64] add(conv, broadcasted_bias) + const.0 = $pdtype[] constant(0) + convert.0 = $dtype[] convert(const.0) + broadcast.0 = $dtype[15,4,5,5,64] broadcast(convert.0), dimensions={} + const.1 = $pdtype[] constant(6) + convert.1 = $dtype[] convert(const.1) + broadcast.1 = $dtype[15,4,5,5,64] broadcast(convert.1), dimensions={} + ROOT clamp.0 = $dtype[15,4,5,5,64] clamp(broadcast.0, add, broadcast.1) +})"; + + RunCompareAndMatchOptimizedHlo(outline, {"BIAS", "RELU6"}); +} + +TEST_P(ConvolutionTest, Conv2DWithBiasAndSigmoidTest) { + const absl::string_view outline = R"( + HloModule convolution.bias.sigmoid.test + + ENTRY convolution.bias.sigmoid.test { + arg0.1 = $dtype[1,22,22,1] parameter(0) + arg0.2 = $dtype[8,8,1,10] parameter(1) + convolution.0 = $dtype[1,11,11,10] convolution(arg0.1, arg0.2), + window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + const.0 = $dtype[10] constant(15) + bcast.1 = $dtype[1,11,11,10] broadcast(const.0), dimensions={3} + add.0 = $dtype[1,11,11,10] add(convolution.0, bcast.1) + const.1 = $pdtype[] constant(1) + convert.0 = $dtype[] convert(const.1) + bcast.2 = $dtype[1,11,11,10] broadcast(convert.0), dimensions={} + negate.0 = $dtype[1,11,11,10] negate(add.0) + exponential.0 = $dtype[1,11,11,10] exponential(negate.0) + add.1 = $dtype[1,11,11,10] add(bcast.2, exponential.0) + divide.0 = $dtype[1,11,11,10] divide(bcast.2, add.1) + tuple.0 =($dtype[1,11,11,10]) tuple(divide.0) + ROOT gte.0 = $dtype[1,11,11,10] get-tuple-element(tuple.0), index=0 + })"; + + RunCompareAndMatchOptimizedHlo(outline, {"BIAS", "SIGMOID"}); +} + +TEST_P(ConvolutionTest, Conv3DWithBiasAndEluTest) { + const absl::string_view outline = R"( + HloModule convolution.test.bias.elu + + ENTRY convolution.test.bias.elu { + arg.0 = $dtype[15,4,5,5,28] parameter(0) + arg.1 = $dtype[3,3,3,28,64] parameter(1) + conv = $dtype[15,4,5,5,64] convolution(arg.0, arg.1), + window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f + bias = $dtype[64] parameter(2) + broadcasted_bias = $dtype[15,4,5,5,64] broadcast(bias), dimensions={4} + add = $dtype[15,4,5,5,64] add(conv, broadcasted_bias) + const.0 = $pdtype[] constant(0) + convert.0 = $dtype[] convert(const.0) + broadcast.0 = $dtype[15,4,5,5,64] broadcast(convert.0), dimensions={} + compare.0 = pred[15,4,5,5,64] compare(add, broadcast.0), direction=GT + exp-min-one.0 = $dtype[15,4,5,5,64] exponential-minus-one(add) + ROOT select.0 = $dtype[15,4,5,5,64] select(compare.0, add, exp-min-one.0) +})"; + + RunCompareAndMatchOptimizedHlo(outline, {"BIAS", "ELU"}); +} + +TEST_P(ConvolutionTest, Conv2DWithGeluApproxTest) { + const absl::string_view outline = R"( + HloModule convolution.gelu.approx.test + + ENTRY convolution.gelu.approx.test { + arg0.1 = $dtype[1,22,22,1] parameter(0) + arg0.2 = $dtype[8,8,1,10] parameter(1) + convolution.0 = $dtype[1,11,11,10] convolution(arg0.1, arg0.2), + window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + mul.0 = $dtype[1,11,11,10] multiply(convolution.0, convolution.0) + mul.1 = $dtype[1,11,11,10] multiply(convolution.0, mul.0) + const.0 = $pdtype[] constant(0.044715) + convert.0 = $dtype[] convert(const.0) + bcast.0 = $dtype[1,11,11,10] broadcast(convert.0), dimensions={} + mul.2 = $dtype[1,11,11,10] multiply(mul.1, bcast.0) + add.0 = $dtype[1,11,11,10] add(convolution.0, mul.2) + const.1 = $pdtype[] constant(0.797884583) + convert.1 = $dtype[] convert(const.1) + bcast.1 = $dtype[1,11,11,10] broadcast(convert.1), dimensions={} + mul.3 = $dtype[1,11,11,10] multiply(add.0, bcast.1) + tanh = $dtype[1,11,11,10] tanh(mul.3) + const.2 = $pdtype[] constant(1) + convert.2 = $dtype[] convert(const.2) + bcast.2 = $dtype[1,11,11,10] broadcast(convert.2), dimensions={} + add.2 = $dtype[1,11,11,10] add(tanh, bcast.2) + const.3 = $pdtype[] constant(0.5) + convert.3 = $dtype[] convert(const.3) + bcast.3 = $dtype[1,11,11,10] broadcast(convert.3), dimensions={} + mul.4 = $dtype[1,11,11,10] multiply(add.2, bcast.3) + ROOT out = $dtype[1,11,11,10] multiply(convolution.0, mul.4) + })"; + + RunCompareAndMatchOptimizedHlo(outline, {"GELU_TANH"}); +} + +TEST_P(ConvolutionTest, Conv2DWithBiasAndGeluApproxTest) { + const absl::string_view outline = R"( + HloModule convolution.bias.gelu.approx.test + + ENTRY convolution.bias.gelu.approx.test { + arg0.1 = $dtype[1,22,22,1] parameter(0) + arg0.2 = $dtype[8,8,1,10] parameter(1) + convolution.0 = $dtype[1,11,11,10] convolution(arg0.1, arg0.2), + window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + constant.0 = $dtype[10] constant(15) + bcast.1 = $dtype[1,11,11,10] broadcast(constant.0), dimensions={3} + add.0 = $dtype[1,11,11,10] add(convolution.0, bcast.1) + constant.12 = $pdtype[] constant(0.044715) + convert.0 = $dtype[] convert(constant.12) + broadcast.13 = $dtype[1,11,11,10] broadcast(convert.0), dimensions={} + multiply.14 = $dtype[1,11,11,10] multiply(broadcast.13, add.0) + multiply.11 = $dtype[1,11,11,10] multiply(add.0, add.0) + multiply.15 = $dtype[1,11,11,10] multiply(multiply.14, multiply.11) + add.16 = $dtype[1,11,11,10] add(add.0, multiply.15) + constant.17 = $pdtype[] constant(0.797884583) + convert.1 = $dtype[] convert(constant.17) + broadcast.18 = $dtype[1,11,11,10] broadcast(convert.1), dimensions={} + multiply.19 = $dtype[1,11,11,10] multiply(add.16, broadcast.18) + tanh.20 = $dtype[1,11,11,10] tanh(multiply.19) + constant.21 = $pdtype[] constant(1) + convert.2 = $dtype[] convert(constant.21) + broadcast.22 = $dtype[1,11,11,10] broadcast(convert.2), dimensions={} + add.23 = $dtype[1,11,11,10] add(tanh.20, broadcast.22) + constant.24 = $pdtype[] constant(0.5) + convert.3 = $dtype[] convert(constant.24) + broadcast.25 = $dtype[1,11,11,10] broadcast(convert.3), dimensions={} + multiply.26 = $dtype[1,11,11,10] multiply(add.23, broadcast.25) + ROOT multiply.27 = $dtype[1,11,11,10] multiply(add.0, multiply.26) + })"; + + RunCompareAndMatchOptimizedHlo(outline, {"BIAS", "GELU_TANH"}); +} + +TEST_P(ConvolutionTest, Conv3DWithGeluExactTest) { + const absl::string_view outline = R"( + HloModule convolution.gelu.exact.test + + ENTRY convolution.gelu.exact.test { + arg.0 = $dtype[15,4,5,5,28] parameter(0) + arg.1 = $dtype[3,3,3,28,64] parameter(1) + conv = $dtype[15,4,5,5,64] convolution(arg.0, arg.1), + window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f + const.0 = $pdtype[] constant(0.707106769) + convert.0 = $dtype[] convert(const.0) + bcast.0 = $dtype[15,4,5,5,64] broadcast(convert.0), dimensions={} + mul.0 = $dtype[15,4,5,5,64] multiply(conv, bcast.0) + erf.0 = $dtype[15,4,5,5,64] erf(mul.0) + const.1 = $pdtype[] constant(1) + convert.1 = $dtype[] convert(const.1) + bcast.1 = $dtype[15,4,5,5,64] broadcast(convert.1), dimensions={} + add.0 = $dtype[15,4,5,5,64] add(erf.0, bcast.1) + const.2 = $pdtype[] constant(0.5) + convert.2 = $dtype[] convert(const.2) + bcast.2 = $dtype[15,4,5,5,64] broadcast(convert.2), dimensions={} + mul.1 = $dtype[15,4,5,5,64] multiply(add.0, bcast.2) + ROOT out = $dtype[15,4,5,5,64] multiply(conv, mul.1) +})"; + + RunCompareAndMatchOptimizedHlo(outline, {"GELU_ERF"}); +} + +TEST_P(ConvolutionTest, Conv2DWithBiasAndGeluExactPattern1Test) { + const absl::string_view outline = R"( + HloModule convolution.test.with.bias.gelu.exact + + ENTRY convolution.test.with.bias.gelu.exact { + arg.0 = $dtype[1,22,22,1] parameter(0) + arg.1 = $dtype[8,8,1,10] parameter(1) + conv = $dtype[1,11,11,10] convolution(arg.0, arg.1), + window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + bias = $dtype[10] parameter(2) + broadcasted_bias = $dtype[1,11,11,10] broadcast(bias), dimensions={3} + add = $dtype[1,11,11,10] add(conv, broadcasted_bias) + const.0 = $pdtype[] constant(0.70703125) + convert.0 = $dtype[] convert(const.0) + bcast.0 = $dtype[1,11,11,10] broadcast(convert.0), dimensions={} + mul.0 = $dtype[1,11,11,10] multiply(add, bcast.0) + erf.0 = $dtype[1,11,11,10] erf(mul.0) + const.1 = $pdtype[] constant(1) + convert.1 = $dtype[] convert(const.1) + bcast.1 = $dtype[1,11,11,10] broadcast(convert.1), dimensions={} + add.0 = $dtype[1,11,11,10] add(erf.0, bcast.1) + const.2 = $pdtype[] constant(0.5) + convert.2 = $dtype[] convert(const.2) + bcast.2 = $dtype[1,11,11,10] broadcast(convert.2), dimensions={} + mul.1 = $dtype[1,11,11,10] multiply(add.0, bcast.2) + ROOT out = $dtype[1,11,11,10] multiply(add, mul.1) +})"; + + RunCompareAndMatchOptimizedHlo(outline, {"BIAS", "GELU_ERF"}); +} + +TEST_P(ConvolutionTest, Conv2DWithBiasAndGeluExactPattern2Test) { + const absl::string_view outline = R"( + HloModule convolution.test.with.bias.gelu.exact + + ENTRY convolution.test.with.bias.gelu.exact { + arg.0 = $dtype[1,22,22,1] parameter(0) + arg.1 = $dtype[8,8,1,10] parameter(1) + conv = $dtype[1,11,11,10] convolution(arg.0, arg.1), + window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + bias = $dtype[10] parameter(2) + broadcasted_bias = $dtype[1,11,11,10] broadcast(bias), dimensions={3} + add = $dtype[1,11,11,10] add(conv, broadcasted_bias) + constant.384 = $pdtype[] constant(0.707182348) + convert.0 = $dtype[] convert(constant.384) + broadcast.385 = $dtype[1,11,11,10] broadcast(convert.0), dimensions={} + multiply.386 = $dtype[1,11,11,10] multiply(broadcast.385, add) + erf.387 = $dtype[1,11,11,10] erf(multiply.386) + constant.388 = $pdtype[] constant(1) + convert.1 = $dtype[] convert(constant.388) + broadcast.389 = $dtype[1,11,11,10] broadcast(convert.1), dimensions={} + add.390 = $dtype[1,11,11,10] add(erf.387, broadcast.389) + multiply.393 = $dtype[1,11,11,10] multiply(add.390, add) + constant.391 = $pdtype[] constant(0.5) + convert.2 = $dtype[] convert(constant.391) + broadcast.392 = $dtype[1,11,11,10] broadcast(convert.2) + ROOT mul.394 = $dtype[1,11,11,10] multiply(multiply.393, broadcast.392) +})"; + + RunCompareAndMatchOptimizedHlo(outline, {"BIAS", "GELU_ERF"}); +} + INSTANTIATE_TEST_SUITE_P( OneDnnConvolutionTestSuite, ConvolutionTest, ::testing::Values(F32, BF16, F16), diff --git a/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc index 46ad5a3a0fe575..9497c17333ab72 100644 --- a/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc +++ b/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "xla/hlo/testlib/filecheck.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/service/cpu/onednn_contraction_rewriter.h" @@ -24,7 +25,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "tsl/platform/cpu_info.h" diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index 4a3c473a5031c0..a5d0aeade482f0 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -24,10 +24,14 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/backends/cpu/codegen/elemental_kernel_emitter.h" +#include "xla/backends/cpu/codegen/llvm_ir_kernel_spec.h" #include "xla/backends/cpu/codegen/target_machine_features.h" #include "xla/backends/cpu/runtime/all_gather_thunk.h" #include "xla/backends/cpu/runtime/all_reduce_thunk.h" @@ -52,6 +56,11 @@ limitations under the License. #include "xla/backends/cpu/runtime/thunk.h" #include "xla/backends/cpu/runtime/topk_thunk.h" #include "xla/backends/cpu/runtime/while_thunk.h" +#include "xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.h" +#include "xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.h" +#include "xla/backends/cpu/xnn_emitter.h" +#include "xla/codegen/kernel_spec.h" +#include "xla/codegen/llvm_ir_kernel_source.h" #include "xla/comparison_util.h" #include "xla/cpu_function_runtime.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -74,10 +83,11 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" +#include "tsl/platform/casts.h" #include "tsl/platform/statusor.h" namespace xla::cpu { @@ -281,6 +291,9 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( return EmitConcatenateKernelThunk(instruction); case HloOpcode::kFusion: + if (instruction->fusion_kind() == HloInstruction::FusionKind::kCustom) { + return EmitXnnFusionThunk(instruction); + } return EmitFusionKernelThunk(instruction); case HloOpcode::kReduce: @@ -513,6 +526,13 @@ absl::StatusOr ThunkEmitter::EmitCallThunk( absl::StatusOr ThunkEmitter::EmitConcatenateKernelThunk( const HloInstruction* instruction) { + if (absl::Status status = ir_emitter_.CanDoFastConcatenate(instruction); + !status.ok()) { + VLOG(1) << "Could not emit fast concatenate for " << instruction->ToString() + << ": " << status.message(); + return EmitElementalKernelThunk(instruction); + } + auto* concatenate = Cast(instruction); TF_ASSIGN_OR_RETURN(auto kernel, ir_emitter_.EmitConcatenateHostKernel(concatenate)); @@ -607,12 +627,20 @@ absl::StatusOr ThunkEmitter::EmitCopyThunk( absl::StatusOr ThunkEmitter::EmitElementalKernelThunk( const HloInstruction* instruction) { - TF_ASSIGN_OR_RETURN(auto kernel, - ir_emitter_.EmitElementalHostKernel(instruction)); - TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); + ElementalKernelEmitter emitter(instruction, &buffer_assignment_, + &target_machine_features_); + TF_ASSIGN_OR_RETURN(std::unique_ptr kernel_spec, + emitter.EmitKernelSpec()); + auto llvm_ir_kernel_spec = absl::WrapUnique( + tsl::down_cast(kernel_spec.release())); + + LlvmIrKernelSource& kernel_source = llvm_ir_kernel_spec->kernel_source(); + std::string kernel_name = kernel_source.kernel_name(); + kernels_.push_back( + {std::move(kernel_name), std::move(kernel_source).thread_safe_module()}); return MakeKernelThunkSequence( - instruction, buffers, kernel, + instruction, std::move(llvm_ir_kernel_spec), /*min_alignment=*/cpu_function_runtime::MinAlign()); } @@ -640,13 +668,8 @@ absl::StatusOr ThunkEmitter::EmitFusionKernelThunk( absl::StatusOr ThunkEmitter::EmitReductionKernelThunk( const HloInstruction* instruction) { - TF_ASSIGN_OR_RETURN(auto kernel, - ir_emitter_.EmitReductionHostKernel(instruction)); - TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - - return MakeKernelThunkSequence( - instruction, buffers, kernel, - /*min_alignment=*/cpu_function_runtime::MinAlign()); + // TODO(ezhulenev): Port vectorized reduction emitter from IrEmitter. + return EmitElementalKernelThunk(instruction); } absl::StatusOr ThunkEmitter::EmitRngThunk( @@ -813,9 +836,23 @@ absl::StatusOr ThunkEmitter::EmitDotThunk( TF_ASSIGN_OR_RETURN(BufferAllocation::Slice out_slice, GetAllocationSlice(instruction)); - return ThunkSequence::Of( - ThunkInfo(instruction), dnums, lhs_slice, lhs->shape(), rhs_slice, - rhs->shape(), out_slice, instruction->shape()); + // Decide whether to use XNNPACK or Eigen. + bool use_xnn = hlo_module_config_.debug_options().xla_cpu_use_xnnpack(); + if (use_xnn) { + TF_ASSIGN_OR_RETURN( + use_xnn, XnnDotThunk::IsSupported(dnums, lhs->shape(), rhs->shape(), + instruction->shape())); + } + + if (use_xnn) { + return ThunkSequence::Of( + ThunkInfo(instruction), dnums, lhs_slice, lhs->shape(), rhs_slice, + rhs->shape(), out_slice, instruction->shape()); + } else { + return ThunkSequence::Of( + ThunkInfo(instruction), dnums, lhs_slice, lhs->shape(), rhs_slice, + rhs->shape(), out_slice, instruction->shape()); + } } } } @@ -967,9 +1004,9 @@ absl::StatusOr ThunkEmitter::EmitCustomCallThunk( // Get backend config and buffer assignments. auto backend_config = custom_call->backend_config(); if (!backend_config.ok()) { - LOG(WARNING) << "Unable to parse backend config for custom call: " - << backend_config.status().message() << "\n" - << "Fall back to parse the opaque str."; + VLOG(3) << "Unable to parse backend config for custom call: " + << backend_config.status().message() << "\n" + << "Fall back to parse the opaque str."; } auto& backend_config_str = !backend_config.ok() @@ -1006,6 +1043,12 @@ absl::StatusOr ThunkEmitter::EmitSliceThunk( absl::StatusOr ThunkEmitter::EmitDynamicUpdateSliceThunk( const HloInstruction* instruction) { + if (!ir_emitter_.CanUpdateDynamicSliceInPlace(instruction)) { + VLOG(2) << "Could not emit in-place dynamic-update-slice kernel: " + << instruction->name(); + return EmitElementalKernelThunk(instruction); + } + TF_ASSIGN_OR_RETURN( auto kernel, ir_emitter_.EmitDynamicUpdateSliceHostKernel(instruction)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); @@ -1103,6 +1146,52 @@ absl::StatusOr ThunkEmitter::EmitSortThunk( return thunks; } +absl::StatusOr ThunkEmitter::EmitXnnFusionThunk( + const HloInstruction* instruction) { + auto* fusion = Cast(instruction); + + // Fusion must have backend config with __xnn_fusion kind. + TF_RET_CHECK(fusion->has_backend_config()) + << "Fusion must have backend config"; + TF_ASSIGN_OR_RETURN(auto backend_config, + fusion->backend_config()); + TF_RET_CHECK(backend_config.has_fusion_config()) + << "Backend config must have fusion config"; + + const FusionBackendConfig& fusion_config = backend_config.fusion_config(); + TF_RET_CHECK(fusion_config.kind() == "__xnn_fusion") + << "Backend config must have __xnn_fusion kind"; + + // Collect XNNPACK fusion arguments. + std::vector arguments; + for (HloInstruction* operand : instruction->operands()) { + for (auto& indexed : ShapeUtil::GetLeafShapes(operand->shape())) { + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice slice, + buffer_assignment_.GetUniqueSlice(operand, indexed.index)); + arguments.push_back(XnnFusionThunk::Argument{slice, indexed.shape}); + } + } + + // Collect XNNPACK fusion results. + std::vector results; + for (auto& indexed : ShapeUtil::GetLeafShapes(instruction->shape())) { + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice slice, + buffer_assignment_.GetUniqueSlice(instruction, indexed.index)); + results.push_back(XnnFusionThunk::Result{slice, indexed.shape}); + } + + // Construct XNNPACK subgraph builder from the fusion computation. + TF_ASSIGN_OR_RETURN( + auto builder, + EmitXnnFusionBuilder(fusion->fused_instructions_computation())); + + return ThunkSequence::Of( + ThunkInfo(instruction), std::move(arguments), std::move(results), + [b = std::move(builder)](auto, auto) mutable { return b(); }); +} + absl::StatusOr ThunkEmitter::GetHostKernelAllocationSlices(const HloInstruction* instruction) { HostKernelAllocationSlices slices; @@ -1154,4 +1243,12 @@ absl::StatusOr ThunkEmitter::MakeKernelThunkSequence( kernel.thread_dims, kernel.invariant_arguments, min_alignment); } +absl::StatusOr ThunkEmitter::MakeKernelThunkSequence( + const HloInstruction* instruction, + std::unique_ptr kernel_spec, + std::optional min_alignment) { + return ThunkSequence::Of(ThunkInfo(instruction), + std::move(kernel_spec), min_alignment); +} + } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.h b/third_party/xla/xla/service/cpu/thunk_emitter.h index 6a5f50698996cc..787254356aaf62 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.h +++ b/third_party/xla/xla/service/cpu/thunk_emitter.h @@ -19,12 +19,15 @@ limitations under the License. #include #include #include +#include #include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "xla/backends/cpu/codegen/llvm_ir_kernel_spec.h" #include "xla/backends/cpu/codegen/target_machine_features.h" #include "xla/backends/cpu/runtime/resource_use.h" #include "xla/backends/cpu/runtime/sort_thunk.h" @@ -50,6 +53,11 @@ namespace xla::cpu { // multiple LLVM modules compiled to object files). class ThunkEmitter { public: + struct EmittedKernel { + std::string kernel_name; + llvm::orc::ThreadSafeModule module; + }; + ThunkEmitter(IrEmitter2& ir_emitter, const BufferAssignment& buffer_assignment, const TargetMachineFeatures& target_machine_features, @@ -58,6 +66,8 @@ class ThunkEmitter { // Emits HLO module entry computation as a sequence of thunks. absl::StatusOr EmitEntryComputation(const HloModule& module); + std::vector& kernels() { return kernels_; } + private: struct HostKernelAllocationSlices { std::vector arguments; @@ -184,6 +194,9 @@ class ThunkEmitter { absl::StatusOr EmitSortThunk( const HloInstruction* instruction); + absl::StatusOr EmitXnnFusionThunk( + const HloInstruction* instruction); + // Returns the list of buffer allocation slices assigned to the given // instruction that will be passed to the host kernel as arguments: a // flattened list of all the leaf buffers for all operands and result. We do @@ -206,6 +219,11 @@ class ThunkEmitter { const IrEmitter2::KernelInfo& kernel, std::optional min_alignment = std::nullopt); + static absl::StatusOr MakeKernelThunkSequence( + const HloInstruction* instruction, + std::unique_ptr kernel_spec, + std::optional min_alignment = std::nullopt); + IrEmitter2& ir_emitter_; const BufferAssignment& buffer_assignment_; @@ -220,6 +238,8 @@ class ThunkEmitter { // create a separate resource for each unique allocation slice. absl::flat_hash_map> token_resources_; + + std::vector kernels_; }; } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc b/third_party/xla/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc index a78dcd6c992f02..d4ed1883f396e6 100644 --- a/third_party/xla/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc +++ b/third_party/xla/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc @@ -30,10 +30,10 @@ limitations under the License. #include "xla/backends/cpu/codegen/target_machine_features.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/testlib/test.h" #include "xla/service/compiler.h" #include "xla/service/cpu/cpu_compiler.h" #include "xla/service/cpu/test_target_triple_helper.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/cpu/xfeed_manager.cc b/third_party/xla/xla/service/cpu/xfeed_manager.cc index 36f2c9c7c308a4..d7d40ff09e1b9b 100644 --- a/third_party/xla/xla/service/cpu/xfeed_manager.cc +++ b/third_party/xla/xla/service/cpu/xfeed_manager.cc @@ -18,11 +18,12 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace xla { namespace cpu { @@ -31,27 +32,19 @@ namespace runtime { void XfeedQueueManager::EnqueueBuffersAtomically( absl::Span buffers) { absl::MutexLock l(&mu_); - bool was_empty = enqueued_buffers_.empty(); for (XfeedBuffer* b : buffers) { VLOG(3) << "Enqueueing " << queue_name_ << " buffer (of " << buffers.size() << " buffers) with length: " << b->length(); enqueued_buffers_.push_back(b); } - if (was_empty && !buffers.empty()) { - // This has the potential to suffer from the notified thread - // immediately trying and failing to acquire mu_, but seems - // preferable to the alternative of notifying outside the lock - // on every enqueue. - cv_.Signal(); - } } XfeedBuffer* XfeedQueueManager::BlockingDequeueBuffer() { - absl::MutexLock l(&mu_); VLOG(3) << "Waiting for an available buffer."; - while (enqueued_buffers_.empty()) { - cv_.Wait(&mu_); - } + auto available_buffer = [this]() ABSL_SHARED_LOCKS_REQUIRED(mu_) { + return !enqueued_buffers_.empty(); + }; + absl::MutexLock l(&mu_, absl::Condition(&available_buffer)); VLOG(3) << "A buffer is available!"; CHECK(current_buffer_ == nullptr); current_buffer_ = enqueued_buffers_.front(); diff --git a/third_party/xla/xla/service/cpu/xfeed_manager.h b/third_party/xla/xla/service/cpu/xfeed_manager.h index 19664ba9f4cbab..3dee7629fdc220 100644 --- a/third_party/xla/xla/service/cpu/xfeed_manager.h +++ b/third_party/xla/xla/service/cpu/xfeed_manager.h @@ -86,10 +86,6 @@ class XfeedQueueManager { absl::Mutex mu_; - // Condition variable that is signaled every time a buffer is - // enqueued to an empty queue. - absl::CondVar cv_; - // XfeedBuffer* queue contents are not owned, but buffer->Done must // be called when the buffer is no longer needed by the runtime. std::deque enqueued_buffers_; diff --git a/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc b/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc index a1cbf20f4d838f..07cad836859731 100644 --- a/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc +++ b/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc @@ -28,22 +28,6 @@ limitations under the License. namespace xla { -namespace { -absl::Status VerifyS4U4Usage(HloInstruction* instruction) { - return ShapeUtil::ForEachSubshapeWithStatus( - instruction->shape(), [&](const Shape& shape, const ShapeIndex&) { - if (primitive_util::IsSubByteNonPredType(shape.element_type()) && - IsCollective(instruction)) { - return absl::InvalidArgumentError( - absl::StrFormat("Int4 is not supported in collective operations, " - "but got instruction: %s", - instruction->ToString())); - } - return absl::OkStatus(); - }); -} -} // namespace - absl::Status CpuGpuShapeVerifier::Preprocess(HloInstruction* hlo) { TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( hlo->shape(), [&](const Shape& shape, const ShapeIndex&) { @@ -64,7 +48,6 @@ absl::Status CpuGpuShapeVerifier::Preprocess(HloInstruction* hlo) { return absl::OkStatus(); })); - TF_RETURN_IF_ERROR(VerifyS4U4Usage(hlo)); return ShapeVerifier::Preprocess(hlo); } diff --git a/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc b/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc index 0ee97bc5db508f..d460db645aaa0a 100644 --- a/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc +++ b/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc @@ -44,28 +44,6 @@ class CpuGpuShapeVerifierTest : public HloTestBase { } }; -TEST_F(CpuGpuShapeVerifierTest, Int4UnsupportedCollectiveInstruction) { - const char* const hlo_string = R"( - HloModule Module - - ENTRY main { - p0 = u4[2,5] parameter(0) - ROOT out = u4[2,10] all-gather(p0), dimensions={1} - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string, config)); - - auto status = verifier().Run(module.get()).status(); - ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.message(), HasSubstr("Int4 is not supported in collective " - "operations, but got instruction: ")); -} - TEST_F(CpuGpuShapeVerifierTest, InvalidElementSize) { const char* const hlo_string = R"( HloModule Module diff --git a/third_party/xla/xla/service/dynamic_dimension_inference_test.cc b/third_party/xla/xla/service/dynamic_dimension_inference_test.cc index ad4d0648528ca1..94ccde71b560b4 100644 --- a/third_party/xla/xla/service/dynamic_dimension_inference_test.cc +++ b/third_party/xla/xla/service/dynamic_dimension_inference_test.cc @@ -22,13 +22,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/service/hlo_runner.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/dynamic_padder_test.cc b/third_party/xla/xla/service/dynamic_padder_test.cc index 83acce7980d4f2..29e3724bcd79c6 100644 --- a/third_party/xla/xla/service/dynamic_padder_test.cc +++ b/third_party/xla/xla/service/dynamic_padder_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/log/check.h" @@ -2392,7 +2391,7 @@ ENTRY gds { } TEST_F(DynamicPadderTest, ShardingDynamicShapes) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main ENTRY main { diff --git a/third_party/xla/xla/service/elemental_ir_emitter.cc b/third_party/xla/xla/service/elemental_ir_emitter.cc index dfd0267907245a..59ecbd9f69968d 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter.cc @@ -32,12 +32,14 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/macros.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/FloatingPointMode.h" #include "llvm/IR/BasicBlock.h" @@ -231,55 +233,75 @@ absl::StatusOr EmitReducePrecisionIR( namespace { -template -llvm::Value* handle_halfway_points_F16ToF8(llvm::Value* f16_abs_bits, - llvm::Value* f8_bits, - llvm::IRBuilderBase* b) { +template +llvm::Value* handle_halfway_points_FxToF8(llvm::Value* fx_abs_bits, + llvm::Value* f8_bits, + llvm::IRBuilderBase* b) { + using llvm::APFloat; using llvm::APInt; using llvm::Value; + static_assert(fx_type == F16 || fx_type == F32 || fx_type == F64); static_assert(3 <= f8_exponent_bits && f8_exponent_bits <= 4); + const llvm::fltSemantics* fx_semantics; + llvm::IntegerType* ix_type; + + if constexpr (fx_type == F16) { + fx_semantics = &llvm::APFloat::IEEEhalf(); + ix_type = b->getInt16Ty(); + } else if constexpr (fx_type == F32) { + fx_semantics = &llvm::APFloat::IEEEsingle(); + ix_type = b->getInt32Ty(); + } else if constexpr (fx_type == F64) { + fx_semantics = &llvm::APFloat::IEEEdouble(); + ix_type = b->getInt64Ty(); + } + + auto ix_const = [fx_semantics, ix_type](APFloat val) { + bool losesInfo; + val.convert(*fx_semantics, llvm::RoundingMode::NearestTiesToEven, + &losesInfo); + return llvm::ConstantInt::get(ix_type, val.bitcastToAPInt()); + }; + llvm::IntegerType* i8_type = b->getInt8Ty(); - llvm::IntegerType* i16_type = b->getInt16Ty(); auto i8_const = [i8_type](int val) { return llvm::ConstantInt::get(i8_type, val); }; - auto i16_const = [i16_type](int val) { - return llvm::ConstantInt::get(i16_type, val); - }; + // F16 values that are halfway between denormal F8 values. This is used to // determine how to round to denormal F8 values. - const int halfway_points_e4[8] = { - 0x1400, // 0x1.0p-10 ; halfway between [0/8 * 2^-6, 1/8 * 2^-6] - 0x1A00, // 0x1.8p-9 ; halfway between [1/8 * 2^-6, 2/8 * 2^-6] - 0x1D00, // 0x1.4p-8 ; halfway between [2/8 * 2^-6, 3/8 * 2^-6] - 0x1F00, // 0x1.Cp-8 ; halfway between [3/8 * 2^-6, 4/8 * 2^-6] - 0x2080, // 0x1.2p-7 ; halfway between [4/8 * 2^-6, 5/8 * 2^-6] - 0x2180, // 0x1.6p-7 ; halfway between [5/8 * 2^-6, 6/8 * 2^-6] - 0x2280, // 0x1.Ap-7 ; halfway between [6/8 * 2^-6, 7/8 * 2^-6] - 0x2380, // 0x1.Ep-7 ; halfway between [7/8 * 2^-6, 8/8 * 2^-6] + const APFloat halfway_points_e4[8] = { + APFloat(0x1.0p-10), // halfway between [0/8 * 2^-6, 1/8 * 2^-6] + APFloat(0x1.8p-9), // halfway between [1/8 * 2^-6, 2/8 * 2^-6] + APFloat(0x1.4p-8), // halfway between [2/8 * 2^-6, 3/8 * 2^-6] + APFloat(0x1.Cp-8), // halfway between [3/8 * 2^-6, 4/8 * 2^-6] + APFloat(0x1.2p-7), // halfway between [4/8 * 2^-6, 5/8 * 2^-6] + APFloat(0x1.6p-7), // halfway between [5/8 * 2^-6, 6/8 * 2^-6] + APFloat(0x1.Ap-7), // halfway between [6/8 * 2^-6, 7/8 * 2^-6] + APFloat(0x1.Ep-7) // halfway between [7/8 * 2^-6, 8/8 * 2^-6] }; - const int halfway_points_e3[16] = { - 0x2000, // 0x1.0p-7; halfway between [0/16 * 2^-2, 1/16 * 2^-2] - 0x2600, // 0x1.8p-6; halfway between [1/16 * 2^-2, 2/16 * 2^-2] - 0x2900, // 0x1.4p-5; halfway between [2/16 * 2^-2, 3/16 * 2^-2] - 0x2B00, // 0x1.Cp-5; halfway between [3/16 * 2^-2, 4/16 * 2^-2] - 0x2C80, // 0x1.2p-4; halfway between [4/16 * 2^-2, 5/16 * 2^-2] - 0x2D80, // 0x1.6p-4; halfway between [5/16 * 2^-2, 6/16 * 2^-2] - 0x2E80, // 0x1.Ap-4; halfway between [6/16 * 2^-2, 7/16 * 2^-2] - 0x2F80, // 0x1.Ep-4; halfway between [7/16 * 2^-2, 8/16 * 2^-2] - 0x3040, // 0x1.1p-3; halfway between [8/16 * 2^-2, 9/16 * 2^-2] - 0x30C0, // 0x1.3p-3; halfway between [9/16 * 2^-2, 10/16 * 2^-2] - 0x3140, // 0x1.5p-3; halfway between [10/16 * 2^-2, 11/16 * 2^-2] - 0x31C0, // 0x1.7p-3; halfway between [11/16 * 2^-2, 12/16 * 2^-2] - 0x3240, // 0x1.9p-3; halfway between [12/16 * 2^-2, 13/16 * 2^-2] - 0x32C0, // 0x1.Bp-3; halfway between [13/16 * 2^-2, 14/16 * 2^-2] - 0x3340, // 0x1.Dp-3; halfway between [14/16 * 2^-2, 15/16 * 2^-2] - 0x33C0, // 0x1.Fp-3; halfway between [15/16 * 2^-2, 16/16 * 2^-2] + const APFloat halfway_points_e3[16] = { + APFloat(0x1.0p-7), // halfway between [0/16 * 2^-2, 1/16 * 2^-2] + APFloat(0x1.8p-6), // halfway between [1/16 * 2^-2, 2/16 * 2^-2] + APFloat(0x1.4p-5), // halfway between [2/16 * 2^-2, 3/16 * 2^-2] + APFloat(0x1.Cp-5), // halfway between [3/16 * 2^-2, 4/16 * 2^-2] + APFloat(0x1.2p-4), // halfway between [4/16 * 2^-2, 5/16 * 2^-2] + APFloat(0x1.6p-4), // halfway between [5/16 * 2^-2, 6/16 * 2^-2] + APFloat(0x1.Ap-4), // halfway between [6/16 * 2^-2, 7/16 * 2^-2] + APFloat(0x1.Ep-4), // halfway between [7/16 * 2^-2, 8/16 * 2^-2] + APFloat(0x1.1p-3), // halfway between [8/16 * 2^-2, 9/16 * 2^-2] + APFloat(0x1.3p-3), // halfway between [9/16 * 2^-2, 10/16 * 2^-2] + APFloat(0x1.5p-3), // halfway between [10/16 * 2^-2, 11/16 * 2^-2] + APFloat(0x1.7p-3), // halfway between [11/16 * 2^-2, 12/16 * 2^-2] + APFloat(0x1.9p-3), // halfway between [12/16 * 2^-2, 13/16 * 2^-2] + APFloat(0x1.Bp-3), // halfway between [13/16 * 2^-2, 14/16 * 2^-2] + APFloat(0x1.Dp-3), // halfway between [14/16 * 2^-2, 15/16 * 2^-2] + APFloat(0x1.Fp-3), // halfway between [15/16 * 2^-2, 16/16 * 2^-2] }; - const int* halfway_points; + const APFloat* halfway_points; int arr_sz; if constexpr (f8_exponent_bits == 4) { halfway_points = halfway_points_e4; @@ -305,13 +327,17 @@ llvm::Value* handle_halfway_points_F16ToF8(llvm::Value* f16_abs_bits, // } for (int i = arr_sz - 1; i >= 0; i--) { Value* comparison; + llvm::Constant* half_way_point = ix_const(halfway_points[i]); + if (i % 2 == 0) { - comparison = b->CreateICmpULE(f16_abs_bits, i16_const(halfway_points[i])); + comparison = b->CreateICmpULE(fx_abs_bits, half_way_point); } else { - comparison = b->CreateICmpULT(f16_abs_bits, i16_const(halfway_points[i])); + comparison = b->CreateICmpULT(fx_abs_bits, half_way_point); } + f8_bits = b->CreateSelect(comparison, i8_const(i), f8_bits); } + return f8_bits; } @@ -337,86 +363,115 @@ llvm::Value* EmitF8e5m2ToF16(llvm::Value* f8_value, llvm::IRBuilderBase* b) { return b->CreateBitCast(shifted, b->getHalfTy()); } -template -absl::StatusOr EmitF16ToF8e(llvm::Value* f16_value, - llvm::IRBuilderBase* b) { +// Convert a float "fx_value" of type "fx_type" to an F8e "f8_exponent_bits" +// bits wide. +template +absl::StatusOr EmitFxToF8e(llvm::Value* fx_value, + llvm::IRBuilderBase* b) { + static_assert(fx_type == F16 || fx_type == F32 || fx_type == F64); static_assert(3 <= f8_exponent_bits && f8_exponent_bits <= 4); + constexpr int f8_mantissa_bits = 7 - f8_exponent_bits; + constexpr int f8_bias = (1 << (f8_exponent_bits - 1)) - 1; + + const uint64_t fx_width = primitive_util::BitWidth(fx_type); + const uint64_t fx_bias = primitive_util::ExponentBias(fx_type); + const uint64_t fx_mantissa_bits = + primitive_util::SignificandWidth(fx_type) - 1; + + const uint64_t exponent_bias_difference = fx_bias - f8_bias; + using llvm::APInt; using llvm::Value; - llvm::IntegerType* i8_type = b->getInt8Ty(); - llvm::IntegerType* i16_type = b->getInt16Ty(); - auto i16_const = [i16_type](int val) { - return llvm::ConstantInt::get(i16_type, val); + const llvm::fltSemantics* fx_semantics; + llvm::IntegerType* ix_type; + + if constexpr (fx_type == F16) { + ix_type = b->getInt16Ty(); + fx_semantics = &llvm::APFloat::IEEEhalf(); + } else if constexpr (fx_type == F32) { + ix_type = b->getInt32Ty(); + fx_semantics = &llvm::APFloat::IEEEsingle(); + } else if constexpr (fx_type == F64) { + ix_type = b->getInt64Ty(); + fx_semantics = &llvm::APFloat::IEEEdouble(); + } + + auto ix_const = [ix_type](uint64_t val) { + return llvm::ConstantInt::get(ix_type, val); }; + llvm::IntegerType* i8_type = b->getInt8Ty(); + llvm::Constant* infinity = llvm::ConstantInt::get( + ix_type, llvm::APFloat::getInf(*fx_semantics).bitcastToAPInt()); + llvm::ConstantInt* nosign_mask = + ix_const(ix_type->getBitMask() ^ ix_type->getSignBit()); + llvm::ConstantInt* sign_mask = ix_const(ix_type->getSignBit()); + llvm::ConstantInt* sign_shift = ix_const(fx_width - 8); + llvm::ConstantInt* fx_exponent_bias_difference = + ix_const(exponent_bias_difference << fx_mantissa_bits); + llvm::ConstantInt* fx_doubled_exponent_bias_difference = + ix_const(exponent_bias_difference << (fx_mantissa_bits + 1)); + llvm::ConstantInt* mantissa_bits_difference = + ix_const(fx_mantissa_bits - f8_mantissa_bits); + llvm::ConstantInt* min_normal_value = + ix_const((exponent_bias_difference + 1) << fx_mantissa_bits); + // Cast the input value to an integer for bitwise manipulation. Get the // absolute value of the input value. - // f16_as_int = bitcast(f16_value, int) - // f16_abs_bits = f16_as_int & 0x7FFF - Value* f16_as_int = b->CreateBitCast(f16_value, i16_type); - llvm::Value* f16_abs_bits = b->CreateAnd(f16_as_int, i16_const(0x7FFF)); + // fx_as_int = bitcast(fx_value, int) + // fx_abs_bits = fx_as_int & nosign_mask + Value* fx_as_int = b->CreateBitCast(fx_value, ix_type); + llvm::Value* fx_abs_bits = b->CreateAnd(fx_as_int, nosign_mask); // Get the sign. - // f8_sign = (f16_as_int & 0x8000) >> 8 - Value* f16_sign = b->CreateAnd(f16_as_int, i16_const(0x8000)); - f16_sign = b->CreateLShr(f16_sign, i16_const(8)); - Value* f8_sign = b->CreateTrunc(f16_sign, i8_type); + // f8_sign = (fx_as_int & sign_mask) >> sign_shift + Value* fx_sign = b->CreateAnd(fx_as_int, sign_mask); + fx_sign = b->CreateLShr(fx_sign, sign_shift); + Value* f8_sign = b->CreateTrunc(fx_sign, i8_type); // Truncate the mantissa to f8 mantissa bits and exponent to f8 exponent bits // Denormal values are not handled properly here and are // dealt with later in this function. - absl::StatusOr f16_reduced_statusor = EmitReducePrecisionIR( - /*src_ty=*/F16, f16_value, + absl::StatusOr fx_reduced_statusor = EmitReducePrecisionIR( + /*src_ty=*/fx_type, fx_value, /*dest_exponent_bits=*/f8_exponent_bits, /*dest_mantissa_bits=*/f8_mantissa_bits, /*quiet_nans=*/true, b); - CHECK_OK(f16_reduced_statusor.status()); // Crash OK - Value* f16_reduced = f16_reduced_statusor.value(); - f16_reduced = b->CreateBitCast(f16_reduced, i16_type); + CHECK_OK(fx_reduced_statusor.status()); // Crash OK + Value* fx_reduced = fx_reduced_statusor.value(); + fx_reduced = b->CreateBitCast(fx_reduced, ix_type); // Remove the sign bit. - // f16_reduced = f16_reduced & 0x7FFF - f16_reduced = b->CreateAnd(f16_reduced, i16_const(0x7FFF)); - - // F16 inf in binary: 0 11111 0000000000 - constexpr int f16_inf_value = 0x7C00; - constexpr int f8_bias = (1 << (f8_exponent_bits - 1)) - 1; - constexpr int exponent_bias_difference = 15 - f8_bias; - constexpr int f16_mantissa_bits = 10; // e5m10 - constexpr int mantissa_bits_difference = f16_mantissa_bits - f8_mantissa_bits; - constexpr int min_normal_value = (exponent_bias_difference + 1) - << f16_mantissa_bits; + // fx_reduced = fx_reduced & nosign_mask + fx_reduced = b->CreateAnd(fx_reduced, nosign_mask); // Round values smaller than the smallest F8 normal value up to the smallest // F8 normal value. The case where we round to a denormal value is handled // later. - // f16_reduced = max(f16_reduced, min_normal_value) - f16_reduced = b->CreateSelect( - b->CreateICmpULT(f16_reduced, i16_const(min_normal_value)), - i16_const(min_normal_value), f16_reduced); + // fx_reduced = max(fx_reduced, min_normal_value) + fx_reduced = b->CreateSelect(b->CreateICmpULT(fx_reduced, min_normal_value), + min_normal_value, fx_reduced); // Adjust the exponent by subtracting the difference in exponent bias: - // f16_reduced -= (exponent_bias_difference << f16_mantissa_bits) + // fx_reduced -= (exponent_bias_difference << fx_mantissa_bits) // For infinity/NaN values, subtract twice the difference in exponent bias - // to ensure the leading exponent bit(s) of f16_reduced are set to zero. - f16_reduced = b->CreateSub( - f16_reduced, - b->CreateSelect( - b->CreateICmpULT(f16_reduced, i16_const(f16_inf_value)), - i16_const(exponent_bias_difference << f16_mantissa_bits), - i16_const(exponent_bias_difference << (f16_mantissa_bits + 1)))); + // to ensure the leading exponent bit(s) of fx_reduced are set to zero. + fx_reduced = b->CreateSub( + fx_reduced, b->CreateSelect(b->CreateICmpULT(fx_reduced, infinity), + fx_exponent_bias_difference, + fx_doubled_exponent_bias_difference)); // Shift to convert to F8. - // f16_reduced = f16_reduced >> mantissa_bits_difference; - f16_reduced = b->CreateLShr(f16_reduced, i16_const(mantissa_bits_difference)); + // fx_reduced = fx_reduced >> mantissa_bits_difference; + fx_reduced = b->CreateLShr(fx_reduced, mantissa_bits_difference); - Value* f8_bits = b->CreateTrunc(f16_reduced, i8_type); + Value* f8_bits = b->CreateTrunc(fx_reduced, i8_type); - // Handle F16 values that are halfway between denormal F8 values. - f8_bits = - handle_halfway_points_F16ToF8(f16_abs_bits, f8_bits, b); + // Handle Fx values that are halfway between denormal F8 values. + f8_bits = handle_halfway_points_FxToF8(fx_abs_bits, + f8_bits, b); // Set the sign bit. // f8_bits |= f8_sign @@ -636,8 +691,8 @@ llvm::Value* EmitF16ToF8e4m3fn(llvm::Value* f16_value, llvm::IRBuilderBase* b) { i8_const(0x7F), f8_bits); // Handle F16 values that are halfway between denormal F8 values. - f8_bits = - handle_halfway_points_F16ToF8(f16_abs_bits, f8_bits, b); + f8_bits = handle_halfway_points_FxToF8(f16_abs_bits, + f8_bits, b); // Set the sign bit. // f8_bits |= f8_sign @@ -814,13 +869,13 @@ llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, PrimitiveType to_type, llvm::Module* module, llvm::IRBuilderBase* b) { if (primitive_util::IsSignedIntegralType(from_type)) { - return b->CreateSIToFP(integer_value, - llvm_ir::PrimitiveTypeToIrType(to_type, module)); + return b->CreateSIToFP(integer_value, llvm_ir::PrimitiveTypeToIrType( + to_type, module->getContext())); } else { CHECK(primitive_util::IsUnsignedIntegralType(from_type) || from_type == PRED); - return b->CreateUIToFP(integer_value, - llvm_ir::PrimitiveTypeToIrType(to_type, module)); + return b->CreateUIToFP(integer_value, llvm_ir::PrimitiveTypeToIrType( + to_type, module->getContext())); } } @@ -870,12 +925,13 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( return b_->CreateZExt( ICmpNE(operand_value, llvm::ConstantInt::get(operand_value->getType(), 0)), - llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + llvm_ir::PrimitiveTypeToIrType(PRED, module_->getContext())); } if (primitive_util::IsIntegralType(to_type)) { - return IntCast(operand_value, - llvm_ir::PrimitiveTypeToIrType(to_type, module_), - primitive_util::IsSignedIntegralType(from_type)); + return IntCast( + operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_->getContext()), + primitive_util::IsSignedIntegralType(from_type)); } if (primitive_util::IsFloatingPointType(to_type)) { if (to_type == F8E5M2) { @@ -885,7 +941,7 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( b_); } if (to_type == F8E4M3) { - return EmitF16ToF8e<4>( + return EmitFxToF8e( EmitIntegralToFloating(operand_value, from_type, F16, module_, b_), b_); @@ -910,7 +966,7 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( to_type, b_); } if (to_type == F8E3M4) { - return EmitF16ToF8e<3>( + return EmitFxToF8e( EmitIntegralToFloating(operand_value, from_type, F16, module_, b_), b_); @@ -920,7 +976,8 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } if (primitive_util::IsComplexType(to_type)) { auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType( - primitive_util::ComplexComponentType(to_type), module_); + primitive_util::ComplexComponentType(to_type), + module_->getContext()); if (primitive_util::IsSignedIntegralType(from_type)) { return EmitComposeComplex( op, SIToFP(operand_value, to_ir_component_type), nullptr); @@ -944,8 +1001,8 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } if (primitive_util::BitWidth(from_type) == primitive_util::BitWidth(to_type)) { - return BitCast(operand_value, - llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return BitCast(operand_value, llvm_ir::PrimitiveTypeToIrType( + to_type, module_->getContext())); } return InvalidArgument( "bitcast conversion from primitive type %s to %s with unequal " @@ -958,8 +1015,8 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( bool is_signed = primitive_util::IsSignedIntegralType(op->shape().element_type()); if (is_signed) { - auto type = - llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); + auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), + module_->getContext()); auto cmp = ICmpSGE(operand_value, GetZero(type)); return Select(cmp, operand_value, Neg(operand_value)); } else { @@ -975,8 +1032,8 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( case HloOpcode::kSign: { CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type())) << op->shape().element_type(); - auto type = - llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); + auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), + module_->getContext()); auto cmp = ICmpEQ(operand_value, GetZero(type)); auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1); return Select(cmp, GetZero(type), Or(ashr, 1)); @@ -989,8 +1046,9 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( // It is not sufficient to just call CreateNot() here because a PRED // is represented as an i8 and the truth value is stored only in the // bottom bit. - return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())), - llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + return b_->CreateZExt( + Not(Trunc(operand_value, b_->getInt1Ty())), + llvm_ir::PrimitiveTypeToIrType(PRED, module_->getContext())); } else if (primitive_util::IsIntegralType(type)) { return Not(operand_value); } @@ -1134,7 +1192,8 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return EmitComposeComplex( op, FPCast(operand_value, - llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)), + llvm_ir::PrimitiveTypeToIrType(to_component_type, + module_->getContext())), nullptr); } if (to_type == BF16) { @@ -1148,23 +1207,36 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( // Cast to F16 first. Casts to F8E5M2 must be from F16. if (from_type != F16) { operand_value = b_->CreateFPCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(F16, module_)); + operand_value, + llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext())); } return EmitF16ToF8e5m2(operand_value, b_); } if (to_type == F8E4M3) { - // Cast to F16 first. Casts to F8E4M3 must be from F16. - if (from_type != F16) { - operand_value = b_->CreateFPCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(F16, module_)); + switch (from_type) { + case F16: + return EmitFxToF8e(operand_value, b_); + case F32: + return EmitFxToF8e(operand_value, b_); + case F64: + return EmitFxToF8e(operand_value, b_); + case BF16: + operand_value = b_->CreateFPCast( + operand_value, + llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext())); + return EmitFxToF8e(operand_value, b_); + default: + return InvalidArgument("Unsupported conversion from %s to %s", + PrimitiveType_Name(from_type), + PrimitiveType_Name(to_type)); } - return EmitF16ToF8e<4>(operand_value, b_); } if (to_type == F8E4M3FN) { // Cast to F16 first. Casts to F8E4M3FN must be from F16. if (from_type != F16) { operand_value = b_->CreateFPCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(F16, module_)); + operand_value, + llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext())); } return EmitF16ToF8e4m3fn(operand_value, b_); } @@ -1172,7 +1244,8 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( // Cast to F16 first. Casts to F8E4M3B11FNUZ must be from F16. if (from_type != F16) { operand_value = b_->CreateFPCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(F16, module_)); + operand_value, + llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext())); } return EmitF16ToF8e4m3b11fnuz(operand_value, b_); } @@ -1180,24 +1253,37 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return EmitFloatingToF8fnuz(from_type, operand_value, to_type, b_); } if (to_type == F8E3M4) { - // Cast to F16 first. Casts to F8E3M4 must be from F16. - if (from_type != F16) { - operand_value = b_->CreateFPCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(F16, module_)); + switch (from_type) { + case F16: + return EmitFxToF8e(operand_value, b_); + case F32: + return EmitFxToF8e(operand_value, b_); + case F64: + return EmitFxToF8e(operand_value, b_); + case BF16: + operand_value = b_->CreateFPCast( + operand_value, + llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext())); + return EmitFxToF8e(operand_value, b_); + default: + return InvalidArgument("Unsupported conversion from %s to %s", + PrimitiveType_Name(from_type), + PrimitiveType_Name(to_type)); } - return EmitF16ToF8e<3>(operand_value, b_); } if (to_type == PRED) { return b_->CreateZExt( FCmpUNE(operand_value, llvm::ConstantFP::get(operand_value->getType(), 0.0)), - llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + llvm_ir::PrimitiveTypeToIrType(PRED, module_->getContext())); } - auto* to_ir_type = llvm_ir::PrimitiveTypeToIrType(to_type, module_); + auto* to_ir_type = + llvm_ir::PrimitiveTypeToIrType(to_type, module_->getContext()); if (primitive_util::IsFloatingPointType(to_type)) { return FPCast(operand_value, to_ir_type); } - auto* from_ir_type = llvm_ir::PrimitiveTypeToIrType(from_type, module_); + auto* from_ir_type = + llvm_ir::PrimitiveTypeToIrType(from_type, module_->getContext()); int to_width = primitive_util::BitWidth(to_type); if (primitive_util::IsSignedIntegralType(to_type)) { int64_t min_int = llvm::minIntN(to_width); @@ -1207,8 +1293,9 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( auto max_value_int = llvm::ConstantInt::get(to_ir_type, max_int); auto min_value_float = llvm::ConstantFP::get(from_ir_type, min_int); auto max_value_float = llvm::ConstantFP::get(from_ir_type, max_int); - auto clamped = FPToSI(operand_value, - llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + auto clamped = FPToSI( + operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_->getContext())); // x <= static_cast(INT_MIN) ? INT_MIN : ... clamped = Select(FCmpOLE(operand_value, min_value_float), min_value_int, clamped); @@ -1227,8 +1314,9 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( auto max_value_int = llvm::ConstantInt::get(to_ir_type, max_int); auto min_value_float = llvm::ConstantFP::get(from_ir_type, min_int); auto max_value_float = llvm::ConstantFP::get(from_ir_type, max_int); - auto clamped = FPToUI(operand_value, - llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + auto clamped = FPToUI( + operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_->getContext())); // (x <= 0.0 || isnan(x)) ? 0 : ... clamped = Select(FCmpULE(operand_value, min_value_float), min_value_int, clamped); @@ -1250,8 +1338,8 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } if (primitive_util::BitWidth(from_type) == primitive_util::BitWidth(to_type)) { - return BitCast(operand_value, - llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return BitCast(operand_value, llvm_ir::PrimitiveTypeToIrType( + to_type, module_->getContext())); } return InvalidArgument( "bitcast conversion from primitive type %s to %s with unequal " @@ -1327,8 +1415,8 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( llvm::Intrinsic::fabs, {operand_value}, {type}, b_); auto infinity = llvm::ConstantFP::getInfinity(type); auto not_infinite = FCmpONE(abs_value, infinity); - return b_->CreateZExt(not_infinite, - llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + return b_->CreateZExt(not_infinite, llvm_ir::PrimitiveTypeToIrType( + PRED, module_->getContext())); } case HloOpcode::kNegate: return FNeg(operand_value); @@ -1424,8 +1512,8 @@ absl::StatusOr ElementalIrEmitter::EmitComplexUnaryOp( } PrimitiveType to_component_type = primitive_util::ComplexComponentType(to_type); - auto to_ir_component_type = - llvm_ir::PrimitiveTypeToIrType(to_component_type, module_); + auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType( + to_component_type, module_->getContext()); return EmitComposeComplex( op, FPCast(EmitExtractReal(operand_value), to_ir_component_type), FPCast(EmitExtractImag(operand_value), to_ir_component_type)); @@ -2335,7 +2423,8 @@ absl::StatusOr ElementalIrEmitter::EmitComplexBinaryOp( TF_ASSIGN_OR_RETURN( auto sqrt_x_squared_plus_y_squared, EmitComplexSqrt(op, component_type, x_squared_plus_y_squared)); - auto type = llvm_ir::PrimitiveTypeToIrType(component_type, module_); + auto type = + llvm_ir::PrimitiveTypeToIrType(component_type, module_->getContext()); auto zero = llvm::ConstantFP::get(type, 0.0); auto one = llvm::ConstantFP::get(type, 1.0); auto i = EmitComposeComplex(op, zero, one); @@ -2376,7 +2465,7 @@ absl::StatusOr ElementalIrEmitter::EmitLog( absl::StatusOr ElementalIrEmitter::EmitLog1p( PrimitiveType prim_type, llvm::Value* value) { auto x = value; - auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_->getContext()); auto one = llvm::ConstantFP::get(type, 1.0); auto negative_half = llvm::ConstantFP::get(type, -0.5); // When x is large, the naive evaluation of ln(x + 1) is more @@ -2450,7 +2539,7 @@ absl::StatusOr ElementalIrEmitter::EmitCos( absl::StatusOr ElementalIrEmitter::EmitCosm1( PrimitiveType prim_type, llvm::Value* value) { auto x = value; - auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_->getContext()); auto negative_half = llvm::ConstantFP::get(type, -0.5); auto negative_one = llvm::ConstantFP::get(type, -1.0); @@ -2495,7 +2584,7 @@ absl::StatusOr ElementalIrEmitter::EmitExp( absl::StatusOr ElementalIrEmitter::EmitExpm1( PrimitiveType prim_type, llvm::Value* value) { auto x = value; - auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_->getContext()); auto one = llvm::ConstantFP::get(type, 1.0); auto half = llvm::ConstantFP::get(type, 0.5); auto zero = llvm::ConstantFP::get(type, 0.0); @@ -2527,7 +2616,7 @@ absl::StatusOr ElementalIrEmitter::EmitPow( absl::StatusOr ElementalIrEmitter::EmitCbrt( PrimitiveType prim_type, llvm::Value* value) { - auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_->getContext()); auto third = llvm::ConstantFP::get(type, 1.0 / 3.0); auto abs_value = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); @@ -2909,9 +2998,10 @@ absl::StatusOr ElementalIrEmitter::EmitElementalConcatenate( } llvm_ir::SetToFirstInsertPoint(exit_block, b_); - llvm::PHINode* output = b_->CreatePHI( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), - hlo->operands().size()); + llvm::PHINode* output = + b_->CreatePHI(llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), + module_->getContext()), + hlo->operands().size()); auto prior_insert_point = b_->GetInsertPoint(); b_->SetInsertPoint(init_block); @@ -3272,7 +3362,8 @@ ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // if (slice_intersection) -> return data from 'update'. // else -> return data from 'input'. llvm::AllocaInst* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), + module_->getContext()), "ret_value_addr", b_); llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(slice_intersection, "slice_intersection", b_); @@ -3336,7 +3427,8 @@ absl::StatusOr ElementalIrEmitter::EmitElementalPad( // ret_value = *operand1; // padding // } llvm::AllocaInst* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), + module_->getContext()), "pad_result_addr", b_); llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_); @@ -3400,9 +3492,10 @@ absl::StatusOr ElementalIrEmitter::EmitElementalDot( SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), b_); PrimitiveType primitive_type = hlo->shape().element_type(); llvm::Type* primitive_type_llvm = - llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); + llvm_ir::PrimitiveTypeToIrType(primitive_type, module_->getContext()); if (primitive_type == BF16) { - primitive_type_llvm = llvm_ir::PrimitiveTypeToIrType(F32, module_); + primitive_type_llvm = + llvm_ir::PrimitiveTypeToIrType(F32, module_->getContext()); } llvm::AllocaInst* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_); @@ -3627,7 +3720,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( if (primitive_util::IsIntegralType(component_element_type)) { iota_result = b_->CreateIntCast( elem_index_linear, - llvm_ir::PrimitiveTypeToIrType(component_element_type, module_), + llvm_ir::PrimitiveTypeToIrType(component_element_type, + module_->getContext()), /*isSigned=*/false); } else { TF_RET_CHECK( @@ -3635,12 +3729,14 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( << component_element_type; llvm::Type* float_ir_type; if (component_element_type == F8E4M3FNUZ) { - float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_); + float_ir_type = + llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext()); } else if (component_element_type == F8E5M2FNUZ) { - float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_); - } else { float_ir_type = - llvm_ir::PrimitiveTypeToIrType(component_element_type, module_); + llvm_ir::PrimitiveTypeToIrType(F16, module_->getContext()); + } else { + float_ir_type = llvm_ir::PrimitiveTypeToIrType( + component_element_type, module_->getContext()); } llvm::Value* float_val = b_->CreateUIToFP(elem_index_linear, float_ir_type); @@ -3796,8 +3892,8 @@ llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) { llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, llvm::Value* real, llvm::Value* imag) { - auto cplx_type = - llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); + auto cplx_type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), + module_->getContext()); auto complex = InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0}); if (imag != nullptr) { @@ -3862,11 +3958,12 @@ absl::StatusOr ElementalIrEmitter::EmitElementalReduceWindow( auto operand = reduce_window->inputs()[operand_index]; PrimitiveType operand_element_type = operand->shape().element_type(); operand_element_types.push_back(operand_element_type); - llvm::Type* llvm_type = - llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_); + llvm::Type* llvm_type = llvm_ir::PrimitiveTypeToIrType( + operand_element_type, module_->getContext()); accum_types.push_back(llvm_type); llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), + llvm_ir::PrimitiveTypeToIrType(operand_element_type, + module_->getContext()), "reduce_window_accum_ptr", b_); accum_ptrs.push_back(accum_ptr); { @@ -3988,7 +4085,7 @@ absl::StatusOr ElementalIrEmitter::EmitElementalReduce( is_variadic ? out_shape.tuple_shapes(i) : out_shape; PrimitiveType accumulator_type = element_shape.element_type(); llvm::Type* accumulator_llvm_type = - llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_); + llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_->getContext()); accumulator_types.push_back(accumulator_llvm_type); // Initialize an accumulator with init_value. @@ -4102,7 +4199,7 @@ absl::StatusOr ElementalIrEmitter::EmitConvolution( // at the given index. PrimitiveType lhs_element_type = lhs->shape().element_type(); llvm::Type* lhs_llvm_type = - llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_); + llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_->getContext()); // Upcast the accumulator to F32 from F16 for increased precision. llvm::Type* accumulator_type = lhs_element_type == F16 ? b_->getFloatTy() : lhs_llvm_type; diff --git a/third_party/xla/xla/service/elemental_ir_emitter_test.cc b/third_party/xla/xla/service/elemental_ir_emitter_test.cc index 71847a88ca518a..f947aa8ada14c0 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter_test.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -445,27 +446,28 @@ XLA_TEST_F(ElementalIrEmitterExecutionTest, TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatsToFloat) { auto tname = this->TypeName(); - if (std::is_same() || - std::is_same() || + const int n = 10; + if (std::is_same() || std::is_same()) { GTEST_SKIP() << "Skipping test for type " << tname; } - const auto hlo_text = absl::StrReplaceAll(R"( + const auto hlo_text = + absl::StrReplaceAll(R"( HloModule m ENTRY main { - f16_ = f16[] parameter(0) - f32_ = f32[] parameter(1) - f64_ = f64[] parameter(2) - bf16_ = bf16[] parameter(3) - converted_f16 = ${tname}[] convert(f16_) - converted_f32 = ${tname}[] convert(f32_) - converted_f64 = ${tname}[] convert(f64_) - converted_bf16 = ${tname}[] convert(bf16_) - ROOT tuple = (${tname}[], ${tname}[], ${tname}[], ${tname}[]) tuple( + f16_ = f16[$n] parameter(0) + f32_ = f32[$n] parameter(1) + f64_ = f64[$n] parameter(2) + bf16_ = bf16[$n] parameter(3) + converted_f16 = ${tname}[$n] convert(f16_) + converted_f32 = ${tname}[$n] convert(f32_) + converted_f64 = ${tname}[$n] convert(f64_) + converted_bf16 = ${tname}[$n] convert(bf16_) + ROOT tuple = (${tname}[$n], ${tname}[$n], ${tname}[$n], ${tname}[$n]) tuple( converted_f16, converted_f32, converted_f64, converted_bf16) } )", - {{"${tname}", tname}}); + {{"${tname}", tname}, {"$n", absl::StrCat(n)}}); ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); } diff --git a/third_party/xla/xla/service/executable_test.cc b/third_party/xla/xla/service/executable_test.cc index 8c21dbe3603517..3c896a016396ee 100644 --- a/third_party/xla/xla/service/executable_test.cc +++ b/third_party/xla/xla/service/executable_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include @@ -56,7 +55,7 @@ TEST_F(ExecutableTest, HloProtoGetterIsThreadCompatible) { // thread-compatible way. // Note that this test needs to run with --config=tsan to reliably // detect any potential data races. - constexpr std::string_view kHloModule = R"( + constexpr absl::string_view kHloModule = R"( HloModule module ENTRY main { diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index ff43a34c1d44b6..42459742ce1e61 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -135,6 +135,7 @@ cc_library( deps = [ "//xla:shape_util", "//xla:util", + "//xla/service:platform_util", "//xla/stream_executor:device_description", "//xla/stream_executor:launch_dim", "@com_google_absl//absl/log", @@ -243,6 +244,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service/llvm_ir:llvm_type_conversion_util", "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -258,7 +260,9 @@ xla_cc_test( srcs = ["target_util_test.cc"], deps = [ ":target_util", + "//xla:xla_data_proto_cc", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", "@llvm-project//llvm:TargetParser", @@ -338,7 +342,6 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/gpu/collectives:gpu_clique_key", - "//xla/backends/gpu/collectives:gpu_clique_locking", "//xla/ffi:attribute_map", "//xla/ffi:ffi_api", "//xla/ffi/api:c_api", @@ -381,6 +384,7 @@ cc_library( "//xla/service/gpu/runtime:nccl_collective_thunk", "//xla/service/gpu/runtime:nccl_group_thunk", "//xla/service/gpu/runtime:nccl_p2p_thunk_common", + "//xla/service/gpu/runtime:nccl_ragged_all_to_all_thunk", "//xla/service/gpu/runtime:nccl_recv_thunk", "//xla/service/gpu/runtime:nccl_send_thunk", "//xla/service/gpu/runtime:norm_thunk", @@ -499,6 +503,7 @@ cc_library( "TENSORFLOW_USE_ROCM=1", ]), deps = [ + "@com_google_absl//absl/strings:string_view", "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", @@ -573,7 +578,7 @@ cc_library( "//xla:util", "//xla/backends/gpu/collectives:gpu_clique", "//xla/backends/gpu/collectives:gpu_clique_key", - "//xla/backends/gpu/collectives:gpu_clique_locking", + "//xla/backends/gpu/collectives:gpu_cliques", "//xla/core/collectives:rank_id", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", @@ -838,10 +843,10 @@ xla_cc_test( deps = [ ":triton_fusion_analysis", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service/gpu/transforms:gemm_fusion", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest", @@ -894,12 +899,12 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:hlo_verifier", "//xla/service:layout_assignment", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", @@ -943,6 +948,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # build_cleaner: keep "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", ], @@ -994,6 +1000,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # build_cleaner: keep "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", ], @@ -1134,9 +1141,9 @@ xla_cc_test( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:verified_hlo_module", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1222,8 +1229,10 @@ cc_library( srcs = ["gpu_float_support.cc"], hdrs = ["gpu_float_support.h"], deps = [ + "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service:collective_ops_utils", "//xla/service:float_support", "//xla/service/gpu/fusions/triton:triton_support", "//xla/stream_executor:device_description", @@ -1314,7 +1323,7 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/service:cpu_gpu_shape_verifier", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_cse", @@ -1340,7 +1349,7 @@ cc_library( "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/service:copy_insertion", "//xla/service:cpu_gpu_shape_verifier", "//xla/service:hlo_verifier", @@ -1405,6 +1414,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "@llvm-project//llvm:AsmParser", @@ -1431,53 +1441,53 @@ cc_library( "//xla/hlo/transforms/collectives:collective_quantizer", "//xla/hlo/transforms/collectives:collectives_schedule_linearizer", "//xla/hlo/transforms/collectives:convert_async_collectives_to_sync", - "//xla/hlo/transforms:algebraic_simplifier", - "//xla/hlo/transforms:all_reduce_folder", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:all_reduce_folder", + "//xla/hlo/transforms/simplifiers:broadcast_canonicalizer", + "//xla/hlo/transforms/simplifiers:conditional_canonicalizer", + "//xla/hlo/transforms/simplifiers:convert_mover", + "//xla/hlo/transforms/simplifiers:dot_merger", + "//xla/hlo/transforms/simplifiers:dynamic_dimension_simplifier", + "//xla/hlo/transforms/simplifiers:flatten_call_graph", + "//xla/hlo/transforms/simplifiers:float_normalization", + "//xla/hlo/transforms/simplifiers:gather_simplifier", + "//xla/hlo/transforms/simplifiers:hlo_computation_deduplicator", + "//xla/hlo/transforms/simplifiers:hlo_constant_folding", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_rematerialization", + "//xla/hlo/transforms/simplifiers:host_memory_transfer_asyncifier", + "//xla/hlo/transforms/simplifiers:optimize_input_output_buffer_alias", + "//xla/hlo/transforms/simplifiers:reduce_window_rewriter", + "//xla/hlo/transforms/simplifiers:reshape_mover", + "//xla/hlo/transforms/simplifiers:result_caster", + "//xla/hlo/transforms/simplifiers:simplify_fp_conversions", + "//xla/hlo/transforms/simplifiers:slice_sinker", + "//xla/hlo/transforms/simplifiers:sort_simplifier", + "//xla/hlo/transforms/simplifiers:sub_byte_normalization", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:zero_sized_hlo_elimination", "//xla/hlo/transforms:bitcast_dtypes_expander", - "//xla/hlo/transforms:broadcast_canonicalizer", "//xla/hlo/transforms:comparison_expander", - "//xla/hlo/transforms:conditional_canonicalizer", "//xla/hlo/transforms:convert_memory_placement_to_internal_annotations", - "//xla/hlo/transforms:convert_mover", "//xla/hlo/transforms:convolution_4d_expander", "//xla/hlo/transforms:convolution_pred_expander", "//xla/hlo/transforms:dot_decomposer", - "//xla/hlo/transforms:dot_merger", - "//xla/hlo/transforms:dynamic_dimension_simplifier", "//xla/hlo/transforms:dynamic_index_splitter", "//xla/hlo/transforms:eigh_expander", - "//xla/hlo/transforms:flatten_call_graph", - "//xla/hlo/transforms:float_normalization", - "//xla/hlo/transforms:gather_simplifier", - "//xla/hlo/transforms:hlo_computation_deduplicator", - "//xla/hlo/transforms:hlo_constant_folding", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:hlo_rematerialization", - "//xla/hlo/transforms:host_memory_transfer_asyncifier", "//xla/hlo/transforms:host_offload_legalize", "//xla/hlo/transforms:host_offloader", "//xla/hlo/transforms:logistic_expander", "//xla/hlo/transforms:operand_upcaster", "//xla/hlo/transforms:optimization_barrier_expander", - "//xla/hlo/transforms:optimize_input_output_buffer_alias", "//xla/hlo/transforms:qr_expander", "//xla/hlo/transforms:real_imag_expander", "//xla/hlo/transforms:reduce_decomposer", - "//xla/hlo/transforms:reduce_window_rewriter", "//xla/hlo/transforms:reshape_decomposer", - "//xla/hlo/transforms:reshape_mover", - "//xla/hlo/transforms:result_caster", "//xla/hlo/transforms:rng_bit_generator_expander", "//xla/hlo/transforms:rng_expander", - "//xla/hlo/transforms:simplify_fp_conversions", - "//xla/hlo/transforms:slice_sinker", - "//xla/hlo/transforms:sort_simplifier", "//xla/hlo/transforms:stable_sort_expander", "//xla/hlo/transforms:stochastic_convert_decomposer", - "//xla/hlo/transforms:sub_byte_normalization", - "//xla/hlo/transforms:tuple_simplifier", "//xla/hlo/transforms:while_loop_trip_count_annotator", - "//xla/hlo/transforms:zero_sized_hlo_elimination", "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", "//xla/hlo/translate/mhlo_to_hlo:location_exporter", "//xla/hlo/utils:hlo_query", @@ -1487,6 +1497,7 @@ cc_library( "//xla/service/gpu/fusions/triton:triton_support", "//xla/service/gpu/model:gpu_cost_model_stats_collection", "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/service/gpu/model:sol_gpu_cost_model_stats_collection", "//xla/service/gpu/runtime:thunk", "//xla/service/gpu/transforms:algebraic_simplifier", "//xla/service/gpu/transforms:algorithm_checker", @@ -1606,7 +1617,6 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", - "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", @@ -1648,6 +1658,7 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:compiler", "//xla/service:executable", "//xla/service:hlo_module_config", @@ -1658,10 +1669,8 @@ xla_test( "//xla/stream_executor:device_description", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", - "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/lib/monitoring:collected_metrics", @@ -1693,8 +1702,8 @@ xla_test( "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", - "//xla/hlo/transforms:hlo_memory_scheduler", - "//xla/hlo/transforms:hlo_rematerialization", + "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", + "//xla/hlo/transforms/simplifiers:hlo_rematerialization", "//xla/hlo/utils:hlo_matchers", "//xla/service:buffer_value", "//xla/service:hlo_cost_analysis", @@ -1715,6 +1724,7 @@ xla_test( backends = ["gpu"], tags = ["no_oss"], # TODO(b/277355322): Make autosharding work in OSS deps = [ + "//xla:xla_data_proto_cc", "//xla/hlo/experimental/auto_sharding:auto_sharding_option", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", @@ -1722,6 +1732,7 @@ xla_test( "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/log", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:logging", ], @@ -1777,14 +1788,14 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:algebraic_simplifier", - "//xla/hlo/transforms:convert_mover", - "//xla/hlo/transforms:dot_dimension_merger", - "//xla/hlo/transforms:float_normalization", - "//xla/hlo/transforms:hlo_constant_folding", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:reshape_mover", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:convert_mover", + "//xla/hlo/transforms/simplifiers:dot_dimension_merger", + "//xla/hlo/transforms/simplifiers:float_normalization", + "//xla/hlo/transforms/simplifiers:hlo_constant_folding", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:reshape_mover", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:call_inliner", "//xla/service:dump", @@ -1796,7 +1807,7 @@ cc_library( "//xla/service/gpu/autotuning:conv_algorithm_picker", "//xla/service/gpu/autotuning:gemm_algorithm_picker", "//xla/service/gpu/autotuning:gemm_fusion_autotuner", - "//xla/service/gpu/llvm_gpu_backend", + "//xla/service/gpu/llvm_gpu_backend:nvptx_backend", "//xla/service/gpu/llvm_gpu_backend:nvptx_utils", "//xla/service/gpu/transforms:algebraic_simplifier", "//xla/service/gpu/transforms:conv_padding_legalization", @@ -1954,6 +1965,7 @@ xla_cc_test( ":amdgpu_compiler_impl", ]) + [ ":gpu_transfer_manager", + "//xla:literal_util", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/service:compiler", @@ -1965,6 +1977,7 @@ xla_cc_test( "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", # build_cleaner: keep "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", @@ -2010,17 +2023,18 @@ cc_library( ":gpu_compiler", ":target_constants", "//xla:util", + "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:algebraic_simplifier", - "//xla/hlo/transforms:convert_mover", - "//xla/hlo/transforms:dot_dimension_merger", - "//xla/hlo/transforms:float_normalization", - "//xla/hlo/transforms:hlo_constant_folding", - "//xla/hlo/transforms:reshape_mover", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:convert_mover", + "//xla/hlo/transforms/simplifiers:dot_dimension_merger", + "//xla/hlo/transforms/simplifiers:float_normalization", + "//xla/hlo/transforms/simplifiers:hlo_constant_folding", + "//xla/hlo/transforms/simplifiers:reshape_mover", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:call_inliner", "//xla/service:float_support", @@ -2030,7 +2044,7 @@ cc_library( "//xla/service/gpu/autotuning:conv_algorithm_picker", "//xla/service/gpu/autotuning:gemm_algorithm_picker", "//xla/service/gpu/autotuning:gemm_fusion_autotuner", - "//xla/service/gpu/llvm_gpu_backend", + "//xla/service/gpu/llvm_gpu_backend:amdgpu_backend", "//xla/service/gpu/transforms:algebraic_simplifier", "//xla/service/gpu/transforms:conv_padding_legalization", "//xla/service/gpu/transforms:conv_rewriter", @@ -2113,15 +2127,15 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:hlo_memory_scheduler", + "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", "//xla/hlo/utils:hlo_query", "//xla/service:buffer_value", "//xla/service:collective_ops_utils", - "//xla/service:collective_utils", "//xla/service:latency_hiding_scheduler", "//xla/service:p2p_schedule_preparation", "//xla/service:profile_guided_latency_estimator", "//xla/service/gpu/model:analytical_latency_estimator", + "//xla/service/gpu/model:sol_latency_estimator", "//xla/service/gpu/transforms:pgle_accuracy_checker", "//xla/service/gpu/transforms:schedule_postprocessing", "//xla/service/gpu/transforms:scheduling_instruction_annotator", @@ -2156,14 +2170,14 @@ xla_test( "//xla:shape_util", "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:verified_hlo_module", "//xla/hlo/utils:hlo_query", "//xla/service:backend", "//xla/service:hlo_module_config", "//xla/stream_executor:device_description", - "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", @@ -2206,9 +2220,9 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/testlib:filecheck", "//xla/service:hlo_module_config", "//xla/service:hlo_verifier", - "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/log:check", @@ -2226,13 +2240,13 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:algebraic_simplifier", - "//xla/hlo/transforms:hlo_constant_folding", - "//xla/hlo/transforms:hlo_constant_splitter", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:reshape_mover", - "//xla/hlo/transforms:sort_simplifier", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:hlo_constant_folding", + "//xla/hlo/transforms/simplifiers:hlo_constant_splitter", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:reshape_mover", + "//xla/hlo/transforms/simplifiers:sort_simplifier", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/service:conditional_simplifier", "//xla/service:gather_expander", "//xla/service:hlo_module_config", @@ -2263,7 +2277,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", "//xla/service:hlo_module_config", "//xla/service/spmd/shardy:constants", "//xla/stream_executor:device_description", @@ -2309,6 +2323,7 @@ gpu_kernel_library( "@local_config_cuda//cuda:cuda_headers", ]) + if_rocm_is_configured([ "@local_config_rocm//rocm:rocm_headers", + "@local_config_rocm//rocm:rocm_config", ]), ) @@ -2316,6 +2331,7 @@ cc_library( name = "stream_executor_util", srcs = ["stream_executor_util.cc"], hdrs = ["stream_executor_util.h"], + compatible_with = get_compatible_with_portable(), copts = tsl_copts(), deps = [ ":cublas_cudnn", @@ -2468,6 +2484,10 @@ cc_library( "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:statusor", + ]) + if_rocm_is_configured([ + # keep sorted + "@local_config_rocm//rocm:rocm_config", + "@local_config_rocm//rocm:rocm_headers", ]), ) @@ -2478,7 +2498,9 @@ gpu_kernel_library( deps = [ "//xla:shape_util", "//xla:types", - ], + ] + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + ]), ) xla_test( @@ -2666,7 +2688,7 @@ xla_cc_test( "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/transforms:float_normalization", + "//xla/hlo/transforms/simplifiers:float_normalization", "//xla/service:hlo_verifier", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", @@ -2701,6 +2723,7 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/monitoring:collected_metrics", "//xla/tsl/lib/monitoring:collection_registry", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:test", ], ) @@ -2719,6 +2742,8 @@ cc_library( "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:typed_kernel_factory", "//xla/stream_executor/gpu:gpu_stream_header", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:errors", @@ -2947,6 +2972,7 @@ xla_test( cc_library( name = "gpu_symbol_repository", hdrs = ["gpu_symbol_repository.h"], + compatible_with = get_compatible_with_portable(), deps = [ "//xla:autotune_results_proto_cc", "//xla:xla_proto_cc", @@ -3024,6 +3050,7 @@ xla_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ], ) @@ -3061,13 +3088,13 @@ cc_library( hdrs = ["gpu_collective_combiner_utils.h"], deps = [ ":backend_configs_cc", + ":gpu_hlo_schedule", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", "//xla/service:collective_utils", "//xla/stream_executor:device_description", "@com_google_absl//absl/log", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", ], ) @@ -3078,24 +3105,19 @@ xla_cc_test( deps = [ ":backend_configs_cc", ":gpu_collective_combiner_utils", - ":gpu_hlo_schedule", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/utils:hlo_query", "//xla/service:collective_pipeliner", - "//xla/service:collective_utils", "//xla/service:hlo_module_config", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", ], ) @@ -3106,7 +3128,6 @@ cc_library( deps = [ ":backend_configs_cc", ":gpu_collective_combiner_utils", - ":gpu_hlo_schedule", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/transforms/collectives:all_gather_combiner", @@ -3145,7 +3166,6 @@ cc_library( deps = [ ":backend_configs_cc", ":gpu_collective_combiner_utils", - ":gpu_hlo_schedule", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_domain_map", @@ -3182,7 +3202,6 @@ cc_library( deps = [ ":backend_configs_cc", ":gpu_collective_combiner_utils", - ":gpu_hlo_schedule", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/transforms/collectives:all_reduce_combiner", @@ -3217,7 +3236,7 @@ cc_library( srcs = ["ptx_compile_options_from_debug_options.cc"], hdrs = ["ptx_compile_options_from_debug_options.h"], deps = [ - "//xla:xla_proto_cc_impl", + "//xla:xla_proto_cc", "//xla/stream_executor/cuda:compilation_options", ], ) @@ -3252,7 +3271,7 @@ xla_cc_test( deps = [ ":flag_utils", "//xla/hlo/ir:hlo", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/service:collective_pipeliner", "//xla/service:hlo_module_config", "//xla/service:latency_hiding_scheduler", diff --git a/third_party/xla/xla/service/gpu/all_gather_combiner.cc b/third_party/xla/xla/service/gpu/all_gather_combiner.cc index 996d3a1fe83bed..96f10d43113b5c 100644 --- a/third_party/xla/xla/service/gpu/all_gather_combiner.cc +++ b/third_party/xla/xla/service/gpu/all_gather_combiner.cc @@ -28,7 +28,6 @@ limitations under the License. #include "xla/hlo/transforms/collectives/all_gather_combiner.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_collective_combiner_utils.h" -#include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/hlo_domain_map.h" #include "tsl/platform/statusor.h" @@ -78,8 +77,7 @@ absl::StatusOr GpuAllGatherCombiner::Run( // Combine as much as possible for pipelined collectives. int previous_combiner_threshold = combine_threshold_in_bytes_; combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold( - *module, device_info_, ScheduleGpuModuleWithMemoryScheduler, - HloOpcode::kAllGather, pointer_size_); + *module, device_info_, HloOpcode::kAllGather, pointer_size_); TF_ASSIGN_OR_RETURN( bool combined_pipelined_instructions, RunWithKeyCombiner(module, execution_threads, PipelinedCombinerKey)); diff --git a/third_party/xla/xla/service/gpu/all_reduce_combiner.cc b/third_party/xla/xla/service/gpu/all_reduce_combiner.cc index 108d10cee3e5d3..5fb3d960bb2371 100644 --- a/third_party/xla/xla/service/gpu/all_reduce_combiner.cc +++ b/third_party/xla/xla/service/gpu/all_reduce_combiner.cc @@ -28,7 +28,6 @@ limitations under the License. #include "xla/hlo/transforms/collectives/all_reduce_combiner.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_collective_combiner_utils.h" -#include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/hlo_domain_map.h" #include "tsl/platform/statusor.h" @@ -76,8 +75,7 @@ absl::StatusOr GpuAllReduceCombiner::Run( // Combine as much as possible for pipelined collectives. int previous_combiner_threshold = combine_threshold_in_bytes_; combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold( - *module, device_info_, ScheduleGpuModuleWithMemoryScheduler, - HloOpcode::kAllReduce, pointer_size_); + *module, device_info_, HloOpcode::kAllReduce, pointer_size_); TF_ASSIGN_OR_RETURN( bool combined_pipelined_instructions, RunWithKeyCombiner(module, execution_threads, PipelinedCombinerKey)); diff --git a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc index ad70672c659a04..48d959e02d2229 100644 --- a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc @@ -44,7 +44,7 @@ limitations under the License. #include "xla/service/gpu/autotuning/gemm_fusion_autotuner.h" #include "xla/service/gpu/cublas_padding_requirements.h" #include "xla/service/gpu/gpu_compiler.h" -#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "xla/service/gpu/llvm_gpu_backend/amdgpu_backend.h" #include "xla/service/gpu/target_constants.h" #include "xla/service/gpu/transforms/algebraic_simplifier.h" #include "xla/service/gpu/transforms/conv_padding_legalization.h" @@ -62,6 +62,7 @@ limitations under the License. #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" diff --git a/third_party/xla/xla/service/gpu/auto_sharding_gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/auto_sharding_gpu_compiler_test.cc index e363c3f0883d6f..ad5a80d836ea2b 100644 --- a/third_party/xla/xla/service/gpu/auto_sharding_gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/auto_sharding_gpu_compiler_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -26,6 +27,7 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index 690645e71573ce..6ec1f18631090b 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -7,6 +7,7 @@ load( ) load("//xla:xla.bzl", "xla_cc_test") load("//xla/tests:build_defs.bzl", "xla_test") +load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") load( "//xla/tsl/platform:build_config.bzl", "tf_proto_library", @@ -126,7 +127,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:float_normalization", + "//xla/hlo/transforms/simplifiers:float_normalization", "//xla/hlo/utils:hlo_query", "//xla/hlo/utils:hlo_traversal", "//xla/pjrt/distributed:key_value_store_interface", @@ -159,9 +160,12 @@ cc_library( "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:semantic_version", "//xla/stream_executor:stream", + "//xla/stream_executor/cuda:ptx_compiler_helpers", "//xla/stream_executor/gpu:redzone_allocator", + "//xla/stream_executor/integrations:tf_allocator_adapter", "//xla/tools:hlo_decomposer_lib", "//xla/tsl/lib/core:bits", + "//xla/tsl/platform:errors", "//xla/tsl/util/proto:proto_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -176,7 +180,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", @@ -228,10 +231,8 @@ xla_test( "//xla/stream_executor:device_description_proto_cc", "//xla/stream_executor:semantic_version", "//xla/stream_executor:stream_executor_h", - "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tools:hlo_decomposer_lib", "//xla/tsl/lib/core:status_test_util", @@ -297,6 +298,7 @@ cc_library( name = "autotuner_status_key", srcs = ["autotuner_status_key.cc"], hdrs = ["autotuner_status_key.h"], + compatible_with = get_compatible_with_portable(), deps = ["@com_google_absl//absl/strings"], ) @@ -304,7 +306,7 @@ cc_library( name = "autotuner_util", srcs = ["autotuner_util.cc"], hdrs = ["autotuner_util.h"], - tags = ["gpu"], + compatible_with = get_compatible_with_portable(), deps = [ ":autotuner_status_key", "//xla:autotune_results_proto_cc", @@ -322,7 +324,6 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", - "//xla/stream_executor/gpu:redzone_allocator", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -516,7 +517,7 @@ xla_test( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service:platform_util", diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc index d8ebefe52ba083..544f6738e6b81c 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include +#include #include -#include #include #include @@ -34,7 +34,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" -#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/maybe_owning_device_memory.h" @@ -100,7 +99,7 @@ AutotunerCompileUtil::AutotunerCompileUtil(const AutotuneConfig& config, opts_.set_xla_gpu_kernel_cache_file(""); } -absl::StatusOr> +absl::StatusOr AutotunerCompileUtil::ProfileExecutable( Executable* executable, se::Stream* stream, absl::Span input_buffers, @@ -110,18 +109,8 @@ AutotunerCompileUtil::ProfileExecutable( ExecutionInputsFromBuffers(input_buffers, input_shapes); // Warmup: in and out buffers are reused while probing different configs, // so GPU caches should be in some comparable states during measurements. - absl::StatusOr execution_output = - Execute(*executable, std::move(execution_inputs)); - if (!execution_output.ok()) { - // Treat register allocation error gracefully. If the compilation happens - // with the driver during execution then the error could surface here. - // It's enough to check this once here. - if (execution_output.status().code() == - absl::StatusCode::kResourceExhausted) { - return {std::nullopt}; - } - return execution_output.status(); - } + TF_ASSIGN_OR_RETURN(ExecutionOutput execution_output, + Execute(*executable, std::move(execution_inputs))); TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); } @@ -134,9 +123,8 @@ AutotunerCompileUtil::ProfileExecutable( TF_ASSIGN_OR_RETURN( ExecutionOutput execution_output, Execute(*executable, std::move(execution_inputs), &profile)); - return std::make_optional( - absl::Nanoseconds(profile.compute_time_ns()), - execution_output.Commit().ConsumeResult()); + return ProfilingOutput(absl::Nanoseconds(profile.compute_time_ns()), + execution_output.Commit().ConsumeResult()); } absl::StatusOr> AutotunerCompileUtil::Compile( @@ -168,11 +156,11 @@ absl::StatusOr> AutotunerCompileUtil::ExtractModule( return extractor(opts_); } -/*static*/ absl::StatusOr> -AutotunerCompileUtil::Create(const AutotuneConfig& config, - const DebugOptions& opts) { +/*static*/ absl::StatusOr AutotunerCompileUtil::Create( + const AutotuneConfig& config, const DebugOptions& opts) { if (config.IsDeviceless()) { - return std::nullopt; + return absl::InvalidArgumentError( + "Deviceless autotuning is not supported."); } se::StreamExecutor* stream_exec = config.GetExecutor(); se::DeviceMemoryAllocator* allocator = config.GetAllocator(); @@ -208,11 +196,13 @@ absl::StatusOr RedzoneBuffers::FromInstruction( const HloInstruction& instruction, const AutotuneConfig& config, const DebugOptions& debug_options, BuffersToCreate buffers_to_create) { RedzoneBuffers buffers; - - TF_ASSIGN_OR_RETURN(auto rz_allocator, AutotunerUtil::CreateRedzoneAllocator( - config, debug_options)); - buffers.redzone_allocator_ = - std::make_unique(std::move(rz_allocator)); + TF_ASSIGN_OR_RETURN(se::Stream * stream, config.GetStream()); + buffers.redzone_allocator_ = std::make_unique( + stream, config.GetAllocator(), + /*memory_limit=*/std::numeric_limits::max(), + /*redzone_size=*/config.should_check_correctness() + ? debug_options.xla_gpu_redzone_padding_bytes() + : 0); int64_t rng_state = 0; @@ -235,8 +225,8 @@ absl::Status RedzoneBuffers::CreateInputs(const HloInstruction& instruction, for (const auto* operand : instruction.operands()) { TF_ASSIGN_OR_RETURN( se::DeviceMemoryBase buf, - AutotunerUtil::CreateBuffer(*redzone_allocator_, operand->shape(), - config, rng_state)); + redzone_allocator_->CreateBuffer( + operand->shape(), config.should_init_buffers(), rng_state)); input_buffers_.push_back(buf); input_shapes_.push_back(operand->shape()); } @@ -251,8 +241,8 @@ absl::Status RedzoneBuffers::CreateOutputs(const HloInstruction& instruction, if (!instruction.shape().IsTuple()) { TF_ASSIGN_OR_RETURN( se::DeviceMemoryBase buf, - AutotunerUtil::CreateBuffer(*redzone_allocator_, instruction.shape(), - config, rng_state)); + redzone_allocator_->CreateBuffer( + instruction.shape(), config.should_init_buffers(), rng_state)); output_buffers_.push_back(buf); output_shape_ = instruction.shape(); return absl::OkStatus(); @@ -275,8 +265,8 @@ absl::Status RedzoneBuffers::CreateOutputs(const HloInstruction& instruction, } TF_ASSIGN_OR_RETURN( se::DeviceMemoryBase buf, - AutotunerUtil::CreateBuffer(*redzone_allocator_, *current_shape_it, - config, rng_state)); + redzone_allocator_->CreateBuffer( + *current_shape_it, config.should_init_buffers(), rng_state)); output_buffers_.push_back(buf); } return absl::OkStatus(); diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.h b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.h index 08061abb40a1c8..0e0fcc712a6eb9 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.h +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include @@ -61,10 +60,7 @@ class AutotunerCompileUtil { const DebugOptions&)>; // Generates a compile util for a platform associated with the `stream`. - // - // Returns an empty optional if the AutotuneConfig is deviceless, as - // autotuning is impossible in that case. - static absl::StatusOr> Create( + static absl::StatusOr Create( const AutotuneConfig& config, const DebugOptions& opts); struct ProfilingOutput { @@ -79,9 +75,8 @@ class AutotunerCompileUtil { // `extractor`. // // Runs the resulting executable with the given extractor, cached with - // `(cache_key, config)`. Returns `std::nullopt` on expected failure, bad - // `Status` otherwise. - absl::StatusOr> ProfileExecutable( + // `(cache_key, config)`. + absl::StatusOr ProfileExecutable( Executable* executable, se::Stream* stream, absl::Span input_buffers, absl::Span input_shapes); diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc index fd55b439c3b0ce..94bc2ede39315d 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -46,14 +45,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/dump.h" #include "xla/service/gpu/autotuning/autotuner_status_key.h" -#include "xla/service/gpu/stream_executor_util.h" -#include "xla/shape.h" -#include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/gpu/redzone_allocator.h" -#include "xla/stream_executor/stream.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "tsl/platform/base64.h" @@ -127,7 +120,7 @@ ResultAndInserted AddResultToInMemoryCache(const AutotuneCacheKey& key, absl::Status AddResultToFileBasedCacheIfEnabled( const AutotuneCacheKey& key, AutotuneResult result, - std::string_view cache_dir, + absl::string_view cache_dir, DebugOptions::AutotuneCacheMode autotune_cache_mode) ABSL_LOCKS_EXCLUDED(autotune_cache_mu) { if (cache_dir.empty() || @@ -170,7 +163,7 @@ absl::Status AddResultToFileBasedCacheIfEnabled( absl::StatusOr AddResultToCaches( const AutotuneCacheKey& key, AutotuneResult result, - std::string_view cache_dir, + absl::string_view cache_dir, DebugOptions::AutotuneCacheMode autotune_cache_mode) ABSL_LOCKS_EXCLUDED(autotune_cache_mu) { ResultAndInserted result_and_inserted = AddResultToInMemoryCache(key, result); @@ -299,18 +292,6 @@ void SerializeAutotuneEntry(AutotuneResults* results, const AutotuneCacheKey& k, return autotune_cache.empty(); } -/* static*/ absl::StatusOr AutotunerUtil::CreateBuffer( - se::RedzoneAllocator& allocator, const Shape& shape, - const AutotuneConfig& config, int64_t& rng_state) { - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer, - allocator.AllocateBytes(ShapeUtil::ByteSizeOf(shape))); - if (config.should_init_buffers()) { - InitializeBuffer(allocator.stream(), shape.element_type(), &rng_state, - buffer); - } - return buffer; -} - namespace { std::string ToCanonicalString(const HloInstruction* instr) { auto options = HloPrintOptions::Canonical(); @@ -576,18 +557,6 @@ bool IsTextProtoPath(absl::string_view file_path) { return absl::OkStatus(); } -/*static*/ absl::StatusOr -AutotunerUtil::CreateRedzoneAllocator(const AutotuneConfig& config, - const DebugOptions& opts) { - TF_ASSIGN_OR_RETURN(se::Stream * stream, config.GetStream()); - return se::RedzoneAllocator( - stream, config.GetAllocator(), - /*memory_limit=*/std::numeric_limits::max(), - /*redzone_size=*/config.should_check_correctness() - ? opts.xla_gpu_redzone_padding_bytes() - : 0); -} - /*static*/ AutotunerUtil::CacheStats AutotunerUtil::GetCacheStats() { absl::MutexLock lock(&autotune_cache_mu); return autotune_cache_stats; diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h index 8aa7f9bfb5be96..3dd57d9df4a4d8 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h @@ -33,11 +33,8 @@ limitations under the License. #include "xla/autotune_results.pb.h" #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/shape.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/gpu/redzone_allocator.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/xla.pb.h" @@ -203,12 +200,6 @@ class AutotuneConfig { using AutotuneNoCacheFn = std::function()>; struct AutotunerUtil { - // Create a buffer for a given operation using redzone checker, initialize - // based on a given rng state. - static absl::StatusOr CreateBuffer( - se::RedzoneAllocator& allocator, const Shape& shape, - const AutotuneConfig& config, int64_t& rng_state); - static absl::StatusOr Autotune( const HloInstruction* instr, const AutotuneConfig& config, const AutotuneNoCacheFn& autotune_fn); @@ -234,10 +225,6 @@ struct AutotunerUtil { AutotuneResult result, const AutotuneConfig& config); - // Creates a RedzoneAllocator from a given config. - static absl::StatusOr CreateRedzoneAllocator( - const AutotuneConfig& config, const DebugOptions& opts); - // Functions to save/load XLA's autotuning results. // // This is used for ahead-of-time autotuning. Specifically: diff --git a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc index 14407847cdfeb2..1e1aff8f70172e 100644 --- a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc @@ -23,7 +23,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -348,7 +347,7 @@ void PrintPlatformInfo(const se::Stream* stream) { // "input/output" or "scratch". absl::StatusOr CheckRedzones(const se::RedzoneAllocator& allocator, se::Stream* stream, absl::string_view name, - std::string_view instr_str, + absl::string_view instr_str, AutotuneResult* result) { XLA_SCOPED_LOGGING_TIMER_LEVEL("CudnnConvAlgorithmPicker checking redzones", 2); @@ -585,10 +584,14 @@ absl::StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( "Disqualified for implicit RELU."); } - TF_ASSIGN_OR_RETURN( - se::RedzoneAllocator scratch_allocator, - AutotunerUtil::CreateRedzoneAllocator( - config_, runtime_arguments.hlo_module_config.debug_options())); + TF_ASSIGN_OR_RETURN(se::Stream * stream, config_.GetStream()); + se::RedzoneAllocator scratch_allocator( + stream, config_.GetAllocator(), + /*memory_limit=*/std::numeric_limits::max(), + /*redzone_size=*/config_.should_check_correctness() + ? runtime_arguments.hlo_module_config.debug_options() + .xla_gpu_redzone_padding_bytes() + : 0); se::dnn::ProfileResult profile_result; VLOG(4) << "Trying algorithm " << alg.ToString() << " for " << instr_str; @@ -627,8 +630,6 @@ absl::StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( std::vector result_buffers = runtime_arguments.rz_buffers.output_buffers(); - TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream()); - // Dry-run to warmup the plan. launch_status = RunGpuConv(config, operand_buffers, result_buffers, scratch_memory, stream, options); diff --git a/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc index 164252eb83312a..67c8496b8a3557 100644 --- a/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc @@ -97,13 +97,12 @@ absl::StatusOr>> ProfileKernels( *fusion_instruction, autotune_config, debug_options, RedzoneBuffers::kAllInputs)); - std::optional reference_buffer; - std::optional profiling_output; - TF_ASSIGN_OR_RETURN(profiling_output, compile_util.ProfileExecutable( - executable->get(), stream, - rz_buffers.input_buffers(), - rz_buffers.input_shapes())); - results.push_back({i, profiling_output->duration}); + TF_ASSIGN_OR_RETURN( + AutotunerCompileUtil::ProfilingOutput profiling_output, + compile_util.ProfileExecutable(executable->get(), stream, + rz_buffers.input_buffers(), + rz_buffers.input_shapes())); + results.push_back({i, profiling_output.duration}); } return results; } @@ -225,9 +224,8 @@ absl::StatusOr CustomKernelFusionAutotuner::Run( } const DebugOptions& debug_options = module->config().debug_options(); - TF_ASSIGN_OR_RETURN(std::optional compile_util, + TF_ASSIGN_OR_RETURN(AutotunerCompileUtil compile_util, AutotunerCompileUtil::Create(config_, debug_options)); - TF_RET_CHECK(compile_util.has_value()); bool hlo_changed = false; for (const HloComputation* computation : module->computations()) { @@ -235,7 +233,7 @@ absl::StatusOr CustomKernelFusionAutotuner::Run( TF_ASSIGN_OR_RETURN( bool instruction_changed, AutotuneCustomKernelFusion(computation->FusionInstruction(), config_, - compile_util.value(), debug_options)); + compile_util, debug_options)); if (instruction_changed) { hlo_changed = true; } diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.h b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.h index 138209106b5bc0..40a57e0293a947 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.h +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.h @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc index 6526e3338fb6c5..fab06bc5bdec35 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc @@ -68,7 +68,7 @@ class GemmAlgorithmPickerTest : public HloTestBase, } void SetUp() override { - std::string_view name = + absl::string_view name = ::testing::UnitTest::GetInstance()->current_test_info()->name(); // We need special handling for BlasGetVersion test. bool blas_get_version = name.rfind("BlasGetVersion") == 0; diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index 4b7fd260d02608..39f31c0dcf0b5f 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -37,6 +37,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/blocking_counter.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/span.h" @@ -86,19 +87,21 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/stream_executor/cuda/ptx_compiler_helpers.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/gpu/redzone_allocator.h" +#include "xla/stream_executor/integrations/tf_allocator_adapter.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream.h" #include "xla/tools/hlo_decomposer.h" #include "xla/tsl/lib/core/bits.h" +#include "xla/tsl/platform/errors.h" #include "xla/tsl/util/proto/proto_utils.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/blocking_counter.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" @@ -316,8 +319,7 @@ absl::StatusOr> CublasGemmAutotuneExtractor( const AutotuneConfig& config, const se::DeviceDescription& gpu_device_info, const se::SemanticVersion& toolkit_version, const HloFusionInstruction* fusion, const DebugOptions& debug_opts) { - const HloComputation* fusion_computation = - fusion->called_computations().at(0); + const HloComputation* fusion_computation = fusion->called_computation(); std::unique_ptr new_module = ExtractComputationIntoNewModule(*fusion_computation); new_module->mutable_config().set_debug_options(debug_opts); @@ -482,17 +484,18 @@ absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config, if (result.has_algorithm()) { return CuDnnFusionExtractor(*fusion, debug_opts, result.algorithm().algo_id()); - } else if (result.has_triton()) { + } + if (result.has_triton()) { return TritonGemmAutotuneExtractor( triton_gemm_config, device_desc, fusion, debug_opts, /*allow_filtering_kernels_spilling_registers=*/true); - } else if (result.has_gemm()) { + } + if (result.has_gemm()) { return CublasGemmAutotuneExtractor(autotune_config, device_desc, toolkit_version, fusion, debug_opts); - } else { - LOG(FATAL) << "Unknown result type: " << result.DebugString(); } + LOG(FATAL) << "Unknown result type: " << result.DebugString(); })); module->set_name(std::string(fusion->name())); // Using the original module for its debug info and name in the first @@ -508,19 +511,32 @@ absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config, return absl::OkStatus(); } +std::string ConfigToString(const BackendConfig& config) { + if (std::holds_alternative(config)) { + return std::get(config).ToString(); + } + if (std::holds_alternative(config)) { + return absl::StrFormat( + "cuDNN plan %d", + std::get(config).plan_id); + } + if (std::holds_alternative(config)) { + return "reference (cublas)"; + } + LOG(FATAL) << "Unsupported config type: " << config.index(); +} + std::string Serialize(const BackendConfig& config) { - if (auto triton_config = std::get_if(&config)) { + if (auto* triton_config = std::get_if(&config)) { tsl::protobuf::TextFormat::Printer printer; printer.SetSingleLineMode(true); std::string result; printer.PrintToString(triton_config->ToProto(), &result); return result; } - return GemmFusionAutotunerImpl::ToString(config); + return ConfigToString(config); } -} // anonymous namespace - absl::Status RewriteGemmFusionToCall(HloInstruction* fusion_instr) { // Falling back to cuBLAS: Converting the fusion to a Call, so that it // can be inlined back again. @@ -563,6 +579,8 @@ absl::Status HandleTritonGemm(HloInstruction* fusion_instr, return absl::OkStatus(); } +} // anonymous namespace + absl::Status GemmFusionAutotunerRewriterVisitor::HandleFusion( HloInstruction* fusion_instr) { TF_ASSIGN_OR_RETURN(auto gpu_config, @@ -672,24 +690,9 @@ bool GemmFusionAutotunerImpl::IsAutotuningEnabled() const { !debug_options_.xla_gpu_deterministic_ops(); } -/*static*/ std::string GemmFusionAutotunerImpl::ToString( - const BackendConfig& config) { - if (std::holds_alternative(config)) { - return std::get(config).ToString(); - } else if (std::holds_alternative(config)) { - return absl::StrFormat("cuDNN plan %d", - std::get(config).plan_id); - } else if (std::holds_alternative(config)) { - return "reference (cublas)"; - } else { - LOG(FATAL) << "Unsupported config type: " << config.index(); - } -} - -std::vector GenerateCustomKernelFusionConfigs( +static std::vector GenerateCustomKernelFusionConfigs( const HloFusionInstruction& fusion, se::DeviceDescription device_description) { - std::vector configs; const CustomKernelFusionPatternRegistry* patterns = CustomKernelFusionPatternRegistry::Default(); HloComputation* computation = fusion.called_computation(); @@ -700,53 +703,60 @@ std::vector GenerateCustomKernelFusionConfigs( patterns->Match(device_description, dot_instruction); // For Cutlass we expect only one match for a GEMM fusion. - if (match.size() == 1) { - CustomKernelFusionRegistry* registry = - CustomKernelFusionRegistry::Default(); - auto* custom_kernel_fusion = registry->Lookup(match[0].config().name()); - - // If custom fusion is not found it means that some of the build targets - // might not be statically linked into the binary. - if (custom_kernel_fusion != nullptr) { - // There can be multiple kernels for a single fusion pattern, which are - // selected by the kernel_index. - // To get the number of kernels we can rewrite the fusion to custom kernel - // fusion and count the number of loaded kernels. - const HloComputation* fusion_computation = fusion.called_computation(); - std::unique_ptr new_module = - ExtractComputationIntoNewModule(*fusion_computation); - CustomKernelFusionRewriter rewriter(&device_description); - absl::StatusOr changed = rewriter.Run(new_module.get()); - if (!changed.ok() || !changed.value()) { - VLOG(2) << "Skip custom kernel config. Failed to rewrite custom kernel " - "fusion: " - << changed.status(); - return configs; - } + if (match.size() != 1) { + return {}; + } - HloInstruction* custom_kernel_fusion_instr = - hlo_query::GetFirstInstructionWithOpcode( - *new_module->entry_computation(), HloOpcode::kFusion); - if (custom_kernel_fusion_instr == nullptr) { - VLOG(2) << "Skip custom kernel config. Failed to find custom kernel " - "fusion instruction in the rewritten module."; - return configs; - } - absl::StatusOr> kernels = - custom_kernel_fusion->LoadKernels( - device_description, - custom_kernel_fusion_instr->fused_instructions_computation()); - if (!kernels.ok()) { - VLOG(2) << "Skip custom kernel config. Failed to load custom kernels: " - << kernels.status(); - } else { - for (int i = 0; i < kernels.value().size(); ++i) { - GemmFusionAutotunerImpl::CustomKernelFusionConfig config{ - /*kernel_index=*/i}; - configs.push_back(config); - } - } - } + CustomKernelFusionRegistry* registry = CustomKernelFusionRegistry::Default(); + auto* custom_kernel_fusion = registry->Lookup(match[0].config().name()); + + // If custom fusion is not found it means that some of the build targets + // might not be statically linked into the binary. + if (custom_kernel_fusion == nullptr) { + return {}; + } + + // There can be multiple kernels for a single fusion pattern, which are + // selected by the kernel_index. + // To get the number of kernels we can rewrite the fusion to custom kernel + // fusion and count the number of loaded kernels. + const HloComputation* fusion_computation = fusion.called_computation(); + std::unique_ptr new_module = + ExtractComputationIntoNewModule(*fusion_computation); + CustomKernelFusionRewriter rewriter(&device_description); + absl::StatusOr changed = rewriter.Run(new_module.get()); + if (!changed.ok() || !*changed) { + VLOG(2) << "Skip custom kernel config. Failed to rewrite custom kernel " + "fusion: " + << changed.status(); + return {}; + } + + HloInstruction* custom_kernel_fusion_instr = + hlo_query::GetFirstInstructionWithOpcode(*new_module->entry_computation(), + HloOpcode::kFusion); + if (custom_kernel_fusion_instr == nullptr) { + VLOG(2) << "Skip custom kernel config. Failed to find custom kernel " + "fusion instruction in the rewritten module."; + return {}; + } + + absl::StatusOr> kernels = + custom_kernel_fusion->LoadKernels( + device_description, + custom_kernel_fusion_instr->fused_instructions_computation()); + if (!kernels.ok()) { + VLOG(2) << "Skip custom kernel config. Failed to load custom kernels: " + << kernels.status(); + return {}; + } + + std::vector configs; + configs.reserve(kernels.value().size()); + for (int i = 0; i < kernels.value().size(); ++i) { + GemmFusionAutotunerImpl::CustomKernelFusionConfig config{ + /*kernel_index=*/i}; + configs.push_back(config); } return configs; @@ -756,7 +766,7 @@ absl::StatusOr> GemmFusionAutotunerImpl::GenerateConfigs(const HloFusionInstruction& fusion) { const HloDotInstruction* dot = Cast(hlo_query::GetFirstInstructionWithOpcode( - *fusion.called_computations().at(0), HloOpcode::kDot)); + *fusion.called_computation(), HloOpcode::kDot)); std::vector configs; if (!debug_options_.xla_gpu_experimental_disable_binary_libraries()) { @@ -902,7 +912,7 @@ absl::StatusOr> results; @@ -931,48 +941,41 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, auto compile = [&](const HloFusionInstruction* fusion, const BackendConfig& config, bool allow_filtering_kernels_spilling_registers) - -> absl::StatusOr { - std::unique_ptr executable; + -> absl::StatusOr> { if (std::holds_alternative(config)) { - TF_ASSIGN_OR_RETURN(executable, - compile_util.Compile([&](const DebugOptions& opts) { - return TritonGemmAutotuneExtractor( - std::get(config), - config_.GetDeviceDescription(), fusion, opts, - allow_filtering_kernels_spilling_registers); - })); - } else if (std::holds_alternative(config)) { - executable = - compile_util - .Compile([&](const DebugOptions& opts) { - return CuDnnFusionExtractor( - *fusion, opts, std::get(config).plan_id); - }) - .value_or(nullptr); - } else if (std::holds_alternative(config)) { - TF_ASSIGN_OR_RETURN( - executable, compile_util.Compile([&](const DebugOptions& opts) { - return CublasGemmAutotuneExtractor(config_, - config_.GetDeviceDescription(), - toolkit_version_, fusion, opts); - })); - } else if (std::holds_alternative(config)) { - TF_ASSIGN_OR_RETURN(executable, - compile_util.Compile([&](const DebugOptions& opts) { - return CustomFusionKernelAutotuneExtractor( - std::get(config), - config_, toolkit_version_, fusion, opts); - })); + return compile_util.Compile([&](const DebugOptions& opts) { + return TritonGemmAutotuneExtractor( + std::get(config), config_.GetDeviceDescription(), + fusion, opts, allow_filtering_kernels_spilling_registers); + }); + } - } else { - LOG(FATAL) << "Unsupported config type: " << config.index(); + if (std::holds_alternative(config)) { + return compile_util + .Compile([&](const DebugOptions& opts) { + return CuDnnFusionExtractor(*fusion, opts, + std::get(config).plan_id); + }) + .value_or(nullptr); } - if (executable != nullptr) { - absl::MutexLock lock(&results_mu); - results[fusion].push_back({config, std::move(executable)}); - return true; + + if (std::holds_alternative(config)) { + return compile_util.Compile([&](const DebugOptions& opts) { + return CublasGemmAutotuneExtractor(config_, + config_.GetDeviceDescription(), + toolkit_version_, fusion, opts); + }); } - return false; + + if (std::holds_alternative(config)) { + return compile_util.Compile([&](const DebugOptions& opts) { + return CustomFusionKernelAutotuneExtractor( + std::get(config), config_, + toolkit_version_, fusion, opts); + }); + } + + LOG(FATAL) << "Unsupported config type: " << config.index(); }; // If the thread pool has only one thread, then it is actually slower to @@ -988,7 +991,8 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, << " fusions on " << thread_pool_->NumThreads() << " threads."; } - tsl::BlockingCounter counter(config_count); + absl::BlockingCounter counter(config_count); + absl::Mutex results_mu; for (const auto& key_value : task) { const HloFusionInstruction* fusion = key_value.first; const std::vector& gemm_config_set = key_value.second; @@ -1005,14 +1009,18 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, "last configuration printed out might not be the one " "causing issues! Use " "--xla_gpu_force_compilation_parallelism=1 to fix."; - absl::StatusOr has_executable = + absl::StatusOr> executable = compile(fusion, config, gemm_config_set.size() > 1); - TF_CHECK_OK(has_executable.status()) + TF_CHECK_OK(executable.status()) << " - Failure occured when compiling fusion " << fusion->name() - << " with config '" << ToString(config) + << " with config '" << ConfigToString(config) << "'\nFused HLO computation:\n" << fusion->fused_instructions_computation()->ToString(); - log(has_executable.value()); + log(*executable != nullptr); + if (*executable != nullptr) { + absl::MutexLock lock(&results_mu); + results[fusion].push_back({config, std::move(*executable)}); + } counter.DecrementCount(); }); } @@ -1037,9 +1045,12 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, "--xla_gpu_override_gemm_autotuner='" << Serialize(config) << "'"; TF_ASSIGN_OR_RETURN( - bool has_executable, + std::unique_ptr executable, compile(fusion, config, gemm_config_set.size() > 1)); - log(has_executable); + log(executable != nullptr); + if (executable != nullptr) { + results[fusion].push_back({config, std::move(executable)}); + } } } } @@ -1052,8 +1063,7 @@ absl::Status GemmFusionAutotunerImpl::CompareBuffers( const HloFusionInstruction& fusion, const ScopedShapedBuffer& reference_buffer, const ScopedShapedBuffer& buffer, AutotuneResult& res) { - const HloComputation* fusion_computation = fusion.called_computations().at(0); - const HloInstruction& root = *fusion_computation->root_instruction(); + const HloInstruction& root = *fusion.called_computation_root(); BufferComparator comparator(root.shape(), debug_options_.xla_gpu_autotune_gemm_rtol()); TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream()); @@ -1089,8 +1099,7 @@ absl::StatusOr GemmFusionAutotunerImpl::CheckRedZones( return false; } -absl::StatusOr> -GemmFusionAutotunerImpl::MeasurePerformance( +absl::StatusOr GemmFusionAutotunerImpl::MeasurePerformance( AutotunerCompileUtil& compile_util, const HloFusionInstruction& fusion, const ExecutableCandidate& candidate, std::optional& reference_buffer) { @@ -1100,40 +1109,36 @@ GemmFusionAutotunerImpl::MeasurePerformance( } TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream()); - VLOG(5) << "Trying : " << ToString(candidate.config); + VLOG(5) << "Trying : " << ConfigToString(candidate.config); AutotuneResult res = FromConfig(candidate.config); - const HloComputation* fusion_computation = fusion.called_computations().at(0); + const HloComputation* fusion_computation = fusion.called_computation(); TF_ASSIGN_OR_RETURN(auto rz_buffers, RedzoneBuffers::FromInstruction( *fusion_computation->FusionInstruction(), config_, debug_options_, RedzoneBuffers::kAllInputs)); - std::optional profiling_output; - TF_ASSIGN_OR_RETURN(profiling_output, compile_util.ProfileExecutable( - candidate.executable.get(), stream, - rz_buffers.input_buffers(), - rz_buffers.input_shapes())); - - if (!profiling_output) { - VLOG(5) << "Skipping this tiling." << ToString(candidate.config); - return std::nullopt; - } - VLOG(5) << "Running the kernel took: " << profiling_output->duration; - LOG_IF(WARNING, profiling_output->duration >= absl::Seconds(1)) - << "Slow kernel for " << fusion.called_computations()[0]->ToString() - << " took: " << profiling_output->duration << ". " - << ToString(candidate.config); + TF_ASSIGN_OR_RETURN( + ProfilingOutput profiling_output, + compile_util.ProfileExecutable(candidate.executable.get(), stream, + rz_buffers.input_buffers(), + rz_buffers.input_shapes())); + + VLOG(5) << "Running the kernel took: " << profiling_output.duration; + LOG_IF(WARNING, profiling_output.duration >= absl::Seconds(1)) + << "Slow kernel for " << fusion.called_computation()->ToString() + << " took: " << profiling_output.duration << ". " + << ConfigToString(candidate.config); *res.mutable_run_time() = - tsl::proto_utils::ToDurationProto(profiling_output->duration); + tsl::proto_utils::ToDurationProto(profiling_output.duration); if (!config_.should_check_correctness()) { return res; } if (std::holds_alternative(candidate.config)) { - reference_buffer = std::move(profiling_output->output); + reference_buffer = std::move(profiling_output.output); return res; } @@ -1144,7 +1149,7 @@ GemmFusionAutotunerImpl::MeasurePerformance( if (!rz_ok) return res; TF_RETURN_IF_ERROR(CompareBuffers(fusion, *reference_buffer, - profiling_output->output, res)); + profiling_output.output, res)); } return res; } @@ -1156,19 +1161,35 @@ absl::StatusOr> GemmFusionAutotunerImpl::Profile( return absl::StrFormat("XlaAutotunerMeasurement:#hlo_op=%s#", fusion.name()); }); + VLOG(2) << "Profiling " << fusion.name() << "."; std::vector results; std::optional reference_buffer; - for (const ExecutableCandidate& candidate : candidates) { - TF_ASSIGN_OR_RETURN( - auto result, - MeasurePerformance(compile_util, fusion, candidate, reference_buffer)); - VLOG(2) << "Ran " << results.size() + 1 << " configs of " - << candidates.size() << "."; - if (result.has_value()) { - results.push_back(std::move(*result)); + for (int i = 0; i < candidates.size(); ++i) { + absl::StatusOr result = MeasurePerformance( + compile_util, fusion, candidates[i], reference_buffer); + // Treat register allocation error gracefully. If the compilation happens + // with the driver during execution then the error could surface here. + // It's enough to check this once here. + if (stream_executor::IsPtxRegisterAllocationError(result.status())) { + VLOG(5) << "Skipping candidate: " << ConfigToString(candidates[i].config) + << ": " << result.status(); + continue; + } + + if (stream_executor::IsMemoryAllocationError(result.status()) && + reference_buffer.has_value()) { + LOG(WARNING) + << "Autotuning candidate failed with out of memory error. Consider " + "disabling correctness checking (i.e. --xla_gpu_autotune_level=3) " + "to reduce autotuning memory usage."; } + + VLOG(2) << "Ran " << i + 1 << " configs out of " << candidates.size() + << "."; + TF_RETURN_IF_ERROR(result.status()); + results.push_back(std::move(*result)); } - VLOG(2) << "Done running."; + VLOG(2) << "Done profiling " << fusion.name() << "."; return results; } @@ -1222,8 +1243,8 @@ GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const { return configs; } -absl::Status DumpAutotuningLogs(const DebugOptions& debug_opts, - const AutotuningLogs& autotuning_logs) { +static absl::Status DumpAutotuningLogs(const DebugOptions& debug_opts, + const AutotuningLogs& autotuning_logs) { if (absl::string_view file_path = debug_opts.xla_gpu_dump_autotune_logs_to(); !file_path.empty()) { std::string resolved_path; @@ -1269,8 +1290,7 @@ absl::StatusOr GemmFusionAutotunerImpl::Autotune( results.erase(results.begin()); } - const HloInstruction* root = - fusion->called_computations().at(0)->root_instruction(); + const HloInstruction* root = fusion->called_computation_root(); TF_ASSIGN_OR_RETURN( AutotuneResult best, PickBestResult(results, root->ToString(), root->GetModule()->config())); @@ -1334,10 +1354,11 @@ static BackendConfigs TrimConfigs(const BackendConfigs& gemm_config_sets, } // Exchange the results with the other ranks. -absl::Status ExchangeResults(KeyValueStoreInterface& key_value_store, - const AutotuneCacheKeySet& keys_to_send, - absl::string_view fusion_set_fingerprint, - const int shard_index, const int shard_count) { +static absl::Status ExchangeResults(KeyValueStoreInterface& key_value_store, + const AutotuneCacheKeySet& keys_to_send, + absl::string_view fusion_set_fingerprint, + const int shard_index, + const int shard_count) { AutotuneResults results; TF_RETURN_IF_ERROR( AutotunerUtil::SerializeAutotuneResults(&results, &keys_to_send)); @@ -1413,9 +1434,8 @@ absl::StatusOr GemmFusionAutotuner::Run( TF_RETURN_IF_ERROR(AutotunerUtil::AddResult(key, res, config_).status()); } } else if (!config_.IsDeviceless()) { - TF_ASSIGN_OR_RETURN(std::optional opt_compile_util, + TF_ASSIGN_OR_RETURN(AutotunerCompileUtil compile_util, AutotunerCompileUtil::Create(config_, debug_options)); - TF_RET_CHECK(opt_compile_util.has_value()); std::string correctness_check_str = config_.should_check_correctness() ? "(with correctness check)" : "(without correctness check)"; @@ -1447,7 +1467,7 @@ absl::StatusOr GemmFusionAutotuner::Run( gemm_config_sets.size(), total_fusion_count, module->name(), correctness_check_str); TF_ASSIGN_OR_RETURN(const AutotuneCacheKeySet added_keys, - autotuner.Autotune(*opt_compile_util, gemm_config_sets, + autotuner.Autotune(compile_util, gemm_config_sets, std::move(fusion_count_map))); VLOG(1) << "Done autotuning."; diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h index 8b86a2d553388b..87dddee19589a9 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h @@ -156,7 +156,6 @@ class GemmFusionAutotunerImpl { // Helper methods. const AutotuneConfig& GetConfig() const { return config_; } bool IsAutotuningEnabled() const; - static std::string ToString(const BackendConfig& config); static const int64_t BLAS_GEMM_DEFAULT; @@ -168,7 +167,7 @@ class GemmFusionAutotunerImpl { // // If the candidate is not cuBLAS, this will check the redzones and compare // the outputs with the reference buffer. - absl::StatusOr> MeasurePerformance( + absl::StatusOr MeasurePerformance( AutotunerCompileUtil& compile_util, const HloFusionInstruction& fusion, const ExecutableCandidate& candidate, std::optional& reference_buffer); diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_cuda.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_cuda.cc index 6689ccb96004f9..cc9eec78bbb811 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_cuda.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_cuda.cc @@ -52,10 +52,13 @@ bool GemmFusionAutotunerImpl::AddLibConfigs( std::vector& configs) { // Add cuDNN plans, if available. auto cc = std::get(GetComputeCapability()); - bool is_hopper = !config_.IsDeviceless() && cc.IsAtLeastHopper(); bool is_cudnn_enabled = - debug_options_.xla_gpu_cudnn_gemm_fusion_level() > 0 && is_hopper && - GetDnnVersionInfoOrDefault(config_.GetExecutor()).major_version() >= 9; + !config_.IsDeviceless() && + GetDnnVersionInfoOrDefault(config_.GetExecutor()).major_version() >= 9 && + ((cc.IsAtLeastAmpere() && + debug_options_.xla_gpu_cudnn_gemm_fusion_level() > 1) || + (cc.IsAtLeastBlackwell() && + debug_options_.xla_gpu_cudnn_gemm_fusion_level() > 0)); if ((IsFusionKind(fusion, kCuDnnFusionKind) && IsAutotuningEnabled()) || (IsFusionKind(fusion, kTritonGemmFusionKind) && is_cudnn_enabled && algorithm_util::IsSupportedByCudnn( diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index 5411f53bdb02a0..f2c2e726be8719 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -1210,7 +1210,7 @@ TEST_F(GemmFusionAutotunerTest, SplitKFLoatNormalization) { GemmFusionAutotunerImpl autotuner(autotune_config, GetToolkitVersion(), GetDebugOptionsForTest(), nullptr); TF_ASSERT_OK_AND_ASSIGN( - auto compile_util, + AutotunerCompileUtil compile_util, AutotunerCompileUtil::Create(autotune_config, GetDebugOptionsForTest())) std::unique_ptr module = ParseAndReturnVerifiedModule(R"( @@ -1241,7 +1241,7 @@ ENTRY entry { /*num_stages=*/1, /*num_warps=*/4, /*num_ctas=*/1))}); - CHECK_OK(autotuner.CompileAll(*compile_util, configs)); + CHECK_OK(autotuner.CompileAll(compile_util, configs)); } TEST_F(GemmFusionAutotunerTest, CreatesCustomKernelFusionConfigs) { diff --git a/third_party/xla/xla/service/gpu/backend_configs.proto b/third_party/xla/xla/service/gpu/backend_configs.proto index 8ac46231e7fb11..906baaa33d512c 100644 --- a/third_party/xla/xla/service/gpu/backend_configs.proto +++ b/third_party/xla/xla/service/gpu/backend_configs.proto @@ -127,6 +127,8 @@ message CollectiveBackendConfig { // Determines whether the collective op of interested has been pipelined // within a loop. bool is_pipelined = 3; + // Cost model prediction. + ReificationCost reification_cost = 4; } // Backend config for cost model estimates. @@ -268,6 +270,11 @@ message CudnnfMHABackendConfig { // Sliding window length // ignored if the value <= 0 int32 sliding_window_length = 24; + + // The maximum number of segments in each batch + // Only used with packed layout + // ignored if the valued <= 1 + int32 max_seg_per_batch = 25; } // Backend config for a general custom call instruction, e.g. XLA FFI. diff --git a/third_party/xla/xla/service/gpu/buffer_comparator.cc b/third_party/xla/xla/service/gpu/buffer_comparator.cc index f6a942b93ea7e9..5c4ea65739938f 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator.cc +++ b/third_party/xla/xla/service/gpu/buffer_comparator.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -60,7 +59,7 @@ struct ComparisonParams { // // Returns `true` if two buffers are equal, `false` otherwise. template -static absl::StatusOr DeviceCompare(std::string_view kernel_name, +static absl::StatusOr DeviceCompare(absl::string_view kernel_name, void* kernel_symbol, const ComparisonParams& params) { se::StreamExecutor* executor = params.stream->parent(); @@ -92,8 +91,8 @@ static absl::StatusOr DeviceCompare(std::string_view kernel_name, CalculateLaunchDimensions(*params.shape, gpu_device_info); se::DeviceMemory as_uint64(out.memory()); - TF_RETURN_IF_ERROR(params.stream->ThenLaunch( - dim.thread_counts_per_block(), dim.block_counts(), comparison_kernel, + TF_RETURN_IF_ERROR(comparison_kernel.Launch( + dim.thread_counts_per_block(), dim.block_counts(), params.stream, current_typed, expected_typed, static_cast(params.relative_tol), buffer_size, as_uint64)); @@ -163,7 +162,7 @@ static absl::StatusOr HostCompare(const ComparisonParams& params) { template static absl::StatusOr CompareEqualParameterized( - std::string_view kernel_name, void* kernel_symbol, + absl::string_view kernel_name, void* kernel_symbol, const ComparisonParams& params) { XLA_SCOPED_LOGGING_TIMER("BufferComparator::CompareEqual"); TF_ASSIGN_OR_RETURN( diff --git a/third_party/xla/xla/service/gpu/conv_layout_normalization_test.cc b/third_party/xla/xla/service/gpu/conv_layout_normalization_test.cc index 12bdf5194b2dbf..a5cfa74b066006 100644 --- a/third_party/xla/xla/service/gpu/conv_layout_normalization_test.cc +++ b/third_party/xla/xla/service/gpu/conv_layout_normalization_test.cc @@ -60,9 +60,15 @@ HloModule TestModule } )"; - MatchOptimizedHlo(hlo, R"( + if (!IsRocm() && GetCudaComputeCapability().IsAtLeastHopper()) { + MatchOptimizedHlo(hlo, R"( +// CHECK: (f32[1,23,136]{2,1,0}, u8[{{[0-9]+}}]{0}) custom-call([[fusion_1_0:%[^ ]+]], [[transpose_1_1:%[^ ]+]]), window={size=31 stride=2 pad=23_23}, dim_labels=b0f_o0i->b0f, custom_call_target="__cudnn$convBackwardInput" + )"); + } else { + MatchOptimizedHlo(hlo, R"( // CHECK: (f32[1,136,23]{2,1,0}, u8[{{[0-9]+}}]{0}) custom-call([[fusion_1_0:%[^ ]+]], [[transpose_1_1:%[^ ]+]]), window={size=31 stride=2 pad=23_23}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convBackwardInput" )"); + } } TEST_F(ConvolutionLayoutNormalizationTest, Forward) { @@ -76,9 +82,15 @@ ENTRY %TestComputation { } )"; - MatchOptimizedHlo(hlo, R"( + if (!IsRocm() && GetCudaComputeCapability().IsAtLeastHopper()) { + MatchOptimizedHlo(hlo, R"( +// CHECK: (f32[2,1,378,128]{3,2,1,0}, u8[{{[0-9]+}}]{0}) custom-call([[param_0_0:%[^ ]+]], [[bitcast_5_1:%[^ ]+]]), window={size=1x5 pad=0_0x2_2}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForward" + )"); + } else { + MatchOptimizedHlo(hlo, R"( // CHECK: (f32[2,128,1,378]{3,2,1,0}, u8[{{[0-9]+}}]{0}) custom-call([[param_0_0:%[^ ]+]], [[bitcast_5_1:%[^ ]+]]), window={size=1x5 pad=0_0x2_2}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward" - )"); + )"); + } } TEST_F(ConvolutionLayoutNormalizationTest, FusedConv3D) { diff --git a/third_party/xla/xla/service/gpu/cudnn_support_utils_test.cc b/third_party/xla/xla/service/gpu/cudnn_support_utils_test.cc index 48ca7ef87579d3..3491563ce5eb7c 100644 --- a/third_party/xla/xla/service/gpu/cudnn_support_utils_test.cc +++ b/third_party/xla/xla/service/gpu/cudnn_support_utils_test.cc @@ -30,12 +30,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/service/gpu/custom_call_test.cc b/third_party/xla/xla/service/gpu/custom_call_test.cc index f6f1eb6475ca42..3605fbce26d3fd 100644 --- a/third_party/xla/xla/service/gpu/custom_call_test.cc +++ b/third_party/xla/xla/service/gpu/custom_call_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #if GOOGLE_CUDA @@ -546,7 +545,7 @@ TEST_F(CustomCallTest, ExportedFfiOpaque) { } static absl::Status CheckTokens(std::vector args, - std::string_view pattern) { + absl::string_view pattern) { if (args.size() != pattern.size()) { return absl::InternalError("Incorrect number of arguments"); } @@ -573,7 +572,7 @@ static absl::Status CheckTokens(std::vector args, static absl::Status FfiTokens(ffi::RemainingArgs inputs, ffi::RemainingRets outputs, - std::string_view pattern) { + absl::string_view pattern) { std::vector types; for (auto i = 0; i < inputs.size(); ++i) { types.push_back(inputs.get(i).value().element_type()); @@ -586,7 +585,7 @@ static absl::Status FfiTokens(ffi::RemainingArgs inputs, XLA_FFI_DEFINE_HANDLER( kFfiTokens, FfiTokens, - ffi::Ffi::Bind().RemainingArgs().RemainingRets().Attr( + ffi::Ffi::Bind().RemainingArgs().RemainingRets().Attr( "pattern")); XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$tokens", PLATFORM, diff --git a/third_party/xla/xla/service/gpu/determinism_test.cc b/third_party/xla/xla/service/gpu/determinism_test.cc index a524f2baee64e2..746056f545e8dd 100644 --- a/third_party/xla/xla/service/gpu/determinism_test.cc +++ b/third_party/xla/xla/service/gpu/determinism_test.cc @@ -223,6 +223,7 @@ TEST_F(DeterminismTest, ExcludingNonDeterministicOpsDoesNotDisableAutotuning) { } debug_options_.set_xla_gpu_cublas_fallback(false); + debug_options_.set_xla_gpu_cudnn_gemm_fusion_level(0); ASSERT_TRUE(debug_options_.xla_gpu_exclude_nondeterministic_ops()); ASSERT_FALSE(debug_options_.xla_gpu_deterministic_ops()); AutotunerUtil::ClearAutotuneResults(); diff --git a/third_party/xla/xla/service/gpu/executable.proto b/third_party/xla/xla/service/gpu/executable.proto index 0c384beca953fc..2d57bb228a40a5 100644 --- a/third_party/xla/xla/service/gpu/executable.proto +++ b/third_party/xla/xla/service/gpu/executable.proto @@ -25,7 +25,7 @@ message CompilationResultProto { BufferAssignmentProto buffer_assignment = 2; string asm_text = 3; bytes binary = 4; - map dnn_compiled_graphs = 5; + map dnn_compiled_graphs = 5; } message LaunchDimensionsProto { diff --git a/third_party/xla/xla/service/gpu/execution_stream_assignment_test.cc b/third_party/xla/xla/service/gpu/execution_stream_assignment_test.cc index 0cd51f656f006d..6785887bd9badd 100644 --- a/third_party/xla/xla/service/gpu/execution_stream_assignment_test.cc +++ b/third_party/xla/xla/service/gpu/execution_stream_assignment_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/service/gpu/execution_stream_assignment.h" #include -#include #include #include @@ -107,21 +106,21 @@ TEST_F(ExecutionStreamAssignmentTest, AsyncFusion) { // to `2`. ExpectExecutionStreamForSyncInstructions( assignment, FindComputation(module.get(), "entry"), ExecutionStreamId(0)); - for (std::string_view instruction : {"start1", "update1", "done1"}) { + for (absl::string_view instruction : {"start1", "update1", "done1"}) { EXPECT_THAT(assignment.GetAsyncExecutionStreamIds(Cast( FindInstruction(module.get(), instruction))), IsOkAndHolds(AsyncExecutionStreamIds{ /*source_stream_id=*/ExecutionStreamId(0), /*destination_stream_id=*/ExecutionStreamId(1)})); } - for (std::string_view instruction : {"start2", "update2", "done2"}) { + for (absl::string_view instruction : {"start2", "update2", "done2"}) { EXPECT_THAT(assignment.GetAsyncExecutionStreamIds(Cast( FindInstruction(module.get(), instruction))), IsOkAndHolds(AsyncExecutionStreamIds{ /*source_stream_id=*/ExecutionStreamId(0), /*destination_stream_id=*/ExecutionStreamId(2)})); } - for (std::string_view instruction : {"start3", "update3", "done3"}) { + for (absl::string_view instruction : {"start3", "update3", "done3"}) { EXPECT_THAT(assignment.GetAsyncExecutionStreamIds(Cast( FindInstruction(module.get(), instruction))), IsOkAndHolds(AsyncExecutionStreamIds{ @@ -158,7 +157,7 @@ TEST_F(ExecutionStreamAssignmentTest, CopyStartStreamIdTest) { ExecutionStreamAssignment assignment(module.get()); - for (std::string_view instruction : {"copy-start"}) { + for (absl::string_view instruction : {"copy-start"}) { EXPECT_THAT( assignment.GetAsyncExecutionStreamIds(Cast( FindInstruction(module.get(), instruction))), @@ -200,7 +199,7 @@ TEST_F(ExecutionStreamAssignmentTest, FusionComputations) { // Computations only reachable through fusion nodes should have no assigned // `ExecutionStreamId`. - for (std::string_view computation : {"reduce", "fusion"}) { + for (absl::string_view computation : {"reduce", "fusion"}) { for (const HloInstruction* instruction : FindComputation(module.get(), computation)->instructions()) { EXPECT_THAT(assignment.GetSyncExecutionStreamId(instruction), diff --git a/third_party/xla/xla/service/gpu/flag_utils.h b/third_party/xla/xla/service/gpu/flag_utils.h index 90f191a16f8ecf..527057465dc8d1 100644 --- a/third_party/xla/xla/service/gpu/flag_utils.h +++ b/third_party/xla/xla/service/gpu/flag_utils.h @@ -25,12 +25,6 @@ limitations under the License. namespace xla { namespace gpu { -// Returns compile time optimization effort in range [-1.0, 1.0] where values < -// 0.0 indicate skipping passes which might optimize the final runtime (thus -// improving compile time), and values > 0.0 indicate running additional passes -// which may improve runtime at the cost of compilation time. -float ExecTimeOptimizationEffort(const HloModuleConfig& config); - // Defines the optimization effort to trigger additional passes which optimize // communication compute overlap. constexpr float kExtraCollectiveOptimizations = 0.2; diff --git a/third_party/xla/xla/service/gpu/fusion_process_dump.cc b/third_party/xla/xla/service/gpu/fusion_process_dump.cc index 9863a3a7b63ef8..c0bb7c71fd75a6 100644 --- a/third_party/xla/xla/service/gpu/fusion_process_dump.cc +++ b/third_party/xla/xla/service/gpu/fusion_process_dump.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/service/gpu/fusion_process_dump.h" #include -#include #include #include "absl/container/flat_hash_map.h" @@ -46,7 +45,7 @@ namespace { HloInstruction* AddFusionInstruction(HloInstruction* producer, HloInstruction* consumer, HloComputation* computation, - std::string_view fusion_name) { + absl::string_view fusion_name) { if (consumer->opcode() == HloOpcode::kFusion) { return consumer; } @@ -66,7 +65,7 @@ HloInstruction* AddFusionInstruction(HloInstruction* producer, HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer, HloComputation* computation, - std::string_view fusion_name) { + absl::string_view fusion_name) { HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer, computation, fusion_name); if (producer->opcode() == HloOpcode::kFusion) { diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 749cf6e10df81c..1e2c206646b72f 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -1,6 +1,7 @@ load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load("//xla:xla.bzl", "xla_cc_test") load("//xla/tests:build_defs.bzl", "xla_test") +load("//xla/tsl:tsl.bzl", "if_google") load("//xla/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") package( @@ -8,6 +9,37 @@ package( licenses = ["notice"], ) +cc_library( + name = "emitter_loc_op_builder", + srcs = ["emitter_loc_op_builder.cc"], + hdrs = ["emitter_loc_op_builder.h"], + visibility = ["//xla/service/gpu/fusions:__subpackages__"], + deps = [ + "@com_google_absl//absl/strings", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform", + ] + if_google(["@com_google_absl//absl/types:source_location"]), +) + +xla_test( + name = "emitter_loc_op_builder_test", + srcs = ["emitter_loc_op_builder_test.cc"], + backends = ["gpu"], + deps = [ + ":emitter_loc_op_builder", + "//xla/hlo/testlib:filecheck", + "//xla/service/gpu/fusions/triton:triton_fusion_emitter", + "//xla/service/llvm_ir:llvm_util", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:test", + ], +) + cc_library( name = "in_place_dynamic_update_slice_mlir", srcs = ["in_place_dynamic_update_slice_mlir.cc"], @@ -15,6 +47,8 @@ cc_library( deps = [ "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/codegen/emitters:computation_partitioner", + "//xla/codegen/emitters:elemental_hlo_to_mlir", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_traversal", @@ -22,8 +56,6 @@ cc_library( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu/fusions/mlir:computation_partitioner", - "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -63,20 +95,20 @@ cc_library( deps = [ ":fusion_emitter", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", "//xla:util", - "//xla:xla_data_proto_cc", "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/ffi:attribute_map", "//xla/ffi:ffi_api", + "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_traversal", "//xla/service:buffer_assignment", "//xla/service:custom_call_status", "//xla/service:custom_call_target_registry", "//xla/service:hlo_proto_cc", - "//xla/service:pattern_matcher", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu:hlo_fusion_analysis", @@ -135,6 +167,7 @@ xla_test( "//xla/hlo/builder:xla_computation", "//xla/hlo/builder/lib:constants", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", "//xla/service:custom_call_target_registry", "//xla/service:executable", "//xla/service:hlo_module_config", @@ -149,7 +182,6 @@ xla_test( "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", "//xla/stream_executor/gpu:gpu_types_header", - "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", "@com_google_absl//absl/algorithm:container", @@ -240,15 +272,15 @@ cc_library( "//xla:shape_util", "//xla:status_macros", "//xla:xla_data_proto_cc", + "//xla/backends/gpu/codegen/ir:xla_gpu", + "//xla/codegen/emitters:computation_partitioner", + "//xla/codegen/emitters:elemental_hlo_to_mlir", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_traversal", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu/fusions/ir:xla_gpu", - "//xla/service/gpu/fusions/mlir:computation_partitioner", - "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -267,30 +299,35 @@ cc_library( hdrs = ["scatter_mlir.h"], deps = [ "//xla:shape_util", + "//xla:util", "//xla:xla_data_proto_cc", + "//xla/codegen/emitters:computation_partitioner", + "//xla/codegen/emitters:elemental_hlo_to_mlir", + "//xla/codegen/emitters:type_util", "//xla/codegen/ir:xla", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", "//xla/service:scatter_simplifier", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu/fusions/ir:xla_gpu", - "//xla/service/gpu/fusions/mlir:computation_partitioner", - "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", + "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:VectorDialect", ], ) @@ -304,16 +341,16 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/gpu/codegen/ir:xla_gpu", + "//xla/codegen/emitters:computation_partitioner", + "//xla/codegen/emitters:elemental_hlo_to_mlir", + "//xla/codegen/emitters:type_util", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu/fusions/ir:xla_gpu", - "//xla/service/gpu/fusions/mlir:computation_partitioner", - "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", - "//xla/service/gpu/fusions/mlir:type_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -427,6 +464,8 @@ xla_test( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:dump", "//xla/service:executable", "//xla/service:pattern_matcher", @@ -436,11 +475,10 @@ xla_test( "//xla/service/gpu/runtime:thunk", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/service/gpu/transforms:cudnn_fusion_compiler", + "//xla/stream_executor:device_description", "//xla/stream_executor:dnn", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", - "//xla/tests:filecheck", - "//xla/tests:verified_hlo_module", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -499,6 +537,10 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/gpu/codegen/ir:xla_gpu", + "//xla/codegen/emitters:computation_partitioner", + "//xla/codegen/emitters:elemental_hlo_to_mlir", + "//xla/codegen/emitters:type_util", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_traversal", @@ -506,11 +548,7 @@ cc_library( "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:reduction_utils", - "//xla/service/gpu/fusions/ir:xla_gpu", - "//xla/service/gpu/fusions/mlir:computation_partitioner", - "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", - "//xla/service/gpu/fusions/mlir:type_util", "//xla/stream_executor:launch_dim", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -550,13 +588,13 @@ cc_library( hdrs = ["concatenate_mlir.h"], deps = [ "//xla:shape_util", + "//xla/codegen/emitters:computation_partitioner", + "//xla/codegen/emitters:elemental_hlo_to_mlir", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu/fusions/mlir:computation_partitioner", - "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -577,14 +615,14 @@ cc_library( deps = [ "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/gpu/codegen/ir:xla_gpu", + "//xla/codegen/emitters:computation_partitioner", + "//xla/codegen/emitters:elemental_hlo_to_mlir", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_traversal", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu/fusions/ir:xla_gpu", - "//xla/service/gpu/fusions/mlir:computation_partitioner", - "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc index 8d5b52b6eb7fe7..51c7c0134dea07 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc @@ -33,12 +33,12 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" +#include "xla/codegen/emitters/computation_partitioner.h" +#include "xla/codegen/emitters/elemental_hlo_to_mlir.h" #include "xla/hlo/analysis/indexing_analysis.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" -#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/launch_dimensions.h" @@ -49,6 +49,7 @@ namespace gpu { namespace { using llvm::SmallVector; +using mlir::ImplicitLocOpBuilder; using mlir::Value; using mlir::ValueRange; @@ -103,22 +104,22 @@ MlirConcatenateFusion::ComputeThreadIdToInputIndexing( largest_shape_, ctx); } -std::vector +std::vector MlirConcatenateFusion::GetEpilogues(const HloFusionInstruction& fusion, mlir::MLIRContext* mlir_context) const { - return {mlir_converter::EpilogueSpecification::FromIdentityIndexing( + return {emitters::EpilogueSpecification::FromIdentityIndexing( &analysis_.fusion_hero(0).instruction(), &analysis_.fusion_root(0).instruction(), mlir_context)}; } absl::Status MlirConcatenateFusion::EmitEntryFunction( - const mlir_converter::PartitionedComputations& computations, - const mlir_converter::CallTargetProvider& call_targets, + const emitters::PartitionedComputations& computations, + const emitters::CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, const HloFusionInstruction& fusion) const { const auto& root_computation = computations.FindPartitionedComputation( fusion.fused_instructions_computation()); - mlir::ImplicitLocOpBuilder builder(entry_function.getLoc(), entry_function); + ImplicitLocOpBuilder builder(entry_function.getLoc(), entry_function); builder.setInsertionPointToStart(entry_function.addEntryBlock()); auto thread_and_block_ids = EmitThreadAndBlockIds(builder); auto* ctx = entry_function.getContext(); @@ -152,32 +153,34 @@ absl::Status MlirConcatenateFusion::EmitEntryFunction( auto loop_nest_body_builder = [&, operand_index = operand_index]( - ValueRange symbol_values, ValueRange output_indices, + ImplicitLocOpBuilder& nested_b, ValueRange symbol_values, + ValueRange output_indices, ValueRange output_tensors) -> SmallVector { - auto input_indices = mlir_converter::ApplyIndexing( - thread_id_to_input_map, thread_and_block_ids, symbol_values, builder); + auto input_indices = + emitters::ApplyIndexing(thread_id_to_input_map, thread_and_block_ids, + symbol_values, nested_b); - auto result_scalar = mlir_converter::ProvideParameter( + auto result_scalar = emitters::ProvideParameter( root_computation, concat, operand_index, input_indices, call_targets, - entry_function, builder); + entry_function, nested_b); absl::flat_hash_map> hero_value{{concat, result_scalar}}; auto result_scalars = EmitEpilogue( /*epilogue_index=*/0, computations, entry_function, hero_value, - output_indices, builder)[&analysis_.fusion_root(0).instruction()]; + output_indices, nested_b)[&analysis_.fusion_root(0).instruction()]; SmallVector result_tensors; result_tensors.reserve(output_tensor_args.size()); for (auto [tensor, value] : llvm::zip(output_tensors, result_scalars)) { result_tensors.push_back( - builder + nested_b .create(value, tensor, output_indices) .getResult()); } return result_tensors; }; - result_tensors = mlir_converter::EmitXlaLoopOp( + result_tensors = emitters::EmitXlaLoopOp( builder, thread_and_block_ids, result_tensors, thread_id_to_output_map, loop_nest_body_builder); } diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.h b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.h index 7db4624797bad9..ffe33ae0a912c7 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.h @@ -24,10 +24,10 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" +#include "xla/codegen/emitters/computation_partitioner.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/launch_dimensions.h" @@ -51,12 +51,12 @@ class MlirConcatenateFusion : public MlirFusionEmitterBase { protected: absl::Status EmitEntryFunction( - const mlir_converter::PartitionedComputations& computations, - const mlir_converter::CallTargetProvider& call_targets, + const emitters::PartitionedComputations& computations, + const emitters::CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, const HloFusionInstruction& fusion) const override; - std::vector GetEpilogues( + std::vector GetEpilogues( const HloFusionInstruction& fusion, mlir::MLIRContext* mlir_context) const override; diff --git a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc index fb206ef9dd5506..1a4929ce70aee8 100644 --- a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc @@ -29,6 +29,8 @@ limitations under the License. #include "xla/debug_options_flags.h" #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/primitive_util.h" #include "xla/service/dump.h" #include "xla/service/executable.h" @@ -38,11 +40,10 @@ limitations under the License. #include "xla/service/gpu/transforms/cudnn_fusion_compiler.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" -#include "xla/tests/filecheck.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" @@ -63,14 +64,14 @@ class CuDnnFusionTest : public GpuCodegenTest { // Let this group of tests just use first available plan skipping // autotuning. debug_options.set_xla_gpu_autotune_level(0); - debug_options.set_xla_gpu_cudnn_gemm_fusion_level(1); + debug_options.set_xla_gpu_cudnn_gemm_fusion_level(2); return debug_options; } - bool IsAtLeastHopperWithCuDnn9() { + bool IsAtLeastAmpereWithCuDnn9() { se::StreamExecutor* executor = backend().default_stream_executor(); return executor->GetDeviceDescription() .cuda_compute_capability() - .IsAtLeastHopper() && + .IsAtLeastAmpere() && GetDnnVersionInfoOrDefault(executor).major_version() >= 9; } bool IsAtLeastCuDnn91() { @@ -82,9 +83,9 @@ class CuDnnFusionTest : public GpuCodegenTest { protected: void SetUp() override { - if (!IsAtLeastHopperWithCuDnn9()) { + if (!IsAtLeastAmpereWithCuDnn9()) { GTEST_SKIP() - << "cuDNN GEMM fusion is not enabled before Hopper / cuDNN 9."; + << "cuDNN GEMM fusion is not tested before Ampere / cuDNN 9."; } } }; @@ -609,17 +610,7 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -class CuDnnFusionLevel2Test : public CuDnnFusionExecutionTest { - public: - DebugOptions GetDebugOptionsForTest() const override { - DebugOptions debug_options = - CuDnnFusionExecutionTest::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_cudnn_gemm_fusion_level(2); - return debug_options; - } -}; - -TEST_F(CuDnnFusionLevel2Test, BroadcastToDim2ExecutesCorrectly) { +TEST_F(CuDnnFusionExecutionTest, BroadcastToDim2ExecutesCorrectly) { EXPECT_TRUE(RunAndCompare(R"( fusion1 { p0 = f16[16,32,128] parameter(0) @@ -642,7 +633,7 @@ ENTRY e { ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(CuDnnFusionLevel2Test, BroadcastToDim1ExecutesCorrectly) { +TEST_F(CuDnnFusionExecutionTest, BroadcastToDim1ExecutesCorrectly) { EXPECT_TRUE(RunAndCompare(R"( fusion1 { p0 = f16[16,32,128] parameter(0) @@ -665,7 +656,7 @@ ENTRY e { ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(CuDnnFusionLevel2Test, BroadcastToDim0ExecutesCorrectly) { +TEST_F(CuDnnFusionExecutionTest, BroadcastToDim0ExecutesCorrectly) { EXPECT_TRUE(RunAndCompare(R"( fusion1 { p0 = bf16[32,128] parameter(0) @@ -685,7 +676,7 @@ ENTRY e { ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(CuDnnFusionLevel2Test, BroadcastTo2DimsExecutesCorrectly) { +TEST_F(CuDnnFusionExecutionTest, BroadcastTo2DimsExecutesCorrectly) { EXPECT_TRUE(RunAndCompare(R"( fusion1 { p0 = f16[16,32,128] parameter(0) @@ -708,7 +699,7 @@ ENTRY e { ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(CuDnnFusionLevel2Test, BroadcastTo3DimsExecutesCorrectly) { +TEST_F(CuDnnFusionExecutionTest, BroadcastTo3DimsExecutesCorrectly) { EXPECT_TRUE(RunAndCompare(R"( fusion1 { p0 = f16[16,32,128] parameter(0) @@ -731,7 +722,7 @@ ENTRY e { ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(CuDnnFusionLevel2Test, ConstantExecutesCorrectly) { +TEST_F(CuDnnFusionExecutionTest, ConstantExecutesCorrectly) { if (!IsAtLeastCuDnn91()) { GTEST_SKIP() << "Fused scalar constants require cuDNN 9.1+."; } @@ -760,7 +751,7 @@ ENTRY e { ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(CuDnnFusionLevel2Test, ClampExecutesCorrectly) { +TEST_F(CuDnnFusionExecutionTest, ClampExecutesCorrectly) { if (!IsAtLeastCuDnn91()) { GTEST_SKIP() << "Clamp test requires cuDNN 9.1+."; } @@ -789,7 +780,7 @@ ENTRY e { ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(CuDnnFusionLevel2Test, DotF8ExecutesCorrectly) { +TEST_F(CuDnnFusionExecutionTest, DotF8ExecutesCorrectly) { EXPECT_TRUE(RunAndCompare(R"( fusion1 { @@ -814,7 +805,7 @@ ENTRY e { ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -TEST_F(CuDnnFusionLevel2Test, SlicingExecutesCorrectly) { +TEST_F(CuDnnFusionExecutionTest, SlicingExecutesCorrectly) { EXPECT_TRUE(RunAndCompare(R"( fusion1 { p0 = f16[11,23,64] parameter(0) @@ -834,17 +825,7 @@ ENTRY e { ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -class CuDnnFusionLevel3Test : public CuDnnFusionExecutionTest { - public: - DebugOptions GetDebugOptionsForTest() const override { - DebugOptions debug_options = - CuDnnFusionExecutionTest::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_cudnn_gemm_fusion_level(3); - return debug_options; - } -}; - -TEST_F(CuDnnFusionLevel3Test, +TEST_F(CuDnnFusionExecutionTest, DotWithSplitNonContractingInputExecutesCorrectly) { EXPECT_TRUE(RunAndCompare(R"( fusion1 { @@ -867,7 +848,7 @@ ENTRY r { ErrorSpec{/*aabs=*/1, /*arel=*/1e-3})); } -TEST_F(CuDnnFusionLevel3Test, +TEST_F(CuDnnFusionExecutionTest, DotWithSplitNonContractingInOutExecutesCorrectly) { EXPECT_TRUE(RunAndCompare(R"( fusion1 { @@ -1098,7 +1079,6 @@ class CuDnnFusionRewriteTest : public CuDnnFusionTest { // Reset autotuning level to default. debug_options.set_xla_gpu_autotune_level( GetDebugOptionsFromFlags().xla_gpu_autotune_level()); - debug_options.set_xla_gpu_cudnn_gemm_fusion_level(1); debug_options.set_xla_gpu_cublas_fallback(false); return debug_options; } @@ -1131,6 +1111,12 @@ TEST_F(CuDnnFusionRewriteTest, AutotuningPicksCuDnnForS8BF16OnHopper) { // The test case relies on measurements by the autotuner and current // performance comparison of the backends. May need to be updated if // the situation changes. + if (backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability() != se::CudaComputeCapability::Hopper()) { + GTEST_SKIP() << "The test is for Hopper."; + } MatchOptimizedHlo(R"( e { p0 = bf16[720,720,720] parameter(0) diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index de6fdc970b55c2..a0d194b6355f4e 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/ffi/attribute_map.h" #include "xla/ffi/ffi_api.h" +#include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -45,6 +46,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_traversal.h" #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" @@ -69,13 +71,11 @@ limitations under the License. #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/hlo.pb.h" -#include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/stream.h" #include "xla/util.h" -#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -86,8 +86,6 @@ namespace { constexpr unsigned kGEMMOutputBufferIndex = 0; constexpr unsigned kGEMMWorkspaceBufferIndex = 1; -namespace m = ::xla::match; - absl::StatusOr> BuildCustomKernelThunkForFusion( IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion, CustomKernel custom_kernel) { @@ -200,25 +198,24 @@ absl::Status CollectSliceInfo( const auto* param = Cast(idx_op); const auto* offset_value = fusion_instr.operand(param->parameter_number()); - if (auto* cst = DynCast(offset_value)) { + VLOG(2) << "Offset value:" << offset_value->ToString(); + + // Try to evaluate the offset value, maybe it is simple arithmetic. + absl::StatusOr offset_literal = HloEvaluator().Evaluate( + /*instruction=*/offset_value, + /*precomputed_analyses=*/{}, + /*recursively_evaluate_nonconstant_operands=*/true); + + if (offset_literal.ok()) { // Loop offset is defined by a constant scalar value. - if (ShapeUtil::IsScalarWithElementType(cst->shape(), - PrimitiveType::S32)) { - arg_offsets.emplace_back() = - static_cast(cst->literal().data()[0]); - } else if (ShapeUtil::IsScalarWithElementType(cst->shape(), - PrimitiveType::S64)) { - arg_offsets.emplace_back() = - static_cast(cst->literal().data()[0]); - } else if (ShapeUtil::IsScalarWithElementType(cst->shape(), - PrimitiveType::U32)) { - arg_offsets.emplace_back() = cst->literal().data()[0]; - } else if (ShapeUtil::IsScalarWithElementType(cst->shape(), - PrimitiveType::U64)) { - arg_offsets.emplace_back() = cst->literal().data()[0]; + std::optional offset_value = + LiteralUtil::LiteralAsScalarInt64(offset_literal.value()); + if (offset_value.has_value()) { + arg_offsets.emplace_back() = *offset_value; } else { - return absl::InternalError(absl::StrCat( - "Unsupported constant offset shape: ", cst->shape().ToString())); + return absl::InternalError( + absl::StrCat("Unsupported constant offset shape: ", + offset_literal->shape().ToString())); } } else { diff --git a/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc b/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc index 047ee4385e9962..3fb6bcd75aa315 100644 --- a/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_executable.h" @@ -42,7 +43,6 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/stream.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" @@ -3190,6 +3190,85 @@ TEST_F(DynamicSliceFusionTest, ReduceScatterDynamicSlice) { false, true, error)); } +TEST_F(DynamicSliceFusionTest, + OffsetsThatCanBeEvaluatedSuccessfullyAreCorrectlyEmbeddedIntoThunks) { + const char* hlo_opt = R"( + HloModule test, replica_count=2 + add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = s32[] add(a,b) + } + dynamic-slice-fusion { + src = s32[32,32] parameter(0) + dest = s32[32,32] parameter(1) + offset1 = s32[] parameter(2) + offset2 = s32[] parameter(3) + rs = s32[16,32] reduce-scatter(src), dimensions={0}, replica_groups={{0,1}}, to_apply=add + ROOT dus = s32[32,32] dynamic-update-slice(dest, rs, offset1, offset2) + } + ENTRY main { + src = s32[32,32] parameter(0) + dest = s32[32,32] parameter(1) + c0 = s32[] constant(0) + c5 = s32[] constant(5) + add = s32[] add(c5, c5) + ROOT fusion = s32[32,32] fusion(src, dest, add, c0), kind=kCustom, calls=dynamic-slice-fusion, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} + } + )"; + + const char* hlo_ref = R"( + HloModule test, replica_count=2 + add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = s32[] add(a,b) + } + ENTRY main { + src = s32[32,32] parameter(0) + dest = s32[32,32] parameter(1) + c0 = s32[] constant(0) + c5 = s32[] constant(5) + add = s32[] add(c5, c5) + rs.1 = ((s32[32,32]), s32[16,32]) reduce-scatter-start(src), dimensions={0}, replica_groups={{0,1}}, to_apply=add + rs = s32[16,32] reduce-scatter-done(rs.1) + ROOT dus = s32[32,32] dynamic-update-slice(dest, rs, add, c0) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_ref, + ParseAndReturnVerifiedModule(hlo_ref)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_opt, + ParseAndReturnVerifiedModule(hlo_opt)); + + // Check that the offset value in the thunk is an evaluated constant even if + // no simplification passes are executed. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr exec, + CreateExecutable(/*module=*/module_opt->Clone(), + /*run_hlo_passes=*/false)); + GpuExecutable* gpu_exec = dynamic_cast(exec.get()); + ASSERT_NE(gpu_exec, nullptr); + const SequentialThunk& thunk = gpu_exec->GetThunk(); + auto dynamic_slice_thunk = + absl::c_find_if(thunk.thunks(), [](const std::unique_ptr& thunk) { + return thunk->kind() == Thunk::kDynamicSlice; + }); + ASSERT_NE(dynamic_slice_thunk, thunk.thunks().end()); + std::vector>> offsets = + dynamic_cast(dynamic_slice_thunk->get()) + ->get_offsets(); + ASSERT_EQ(offsets.size(), 2); + ASSERT_TRUE(offsets[1].has_value()); + ASSERT_EQ(offsets[1].value()[0], DynamicSliceThunk::Offset(10l)); + ASSERT_EQ(offsets[1].value()[1], DynamicSliceThunk::Offset(0l)); + + ErrorSpec error{1e-3, 1e-3}; + EXPECT_TRUE(RunAndCompareTwoModulesReplicated( + /*module_0=*/std::move(module_ref), /*module_1=*/std::move(module_opt), + /*run_hlo_passes=*/false, /*use_threads=*/true, error)); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/emitter_loc_op_builder.cc b/third_party/xla/xla/service/gpu/fusions/emitter_loc_op_builder.cc new file mode 100644 index 00000000000000..d3a24e92667428 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/emitter_loc_op_builder.cc @@ -0,0 +1,79 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusions/emitter_loc_op_builder.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" +#include "mlir/Support/LLVM.h" + +namespace xla::gpu { + +// Aligns the annotations to the Nth character of the lines. +constexpr size_t kAnnotationPadding = 100ul; + +/* static */ std::string EmitterLocOpBuilder::FormatTritonIrWithAnnotations( + absl::string_view mlir_ir) { + auto triton_with_annotations = absl::StrSplit(mlir_ir, '\n'); + std::vector formatted_lines; + for (auto& line : triton_with_annotations) { + std::vector line_and_annotation = absl::StrSplit(line, '"'); + constexpr int kInstructionLineFragments = 3; + if (line_and_annotation.size() != kInstructionLineFragments) { + // The line does not matches with the pattern: + // x = instruction(y, z) "annotation" + // So we just add it to the output as is. + formatted_lines.emplace_back(line); + continue; + } + auto text_size = + std::min(line_and_annotation[0].size(), kAnnotationPadding); + auto new_line = + absl::StrCat(line_and_annotation[0], + std::string(kAnnotationPadding - text_size, ' '), "\"", + line_and_annotation[1], "\"", line_and_annotation[2]); + formatted_lines.emplace_back(new_line); + } + return absl::StrJoin(formatted_lines, "\n"); +} + +mlir::Location EmitterLocOpBuilder::Loc( + EmitterLocOpBuilder::SourceLocation location) const { + if (!annotate_loc_ || location.line() == 0) { + return current_loc_; + } + std::vector file_name = + absl::StrSplit(location.file_name(), '/'); + std::string previous_loc; + if (mlir::isa(current_loc_)) { + auto name_loc = mlir::cast(current_loc_); + previous_loc = name_loc.getName().str(); + } + + const std::string text = absl::StrCat(previous_loc, " -> ", file_name.back(), + ":", location.line()); + return mlir::NameLoc::get(mlir::StringAttr::get(getContext(), text)); +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusions/emitter_loc_op_builder.h b/third_party/xla/xla/service/gpu/fusions/emitter_loc_op_builder.h new file mode 100644 index 00000000000000..151f05e9678d98 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/emitter_loc_op_builder.h @@ -0,0 +1,206 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSIONS_EMITTER_LOC_OP_BUILDER_H_ +#define XLA_SERVICE_GPU_FUSIONS_EMITTER_LOC_OP_BUILDER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "tsl/platform/platform.h" + +#if defined(PLATFORM_GOOGLE) +// The source_location.h is not available in open source. +#include "absl/types/source_location.h" +#else +#include +#endif + +namespace xla::gpu { + +// The builder that could add the NameLoc attribute to the newly created +// operations and fills this attribute with the SourceLocation(file:line) of the +// create(...) calls. The location info will be added to the current_loc_ +// location that the builder got through the constructor. The copy constructor +// also remembers the source location where the copy was created. +// +// Why: it is useful for tracking up the emitter file and line from the +// generated MLIR. +// +// How: +// 1. create(...) functions have absl::SourceLocation as the last +// argument with the default value of SourceLocation::current(). Every time they +// construct a new NameLoc attribute that contains the string from the +// current_loc_ and file:line from the source location parameter. +// +// 2. The copy constructor also gets the source location as the argument and +// remembers it in the current_loc_ as a join of the original current_loc_ and +// the place where the copy was created. +class EmitterLocOpBuilder : public mlir::ImplicitLocOpBuilder { + public: + // TODO(b/382419919): Remove ifdefs once we have absl::SourceLocation in absl + // OSS builds. +#if defined(PLATFORM_GOOGLE) + using SourceLocation = absl::SourceLocation; + constexpr static bool kSourceLocationSupported = true; +#else + // Mimicking absl::SourceLocation and doing nothing. + class FakeSourceLocation { + public: + static FakeSourceLocation current() { return FakeSourceLocation(); } + absl::string_view file_name() const { return ""; } + int line() const { return 0; } + }; + using SourceLocation = FakeSourceLocation; + constexpr static bool kSourceLocationSupported = false; +#endif + + // Constructor that takes the op builder and a flag indicating whether to + // annotate the location of the operations. + EmitterLocOpBuilder(mlir::ImplicitLocOpBuilder& op_builder, bool annotate_loc) + : mlir::ImplicitLocOpBuilder(op_builder), + annotate_loc_(annotate_loc), + current_loc_(op_builder.getLoc()) {} + + // A few constructors below that could be used when we replace the + // mlir::ImplicitLocOpBuilder and mlir::OpBuilder one by one. + // The intent is to use EmitterLocOpBuilder everywhere in the emitters. + + // The constructor that should be used instead of mlir::ImplicitLocOpBuilder. + EmitterLocOpBuilder(mlir::Location loc, mlir::OpBuilder& op_builder, + bool annotate_loc = false) + : mlir::ImplicitLocOpBuilder(loc, op_builder), + + annotate_loc_(annotate_loc), + current_loc_(loc) {} + + // The constructor that should be used instead of mlir::ImplicitLocOpBuilder. + EmitterLocOpBuilder(mlir::Location loc, mlir::MLIRContext* mlir_context, + bool annotate_loc = false) + : mlir::ImplicitLocOpBuilder(loc, mlir_context), + annotate_loc_(annotate_loc), + current_loc_(loc) {} + + EmitterLocOpBuilder& operator=(const EmitterLocOpBuilder&) = delete; + + // Copy constructor that also remembers the source location where the copy + // was created. If the helper functions that gets the builder as the argument + // receives the argument by value then the current location points to the + // place where the copy was created. + EmitterLocOpBuilder(const EmitterLocOpBuilder& builder, + SourceLocation location = SourceLocation::current()) + : mlir::ImplicitLocOpBuilder(builder), + annotate_loc_(builder.annotate_loc_), + current_loc_(builder.Loc(location)) {} + + // Formats the MLIR IR with annotations to make it easier to read. + static std::string FormatTritonIrWithAnnotations(absl::string_view mlir_ir); + + // Below is the set of create() methods that are used to create operations. + // These are all templated to allow for the creation of operations with + // different numbers of arguments. + // + // For some reason the version of create that accepts the variadic arguments + // and a source location with the default value does not work. + + template + OpTy create(SourceLocation location = SourceLocation::current()) { + return OpBuilder::create(Loc(location)); + } + + // Creates an operation with the given type and one argument. + template + OpTy create(Arg0&& arg, SourceLocation location = SourceLocation::current()) { + return OpBuilder::create(Loc(location), std::forward(arg)); + } + + template + OpTy create(Arg0&& arg0, Arg1&& arg1, + SourceLocation location = SourceLocation::current()) { + return OpBuilder::create(Loc(location), std::forward(arg0), + std::forward(arg1)); + } + + template + OpTy create(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2, + SourceLocation location = SourceLocation::current()) { + return OpBuilder::create(Loc(location), std::forward(arg0), + std::forward(arg1), + std::forward(arg2)); + } + + template + OpTy create(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2, Arg3&& arg3, + SourceLocation location = SourceLocation::current()) { + return OpBuilder::create( + Loc(location), std::forward(arg0), std::forward(arg1), + std::forward(arg2), std::forward(arg3)); + } + + template + OpTy create(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2, Arg3&& arg3, Arg4&& arg4, + SourceLocation location = SourceLocation::current()) { + return OpBuilder::create( + Loc(location), std::forward(arg0), std::forward(arg1), + std::forward(arg2), std::forward(arg3), + std::forward(arg4)); + } + + template + OpTy create(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2, Arg3&& arg3, Arg4&& arg4, + Arg5&& arg5, + SourceLocation location = SourceLocation::current()) { + return OpBuilder::create( + Loc(location), std::forward(arg0), std::forward(arg1), + std::forward(arg2), std::forward(arg3), + std::forward(arg4), std::forward(arg5)); + } + + template + OpTy create(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2, Arg3&& arg3, Arg4&& arg4, + Arg5&& arg5, Arg6&& arg6, + SourceLocation location = SourceLocation::current()) { + return OpBuilder::create( + Loc(location), std::forward(arg0), std::forward(arg1), + std::forward(arg2), std::forward(arg3), + std::forward(arg4), std::forward(arg5), + std::forward(arg6)); + } + + mlir::Location current_loc() const { return current_loc_; } + + bool annotate_loc() const { return annotate_loc_; } + + private: + // Helper function to create a location from a source location. + mlir::Location Loc(SourceLocation location) const; + + // Keep the current location of the builder and use it for annotating the + // newly created operations. + const bool annotate_loc_; + const mlir::Location current_loc_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_FUSIONS_EMITTER_LOC_OP_BUILDER_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/emitter_loc_op_builder_test.cc b/third_party/xla/xla/service/gpu/fusions/emitter_loc_op_builder_test.cc new file mode 100644 index 00000000000000..d5691f31ec94c9 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/emitter_loc_op_builder_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusions/emitter_loc_op_builder.h" + +#include + +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/test.h" + +namespace xla::gpu { +namespace { + +using mlir::NameLoc; +using mlir::StringAttr; +using ::tsl::testing::IsOkAndHolds; + +class EmitterLocOpBuilderTest : public ::testing::Test { + protected: + void SetUp() override { LoadMlirDialectsForTriton(context_); } + + mlir::MLIRContext context_; +}; + +NameLoc NameLoc(mlir::MLIRContext& context, absl::string_view name) { + return NameLoc::get(StringAttr::get(&context, name)); +} + +mlir::OwningOpRef MakeModuleWithOneOp( + mlir::MLIRContext& context, EmitterLocOpBuilder& b) { + auto loc = NameLoc(context, "module"); + auto triton_module = llvm_ir::CreateMlirModuleOp(loc); + b.setInsertionPointToEnd(triton_module->getBody()); + auto i32_type = b.getI32Type(); + auto attr = b.getIntegerAttr(i32_type, 42); + b.create(attr); + return triton_module; +} + +TEST_F(EmitterLocOpBuilderTest, IRWithAnnotations) { + auto loc = NameLoc(context_, "IRWithAnnotations"); + EmitterLocOpBuilder b(loc, &context_, /*annotate_loc=*/true); + auto triton_module = MakeModuleWithOneOp(context_, b); + std::string ir = DumpTritonIR(triton_module.get(), /*dump_annotations=*/true); + if constexpr (EmitterLocOpBuilder::kSourceLocationSupported) { + EXPECT_THAT(RunFileCheck(ir, R"( + CHECK: "IRWithAnnotations -> [[FILE:.*_test.cc]]:[[LINE:[0-9]+]]" + )"), + IsOkAndHolds(true)); + } else { + EXPECT_THAT(RunFileCheck(ir, R"( + CHECK: "IRWithAnnotations" + )"), + IsOkAndHolds(true)); + } +} + +TEST_F(EmitterLocOpBuilderTest, IRWithoutAnnotations) { + auto loc = NameLoc(context_, "IRWithoutAnnotations"); + EmitterLocOpBuilder b(loc, &context_, /*annotate_loc=*/false); + auto triton_module = MakeModuleWithOneOp(context_, b); + std::string ir = + DumpTritonIR(triton_module.get(), /*dump_annotations=*/false); + EXPECT_THAT(RunFileCheck(ir, R"( + CHECK-NOT: IRWithoutAnnotations + )"), + IsOkAndHolds(true)); +} + +} // namespace + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc index 084493e0b0b252..849779f7da3535 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc @@ -297,7 +297,7 @@ BuildKernelPrototypeFromUniqueName(IrEmitterContext& ir_emitter_context, llvm::Argument& llvm_arg = *kernel->getArg(to_llvm_arg_no[arg_no]); llvm::Type* ir_type = - llvm_ir::ShapeToIrType(kernel_argument.shape(), llvm_module); + llvm_ir::ShapeToIrType(kernel_argument.shape(), context); llvm_ir::IrArray ir_array(&llvm_arg, ir_type, kernel_argument.shape()); if (!kernel_argument.written()) { diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.cc b/third_party/xla/xla/service/gpu/fusions/fusions.cc index 05da9a663c9e9b..cb3df1889bfd0b 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusions.cc @@ -108,15 +108,15 @@ std::unique_ptr GetFusionEmitter( return std::make_unique(analysis); } case HloFusionAnalysis::EmitterFusionKind::kReduction: - return CreateMlirReductionFusion(analysis); + return CreateMlirReductionFusion(analysis); case HloFusionAnalysis::EmitterFusionKind::kScatter: { - return std::make_unique(analysis); + return CreateMlirScatterFusion(analysis); } case HloFusionAnalysis::EmitterFusionKind::kTranspose: { - return std::make_unique(analysis); + return std::make_unique(analysis); } case HloFusionAnalysis::EmitterFusionKind::kConcatenate: { - return std::make_unique(analysis); + return std::make_unique(analysis); } case HloFusionAnalysis::EmitterFusionKind::kTriton: return std::make_unique(analysis); diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc index 223580fcdc4560..f853a7df22c53b 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc @@ -31,6 +31,8 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" +#include "xla/codegen/emitters/computation_partitioner.h" +#include "xla/codegen/emitters/elemental_hlo_to_mlir.h" #include "xla/hlo/analysis/indexing_analysis.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -38,8 +40,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/primitive_util.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" -#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/xla_data.pb.h" @@ -48,6 +48,11 @@ namespace xla { namespace gpu { namespace { +using emitters::ApplyIndexing; +using emitters::CallTargetProvider; +using emitters::ClampIndex; +using emitters::PartitionedComputations; +using emitters::ProvideParameter; using llvm::SmallVector; using mlir::ImplicitLocOpBuilder; using mlir::Value; @@ -55,11 +60,6 @@ using mlir::ValueRange; using mlir::arith::AddIOp; using mlir::func::ReturnOp; using mlir::tensor::InsertOp; -using mlir_converter::ApplyIndexing; -using mlir_converter::CallTargetProvider; -using mlir_converter::ClampIndex; -using mlir_converter::PartitionedComputations; -using mlir_converter::ProvideParameter; constexpr int kDUSUpdateIndex = 1; @@ -89,17 +89,16 @@ MlirInPlaceDynamicUpdateSliceFusion::ComputeThreadIdToInputIndexing( update_shape, indexing_context); } -std::vector +std::vector MlirInPlaceDynamicUpdateSliceFusion::GetEpilogues( const HloFusionInstruction& fusion, mlir::MLIRContext* mlir_context) const { // We don't actually support epilogues for DUS, but this is how we tell // the base class that we don't want it to generate code for the DUS. - std::vector epilogues; + std::vector epilogues; for (const auto& [dus_op, root] : llvm::zip(dus_ops_, analysis_.fusion_roots())) { - epilogues.push_back( - mlir_converter::EpilogueSpecification::FromIdentityIndexing( - &dus_op.instruction(), &root.instruction(), mlir_context)); + epilogues.push_back(emitters::EpilogueSpecification::FromIdentityIndexing( + &dus_op.instruction(), &root.instruction(), mlir_context)); } return epilogues; } @@ -126,9 +125,10 @@ absl::Status MlirInPlaceDynamicUpdateSliceFusion::EmitEntryFunction( const auto& root_computation = computations.FindPartitionedComputation( fusion.fused_instructions_computation()); - auto result_tensors = mlir_converter::EmitXlaLoopOp( + auto result_tensors = emitters::EmitXlaLoopOp( b, thread_and_block_ids, output_tensor_args, indexing, - [&](ValueRange symbol_values, ValueRange input_indices, + [&](ImplicitLocOpBuilder& nested_b, ValueRange symbol_values, + ValueRange input_indices, ValueRange output_tensors) -> llvm::SmallVector { llvm::SmallVector results; for (auto [instr, root, output] : @@ -140,7 +140,7 @@ absl::Status MlirInPlaceDynamicUpdateSliceFusion::EmitEntryFunction( auto start_indices = ProvideParameterRange( root_computation, dus_instr, dus_instr->first_index_operand_number(), update_shape.rank(), {}, - call_targets, entry_function, b); + call_targets, entry_function, nested_b); for (int i = 0; i < update_shape.rank(); ++i) { int64_t update_size = update_shape.dimensions(i); auto start_index = ClampIndex( @@ -150,23 +150,23 @@ absl::Status MlirInPlaceDynamicUpdateSliceFusion::EmitEntryFunction( ->operand(i + dus_instr->first_index_operand_number()) ->shape() .element_type()), - dus_instr->shape().dimensions(i) - update_size, b); + dus_instr->shape().dimensions(i) - update_size, nested_b); update_indices.push_back( - b.create(input_indices[i], start_index)); + nested_b.create(input_indices[i], start_index)); } - auto updated_value = - ProvideParameter(root_computation, dus_instr, kDUSUpdateIndex, - input_indices, call_targets, entry_function, b); + auto updated_value = ProvideParameter( + root_computation, dus_instr, kDUSUpdateIndex, input_indices, + call_targets, entry_function, nested_b); // Handle bitcasts under the DUS. if (dus_instr->shape() != root.shape()) { update_indices = ApplyIndexing( GetBitcastMap(dus_instr->shape(), root.shape(), b.getContext()), - update_indices, {}, b); + update_indices, {}, nested_b); } - results.push_back( - b.create(updated_value[0], output, update_indices)); + results.push_back(nested_b.create(updated_value[0], output, + update_indices)); } return results; }); diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h index d0803e1d044cc0..8be5fdbabe14c6 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h @@ -22,11 +22,11 @@ limitations under the License. #include "absl/status/status.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/MLIRContext.h" +#include "xla/codegen/emitters/computation_partitioner.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/utils/hlo_traversal.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" @@ -68,12 +68,12 @@ class MlirInPlaceDynamicUpdateSliceFusion : public MlirFusionEmitterBase { protected: absl::Status EmitEntryFunction( - const mlir_converter::PartitionedComputations& computations, - const mlir_converter::CallTargetProvider& call_targets, + const emitters::PartitionedComputations& computations, + const emitters::CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, const HloFusionInstruction& fusion) const override; - std::vector GetEpilogues( + std::vector GetEpilogues( const HloFusionInstruction& fusion, mlir::MLIRContext* mlir_context) const override; diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc index 89acdcca62489b..91d9a23b4954f9 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc @@ -35,15 +35,15 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" +#include "xla/codegen/emitters/computation_partitioner.h" +#include "xla/codegen/emitters/elemental_hlo_to_mlir.h" #include "xla/hlo/analysis/indexing_analysis.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/utils/hlo_traversal.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" -#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/xla_data.pb.h" @@ -51,6 +51,7 @@ namespace xla { namespace gpu { using llvm::SmallVector; +using mlir::ImplicitLocOpBuilder; using mlir::Value; using mlir::ValueRange; @@ -67,7 +68,7 @@ MlirInputSlicesFusion::ComputeThreadIdToOutputIndexing( .begin(); } -std::vector +std::vector MlirInputSlicesFusion::GetEpilogues(const HloFusionInstruction& fusion, mlir::MLIRContext* mlir_context) const { std::vector roots; @@ -91,8 +92,8 @@ LaunchDimensions MlirInputSlicesFusion::launch_dimensions() const { } absl::Status MlirInputSlicesFusion::EmitEntryFunction( - const mlir_converter::PartitionedComputations& computations, - const mlir_converter::CallTargetProvider& call_targets, + const emitters::PartitionedComputations& computations, + const emitters::CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, const HloFusionInstruction& fusion) const { mlir::ImplicitLocOpBuilder builder(entry_function.getLoc(), entry_function); @@ -109,9 +110,10 @@ absl::Status MlirInputSlicesFusion::EmitEntryFunction( auto output_tensor_args = entry_function.getArguments().drop_front(num_inputs); - auto result_tensors = mlir_converter::EmitXlaLoopOp( + auto result_tensors = emitters::EmitXlaLoopOp( builder, thread_and_block_ids, output_tensor_args, input_indexing, - [&](ValueRange symbol_values, ValueRange map_results, + [&](ImplicitLocOpBuilder nested_b, ValueRange symbol_values, + ValueRange map_results, ValueRange output_tensors) -> SmallVector { SmallVector input_operands( entry_function.getArguments().take_front(num_inputs)); @@ -124,7 +126,7 @@ absl::Status MlirInputSlicesFusion::EmitEntryFunction( const auto* arg = root.instruction().operand(0); if (auto& value = input_values[arg]; !value) { value = - builder.create(call_targets(arg), input_operands) + nested_b.create(call_targets(arg), input_operands) .getResult(0); } } @@ -132,14 +134,14 @@ absl::Status MlirInputSlicesFusion::EmitEntryFunction( for (auto [output_index, output] : llvm::enumerate(output_tensors)) { auto output_indexing = ComputeThreadIdToOutputIndexing( output_index, entry_function.getContext()); - mlir::Value in_bounds = mlir_converter::CheckConstraints( - *output_indexing, thread_and_block_ids, symbol_values, builder); - auto if_op = builder.create( + mlir::Value in_bounds = emitters::CheckConstraints( + *output_indexing, thread_and_block_ids, symbol_values, nested_b); + auto if_op = nested_b.create( in_bounds, [&, output_index = output_index, output = output]( mlir::OpBuilder b, mlir::Location loc) { mlir::ImplicitLocOpBuilder then_builder(loc, b); - auto output_indices = mlir_converter::ApplyIndexing( + auto output_indices = emitters::ApplyIndexing( *output_indexing, thread_and_block_ids, symbol_values, then_builder); const auto* arg = analysis_.fusion_root(output_index) diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.h b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.h index 14bf9aa30d76da..fa6a26d9aac1ea 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.h @@ -22,9 +22,9 @@ limitations under the License. #include "absl/status/status.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/MLIRContext.h" +#include "xla/codegen/emitters/computation_partitioner.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/launch_dimensions.h" @@ -54,12 +54,12 @@ class MlirInputSlicesFusion : public MlirFusionEmitterBase { protected: absl::Status EmitEntryFunction( - const mlir_converter::PartitionedComputations& computations, - const mlir_converter::CallTargetProvider& call_targets, + const emitters::PartitionedComputations& computations, + const emitters::CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, const HloFusionInstruction& fusion) const override; - std::vector GetEpilogues( + std::vector GetEpilogues( const HloFusionInstruction& fusion, mlir::MLIRContext* mlir_context) const override; diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc index 10760a0df22fcd..d820077de404a8 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc @@ -32,15 +32,15 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" +#include "xla/codegen/emitters/computation_partitioner.h" +#include "xla/codegen/emitters/elemental_hlo_to_mlir.h" #include "xla/hlo/analysis/indexing_analysis.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/utils/hlo_traversal.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" -#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/shape.h" @@ -52,6 +52,7 @@ namespace gpu { namespace { using llvm::SmallVector; +using mlir::ImplicitLocOpBuilder; using mlir::Value; using mlir::ValueRange; @@ -99,11 +100,11 @@ LaunchDimensions MlirLoopFusion::launch_dimensions() const { } absl::Status MlirLoopFusion::EmitEntryFunction( - const mlir_converter::PartitionedComputations& computations, - const mlir_converter::CallTargetProvider& call_targets, + const emitters::PartitionedComputations& computations, + const emitters::CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, const HloFusionInstruction& fusion) const { - mlir::ImplicitLocOpBuilder builder(entry_function.getLoc(), entry_function); + ImplicitLocOpBuilder builder(entry_function.getLoc(), entry_function); builder.setInsertionPointToStart(entry_function.addEntryBlock()); auto thread_and_block_ids = EmitThreadAndBlockIds(builder); @@ -125,7 +126,8 @@ absl::Status MlirLoopFusion::EmitEntryFunction( } } - auto body_builder = [&](ValueRange symbol_values, ValueRange map_results, + auto body_builder = [&](ImplicitLocOpBuilder& nested_b, + ValueRange symbol_values, ValueRange map_results, ValueRange output_tensors) -> SmallVector { auto root_fn = call_targets( fusion.fused_instructions_computation()->root_instruction()); @@ -135,25 +137,25 @@ absl::Status MlirLoopFusion::EmitEntryFunction( entry_function.getArguments().take_front(num_inputs)); absl::c_copy(map_results, std::back_inserter(operands)); auto result_scalars = - builder.create(root_fn, operands).getResults(); + nested_b.create(root_fn, operands).getResults(); SmallVector result_tensors; result_tensors.reserve(output_tensor_args.size()); for (auto [root_shape, tensor, value] : llvm::zip(result_shapes, output_tensors, result_scalars)) { - llvm::SmallVector output_indices = mlir_converter::ApplyIndexing( + llvm::SmallVector output_indices = emitters::ApplyIndexing( GetBitcastMap(*result_shapes.front(), *root_shape, - builder.getContext()), - map_results, {}, builder); - result_tensors.push_back(builder.create( + nested_b.getContext()), + map_results, {}, nested_b); + result_tensors.push_back(nested_b.create( value, tensor, output_indices)); } return result_tensors; }; - builder.create(mlir_converter::EmitXlaLoopOp( - builder, thread_and_block_ids, output_tensor_args, *indexing, - body_builder)); + builder.create( + emitters::EmitXlaLoopOp(builder, thread_and_block_ids, output_tensor_args, + *indexing, body_builder)); return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir.h b/third_party/xla/xla/service/gpu/fusions/loop_mlir.h index b43fd2bfb61e73..e983e386026317 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir.h @@ -21,9 +21,9 @@ limitations under the License. #include "absl/status/status.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/MLIRContext.h" +#include "xla/codegen/emitters/computation_partitioner.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" @@ -48,8 +48,8 @@ class MlirLoopFusion : public MlirFusionEmitterBase { protected: absl::Status EmitEntryFunction( - const mlir_converter::PartitionedComputations& computations, - const mlir_converter::CallTargetProvider& call_targets, + const emitters::PartitionedComputations& computations, + const emitters::CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, const HloFusionInstruction& fusion) const override; diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD index c4bd5923db38cb..d0d73d1c38915a 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD @@ -13,146 +13,20 @@ package_group( ], ) -cc_library( - name = "computation_partitioner", - srcs = ["computation_partitioner.cc"], - hdrs = ["computation_partitioner.h"], - deps = [ - ":type_util", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/analysis:indexing_analysis", - "//xla/hlo/ir:hlo", - "//xla/service/llvm_ir:llvm_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:DataLayoutInterfaces", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:Support", - ], -) - -xla_cc_test( - name = "computation_partitioner_test", - srcs = ["computation_partitioner_test.cc"], - deps = [ - ":computation_partitioner", - "//xla/hlo/analysis:indexing_analysis", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - ], -) - -cc_library( - name = "elemental_hlo_to_mlir", - srcs = ["elemental_hlo_to_mlir.cc"], - hdrs = ["elemental_hlo_to_mlir.h"], - deps = [ - ":computation_partitioner", - ":type_util", - "//xla:comparison_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:xla_data_proto_cc", - "//xla/codegen/ir:xla", - "//xla/hlo/analysis:indexing_analysis", - "//xla/hlo/ir:hlo", - "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", - "//xla/hlo/utils:hlo_traversal", - "//xla/mlir_hlo", - "//xla/mlir_hlo:map_mhlo_to_scalar_op", - "//xla/service:algorithm_util", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:AffineUtils", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ComplexDialect", - "@llvm-project//mlir:DataLayoutInterfaces", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:VectorDialect", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "elemental_hlo_to_mlir_test", - srcs = ["elemental_hlo_to_mlir_test.cc"], - deps = [ - ":computation_partitioner", - ":elemental_hlo_to_mlir", - "//xla:status_macros", - "//xla/codegen/ir:xla", - "//xla/hlo/analysis:indexing_analysis", - "//xla/hlo/ir:hlo", - "//xla/hlo/parser:hlo_parser", - "//xla/hlo/testlib:filecheck", - "//xla/mlir_hlo", - "//xla/service/gpu/fusions/ir:xla_gpu", - "//xla/service/llvm_ir:llvm_util", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/status", - "@com_google_googletest//:gtest", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:AsmParser", - "@llvm-project//mlir:DLTIDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:Transforms", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "mlir_fusion_emitter", srcs = ["mlir_fusion_emitter.cc"], hdrs = ["mlir_fusion_emitter.h"], deps = [ - ":computation_partitioner", - ":elemental_hlo_to_mlir", - ":type_util", "//xla:shape_util", "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/gpu/codegen/ir:xla_gpu", + "//xla/backends/gpu/codegen/transforms:passes", + "//xla/codegen/emitters:computation_partitioner", + "//xla/codegen/emitters:elemental_hlo_to_mlir", + "//xla/codegen/emitters:type_util", "//xla/codegen/ir:xla", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", @@ -169,8 +43,6 @@ cc_library( "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:target_util", "//xla/service/gpu/fusions:fusion_emitter", - "//xla/service/gpu/fusions/ir:xla_gpu", - "//xla/service/gpu/fusions/transforms:passes", "//xla/service/gpu/runtime:kernel_thunk", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", @@ -226,15 +98,15 @@ xla_cc_test( name = "mlir_fusion_emitter_test", srcs = ["mlir_fusion_emitter_test.cc"], deps = [ - ":computation_partitioner", ":mlir_fusion_emitter", + "//xla/codegen/emitters:computation_partitioner", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", "//xla/mlir_hlo", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:launch_dimensions", "//xla/stream_executor:device_description", - "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/status", @@ -264,32 +136,3 @@ xla_cc_test( "@local_tsl//tsl/platform:statusor", ], ) - -cc_library( - name = "type_util", - srcs = ["type_util.cc"], - hdrs = ["type_util.h"], - deps = [ - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", - "//xla/mlir/utils:type_util", - "@com_google_absl//absl/log:check", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - ], -) - -xla_cc_test( - name = "type_util_test", - srcs = ["type_util_test.cc"], - deps = [ - ":type_util", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - ], -) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc index 712e060ab71dfc..d1a43a4811e0da 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/algorithm/container.h" @@ -77,6 +76,11 @@ limitations under the License. #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Transforms/Passes.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" +#include "xla/backends/gpu/codegen/transforms/passes.h" +#include "xla/codegen/emitters/computation_partitioner.h" +#include "xla/codegen/emitters/elemental_hlo_to_mlir.h" +#include "xla/codegen/emitters/type_util.h" #include "xla/codegen/ir/xla_ops.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -88,11 +92,6 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/dump.h" #include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" -#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/type_util.h" -#include "xla/service/gpu/fusions/transforms/passes.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/kernel_arguments.h" @@ -389,11 +388,11 @@ MlirFusionEmitterBase::CreateMLIRModule( int arg_index = 0; for (auto* param : fusion.operands()) { param_types.push_back( - mlir_converter::TensorShapeToMlirType(param->shape(), builder)); + emitters::TensorShapeToMlirType(param->shape(), builder)); TF_ASSIGN_OR_RETURN(arg_attrs.emplace_back(), get_arg_attrs(arg_index++)); } - auto result_types = mlir_converter::ShapeToMlirTypes(fusion.shape(), builder); + auto result_types = emitters::ShapeToMlirTypes(fusion.shape(), builder); param_types.append(result_types.begin(), result_types.end()); TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( fusion.shape(), [&](const auto& shape, const ShapeIndex& index) { @@ -417,13 +416,13 @@ MlirFusionEmitterBase::CreateMLIRModule( return module; } -mlir_converter::EpilogueSpecification +emitters::EpilogueSpecification MlirFusionEmitterBase::GetEpilogueForOutputIndexing( const HloFusionAnalysis& analysis, const std::vector& heroes, const std::vector& roots, mlir::MLIRContext* mlir_context) const { - mlir_converter::EpilogueSpecification result; + emitters::EpilogueSpecification result; absl::flat_hash_map root_to_hero; @@ -464,9 +463,9 @@ MlirFusionEmitterBase::GetEpilogueForOutputIndexing( absl::Status MlirFusionEmitterBase::EmitMlir( mlir::ModuleOp module, FuncOp entry_function, const HloFusionInstruction& fusion) const { - std::vector epilogues = + std::vector epilogues = GetEpilogues(fusion, module->getContext()); - mlir_converter::PartitionedComputations computations( + emitters::PartitionedComputations computations( fusion.fused_instructions_computation(), module->getContext(), epilogues); auto subgraph_to_mlir_fn = computations.DeclareFunctions(module); @@ -496,14 +495,14 @@ absl::Status MlirFusionEmitterBase::EmitMlir( for (const auto& comp : computations.partitioned_computations()) { for (const auto& subgraph : comp.subgraphs()) { if (subgraph_to_mlir_fn.contains(&subgraph)) { - TF_RETURN_IF_ERROR(mlir_converter::SubgraphToMlirFunction( + TF_RETURN_IF_ERROR(emitters::SubgraphToMlirFunction( comp, subgraph, subgraph_to_mlir_fn[&subgraph], call_targets)); } } } for (const auto& epilogue : computations.epilogues()) { if (epilogue.roots.empty()) continue; - TF_RETURN_IF_ERROR(mlir_converter::SubgraphToMlirFunction( + TF_RETURN_IF_ERROR(emitters::SubgraphToMlirFunction( computations.FindPartitionedComputation( fusion.fused_instructions_computation()), epilogue, subgraph_to_mlir_fn[&epilogue], call_targets)); @@ -523,8 +522,7 @@ absl::Status MlirFusionEmitterBase::EmitMlir( absl::flat_hash_map MlirFusionEmitterBase::EmitEpilogue( - int epilogue_index, - const mlir_converter::PartitionedComputations& computations, + int epilogue_index, const emitters::PartitionedComputations& computations, FuncOp entry_fn, const absl::flat_hash_map>& injected, @@ -609,7 +607,7 @@ void AddLoopTransformationPasses(mlir::OpPassManager& pm, // opportunities for LICM. This would not be necessary if LICM also moved // instructions over ifs. pm.addPass(mlir::createLoopInvariantCodeMotionPass()); - pm.addNestedPass(CreateVectorizeLoadsAndStoresPass()); + pm.addNestedPass(CreateVectorizeLoadsAndStoresPass(device)); pm.addNestedPass(CreateOptimizeLoopsPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h index 542b168460407d..cdb621e2f8771b 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h @@ -16,6 +16,7 @@ limitations under the License. #define XLA_SERVICE_GPU_FUSIONS_MLIR_MLIR_FUSION_EMITTER_H_ #include +#include #include #include @@ -34,13 +35,13 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Pass/PassManager.h" +#include "xla/codegen/emitters/computation_partitioner.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/mlir/tools/mlir_replay/public/compiler_trace.pb.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/stream_executor/device_description.h" @@ -74,7 +75,7 @@ class MlirFusionEmitterBase : public KernelFusionInterface { // Returns the set of instructions that will be isolated in the partitioned, // i.e., they will get their own subgraph. We won't automatically emit // functions for these instructions. - virtual std::vector GetEpilogues( + virtual std::vector GetEpilogues( const HloFusionInstruction& fusion, mlir::MLIRContext* mlir_context) const { return {}; @@ -82,23 +83,22 @@ class MlirFusionEmitterBase : public KernelFusionInterface { // Creates an epilogue with the raw thread/block/symbol indices, as defined // by the fusion's thread->output mapping. - mlir_converter::EpilogueSpecification GetEpilogueForOutputIndexing( + emitters::EpilogueSpecification GetEpilogueForOutputIndexing( const HloFusionAnalysis& analysis, const std::vector& heroes, const std::vector& roots, mlir::MLIRContext* mlir_context) const; virtual absl::Status EmitEntryFunction( - const mlir_converter::PartitionedComputations& computations, - const mlir_converter::CallTargetProvider& call_targets, + const emitters::PartitionedComputations& computations, + const emitters::CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, const HloFusionInstruction& fusion) const = 0; // Evaluates the epilogue of the fusion. Returns the results for each epilogue // root. absl::flat_hash_map EmitEpilogue( - int epilogue_index, - const mlir_converter::PartitionedComputations& computations, + int epilogue_index, const emitters::PartitionedComputations& computations, mlir::func::FuncOp entry_fn, const absl::flat_hash_map>& injected, diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc index 208f4a3e7ba869..671860aeaa6454 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc @@ -44,15 +44,15 @@ limitations under the License. #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" +#include "xla/codegen/emitters/computation_partitioner.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/stream_executor/device_description.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" @@ -76,8 +76,8 @@ class DummyCopyFusionEmitter : public MlirFusionEmitterBase { protected: absl::Status EmitEntryFunction( - const mlir_converter::PartitionedComputations& computations, - const mlir_converter::CallTargetProvider& call_targets, + const emitters::PartitionedComputations& computations, + const emitters::CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, const HloFusionInstruction& fusion) const override { mlir::ImplicitLocOpBuilder b(entry_function.getLoc(), entry_function); diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc index be764922a0840d..e22f37c7c30f83 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc @@ -46,16 +46,16 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" +#include "xla/codegen/emitters/computation_partitioner.h" +#include "xla/codegen/emitters/elemental_hlo_to_mlir.h" +#include "xla/codegen/emitters/type_util.h" #include "xla/hlo/analysis/indexing_analysis.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/utils/hlo_traversal.h" #include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" -#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/type_util.h" #include "xla/service/gpu/fusions/reduction_base.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" @@ -70,6 +70,7 @@ limitations under the License. namespace xla { namespace gpu { +using emitters::PartitionedComputations; using llvm::SmallVector; using mlir::AffineExpr; using mlir::AffineMap; @@ -77,7 +78,6 @@ using mlir::ImplicitLocOpBuilder; using mlir::MLIRContext; using mlir::Value; using mlir::ValueRange; -using mlir_converter::PartitionedComputations; constexpr int kRowKept = ReductionDimensions::kRowKeptDimension; constexpr int kRowMinorReduced = ReductionDimensions::kRowMinorReducedDimension; @@ -96,7 +96,7 @@ struct MlirReductionFusion::EmitterState { mlir::func::FuncOp entry_function, const HloFusionInstruction& fusion, const PartitionedComputations& computations, - const mlir_converter::CallTargetProvider& call_target) + const emitters::CallTargetProvider& call_target) : owner(owner), entry_function(entry_function), fusion(fusion), @@ -164,9 +164,9 @@ struct MlirReductionFusion::EmitterState { mlir::func::FuncOp entry_function; const HloFusionInstruction& fusion; const PartitionedComputations& computations; - const mlir_converter::CallTargetProvider& call_target; + const emitters::CallTargetProvider& call_target; ImplicitLocOpBuilder builder; - const mlir_converter::PartitionedComputation& computation; + const emitters::PartitionedComputation& computation; absl::flat_hash_map fusion_result_index_starts; absl::flat_hash_map root_indices; SmallVector thread_and_block_ids; @@ -193,23 +193,24 @@ PerThreadOutputs MlirReductionFusion::EmitterState::EmitPerThreadElements( iter_arg_inits.append(init); } - auto body_builder = [&](ValueRange symbol_values, ValueRange map_results, + auto body_builder = [&](ImplicitLocOpBuilder& nested_b, + ValueRange symbol_values, ValueRange map_results, ValueRange iter_args) -> SmallVector { llvm::SmallVector results = iter_args; for (auto* reduction : reductions) { int arity = reduction->operand_count() / 2; int start = iter_arg_starts[reduction]; SmallVector reduce_args = iter_args.slice(start, arity); - auto indices = mlir_converter::ApplyIndexing( + auto indices = emitters::ApplyIndexing( GetBitcastMap(owner.input_shape_, reduction->operand(0)->shape(), - builder.getContext()), - map_results, {}, builder); + nested_b.getContext()), + map_results, {}, nested_b); reduce_args.append(ProvideParameterRange(computation, reduction, 0, arity, indices, call_target, - entry_function, builder)); + entry_function, nested_b)); const auto& reducer = GetReducer(reduction); absl::c_copy( - builder.create(reducer, reduce_args).getResults(), + nested_b.create(reducer, reduce_args).getResults(), results.begin() + start); } struct SideOutput { @@ -218,12 +219,12 @@ PerThreadOutputs MlirReductionFusion::EmitterState::EmitPerThreadElements( }; llvm::SmallVector side_output_values; for (auto* side_output : side_outputs) { - auto indices = mlir_converter::ApplyIndexing( + auto indices = emitters::ApplyIndexing( GetBitcastMap(owner.input_shape_, side_output->shape(), builder.getContext()), map_results, {}, builder); auto* root_tuple = fusion.fused_expression_root(); - Value value = mlir_converter::ProvideParameter( + Value value = emitters::ProvideParameter( computation, root_tuple, root_tuple->operand_index(side_output), indices, call_target, entry_function, builder)[0]; side_output_values.push_back({std::move(indices), value}); @@ -238,9 +239,9 @@ PerThreadOutputs MlirReductionFusion::EmitterState::EmitPerThreadElements( return results; }; - auto results_vector = mlir_converter::EmitXlaLoopOp( - builder, thread_and_block_ids, iter_arg_inits, tile_indexing, - body_builder, vectorize); + auto results_vector = + emitters::EmitXlaLoopOp(builder, thread_and_block_ids, iter_arg_inits, + tile_indexing, body_builder, vectorize); mlir::ValueRange results = results_vector; PerThreadOutputs scalars_and_outputs; @@ -274,16 +275,16 @@ SmallVector MlirReductionFusion::EmitterState::WriteToSharedMemory( auto tile_shape = ShapeUtil::MakeShapeWithDescendingLayout( reduction->operand(i)->shape().element_type(), shape); tiles.push_back(builder.create( - mlir_converter::TensorShapeToMlirType(tile_shape, builder))); + emitters::TensorShapeToMlirType(tile_shape, builder))); } } - auto written_tiles = mlir_converter::EmitLoopNest( + auto written_tiles = emitters::EmitLoopNest( builder, {thread_and_block_ids[0]}, tiles, map, [&](mlir::ValueRange iter_args, mlir::ValueRange dim_values, mlir::ValueRange symbol_values) { - auto indices = mlir_converter::ApplyIndexing(map, dim_values, - symbol_values, builder); + auto indices = + emitters::ApplyIndexing(map, dim_values, symbol_values, builder); int shared_index = 0; SmallVector written = iter_args; for (auto* hero : reductions) { @@ -339,14 +340,14 @@ mlir::ValueRange MlirReductionFusion::EmitterState::ReduceViaSharedMemory( auto tiles = WriteToSharedMemory(reductions, per_thread.reduction_scalars, padding); - return mlir_converter::EmitLoopNest( + return emitters::EmitLoopNest( builder, {thread_and_block_ids[0]}, per_thread.outputs, loop_indexing, [&](ValueRange outputs, ValueRange dim_values, ValueRange symbol_values) -> SmallVector { - auto read_condition = mlir_converter::CheckConstraints( + auto read_condition = emitters::CheckConstraints( read_indexing, dim_values, symbol_values, builder); - auto indices = mlir_converter::ApplyIndexing(read_indexing, dim_values, - symbol_values, builder); + auto indices = emitters::ApplyIndexing(read_indexing, dim_values, + symbol_values, builder); int64_t tile_index = 0; HloValueMap reduce_args; @@ -438,10 +439,9 @@ LaunchDimensions MlirReductionFusion::launch_dimensions() const { /*y=*/1, /*z=*/1)}; } -std::vector -MlirReductionFusion::GetEpilogues(const HloFusionInstruction& fusion, - MLIRContext* mlir_context) const { - std::vector epilogues; +std::vector MlirReductionFusion::GetEpilogues( + const HloFusionInstruction& fusion, MLIRContext* mlir_context) const { + std::vector epilogues; epilogues.reserve(reduction_heroes_.size()); for (const auto& [heroes, roots] : llvm::zip(reduction_heroes_, reduction_roots_)) { @@ -452,9 +452,8 @@ MlirReductionFusion::GetEpilogues(const HloFusionInstruction& fusion, // get "fused" into the tuple function. for (const auto& roots : side_output_roots_) { for (const auto* root : roots) { - epilogues.push_back( - mlir_converter::EpilogueSpecification::FromIdentityIndexing( - root, root, mlir_context)); + epilogues.push_back(emitters::EpilogueSpecification::FromIdentityIndexing( + root, root, mlir_context)); } } return epilogues; @@ -462,7 +461,7 @@ MlirReductionFusion::GetEpilogues(const HloFusionInstruction& fusion, absl::Status MlirReductionFusion::EmitEntryFunction( const PartitionedComputations& computations, - const mlir_converter::CallTargetProvider& call_targets, + const emitters::CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, const HloFusionInstruction& fusion) const { EmitterState state{*this, entry_function, fusion, computations, call_targets}; @@ -567,13 +566,13 @@ SmallVector MlirReductionFusion::EvaluateEpilogue( auto values = EmitEpilogue(group_id, state.computations, state.entry_function, results, epilogue_input_indices, b); int first_root_index = state.root_indices[epilogue.roots.front()]; - auto thread_has_output = mlir_converter::CheckConstraints( + auto thread_has_output = emitters::CheckConstraints( *ComputeThreadIdToOutputIndexing(first_root_index, b.getContext()), state.thread_and_block_ids, symbol_values, b); for (auto [index, root] : llvm::enumerate(epilogue.roots)) { - auto output_indices = mlir_converter::ApplyIndexing( - epilogue.root_indexing[index], state.thread_and_block_ids, - symbol_values, b); + auto output_indices = + emitters::ApplyIndexing(epilogue.root_indexing[index], + state.thread_and_block_ids, symbol_values, b); for (auto [result_index, result] : llvm::enumerate(values.at(root))) { auto& output = outputs[state.OutputIndex(root, result_index)]; output = b.create(thread_has_output, result, output, diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h index 8d56895b09169b..77c931dff445db 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h @@ -31,10 +31,10 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" +#include "xla/codegen/emitters/computation_partitioner.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" #include "xla/service/gpu/fusions/reduction_base.h" #include "xla/service/gpu/hlo_fusion_analysis.h" @@ -74,12 +74,12 @@ class MlirReductionFusion : public MlirFusionEmitterBase { HloValueMap GetInits(int group_id, EmitterState& state) const; absl::Status EmitEntryFunction( - const mlir_converter::PartitionedComputations& computations, - const mlir_converter::CallTargetProvider& call_targets, + const emitters::PartitionedComputations& computations, + const emitters::CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, const HloFusionInstruction& fusion) const override; - std::vector GetEpilogues( + std::vector GetEpilogues( const HloFusionInstruction& fusion, mlir::MLIRContext* mlir_context) const override; diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc index 00d7e9735eb7cb..eceeb699996b6d 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc @@ -14,7 +14,10 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/fusions/scatter_mlir.h" +#include #include +#include +#include #include #include #include @@ -22,11 +25,18 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/types/span.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -35,61 +45,372 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" +#include "xla/codegen/emitters/computation_partitioner.h" +#include "xla/codegen/emitters/elemental_hlo_to_mlir.h" +#include "xla/codegen/emitters/type_util.h" #include "xla/codegen/ir/xla_ops.h" +#include "xla/hlo/analysis/indexing_analysis.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/primitive_util.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" -#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/scatter_simplifier.h" #include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla { namespace gpu { namespace { -namespace ma = ::mlir::arith; +namespace arith = ::mlir::arith; namespace scf = ::mlir::scf; +namespace vector = ::mlir::vector; +namespace tensor = ::mlir::tensor; +using emitters::CallTargetProvider; +using emitters::EmitXlaLoopOp; +using emitters::PartitionedComputations; +using emitters::ProvideParameter; +using llvm::APFloat; +using llvm::APInt; using llvm::SmallVector; +using mlir::AffineExpr; +using mlir::AffineMap; +using mlir::DenseElementsAttr; +using mlir::getAffineDimExpr; +using mlir::getAffineSymbolExpr; +using mlir::ImplicitLocOpBuilder; using mlir::Location; +using mlir::MLIRContext; using mlir::OpBuilder; using mlir::Value; using mlir::ValueRange; +using mlir::VectorType; +using mlir::func::FuncOp; using mlir::func::ReturnOp; -using mlir_converter::CallTargetProvider; -using mlir_converter::PartitionedComputations; -using mlir_converter::ProvideParameter; +using primitive_util::IsUnsignedIntegralType; + +constexpr int64_t kNumWarpsPerBlock = 4; +constexpr int64_t kMaxVectorizedBits = 128; +constexpr int64_t kScatterOperandIndex = 0; +constexpr int64_t kScatterIndicesIndex = 1; +constexpr int64_t kScatterUpdateIndex = 2; + +// Emit +// if (condition) { +// updated_values = updated_values_fn(); +// yield updated_values; +// } else { +// yield values; +// } +ValueRange EmitUpdateIf( + ImplicitLocOpBuilder& b, Value condition, ValueRange values, + llvm::function_ref(ImplicitLocOpBuilder&)> + updated_values_fn) { + return b + .create( + condition, + [&](OpBuilder& then_b, Location then_loc) -> void { + ImplicitLocOpBuilder implicit_then_b(then_loc, then_b); + then_b.create(then_loc, + updated_values_fn(implicit_then_b)); + }, + [&](OpBuilder& else_b, Location else_loc) -> void { + else_b.create(else_loc, values); + }) + .getResults(); +} + +// Computes if the slice with the sizes `slice_shape` and the offsets `offsets` +// can be inserted into the operand with the shape `operand_shape`. +Value EmitBoundsCheck(ImplicitLocOpBuilder& b, + absl::Span slice_shape, + absl::Span operand_shape, + ValueRange offsets) { + Value in_bounds = b.create(1, b.getI1Type()); + for (auto [update_dim, operand_dim, offset] : + llvm::zip(slice_shape, operand_shape, offsets)) { + Value ub = b.create(operand_dim - update_dim); + // One bounds check is enough even for signed indices: `sge 0` is + // implied by `ule ub`, because `ub >= 0`. + in_bounds = b.createOrFold( + in_bounds, + b.createOrFold(arith::CmpIPredicate::ule, offset, ub)); + } + return in_bounds; +} + +Value EmitInequalityCheck(ImplicitLocOpBuilder& b, ValueRange lhs, + ValueRange rhs) { + Value not_equal = b.create(0, b.getI1Type()); + for (auto [lhs_elem, rhs_elem] : llvm::zip(lhs, rhs)) { + not_equal = b.createOrFold( + not_equal, b.createOrFold(arith::CmpIPredicate::ne, + lhs_elem, rhs_elem)); + } + return not_equal; +} + +Value UpdateIsInbounds(ImplicitLocOpBuilder& b, Value is_inbounds, + Value offsets_changed, ValueRange offsets, + absl::Span slice_shape, + absl::Span operand_shape) { + return EmitUpdateIf(b, offsets_changed, is_inbounds, + [&](ImplicitLocOpBuilder& if_b) -> SmallVector { + return {EmitBoundsCheck(if_b, slice_shape, + operand_shape, offsets)}; + }) + .front(); +} + +SmallVector Pack(llvm::ArrayRef ranges) { + int64_t total_size = 0; + for (auto& range : ranges) { + total_size += range.size(); + } + SmallVector result; + result.reserve(total_size); + for (auto range : ranges) { + result.append(range.begin(), range.end()); + } + return result; +} + +SmallVector Unpack(ValueRange range, + llvm::ArrayRef sizes) { + int64_t total_size = 0; + for (auto& size : sizes) { + total_size += size; + } + assert(total_size == range.size()); + SmallVector result; + result.reserve(sizes.size()); + for (int64_t size : sizes) { + result.push_back(range.take_front(size)); + range = range.drop_front(size); + } + return result; +} + +// Pads the given values with zeros to the given container size. +SmallVector PadWithZeros(ValueRange values, int64_t size, + ImplicitLocOpBuilder& b) { + SmallVector padded_values(values.begin(), values.end()); + if (values.size() >= size) return padded_values; + auto zero = b.create(0); + for (int i = values.size(); i < size; ++i) { + padded_values.push_back(zero); + } + return padded_values; +} + +// Creates a new indexing map that is the same as `map` but with the range +// variable at `range_var_index` replaced with the new dimension variable at +// `dimension_{dim_var_size)`. Potentially, it can be moved to indexing_map.h. +IndexingMap ConvertRangeVariableToDimension(const IndexingMap& map, + int64_t range_var_index) { + auto* mlir_context = map.GetMLIRContext(); + + AffineMap affine_map = map.GetAffineMap(); + // Update the affine map. + SmallVector symbol_replacements; + symbol_replacements.reserve(affine_map.getNumSymbols()); + for (int i = 0; i < affine_map.getNumSymbols(); ++i) { + if (i == range_var_index) { + symbol_replacements.push_back( + getAffineDimExpr(affine_map.getNumDims(), mlir_context)); + } else { + symbol_replacements.push_back( + getAffineSymbolExpr(i > range_var_index ? i - 1 : i, mlir_context)); + } + } + + AffineMap converted_affine_map = affine_map.replaceDimsAndSymbols( + {}, symbol_replacements, affine_map.getNumDims() + 1, + affine_map.getNumSymbols() - 1); + + // Update the constraints. + std::vector> constraints; + constraints.reserve(map.GetConstraintsCount()); + for (auto constraint : map.GetConstraints()) { + constraints.push_back({constraint.first.replaceSymbols(symbol_replacements), + constraint.second}); + } + // Update the variables. + std::vector dims = map.GetDimVars(); + std::vector range_vars = map.GetRangeVars(); + std::vector rt_vars = map.GetRTVars(); + + dims.push_back(range_vars[range_var_index]); + range_vars.erase(range_vars.begin() + range_var_index); + return IndexingMap{converted_affine_map, std::move(dims), + std::move(range_vars), std::move(rt_vars), constraints}; +} } // namespace -MlirScatterFusion::MlirScatterFusion(const HloFusionAnalysis& analysis) - : analysis_(analysis) { - const auto& scatter = analysis_.fusion_hero(0).instruction(); - auto& scatter_update_shape = scatter.operands().back()->shape(); - config_ = ComputeLoopFusionConfig(analysis, scatter_update_shape); +class EmitterHelper { + public: + EmitterHelper(const ScatterDescription& description, + const PartitionedComputations* computations, + const CallTargetProvider* call_targets, FuncOp entry_function, + const HloFusionInstruction& fusion) + : description_(&description), + entry_function_(entry_function), + call_targets_(call_targets), + root_computation_(&computations->FindPartitionedComputation( + fusion.fused_instructions_computation())) {} + + Value GetOperandElement(ImplicitLocOpBuilder& b, ValueRange indices) const { + return GetElement(b, kScatterOperandIndex, indices); + } + + Value GetIndicesElement(ImplicitLocOpBuilder& b, ValueRange indices) const { + return GetElement(b, kScatterIndicesIndex, indices); + } + + Value GetUpdateElement(ImplicitLocOpBuilder& b, ValueRange indices) const { + return GetElement(b, kScatterUpdateIndex, indices); + } + + FuncOp GetReducer() const { + return (*call_targets_)( + description_->scatter->called_computations()[0]->root_instruction()); + } + + SmallVector ExtractOffsets(ImplicitLocOpBuilder& b, + Value slice_id) const; + + Value EmitScatterComputation(ImplicitLocOpBuilder& b, ValueRange indices, + Value update_elem, Value output_tensor) const; + + SmallVector WriteAccumulatedElementToOutput( + ImplicitLocOpBuilder& b, Value accumulator, + ValueRange accumulator_indices, ValueRange slice_indices, + ValueRange offsets, Value output_tensor) const; + + Value WriteAccumulatorToOutput(ImplicitLocOpBuilder& b, + Value write_to_output_required, + ValueRange thread_and_block_ids, Value iv, + const IndexingMap& slice_indexing, + ValueRange offsets, Value accumulator, + Value output_tensor) const; + + private: + Value GetElement(ImplicitLocOpBuilder& b, int operand_index, + ValueRange indices) const; + + const ScatterDescription* description_; + FuncOp entry_function_; + const emitters::CallTargetProvider* call_targets_; + const emitters::PartitionedComputation* root_computation_; +}; + +SmallVector EmitterHelper::ExtractOffsets(ImplicitLocOpBuilder& b, + Value slice_id) const { + auto index_type = b.getIndexType(); + SmallVector offsets; + offsets.reserve(description_->index_vector_length); + for (int i = 0; i < description_->index_vector_length; ++i) { + SmallVector indices_tensor_indices = { + slice_id, b.create(i)}; + auto index = GetIndicesElement(b, indices_tensor_indices); + index = + IsUnsignedIntegralType( + description_->scatter->scatter_indices()->shape().element_type()) + ? b.create(index_type, index).getResult() + : b.create(index_type, index).getResult(); + offsets.push_back(index); + } + return offsets; +} + +Value EmitterHelper::EmitScatterComputation(ImplicitLocOpBuilder& b, + ValueRange indices, + Value update_elem, + Value output_tensor) const { + FuncOp reducer = GetReducer(); + if (description_->scatter->unique_indices()) { + auto operand_elem = GetOperandElement(b, indices); + auto reduced_val = emitters::InlineBlock(b, reducer.getBody().front(), + {operand_elem, update_elem})[0]; + return b.create(reduced_val, output_tensor, indices); + } + auto atomic_rmw = b.create(output_tensor, indices); + OpBuilder body_b = atomic_rmw.getBodyBuilder(); + auto reduced_val = + emitters::InlineBlock(body_b, reducer.getBody().front(), + {atomic_rmw.getCurrentValue(), update_elem})[0]; + body_b.create(reducer->getLoc(), reduced_val); + return atomic_rmw->getResult(0); +} + +SmallVector EmitterHelper::WriteAccumulatedElementToOutput( + ImplicitLocOpBuilder& b, Value accumulator, ValueRange accumulator_indices, + ValueRange slice_indices, ValueRange offsets, Value output_tensor) const { + Value accumulator_elem = b.create( + accumulator, mlir::getAsOpFoldResult(accumulator_indices)); + + SmallVector output_indices(offsets.begin(), offsets.end()); + for (int i = 0; i < output_indices.size(); ++i) { + output_indices[i] = + b.create(slice_indices[i + 1], output_indices[i]); + } + return {EmitScatterComputation(b, output_indices, accumulator_elem, + output_tensor)}; } -std::optional MlirScatterFusion::ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const { - return std::nullopt; +Value EmitterHelper::WriteAccumulatorToOutput( + ImplicitLocOpBuilder& b, Value write_to_output_required, + ValueRange thread_and_block_ids, Value iv, + const IndexingMap& slice_indexing, ValueRange offsets, Value accumulator, + Value output_tensor) const { + SmallVector dims = Pack({thread_and_block_ids, iv}); + return EmitUpdateIf( + b, write_to_output_required, output_tensor, + [&](ImplicitLocOpBuilder& if_builder) -> SmallVector { + return EmitXlaLoopOp( + if_builder, dims, output_tensor, slice_indexing, + [&](ImplicitLocOpBuilder& update_loop_b, + ValueRange accumulator_indices, ValueRange slice_indices, + ValueRange output_tensors) -> SmallVector { + return WriteAccumulatedElementToOutput( + update_loop_b, accumulator, accumulator_indices, + slice_indices, offsets, output_tensors.front()); + }); + }) + .front(); } +Value EmitterHelper::GetElement(ImplicitLocOpBuilder& b, int operand_index, + ValueRange indices) const { + return ProvideParameter(*root_computation_, description_->scatter, + operand_index, indices, *call_targets_, + entry_function_, b)[0]; +} + +MlirScatterFusion::MlirScatterFusion(const HloFusionAnalysis& analysis, + const ScatterDescription& description, + int64_t vector_size) + : analysis_(analysis), + description_(description), + warp_size_(WarpSize(analysis_.device_info())), + vector_size_(vector_size) {} + std::optional MlirScatterFusion::ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const { - const auto* scatter = - DynCast(&analysis_.fusion_hero(0).instruction()); - CHECK(ScatterSimplifier::IsSimplifiedScatter(scatter)) + int64_t root_index, int64_t hero_operand_index, MLIRContext* ctx) const { + CHECK(ScatterSimplifier::IsSimplifiedScatter(description_.scatter)) << "Non-simplified HLO Scatter is not supported."; - int64_t scatter_operand_count = scatter->scatter_operand_count(); + + int64_t scatter_operand_count = description_.scatter->scatter_operand_count(); // Scatter operands a packed in the following way: // Operand IDs [0, scatter_operand_count - 1] for `scatter operands`. // Operand ID scatter_operand_count for `scatter indices`. @@ -100,190 +421,546 @@ std::optional MlirScatterFusion::ComputeThreadIdToInputIndexing( if (hero_operand_index < scatter_operand_count) { return std::nullopt; } - // Compute thread id mapping based on the first update operand. - Shape scatter_update_shape = scatter->scatter_updates().front()->shape(); - // TODO(jreiffers): There are scatters where vectorization makes sense, but we - // cannot currently detect them. Add a heuristic. - IndexingMap scatter_update_map = GetDefaultThreadIdIndexingMap( - launch_dimensions(), /*unroll_factor=*/1, scatter_update_shape, ctx); - - // For scatter indices we project indexing for scatter updates and take the - // first result of the affine map only, because they coincide. - if (hero_operand_index == scatter_operand_count) { - Shape scatter_indices_shape = scatter->scatter_indices()->shape(); - CHECK_EQ(scatter_indices_shape.rank(), 2) << scatter->ToString(); - // Create a map from scatter update to scatter indices. - IndexingMap updates_to_indices_map{ - mlir::AffineMap::get( - /*dimCount=*/scatter_update_shape.rank(), /*symbolCount=*/1, - {mlir::getAffineDimExpr(0, ctx), mlir::getAffineSymbolExpr(0, ctx)}, - ctx), - DimVarsFromTensorSizes(scatter_update_shape.dimensions()), - RangeVarsFromTensorSizes({scatter_indices_shape.dimensions(1)}), - /*rt_vars=*/{}}; - auto scatter_indices_map = scatter_update_map * updates_to_indices_map; - scatter_indices_map.Simplify(); - return scatter_indices_map; + bool is_indices_operand = hero_operand_index == scatter_operand_count; + auto map = IndexingMap::GetUndefined(); + if (is_indices_operand) { + ComputeIndexing(ctx, /*updates_map=*/nullptr, &map); + return map; } - return scatter_update_map; -} - -LaunchDimensions MlirScatterFusion::launch_dimensions() const { - const auto& scatter = analysis_.fusion_hero(0).instruction(); - // Compute thread id mapping based on the shape of update operand. - auto& scatter_update_shape = scatter.operands().back()->shape(); - return CalculateLaunchDimensions(scatter_update_shape, - analysis_.device_info()); + ComputeIndexing(ctx, &map, /*indices_map=*/nullptr); + return map; } -std::vector -MlirScatterFusion::GetEpilogues(const HloFusionInstruction& fusion, - mlir::MLIRContext* mlir_context) const { +std::vector MlirScatterFusion::GetEpilogues( + const HloFusionInstruction& fusion, MLIRContext* mlir_context) const { // We don't actually support epilogues for scatter, but this is how we tell // the base class that we don't want it to generate code for the scatter. - return {mlir_converter::EpilogueSpecification::FromIdentityIndexing( + return {emitters::EpilogueSpecification::FromIdentityIndexing( &analysis_.fusion_hero(0).instruction(), &analysis_.fusion_root(0).instruction(), mlir_context)}; } -mlir::Value EmitScatterComputation( - const HloInstruction* scatter, ValueRange indices, Value update_elem, - Value output_tensor, - const mlir_converter::PartitionedComputation& root_computation, - const mlir_converter::CallTargetProvider& call_targets, - mlir::func::FuncOp entry_function, mlir::ImplicitLocOpBuilder& b) { - constexpr int kScatterOperandIndex = 0; - auto reducer = - call_targets(scatter->called_computations()[0]->root_instruction()); - if (scatter->unique_indices()) { - auto operand_elem = - ProvideParameter(root_computation, scatter, kScatterOperandIndex, - indices, call_targets, entry_function, b)[0]; - auto reduced_val = mlir_converter::InlineBlock( - b, reducer.getBody().front(), {operand_elem, update_elem})[0]; - - return b.create(reduced_val, output_tensor, - indices); +ScatterWithDistributedUpdates::ScatterWithDistributedUpdates( + const HloFusionAnalysis& analysis, const ScatterDescription& description, + int64_t vector_size) + : MlirScatterFusion(analysis, description, vector_size) { + // We have to make sure that there is no thread that processes elements of + // two different update slice. + auto launch_dimensions = CalculateLaunchDimensions( + description_.update_shape, analysis_.device_info(), + {static_cast(vector_size_)}); + num_blocks_ = launch_dimensions.num_blocks(); + num_warps_ = CeilOfRatio( + static_cast(launch_dimensions.num_threads_per_block()), + warp_size_); +} + +void ScatterWithDistributedUpdates::ComputeIndexing( + MLIRContext* ctx, IndexingMap* updates_map, + IndexingMap* indices_map) const { + // Compute thread id mapping based on the first update operand. + IndexingMap scatter_update_map = GetDefaultThreadIdIndexingMap( + launch_dimensions(), vector_size_, description_.update_shape, ctx); + + // For scatter indices we project indexing for scatter updates and take the + // first result of the affine map only, because they coincide. + if (indices_map) { + // Create a map from scatter update to scatter indices. + *indices_map = IndexingMap{ + AffineMap::get(6, 1, + {scatter_update_map.GetAffineMap().getResult(0), + getAffineSymbolExpr(0, ctx)}, + ctx), + DimVarsFromGPUGrid({num_warps_ * warp_size_, 1, 1, num_blocks_, 1, 1}), + RangeVarsFromTensorSizes({description_.index_vector_length}), + /*rt_vars=*/{}}; + indices_map->Simplify(); + } + if (updates_map) { + *updates_map = std::move(scatter_update_map); } - auto atomic_rmw = b.create(output_tensor, indices); - mlir::OpBuilder body_builder = atomic_rmw.getBodyBuilder(); - auto reduced_val = mlir_converter::InlineBlock( - body_builder, reducer.getBody().front(), - {atomic_rmw.getCurrentValue(), update_elem})[0]; - body_builder.create(reducer->getLoc(), reduced_val); - return atomic_rmw->getResult(0); } -// The scatter has to be canonicalized with `scatter_simplifier` pass. absl::Status MlirScatterFusion::EmitEntryFunction( const PartitionedComputations& computations, - const CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, + const CallTargetProvider& call_targets, FuncOp entry_function, const HloFusionInstruction& fusion) const { - constexpr int kScatterOperandIndex = 0; - constexpr int kScatterIndicesIndex = 1; - constexpr int kScatterUpdateIndex = 2; - const auto* scatter = &analysis_.fusion_hero(0).instruction(); - const HloInstruction* scatter_operand = - scatter->operand(kScatterOperandIndex); - const HloInstruction* scatter_indices = - scatter->operand(kScatterIndicesIndex); - const HloInstruction* scatter_update = scatter->operand(kScatterUpdateIndex); - - mlir::MLIRContext* mlir_context = entry_function.getContext(); - auto thread_id_to_update_map = - ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/kScatterUpdateIndex, - mlir_context) - .value(); - thread_id_to_update_map.Simplify(); - thread_id_to_update_map.RemoveUnusedSymbols(); - - auto thread_id_to_update_id_map = - IndexingMap(thread_id_to_update_map.GetAffineMap().getMajorSubMap(1), - thread_id_to_update_map.GetDimVars(), - thread_id_to_update_map.GetRangeVars(), /*rt vars = */ {}); - thread_id_to_update_id_map.RemoveUnusedSymbols(); - - const auto& root_computation = computations.FindPartitionedComputation( - fusion.fused_instructions_computation()); - mlir::ImplicitLocOpBuilder b(entry_function.getLoc(), entry_function); - b.setInsertionPointToStart(entry_function.addEntryBlock()); + EmitterHelper helper(description_, &computations, &call_targets, + entry_function, fusion); + // Prepare the entry function. + ImplicitLocOpBuilder b(entry_function.getLoc(), entry_function); + b.setInsertionPointToStart(entry_function.addEntryBlock()); auto thread_and_block_ids = EmitThreadAndBlockIds(b); + Value output_tensor = entry_function.getArguments().back(); + + // Compute indexing maps. + MLIRContext* mlir_context = entry_function.getContext(); + IndexingMap updates_map = IndexingMap::GetUndefined(); + IndexingMap indices_map = IndexingMap::GetUndefined(); + ComputeIndexing(mlir_context, &updates_map, &indices_map); + updates_map.Simplify(); + + return EmitEntryFunctionImpl(b, helper, updates_map, indices_map, + thread_and_block_ids, output_tensor); +} + +// Emits an inbounds check and a loop over updates inside it. Does not do any +// accumulation. +void EmitNaiveImplementation(ImplicitLocOpBuilder& b, + const ScatterDescription& description, + const EmitterHelper& helper, + const IndexingMap& updates_map, + const IndexingMap& indices_map, + ValueRange thread_and_block_ids, + Value output_tensor) { + MLIRContext* mlir_context = b.getContext(); + auto thread_id_to_update_id_map = IndexingMap( + AffineMap::get(6, 0, {updates_map.GetAffineMap().getResult(0)}, + mlir_context), + updates_map.GetDimVars(), + /*range_vars = */ {}, /*rt vars = */ {}); Value thread_id_to_index_id_value = - mlir_converter::ApplyIndexing(thread_id_to_update_id_map, - thread_and_block_ids, {}, b) + emitters::ApplyIndexing(thread_id_to_update_id_map, thread_and_block_ids, + {}, b) .front(); - SmallVector result_tensors{entry_function.getArguments().back()}; + SmallVector update_offsets = + helper.ExtractOffsets(b, thread_id_to_index_id_value); - // Extract slice offsets from scatter_indices operand, compute if the - // whole slice of scatter_update operand will fit into the output. - mlir::Value in_bounds = b.create(1, b.getI1Type()); + Value in_bounds = EmitBoundsCheck(b, description.slice_shape, + description.output_shape, update_offsets); - Value zero = b.create(0); - SmallVector update_offsets(scatter->shape().rank(), zero); - for (int i = 0; i < scatter_indices->shape().dimensions(1); ++i) { - SmallVector indices_tensor_indices = { - thread_id_to_index_id_value, b.create(i)}; - auto index = ProvideParameter(root_computation, scatter, - kScatterIndicesIndex, indices_tensor_indices, - call_targets, entry_function, b)[0]; - if (primitive_util::IsUnsignedIntegralType( - scatter->operand(kScatterIndicesIndex)->shape().element_type())) { - index = b.create(b.getIndexType(), index); - } else { - index = b.create(b.getIndexType(), index); - } - Value ub = b.create( - scatter_operand->shape().dimensions(i) - - scatter_update->shape().dimensions(i + 1)); - // One bounds check is enough even for signed indices: `sge 0` is - // implied by `ule ub`, because `ub >= 0`. - in_bounds = b.create( - in_bounds, b.create(ma::CmpIPredicate::ule, index, ub)); - update_offsets[i] = index; - } Value predicated_update = - b.create( - in_bounds, - [&](OpBuilder& then_builder, Location then_loc) -> void { - mlir::ImplicitLocOpBuilder implicit_then_builder(then_loc, - then_builder); - auto scatter_result = mlir_converter::EmitXlaLoopOp( - implicit_then_builder, thread_and_block_ids, result_tensors, - thread_id_to_update_map, - [&](ValueRange symbol_values, ValueRange map_results, - ValueRange output_tensors) -> SmallVector { - // Extract update element. - auto update_elem = ProvideParameter( - root_computation, scatter, kScatterUpdateIndex, - map_results, call_targets, entry_function, - implicit_then_builder)[0]; - - auto output_indices = std::move(update_offsets); - for (int i = 0; i < output_indices.size(); ++i) { - output_indices[i] = - implicit_then_builder.create( - map_results[i + 1], output_indices[i]); - } - Value output_tensor = output_tensors.front(); - Value updated_output = EmitScatterComputation( - scatter, output_indices, update_elem, output_tensor, - root_computation, call_targets, entry_function, - implicit_then_builder); - return {updated_output}; - }); - implicit_then_builder.create(scatter_result); - }, - [&](OpBuilder& else_b, Location else_loc) { - else_b.create(else_loc, result_tensors.front()); - }) - .getResult(0); + EmitUpdateIf( + b, in_bounds, {output_tensor}, + [&](ImplicitLocOpBuilder& nested_b) -> SmallVector { + return EmitXlaLoopOp( + nested_b, thread_and_block_ids, {output_tensor}, updates_map, + [&](ImplicitLocOpBuilder& update_loop_b, + ValueRange symbol_values, ValueRange map_results, + ValueRange output_tensors) -> SmallVector { + // Extract update element. + auto update_elem = + helper.GetUpdateElement(update_loop_b, map_results); + auto output_indices = std::move(update_offsets); + int64_t output_rank = description.output_shape.size(); + output_indices = + PadWithZeros(output_indices, output_rank, update_loop_b); + for (int i = 0; i < output_indices.size(); ++i) { + output_indices[i] = update_loop_b.create( + map_results[i + 1], output_indices[i]); + } + Value output_tensor = output_tensors.front(); + Value updated_output = helper.EmitScatterComputation( + update_loop_b, output_indices, update_elem, + output_tensor); + return {updated_output}; + }); + }) + .front(); b.create(predicated_update); +} + +absl::Status ScatterWithDistributedUpdates::EmitEntryFunctionImpl( + ImplicitLocOpBuilder& b, const EmitterHelper& helper, + const IndexingMap& updates_map, const IndexingMap& indices_map, + ValueRange thread_and_block_ids, Value output_tensor) const { + if (VLOG_IS_ON(5)) { + llvm::errs() << "Settings for ScatterWithDistributedUpdates: \n" + << "vector_size_: " << vector_size_ << "\n" + << "num_warps_: " << num_warps_ << "\n" + << "num_blocks_: " << num_blocks_; + } + EmitNaiveImplementation(b, description_, helper, updates_map, indices_map, + thread_and_block_ids, output_tensor); return absl::OkStatus(); } +ScatterWithDistributedIndices::ScatterWithDistributedIndices( + const HloFusionAnalysis& analysis, const ScatterDescription& description, + int64_t vector_size, int64_t num_warps_per_slice, + int64_t num_indices_per_warp) + : MlirScatterFusion(analysis, description, vector_size), + num_warps_per_slice_(num_warps_per_slice), + num_indices_per_warp_(num_indices_per_warp) { + num_warps_ = kNumWarpsPerBlock; + num_blocks_ = CeilOfRatio(description.num_slices * num_warps_per_slice_, + num_indices_per_warp_ * num_warps_); +} + +void ScatterWithDistributedIndices::ComputeIndexing( + MLIRContext* ctx, IndexingMap* updates_map, + IndexingMap* indices_map) const { + // Compute thread id mapping based on the first update operand. + auto thread_x = getAffineDimExpr( + KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx); + auto block_x = + getAffineDimExpr(KernelFusionInterface::kIndexingMapBlockIdxDims[0], ctx); + auto warp_id = thread_x.floorDiv(warp_size_); + auto slice_id = + (block_x * num_warps_ + warp_id).floorDiv(num_warps_per_slice_); + auto warp_id_in_slice = + (block_x * num_warps_ + warp_id) % num_warps_per_slice_; + auto lane_id = thread_x % warp_size_; + auto index_id_loop = getAffineSymbolExpr(0, ctx); + + auto index_id_expr = slice_id * num_indices_per_warp_ + index_id_loop; + std::pair index_id_constraint = + std::make_pair(index_id_expr, Interval{0, description_.num_slices - 1}); + + auto grid_vars = + DimVarsFromGPUGrid({num_warps_ * warp_size_, 1, 1, num_blocks_, 1, 1}); + if (indices_map) { + auto index_dim_loop = getAffineSymbolExpr(1, ctx); + *indices_map = IndexingMap{ + AffineMap::get(6, 2, {index_id_expr, index_dim_loop}, ctx), + grid_vars, + {IndexingMap::Variable{{0, num_indices_per_warp_ - 1}, "index_id_loop"}, + IndexingMap::Variable{{0, description_.index_vector_length - 1}, + "index_dim"}}, + /*rt_vars=*/{}, + {index_id_constraint}}; + + indices_map->Simplify(); + } + + if (updates_map) { + auto update_dim_loop = getAffineSymbolExpr(1, ctx); + auto vector_id = getAffineSymbolExpr(2, ctx); + auto num_elements_per_slice = Product(description_.slice_shape); + + auto linear_slice_index = + warp_id_in_slice * warp_size_ * vector_size_ + + update_dim_loop * vector_size_ * warp_size_ * num_warps_per_slice_ + + lane_id * vector_size_ + vector_id; + + SmallVector updates_indexing = {index_id_expr}; + updates_indexing.append( + DelinearizeInBoundsIndex(linear_slice_index, description_.slice_shape)); + + *updates_map = IndexingMap{ + AffineMap::get(6, 3, updates_indexing, ctx), + grid_vars, + {IndexingMap::Variable{{0, num_indices_per_warp_ - 1}, "index_id_loop"}, + IndexingMap::Variable{ + {0, CeilOfRatio(num_elements_per_slice, + num_warps_per_slice_ * warp_size_ * vector_size_) - + 1}, + "update_loop"}, + IndexingMap::Variable{{0, vector_size_ - 1}, "vector_id"}}, + /*rt_vars=*/{}, + std::vector>{ + index_id_constraint, + std::make_pair(linear_slice_index, + Interval{0, num_elements_per_slice - 1})}}; + + updates_map->Simplify(); + } +} + +DenseElementsAttr GetShapedZeroConstantAttr(VectorType vector_type) { + auto elem_type = vector_type.getElementType(); + if (auto float_type = mlir::dyn_cast(elem_type)) { + std::vector values( + vector_type.getNumElements(), + APFloat::getZero(float_type.getFloatSemantics())); + return DenseElementsAttr::get(vector_type, values); + } + if (auto int_type = mlir::dyn_cast(elem_type)) { + std::vector values( + vector_type.getNumElements(), + APInt::getZero(int_type.getIntOrFloatBitWidth())); + return DenseElementsAttr::get(vector_type, values); + } + llvm_unreachable("Unsupported vector element type"); +} + +Value ScatterWithDistributedIndices::InitializeAccumulator( + ImplicitLocOpBuilder& b) const { + auto elem_type = emitters::PrimitiveTypeToMlirType(description_.elem_type, b); + auto num_elements_per_slice = Product(description_.slice_shape); + auto update_iterations_per_thread = CeilOfRatio( + num_elements_per_slice, num_warps_per_slice_ * warp_size_ * vector_size_); + auto accumulator_type = + VectorType::get({update_iterations_per_thread, vector_size_}, elem_type); + return b.create( + accumulator_type, GetShapedZeroConstantAttr(accumulator_type)); +} + +absl::Status ScatterWithDistributedIndices::EmitEntryFunctionImpl( + ImplicitLocOpBuilder& b, const EmitterHelper& helper, + const IndexingMap& updates_map, const IndexingMap& indices_map, + ValueRange thread_and_block_ids, Value output_tensor) const { + if (VLOG_IS_ON(5)) { + llvm::errs() << "Settings for ScatterWithDistributedIndices: \n" + << "vector_size_: " << vector_size_ << "\n" + << "num_warps_: " << num_warps_ << "\n" + << "num_blocks_: " << num_blocks_ + << "num_warps_per_slice_: " << num_warps_per_slice_ << "\n" + << "num_indices_per_warp_: " << num_indices_per_warp_; + } + if (num_indices_per_warp_ == 1) { + EmitNaiveImplementation(b, description_, helper, updates_map, indices_map, + thread_and_block_ids, output_tensor); + return absl::OkStatus(); + } + MLIRContext* mlir_context = b.getContext(); + + auto thread_id_to_update_id_map = IndexingMap( + AffineMap::get(6, 1, {updates_map.GetAffineMap().getResult(0)}, + mlir_context), + updates_map.GetDimVars(), + /*range_vars = */ {updates_map.GetRangeVars().front()}, + /*rt vars = */ {}); + IndexingMap slice_indexing = ConvertRangeVariableToDimension(updates_map, 0); + + // Prepare loop initial values. Inits are packed as + // [index_changed, is_inbounds, index_0, ..., accumulator]. + Value is_inbounds_init = b.create(0, b.getI1Type()); + Value slice_id_init = b.create(0); + std::vector indices_init(description_.index_vector_length, + b.create(-1)); + Value accumulator_init = InitializeAccumulator(b); + SmallVector inits = + Pack({slice_id_init, indices_init, is_inbounds_init, accumulator_init, + output_tensor}); + + int64_t output_rank = description_.output_shape.size(); + + auto loop_over_indices_fn = + [&](ImplicitLocOpBuilder& nested_b, ValueRange ivs, + ValueRange thread_id_to_index_id_value, + ValueRange outer_iter_args) -> SmallVector { + // Unpack the iter_args. + SmallVector iter_args_unpack = + Unpack(outer_iter_args, {1, description_.index_vector_length, 1, 1, 1}); + ValueRange trimmed_offsets = iter_args_unpack[1]; + Value iter_is_inbounds = iter_args_unpack[2].front(); + Value iter_acc = iter_args_unpack[3].front(); + Value iter_output = iter_args_unpack[4].front(); + Value iter_slice_id = ivs.front(); + + SmallVector offsets = + PadWithZeros(trimmed_offsets, output_rank, nested_b); + + auto new_trimmed_offsets = + helper.ExtractOffsets(nested_b, thread_id_to_index_id_value.front()); + + // Check if the offsets changed. + Value offsets_changed = + EmitInequalityCheck(nested_b, trimmed_offsets, new_trimmed_offsets); + + for (int i = 0; i < description_.index_vector_length; ++i) { + new_trimmed_offsets[i] = nested_b.create( + offsets_changed, new_trimmed_offsets[i], trimmed_offsets[i]); + } + + auto new_offsets = PadWithZeros(new_trimmed_offsets, output_rank, nested_b); + + // Write accumulated values into the tensor if the offsets changed. + Value is_not_first_iteration = + b.create(arith::CmpIPredicate::ne, iter_slice_id, + b.create(0)); + Value write_to_output_required = b.create( + is_not_first_iteration, + b.create(offsets_changed, iter_is_inbounds)); + iter_output = helper.WriteAccumulatorToOutput( + b, write_to_output_required, thread_and_block_ids, iter_slice_id, + slice_indexing, offsets, iter_acc, iter_output); + + // Update `is_inbounds` if the offsets changed. + Value new_is_inbounds = UpdateIsInbounds( + nested_b, iter_is_inbounds, offsets_changed, new_offsets, + description_.slice_shape, description_.output_shape); + + // Emits a loop that overwrites the accumulator with the new update elements + // if the offsets changed. + auto emit_overwrite_accumulator_fn = [&](OpBuilder& then_b, + Location then_loc) -> void { + ImplicitLocOpBuilder implicit_then_b(then_loc, then_b); + auto then_results = EmitXlaLoopOp( + implicit_then_b, Pack({thread_and_block_ids, iter_slice_id}), + {iter_acc}, slice_indexing, + [&](ImplicitLocOpBuilder& update_loop_b, + ValueRange accumulator_indices, ValueRange slice_indices, + ValueRange inner_iter_args) -> SmallVector { + Value acc_arg = inner_iter_args.front(); + auto update_elem = + helper.GetUpdateElement(update_loop_b, slice_indices); + auto acc_ind_opfold = mlir::getAsOpFoldResult(accumulator_indices); + return update_loop_b + .create(then_loc, update_elem, acc_arg, + acc_ind_opfold) + ->getResults(); + }); + implicit_then_b.create(then_loc, then_results); + }; + // Emits a loop that combines the accumulator with the new update elements + // if the offsets did not change. + auto emit_combine_accumulator_fn = [&](OpBuilder& else_b, + Location else_loc) -> void { + ImplicitLocOpBuilder implicit_else_b(else_loc, else_b); + auto else_results = EmitXlaLoopOp( + implicit_else_b, Pack({thread_and_block_ids, iter_slice_id}), + {iter_acc}, slice_indexing, + [&](ImplicitLocOpBuilder& update_loop_b, + ValueRange accumulator_indices, ValueRange slice_indices, + ValueRange inner_iter_args) -> SmallVector { + Value acc_arg = inner_iter_args.front(); + auto update_elem = + helper.GetUpdateElement(update_loop_b, slice_indices); + auto acc_ind_opfold = mlir::getAsOpFoldResult(accumulator_indices); + Value accumulator_elem = update_loop_b.create( + acc_arg, acc_ind_opfold); + auto reduced_val = emitters::InlineBlock( + update_loop_b, helper.GetReducer().getBody().front(), + {accumulator_elem, update_elem})[0]; + return update_loop_b + .create(reduced_val, acc_arg, acc_ind_opfold) + ->getResults(); + }); + implicit_else_b.create(else_results); + }; + auto updated_accumulator = + EmitUpdateIf(nested_b, new_is_inbounds, {iter_acc}, + [&](ImplicitLocOpBuilder& if_b) { + return nested_b + .create(offsets_changed, + emit_overwrite_accumulator_fn, + emit_combine_accumulator_fn) + .getResults(); + }) + .front(); + SmallVector updated_if_loop_results = + Pack({iter_slice_id, new_trimmed_offsets, new_is_inbounds, + updated_accumulator, iter_output}); + return updated_if_loop_results; + }; + auto loop_over_indices_results = + EmitXlaLoopOp(b, thread_and_block_ids, inits, thread_id_to_update_id_map, + loop_over_indices_fn); + + // Write the accumulator to the output tensor. + SmallVector loop_over_indices_results_unpacked = + Unpack(loop_over_indices_results, + {1, description_.index_vector_length, 1, 1, 1}); + Value result_slice_id = loop_over_indices_results_unpacked[0].front(); + auto result_offsets = + PadWithZeros(loop_over_indices_results_unpacked[1], output_rank, b); + Value result_is_inbounds = loop_over_indices_results_unpacked[2].front(); + Value result_acc = loop_over_indices_results_unpacked[3].front(); + Value result_output = loop_over_indices_results_unpacked[4].front(); + result_output = helper.WriteAccumulatorToOutput( + b, result_is_inbounds, thread_and_block_ids, result_slice_id, + slice_indexing, result_offsets, result_acc, result_output); + + b.create(result_output); + return absl::OkStatus(); +} + +ScatterDescription GetScatterDescription(const HloFusionAnalysis& analysis) { + auto* hero = &analysis.fusion_hero(0).instruction(); + CHECK_NE(hero, nullptr); + auto* scatter = Cast(hero); + auto indices_shape = scatter->scatter_indices()->shape(); + auto update_shape = scatter->scatter_updates().front()->shape(); + auto output_shape = scatter->scatter_operands().front()->shape(); + + return ScatterDescription{ + scatter, + indices_shape.dimensions(0), + indices_shape.dimensions(1), + output_shape.element_type(), + update_shape, + SmallVector(update_shape.dimensions().begin() + 1, + update_shape.dimensions().end()), + SmallVector(output_shape.dimensions().begin(), + output_shape.dimensions().end()), + }; +} + +// Compute the maximal vector size that can be used to process the given number +// of elements in a single slice. +int64_t GetSingleSliceVectorSize(int64_t num_elements_in_slice, + int64_t max_vectorized_elements, + int64_t warp_size) { + int64_t vector_size = + std::gcd(num_elements_in_slice, max_vectorized_elements); + int64_t num_processed_elememts_per_warp = warp_size * vector_size; + while (vector_size > 1 && + num_processed_elememts_per_warp > num_elements_in_slice) { + vector_size /= 2; + num_processed_elememts_per_warp /= 2; + } + return vector_size; +} + +int64_t GetNumPossibleValidIndices(absl::Span slice_shape, + absl::Span output_shape, + int64_t index_vector_length) { + int64_t num_possible_valid_indices = 1; + for (int64_t i = 0; i < index_vector_length; ++i) { + num_possible_valid_indices *= output_shape[i] - slice_shape[i] + 1; + } + return num_possible_valid_indices; +} + +std::unique_ptr CreateMlirScatterFusion( + const HloFusionAnalysis& analysis) { + auto description = GetScatterDescription(analysis); + int64_t warp_size = WarpSize(analysis.device_info()); + int64_t num_elements_per_slice = Product(description.slice_shape); + int64_t num_slices = description.num_slices; + + // Initialize the vector size with the maximum allowed vector size that does + // not require masking/padding. + int64_t elem_type_bits = primitive_util::BitWidth(description.elem_type); + CHECK_EQ(kMaxVectorizedBits % elem_type_bits, 0); + int64_t max_vectorized_elements = kMaxVectorizedBits / elem_type_bits; + int64_t vector_size = GetSingleSliceVectorSize( + num_elements_per_slice, max_vectorized_elements, warp_size); + int64_t num_active_threads_per_warp = + std::min(warp_size, num_elements_per_slice / vector_size); + + int64_t max_active_warps = + kNumWarpsPerBlock * analysis.device_info().core_count(); + // For sorted scatter, we try to estimate the number of updates per warp by + // computing the ratio of the number of the given updates to the number of the + // possible valid indices. If we do not have multiple updates per warp, there + // is no reason to use this algorithm. + // TODO(b/385081952): Investigate why bf16 and f64 leads to incorrect results. + if (description.scatter->indices_are_sorted() && + description.elem_type != BF16 && num_slices > 2 * max_active_warps) { + int64_t num_indices_per_warp = CeilOfRatio( + num_slices, GetNumPossibleValidIndices( + description.slice_shape, description.output_shape, + description.index_vector_length)); + int64_t num_warps_per_slice = 1; + if (num_indices_per_warp > 2 && + num_active_threads_per_warp > warp_size / 2) { + return std::make_unique( + analysis, description, vector_size, num_warps_per_slice, + num_indices_per_warp); + } + } + // If we have enough data, we assign each warp to process a single + // slice. + if (num_slices > max_active_warps && + num_active_threads_per_warp > warp_size / 2) { + return std::make_unique( + analysis, description, vector_size, + /*num_warps_per_slice=*/1, /*num_indices_per_warp=*/1); + } + // Otherwise, we distribute the linearized updates tensor. + vector_size = + std::gcd(num_elements_per_slice, + ComputeLoopFusionConfig(analysis, description.update_shape) + .unroll_factor); + return std::make_unique(analysis, description, + vector_size); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h index 1ce89296984f01..3b4a5b412e3158 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h @@ -16,54 +16,201 @@ limitations under the License. #define XLA_SERVICE_GPU_FUSIONS_SCATTER_MLIR_H_ #include +#include #include #include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "xla/codegen/emitters/computation_partitioner.h" #include "xla/hlo/analysis/indexing_map.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/shape.h" +#include "xla/stream_executor/device_description.h" +#include "xla/util.h" namespace xla { namespace gpu { -// Generic loop fusion. Lowers to LLVM via MLIR. +class EmitterHelper; + +// Full description of the scatter operation. +// The shape of the indices tensor is . +// The shape of the updates tensor is . +struct ScatterDescription { + const HloScatterInstruction* scatter; + int64_t num_slices; + int64_t index_vector_length; + PrimitiveType elem_type; + // The shape of the updates tensor + Shape update_shape; + llvm::SmallVector slice_shape; + llvm::SmallVector output_shape; +}; +ScatterDescription GetScatterDescription(const HloFusionAnalysis& analysis); + class MlirScatterFusion : public MlirFusionEmitterBase { public: - explicit MlirScatterFusion(const HloFusionAnalysis& analysis); + explicit MlirScatterFusion(const HloFusionAnalysis& analysis, + const ScatterDescription& description, + int64_t vector_size); + + absl::Status EmitEntryFunction( + const emitters::PartitionedComputations& computations, + const emitters::CallTargetProvider& call_targets, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const override; - LaunchDimensions launch_dimensions() const override; + LaunchDimensions launch_dimensions() const override { + return LaunchDimensions(num_blocks_, num_warps_ * warp_size_); + } std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const override; + int64_t root_index, mlir::MLIRContext* ctx) const override { + // Since the access pattern to the output is not statically known, we cannot + // compute the output->input indexing map. + return std::nullopt; + } std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, mlir::MLIRContext* ctx) const override; protected: - absl::Status EmitEntryFunction( - const mlir_converter::PartitionedComputations& computations, - const mlir_converter::CallTargetProvider& call_targets, - mlir::func::FuncOp entry_function, - const HloFusionInstruction& fusion) const override; + virtual absl::Status EmitEntryFunctionImpl( + mlir::ImplicitLocOpBuilder& b, const EmitterHelper& helper, + const IndexingMap& updates_map, const IndexingMap& indices_map, + mlir::ValueRange thread_and_block_ids, + mlir::Value output_tensor) const = 0; - std::vector GetEpilogues( + virtual void ComputeIndexing(mlir::MLIRContext* ctx, IndexingMap* updates_map, + IndexingMap* indices_map) const = 0; + + std::vector GetEpilogues( const HloFusionInstruction& fusion, - mlir::MLIRContext* mlir_context) const override; + mlir::MLIRContext* mlir_context) const final; - private: const HloFusionAnalysis& analysis_; - LaunchDimensionsConfig config_; + ScatterDescription description_; + + // The grid is {num_warps_ * WarpSize(), 1, 1, num_blocks_, 1, 1}. + int64_t warp_size_; + int64_t num_warps_; + int64_t num_blocks_; + + // The number of elements that every thread will read from the updates tensor + // and write to the output tensor. + int64_t vector_size_; +}; + +// The distribution happens similarly to the loop emitter, but the iteration +// space corresponds to the shape of the updates tensor. In this case, GPU +// performs a grid-stride loop over the updates and every warp computes at what +// index to scatter an element(s) of the update. +class ScatterWithDistributedUpdates : public MlirScatterFusion { + public: + explicit ScatterWithDistributedUpdates(const HloFusionAnalysis& analysis, + const ScatterDescription& description, + int64_t vector_size); + + protected: + absl::Status EmitEntryFunctionImpl(mlir::ImplicitLocOpBuilder& b, + const EmitterHelper& helper, + const IndexingMap& updates_map, + const IndexingMap& indices_map, + mlir::ValueRange thread_and_block_ids, + mlir::Value output_tensor) const override; + + void ComputeIndexing(mlir::MLIRContext* ctx, IndexingMap* updates_map, + IndexingMap* indices_map) const override; +}; + +// Every warp will process one or more indices, i.e. there won't be two threads +// in a warp that scatter different indices at a time. In this case, every warp +// iterates its fraction of the indices, and then computes what updates to +// scatter. +// It implements the following algorithm: + +/* + %indices = -1 + %inbounds = false + %acc = vector + + // #indices_map + %updated_accumulator, %updated_out = for %i = 0 to %num_indices_per_warp_ { + %new_indices = PadWithZeros(ExtractOffsets(%indices_operand, %i)) + %indices_changed = EmitInequalityCheck(%new_indices, %indices) + if (%indices_changed && %i != 0) { + %output_tensor = WriteAccumulatorToOutput(%current_acc, %current_out); + } + if (%indices_changed) { + %inbounds = EmitBoundsCheck(%new_indices, %slice_shape, %output_shape) + } + if (%inbounds) { + if (%indices_changed) { + // updates_map(%i) + for %j = 0 to %num_slice_iterations_per_warp step 1 { + for %k = 0 to %vector_size step 1 { + %update_elem = GetUpdateElement + %acc = %update_elem + } + } + } else { + // updates_map(%i) + for %j = 0 to %num_slice_iterations_per_warp step 1 { + for %k = 0 to %vector_size step 1 { + %update_elem = GetUpdateElement + %acc = Reduce(%update_elem, %acc) + } + } + } + } +} +%final_out = WriteAccumulatorToOutput(%updated_accumulator, %updated_out); +*/ +class ScatterWithDistributedIndices : public MlirScatterFusion { + public: + explicit ScatterWithDistributedIndices(const HloFusionAnalysis& analysis, + const ScatterDescription& description, + int64_t vector_size, + int64_t num_warps_per_slice, + int64_t num_indices_per_warp); + + protected: + void ComputeIndexing(mlir::MLIRContext* ctx, IndexingMap* updates_map, + IndexingMap* indices_map) const override; + + absl::Status EmitEntryFunctionImpl(mlir::ImplicitLocOpBuilder& b, + const EmitterHelper& helper, + const IndexingMap& updates_map, + const IndexingMap& indices_map, + mlir::ValueRange thread_and_block_ids, + mlir::Value output_tensor) const override; + + private: + // Creates a 2D vector to store the accumulated updates in each thread. + mlir::Value InitializeAccumulator(mlir::ImplicitLocOpBuilder& b) const; + + // The number of warps that process a single slice of the update. + int64_t num_warps_per_slice_; + // The number of indices that every warp iterates over. This is a useful + // setting, if we know that the indices tensor is sorted. + int64_t num_indices_per_warp_; }; +std::unique_ptr CreateMlirScatterFusion( + const HloFusionAnalysis& analysis); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo b/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo index d289ab87cf1fd7..125cd72209ec03 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo @@ -9,9 +9,9 @@ fusion { ROOT concat = f32[900] concatenate(param0, param1, param2), dimensions={0} } // CHECK-DAG: #[[MAP:.*]] = #xla.indexing_map<"(th_x, bl_x) -> (bl_x * 128 + th_x) -// CHECK-DAG: #[[LOOPMAP_1:.*]] = #xla.indexing_map<"(th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1] -> (bl_x * 128 + th_x) -// CHECK-DAG: #[[LOOPMAP_2:.*]] = #xla.indexing_map<"(th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1] -> (bl_x * 128 + th_x + 200) -// CHECK-DAG: #[[LOOPMAP_3:.*]] = #xla.indexing_map<"(th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1] -> (bl_x * 128 + th_x + 600) +// CHECK-DAG: #[[LOOPMAP_1:.*]] = #xla.indexing_map<"(th_x, bl_x)[s0, s1] -> (bl_x * 128 + th_x) +// CHECK-DAG: #[[LOOPMAP_2:.*]] = #xla.indexing_map<"(th_x, bl_x)[s0, s1] -> (bl_x * 128 + th_x + 200) +// CHECK-DAG: #[[LOOPMAP_3:.*]] = #xla.indexing_map<"(th_x, bl_x)[s0, s1] -> (bl_x * 128 + th_x + 600) // CHECK: func.func @main // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9]*]]: {{[^,]*}}, diff --git a/third_party/xla/xla/service/gpu/fusions/tests/scatter/add_vectorized.hlo b/third_party/xla/xla/service/gpu/fusions/tests/scatter/add_vectorized.hlo new file mode 100644 index 00000000000000..915dc5545f15a8 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/scatter/add_vectorized.hlo @@ -0,0 +1,27 @@ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize \ +// RUN: -xla-gpu-test-transform-loops | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=scatter:2 + +add { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + ROOT %sum = f32[] add(%p0, %p1) +} +scatter { + %operand = f32[40,1500] parameter(0) + %indices = s32[24,1] parameter(1) + %update = f32[24,20,1000] parameter(2) + + ROOT %scatter = f32[40,1500] scatter( + f32[40,1500] %operand, + s32[24,1] %indices, + f32[24,20,1000] %update + ), + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + unique_indices=false, + to_apply=add +} +// CHECK: vector.transfer_read {{.*}} : tensor<480000xf32>, vector<4xf32> \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/scatter/sorted_indices.hlo b/third_party/xla/xla/service/gpu/fusions/tests/scatter/sorted_indices.hlo new file mode 100644 index 00000000000000..332eb543af61b0 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/scatter/sorted_indices.hlo @@ -0,0 +1,29 @@ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize \ +// RUN: | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=scatter:2 + +add { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + ROOT %sum = f32[] add(%p0, %p1) +} +scatter { + %operand = f32[100] parameter(0) + %indices = s32[2001,1] parameter(1) + %update = f32[2001,32] parameter(2) + + ROOT %scatter = f32[100] scatter( + f32[100] %operand, + s32[2001,1] %indices, + f32[2001,32] %update + ), + update_window_dims={1}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + indices_are_sorted=true, + unique_indices=false, + to_apply=add +} +// CHECK-LABEL: func.func @main +// CHECK: arith.constant dense<0.000000e+00> : vector<1x1xf32> diff --git a/third_party/xla/xla/service/gpu/fusions/tests/scatter/sorted_indices_small.hlo b/third_party/xla/xla/service/gpu/fusions/tests/scatter/sorted_indices_small.hlo new file mode 100644 index 00000000000000..05b33ce5271554 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/scatter/sorted_indices_small.hlo @@ -0,0 +1,37 @@ +// RUN: fusion_to_mlir %s | emitters_opt -xla-gpu-test-optimize \ +// RUN: | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=scatter:2 + +add { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + ROOT %sum = f32[] add(%p0, %p1) +} +scatter { + %operand = f32[100] parameter(0) + %indices = s32[200,1] parameter(1) + %update = f32[200,32] parameter(2) + + ROOT %scatter = f32[100] scatter( + f32[100] %operand, + s32[200,1] %indices, + f32[200,32] %update + ), + update_window_dims={1}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + indices_are_sorted=true, + unique_indices=false, + to_apply=add +} +// When there is not enough indices per warp, we fall back to the naive impl, +// when one warp processes one slice. +// CHECK: #xla.indexing_map<"(th_x, bl_x)[s0, s1] +// CHECK-SAME: -> (bl_x * 4 + th_x floordiv 32, th_x mod 32), +// CHECK-SAME: domain: th_x in [0, 127], +// CHECK-SAME: bl_x in [0, 49], +// CHECK-LABEL: func.func @main +// CHECK: xla.loop +// CHECK-NOT: xla.loop + diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo index 7ebd717c8f6c63..360ea0a1183385 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo @@ -5,7 +5,7 @@ fusion { %p0 = s8[160,170,3] parameter(0) ROOT %transpose = s8[170,160,3] transpose(%p0), dimensions={1,0,2} -} +} // CHECK: func.func @main( // CHECK-SAME: }, %[[OUT:.*]]: tensor<170x160x3xi8> diff --git a/third_party/xla/xla/service/gpu/fusions/tools/BUILD b/third_party/xla/xla/service/gpu/fusions/tools/BUILD index 28225e093b5043..1a8b02e8698890 100644 --- a/third_party/xla/xla/service/gpu/fusions/tools/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/tools/BUILD @@ -5,42 +5,6 @@ package( licenses = ["notice"], ) -xla_cc_binary( - name = "mlir_fusions_opt", - srcs = ["mlir_fusions_opt.cc"], - # We want to use this tool for lit tests. Due to hermetic cuda, we need to - # set linkopts in such a way that dynamic libraries are found, which are - # symlinked from the lit_lib directory. - linkopts = ["-Wl,-rpath,$$ORIGIN/../lit_lib"], - visibility = ["//xla/service/gpu/fusions:__subpackages__"], - deps = [ - "//xla/mlir_hlo", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu/fusions/ir:xla_gpu", - "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", - "//xla/service/gpu/fusions/transforms:passes", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ComplexDialect", - "@llvm-project//mlir:DLTIDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncExtensions", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:LLVMIRTransforms", - "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:MlirOptLib", - "@llvm-project//mlir:NVVMDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:VectorDialect", - ], -) - cc_library( name = "test_lib", testonly = 1, @@ -48,12 +12,12 @@ cc_library( hdrs = ["test_lib.h"], deps = [ "//xla:status_macros", + "//xla/backends/gpu/codegen/ir:xla_gpu", "//xla/hlo/ir:hlo", "//xla/mlir_hlo", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/fusions", - "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "//xla/stream_executor:device_description", "//xla/tools:hlo_module_loader", @@ -127,3 +91,18 @@ xla_cc_binary( "@local_tsl//tsl/platform:statusor", ], ) + +xla_cc_binary( + name = "fusion_wrapper", + testonly = 1, + srcs = ["fusion_wrapper.cc"], + visibility = ["//xla/service/gpu/fusions:__subpackages__"], + deps = [ + ":test_lib", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:platform_port", + ], +) diff --git a/third_party/xla/xla/service/gpu/fusions/tools/fusion_wrapper.cc b/third_party/xla/xla/service/gpu/fusions/tools/fusion_wrapper.cc new file mode 100644 index 00000000000000..8165c343e6037a --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tools/fusion_wrapper.cc @@ -0,0 +1,41 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "llvm/Support/raw_ostream.h" +#include "xla/service/gpu/fusions/tools/test_lib.h" +#include "xla/tsl/platform/statusor.h" +#include "tsl/platform/init_main.h" + +namespace xla { +namespace gpu { + +absl::Status Run(const std::string& filename) { + TF_ASSIGN_OR_RETURN(auto module, LoadTestModule(filename)); + llvm::outs() << module->ToString(); + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla + +int main(int argc, char** argv) { + tsl::port::InitMain(argv[0], &argc, &argv); + CHECK_EQ(argc, 2) << "Must specify an input file"; + CHECK_OK(xla::gpu::Run(argv[1])); + return 0; +} diff --git a/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc b/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc deleted file mode 100644 index 43a1f708286456..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "llvm/ADT/STLFunctionalExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Twine.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/DLTI/DLTI.h" -#include "mlir/Dialect/Func/Extensions/AllExtensions.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Tools/mlir-opt/MlirOptMain.h" -#include "mlir/Transforms/Passes.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" -#include "xla/service/gpu/fusions/transforms/passes.h" -#include "xla/service/gpu/gpu_device_info_for_tests.h" - -int main(int argc, char** argv) { - mlir::DialectRegistry registry; - registry.insert(); - mlir::func::registerAllExtensions(registry); - mlir::LLVM::registerInlinerInterface(registry); - mlir::registerCanonicalizerPass(); - mlir::registerCSEPass(); - mlir::registerInliner(); - xla::gpu::registerGpuFusionTransformsPasses(); - mlir::registerPassPipeline( - "xla-gpu-test-optimize", - "Test pipeline of passes up to inlining. No vectorization, also does not " - "lower xla_gpu. Intended to simplify IR in tests.", - [=](mlir::OpPassManager& pm, llvm::StringRef options, - llvm::function_ref - errorHandler) { - if (!options.empty()) return mlir::failure(); - - xla::gpu::AddXlaGpuOpsOptimizationPasses(pm); - return mlir::success(); - }, - [](llvm::function_ref) {}); - mlir::registerPassPipeline( - "xla-gpu-test-transform-loops", - "Test pipeline for vectorization. Should run after " - "xla-gpu-test-to-inline.", - [=](mlir::OpPassManager& pm, llvm::StringRef options, - llvm::function_ref - errorHandler) { - if (!options.empty()) return mlir::failure(); - xla::gpu::AddLoopTransformationPasses( - pm, xla::gpu::TestGpuDeviceInfo::RTXA6000DeviceInfo()); - return mlir::success(); - }, - [](llvm::function_ref) {}); - - return mlir::failed( - MlirOptMain(argc, argv, "XLA MLIR Fusion Pass Driver\n", registry)); -} diff --git a/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc index dc1955a432c686..867131681ad81e 100644 --- a/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc +++ b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -40,7 +41,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/gpu/fusions/fusions.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc index 766d2a5dbcc955..dbd19bcc57c196 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc @@ -40,6 +40,10 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" +#include "xla/codegen/emitters/computation_partitioner.h" +#include "xla/codegen/emitters/elemental_hlo_to_mlir.h" +#include "xla/codegen/emitters/type_util.h" #include "xla/hlo/analysis/indexing_analysis.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_computation.h" @@ -48,10 +52,6 @@ limitations under the License. #include "xla/permutation_util.h" #include "xla/primitive_util.h" #include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" -#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/type_util.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/launch_dimensions.h" @@ -64,15 +64,16 @@ namespace xla { namespace gpu { namespace { +using emitters::ApplyIndexing; using llvm::SmallVector; using mlir::AffineExpr; +using mlir::ImplicitLocOpBuilder; using mlir::MLIRContext; using mlir::RankedTensorType; using mlir::Value; using mlir::ValueRange; using mlir::func::FuncOp; using mlir::func::ReturnOp; -using mlir_converter::ApplyIndexing; constexpr int kNumRows = 4; constexpr int kNumThreadsPerBlock = 128; @@ -223,8 +224,8 @@ IndexingMap MlirTransposeFusion::GetSharedMemoryIndexing( MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( mlir::ImplicitLocOpBuilder& builder, FuncOp entry_function, const HloFusionInstruction& fusion, - const mlir_converter::PartitionedComputation& root_computation, - const mlir_converter::CallTargetProvider& call_target_provider, + const emitters::PartitionedComputation& root_computation, + const emitters::CallTargetProvider& call_target_provider, ValueRange output_args, mlir::ValueRange thread_and_block_ids) const { MLIRContext* ctx = builder.getContext(); auto shmem_tensor_size = block_sizes_; @@ -264,7 +265,7 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( operand_shape.dimensions().end()); SmallVector shmem_tensors; for (auto* transpose : shmem_transposes_) { - auto elem_type = mlir_converter::PrimitiveTypeToMlirType( + auto elem_type = emitters::PrimitiveTypeToMlirType( transpose->shape().element_type(), builder); auto shmem = builder.create( RankedTensorType::get(shmem_tensor_size, elem_type)); @@ -295,11 +296,12 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( for (int index : side_output_root_indices_) { side_output_inits.push_back(entry_function.getArgument(num_inputs + index)); } - auto body_builder = [&](ValueRange symbol_values, ValueRange map_results, + auto body_builder = [&](ImplicitLocOpBuilder& nested_b, + ValueRange symbol_values, ValueRange map_results, ValueRange output_tensors) -> SmallVector { auto input_indices = [&](const HloInstruction* instr) { return ApplyIndexing(GetIndexing(/*input=*/true, instr->shape(), ctx), - thread_and_block_ids, symbol_values, builder); + thread_and_block_ids, symbol_values, nested_b); }; SmallVector side_outputs; @@ -307,10 +309,10 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( auto* root_tuple = fusion.fused_expression_root(); for (auto root : side_output_roots_) { side_output_indices.push_back(input_indices(root)); - ValueRange param_values = mlir_converter::ProvideParameter( + ValueRange param_values = emitters::ProvideParameter( root_computation, root_tuple, root_tuple->operand_index(root), side_output_indices.back(), call_target_provider, entry_function, - builder); + nested_b); side_outputs.append(param_values.begin(), param_values.end()); } @@ -318,16 +320,16 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( for (const auto& [value, indices, output] : llvm::zip(side_outputs, side_output_indices, output_tensors)) { result_tensors.push_back( - builder.create(value, output, indices)); + nested_b.create(value, output, indices)); } return result_tensors; }; mlir::ValueRange side_output_vector; if (!side_output_inits.empty()) { - side_output_vector = mlir_converter::EmitXlaLoopOp( - builder, thread_and_block_ids, side_output_inits, indexing, - body_builder); + side_output_vector = + emitters::EmitXlaLoopOp(builder, thread_and_block_ids, + side_output_inits, indexing, body_builder); } WriteResult result; @@ -346,39 +348,40 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( void MlirTransposeFusion::EmitReadFromShMemMlir( mlir::ImplicitLocOpBuilder& builder, FuncOp entry_function, const HloFusionInstruction& fusion, - const mlir_converter::PartitionedComputations& computations, + const emitters::PartitionedComputations& computations, const WriteResult& written, mlir::ValueRange thread_and_block_ids) const { auto* mlir_context = builder.getContext(); auto output_indexing = *ComputeThreadIdToOutputIndexing( shmem_transpose_root_indices_[0], mlir_context); auto shmem_read_indexing = GetSharedMemoryIndexing(/*read=*/true, mlir_context); - auto result_tensors = mlir_converter::EmitXlaLoopOp( + auto result_tensors = emitters::EmitXlaLoopOp( builder, thread_and_block_ids, written.updated_outputs, output_indexing, - [&](ValueRange symbol_values, ValueRange map_results, + [&](ImplicitLocOpBuilder& nested_b, ValueRange symbol_values, + ValueRange map_results, ValueRange output_tensors) -> SmallVector { auto shmem_indices = ApplyIndexing( - shmem_read_indexing, thread_and_block_ids, symbol_values, builder); + shmem_read_indexing, thread_and_block_ids, symbol_values, nested_b); absl::flat_hash_map> transpose_values; for (auto [transpose, shmem] : llvm::zip(shmem_transposes_, written.shmem_tensors)) { transpose_values[transpose].push_back( - builder.create(shmem, shmem_indices)); + nested_b.create(shmem, shmem_indices)); } llvm::SmallVector epilogue_indices = thread_and_block_ids; absl::c_copy(symbol_values, std::back_inserter(epilogue_indices)); auto result_scalars = EmitEpilogue(/*epilogue_index=*/0, computations, entry_function, - transpose_values, epilogue_indices, builder); + transpose_values, epilogue_indices, nested_b); SmallVector results = output_tensors; for (auto [root, indexing, root_index] : llvm::zip(shmem_transpose_roots_, computations.epilogues().front().root_indexing, shmem_transpose_root_indices_)) { llvm::SmallVector indices = ApplyIndexing( - indexing, thread_and_block_ids, symbol_values, builder); - results[root_index] = builder.create( + indexing, thread_and_block_ids, symbol_values, nested_b); + results[root_index] = nested_b.create( result_scalars.at(root).front(), results[root_index], indices); } return results; @@ -387,25 +390,23 @@ void MlirTransposeFusion::EmitReadFromShMemMlir( builder.create(result_tensors); } -std::vector -MlirTransposeFusion::GetEpilogues(const HloFusionInstruction& fusion, - MLIRContext* mlir_context) const { - std::vector epilogues{ +std::vector MlirTransposeFusion::GetEpilogues( + const HloFusionInstruction& fusion, MLIRContext* mlir_context) const { + std::vector epilogues{ GetEpilogueForOutputIndexing(analysis_, shmem_transposes_, shmem_transpose_roots_, mlir_context)}; // Add empty epilogues for the side outputs. This ensures their roots don't // get "fused" into the tuple function. for (const auto* root : side_output_roots_) { - epilogues.push_back( - mlir_converter::EpilogueSpecification::FromIdentityIndexing( - root, root, mlir_context)); + epilogues.push_back(emitters::EpilogueSpecification::FromIdentityIndexing( + root, root, mlir_context)); } return epilogues; } absl::Status MlirTransposeFusion::EmitEntryFunction( - const mlir_converter::PartitionedComputations& computations, - const mlir_converter::CallTargetProvider& call_targets, + const emitters::PartitionedComputations& computations, + const emitters::CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, const HloFusionInstruction& fusion) const { const auto& root_computation = computations.FindPartitionedComputation( diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h index f21451106adce1..ea70d188079d9a 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h @@ -30,10 +30,10 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" +#include "xla/codegen/emitters/computation_partitioner.h" #include "xla/hlo/analysis/indexing_map.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" @@ -65,12 +65,12 @@ class MlirTransposeFusion : public MlirFusionEmitterBase { protected: absl::Status EmitEntryFunction( - const mlir_converter::PartitionedComputations& computations, - const mlir_converter::CallTargetProvider& call_targets, + const emitters::PartitionedComputations& computations, + const emitters::CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, const HloFusionInstruction& fusion) const override; - std::vector GetEpilogues( + std::vector GetEpilogues( const HloFusionInstruction& fusion, mlir::MLIRContext* mlir_context) const override; @@ -84,14 +84,14 @@ class MlirTransposeFusion : public MlirFusionEmitterBase { WriteResult EmitWriteToShMemMlir( mlir::ImplicitLocOpBuilder& builder, mlir::func::FuncOp entry_function, const HloFusionInstruction& fusion, - const mlir_converter::PartitionedComputation& root_computation, - const mlir_converter::CallTargetProvider& call_target_provider, + const emitters::PartitionedComputation& root_computation, + const emitters::CallTargetProvider& call_target_provider, mlir::ValueRange output_args, mlir::ValueRange thread_and_block_ids) const; void EmitReadFromShMemMlir( mlir::ImplicitLocOpBuilder& builder, mlir::func::FuncOp entry_function, const HloFusionInstruction& fusion, - const mlir_converter::PartitionedComputations& computations, + const emitters::PartitionedComputations& computations, const WriteResult& written, mlir::ValueRange thread_and_block_ids) const; private: diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD index 0b2095bc2bfe6c..da184c160a73f8 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD @@ -26,7 +26,9 @@ package_group( cc_library( name = "emitter_helpers", srcs = ["emitter_helpers.cc"], - hdrs = ["emitter_helpers.h"], + hdrs = [ + "emitter_helpers.h", + ], deps = [ "//xla:literal", "//xla:shape_util", @@ -37,6 +39,7 @@ cc_library( "//xla/mlir_hlo:map_mhlo_to_scalar_op", "//xla/mlir_hlo:transformation_helpers", "//xla/service/gpu:target_util", + "//xla/service/gpu/fusions:emitter_loc_op_builder", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", "@com_google_absl//absl/log", @@ -57,17 +60,59 @@ cc_library( ) cc_library( - name = "triton_fusion_emitter", + name = "compilation_pipeline", srcs = if_gpu_is_configured( - ["triton_fusion_emitter.cc"], - ["triton_fusion_emitter_stub.cc"], + [], + ["compilation_pipeline_stub.cc"], ) + if_cuda_is_configured([ "compilation_pipeline_cuda.cc", ]) + if_rocm_is_configured([ "compilation_pipeline_rocm.cc", ]), + hdrs = ["compilation_pipeline.h"], + deps = [ + "@com_google_absl//absl/status", + "@llvm-project//mlir:Pass", + ] + if_gpu_is_configured([ + ":xla_triton_passes", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:IndexToLLVM", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Transforms", + "//xla/service:hlo_module_config", + "//xla/service/gpu:matmul_utils", + "//xla/stream_executor:device_description", + "@triton//:TritonDialects", + "@triton//:TritonGPUToLLVM", + "@triton//:TritonGPUTransforms", + "@triton//:TritonLLVMIR", + "@triton//:TritonNvidiaGPUTransforms", + "@triton//:TritonToTritonGPU", + "@triton//:TritonTransforms", + ]) + if_cuda_is_configured([ + "//xla/service/gpu/llvm_gpu_backend:nvptx_backend", + "//xla/service/gpu/llvm_gpu_backend:nvptx_libdevice_path", + "@triton//third_party/nvidia:NVGPUToLLVM", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ]) + if_rocm_is_configured([ + "//xla/service/gpu/llvm_gpu_backend:amdgpu_backend", + "@local_tsl//tsl/platform:rocm_rocdl_path", + "@triton//third_party/amd:TritonAMDGPUToLLVM", + "@triton//third_party/amd:TritonAMDGPUTransforms", + ]), +) + +cc_library( + name = "triton_fusion_emitter", + srcs = if_gpu_is_configured( + ["triton_fusion_emitter.cc"], + ["triton_fusion_emitter_stub.cc"], + ), hdrs = ["triton_fusion_emitter.h"], deps = [ + ":compilation_pipeline", ":emitter_helpers", ":passes", ":triton_fusion_emitter_legacy_matmul", @@ -81,6 +126,9 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", + "//xla/backends/gpu/codegen/ir:xla_gpu", + "//xla/backends/gpu/codegen/transforms:passes", + "//xla/codegen/emitters:elemental_hlo_to_mlir", "//xla/codegen/ir:xla", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", @@ -95,9 +143,7 @@ cc_library( "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:triton_fusion_analysis", - "//xla/service/gpu/fusions/ir:xla_gpu", - "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", - "//xla/service/gpu/fusions/transforms:passes", + "//xla/service/gpu/fusions:emitter_loc_op_builder", "//xla/service/gpu/model:symbolic_tile_analysis", "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/service/gpu/model:triton_emitter_constraints", @@ -152,10 +198,11 @@ cc_library( ]) + if_cuda_is_configured([ "@triton//third_party/nvidia:NVGPUToLLVM", "//xla/service/gpu/llvm_gpu_backend:nvptx_libdevice_path", + "//xla/service/gpu/llvm_gpu_backend:nvptx_backend", "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", ]) + if_rocm_is_configured([ "@local_tsl//tsl/platform:rocm_rocdl_path", - "//xla/service/gpu/llvm_gpu_backend:llvm_gpu_backend", + "//xla/service/gpu/llvm_gpu_backend:amdgpu_backend", "@triton//third_party/amd:TritonAMDGPUToLLVM", "@triton//third_party/amd:TritonAMDGPUTransforms", ]), @@ -190,6 +237,7 @@ cc_library( "//xla/service/gpu:matmul_utils", "//xla/service/gpu:triton_fusion_analysis", "//xla/service/gpu:triton_tiling_propagation", + "//xla/service/gpu/fusions:emitter_loc_op_builder", "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", @@ -222,10 +270,12 @@ cc_library( cc_library( name = "triton_fusion_emitter_stub_for_testing", srcs = [ + "compilation_pipeline_stub.cc", "triton_fusion_emitter_legacy_matmul_stub.cc", "triton_fusion_emitter_stub.cc", ], hdrs = [ + "compilation_pipeline.h", "triton_fusion_emitter.h", "triton_fusion_emitter_legacy_matmul.h", ], @@ -237,6 +287,7 @@ cc_library( "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:triton_fusion_analysis", + "//xla/service/gpu/fusions:emitter_loc_op_builder", "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/stream_executor:device_description", "//xla/stream_executor:launch_dim", @@ -261,6 +312,7 @@ xla_cc_test( "//xla:literal_util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_traversal", + "//xla/service/gpu/fusions:emitter_loc_op_builder", "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", @@ -324,6 +376,7 @@ gentbl_cc_library( cc_library( name = "xla_triton_passes", srcs = [ + "xla_triton_int4_passes.cc", "xla_triton_prevent_mmav3_loop_unrolling_pass.cc", "xla_triton_sparse_passes.cc", ], @@ -333,9 +386,12 @@ cc_library( deps = [ ":xla_triton", ":xla_triton_passes_inc_gen", + "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUToNVVMTransforms", "@llvm-project//mlir:IR", @@ -343,6 +399,7 @@ cc_library( "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@triton//:TritonAnalysis", @@ -455,6 +512,29 @@ cc_library( ], ) +xla_test( + name = "triton_fusion_emitter_deviceless_test", + srcs = ["triton_fusion_emitter_deviceless_test.cc"], + backends = ["gpu"], + deps = [ + ":triton_fusion_emitter", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu/fusions:emitter_loc_op_builder", + "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", + "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/stream_executor:device_description", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + xla_test( name = "triton_fusion_emitter_device_legacy_test", srcs = if_gpu_is_configured(["triton_fusion_emitter_device_legacy_test.cc"]), @@ -495,17 +575,55 @@ xla_test( "//xla/stream_executor:device_description", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_matchers", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", + ], +) + +xla_test( + name = "triton_fusion_emitter_int4_device_test", + srcs = if_gpu_is_configured(["triton_fusion_emitter_int4_device_test.cc"]), + # TODO(b/372714955): Fix the memory leak! + backend_args = if_google( + { + "gpu_h100": ["--heap_check="], + "gpu_a100": ["--heap_check="], + }, + {}, + ), + backends = [ + "gpu_a100", + "gpu_h100", + "gpu_b100", + "gpu_amd_any", + ], + tags = [ + "no_mac", + ], + deps = [ + "//xla:autotuning_proto_cc", + "//xla:error_spec", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/stream_executor:device_description", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:path", ], ) @@ -582,12 +700,13 @@ xla_test( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", @@ -646,7 +765,9 @@ cc_library( "//xla:status_macros", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:float_normalization", + "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/hlo/transforms/simplifiers:float_normalization", "//xla/hlo/utils:hlo_query", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_device_info_for_tests", @@ -655,9 +776,7 @@ cc_library( "//xla/service/gpu:matmul_utils", "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/stream_executor:device_description", - "//xla/tests:filecheck", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -681,12 +800,12 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_traversal", "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu/fusions:emitter_loc_op_builder", "//xla/service/gpu/model:symbolic_tile_analysis", "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/service/gpu/model:triton_emitter_constraints", "//xla/service/llvm_ir:llvm_util", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline.h b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline.h new file mode 100644 index 00000000000000..e6a8b2f1aca0fc --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline.h @@ -0,0 +1,51 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_COMPILATION_PIPELINE_H_ +#define XLA_SERVICE_GPU_FUSIONS_TRITON_COMPILATION_PIPELINE_H_ + +#include + +#include "absl/status/status.h" +#include "mlir/Pass/PassManager.h" + +namespace mlir::triton::nvidia_gpu { + +// Forward declaration to avoid including a GPU-only header. +struct ClusterInfo; + +} // namespace mlir::triton::nvidia_gpu + +namespace xla { +namespace gpu { + +// Creates a Triton compilation pipeline. +// +// `out_cluster_info` must be kept alive at least until pm.run() is called. +// It should be read after that. We have to pass the cluster dims to +// LaunchDimensions. Triton currently uses this as an out-parameter to return +// the cluster dims determined based on `config.num_ctas` and a heuristic. There +// are some signs that show that this was intended to be used as an in-out +// parameter which would give a hint to Triton which cluster dims we prefer to +// use, but that's not the case currently. +absl::Status CreateTritonPipeline( + mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas, + int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info, + bool is_xla_fusion); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_COMPILATION_PIPELINE_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc index 8ad50e305721d0..2e1e6dcde49a99 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include "nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h" #include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" @@ -26,7 +27,6 @@ limitations under the License. #include "mlir/Transforms/Passes.h" #include "xla/service/gpu/fusions/triton/xla_triton_passes.h" #include "xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h" -#include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/hlo_module_config.h" #include "xla/stream_executor/device_description.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" @@ -41,95 +41,97 @@ namespace gpu { namespace mt = ::mlir::triton; namespace mt_xla = ::mlir::triton::xla; -absl::Status CreateTritonPipeline( - mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, - const BlockLevelParameters& block_level_parameters, - mt::nvidia_gpu::ClusterInfo& out_cluster_info) { - auto ccCuda = std::get(cc); - const int ccAsInt = ccCuda.major * 10 + ccCuda.minor; +absl::Status CreateTritonPipeline(mlir::OpPassManager* pm, + std::string arch_name, int num_warps, + int num_ctas, int num_stages, + mt::nvidia_gpu::ClusterInfo& out_cluster_info, + bool is_xla_fusion) { + auto cc = se::CudaComputeCapability(std::move(arch_name)); + const int ccAsInt = cc.major * 10 + cc.minor; const int threadsPerWarp = 32; + if (is_xla_fusion) { + pm->addPass(mt_xla::CreateInt4ToPackedInt4RewritePass()); + } + // Based on make_ttir() in // @triton//:third_party/nvidia/backend/compiler.py - pm.addPass(mlir::createInlinerPass()); - pm.addPass(mt::createRewriteTensorPointerPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mt::createCombineOpsPass()); - pm.addPass(mt::createReorderBroadcastPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createLoopInvariantCodeMotionPass()); - pm.addPass(mlir::createSymbolDCEPass()); - pm.addPass(mt::createLoopUnrollPass()); + pm->addPass(mlir::createInlinerPass()); + pm->addPass(mt::createRewriteTensorPointerPass()); + pm->addPass(mlir::createCanonicalizerPass()); + pm->addPass(mt::createCombineOpsPass()); + pm->addPass(mt::createReorderBroadcastPass()); + pm->addPass(mlir::createCSEPass()); + pm->addPass(mlir::createLoopInvariantCodeMotionPass()); + pm->addPass(mlir::createSymbolDCEPass()); + pm->addPass(mt::createLoopUnrollPass()); // Based on make_ttgir() in // @triton//:third_party/nvidia/backend/compiler.py - pm.addPass(mt::createConvertTritonToTritonGPUPass( - absl::StrFormat("cuda:%u", ccAsInt), block_level_parameters.num_warps, - threadsPerWarp, block_level_parameters.num_ctas)); - pm.addPass(mt_xla::CreateSparseAddEncodingPass( - block_level_parameters.num_warps, threadsPerWarp, - block_level_parameters.num_ctas)); - pm.addPass(mt::gpu::createTritonGPUCoalesce()); - if (ccCuda.IsAtLeastAmpere()) { - pm.addPass(mt::gpu::createTritonGPUF32DotTC()); + pm->addPass(mt::createConvertTritonToTritonGPUPass( + absl::StrFormat("cuda:%u", ccAsInt), num_warps, threadsPerWarp, + num_ctas)); + pm->addPass( + mt_xla::CreateSparseAddEncodingPass(num_warps, threadsPerWarp, num_ctas)); + pm->addPass(mt::gpu::createTritonGPUCoalesce()); + if (cc.IsAtLeastAmpere()) { + pm->addPass(mt::gpu::createTritonGPUF32DotTC()); } - pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&out_cluster_info)); - pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); - pm.addPass(mt::gpu::createTritonGPUOptimizeThreadLocality()); - pm.addPass(mt_xla::CreateSparseBlockedToMMAPass()); - pm.addPass(mt::gpu::createTritonGPUAccelerateMatmul()); - pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); - pm.addPass( - mt::gpu::createTritonGPUOptimizeDotOperands({ccCuda.IsAtLeastAmpere()})); - pm.addPass(mlir::createCSEPass()); + pm->addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&out_cluster_info)); + pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); + pm->addPass(mt::gpu::createTritonGPUOptimizeThreadLocality()); + pm->addPass(mt_xla::CreateSparseBlockedToMMAPass()); + pm->addPass(mt::gpu::createTritonGPUAccelerateMatmul()); + pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); + pm->addPass( + mt::gpu::createTritonGPUOptimizeDotOperands({cc.IsAtLeastAmpere()})); + pm->addPass(mlir::createCSEPass()); // Even though we don't run on pre-Ampere architectures anymore, we keep this // check for consistency with the upstream pipeline - if (ccCuda.IsAtLeastAmpere()) { - pm.addPass(mt::gpu::createTritonGPUCombineTensorSelectAndIf()); - pm.addPass(mt::gpu::createTritonGPULoopScheduling( - {block_level_parameters.num_stages})); - pm.addPass( - mt::gpu::createTritonGPUPipeline({block_level_parameters.num_stages})); + if (cc.IsAtLeastAmpere()) { + pm->addPass(mt::gpu::createTritonGPUCombineTensorSelectAndIf()); + pm->addPass(mt::gpu::createTritonGPULoopScheduling({num_stages})); + pm->addPass(mt::gpu::createTritonGPUPipeline({num_stages})); } - pm.addPass(mt::gpu::createTritonGPUPrefetch()); - pm.addPass( - mt::gpu::createTritonGPUOptimizeDotOperands({ccCuda.IsAtLeastAmpere()})); - pm.addPass(mt::gpu::createTritonGPUCoalesceAsyncCopy()); - pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); - pm.addPass(mt_xla::CreateSparseRemoveLayoutConversionPass()); - pm.addPass(mt::gpu::createTritonGPUReduceDataDuplication()); - pm.addPass(mt::gpu::createTritonGPUReorderInstructions()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createSymbolDCEPass()); - if (ccCuda.IsAtLeastHopper()) { - pm.addPass(mlir::createTritonNvidiaGPUFenceInsertionPass(ccAsInt)); - pm.addPass(mlir::createTritonNvidiaGPUTMALoweringPass()); + pm->addPass(mt::gpu::createTritonGPUPrefetch()); + pm->addPass( + mt::gpu::createTritonGPUOptimizeDotOperands({cc.IsAtLeastAmpere()})); + pm->addPass(mt::gpu::createTritonGPUCoalesceAsyncCopy()); + pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); + pm->addPass(mt_xla::CreateSparseRemoveLayoutConversionPass()); + pm->addPass(mt::gpu::createTritonGPUReduceDataDuplication()); + pm->addPass(mt::gpu::createTritonGPUReorderInstructions()); + pm->addPass(mlir::createCSEPass()); + pm->addPass(mlir::createSymbolDCEPass()); + if (cc.IsAtLeastHopper()) { + pm->addPass(mlir::createTritonNvidiaGPUFenceInsertionPass(ccAsInt)); + pm->addPass(mlir::createTritonNvidiaGPUTMALoweringPass()); } - pm.addPass(mlir::createCanonicalizerPass()); + pm->addPass(mlir::createCanonicalizerPass()); // Based on make_llir() in // @triton//:third_party/nvidia/backend/compiler.py - pm.addPass(mt::NVIDIA::createDecomposeUnsupportedConversionsPass()); + pm->addPass(mt::NVIDIA::createDecomposeUnsupportedConversionsPass()); // This pass reduces Hopper compile time extensively: b/344841434. - if (ccCuda.IsAtLeastHopper()) { - pm.addPass(mt_xla::CreatePreventMmaV3LoopUnrollingPass()); + if (cc.IsAtLeastHopper()) { + pm->addPass(mt_xla::CreatePreventMmaV3LoopUnrollingPass()); } - pm.addPass(mlir::createConvertSCFToCFPass()); - pm.addPass(mlir::createConvertIndexToLLVMPass()); - pm.addPass(mt::gpu::createAllocateSharedMemoryPass()); - pm.addPass(mt::gpu::createTritonGPUGlobalScratchAllocationPass()); - pm.addPass(mt_xla::CreateSparseLocalLoadToLLVMPass()); - pm.addPass(mt::createConvertTritonGPUToLLVMPass(ccAsInt)); + pm->addPass(mlir::createConvertSCFToCFPass()); + pm->addPass(mlir::createConvertIndexToLLVMPass()); + pm->addPass(mt::gpu::createAllocateSharedMemoryPass()); + pm->addPass(mt::gpu::createTritonGPUGlobalScratchAllocationPass()); + pm->addPass(mt_xla::CreateSparseLocalLoadToLLVMPass()); + pm->addPass(mt::createConvertTritonGPUToLLVMPass(ccAsInt)); // The triton_xla.sparse_dot ops need to be rewritten after // ModuleAxisInfoAnalysis inside convert-triton-gpu-to-llvm. - pm.addPass(mt_xla::CreateSparseDotOpToLLVMPass()); - pm.addPass(mt::createConvertNVGPUToLLVMPass()); - pm.addPass(mt_xla::CreateSparseWGMMAOpToLLVMPass()); - pm.addPass(mlir::createArithToLLVMConversionPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createSymbolDCEPass()); + pm->addPass(mt_xla::CreateSparseDotOpToLLVMPass()); + pm->addPass(mt::createConvertNVGPUToLLVMPass()); + pm->addPass(mt_xla::CreateSparseWGMMAOpToLLVMPass()); + pm->addPass(mlir::createArithToLLVMConversionPass()); + pm->addPass(mlir::createCanonicalizerPass()); + pm->addPass(mlir::createCSEPass()); + pm->addPass(mlir::createSymbolDCEPass()); // Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass. return absl::OkStatus(); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc index 187d96657e34af..4fc127382dd10b 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ // TODO(ROCm): Enable and include ROCm Triton passes when ROCm Triton is // included in build. +#include +#include + #include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h" #include "third_party/amd/include/TritonAMDGPUTransforms/Passes.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" @@ -22,10 +25,10 @@ limitations under the License. #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "xla/service/gpu/llvm_gpu_backend/amdgpu_backend.h" #include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/hlo_module_config.h" +#include "xla/stream_executor/device_description.h" #include "tsl/platform/rocm_rocdl_path.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" @@ -52,81 +55,79 @@ using ::mlir::Type; using ::mlir::Value; using mlir::ValueRange; -absl::Status CreateTritonPipeline( - mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, - const BlockLevelParameters& block_level_parameters, - mt::nvidia_gpu::ClusterInfo& out_cluster_info) { +absl::Status CreateTritonPipeline(mlir::OpPassManager* pm, + std::string arch_name, int num_warps, + int num_ctas, int num_stages, + mt::nvidia_gpu::ClusterInfo& out_cluster_info, + bool is_xla_fusion) { // TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64. const int threadsPerWarp = 32; - auto ccRocm = std::get(cc); + auto cc = se::RocmComputeCapability(std::move(arch_name)); // Based on make_ttir() in // @triton//:third_party/amd/backend/compiler.py - pm.addPass(mlir::createInlinerPass()); - pm.addPass(mt::createRewriteTensorPointerPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mt::createCombineOpsPass()); - pm.addPass(mt::createReorderBroadcastPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createLoopInvariantCodeMotionPass()); - pm.addPass(mlir::createSymbolDCEPass()); - pm.addPass(mt::createLoopUnrollPass()); + pm->addPass(mlir::createInlinerPass()); + pm->addPass(mt::createRewriteTensorPointerPass()); + pm->addPass(mlir::createCanonicalizerPass()); + pm->addPass(mt::createCombineOpsPass()); + pm->addPass(mt::createReorderBroadcastPass()); + pm->addPass(mlir::createCSEPass()); + pm->addPass(mlir::createLoopInvariantCodeMotionPass()); + pm->addPass(mlir::createSymbolDCEPass()); + pm->addPass(mt::createLoopUnrollPass()); // Based on make_ttgir() in // @triton//:third_party/amd/backend/compiler.py - pm.addPass(mt::createConvertTritonToTritonGPUPass( - absl::StrCat("hip:", ccRocm.gfx_version()), - block_level_parameters.num_warps, threadsPerWarp, - block_level_parameters.num_ctas)); - pm.addPass(mt::gpu::createTritonGPUCoalesce()); - pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); - pm.addPass(mt::gpu::createTritonGPUOptimizeThreadLocality()); - pm.addPass(mt::gpu::createTritonGPUAccelerateMatmul()); - pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); + pm->addPass(mt::createConvertTritonToTritonGPUPass( + absl::StrCat("hip:", cc.gfx_version()), num_warps, threadsPerWarp, + num_ctas)); + pm->addPass(mt::gpu::createTritonGPUCoalesce()); + pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); + pm->addPass(mt::gpu::createTritonGPUOptimizeThreadLocality()); + pm->addPass(mt::gpu::createTritonGPUAccelerateMatmul()); + pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); // TODO ROCm Check if we want to compare MI100 and greater - pm.addPass(mlir::createTritonAMDGPUOptimizeEpiloguePass()); - pm.addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true})); - if (block_level_parameters.num_stages == kAmdDoubleBuffering && - ccRocm.has_amd_matrix_core()) { - pm.addPass(mlir::createTritonAMDGPUStreamPipelinePass( - block_level_parameters.num_stages, /*stream_prefetch=*/true)); - pm.addPass(mlir::createCanonicalizerPass()); + pm->addPass(mlir::createTritonAMDGPUOptimizeEpiloguePass()); + pm->addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true})); + if (num_stages == kAmdDoubleBuffering && cc.has_amd_matrix_core()) { + pm->addPass(mlir::createTritonAMDGPUStreamPipelinePass( + num_stages, /*stream_prefetch=*/true)); + pm->addPass(mlir::createCanonicalizerPass()); } - pm.addPass(mt::createTritonAMDGPUInsertInstructionSchedHintsPass()); - pm.addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true})); - pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); - pm.addPass(mt::gpu::createTritonGPUReduceDataDuplication()); - if (block_level_parameters.num_stages != kAmdDoubleBuffering) { - pm.addPass(mt::gpu::createTritonGPUReorderInstructions()); + pm->addPass(mt::createTritonAMDGPUInsertInstructionSchedHintsPass()); + pm->addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true})); + pm->addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); + pm->addPass(mt::gpu::createTritonGPUReduceDataDuplication()); + if (num_stages != kAmdDoubleBuffering) { + pm->addPass(mt::gpu::createTritonGPUReorderInstructions()); } - pm.addPass(mlir::createTritonAMDGPUCanonicalizePointersPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createSymbolDCEPass()); + pm->addPass(mlir::createTritonAMDGPUCanonicalizePointersPass()); + pm->addPass(mlir::createCanonicalizerPass()); + pm->addPass(mlir::createCSEPass()); + pm->addPass(mlir::createSymbolDCEPass()); // Based on make_llir() in // @triton//:third_party/amd/backend/compiler.py - pm.addPass(mlir::triton::AMD::createDecomposeUnsupportedConversionsPass( - ccRocm.gfx_version())); + pm->addPass(mlir::triton::AMD::createDecomposeUnsupportedConversionsPass( + cc.gfx_version())); const int custom_lds_size = 0; - pm.addPass(mlir::triton::AMD::createOptimizeLDSUsagePass(ccRocm.gfx_version(), - custom_lds_size)); - pm.addPass(mlir::createConvertSCFToCFPass()); - pm.addPass(mlir::createConvertIndexToLLVMPass()); - pm.addPass(mt::gpu::createAllocateSharedMemoryPass()); - pm.addPass( - mt::createConvertTritonAMDGPUToLLVMPass(ccRocm.gfx_version(), true)); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createCSEPass()); + pm->addPass(mlir::triton::AMD::createOptimizeLDSUsagePass(cc.gfx_version(), + custom_lds_size)); + pm->addPass(mlir::createConvertSCFToCFPass()); + pm->addPass(mlir::createConvertIndexToLLVMPass()); + pm->addPass(mt::gpu::createAllocateSharedMemoryPass()); + pm->addPass(mt::createConvertTritonAMDGPUToLLVMPass(cc.gfx_version(), true)); + pm->addPass(mlir::createCanonicalizerPass()); + pm->addPass(mlir::createCSEPass()); // Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass. - pm.addPass(mlir::createConvertControlFlowToLLVMPass()); - pm.addPass(mlir::createArithToLLVMConversionPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createSymbolDCEPass()); - pm.addPass(mt::createTritonAMDGPULowerInstructionSchedHintsPass( - ccRocm.gfx_version(), block_level_parameters.num_stages, "default")); - pm.addPass(mt::createConvertBuiltinFuncToLLVMPass(/*ftz=*/true)); + pm->addPass(mlir::createConvertControlFlowToLLVMPass()); + pm->addPass(mlir::createArithToLLVMConversionPass()); + pm->addPass(mlir::createCanonicalizerPass()); + pm->addPass(mlir::createCSEPass()); + pm->addPass(mlir::createSymbolDCEPass()); + pm->addPass(mt::createTritonAMDGPULowerInstructionSchedHintsPass( + cc.gfx_version(), num_stages, "default")); + pm->addPass(mt::createConvertBuiltinFuncToLLVMPass(/*ftz=*/true)); // There is no clusters in ROCm for now. out_cluster_info.clusterDimX = 1; out_cluster_info.clusterDimY = 1; diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_stub.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_stub.cc new file mode 100644 index 00000000000000..9a732b91ae4ac5 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_stub.cc @@ -0,0 +1,33 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/status/status.h" +#include "mlir/Pass/PassManager.h" +#include "xla/service/gpu/fusions/triton/compilation_pipeline.h" + +namespace xla { +namespace gpu { + +absl::Status CreateTritonPipeline( + mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas, + int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info, + bool is_xla_fusion) { + return absl::UnimplementedError("not supported for this build configuration"); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/triton/dot_algorithms_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/dot_algorithms_test.cc index cbc9962ca6e6ff..5904e4f39008aa 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/dot_algorithms_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/dot_algorithms_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -186,7 +185,7 @@ class TritonAlgorithmTest : public AlgorithmTest { }; TEST_F(AlgorithmTest, Algorithm3xBF16) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule Algorithm3xBF16 ENTRY e { @@ -202,7 +201,7 @@ TEST_F(AlgorithmTest, Algorithm3xBF16) { } TEST_F(AlgorithmTest, Algorithm6xBF16) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule Algorithm6xBF16 ENTRY e { @@ -225,7 +224,7 @@ TEST_F(BlasAlgorithmTest, Algorithm_BF16_BF16_F32) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule Algorithm_BF16_BF16_F32 ENTRY main { @@ -284,7 +283,7 @@ TEST_F(BlasAlgorithmTest, Algorithm_BF16_BF16_F32_X3) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule Algorithm_BF16_BF16_F32_X3 ENTRY main { @@ -339,7 +338,7 @@ TEST_F(BlasAlgorithmTest, Algorithm_BF16_BF16_F32_X6) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule Algorithm_BF16_BF16_F32_X6 ENTRY main { @@ -395,7 +394,7 @@ TEST_F(BlasAlgorithmTest, Algorithm_TF32_TF32_F32_X3) { // We check that the algorithm is propagated to the BLAS call. // We also check that the kernel name matches the algorithm for Ampere. - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule Algorithm_TF32_TF32_F32_X3 ENTRY main { @@ -449,7 +448,7 @@ TEST_F(BlasAlgorithmTest, Algorithm_TF32_TF32_F32_X3) { } TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmWhenBothInputsAreF32) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule Emit6xBF16GemmWhenBothInputsAreF32 triton_dot { @@ -491,7 +490,7 @@ CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf } TEST_F(Triton6xBF16GemmTestWithFlag, Emit6xBF16GemmWhenBothInputsAreF32) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule Emit6xBF16GemmWhenBothInputsAreF32 triton_dot { @@ -532,7 +531,7 @@ CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf } TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmWorksForLongContractingDimension) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule Triton6xBF16GemmWorksForLongContractingDimension triton_dot { @@ -564,7 +563,7 @@ TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmEndToEnd) { if (std::holds_alternative(GpuComputeComp())) { GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X6 not supported on ROCM."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule Emit6xBF16GemmEndToEnd ENTRY e { @@ -636,7 +635,7 @@ class Triton3xBF16GemmTestWithFlag : public AlgorithmTest { }; TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmWhenBothInputsAreF32) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule Emit3xBF16GemmWhenBothInputsAreF32 triton_dot { @@ -678,7 +677,7 @@ CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf } TEST_F(Triton3xBF16GemmTestWithFlag, Emit3xBF16GemmWhenBothInputsAreF32) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule Emit3xBF16GemmWhenBothInputsAreF32 triton_dot { @@ -719,7 +718,7 @@ CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf } TEST_F(Triton3xBF16GemmTestWithFlag, NoEmit3xBF16GemmWhenBothInputsAreNotF32) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule NoEmit3xBF16GemmWhenBothInputsAreNotF32 triton_dot { @@ -747,7 +746,7 @@ CHECK-NOT: tt.dot } TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmWorksForLongContractingDimension) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule Triton3xBF16GemmWorksForLongContractingDimension triton_dot { @@ -779,7 +778,7 @@ TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmEndToEnd) { if (std::holds_alternative(GpuComputeComp())) { GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X3 not supported on ROCM."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule Emit3xBF16GemmEndToEnd ENTRY e { @@ -953,9 +952,9 @@ TEST_F(TritonAlgorithmTest, Dot_BF16_X6_WithConst) { ENTRY %entry_computation { %p_0 = f32[1,258]{1,0} parameter(0) - ROOT %dot = f32[258]{0} fusion(f32[1,258]{1,0} %p_0), - kind=kCustom, - calls=%triton_fusion_dot, + ROOT %dot = f32[258]{0} fusion(f32[1,258]{1,0} %p_0), + kind=kCustom, + calls=%triton_fusion_dot, backend_config={ "operation_queue_id":"0", "wait_on_operation_queues":[], @@ -1121,9 +1120,9 @@ class BlasCanHandle return absl::StrFormat(kHloTextTemplate, HloModuleTestName(), algorithm_); } - static constexpr std::string_view kPattern = R"(CHECK: __cublas$gemm)"; + static constexpr absl::string_view kPattern = R"(CHECK: __cublas$gemm)"; - static constexpr std::string_view kReferenceHloText = R"( + static constexpr absl::string_view kReferenceHloText = R"( HloModule %s ENTRY e { @@ -1158,7 +1157,7 @@ class BlasCanHandle } private: - static constexpr std::string_view kHloTextTemplate = R"( + static constexpr absl::string_view kHloTextTemplate = R"( HloModule %s ENTRY e { @@ -1188,10 +1187,10 @@ class TritonCanHandle return absl::StrFormat(kHloTextTemplate, HloModuleTestName(), algorithm_); } - static constexpr std::string_view kPattern = R"(CHECK: __triton_gemm)"; + static constexpr absl::string_view kPattern = R"(CHECK: __triton_gemm)"; private: - static constexpr std::string_view kHloTextTemplate = R"( + static constexpr absl::string_view kHloTextTemplate = R"( HloModule %s triton_dot { @@ -1373,8 +1372,8 @@ class CSVWriter { } // Returns the results in CSV format. - std::string GetResult(std::string_view title, - std::string_view delimiter = ", ", + std::string GetResult(absl::string_view title, + absl::string_view delimiter = ", ", bool separate_first_row = true) const { std::vector sizes; size_t columns = 0; @@ -1432,7 +1431,7 @@ class AlgorithmsSupportTest } absl::StatusOr> GetModule( - std::string_view hlo_template, + absl::string_view hlo_template, const std::vector>& args, const DebugOptions& options) { auto config = GetModuleConfig(options); @@ -1476,7 +1475,7 @@ class AlgorithmsSupportTest algorithm_ = AlgorithmToString(std::get<0>(GetParam())); } - std::string GetTestName(std::string_view delimiter) const { + std::string GetTestName(absl::string_view delimiter) const { auto test_info = ::testing::UnitTest::GetInstance()->current_test_info(); auto suite_name = test_info->test_suite_name(); std::string test_name = test_info->name(); @@ -1484,7 +1483,7 @@ class AlgorithmsSupportTest {{"/", "_"}}); } - void DumpResults(const CSVWriter& csv, std::string_view suffix) { + void DumpResults(const CSVWriter& csv, absl::string_view suffix) { auto title = absl::StrCat("Test name: ", GetTestName(".")); auto result = csv.GetResult(title, ", "); LOG(ERROR) << "result: \n" << result; @@ -1501,8 +1500,8 @@ class AlgorithmsSupportTest std::string algorithm_; - static constexpr std::string_view kBlasPattern = "__cublas$gemm"; - static constexpr std::string_view kTritonGemmPattern = "__triton_gemm"; + static constexpr absl::string_view kBlasPattern = "__cublas$gemm"; + static constexpr absl::string_view kTritonGemmPattern = "__triton_gemm"; static constexpr int kMaxSize = 8192; static constexpr int kStepSize = 8; static constexpr int kMaxK = kMaxSize; @@ -1532,8 +1531,8 @@ TEST_P(AlgorithmsSupportTest, DotBC) { csv.nextRow(); csv.appendValue(b); for (int k = 1; k <= kMaxSize; k *= kStepSize) { - auto run = [&](std::string_view backend, std::string_view pattern, - const DebugOptions& options) -> std::string_view { + auto run = [&](absl::string_view backend, absl::string_view pattern, + const DebugOptions& options) -> absl::string_view { auto test_name = absl::StrReplaceAll(TestName(), {{"/", "_"}}); auto module_name = absl::StrCat(test_name, "_", backend, "_", b, "_", k); @@ -1580,8 +1579,8 @@ TEST_P(AlgorithmsSupportTest, DotNC) { csv.nextRow(); csv.appendValue(m); for (int n = 1; n <= kMaxSize; n *= kStepSize) { - auto run = [&](std::string backend, std::string_view pattern, - const DebugOptions& options) -> std::string_view { + auto run = [&](std::string backend, absl::string_view pattern, + const DebugOptions& options) -> absl::string_view { auto test_name = absl::StrReplaceAll(TestName(), {{"/", "_"}}); auto module_name = absl::StrCat(test_name, "_", backend, "_", m, "_", kMaxK, "_", n, "_", algorithm_); @@ -1609,7 +1608,8 @@ TEST_P(AlgorithmsSupportTest, DotNC) { TEST_P(AlgorithmsSupportTest, IsDotAlgorithmSupportedByTriton) { // TODO: Weekly-sync 24-12-10 - GTEST_SKIP() << "TODO: Weekly-sync 24-12-10: Skip IsDotAlgorithmSupportedByTriton ."; + GTEST_SKIP() + << "TODO: Weekly-sync 24-12-10: Skip IsDotAlgorithmSupportedByTriton ."; // Here we test which dot algorithm is supported by triton. // In case of a change you need to update the expected results. @@ -1628,7 +1628,7 @@ TEST_P(AlgorithmsSupportTest, IsDotAlgorithmSupportedByTriton) { auto m = 128; auto n = 128; auto k = 128; - auto run = [&](std::string backend, std::string_view pattern, + auto run = [&](std::string backend, absl::string_view pattern, const DebugOptions& options) -> absl::StatusOr { auto test_name = absl::StrReplaceAll(TestName(), {{"/", "_"}}); auto module_name = absl::StrCat(test_name, "_", backend, "_", m, "_", kMaxK, diff --git a/third_party/xla/xla/service/gpu/fusions/triton/emitter_helpers.cc b/third_party/xla/xla/service/gpu/fusions/triton/emitter_helpers.cc index c3be827bf59cfc..7f3b990219c231 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/emitter_helpers.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/emitter_helpers.cc @@ -31,7 +31,7 @@ limitations under the License. #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" @@ -43,6 +43,7 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" #include "xla/mlir_hlo/mhlo/transforms/transformation_helpers.h" #include "xla/primitive_util.h" +#include "xla/service/gpu/fusions/emitter_loc_op_builder.h" #include "xla/service/gpu/target_util.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/stream_executor/device_description.h" @@ -54,7 +55,6 @@ namespace xla::gpu::triton { using ::llvm::SmallVector; using ::mlir::ArrayRef; -using ::mlir::ImplicitLocOpBuilder; using ::mlir::ShapedType; using ::mlir::Type; using ::mlir::Value; @@ -65,13 +65,9 @@ namespace mh = ::mlir::mhlo; namespace mm = ::mlir::math; namespace mt = ::mlir::triton; -ScalarOrTensor::ScalarOrTensor(mlir::Value value) { - if (auto tt = mlir::dyn_cast(value.getType())) { - CHECK_GT(tt.getRank(), 0); - value_ = TensorValue{value}; - } else { - value_ = ScalarValue{value}; - } +ScalarOrTensor::ScalarOrTensor(mlir::Value value) : value_(value) { + CHECK(IsScalar() || UnwrapTensor().getType().getRank() > 0) + << "0D tensors are not supported by Triton"; } SmallVector GetPaddedTileSizes(ArrayRef tile_sizes) { @@ -83,7 +79,7 @@ SmallVector GetPaddedTileSizes(ArrayRef tile_sizes) { return result; } -absl::StatusOr TritonType(mlir::OpBuilder b, PrimitiveType t) { +absl::StatusOr TritonType(EmitterLocOpBuilder& b, PrimitiveType t) { switch (t) { case F64: return b.getF64Type(); @@ -114,7 +110,7 @@ absl::StatusOr TritonType(mlir::OpBuilder b, PrimitiveType t) { } } -Type StorageType(mlir::OpBuilder b, Type t) { +Type StorageType(EmitterLocOpBuilder& b, Type t) { if (t.isInteger(1)) { return b.getI8Type(); } @@ -126,7 +122,7 @@ bool IsFp8Type(Type t) { t.isFloat8E4M3FNUZ() || t.isFloat8E4M3B11FNUZ(); } -Value Cast(ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) { +Value Cast(EmitterLocOpBuilder& b, Value value, Type dst_element_ty) { Type src_ty = value.getType(); Type src_element_ty = src_ty; Type fp32_ty = b.getF32Type(); @@ -243,7 +239,7 @@ Value Cast(ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) { << llvm_ir::DumpToString(dst_element_ty); } -Value Subtract(ImplicitLocOpBuilder& b, ValueRange values) { +Value Subtract(EmitterLocOpBuilder& b, ValueRange values) { if (mlir::isa(mlir::getElementTypeOrSelf(values[0]))) { return b.create(values[0], values[1]); } else { @@ -251,7 +247,7 @@ Value Subtract(ImplicitLocOpBuilder& b, ValueRange values) { } } -Value Compare(ImplicitLocOpBuilder& b, ValueRange values, +Value Compare(EmitterLocOpBuilder& b, ValueRange values, mh::ComparisonDirection direction) { const Type type = mlir::getElementTypeOrSelf(values[0]); if (mlir::isa(type)) { @@ -268,7 +264,7 @@ Value Compare(ImplicitLocOpBuilder& b, ValueRange values, values[0], values[1]); } -Value Maximum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, +Value Maximum(EmitterLocOpBuilder& b, const se::DeviceDescription& device_info, ValueRange values) { if (mlir::isa(mlir::getElementTypeOrSelf(values[0]))) { return b.create(values); @@ -289,7 +285,7 @@ Value Maximum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, values[0], values[1]); } -Value Minimum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, +Value Minimum(EmitterLocOpBuilder& b, const se::DeviceDescription& device_info, ValueRange values) { if (mlir::isa(mlir::getElementTypeOrSelf(values[0]))) { return b.create(values); @@ -311,10 +307,10 @@ Value Minimum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, values[0], values[1]); } -ScalarOrTensor Splat(ImplicitLocOpBuilder& b, ScalarOrTensor value, +ScalarOrTensor Splat(EmitterLocOpBuilder& b, ScalarOrTensor value, ArrayRef shape) { CHECK(!shape.empty()); - auto type = mlir::RankedTensorType::get(shape, value.Type()); + auto type = mlir::RankedTensorType::get(shape, value.getType()); return ScalarOrTensor(b.create(type, value.UnwrapUnsafe())); } @@ -330,7 +326,7 @@ bool IsSupportedElementwiseLibdeviceFunction(const HloInstruction& hlo) { } absl::StatusOr EmitElementwiseLibdeviceFunction( - ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + EmitterLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloInstruction& hlo, ValueRange inputs) { auto dev_fn_id = GetTargetDeviceFunctionID(hlo.opcode()); @@ -370,7 +366,7 @@ absl::StatusOr EmitElementwiseLibdeviceFunction( return res; } -absl::StatusOr EmitElementwise(ImplicitLocOpBuilder& b, +absl::StatusOr EmitElementwise(EmitterLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloInstruction& hlo, @@ -457,7 +453,7 @@ absl::StatusOr EmitElementwise(ImplicitLocOpBuilder& b, } } -absl::StatusOr EmitConstant(ImplicitLocOpBuilder& b, +absl::StatusOr EmitConstant(EmitterLocOpBuilder& b, const HloInstruction& constant) { TF_ASSIGN_OR_RETURN(Type ty, TritonType(b, constant.shape().element_type())); llvm::SmallVector shape{constant.shape().dimensions().begin(), diff --git a/third_party/xla/xla/service/gpu/fusions/triton/emitter_helpers.h b/third_party/xla/xla/service/gpu/fusions/triton/emitter_helpers.h index 17a1015ddfeaf8..7e20b6b3f6157f 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/emitter_helpers.h +++ b/third_party/xla/xla/service/gpu/fusions/triton/emitter_helpers.h @@ -27,7 +27,6 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" @@ -36,6 +35,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/literal.h" +#include "xla/service/gpu/fusions/emitter_loc_op_builder.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" @@ -48,6 +48,8 @@ namespace xla::gpu::triton { // non-0D tensor. An attempt to use this class with 0D tensors will CHECK-fail // because 0D tensors are not supported by Triton. class ScalarOrTensor { + using TensorValue = mlir::TypedValue; + public: ScalarOrTensor() = default; @@ -55,17 +57,17 @@ class ScalarOrTensor { // value is a 0D tensor, because Triton does not support 0D tensors. explicit ScalarOrTensor(mlir::Value value); - bool IsScalar() const { return std::holds_alternative(value_); } - bool IsTensor() const { return std::holds_alternative(value_); } + bool IsScalar() const { return !IsTensor(); } + bool IsTensor() const { return mlir::isa(value_); } - mlir::Value UnwrapScalar() { + mlir::Value UnwrapScalar() const { CHECK(IsScalar()); - return std::get(value_).scalar_value; + return value_; } - mlir::Value UnwrapTensor() { + TensorValue UnwrapTensor() const { CHECK(IsTensor()); - return std::get(value_).tensor_value; + return mlir::cast(value_); } // Returns the underlying value regardless of whether it is a scalar or a @@ -73,25 +75,12 @@ class ScalarOrTensor { // both needs to use an `mlir::Value` and functions identically for scalars // and tensors. In other cases, prefer to use the `UnwrapScalar` or // `UnwrapTensor` methods. - mlir::Value UnwrapUnsafe() { - if (auto* scalar = std::get_if(&value_)) { - return scalar->scalar_value; - } - return std::get(value_).tensor_value; - } + mlir::Value UnwrapUnsafe() const { return value_; } - mlir::Type Type() { return UnwrapUnsafe().getType(); } + mlir::Type getType() const { return value_.getType(); } private: - struct ScalarValue { - mlir::Value scalar_value; - }; - - struct TensorValue { - mlir::Value tensor_value; - }; - - std::variant value_; + mlir::Value value_; }; // Triton requires that all block dimensions are a power of 2. @@ -101,9 +90,9 @@ llvm::SmallVector GetPaddedTileSizes( llvm::ArrayRef tile_sizes); // XLA -> Triton type conversions. -absl::StatusOr TritonType(mlir::OpBuilder b, PrimitiveType t); +absl::StatusOr TritonType(EmitterLocOpBuilder& b, PrimitiveType t); -mlir::Type StorageType(mlir::OpBuilder b, mlir::Type t); +mlir::Type StorageType(EmitterLocOpBuilder& b, mlir::Type t); // Get the value of the scalar constant's literal in a C++ type. template @@ -117,8 +106,7 @@ T ScalarConstantValue(const HloInstruction& instr, PrimitiveType dst_type) { // Create a scalar constant. template -ScalarOrTensor CreateConst(mlir::ImplicitLocOpBuilder b, mlir::Type type, - T value) { +ScalarOrTensor CreateConst(EmitterLocOpBuilder& b, mlir::Type type, T value) { if (mlir::isa(type)) { auto result = b.create(b.getIntegerAttr(type, value)); @@ -134,8 +122,8 @@ ScalarOrTensor CreateConst(mlir::ImplicitLocOpBuilder b, mlir::Type type, // Create a tensor constant. template -ScalarOrTensor CreateConst(mlir::ImplicitLocOpBuilder& b, mlir::Type type, - T value, llvm::ArrayRef shape) { +ScalarOrTensor CreateConst(EmitterLocOpBuilder& b, mlir::Type type, T value, + llvm::ArrayRef shape) { if (shape.empty()) { return CreateConst(b, type, value); } @@ -159,8 +147,7 @@ ScalarOrTensor CreateConst(mlir::ImplicitLocOpBuilder& b, mlir::Type type, // Create a constant of the same shape as `like` but with a new type and value. template -mlir::Value ConstLike(mlir::ImplicitLocOpBuilder& b, mlir::Value like, - T new_value) { +mlir::Value ConstLike(EmitterLocOpBuilder& b, mlir::Value like, T new_value) { if (auto src_shaped_ty = mlir::dyn_cast(like.getType())) { mlir::Type src_ty = src_shaped_ty.getElementType(); return CreateConst(b, src_ty, new_value, src_shaped_ty.getShape()) @@ -169,25 +156,25 @@ mlir::Value ConstLike(mlir::ImplicitLocOpBuilder& b, mlir::Value like, return CreateConst(b, like.getType(), new_value).UnwrapUnsafe(); } -inline mlir::Value ZerosLike(mlir::ImplicitLocOpBuilder& b, mlir::Value x) { +inline mlir::Value ZerosLike(EmitterLocOpBuilder& b, mlir::Value x) { return ConstLike(b, x, 0); } -inline mlir::Value OnesLike(mlir::ImplicitLocOpBuilder& b, mlir::Value x) { +inline mlir::Value OnesLike(EmitterLocOpBuilder& b, mlir::Value x) { return ConstLike(b, x, 1); } bool IsFp8Type(mlir::Type t); -ScalarOrTensor Splat(mlir::ImplicitLocOpBuilder& b, ScalarOrTensor value, +ScalarOrTensor Splat(EmitterLocOpBuilder& b, ScalarOrTensor value, llvm::ArrayRef shape); // Triton type conversions. -mlir::Value Cast(mlir::ImplicitLocOpBuilder& b, mlir::Value value, +mlir::Value Cast(EmitterLocOpBuilder& b, mlir::Value value, mlir::Type dst_element_ty); // Emits a scalar constant. -absl::StatusOr EmitConstant(mlir::ImplicitLocOpBuilder& b, +absl::StatusOr EmitConstant(EmitterLocOpBuilder& b, const HloInstruction& constant); bool IsSupportedElementwiseLibdeviceFunction(const HloInstruction& hlo); @@ -195,12 +182,12 @@ bool IsSupportedElementwiseLibdeviceFunction(const HloInstruction& hlo); // Should only be called if IsSupportedElementwiseLibdeviceFunction() returns // true for `hlo`, otherwise an error is returned. absl::StatusOr EmitElementwiseLibdeviceFunction( - mlir::ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + EmitterLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloInstruction& hlo, mlir::ValueRange inputs); absl::StatusOr EmitElementwise( - mlir::ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + EmitterLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloInstruction& hlo, mlir::ValueRange inputs); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/tests/BUILD new file mode 100644 index 00000000000000..b000766eaf4df6 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/tests/BUILD @@ -0,0 +1,26 @@ +load("//xla:lit.bzl", "lit_test_suite") # @unused + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +# copybara:uncomment_begin(triton-opt tool doesn't build in OSS) +# lit_test_suite( +# name = "mlir_lit_tests", +# srcs = glob(["*.mlir"]), +# cfg = "//xla:lit.cfg.py", +# tools = [ +# "@llvm-project//llvm:FileCheck", +# "//xla/service/gpu/tests:xla-opt", +# ], +# ) +# copybara:uncomment_end diff --git a/third_party/xla/xla/service/gpu/fusions/triton/tests/int4_packed_dim_major_1d.mlir b/third_party/xla/xla/service/gpu/fusions/triton/tests/int4_packed_dim_major_1d.mlir new file mode 100644 index 00000000000000..a7c1096dbd555f --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/tests/int4_packed_dim_major_1d.mlir @@ -0,0 +1,41 @@ +// RUN: xla-opt --int4-to-packed-int4-rewrite --canonicalize -- %s | FileCheck --dump-input=never %s + +module { + tt.func @major_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %c128_i32 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c64_i32 = arith.constant 64 : i32 + %cst = arith.constant dense<0> : tensor<64x64xi8> + + %0 = tt.make_tensor_ptr %arg0, [%c1_i64, %c128_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array, packed_dim = 1 } : > +// CHECK: %0 = tt.make_tensor_ptr %arg0, [%c1_i64, %c64_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + + %1 = tt.advance %0, [%c64_i32, %c0_i32] : > +// CHECK-NEXT: %1 = tt.advance %0, [%c64_i32, %c0_i32] : > + + %2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst) -> (!tt.ptr>, tensor<64x64xi8>) : i32 { +// CHECK-NEXT: %2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst_0) -> (!tt.ptr>, tensor<64x64xi8>) : i32 { + + %4 = tt.load %arg3 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> +// CHECK-NEXT: %4 = tt.load %arg3 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + + %5 = tt.advance %arg3, [%c0_i32, %c64_i32] : > +// CHECK-NEXT: %5 = tt.advance %arg3, [%c0_i32, %c32_i32] : > + + %6 = arith.extsi %4 : tensor<64x64xi4> to tensor<64x64xi8> +// CHECK-NEXT: %6 = arith.shli %4, %cst : tensor<64x32xi8> +// CHECK-NEXT: %7 = arith.shrsi %6, %cst : tensor<64x32xi8> +// CHECK-NEXT: %8 = arith.shrsi %4, %cst : tensor<64x32xi8> +// CHECK-NEXT: %9 = tt.join %8, %7 : tensor<64x32xi8> -> tensor<64x32x2xi8> +// CHECK-NEXT: %10 = tt.reshape %9 : tensor<64x32x2xi8> -> tensor<64x64xi8> + + scf.yield %5, %6 : !tt.ptr>, tensor<64x64xi8> +// CHECK-NEXT: scf.yield %5, %10 : !tt.ptr>, tensor<64x64xi8> + } + %3 = tt.make_tensor_ptr %arg1, [%c1_i64, %c128_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + tt.store %3, %2#1 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} diff --git a/third_party/xla/xla/service/gpu/fusions/triton/tests/int4_packed_dim_major_2d.mlir b/third_party/xla/xla/service/gpu/fusions/triton/tests/int4_packed_dim_major_2d.mlir new file mode 100644 index 00000000000000..00b268e056b867 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/tests/int4_packed_dim_major_2d.mlir @@ -0,0 +1,44 @@ +// RUN: xla-opt --int4-to-packed-int4-rewrite --canonicalize -- %s | FileCheck --dump-input=never %s + +module { + tt.func @major_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %c128_i32 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c16_i64 = arith.constant 16 : i64 + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c64_i32 = arith.constant 64 : i32 + %cst = arith.constant dense<0> : tensor<64x64xi8> + + %0 = tt.make_tensor_ptr %arg0, [%c16_i64, %c128_i64], [%c1_i64, %c16_i64], [%c0_i32, %c0_i32] {order = array, packed_dim = 1 } : > +// CHECK: %0 = tt.make_tensor_ptr %arg0, [%c8_i64, %c128_i64], [%c1_i64, %c8_i64], [%c0_i32, %c0_i32] {order = array} : > + + %1 = tt.advance %0, [%c64_i32, %c0_i32] : > +// CHECK-NEXT: %1 = tt.advance %0, [%c32_i32, %c0_i32] : > + + %2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst) -> (!tt.ptr>, tensor<64x64xi8>) : i32 { +// CHECK-NEXT: %2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst_0) -> (!tt.ptr>, tensor<64x64xi8>) : i32 { + + %4 = tt.load %arg3 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> +// CHECK-NEXT: %4 = tt.load %arg3 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + + %5 = tt.advance %arg3, [%c0_i32, %c64_i32] : > +// CHECK-NEXT: %5 = tt.advance %arg3, [%c0_i32, %c64_i32] : > + + %6 = arith.extsi %4 : tensor<64x64xi4> to tensor<64x64xi8> +// CHECK-NEXT: %6 = arith.shli %4, %cst : tensor<32x64xi8> +// CHECK-NEXT: %7 = arith.shrsi %6, %cst : tensor<32x64xi8> +// CHECK-NEXT: %8 = arith.shrsi %4, %cst : tensor<32x64xi8> +// CHECK-NEXT: %9 = tt.join %8, %7 : tensor<32x64xi8> -> tensor<32x64x2xi8> +// CHECK-NEXT: %10 = tt.trans %9 {order = array} : tensor<32x64x2xi8> -> tensor<32x2x64xi8> +// CHECK-NEXT: %11 = tt.reshape %10 : tensor<32x2x64xi8> -> tensor<64x64xi8> + + scf.yield %5, %6 : !tt.ptr>, tensor<64x64xi8> +// CHECK-NEXT: scf.yield %5, %11 : !tt.ptr>, tensor<64x64xi8> + } + %3 = tt.make_tensor_ptr %arg1, [%c128_i64, %c1_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + tt.store %3, %2#1 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} + diff --git a/third_party/xla/xla/service/gpu/fusions/triton/tests/int4_packed_dim_minor_1d.mlir b/third_party/xla/xla/service/gpu/fusions/triton/tests/int4_packed_dim_minor_1d.mlir new file mode 100644 index 00000000000000..06f3957e9505bc --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/tests/int4_packed_dim_minor_1d.mlir @@ -0,0 +1,44 @@ +// RUN: xla-opt --int4-to-packed-int4-rewrite --canonicalize -- %s | FileCheck --dump-input=never %s + +module { + tt.func @minor_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %c128_i32 = arith.constant 128 : i32 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c128_i64 = arith.constant 128 : i64 + %c64_i32 = arith.constant 64 : i32 + %cst = arith.constant dense<0> : tensor<64x64xi8> + + %0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c1_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array, packed_dim = 0 } : > +// CHECK: %0 = tt.make_tensor_ptr %arg0, [%c64_i64, %c1_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + + %1 = tt.advance %0, [%c0_i32, %c64_i32] : > +// CHECK-NEXT: %1 = tt.advance %0, [%c0_i32, %c64_i32] : > + + %2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst) -> (!tt.ptr>, tensor<64x64xi8>) : i32 { +// CHECK-NEXT: %2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst_0) -> (!tt.ptr>, tensor<64x64xi8>) : i32 { + + %4 = tt.load %arg3 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> +// CHECK-NEXT: %4 = tt.load %arg3 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + + %5 = tt.advance %arg3, [%c64_i32, %c0_i32] : > +// CHECK-NEXT: %5 = tt.advance %arg3, [%c32_i32, %c0_i32] : > + + %6 = arith.extsi %4 : tensor<64x64xi4> to tensor<64x64xi8> +// CHECK-NEXT: %6 = arith.shli %4, %cst : tensor<32x64xi8> +// CHECK-NEXT: %7 = arith.shrsi %6, %cst : tensor<32x64xi8> +// CHECK-NEXT: %8 = arith.shrsi %4, %cst : tensor<32x64xi8> +// CHECK-NEXT: %9 = tt.join %8, %7 : tensor<32x64xi8> -> tensor<32x64x2xi8> +// CHECK-NEXT: %10 = tt.trans %9 {order = array} : tensor<32x64x2xi8> -> tensor<32x2x64xi8> +// CHECK-NEXT: %11 = tt.reshape %10 : tensor<32x2x64xi8> -> tensor<64x64xi8> + + scf.yield %5, %6 : !tt.ptr>, tensor<64x64xi8> +// CHECK-NEXT: scf.yield %5, %11 : !tt.ptr>, tensor<64x64xi8> + } + %3 = tt.make_tensor_ptr %arg1, [%c128_i64, %c1_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + tt.store %3, %2#1 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} + + diff --git a/third_party/xla/xla/service/gpu/fusions/triton/tests/int4_packed_dim_minor_2d.mlir b/third_party/xla/xla/service/gpu/fusions/triton/tests/int4_packed_dim_minor_2d.mlir new file mode 100644 index 00000000000000..462f0317767dac --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/tests/int4_packed_dim_minor_2d.mlir @@ -0,0 +1,43 @@ +// RUN: xla-opt --int4-to-packed-int4-rewrite --canonicalize -- %s | FileCheck --dump-input=never %s + +module { + tt.func @minor_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %c128_i32 = arith.constant 128 : i32 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c128_i64 = arith.constant 128 : i64 + %c64_i32 = arith.constant 64 : i32 + %cst = arith.constant dense<0> : tensor<64x64xi8> + + %0 = tt.make_tensor_ptr %arg0, [%c1_i64, %c128_i64], [%c128_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array, packed_dim = 0 } : > +// CHECK: %0 = tt.make_tensor_ptr %arg0, [%c1_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + + %1 = tt.advance %0, [%c64_i32, %c0_i32] : > +// CHECK-NEXT: %1 = tt.advance %0, [%c64_i32, %c0_i32] : > + + %2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst) -> (!tt.ptr>, tensor<64x64xi8>) : i32 { +// CHECK-NEXT: %2:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg3 = %1, %arg4 = %cst_0) -> (!tt.ptr>, tensor<64x64xi8>) : i32 { + + %4 = tt.load %arg3 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> +// CHECK-NEXT: %4 = tt.load %arg3 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + + %5 = tt.advance %arg3, [%c0_i32, %c64_i32] : > +// CHECK-NEXT: %5 = tt.advance %arg3, [%c0_i32, %c32_i32] : > + + %6 = arith.extsi %4 : tensor<64x64xi4> to tensor<64x64xi8> +// CHECK-NEXT: %6 = arith.shli %4, %cst : tensor<64x32xi8> +// CHECK-NEXT: %7 = arith.shrsi %6, %cst : tensor<64x32xi8> +// CHECK-NEXT: %8 = arith.shrsi %4, %cst : tensor<64x32xi8> +// CHECK-NEXT: %9 = tt.join %8, %7 : tensor<64x32xi8> -> tensor<64x32x2xi8> +// CHECK-NEXT: %10 = tt.reshape %9 : tensor<64x32x2xi8> -> tensor<64x64xi8> + + scf.yield %5, %6 : !tt.ptr>, tensor<64x64xi8> +// CHECK-NEXT: scf.yield %5, %10 : !tt.ptr>, tensor<64x64xi8> + } + %3 = tt.make_tensor_ptr %arg1, [%c128_i64, %c1_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + tt.store %3, %2#1 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} + + diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index ace9a68b3e354d..4af1005413a172 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -61,7 +61,6 @@ limitations under the License. #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectRegistry.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" @@ -81,6 +80,9 @@ limitations under the License. #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Transforms/Passes.h" #include "xla/autotuning.pb.h" +#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" +#include "xla/backends/gpu/codegen/transforms/passes.h" +#include "xla/codegen/emitters/elemental_hlo_to_mlir.h" #include "xla/codegen/ir/xla_ops.h" #include "xla/hlo/analysis/indexing_analysis.h" #include "xla/hlo/analysis/indexing_map.h" @@ -95,9 +97,8 @@ limitations under the License. #include "xla/permutation_util.h" #include "xla/service/dump.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/transforms/passes.h" +#include "xla/service/gpu/fusions/emitter_loc_op_builder.h" +#include "xla/service/gpu/fusions/triton/compilation_pipeline.h" #include "xla/service/gpu/fusions/triton/emitter_helpers.h" #include "xla/service/gpu/fusions/triton/passes.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.h" @@ -137,7 +138,6 @@ namespace ttir = ::mlir::triton; using ::llvm::SmallVector; using ::mlir::ArrayRef; -using ::mlir::ImplicitLocOpBuilder; using ::mlir::ShapedType; using ::mlir::Type; using ::mlir::Value; @@ -156,29 +156,29 @@ namespace { using TensorValue = mlir::TypedValue; -ScalarOrTensor Broadcast(ImplicitLocOpBuilder& b, TensorValue value, +ScalarOrTensor Broadcast(EmitterLocOpBuilder& b, TensorValue value, ArrayRef shape) { return ScalarOrTensor( b.create(value.getType().clone(shape), value)); } -ScalarOrTensor Range(ImplicitLocOpBuilder& b, int32_t limit) { +ScalarOrTensor Range(EmitterLocOpBuilder& b, int32_t limit) { auto type = mlir::RankedTensorType::get(limit, b.getI32Type()); return ScalarOrTensor(b.create(type, 0, limit)); } -Value AddPtr(ImplicitLocOpBuilder& b, Value ptr, Value offset) { +Value AddPtr(EmitterLocOpBuilder& b, Value ptr, Value offset) { return b.create(ptr.getType(), ptr, offset); } -ScalarOrTensor EmitParameterLoad(ImplicitLocOpBuilder& b, Value pointer, +ScalarOrTensor EmitParameterLoad(EmitterLocOpBuilder& b, Value pointer, ArrayRef boundary_checks) { if (auto make_tensor_ptr = pointer.getDefiningOp()) { if (make_tensor_ptr.getOffsets().empty()) { return ScalarOrTensor(b.create(make_tensor_ptr.getBase(), ttir::CacheModifier::NONE, ttir::EvictionPolicy::NORMAL, - /*isVolatile=*/false)); + /*isVolatile*/ false)); } } @@ -191,24 +191,24 @@ ScalarOrTensor EmitParameterLoad(ImplicitLocOpBuilder& b, Value pointer, return ScalarOrTensor(b.create( pointer, boundary_checks, padding, ttir::CacheModifier::NONE, ttir::EvictionPolicy::NORMAL, - /*isVolatile=*/false)); + /*isVolatile*/ false)); } // Non-tensor pointer. return ScalarOrTensor(b.create( pointer, ttir::CacheModifier::NONE, ttir::EvictionPolicy::NORMAL, - /*isVolatile=*/false)); + /*isVolatile*/ false)); } absl::StatusOr EmitScope( - ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + EmitterLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const TritonFusionAnalysis* analysis, absl::Span instructions, absl::flat_hash_map& values); absl::StatusOr EmitReduce( - ImplicitLocOpBuilder& b, const TiledHloInstruction& tiled_hlo_reduce, + EmitterLocOpBuilder& b, const TiledHloInstruction& tiled_hlo_reduce, absl::flat_hash_map& values, absl::string_view libdevice_path, const se::DeviceDescription& device_info) { @@ -218,7 +218,7 @@ absl::StatusOr EmitReduce( *::xla::Cast(tiled_hlo_reduce.hlo()); ScalarOrTensor input = values[tiled_hlo_reduce.operand(0)]; llvm::ArrayRef input_shape = - mlir::cast(input.Type()).getShape(); + mlir::cast(input.getType()).getShape(); absl::Span source_tensor_shape = hlo_reduce.operand(0)->shape().dimensions(); @@ -242,9 +242,9 @@ absl::StatusOr EmitReduce( // result are equal. for (int i = 0; i < input_shape.size() - 1; i++) { if (i < reduction_dimension) { - range = b.create(range, /*axis=*/0); + range = b.create(range, /*axis*/ 0); } else { - range = b.create(range, /*axis=*/i + 1); + range = b.create(range, /*axis*/ i + 1); } } Value mask = Broadcast(b, mlir::cast(range), input_shape) @@ -262,7 +262,7 @@ absl::StatusOr EmitReduce( } else { for (int i = 0; i < input_shape.size(); i++) { neutral = ScalarOrTensor( - b.create(neutral.UnwrapUnsafe(), /*axis=*/0)); + b.create(neutral.UnwrapUnsafe(), /*axis*/ 0)); } neutral = Broadcast(b, mlir::cast(neutral.UnwrapUnsafe()), input_shape); @@ -319,7 +319,7 @@ absl::StatusOr EmitReduce( // // TODO(b/331413981): get rid of this special handling once this is solved. absl::StatusOr EmitNestedFusion( - ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + EmitterLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloFusionInstruction& fusion_instruction, absl::flat_hash_map& values) { @@ -350,7 +350,7 @@ absl::StatusOr EmitNestedFusion( } ScalarOrTensor EmitTiledBroadcast( - ImplicitLocOpBuilder& b, const TiledHloInstruction& tiled_broadcast, + EmitterLocOpBuilder& b, const TiledHloInstruction& tiled_broadcast, absl::flat_hash_map& values) { const llvm::SmallVector& input_tile_shape = tiled_broadcast.operand(0)->tile_sizes(); @@ -407,7 +407,7 @@ ScalarOrTensor EmitTiledBroadcast( } absl::StatusOr EmitTiledIota( - ImplicitLocOpBuilder& b, ValueRange tile_multi_index, + EmitterLocOpBuilder& b, ValueRange tile_multi_index, const TiledHloInstruction& tiled_iota) { const HloIotaInstruction* hlo_iota = ::xla::Cast(tiled_iota.hlo()); @@ -422,9 +422,9 @@ absl::StatusOr EmitTiledIota( tiled_iota.tile_offsets_indexing()); auto iota_dim_offset = b.create( - b.getI32Type(), mlir_converter::ApplyIndexing( - tile_offsets_indexing, /*dims=*/tile_multi_index, - /*symbols=*/{}, b)[iota_dim]); + b.getI32Type(), + emitters::ApplyIndexing(tile_offsets_indexing, /*dims=*/tile_multi_index, + /*symbols=*/{}, b)[iota_dim]); // First, stride as needed between the iota components. Value range = b.create( @@ -450,9 +450,9 @@ absl::StatusOr EmitTiledIota( // produce the whole iota tile. for (int i = 0; i < padded_tile_sizes.size() - 1; i++) { if (i < iota_dim) { - range = b.create(range, /*axis=*/0); + range = b.create(range, /*axis*/ 0); } else { - range = b.create(range, /*axis=*/i + 1); + range = b.create(range, /*axis*/ i + 1); } } @@ -460,7 +460,7 @@ absl::StatusOr EmitTiledIota( } // Reshapes a non-0D tensor of shape [1, 1, 1, ...] to a scalar. -ScalarOrTensor ReshapeTensorToScalar(ImplicitLocOpBuilder& b, Value input) { +ScalarOrTensor ReshapeTensorToScalar(EmitterLocOpBuilder& b, Value input) { auto element_type = mlir::cast(input.getType()).getElementType(); // First, reshape to a 1D tensor if not already the case. This is needed @@ -469,12 +469,12 @@ ScalarOrTensor ReshapeTensorToScalar(ImplicitLocOpBuilder& b, Value input) { if (mlir::cast(input.getType()).getRank() > 1) { Type output_tensor_type = mlir::RankedTensorType::get({1}, element_type); single_dim_tensor = b.create(output_tensor_type, input, - /*allow_reorder=*/true); + /*allow_reorder*/ true); } // Second, reduce to a scalar. ttir::ReduceOp reduction = - b.create(single_dim_tensor, /*axis=*/0); + b.create(single_dim_tensor, /*axis*/ 0); mlir::Location loc = b.getLoc(); mlir::Block* reducer = b.createBlock( @@ -495,7 +495,7 @@ ScalarOrTensor ReshapeTensorToScalar(ImplicitLocOpBuilder& b, Value input) { return ScalarOrTensor(reduction.getResult().front()); } -absl::StatusOr EmitTiledReshape(ImplicitLocOpBuilder& b, +absl::StatusOr EmitTiledReshape(EmitterLocOpBuilder& b, ArrayRef tile_sizes, ScalarOrTensor input) { SmallVector padded_tile_sizes = GetPaddedTileSizes(tile_sizes); @@ -511,7 +511,7 @@ absl::StatusOr EmitTiledReshape(ImplicitLocOpBuilder& b, // At this point we know that the input is a non-0D tensor. - auto input_shaped_type = mlir::cast(input.Type()); + auto input_shaped_type = mlir::cast(input.getType()); // Handle the case of reshaping [1,1,1...] to a scalar. if (tile_sizes.empty()) { @@ -531,7 +531,7 @@ absl::StatusOr EmitTiledReshape(ImplicitLocOpBuilder& b, return ScalarOrTensor(reshape.getResult()); } -Value EmitTiledTranspose(ImplicitLocOpBuilder& b, ArrayRef tile_sizes, +Value EmitTiledTranspose(EmitterLocOpBuilder& b, ArrayRef tile_sizes, SmallVector dimensions, Value input) { SmallVector padded_tile_sizes = GetPaddedTileSizes(tile_sizes); @@ -546,7 +546,7 @@ Value EmitTiledTranspose(ImplicitLocOpBuilder& b, ArrayRef tile_sizes, } absl::StatusOr EmitTiledBitcast( - ImplicitLocOpBuilder& b, const TiledHloInstruction& tiled_bitcast, + EmitterLocOpBuilder& b, const TiledHloInstruction& tiled_bitcast, Value input) { // Any Bitcast is decomposable to a transpose+reshape+transpose. auto trt = ShapeUtil::DecomposeBitcastToTrt( @@ -601,7 +601,7 @@ absl::StatusOr EmitTiledBitcast( } absl::StatusOr EmitTiledHloInstruction( - ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + EmitterLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloFusionInstruction* fusion, const TiledHloInstruction& tiled_hlo, mlir::triton::FuncOp fn, ValueRange tile_multi_index, @@ -621,7 +621,7 @@ absl::StatusOr EmitTiledHloInstruction( // as i8. It's important to type checking that we perform a conversion after // loading if the type of the loaded parameter does not match what is // expected. - Type loaded_element_type = getElementTypeOrSelf(parameter.Type()); + Type loaded_element_type = getElementTypeOrSelf(parameter.getType()); TF_ASSIGN_OR_RETURN(Type expected_element_type, TritonType(b, hlo->shape().element_type())); @@ -705,7 +705,7 @@ absl::StatusOr EmitTiledHloInstruction( // Emit sequence of instructions using compatible tiling ordered producers // before consumers. absl::StatusOr EmitTiledComputation( - ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + EmitterLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloFusionInstruction* fusion, const TiledHloComputation& tiled_computation, mlir::triton::FuncOp fn, @@ -728,7 +728,7 @@ absl::StatusOr EmitTiledComputation( // Emit sequence of instructions using compatible tiling ordered producers // before consumers. absl::StatusOr EmitScope( - ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + EmitterLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const TritonFusionAnalysis* analysis, absl::Span instructions, @@ -791,7 +791,7 @@ absl::StatusOr EmitScope( // Computes the base pointer offset for the given tile multi-index and hlo shape // taking into account the physical layout of the hlo buffer. absl::StatusOr ComputeBasePtrOffset( - ImplicitLocOpBuilder b, ValueRange tile_multi_index, + EmitterLocOpBuilder& b, ValueRange tile_multi_index, const TiledHloInstruction& tiled_hlo) { const Shape& shape = tiled_hlo.hlo()->shape(); Shape linear_shape = ShapeUtil::MakeShape(shape.element_type(), @@ -809,9 +809,9 @@ absl::StatusOr ComputeBasePtrOffset( compose_indexing_maps.Simplify(); return b.create( - b.getI64Type(), mlir_converter::ApplyIndexing(compose_indexing_maps, - /*dims=*/tile_multi_index, - /*symbols=*/{}, b)[0]); + b.getI64Type(), emitters::ApplyIndexing(compose_indexing_maps, + /*dims=*/tile_multi_index, + /*symbols=*/{}, b)[0]); } } // namespace @@ -819,7 +819,7 @@ absl::StatusOr ComputeBasePtrOffset( namespace ir_emitter_triton_internal { SmallVector ComputeDelinearizedTileIndex( - ImplicitLocOpBuilder& b, + EmitterLocOpBuilder& b, absl::Span num_output_tiles_per_dim) { Value pid = b.create( b.getIndexType(), b.create(ttir::ProgramIDDim::X)); @@ -835,13 +835,13 @@ SmallVector ComputeDelinearizedTileIndex( /*dim_upper_bounds=*/{Product(num_output_tiles_per_dim)}, /*symbol_upper_bounds=*/{}); - return mlir_converter::ApplyIndexing(program_id_to_root_tile_offset, - /*dims=*/pid, - /*symbols=*/{}, b); + return emitters::ApplyIndexing(program_id_to_root_tile_offset, + /*dims=*/pid, + /*symbols=*/{}, b); } absl::StatusOr CreateMakeTensorPtrOp( - ImplicitLocOpBuilder& b, ValueRange tile_multi_index, + EmitterLocOpBuilder& b, ValueRange tile_multi_index, const TiledHloInstruction& tiled_hlo, Value parent_base_ptr) { const llvm::SmallVector& tile_strides = tiled_hlo.tile_strides(); const Shape& shape = tiled_hlo.hlo()->shape(); @@ -863,9 +863,9 @@ absl::StatusOr CreateMakeTensorPtrOp( TF_ASSIGN_OR_RETURN(IndexingMap tile_offsets_indexing, tiled_hlo.tile_offsets_indexing()); auto tile_offsets_as_indices = - mlir_converter::ApplyIndexing(tile_offsets_indexing, - /*dims=*/tile_multi_index, - /*symbols=*/{}, b); + emitters::ApplyIndexing(tile_offsets_indexing, + /*dims=*/tile_multi_index, + /*symbols=*/{}, b); // Triton requires that all block dimensions are a power of 2. SmallVector padded_tile_sizes = @@ -917,12 +917,12 @@ absl::StatusOr CreateMakeTensorPtrOp( return MakeTensorPtrOpAndBoundaryChecks{ b.create( - /*base=*/tile_ptr, - /*shape=*/residual_shape, - /*strides=*/strides, - /*offsets=*/offsets, - /*tensorShape=*/llvm::to_vector_of(padded_tile_sizes), - /*order=*/order), + /*base*/ tile_ptr, + /*shape*/ residual_shape, + /*strides*/ strides, + /*offsets*/ offsets, + /*tensorShape*/ llvm::to_vector_of(padded_tile_sizes), + /*order*/ order), boundary_checks}; } @@ -951,7 +951,11 @@ absl::Status EmitGeneric(mlir::OpBuilder builder, std::get(symbolic_tile_analysis_or); const HloInstruction* root = computation->root_instruction(); auto loc = mlir::NameLoc::get(builder.getStringAttr(root->name())); - ImplicitLocOpBuilder b(loc, builder); + EmitterLocOpBuilder b(loc, builder, + root->GetModule() + ->config() + .debug_options() + .xla_gpu_unsupported_annotate_with_emitter_loc()); TF_ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation, symbolic_tile_analysis.ComputeTiledHloInstructions( @@ -972,7 +976,7 @@ absl::Status EmitGeneric(mlir::OpBuilder builder, // as i8. It's important to type checking that we perform a conversion before // storing if the type of the result does not match the type of the output // pointer. - Type result_element_type = getElementTypeOrSelf(result.Type()); + Type result_element_type = getElementTypeOrSelf(result.getType()); Type result_storage_type = StorageType(b, result_element_type); if (result_element_type != result_storage_type) { @@ -1040,7 +1044,18 @@ absl::StatusOr> TranslateLLVMToLLVMIR( return llvmModule; } -absl::Status CreateInternalError(std::string_view message, +std::string DumpTritonIR(mlir::ModuleOp triton_module, bool dump_annotations) { + std::string triton_ir; + llvm::raw_string_ostream os(triton_ir); + triton_module.print(os, mlir::OpPrintingFlags().enableDebugInfo( + dump_annotations, dump_annotations)); + if (dump_annotations) { + return EmitterLocOpBuilder::FormatTritonIrWithAnnotations(triton_ir); + } + return triton_ir; +} + +absl::Status CreateInternalError(absl::string_view message, const HloFusionInstruction* fusion, mlir::ModuleOp triton_module) { std::string err; @@ -1060,17 +1075,21 @@ absl::StatusOr> CreateTritonModule( const BlockLevelParameters& block_level_parameters, mlir::MLIRContext& mlir_context) { LoadMlirDialectsForTriton(mlir_context); + const auto debug_options = fusion->GetModule()->config().debug_options(); const HloComputation* hlo_computation = fusion->fused_instructions_computation(); - mlir::OpBuilder b(&mlir_context); - auto loc = mlir::NameLoc::get(b.getStringAttr(hlo_computation->name())); + auto loc = mlir::NameLoc::get( + mlir::StringAttr::get(&mlir_context, hlo_computation->name())); + EmitterLocOpBuilder b( + loc, &mlir_context, + debug_options.xla_gpu_unsupported_annotate_with_emitter_loc()); + mlir::OwningOpRef triton_module = llvm_ir::CreateMlirModuleOp(loc); b.setInsertionPointToEnd(triton_module->getBody()); - const auto debug_options = fusion->GetModule()->config().debug_options(); // Build Triton kernel. SmallVector fn_arg_types; for (HloInstruction* p : hlo_computation->parameter_instructions()) { @@ -1079,7 +1098,11 @@ absl::StatusOr> CreateTritonModule( if (type == U16) { ir_type = b.getI16Type(); } else if (type == S4) { - ir_type = b.getI8Type(); + if (debug_options.xla_gpu_experimental_enable_triton_i4_rewrites()) { + ir_type = b.getI4Type(); + } else { + ir_type = b.getI8Type(); + } } else { TF_ASSIGN_OR_RETURN(ir_type, TritonType(b, type)); } @@ -1095,10 +1118,11 @@ absl::StatusOr> CreateTritonModule( } auto fn = b.create( - loc, fn_name, b.getFunctionType(fn_arg_types, std::nullopt)); + fn_name, b.getFunctionType(fn_arg_types, std::nullopt)); for (int i = 0; i < fn.getNumArguments(); ++i) { fn.setArgAttr(i, "tt.divisibility", b.getIntegerAttr(b.getI32Type(), 16)); } + fn.addEntryBlock(); b.setInsertionPointToStart(&fn.front()); @@ -1119,19 +1143,16 @@ absl::StatusOr> CreateTritonModule( return Internal("Unsupported fusion kind: %s", fusion_kind); } - b.create(loc); - - auto dump_triton_ir = [&]() { - std::string triton_ir; - llvm::raw_string_ostream os(triton_ir); - triton_module->print(os, - mlir::OpPrintingFlags().enableDebugInfo(true, true)); - return triton_ir; - }; + b.create(); if (DumpingEnabledForHloModule(*hlo_computation->parent())) { - DumpToFileInDirOrStdout(*hlo_computation->parent(), "triton_ir", - "before_validation.ttir", dump_triton_ir()); + DumpToFileInDirOrStdout( + *hlo_computation->parent(), "triton_ir", "before_validation.ttir", + DumpTritonIR(triton_module.get(), + fusion->GetModule() + ->config() + .debug_options() + .xla_gpu_unsupported_annotate_with_emitter_loc())); } if (mlir::failed(mlir::verify(*triton_module))) { @@ -1147,12 +1168,21 @@ absl::StatusOr> CreateTritonModule( "Failed to create Triton module for fusion:", fusion, *triton_module); } - VLOG(6) << dump_triton_ir(); + VLOG(6) << DumpTritonIR(triton_module.get(), + fusion->GetModule() + ->config() + .debug_options() + .xla_gpu_unsupported_annotate_with_emitter_loc()); // TODO(loislo): Remove this dump once we have the Triton IR dump in // CompileTritonToLLVM after the Triton optimization passes. if (DumpingEnabledForHloModule(*hlo_computation->parent())) { - DumpToFileInDirOrStdout(*hlo_computation->parent(), "triton_ir", "ttir", - dump_triton_ir()); + DumpToFileInDirOrStdout( + *hlo_computation->parent(), "triton_ir", "ttir", + DumpTritonIR(triton_module.get(), + fusion->GetModule() + ->config() + .debug_options() + .xla_gpu_unsupported_annotate_with_emitter_loc())); } return std::move(triton_module); @@ -1186,7 +1216,8 @@ absl::StatusOr TritonWrapper( const HloModule* hlo_module = fusion->GetModule(); return CompileTritonToLLVM(hlo_module->config(), hlo_module->name(), device_info, block_level_parameters, - triton_module.get(), llvm_module, mlir_context); + triton_module.get(), llvm_module, mlir_context, + /*is_xla_fusion=*/true); } absl::StatusOr CompileTritonToLLVM( @@ -1194,8 +1225,10 @@ absl::StatusOr CompileTritonToLLVM( const se::DeviceDescription& device_info, const BlockLevelParameters& block_level_parameters, mlir::ModuleOp triton_module, llvm::Module* llvm_module, - mlir::MLIRContext& mlir_context, bool emit_kernel) { + mlir::MLIRContext& mlir_context, bool is_xla_fusion, bool emit_kernel) { const auto& cc = device_info.gpu_compute_capability(); + const std::string arch_name = + std::visit([](auto& cc) { return cc.ToString(); }, cc); if (std::holds_alternative(cc)) { auto ccCuda = std::get(cc); if (!ccCuda.IsAtLeastAmpere()) { @@ -1255,7 +1288,10 @@ absl::StatusOr CompileTritonToLLVM( pm.addPass(CreateSimplifyAffinePass()); mlir::triton::nvidia_gpu::ClusterInfo cluster_info; - if (!CreateTritonPipeline(pm, cc, block_level_parameters, cluster_info) + if (!CreateTritonPipeline(&pm, arch_name, block_level_parameters.num_warps, + block_level_parameters.num_ctas, + block_level_parameters.num_stages, cluster_info, + is_xla_fusion) .ok()) { return Internal("Failed to create Triton pipeline."); } diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h index 4a7db42acaf53d..0181bff7ffca62 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h @@ -27,7 +27,6 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Module.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" #include "mlir/IR/Value.h" @@ -35,6 +34,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/emitter_loc_op_builder.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/model/tiled_hlo_instruction.h" #include "xla/service/hlo_module_config.h" @@ -87,21 +87,8 @@ absl::StatusOr CompileTritonToLLVM( const se::DeviceDescription& device_info, const BlockLevelParameters& block_level_parameters, mlir::ModuleOp triton_module, llvm::Module* llvm_module, - mlir::MLIRContext& mlir_context, bool emit_kernel = true); - -// Create Triton pipeline. -// -// `out_cluster_info` must be kept alive at least until pm.run() is called. -// It should be read after that. We have to pass the cluster dims to -// LaunchDimensions. Triton currently uses this as an out-parameter to return -// the cluster dims determined based on `config.num_ctas` and a heuristic. There -// are some signs that show that this was intended to be used as an in-out -// parameter which would give a hint to Triton which cluster dims we prefer to -// use, but that's not the case currently. -absl::Status CreateTritonPipeline( - mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, - const BlockLevelParameters& block_level_parameters, - ::mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info); + mlir::MLIRContext& mlir_context, bool is_xla_fusion, + bool emit_kernel = true); std::string GetLibdevicePath(const HloModuleConfig& hlo_config, const se::DeviceDescription& device_info); @@ -111,8 +98,7 @@ namespace ir_emitter_triton_internal { // Computes the transformation from a 1-d program_id to a tile multi-index. llvm::SmallVector ComputeDelinearizedTileIndex( - mlir::ImplicitLocOpBuilder& b, - absl::Span num_output_tiles_per_dim); + EmitterLocOpBuilder& b, absl::Span num_output_tiles_per_dim); // Used for creating Triton Load and Store ops. struct MakeTensorPtrOpAndBoundaryChecks { @@ -124,10 +110,17 @@ struct MakeTensorPtrOpAndBoundaryChecks { }; absl::StatusOr CreateMakeTensorPtrOp( - mlir::ImplicitLocOpBuilder& b, mlir::ValueRange tile_multi_index, + EmitterLocOpBuilder& b, mlir::ValueRange tile_multi_index, const TiledHloInstruction& tiled_hlo, mlir::Value parent_base_ptr); } // namespace ir_emitter_triton_internal +// Dumps the Triton IR to a string. +// +// If `dump_annotations` is true, then the function also dumps the loc +// attributes of the instructions. Otherwise, it dumps the IR without +// annotations. +std::string DumpTritonIR(mlir::ModuleOp triton_module, bool dump_annotations); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index c1deef22e788b2..7e3be2bc6863c3 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -46,13 +46,13 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/device_description.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status_matchers.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/xla.pb.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" #include "tsl/platform/path.h" -#include "tsl/platform/status_matchers.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" namespace xla { namespace gpu { @@ -95,8 +95,9 @@ class TritonGemmTest : public TritonTest { public: DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); - // Do not fall back to cuBLAS, we are testing Triton. + // Do not fall back to cuBLAS and disable cuDNN; we are testing Triton. debug_options.set_xla_gpu_cublas_fallback(false); + debug_options.set_xla_gpu_cudnn_gemm_fusion_level(0); // Do not autotune split-k by default, since this prevents deterministically // matching the optimized HLO. debug_options.set_xla_gpu_enable_split_k_autotuning(false); @@ -130,69 +131,6 @@ class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest { } }; -TEST_F(TritonGemmTest, NonstandardLayoutInt4) { - constexpr std::string_view kHloText = R"( - HloModule NonstandardLayoutInt4 - - ENTRY main { - p0 = s4[64,128]{0,1} parameter(0) - p1 = bf16[256,64]{1,0} parameter(1) - ROOT %dot = bf16[128,256]{1,0} dot(s4[64,128]{0,1} p0, bf16[256,64]{1,0} p1), - lhs_contracting_dims={0}, - rhs_contracting_dims={1} - } - )"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); - EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( - CHECK: %[[param_0:.*]] = s4[64,128]{0,1:E(4)} parameter(0) - CHECK: %[[bitcast:.*]] = s4[128,64]{1,0:E(4)} bitcast(s4[64,128]{0,1:E(4)} %[[param_0]]) - CHECK: %[[convert:.*]] = bf16[128,64]{1,0} convert(s4[128,64]{1,0:E(4)} %[[bitcast]]) - CHECK: %[[param_1:.*]] = bf16[256,64]{1,0} parameter(1) - CHECK: ROOT %dot.1 = bf16[128,256]{1,0} dot(bf16[128,64]{1,0} %[[convert]], bf16[256,64]{1,0} %[[param_1]]), lhs_contracting_dims={1}, rhs_contracting_dims={1} - )")); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, NonstandardLayoutInt4WithManyNonContractingDims) { - // We cannot do triton_gemm and we use cuBLAS instead. - constexpr std::string_view kHloText = R"( - HloModule t - - ENTRY main { - p0 = s4[128,64,192]{1,0,2} parameter(0) - p1 = bf16[256,64]{1,0} parameter(1) - ROOT %dot = bf16[128,192,256]{2,1,0} dot(p0, p1), - lhs_contracting_dims={1}, - rhs_contracting_dims={1} - } - )"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); - EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(CHECK: "__cublas$gemm")")); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-2})); -} - -TEST_F(TritonGemmTest, - NonstandardLayoutInt4WithManyNonContractingDimsReversedLayout) { - // We cannot do triton_gemm and we use cuBLAS instead. - constexpr std::string_view kHloText = R"( - HloModule t - - ENTRY main { - p0 = s4[128,64,192]{0,1,2} parameter(0) - p1 = bf16[256,64]{1,0} parameter(1) - ROOT %dot = bf16[128,192,256]{2,1,0} dot(p0, p1), - lhs_contracting_dims={1}, - rhs_contracting_dims={1} - } - )"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); - EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(CHECK: "__cublas$gemm")")); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - TEST_F(TritonGemmTest, FP8DotSmallTileDoesNotCrash) { GTEST_SKIP() << "TODO(b/337839570): Re-enable once the bug is fixed. " "Currently the test is not representative of the issue. " @@ -202,7 +140,7 @@ TEST_F(TritonGemmTest, FP8DotSmallTileDoesNotCrash) { GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m triton_dot { @@ -224,349 +162,8 @@ ENTRY e { EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false)); } -TEST_F(TritonGemmTest, Int4NegatePlusConvertHLO) { - constexpr std::string_view kHloText = R"( - HloModule t - - ENTRY main { - lhs = s4[16,32,64]{2,1,0} parameter(0) - lhs_negated = s4[16,32,64]{2,1,0} negate(lhs) - lhs_converted = bf16[16,32,64]{2,1,0} convert(lhs_negated) - rhs = bf16[16,64,16]{2,1,0} parameter(1) - ROOT dot = bf16[16,32,16]{2,1,0} dot(lhs_converted, rhs), - lhs_contracting_dims={2}, - rhs_contracting_dims={1}, - lhs_batch_dims={0}, - rhs_batch_dims={0} - } - )"; - EXPECT_TRUE(RunAndCompareNoHloPasses( - kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, RejectTritonFusionForInt4WithMinorBatchDim) { - constexpr std::string_view kHloText = R"( - HloModule t - - ENTRY main { - lhs = s4[32,64,16]{2,1,0} parameter(0) - lhs_converted = bf16[32,64,16]{2,1,0} convert(lhs) - rhs = bf16[16,64,16]{2,1,0} parameter(1) - ROOT dot = bf16[16,32,16]{2,1,0} dot(lhs_converted, rhs), - lhs_contracting_dims={1}, - rhs_contracting_dims={1}, - lhs_batch_dims={2}, - rhs_batch_dims={0} - } - )"; - - const std::string pattern = - R"(CHECK-NOT: "kind":"__triton_gemm","triton_gemm_config")"; - TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); - TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); - EXPECT_TRUE(ok); -} - -TEST_F(TritonGemmTest, LHSInt4WithMinorDimEqualTo1) { - // We prove that triton can handle int4 dot with non contracting dim size - // equal to 1. - constexpr std::string_view kHloText = R"( - HloModule t - - triton_computation { - lhs = s4[16,32,1]{2,1,0} parameter(0) - lhs_converted = bf16[16,32,1]{2,1,0} convert(lhs) - rhs = bf16[16,64,32]{2,1,0} parameter(1) - ROOT dot = bf16[16,1,64]{2,1,0} dot(lhs_converted, rhs), - lhs_contracting_dims={1}, - rhs_contracting_dims={2}, - lhs_batch_dims={0}, - rhs_batch_dims={0} - } - - ENTRY main { - lhs = s4[16,32,1]{2,1,0} parameter(0) - rhs = bf16[16,64,32]{2,1,0} parameter(1) - ROOT dot = bf16[16,1,64]{2,1,0} fusion(lhs, rhs), kind=kCustom, - calls=triton_computation, - backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} - } - )"; - EXPECT_TRUE(RunAndCompareNoHloPasses( - kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, RHSInt4WithMinorDimEqualTo1) { - // We prove that triton can handle int4 dot with non contracting dim size - // equal to 1. - constexpr std::string_view kHloText = R"( - HloModule t - - triton_computation { - lhs = bf16[16,32,64]{2,1,0} parameter(0) - rhs = s4[16,32,1]{2,1,0} parameter(1) - rhs_converted = bf16[16,32,1]{2,1,0} convert(rhs) - ROOT dot = bf16[16,64,1]{2,1,0} dot(lhs, rhs_converted), - lhs_contracting_dims={1}, - rhs_contracting_dims={1}, - lhs_batch_dims={0}, - rhs_batch_dims={0} - } - - ENTRY main { - lhs = bf16[16,32,64]{2,1,0} parameter(0) - rhs = s4[16,32,1]{2,1,0} parameter(1) - ROOT dot = bf16[16,64,1]{2,1,0} fusion(lhs, rhs), kind=kCustom, - calls=triton_computation, - backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} - } - )"; - - EXPECT_TRUE(RunAndCompareNoHloPasses( - kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, LHSInt4NonMinorContractingDim) { - // We prove that triton can handle int4 dot with non minor - // lhs_contracting_dim. - constexpr std::string_view kHloText = R"( - HloModule t - - triton_computation { - lhs = s4[1024,8]{1,0} parameter(0) - lhs_converted = bf16[1024,8]{1,0} convert(lhs) - rhs = bf16[1024,4]{1,0} parameter(1) - ROOT dot = bf16[8,4]{1,0} dot(lhs_converted, rhs), - lhs_contracting_dims={0}, - rhs_contracting_dims={0} - } - - ENTRY main { - lhs = s4[1024,8]{1,0} parameter(0) - rhs = bf16[1024,4]{1,0} parameter(1) - ROOT dot = bf16[8,4]{1,0} fusion(lhs, rhs), kind=kCustom, - calls=triton_computation, - backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} - } - )"; - - EXPECT_TRUE(RunAndCompareNoHloPasses( - kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, LHSInt4NonMinorContractingDimWithBatchDim0) { - // We prove that triton can handle int4 dot with non minor - // lhs_contracting_dim. - constexpr std::string_view kHloText = R"( - HloModule t - - triton_computation { - lhs = s4[16,1024,8]{2,1,0} parameter(0) - lhs_converted = bf16[16,1024,8]{2,1,0} convert(lhs) - rhs = bf16[16,1024,4]{2,1,0} parameter(1) - ROOT dot = bf16[16,8,4]{2,1,0} dot(lhs_converted, rhs), - lhs_batch_dims={0}, - lhs_contracting_dims={1}, - rhs_batch_dims={0}, - rhs_contracting_dims={1} - } - - ENTRY main { - lhs = s4[16,1024,8]{2,1,0} parameter(0) - rhs = bf16[16,1024,4]{2,1,0} parameter(1) - ROOT dot = bf16[16,8,4]{2,1,0} fusion(lhs, rhs), kind=kCustom, - calls=triton_computation, - backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} - } - )"; - EXPECT_TRUE(RunAndCompareNoHloPasses( - kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); -} - -TEST_F(TritonGemmTest, LHSInt4MinorContractingDim) { - // We prove that triton can handle int4 dot with minor lhs_contracting_dim. - constexpr std::string_view kHloText = R"( - HloModule t - - triton_computation { - lhs = s4[8,1024]{1,0} parameter(0) - lhs_converted = bf16[8,1024]{1,0} convert(lhs) - rhs = bf16[1024,4]{1,0} parameter(1) - ROOT dot = bf16[8,4]{1,0} dot(lhs_converted, rhs), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - } - - ENTRY main { - lhs = s4[8,1024]{1,0} parameter(0) - rhs = bf16[1024,4]{1,0} parameter(1) - ROOT dot = bf16[8,4]{1,0} fusion(lhs, rhs), kind=kCustom, - calls=triton_computation, - backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} - } - )"; - EXPECT_TRUE(RunAndCompareNoHloPasses( - kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); -} - -TEST_F(TritonGemmTest, Int4ConvertPlusNegate) { - constexpr std::string_view kHloText = R"( - HloModule t - - triton_computation { - lhs = s4[8,1024]{1,0} parameter(0) - lhs_converted = bf16[8,1024]{1,0} convert(lhs) - lhs_negated = bf16[8,1024]{1,0} negate(lhs_converted) - rhs = bf16[1024,4]{1,0} parameter(1) - ROOT dot = bf16[8,4]{1,0} dot(lhs_negated, rhs), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - } - - ENTRY main { - lhs = s4[8,1024]{1,0} parameter(0) - rhs = bf16[1024,4]{1,0} parameter(1) - ROOT dot = bf16[8,4]{1,0} fusion(lhs, rhs), kind=kCustom, - calls=triton_computation, - backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} - } - )"; - EXPECT_TRUE(RunAndCompareNoHloPasses( - kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); -} - -TEST_F(TritonGemmTest, LHSInt4MinorContractingDimWithBatchDim0) { - // We prove that triton can handle int4 dot with minor lhs_contracting_dim. - constexpr std::string_view kHloText = R"( - HloModule t - - triton_computation { - lhs = s4[16,8,1024]{2,1,0} parameter(0) - lhs_converted = bf16[16,8,1024]{2,1,0} convert(lhs) - rhs = bf16[16,1024,4]{2,1,0} parameter(1) - ROOT dot = bf16[16,8,4]{2,1,0} dot(lhs_converted, rhs), - lhs_batch_dims={0}, - lhs_contracting_dims={2}, - rhs_batch_dims={0}, - rhs_contracting_dims={1} - } - - ENTRY main { - lhs = s4[16,8,1024]{2,1,0} parameter(0) - rhs = bf16[16,1024,4]{2,1,0} parameter(1) - ROOT dot = bf16[16,8,4]{2,1,0} fusion(lhs, rhs), kind=kCustom, - calls=triton_computation, - backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} - } - )"; - EXPECT_TRUE(RunAndCompareNoHloPasses( - kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); -} - -TEST_F(TritonGemmTest, RHSInt4TestWithMinorContractingDim) { - constexpr std::string_view kHloText = R"( - HloModule t - - triton_computation { - lhs = bf16[8,1024]{1,0} parameter(0) - rhs = s4[1024,4]{1,0} parameter(1) - rhs_converted = bf16[1024,4]{1,0} convert(rhs) - ROOT dot = bf16[8,4] dot(lhs, rhs_converted), - lhs_contracting_dims={1}, - rhs_contracting_dims={0} - } - - ENTRY main { - lhs = bf16[8,1024]{1,0} parameter(0) - rhs = s4[1024,4]{1,0} parameter(1) - ROOT dot = bf16[8,4] fusion(lhs, rhs), kind=kCustom, - calls=triton_computation, - backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} - } - )"; - EXPECT_TRUE(RunAndCompareNoHloPasses( - kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); -} - -TEST_F(TritonGemmTest, RHSInt4TestWithNotMinorContractingDim) { - constexpr std::string_view kHloText = R"( - HloModule t - - triton_computation { - lhs = bf16[8,1024]{1,0} parameter(0) - rhs = s4[4,1024]{1,0} parameter(1) - rhs_converted = bf16[4,1024]{1,0} convert(rhs) - ROOT dot = bf16[8,4] dot(lhs, rhs_converted), - lhs_contracting_dims={1}, - rhs_contracting_dims={1} - } - - ENTRY main { - lhs = bf16[8,1024]{1,0} parameter(0) - rhs = s4[4,1024]{1,0} parameter(1) - ROOT dot = bf16[8,4] fusion(lhs, rhs), kind=kCustom, - calls=triton_computation, - backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} - } - )"; - EXPECT_TRUE(RunAndCompareNoHloPasses( - kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); -} - -TEST_F(TritonGemmTest, RHSInt4TestWithMinorContractingDimWithBatchDim) { - constexpr std::string_view kHloText = R"( - HloModule t - - triton_computation { - lhs = bf16[16,8,1024]{2,1,0} parameter(0) - rhs = s4[16,1024,4]{2,1,0} parameter(1) - rhs_converted = bf16[16,1024,4]{2,1,0} convert(rhs) - ROOT dot = bf16[16,8,4] dot(lhs, rhs_converted), - lhs_batch_dims={0}, - lhs_contracting_dims={2}, - rhs_batch_dims={0}, - rhs_contracting_dims={1} - } - - ENTRY main { - lhs = bf16[16,8,1024]{2,1,0} parameter(0) - rhs = s4[16,1024,4]{2,1,0} parameter(1) - ROOT dot = bf16[16,8,4] fusion(lhs, rhs), kind=kCustom, - calls=triton_computation, - backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} - } - )"; - EXPECT_TRUE(RunAndCompareNoHloPasses( - kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); -} - -TEST_F(TritonGemmTest, RHSInt4TestWithNotMinorContractingDimWithBatchDim0) { - constexpr std::string_view kHloText = R"( - HloModule t - - triton_computation { - lhs = bf16[16,8,1024]{2,1,0} parameter(0) - rhs = s4[16,4,1024]{2,1,0} parameter(1) - rhs_converted = bf16[16,4,1024]{2,1,0} convert(rhs) - ROOT dot = bf16[16,8,4] dot(lhs, rhs_converted), - lhs_batch_dims={0}, - lhs_contracting_dims={2}, - rhs_batch_dims={0}, - rhs_contracting_dims={2} - } - - ENTRY main { - lhs = bf16[16,8,1024]{2,1,0} parameter(0) - rhs = s4[16,4,1024]{2,1,0} parameter(1) - ROOT dot = bf16[16,8,4] fusion(lhs, rhs), kind=kCustom, - calls=triton_computation, - backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} - } - )"; - EXPECT_TRUE(RunAndCompareNoHloPasses( - kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); -} - TEST_F(TritonTest, TestGemm) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t, is_scheduled=true triton_gemm_r { @@ -658,7 +255,7 @@ CHECK: } } TEST_F(TritonTest, TestGemmWithTrivialNonContractingDimension) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t, is_scheduled=true triton_dot { @@ -748,7 +345,7 @@ CHECK: } } TEST_F(TritonTest, PredParametersAreTruncatedToI1) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m triton_gemm_computation { @@ -789,7 +386,7 @@ CHECK: %{{.*}} = arith.andi %[[TRUNCI]], %{{.*}} : tensor<16x16xi1> } TEST_F(TritonTest, CodegenBatchedDotWithConcatenationWithCorrectBatchStride) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t, is_scheduled=true triton_gemm { @@ -832,7 +429,7 @@ CHECK: %[[BLOCK_BASE_PTR:.*]] = tt.addptr %[[ARG_PTR]], %[[OFFSET]] TEST_F(TritonTest, CodegenDynamicSliceWithCorrectOffsets) { // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t triton_gemm { @@ -882,7 +479,7 @@ CHECK-DAG: tt.make_tensor_ptr %[[DYNAMIC_SLICE_INPUT]], [%[[C2_i64]], %[[ROW_L } TEST_F(TritonTest, SparseDot) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t triton_dot { @@ -913,7 +510,7 @@ CHECK: triton_xla.sparse_dot %[[LHS]], %[[RHS]], %{{[^:]+}}, %[[META]] : } TEST_F(TritonTest, SparseDotWithMasking) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t triton_dot { @@ -950,7 +547,7 @@ CHECK: triton_xla.sparse_dot %[[LHS_MASKED]], %[[RHS_MASKED]], %{{[^:]+}}, %[[ME } TEST_F(TritonTest, SparseDotBroadcastMetadata) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t triton_dot { @@ -987,7 +584,7 @@ CHECK: triton_xla.sparse_dot %[[LHS]], %[[RHS]], %{{[^:]+}}, %[[META]] : } TEST_F(TritonGemmTest, DoNotUseTensorCoresWithNonDefaultPrecision) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( triton_gemm_r { parameter_0 = s8[80,15]{1,0} parameter(0) convert.3 = f32[80,15]{1,0} convert(parameter_0) @@ -1017,7 +614,7 @@ CHECK-NOT: mma } TEST_F(TritonGemmTest, DebugOptionsArePropagated) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY e { p0 = f16[30,30] parameter(0) p1 = s8[30,30] parameter(1) @@ -1069,7 +666,7 @@ ENTRY main { } TEST_F(TritonGemmTest, UseTensorCoresForF32OnAmpere) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( triton_gemm_r { parameter_0 = f16[80,15]{1,0} parameter(0) convert.3 = f32[80,15]{1,0} convert(parameter_0) @@ -1101,7 +698,7 @@ TEST_F(TritonGemmTest, FailIfTooMuchShmem) { if (std::holds_alternative(GpuComputeComp())) { GTEST_SKIP() << "GEMM padding requirements for ROCM not included yet."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule module, is_scheduled=true triton_gemm_dot { @@ -1177,7 +774,7 @@ TEST_F(TritonGemmTestWithSplitK, // The condition mentioned in the test name is fulfilled by // GemmKey(16, 64, 256, 8, 1, 4), which was part of the default configs for // Ampere at the time of the addition of this test case. - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule extracted ENTRY e { @@ -1331,7 +928,7 @@ ENTRY e { } TEST_F(TritonGemmTest, SplitAndTransposeLhsExecutesCorrectly) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -1361,7 +958,7 @@ TEST_F(TritonGemmTest, NondefaultOperandLayoutIsSupported) { #ifndef NDEBUG GTEST_SKIP() << "This test times out when -UNDEBUG is set."; #endif - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY r { p1 = f16[9,140,128]{2,1,0} parameter(1) cp = f16[9,140,128]{2,0,1} copy(p1) @@ -1534,7 +1131,7 @@ ENTRY e { } TEST_F(TritonGemmTest, MultipleBatchRequireSeparateTranspose) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -1557,7 +1154,7 @@ ENTRY e { } TEST_F(TritonGemmTest, CanCodegenNonBatchedDotWithConcatenationCorrectly) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY e { parameter_0 = f32[3,10]{1,0} parameter(0) parameter_1 = f32[10,128]{1,0} parameter(1) @@ -1581,7 +1178,7 @@ ENTRY e { } TEST_F(TritonGemmTest, CanCodegenBatchedDotWithConcatenationCorrectly) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY e { parameter_0 = f32[2,3,10]{2,1,0} parameter(0) parameter_1 = f32[2,10,128]{2,1,0} parameter(1) @@ -1626,7 +1223,7 @@ ENTRY e { } TEST_F(TritonTest, FloatToSignedIntConversion) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t, is_scheduled=true triton_gemm_r { @@ -1687,7 +1284,7 @@ ENTRY e { // This tests the complexity heuristics in TritonWrapper. TEST_F(TritonGemmTest, FailForTooComplexTiling) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule module, is_scheduled=true triton_gemm_dot { @@ -1974,7 +1571,7 @@ TEST_F(TritonGemmTest, DynamicSliceIsSupportedInLhsEndToEnd) { // is not strictly needed, because we also support clamping the indices. // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -2005,7 +1602,7 @@ ENTRY e { TEST_F(TritonGemmTest, DynamicSliceIsSupportedInRhs) { // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m triton_gemm { @@ -2038,7 +1635,7 @@ ENTRY e { } TEST_F(TritonGemmTest, MultiplePathsToSameOperandWorks) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( triton_computation { p0 = bf16[8192,512]{1,0} parameter(0) p1 = bf16[512,512]{1,0} parameter(1) @@ -2121,7 +1718,7 @@ TEST_F(TritonGemmTest, DynamicSliceOfMajormostContractingDimIsSupported) { // dimension is contracted. // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m triton_gemm { @@ -2158,7 +1755,7 @@ TEST_F(TritonGemmTest, DynamicSliceOfMajormostBatchDimIsSupported) { // dimension is a batch. // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m triton_gemm { @@ -2197,7 +1794,7 @@ TEST_F(TritonGemmTest, DynamicSliceSingleDimensionIntoReshapeIsSupported) { // layer weights and extracting them with dynamic slice. // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m triton_gemm { @@ -2264,7 +1861,7 @@ ENTRY e { } TEST_F(TritonGemmTest, BroadcastOfScalarWorksCorrectly) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( fusion { p0 = f16[2,18] parameter(0) p1 = f16[256,2] parameter(1) @@ -2285,7 +1882,7 @@ ENTRY e { TF_ASSERT_OK(CreateTritonIrAndFileCheckForDot(this, kHloText, "fusion", R"( CHECK: tt.dot - CHECK: arith.mulf %{{.*}}, %{{.*}} : tensor + CHECK: arith.mulf %{{.*}}, %{{.*}} : tensor CHECK: tt.broadcast %{{.*}} : tensor<1x1xf16> -> tensor<32x32xf16> CHECK: arith.mulf %{{.*}}, %{{.*}} : tensor<32x32xf16> )")); @@ -2334,7 +1931,7 @@ class TritonGemmLevel2TestAny : public TritonGemmLevel2Test { }; TEST_F(TritonGemmLevel2Test, BinaryOperationWithSmallInputsIsFused) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -2360,7 +1957,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BinaryOperationWithLargeInputsIsNotFused) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -2391,7 +1988,7 @@ ENTRY e { TEST_F(TritonGemmLevel2Test, ParametersWithDifferentLayoutsAreSupportedInOneScope) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY e { p0 = s8[5,3] parameter(0) p0c = f16[5,3] convert(p0) @@ -2414,7 +2011,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BinaryOperationOnLargeParametersIsFused) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -2439,7 +2036,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, LinkingLibdeviceTwiceWorks) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY e { p0 = s8[7,3] parameter(0) c0 = f32[7,3] convert(p0) @@ -2470,7 +2067,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BroadcastOfScalarParameterIsFused) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY e { p0 = f16[64,256] parameter(0) p0c = f32[64,256] convert(p0) @@ -2491,7 +2088,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BroadcastOfScalarConstantIsFused) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -2517,7 +2114,7 @@ TEST_F(TritonGemmLevel2Test, DoubleBroadcastOfScalarConstantIsHandled) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY e { c = s32[] constant(1) bc1 = s32[21]{0} broadcast(c), dimensions={} @@ -2541,7 +2138,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BroadcastOfVectorConstantIsFused) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -2565,7 +2162,7 @@ TEST_F(TritonGemmLevel2Test, AlwaysFuseScalarConstantAtBroadcastInput) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY e { p0 = bf16[2,3,3]{2,1,0} parameter(0) p1 = bf16[3,2,3]{2,1,0} parameter(1) @@ -2592,7 +2189,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BroadcastOfVectorParameterIsFused) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( triton_dot { p0 = f16[75] parameter(0) bc0 = f16[75,67] broadcast(p0), dimensions={0} @@ -2621,7 +2218,7 @@ TEST_F(TritonGemmLevel2Test, FuseConcatenation) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( e { p0 = s8[153,1536] parameter(0) p1 = s8[153,128] parameter(1) @@ -2647,7 +2244,7 @@ e { } TEST_F(TritonGemmLevel2TestAny, MinimumHandlesNaNsOnTheLeft) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -2670,7 +2267,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MinimumHandlesNaNsOnTheRight) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -2693,7 +2290,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MaximumHandlesNaNsOnTheLeft) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -2716,7 +2313,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MaximumHandlesNaNsOnTheRight) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -2739,7 +2336,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MinimumReturnsLHS) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -2764,7 +2361,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MinimumReturnsRHS) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -2789,7 +2386,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MaximumReturnsLHS) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -2814,7 +2411,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MaximumReturnsRHS) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t ENTRY e { @@ -2839,7 +2436,7 @@ ENTRY e { } TEST_F(TritonGemmTest, SineOutputIsNotFused) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -2862,7 +2459,7 @@ ENTRY e { } TEST_F(TritonGemmTest, SliceInputIsFused) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY e { p0 = f16[97,121] parameter(0) s0 = f16[7,101] slice(p0), slice={[3:10], [10:111]} @@ -2883,7 +2480,7 @@ ENTRY e { } TEST_F(TritonGemmTest, SliceInputWithReshapeIsFused) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY e { p0 = f32[363,1536] parameter(0) p1 = f32[4,1536,611] parameter(1) @@ -2905,7 +2502,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, NestedSlicingWorks) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY e { p1 = f32[6,24] parameter(1) slice1 = f32[5,20] slice(p1), slice={[1:6], [3:23]} @@ -2927,14 +2524,14 @@ ENTRY e { } TEST_F(TritonGemmTest, SlicedBatchDimensionIsSupported) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY e { p0 = f16[3,3,256] parameter(0) s0 = f16[3,3,128] slice(p0), slice={[0:3], [0:3], [123:251]} r0 = f16[3,3,128] reshape(s0) p1 = f16[3,3,256] parameter(1) - s1 = f16[3,3,128] slice(p1), slice={[0:3], [0:3], [30:158]} - r1 = f16[3,3,128] reshape(s1) + svar1 = f16[3,3,128] slice(p1), slice={[0:3], [0:3], [30:158]} + r1 = f16[3,3,128] reshape(svar1) ROOT d = f16[128,3,3]{2,1,0} dot(r0, r1), lhs_batch_dims={2}, lhs_contracting_dims={1}, rhs_batch_dims={2}, rhs_contracting_dims={1} @@ -2952,7 +2549,7 @@ ENTRY e { TEST_F(TritonGemmTestWithSplitK, SplitKDoesNotBreakSlicedFragmentedContractingDimension) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY e { p0 = f16[16,8,128]{2,1,0} parameter(0) s0 = f16[16,4,128]{2,1,0} slice(p0), @@ -2976,7 +2573,7 @@ ENTRY e { } TEST_F(TritonGemmTestWithSplitK, SplitKWithTrivialDimension) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY entry_computation { p0 = f16[1001,1]{1,0} parameter(0) convert = f32[1001,1]{1,0} convert(p0) @@ -2989,7 +2586,7 @@ ENTRY entry_computation { } TEST_F(TritonGemmLevel2Test, NarrowingConvertOutputIsFused) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -3015,7 +2612,7 @@ TEST_F(TritonGemmLevel2Test, ParameterAfterDotIsFused) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -3047,7 +2644,7 @@ TEST_F(TritonGemmLevel2Test, OutputFusionExecutesCorrectly) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -3083,7 +2680,7 @@ TEST_F(TritonGemmLevel2Test, SplitLHSOutputTransposeAloneIsNotFused) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -3116,7 +2713,7 @@ TEST_F(TritonGemmLevel2Test, SplitLHSInputOutputIsFused) { GTEST_SKIP() << "Skipped until corresponding issue on ROCm is fixed."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY e { p0t = (s8[5,18,20,150]) parameter(0) p0 = s8[5,18,20,150] get-tuple-element(p0t), index=0 @@ -3141,7 +2738,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, SupportPredParametersUsedInExpressions) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY e { p = pred[2,2]{1,0} parameter(0) a = f32[2,2]{1,0} parameter(1) @@ -4508,7 +4105,7 @@ TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_0) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -4533,7 +4130,7 @@ TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_1_2) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -4558,7 +4155,7 @@ TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_0_1) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -4584,7 +4181,7 @@ TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_1) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -4783,7 +4380,7 @@ TEST_F(TritonGemmTest, TestNoAutotuner) { if (std::holds_alternative(GpuComputeComp())) { GTEST_SKIP() << "Autotuner is always in pipeline on Cuda."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( ENTRY e { p0 = f16[30,30] parameter(0) p1 = s8[30,30] parameter(1) diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc index 5d6dc13a380ace..c9ca9b577bd25c 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/primitive_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" @@ -39,7 +40,6 @@ limitations under the License. #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/stream_executor/device_description.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" @@ -62,7 +62,7 @@ class TritonEmitterTest : public GpuCodegenTest { }; TEST_F(TritonEmitterTest, ReductionOnMinormostAxisIsEmittedCorrectly) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t maximum { Arg_0 = f32[] parameter(0) @@ -90,7 +90,7 @@ CHECK: "tt.reduce"(%[[LOAD:.*]]) <{axis = 1 : i32}> } TEST_F(TritonEmitterTest, ReductionOnMajormostAxisIsEmittedCorrectly) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t maximum { Arg_0 = f32[] parameter(0) @@ -118,7 +118,7 @@ CHECK: "tt.reduce"(%[[LOAD:.*]]) <{axis = 0 : i32}> } TEST_F(TritonEmitterTest, ReductionOnIntermediateAxisIsEmittedCorrectly) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t maximum { Arg_0 = f32[] parameter(0) @@ -148,7 +148,7 @@ CHECK: "tt.reduce"(%[[SELECT:.*]]) <{axis = 2 : i32}> } TEST_F(TritonEmitterTest, TestReductionWithTileSizeLargerThanSourceTensor) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t maximum { Arg_0 = f32[] parameter(0) @@ -189,7 +189,7 @@ CHECK: }) // TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be // moved to deviceless test file. TEST_F(TritonEmitterTest, TestGenericEmitterWithSoftMaxSingleParameter) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t add { Arg_0 = f32[] parameter(0) @@ -250,7 +250,7 @@ CHECK: } // TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be // moved to deviceless test file. TEST_F(TritonEmitterTest, TestGenericEmitterWithMultipleParameters) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t add { @@ -313,7 +313,7 @@ CHECK-DAG: tt.store {{.*}} : !tt.ptr> } TEST_F(TritonEmitterTest, TestGenericEmitterWithMultipleTiledDimensions) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t max { @@ -394,7 +394,7 @@ CHECK-NEXT: tt.store {{.*}} : !tt.ptr> TEST_F( TritonEmitterTest, DiamondWithAdditionalDiamondParameterBroadcastedAlongReductionDimProducesAccurateResults) { // NOLINT(whitespace/line_length) - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule h1 max_computation { @@ -432,7 +432,7 @@ TEST_F(TritonEmitterTest, NestedReducerFusionGetsCodegenedCorrectly) { GTEST_SKIP() << "BF16 not supported."; } - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule softmax fused_convert { @@ -471,7 +471,7 @@ ENTRY main { TEST_F( TritonEmitterTest, DiamondWithAdditionalDiamondParameterBroadcastedAlongBatchDimProducesAccurateResults) { // NOLINT(whitespace/line_length) - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule h1 max_computation { @@ -504,7 +504,7 @@ ENTRY main { TEST_F( TritonEmitterTest, DiamondWithAdditionalSplatDiamondScalarParameterProducesAccurateResults) { // NOLINT(whitespace/line_length) - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule h1 max_computation { @@ -559,7 +559,7 @@ ENTRY main { TEST_F( TritonEmitterTest, DiamondWithAdditionalBroadcastOf1DParameterAlongNonReductionDimensionsProducesAccurateResults) { // NOLINT(whitespace/line_length) - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule h1 max_computation { @@ -593,7 +593,7 @@ ENTRY main { // TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be // moved to deviceless test file. TEST_F(TritonEmitterTest, EmitterFailsIfComputeCapabilityIsBelowAmpere) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( triton_computation { p0 = f32[10,10] parameter(0) p1 = f32[10,10] parameter(1) @@ -693,7 +693,7 @@ ENTRY entry_computation { // TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should b // moved to deviceless test file. TEST_F(TritonEmitterTest, TestGenericEmitterReductionFusion) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t add { Arg_0 = f32[] parameter(0) @@ -735,7 +735,7 @@ CHECK: tt.store {{.*}} : !tt.ptr> TEST_F(TritonEmitterTest, TestGenericEmitterWithReductonAndMultidimensionalTile) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule t max { Arg_0 = f32[] parameter(0) @@ -763,7 +763,7 @@ ENTRY main { } TEST_F(TritonEmitterTest, TestSoftMaxWithTileElementsNotAllContiguous) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m region { @@ -792,7 +792,7 @@ ENTRY entry_computation { } TEST_F(TritonEmitterTest, TestSliceWithTileThatNeedsMasking) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m fused_computation { @@ -811,7 +811,7 @@ ENTRY entry_computation { } TEST_F(TritonEmitterTest, TestSliceWithTileElementsNotAllContiguous) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m fused_computation { @@ -830,7 +830,7 @@ ENTRY entry_computation { } TEST_F(TritonEmitterTest, TestSliceWithTileElementsNotAllContiguousUnaligned) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m fused_computation { @@ -853,7 +853,7 @@ ENTRY entry_computation { } TEST_F(TritonEmitterTest, ReshapeIntoBroadcastIsLoweredCorrectly) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( triton_computation { param_0 = f32[128,256]{1,0} parameter(0) reshape = f32[64,2,256]{2,1,0} reshape(param_0) @@ -879,7 +879,7 @@ CHECK: tt.reshape } TEST_F(TritonEmitterTest, BitcastIntoBroadcastIsLoweredCorrectly) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( triton_computation { param_0 = f32[128,256]{1,0} parameter(0) bitcast = f32[64,2,256]{2,1,0} bitcast(param_0) @@ -905,7 +905,7 @@ CHECK: tt.reshape } TEST_F(TritonEmitterTest, BitcastNormalizedLayoutsIsLoweredCorrectly) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( triton_computation { p = s8[5,42] parameter(0) ROOT bitcast = s8[5,6,7] bitcast(p) @@ -933,7 +933,7 @@ CHECK: tt.store } TEST_F(TritonEmitterTest, BitcastNonNormalizedInputLayoutIsLoweredCorrectly) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( triton_computation { p = s8[42,5]{0,1} parameter(0) ROOT bitcast = s8[5,6,7] bitcast(p) @@ -961,7 +961,7 @@ CHECK: tt.store } TEST_F(TritonEmitterTest, BitcastNonNormalizedOutputLayoutIsLoweredCorrectly) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( triton_computation { p = s8[5,42] parameter(0) ROOT bitcast = s8[5,6,7]{1,2,0} bitcast(p) @@ -990,7 +990,7 @@ CHECK: tt.store TEST_F(TritonEmitterTest, BitcastNonNormalizedInputOutputLayoutIsLoweredCorrectly) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( triton_computation { p = s8[42,5]{0,1} parameter(0) ROOT bitcast = s8[5,6,7]{1,2,0} bitcast(p) @@ -1018,7 +1018,7 @@ CHECK: tt.store } TEST_F(TritonEmitterTest, BitcastTransposeOnlyIsLoweredCorrectly) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( triton_computation { p = s8[42,5]{0,1} parameter(0) ROOT bitcast = s8[5,42] bitcast(p) @@ -1047,7 +1047,7 @@ CHECK: tt.store // TODO(b/353484968): move this test to a deviceless file. TEST_F(TritonEmitterTest, GenericEmitterLowersBroadcastFrom0dOperandCorrectly) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( triton_computation { param_0 = f32[] parameter(0) ROOT broadcast = f32[127,125]{1,0} broadcast(param_0), dimensions={} @@ -1071,7 +1071,7 @@ CHECK: tt.splat {{.*}} f32 -> tensor<8x4xf32> TEST_F(TritonEmitterTest, PredOutputIsStoredCorrectly) { // The 'pred' element type in XLA is unpacked and uses i8 for storage. This // is the only sub-byte type to have this behavior. - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m triton_computation { @@ -1104,7 +1104,7 @@ CHECK: tt.store {{.*}} %[[CASTED_OUT]] TEST_F(TritonEmitterTest, PredInputIsLoadedCorrectly) { // The 'pred' element type in XLA is unpacked and uses i8 for storage. This // is the only sub-byte type to have this behavior. - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m triton_computation { @@ -1140,7 +1140,7 @@ CHECK: arith.trunci %[[I8_PARAM]] : tensor<4xi8> to tensor<4xi1> } TEST_F(TritonEmitterTest, Transpose3D) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m triton_computation { @@ -1170,7 +1170,7 @@ CHECK: tt.trans %[[TILE]] {order = array} : tensor<8x4x1xf32> // TODO(b/353484968): Delete this test once we have constraints to only // propagate tile sizes that are a power of 2. TEST_F(TritonEmitterTest, Transpose3D_TileFullDimThatIsNotPowerOf2) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m triton_computation { @@ -1192,7 +1192,7 @@ ENTRY main { } TEST_F(TritonEmitterTest, StridedIota4DIsCodegeneratedCorrectly) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( triton_computation { iota = f32[3,4,1000,5] iota(), iota_dimension=2 ROOT slice = f32[3,4,182,5] slice(iota), slice={[0:3], [0:4], [91:1000:5], [0:5]} diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_deviceless_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_deviceless_test.cc new file mode 100644 index 00000000000000..cf08812145db5a --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_deviceless_test.cc @@ -0,0 +1,125 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "absl/strings/string_view.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/service/gpu/fusions/emitter_loc_op_builder.h" +#include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/stream_executor/device_description.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +#if defined(PLATFORM_GOOGLE) +#else + +#endif +namespace xla::gpu { +namespace { + +using ::tsl::testing::IsOkAndHolds; + +class AnnotationsTest : public GpuCodegenTest { + public: + const stream_executor::GpuComputeCapability& GpuComputeComp() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_unsupported_annotate_with_emitter_loc(true); + return debug_options; + } +}; + +TEST_F(AnnotationsTest, Annotations) { + static constexpr absl::string_view kHloText = R"( + HloModule Annotations + + triton_dot { + p0 = f32[8,8] parameter(0) + p1 = f32[8,8] parameter(1) + ROOT dot = f32[8,8] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 + } + + ENTRY e { + p0 = f32[8,8]{1, 0} parameter(0) + p1 = f32[8,8]{1, 0} parameter(1) + ROOT _ = f32[8,8] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + { + "block_m":32, + "block_n":32, + "block_k":32, + "split_k":1, + "num_stages":1, + "num_warps":1, + "num_ctas":1 + } + } + } + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + auto* comp = module->GetComputationWithName("triton_dot"); + EXPECT_NE(comp, nullptr); + auto fusion_backend_config = comp->FusionInstruction() + ->backend_config() + ->fusion_backend_config(); + BlockLevelParameters block_level_parameters = + BlockLevelParameters::FromBlockLevelFusionConfig( + fusion_backend_config.block_level_fusion_config()); + + auto* fusion = Cast(comp->FusionInstruction()); + + mlir::MLIRContext context; + TF_ASSERT_OK_AND_ASSIGN( + auto triton_module, + CreateTritonModule("triton_fn", fusion, + TestGpuDeviceInfo::RTXA6000DeviceInfo(), + block_level_parameters, context)); + + std::string annotated_ir = DumpTritonIR(triton_module.get(), true); + + if constexpr (EmitterLocOpBuilder::kSourceLocationSupported) { + EXPECT_THAT(RunFileCheck(annotated_ir, R"( + CHECK: [[SOMETHING:.*]] "triton_dot -> [[FILE_LINE:triton_fusion_emitter.*:.*]]" + )"), + IsOkAndHolds(true)); + } else { + EXPECT_THAT(RunFileCheck(annotated_ir, R"( + CHECK: [[SOMETHING:.*]] "triton_dot" + )"), + IsOkAndHolds(true)); + } +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_int4_device_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_int4_device_test.cc new file mode 100644 index 00000000000000..0be5ca8afd8570 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_int4_device_test.cc @@ -0,0 +1,727 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "xla/autotuning.pb.h" +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla.pb.h" + +namespace xla { +namespace gpu { +namespace { + +class TritonTest : public GpuCodegenTest { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + // Do not fall back to cuBLAS, we are testing Triton. + debug_options.set_xla_gpu_cublas_fallback(false); + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + // Always rewrite Gemms with Triton regardless of size. + debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); + return debug_options; + } + + stream_executor::CudaComputeCapability GetCudaComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + } + + const stream_executor::GpuComputeCapability& GpuComputeComp() { + return device_desc().gpu_compute_capability(); + } + stream_executor::GpuComputeCapability CudaAmpereOrRocm() { + if (std::holds_alternative( + GpuComputeComp())) { + return stream_executor::GpuComputeCapability{ + device_desc().rocm_compute_capability()}; + } else { + return stream_executor::GpuComputeCapability{ + stream_executor::CudaComputeCapability{ + stream_executor::CudaComputeCapability::AMPERE, 0}}; + } + } + + protected: + const stream_executor::DeviceDescription& device_desc() { + return backend().default_stream_executor()->GetDeviceDescription(); + } +}; + +// The test class for the Triton MLIR pass that converts MLIR code that works +// with the plain int4 tensors to the packed int4 tensors. The goal is to prove +// that the pass generates the correct MLIR and it produces the same +// results. Eventually the pass will be enabled by default and the support for +// the int4 tensors will be removed from the Legacy Triton emitter. +class PlainInt4ToPackedInt4RewritePassTest : public TritonTest { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_experimental_enable_triton_i4_rewrites(true); + return debug_options; + } +}; + +TEST_F(PlainInt4ToPackedInt4RewritePassTest, + DotWithI4WeightsOnLhsFusedWithMultiplyByChannelScales) { + GTEST_SKIP() << "TODO(rocm): Weekly-sync 25-01-13: Skip ivestigate int4 " + "issue with triton."; + constexpr absl::string_view kHloText = R"( + HloModule DotWithI4WeightsOnLhsFusedWithMultiplyByChannelScales + + DotWithI4WeightsOnLhsFusedWithMultiplyByChannelScales { + w = s4[32,64,128]{2,1,0} parameter(0) + w.i8 = s8[32,64,128]{2,1,0} convert(w) + w.f32 = f32[32,64,128]{2,1,0} convert(w.i8) + scales = f32[32,128]{1,0} parameter(1) + scales.broadcast = f32[32,64,128]{2,1,0} broadcast(scales), dimensions={0,2} + weights.scaled = f32[32,64,128]{2,1,0} multiply(w.f32, scales.broadcast) + activations = f32[32,64,256]{2,1,0} parameter(2) + ROOT dot = f32[32,128,256]{2,1,0} dot(weights.scaled, activations), + lhs_batch_dims={0}, + lhs_contracting_dims={1}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + } + + ENTRY main { + w = s4[32,64,128]{2,1,0} parameter(0) + scales = f32[32,128]{1,0} parameter(1) + p2 = f32[32,64,256]{2,1,0} parameter(2) + ROOT dot = f32[32,128,256]{2,1,0} fusion(w, scales, p2), + kind=kCustom, + calls=DotWithI4WeightsOnLhsFusedWithMultiplyByChannelScales, + backend_config={ + "fusion_backend_config":{ + "kind":"__triton_gemm" + } + } + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-5, /*arel=*/1e-5})); +} + +using ::testing::TestParamInfo; +using ::testing::WithParamInterface; + +struct I4TestParams { + static std::string ToString(const TestParamInfo& params) { + return params.param.name; + } + + std::string Format(absl::string_view format) const { + return absl::StrReplaceAll( + format, {{"${name}", name}, + {"${lhs}", lhs}, + {"${rhs}", rhs}, + {"${lhs_contracting_dim}", absl::StrCat(lhs_contracting_dim)}, + {"${rhs_contracting_dim}", absl::StrCat(rhs_contracting_dim)}, + {"${out}", out}}); + } + bool HasBatchDim() const { + return std::vector(absl::StrSplit(lhs, ',')).size() > 2; + } + + std::string name; // The name of the test. + std::string lhs; // The lhs shape like "128,16". + std::string rhs; // The rhs shape like "128,256". + int lhs_contracting_dim; // The contracting dimension of the lhs. + int rhs_contracting_dim; // The contracting dimension of the rhs. + std::string out; // The output shape like "16,256". +}; + +class ParametrizedPlainInt4ToPackedInt4RewritePassTest + : public PlainInt4ToPackedInt4RewritePassTest, + public WithParamInterface {}; + +TEST_P(ParametrizedPlainInt4ToPackedInt4RewritePassTest, Int4WeightsOnTheLhs) { + GTEST_SKIP() << "TODO(rocm): Weekly-sync 25-01-13: Skip ivestigate int4 " + "issue with triton."; + if (GetParam().HasBatchDim()) { + GTEST_SKIP() << "2d test ignores batch dim case."; + } + constexpr absl::string_view kHloTextTemplate = R"( + HloModule lhs_${name} + + lhs_${name} { + w.s4 = s4[${lhs}]{1,0} parameter(0) + w.s8 = s8[${lhs}]{1,0} convert(w.s4) + w.f32 = f32[${lhs}]{1,0} convert(w.s8) + a = f32[${rhs}]{1,0} parameter(1) + ROOT lhs_${name} = f32[${out}]{1,0} dot(w.f32, a), + lhs_contracting_dims={${lhs_contracting_dim}}, + rhs_contracting_dims={${rhs_contracting_dim}} + } + + ENTRY main { + w = s4[${lhs}]{1,0} parameter(0) + a = f32[${rhs}]{1,0} parameter(1) + ROOT gemm_fusion_dot.2 = f32[${out}]{1,0} fusion(w, a), + kind=kCustom, + calls=lhs_${name}, + backend_config={ + "fusion_backend_config":{ + "kind":"__triton_gemm" + } + } + } + )"; + std::string hlo_text = GetParam().Format(kHloTextTemplate); + EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_text, + ErrorSpec{/*aabs=*/1e-5, /*arel=*/1e-5})) + << "Failed for HLO: " << hlo_text; +} + +TEST_P(ParametrizedPlainInt4ToPackedInt4RewritePassTest, + Int4WeightsOnTheLhsWithBatchDim) { + GTEST_SKIP() << "TODO(rocm): Weekly-sync 25-01-13: Skip ivestigate int4 " + "issue with triton."; + if (!GetParam().HasBatchDim()) { + GTEST_SKIP() << "3d test ignores 2d case."; + } + constexpr absl::string_view kHloTextTemplate = R"( + HloModule ${name} + + fusion { + w.s4 = s4[${lhs}]{2,1,0} parameter(0) + w.s8 = s8[${lhs}]{2,1,0} convert(w.s4) + w.f32 = f32[${lhs}]{2,1,0} convert(w.s8) + a = f32[${rhs}]{2,1,0} parameter(1) + ROOT dot.0 = f32[${out}]{2,1,0} dot(w.f32, a), + lhs_contracting_dims={${lhs_contracting_dim}}, + rhs_contracting_dims={${rhs_contracting_dim}}, + lhs_batch_dims={0}, + rhs_batch_dims={0} + } + + ENTRY gemm_fusion_dot_computation { + w = s4[${lhs}]{2,1,0} parameter(0) + a = f32[${rhs}]{2,1,0} parameter(1) + ROOT gemm_fusion_dot.2 = f32[${out}]{2,1,0} fusion(w, a), + kind=kCustom, + calls=fusion, + backend_config={ + "fusion_backend_config":{ + "kind":"__triton_gemm" + } + } + } + )"; + std::string hlo_text = GetParam().Format(kHloTextTemplate); + EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_text, + ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})) + << "Failed for HLO: " << hlo_text; +} + +TEST_P(ParametrizedPlainInt4ToPackedInt4RewritePassTest, Int4WeightsOnTheRhs) { + GTEST_SKIP() + << "TODO: Weekly-sync 25-01-13: Skip ivestigate int4 issue with triton."; + if (GetParam().HasBatchDim()) { + GTEST_SKIP() << "2d test ignores batch dim case."; + } + + constexpr absl::string_view kHloTextTemplate = R"( + HloModule rhs_${name} + + rhs_${name} { + a = f32[${lhs}]{1,0} parameter(0) + w.s4 = s4[${rhs}]{1,0} parameter(1) + w.s8 = s8[${rhs}]{1,0} convert(w.s4) + w.f32 = f32[${rhs}]{1,0} convert(w.s8) + ROOT rhs_${name} = f32[${out}]{1,0} dot(a, w.f32), + lhs_contracting_dims={${lhs_contracting_dim}}, + rhs_contracting_dims={${rhs_contracting_dim}} + } + + ENTRY main { + a = f32[${lhs}]{1,0} parameter(0) + w = s4[${rhs}]{1,0} parameter(1) + ROOT rhs_${name} = f32[${out}]{1,0} fusion(a, w), + kind=kCustom, + calls=rhs_${name}, + backend_config={ + "fusion_backend_config":{ + "kind":"__triton_gemm" + } + } + } + )"; + std::string hlo_text = GetParam().Format(kHloTextTemplate); + EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_text, + ErrorSpec{/*aabs=*/1e-5, /*arel=*/1e-5})) + << "Failed for HLO: " << hlo_text; +} + +std::vector Int4TestCases() { + return { + {"int4_dot_128_16_x_128_256", "128,16", "128,256", 0, 0, "16,256"}, + {"int4_dot_128_16_x_256_128", "128,16", "256,128", 0, 1, "16,256"}, + {"int4_dot_16_128_x_256_128", "16,128", "256,128", 1, 1, "16,256"}, + {"int4_dot_16_128_x_128_256", "16,128", "128,256", 1, 0, "16,256"}, + {"int4_dot_1_128_x_256_128", "1,128", "256,128", 1, 1, "1,256"}, + {"int4_dot_128_1_x_256_128", "128,1", "256,128", 0, 1, "1,256"}, + {"int4_dot_16_128_x_128_1", "16,128", "128,1", 1, 0, "16,1"}, + {"int4_dot_16_128_x_1_128", "16,128", "1,128", 1, 1, "16,1"}, + + {"dot_8_128_16_x_8_128_256", "8,128,16", "8,128,256", 1, 1, "8,16,256"}, + {"dot_8_128_16_x_8_256_128", "8,128,16", "8,256,128", 1, 2, "8,16,256"}, + {"dot_8_16_128_x_8_256_128", "8,16,128", "8,256,128", 2, 2, "8,16,256"}, + {"dot_8_16_128_x_8_128_256", "8,16,128", "8,128,256", 2, 1, "8,16,256"}, + {"dot_8_1_128_x_8_256_128", "8,1,128", "8,256,128", 2, 2, "8,1,256"}, + {"dot_8_128_1_x_8_256_128", "8,128,1", "8,256,128", 1, 2, "8,1,256"}, + {"dot_8_16_128_x_8_128_1", "8,16,128", "8,128,1", 2, 1, "8,16,1"}, + {"dot_8_16_128_x_8_1_128", "8,16,128", "8,1,128", 2, 2, "8,16,1"}, + }; +} + +INSTANTIATE_TEST_SUITE_P(PlainInt4ToPackedInt4RewritePassTests, + ParametrizedPlainInt4ToPackedInt4RewritePassTest, + ::testing::ValuesIn(Int4TestCases()), + I4TestParams::ToString); + +TEST_F(TritonTest, NonstandardLayoutInt4) { + GTEST_SKIP() << "TODO(rocm): Weekly-sync 25-01-13: Skip ivestigate int4 " + "issue with triton."; + constexpr absl::string_view kHloText = R"( + HloModule NonstandardLayout + + ENTRY main { + p0 = s4[64,128]{0,1} parameter(0) + p1 = bf16[256,64]{1,0} parameter(1) + ROOT %dot = bf16[128,256]{1,0} dot(s4[64,128]{0,1} p0, bf16[256,64]{1,0} p1), + lhs_contracting_dims={0}, + rhs_contracting_dims={1} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( + CHECK: %[[param_0:.*]] = s4[64,128]{0,1:E(4)} parameter(0) + CHECK: %[[bitcast:.*]] = s4[128,64]{1,0:E(4)} bitcast(s4[64,128]{0,1:E(4)} %[[param_0]]) + CHECK: %[[convert:.*]] = bf16[128,64]{1,0} convert(s4[128,64]{1,0:E(4)} %[[bitcast]]) + CHECK: %[[param_1:.*]] = bf16[256,64]{1,0} parameter(1) + CHECK: ROOT %dot.1 = bf16[128,256]{1,0} dot(bf16[128,64]{1,0} %[[convert]], bf16[256,64]{1,0} %[[param_1]]), lhs_contracting_dims={1}, rhs_contracting_dims={1} + )")); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonTest, NonstandardLayoutWithManyNonContractingDims) { + GTEST_SKIP() << "TODO(rocm): Weekly-sync 25-01-13: Skip ivestigate int4 " + "issue with triton."; + // We cannot do triton_gemm and we use cuBLAS instead. + constexpr absl::string_view kHloText = R"( + HloModule t + + ENTRY main { + p0 = s4[128,64,192]{1,0,2} parameter(0) + p1 = bf16[256,64]{1,0} parameter(1) + ROOT %dot = bf16[128,192,256]{2,1,0} dot(p0, p1), + lhs_contracting_dims={1}, + rhs_contracting_dims={1} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(CHECK: "__cublas$gemm")")); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-2})); +} + +TEST_F(TritonTest, NonstandardLayoutWithManyNonContractingDimsReversedLayout) { + GTEST_SKIP() << "TODO(rocm): Weekly-sync 25-01-13: Skip ivestigate int4 " + "issue with triton."; + // We cannot do triton_gemm and we use cuBLAS instead. + constexpr absl::string_view kHloText = R"( + HloModule t + + ENTRY main { + lhs = s4[128,64,192]{0,1,2} parameter(0) + rhs = bf16[256,64]{1,0} parameter(1) + ROOT %dot = bf16[128,192,256]{2,1,0} dot(lhs, rhs), + lhs_contracting_dims={1}, + rhs_contracting_dims={1} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(CHECK: "__cublas$gemm")")); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonTest, NegatePlusConvertHLO) { + constexpr absl::string_view kHloText = R"( + HloModule t + + ENTRY main { + lhs = s4[16,32,64]{2,1,0} parameter(0) + lhs_negated = s4[16,32,64]{2,1,0} negate(lhs) + lhs_converted = bf16[16,32,64]{2,1,0} convert(lhs_negated) + rhs = bf16[16,64,16]{2,1,0} parameter(1) + ROOT dot = bf16[16,32,16]{2,1,0} dot(lhs_converted, rhs), + lhs_contracting_dims={2}, + rhs_contracting_dims={1}, + lhs_batch_dims={0}, + rhs_batch_dims={0} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonTest, RejectTritonFusionForWithMinorBatchDim) { + constexpr absl::string_view kHloText = R"( + HloModule t + + ENTRY main { + lhs = s4[32,64,16]{2,1,0} parameter(0) + lhs_converted = bf16[32,64,16]{2,1,0} convert(lhs) + rhs = bf16[16,64,16]{2,1,0} parameter(1) + ROOT dot = bf16[16,32,16]{2,1,0} dot(lhs_converted, rhs), + lhs_contracting_dims={1}, + rhs_contracting_dims={1}, + lhs_batch_dims={2}, + rhs_batch_dims={0} + } + )"; + + const std::string pattern = + R"(CHECK-NOT: "kind":"__triton_gemm","triton_gemm_config")"; + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); + EXPECT_TRUE(ok); +} + +TEST_F(TritonTest, LHSWithMinorDimEqualTo1) { + // We prove that triton can handle int4 dot with non contracting dim size + // equal to 1. + constexpr absl::string_view kHloText = R"( + HloModule t + + triton_computation { + lhs = s4[16,1024,1]{2,1,0} parameter(0) + lhs_converted = bf16[16,1024,1]{2,1,0} convert(lhs) + rhs = bf16[16,64,1024]{2,1,0} parameter(1) + ROOT dot = bf16[16,1,64]{2,1,0} dot(lhs_converted, rhs), + lhs_contracting_dims={1}, + rhs_contracting_dims={2}, + lhs_batch_dims={0}, + rhs_batch_dims={0} + } + + ENTRY main { + lhs = s4[16,1024,1]{2,1,0} parameter(0) + rhs = bf16[16,64,1024]{2,1,0} parameter(1) + ROOT dot = bf16[16,1,64]{2,1,0} fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonTest, RHSWithMinorDimEqualTo1) { + // We prove that triton can handle int4 dot with non contracting dim size + // equal to 1. + constexpr absl::string_view kHloText = R"( + HloModule t + + triton_computation { + lhs = bf16[16,1024,64]{2,1,0} parameter(0) + rhs = s4[16,1024,1]{2,1,0} parameter(1) + rhs_converted = bf16[16,1024,1]{2,1,0} convert(rhs) + ROOT dot = bf16[16,64,1]{2,1,0} dot(lhs, rhs_converted), + lhs_contracting_dims={1}, + rhs_contracting_dims={1}, + lhs_batch_dims={0}, + rhs_batch_dims={0} + } + + ENTRY main { + lhs = bf16[16,1024,64]{2,1,0} parameter(0) + rhs = s4[16,1024,1]{2,1,0} parameter(1) + ROOT dot = bf16[16,64,1]{2,1,0} fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonTest, LHSNonMinorContractingDim) { + // We prove that triton can handle int4 dot with non minor + // lhs_contracting_dim. + constexpr absl::string_view kHloText = R"( + HloModule t + + triton_computation { + lhs = s4[1024,8]{1,0} parameter(0) + lhs_converted = bf16[1024,8]{1,0} convert(lhs) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} dot(lhs_converted, rhs), + lhs_contracting_dims={0}, + rhs_contracting_dims={0} + } + + ENTRY main { + lhs = s4[1024,8]{1,0} parameter(0) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonTest, LHSNonMinorContractingDimWithBatchDim0) { + // We prove that triton can handle int4 dot with non minor + // lhs_contracting_dim. + constexpr absl::string_view kHloText = R"( + HloModule t + + triton_computation { + lhs = s4[16,1024,8]{2,1,0} parameter(0) + lhs_converted = bf16[16,1024,8]{2,1,0} convert(lhs) + rhs = bf16[16,1024,4]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4]{2,1,0} dot(lhs_converted, rhs), + lhs_batch_dims={0}, + lhs_contracting_dims={1}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + } + + ENTRY main { + lhs = s4[16,1024,8]{2,1,0} parameter(0) + rhs = bf16[16,1024,4]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4]{2,1,0} fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonTest, LHSMinorContractingDim) { + // We prove that triton can handle int4 dot with minor lhs_contracting_dim. + constexpr absl::string_view kHloText = R"( + HloModule t + + triton_computation { + lhs = s4[8,1024]{1,0} parameter(0) + lhs_converted = bf16[8,1024]{1,0} convert(lhs) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} dot(lhs_converted, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY main { + lhs = s4[8,1024]{1,0} parameter(0) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonTest, ConvertPlusNegate) { + constexpr absl::string_view kHloText = R"( + HloModule t + + triton_computation { + lhs = s4[8,1024]{1,0} parameter(0) + lhs_converted = bf16[8,1024]{1,0} convert(lhs) + lhs_negated = bf16[8,1024]{1,0} negate(lhs_converted) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} dot(lhs_negated, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY main { + lhs = s4[8,1024]{1,0} parameter(0) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonTest, LHSMinorContractingDimWithBatchDim0) { + // We prove that triton can handle int4 dot with minor lhs_contracting_dim. + constexpr absl::string_view kHloText = R"( + HloModule t + + triton_computation { + lhs = s4[16,8,1024]{2,1,0} parameter(0) + lhs_converted = bf16[16,8,1024]{2,1,0} convert(lhs) + rhs = bf16[16,1024,4]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4]{2,1,0} dot(lhs_converted, rhs), + lhs_batch_dims={0}, + lhs_contracting_dims={2}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + } + + ENTRY main { + lhs = s4[16,8,1024]{2,1,0} parameter(0) + rhs = bf16[16,1024,4]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4]{2,1,0} fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonTest, RHSTestWithNotMinorContractingDim) { + constexpr absl::string_view kHloText = R"( + HloModule t + + triton_computation { + lhs = bf16[8,1024]{1,0} parameter(0) + rhs = s4[1024,4]{1,0} parameter(1) + rhs_converted = bf16[1024,4]{1,0} convert(rhs) + ROOT dot = bf16[8,4] dot(lhs, rhs_converted), + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + + ENTRY main { + lhs = bf16[8,1024]{1,0} parameter(0) + rhs = s4[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4] fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonTest, RHSTestWithMinorContractingDim) { + constexpr absl::string_view kHloText = R"( + HloModule t + + triton_computation { + lhs = bf16[8,1024]{1,0} parameter(0) + rhs = s4[4,1024]{1,0} parameter(1) + rhs_converted = bf16[4,1024]{1,0} convert(rhs) + ROOT dot = bf16[8,4] dot(lhs, rhs_converted), + lhs_contracting_dims={1}, + rhs_contracting_dims={1} + } + + ENTRY main { + lhs = bf16[8,1024]{1,0} parameter(0) + rhs = s4[4,1024]{1,0} parameter(1) + ROOT dot = bf16[8,4] fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonTest, RHSTestWithMinorContractingDimWithBatchDim) { + constexpr absl::string_view kHloText = R"( + HloModule t + + triton_computation { + lhs = bf16[16,8,1024]{2,1,0} parameter(0) + rhs = s4[16,1024,4]{2,1,0} parameter(1) + rhs_converted = bf16[16,1024,4]{2,1,0} convert(rhs) + ROOT dot = bf16[16,8,4] dot(lhs, rhs_converted), + lhs_batch_dims={0}, + lhs_contracting_dims={2}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + } + + ENTRY main { + lhs = bf16[16,8,1024]{2,1,0} parameter(0) + rhs = s4[16,1024,4]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4] fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonTest, RHSTestWithNotMinorContractingDimWithBatchDim0) { + constexpr absl::string_view kHloText = R"( + HloModule t + + triton_computation { + lhs = bf16[16,8,1024]{2,1,0} parameter(0) + rhs = s4[16,4,1024]{2,1,0} parameter(1) + rhs_converted = bf16[16,4,1024]{2,1,0} convert(rhs) + ROOT dot = bf16[16,8,4] dot(lhs, rhs_converted), + lhs_batch_dims={0}, + lhs_contracting_dims={2}, + rhs_batch_dims={0}, + rhs_contracting_dims={2} + } + + ENTRY main { + lhs = bf16[16,8,1024]{2,1,0} parameter(0) + rhs = s4[16,4,1024]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4] fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc index 3ffa43bca72bc9..01fa9c22d45d0f 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include #include #include "absl/log/check.h" @@ -88,7 +87,7 @@ ENTRY e { } TEST_F(TritonGemmTest, LargeNonContractingProductWorks) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -112,7 +111,7 @@ ENTRY e { } TEST_F(TritonGemmTest, LargeBatchWorks) { - constexpr std::string_view kHloText = R"( + constexpr absl::string_view kHloText = R"( HloModule m ENTRY e { @@ -134,17 +133,9 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -class TritonSoftmaxTest : public GpuCodegenTest { - public: - DebugOptions GetDebugOptionsForTest() const override { - DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); - debug_options - .set_xla_gpu_experimental_enable_triton_softmax_priority_fusion(true); - return debug_options; - } -}; +using TritonNormalizationTest = GpuCodegenTest; -TEST_F(TritonSoftmaxTest, +TEST_F(TritonNormalizationTest, CanEmitDiamondWithInputNumberOfElementsLargerThanInt32Max) { const std::string hlo_text = R"( HloModule softmax diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.cc index 62fdc9ff20b222..b3e70adddcda83 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -43,9 +44,9 @@ limitations under the License. #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" @@ -66,6 +67,7 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/transforms/transformation_helpers.h" #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" +#include "xla/service/gpu/fusions/emitter_loc_op_builder.h" #include "xla/service/gpu/fusions/triton/emitter_helpers.h" #include "xla/service/gpu/fusions/triton/xla_triton_ops.h" #include "xla/service/gpu/ir_emission_utils.h" @@ -98,7 +100,6 @@ namespace mh = ::mlir::mhlo; using ::llvm::SmallVector; using ::mlir::ArrayRef; -using ::mlir::ImplicitLocOpBuilder; using ::mlir::ShapedType; using ::mlir::Type; using ::mlir::Value; @@ -106,7 +107,14 @@ using ::mlir::ValueRange; namespace { -absl::StatusOr TritonType(mlir::OpBuilder b, PrimitiveType t) { +bool IsTritonInt4RewritesEnabled(const HloInstruction& hlo) { + return hlo.GetModule() + ->config() + .debug_options() + .xla_gpu_experimental_enable_triton_i4_rewrites(); +} + +absl::StatusOr TritonType(EmitterLocOpBuilder& b, PrimitiveType t) { switch (t) { case F64: return b.getF64Type(); @@ -129,7 +137,7 @@ absl::StatusOr TritonType(mlir::OpBuilder b, PrimitiveType t) { case S4: // The unpacking to i8 is supported by the emitter. // We pass the s4 tensor as i8 tensor with the minor dimension having 2x // less elements and unpack in the inner loop of the triton kernel. - return b.getI8Type(); + return b.getI4Type(); case F8E5M2: return b.getFloat8E5M2Type(); case F8E4M3FN: @@ -141,7 +149,7 @@ absl::StatusOr TritonType(mlir::OpBuilder b, PrimitiveType t) { } } -Type StorageType(mlir::OpBuilder b, Type t) { +Type StorageType(EmitterLocOpBuilder& b, Type t) { if (t.isInteger(1)) { return b.getI8Type(); } @@ -150,7 +158,7 @@ Type StorageType(mlir::OpBuilder b, Type t) { // Create a scalar constant. template -ma::ConstantOp CreateConst(ImplicitLocOpBuilder b, Type type, T value) { +ma::ConstantOp CreateConst(EmitterLocOpBuilder b, Type type, T value) { if (mlir::isa(type)) { return b.create(b.getIntegerAttr(type, value)); } @@ -163,13 +171,14 @@ ma::ConstantOp CreateConst(ImplicitLocOpBuilder b, Type type, T value) { // Create a tensor constant. template -ma::ConstantOp CreateConst(ImplicitLocOpBuilder& b, Type type, T value, +ma::ConstantOp CreateConst(EmitterLocOpBuilder b, Type type, T value, llvm::ArrayRef shape) { auto tensor_type = mlir::RankedTensorType::get(shape, type); if (auto int_type = mlir::dyn_cast(type)) { return b.create(mlir::DenseElementsAttr::get( - tensor_type, mlir::APInt(int_type.getIntOrFloatBitWidth(), value, - /*isSigned=*/std::is_signed_v))); + tensor_type, + mlir::APInt(int_type.getIntOrFloatBitWidth(), value, + /*isSigned=*/std::is_signed_v, /*implicitTrunc=*/true))); } if (auto float_type = mlir::dyn_cast(type)) { return b.create(mlir::DenseElementsAttr::get( @@ -178,7 +187,7 @@ ma::ConstantOp CreateConst(ImplicitLocOpBuilder& b, Type type, T value, LOG(FATAL) << "Constant type not supported: " << llvm_ir::DumpToString(type); } -Value ZerosLike(ImplicitLocOpBuilder& b, Value x) { +Value ZerosLike(EmitterLocOpBuilder b, Value x) { if (auto src_shaped_ty = mlir::dyn_cast(x.getType())) { Type src_ty = src_shaped_ty.getElementType(); return CreateConst(b, src_ty, 0, src_shaped_ty.getShape()); @@ -186,7 +195,7 @@ Value ZerosLike(ImplicitLocOpBuilder& b, Value x) { return CreateConst(b, x.getType(), 0); } -Value OnesLike(ImplicitLocOpBuilder& b, Value x) { +Value OnesLike(EmitterLocOpBuilder b, Value x) { if (auto src_shaped_ty = mlir::dyn_cast(x.getType())) { Type src_ty = src_shaped_ty.getElementType(); return CreateConst(b, src_ty, 1, src_shaped_ty.getShape()); @@ -199,7 +208,7 @@ bool IsFp8Type(Type t) { t.isFloat8E4M3FNUZ() || t.isFloat8E4M3B11FNUZ(); } -Value Cast(ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) { +Value Cast(EmitterLocOpBuilder b, Value value, Type dst_element_ty) { Type src_ty = value.getType(); Type src_element_ty = src_ty; Type fp32_ty = b.getF32Type(); @@ -277,14 +286,14 @@ Value Cast(ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) { // TODO(b/266862493): Support unsigned integer types. // The current logic handles signed integer types only. Additional handling // is needed for unsigned integer types. - auto cst_int = [&](int64_t x) { + auto cst_int = [&](EmitterLocOpBuilder b, int64_t x) { if (auto src_shaped_ty = mlir::dyn_cast(src_ty)) { return CreateConst(b, dst_element_ty, x, src_shaped_ty.getShape()); } else { return CreateConst(b, dst_element_ty, x); } }; - auto cst_float = [&](int64_t x) { + auto cst_float = [&](EmitterLocOpBuilder b, int64_t x) { if (auto src_shaped_ty = mlir::dyn_cast(src_ty)) { return CreateConst(b, src_fp_element_ty, x, src_shaped_ty.getShape()); } else { @@ -297,16 +306,16 @@ Value Cast(ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) { // value <= static_cast(INT_MIN) ? INT_MIN : ... auto clamped = b.create( - b.create(ma::CmpFPredicate::OLE, value, cst_float(min)), - cst_int(min), fptosi); + b.create(ma::CmpFPredicate::OLE, value, cst_float(b, min)), + cst_int(b, min), fptosi); // value >= static_cast(INT_MAX) ? INT_MAX : ... clamped = b.create( - b.create(ma::CmpFPredicate::OGE, value, cst_float(max)), - cst_int(max), clamped); + b.create(ma::CmpFPredicate::OGE, value, cst_float(b, max)), + cst_int(b, max), clamped); // isnan(value) ? 0 : ... return b.create( - b.create(ma::CmpFPredicate::UNO, value, value), cst_int(0), - clamped); + b.create(ma::CmpFPredicate::UNO, value, value), + cst_int(b, 0), clamped); } LOG(FATAL) << "Type conversion not supported: " @@ -314,7 +323,7 @@ Value Cast(ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) { << llvm_ir::DumpToString(dst_element_ty); } -Value Subtract(ImplicitLocOpBuilder& b, ValueRange values) { +Value Subtract(EmitterLocOpBuilder b, ValueRange values) { if (mlir::isa(mlir::getElementTypeOrSelf(values[0]))) { return b.create(values[0], values[1]); } else { @@ -322,7 +331,7 @@ Value Subtract(ImplicitLocOpBuilder& b, ValueRange values) { } } -Value Compare(ImplicitLocOpBuilder& b, ValueRange values, +Value Compare(EmitterLocOpBuilder b, ValueRange values, mh::ComparisonDirection direction) { const Type type = mlir::getElementTypeOrSelf(values[0]); if (mlir::isa(type)) { @@ -339,7 +348,7 @@ Value Compare(ImplicitLocOpBuilder& b, ValueRange values, values[0], values[1]); } -Value Maximum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, +Value Maximum(EmitterLocOpBuilder b, const se::DeviceDescription& device_info, ValueRange values) { if (mlir::isa(mlir::getElementTypeOrSelf(values[0]))) { return b.create(values); @@ -360,7 +369,7 @@ Value Maximum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, values[0], values[1]); } -Value Minimum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, +Value Minimum(EmitterLocOpBuilder b, const se::DeviceDescription& device_info, ValueRange values) { if (mlir::isa(mlir::getElementTypeOrSelf(values[0]))) { return b.create(values); @@ -382,12 +391,12 @@ Value Minimum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, values[0], values[1]); } -Value Splat(ImplicitLocOpBuilder& b, Value value, ArrayRef shape) { +Value Splat(EmitterLocOpBuilder b, Value value, ArrayRef shape) { auto type = mlir::RankedTensorType::get(shape, value.getType()); return b.create(type, value); } -absl::StatusOr EmitElementwise(ImplicitLocOpBuilder& b, +absl::StatusOr EmitElementwise(EmitterLocOpBuilder b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloInstruction& hlo, @@ -474,7 +483,7 @@ absl::StatusOr EmitElementwise(ImplicitLocOpBuilder& b, } } -absl::StatusOr EmitConstant(ImplicitLocOpBuilder& b, +absl::StatusOr EmitConstant(EmitterLocOpBuilder b, const HloInstruction& constant) { CHECK_EQ(constant.opcode(), HloOpcode::kConstant); CHECK(ShapeUtil::IsEffectiveScalar(constant.shape())); @@ -496,7 +505,7 @@ absl::StatusOr EmitConstant(ImplicitLocOpBuilder& b, } // Emit sequence of operations for unpacking 2xi4 -> i8. -absl::StatusOr EmitUnpackInt4(ImplicitLocOpBuilder& b, +absl::StatusOr EmitUnpackInt4(EmitterLocOpBuilder& b, const HloInstruction* hlo, int64_t unpack_dim_idx, Value& value) { VLOG(6) << "EmitUnpackInt4: " << hlo->ToString(); @@ -522,21 +531,21 @@ absl::StatusOr EmitUnpackInt4(ImplicitLocOpBuilder& b, using TensorValue = mlir::TypedValue; -Value Broadcast(ImplicitLocOpBuilder& b, TensorValue value, +Value Broadcast(EmitterLocOpBuilder b, TensorValue value, ArrayRef shape) { return b.create(value.getType().clone(shape), value); } -Value Range(ImplicitLocOpBuilder& b, int32_t limit) { +Value Range(EmitterLocOpBuilder b, int32_t limit) { auto type = mlir::RankedTensorType::get(limit, b.getI32Type()); return b.create(type, 0, limit); } -Value AddPtr(ImplicitLocOpBuilder& b, Value ptr, Value offset) { +Value AddPtr(EmitterLocOpBuilder b, Value ptr, Value offset) { return b.create(ptr.getType(), ptr, offset); } -Value EmitParameterLoad(ImplicitLocOpBuilder& b, Value pointer, +Value EmitParameterLoad(EmitterLocOpBuilder b, Value pointer, ArrayRef boundary_checks) { // 0-D MakeTensorPtrOp // @@ -606,7 +615,7 @@ struct Side { int64_t unpack_dim_idx = 0; }; -absl::StatusOr EmitBroadcast(ImplicitLocOpBuilder& b, +absl::StatusOr EmitBroadcast(EmitterLocOpBuilder b, const TritonFusionAnalysis* analysis, const Side& side, const HloInstruction& broadcast, @@ -653,7 +662,7 @@ absl::StatusOr EmitBroadcast(ImplicitLocOpBuilder& b, // Emit sequence of instructions using compatible tiling ordered producers // before consumers. absl::StatusOr EmitScope( - ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + EmitterLocOpBuilder b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const TritonFusionAnalysis* analysis, const Side& side, absl::Span instructions, @@ -662,9 +671,14 @@ absl::StatusOr EmitScope( Value result; if (hlo->opcode() == HloOpcode::kConvert && hlo->operand(0)->shape().element_type() == S4) { - TF_ASSIGN_OR_RETURN( - auto unpacked, - EmitUnpackInt4(b, hlo, side.unpack_dim_idx, values[hlo->operand(0)])); + Value unpacked; + if (IsTritonInt4RewritesEnabled(*hlo)) { + unpacked = Cast(b, values[hlo->operand(0)], b.getI8Type()); + } else { + TF_ASSIGN_OR_RETURN(unpacked, + EmitUnpackInt4(b, hlo, side.unpack_dim_idx, + values[hlo->operand(0)])); + } std::vector operands({unpacked}); TF_ASSIGN_OR_RETURN(result, EmitElementwise(b, libdevice_path, device_info, *hlo, operands)); @@ -770,6 +784,12 @@ struct MatMulDims { int64_t n; int64_t k; + std::string ToString() const { + return absl::StrCat("MxNxK: ", m, "x", n, "x", k, + " contracting: lhs=", lhs_contracting_dim_idx, + " rhs=", rhs_contracting_dim_idx); + } + private: MatMulDims() = default; }; @@ -953,7 +973,7 @@ absl::Status ValidateMatMulConfig(const TritonGemmConfig& config, // } else { // return choices.back(); // } -absl::StatusOr EmitMultiSelect(ImplicitLocOpBuilder b, Value index, +absl::StatusOr EmitMultiSelect(EmitterLocOpBuilder& b, Value index, ValueRange limits, ValueRange choices) { TF_RET_CHECK(choices.size() - 1 == limits.size()); Value result = choices[0]; @@ -983,7 +1003,7 @@ class MatMulEmitterHelper { MatMulEmitterHelper(absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloDotInstruction* dot_instr, - ImplicitLocOpBuilder& b, Type index_ty, MatMulDims dims, + EmitterLocOpBuilder& b, Type index_ty, MatMulDims dims, const MatMulLaunchConfig& launch_config, const TritonFusionAnalysis& analysis) : b_(b), @@ -1374,7 +1394,8 @@ class MatMulEmitterHelper { if (dim_bound % (properties.block_size * properties.split_value) != 0) { boundary_checks.push_back(bounds.size() - 1); } - if (hlo->shape().element_type() == PrimitiveType::S4) { + if (hlo->shape().element_type() == PrimitiveType::S4 && + !IsTritonInt4RewritesEnabled(*hlo)) { // For s4 type we need to divide the minor dim bound by 2 because it // is the packing dimension. But if the minor dim has length == 1 then // the major dim stride is also 1 and it is the packing dimension. @@ -1428,7 +1449,8 @@ class MatMulEmitterHelper { b_.create(Cst(offset_batch), ConvertScalar(pid_batch)), batch_stride); - if (hlo->shape().element_type() == PrimitiveType::S4) { + if (hlo->shape().element_type() == PrimitiveType::S4 && + !IsTritonInt4RewritesEnabled(*hlo)) { pid_offset_batch = b_.create(pid_offset_batch, Cst(2)); } base = AddPtr(b_, base, pid_offset_batch); @@ -1453,11 +1475,35 @@ class MatMulEmitterHelper { b_.create(base, bounds, strides, tensor_offsets, block_dims, dim_order) .getResult()); + if (hlo->shape().element_type() == PrimitiveType::S4 && + IsTritonInt4RewritesEnabled(*hlo)) { + tensor_ptr.getDefiningOp()->setAttr("packed_dim", GetPackedDimAttr(side)); + } tensor_ptr = b_.create(tensor_ptr.getType(), tensor_ptr, block_offsets); return tensor_ptr; } + // Naive implementation of the packed_dim attribute for the int4 tensors. + // It doesn't take into account different layout schemes. + mlir::IntegerAttr GetPackedDimAttr(const Side& side) const { + int packed_dim = 0; + if (side.scope == TritonFusionAnalysis::Scope::LHS) { + if (dims_.lhs_contracting_dim_idx > dims_.lhs_noncontracting_dim_idx) { + packed_dim = 0; + } else { + packed_dim = 1; + } + } else if (side.scope == TritonFusionAnalysis::Scope::RHS) { + if (dims_.rhs_contracting_dim_idx > dims_.rhs_noncontracting_dim_idx) { + packed_dim = 1; + } else { + packed_dim = 0; + } + } + return b_.getI32IntegerAttr(packed_dim); + } + private: // Extend int32 indexes to int64, if necessary. Value ConvertScalar(Value value) { @@ -1471,7 +1517,7 @@ class MatMulEmitterHelper { Value Cst32(int32_t v) { return CreateConst(b_, i32_ty_, v); } Value Cst64(int64_t v) { return CreateConst(b_, i64_ty_, v); } - ImplicitLocOpBuilder& b_; + EmitterLocOpBuilder& b_; absl::string_view libdevice_path_; const se::DeviceDescription& device_info_; const HloDotInstruction* dot_instr_; @@ -1531,7 +1577,7 @@ ConstHloInstructionSet ScopeInputs(const TritonFusionAnalysis& analysis, // Truncates |input| of F32 type to the number representable in Bf16 toward // zero. // It is used for Emit6xBfloat16MatMul. -Value TruncateToBF16TowardsZero(ImplicitLocOpBuilder& b, Value input) { +Value TruncateToBF16TowardsZero(EmitterLocOpBuilder& b, Value input) { ShapedType input_type = mlir::dyn_cast(input.getType()); Type input_type_as_i32 = input_type.clone(b.getI32Type()); Value input_as_i32 = b.create(input_type_as_i32, input); @@ -1544,14 +1590,14 @@ Value TruncateToBF16TowardsZero(ImplicitLocOpBuilder& b, Value input) { // Finds the middle 8 bits of |input|'s mantissa. // It is used for Emit6xBfloat16MatMul. -Value SoftMiddleEight(ImplicitLocOpBuilder& b, Value input) { +Value SoftMiddleEight(EmitterLocOpBuilder& b, Value input) { Value high = TruncateToBF16TowardsZero(b, input); return b.create(input, high); } // Finds the low 8 bits of |input|'s mantissa. // It is used for Emit6xBfloat16MatMul. -Value SoftLowEight(ImplicitLocOpBuilder& b, Value input) { +Value SoftLowEight(EmitterLocOpBuilder& b, Value input) { // Find the middle bits of the middle bits, and these are the low eight // bits. return SoftMiddleEight(b, SoftMiddleEight(b, input)); @@ -1559,13 +1605,13 @@ Value SoftLowEight(ImplicitLocOpBuilder& b, Value input) { // Rounds |input| to BF16 type. // It is used for Emit6xBfloat16MatMul. -Value RoundToBF16(ImplicitLocOpBuilder& b, Value input) { +Value RoundToBF16(EmitterLocOpBuilder& b, Value input) { return Cast(b, input, b.getBF16Type()); } // Checks |input| is finite f32 (not Nan and not infinite). // It is used for Emit6xBfloat16MatMul and Emit3xBfloat16MatMul. -Value CheckFiniteF32(ImplicitLocOpBuilder& b, Value input) { +Value CheckFiniteF32(EmitterLocOpBuilder& b, Value input) { Value positive_inf = CreateConst( b, b.getF32Type(), std::numeric_limits::infinity(), mlir::cast(input.getType()).getShape()); @@ -1575,7 +1621,7 @@ Value CheckFiniteF32(ImplicitLocOpBuilder& b, Value input) { // Leverages BF16 datatype for F32 matmul computation. It follows the guidance // from https://arxiv.org/pdf/1904.06376.pdf. -absl::StatusOr Emit6xBfloat16MatMul(ImplicitLocOpBuilder& b, Value lhs, +absl::StatusOr Emit6xBfloat16MatMul(EmitterLocOpBuilder& b, Value lhs, Value rhs, Value acc) { Type f32 = b.getF32Type(); TF_RET_CHECK(mlir::cast(lhs.getType()).getElementType() == f32); @@ -1623,7 +1669,7 @@ absl::StatusOr Emit6xBfloat16MatMul(ImplicitLocOpBuilder& b, Value lhs, // Compute F32 matmul with 3 BF16 dots. It is less accurate than // Emit6xBfloat16MatMul. -absl::StatusOr Emit3xBfloat16MatMul(ImplicitLocOpBuilder& b, Value lhs, +absl::StatusOr Emit3xBfloat16MatMul(EmitterLocOpBuilder& b, Value lhs, Value rhs, Value acc) { Type f32 = b.getF32Type(); TF_RET_CHECK(mlir::cast(lhs.getType()).getElementType() == f32); @@ -1690,7 +1736,7 @@ mt::InputPrecision InferDotPrecision(const HloDotInstruction* dot_instr) { } bool Is6xBfloat16MatMul(const HloDotInstruction* dot_instr, - mlir::OpBuilder& builder, Value dot_input_lhs, + EmitterLocOpBuilder& b, Value dot_input_lhs, Value dot_input_rhs, const se::DeviceDescription& device_info) { const PrecisionConfig::Algorithm algorithm = @@ -1698,7 +1744,7 @@ bool Is6xBfloat16MatMul(const HloDotInstruction* dot_instr, if (algorithm == PrecisionConfig::ALG_UNSET) { const HloModule* hlo_module = dot_instr->GetModule(); - Type f32 = builder.getF32Type(); + Type f32 = b.getF32Type(); return hlo_module->config() .debug_options() .xla_gpu_enable_bf16_6way_gemm() && @@ -1712,7 +1758,7 @@ bool Is6xBfloat16MatMul(const HloDotInstruction* dot_instr, } bool Is3xBfloat16MatMul(const HloDotInstruction* dot_instr, - mlir::OpBuilder& builder, Value dot_input_lhs, + EmitterLocOpBuilder& b, Value dot_input_lhs, Value dot_input_rhs, const se::DeviceDescription& device_info) { const PrecisionConfig::Algorithm algorithm = @@ -1720,7 +1766,7 @@ bool Is3xBfloat16MatMul(const HloDotInstruction* dot_instr, if (algorithm == PrecisionConfig::ALG_UNSET) { const HloModule* hlo_module = dot_instr->GetModule(); - Type f32 = builder.getF32Type(); + Type f32 = b.getF32Type(); return hlo_module->config() .debug_options() .xla_gpu_enable_bf16_3way_gemm() && @@ -1772,7 +1818,7 @@ absl::Status CheckGemmTilingComplexityHeuristic( class Scopes { public: - Scopes(ImplicitLocOpBuilder& b, const HloInstruction* dot_instr, + Scopes(EmitterLocOpBuilder& b, const HloInstruction* dot_instr, const TritonFusionAnalysis& analysis, const MatMulDims& dims, const TritonGemmConfig& config, const MatMulLaunchConfig launch_config, bool is_sparse) @@ -1807,7 +1853,8 @@ class Scopes { int lhs_non_contracting_block_size = config.block_m; int lhs_contracting_block_size = config.block_k; int lhs_unpack_bound_idx = 0; - if (is_int4_param(analysis, TritonFusionAnalysis::Scope::LHS)) { + if (!IsTritonInt4RewritesEnabled(*dot_instr) && + is_int4_param(analysis, TritonFusionAnalysis::Scope::LHS)) { auto minor_dim = std::max(dims.lhs_contracting_dim_idx, dims.lhs_noncontracting_dim_idx); auto minor_bound = analysis @@ -1845,7 +1892,8 @@ class Scopes { int rhs_contracting_block_size = config.block_k; int rhs_non_contracting_block_size = config.block_n; int rhs_unpack_bound_idx = 0; - if (is_int4_param(analysis, TritonFusionAnalysis::Scope::RHS)) { + if (!IsTritonInt4RewritesEnabled(*dot_instr) && + is_int4_param(analysis, TritonFusionAnalysis::Scope::RHS)) { auto minor_dim = std::max(dims.rhs_contracting_dim_idx, dims.rhs_noncontracting_dim_idx); auto minor_bound = analysis @@ -1929,7 +1977,7 @@ class Scopes { enum MaskExpandDimension { kMajor = 0, kMinor = 1 }; -Value EmitMaskOnInput(ImplicitLocOpBuilder& b, +Value EmitMaskOnInput(EmitterLocOpBuilder& b, MaskExpandDimension expand_along_dimension, Value input, int dim_k_denom, Value k, int64_t dims_k, int64_t block_k, Value pid_k, int64_t other_dim_block_size) { @@ -1969,8 +2017,8 @@ Value EmitMaskOnInput(ImplicitLocOpBuilder& b, auto if_op = b.create( is_last_tile_cond, /*thenBranch=*/ - [&](mlir::OpBuilder& builder, mlir::Location loc) { - ImplicitLocOpBuilder b(loc, builder); + [&, &parent_builder = b](mlir::OpBuilder& builder, mlir::Location loc) { + EmitterLocOpBuilder b(loc, builder, parent_builder.annotate_loc()); // Make a range vector from 0 to block_k. auto range_from_0_to_k = Range(b, block_k_size); if (pid_k != nullptr) { @@ -2005,10 +2053,10 @@ Value EmitMaskOnInput(ImplicitLocOpBuilder& b, b.create(mlir::ValueRange(result)); }, /*elseBranch=*/ - [&](mlir::OpBuilder& builder, mlir::Location loc) { + [&, &parent_builder = b](mlir::OpBuilder& builder, mlir::Location loc) { // We don't need to mask anything but we need to expand the input. // Otherwise Triton complains. - ImplicitLocOpBuilder b(loc, builder); + EmitterLocOpBuilder b(loc, builder, parent_builder.annotate_loc()); b.create(mlir::ValueRange(expanded_input)); }); return if_op.getResult(0); @@ -2019,7 +2067,7 @@ Value EmitMaskOnInput(ImplicitLocOpBuilder& b, // Use tiling and execution parameters from 'config'. BlockLevelParameters are // ignored. // Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n]. -absl::Status EmitMatMul(mlir::OpBuilder builder, +absl::Status EmitMatMul(EmitterLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloFusionInstruction* fusion, @@ -2064,7 +2112,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, ShapeUtil::ElementsIn(dot_instr->operand(0)->shape()) > INT_MAX || ShapeUtil::ElementsIn(dot_instr->operand(1)->shape()) > INT_MAX || ShapeUtil::ElementsIn(dot_instr->shape()) * config.split_k > INT_MAX; - Type index_ty = builder.getIntegerType(use_64bit_indexing ? 64 : 32); + Type index_ty = b.getIntegerType(use_64bit_indexing ? 64 : 32); const HloInstruction* root = dot_instr->parent()->root_instruction(); TF_RET_CHECK(!root->shape().IsTuple()); @@ -2072,8 +2120,6 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, // We'll be creating a lot of instructions from a single dot, use an // implicit loc builder so we don't have to pass around the location all the // time. - auto loc = mlir::NameLoc::get(builder.getStringAttr(dot_instr->name())); - ImplicitLocOpBuilder b(loc, builder); TF_RETURN_IF_ERROR(ValidateMatMulConfig(config, *dot_instr)); const int split_k = config.split_k; diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.h b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.h index 540f511ec03061..e56eb7de099a9e 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.h +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.h @@ -19,9 +19,9 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "mlir/IR/Builders.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/utils/hlo_traversal.h" +#include "xla/service/gpu/fusions/emitter_loc_op_builder.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" @@ -39,7 +39,7 @@ absl::StatusOr GetMatMulLaunchDimensions( // Use tiling and execution parameters from 'config'. BlockLevelParameters are // ignored. // Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n]. -absl::Status EmitMatMul(mlir::OpBuilder builder, +absl::Status EmitMatMul(EmitterLocOpBuilder& builder, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloFusionInstruction* fusion, diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul_stub.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul_stub.cc index 82ad657d247083..9ce1839b23d6dc 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul_stub.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul_stub.cc @@ -16,7 +16,14 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/utils/hlo_traversal.h" +#include "xla/service/gpu/fusions/emitter_loc_op_builder.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/stream_executor/device_description.h" namespace xla::gpu { @@ -28,7 +35,7 @@ absl::StatusOr GetMatMulLaunchDimensions( return absl::UnimplementedError("not supported for this build configuration"); } -absl::Status EmitMatMul(mlir::OpBuilder builder, +absl::Status EmitMatMul(EmitterLocOpBuilder& builder, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloFusionInstruction* fusion, diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc index e570cb8a8bb7b3..5030e2268ea12a 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc @@ -35,7 +35,6 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" @@ -44,6 +43,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_traversal.h" +#include "xla/service/gpu/fusions/emitter_loc_op_builder.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/model/symbolic_tile_analysis.h" @@ -61,7 +61,6 @@ namespace xla::gpu::ir_emitter_triton_internal { namespace { using ::llvm::SmallVector; -using ::mlir::ImplicitLocOpBuilder; using ::mlir::MLIRContext; using ::mlir::OpBuilder; using ::mlir::Type; @@ -134,7 +133,7 @@ TritonMakeTensorPtrTest::CreateAndTileParameterHloInstruction( } mlir::triton::FuncOp CreateTritonFunction( - ImplicitLocOpBuilder& b, const std::vector shape_sizes) { + EmitterLocOpBuilder& b, const std::vector shape_sizes) { auto fn = b.create<::mlir::triton::FuncOp>( "func", b.getFunctionType({::mlir::triton::PointerType::get( @@ -166,7 +165,7 @@ TritonMakeTensorPtrTest::CreateTestTensorPtr( llvm_ir::CreateMlirModuleOp(loc); builder.setInsertionPointToEnd(triton_module->getBody()); - ImplicitLocOpBuilder b(loc, builder); + EmitterLocOpBuilder b(loc, builder); auto fn = CreateTritonFunction(b, parent_shape); SmallVector tile_multi_index = ComputeDelinearizedTileIndex( diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc index 06e9369a3ee011..fb4ff691658ba0 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc @@ -889,13 +889,12 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(kSupportedDataTypes)), TwoPrimitiveTypesToString); -class TritonSoftmaxTest : public GpuCodegenTest, - public ::testing::WithParamInterface { +class TritonNormalizationTest + : public GpuCodegenTest, + public ::testing::WithParamInterface { public: DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); - debug_options - .set_xla_gpu_experimental_enable_triton_softmax_priority_fusion(true); // TODO(b/38354253): Remove once HloTestBase does not remove constant // folding. debug_options.clear_xla_disable_hlo_passes(); @@ -903,7 +902,7 @@ class TritonSoftmaxTest : public GpuCodegenTest, } }; -TEST_P(TritonSoftmaxTest, CanFuseAndEmitExactSoftmax) { +TEST_P(TritonNormalizationTest, CanFuseAndEmitExactSoftmax) { PrimitiveType data_type = GetParam(); if (data_type == F16) { @@ -967,7 +966,7 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_P(TritonSoftmaxTest, CanFuseAndEmitFirstSoftmaxDiamond) { +TEST_P(TritonNormalizationTest, CanFuseAndEmitFirstSoftmaxDiamond) { PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( HloModule softmax @@ -1022,7 +1021,7 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_P(TritonSoftmaxTest, CanFuseAndEmitSoftmaxDiamondWithSmallRows) { +TEST_P(TritonNormalizationTest, CanFuseAndEmitSoftmaxDiamondWithSmallRows) { PrimitiveType data_type = GetParam(); constexpr absl::string_view kHloTextTemplate = R"( HloModule softmax @@ -1059,7 +1058,7 @@ ENTRY main { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec(/*aabs=*/0, /*arel=*/0))); } -TEST_F(TritonSoftmaxTest, CanFuseAndEmitDiamondWithBF16Converts) { +TEST_F(TritonNormalizationTest, CanFuseAndEmitDiamondWithBF16Converts) { const std::string hlo_text = R"( HloModule softmax max_computation { @@ -1094,7 +1093,7 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_P(TritonSoftmaxTest, +TEST_P(TritonNormalizationTest, CanFuseAndEmitDiamondWithMultipleBroadcastDimensions) { PrimitiveType data_type = GetParam(); @@ -1148,7 +1147,7 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_P(TritonSoftmaxTest, +TEST_P(TritonNormalizationTest, CanFuseAndEmitSoftmaxWithIntermediateUnaryElementwise) { PrimitiveType data_type = GetParam(); @@ -1215,7 +1214,7 @@ ENTRY main { } TEST_P( - TritonSoftmaxTest, + TritonNormalizationTest, CanFuseAndEmitTwoDiamondsWithSecondDiamondProducerEqualToFirstDiamondRoot) { PrimitiveType data_type = GetParam(); @@ -1276,7 +1275,7 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_P(TritonSoftmaxTest, +TEST_P(TritonNormalizationTest, CanFuseAndEmitDiamondWithTrailingUnaryElementwiseAtTheRoot) { PrimitiveType data_type = GetParam(); @@ -1331,7 +1330,8 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_P(TritonSoftmaxTest, CanFuseAndEmitDiamondWithUnaryElementwisePrefix) { +TEST_P(TritonNormalizationTest, + CanFuseAndEmitDiamondWithUnaryElementwisePrefix) { PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -1385,7 +1385,7 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_P(TritonSoftmaxTest, +TEST_P(TritonNormalizationTest, CanFuseAndEmitSoftmaxDiamondWithLastDimensionBitcastAfterReduce) { PrimitiveType data_type = GetParam(); @@ -1442,7 +1442,7 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_P(TritonSoftmaxTest, +TEST_P(TritonNormalizationTest, CanFuseAndEmitConvertInvolvingBF16InputIntoSoftmaxDiamondCorrectly) { PrimitiveType data_type = GetParam(); @@ -1495,7 +1495,7 @@ ENTRY main { } TEST_P( - TritonSoftmaxTest, + TritonNormalizationTest, CanFuseAndEmitBinaryElementwiseProducerIntoDiamondWhenBothOperandsAreTheSame) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); @@ -1551,7 +1551,7 @@ ENTRY main { } TEST_P( - TritonSoftmaxTest, + TritonNormalizationTest, CanFuseAndEmitIntermediateBinaryElementwiseWithinDiamondWhenBothOperandsAreTheSame) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); @@ -1607,7 +1607,7 @@ ENTRY main { } TEST_P( - TritonSoftmaxTest, + TritonNormalizationTest, CanFuseAndEmitBinaryElementwiseWhenBothOperandsAreTheSameBetweenDiamonds) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); @@ -1674,7 +1674,7 @@ ENTRY main { } TEST_P( - TritonSoftmaxTest, + TritonNormalizationTest, CanFuseAndEmitBinaryElementwiseConsumerWhereBothOperandsAreTheSameIntoDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); @@ -1736,7 +1736,7 @@ ENTRY main { } TEST_P( - TritonSoftmaxTest, + TritonNormalizationTest, CanFuseAndEmitTwoBinaryElementwiseWhereBothOperandsAreTheSameBetweenDiamonds) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); @@ -1799,7 +1799,7 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_P(TritonSoftmaxTest, DiamondEmitterIsNumericallyStable) { +TEST_P(TritonNormalizationTest, DiamondEmitterIsNumericallyStable) { PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -1833,7 +1833,7 @@ ENTRY main { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec(/*aabs=*/0, /*arel=*/0))); } -TEST_P(TritonSoftmaxTest, CanFuseAndEmitRMSNormDiamond) { +TEST_P(TritonNormalizationTest, CanFuseAndEmitRMSNormDiamond) { PrimitiveType data_type = GetParam(); const std::string hlo_text_template = R"( @@ -1896,7 +1896,7 @@ ENTRY main.30 { } TEST_P( - TritonSoftmaxTest, + TritonNormalizationTest, CanFuseAndEmitBinaryElementwiseWhereTheFirstOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); @@ -1959,7 +1959,7 @@ ENTRY main { } TEST_P( - TritonSoftmaxTest, + TritonNormalizationTest, CanFuseAndEmitBinaryElementwiseWhereTheSecondOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); @@ -2022,7 +2022,7 @@ ENTRY main { } TEST_P( - TritonSoftmaxTest, + TritonNormalizationTest, CanFuseAndEmitBinaryElementwiseWhereTheFirstOperandIsASplatConstantWithinDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); @@ -2081,7 +2081,7 @@ ENTRY main { } TEST_P( - TritonSoftmaxTest, + TritonNormalizationTest, CanFuseAndEmitBinaryElementwiseConsumerWhereTheFirstOperandIsASplatConstantIntoDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); @@ -2139,7 +2139,7 @@ ENTRY main { } TEST_P( - TritonSoftmaxTest, + TritonNormalizationTest, CanFuseAndEmitBinaryElementwiseProducerWhereTheFirstOperandIsASplatConstantIntoDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); @@ -2198,7 +2198,7 @@ ENTRY main { } TEST_P( - TritonSoftmaxTest, + TritonNormalizationTest, CanFuseAndEmitBinaryElementwiseOperationWhereOneOperandIsASharedSplatProducerIntoDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); @@ -2242,10 +2242,10 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -INSTANTIATE_TEST_SUITE_P(TritonSoftmaxTestSuite, TritonSoftmaxTest, +INSTANTIATE_TEST_SUITE_P(TritonNormalizationTestSuite, TritonNormalizationTest, ::testing::Values(F32, F16, BF16)); -TEST_F(TritonSoftmaxTest, CanFuseAndEmitTritonSoftmaxWithTwoParameters) { +TEST_F(TritonNormalizationTest, CanFuseAndEmitTritonSoftmaxWithTwoParameters) { const std::string hlo_text = R"( HloModule layernorm @@ -2285,7 +2285,7 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_F(TritonSoftmaxTest, CanFuseAndEmitTritonSoftmaxWithNonBatchReduce) { +TEST_F(TritonNormalizationTest, CanFuseAndEmitTritonSoftmaxWithNonBatchReduce) { const std::string hlo_text = R"( HloModule layernorm diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub.cc index 9a8f45539b1304..c3ef447d264561 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub.cc @@ -24,7 +24,6 @@ limitations under the License. #include "llvm/IR/Module.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" #include "mlir/IR/Value.h" @@ -32,6 +31,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/emitter_loc_op_builder.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/model/tiled_hlo_instruction.h" @@ -74,14 +74,7 @@ absl::StatusOr CompileTritonToLLVM( const se::DeviceDescription& device_info, const BlockLevelParameters& block_level_parameters, mlir::ModuleOp triton_module, llvm::Module* llvm_module, - mlir::MLIRContext& mlir_context, bool emit_kernel) { - return absl::UnimplementedError("not supported for this build configuration"); -} - -absl::Status CreateTritonPipeline( - mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, - const BlockLevelParameters& block_level_parameters, - ::mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info) { + mlir::MLIRContext& mlir_context, bool is_xla_fusion, bool emit_kernel) { return absl::UnimplementedError("not supported for this build configuration"); } @@ -93,13 +86,13 @@ std::string GetLibdevicePath(const HloModuleConfig& hlo_config, namespace ir_emitter_triton_internal { llvm::SmallVector ComputeDelinearizedTileIndex( - mlir::ImplicitLocOpBuilder& b, + EmitterLocOpBuilder& b, absl::Span num_output_tiles_per_dim) { return {}; } absl::StatusOr CreateMakeTensorPtrOp( - mlir::ImplicitLocOpBuilder& b, mlir::ValueRange tile_multi_index, + EmitterLocOpBuilder& b, mlir::ValueRange tile_multi_index, const TiledHloInstruction& tiled_hlo, mlir::Value parent_base_ptr) { return absl::UnimplementedError("not supported for this build configuration"); } diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub_test.cc index 8466ac7a70d52c..dbaecf015441dd 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub_test.cc @@ -14,8 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include "mlir/IR/Builders.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" @@ -23,6 +21,8 @@ limitations under the License. #include "xla/hlo/utils/hlo_traversal.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/service/gpu/fusions/emitter_loc_op_builder.h" +#include "xla/service/gpu/fusions/triton/compilation_pipeline.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.h" #include "xla/service/gpu/model/tiled_hlo_instruction.h" @@ -44,16 +44,19 @@ TEST(TritonStub, CallStubApi) { LoadMlirDialectsForTriton(context); EXPECT_FALSE(TritonWrapper({}, nullptr, {}, {}, {}, nullptr, context).ok()); EXPECT_FALSE(CreateTritonModule({}, nullptr, {}, {}, context).ok()); - EXPECT_FALSE( - CompileTritonToLLVM({}, {}, {}, {}, {}, nullptr, context, {}).ok()); + EXPECT_FALSE(CompileTritonToLLVM({}, {}, {}, {}, {}, nullptr, context, + /*is_xla_fusion=*/true, {}) + .ok()); mlir::OpPassManager pm; ::mlir::triton::nvidia_gpu::ClusterInfo cluster_info; - EXPECT_FALSE(CreateTritonPipeline(pm, {}, {}, cluster_info).ok()); + EXPECT_FALSE(CreateTritonPipeline(&pm, "", 1, 1, 1, cluster_info, + /*is_xla_fusion=*/true) + .ok()); EXPECT_EQ(GetLibdevicePath({}, {}), ""); - mlir::ImplicitLocOpBuilder builder(mlir::UnknownLoc::get(&context), &context); + EmitterLocOpBuilder builder(mlir::UnknownLoc::get(&context), &context); EXPECT_TRUE( ir_emitter_triton_internal::ComputeDelinearizedTileIndex(builder, {}) @@ -74,7 +77,7 @@ TEST(TritonStub, CallLegacyMatMulApis) { EXPECT_FALSE(GetMatMulLaunchDimensions({}, *adaptor.get(), {}, {}).ok()); mlir::MLIRContext context; - mlir::OpBuilder builder(&context); + EmitterLocOpBuilder builder(mlir::UnknownLoc::get(&context), &context); EXPECT_FALSE(EmitMatMul(builder, {}, {}, nullptr, {}, {}).ok()); } diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_test.cc index 5d0c696ccc9807..0e5a2ffe7a6a60 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_test.cc @@ -556,18 +556,18 @@ TEST_P(ReduceTest, IsTritonSupportedReduction) { const std::string kHloTestTemplate = absl::Substitute(R"( add { - Arg_0 = $0[] parameter(0) - Arg_1 = $0[] parameter(1) - ROOT add = $0[] add(Arg_0, Arg_1) + Arg_0 = $$0[] parameter(0) + Arg_1 = $$0[] parameter(1) + ROOT add = $$0[] add(Arg_0, Arg_1) } ENTRY triton_computation { - parameter_0 = $0[125,127] parameter(0) - constant_0 = $0[] constant($1) - ROOT reduce = $0[125] reduce(parameter_0, constant_0), + parameter_0 = $$0[125,127] parameter(0) + constant_0 = $$0[] constant($0) + ROOT reduce = $$0[125] reduce(parameter_0, constant_0), dimensions={1}, to_apply=add })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + dtype_is_complex ? "(0, 0)" : "0"); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -603,18 +603,18 @@ TEST_P( const std::string kHloTestTemplate = absl::Substitute(R"( add { - Arg_0 = $0[] parameter(0) - Arg_1 = $0[] parameter(1) - ROOT add = $0[] add(Arg_0, Arg_1) + Arg_0 = $$0[] parameter(0) + Arg_1 = $$0[] parameter(1) + ROOT add = $$0[] add(Arg_0, Arg_1) } ENTRY triton_computation { - parameter_0 = $0[2,125,127] parameter(0) - constant_0 = $0[] constant($1) - ROOT reduce = $0[2] reduce(parameter_0, constant_0), + parameter_0 = $$0[2,125,127] parameter(0) + constant_0 = $$0[] constant($0) + ROOT reduce = $$0[2] reduce(parameter_0, constant_0), dimensions={1,2}, to_apply=add })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + dtype_is_complex ? "(0, 0)" : "0"); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -628,17 +628,17 @@ TEST_P(ReduceTest, IsTritonSupportedReduceWithNonLastReduceDimension) { const std::string kHloTestTemplate = absl::Substitute(R"( add { - Arg_0 = $0[] parameter(0) - Arg_1 = $0[] parameter(1) - ROOT add = $0[] add(Arg_0, Arg_1) + Arg_0 = $$0[] parameter(0) + Arg_1 = $$0[] parameter(1) + ROOT add = $$0[] add(Arg_0, Arg_1) } ENTRY triton_computation { - parameter_0 = $0[125,127] parameter(0) - constant_0 = $0[] constant($1) - ROOT reduce = $0[127] reduce(parameter_0, constant_0), dimensions={0}, to_apply=add + parameter_0 = $$0[125,127] parameter(0) + constant_0 = $$0[] constant($0) + ROOT reduce = $$0[127] reduce(parameter_0, constant_0), dimensions={0}, to_apply=add })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + dtype_is_complex ? "(0, 0)" : "0"); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -653,24 +653,24 @@ TEST_P(ReduceTest, const std::string kHloTestTemplate = absl::Substitute(R"( add { - Arg_0 = $0[] parameter(0) - Arg_1 = $0[] parameter(1) - Arg_2 = $0[] parameter(2) - Arg_3 = $0[] parameter(3) - add_0 = $0[] add(Arg_0, Arg_2) - add_1 = $0[] add(Arg_1, Arg_3) - ROOT pair = ($0[], $0[]) tuple(add_0, add_1) + Arg_0 = $$0[] parameter(0) + Arg_1 = $$0[] parameter(1) + Arg_2 = $$0[] parameter(2) + Arg_3 = $$0[] parameter(3) + add_0 = $$0[] add(Arg_0, Arg_2) + add_1 = $$0[] add(Arg_1, Arg_3) + ROOT pair = ($$0[], $$0[]) tuple(add_0, add_1) } ENTRY triton_computation { - parameter_0 = $0[125,127] parameter(0) - constant_0 = $0[] constant($1) - tuple = ($0[125], $0[125]) reduce( + parameter_0 = $$0[125,127] parameter(0) + constant_0 = $$0[] constant($0) + tuple = ($$0[125], $$0[125]) reduce( parameter_0, parameter_0, constant_0, constant_0), dimensions={1}, to_apply=add - ROOT reduce = $0[125] get-tuple-element(tuple), index=0 + ROOT reduce = $$0[125] get-tuple-element(tuple), index=0 })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + dtype_is_complex ? "(0, 0)" : "0"); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -705,18 +705,18 @@ TEST_P(ReduceTest, UnsupportedReductionComputationFailsGracefullyWithTriton) { const std::string kHloTestTemplate = absl::Substitute(R"( custom_call { - Arg_0 = $0[] parameter(0) - Arg_1 = $0[] parameter(1) - ROOT custom_call = $0[] custom-call(Arg_0, Arg_1), custom_call_target="foo" + Arg_0 = $$0[] parameter(0) + Arg_1 = $$0[] parameter(1) + ROOT custom_call = $$0[] custom-call(Arg_0, Arg_1), custom_call_target="foo" } ENTRY triton_computation { - parameter_0 = $0[125,127] parameter(0) - constant_0 = $0[] constant($1) - ROOT reduce = $0[125] reduce(parameter_0, constant_0), + parameter_0 = $$0[125,127] parameter(0) + constant_0 = $$0[] constant($0) + ROOT reduce = $$0[125] reduce(parameter_0, constant_0), dimensions={1}, to_apply=custom_call })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + dtype_is_complex ? "(0, 0)" : "0"); TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); @@ -744,18 +744,18 @@ TEST_P(ReductionComputationTest, DifferentBinaryOps) { const std::string kHloTestTemplate = absl::Substitute( R"( reduce_computation { - Arg_0 = $0[] parameter(0) - Arg_1 = $0[] parameter(1) - ROOT output = $0[] $1(Arg_0, Arg_1) + Arg_0 = $$0[] parameter(0) + Arg_1 = $$0[] parameter(1) + ROOT output = $$0[] $0(Arg_0, Arg_1) } ENTRY triton_computation { - parameter_0 = $0[125,127] parameter(0) - constant_0 = $0[] constant($2) - ROOT reduce = $0[125] reduce(parameter_0, constant_0), + parameter_0 = $$0[125,127] parameter(0) + constant_0 = $$0[] constant($1) + ROOT reduce = $$0[125] reduce(parameter_0, constant_0), dimensions={1}, to_apply=reduce_computation })", - "$0", HloOpcodeString(opcode), dtype_is_complex ? "(0, 0)" : "0"); + HloOpcodeString(opcode), dtype_is_complex ? "(0, 0)" : "0"); TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( @@ -1119,9 +1119,9 @@ TEST_P(ConstantTest, ConstantEffectiveScalar) { const std::string kHloTestTemplate = absl::Substitute(R"( ENTRY triton_computation { - ROOT const = $0[1,1] constant({{$1}}) + ROOT const = $$0[1,1] constant({{$0}}) })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + dtype_is_complex ? "(0, 0)" : "0"); TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( kHloTestTemplate, data_type, @@ -1137,9 +1137,9 @@ TEST_P(ConstantTest, Constant2D) { const std::string kHloTestTemplate = absl::Substitute(R"( ENTRY triton_computation { - ROOT const = $0[3,3] constant({{$1,$1,$1},{$1,$1,$1},{$1,$1,$1}}) + ROOT const = $$0[3,3] constant({{$0,$0,$0},{$0,$0,$0},{$0,$0,$0}}) })", - "$0", dtype_is_complex ? "(0, 0)" : "0"); + dtype_is_complex ? "(0, 0)" : "0"); TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( kHloTestTemplate, data_type, diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_test_utils.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_test_utils.cc index 881fdbc89e634f..7b80af7ab5858e 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_test_utils.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_test_utils.cc @@ -39,6 +39,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/transforms/simplifiers/float_normalization.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/primitive_util.h" @@ -50,9 +52,7 @@ limitations under the License. #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_int4_passes.cc b/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_int4_passes.cc new file mode 100644 index 00000000000000..dba98fd26f1b27 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_int4_passes.cc @@ -0,0 +1,509 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +namespace mlir::triton::xla { + +using ::xla::llvm_ir::DumpToString; + +namespace mt = ::mlir::triton; +namespace ma = ::mlir::arith; + +#define GEN_PASS_DEF_LOADINT4REWRITEPASS +#include "xla/service/gpu/fusions/triton/xla_triton_passes.h.inc" + +class I4ToI8Converter : public TypeConverter { + public: + Type convertIntegerType(IntegerType type) const { + VLOG(2) << "I4ToI8Converter: converting IntegerType for " + << DumpToString(type); + if (type.getWidth() == 4) { + auto new_type = IntegerType::get(type.getContext(), 8); + VLOG(2) << " -> I4ToI8Converter: IntegerType converted to " + << DumpToString(new_type); + return new_type; + } + return type; + } + + Type convertRankedTensorType(RankedTensorType type) const { + VLOG(2) << "I4ToI8Converter: RankedTensorType for " << DumpToString(type); + if (!type.getElementType().isInteger(4)) return type; + + auto shape = type.getShape(); + if (shape[0] == ShapedType::kDynamic) + return type; // Only handle static shapes for simplicity + + std::vector new_shape = shape; + new_shape[new_shape.size() - packed_dim_idx_ - 1] /= 2; + + auto new_type = RankedTensorType::get( + new_shape, IntegerType::get(type.getContext(), 8)); + VLOG(2) << " -> I4ToI8Converter: RankedTensorType converted to " + << DumpToString(new_type); + return new_type; + } + + PointerType convertPointerType(PointerType ptr_type) const { + VLOG(2) << "I4ToI8Converter: converting PointerType for " + << DumpToString(ptr_type); + auto pointee_type = ptr_type.getPointeeType(); + auto new_pointee_type = convertType(pointee_type); + auto new_ptr_type = + PointerType::get(new_pointee_type, ptr_type.getAddressSpace()); + VLOG(2) << " -> I4ToI8Converter: converted PointerType to " + << DumpToString(new_ptr_type); + return new_ptr_type; + } + + Type convertFunctionType(FunctionType func_type) const { + VLOG(2) << "I4ToI8Converter: converting FunctionType " + << DumpToString(func_type); + + SmallVector inputs; + if (failed(convertTypes(func_type.getInputs(), inputs))) return func_type; + + SmallVector results; + if (failed(convertTypes(func_type.getResults(), results))) return func_type; + + auto new_func_type = + FunctionType::get(func_type.getContext(), inputs, results); + VLOG(2) << " -> I4ToI8Converter: converted FunctionType to " + << DumpToString(new_func_type); + return new_func_type; + } + + explicit I4ToI8Converter(int packed_dim_idx) + : packed_dim_idx_(packed_dim_idx) { + // Passthrough for other types. + addConversion([](Type type) { + VLOG(2) << "I4ToI8Converter: passthrough for " << DumpToString(type); + return type; + }); + + // Convert i4 to i8 + addConversion( + [this](IntegerType type) { return this->convertIntegerType(type); }); + + // Convert tensor to tensor + addConversion([this](RankedTensorType type) { + return this->convertRankedTensorType(type); + }); + + // Convert !tt.ptr> to !tt.ptr> + addConversion( + [this](PointerType type) { return this->convertPointerType(type); }); + + // Convert function type to function type + addConversion( + [this](FunctionType type) { return this->convertFunctionType(type); }); + } + int packed_dim_idx() const { return packed_dim_idx_; } + + private: + int packed_dim_idx_; +}; + +// Divides a value by an integer constant. +Value div(ConversionPatternRewriter &r, Value value, int64_t constant) { + auto const_attr = r.getIntegerAttr(value.getType(), constant); + auto const_op = r.template create(value.getLoc(), const_attr); + return r.template create(value.getLoc(), value, const_op); +} + +// Divides a value by an integer constant. +Value ceilDiv(ConversionPatternRewriter &r, Value value, int64_t constant) { + auto const_attr = r.getIntegerAttr(value.getType(), constant); + auto const_op = r.template create(value.getLoc(), const_attr); + return r.template create(value.getLoc(), value, const_op); +} + +// Returns the integer value of a constant op. +// Returns std::nullopt if the value is not a constant op or the constant op +// does not have an integer value. +std::optional GetConstValue(Value value) { + if (auto const_op = value.getDefiningOp()) { + if (auto attr = dyn_cast(const_op.getValue())) { + return attr.getInt(); + } + } + return std::nullopt; +} + +class MakeTensorPtrOpConversionPattern + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + MakeTensorPtrOpConversionPattern(const I4ToI8Converter &converter, + MLIRContext *context) + : OpConversionPattern(converter, context), + converter_(converter) {} + + LogicalResult matchAndRewrite( + MakeTensorPtrOp op, + OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &r) const override { + // Convert the tensor type using the TypeConverter + auto new_type = getTypeConverter()->convertType(op.getType()); + if (op.getType() == new_type) { + return r.notifyMatchFailure(op, "no conversion needed"); + } + + SmallVector shape = adaptor.getShape(); + int affected_dim_idx = shape.size() - 1 - converter_.packed_dim_idx(); + // The shape of the i8 tensor is half of the i4 tensor but at least 1. + shape[affected_dim_idx] = ceilDiv(r, shape[affected_dim_idx], 2); + + // The stride of the i8 tensor is half of the i4 tensor but at least 1. + SmallVector new_strides = adaptor.getStrides(); + for (int i = 0; i < new_strides.size(); ++i) { + new_strides[i] = ceilDiv(r, new_strides[i], 2); + } + + r.replaceOpWithNewOp( + op, new_type, adaptor.getBase(), shape, new_strides, + adaptor.getOffsets(), adaptor.getOrderAttr()); + + return success(); + } + + private: + const I4ToI8Converter &converter_; +}; + +class AddPtrOpConversionPattern : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + AddPtrOp op, OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &r) const override { + // Convert the tensor type using the TypeConverter + auto new_type = getTypeConverter()->convertType(op.getType()); + if (op.getType() == new_type) { + return r.notifyMatchFailure(op, "no conversion needed"); + } + + // The increment for the next stripe of tiles along K dimension should be + // twice smaller. + auto ptr = adaptor.getOperands()[0]; + auto offset = adaptor.getOperands()[1]; + auto new_offset = div(r, offset, 2); + + r.replaceOpWithNewOp(op, new_type, ptr, new_offset); + + return success(); + } +}; + +class AdvanceOpConversionPattern : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + AdvanceOpConversionPattern(const I4ToI8Converter &converter, + MLIRContext *context) + : OpConversionPattern(converter, context), + converter_(converter) {} + LogicalResult matchAndRewrite( + AdvanceOp op, typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &r) const override { + VLOG(2) << "AvanceOpConversionPattern: matching\n" + << DumpToString(static_cast(op.getOperation())); + // Convert the tensor type using the TypeConverter + auto new_type = converter_.convertType(op.getType()); + if (op.getType() == new_type) { + VLOG(2) << "AdvanceOpConversionPattern: no conversion needed for " + << DumpToString(op.getType()); + return r.notifyMatchFailure(op, "no conversion needed"); + } + SmallVector offsets = adaptor.getOffsets(); + int affected_dim_idx = offsets.size() - 1 - converter_.packed_dim_idx(); + offsets[affected_dim_idx] = div(r, offsets[affected_dim_idx], 2); + auto new_op = r.replaceOpWithNewOp(op, new_type, + adaptor.getPtr(), offsets); + VLOG(2) << "AdvanceOpConversionPattern: replaced " + << DumpToString(op.getOperation()) << " with " + << DumpToString(static_cast(new_op)); + return success(); + } + + private: + const I4ToI8Converter &converter_; +}; + +// The generic converter for the ops that requires only type conversion. +template +class OpTypeConversionPattern : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + OpTypeConversionPattern(const I4ToI8Converter &converter, + MLIRContext *context) + : OpConversionPattern(converter, context), + converter_(converter) {} + LogicalResult matchAndRewrite( + OpType op, typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &r) const override { + VLOG(2) << "OpTypeConversionPattern: matching\n" + << DumpToString(static_cast(op.getOperation())); + // Convert the tensor type using the TypeConverter + auto new_type = converter_.convertType(op.getType()); + if (op.getType() == new_type) { + VLOG(2) << "OpTypeConversionPattern: no conversion needed for " + << DumpToString(op.getType()); + return r.notifyMatchFailure(op, "no conversion needed"); + } + + r.replaceOpWithNewOp(op, new_type, adaptor.getOperands(), + op->getAttrs()); + return success(); + } + + private: + const I4ToI8Converter &converter_; +}; + +// The pattern converts the ExtSIOp that converts i4 tensor to i8 tensor to an +// unpack sequence that uses ShLIOp, ShRSIOp, JoinOp, TransOp and ReshapeOp to +// do the same thing. +class ExtSIInt4ToInt8Pattern : public OpConversionPattern { + public: + ExtSIInt4ToInt8Pattern(const I4ToI8Converter &converter, MLIRContext *context) + : OpConversionPattern(converter, context), + converter_(converter) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ma::ExtSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &r) const override { + VLOG(2) << "ExtSIInt4ToInt8Pattern: matching\n" + << DumpToString(static_cast(op)); + auto input_type = cast(op.getIn().getType()); + auto packed_type = converter_.convertType(input_type); + if (input_type == packed_type) { + return r.notifyMatchFailure(op, "no conversion needed"); + } + + // Make a new i8 tensor with the shape that is half of the int4 tensor. + auto loc = op.getLoc(); + + Value shift4_const = + r.create(loc, r.getIntegerAttr(r.getI8Type(), 4)); + Value shift4 = r.create(loc, packed_type, shift4_const); + Value shifted_lo = + r.create(loc, packed_type, adaptor.getIn(), shift4); + Value lo = r.create(loc, packed_type, shifted_lo, shift4); + Value hi = r.create(loc, packed_type, adaptor.getIn(), shift4); + Value hi_lo = r.create(loc, hi, lo); + if (converter_.packed_dim_idx() != 0) { + auto trans_attr = r.getDenseI32ArrayAttr({0, 2, 1}); + hi_lo = r.create(loc, hi_lo, trans_attr); + } + auto unpacked_type = input_type.clone(r.getI8Type()); + r.replaceOpWithNewOp(op, unpacked_type, hi_lo, + /*allow_reorder=*/false); + return success(); + } + + private: + const I4ToI8Converter &converter_; +}; + +// Traverses the operands of the op passing though the forOp and returns the +// list of ops that belong to the same argument. +std::vector TraverseUpwards(Operation *op) { + std::vector result; + while (op != nullptr) { + VLOG(2) << "op: \n" << DumpToString(op); + result.push_back(op); + // Handle the argN of the forOp. + if (auto arg = dyn_cast(op->getOperand(0))) { + // Add the other users of the argN except the op itself. Usually the argN + // is the arg of a ForOp, op is the LoadOp and the other user is the + // AdvanceOp. + for (auto user : arg.getUsers()) { + if (user != op) { + result.push_back(user); + } + } + // Translate the argN of the forOp to the corresponding op that was passed + // as the init arg. + if (auto forOp = + dyn_cast(arg.getParentBlock()->getParentOp())) { + auto arg_number = arg.getArgNumber(); + op = forOp.getInitArgs()[arg_number - 1].getDefiningOp(); + continue; + } + } + + op = op->getOperand(0).getDefiningOp(); + } + return result; +} + +// Finds all the ExtSIOp that require the type conversion. +std::vector FindInt4ExtSIOp(const ModuleOp &module) { + // It does not matter which packed dimension idx we use here, because use the + // converter to detect that the conversion is needed. + I4ToI8Converter converter(/*packed_dim_idx=*/0); + std::vector result; + module->walk([&](Operation *op) { + if (auto extSI = dyn_cast(op)) { + VLOG(2) << "found ExtSI: " << DumpToString(op); + auto input_type = extSI.getIn().getType(); + if (input_type != converter.convertType(input_type)) { + result.push_back(op); + } + } + return WalkResult::advance(); + }); + return result; +} + +// When both strides are 1 then the tensor is actually a vector. +bool IsSingleDimTensor(MakeTensorPtrOp &op) { + auto strides = op.getStrides(); + if (strides.size() != 2) return false; + + auto major_stride = GetConstValue(strides[0]); + bool is_major_stride_1 = major_stride.has_value() && *major_stride == 1; + auto minor_stride = GetConstValue(strides[1]); + bool is_minor_stride_2 = minor_stride.has_value() && *minor_stride == 1; + return is_major_stride_1 && is_minor_stride_2; +} + +// Checks which dimension is packed. We use packed_dim attribute to determine +// which dimension is packed. The tensor (Nx1) which is packed along the minor +// dimension, but every byte has two i4 elements belonging to different rows, so +// the tensor is packed along the major dimension and vice versa. In these +// cases we replace the Major dimension with the Minor dimension and vice versa. +int GetPackedDimIdx(MLIRContext *ctx, const std::vector &ops) { + for (auto *op : ops) { + if (!isa(op)) continue; + + auto make_tensor_ptr = dyn_cast(op); + int packed_dim = 0; + auto attr_dict = make_tensor_ptr->getAttrDictionary(); + if (attr_dict.contains("packed_dim")) { + auto packed_dim_attr = attr_dict.get(StringRef("packed_dim")); + auto packed_dim_int_attr = dyn_cast(packed_dim_attr); + VLOG(2) << "packed_dim: " << packed_dim_int_attr.getInt(); + packed_dim = packed_dim_int_attr.getInt(); + } + + if (IsSingleDimTensor(make_tensor_ptr)) { + return packed_dim == 0 ? 1 : 0; + } + + return packed_dim; + } + return 0; // Default to minor dimension. +} + +struct PlainInt4ToPackedInt4RewritePass + : public impl::LoadInt4RewritePassBase { + // The pass converts the types like tensor to tensor in the + // Triton dialect and replaces the ExtSIOp with the unpack sequence that + // accepts twice smaller i8 tensor and converts it to the twice bigger i8 + // tensor where every i4 element uses i8 space. At the end the module accepts + // the tt.ptr to the packed i4 tensor, and unpacks it to the i8 tensor for + // further processing. It gets the packed dimension from the MakeTensorPtrOp + // attribute. + void runOnOperation() override { + auto *ctx = &getContext(); + auto module = getOperation(); + + auto ext_ops = FindInt4ExtSIOp(module); + int packed_dim_idx = 0; + // TODO(b/383255324): Support the case when both sides of the dot are packed + // differently. + for (auto *op : ext_ops) { + VLOG(2) << "ext_op: " << DumpToString(op); + auto ops = TraverseUpwards(op); + packed_dim_idx = GetPackedDimIdx(ctx, ops); + } + + ConversionTarget target(*ctx); + I4ToI8Converter converter(packed_dim_idx); + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + if (auto func_op = dyn_cast(op)) { + VLOG(2) << "check funcOp: " << DumpToString(func_op); + if (func_op.getFunctionType() != + converter.convertType(func_op.getFunctionType())) { + VLOG(2) << "funcOp not legal: " << DumpToString(func_op); + return false; + } + } + bool is_legal = converter.isLegal(op); + VLOG(2) << "is_legal: " << is_legal << " for " << DumpToString(op); + return is_legal; + }); + RewritePatternSet patterns(ctx); + scf::populateSCFStructuralTypeConversions(converter, patterns); + patterns.add(converter, ctx); + patterns.add>(converter, ctx); + patterns.add(converter, ctx); + patterns.add(converter, ctx); + patterns.add(converter, ctx); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + VLOG(2) << "failed to apply partial conversion"; + signalPassFailure(); + } + } +}; + +// The pass converts the types like tensor to tensor in the +// Triton dialect and replaces the ExtSIOp with the unpack sequence that accepts +// twice smaller i8 tensor and convert it to the twice bigger i8 tensor where +// every i4 element uses i8 space. At the end the module accepts the tt.ptr +// to the packed i4 tensor, and unpacks it to the i8 tensor for the further +// processing. It expects that the i4 tensor is packed along the major +// dimension. +std::unique_ptr CreateInt4ToPackedInt4RewritePass() { + return std::make_unique(); +} + +} // namespace mlir::triton::xla diff --git a/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_passes.h b/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_passes.h index 10f5e684cb5516..67034fe1df1897 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_passes.h +++ b/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_passes.h @@ -36,6 +36,7 @@ std::unique_ptr CreateSparseLocalLoadToLLVMPass(); std::unique_ptr CreateSparseDotOpToLLVMPass(); std::unique_ptr CreateSparseWGMMAOpToLLVMPass(); std::unique_ptr CreatePreventMmaV3LoopUnrollingPass(); +std::unique_ptr CreateInt4ToPackedInt4RewritePass(); // Returns true if the `op` contains an operation in it's regions that satisfies // the `fn`. diff --git a/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_passes.td b/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_passes.td index 49e003e392ed15..21db540475b390 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_passes.td +++ b/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_passes.td @@ -95,4 +95,15 @@ def PreventMmaV3LoopUnrollingPass let constructor = "CreatePreventMmaV3LoopUnrollingPass()"; } +def LoadInt4RewritePass + : Pass<"int4-to-packed-int4-rewrite", "mlir::ModuleOp"> { + let summary = "Converts ops with int4 tensors to the ops with int4 packed to int8 tensors."; + let description = [{ + This pass replaces the int4 tensors with the int4 packed to int8 tensor of + the twice smaller size. It also replaces the plain ExtSIOp upcast to the + int8 tensor with the unpack sequence. + }]; + let constructor = "CreateInt4ToPackedInt4RewritePass()"; +} + #endif // XLA_SERVICE_GPU_FUSIONS_TRITON_XLA_TRITON_PASSES_TD_ diff --git a/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_sparse_passes.cc b/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_sparse_passes.cc index d4c84259f2dbd9..08d4bc8894a2ef 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_sparse_passes.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_sparse_passes.cc @@ -360,7 +360,7 @@ struct SparseBlockedToMMAPass auto pattern = std::make_unique(context, compute_capability); RewritePatternSet patterns(context, std::move(pattern)); - if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) { + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { return signalPassFailure(); } } @@ -975,8 +975,7 @@ struct SparseWGMMAOpToLLVMPass MLIRContext *context = &getContext(); auto pattern = std::make_unique(context); RewritePatternSet patterns(context, std::move(pattern)); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/third_party/xla/xla/service/gpu/gpu_aot_compilation_test.cc b/third_party/xla/xla/service/gpu/gpu_aot_compilation_test.cc index 945f63a1f87c0d..76efde170bca39 100644 --- a/third_party/xla/xla/service/gpu/gpu_aot_compilation_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_aot_compilation_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" +#include "xla/literal_util.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/service/gpu/fusions/triton/triton_support.h" @@ -32,6 +33,7 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/literal_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils.cc b/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils.cc index d789b652df6d4a..43a99ea4fe612b 100644 --- a/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils.cc +++ b/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils.cc @@ -25,14 +25,11 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/collective_utils.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/stream_executor/device_description.h" #include "tsl/platform/statusor.h" namespace xla::gpu { - -using MemoryAwareScheduler = std::function( - const HloModule*, int64_t, int64_t*)>; - namespace { int64_t GetDefaultValue(HloOpcode opcode) { @@ -52,13 +49,13 @@ int64_t GetDefaultValue(HloOpcode opcode) { int64_t ComputeSuggestedCombinerThreshold( const HloModule& module, const se::DeviceDescription& device_info, - MemoryAwareScheduler scheduler, HloOpcode collective_opcode, - int64_t pointer_size) { + HloOpcode collective_opcode, int64_t pointer_size) { int64_t base_limit = module.config().device_memory_size() != 0 ? module.config().device_memory_size() : device_info.device_memory_size(); int64_t peak_memory_bytes = -1; - auto mem_schedule = scheduler(&module, pointer_size, &peak_memory_bytes); + auto mem_schedule = ScheduleGpuModuleWithMemoryScheduler( + &module, pointer_size, &peak_memory_bytes); if (!mem_schedule.ok() || peak_memory_bytes == -1) { VLOG(1) << "Cannot schedule module: " << mem_schedule.status().message(); diff --git a/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils.h b/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils.h index 38a7890decb59b..d78abf552eeb33 100644 --- a/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils.h +++ b/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils.h @@ -17,10 +17,8 @@ limitations under the License. #define XLA_SERVICE_GPU_GPU_COLLECTIVE_COMBINER_UTILS_H_ #include -#include #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -36,9 +34,6 @@ namespace xla::gpu { // `collective_opcode`. int64_t ComputeSuggestedCombinerThreshold( const HloModule& module, const se::DeviceDescription& device_info, - std::function(const HloModule*, int64_t, - int64_t*)> - scheduler, HloOpcode collective_opcode, int64_t pointer_size); // Adds information that `instr` has been pipelined to the diff --git a/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils_test.cc b/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils_test.cc index f0b213f343e587..9d7a9596641618 100644 --- a/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils_test.cc @@ -19,27 +19,20 @@ limitations under the License. #include #include -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/pass/hlo_pass_fix.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/collective_pipeliner.h" -#include "xla/service/collective_utils.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/hlo_module_config.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" namespace xla::gpu { namespace { @@ -65,8 +58,7 @@ TEST_F(CollectiveCombinerUtilsTest, device_info.set_device_memory_size(20000); int64_t suggested_threshold = ComputeSuggestedCombinerThreshold( - *module, device_info, gpu::ScheduleGpuModuleWithMemoryScheduler, - HloOpcode::kAllReduce, pointer_size); + *module, device_info, HloOpcode::kAllReduce, pointer_size); // device size = 20000 bytes // slop factor = 0.95 @@ -96,8 +88,7 @@ TEST_F(CollectiveCombinerUtilsTest, stream_executor::DeviceDescription device_info; int64_t suggested_threshold = ComputeSuggestedCombinerThreshold( - *module, device_info, gpu::ScheduleGpuModuleWithMemoryScheduler, - HloOpcode::kAllReduce, pointer_size); + *module, device_info, HloOpcode::kAllReduce, pointer_size); // device size = 20000 bytes // slop factor = 0.95 @@ -106,45 +97,6 @@ TEST_F(CollectiveCombinerUtilsTest, EXPECT_EQ(suggested_threshold, 6712); } -TEST_F( - CollectiveCombinerUtilsTest, - ComputeSuggestedCombinerThresholdReturnsDefaultValueUponSchedulingFailure) { - absl::string_view kHloText = R"( - HloModule m - - ENTRY ar { - p0 = f32[32,32] parameter(0) - p1 = f32[32,32] parameter(1) - - ROOT _ = f32[32,32]{1,0} custom-call(p0, p1), - custom_call_target="__cublas$gemm" - })"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); - int pointer_size = 4; - stream_executor::DeviceDescription device_info; - device_info.set_device_memory_size(20000); - - auto sched_fun = [](const HloModule* m, int64_t p_sz, - int64_t* p) -> absl::StatusOr { - return absl::UnimplementedError("Fail."); - }; - - int64_t suggested_threshold_all_reduce = ComputeSuggestedCombinerThreshold( - *module, device_info, sched_fun, HloOpcode::kAllReduce, pointer_size); - int64_t suggested_threshold_all_gather = ComputeSuggestedCombinerThreshold( - *module, device_info, sched_fun, HloOpcode::kAllGather, pointer_size); - int64_t suggested_threshold_reduce_scatter = - ComputeSuggestedCombinerThreshold(*module, device_info, sched_fun, - HloOpcode::kReduceScatter, - pointer_size); - - EXPECT_EQ(suggested_threshold_all_reduce, kDefaultAllReduceCombineThreshold); - EXPECT_EQ(suggested_threshold_all_gather, kDefaultAllGatherCombineThreshold); - EXPECT_EQ(suggested_threshold_reduce_scatter, - kDefaultReduceScatterCombineThreshold); -} - TEST_F(CollectiveCombinerUtilsTest, AppendPipelinedInstructionAppendsPipelinedInstructionInfoForward) { // This is just a canonical IR which makes it easy to pipeline a collective diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index aa38ae597b9286..bb720566717bfc 100755 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -36,6 +35,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "llvm/ADT/DenseMap.h" @@ -174,6 +174,7 @@ limitations under the License. #include "xla/service/gpu/metrics.h" #include "xla/service/gpu/model/gpu_cost_model_stats_collection.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/sol_gpu_cost_model_stats_collection.h" #include "xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h" #include "xla/service/gpu/reduce_scatter_combiner.h" #include "xla/service/gpu/reduction_utils.h" @@ -264,7 +265,6 @@ limitations under the License. #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/blocking_counter.h" #include "tsl/platform/casts.h" #include "tsl/platform/cpu_info.h" #include "tsl/platform/env.h" @@ -340,7 +340,7 @@ class GpuThunkAotCompilationResult : public AotCompilationResult { static absl::StatusOr> FromModule(const HloModule* hlo_module, const BufferAssignment* buffer_assignment, - std::string_view asm_text, absl::Span binary, + absl::string_view asm_text, absl::Span binary, const BinaryMap& dnn_compiled_graphs) { CompilationResultProto proto; *proto.mutable_hlo_module_with_config() = hlo_module->ToProtoWithConfig(); @@ -1044,19 +1044,6 @@ absl::Status RunFusionPasses(HloModule* hlo_module, .Run(hlo_module) .status()); - if (hlo_module->config().debug_options().xla_gpu_collect_cost_model_stats()) { - GpuHloCostAnalysis::Options cost_analysis_options{ - shape_size_fn, - /*per_second_rates=*/{}, - /*min_latencies_seconds=*/{}, - /*count_multiple_input_accesses=*/true}; - - HloPassPipeline post_fusion_analysis("post_fusion_analysis"); - post_fusion_analysis.AddPass( - gpu_device_info, cost_analysis_options); - TF_RETURN_IF_ERROR(post_fusion_analysis.Run(hlo_module).status()); - } - TF_RETURN_IF_ERROR( HorizontalFusionPipeline(gpu_device_info).Run(hlo_module).status()); @@ -1157,12 +1144,13 @@ absl::Status RunPostFusionCollectiveOptimizationPasses(HloModule* hlo_module) { // that actually need to run asynchronously with a GPU specific backend // config. AsyncCollectiveCreator::CollectiveCreatorConfig config; + config.convert_all_gather = HloPredicateTrue; config.convert_all_reduce = HloPredicateTrue; + config.convert_all_to_all = HloPredicateTrue; config.convert_collective_broadcast = HloPredicateTrue; config.convert_collective_permute = HloPredicateTrue; - config.convert_all_gather = HloPredicateTrue; + config.convert_ragged_all_to_all = HloPredicateTrue; config.convert_reduce_scatter = HloPredicateTrue; - config.convert_all_to_all = HloPredicateTrue; pipeline.AddPass(std::move(config)); absl::flat_hash_set disabled_async_ops; @@ -1190,6 +1178,8 @@ absl::Status RunPostFusionCollectiveOptimizationPasses(HloModule* hlo_module) { return !disabled_async_ops.contains(DebugOptions::REDUCESCATTER); case HloOpcode::kAllToAll: return !disabled_async_ops.contains(DebugOptions::ALLTOALL); + case HloOpcode::kRaggedAllToAll: + return !disabled_async_ops.contains(DebugOptions::RAGGEDALLTOALL); default: return false; } @@ -1592,14 +1582,9 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // in the softmax codegen pipeline. However we should run before // ReductionDimensionGrouper, as that makes matching the softmax pattern // harder. - if (debug_options - .xla_gpu_experimental_enable_triton_softmax_priority_fusion() && - ((cuda_cc != nullptr && - cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) || - rocm_cc != nullptr)) { - // Triton compilation needs normalized operations on bf16 (i.e. converted - // to f32). - add_float_normalization(pipeline); + if ((cuda_cc != nullptr && + cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) || + rocm_cc != nullptr) { pipeline.AddPass>(simplifier_options, gpu_version); pipeline.AddPass(/*is_layout_sensitive=*/true); @@ -1652,14 +1637,19 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // Rewrite GEMMs with broadcasted inputs as strided GEMMs. pipeline.AddPass(); + // Recover host-offloader invariants (such as the single-use broadcast buffer + // initialization before loops) by re-running the offload legalizer. + pipeline.AddPass( + static_cast(stream_executor::MemoryType::kHost), + /* after_layout= */ true); + pipeline.AddPass(&NormalizeLayoutForGpuCustomCalls); // Layout normalization will create scatters that are not simplified and // also have unsorted update_window_dims. pipeline.AddPass(); - pipeline.AddPass( - static_cast(stream_executor::MemoryType::kHost)); + pipeline.AddPass(); TF_RETURN_IF_ERROR( AddConvAndGemmAutotuningPasses(&pipeline, gpu_version, options, @@ -1685,6 +1675,22 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass(); } + { + // Because of an issue with JAX remat and `SimplifyFPConversions` (see PR: + // https://github.com/jax-ml/jax/pull/22244), we can only eliminate the + // no-op reduce-precision operations after the last call to + // `SimplifyFPConversions`. We are creating a sub-pipeline here because that + // allows us to test this order in a unit test. + HloPassPipeline& remove_no_op_reduce_precision_pipeline = + pipeline.AddPass( + "remove-no-op-reduce-precision-algebraic-simplifier"); + AlgebraicSimplifierOptions simplifier_options_{simplifier_options}; + simplifier_options_.set_enable_remove_no_op_reduce_precision(true); + remove_no_op_reduce_precision_pipeline + .AddPass>(simplifier_options_, + gpu_version); + } + pipeline.AddPass(/*is_layout_sensitive=*/true); pipeline.AddPass( @@ -2145,7 +2151,7 @@ absl::StatusOr GpuCompiler::CompileAndLink( }; std::vector compile_results(llvm_modules.size()); if (thread_pool.get() != nullptr) { - tsl::BlockingCounter counter(llvm_modules.size()); + absl::BlockingCounter counter(llvm_modules.size()); for (int i = 0; i < llvm_modules.size(); ++i) { thread_pool.get_mutable()->Schedule( [&compile_results, i, &llvm_modules, &counter, this, &module_config, @@ -2561,6 +2567,19 @@ absl::Status GpuCompiler::RunPreSchedulingPasses( const se::DeviceDescription& gpu_device_info) { HloPassPipeline pipeline("pre-scheduling-passes"); pipeline.AddPass(gpu_device_info); + if (module->config().debug_options().xla_gpu_collect_cost_model_stats()) { + GpuHloCostAnalysis::Options cost_analysis_options{ + ShapeSizeBytesFunction(), + /*per_second_rates=*/{}, + /*min_latencies_seconds=*/{}, + /*count_multiple_input_accesses=*/true}; + // Cost model analysis for compute. + pipeline.AddPass(gpu_device_info, + cost_analysis_options); + // Cost model analysis for collectives. + pipeline.AddPass(gpu_device_info, + ShapeSizeBytesFunction()); + } return pipeline.Run(module).status(); } diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc old mode 100755 new mode 100644 index ff1fcb6db14372..f3bae16777d10a --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" @@ -58,11 +59,9 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tests/literal_test_util.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/lib/monitoring/collected_metrics.h" #include "xla/tsl/lib/monitoring/collection_registry.h" @@ -611,7 +610,7 @@ TEST_F(GpuCompilerTestWithAutotuneDb, << "Autotuning results have only been generated for Hopper GPUs"; } const absl::string_view hlo_string = R"( -HloModule test +HloModule test ENTRY main { p0 = f8e4m3fn[12288,4096]{0,1} parameter(0) @@ -688,6 +687,7 @@ ENTRY main { DebugOptions debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_cublas_fallback(enable_blas_fallback); debug_options.set_xla_gpu_enable_triton_gemm(enable_triton); + debug_options.set_xla_gpu_cudnn_gemm_fusion_level(0); if (!enable_blas) { debug_options.add_xla_disable_hlo_passes("cublas-gemm-rewriter"); } @@ -1332,6 +1332,40 @@ class PassOrderTest : public GpuCompilerTest { CompileModule(config); } + // Fails if any of the passes matching `other_pass_regex` runs before + // the first occurrence of the pass matching `first_pass_regex`. + void VerifyPassRunsAtLeastOnceBefore(absl::string_view first_pass_regex, + absl::string_view other_pass_regex) { + if (!optimized_module_) { + CompileModule(GetModuleConfigForTest()); + } + int first_pass_first_run = std::numeric_limits::max(); + int other_pass_first_run = std::numeric_limits::max(); + int run_index = 0; + for (const HloPassMetadata& pass_metadata : + optimized_module_->metadata()->proto().pass_metadata()) { + if (RE2::FullMatch(pass_metadata.pass_name(), first_pass_regex)) { + VLOG(2) << "Pass " << pass_metadata.pass_name() + << " matches first_pass_regex." << std::endl; + first_pass_first_run = std::min(first_pass_first_run, run_index); + } + if (RE2::FullMatch(pass_metadata.pass_name(), other_pass_regex)) { + VLOG(2) << "Pass " << pass_metadata.pass_name() + << " matches other_pass_regex." << std::endl; + other_pass_first_run = std::min(other_pass_first_run, run_index); + } + ++run_index; + } + + EXPECT_NE(first_pass_first_run, std::numeric_limits::max()) + << "Did not run a pass matching " << first_pass_regex; + EXPECT_NE(other_pass_first_run, std::numeric_limits::max()) + << "Did not run a pass matching " << other_pass_regex; + EXPECT_LE(first_pass_first_run, other_pass_first_run) + << "A pass matching " << first_pass_regex + << " did not run before passes matching " << other_pass_regex; + } + // Fails if any of the passes with names matching the regular expression // `first_pass_regex` run after any of the passes matching `last_pass_regex` // or if none of the executed passes matches `first_pass_regex` or @@ -1412,8 +1446,24 @@ TEST_F(PassOrderTest, PassesAreRunInCorrectOrder) { /*last_pass_regex=*/"priority-fusion"); VerifyPassOrder(/*first_pass_regex=*/"layout-assignment", /*last_pass_regex=*/"layout_normalization"); - VerifyPassOrder(/*first_pass_regex=*/"host-offload-legalize", - /*last_pass_regex=*/"layout_normalization"); +} + +TEST_F(PassOrderTest, OffloadingPassesAreRunInCorrectOrder) { + // HostOffloadLegalize must run before LayoutNormalization to prevent + // the creation of invalid transpose/bitcast operations within + // host memory offloading segments. + VerifyPassRunsAtLeastOnceBefore(/*first_pass_regex=*/"host-offload-legalize", + /*other_pass_regex=*/"layout_normalization"); + + // CSE should not run between HostOffloadLegalize and HostOffloader + // because it could break the invariants established + // by the legalize pass, such as the buffer initialization broadcasts + // before loops having only a single use + // (see https://github.com/openxla/xla/issues/20373). + auto pass_range = + VerifyPassOrder(/*first_pass_regex=*/"host-offload-legalize", + /*last_pass_regex=*/"host-offloader"); + VerifyNotRunInBetween(pass_range, /*pass_regex=*/"cse"); } TEST_F(PassOrderTest, FusionBlockLevelRewriterRunsAfterAllFusionPasses) { @@ -1512,6 +1562,19 @@ TEST_F(PassOrderTest, GemmRewriterRunsAfterDotNormalizer) { VerifyNotRunInBetween(pass_range, /*pass_regex=*/"algsimp"); } +TEST_F(PassOrderTest, + ReducePrecisionIsRemovedAfterAllCallsToSimplifyFPConversions) { + // Because of an issue with JAX remat and `SimplifyFPConversions` (see PR: + // https://github.com/jax-ml/jax/pull/22244), we can only eliminate the + // no-op reduce-precision operations after the last call to + // `SimplifyFPConversions`. No-op reduce-precisions are removed within + // algebraic simplifier, if the option to remove them is set. In the compiler + // pipeline, this is done as a subpipeline, which should be after the last + // invocation of SimplifyFPConversions. + VerifyPassOrder("simplify-fp-conversions", + "remove-no-op-reduce-precision-algebraic-simplifier"); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index 56b03b13e188f9..f656e9691a0d3a 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" @@ -39,7 +40,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/backends/gpu/collectives/gpu_clique.h" #include "xla/backends/gpu/collectives/gpu_clique_key.h" -#include "xla/backends/gpu/collectives/gpu_clique_locking.h" +#include "xla/backends/gpu/collectives/gpu_cliques.h" #include "xla/core/collectives/rank_id.h" #include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" @@ -185,6 +186,17 @@ absl::Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions( namespace { +// A container for per-process persistent cliques. +struct PersistentCliquesMap { + absl::Mutex mutex; + AcquiredCliquesMap cliques_map ABSL_GUARDED_BY(mutex); +}; + +static PersistentCliquesMap& GetPersistentCliquesMap() { + static auto* persistent_cliques = new PersistentCliquesMap(); + return *persistent_cliques; +} + // Shared resources required for thunk initialization and execution. class ResourceRequests : public Thunk::ResourceRequests { public: @@ -220,7 +232,8 @@ class ResourceRequests : public Thunk::ResourceRequests { } absl::StatusOr AcquireCollectiveCliques( - const Thunk::CollectiveExecuteParams& params) { + const Thunk::CollectiveExecuteParams& params, + bool use_persistent_cliques) { if (cliques_.empty()) return Thunk::CollectiveCliques(); VLOG(2) << "Acquire " << cliques_.size() @@ -229,7 +242,8 @@ class ResourceRequests : public Thunk::ResourceRequests { << "; run_id=" << params.run_id.ToInt() << "; max number of channels for collectives " << params.collective_max_nchannels - << "; max number of channels for p2p " << params.p2p_max_nchannels; + << "; max number of channels for p2p " << params.p2p_max_nchannels + << "; use_persistent_cliques=" << use_persistent_cliques; std::vector ordered_cliques = GetOrderedCliqueRequests(); for (size_t i = 0; i < ordered_cliques.size(); ++i) { @@ -241,13 +255,16 @@ class ResourceRequests : public Thunk::ResourceRequests { } tsl::profiler::TraceMe trace([&] { - return tsl::profiler::TraceMeEncode("AcquireCollectiveCliques", - {{"num_cliques", cliques_.size()}}); + return tsl::profiler::TraceMeEncode( + "AcquireCollectiveCliques", + {{"num_cliques", cliques_.size()}, + {"use_persistent_cliques", use_persistent_cliques}}); }); auto start_micros = tsl::Env::Default()->NowMicros(); AcquiredCliquesMap cliques_map; + int32_t num_transient_cliques = 0; for (const CliqueRequest& r : ordered_cliques) { std::optional rank = r.key.rank(params.global_device_id); @@ -266,12 +283,43 @@ class ResourceRequests : public Thunk::ResourceRequests { int64_t max_channels = r.key.stream_kind() == AsyncStreamKind::kCollective ? params.collective_max_nchannels : params.p2p_max_nchannels; + + // Check if we have a persistent clique for this key. + if (use_persistent_cliques) { + auto& pc = GetPersistentCliquesMap(); + absl::MutexLock lock(&pc.mutex); + + if (auto it = pc.cliques_map.find(r.key); it != pc.cliques_map.end()) { + VLOG(2) << "Found persistent clique for key " << r.key.ToString(); + cliques_map[r.key] = it->second; + continue; + } + } + + // If we don't have a persistent clique we have to acquire a transient + // one. TF_ASSIGN_OR_RETURN( std::shared_ptr clique, AcquireGpuClique(params.collectives, params.executor, params.run_id, r.key, *clique_id_callback, *rank, r.num_local_participants, cliques_map, max_channels)); + ++num_transient_cliques; + + // Take a copy of the clique lock, so that we can reuse it. This is + // potentially unsafe in the case when we have multiple racing executions + // of XLA, as we might observe partial state and some of the replicas will + // use persistent clique, and others will try to acquire a new one. + // + // However given that persistent cliques is an unsafe escape hatch, any + // racing execution together with persistent cliques will lead to + // deadlocks anyway, so we don't bother to fix this. If anyone is doing + // it, it's 100% their fault and they will suffer. + if (use_persistent_cliques) { + auto& pc = GetPersistentCliquesMap(); + absl::MutexLock lock(&pc.mutex); + pc.cliques_map[r.key] = clique; + } cliques_map[r.key] = std::move(clique); } @@ -281,9 +329,11 @@ class ResourceRequests : public Thunk::ResourceRequests { << " collective cliques for global device id " << params.global_device_id.value() << " in " << (end_micros - start_micros) << " μs" - << "; run_id=" << params.run_id.ToInt(); + << "; run_id=" << params.run_id.ToInt() + << "; num_transient_cliques=" << num_transient_cliques; - return Thunk::CollectiveCliques(std::move(cliques_map)); + return Thunk::CollectiveCliques(std::move(cliques_map), + num_transient_cliques); } private: @@ -449,7 +499,11 @@ absl::Status ExecuteThunks( if (!mock_collectives) { TF_ASSIGN_OR_RETURN( collective_cliques, - resource_requests.AcquireCollectiveCliques(collective_params)); + resource_requests.AcquireCollectiveCliques( + collective_params, + debug_options + ? debug_options->xla_gpu_collectives_use_persistent_cliques() + : false)); } { // Initialize thunks using prepared resources before execution. @@ -470,9 +524,11 @@ absl::Status ExecuteThunks( } // Maybe join a round of rendezvous after thunk initialization. We do this - // only in presence of collective cliques which means that we have collective - // operations in the XLA operations that tend to cause deadlocks. - if (!collective_cliques.empty()) { + // only in presence of newly acquired collective cliques which means that we + // have collective operations and clique initialization is famous for + // introducing deadlocks if we try to execute it concurrently with other + // potentially memory-allocating operations. + if (collective_cliques.num_transient_cliques() > 0) { TF_RETURN_IF_ERROR( RendezvousAfterInitialization(run_options, debug_options)); } @@ -485,8 +541,18 @@ absl::Status ExecuteThunks( TF_RETURN_IF_ERROR(thunk_sequence.ExecuteOnStream(execute_params)); - return MaybeSyncAndProfile(run_options, execution_timer.get(), - block_host_until_done ? main_stream : nullptr); + auto status = + MaybeSyncAndProfile(run_options, execution_timer.get(), + block_host_until_done ? main_stream : nullptr); + + Thunk::CleanupParams cleanup_params{ + executor, + &collective_params, + &collective_cliques, + }; + TF_RETURN_IF_ERROR(thunk_sequence.Cleanup(cleanup_params)); + + return status; } namespace { @@ -560,7 +626,7 @@ absl::Status RendezvousAfterInitialization( run_options->device_ordinal(), run_options->run_options().run_id().ToInt()); - RendezvousSingle( + Rendezvous( rendezvous_name, rendezvous_key, num_local_participants, absl::Seconds( debug_options diff --git a/third_party/xla/xla/service/gpu/gpu_float_support.cc b/third_party/xla/xla/service/gpu/gpu_float_support.cc index 38d64e54b56dc8..2f493c57e177ca 100644 --- a/third_party/xla/xla/service/gpu/gpu_float_support.cc +++ b/third_party/xla/xla/service/gpu/gpu_float_support.cc @@ -23,6 +23,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/primitive_util.h" +#include "xla/service/collective_ops_utils.h" #include "xla/service/float_support.h" #include "xla/service/gpu/fusions/triton/triton_support.h" #include "xla/stream_executor/device_description.h" @@ -50,6 +52,10 @@ bool GpuFloatSupport::SupportsMixedPrecisions(const HloInstruction& hlo) const { } bool GpuFloatSupport::IsSupported(const HloInstruction& hlo) const { + if (IsCollective(&hlo) && + primitive_util::IsSubByteNonPredType(hlo.shape().element_type())) { + return false; + } switch (hlo.opcode()) { // Collective ops. case HloOpcode::kAllReduce: diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc index 8e7d524acd33d7..5a5cc36dce644c 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" @@ -46,11 +47,11 @@ limitations under the License. #include "xla/hlo/utils/hlo_query.h" #include "xla/service/buffer_value.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/collective_utils.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/flag_utils.h" #include "xla/service/gpu/gpu_latency_hiding_scheduler.h" #include "xla/service/gpu/model/analytical_latency_estimator.h" +#include "xla/service/gpu/model/sol_latency_estimator.h" #include "xla/service/gpu/transforms/pgle_accuracy_checker.h" #include "xla/service/gpu/transforms/schedule_postprocessing.h" #include "xla/service/gpu/transforms/scheduling_instruction_annotator.h" @@ -262,21 +263,32 @@ SchedulerConfig GetSchedulerConfig(int64_t memory_limit, config.schedule_send_recvs = true; config.memory_limit = memory_limit; config.parallel_collective_overlap_limit = collective_resource; + + CHECK(config.collective_broadcast_overlap_limit <= + config.parallel_collective_overlap_limit); + CHECK(config.all_to_all_overlap_limit <= + config.parallel_collective_overlap_limit); + CHECK(config.all_gather_overlap_limit <= + config.parallel_collective_overlap_limit); + CHECK(config.all_reduce_overlap_limit <= + config.parallel_collective_overlap_limit); + CHECK(config.reduce_scatter_overlap_limit <= + config.parallel_collective_overlap_limit); + return config; } tensorflow::profiler::ProfiledInstructionsProto GetProfileForFingerprint( tensorflow::profiler::ProfiledInstructionsProto& profile, - const std::string& fingerprint) { + absl::string_view fingerprint) { tensorflow::profiler::ProfiledInstructionsProto result; bool merge_remat_clones = false; for (const auto& cost : profile.costs()) { - absl::string_view cost_name = cost.name(); std::string new_cost_name = cost.name(); absl::string_view cost_sep = "::"; - if (absl::StrContains(cost_name, cost_sep)) { - std::vector split_names = - absl::StrSplit(cost_name, cost_sep); + if (absl::StrContains(cost.name(), cost_sep)) { + std::vector split_names = + absl::StrSplit(cost.name(), cost_sep); if (split_names.size() != 2 || split_names[0] != fingerprint) { continue; } @@ -314,30 +326,33 @@ tensorflow::profiler::ProfiledInstructionsProto GetProfileForFingerprint( return name; }; - // Map from stripped name -> pair - absl::flat_hash_map> costs; + struct Data { + double accumulated_cost = 0.0; + int64_t count = 0; + }; + absl::flat_hash_map costs; for (const auto& cost : result.costs()) { - std::pair& data = costs[strip_remat_suffix(cost.name())]; - data.first += cost.cost_us(); - data.second++; + Data& data = costs[strip_remat_suffix(cost.name())]; + data.accumulated_cost += cost.cost_us(); + data.count++; } tensorflow::profiler::ProfiledInstructionsProto merged_result; - for (const auto& cost : costs) { + for (const auto& [name, data] : costs) { auto* new_cost = merged_result.add_costs(); - double average = cost.second.first / cost.second.second; + double average = data.accumulated_cost / data.count; new_cost->set_cost_us(average); - new_cost->set_name(std::string(cost.first)); + new_cost->set_name(std::string(name)); } return merged_result; } std::optional ReadPGLEProfile( - const HloModule* module, const std::string& fingerprint) { + const HloModule& module, absl::string_view fingerprint) { tensorflow::profiler::ProfiledInstructionsProto profile; - absl::string_view fdo_profile = module->config().fdo_profile(); + absl::string_view fdo_profile = module.config().fdo_profile(); // First attempt to read the profile from `fdo_profile` in ModuleConfig if (!fdo_profile.empty()) { // Attempt to parse it as a binary proto. @@ -358,14 +373,14 @@ std::optional ReadPGLEProfile( } const std::string& pgle_profile_file_or_dir_path = - module->config() + module.config() .debug_options() .xla_gpu_pgle_profile_file_or_directory_path(); if (pgle_profile_file_or_dir_path.empty()) { return std::nullopt; } tsl::Env* env = tsl::Env::Default(); - auto read_text_or_binary_profile = [&profile, env, &fingerprint]( + auto read_text_or_binary_profile = [&profile, env, fingerprint]( const std::string& text_path, const std::string& binary_path) -> std::optional { @@ -398,7 +413,7 @@ std::optional ReadPGLEProfile( // specific module. if (env->IsDirectory(pgle_profile_file_or_dir_path).ok()) { std::string pgle_profile_path_prefix = - pgle_profile_file_or_dir_path + "/" + fingerprint; + absl::StrCat(pgle_profile_file_or_dir_path, "/", fingerprint); return read_text_or_binary_profile(pgle_profile_path_prefix + ".pbtxt", pgle_profile_path_prefix + ".pb"); } @@ -416,77 +431,47 @@ std::optional ReadPGLEProfile( pgle_profile_file_or_dir_path); } } -} // end namespace - -static int64_t GetSchedulerMemoryLimit( - const HloModule* module, const se::DeviceDescription& gpu_device_info, - int pointer_size); - -absl::StatusOr ScheduleGpuModule( - HloModule* module, int64_t pointer_size, - const se::DeviceDescription& gpu_device_info) { - tsl::profiler::TraceMe traceme("GpuCompiler::CompileToBackendResult"); - int64_t memory_limit = - GetSchedulerMemoryLimit(module, gpu_device_info, pointer_size); - if (module->has_schedule()) { - return ScheduleMetadata{memory_limit}; - } - const DebugOptions& options = module->config().debug_options(); - if (options.xla_gpu_enable_pipelined_p2p()) { - HloPassPipeline prepare_pipeline("p2p-schedule-preparation"); - prepare_pipeline.AddPass(); - TF_RETURN_IF_ERROR(prepare_pipeline.Run(module).status()); +// Runs P2P schedule preparation prior any scheduling. +absl::Status RunP2PSchedulePreparation(HloModule* module) { + if (!module->config().debug_options().xla_gpu_enable_pipelined_p2p()) { + return absl::OkStatus(); } + HloPassPipeline prepare_pipeline("p2p-schedule-preparation"); + prepare_pipeline.AddPass(); + return prepare_pipeline.Run(module).status(); +} - TF_ASSIGN_OR_RETURN( - HloSchedule schedule, - ScheduleGpuModuleWithMemoryScheduler(module, pointer_size)); - TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); - - // Tag the module with its 128 bit fingerprint. The fingerprint should include - // instruction name with ids. +// Adds fingerprint to the module before. +// +// Returns said fingerprint. +std::string TagWithFingerprint(HloModule* module) { std::string fingerprint = module->GetFingerprint128( HloPrintOptions::Canonical().set_print_backend_config(true)); FrontendAttributes attributes; (*attributes.mutable_map())[std::string(kFingerprintBeforeLHS)] = fingerprint; - module->add_frontend_attributes(attributes); + module->add_frontend_attributes(std::move(attributes)); VLOG(1) << "Fingerprint before LHS for module " << module->name() << "(" << module->unique_id() << ") = " << fingerprint; + return fingerprint; +} - const bool enable_latency_hiding_scheduler = - options.xla_gpu_enable_latency_hiding_scheduler() || - IsPassEnabledAtOptimizationEffort(*module); - - if (!enable_latency_hiding_scheduler) { - return ScheduleMetadata{memory_limit}; - } +// Returns latency estimator, key abstraction used by LHS which returns how much +// each instruction takes. If we return a PGO based estimator then we will +// additionally add fail-fast/warn checks to the pipeline which act in the +// absence of instruction in the profile. See `PGLEAccuracyChecker` for details. +std::unique_ptr GetLatencyEstimator( + const HloModule& module, int pointer_size, + const se::DeviceDescription& gpu_device_info, absl::string_view fingerprint, + const SchedulerConfig& config, HloPassPipeline& pipeline) { + const DebugOptions& options = module.config().debug_options(); - SchedulerConfig config = GetSchedulerConfig( - memory_limit, - module->config() - .debug_options() - .xla_gpu_experimental_parallel_collective_overlap_limit()); - CHECK((config.collective_broadcast_overlap_limit <= - config.parallel_collective_overlap_limit) && - (config.all_to_all_overlap_limit <= - config.parallel_collective_overlap_limit) && - (config.all_gather_overlap_limit <= - config.parallel_collective_overlap_limit) && - (config.all_reduce_overlap_limit <= - config.parallel_collective_overlap_limit) && - (config.reduce_scatter_overlap_limit <= - config.parallel_collective_overlap_limit)); auto gpu_latency_estimator = std::make_unique(pointer_size); - std::unique_ptr latency_estimator; std::optional profile = ReadPGLEProfile(module, fingerprint); - const bool enable_analytical_latency_estimator = - options.xla_gpu_enable_analytical_latency_estimator(); - HloPassPipeline pipeline("latency-hiding-scheduler"); if (profile.has_value()) { auto aggregator = std::make_unique(); auto pg_latency_estimator = std::make_unique( @@ -500,71 +485,73 @@ absl::StatusOr ScheduleGpuModule( DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR) { pipeline.AddPass(*pg_latency_estimator); } - latency_estimator = std::move(pg_latency_estimator); - } else if (enable_analytical_latency_estimator) { - latency_estimator = std::make_unique( + return pg_latency_estimator; + } + + if (options.xla_gpu_enable_analytical_latency_estimator()) { + LOG(INFO) << "Using analytical latency estimator"; + return std::make_unique( config, std::move(gpu_latency_estimator), gpu_device_info, [input_pointer_size = pointer_size](const Shape& shape) { return GetSizeOfShape(shape, input_pointer_size); }, - module->entry_computation()); - LOG(INFO) << "Using analytical latency estimator"; - } else { - latency_estimator = std::move(gpu_latency_estimator); + module.entry_computation()); } - auto async_tracker = [&]() -> std::unique_ptr { - return options.xla_gpu_lhs_enable_gpu_async_tracker() - ? std::make_unique(config) - : std::make_unique(config); - }(); + if (options.xla_gpu_enable_analytical_sol_latency_estimator()) { + LOG(INFO) << "Using Speed-of-Light (SoL) analytical latency estimator"; + return std::make_unique( + config, std::move(gpu_latency_estimator), gpu_device_info, + [input_pointer_size = pointer_size](const Shape& shape) { + return GetSizeOfShape(shape, input_pointer_size); + }, + module.entry_computation()); + } + return gpu_latency_estimator; +} + +// Adds necessary passes to perform latency hiding estimations for the +// `pipeline`. +absl::Status RunLatencyHidingSchedulerPasses( + HloModule* module, int pointer_size, absl::string_view fingerprint, + int64_t memory_limit, const se::DeviceDescription& gpu_device_info) { + SchedulerConfig config = GetSchedulerConfig( + memory_limit, + module->config() + .debug_options() + .xla_gpu_experimental_parallel_collective_overlap_limit()); auto shape_size_in_bytes = [pointer_size](const Shape& shape) { return GetSizeOfShape(shape, pointer_size); }; + + auto async_tracker = std::make_unique(config); + + HloPassPipeline pipeline("latency-hiding-scheduler"); + std::unique_ptr latency_estimator = GetLatencyEstimator( + *module, pointer_size, gpu_device_info, fingerprint, config, pipeline); + auto scheduler_core = std::make_unique( shape_size_in_bytes, async_tracker.get(), latency_estimator.get(), config, /*target_scheduling_rule=*/nullptr, /*early_target_scheduling_rule=*/nullptr, /*post_processing_fn=*/nullptr, /*scheduling_instruction_crosses_overlap_limit=*/ GpuScheduleCrossesOverlapLimit); - pipeline.AddPass(); + pipeline.AddPass( std::move(latency_estimator), std::move(async_tracker), std::move(scheduler_core), shape_size_in_bytes); + pipeline.AddPass(); + pipeline.AddPass(); - TF_RETURN_IF_ERROR(pipeline.Run(module).status()); - - HloPassPipeline postprocessing_pipeline("schedule-postprocessing"); - postprocessing_pipeline.AddPass(); - TF_RETURN_IF_ERROR(postprocessing_pipeline.Run(module).status()); - - return ScheduleMetadata{memory_limit}; -} - -absl::StatusOr ScheduleGpuModuleWithMemoryScheduler( - const HloModule* module, int64_t pointer_size, int64_t* peak_memory_bytes) { - return ScheduleModule( - module, - [pointer_size](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); - }, - ComputationSchedulerToModuleScheduler(DefaultMemoryScheduler, - PostProcessSchedule), - /*execution_threads=*/{}, /*peak_memory=*/peak_memory_bytes); -} - -HloInstructionSequence PostProcessSchedule( - const HloInstructionSequence& input) { - HloInstructionSequence result = PostprocessorToScheduleSyncCollectives(input); - return PostprocessorToScheduleAsEarlyOrLateAsPossible(result); + return pipeline.Run(module).status(); } // Compute the device memory limit to be used by passes like scheduler and // HLO rematerialization. -static int64_t GetSchedulerMemoryLimit( - const HloModule* module, const se::DeviceDescription& gpu_device_info, - int pointer_size) { +int64_t GetSchedulerMemoryLimit(const HloModule& module, + const se::DeviceDescription& gpu_device_info, + int pointer_size) { // There is a "base" value which is either specified in HloModuleConfig (this // value should take into account the fact that we need to leave some memory // free for allocations that happen outside of XLA's allocator) or @@ -574,14 +561,14 @@ static int64_t GetSchedulerMemoryLimit( // From that base value, subtract any input and output sizes (assuming they // are live throughout the execution) and then apply a slop factor. const int64_t base_limit = - module->config().device_memory_size() != 0 - ? module->config().device_memory_size() + module.config().device_memory_size() != 0 + ? module.config().device_memory_size() : gpu_device_info.device_memory_size() * 80 / 100; // Find the total size of inputs and outputs. int64_t total_io_size = 0; for (HloInstruction* param : - module->entry_computation()->parameter_instructions()) { + module.entry_computation()->parameter_instructions()) { ShapeUtil::ForEachSubshape( param->shape(), [&](const Shape& subshape, const ShapeIndex& /*index*/) { @@ -589,25 +576,86 @@ static int64_t GetSchedulerMemoryLimit( }); } ShapeUtil::ForEachSubshape( - module->result_shape(), + module.result_shape(), [&](const Shape& subshape, const ShapeIndex& /*index*/) { total_io_size += GetSizeOfShape(subshape, pointer_size); }); // If any inputs and outputs are aliased, do not double count them. - module->input_output_alias_config().ForEachAlias( + module.input_output_alias_config().ForEachAlias( [&](const ShapeIndex& output_index, const HloInputOutputAliasConfig::Alias&) { const Shape& subshape = - ShapeUtil::GetSubshape(module->result_shape(), output_index); + ShapeUtil::GetSubshape(module.result_shape(), output_index); total_io_size -= GetSizeOfShape(subshape, pointer_size); }); int64_t limit = (base_limit - total_io_size) * - module->config().debug_options().xla_gpu_memory_limit_slop_factor() / 100; + module.config().debug_options().xla_gpu_memory_limit_slop_factor() / 100; return limit; } +} // end namespace + +absl::StatusOr ScheduleGpuModule( + HloModule* module, int64_t pointer_size, + const se::DeviceDescription& gpu_device_info) { + tsl::profiler::TraceMe traceme("GpuCompiler::CompileToBackendResult"); + + // Tag the module with its 128 bit fingerprint. The fingerprint should include + // instruction name with ids. + std::string fingerprint = TagWithFingerprint(module); + int64_t memory_limit = + GetSchedulerMemoryLimit(*module, gpu_device_info, pointer_size); + + // Module already has a schedule, do nothing. + if (module->has_schedule()) { + return ScheduleMetadata{memory_limit}; + } + + // Run the scheduler which minimizes peak memory usage. + // We need to run it anyway because LHS relies on it track buffers. See + // `xla::BufferInfoTracker::BufferInfoTracker()`. + TF_RETURN_IF_ERROR(RunP2PSchedulePreparation(module)); + TF_ASSIGN_OR_RETURN( + HloSchedule schedule, + ScheduleGpuModuleWithMemoryScheduler(module, pointer_size)); + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + + bool enable_latency_hiding_scheduler = + module->config() + .debug_options() + .xla_gpu_enable_latency_hiding_scheduler() || + IsPassEnabledAtOptimizationEffort(*module); + + // Run Latency Hiding Scheduler (LHS). It maximizes the compute-communication + // overlap, potentially at the cost of memory usage. + if (enable_latency_hiding_scheduler) { + TF_RETURN_IF_ERROR(RunLatencyHidingSchedulerPasses( + module, pointer_size, fingerprint, memory_limit, gpu_device_info)); + } + + return ScheduleMetadata{memory_limit}; +} + +absl::StatusOr ScheduleGpuModuleWithMemoryScheduler( + const HloModule* module, int64_t pointer_size, int64_t* peak_memory_bytes) { + return ScheduleModule( + module, + [pointer_size](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); + }, + ComputationSchedulerToModuleScheduler(DefaultMemoryScheduler, + PostProcessSchedule), + /*execution_threads=*/{}, /*peak_memory=*/peak_memory_bytes); +} + +HloInstructionSequence PostProcessSchedule( + const HloInstructionSequence& input) { + HloInstructionSequence result = PostprocessorToScheduleSyncCollectives(input); + return PostprocessorToScheduleAsEarlyOrLateAsPossible(result); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc index dc883c22ec1e8e..23def749504f0d 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -37,6 +36,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/backend.h" #include "xla/service/gpu/gpu_compiler.h" @@ -44,10 +45,8 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" @@ -78,7 +77,6 @@ class GpuHloScheduleTest : public HloTestBase { struct TestConfig { bool enable_latency_hiding_scheduler = false; - bool enable_gpu_async_tracker = false; bool enable_pipelined_p2p = false; std::string fdo_profile = ""; }; @@ -88,8 +86,6 @@ class GpuHloScheduleTest : public HloTestBase { DebugOptions debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_latency_hiding_scheduler( test_config.enable_latency_hiding_scheduler); - debug_options.set_xla_gpu_lhs_enable_gpu_async_tracker( - test_config.enable_gpu_async_tracker); debug_options.set_xla_gpu_enable_pipelined_p2p( test_config.enable_pipelined_p2p); config.set_debug_options(debug_options); @@ -510,7 +506,6 @@ TEST_F(GpuHloScheduleTest, ProfileGuidedCostModel) { for (const SubTest& subtest : subtests) { TestConfig test_config; test_config.enable_latency_hiding_scheduler = true; - test_config.enable_gpu_async_tracker = true; test_config.fdo_profile = subtest.profile; TF_ASSERT_OK_AND_ASSIGN( auto module, @@ -573,7 +568,6 @@ TEST_F(GpuHloScheduleTest, ProfileGuidedCostModelFailsWithIncompleteProfile) { TestConfig test_config; test_config.enable_latency_hiding_scheduler = true; - test_config.enable_gpu_async_tracker = true; test_config.fdo_profile = kProfile; TF_ASSERT_OK_AND_ASSIGN( auto module, @@ -634,7 +628,6 @@ TEST_F( TestConfig test_config; test_config.enable_latency_hiding_scheduler = true; - test_config.enable_gpu_async_tracker = true; test_config.fdo_profile = kProfile; TF_ASSERT_OK_AND_ASSIGN( auto module, @@ -692,7 +685,6 @@ TEST_F(GpuHloScheduleTest, ProfileGuidedCostModelWithRematData) { )pb"; TestConfig test_config; test_config.enable_latency_hiding_scheduler = true; - test_config.enable_gpu_async_tracker = true; test_config.fdo_profile = ar_long_latency_proto_text; TF_ASSERT_OK_AND_ASSIGN( auto module, @@ -876,7 +868,6 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPairs2) { TestConfig test_config; test_config.enable_latency_hiding_scheduler = true; - test_config.enable_gpu_async_tracker = true; test_config.enable_pipelined_p2p = true; TF_ASSERT_OK_AND_ASSIGN( auto module, @@ -973,7 +964,6 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvAllReduce) { TestConfig test_config; test_config.enable_latency_hiding_scheduler = true; - test_config.enable_gpu_async_tracker = true; test_config.enable_pipelined_p2p = true; TF_ASSERT_OK_AND_ASSIGN( auto module, @@ -1095,7 +1085,6 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined1) { TestConfig test_config; test_config.enable_latency_hiding_scheduler = true; - test_config.enable_gpu_async_tracker = true; test_config.enable_pipelined_p2p = true; TF_ASSERT_OK_AND_ASSIGN( auto module, @@ -1291,7 +1280,6 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPipelined2) { TestConfig test_config; test_config.enable_latency_hiding_scheduler = true; - test_config.enable_gpu_async_tracker = true; test_config.enable_pipelined_p2p = true; TF_ASSERT_OK_AND_ASSIGN( auto module, @@ -1520,7 +1508,7 @@ TEST_P(GpuHloScheduleParameterizedTest, AsyncAllReduce) { EXPECT_TRUE(HasValidFingerprint(module.get())); } -TEST_P(GpuHloScheduleParameterizedTest, LHSResourceModel) { +TEST_F(GpuHloScheduleTest, LHSResourceModel) { const char* hlo_text = R"( HloModule AsyncModule apply_op { @@ -1559,19 +1547,13 @@ TEST_P(GpuHloScheduleParameterizedTest, LHSResourceModel) { ROOT t = (f32[32], f32[64], f32[32,32]) tuple(ar-done, %ag-done, add5) })"; - const bool enable_gpu_async_tracker = GetParam(); TestConfig test_config; test_config.enable_latency_hiding_scheduler = true; - test_config.enable_gpu_async_tracker = GetParam(); TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndReturnVerifiedModule(hlo_text, GetModuleConfig(test_config))); SequentialHloOrdering order = BuildHloOrdering(module.get()); - // Count the number of collectives in flight. Without gpu async tracker, we - // will incorrectly have 2 in-flight (as base async tracker assumes each - // collective can be scheduled independently as they use different resource - // types), but with gpu async tracker we will have 1. uint32_t in_flight = 0; uint32_t max_in_flight = 0; for (const HloInstruction* inst : @@ -1584,8 +1566,7 @@ TEST_P(GpuHloScheduleParameterizedTest, LHSResourceModel) { } } - const uint32_t expected_max_in_flight = enable_gpu_async_tracker ? 1 : 2; - EXPECT_EQ(expected_max_in_flight, max_in_flight); + EXPECT_EQ(max_in_flight, 1); EXPECT_TRUE(HasValidFingerprint(module.get())); } @@ -1630,7 +1611,7 @@ TEST_F(GpuHloSchedulePostProcessTest, PostProcessAsyncCollectives) { module->schedule().sequence(module->entry_computation()); HloInstructionSequence result = PostProcessSchedule(input); - const std::vector expected_sequence = { + const std::vector expected_sequence = { "p0", "ar-start", // ar-start is async, should be scheduled as early as // possible. diff --git a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc index 5d0053752d19f4..2c50af565ef2b4 100644 --- a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc @@ -243,19 +243,58 @@ bool GpuAsyncTrackerBase::IsSupportedAsyncStart( return IsGpuAsyncStart(hlo); } +static bool IsPartiallyPipelinedSendRecvDone(const HloInstruction* instr) { + // Is send-done/recv-done but does not have send/recv operand. + return HloPredicateIsOp(instr) && + HloPredicateIsNotOp( + instr->operand(0)); +} + +static bool IsPartiallyPipelinedSendRecv(const HloInstruction* instr) { + // Is send/recv but does not feed into send-done/recv-done. + return HloPredicateIsOp(instr) && + instr->user_count() == 1 && + HloPredicateIsNotOp( + instr->users().front()); +} + void GpuAsyncTrackerBase::PostProcessScheduleGraph( HloScheduleGraph* schedule_graph, const LatencyEstimator* latency_estimator) const { - for (auto inst : schedule_graph->GetOriginalInstrList()) { + if (schedule_graph->GetOriginalInstrList().empty()) return; + auto debug_options = schedule_graph->GetOriginalInstrList() + .front() + ->GetModule() + ->config() + .debug_options(); + + for (const HloInstruction* inst : schedule_graph->GetOriginalInstrList()) { + // Schedule partially pipelined send/recv instructions late so that they can + // overlap with compute. Schedule send/recv late and, when unblocked, + // schedule send-done/recv-done early. + if (debug_options.xla_gpu_enable_experimental_pipeline_parallelism_opt() && + IsPartiallyPipelinedSendRecv(inst)) { + HloGraphNode& node = schedule_graph->GetNode(inst); + node.SetForceDelay(true); + VLOG(5) << "Setting force delay for instruction: " << inst->ToString(); + } + if (debug_options.xla_gpu_enable_experimental_pipeline_parallelism_opt() && + IsPartiallyPipelinedSendRecvDone(inst)) { + HloGraphNode& node = schedule_graph->GetNode(inst); + node.SetForceEarly(true); + VLOG(5) << "Setting force early for instruction: " << inst->ToString(); + } + // Force pipelined Recv to be closed to Recvdone so that copies inserted // for RecvDone can be eliminated. - if (inst->opcode() == HloOpcode::kRecv) { - if (inst->frontend_attributes().map().count(kSendRecvPipelineAttr) > 0) { - HloGraphNode& node = schedule_graph->GetNode(inst); - node.SetForceEarly(true); - VLOG(5) << "Setting force early for instruction: " << inst->ToString(); - } + if (debug_options.xla_gpu_enable_pipelined_p2p() && + inst->opcode() == HloOpcode::kRecv && + inst->frontend_attributes().map().count(kSendRecvPipelineAttr) > 0) { + HloGraphNode& node = schedule_graph->GetNode(inst); + node.SetForceEarly(true); + VLOG(5) << "Setting force early for instruction: " << inst->ToString(); } + if (inst->has_backend_config()) { auto gpu_config = inst->backend_config(); if (gpu_config.ok()) { diff --git a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc index 7a1ddba502ab62..382e6e148e50e3 100644 --- a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_hlo_schedule.h" @@ -34,6 +35,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" namespace xla::gpu { @@ -76,11 +78,14 @@ class GpuLatencyHidingSchedulerBaseTest : public HloTestBase { return module; } - HloModuleConfig GetModuleConfig(absl::string_view fdo_profile) { + HloModuleConfig GetModuleConfig( + absl::string_view fdo_profile, + bool enable_experimental_pipeline_parallelism_opt = false) { HloModuleConfig config; DebugOptions debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_latency_hiding_scheduler(true); - debug_options.set_xla_gpu_lhs_enable_gpu_async_tracker(true); + debug_options.set_xla_gpu_enable_experimental_pipeline_parallelism_opt( + enable_experimental_pipeline_parallelism_opt); config.set_debug_options(debug_options); config.set_fdo_profile(fdo_profile); return config; @@ -444,5 +449,108 @@ TEST_F(GpuLatencyHidingSchedulerBaseTest, GetIndexByName(instruction_sequence, "rs_1"))); } +TEST_F(GpuLatencyHidingSchedulerBaseTest, SchedulePipelinedSendRecvsLate) { + absl::string_view kHloModule = R"( + HloModule m + + while_condition { + tuple = ((f32[16,16], u32[], token[]), (f32[16,16], u32[], token[]), + f32[16,16], u32[]) parameter(0) + i = get-tuple-element(tuple), index=3 + n = u32[] constant(13) + ROOT predicate = pred[] compare(i, n), direction=LT + } + + while_body { + tuple = ((f32[16,16], u32[], token[]), (f32[16,16], u32[], token[]), + f32[16,16], u32[]) parameter(0) + send_ctx = get-tuple-element(tuple), index=0 + recv_ctx = get-tuple-element(tuple), index=1 + some_arg = get-tuple-element(tuple), index=2 + i = get-tuple-element(tuple), index=3 + some_res = f32[16,16] dot(some_arg, some_arg), lhs_contracting_dims={0}, + rhs_contracting_dims={1} + recv_done = (f32[16], token[]) recv-done(recv_ctx), + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + send_done = token[] send-done(send_ctx), frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + after_all = token[] after-all() + send_ctx_ = (f32[16,16], u32[], token[]) send(some_arg, after_all), + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, + control-predecessors={send_done} + recv_ctx_ = (f32[16,16], u32[], token[]) recv(after_all), + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, + control-predecessors={recv_done} + c1 = u32[] constant(1) + i_ = add(i, c1) + ROOT tuple_ = ((f32[16,16], u32[], token[]), (f32[16,16], u32[], token[]), + f32[16,16], u32[]) tuple(send_ctx_, recv_ctx_, some_res, i_) + } + + + ENTRY main { + some_arg = f32[16,16] parameter(0) + after_all = token[] after-all() + send_ctx = (f32[16,16], u32[], token[]) send(some_arg, after_all), + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + recv_ctx = (f32[16,16], u32[], token[]) recv(after_all), + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + c0 = u32[] constant(0) + tuple = ((f32[16,16], u32[], token[]), (f32[16,16], u32[], token[]), + f32[16,16], u32[]) + tuple(send_ctx, recv_ctx, some_arg, c0) + tuple_ = ((f32[16,16], u32[], token[]), (f32[16,16], u32[], token[]), + f32[16,16], u32[]) + while(tuple), body=while_body, condition=while_condition + send_ctx_ = (f32[16,16], u32[], token[]) get-tuple-element(tuple_), index=0 + recv_ctx_ = (f32[16,16], u32[], token[]) get-tuple-element(tuple_), index=1 + recv_done = (f32[16], token[]) recv-done(recv_ctx_), frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + send_done = token[] send-done(send_ctx_), frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + } + )"; + + absl::string_view kFdoProfile = ""; + auto config = GetModuleConfig( + kFdoProfile, /*enable_experimental_pipeline_parallelism_opt=*/true); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloModule, config)); + + TF_EXPECT_OK( + ScheduleModule(module.get(), /*num_parallel_resources=*/2, + /*strictness=*/DebugOptions::PGLE_STRICTNESS_LEVEL_OFF)); + auto schedule = module->schedule(); + VLOG(3) << module->schedule().ToString(); + + // Expect send/recv and send/recv-done to be scheduled late so that they + // appear at the top of the while loop body. This is to ensure their execution + // overlaps with the present compute. + HloComputation* while_body = FindComputation(module.get(), "while_body"); + std::vector while_body_instrs = + schedule.sequence(while_body).instructions(); + + // Expect: `recv_ctx` -> `recv_done` -> `recv_ctx_` -> `some_res` + EXPECT_LT(GetIndexByName(while_body_instrs, "recv_ctx"), + GetIndexByName(while_body_instrs, "recv_done")); + EXPECT_LT(GetIndexByName(while_body_instrs, "recv_done"), + GetIndexByName(while_body_instrs, "recv_ctx_")); + EXPECT_LT(GetIndexByName(while_body_instrs, "recv_ctx_"), + GetIndexByName(while_body_instrs, "some_res")); + + // Expect: `send_ctx` -> `send_done` -> `send_ctx_` -> `some_res` + EXPECT_LT(GetIndexByName(while_body_instrs, "send_ctx"), + GetIndexByName(while_body_instrs, "send_done")); + EXPECT_LT(GetIndexByName(while_body_instrs, "send_done"), + GetIndexByName(while_body_instrs, "send_ctx_")); + EXPECT_LT(GetIndexByName(while_body_instrs, "send_ctx_"), + GetIndexByName(while_body_instrs, "some_res")); +} + } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc index 6520f8dbb0555c..551110d70dbc0d 100644 --- a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc @@ -31,9 +31,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/gpu/hlo_to_ir_bindings.cc b/third_party/xla/xla/service/gpu/hlo_to_ir_bindings.cc index 80f47db4994df5..66931a94c992ba 100644 --- a/third_party/xla/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/third_party/xla/xla/service/gpu/hlo_to_ir_bindings.cc @@ -96,8 +96,8 @@ void HloToIrBindings::EmitBasePointersForHlos( << llvm_ir::ConstantHloToGlobalName(*non_io_hlo); BindHloToIrValue(*non_io_hlo, global_for_constant); } else { - llvm::Type* pointee_type = - llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_); + llvm::Type* pointee_type = llvm_ir::ShapeToIrType( + non_io_hlo->shape(), module_->getContext()); BindHloToIrValue(*non_io_hlo, llvm_ir::EmitAllocaAtFunctionEntry( pointee_type, /*name=*/"", b_), @@ -128,7 +128,8 @@ llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, llvm::Value* base_ptr = GetBasePointer(hlo, shape_index); Shape new_shape = ShapeUtil::GetSubshape(hlo.shape(), shape_index); - llvm::Type* pointee_type = llvm_ir::ShapeToIrType(new_shape, module_); + llvm::Type* pointee_type = + llvm_ir::ShapeToIrType(new_shape, module_->getContext()); CHECK_NE(base_ptr, nullptr) << "Buffer not assigned for shape_index " << shape_index.ToString() << " of " << hlo.ToString(); diff --git a/third_party/xla/xla/service/gpu/ir_emitter.cc b/third_party/xla/xla/service/gpu/ir_emitter.cc index bcfac22d9c900d..f0587d8ec10110 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter.cc @@ -93,7 +93,8 @@ absl::Status IrEmitter::HandleGetTupleElement( // TODO(b/26344050): tighten the alignment here // based on the real element type. /*alignment=*/1, GetBasePointer(*operand), - llvm_ir::ShapeToIrType(operand->shape(), module_), &b_)); + llvm_ir::ShapeToIrType(operand->shape(), module_->getContext()), + &b_)); return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/ir_emitter_nested.cc b/third_party/xla/xla/service/gpu/ir_emitter_nested.cc index 4d96a0cc14aed2..149bddb5bbe222 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_nested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_nested.cc @@ -203,12 +203,13 @@ absl::StatusOr IrEmitterNested::CodegenNestedComputation() { if (ShapeUtil::IsScalar(return_shape)) { llvm::Value* ret_value = - Load(llvm_ir::ShapeToIrType(return_shape, module_), root_value, - "load_ret_value"); + Load(llvm_ir::ShapeToIrType(return_shape, module_->getContext()), + root_value, "load_ret_value"); Store(ret_value, out_parameter); } else { CHECK(return_shape.IsTuple()); - llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_); + llvm::Type* tuple_type = + llvm_ir::ShapeToIrType(return_shape, module_->getContext()); for (int i = 0; i < return_shape.tuple_shapes_size(); i++) { const Shape& element_shape = return_shape.tuple_shapes(i); @@ -220,8 +221,11 @@ absl::StatusOr IrEmitterNested::CodegenNestedComputation() { element_shape, /*index=*/i, /*alignment=*/1, root_value, - llvm_ir::ShapeToIrType(root_instruction->shape(), module_), &b_); - Store(Load(llvm_ir::ShapeToIrType(element_shape, module_), source), + llvm_ir::ShapeToIrType(root_instruction->shape(), + module_->getContext()), + &b_); + Store(Load(llvm_ir::ShapeToIrType(element_shape, module_->getContext()), + source), destination); } } @@ -347,8 +351,8 @@ absl::StatusOr> CallNestedComputationWithScalarAddrs( const HloComputation& computation, absl::Span parameter_elements_addrs) { const Shape& return_shape = computation.root_instruction()->shape(); - llvm::Type* return_buffer_type = llvm_ir::ShapeToIrType( - return_shape, builder->GetInsertBlock()->getModule()); + llvm::Type* return_buffer_type = + llvm_ir::ShapeToIrType(return_shape, builder->getContext()); llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( return_buffer_type, "return_buffer", builder); diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index c62df25477e365..7e95b23fe29847 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -142,6 +142,7 @@ limitations under the License. #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/nccl_group_thunk.h" #include "xla/service/gpu/runtime/nccl_p2p_thunk_common.h" +#include "xla/service/gpu/runtime/nccl_ragged_all_to_all_thunk.h" #include "xla/service/gpu/runtime/nccl_recv_thunk.h" #include "xla/service/gpu/runtime/nccl_send_thunk.h" #include "xla/service/gpu/runtime/norm_thunk.h" @@ -231,7 +232,12 @@ absl::Status IrEmitterUnnested::EmitConditional(const HloInstruction* instr) { for (auto comp : instr->branch_computations()) { auto ir_emitter = IrEmitterUnnested::Create(ir_emitter_context_); TF_RETURN_IF_ERROR(ir_emitter->EmitHloComputation(comp)); - branch_thunks.push_back(ir_emitter->ConsumeThunkSequence()); + Thunk::ThunkInfo branch_thunk_info = + Thunk::ThunkInfo::WithProfileAnnotation(instr); + branch_thunk_info.profile_annotation += + absl::StrCat("_branch_", comp->name()); + branch_thunks.push_back( + ir_emitter->ConsumeThunkSequence(branch_thunk_info)); } ConditionalThunkConfig config = @@ -1420,7 +1426,7 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall( ir_emitter_context_->gpu_device_info(), block_level_parameters, triton_module.get(), ir_emitter_context_->llvm_module(), mlir_context, - emit_kernels)); + /*is_xla_fusion=*/false, emit_kernels)); TF_ASSIGN_OR_RETURN( auto kernel_arguments, @@ -1871,7 +1877,15 @@ absl::Status IrEmitterUnnested::EmitNcclThunk( // A given collective op can be degenerate if across all groups formed // by it are singleton. In such a case, we don't need to do any communication // and we can just copy the input to the output. - bool is_degenerate = GetNcclCollectiveConfig(inst, use_global_device_ids) + // + // The only exception is RaggedAllToAll, which is not degenerate even if + // all groups are singleton. In a singleton group case, RaggedAllToAll becomes + // a generic equivalent of DynamicUpdateSlice, except update size is not + // statically known. This operation can not be expressed in term of standard + // HLO instructions, so the best solution we have is to use NCCL thunk even + // for degenerate cases. + bool is_degenerate = kind != Thunk::Kind::kNcclRaggedAllToAll && + GetNcclCollectiveConfig(inst, use_global_device_ids) .IsDegenerate(replica_count, partition_count); absl::Status implementable_status = NcclThunkType::CheckImplementable(inst, replica_count, partition_count); @@ -1911,7 +1925,34 @@ absl::Status IrEmitterUnnested::EmitNcclThunk( src_shape.layout().memory_space(), dst, dst_shape.layout().memory_space()); } - + } else if (kind == Thunk::Kind::kNcclRaggedAllToAll) { + // RaggedAllToAll operation has 6 operands: input, output, input_offset, + // send_size, output_offset, recv_size. + const Shape& input_shape = inst->operand(0)->shape(); + const Shape& result_shape = inst->shape(); + TF_ASSIGN_OR_RETURN(auto input_buffer, + GetAllocationSliceForHlo(inst->operand(0))); + TF_ASSIGN_OR_RETURN(auto result_buffer, GetAllocationSliceForHlo(inst)); + add_buffer(ShapeUtil::ElementsIn(input_shape), input_buffer, + input_shape.layout().memory_space(), result_buffer, + result_shape.layout().memory_space()); + + const Shape& output_shape = inst->operand(1)->shape(); + TF_ASSIGN_OR_RETURN(auto output_buffer, + GetAllocationSliceForHlo(inst->operand(1))); + + add_buffer(ShapeUtil::ElementsIn(result_shape), output_buffer, + output_shape.layout().memory_space(), output_buffer, + output_shape.layout().memory_space()); + + for (int64_t i = 2; i < operand_count; i++) { + const Shape& shape = inst->operand(i)->shape(); + TF_ASSIGN_OR_RETURN(auto slice, + GetAllocationSliceForHlo(inst->operand(i))); + add_buffer(ShapeUtil::ElementsIn(shape), slice, + shape.layout().memory_space(), slice, + shape.layout().memory_space()); + } } else { // For other operations simply zip operands with results. for (int64_t i = 0; i < operand_count; i++) { @@ -2070,34 +2111,28 @@ static const HloInstruction* FindCanonicalSendRecvStartOp( return canonical_start_op; } -absl::Status IrEmitterUnnested::EmitNcclGroupThunk(const HloInstruction* instr, - Thunk::Kind kind) { +absl::Status IrEmitterUnnested::EmitNcclGroupStartThunk( + const HloInstruction* instr) { emit_group_thunks_ = true; - for (const HloInstruction* instr : + std::optional stream_kind; + for (const HloInstruction* nested_instruction : instr->async_wrapped_computation()->instructions()) { - if (kind == Thunk::Kind::kNcclGroupStart) { - TF_RETURN_IF_ERROR(EmitHloInstruction(instr)); - } else { - // For kNcclGroupDone, we only need to emit the corresponding async done - // instructions. For now, only send/recv is supported. - switch (instr->opcode()) { - case HloOpcode::kSend: - TF_RETURN_IF_ERROR( - EmitNcclAsyncDone(Thunk::Kind::kNcclSendDone, instr)); - break; - case HloOpcode::kRecv: - TF_RETURN_IF_ERROR( - EmitNcclAsyncDone(Thunk::Kind::kNcclRecvDone, instr)); - break; - default: - break; - } + TF_RETURN_IF_ERROR(EmitHloInstruction(nested_instruction)); + if ((nested_instruction->opcode() == HloOpcode::kSend || + nested_instruction->opcode() == HloOpcode::kRecv) && + !stream_kind.has_value()) { + // We only need to modify the stream kind once, since all send/recv + // instructions in a group should have the same stream kind. + stream_kind = GetStreamKindForSendRecv( + Cast(nested_instruction)); } } auto thunk = std::make_unique( - instr, kind, std::move(scoped_thunk_sequence_)); - // TODO (rosiezou): use absl cleanup to automatically reset this boolean. + instr, Thunk::Kind::kNcclGroupStart, std::move(scoped_thunk_sequence_), + stream_kind.value_or(AsyncStreamKind::kCollective)); emit_group_thunks_ = false; + + GetCollectivesAsyncEvents().insert({instr, thunk->async_events()}); AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); } @@ -2233,9 +2268,17 @@ absl::StatusOr> IrEmitterUnnested::BuildWhileThunk( TF_ASSIGN_OR_RETURN( auto pred, GetAllocationSliceForHlo(condition->root_instruction(), {})); + Thunk::ThunkInfo cond_thunk_info = + Thunk::ThunkInfo::WithProfileAnnotation(instr); + cond_thunk_info.profile_annotation += "_condition"; + Thunk::ThunkInfo body_thunk_info = + Thunk::ThunkInfo::WithProfileAnnotation(instr); + body_thunk_info.profile_annotation += "_body"; + return std::unique_ptr(new WhileThunk( - thunk_info, pred, ir_emitter_condition->ConsumeThunkSequence(), - ir_emitter_body->ConsumeThunkSequence(), trip_count)); + thunk_info, pred, + ir_emitter_condition->ConsumeThunkSequence(cond_thunk_info), + ir_emitter_body->ConsumeThunkSequence(body_thunk_info), trip_count)); } absl::Status IrEmitterUnnested::EmitTargetElementLoop( @@ -2375,8 +2418,6 @@ absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) { } else { collectives_async_events.try_emplace(instr, thunk->async_events()); } - } else { - collectives_async_events.try_emplace(instr, thunk->async_events()); } AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); @@ -2450,8 +2491,6 @@ absl::Status IrEmitterUnnested::EmitRecvThunk(const HloRecvInstruction* instr) { } else { collectives_async_events.try_emplace(instr, thunk->async_events()); } - } else { - collectives_async_events.try_emplace(instr, thunk->async_events()); } AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); @@ -2511,7 +2550,7 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( case HloOpcode::kAsyncDone: { if (!instr->async_wrapped_computation() ->CanExpandIntoSingleInstruction()) { - return EmitNcclGroupThunk(instr, Thunk::kNcclGroupDone); + return EmitNcclAsyncDone(Thunk::kNcclGroupDone, instr); } const HloInstruction* wrapped = instr->async_wrapped_instruction(); switch (wrapped->opcode()) { @@ -2519,6 +2558,8 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( return EmitNcclAsyncDone(Thunk::kNcclReduceScatterDone, instr); case HloOpcode::kAllToAll: return EmitNcclAsyncDone(Thunk::kNcclAllToAllDone, instr); + case HloOpcode::kRaggedAllToAll: + return EmitNcclAsyncDone(Thunk::kNcclRaggedAllToAllDone, instr); case HloOpcode::kCollectiveBroadcast: return EmitNcclAsyncDone(Thunk::kNcclCollectiveBroadcastDone, instr); case HloOpcode::kFusion: @@ -2544,7 +2585,7 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( // Multi-op async start will emit a NCCL group thunk. if (!instr->async_wrapped_computation() ->CanExpandIntoSingleInstruction()) { - return EmitNcclGroupThunk(instr, Thunk::kNcclGroupStart); + return EmitNcclGroupStartThunk(instr); } const HloInstruction* wrapped = instr->async_wrapped_instruction(); switch (wrapped->opcode()) { @@ -2560,6 +2601,13 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( return EmitNcclThunk( Thunk::kNcclAllToAll, instr, all_to_all, std::nullopt); } + case HloOpcode::kRaggedAllToAll: { + auto* ragged_all_to_all = Cast(wrapped); + return EmitNcclThunk( + Thunk::kNcclRaggedAllToAll, instr, ragged_all_to_all, + std::nullopt); + } case HloOpcode::kCollectiveBroadcast: { auto* collective_broadcast = Cast(wrapped); diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h index 756166bfc7eed1..cc4e281a48de92 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h @@ -89,8 +89,9 @@ class IrEmitterUnnested : public IrEmitter { IrEmitterContext* ir_emitter_context); // Transfers the ownship of thunk_sequence_ out. - std::unique_ptr ConsumeThunkSequence() { - return std::make_unique(Thunk::ThunkInfo{}, + std::unique_ptr ConsumeThunkSequence( + Thunk::ThunkInfo thunk_info = Thunk::ThunkInfo{}) { + return std::make_unique(thunk_info, std::move(thunk_sequence_)); } @@ -166,8 +167,7 @@ class IrEmitterUnnested : public IrEmitter { absl::Status EmitHloInstruction(const HloInstruction* instr); - absl::Status EmitNcclGroupThunk(const HloInstruction* instr, - Thunk::Kind kind); + absl::Status EmitNcclGroupStartThunk(const HloInstruction* instr); absl::Status EmitTargetElementLoop( const HloInstruction& hlo, diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 1fa91f7b2084de..6b0ee17ab3213e 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -2,7 +2,6 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load("//xla:xla.bzl", "xla_cc_binary") load("//xla/service/gpu:build_defs.bzl", "gpu_kernel_library") -load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") load("//xla/tests:build_defs.bzl", "DEFAULT_DISABLED_BACKENDS", "xla_test") load("//xla/tsl:tsl.bzl", "if_windows") load( @@ -32,7 +31,9 @@ cc_library( "//xla/stream_executor:device_description", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:logging", @@ -65,6 +66,7 @@ cc_library( "//xla/stream_executor:kernel_spec", "//xla/stream_executor:launch_dim", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", ], ) @@ -127,11 +129,11 @@ xla_test( "//xla:literal_util", "//xla:types", "//xla:xla_data_proto_cc", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", @@ -151,18 +153,20 @@ cc_library( "//xla:shape_util", "//xla:types", "//xla:xla_data_proto_cc", + "//xla/stream_executor:device_memory", "//xla/stream_executor:kernel", "//xla/stream_executor:launch_dim", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:typed_kernel_factory", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", ], ) @@ -203,6 +207,7 @@ xla_test( "@com_google_absl//absl/random", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", @@ -214,12 +219,11 @@ cc_library( name = "topk_custom_kernel", srcs = ["topk_custom_kernel.cc"], hdrs = ["topk_custom_kernel.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), + tags = ["gpu"], visibility = [":friends"], deps = [ ":custom_kernel", + ":topk_kernel_gpu", "//xla:types", "//xla:xla_data_proto_cc", "//xla/stream_executor:device_memory", @@ -231,9 +235,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:statusor", - ] + if_gpu_is_configured([ - ":topk_kernel_gpu", - ]), + ], ) xla_test( @@ -308,6 +310,7 @@ xla_test( "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/cuda:cuda_platform", "//xla/tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -523,6 +526,7 @@ cc_library( "//xla/stream_executor:launch_dim", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ], ) @@ -540,6 +544,8 @@ xla_test( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/cuda:cuda_platform", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/service/gpu/kernels/custom_kernel.cc b/third_party/xla/xla/service/gpu/kernels/custom_kernel.cc index 47cb849c611bcc..eca174f840cc5d 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_kernel.cc +++ b/third_party/xla/xla/service/gpu/kernels/custom_kernel.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include #include -#include #include #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" @@ -50,7 +50,7 @@ CustomKernel::CustomKernel(std::string name, cluster_dims_(cluster_dims), shared_memory_bytes_(shared_memory_bytes) {} -std::string_view CustomKernel::name() const { return name_; } +absl::string_view CustomKernel::name() const { return name_; } const se::MultiKernelLoaderSpec& CustomKernel::kernel_spec() const { return kernel_spec_; diff --git a/third_party/xla/xla/service/gpu/kernels/custom_kernel.h b/third_party/xla/xla/service/gpu/kernels/custom_kernel.h index 433f43f38ce49c..a6e6eb5b7353fc 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_kernel.h +++ b/third_party/xla/xla/service/gpu/kernels/custom_kernel.h @@ -19,8 +19,8 @@ limitations under the License. #include #include #include -#include +#include "absl/strings/string_view.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" @@ -53,7 +53,7 @@ class CustomKernel { se::BlockDim block_dims, se::ThreadDim thread_dims, se::ClusterDim cluster_dims, size_t shared_memory_bytes); - std::string_view name() const; + absl::string_view name() const; const se::MultiKernelLoaderSpec& kernel_spec() const; diff --git a/third_party/xla/xla/service/gpu/kernels/custom_kernel_fusion.cc b/third_party/xla/xla/service/gpu/kernels/custom_kernel_fusion.cc index 3132ae44c709ba..88039dd467ae6b 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_kernel_fusion.cc +++ b/third_party/xla/xla/service/gpu/kernels/custom_kernel_fusion.cc @@ -17,12 +17,12 @@ limitations under the License. #include #include -#include #include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" namespace xla::gpu { @@ -46,7 +46,7 @@ absl::Status CustomKernelFusionRegistry::Register( } CustomKernelFusion* CustomKernelFusionRegistry::Lookup( - std::string_view name) const { + absl::string_view name) const { absl::MutexLock lock(&mutex_); if (auto it = registry_.find(name); it != registry_.end()) return it->second.get(); diff --git a/third_party/xla/xla/service/gpu/kernels/custom_kernel_fusion.h b/third_party/xla/xla/service/gpu/kernels/custom_kernel_fusion.h index ae5cb3e51dd947..2b12cc4d7557c0 100644 --- a/third_party/xla/xla/service/gpu/kernels/custom_kernel_fusion.h +++ b/third_party/xla/xla/service/gpu/kernels/custom_kernel_fusion.h @@ -18,12 +18,14 @@ limitations under the License. #include #include -#include #include #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/service/gpu/kernels/custom_kernel.h" @@ -126,7 +128,7 @@ class CustomKernelFusionRegistry { std::unique_ptr fusion); // Looks up custom kernel fusion by name. Return nullptr if it's not found. - CustomKernelFusion* Lookup(std::string_view name) const; + CustomKernelFusion* Lookup(absl::string_view name) const; private: mutable absl::Mutex mutex_; diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc index 124569ea5461bc..d22175e66c239e 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc @@ -76,8 +76,8 @@ static void BM_RowMajorGemm(benchmark::State& state) { custom_kernel.shared_memory_bytes()); for (auto s : state) { - TF_CHECK_OK(stream->Launch(custom_kernel.thread_dims(), - custom_kernel.block_dims(), *gemm, args)); + TF_CHECK_OK(gemm->Launch(custom_kernel.thread_dims(), + custom_kernel.block_dims(), stream.get(), args)); TF_CHECK_OK(stream->BlockHostUntilDone()); } } diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc index 7cdc9507e3e7f0..bdf61784f937b9 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/platform.h" @@ -73,8 +74,8 @@ TEST(CutlassGemmKernelTest, SimpleGemm) { se::KernelArgsDeviceMemoryArray arr( std::vector({a, b, c}), custom_kernel.shared_memory_bytes()); - TF_ASSERT_OK(stream->Launch(custom_kernel.thread_dims(), - custom_kernel.block_dims(), *gemm, arr)); + TF_ASSERT_OK(gemm->Launch(custom_kernel.thread_dims(), + custom_kernel.block_dims(), stream.get(), arr)); // Copy `c` data back to host. std::vector dst(length, -1.0f); @@ -122,8 +123,8 @@ TEST(CutlassGemmKernelTest, LoadFromSharedLibrary) { se::KernelArgsDeviceMemoryArray arr( std::vector({a, b, c}), custom_kernel->shared_memory_bytes()); - TF_ASSERT_OK(stream->Launch(custom_kernel->thread_dims(), - custom_kernel->block_dims(), *gemm, arr)); + TF_ASSERT_OK(gemm->Launch(custom_kernel->thread_dims(), + custom_kernel->block_dims(), stream.get(), arr)); // Copy `c` data back to host. std::vector dst(length, -1.0f); diff --git a/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.cc b/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.cc index 228804d0d83b0f..21e6e56b7c7113 100644 --- a/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.cc +++ b/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include #include -#include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" @@ -42,7 +42,7 @@ KernelArgsPacking(const se::Kernel &kernel, const se::KernelArgs &args) { // otherwise you will get a "CUDA_ERROR_NOT_FOUND: named symbol not found.". // E.g. `.visible .entry AddI32(...)` would have a kernel name of "AddI32". absl::StatusOr GetPtxCustomKernel(std::string kernel_name, - std::string_view ptx, + absl::string_view ptx, int num_args, se::BlockDim block_dim, se::ThreadDim thread_dim, diff --git a/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.h b/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.h index 7ebe304df9c466..d39d6ca1baae02 100644 --- a/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.h +++ b/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.h @@ -18,16 +18,16 @@ limitations under the License. #include #include -#include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/stream_executor/launch_dim.h" namespace xla::gpu::kernel { absl::StatusOr GetPtxCustomKernel(std::string kernel_name, - std::string_view ptx, + absl::string_view ptx, int num_args, se::BlockDim block_dim, se::ThreadDim thread_dim, diff --git a/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel_test.cc index bf6f650876a6ea..a916d2b91f7ac4 100644 --- a/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel_test.cc @@ -17,9 +17,10 @@ limitations under the License. #include #include -#include #include +#include +#include "absl/strings/string_view.h" #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/stream_executor/cuda/cuda_platform.h" #include "xla/stream_executor/device_memory.h" @@ -35,7 +36,7 @@ namespace xla::gpu::kernel { namespace se = ::stream_executor; -constexpr std::string_view kAddI32KernelPtx = R"( +constexpr absl::string_view kAddI32KernelPtx = R"( .version 4.0 .target sm_50 .address_size 64 @@ -101,8 +102,8 @@ TEST(PtxCustomKernelTest, GetPtxCustomKernel) { se::KernelArgsDeviceMemoryArray args( std::vector({a, b, c}), custom_kernel.shared_memory_bytes()); - TF_CHECK_OK(stream->Launch(custom_kernel.thread_dims(), - custom_kernel.block_dims(), *kernel, args)); + TF_CHECK_OK(kernel->Launch(custom_kernel.thread_dims(), + custom_kernel.block_dims(), stream.get(), args)); TF_CHECK_OK(stream->BlockHostUntilDone()); diff --git a/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel.cc b/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel.cc index 2be74bf301ee1e..a2611258acc103 100644 --- a/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel.cc +++ b/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/kernels/topk_kernel_common.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" @@ -35,14 +36,8 @@ limitations under the License. #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" -#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) -#include "xla/service/gpu/kernels/topk_kernel_common.h" -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - namespace xla::gpu::kernel::topk { -#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) - namespace { using KernelArgsPacking = se::MultiKernelLoaderSpec::KernelArgsPacking; @@ -135,16 +130,4 @@ absl::StatusOr GetTopKKernel(std::string name, } } -#else - -// Fallback implementation of creating a CustomKernel for TopK operation. -absl::StatusOr GetTopKKernel(std::string name, - PrimitiveType dtype, - size_t num_elements, size_t k, - size_t batch_size) { - return absl::InternalError("XLA compiled without CUDA support"); -} - -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - } // namespace xla::gpu::kernel::topk diff --git a/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc index 4f6f62605996a6..0f8cd08cafdc8f 100644 --- a/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include #include #include "absl/random/random.h" #include "absl/strings/ascii.h" @@ -118,8 +119,8 @@ TEST_P(TopKKernelTest, TopKFloat) { std::vector( {input_buffer, output_values, output_indices}), custom_kernel->shared_memory_bytes()); - TF_ASSERT_OK(stream->Launch(custom_kernel->thread_dims(), - custom_kernel->block_dims(), *kernel, arr)); + TF_ASSERT_OK(kernel->Launch(custom_kernel->thread_dims(), + custom_kernel->block_dims(), stream.get(), arr)); std::vector got(k); ASSERT_TRUE(stream->BlockHostUntilDone().ok()); @@ -172,8 +173,8 @@ TEST_P(TopKKernelTest, TopKPackedNegative) { std::vector( {input_buffer, output_values, output_indices}), custom_kernel->shared_memory_bytes()); - TF_ASSERT_OK(stream->Launch(custom_kernel->thread_dims(), - custom_kernel->block_dims(), *kernel, arr)); + TF_ASSERT_OK(kernel->Launch(custom_kernel->thread_dims(), + custom_kernel->block_dims(), stream.get(), arr)); std::vector got(k); ASSERT_TRUE(stream->BlockHostUntilDone().ok()); diff --git a/third_party/xla/xla/service/gpu/kernels/topk_kernel.cc b/third_party/xla/xla/service/gpu/kernels/topk_kernel.cc index 1595d823b41fd8..41ffdbfba5aee6 100644 --- a/third_party/xla/xla/service/gpu/kernels/topk_kernel.cc +++ b/third_party/xla/xla/service/gpu/kernels/topk_kernel.cc @@ -23,21 +23,23 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/numeric/bits.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "xla/primitive_util.h" #include "xla/service/gpu/kernels/topk_kernel_common.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/typed_kernel_factory.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" namespace xla::gpu { namespace { @@ -90,10 +92,10 @@ absl::Status TypedTopK(se::Stream* stream, se::DeviceMemoryBase data, size_t>::Create(executor, "topk", kernel_symbol))); - TF_RETURN_IF_ERROR(stream->ThenLaunch( - se::ThreadDim(num_threads, 1, 1), se::BlockDim(batch_size, 1, 1), - shmem_size, kernel, data_typed, num_elements, top_elements_typed, - top_indices_typed, k)); + TF_RETURN_IF_ERROR(kernel.Launch(se::ThreadDim(num_threads, 1, 1), + se::BlockDim(batch_size, 1, 1), shmem_size, + stream, data_typed, num_elements, + top_elements_typed, top_indices_typed, k)); return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/kernels/topk_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/topk_kernel_test.cc index 8ce9d80af12615..db457017091b5c 100644 --- a/third_party/xla/xla/service/gpu/kernels/topk_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/topk_kernel_test.cc @@ -23,6 +23,8 @@ limitations under the License. #include #include +#include +#include #include "absl/log/check.h" #include "absl/random/random.h" #include "absl/strings/substitute.h" diff --git a/third_party/xla/xla/service/gpu/launch_dimensions.cc b/third_party/xla/xla/service/gpu/launch_dimensions.cc index db060f1eb4b66e..401c9dd2070ad8 100644 --- a/third_party/xla/xla/service/gpu/launch_dimensions.cc +++ b/third_party/xla/xla/service/gpu/launch_dimensions.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "xla/service/platform_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" @@ -35,18 +36,38 @@ LaunchDimensions CalculateLaunchDimensions( return LaunchDimensions(); } num_elements = CeilOfRatio(num_elements, int64_t{dim_config.unroll_factor}); - const int kWarpSchedulers = 4; - int64_t threads_per_block = std::min( - gpu_device_info.threads_per_warp() * kWarpSchedulers, num_elements); - int64_t num_blocks_total = CeilOfRatio(num_elements, threads_per_block); - int64_t num_blocks_y = CeilOfRatio( - num_blocks_total, gpu_device_info.block_dim_limit().x); - int64_t num_blocks_x = CeilOfRatio(num_blocks_total, num_blocks_y); + if (xla::PlatformUtil::CanonicalPlatformName("gpu").value() == "rocm") { + int64_t threads_per_block_x = std::min( + gpu_device_info.threads_per_warp() * kWarpSchedulers, num_elements); + + int64_t num_blocks = CeilOfRatio(num_elements, threads_per_block_x); + CHECK(num_blocks < gpu_device_info.block_dim_limit().x); + + int threads_per_block_y = 1; + while ((num_blocks * threads_per_block_x) > + std::numeric_limits::max()) { + threads_per_block_x /= 2; + threads_per_block_y *= 2; + } + + return LaunchDimensions( + se::BlockDim(num_blocks, 1, 1), + se::ThreadDim(threads_per_block_x, threads_per_block_y, 1)); - return LaunchDimensions(se::BlockDim(num_blocks_x, num_blocks_y, 1), - se::ThreadDim(threads_per_block, 1, 1)); + } else { + int64_t threads_per_block = std::min( + gpu_device_info.threads_per_warp() * kWarpSchedulers, num_elements); + + int64_t num_blocks_total = CeilOfRatio(num_elements, threads_per_block); + int64_t num_blocks_y = CeilOfRatio( + num_blocks_total, gpu_device_info.block_dim_limit().x); + int64_t num_blocks_x = CeilOfRatio(num_blocks_total, num_blocks_y); + + return LaunchDimensions(se::BlockDim(num_blocks_x, num_blocks_y, 1), + se::ThreadDim(threads_per_block, 1, 1)); + } } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD index cbe802e772386a..3a1a2690e46db2 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD @@ -1,21 +1,9 @@ -load( - "@local_config_rocm//rocm:build_defs.bzl", - "if_rocm_is_configured", -) -load( - "@local_config_sycl//sycl:build_defs.bzl", - "if_sycl_is_configured", -) load("//xla:xla.bzl", "xla_cc_test") load( "//xla/tsl:tsl.bzl", "if_google", "internal_visibility", ) -load( - "//xla/tsl/platform/default:cuda_build_defs.bzl", - "if_cuda_is_configured", -) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -38,12 +26,121 @@ cc_library( hdrs = [ "gpu_backend_lib.h", ], - local_defines = if_cuda_is_configured([ - "GOOGLE_CUDA=1", - ]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), deps = [ + ":load_ir_module", + ":utils", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_proto_cc", + "//xla/service/gpu:metrics", + "//xla/service/llvm_ir:llvm_command_line_options", + "//xla/service/llvm_ir:llvm_type_conversion_util", + "//xla/stream_executor:device_description", + "//xla/stream_executor:semantic_version", + "//xla/tsl/util:env_var", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BitReader", + "@llvm-project//llvm:BitWriter", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:ObjCARC", # buildcleaner: keep + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Scalar", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:TargetParser", + "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/platform:cuda_root_path", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:random", + "@local_tsl//tsl/platform:rocm_rocdl_path", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:scoped_annotation", + "@local_tsl//tsl/profiler/lib:traceme", + ], +) + +cc_library( + name = "nvptx_backend", + srcs = [ + "nvptx_backend.cc", + ], + hdrs = [ + "nvptx_backend.h", + ], + tags = [ + "cuda-only", + "gpu", + ], + deps = [ + ":llvm_gpu_backend", ":load_ir_module", ":nvptx_libdevice_path", + "//xla:util", + "//xla:xla_proto_cc", + "//xla/service/gpu:metrics", + "//xla/service/llvm_ir:llvm_command_line_options", + "//xla/stream_executor:device_description", + "//xla/stream_executor:semantic_version", + "//xla/stream_executor/cuda:subprocess_compilation", + "@com_google_absl//absl/base", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BitReader", + "@llvm-project//llvm:BitWriter", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:NVPTXCodeGen", # buildcleaner: keep + "@llvm-project//llvm:ObjCARC", # buildcleaner: keep + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Scalar", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/profiler/lib:scoped_annotation", + "@local_tsl//tsl/profiler/lib:traceme", + ], +) + +cc_library( + name = "amdgpu_backend", + srcs = [ + "amdgpu_backend.cc", + ], + hdrs = [ + "amdgpu_backend.h", + ], + deps = [ + ":llvm_gpu_backend", + ":load_ir_module", ":utils", "//xla:status_macros", "//xla:types", @@ -65,6 +162,8 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@llvm-project//llvm:AMDGPUAsmParser", + "@llvm-project//llvm:AMDGPUCodeGen", "@llvm-project//llvm:Analysis", "@llvm-project//llvm:BitReader", "@llvm-project//llvm:BitWriter", @@ -74,7 +173,6 @@ cc_library( "@llvm-project//llvm:IRReader", "@llvm-project//llvm:Linker", "@llvm-project//llvm:MC", - "@llvm-project//llvm:NVPTXCodeGen", # buildcleaner: keep "@llvm-project//llvm:ObjCARC", # buildcleaner: keep "@llvm-project//llvm:Passes", "@llvm-project//llvm:Scalar", @@ -82,6 +180,7 @@ cc_library( "@llvm-project//llvm:Target", "@llvm-project//mlir:NVVMDialect", "@local_config_cuda//cuda:cuda_headers", + "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform:cuda_root_path", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", @@ -93,15 +192,7 @@ cc_library( "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:scoped_annotation", "@local_tsl//tsl/profiler/lib:traceme", - ] + if_cuda_is_configured([ - "//xla/stream_executor/cuda:cuda_asm_compiler", - ]) + if_rocm_is_configured([ - "@local_config_rocm//rocm:rocm_headers", - "@llvm-project//llvm:AMDGPUCodeGen", - "@llvm-project//llvm:AMDGPUAsmParser", - ]) + if_sycl_is_configured([ - "@spirv_llvm_translator//:spirv_llvm_translator", - ]), + ], ) cc_library( @@ -148,11 +239,15 @@ cc_library( ) xla_cc_test( - name = "gpu_backend_lib_test", + name = "nvptx_backend_test", size = "small", - srcs = ["gpu_backend_lib_test.cc"], + srcs = ["nvptx_backend_test.cc"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ - ":llvm_gpu_backend", + ":nvptx_backend", "//xla/stream_executor:device_description", "//xla/stream_executor:semantic_version", "//xla/tests:xla_internal_test_main", diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc new file mode 100644 index 00000000000000..71e8990cbd3c52 --- /dev/null +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc @@ -0,0 +1,533 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/llvm_gpu_backend/amdgpu_backend.h" + +#include +#include +#include +#include +#include // NOLINT +#include +#include // NOLINT +#include + +#include "absl/base/call_once.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/CGSCCPassManager.h" +#include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Bitcode/BitcodeReader.h" +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/InitializePasses.h" +#include "llvm/Linker/Linker.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/PassRegistry.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/StandardInstrumentations.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Program.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/IPO/AlwaysInliner.h" +#include "llvm/Transforms/IPO/Internalize.h" +#include "llvm/Transforms/Scalar.h" +#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "xla/service/gpu/llvm_gpu_backend/load_ir_module.h" +#include "xla/service/llvm_ir/llvm_command_line_options.h" +#include "xla/service/llvm_ir/llvm_type_conversion_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tsl/util/env_var.h" +#include "xla/util.h" +#include "xla/xla.pb.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/path.h" +#include "tsl/platform/random.h" +#include "tsl/platform/rocm_rocdl_path.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla { +namespace gpu { +namespace { + +// Inline threshold value to use in LLVM AMDGPU backend. +const int kAMDGPUInlineThreshold = 0x100000; + +// Gets the ROCm-Device-Libs filenames for a particular AMDGPU version. +std::vector GetROCDLPaths(std::string gcn_arch_name, + const std::string& rocdl_dir_path) { + // AMDGPU version-neutral bitcodes. + static std::vector* rocdl_filenames = + new std::vector( + {"opencl.bc", "ocml.bc", "ockl.bc", "oclc_finite_only_off.bc", + "oclc_daz_opt_off.bc", "oclc_correctly_rounded_sqrt_on.bc", + "oclc_unsafe_math_off.bc", "oclc_wavefrontsize64_on.bc", + "oclc_abi_version_500.bc"}); + + // Construct full path to ROCDL bitcode libraries. + std::vector result; + result.reserve(rocdl_filenames->size() + 1); + for (auto& filename : *rocdl_filenames) { + result.push_back(tsl::io::JoinPath(rocdl_dir_path, filename)); + } + + // Add AMDGPU version-specific bitcodes. + std::vector tokens = absl::StrSplit(gcn_arch_name, ':'); + std::string amdgpu_version = gcn_arch_name; + if (!tokens.empty() && tokens[0].size() >= 3) { + amdgpu_version = tokens[0].substr(3); + } + result.push_back(tsl::io::JoinPath( + rocdl_dir_path, + absl::StrCat("oclc_isa_version_", amdgpu_version, ".bc"))); + return result; +} + +struct HsacoCacheEntry { + uint64_t hash; + std::string ir; + std::string gfx; + std::vector hsaco; +}; + +struct HsacoCache { + protected: + std::vector cache; + std::mutex m_mutex; + int request_count = 0; + int hit_count = 0; + + public: + static bool Find(const std::string& ir, uint64_t& hash, + const std::string& gfx, std::vector& hsaco); + static void Add(const std::string& ir, uint64_t hash, const std::string& gfx, + const std::vector& hsaco); +}; + +static HsacoCache g_hsacoCache; // NOLINT: static/global vars forbidden + +bool HsacoCache::Find(const std::string& ir, uint64_t& hash, + const std::string& gfx, std::vector& hsaco) { + std::lock_guard lg(g_hsacoCache.m_mutex); + hash = std::hash{}(ir); + bool hit = false; + for (auto& x : g_hsacoCache.cache) { + if (x.hash != hash) continue; + if (x.gfx != gfx) continue; + if (x.ir != ir) continue; + hsaco = x.hsaco; + hit = true; + break; + } + g_hsacoCache.request_count++; + if (hit) g_hsacoCache.hit_count++; + if (!(g_hsacoCache.request_count % 50)) + VLOG(1) << "HSACO cache: " << g_hsacoCache.request_count << " requests, " + << g_hsacoCache.hit_count << " hits"; + return hit; +} + +void HsacoCache::Add(const std::string& ir, uint64_t hash, + const std::string& gfx, + const std::vector& hsaco) { + std::lock_guard lg(g_hsacoCache.m_mutex); + g_hsacoCache.cache.resize(g_hsacoCache.cache.size() + 1); + g_hsacoCache.cache.back().ir = ir; + g_hsacoCache.cache.back().hash = hash; + g_hsacoCache.cache.back().gfx = gfx; + g_hsacoCache.cache.back().hsaco = hsaco; +} + +// Emits the given module to HSA Code Object. target_machine is an initialized +// TargetMachine for the AMDGPU target. +absl::StatusOr> EmitModuleToHsaco( + llvm::Module* module, llvm::TargetMachine* target_machine) { + auto* env = tsl::Env::Default(); + std::vector tempdir_vector; + env->GetLocalTempDirectories(&tempdir_vector); + if (tempdir_vector.empty()) { + return xla::Internal( + "Unable to locate a temporary directory for compile-time artifacts."); + } + std::string tempdir_name = tempdir_vector.front(); + VLOG(1) << "Compile-time artifacts located at: " << tempdir_name; + + bool keep_tempfiles = false; + TF_CHECK_OK(tsl::ReadBoolFromEnvVar("TF_ROCM_KEEP_XLA_TEMPFILES", + /*default_val=*/false, &keep_tempfiles)); + // Prepare filenames for all stages of compilation: + // IR, binary ISA, and HSACO. + std::string random_number = std::to_string(tsl::random::New64()); + std::string ir_filename = + absl::StrCat(module->getModuleIdentifier(), random_number + ".ll"); + std::string ir_path = tsl::io::JoinPath(tempdir_name, ir_filename); + + std::string ir_opt_filename = + absl::StrCat(module->getModuleIdentifier(), random_number + "_opt.ll"); + std::string ir_opt_path = tsl::io::JoinPath(tempdir_name, ir_opt_filename); + + std::string isabin_filename = + absl::StrCat(module->getModuleIdentifier(), random_number + ".o"); + std::string isabin_path = tsl::io::JoinPath(tempdir_name, isabin_filename); + + std::string hsaco_filename = + absl::StrCat(module->getModuleIdentifier(), random_number + ".hsaco"); + std::string hsaco_path = tsl::io::JoinPath(tempdir_name, hsaco_filename); + + std::error_code ec; + + // Dump LLVM IR. + std::unique_ptr ir_fs( + new llvm::raw_fd_ostream(ir_path, ec, llvm::sys::fs::OF_None)); + module->print(*ir_fs, nullptr); + ir_fs->flush(); + + // Emit GCN ISA binary. + llvm::legacy::PassManager pm; + pm.add(new llvm::TargetLibraryInfoWrapperPass( + llvm::Triple(module->getTargetTriple()))); + llvm::SmallVector stream; + llvm::raw_svector_ostream pstream(stream); + std::unique_ptr isabin_fs( + new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text)); + module->setDataLayout(target_machine->createDataLayout()); + target_machine->addPassesToEmitFile(pm, *isabin_fs, nullptr, + llvm::CodeGenFileType::ObjectFile); + pm.run(*module); + isabin_fs->flush(); + + if (keep_tempfiles) { + std::unique_ptr ir_fs( + new llvm::raw_fd_ostream(ir_opt_path, ec, llvm::sys::fs::OF_None)); + module->print(*ir_fs, nullptr); + ir_fs->flush(); + } + // Locate lld. + std::string lld_path; + if (std::getenv("LLVM_PATH")) { + lld_path = tsl::io::JoinPath(std::getenv("LLVM_PATH"), "bin"); + } else { + lld_path = tsl::io::JoinPath(tsl::RocmRoot(), "llvm/bin"); + } + auto lld_program = llvm::sys::findProgramByName("ld.lld", {lld_path}); + if (!lld_program) { + return xla::Internal("unable to find ld.lld in PATH: %s", + lld_program.getError().message()); + } + std::vector lld_args{ + llvm_ir::AsStringRef("ld.lld"), llvm_ir::AsStringRef("-flavor"), + llvm_ir::AsStringRef("gnu"), llvm_ir::AsStringRef("-shared"), + llvm_ir::AsStringRef(isabin_path), llvm_ir::AsStringRef("-o"), + llvm_ir::AsStringRef(hsaco_path), + }; + + std::string error_message; + int lld_result = + llvm::sys::ExecuteAndWait(*lld_program, llvm_ir::AsArrayRef(lld_args), + std::nullopt, {}, 0, 0, &error_message); + if (lld_result) { + return xla::Internal("ld.lld execute fail: %s, error code %d", + error_message, lld_result); + } + + // Read HSACO. + std::ifstream hsaco_file(hsaco_path, std::ios::binary | std::ios::ate); + std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg(); + + std::vector hsaco(hsaco_file_size); + hsaco_file.seekg(0, std::ios::beg); + hsaco_file.read(reinterpret_cast(hsaco.data()), hsaco_file_size); + hsaco_file.close(); + if (!keep_tempfiles) { + remove(ir_path.c_str()); + remove(isabin_path.c_str()); + remove(hsaco_path.c_str()); + } + return hsaco; +} + +// Links ROCm-Device-Libs into the given module if the module needs it. +absl::Status LinkROCDLIfNecessary(llvm::Module* module, + std::string gcn_arch_name, + const std::string& rocdl_dir_path) { + if (!CouldNeedDeviceBitcode(*module)) { + return absl::OkStatus(); + } + + return LinkWithBitcodeVector(module, + GetROCDLPaths(gcn_arch_name, rocdl_dir_path)); +} + +absl::Status AMDGPUTargetModuleLinker( + llvm::Module* module, se::GpuComputeCapability gpu_version, + const DebugOptions& debug_options, + const std::string& device_bitcode_dir_path) { + // Link the input module with ROCDL. + + auto compute_capability = + std::get_if(&gpu_version); + if (!compute_capability) { + return xla::Internal("Incompatible compute capability was specified."); + } + + std::string gcn_arch_name = compute_capability->gcn_arch_name(); + TF_RETURN_IF_ERROR( + LinkROCDLIfNecessary(module, gcn_arch_name, device_bitcode_dir_path)); + + // If ftz is enabled, set it as an attribute on every function in the module. + if (debug_options.xla_gpu_ftz()) { + for (llvm::Function& fn : *module) { + fn.addFnAttr("denormal-fp-math-f32", "preserve-sign"); + } + } + + return absl::OkStatus(); +} + +// The following routine maps a feature token extracted from the +// hipDeviceProp_t::gcnArchName string, and maps it to a valid feature_str +// to be used for creating the AMDGPUTarget. +// This mapping is currently in a state of flux because TF XLA uses its +// own copy of LLVM, which is different from the LLVM version used by +// hipcc/runtime in the ROCm install. Ordinarily this is not a problem, +// but right now, the LLVM version used by hipcc/runtime has "targetID" +// related changes which have not yet been upstreamed (to the LLVM repo) +// When that upstreaming happens (and TF LLVM pointer moves past the +// upstream commit), the following mapping will need to change +std::string MapGCNArchNameTokenToFeatureStr(const std::string& token, + const std::string& gfx) { + if (token == "sramecc+") { + return "+sramecc"; + } else if (token == "sramecc-") { + if (gfx == "gfx90a" || gfx == "gfx940" || gfx == "gfx941" || + gfx == "gfx942") + return ""; + return "-sramecc"; + } else if (token == "xnack+") { + return "+xnack"; + } else if (token == "xnack-") { + return "-xnack"; + } + return ""; +} + +std::pair GetFeatureStrFromGCNArchName( + const std::string& gcn_arch_name) { + std::string feature_str; + + std::string gfx = gcn_arch_name; + // For ROCm versions 4.0 and greater, we need to specify the correct + // feature str, based on the underlying GPU HW to get max performance. + std::vector tokens = absl::StrSplit(gcn_arch_name, ':'); + std::vector mapped_tokens; + if (!tokens.empty()) gfx = tokens[0]; + for (auto it = tokens.begin(); it != tokens.end(); it++) { + // Skip the first token, that is the gfxNNN str + // The rest of the tokens are the feature/targetid strings + if (it != tokens.begin()) { + std::string token(*it); + std::string mapped_token = MapGCNArchNameTokenToFeatureStr(token, gfx); + mapped_tokens.push_back(mapped_token); + } + } + feature_str = absl::StrJoin(mapped_tokens, ","); + + return std::make_pair(gfx, feature_str); +} + +std::unique_ptr AMDGPUGetTargetMachine( + llvm::Triple target_triple, se::GpuComputeCapability gpu_version, + const DebugOptions& debug_options) { + auto compute_capability = + std::get_if(&gpu_version); + + std::string gcn_arch_name = compute_capability->gcn_arch_name(); + auto arch = GetFeatureStrFromGCNArchName(gcn_arch_name); + return GetTargetMachine(std::move(target_triple), arch.first, debug_options, + arch.second); +} + +// Returns the directory containing ROCm-Device-Libs files. +std::string GetROCDLDir(const DebugOptions& debug_options) { + std::vector potential_rocdl_dirs; + const std::string& datadir = debug_options.xla_gpu_cuda_data_dir(); + if (!datadir.empty()) { + potential_rocdl_dirs.push_back(datadir); + } + potential_rocdl_dirs.push_back(tsl::RocdlRoot()); + + // Tries all potential ROCDL directories in the order they are inserted. + // Returns the first directory that exists in the file system. + for (const std::string& potential_rocdl_dir : potential_rocdl_dirs) { + if (tsl::Env::Default()->IsDirectory(potential_rocdl_dir).ok()) { + VLOG(2) << "Found ROCm-Device-Libs dir " << potential_rocdl_dir; + return potential_rocdl_dir; + } + VLOG(2) << "Unable to find potential ROCm-Device-Libs dir " + << potential_rocdl_dir; + } + + // Last resort: maybe in the current folder. + return "."; +} + +void AMDGPUBackendInit(const DebugOptions& debug_options, + std::string& rocdl_dir_path) { + // Initialize the AMDGPU target; it's the only target we link with, so call + // its specific initialization functions instead of the catch-all + // InitializeAll*. + LLVMInitializeAMDGPUTarget(); + LLVMInitializeAMDGPUTargetInfo(); + LLVMInitializeAMDGPUTargetMC(); + LLVMInitializeAMDGPUAsmParser(); + LLVMInitializeAMDGPUAsmPrinter(); + + rocdl_dir_path = GetROCDLDir(debug_options); + llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry(); + gpu::InitializePasses(registry); +} + +std::vector GetAMDGPUBackendOptions( + const DebugOptions& debug_options) { + std::vector backend_llvm_opts; + + // Extra backend options must go after regular backend options in order to be + // able for the later to override the former. + auto backend_extra_llvm_opts = llvm_ir::ExtractXlaBackendExtraOptions( + debug_options.xla_backend_extra_options()); + backend_llvm_opts.insert(backend_llvm_opts.end(), + backend_extra_llvm_opts.cbegin(), + backend_extra_llvm_opts.cend()); + + return backend_llvm_opts; +} + +} // namespace + +namespace amdgpu { + +std::string LibDevicePath(std::string gcn_arch_name, + const std::string& rocdl_dir_path) { + auto libdevice_dir_paths = GetROCDLPaths(gcn_arch_name, rocdl_dir_path); + for (auto libdevice_dir_path : libdevice_dir_paths) { + if (libdevice_dir_path.find("ocml.bc")) { + return libdevice_dir_path; + } + } + return ""; +} + +absl::StatusOr> CompileToHsaco( + llvm::Module* module, se::GpuComputeCapability gpu_version, + const DebugOptions& debug_options, + const std::string& module_config_cache_key) { + static absl::once_flag backend_init_flag; + // TODO(rocm) Ideally this would be refreshed if xla_gpu_cuda_data_dir + // changes. + static std::string rocdl_dir_path; // NOLINT: static/global vars forbidden + absl::call_once(backend_init_flag, AMDGPUBackendInit, debug_options, + rocdl_dir_path); + auto llvm_opts = GetAMDGPUBackendOptions(debug_options); + llvm_ir::LLVMCommandLineOptionsLock llvm_lock(llvm_opts); + + std::vector hsaco; + std::unique_ptr target_machine; + std::string str; + llvm::raw_string_ostream stream(str); + stream << *module; + // Delete the first two lines, since they usually vary even when the rest of + // the code is the same (but verify that they are what we expect). + if (str.size() >= 13 && str.substr(0, 13) == "; ModuleID = ") { + auto pos = str.find('\n'); + if (pos != std::string::npos) str = str.substr(pos + 1); + } + if (str.size() >= 18 && str.substr(0, 18) == "source_filename = ") { + auto pos = str.find('\n'); + if (pos != std::string::npos) str = str.substr(pos + 1); + } + str += module_config_cache_key; + { + tsl::profiler::TraceMe activity( + [&] { return absl::StrCat("Compiling IR", module->getName().str()); }, + tsl::profiler::TraceMeLevel::kInfo); + XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str()); + + auto compute_capability = + std::get_if(&gpu_version); + if (!compute_capability) { + return xla::Internal("Incompatible compute capability was specified."); + } + + std::string gcn_arch_name = compute_capability->gcn_arch_name(); + + uint64_t hash; + if (HsacoCache::Find(str, hash, gcn_arch_name, hsaco)) { + VLOG(1) << "HSACO cache hit"; + return hsaco; + } + VLOG(1) << "HSACO cache miss"; + bool dump_lls = false; + if (dump_lls) { + static int hsaco_count = 0; + std::string name = "/tmp/" + std::to_string(hsaco_count) + ".ll"; + hsaco_count++; + std::ofstream ofs(name); + ofs << str; + ofs.close(); + } + + llvm::Triple default_target_triple("amdgcn--amdhsa-amdgiz"); + // Construct LLVM TargetMachine for AMDGPU. + std::unique_ptr target_machine = + AMDGPUGetTargetMachine(default_target_triple, gpu_version, + debug_options); + + // Link with ROCm-Device-Libs, and optimize the LLVM module. + TF_RETURN_IF_ERROR(gpu::LinkAndOptimizeModule( + module, gpu_version, debug_options, rocdl_dir_path, + AMDGPUTargetModuleLinker, default_target_triple, target_machine.get(), + kAMDGPUInlineThreshold)); + + // Lower optimized LLVM module to HSA code object. + TF_ASSIGN_OR_RETURN(hsaco, EmitModuleToHsaco(module, target_machine.get())); + HsacoCache::Add(str, hash, gcn_arch_name, hsaco); + } + return hsaco; +} + +} // namespace amdgpu +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.h b/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.h new file mode 100644 index 00000000000000..f44218c1677f7b --- /dev/null +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// LLVM-based compiler backend. +#ifndef XLA_SERVICE_GPU_LLVM_GPU_BACKEND_AMDGPU_BACKEND_H_ +#define XLA_SERVICE_GPU_LLVM_GPU_BACKEND_AMDGPU_BACKEND_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "llvm/IR/Module.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla.pb.h" + +namespace xla::gpu::amdgpu { +// Get path to libdevice file. +std::string LibDevicePath(std::string gcn_arch_name, + const std::string& rocdl_dir_path); +// Compiles the argument module and returns it with LLVM AMDGPU backend. +// rocdl_dir_path is the parent directory of ROCm-Device-Libs bitcode libraries. +// The contents of the module may be changed. +absl::StatusOr> CompileToHsaco( + llvm::Module* module, stream_executor::GpuComputeCapability gpu_version, + const DebugOptions& debug_options, + const std::string& module_config_cache_key); +} // namespace xla::gpu::amdgpu + +#endif // XLA_SERVICE_GPU_LLVM_GPU_BACKEND_AMDGPU_BACKEND_H_ diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 07847f5590ee67..0fb6db0211b7af 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -15,37 +15,24 @@ limitations under the License. #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" -#include -#include -#include -#include #include -#include #include -#include // NOLINT #include #include -#include #include // NOLINT #include #include #include -#include "absl/base/call_once.h" #include "absl/memory/memory.h" #include "absl/status/status.h" -#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "llvm/ADT/Any.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/Analysis/CGSCCPassManager.h" #include "llvm/Analysis/LazyCallGraph.h" #include "llvm/Analysis/LoopAnalysisManager.h" -#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/Bitcode/BitcodeWriter.h" @@ -64,71 +51,30 @@ limitations under the License. #include "llvm/Passes/PassBuilder.h" #include "llvm/Passes/StandardInstrumentations.h" #include "llvm/Support/CodeGen.h" -#include "llvm/Support/CommandLine.h" #include "llvm/Support/FileSystem.h" -#include "llvm/Support/Program.h" -#include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/IPO/Internalize.h" #include "llvm/Transforms/Scalar.h" #include "xla/service/gpu/llvm_gpu_backend/load_ir_module.h" -#include "xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h" #include "xla/service/gpu/llvm_gpu_backend/utils.h" -#include "xla/service/gpu/metrics.h" -#include "xla/service/llvm_ir/llvm_command_line_options.h" #include "xla/service/llvm_ir/llvm_type_conversion_util.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/semantic_version.h" -#include "xla/tsl/util/env_var.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/path.h" -#include "tsl/platform/random.h" -#include "tsl/platform/rocm_rocdl_path.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/scoped_annotation.h" -#include "tsl/profiler/lib/traceme.h" - -#if !defined(PLATFORM_GOOGLE) && TENSORFLOW_USE_ROCM -#include "rocm/rocm_config.h" -#endif - -#if GOOGLE_CUDA -#include "third_party/gpus/cuda/include/cuda.h" -#include "xla/stream_executor/cuda/subprocess_compilation.h" -#endif - -#if TENSORFLOW_USE_SYCL -#include "LLVMSPIRVLib.h" -#include "LLVMSPIRVOpts.h" -#endif // TENSORFLOW_USE_SYCL namespace xla { namespace gpu { -namespace { +namespace { static llvm::codegen::RegisterCodeGenFlags CGF; - -// Inline threshold value to use in LLVM AMDGPU backend. -const int kAMDGPUInlineThreshold = 0x100000; - -// Default inline threshold value to use in llvm. -const int kDefaultInlineThreshold = 1100; - -// NOLINTBEGIN: clang-diagnostic-unused-function -// Convenience function for producing a name of a temporary compilation product -// from the input filename. -std::string MakeNameForTempProduct(absl::string_view input_filename, - absl::string_view extension) { - return ReplaceFilenameExtension(tsl::io::Basename(input_filename), extension); } -// NOLINTEND: clang-diagnostic-unused-function // Initializes LLVM passes. Uses the PassRegistry mechanism. void InitializePasses(llvm::PassRegistry* pass_registry) { @@ -186,26 +132,6 @@ std::unique_ptr GetTargetMachine( llvm::codegen::getExplicitCodeModel(), codegen_opt_level)); } -// Emits the given module to PTX. target_machine is an initialized TargetMachine -// for the NVPTX target. -std::string EmitModuleToPTX(llvm::Module* module, - llvm::TargetMachine* target_machine) { - tsl::profiler::ScopedAnnotation annotation([&] { - return absl::StrFormat("XlaEmitGpuAsm:#module=%s#", - module->getName().str()); - }); - std::string ptx; - llvm::raw_string_ostream stream(ptx); - llvm::buffer_ostream pstream(stream); - llvm::legacy::PassManager pm; - pm.add(new llvm::TargetLibraryInfoWrapperPass( - llvm::Triple(module->getTargetTriple()))); - target_machine->addPassesToEmitFile(pm, pstream, nullptr, - llvm::CodeGenFileType::AssemblyFile); - pm.run(*module); - return ptx; -} - // Returns whether the module could use any device bitcode library functions. bool CouldNeedDeviceBitcode(const llvm::Module& module) { for (const llvm::Function& function : module.functions()) { @@ -254,85 +180,16 @@ absl::Status LinkWithBitcodeVector( return absl::OkStatus(); } -// Links libdevice into the given module if the module needs libdevice. -absl::Status LinkLibdeviceIfNecessary(llvm::Module* module, - const std::string& libdevice_path) { - if (!CouldNeedDeviceBitcode(*module)) { - return absl::OkStatus(); - } - - if (!tsl::Env::Default()->FileExists(libdevice_path).ok()) { - LOG(WARNING) - << "libdevice is required by this HLO module but was not found at " - << libdevice_path; - return xla::Internal("libdevice not found at %s", libdevice_path); - } - - VLOG(1) << "Linking with libdevice from: " << libdevice_path; - return LinkWithBitcodeVector(module, {libdevice_path}); -} - -absl::Status NVPTXTargetModuleLinker(llvm::Module* module, - se::GpuComputeCapability gpu_version, - const DebugOptions& debug_options, - const std::string& device_bitcode_path) { - // Link the input module with libdevice, to pull in implementations of some - // builtins. - TF_RETURN_IF_ERROR(LinkLibdeviceIfNecessary(module, device_bitcode_path)); - - // Set the flush-denormals-to-zero flag on the module so the NVVM reflect pass - // can access it. - module->addModuleFlag(llvm::Module::Override, "nvvm-reflect-ftz", - debug_options.xla_gpu_ftz()); - - // If ftz is enabled, set it as an attribute on every function in the module. - if (debug_options.xla_gpu_ftz()) { - for (llvm::Function& fn : *module) { - fn.addFnAttr("denormal-fp-math-f32", "preserve-sign"); - } - } - - return absl::OkStatus(); -} - -std::unique_ptr NVPTXGetTargetMachine( - llvm::Triple target_triple, se::CudaComputeCapability compute_capability, - const DebugOptions& debug_options) { -#ifdef GOOGLE_CUDA - absl::StatusOr runtime_cuda_version = - stream_executor::GetAsmCompilerVersion( - debug_options.xla_gpu_cuda_data_dir()); - - constexpr stream_executor::SemanticVersion kCompileTimeCudaVersion{ - CUDA_VERSION / 1000, (CUDA_VERSION / 10) % 100, CUDA_VERSION % 10}; - - auto highest_supported_cuda_version = [&] { - if (runtime_cuda_version.ok()) { - return std::min(runtime_cuda_version.value(), kCompileTimeCudaVersion); - } - - return kCompileTimeCudaVersion; - }(); - - auto ptx_version = nvptx::DetermineHighestSupportedPtxVersionFromCudaVersion( - highest_supported_cuda_version); - int highest_supported_ptx_version = - ptx_version.major() * 10 + ptx_version.minor(); - - VLOG(1) << "Targeting PTX version: " << highest_supported_ptx_version; - std::string feature_str = - absl::StrFormat("+ptx%d", highest_supported_ptx_version); +namespace { -#else - std::string feature_str; -#endif // GOOGLE_CUDA - return GetTargetMachine(target_triple, nvptx::GetSmName(compute_capability), - debug_options, feature_str); +// NOLINTBEGIN: clang-diagnostic-unused-function +// Convenience function for producing a name of a temporary compilation product +// from the input filename. +std::string MakeNameForTempProduct(absl::string_view input_filename, + absl::string_view extension) { + return ReplaceFilenameExtension(tsl::io::Basename(input_filename), extension); } - -using TargetModuleLinker = - std::function; +// NOLINTEND: clang-diagnostic-unused-function void DumpModule(const std::string output_filename, const llvm::Module* module) { std::error_code ec; @@ -383,6 +240,8 @@ auto DumpCallbackForModule(std::string module_identifier, }; } +} // namespace + absl::Status LinkAndOptimizeModule( llvm::Module* module, se::GpuComputeCapability gpu_version, const DebugOptions& debug_options, const std::string& device_bitcode_path, @@ -465,756 +324,5 @@ absl::Status LinkAndOptimizeModule( return absl::OkStatus(); } -// One-time module initializer. -// Must be called only once -- DO NOT CALL DIRECTLY. -void NVPTXBackendInit() { - // Initialize the NVPTX target; it's the only target we link with, so call its - // specific initialization functions instead of the catch-all InitializeAll*. - LLVMInitializeNVPTXTarget(); - LLVMInitializeNVPTXTargetInfo(); - LLVMInitializeNVPTXTargetMC(); - LLVMInitializeNVPTXAsmPrinter(); - - // Initialize the LLVM optimization passes. - llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry(); - InitializePasses(registry); -} - -std::vector GetNVPTXBackendOptions( - const DebugOptions& debug_options) { - // Feed all customized flags here, so we can override them with llvm_cl_opts - // without redeploy the compiler for development purpose. - std::vector backend_llvm_opts; - - // This flag tunes a threshold in branch folding. The default threshold, which - // is one, is not suitable for CUDA programs where branches are more expensive - // than for CPU programs. Setting the threshold to 2 improves the latency of - // TwoDPatchDotProductKernel_IND_3_ND_48 by over 5%, and does not affect the - // latency of other benchmarks so far. - // - // I also tried setting this threshold to other values: - // * 3-6 gives similar results as 2; - // * >6 start hurting the performance of at least dot product kernels. - // - // TODO(jingyue): The current threshold only considers the number of IR - // instructions which do not accurately reflect the true cost. We need a - // better cost model. - backend_llvm_opts.emplace_back("-bonus-inst-threshold=2"); - - // Use div.full -- it matters for some float-division heavy benchmarks. - // Using div.approx produces incorrect result for float32(max)/float32(max). - backend_llvm_opts.emplace_back("-nvptx-prec-divf32=1"); - - // SLPVectorizer is useful (vectorizes f16x2 ops) but slow. Most of the - // slowness appears to be in trying to form horizontal reductions, which don't - // exist in PTX *anyway*. Disable these. While we're here, tweak - // SLPVectorizer so it doesn't try to create large vectors -- f16x2 are the - // only vectors supported in PTX. - backend_llvm_opts.emplace_back("-slp-vectorize-hor=false"); - backend_llvm_opts.emplace_back("-slp-max-reg-size=32"); - - // Extra backend options must go after regular backend options in order to be - // able for the later to override the former. - auto backend_extra_llvm_opts = llvm_ir::ExtractXlaBackendExtraOptions( - debug_options.xla_backend_extra_options()); - backend_llvm_opts.insert(backend_llvm_opts.end(), - backend_extra_llvm_opts.cbegin(), - backend_extra_llvm_opts.cend()); - - return backend_llvm_opts; -} - -} // namespace - -namespace nvptx { - -std::string GetSmName(se::CudaComputeCapability compute_capability) { - int compute_capability_version = - compute_capability.major * 10 + compute_capability.minor; - int sm_version = 30; - // If the current compute capability isn't known, fallback to the - // most recent version before it. - int supported_versions[] = {90, 89, 87, 86, 80, 75, 72, 70, 62, - 61, 60, 53, 52, 50, 37, 35, 32, 30}; - for (int v : supported_versions) { - if (v <= compute_capability_version) { - sm_version = v; - break; - } - } - - // If the current CC isn't supported by LLVM and it is newer then - // the max supported LLVM version, do not warn about it. The end - // user can't do anything about this. E.g., PTX compiled for SM75 will - // run on SM80 too. - if (sm_version != compute_capability_version && - compute_capability_version < supported_versions[0]) { - LOG(WARNING) << "Unknown compute capability " - << compute_capability.ToString() - << ". Defaulting to telling LLVM that we're compiling for sm_" - << sm_version; - } - // On Hopper, default to sm_90a so that all instructions can be used. But - // only sm_90 is forward compatible, so don't use sm_90a with newer hardware: - // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility - std::string_view extension = - (compute_capability.major == 9 && sm_version == 90) ? "a" : ""; - return absl::StrCat("sm_", sm_version, extension); -} - -absl::StatusOr CompileToPtx( - llvm::Module* module, se::GpuComputeCapability gpu_version, - const DebugOptions& debug_options, - std::function configure_target) { - static absl::once_flag backend_init_flag; - absl::call_once(backend_init_flag, NVPTXBackendInit); - auto llvm_opts = GetNVPTXBackendOptions(debug_options); - llvm_ir::LLVMCommandLineOptionsLock llvm_lock(llvm_opts); - - std::string ptx; - std::unique_ptr target_machine; - { - tsl::profiler::TraceMe activity( - [&] { return absl::StrCat("Compiling IR:", module->getName().str()); }, - tsl::profiler::TraceMeLevel::kInfo); - XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str()); - - // If the module has no functions or globals, there's nothing to compile. - // Just return an empty string. - if (module->empty() && module->global_empty()) { - VLOG(2) << "Module '" << module->getName().str() - << "' is empty. Skipping compilation."; - return std::string(); - } - - auto compute_capability = - std::get_if(&gpu_version); - if (!compute_capability) { - return xla::Internal("Incompatible compute capability was specified."); - } - - llvm::Triple default_target_triple("nvptx64-unknown-unknown"); - // Construct LLVM TargetMachine for NVPTX. - std::unique_ptr target_machine = NVPTXGetTargetMachine( - default_target_triple, *compute_capability, debug_options); - - // Apply target machine configuration from call-back if available. - if (configure_target) { - configure_target(target_machine.get()); - } - - uint64_t start_usecs = tsl::Env::Default()->NowMicros(); - - // Link with libdevice, and optimize the LLVM module. - TF_RETURN_IF_ERROR(LinkAndOptimizeModule( - module, gpu_version, debug_options, - LibDevicePath(debug_options.xla_gpu_cuda_data_dir()), - NVPTXTargetModuleLinker, default_target_triple, target_machine.get(), - kDefaultInlineThreshold)); - - uint64_t end_usecs = tsl::Env::Default()->NowMicros(); - RecordLlvmPassesDuration(end_usecs - start_usecs); - - start_usecs = tsl::Env::Default()->NowMicros(); - - // Lower optimized LLVM module to PTX. - ptx = EmitModuleToPTX(module, target_machine.get()); - - end_usecs = tsl::Env::Default()->NowMicros(); - RecordLlvmToPtxDuration(end_usecs - start_usecs); - } - return ptx; -} - -namespace { -constexpr stream_executor::SemanticVersion kFallbackPtxVersion{6, 5, 0}; -constexpr stream_executor::SemanticVersion kMaxPtxVersion{8, 5, 0}; -} // namespace - -stream_executor::SemanticVersion -DetermineHighestSupportedPtxVersionFromCudaVersion( - stream_executor::SemanticVersion cuda_version) { - if (cuda_version < stream_executor::SemanticVersion{11, 0, 0}) { - // For everything below CUDA 11 we just fall back to PTX 6.5. - // We don't support CUDA below 11 anymore. - return kFallbackPtxVersion; - } - - // Mapping determined from - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#release-notes - // Examples: - // CUDA 11.0 -> PTX 7.0 - // CUDA 11.1 -> PTX 7.1 - // CUDA 12.0 -> PTX 8.0 - // CUDA 12.4 -> PTX 8.4 - // This versioning scheme is valid until CUDA 12.6 - if (cuda_version < stream_executor::SemanticVersion{12, 6, 0}) { - return {cuda_version.major() - 4, cuda_version.minor(), 0}; - } - - // Return maximum known PTX version. - return kMaxPtxVersion; -} -} // namespace nvptx - -namespace { - -// Gets the ROCm-Device-Libs filenames for a particular AMDGPU version. -std::vector GetROCDLPaths(std::string gcn_arch_name, - const std::string& rocdl_dir_path) { - // AMDGPU version-neutral bitcodes. - static std::vector* rocdl_filenames = - new std::vector( - {"opencl.bc", "ocml.bc", "ockl.bc", "oclc_finite_only_off.bc", - "oclc_daz_opt_off.bc", "oclc_correctly_rounded_sqrt_on.bc", - "oclc_unsafe_math_off.bc", "oclc_wavefrontsize64_on.bc", - "oclc_abi_version_500.bc"}); - - // Construct full path to ROCDL bitcode libraries. - std::vector result; - result.reserve(rocdl_filenames->size() + 1); - for (auto& filename : *rocdl_filenames) { - result.push_back(tsl::io::JoinPath(rocdl_dir_path, filename)); - } - - // Add AMDGPU version-specific bitcodes. - std::vector tokens = absl::StrSplit(gcn_arch_name, ':'); - std::string amdgpu_version = gcn_arch_name; - if (!tokens.empty() && tokens[0].size() >= 3) { - amdgpu_version = tokens[0].substr(3); - } - result.push_back(tsl::io::JoinPath( - rocdl_dir_path, - absl::StrCat("oclc_isa_version_", amdgpu_version, ".bc"))); - return result; -} - -struct HsacoCacheEntry { - uint64_t hash; - std::string ir; - std::string gfx; - std::vector hsaco; -}; - -struct HsacoCache { - protected: - std::vector cache; - std::mutex m_mutex; - int request_count = 0; - int hit_count = 0; - - public: - static bool Find(const std::string& ir, uint64_t& hash, - const std::string& gfx, std::vector& hsaco); - static void Add(const std::string& ir, uint64_t hash, const std::string& gfx, - const std::vector& hsaco); -}; - -static HsacoCache g_hsacoCache; // NOLINT: static/global vars forbidden - -bool HsacoCache::Find(const std::string& ir, uint64_t& hash, - const std::string& gfx, std::vector& hsaco) { - std::lock_guard lg(g_hsacoCache.m_mutex); - hash = std::hash{}(ir); - bool hit = false; - for (auto& x : g_hsacoCache.cache) { - if (x.hash != hash) continue; - if (x.gfx != gfx) continue; - if (x.ir != ir) continue; - hsaco = x.hsaco; - hit = true; - break; - } - g_hsacoCache.request_count++; - if (hit) g_hsacoCache.hit_count++; - if (!(g_hsacoCache.request_count % 50)) - VLOG(1) << "HSACO cache: " << g_hsacoCache.request_count << " requests, " - << g_hsacoCache.hit_count << " hits"; - return hit; -} - -void HsacoCache::Add(const std::string& ir, uint64_t hash, - const std::string& gfx, - const std::vector& hsaco) { - std::lock_guard lg(g_hsacoCache.m_mutex); - g_hsacoCache.cache.resize(g_hsacoCache.cache.size() + 1); - g_hsacoCache.cache.back().ir = ir; - g_hsacoCache.cache.back().hash = hash; - g_hsacoCache.cache.back().gfx = gfx; - g_hsacoCache.cache.back().hsaco = hsaco; -} - -// Emits the given module to HSA Code Object. target_machine is an initialized -// TargetMachine for the AMDGPU target. -absl::StatusOr> EmitModuleToHsaco( - llvm::Module* module, llvm::TargetMachine* target_machine) { - auto* env = tsl::Env::Default(); - std::vector tempdir_vector; - env->GetLocalTempDirectories(&tempdir_vector); - if (tempdir_vector.empty()) { - return xla::Internal( - "Unable to locate a temporary directory for compile-time artifacts."); - } - std::string tempdir_name = tempdir_vector.front(); - VLOG(1) << "Compile-time artifacts located at: " << tempdir_name; - - bool keep_tempfiles = false; - TF_CHECK_OK(tsl::ReadBoolFromEnvVar("TF_ROCM_KEEP_XLA_TEMPFILES", - /*default_val=*/false, &keep_tempfiles)); - // Prepare filenames for all stages of compilation: - // IR, binary ISA, and HSACO. - std::string random_number = std::to_string(tsl::random::New64()); - std::string ir_filename = - absl::StrCat(module->getModuleIdentifier(), random_number + ".ll"); - std::string ir_path = tsl::io::JoinPath(tempdir_name, ir_filename); - - std::string ir_opt_filename = - absl::StrCat(module->getModuleIdentifier(), random_number + "_opt.ll"); - std::string ir_opt_path = tsl::io::JoinPath(tempdir_name, ir_opt_filename); - - std::string isabin_filename = - absl::StrCat(module->getModuleIdentifier(), random_number + ".o"); - std::string isabin_path = tsl::io::JoinPath(tempdir_name, isabin_filename); - - std::string hsaco_filename = - absl::StrCat(module->getModuleIdentifier(), random_number + ".hsaco"); - std::string hsaco_path = tsl::io::JoinPath(tempdir_name, hsaco_filename); - - std::error_code ec; - - // Dump LLVM IR. - std::unique_ptr ir_fs( - new llvm::raw_fd_ostream(ir_path, ec, llvm::sys::fs::OF_None)); - module->print(*ir_fs, nullptr); - ir_fs->flush(); - - // Emit GCN ISA binary. - llvm::legacy::PassManager pm; - pm.add(new llvm::TargetLibraryInfoWrapperPass( - llvm::Triple(module->getTargetTriple()))); - llvm::SmallVector stream; - llvm::raw_svector_ostream pstream(stream); - std::unique_ptr isabin_fs( - new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text)); - module->setDataLayout(target_machine->createDataLayout()); - target_machine->addPassesToEmitFile(pm, *isabin_fs, nullptr, - llvm::CodeGenFileType::ObjectFile); - pm.run(*module); - isabin_fs->flush(); - - if (keep_tempfiles) { - std::unique_ptr ir_fs( - new llvm::raw_fd_ostream(ir_opt_path, ec, llvm::sys::fs::OF_None)); - module->print(*ir_fs, nullptr); - ir_fs->flush(); - } - // Locate lld. - std::string lld_path; - if (std::getenv("LLVM_PATH")) { - lld_path = tsl::io::JoinPath(std::getenv("LLVM_PATH"), "bin"); - } else { - lld_path = tsl::io::JoinPath(tsl::RocmRoot(), "llvm/bin"); - } - auto lld_program = llvm::sys::findProgramByName("ld.lld", {lld_path}); - if (!lld_program) { - return xla::Internal("unable to find ld.lld in PATH: %s", - lld_program.getError().message()); - } - std::vector lld_args{ - llvm_ir::AsStringRef("ld.lld"), llvm_ir::AsStringRef("-flavor"), - llvm_ir::AsStringRef("gnu"), llvm_ir::AsStringRef("-shared"), - llvm_ir::AsStringRef(isabin_path), llvm_ir::AsStringRef("-o"), - llvm_ir::AsStringRef(hsaco_path), - }; - - std::string error_message; - int lld_result = - llvm::sys::ExecuteAndWait(*lld_program, llvm_ir::AsArrayRef(lld_args), - std::nullopt, {}, 0, 0, &error_message); - if (lld_result) { - return xla::Internal("ld.lld execute fail: %s, error code %d", - error_message, lld_result); - } - - // Read HSACO. - std::ifstream hsaco_file(hsaco_path, std::ios::binary | std::ios::ate); - std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg(); - - std::vector hsaco(hsaco_file_size); - hsaco_file.seekg(0, std::ios::beg); - hsaco_file.read(reinterpret_cast(hsaco.data()), hsaco_file_size); - hsaco_file.close(); - if (!keep_tempfiles) { - remove(ir_path.c_str()); - remove(isabin_path.c_str()); - remove(hsaco_path.c_str()); - } - return hsaco; -} - -// Links ROCm-Device-Libs into the given module if the module needs it. -absl::Status LinkROCDLIfNecessary(llvm::Module* module, - std::string gcn_arch_name, - const std::string& rocdl_dir_path) { - if (!CouldNeedDeviceBitcode(*module)) { - return absl::OkStatus(); - } - - return LinkWithBitcodeVector(module, - GetROCDLPaths(gcn_arch_name, rocdl_dir_path)); -} - -absl::Status AMDGPUTargetModuleLinker( - llvm::Module* module, se::GpuComputeCapability gpu_version, - const DebugOptions& debug_options, - const std::string& device_bitcode_dir_path) { - // Link the input module with ROCDL. - - auto compute_capability = - std::get_if(&gpu_version); - if (!compute_capability) { - return xla::Internal("Incompatible compute capability was specified."); - } - - std::string gcn_arch_name = compute_capability->gcn_arch_name(); - TF_RETURN_IF_ERROR( - LinkROCDLIfNecessary(module, gcn_arch_name, device_bitcode_dir_path)); - - // For rocm, we always enable flush to zero. (for cuda, this is determined - // via environemnt variables). This deceision was based on the observation - // Eugene had that the AMD GPU llvm backend has not picked up the atomic add - // instructions correctly without ftz enabled. We concluded that this should - // not has major impact as the hipcc path by default enables flush to zero for - // compilation. - // If ftz is enabled, set it as an attribute on every function in the module. - if (debug_options.xla_gpu_ftz()) { - for (llvm::Function& fn : *module) { - // may be necessary for the compiler to generate atomics (confirm!) - fn.addFnAttr("denormal-fp-math-f32", "preserve-sign"); - fn.addFnAttr("amdgpu-unsafe-fp-atomics", "true"); - } - } - - return absl::OkStatus(); -} - -// The following routine maps a feature token extracted from the -// hipDeviceProp_t::gcnArchName string, and maps it to a valid feature_str -// to be used for creating the AMDGPUTarget. -// This mapping is currently in a state of flux because TF XLA uses its -// own copy of LLVM, which is different from the LLVM version used by -// hipcc/runtime in the ROCm install. Ordinarily this is not a problem, -// but right now, the LLVM version used by hipcc/runtime has "targetID" -// related changes which have not yet been upstreamed (to the LLVM repo) -// When that upstreaming happens (and TF LLVM pointer moves past the -// upstream commit), the following mapping will need to change -std::string MapGCNArchNameTokenToFeatureStr(const std::string& token, - const std::string& gfx) { - if (token == "sramecc+") { - return "+sramecc"; - } else if (token == "sramecc-") { - if (gfx == "gfx90a" || gfx == "gfx940" || gfx == "gfx941" || - gfx == "gfx942") - return ""; - return "-sramecc"; - } else if (token == "xnack+") { - return "+xnack"; - } else if (token == "xnack-") { - return "-xnack"; - } - return ""; - -} - -std::pair GetFeatureStrFromGCNArchName( - const std::string& gcn_arch_name) { - std::string feature_str; - - std::string gfx = gcn_arch_name; - // For ROCm versions 4.0 and greater, we need to specify the correct - // feature str, based on the underlying GPU HW to get max performance. - std::vector tokens = absl::StrSplit(gcn_arch_name, ':'); - std::vector mapped_tokens; - if (!tokens.empty()) gfx = tokens[0]; - for (auto it = tokens.begin(); it != tokens.end(); it++) { - // Skip the first token, that is the gfxNNN str - // The rest of the tokens are the feature/targetid strings - if (it != tokens.begin()) { - std::string token(*it); - std::string mapped_token = MapGCNArchNameTokenToFeatureStr(token, gfx); - mapped_tokens.push_back(mapped_token); - } - } - feature_str = absl::StrJoin(mapped_tokens, ","); - - return std::make_pair(gfx, feature_str); -} - -std::unique_ptr AMDGPUGetTargetMachine( - llvm::Triple target_triple, se::GpuComputeCapability gpu_version, - const DebugOptions& debug_options) { - auto compute_capability = - std::get_if(&gpu_version); - - std::string gcn_arch_name = compute_capability->gcn_arch_name(); - auto arch = GetFeatureStrFromGCNArchName(gcn_arch_name); - return GetTargetMachine(std::move(target_triple), arch.first, debug_options, - arch.second); -} - -// Returns the directory containing ROCm-Device-Libs files. -std::string GetROCDLDir(const DebugOptions& debug_options) { - std::vector potential_rocdl_dirs; - const std::string& datadir = debug_options.xla_gpu_cuda_data_dir(); - if (!datadir.empty()) { - potential_rocdl_dirs.push_back(datadir); - } - potential_rocdl_dirs.push_back(tsl::RocdlRoot()); - - // Tries all potential ROCDL directories in the order they are inserted. - // Returns the first directory that exists in the file system. - for (const std::string& potential_rocdl_dir : potential_rocdl_dirs) { - if (tsl::Env::Default()->IsDirectory(potential_rocdl_dir).ok()) { - VLOG(2) << "Found ROCm-Device-Libs dir " << potential_rocdl_dir; - return potential_rocdl_dir; - } - VLOG(2) << "Unable to find potential ROCm-Device-Libs dir " - << potential_rocdl_dir; - } - - // Last resort: maybe in the current folder. - return "."; -} - -void AMDGPUBackendInit(const DebugOptions& debug_options, - std::string& rocdl_dir_path) { - // Initialize the AMDGPU target; it's the only target we link with, so call - // its specific initialization functions instead of the catch-all - // InitializeAll*. -#if TENSORFLOW_USE_ROCM - LLVMInitializeAMDGPUTarget(); - LLVMInitializeAMDGPUTargetInfo(); - LLVMInitializeAMDGPUTargetMC(); - LLVMInitializeAMDGPUAsmParser(); - LLVMInitializeAMDGPUAsmPrinter(); -#endif - - rocdl_dir_path = GetROCDLDir(debug_options); - llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry(); - InitializePasses(registry); -} - -std::vector GetAMDGPUBackendOptions( - const DebugOptions& debug_options) { - std::vector backend_llvm_opts; - - // Extra backend options must go after regular backend options in order to be - // able for the later to override the former. - auto backend_extra_llvm_opts = llvm_ir::ExtractXlaBackendExtraOptions( - debug_options.xla_backend_extra_options()); - backend_llvm_opts.insert(backend_llvm_opts.end(), - backend_extra_llvm_opts.cbegin(), - backend_extra_llvm_opts.cend()); - - return backend_llvm_opts; -} - -} // namespace - -namespace amdgpu { - -std::string LibDevicePath(std::string gcn_arch_name, - const std::string& rocdl_dir_path) { - auto libdevice_dir_paths = GetROCDLPaths(gcn_arch_name, rocdl_dir_path); - for (auto libdevice_dir_path : libdevice_dir_paths) { - if (libdevice_dir_path.find("ocml.bc")) { - return libdevice_dir_path; - } - } - return ""; -} - -absl::StatusOr> CompileToHsaco( - llvm::Module* module, se::GpuComputeCapability gpu_version, - const DebugOptions& debug_options, - const std::string& module_config_cache_key) { - static absl::once_flag backend_init_flag; - // TODO(rocm) Ideally this would be refreshed if xla_gpu_cuda_data_dir - // changes. - static std::string rocdl_dir_path; // NOLINT: static/global vars forbidden - absl::call_once(backend_init_flag, AMDGPUBackendInit, debug_options, - rocdl_dir_path); - auto llvm_opts = GetAMDGPUBackendOptions(debug_options); - llvm_ir::LLVMCommandLineOptionsLock llvm_lock(llvm_opts); - - std::vector hsaco; - std::unique_ptr target_machine; - std::string str; - llvm::raw_string_ostream stream(str); - stream << *module; - // Delete the first two lines, since they usually vary even when the rest of - // the code is the same (but verify that they are what we expect). - if (str.size() >= 13 && str.substr(0, 13) == "; ModuleID = ") { - auto pos = str.find('\n'); - if (pos != std::string::npos) str = str.substr(pos + 1); - } - if (str.size() >= 18 && str.substr(0, 18) == "source_filename = ") { - auto pos = str.find('\n'); - if (pos != std::string::npos) str = str.substr(pos + 1); - } - str += module_config_cache_key; - { - tsl::profiler::TraceMe activity( - [&] { return absl::StrCat("Compiling IR", module->getName().str()); }, - tsl::profiler::TraceMeLevel::kInfo); - XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str()); - - auto compute_capability = - std::get_if(&gpu_version); - if (!compute_capability) { - return xla::Internal("Incompatible compute capability was specified."); - } - - std::string gcn_arch_name = compute_capability->gcn_arch_name(); - - uint64_t hash; - if (HsacoCache::Find(str, hash, gcn_arch_name, hsaco)) { - VLOG(1) << "HSACO cache hit"; - return hsaco; - } - VLOG(1) << "HSACO cache miss"; - bool dump_lls = false; - if (dump_lls) { - static int hsaco_count = 0; - std::string name = "/tmp/" + std::to_string(hsaco_count) + ".ll"; - hsaco_count++; - std::ofstream ofs(name); - ofs << str; - ofs.close(); - } - - llvm::Triple default_target_triple("amdgcn--amdhsa-amdgiz"); - // Construct LLVM TargetMachine for AMDGPU. - std::unique_ptr target_machine = - AMDGPUGetTargetMachine(default_target_triple, gpu_version, - debug_options); - - // Link with ROCm-Device-Libs, and optimize the LLVM module. - TF_RETURN_IF_ERROR(LinkAndOptimizeModule( - module, gpu_version, debug_options, rocdl_dir_path, - AMDGPUTargetModuleLinker, default_target_triple, target_machine.get(), - kAMDGPUInlineThreshold)); - - // Lower optimized LLVM module to HSA code object. - TF_ASSIGN_OR_RETURN(hsaco, EmitModuleToHsaco(module, target_machine.get())); - HsacoCache::Add(str, hash, gcn_arch_name, hsaco); - } - return hsaco; -} - -} // namespace amdgpu - -namespace { - -std::unique_ptr SPIRGetTargetMachine( - llvm::Triple target_triple, se::GpuComputeCapability gpu_version, - const DebugOptions& debug_options) { - return nullptr; -} - -absl::Status SPIRTargetModuleLinker( - llvm::Module* module, se::GpuComputeCapability gpu_version, - const DebugOptions& debug_options, - const std::string& device_bitcode_dir_path) { - return absl::OkStatus(); -} - -absl::StatusOr EmitModuleToSpir( - llvm::Module* module, se::GpuComputeCapability gpu_version, - const DebugOptions& debug_options) { -#if TENSORFLOW_USE_SYCL - SPIRV::TranslatorOpts::ExtensionsStatusMap ExtensionsStatus; - SPIRV::TranslatorOpts opts(SPIRV::VersionNumber::MaximumVersion, - ExtensionsStatus); - opts.enableAllExtensions(); // enable all SPIR-V extension first - - std::ostringstream oss; - std::string err; - bool success = llvm::writeSpirv(module, opts, oss, err); - if (!success) { - return xla::Internal("Fails to convert LLVM as SPIR-V: %s", err); - } - return oss.str(); -#else - return absl::UnimplementedError("Not implemented for SYCL"); -#endif -} - -void SPIRBackendInit() { - llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry(); - InitializePasses(registry); -} - -std::vector GetSPIRBackendOptions( - const DebugOptions& debug_options) { - std::vector backend_llvm_opts; - - backend_llvm_opts.emplace_back("-slp-vectorize-hor=false"); - backend_llvm_opts.emplace_back("-slp-min-reg-size=64"); - backend_llvm_opts.emplace_back("-slp-max-reg-size=64"); - - // Extra backend options must go after regular backend options in order to be - // able for the later to override the former. - auto backend_extra_llvm_opts = llvm_ir::ExtractXlaBackendExtraOptions( - debug_options.xla_backend_extra_options()); - backend_llvm_opts.insert(backend_llvm_opts.end(), - backend_extra_llvm_opts.cbegin(), - backend_extra_llvm_opts.cend()); - - return backend_llvm_opts; -} - -} // namespace - -namespace spir { - -absl::StatusOr> CompileToSpir( - llvm::Module* module, se::GpuComputeCapability gpu_version, - const DebugOptions& debug_options) { - std::string libdevice_dir_path; - static absl::once_flag backend_init_flag; - absl::call_once(backend_init_flag, SPIRBackendInit); - auto llvm_opts = GetSPIRBackendOptions(debug_options); - llvm_ir::LLVMCommandLineOptionsLock llvm_lock(llvm_opts); - - std::string spir; - { - XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str()); - - // If the module has no functions or globals, there's nothing to compile. - if (module->empty() && module->global_empty()) { - VLOG(2) << "Module '" << module->getName().str() - << "' is empty. Skipping compilation."; - return std::vector(); - } - - llvm::Triple default_target_triple("spir64-unknown-unknown"); - std::unique_ptr target_machine = - SPIRGetTargetMachine(default_target_triple, gpu_version, debug_options); - - TF_RETURN_IF_ERROR(LinkAndOptimizeModule( - module, gpu_version, debug_options, libdevice_dir_path, - SPIRTargetModuleLinker, default_target_triple, target_machine.get(), - kDefaultInlineThreshold)); - - // Lower optimized LLVM module to SPIR. - TF_ASSIGN_OR_RETURN(spir, - EmitModuleToSpir(module, gpu_version, debug_options)); - } - return std::vector(spir.begin(), spir.end()); -} - -} // namespace spir - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h index a93a1d3e1590de..39fe8b4eb944e8 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h @@ -17,70 +17,49 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_ #define XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_ -#include #include +#include #include -#include #include #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "llvm/IR/Module.h" +#include "llvm/PassRegistry.h" #include "llvm/Target/TargetMachine.h" +#include "llvm/TargetParser/Triple.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/semantic_version.h" #include "xla/xla.pb.h" namespace xla { namespace gpu { -namespace nvptx { +// Initializes LLVM passes. Uses the PassRegistry mechanism. +void InitializePasses(llvm::PassRegistry* pass_registry); -// Gets the GPU name as it's known to LLVM for a given compute -// capability. If we see an unrecognized compute capability, we -// return the highest one that is known and below the selected device. -std::string GetSmName( - stream_executor::CudaComputeCapability compute_capability); +// Returns the TargetMachine, given a triple. +std::unique_ptr GetTargetMachine( + llvm::Triple triple, absl::string_view cpu_name, + const DebugOptions& debug_options, absl::string_view feature_str); -// Compiles the argument module and returns it. libdevice_dir_path is the parent -// directory of the libdevice bitcode libraries. The contents of the module may -// be changed. -// -// The Compile.* interfaces each create their own llvm::LLVMContext objects for -// thread safety, but note that LLVM's multithreaded support is very -// preliminary; multithreaded use is not recommended at this time. -absl::StatusOr CompileToPtx( - llvm::Module* module, stream_executor::GpuComputeCapability gpu_version, - const DebugOptions& debug_options, - std::function configure_target = nullptr); - -// Determine PTX version from CUDA version. -stream_executor::SemanticVersion -DetermineHighestSupportedPtxVersionFromCudaVersion( - stream_executor::SemanticVersion cuda_version); +// Returns whether the module could use any device bitcode library functions. +bool CouldNeedDeviceBitcode(const llvm::Module& module); -} // namespace nvptx +// Links the module with a vector of path to bitcode modules. +// The caller must guarantee that the paths exist. +absl::Status LinkWithBitcodeVector( + llvm::Module* module, const std::vector& bitcode_path_vector); -namespace amdgpu { -// Get path to libdevice file. -std::string LibDevicePath(std::string gcn_arch_name, - const std::string& rocdl_dir_path); -// Compiles the argument module and returns it with LLVM AMDGPU backend. -// rocdl_dir_path is the parent directory of ROCm-Device-Libs bitcode libraries. -// The contents of the module may be changed. -absl::StatusOr> CompileToHsaco( - llvm::Module* module, stream_executor::GpuComputeCapability gpu_version, - const DebugOptions& debug_options, - const std::string& module_config_cache_key); -} // namespace amdgpu +using TargetModuleLinker = std::function; -namespace spir { -// Compiles the argument module and returns it. -absl::StatusOr> CompileToSpir( +// Links and optimizes the module. +absl::Status LinkAndOptimizeModule( llvm::Module* module, stream_executor::GpuComputeCapability gpu_version, - const DebugOptions& debug_options); -} // namespace spir + const DebugOptions& debug_options, const std::string& device_bitcode_path, + TargetModuleLinker module_linker, llvm::Triple default_target_triple, + llvm::TargetMachine* target_machine, int inline_threshold); } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.cc new file mode 100644 index 00000000000000..9b0f94cc3e9f05 --- /dev/null +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.cc @@ -0,0 +1,361 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/llvm_gpu_backend/nvptx_backend.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/call_once.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "llvm/Analysis/CGSCCPassManager.h" +#include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Bitcode/BitcodeReader.h" +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/InitializePasses.h" +#include "llvm/Linker/Linker.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/PassRegistry.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/StandardInstrumentations.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/IPO/AlwaysInliner.h" +#include "llvm/Transforms/IPO/Internalize.h" +#include "llvm/Transforms/Scalar.h" +#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "xla/service/gpu/llvm_gpu_backend/load_ir_module.h" +#include "xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h" +#include "xla/service/gpu/metrics.h" +#include "xla/service/llvm_ir/llvm_command_line_options.h" +#include "xla/stream_executor/cuda/subprocess_compilation.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/semantic_version.h" +#include "xla/util.h" +#include "xla/xla.pb.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/profiler/lib/scoped_annotation.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla::gpu::nvptx { + +namespace { + +// Default inline threshold value to use in llvm. +const int kDefaultInlineThreshold = 1100; + +// Emits the given module to PTX. target_machine is an initialized TargetMachine +// for the NVPTX target. +std::string EmitModuleToPTX(llvm::Module* module, + llvm::TargetMachine* target_machine) { + tsl::profiler::ScopedAnnotation annotation([&] { + return absl::StrFormat("XlaEmitGpuAsm:#module=%s#", + module->getName().str()); + }); + std::string ptx; + llvm::raw_string_ostream stream(ptx); + llvm::buffer_ostream pstream(stream); + llvm::legacy::PassManager pm; + pm.add(new llvm::TargetLibraryInfoWrapperPass( + llvm::Triple(module->getTargetTriple()))); + target_machine->addPassesToEmitFile(pm, pstream, nullptr, + llvm::CodeGenFileType::AssemblyFile); + pm.run(*module); + return ptx; +} + +// Links libdevice into the given module if the module needs libdevice. +absl::Status LinkLibdeviceIfNecessary(llvm::Module* module, + const std::string& libdevice_path) { + if (!CouldNeedDeviceBitcode(*module)) { + return absl::OkStatus(); + } + + if (!tsl::Env::Default()->FileExists(libdevice_path).ok()) { + LOG(WARNING) + << "libdevice is required by this HLO module but was not found at " + << libdevice_path; + return xla::Internal("libdevice not found at %s", libdevice_path); + } + + VLOG(1) << "Linking with libdevice from: " << libdevice_path; + return LinkWithBitcodeVector(module, {libdevice_path}); +} + +absl::Status NVPTXTargetModuleLinker(llvm::Module* module, + se::GpuComputeCapability gpu_version, + const DebugOptions& debug_options, + const std::string& device_bitcode_path) { + // Link the input module with libdevice, to pull in implementations of some + // builtins. + TF_RETURN_IF_ERROR(LinkLibdeviceIfNecessary(module, device_bitcode_path)); + + // Set the flush-denormals-to-zero flag on the module so the NVVM reflect pass + // can access it. + module->addModuleFlag(llvm::Module::Override, "nvvm-reflect-ftz", + debug_options.xla_gpu_ftz()); + + // If ftz is enabled, set it as an attribute on every function in the module. + if (debug_options.xla_gpu_ftz()) { + for (llvm::Function& fn : *module) { + fn.addFnAttr("denormal-fp-math-f32", "preserve-sign"); + } + } + + return absl::OkStatus(); +} + +std::unique_ptr NVPTXGetTargetMachine( + llvm::Triple target_triple, se::CudaComputeCapability compute_capability, + const DebugOptions& debug_options) { + absl::StatusOr runtime_cuda_version = + stream_executor::GetAsmCompilerVersion( + debug_options.xla_gpu_cuda_data_dir()); + + constexpr stream_executor::SemanticVersion kCompileTimeCudaVersion{ + CUDA_VERSION / 1000, (CUDA_VERSION / 10) % 100, CUDA_VERSION % 10}; + + auto highest_supported_cuda_version = [&] { + if (runtime_cuda_version.ok()) { + return std::min(runtime_cuda_version.value(), kCompileTimeCudaVersion); + } + + return kCompileTimeCudaVersion; + }(); + + auto ptx_version = nvptx::DetermineHighestSupportedPtxVersionFromCudaVersion( + highest_supported_cuda_version); + int highest_supported_ptx_version = + ptx_version.major() * 10 + ptx_version.minor(); + + VLOG(1) << "Targeting PTX version: " << highest_supported_ptx_version; + std::string feature_str = + absl::StrFormat("+ptx%d", highest_supported_ptx_version); + + return GetTargetMachine(target_triple, nvptx::GetSmName(compute_capability), + debug_options, feature_str); +} + +// One-time module initializer. +// Must be called only once -- DO NOT CALL DIRECTLY. +void NVPTXBackendInit() { + // Initialize the NVPTX target; it's the only target we link with, so call its + // specific initialization functions instead of the catch-all InitializeAll*. + LLVMInitializeNVPTXTarget(); + LLVMInitializeNVPTXTargetInfo(); + LLVMInitializeNVPTXTargetMC(); + LLVMInitializeNVPTXAsmPrinter(); + + // Initialize the LLVM optimization passes. + llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry(); + InitializePasses(registry); +} + +std::vector GetNVPTXBackendOptions( + const DebugOptions& debug_options) { + // Feed all customized flags here, so we can override them with llvm_cl_opts + // without redeploy the compiler for development purpose. + std::vector backend_llvm_opts; + + // This flag tunes a threshold in branch folding. The default threshold, which + // is one, is not suitable for CUDA programs where branches are more expensive + // than for CPU programs. Setting the threshold to 2 improves the latency of + // TwoDPatchDotProductKernel_IND_3_ND_48 by over 5%, and does not affect the + // latency of other benchmarks so far. + // + // I also tried setting this threshold to other values: + // * 3-6 gives similar results as 2; + // * >6 start hurting the performance of at least dot product kernels. + // + // TODO(jingyue): The current threshold only considers the number of IR + // instructions which do not accurately reflect the true cost. We need a + // better cost model. + backend_llvm_opts.emplace_back("-bonus-inst-threshold=2"); + + // Use div.full -- it matters for some float-division heavy benchmarks. + // Using div.approx produces incorrect result for float32(max)/float32(max). + backend_llvm_opts.emplace_back("-nvptx-prec-divf32=1"); + + // SLPVectorizer is useful (vectorizes f16x2 ops) but slow. Most of the + // slowness appears to be in trying to form horizontal reductions, which don't + // exist in PTX *anyway*. Disable these. While we're here, tweak + // SLPVectorizer so it doesn't try to create large vectors -- f16x2 are the + // only vectors supported in PTX. + backend_llvm_opts.emplace_back("-slp-vectorize-hor=false"); + backend_llvm_opts.emplace_back("-slp-max-reg-size=32"); + + // Extra backend options must go after regular backend options in order to be + // able for the later to override the former. + auto backend_extra_llvm_opts = llvm_ir::ExtractXlaBackendExtraOptions( + debug_options.xla_backend_extra_options()); + backend_llvm_opts.insert(backend_llvm_opts.end(), + backend_extra_llvm_opts.cbegin(), + backend_extra_llvm_opts.cend()); + + return backend_llvm_opts; +} + +} // namespace + +std::string GetSmName(se::CudaComputeCapability compute_capability) { + int compute_capability_version = + compute_capability.major * 10 + compute_capability.minor; + int sm_version = 30; + // If the current compute capability isn't known, fallback to the + // most recent version before it. + int supported_versions[] = {90, 89, 87, 86, 80, 75, 72, 70, 62, + 61, 60, 53, 52, 50, 37, 35, 32, 30}; + for (int v : supported_versions) { + if (v <= compute_capability_version) { + sm_version = v; + break; + } + } + + // If the current CC isn't supported by LLVM and it is newer then + // the max supported LLVM version, do not warn about it. The end + // user can't do anything about this. E.g., PTX compiled for SM75 will + // run on SM80 too. + if (sm_version != compute_capability_version && + compute_capability_version < supported_versions[0]) { + LOG(WARNING) << "Unknown compute capability " + << compute_capability.ToString() + << ". Defaulting to telling LLVM that we're compiling for sm_" + << sm_version; + } + // On Hopper, default to sm_90a so that all instructions can be used. But + // only sm_90 is forward compatible, so don't use sm_90a with newer hardware: + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility + absl::string_view extension = + (compute_capability.major == 9 && sm_version == 90) ? "a" : ""; + return absl::StrCat("sm_", sm_version, extension); +} + +absl::StatusOr CompileToPtx( + llvm::Module* module, se::GpuComputeCapability gpu_version, + const DebugOptions& debug_options, + std::function configure_target) { + static absl::once_flag backend_init_flag; + absl::call_once(backend_init_flag, NVPTXBackendInit); + auto llvm_opts = GetNVPTXBackendOptions(debug_options); + llvm_ir::LLVMCommandLineOptionsLock llvm_lock(llvm_opts); + + std::string ptx; + std::unique_ptr target_machine; + { + tsl::profiler::TraceMe activity( + [&] { return absl::StrCat("Compiling IR:", module->getName().str()); }, + tsl::profiler::TraceMeLevel::kInfo); + XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str()); + + // If the module has no functions or globals, there's nothing to compile. + // Just return an empty string. + if (module->empty() && module->global_empty()) { + VLOG(2) << "Module '" << module->getName().str() + << "' is empty. Skipping compilation."; + return std::string(); + } + + auto compute_capability = + std::get_if(&gpu_version); + if (!compute_capability) { + return xla::Internal("Incompatible compute capability was specified."); + } + + llvm::Triple default_target_triple("nvptx64-unknown-unknown"); + // Construct LLVM TargetMachine for NVPTX. + std::unique_ptr target_machine = NVPTXGetTargetMachine( + default_target_triple, *compute_capability, debug_options); + + // Apply target machine configuration from call-back if available. + if (configure_target) { + configure_target(target_machine.get()); + } + + uint64_t start_usecs = tsl::Env::Default()->NowMicros(); + + // Link with libdevice, and optimize the LLVM module. + TF_RETURN_IF_ERROR(LinkAndOptimizeModule( + module, gpu_version, debug_options, + LibDevicePath(debug_options.xla_gpu_cuda_data_dir()), + NVPTXTargetModuleLinker, default_target_triple, target_machine.get(), + kDefaultInlineThreshold)); + + uint64_t end_usecs = tsl::Env::Default()->NowMicros(); + RecordLlvmPassesDuration(end_usecs - start_usecs); + + start_usecs = tsl::Env::Default()->NowMicros(); + + // Lower optimized LLVM module to PTX. + ptx = EmitModuleToPTX(module, target_machine.get()); + + end_usecs = tsl::Env::Default()->NowMicros(); + RecordLlvmToPtxDuration(end_usecs - start_usecs); + } + return ptx; +} + +namespace { +constexpr stream_executor::SemanticVersion kFallbackPtxVersion{6, 5, 0}; +constexpr stream_executor::SemanticVersion kMaxPtxVersion{8, 5, 0}; +} // namespace + +stream_executor::SemanticVersion +DetermineHighestSupportedPtxVersionFromCudaVersion( + stream_executor::SemanticVersion cuda_version) { + if (cuda_version < stream_executor::SemanticVersion{11, 0, 0}) { + // For everything below CUDA 11 we just fall back to PTX 6.5. + // We don't support CUDA below 11 anymore. + return kFallbackPtxVersion; + } + + // Mapping determined from + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#release-notes + // Examples: + // CUDA 11.0 -> PTX 7.0 + // CUDA 11.1 -> PTX 7.1 + // CUDA 12.0 -> PTX 8.0 + // CUDA 12.4 -> PTX 8.4 + // This versioning scheme is valid until CUDA 12.6 + if (cuda_version < stream_executor::SemanticVersion{12, 6, 0}) { + return {cuda_version.major() - 4, cuda_version.minor(), 0}; + } + + // Return maximum known PTX version. + return kMaxPtxVersion; +} +} // namespace xla::gpu::nvptx diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.h b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.h new file mode 100644 index 00000000000000..9d42dc44935b6e --- /dev/null +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.h @@ -0,0 +1,57 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// LLVM-based compiler backend. +#ifndef XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_BACKEND_H_ +#define XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_BACKEND_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "llvm/IR/Module.h" +#include "llvm/Target/TargetMachine.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/semantic_version.h" +#include "xla/xla.pb.h" + +namespace xla::gpu::nvptx { + +// Gets the GPU name as it's known to LLVM for a given compute +// capability. If we see an unrecognized compute capability, we +// return the highest one that is known and below the selected device. +std::string GetSmName( + stream_executor::CudaComputeCapability compute_capability); + +// Compiles the argument module and returns it. libdevice_dir_path is the +// parent directory of the libdevice bitcode libraries. The contents of the +// module may be changed. +// +// The Compile.* interfaces each create their own llvm::LLVMContext objects +// for thread safety, but note that LLVM's multithreaded support is very +// preliminary; multithreaded use is not recommended at this time. +absl::StatusOr CompileToPtx( + llvm::Module* module, stream_executor::GpuComputeCapability gpu_version, + const DebugOptions& debug_options, + std::function configure_target = nullptr); + +// Determine PTX version from CUDA version. +stream_executor::SemanticVersion +DetermineHighestSupportedPtxVersionFromCudaVersion( + stream_executor::SemanticVersion cuda_version); + +} // namespace xla::gpu::nvptx + +#endif // XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_BACKEND_H_ diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib_test.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend_test.cc similarity index 97% rename from third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib_test.cc rename to third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend_test.cc index 57d8aa96872bc9..bc3f4ac7e83871 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib_test.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend_test.cc @@ -13,11 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "xla/service/gpu/llvm_gpu_backend/nvptx_backend.h" #include -#include "absl/strings/str_cat.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/semantic_version.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/gpu/make_batch_pointers.cc b/third_party/xla/xla/service/gpu/make_batch_pointers.cc index f2516742e1dedd..ad569593a84924 100644 --- a/third_party/xla/xla/service/gpu/make_batch_pointers.cc +++ b/third_party/xla/xla/service/gpu/make_batch_pointers.cc @@ -24,9 +24,9 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/typed_kernel_factory.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" #if TENSORFLOW_USE_ROCM #include "xla/stream_executor/gpu/gpu_stream.h" @@ -64,10 +64,10 @@ absl::Status MakeBatchPointers(se::Stream* stream, se::DeviceMemoryBase>::Create(executor, "make_batch_pointers", make_batch_pointers::kernel()))); - TF_RETURN_IF_ERROR( - stream->ThenLaunch(se::ThreadDim(kThreads, 1, 1), - se::BlockDim(CeilOfRatio(n, kThreads), 1, 1), kernel, - base_ptr, stride_bytes, n, ptrs_out)); + TF_RETURN_IF_ERROR(kernel.Launch(se::ThreadDim(kThreads, 1, 1), + se::BlockDim(CeilOfRatio(n, kThreads), 1, 1), + stream, base_ptr, stride_bytes, n, + ptrs_out)); #endif return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/matmul_indexing_utils_test.cc b/third_party/xla/xla/service/gpu/matmul_indexing_utils_test.cc index 099b64c0471e16..04aabe18e8c798 100644 --- a/third_party/xla/xla/service/gpu/matmul_indexing_utils_test.cc +++ b/third_party/xla/xla/service/gpu/matmul_indexing_utils_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/matmul_indexing_utils.h" +#include #include "absl/strings/string_view.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/shape.h" diff --git a/third_party/xla/xla/service/gpu/matmul_utils_test.cc b/third_party/xla/xla/service/gpu/matmul_utils_test.cc index d758d04169b7e0..77286130d12344 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils_test.cc +++ b/third_party/xla/xla/service/gpu/matmul_utils_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/parser/hlo_parser.h" diff --git a/third_party/xla/xla/service/gpu/metrics_test.cc b/third_party/xla/xla/service/gpu/metrics_test.cc index 836c32d0563cb5..a6a1346b563894 100644 --- a/third_party/xla/xla/service/gpu/metrics_test.cc +++ b/third_party/xla/xla/service/gpu/metrics_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include "xla/tsl/lib/monitoring/collected_metrics.h" #include "xla/tsl/lib/monitoring/collection_registry.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 91f7de0dc7631d..18f7e035aa9d9b 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -43,6 +43,74 @@ cc_library( ], ) +cc_library( + name = "sol_latency_estimator", + srcs = ["sol_latency_estimator.cc"], + hdrs = ["sol_latency_estimator.h"], + deps = [ + ":gpu_hlo_cost_analysis", + ":gpu_performance_model", + ":gpu_performance_model_base", + ":sol_gpu_cost_model", + "//xla:util", + "//xla/hlo/analysis:hlo_dataflow_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_cost_analysis", + "//xla/service:latency_hiding_scheduler", + "//xla/service/gpu:backend_configs_cc", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/time", + ], +) + +xla_cc_test( + name = "sol_latency_estimator_test", + srcs = ["sol_latency_estimator_test.cc"], + deps = [ + ":sol_gpu_cost_model", + ":sol_latency_estimator", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_cost_analysis", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "sol_gpu_cost_model", + srcs = ["sol_gpu_cost_model.cc"], + hdrs = ["sol_gpu_cost_model.h"], + deps = [ + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], +) + +xla_cc_test( + name = "sol_gpu_cost_model_test", + srcs = ["sol_gpu_cost_model_test.cc"], + deps = [ + ":sol_gpu_cost_model", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest", + ], +) + xla_test( name = "analytical_latency_estimator_test", srcs = ["analytical_latency_estimator_test.cc"], @@ -128,11 +196,11 @@ xla_cc_test( ":gpu_hlo_cost_analysis", "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:hlo_cost_analysis", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", @@ -175,16 +243,14 @@ xla_cc_test( deps = [ ":gpu_hlo_cost_analysis", ":hlo_op_profiles", - "//xla:shape_util", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_cost_analysis", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", ], ) @@ -501,9 +567,9 @@ xla_cc_test( ":symbolic_tiled_hlo_instruction", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/hlo/utils:hlo_traversal", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:statusor", @@ -617,7 +683,6 @@ xla_cc_test( "//xla/hlo/utils:hlo_traversal", "//xla/service:instruction_fusion", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -667,7 +732,6 @@ xla_cc_test( "//xla/service/gpu:gpu_device_info_for_tests", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "@com_google_absl//absl/log", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", @@ -686,6 +750,7 @@ cc_library( ":tiled_hlo_instruction_or_computation", "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_traversal", @@ -711,6 +776,7 @@ xla_cc_test( ":symbolic_tile_analysis", ":tiled_hlo_instruction_or_computation", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_traversal", "//xla/service:hlo_module_config", @@ -877,3 +943,41 @@ xla_test( "@local_tsl//tsl/platform:test_main", ], ) + +cc_library( + name = "sol_gpu_cost_model_stats_collection", + srcs = ["sol_gpu_cost_model_stats_collection.cc"], + hdrs = ["sol_gpu_cost_model_stats_collection.h"], + deps = [ + ":sol_gpu_cost_model", + ":sol_latency_estimator", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_verifier", + "//xla/service/gpu:backend_configs_cc", + "//xla/stream_executor:device_description", + "//xla/tsl/platform:status", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + ], +) + +xla_cc_test( + name = "sol_gpu_cost_model_stats_collection_test", + srcs = ["sol_gpu_cost_model_stats_collection_test.cc"], + deps = [ + ":sol_gpu_cost_model_stats_collection", + "//xla:shape_util", + "//xla/hlo/testlib:filecheck", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/third_party/xla/xla/service/gpu/model/affine_map_evaluator.cc b/third_party/xla/xla/service/gpu/model/affine_map_evaluator.cc index b85703bde0b566..b4a58be494eb15 100644 --- a/third_party/xla/xla/service/gpu/model/affine_map_evaluator.cc +++ b/third_party/xla/xla/service/gpu/model/affine_map_evaluator.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/service/gpu/model/affine_map_evaluator.h" #include -#include #include "absl/types/span.h" #include "llvm/Support/MathExtras.h" diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc index 77a68f67f4e997..a583c692c2d8b5 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc @@ -47,6 +47,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" namespace xla { namespace gpu { @@ -547,7 +548,7 @@ std::vector FindContiguousIntervals( } // Case 2: f(thread_x) != thread_x * multiplier. auto intervals = FindIntervals(partitioned_expr.func_of_d0, - {indexing_map.GetDimVars(0).bounds}); + {indexing_map.GetDimVar(0).bounds}); // Case 2.1: g(s) != s. if (partitioned_expr.func_of_s0 != range) { return intervals; diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc index 928914835a1f5b..201f061e66c111 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc @@ -39,6 +39,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc index 820d9925ab3193..c5711bf8c80bec 100644 --- a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc +++ b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/service/gpu/model/fusion_analysis_cache.h" -#include #include #include "absl/strings/string_view.h" #include "xla/hlo/parser/hlo_parser.h" diff --git a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.cc b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.cc index a874395061777f..0461814d6d7c6b 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.cc @@ -454,6 +454,90 @@ absl::Status GpuHloCostAnalysis::HandleReduce(const HloInstruction* hlo) { return absl::OkStatus(); } +absl::Status GpuHloCostAnalysis::HandleAllReduceStart( + const HloInstruction* hlo) { + int64_t output_bytes_accessed = 0; + ShapeUtil::ForEachLeafShape( + hlo->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsArray()) { + output_bytes_accessed += GetShapeSize(subshape); + } + }); + current_properties_.set_output_bytes_accessed(output_bytes_accessed); + return absl::OkStatus(); +} + +absl::Status GpuHloCostAnalysis::HandleAllGather(const HloInstruction* hlo) { + int64_t output_bytes_accessed = 0; + ShapeUtil::ForEachLeafShape( + hlo->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsArray()) { + output_bytes_accessed += GetShapeSize(subshape); + } + }); + current_properties_.set_output_bytes_accessed(output_bytes_accessed); + return absl::OkStatus(); +} + +absl::Status GpuHloCostAnalysis::HandleAllGatherStart( + const HloInstruction* hlo) { + int64_t output_bytes_accessed = 0; + ShapeUtil::ForEachLeafShape( + hlo->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + // Skip first element of a tuple as it expresses the input of the + // collective operation. + if (index.empty() || index.front() == 0) { + return; + } + if (subshape.IsArray()) { + output_bytes_accessed += GetShapeSize(subshape); + } + }); + current_properties_.set_output_bytes_accessed(output_bytes_accessed); + return absl::OkStatus(); +} + +absl::Status GpuHloCostAnalysis::HandleAsyncStart(const HloInstruction* hlo) { + auto* async_start = DynCast(hlo); + if (async_start->async_wrapped_opcode() != HloOpcode::kReduceScatter) { + VLOG(2) << "Only Reduce Scatter is supported."; + return absl::OkStatus(); + } + int index_to_skip = 1; + int64_t output_bytes_accessed = 0; + ShapeUtil::ForEachLeafShape( + hlo->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + // Skip second element of a tuple as it is an output but it is not + // actual bytes transferred. + if (index.empty() || index.front() == index_to_skip) { + return; + } + if (subshape.IsArray()) { + output_bytes_accessed += GetShapeSize(subshape); + } + }); + + current_properties_.set_output_bytes_accessed(output_bytes_accessed); + return absl::OkStatus(); +} + +absl::Status GpuHloCostAnalysis::HandleReduceScatter( + const HloInstruction* hlo) { + int64_t output_bytes_accessed = 0; + + for (auto* operand : hlo->operands()) { + ShapeUtil::ForEachLeafShape( + operand->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsArray()) { + output_bytes_accessed += GetShapeSize(subshape); + } + }); + } + current_properties_.set_output_bytes_accessed(output_bytes_accessed); + + return absl::OkStatus(); +} + absl::Status GpuHloCostAnalysis::HandleElementwiseOp( const HloInstruction* hlo) { current_properties_[kFlopsKey] = GetFlopsForElementwiseOp(hlo); diff --git a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.h b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.h index 81fcd09eaeae16..5561a321b318ed 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.h +++ b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.h @@ -72,6 +72,11 @@ class GpuHloCostAnalysis : public HloCostAnalysis { absl::Status HandleConcatenate(const HloInstruction* hlo) override; absl::Status HandleAllReduce(const HloInstruction* allreduce) override; absl::Status HandleReduce(const HloInstruction* hlo) override; + absl::Status HandleAllReduceStart(const HloInstruction* hlo) override; + absl::Status HandleAllGather(const HloInstruction* hlo) override; + absl::Status HandleAllGatherStart(const HloInstruction* hlo) override; + absl::Status HandleAsyncStart(const HloInstruction* hlo) override; + absl::Status HandleReduceScatter(const HloInstruction* hlo) override; // Estimate the total size of IR accounting for both duplication // of producer code by consumer and the total number of basic blocks. diff --git a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc index 9f591ac8c25e6a..71b7da2332e30d 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc @@ -24,10 +24,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/model/hlo_op_profiles.h" #include "xla/service/hlo_cost_analysis.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -308,8 +307,8 @@ f { m0 = s8[10] multiply(n0, n0) a0 = s8[10] add(n0, n0) s0 = s8[5] slice(a0), slice={[0:5]} - s1 = s8[2] slice(n0), slice={[4:6]} - n1 = s8[2] negate(s1) + svar1 = s8[2] slice(n0), slice={[4:6]} + n1 = s8[2] negate(svar1) ROOT c0 = s8[17] concatenate(s0, m0, n1), dimensions={0} } @@ -642,6 +641,136 @@ ENTRY entry_computation { EXPECT_EQ(analysis_.flop_count(*reduce), 32 * 39 * 6); } +TEST_F(GpuHloCostAnalysisTest, AsyncAllReduce) { + absl::string_view hlo_string = R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT t = f32[] add(param_0, param_1) +} + +ENTRY entry_computation { + p = f32[4096] parameter(0) + ar-start = f32[4096] all-reduce-start(p), to_apply=add + ROOT _ = f32[4096] all-reduce-done(ar-start) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + + const HloInstruction* all_reduce = + module->entry_computation()->root_instruction()->operand(0); + EXPECT_EQ(analysis_.output_bytes_accessed(*all_reduce), 4096 * 4); +} + +TEST_F(GpuHloCostAnalysisTest, AllGather) { + absl::string_view hlo_string = R"( +HloModule m + +ENTRY entry_computation { + p = f32[1024] parameter(0) + ROOT _ = f32[4096] all-gather(p), dimensions={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + + const HloInstruction* all_gather = + module->entry_computation()->root_instruction(); + EXPECT_EQ(analysis_.output_bytes_accessed(*all_gather), 4096 * 4); +} + +TEST_F(GpuHloCostAnalysisTest, AsyncAllGather) { + absl::string_view hlo_string = R"( +HloModule m + +ENTRY entry_computation { + p.0 = f32[1024] parameter(0) + p.1 = f32[512] parameter(1) + ag-start = ((f32[1024],f32[512]), (f32[4096],f32[2048])) all-gather-start(p.0,p.1), + dimensions={0} + ROOT _ = (f32[4096],f32[2048]) all-gather-done(ag-start) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + + const HloInstruction* all_gather = + module->entry_computation()->root_instruction()->operand(0); + // Output is (f32[4096], f32[2048]). + EXPECT_EQ(analysis_.output_bytes_accessed(*all_gather), 4096 * 4 + 2048 * 4); +} + +TEST_F(GpuHloCostAnalysisTest, ReduceScatter) { + absl::string_view hlo_string = R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT t = f32[] add(param_0, param_1) +} + +ENTRY entry_computation { + p = f32[4096] parameter(0) + ROOT _ = f32[1024] reduce-scatter(p), dimensions={0}, to_apply=add +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + + const HloInstruction* reduce_scatter = + module->entry_computation()->root_instruction(); + EXPECT_EQ(analysis_.output_bytes_accessed(*reduce_scatter), 4096 * 4); +} + +TEST_F(GpuHloCostAnalysisTest, AsyncReduceScatter) { + absl::string_view hlo_string = R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT t = f32[] add(param_0, param_1) +} + +async_computation { + param_3 = f32[4096] parameter(0) + param_4 = f32[2048] parameter(1) + ROOT r = (f32[1024],f32[512]) reduce-scatter(param_3,param_4), + dimensions={0}, + to_apply=add +} + +ENTRY entry_computation { + p.0 = f32[4096] parameter(0) + p.1 = f32[2048] parameter(1) + rs-start = ((f32[4096],f32[2048]),(f32[1024],f32[512])) async-start(p.0,p.1), calls=async_computation + ROOT _ = (f32[1024],f32[512]) async-done(rs-start) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + + const HloInstruction* reduce_scatter = + module->entry_computation()->root_instruction()->operand(0); + // Output is (f32[1024],f32[512]). + EXPECT_EQ(analysis_.output_bytes_accessed(*reduce_scatter), + 4096 * 4 + 2048 * 4); +} + TEST_F(GpuHloCostAnalysisTest, CustomOpProfileIsUsed) { absl::string_view hlo_string = R"( HloModule m diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc index 2a087829276d36..ab09a82537e9b3 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/service/gpu/model/gpu_performance_model.h" #include -#include #include #include diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc index 8cbf262c6f8b38..9769c8f6d0ed23 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h index 1655003def697f..0ac09b5dcf2bd7 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc index 77c357d3cbdc69..656037c48fc0d2 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc @@ -15,8 +15,6 @@ limitations under the License. #include "xla/service/gpu/model/gpu_performance_model_base.h" -#include - #include #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" diff --git a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc index 86ada554f5184c..dbab70f96ae12f 100644 --- a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc +++ b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/model/hlo_op_profiler.h" +#include #include #include "xla/hlo/ir/hlo_opcode.h" #include "xla/tests/hlo_test_base.h" diff --git a/third_party/xla/xla/service/gpu/model/hlo_op_profiles.cc b/third_party/xla/xla/service/gpu/model/hlo_op_profiles.cc index 94d1f6a9800784..db39b830f2eb37 100644 --- a/third_party/xla/xla/service/gpu/model/hlo_op_profiles.cc +++ b/third_party/xla/xla/service/gpu/model/hlo_op_profiles.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include @@ -51,8 +50,8 @@ namespace gpu { } /*static*/ std::unique_ptr HloOpProfiles::Load( - std::string_view profiles_text_proto, - std::string_view default_profile_name) { + absl::string_view profiles_text_proto, + absl::string_view default_profile_name) { ProfilesNestedMap profiles_map; DeviceHloInstructionProfiles all_device_profiles; CHECK(tsl::protobuf::TextFormat::ParseFromString( diff --git a/third_party/xla/xla/service/gpu/model/hlo_op_profiles.h b/third_party/xla/xla/service/gpu/model/hlo_op_profiles.h index 28845d4ab4eea8..109f6b590435f2 100644 --- a/third_party/xla/xla/service/gpu/model/hlo_op_profiles.h +++ b/third_party/xla/xla/service/gpu/model/hlo_op_profiles.h @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/container/flat_hash_map.h" @@ -51,8 +50,8 @@ class HloOpProfiles { // Loads profiles from the given text proto data. static std::unique_ptr Load( - std::string_view profiles_text_proto, - std::string_view default_profile_name); + absl::string_view profiles_text_proto, + absl::string_view default_profile_name); const HloOpProfile& GetProfile( const se::DeviceDescription& device_info) const; @@ -61,7 +60,7 @@ class HloOpProfiles { private: HloOpProfiles(ProfilesNestedMap profiles, - std::string_view default_profile_name) + absl::string_view default_profile_name) : profiles_(std::move(profiles)), default_profile_(profiles_.at(default_profile_name)) {} diff --git a/third_party/xla/xla/service/gpu/model/hlo_op_profiles_test.cc b/third_party/xla/xla/service/gpu/model/hlo_op_profiles_test.cc index 2ca9ec201ec965..1e566ac1b8ae14 100644 --- a/third_party/xla/xla/service/gpu/model/hlo_op_profiles_test.cc +++ b/third_party/xla/xla/service/gpu/model/hlo_op_profiles_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/gpu/model/hlo_op_profiles.h" +#include + #include #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" diff --git a/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model.cc b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model.cc new file mode 100644 index 00000000000000..1334f6c4185cd6 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model.cc @@ -0,0 +1,190 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/sol_gpu_cost_model.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/numeric/bits.h" +#include "absl/strings/numbers.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_module.h" + +namespace xla { +namespace gpu { +namespace { +// Constants for NCCL SoL model +constexpr double kHeaderOverhead = 0.025; +constexpr absl::string_view kNcclOpLaunchUs = "nccl_op_launch_us"; +constexpr absl::string_view kNicSpeedGbps = "nic_speed_gbps"; +constexpr absl::string_view kChunkPrepUs = "chunk_prep_us"; +constexpr absl::string_view kRttUs = "rtt_us"; +constexpr absl::string_view kGpusPerNode = "gpus_per_node"; +constexpr absl::string_view kChunkSizeBytes = "chunk_size_bytes"; + +// Returns the number of communicators in the mask. +// For example, if the mask is 0x0, this function returns 1. If the mask is 0x7, +// this function returns 8. +int NumCommunicators(const absl::string_view mask) { + // Assuming the mask is a hexadecimal number + uint64_t mask_value = std::stoul(std::string(mask), nullptr, 16); + int bit_count = absl::popcount(mask_value); // Count set bits + return static_cast(std::pow(2, bit_count)); +} + +// Returns the number of rounds for the given collective type. +int NumRounds(const SolGPUCostModel::CollectiveType& coll_type) { + // AllReduce requires ReduceScatter and AllGather, so it has 2 rounds. + return coll_type == SolGPUCostModel::CollectiveType::kAllReduce ? 2 : 1; +} + +} // namespace + +/*static*/ SolGPUCostModel::Config SolGPUCostModel::GetConfig( + const HloModule* module) { + SolGPUCostModel::Config config; + const auto& extra_options = + module->config() + .debug_options() + .xla_gpu_analytical_latency_estimator_options(); + for (const auto& [option_name, option_value] : extra_options) { + int64_t value; + double value_d; + VLOG(2) << "[SoL] option: " << option_name << " is " << option_value; + if (option_name == kNcclOpLaunchUs && + absl::SimpleAtoi(option_value, &value)) { + config.nccl_op_launch_time = absl::Microseconds(value); + } else if (option_name == kNicSpeedGbps && + absl::SimpleAtod(option_value, &value_d)) { + config.nic_speed_gbps = value_d; + } else if (option_name == kChunkPrepUs && + absl::SimpleAtoi(option_value, &value)) { + config.chunk_prep_time = absl::Microseconds(value); + } else if (option_name == kRttUs && + absl::SimpleAtoi(option_value, &value)) { + config.rtt = absl::Microseconds(value); + } else if (option_name == kGpusPerNode && + absl::SimpleAtoi(option_value, &value)) { + config.gpus_per_node = value; + } else if (option_name == kChunkSizeBytes && + absl::SimpleAtoi(option_value, &value)) { + config.chunk_size_bytes = value; + } + } + return config; +} + +SolGPUCostModel::SolGPUCostModel(const Config& sys_config) + : xla_flag_config_(sys_config) { + VLOG(2) << "[SoL] NIC speed: " << xla_flag_config_.nic_speed_gbps; + VLOG(2) << "[SoL] RTT: " << xla_flag_config_.rtt; + VLOG(2) << "[SoL] Chunk preparation time: " + << xla_flag_config_.chunk_prep_time; + VLOG(2) << "[SoL] NCCL op launch time: " + << xla_flag_config_.nccl_op_launch_time; + VLOG(2) << "[SoL] GPUs per node: " << xla_flag_config_.gpus_per_node; +} + +// This is a insignificant term, and we are making it consistent +// with the existing formula. +absl::Duration SolGPUCostModel::ChunkPrepLatency( + const int64_t per_gpu_msg_size_bytes) const { + return std::ceil(static_cast(per_gpu_msg_size_bytes) / + xla_flag_config_.chunk_size_bytes) * + xla_flag_config_.chunk_prep_time; +} + +absl::Duration SolGPUCostModel::TransferDuration( + const int64_t per_gpu_msg_size_bytes) const { + // x1e6 to comvert secs to microseconds; + // x1024*1024 *1024 to convert Gbytes/sec to bytes/sec + const long double ret = + (1e6 * static_cast(per_gpu_msg_size_bytes)) / + (std::pow(1024.0, 3) * xla_flag_config_.nic_speed_gbps); + return absl::Microseconds(ret * (1 + kHeaderOverhead)); +} + +absl::Duration SolGPUCostModel::RingLatency( + const int64_t buff_size_bytes, const int num_nodes, + const CollectiveType& coll_type, const absl::string_view mask) const { + const int num_gpus = NumGpusPerComm(num_nodes, coll_type, mask); + + int64_t per_gpu_msg_size_bytes; + if (coll_type == CollectiveType::kSendRecv) { + per_gpu_msg_size_bytes = buff_size_bytes; + } else { + per_gpu_msg_size_bytes = buff_size_bytes / num_gpus; + } + + // This is the number of GPUs per communicator per node. We assume that each + // GPU has a NIC, and this is also the number of NICs per communicator per + // node. + // Note that this happens to be correct value (i.e. 1) for SendRecv. + int num_gpus_per_node = num_gpus / num_nodes; + + // In each channel, consider one GPU next to the Ethernet link. Below is the + // sum of 3 time costs for each piece of data of size + // `per_gpu_msg_size_bytes` + // + // 1. transfer duration defined by the NIC bandwidth, + // 2. chunk preparation latency, and + // 3. RTT + // + // then followed by two factors: + // + // 1. Multiply by `num_gpus - 1`, as `num_gpus - 1` pieces of data will be + // sent over the link in AllGather. + // 2. Divide by `num_gpus_per_node` as there are `num_gpus_per_node` NICs + // and + // GPUs in each node for parallelism. + // + // Better estimates of terms like this will come in future versions + // of the SoL model. + absl::Duration ret = TransferDuration(per_gpu_msg_size_bytes) + + ChunkPrepLatency(per_gpu_msg_size_bytes) + + xla_flag_config_.rtt; + ret *= (num_gpus - 1.0) / static_cast(num_gpus_per_node); + // Multiply by the number of rounds, which is different for AllReduce. + ret = ret * NumRounds(coll_type); + + // Time to initiate the collective. + return ret + xla_flag_config_.nccl_op_launch_time; +} + +// Helper functions +int SolGPUCostModel::NumGpusPerComm(int num_nodes, + const CollectiveType& coll_type, + const absl::string_view mask) const { + if (coll_type == CollectiveType::kSendRecv) { + return 2; + } + int num_comms = NumCommunicators(mask); + CHECK_EQ(xla_flag_config_.gpus_per_node % num_comms, 0) + << "GPU_PER_NODE must be divisible by the number of communicators. " + "GPU_PER_NODE: " + << xla_flag_config_.gpus_per_node + << " Number of communicators: " << num_comms + << ". Adjust the number of GPUs per node with the flag " + "gpus_per_node in xla_gpu_analytical_latency_estimator_options."; + return num_nodes * xla_flag_config_.gpus_per_node / num_comms; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model.h b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model.h new file mode 100644 index 00000000000000..b359f196382dbe --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model.h @@ -0,0 +1,84 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_SOL_GPU_COST_MODEL_H_ +#define XLA_SERVICE_GPU_MODEL_SOL_GPU_COST_MODEL_H_ + +#include + +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_module.h" + +namespace xla { +namespace gpu { +inline constexpr absl::string_view kSplitMaskWorldLevel = "0x0"; + +class SolGPUCostModel { + // Speed-of-Light (SoL) analytical cost model for NCCL collectives. + public: + // Tunable system configuration, see + // xla_gpu_analytical_latency_estimator_options + struct Config { + absl::Duration nccl_op_launch_time; + double nic_speed_gbps; // it's GBytes/s, not Gbit/s (ex: 40Gb/s = 5GB/s) + absl::Duration chunk_prep_time; + absl::Duration rtt; + int64_t gpus_per_node; + int64_t chunk_size_bytes; + }; + enum CollectiveAlgorithmType { + RING = 0, + TREE, + }; + enum class CollectiveType { + kAllReduce, + kAllGather, + kReduceScatter, + kSendRecv, + }; + explicit SolGPUCostModel(const Config& sys_config); + + // Extract the SoL-related configuration from XLA flags. + static SolGPUCostModel::Config GetConfig(const HloModule* module); + + // Returns the latency of a NCCL ring collective. + // + // `buff_size_bytes`: the size of the message to be transferred. + // `num_nodes`: the number of nodes participating in the ring. + // `coll_type`: the type of the collective (eg AllGather). + // `mask`: the mask of the collective (AllWorld 0x0 vs RailAligned 0x7). + absl::Duration RingLatency( + int64_t buff_size_bytes, int num_nodes, const CollectiveType& coll_type, + absl::string_view mask = kSplitMaskWorldLevel) const; + + private: + // Helper functions to estimate the latency subcomponents + absl::Duration ChunkPrepLatency(int64_t per_gpu_msg_size_bytes) const; + + absl::Duration TransferDuration(int64_t per_gpu_msg_size_bytes) const; + // NumGpusPerComm returns GPUs number participating in a given NCCL + // collective operation. + int NumGpusPerComm(int num_nodes, const CollectiveType& coll_type, + absl::string_view mask) const; + + // SoL-related configuration for NCCL cost modelling passed by user as flags. + Config xla_flag_config_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_SOL_GPU_COST_MODEL_H_ diff --git a/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_stats_collection.cc b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_stats_collection.cc new file mode 100644 index 00000000000000..766123c2c49297 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_stats_collection.cc @@ -0,0 +1,59 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/sol_gpu_cost_model_stats_collection.h" + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/model/sol_gpu_cost_model.h" +#include "xla/service/gpu/model/sol_latency_estimator.h" +#include "xla/tsl/platform/status.h" + +namespace xla::gpu { + +absl::StatusOr SolGpuCostModelStatsCollection::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + SolGPUCostModel::Config config = SolGPUCostModel::GetConfig(module); + + hlo_query::ForEachInstructionWithPred( + *module, + [](const HloInstruction* instr) { + return hlo_query::IsCollectiveCommunicationOp(instr->opcode()); + }, + [&](HloInstruction* instr) { + // Generate exec time for a collective. + absl::Duration exec_time = SolLatencyEstimator::ComputeCollectiveTime( + *instr, device_info_, shape_size_in_bytes_fn_, config); + + // Set it in the `CollectiveBackendConfig`. + auto gpu_config = instr->backend_config(); + TF_CHECK_OK(gpu_config.status()) << instr->ToString(); + auto reification_cost = gpu_config->mutable_collective_backend_config() + ->mutable_reification_cost(); + reification_cost->set_exec_time_us( + absl::ToDoubleMicroseconds(exec_time)); + TF_CHECK_OK(instr->set_backend_config(*gpu_config)); + }); + + return false; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_stats_collection.h b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_stats_collection.h new file mode 100644 index 00000000000000..67fe7963fbe689 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_stats_collection.h @@ -0,0 +1,54 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_SOL_GPU_COST_MODEL_STATS_COLLECTION_H_ +#define XLA_SERVICE_GPU_MODEL_SOL_GPU_COST_MODEL_STATS_COLLECTION_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/service/hlo_verifier.h" +#include "xla/stream_executor/device_description.h" + +namespace xla::gpu { + +class SolGpuCostModelStatsCollection : public HloModulePass { + public: + explicit SolGpuCostModelStatsCollection( + const se::DeviceDescription& device_description, + ShapeSizeFn shape_size_in_bytes_fn) + : device_info_(device_description), + shape_size_in_bytes_fn_(shape_size_in_bytes_fn) {} + + absl::string_view name() const override { + return "sol-gpu-cost-model-stats-collection"; + } + + using HloPassInterface::Run; + + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + se::DeviceDescription device_info_; + ShapeSizeFn shape_size_in_bytes_fn_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_MODEL_SOL_GPU_COST_MODEL_STATS_COLLECTION_H_ diff --git a/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_stats_collection_test.cc b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_stats_collection_test.cc new file mode 100644 index 00000000000000..35419533431963 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_stats_collection_test.cc @@ -0,0 +1,95 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/sol_gpu_cost_model_stats_collection.h" + +#include +#include + +#include +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +using ShapeSizeFn = std::function; + +class SolGpuCostModelStatsCollectionTest : public HloTestBase { + public: + explicit SolGpuCostModelStatsCollectionTest() : HloTestBase() { + ShapeSizeFn shape_size_bytes = + [&shape_size_bytes](const Shape& shape) -> int64_t { + int64_t shape_size = 0; + if (shape.IsTuple()) { + for (auto& sub_shape : shape.tuple_shapes()) { + shape_size += shape_size_bytes(sub_shape); + } + return shape_size; + } + return ShapeUtil::ByteSizeOfElements(shape); + }; + shape_size_fn_ = shape_size_bytes; + } + + protected: + se::DeviceDescription device_info_ = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + ShapeSizeFn shape_size_fn_; +}; + +TEST_F(SolGpuCostModelStatsCollectionTest, + RecordsRuntimeInformationForCollectives) { + constexpr absl::string_view kHloText = R"( + HloModule m + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT _ = f32[] add(x, y) + } + + ENTRY ar { + p0 = f32[8192,4096] parameter(0) + + ar-start = f32[8192,4096] all-reduce-start(p0), to_apply=add, + replica_groups={{0,1,2,3,4,5,6,7}, {8,9,10,11,12,13,14,15}} + ROOT ar-done = f32[8192,4096] all-reduce-done(ar-start) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + + TF_ASSERT_OK_AND_ASSIGN( + bool changed, SolGpuCostModelStatsCollection(device_info_, shape_size_fn_) + .Run(module.get())); + + VLOG(1) << module->ToString(); + + EXPECT_FALSE(changed); + EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( +// CHECK: ar-start +// CHECK-SAME: collective_backend_config +// CHECK-SAME: "exec_time_us":1407 +)")); +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_test.cc b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_test.cc new file mode 100644 index 00000000000000..d7892a13fe713a --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_test.cc @@ -0,0 +1,68 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/sol_gpu_cost_model.h" + +#include + +#include +#include "absl/time/time.h" +namespace xla { +namespace gpu { +namespace { +constexpr int64_t kTenMB = 10 * 1024 * 1024; // 10MB + +using ::testing::TestWithParam; +using ::testing::ValuesIn; + +struct RingLatencyTestCase { + SolGPUCostModel::CollectiveType collective_type; + absl::Duration expected_latency; +}; + +class SolGPUCostModelTest : public TestWithParam { + protected: + SolGPUCostModelTest() + : model_({ + /*nccl_op_launch_time=*/absl::Microseconds(100), + /*nic_speed_gbps=*/100, + /*chunk_prep_time=*/absl::Microseconds(100), + /*rtt=*/absl::Microseconds(100), + /*gpus_per_node=*/100, + /*chunk_size_bytes=*/4 * 1024 * 1024, + }) {} + SolGPUCostModel model_; +}; + +TEST_P(SolGPUCostModelTest, TestRingLatency) { + const RingLatencyTestCase& test_case = GetParam(); + absl::Duration actual_latency = + absl::Trunc(model_.RingLatency(kTenMB, 1, test_case.collective_type), + absl::Microseconds(1)); + EXPECT_EQ(actual_latency, test_case.expected_latency); +} + +INSTANTIATE_TEST_SUITE_P( + SolGPUCostModelTests, SolGPUCostModelTest, + ValuesIn({ + {SolGPUCostModel::CollectiveType::kAllGather, absl::Microseconds(298)}, + {SolGPUCostModel::CollectiveType::kAllReduce, absl::Microseconds(497)}, + {SolGPUCostModel::CollectiveType::kReduceScatter, + absl::Microseconds(298)}, + {SolGPUCostModel::CollectiveType::kSendRecv, absl::Microseconds(350)}, + })); +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/sol_latency_estimator.cc b/third_party/xla/xla/service/gpu/model/sol_latency_estimator.cc new file mode 100644 index 00000000000000..0e2f67a9327110 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/sol_latency_estimator.cc @@ -0,0 +1,198 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/sol_latency_estimator.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/time/time.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/gpu_performance_model.h" +#include "xla/service/gpu/model/gpu_performance_model_base.h" +#include "xla/service/gpu/model/sol_gpu_cost_model.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/latency_hiding_scheduler.h" +#include "xla/stream_executor/device_description.h" +#include "xla/util.h" + +namespace xla { +namespace gpu { + +namespace { + +int GetNumGpus(const HloInstruction& instr) { + const HloInstruction* i = &instr; + if (instr.opcode() == HloOpcode::kAsyncStart) { + i = instr.async_wrapped_instruction(); + } + int size = 0; + for (auto& rg : i->replica_groups()) { + size += rg.replica_ids_size(); + } + return size; +} + +} // namespace + +/*static*/ absl::Duration SolLatencyEstimator::ComputeCollectiveTime( + const HloInstruction& instr, const se::DeviceDescription& gpu_device_info, + HloCostAnalysis::ShapeSizeFunction shape_size_fn, + const SolGPUCostModel::Config& sol_flags) { + GpuHloCostAnalysis analysis( + GpuHloCostAnalysis::Options{shape_size_fn, + /*per_second_rates=*/{}, + /*min_latencies_seconds=*/{}, + /*count_multiple_input_accesses=*/true}, + gpu_device_info); + + CHECK_OK(instr.parent()->Accept(&analysis)); + + return SolLatencyEstimator::ComputeCollectiveTime( + instr, gpu_device_info, shape_size_fn, sol_flags, analysis); +} + +/*static*/ absl::Duration SolLatencyEstimator::ComputeCollectiveTime( + const HloInstruction& instr, const se::DeviceDescription& gpu_device_info, + HloCostAnalysis::ShapeSizeFunction shape_size_fn, + const SolGPUCostModel::Config& sol_flags, + const GpuHloCostAnalysis& analysis) { + const int num_nodes = GetNumGpus(instr) / sol_flags.gpus_per_node; + if (num_nodes == 1) { + VLOG(8) << "Returning only kernel launch overhead for a single node."; + return GpuPerformanceModelBase::kNcclKernelLaunchOverhead; + } + + if (HloDataflowAnalysis::IsAsynchronousOperationDone(instr.opcode())) { + VLOG(8) << "Returning 0 cost for async done op " << instr.name(); + return absl::ZeroDuration(); + } + SolGPUCostModel sol_model(sol_flags); + const int64_t msg_size = analysis.output_bytes_accessed(instr); + + switch (instr.opcode()) { + case HloOpcode::kAllGather: + case HloOpcode::kAllGatherStart: { + return sol_model.RingLatency(msg_size, num_nodes, + SolGPUCostModel::CollectiveType::kAllGather); + } + case HloOpcode::kAllReduce: + case HloOpcode::kAllReduceStart: { + return sol_model.RingLatency(msg_size, num_nodes, + SolGPUCostModel::CollectiveType::kAllReduce); + } + case HloOpcode::kReduceScatter: { + return sol_model.RingLatency( + msg_size, num_nodes, SolGPUCostModel::CollectiveType::kReduceScatter); + } + case HloOpcode::kAsyncStart: { + if (instr.async_wrapped_opcode() == HloOpcode::kReduceScatter) { + return sol_model.RingLatency( + msg_size, num_nodes, + SolGPUCostModel::CollectiveType::kReduceScatter); + } + break; + } + case HloOpcode::kRecv: + case HloOpcode::kSend: { + return sol_model.RingLatency(msg_size, num_nodes, + SolGPUCostModel::CollectiveType::kSendRecv); + } + // note: AllToAll is not yet supported in XLA + default: { + LOG(WARNING) + << "[SoL] Runtime estimate for " << instr.name() + << " not implemented. Returning only the kernel launch time."; + return GpuPerformanceModelBase::kNcclKernelLaunchOverhead; + } + } + return GpuPerformanceModelBase::kNcclKernelLaunchOverhead; +} + +LatencyEstimator::TimeCost SolLatencyEstimator::GetLatencyBetween( + const HloGraphNode& from, const HloGraphNode& target) const { + const HloOpcode from_op = from.GetInstr().opcode(); + if (!config_.schedule_send_recvs && + (from_op == HloOpcode::kSend || from_op == HloOpcode::kRecv)) { + return kLowLatency; + } + + if (IsAsyncPair(from, target)) { + double coll_time = absl::ToDoubleMicroseconds( + ComputeCollectiveTime(from.GetInstr(), gpu_info_, shape_size_function_, + sol_flags_, *cost_analysis_)); + VLOG(10) << "[SoL] Analytical estimator calculated latency between " + << from.GetInstr().name() << " and " << target.GetInstr().name() + << " to be: " << coll_time << " us."; + return coll_time; + } + return latency_estimator_->GetLatencyBetween(from, target); +} + +LatencyEstimator::TimeCost SolLatencyEstimator::NodeCost( + const HloInstruction* instr) const { + if (hlo_query::IsAsyncCollectiveStartOp(instr, /*include_send_recv=*/true) || + hlo_query::IsAsyncCollectiveDoneOp(instr, /*include_send_recv=*/true)) { + return kLowCost; + } + + absl::Duration total_estimated_time = + GpuPerformanceModel::EstimateRunTimeForInstruction( + instr, gpu_info_, &*cost_analysis_, + GpuPerformanceModelOptions::Default()) + .exec_time; + LatencyEstimator::TimeCost cost_in_us = + absl::ToDoubleMicroseconds(total_estimated_time); + VLOG(10) << "Analytical estimator calculated cost for: " << instr->name() + << ". Cost: " << cost_in_us; + return cost_in_us; +} + +SolLatencyEstimator::SolLatencyEstimator( + const SchedulerConfig& config, + std::unique_ptr latency_estimator, + const se::DeviceDescription& gpu_info, + HloCostAnalysis::ShapeSizeFunction shape_size_function, + HloComputation* computation) + : config_(config), + gpu_info_(gpu_info), + latency_estimator_(std::move(latency_estimator)), + shape_size_function_(shape_size_function), + sol_flags_(SolGPUCostModel::GetConfig(computation->parent())) { + cost_analysis_.emplace( + GpuHloCostAnalysis::Options{shape_size_function_, + /*per_second_rates=*/{}, + /*min_latencies_seconds=*/{}, + /*count_multiple_input_accesses=*/true}, + gpu_info_); + TF_CHECK_OK(computation->Accept(&cost_analysis_.value())); + if (sol_flags_.nccl_op_launch_time == absl::ZeroDuration() || + sol_flags_.nic_speed_gbps == 0 || + sol_flags_.chunk_prep_time == absl::ZeroDuration() || + sol_flags_.rtt == absl::ZeroDuration() || sol_flags_.gpus_per_node == 0) { + LOG(WARNING) << "[SoL] Failed to parse SoL system config options."; + } +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/sol_latency_estimator.h b/third_party/xla/xla/service/gpu/model/sol_latency_estimator.h new file mode 100644 index 00000000000000..0c9da3d0abcce0 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/sol_latency_estimator.h @@ -0,0 +1,77 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_SOL_LATENCY_ESTIMATOR_H_ +#define XLA_SERVICE_GPU_MODEL_SOL_LATENCY_ESTIMATOR_H_ + +#include +#include + +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/sol_gpu_cost_model.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/latency_hiding_scheduler.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { + +class SolLatencyEstimator : public LatencyEstimator { + public: + // Implementation of SolLatencyEstimator using HloAnalysis and + // GPUPerformanceModel to estimate latencies for instructions. + SolLatencyEstimator(const SchedulerConfig& config, + std::unique_ptr latency_estimator, + const se::DeviceDescription& gpu_info, + HloCostAnalysis::ShapeSizeFunction shape_size_function, + HloComputation* computation); + + TimeCost GetLatencyBetween(const HloGraphNode& from, + const HloGraphNode& target) const override; + TimeCost NodeCost(const HloInstruction* instr) const override; + int CyclesPerMicrosecond() const override { + return latency_estimator_->CyclesPerMicrosecond(); + } + + static absl::Duration ComputeCollectiveTime( + const HloInstruction& instr, const se::DeviceDescription& gpu_device_info, + HloCostAnalysis::ShapeSizeFunction shape_size_fn, + const SolGPUCostModel::Config& sol_flags); + + static absl::Duration ComputeCollectiveTime( + const HloInstruction& instr, const se::DeviceDescription& gpu_device_info, + HloCostAnalysis::ShapeSizeFunction shape_size_fn, + const SolGPUCostModel::Config& sol_flags, + const GpuHloCostAnalysis& cost_analysis); + + static constexpr TimeCost kLowCost = 1.0; + static constexpr TimeCost kLowLatency = 1.0; + + private: + const SchedulerConfig config_; + const se::DeviceDescription& gpu_info_; + std::optional cost_analysis_; + std::unique_ptr latency_estimator_; + HloCostAnalysis::ShapeSizeFunction shape_size_function_; + const SolGPUCostModel::Config sol_flags_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_SOL_LATENCY_ESTIMATOR_H_ diff --git a/third_party/xla/xla/service/gpu/model/sol_latency_estimator_test.cc b/third_party/xla/xla/service/gpu/model/sol_latency_estimator_test.cc new file mode 100644 index 00000000000000..de40364d29f887 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/sol_latency_estimator_test.cc @@ -0,0 +1,185 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/sol_latency_estimator.h" + +#include +#include +#include + +#include +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/model/sol_gpu_cost_model.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +using ::testing::TestParamInfo; +using ::testing::ValuesIn; +using ::testing::WithParamInterface; + +struct EstimatorTestCase { + std::string test_name; + std::string module_string; + HloOpcode opcode; + absl::Duration expected_latency; +}; + +class SolLatencyEstimatorTest : public HloTestBase, + public WithParamInterface { + protected: + SolLatencyEstimatorTest() + : shape_size_fn_(HloCostAnalysis::DefaultShapeSize), + gpu_device_info_( + backend().default_stream_executor()->GetDeviceDescription()), + sol_flags_({ + /*nccl_op_launch_time=*/absl::Microseconds(100), + /*nic_speed_gbps=*/100, + /*chunk_prep_time=*/absl::Microseconds(100), + /*rtt=*/absl::Microseconds(100), + /*gpus_per_node=*/8, + /*chunk_size_bytes=*/4 * 1024 * 1024, + }) {} + + absl::Duration ComputeCollectiveTime(const HloInstruction& instr) { + return SolLatencyEstimator::ComputeCollectiveTime( + instr, gpu_device_info_, shape_size_fn_, sol_flags_); + } + + HloCostAnalysis::ShapeSizeFunction shape_size_fn_; + const se::DeviceDescription& gpu_device_info_; + const SolGPUCostModel::Config sol_flags_; +}; + +TEST_P(SolLatencyEstimatorTest, TestLatencyEstimation) { + EstimatorTestCase test_case = GetParam(); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(test_case.module_string)); + HloInstruction* instr = + hlo_query::FindInstruction(module->entry_computation(), test_case.opcode); + absl::Duration actual_time_us = + absl::Trunc(ComputeCollectiveTime(*instr), absl::Microseconds(1)); + EXPECT_EQ(actual_time_us, test_case.expected_latency); +} + +std::vector GetSolLatencyEstimatorTestCases() { + EstimatorTestCase all_gather_intra_host = { + /*test_name=*/"all_gather_intra_host", + /*module_string=*/R"( +HloModule m + +ENTRY main { + p = bf16[16000,1000] parameter(0) + ag-start = (bf16[16000,1000], bf16[16000,8000]) all-gather-start(p), + replica_groups={{0,1,2,3,4,5,6,7},{8,9,10,11,12,13,14,15}}, + channel_id=1, + use_global_device_ids=true, + dimensions={1} + ROOT ag-done = bf16[16000,8000] all-gather-done(ag-start) + +})", + /*opcode=*/HloOpcode::kAllGatherStart, + /*expected_latency=*/absl::Microseconds(1323), + }; + + EstimatorTestCase all_gather_inter_host_pairwise = { + /*test_name=*/"all_gather_intra_host_pairwise", + /*module_string=*/R"( +HloModule m + +ENTRY main { + p = bf16[16000,4000] parameter(0) + ag-start = (bf16[16000,4000], bf16[16000,8000]) all-gather-start(p), + replica_groups={{0,8},{1,9},{2,10},{3,11},{4,12},{5,13},{6,14},{7,15}}, + channel_id=1, + use_global_device_ids=true, + dimensions={1} + ROOT ag-done = bf16[16000,8000] all-gather-done(ag-start) +})", + /*opcode=*/HloOpcode::kAllGatherStart, + /*expected_latency=*/absl::Microseconds(1323), + }; + + EstimatorTestCase all_gather_all_ranks = { + /*test_name=*/"all_gather_all_ranks", + /*module_string=*/R"( +HloModule m + +ENTRY main { + p = bf16[16000,500] parameter(0) + ag-start = (bf16[16000,500], bf16[16000,8000]) all-gather-start(p), + replica_groups={{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}, + channel_id=1, + use_global_device_ids=true, + dimensions={1} + ROOT ag-done = bf16[16000,8000] all-gather-done(ag-start) +})", + /*opcode=*/HloOpcode::kAllGatherStart, + /*expected_latency=*/absl::Microseconds(1323), + }; + + EstimatorTestCase reduce_scatter_all_ranks = { + /*test_name=*/"reduce_scatter_all_ranks", + /*module_string=*/R"( +HloModule m + +add { + param_0 = bf16[] parameter(0) + param_1 = bf16[] parameter(1) + ROOT t = bf16[] add(param_0, param_1) +} + +async_comp { + param_3 = bf16[8192,128256] parameter(0) + ROOT r = bf16[64,128256] reduce-scatter(param_3), + dimensions={0}, + to_apply=add, + replica_groups=[1,128]<=[128], + channel_id=1, + use_global_device_ids=true +} + +ENTRY main { + p = bf16[8192,128256] parameter(0) + rs-start = ((bf16[8192,128256]), bf16[64,128256]) async-start(p), calls=async_comp + ROOT rs-done = bf16[64,128256] async-done(rs-start) +})", + /*opcode=*/HloOpcode::kAsyncStart, + /*expected_latency=*/absl::Microseconds(10525), + }; + + return { + all_gather_intra_host, + all_gather_inter_host_pairwise, + all_gather_all_ranks, + reduce_scatter_all_ranks, + }; +} + +INSTANTIATE_TEST_SUITE_P(SolLatencyEstimatorTests, SolLatencyEstimatorTest, + ValuesIn(GetSolLatencyEstimatorTestCases()), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc index 77dd2a78b63460..1b2e3c5af075b3 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index 51cc4930f16aa6..e35eda3d01657f 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -64,7 +64,7 @@ limitations under the License. #include "xla/service/gpu/cublas_padding_requirements.h" #include "xla/service/gpu/gpu_compiler.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "xla/service/gpu/llvm_gpu_backend/nvptx_backend.h" #include "xla/service/gpu/llvm_gpu_backend/nvptx_utils.h" #include "xla/service/gpu/metrics.h" #include "xla/service/gpu/ptx_compile_options_from_debug_options.h" diff --git a/third_party/xla/xla/service/gpu/ptx_compilation_test.cc b/third_party/xla/xla/service/gpu/ptx_compilation_test.cc index 177fd03f120dd0..a8a99a481adeb6 100644 --- a/third_party/xla/xla/service/gpu/ptx_compilation_test.cc +++ b/third_party/xla/xla/service/gpu/ptx_compilation_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -57,7 +56,7 @@ limitations under the License. namespace xla::gpu { namespace { -constexpr std::string_view kSimpleHlo = R"( +constexpr absl::string_view kSimpleHlo = R"( HloModule simple ENTRY main { @@ -65,7 +64,7 @@ ENTRY main { ROOT neg = f32[10]{0} negate(p) } )"; -constexpr std::string_view kParallelCompilationHlo = R"( +constexpr absl::string_view kParallelCompilationHlo = R"( HloModule parallel_compilation ENTRY main { @@ -80,7 +79,7 @@ ENTRY main { } )"; -constexpr std::string_view kSM90AHlo = R"( +constexpr absl::string_view kSM90AHlo = R"( gemm_fusion_dot { %p0 = f16[64,1024]{1,0} parameter(0) %p1 = f16[1024,32,32]{2,1,0} parameter(1) @@ -102,16 +101,16 @@ ENTRY e { "num_ctas":1}}} })"; -constexpr std::string_view kResultsInNoPtxHlo = R"( +constexpr absl::string_view kResultsInNoPtxHlo = R"( ENTRY e { a = f32[5,5] parameter(0) ROOT _ = f32[5,5] custom-call(a, a), custom_call_target="__cublas$gemm", backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" })"; -std::string_view GetHlo(std::string_view name) { - static const absl::flat_hash_map* const - kHloMap = new absl::flat_hash_map( +absl::string_view GetHlo(absl::string_view name) { + static const absl::flat_hash_map* const + kHloMap = new absl::flat_hash_map( {{"simple", kSimpleHlo}, {"parallel_compilation", kParallelCompilationHlo}, {"requires_sm90a", kSM90AHlo}, @@ -119,14 +118,14 @@ std::string_view GetHlo(std::string_view name) { return kHloMap->at(name); } -void DumpArtifactIfEnabled(std::string_view name, +void DumpArtifactIfEnabled(absl::string_view name, absl::Span data) { if (std::string output_dir; tsl::io::GetTestUndeclaredOutputsDir(&output_dir)) { (void)tsl::WriteStringToFile( tsl::Env::Default(), tsl::io::JoinPath(output_dir, name), - std::string_view(reinterpret_cast(data.data()), - data.size())); + absl::string_view(reinterpret_cast(data.data()), + data.size())); } } @@ -134,7 +133,7 @@ using stream_executor::PtxCompilationMethod; using stream_executor::PtxLinkingMethod; std::string GenerateParametrizedTestname( - std::string_view name, PtxCompilationMethod compilation_method, + absl::string_view name, PtxCompilationMethod compilation_method, PtxLinkingMethod linking_method) { return absl::StrFormat("%v_CompilationMethod_%v_LinkingMethod_%v", name, compilation_method, linking_method); @@ -143,9 +142,9 @@ std::string GenerateParametrizedTestname( class NVPTXCompilationTests : public HloTestBase, public ::testing::WithParamInterface> { + absl::string_view, PtxCompilationMethod, PtxLinkingMethod>> { public: - void SkipTestIfUnsupported(std::string_view name, + void SkipTestIfUnsupported(absl::string_view name, PtxCompilationMethod compilation_method, PtxLinkingMethod linking_method) { using CudaComputeCapability = stream_executor::CudaComputeCapability; @@ -227,7 +226,7 @@ class NVPTXCompilationTests void SetUp() override { HloTestBase::SetUp(); - std::string_view name = std::get<0>(GetParam()); + absl::string_view name = std::get<0>(GetParam()); PtxCompilationMethod compilation_method = std::get<1>(GetParam()); PtxLinkingMethod linking_method = std::get<2>(GetParam()); SkipTestIfUnsupported(name, compilation_method, linking_method); @@ -247,8 +246,8 @@ class NVPTXCompilationTests }; TEST_P(NVPTXCompilationTests, CompileProgram) { - std::string_view name = std::get<0>(GetParam()); - std::string_view hlo_text = GetHlo(name); + absl::string_view name = std::get<0>(GetParam()); + absl::string_view hlo_text = GetHlo(name); auto module = ParseAndReturnVerifiedModule(hlo_text).value(); HloModuleConfig hlo_module_config = module->config(); @@ -270,8 +269,8 @@ MATCHER(MatchesSectionNameAndBinarySize, "") { } TEST_P(NVPTXCompilationTests, CompareBinaryOutput) { - std::string_view name = std::get<0>(GetParam()); - std::string_view hlo_text = GetHlo(name); + absl::string_view name = std::get<0>(GetParam()); + absl::string_view hlo_text = GetHlo(name); auto compile = [&](PtxCompilationMethod compilation_method, PtxLinkingMethod linking_method) { auto module = ParseAndReturnVerifiedModule(hlo_text).value(); @@ -392,7 +391,7 @@ INSTANTIATE_TEST_SUITE_P( PtxLinkingMethod::kDriver, PtxLinkingMethod::kNvJitLink)), [](const ::testing::TestParamInfo>& info) { + absl::string_view, PtxCompilationMethod, PtxLinkingMethod>>& info) { return GenerateParametrizedTestname(std::get<0>(info.param), std::get<1>(info.param), std::get<2>(info.param)); diff --git a/third_party/xla/xla/service/gpu/reduce_scatter_combiner.cc b/third_party/xla/xla/service/gpu/reduce_scatter_combiner.cc index 6b07a79cd4ecd8..2d9813dda1e6a0 100644 --- a/third_party/xla/xla/service/gpu/reduce_scatter_combiner.cc +++ b/third_party/xla/xla/service/gpu/reduce_scatter_combiner.cc @@ -26,7 +26,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_collective_combiner_utils.h" -#include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/hlo_domain_map.h" #include "xla/service/reduce_scatter_combiner.h" #include "tsl/platform/statusor.h" @@ -76,8 +75,7 @@ absl::StatusOr GpuReduceScatterCombiner::Run( // Combine as much as possible for pipelined collectives. int previous_combiner_threshold = combine_threshold_in_bytes_; combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold( - *module, device_info_, ScheduleGpuModuleWithMemoryScheduler, - HloOpcode::kReduceScatter, pointer_size_); + *module, device_info_, HloOpcode::kReduceScatter, pointer_size_); TF_ASSIGN_OR_RETURN( bool combined_pipelined_instructions, RunWithKeyCombiner(module, execution_threads, PipelinedCombinerKey)); diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 4161c53e282099..f517b4203bb4a3 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -51,9 +51,6 @@ cc_library( name = "command_buffer_cmd", srcs = ["command_buffer_cmd.cc"], hdrs = ["command_buffer_cmd.h"], - local_defines = if_cuda_is_configured([ - "GOOGLE_CUDA=1", - ]), deps = [ ":annotation", ":custom_call_thunk", @@ -64,7 +61,6 @@ cc_library( ":nccl_collective_broadcast_thunk", ":nccl_collective_thunk", ":thunk", - ":while_thunk", "//xla:debug_options_flags", "//xla:executable_run_options", "//xla:shape_util", @@ -73,11 +69,11 @@ cc_library( "//xla:util", "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/backends/gpu/collectives:gpu_collectives", - "//xla/core/collectives:communicator", "//xla/ffi:call_frame", "//xla/ffi:ffi_api", "//xla/ffi/api:c_api", "//xla/hlo/ir:hlo", + "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", @@ -90,6 +86,7 @@ cc_library( "//xla/service/gpu:stream_executor_util", "//xla/service/gpu/kernels:custom_kernel", "//xla/stream_executor:command_buffer", + "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory", "//xla/stream_executor:dnn", "//xla/stream_executor:kernel", @@ -99,9 +96,6 @@ cc_library( "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:trace_command_buffer_factory", "//xla/stream_executor/gpu:gpu_blas_lt", - "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_types_header", - "//xla/tsl/concurrency:ref_count", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -147,7 +141,7 @@ cc_library( ":wait_for_streams_thunk", ":while_thunk", "//xla:util", - "//xla/service:buffer_assignment", + "//xla/runtime:buffer_use", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -165,6 +159,7 @@ xla_test( ":command_buffer_cmd", ":thunk", "//xla:types", + "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:executable", "//xla/service:platform_util", @@ -358,6 +353,7 @@ xla_test( "//xla:shape_util", "//xla:types", "//xla:xla_data_proto_cc", + "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:executable", "//xla/service:platform_util", @@ -391,6 +387,7 @@ xla_test( ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", ]), + data = ["//xla/stream_executor/gpu:gpu_test_kernels_fatbin"] ) cc_library( @@ -608,7 +605,6 @@ cc_library( "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:matmul_utils", "//xla/service/gpu/autotuning:autotuner_util", - "//xla/stream_executor", "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", "//xla/stream_executor/gpu:gpu_blas_lt", @@ -743,24 +739,61 @@ cc_library( ":thunk", "//xla:shape_util", "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/backends/gpu/collectives:gpu_collectives", "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "nccl_ragged_all_to_all_thunk", + srcs = ["nccl_ragged_all_to_all_thunk.cc"], + hdrs = ["nccl_ragged_all_to_all_thunk.h"], + deps = [ + ":nccl_collective_thunk", + ":thunk", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/backends/gpu/collectives:gpu_clique_key", + "//xla/backends/gpu/collectives:gpu_collectives", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/hlo/ir:hlo", + "//xla/service:collective_ops_utils", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:memory_allocation", + "//xla/stream_executor:stream", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", ], ) @@ -774,6 +807,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/backends/gpu/collectives:gpu_collectives", "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", "//xla/stream_executor:device_memory", @@ -793,24 +827,32 @@ cc_library( ":nccl_collective_thunk", ":nccl_p2p_thunk_common", ":thunk", + "//xla:status_macros", "//xla:xla_data_proto_cc", "//xla/backends/gpu/collectives:gpu_collectives", "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", + "//xla/service:computation_placer", "//xla/service:global_device_id", "//xla/service/gpu:backend_configs_cc", + "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:errors", ], ) @@ -825,7 +867,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/gpu/collectives:gpu_clique_key", - "//xla/backends/gpu/collectives:gpu_clique_locking", + "//xla/backends/gpu/collectives:gpu_cliques", "//xla/backends/gpu/collectives:gpu_collectives", "//xla/backends/gpu/collectives:gpu_collectives_plugin", "//xla/core/collectives:communicator", @@ -904,17 +946,19 @@ cc_library( "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/backends/gpu/collectives:gpu_collectives", "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", "//xla/service:global_device_id", + "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", ], ) @@ -930,17 +974,19 @@ cc_library( "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/backends/gpu/collectives:gpu_collectives", "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", "//xla/service:global_device_id", + "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", ], ) @@ -955,9 +1001,10 @@ cc_library( "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/backends/gpu/collectives:gpu_collectives", "//xla/hlo/ir:hlo", + "//xla/stream_executor:event", "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], @@ -1069,7 +1116,7 @@ cc_library( "//xla:executable_run_options", "//xla:util", "//xla/backends/gpu/collectives:gpu_clique_key", - "//xla/backends/gpu/collectives:gpu_clique_locking", + "//xla/backends/gpu/collectives:gpu_cliques", "//xla/backends/gpu/collectives:gpu_collectives", "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", @@ -1085,7 +1132,6 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/tsl/lib/gtl:int_type", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:function_ref", diff --git a/third_party/xla/xla/service/gpu/runtime/annotation.cc b/third_party/xla/xla/service/gpu/runtime/annotation.cc index f1473d476cf982..c367793c3ddbf4 100644 --- a/third_party/xla/xla/service/gpu/runtime/annotation.cc +++ b/third_party/xla/xla/service/gpu/runtime/annotation.cc @@ -25,7 +25,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -33,9 +32,9 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" #include "xla/printer.h" #include "tsl/platform/errors.h" #include "tsl/profiler/lib/nvtx_utils.h" @@ -61,7 +60,7 @@ StringHandle RegisterString(const std::string& str) { // Nsight Systems supports some basic HTML markup in annotation strings. This // escaping stops things like from disappearing. -std::ostream& PrintEscaped(std::ostream& os, std::string_view str) { +std::ostream& PrintEscaped(std::ostream& os, absl::string_view str) { for (char c : str) { switch (c) { case '<': @@ -92,7 +91,7 @@ HloPrintOptions PrintOptions() { // Sortable struct representing a frame in the Python stacktrace attached to a // given instruction. struct StackFrame { - std::string_view file_name, function_name, op_name; + absl::string_view file_name, function_name, op_name; int line, column; private: @@ -126,7 +125,7 @@ struct StackFrame { class SourceLocationVisitor : public ConstDfsHloVisitorWithDefault { public: explicit SourceLocationVisitor( - std::string_view op_name_prefix_to_remove__ = {}) + absl::string_view op_name_prefix_to_remove__ = {}) : op_name_prefix_to_remove_{op_name_prefix_to_remove__} {} std::string AsString(int32_t common_prefix) const { @@ -161,7 +160,7 @@ class SourceLocationVisitor : public ConstDfsHloVisitorWithDefault { // sections of the name are common to all operations in the kernel, and the // individual call stack frames in the kernel-level annotation show the // final parts of the op_name that have not already been shown. - std::string_view op_name = meta.op_name(); + absl::string_view op_name = meta.op_name(); if (!op_name.empty()) { op_name = op_name.substr(op_name_prefix_to_remove_.size()); } @@ -234,7 +233,7 @@ class SourceLocationVisitor : public ConstDfsHloVisitorWithDefault { } oss << '\n'; } - std::string_view op_name_prefix_to_remove_{}; + absl::string_view op_name_prefix_to_remove_{}; std::set> location_set_{}; }; @@ -255,8 +254,8 @@ absl::Status VisitInstAndCalledButNotOperands(Visitor& visitor, // Split `a` and `b` by `delim` into two lists of possibly-empty tokens, then // rejoin the first N of those lists that match by `delim`. Note: it is // unspecified which argument the return value points into. -std::string_view LongestPrefix(std::string_view a, std::string_view b, - char delim = '/') { +absl::string_view LongestPrefix(absl::string_view a, absl::string_view b, + char delim = '/') { auto split_a = absl::StrSplit(a, delim); auto split_b = absl::StrSplit(b, delim); @@ -270,7 +269,7 @@ std::string_view LongestPrefix(std::string_view a, std::string_view b, common_prefix_len += a_it->size(); // length of a matching token } - return std::string_view(a.data(), common_prefix_len); + return absl::string_view(a.data(), common_prefix_len); } // Find the longest prefix among instructions' op_name metadata @@ -286,15 +285,15 @@ class OpNamePrefixVisitor : public ConstDfsHloVisitorWithDefault { return absl::OkStatus(); } - std::string_view longest_op_name_prefix() const { + absl::string_view longest_op_name_prefix() const { return prefix_.value_or(""); } private: - std::optional prefix_; + std::optional prefix_; }; -std::string_view GetLongestOpNamePrefix(const HloModule& mod) { +absl::string_view GetLongestOpNamePrefix(const HloModule& mod) { // In the presence of (at least) debug callbacks, calling Accept on the root // instruction of the module may not reach all instructions in the module. OpNamePrefixVisitor visitor{}; @@ -308,7 +307,7 @@ std::string_view GetLongestOpNamePrefix(const HloModule& mod) { return visitor.longest_op_name_prefix(); } -std::string_view GetLongestOpNamePrefix(const HloInstruction& inst) { +absl::string_view GetLongestOpNamePrefix(const HloInstruction& inst) { OpNamePrefixVisitor visitor{}; if (!VisitInstAndCalledButNotOperands(visitor, inst).ok()) { return {}; @@ -316,7 +315,7 @@ std::string_view GetLongestOpNamePrefix(const HloInstruction& inst) { return visitor.longest_op_name_prefix(); } -std::string MakeTitle(const HloModule& mod, std::string_view longest_prefix) { +std::string MakeTitle(const HloModule& mod, absl::string_view longest_prefix) { if (longest_prefix.empty()) { return absl::StrFormat("XlaModule:#hlo_module=%s,program_id=%d#", mod.name(), mod.unique_id()); @@ -379,7 +378,7 @@ std::pair GetLongestSourceLocationPrefix( } } // namespace -ModuleAnnotation::ModuleAnnotation(std::string_view module_name_) +ModuleAnnotation::ModuleAnnotation(absl::string_view module_name_) : title_str_(absl::StrFormat("XlaModule:#hlo_module=%s#", module_name_)), title_(RegisterString(title_str_)), module_name_(RegisterString(std::string{module_name_})) {} @@ -441,12 +440,12 @@ uint64_t ModuleAnnotation::NvtxSchemaId() { } namespace { -std::string MakeKernelName(std::string_view prefix, +std::string MakeKernelName(absl::string_view prefix, const HloInstruction& inst) { // Sometimes an instruction doesn't have metadata, but the computations that // it calls do have metadata. Consider all of those metadata op_name entries // and attach the longest prefix to this launch. - std::string_view op_name = GetLongestOpNamePrefix(inst); + absl::string_view op_name = GetLongestOpNamePrefix(inst); if (op_name.empty()) { return absl::StrFormat("Thunk:#hlo_op=%s#", inst.name()); } else if (op_name.substr(0, prefix.size()) != prefix) { @@ -477,7 +476,7 @@ KernelAnnotation::KernelAnnotation(const ModuleAnnotation& module_annotation, called_hlo_dump(RegisterString("\n" + CalledInstructionsAsString(inst))) { } -ModuleAnnotations::ModuleAnnotations(std::string_view module_name) +ModuleAnnotations::ModuleAnnotations(absl::string_view module_name) : top_level(module_name) {} uint64_t KernelAnnotation::NvtxSchemaId() { @@ -549,7 +548,7 @@ ScopedModuleAnnotations::~ScopedModuleAnnotations() { } std::optional GetKernelAnnotation( - std::string_view profile_annotation) { + absl::string_view profile_annotation) { if (profile_annotation.empty()) { return {}; } diff --git a/third_party/xla/xla/service/gpu/runtime/annotation.h b/third_party/xla/xla/service/gpu/runtime/annotation.h index e5e170891a31c9..13d34e35dbc1f6 100644 --- a/third_party/xla/xla/service/gpu/runtime/annotation.h +++ b/third_party/xla/xla/service/gpu/runtime/annotation.h @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include "absl/container/flat_hash_map.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -33,11 +32,11 @@ namespace xla::gpu { // HloModule class ModuleAnnotation { public: - explicit ModuleAnnotation(std::string_view module_name); + explicit ModuleAnnotation(absl::string_view module_name); explicit ModuleAnnotation(const HloModule& mod); - std::string_view longest_op_name_prefix() const { return longest_prefix_; } - explicit operator std::string_view() const { return title_str_; } + absl::string_view longest_op_name_prefix() const { return longest_prefix_; } + explicit operator absl::string_view() const { return title_str_; } tsl::profiler::StringHandle title() const { return title_; } static uint64_t NvtxSchemaId(); int32_t common_stack_frames() const { return common_stack_frames_; } @@ -62,7 +61,7 @@ struct KernelAnnotation { KernelAnnotation(const ModuleAnnotation& module_annotation, const HloInstruction& inst); - explicit operator std::string_view() const { return title_str; } + explicit operator absl::string_view() const { return title_str; } static uint64_t NvtxSchemaId(); private: @@ -81,11 +80,11 @@ struct KernelAnnotation { // Parsed/prepared information for an HloModule that gets propagated to NVTX // ranges/profilers/... at execution time. struct ModuleAnnotations { - explicit ModuleAnnotations(std::string_view module_name); + explicit ModuleAnnotations(absl::string_view module_name); explicit ModuleAnnotations(const HloModule&); ModuleAnnotation top_level; - absl::flat_hash_map kernels; + absl::flat_hash_map kernels; }; //===----------------------------------------------------------------------===// @@ -104,7 +103,7 @@ class ScopedModuleAnnotations { const ModuleAnnotations* GetCurrentModuleAnnotations(); std::optional GetKernelAnnotation( - std::string_view profile_annotation); + absl::string_view profile_annotation); } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc index 1e7248b238d79d..af5b42f5066a3b 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -23,7 +23,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -43,7 +42,6 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" -#include "xla/core/collectives/communicator.h" #include "xla/debug_options_flags.h" #include "xla/executable_run_options.h" #include "xla/ffi/call_frame.h" @@ -70,6 +68,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" @@ -79,7 +78,6 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/trace_command_buffer_factory.h" -#include "xla/tsl/concurrency/ref_count.h" #include "xla/types.h" // IWYU pragma: keep #include "xla/util.h" #include "tsl/platform/env.h" @@ -91,7 +89,7 @@ limitations under the License. namespace xla::gpu { using ExecutionScopeId = se::CommandBuffer::ExecutionScopeId; -using MemoryAccess = CommandBufferCmd::MemoryAccess; +using MemoryAccess = BufferUse::MemoryAccess; std::string CommandBufferCmdString(CommandBufferCmdType type) { switch (type) { @@ -105,7 +103,7 @@ std::string CommandBufferCmdString(CommandBufferCmdType type) { } } -static std::string_view ReductionKindString(ReductionKind kind) { +static absl::string_view ReductionKindString(ReductionKind kind) { switch (kind) { case ReductionKind::MAX: return "max"; @@ -197,13 +195,13 @@ CommandBufferCmdSequence::CommandBufferCmdSequence( : synchronization_mode_(synchronization_mode) {} void CommandBufferCmdSequence::Append(std::unique_ptr cmd) { - for (const CommandBufferCmd::BufferUsage& buffer : cmd->buffers()) { + for (const BufferUse& buffer : cmd->buffers()) { buffers_.insert(buffer); - allocs_indices_.insert(buffer.slice.index()); + allocs_indices_.insert(buffer.slice().index()); } ExecutionStreamId execution_stream_id = cmd->execution_stream_id(); - CommandBufferCmd::BufferUsageVector buffers = cmd->buffers(); + CommandBufferCmd::BufferUseVector buffers = cmd->buffers(); bool requires_barrier = HasConflicts(execution_stream_id, buffers); // Always add barriers between commands if we want to serialize execution. @@ -243,41 +241,39 @@ absl::Status CommandBufferCmdSequence::Initialize( return absl::OkStatus(); } +namespace { +// Returns true if slice overlaps with any of the slices in read set. +bool Overlaps(const BufferAllocation::Slice& slice, + const absl::flat_hash_set& slices) { + if (slices.contains(slice)) return true; + for (auto& read : slices) + if (read.OverlapsWith(slice)) return true; + return false; +} +} // namespace + bool CommandBufferCmdSequence::HasConflicts( ExecutionStreamId execution_stream_id, - const CommandBufferCmd::BufferUsageVector& buffers) { + const CommandBufferCmd::BufferUseVector& buffers) { auto& rwset = read_write_sets_[execution_stream_id]; - // Returns true if slice overlaps with any of the slices in read set. - auto read_overlap = [&](const BufferAllocation::Slice& slice) { - if (rwset.read.contains(slice)) return true; - for (auto& read : rwset.read) - if (read.OverlapsWith(slice)) return true; - return false; - }; - - // Returns true if slice overlaps with any of the slices in write set. - auto write_overlap = [&](const BufferAllocation::Slice& slice) { - if (rwset.write.contains(slice)) return true; - for (auto& write : rwset.write) - if (write.OverlapsWith(slice)) return true; - return false; - }; - return absl::c_any_of(buffers, [&](const auto& buffer) { - return buffer.access == MemoryAccess::kWrite - ? write_overlap(buffer.slice) || read_overlap(buffer.slice) - : write_overlap(buffer.slice); + return buffer.access() == MemoryAccess::kWrite + ? Overlaps(buffer.slice(), rwset.write) || + Overlaps(buffer.slice(), rwset.read) + : Overlaps(buffer.slice(), rwset.write); }); } void CommandBufferCmdSequence::TrackBuffers( ExecutionStreamId execution_stream_id, - const CommandBufferCmd::BufferUsageVector& buffers) { + const CommandBufferCmd::BufferUseVector& buffers) { auto& rwset = read_write_sets_[execution_stream_id]; - for (const CommandBufferCmd::BufferUsage& buffer : buffers) { - if (buffer.access == MemoryAccess::kWrite) rwset.write.insert(buffer.slice); - if (buffer.access == MemoryAccess::kRead) rwset.read.insert(buffer.slice); + for (const BufferUse& buffer : buffers) { + if (buffer.access() == MemoryAccess::kWrite) + rwset.write.insert(buffer.slice()); + if (buffer.access() == MemoryAccess::kRead) + rwset.read.insert(buffer.slice()); } } @@ -286,7 +282,7 @@ void CommandBufferCmdSequence::ClearTrackedBuffers( read_write_sets_[execution_stream_id] = ReadWriteSet(); } -static std::string_view RecordModeString( +static absl::string_view RecordModeString( CommandBufferCmdSequence::RecordMode mode) { switch (mode) { case CommandBufferCmdSequence::RecordMode::kExclusive: @@ -352,8 +348,8 @@ absl::Status CommandBufferCmdSequence::Record( return absl::OkStatus(); } -const absl::flat_hash_set& -CommandBufferCmdSequence::buffers() const { +const absl::flat_hash_set& CommandBufferCmdSequence::buffers() + const { return buffers_; } @@ -375,13 +371,13 @@ std::vector CommandBufferCmdSequence::barriers() const { TracedCommandBuffer::TracedCommandBuffer( const CommandBufferCmd* trace_cmd, - CommandBufferCmd::BufferUsageVector buffers, int64_t capacity) + CommandBufferCmd::BufferUseVector buffers, int64_t capacity) : trace_cmd_(trace_cmd), capacity_(capacity), entries_(capacity) { CHECK_GT(capacity, 0) << "capacity must be larger than 0"; // NOLINT // Collect unique buffer allocation indices in a set first and convert to // vector as flat hash set iteration has measurable overheads. absl::flat_hash_set allocs_indices; - for (auto& buffer : buffers) allocs_indices.insert(buffer.slice.index()); + for (auto& buffer : buffers) allocs_indices.insert(buffer.slice().index()); allocs_indices_.assign(allocs_indices.begin(), allocs_indices.end()); } @@ -492,7 +488,7 @@ absl::Status TracedCommandBufferCmd::AddTracedCommandBuffer( // } // // Easiest way to get PTX from C++ is to use https://godbolt.org. -inline constexpr std::string_view kMemset32Kernel = R"( +inline constexpr absl::string_view kMemset32Kernel = R"( .version 4.0 .target sm_50 .address_size 64 @@ -541,26 +537,28 @@ ComputationIdCmd::ComputationIdCmd(ExecutionStreamId execution_stream_id, dest_(dest), kind_(kind) {} -CommandBufferCmd::BufferUsageVector ComputationIdCmd::buffers() { +CommandBufferCmd::BufferUseVector ComputationIdCmd::buffers() { return {{dest_, MemoryAccess::kWrite}}; } absl::Status ComputationIdCmd::Initialize(const Thunk::InitializeParams& params, StateManager& state) { -#if defined(GOOGLE_CUDA) - { - absl::MutexLock lock(&mutex_); - if (memset_kernels_.contains(params.executor)) return absl::OkStatus(); - } + auto cuda_cc = std::get_if( + ¶ms.executor->GetDeviceDescription().gpu_compute_capability()); + if (cuda_cc != nullptr) { + { + absl::MutexLock lock(&mutex_); + if (memset_kernels_.contains(params.executor)) return absl::OkStatus(); + } - TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, - CreateKernel("memset32", 3, kMemset32Kernel, - /*cubin_data=*/{}, params.executor, - /*shared_mem_bytes=*/0)); + TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, + CreateKernel("memset32", 3, kMemset32Kernel, + /*cubin_data=*/{}, params.executor, + /*shared_mem_bytes=*/0)); - absl::MutexLock lock(&mutex_); - memset_kernels_.emplace(params.executor, std::move(kernel)); -#endif // GOOGLE_CUDA + absl::MutexLock lock(&mutex_); + memset_kernels_.emplace(params.executor, std::move(kernel)); + } return absl::OkStatus(); } @@ -586,25 +584,29 @@ absl::Status ComputationIdCmd::Record( << "; value=" << value << "; execution_scope_id=" << execution_scope_id.value(); VLOG(5) << " Id: " << dest_ << " (" << dst.opaque() << ")"; + auto cuda_cc = std::get_if( + &execute_params.stream->parent() + ->GetDeviceDescription() + .gpu_compute_capability()); + + if (cuda_cc != nullptr) { + se::Kernel* memset_kernel = [&] { + absl::MutexLock lock(&mutex_); + return memset_kernels_[execute_params.stream->parent()].get(); + }(); + + if (memset_kernel == nullptr) { + return absl::InternalError( + "Memset kernel not loaded on a command buffer executor"); + } -#if defined(GOOGLE_CUDA) - se::Kernel* memset_kernel = [&] { - absl::MutexLock lock(&mutex_); - return memset_kernels_[execute_params.stream->parent()].get(); - }(); - - if (memset_kernel == nullptr) { - return absl::InternalError( - "Memset kernel not loaded on a command buffer executor"); + auto args = se::PackKernelArgs(/*shmem_bytes=*/0, int64_t{1}, value, dst); + return command_buffer->Launch(execution_scope_id, se::ThreadDim(1), + se::BlockDim(1), *memset_kernel, *args); + } else { + return command_buffer->Memset(execution_scope_id, &dst, value, + /*num_elements=*/1); } - - auto args = se::PackKernelArgs(/*shmem_bytes=*/0, int64_t{1}, value, dst); - return command_buffer->Launch(execution_scope_id, se::ThreadDim(1), - se::BlockDim(1), *memset_kernel, *args); -#else - return command_buffer->Memset(execution_scope_id, &dst, value, - /*num_elements=*/1); -#endif // GOOGLE_CUDA } //===----------------------------------------------------------------------===// @@ -674,8 +676,8 @@ absl::Status LaunchCmd::Record(const Thunk::ExecuteParams& execute_params, dims_.block_counts(), *kernel, *kernel_args); } -CommandBufferCmd::BufferUsageVector LaunchCmd::buffers() { - BufferUsageVector buffers; +CommandBufferCmd::BufferUseVector LaunchCmd::buffers() { + BufferUseVector buffers; for (int32_t i = 0; i < args_.size(); ++i) { buffers.emplace_back(args_[i], args_access_[i]); } @@ -746,8 +748,8 @@ absl::Status CustomKernelLaunchCmd::Record( custom_kernel_.block_dims(), *kernel, kernel_args); } -CommandBufferCmd::BufferUsageVector CustomKernelLaunchCmd::buffers() { - BufferUsageVector buffers; +CommandBufferCmd::BufferUseVector CustomKernelLaunchCmd::buffers() { + BufferUseVector buffers; for (int32_t i = 0; i < args_.size(); ++i) { buffers.emplace_back(args_[i], args_access_[i]); } @@ -790,7 +792,7 @@ absl::Status MemcpyDeviceToDeviceCmd::Record( num_bytes_); } -CommandBufferCmd::BufferUsageVector MemcpyDeviceToDeviceCmd::buffers() { +CommandBufferCmd::BufferUseVector MemcpyDeviceToDeviceCmd::buffers() { return {{dst_, MemoryAccess::kWrite}, {src_, MemoryAccess::kRead}}; } @@ -822,7 +824,7 @@ absl::Status MemzeroCmd::Record(const Thunk::ExecuteParams& execute_params, /*num_elements=*/dst_.size()); } -CommandBufferCmd::BufferUsageVector MemzeroCmd::buffers() { +CommandBufferCmd::BufferUseVector MemzeroCmd::buffers() { return {{dst_, MemoryAccess::kWrite}}; } @@ -857,7 +859,7 @@ absl::Status Memset32Cmd::Record(const Thunk::ExecuteParams& execute_params, /*num_elements=*/dst_.size() / sizeof(uint32_t)); } -CommandBufferCmd::BufferUsageVector Memset32Cmd::buffers() { +CommandBufferCmd::BufferUseVector Memset32Cmd::buffers() { return {{dst_, MemoryAccess::kWrite}}; } @@ -894,8 +896,8 @@ absl::Status IfCmd::Record(const Thunk::ExecuteParams& execute_params, bool IfCmd::force_update() { return then_commands_.force_update(); } -CommandBufferCmd::BufferUsageVector IfCmd::buffers() { - absl::flat_hash_set buffers; +CommandBufferCmd::BufferUseVector IfCmd::buffers() { + absl::flat_hash_set buffers; buffers.emplace(pred_, MemoryAccess::kRead); buffers.insert(then_commands_.buffers().begin(), then_commands_.buffers().end()); @@ -942,8 +944,8 @@ bool IfElseCmd::force_update() { return (then_commands_.force_update() || else_commands_.force_update()); } -CommandBufferCmd::BufferUsageVector IfElseCmd::buffers() { - absl::flat_hash_set buffers; +CommandBufferCmd::BufferUseVector IfElseCmd::buffers() { + absl::flat_hash_set buffers; buffers.emplace(pred_, MemoryAccess::kRead); buffers.insert(then_commands_.buffers().begin(), then_commands_.buffers().end()); @@ -992,8 +994,8 @@ bool CaseCmd::force_update() { [](const auto& seq) { return seq.force_update(); }); } -CommandBufferCmd::BufferUsageVector CaseCmd::buffers() { - absl::flat_hash_set buffers; +CommandBufferCmd::BufferUseVector CaseCmd::buffers() { + absl::flat_hash_set buffers; buffers.emplace(index_, MemoryAccess::kRead); for (auto& branch : branches_commands_) { buffers.insert(branch.buffers().begin(), branch.buffers().end()); @@ -1039,8 +1041,8 @@ absl::Status ForCmd::Record(const Thunk::ExecuteParams& execute_params, bool ForCmd::force_update() { return body_commands_.force_update(); } -CommandBufferCmd::BufferUsageVector ForCmd::buffers() { - absl::flat_hash_set buffers; +CommandBufferCmd::BufferUseVector ForCmd::buffers() { + absl::flat_hash_set buffers; buffers.emplace(loop_counter_, MemoryAccess::kWrite); buffers.insert(body_commands_.buffers().begin(), body_commands_.buffers().end()); @@ -1089,8 +1091,8 @@ bool WhileCmd::force_update() { return (cond_commands_.force_update() || body_commands_.force_update()); } -CommandBufferCmd::BufferUsageVector WhileCmd::buffers() { - absl::flat_hash_set buffers; +CommandBufferCmd::BufferUseVector WhileCmd::buffers() { + absl::flat_hash_set buffers; buffers.emplace(pred_, MemoryAccess::kWrite); buffers.insert(cond_commands_.buffers().begin(), cond_commands_.buffers().end()); @@ -1152,7 +1154,7 @@ absl::Status GemmCmd::Record(const Thunk::ExecuteParams& execute_params, }); } -CommandBufferCmd::BufferUsageVector GemmCmd::buffers() { +CommandBufferCmd::BufferUseVector GemmCmd::buffers() { return {{lhs_buffer_, MemoryAccess::kRead}, {rhs_buffer_, MemoryAccess::kRead}, {output_buffer_, MemoryAccess::kWrite}, @@ -1292,8 +1294,8 @@ absl::Status CublasLtCmd::Record(const Thunk::ExecuteParams& execute_params, }); } -CommandBufferCmd::BufferUsageVector CublasLtCmd::buffers() { - BufferUsageVector buffer_usage; +CommandBufferCmd::BufferUseVector CublasLtCmd::buffers() { + BufferUseVector buffer_usage; buffer_usage.reserve(13); buffer_usage.push_back({a_buffer_, MemoryAccess::kRead}); buffer_usage.push_back({b_buffer_, MemoryAccess::kRead}); @@ -1366,8 +1368,8 @@ absl::Status CuDnnCmd::Record(const Thunk::ExecuteParams& execute_params, }); } -CommandBufferCmd::BufferUsageVector CuDnnCmd::buffers() { - CommandBufferCmd::BufferUsageVector buffer_usage; +CommandBufferCmd::BufferUseVector CuDnnCmd::buffers() { + CommandBufferCmd::BufferUseVector buffer_usage; buffer_usage.reserve(args_.size()); for (int i = 0; i < args_.size() - 1; ++i) { buffer_usage.push_back({args_[i], MemoryAccess::kRead}); @@ -1390,50 +1392,47 @@ absl::Status CustomCallCmd::Record(const Thunk::ExecuteParams& execute_params, return RecordXlaFfiCall(execute_params, record_params, command_buffer); } -absl::Status CustomCallCmd::RecordLegacyCustomCall( +namespace { +// Records each buffer associated with each slice into the provided vector. +// Returns an error if any of the slices is missing a buffer allocation. +absl::Status GetBuffers( const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { - std::vector buffers; - buffers.reserve(operands_.size() + results_.size()); - for (auto& slices : {operands_, results_}) { - for (const std::optional& slice : slices) { - if (!slice.has_value()) { - buffers.push_back(nullptr); - continue; - } - - if (!slice->slice.allocation()) { - return absl::InternalError( - "custom call input missing buffer allocation"); - } + absl::Span> slices, + std::vector& buffers, absl::string_view label) { + for (int i = 0; i < slices.size(); ++i) { + if (!slices[i].has_value()) { + buffers.push_back(nullptr); + VLOG(5) << label << i << ": null"; + continue; + } - buffers.push_back( - execute_params.buffer_allocations->GetDeviceAddress(slice->slice) - .opaque()); + if (!slices[i]->slice.allocation()) { + return absl::InternalError("custom call input missing buffer allocation"); } + + auto buffer = + execute_params.buffer_allocations->GetDeviceAddress(slices[i]->slice) + .opaque(); + VLOG(5) << label << i << ": " << slices[i]->slice << " (" << buffer << ")"; + buffers.push_back(buffer); } + return absl::OkStatus(); +} +} // namespace +absl::Status CustomCallCmd::RecordLegacyCustomCall( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { + std::vector buffers; + buffers.reserve(operands_.size() + results_.size()); ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); VLOG(5) << "CustomCallCmd: target_name=" << target_name_ << ", execution_scope_id=" << execution_scope_id.value(); - for (int i = 0; i < operands_.size(); ++i) { - if (operands_[i].has_value()) { - VLOG(5) << " Operand " << i << ": " << operands_[i]->slice << " (" - << buffers[i] << ")"; - } else { - VLOG(5) << " Operand " << i << ": null"; - } - } - for (int i = 0; i < results_.size(); ++i) { - if (results_[i].has_value()) { - VLOG(5) << " Result " << i << ": " << results_[i]->slice << " (" - << buffers[operands_.size() + i] << ")"; - } else { - VLOG(5) << " Result " << i << ": null"; - } - } + TF_RETURN_IF_ERROR( + GetBuffers(execute_params, operands_, buffers, " Operand ")); + TF_RETURN_IF_ERROR( + GetBuffers(execute_params, results_, buffers, " Result ")); -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM TF_ASSIGN_OR_RETURN( auto nested_cmd, se::TraceCommandBufferFactory::Create( @@ -1452,11 +1451,6 @@ absl::Status CustomCallCmd::RecordLegacyCustomCall( return command_buffer->AddNestedCommandBuffer(execution_scope_id, *nested_cmd); -#else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - return Unavailable( - "Custom calls on GPU are not supported in this configuration. Please " - "build with --config=cuda or --config=rocm"); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } absl::Status CustomCallCmd::RecordXlaFfiCall( @@ -1464,7 +1458,8 @@ absl::Status CustomCallCmd::RecordXlaFfiCall( const RecordParams& record_params, se::CommandBuffer* command_buffer) { // TODO(ezhulenev): This is not the most optimal approach, as we'll be doing // a lot of extra allocation on every call. We have to keep attributes - // separate from arguments, as they do not change after thunk is constructed. + // separate from arguments, as they do not change after thunk is + // constructed. ffi::CallFrameBuilder builder(operands_.size(), results_.size()); ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); @@ -1512,7 +1507,6 @@ absl::Status CustomCallCmd::RecordXlaFfiCall( builder.AddAttributes(attrs.Build()); ffi::CallFrame call_frame = builder.Build(); -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM TF_ASSIGN_OR_RETURN( auto nested_cmd, se::TraceCommandBufferFactory::Create( @@ -1530,15 +1524,10 @@ absl::Status CustomCallCmd::RecordXlaFfiCall( return command_buffer->AddNestedCommandBuffer(execution_scope_id, *nested_cmd); -#else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - return Unavailable( - "Custom calls on GPU are not supported in this configuration. Please " - "build with --config=cuda or --config=rocm"); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } -CommandBufferCmd::BufferUsageVector CustomCallCmd::buffers() { - CommandBufferCmd::BufferUsageVector buffer_usage; +CommandBufferCmd::BufferUseVector CustomCallCmd::buffers() { + CommandBufferCmd::BufferUseVector buffer_usage; for (auto& slices : {operands_, results_}) { for (const std::optional& slice : slices) { if (!slice.has_value()) continue; @@ -1571,7 +1560,7 @@ absl::Status BarrierCmd::Record(const Thunk::ExecuteParams& execute_params, return absl::OkStatus(); } -BarrierCmd::BufferUsageVector BarrierCmd::buffers() { return {}; } +BarrierCmd::BufferUseVector BarrierCmd::buffers() { return {}; } //===----------------------------------------------------------------------===// // CollectiveCmd @@ -1689,8 +1678,8 @@ absl::Status AllReduceCmd::Record(const Thunk::ExecuteParams& execute_params, }); } -CommandBufferCmd::BufferUsageVector AllReduceCmd::buffers() { - BufferUsageVector buffer_usage; +CommandBufferCmd::BufferUseVector AllReduceCmd::buffers() { + BufferUseVector buffer_usage; for (auto& buffer : buffers_) { buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead); buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite); @@ -1756,8 +1745,8 @@ absl::Status ReduceScatterCmd::Record( }); } -CommandBufferCmd::BufferUsageVector ReduceScatterCmd::buffers() { - BufferUsageVector buffer_usage; +CommandBufferCmd::BufferUseVector ReduceScatterCmd::buffers() { + BufferUseVector buffer_usage; for (auto& buffer : buffers_) { buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead); buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite); @@ -1820,8 +1809,8 @@ absl::Status AllToAllCmd::Record(const Thunk::ExecuteParams& execute_params, }); } -CommandBufferCmd::BufferUsageVector AllToAllCmd::buffers() { - BufferUsageVector buffer_usage; +CommandBufferCmd::BufferUseVector AllToAllCmd::buffers() { + BufferUseVector buffer_usage; for (auto& buffer : buffers_) { buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead); buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite); @@ -1883,8 +1872,8 @@ absl::Status AllGatherCmd::Record(const Thunk::ExecuteParams& execute_params, }); } -CommandBufferCmd::BufferUsageVector AllGatherCmd::buffers() { - BufferUsageVector buffer_usage; +CommandBufferCmd::BufferUseVector AllGatherCmd::buffers() { + BufferUseVector buffer_usage; for (auto& buffer : buffers_) { buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead); buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite); @@ -1948,8 +1937,8 @@ absl::Status CollectiveBroadcastCmd::Record( }); } -CommandBufferCmd::BufferUsageVector CollectiveBroadcastCmd::buffers() { - BufferUsageVector buffer_usage; +CommandBufferCmd::BufferUseVector CollectiveBroadcastCmd::buffers() { + BufferUseVector buffer_usage; for (auto& buffer : buffers_) { buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead); buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite); @@ -2003,14 +1992,14 @@ DynamicSliceFusionCmd::DynamicSliceFusionCmd( // Force update the command when there is any non-constant value slice offset, // because the memory address might changed if the offset is loop -// iterator or operator outputs even if the parent command's memory pointers do -// not change. +// iterator or operator outputs even if the parent command's memory pointers +// do not change. bool DynamicSliceFusionCmd::force_update() { return !absl::c_all_of(slices_, [](const DynamicSliceThunk::SliceDef& slice) { if (!slice.offsets.has_value()) return true; return absl::c_all_of(slice.offsets.value(), [](DynamicSliceThunk::Offset offset) { - return std::holds_alternative(offset); + return std::holds_alternative(offset); }); }); } @@ -2106,7 +2095,7 @@ absl::Status DynamicSliceFusionCmd::Record( for (auto [offset_idx, values] : llvm::enumerate(llvm::zip( *slice.offsets, src_shape.dimensions(), dst_shape.dimensions()))) { auto [offset, src_dim, dst_dim] = values; - if (uint64_t* const_offset = std::get_if(&offset)) { + if (int64_t* const_offset = std::get_if(&offset)) { // Forward slice offsets that are known constant values VLOG(2) << " - arg " << argument_idx << "[" << offset_idx << "]: constant offset = " << *const_offset; @@ -2170,8 +2159,8 @@ absl::Status DynamicSliceFusionCmd::Record( argument_buffer.GetByteSlice(new_offset, new_size); } - // Safe to create a local BufferAllocations here since buffers are only slices - // of bigger ones allocated elsewhere. + // Safe to create a local BufferAllocations here since buffers are only + // slices of bigger ones allocated elsewhere. BufferAllocations slice_allocations(slice_buffers, orig_allocations.device_ordinal(), orig_allocations.memory_allocator()); @@ -2189,14 +2178,15 @@ absl::Status DynamicSliceFusionCmd::Record( *nested_command_buffer); } -CommandBufferCmd::BufferUsageVector DynamicSliceFusionCmd::buffers() { - CommandBufferCmd::BufferUsageVector buffers; +CommandBufferCmd::BufferUseVector DynamicSliceFusionCmd::buffers() { + CommandBufferCmd::BufferUseVector buffers; auto embed_buffers = embedded_commands_->buffers(); for (auto buffer_usage : embed_buffers) { - CHECK(embeded_to_origin_slice_map_[buffer_usage.slice.index()].has_value()); + CHECK( + embeded_to_origin_slice_map_[buffer_usage.slice().index()].has_value()); buffers.emplace_back( - embeded_to_origin_slice_map_[buffer_usage.slice.index()].value(), - buffer_usage.access); + embeded_to_origin_slice_map_[buffer_usage.slice().index()].value(), + buffer_usage.access()); } return buffers; } diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h index 818fb45b247c40..820771e142c67a 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -40,6 +39,7 @@ limitations under the License. #include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/ffi/api/c_api.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/buffer_allocations.h" @@ -119,28 +119,7 @@ class CommandBufferCmd { : cmd_type_(cmd_type), execution_stream_id_(execution_stream_id) {} virtual ~CommandBufferCmd() = default; - enum class MemoryAccess { kRead, kWrite }; - - // BufferUsage tracks memory access type for a buffer slice, so that we can - // correctly insert command buffer barriers to avoid read/write conflicts. - struct BufferUsage { - BufferUsage(BufferAllocation::Slice slice, MemoryAccess access) - : slice(slice), access(access) {} - - template - friend H AbslHashValue(H h, const BufferUsage& buffer) { - return H::combine(std::move(h), buffer.slice, buffer.access); - } - - bool operator==(const BufferUsage& other) const { - return slice == other.slice && access == other.access; - } - - BufferAllocation::Slice slice; - MemoryAccess access; - }; - - using BufferUsageVector = absl::InlinedVector; + using BufferUseVector = absl::InlinedVector; // A base class for externally managed command state. // @@ -211,7 +190,7 @@ class CommandBufferCmd { // This argument allows conditional commands to record a command sequence // into non-default execution scope. se::CommandBuffer::ExecutionScopeId execution_scope_id = - se::CommandBuffer::kDefaulExecutionScope; + se::CommandBuffer::kDefaultExecutionScope; }; // See Thunk documentation for XLA execution stages (prepare, initialize, @@ -245,7 +224,7 @@ class CommandBufferCmd { // Returns all buffers used by the cmd. These will be used to track cmd // updates, thus they need to be consistent across calls to the function. - virtual BufferUsageVector buffers() = 0; + virtual BufferUseVector buffers() = 0; // Returns true if command implemented as a nested command buffer. virtual bool IsNestedCommandBuffer() const { return false; } @@ -261,8 +240,8 @@ class CommandBufferCmd { virtual se::CommandBuffer::ExecutionScopeId GetExecutionScope( const CommandBufferCmd::RecordParams& record_params) const; - std::string_view profile_annotation() const { return profile_annotation_; } - void set_profile_annotation(std::string_view profile_annotation) { + absl::string_view profile_annotation() const { return profile_annotation_; } + void set_profile_annotation(absl::string_view profile_annotation) { profile_annotation_ = profile_annotation; } @@ -356,7 +335,7 @@ class CommandBufferCmdSequence { RecordMode mode = RecordMode::kExclusive); // Returns buffers referenced by commands in this sequence. - const absl::flat_hash_set& buffers() const; + const absl::flat_hash_set& buffers() const; // Returns buffer allocations indices referenced by commands in this sequence. const absl::flat_hash_set& allocs_indices() const; @@ -383,16 +362,16 @@ class CommandBufferCmdSequence { // Functions for tracking buffer usage of recorded commands and figuring out // when the next command requires a barrier for correctness. bool HasConflicts(ExecutionStreamId execution_stream_id, - const CommandBufferCmd::BufferUsageVector& buffers); + const CommandBufferCmd::BufferUseVector& buffers); void TrackBuffers(ExecutionStreamId execution_stream_id, - const CommandBufferCmd::BufferUsageVector& buffers); + const CommandBufferCmd::BufferUseVector& buffers); void ClearTrackedBuffers(ExecutionStreamId execution_stream_id); SynchronizationMode synchronization_mode_; std::vector commands_; // Buffers referenced by commands in this sequence. - absl::flat_hash_set buffers_; + absl::flat_hash_set buffers_; // Buffer allocations indices referenced by commands in this sequence. absl::flat_hash_set allocs_indices_; @@ -419,7 +398,7 @@ class CommandBufferCmdSequence { class TracedCommandBuffer : public CommandBufferCmd::State { public: explicit TracedCommandBuffer(const CommandBufferCmd* trace_cmd, - CommandBufferCmd::BufferUsageVector buffers, + CommandBufferCmd::BufferUseVector buffers, int64_t capacity = 16); // Returns cached command buffer traced using the same buffer addresses or @@ -477,7 +456,7 @@ class ComputationIdCmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: BufferAllocation::Slice dest_; @@ -504,8 +483,8 @@ class LaunchCmd : public CommandBufferCmd { public: LaunchCmd(ExecutionStreamId execution_stream_id, std::string kernel_name, absl::Span args, - absl::Span args_access, LaunchDimensions dims, - int64_t shmem_bytes); + absl::Span args_access, + LaunchDimensions dims, int64_t shmem_bytes); absl::Status Initialize(const Thunk::InitializeParams& params, StateManager& state) override; @@ -514,12 +493,12 @@ class LaunchCmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: std::string kernel_name_; std::vector args_; - std::vector args_access_; + std::vector args_access_; LaunchDimensions dims_; int64_t shmem_bytes_; @@ -538,7 +517,7 @@ class CustomKernelLaunchCmd : public CommandBufferCmd { public: CustomKernelLaunchCmd(ExecutionStreamId execution_stream_id, absl::Span args, - absl::Span args_access, + absl::Span args_access, CustomKernel custom_kernel); absl::Status Initialize(const Thunk::InitializeParams& params, @@ -548,11 +527,11 @@ class CustomKernelLaunchCmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: std::vector args_; - std::vector args_access_; + std::vector args_access_; CustomKernel custom_kernel_; // Command sequence can be recorded concurrently for multiple command buffers @@ -576,7 +555,7 @@ class MemcpyDeviceToDeviceCmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: BufferAllocation::Slice dst_; @@ -597,7 +576,7 @@ class MemzeroCmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: BufferAllocation::Slice dst_; @@ -616,7 +595,7 @@ class Memset32Cmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: BufferAllocation::Slice dst_; @@ -641,7 +620,7 @@ class IfCmd : public CommandBufferCmd { bool force_update() override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: BufferAllocation::Slice pred_; @@ -667,7 +646,7 @@ class IfElseCmd : public CommandBufferCmd { bool force_update() override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: BufferAllocation::Slice pred_; @@ -693,7 +672,7 @@ class CaseCmd : public CommandBufferCmd { bool force_update() override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: BufferAllocation::Slice index_; @@ -719,7 +698,7 @@ class ForCmd : public CommandBufferCmd { bool force_update() override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: int32_t num_iterations_; @@ -746,7 +725,7 @@ class WhileCmd : public CommandBufferCmd { bool force_update() override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: BufferAllocation::Slice pred_; @@ -773,7 +752,7 @@ class GemmCmd : public TracedCommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; bool IsNestedCommandBuffer() const final { return true; } @@ -815,7 +794,7 @@ class CublasLtCmd : public TracedCommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; bool IsNestedCommandBuffer() const final { return true; } @@ -868,7 +847,7 @@ class CuDnnCmd : public TracedCommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; bool IsNestedCommandBuffer() const final { return true; } @@ -921,7 +900,7 @@ class CustomCallCmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; bool IsNestedCommandBuffer() const final { return true; } private: @@ -970,7 +949,7 @@ class BarrierCmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: const ExecutionStreamId from_stream_id_; @@ -1040,7 +1019,7 @@ class AllReduceCmd : public CollectiveCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; AsyncStreamKind GetAsyncStreamKind() override { return AsyncStreamKind::kCollective; @@ -1066,7 +1045,7 @@ class ReduceScatterCmd : public CollectiveCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; AsyncStreamKind GetAsyncStreamKind() override { return AsyncStreamKind::kCollective; @@ -1092,7 +1071,7 @@ class AllToAllCmd : public CollectiveCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; AsyncStreamKind GetAsyncStreamKind() override { return AsyncStreamKind::kCollective; @@ -1118,7 +1097,7 @@ class AllGatherCmd : public CollectiveCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; AsyncStreamKind GetAsyncStreamKind() override { return AsyncStreamKind::kCollective; @@ -1143,7 +1122,7 @@ class CollectiveBroadcastCmd : public CollectiveCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; private: std::vector buffers_; @@ -1167,7 +1146,7 @@ class DynamicSliceFusionCmd : public CommandBufferCmd { std::vector> offset_byte_sizes); absl::Status Initialize(const Thunk::InitializeParams& params, - StateManager& state); + StateManager& state) override; absl::Status Prepare(const Thunk::PrepareParams& params, Thunk::ResourceRequests& resource_requests) final; @@ -1176,7 +1155,7 @@ class DynamicSliceFusionCmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer) override; - BufferUsageVector buffers() override; + BufferUseVector buffers() override; bool force_update() override; diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc index de9734682ff870..09e3d5f4cffdee 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/runtime/buffer_use.h" #include "xla/service/gpu/runtime/command_buffer_cmd.h" #include "xla/service/gpu/runtime/conditional_thunk.h" #include "xla/service/gpu/runtime/copy_thunk.h" @@ -62,13 +63,14 @@ static absl::Status AppendCommands( //===----------------------------------------------------------------------===// using Command = std::unique_ptr; +using xla::BufferUse; static auto ArgsAccess(const std::vector& written) { - absl::InlinedVector args_access; + absl::InlinedVector args_access; args_access.reserve(written.size()); for (bool w : written) { - args_access.push_back(w ? CommandBufferCmd::MemoryAccess::kWrite - : CommandBufferCmd::MemoryAccess::kRead); + args_access.push_back(w ? BufferUse::MemoryAccess::kWrite + : BufferUse::MemoryAccess::kRead); } return args_access; } diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc index 803dbc37e2bcbc..2ab4d61ff4ab6d 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc @@ -45,9 +45,9 @@ limitations under the License. namespace xla::gpu { -using BufferUsage = CommandBufferCmd::BufferUsage; -using BufferUsageVector = CommandBufferCmd::BufferUsageVector; -using MemoryAccess = CommandBufferCmd::MemoryAccess; +using xla::BufferUse; +using BufferUseVector = CommandBufferCmd::BufferUseVector; +using MemoryAccess = BufferUse::MemoryAccess; static se::StreamExecutor* GpuExecutor() { auto name = @@ -65,7 +65,7 @@ static constexpr auto s1 = ExecutionStreamId(1); // buffer usage vector to the command buffer cmd sequence. struct TestOnlyCommandBufferCmd : public CommandBufferCmd { TestOnlyCommandBufferCmd(ExecutionStreamId execution_stream_id, - BufferUsageVector buffer_usage) + BufferUseVector buffer_usage) : CommandBufferCmd(CommandBufferCmdType::kUnknownCmd, execution_stream_id), buffer_usage(buffer_usage) {} @@ -75,9 +75,9 @@ struct TestOnlyCommandBufferCmd : public CommandBufferCmd { return absl::OkStatus(); } - BufferUsageVector buffers() override { return buffer_usage; } + BufferUseVector buffers() override { return buffer_usage; } - BufferUsageVector buffer_usage; + BufferUseVector buffer_usage; }; class FakeCmd : public CommandBufferCmd { @@ -91,7 +91,7 @@ class FakeCmd : public CommandBufferCmd { se::CommandBuffer* command_buffer) override { return absl::OkStatus(); } - BufferUsageVector buffers() override { return BufferUsageVector{}; } + BufferUseVector buffers() override { return BufferUseVector{}; } }; TEST(CommandBufferCmdTest, SerializeExecution) { @@ -101,13 +101,13 @@ TEST(CommandBufferCmdTest, SerializeExecution) { auto slice1 = BufferAllocation::Slice(&alloc0, 50, 100); // Reads from overlapping slices do not require barriers by default. - auto use0 = BufferUsage(slice0, MemoryAccess::kRead); - auto use1 = BufferUsage(slice1, MemoryAccess::kRead); + auto use0 = BufferUse(slice0, BufferUse::kRead); + auto use1 = BufferUse(slice1, BufferUse::kRead); CommandBufferCmdSequence commands( CommandBufferCmdSequence::SynchronizationMode::kSerialize); - commands.Emplace(s0, BufferUsageVector{use0}); - commands.Emplace(s0, BufferUsageVector{use1}); + commands.Emplace(s0, BufferUseVector{use0}); + commands.Emplace(s0, BufferUseVector{use1}); ASSERT_EQ(commands.barriers().size(), 2); EXPECT_EQ(commands.barriers().at(0), false); @@ -121,12 +121,12 @@ TEST(CommandBufferCmdTest, NoReadBarrier) { auto slice1 = BufferAllocation::Slice(&alloc0, 50, 100); // Reads from overlapping slices do not require barriers. - auto use0 = BufferUsage(slice0, MemoryAccess::kRead); - auto use1 = BufferUsage(slice1, MemoryAccess::kRead); + auto use0 = BufferUse(slice0, BufferUse::kRead); + auto use1 = BufferUse(slice1, BufferUse::kRead); CommandBufferCmdSequence commands; - commands.Emplace(s0, BufferUsageVector{use0}); - commands.Emplace(s0, BufferUsageVector{use1}); + commands.Emplace(s0, BufferUseVector{use0}); + commands.Emplace(s0, BufferUseVector{use1}); ASSERT_EQ(commands.barriers().size(), 2); EXPECT_EQ(commands.barriers().at(0), false); @@ -140,12 +140,12 @@ TEST(CommandBufferCmdTest, NoWriteBarrier) { auto slice0 = BufferAllocation::Slice(&alloc0, 0, 100); auto slice1 = BufferAllocation::Slice(&alloc0, 200, 100); - auto use0 = BufferUsage(slice0, MemoryAccess::kWrite); - auto use1 = BufferUsage(slice1, MemoryAccess::kWrite); + auto use0 = BufferUse(slice0, BufferUse::kWrite); + auto use1 = BufferUse(slice1, BufferUse::kWrite); CommandBufferCmdSequence commands; - commands.Emplace(s0, BufferUsageVector{use0}); - commands.Emplace(s0, BufferUsageVector{use1}); + commands.Emplace(s0, BufferUseVector{use0}); + commands.Emplace(s0, BufferUseVector{use1}); ASSERT_EQ(commands.barriers().size(), 2); EXPECT_EQ(commands.barriers().at(0), false); @@ -160,14 +160,14 @@ TEST(CommandBufferCmdTest, WriteConflictBarrier) { // Reads from overlapping slices can be done in parallel, and before a write // into overlapping slice we need to insert a barrier. - auto use0 = BufferUsage(slice0, MemoryAccess::kRead); - auto use1 = BufferUsage(slice0, MemoryAccess::kRead); - auto use2 = BufferUsage(slice1, MemoryAccess::kWrite); + auto use0 = BufferUse(slice0, BufferUse::kRead); + auto use1 = BufferUse(slice0, BufferUse::kRead); + auto use2 = BufferUse(slice1, BufferUse::kWrite); CommandBufferCmdSequence commands; - commands.Emplace(s0, BufferUsageVector{use0}); - commands.Emplace(s0, BufferUsageVector{use1}); - commands.Emplace(s0, BufferUsageVector{use2}); + commands.Emplace(s0, BufferUseVector{use0}); + commands.Emplace(s0, BufferUseVector{use1}); + commands.Emplace(s0, BufferUseVector{use2}); ASSERT_EQ(commands.barriers().size(), 3); EXPECT_EQ(commands.barriers().at(0), false); @@ -183,12 +183,12 @@ TEST(CommandBufferCmdTest, NoWriteConflictsAcrossStreams) { // Read and write happens on different execution streams and we do not insert // any automatic barriers between streams. - auto use0 = BufferUsage(slice0, MemoryAccess::kRead); - auto use1 = BufferUsage(slice1, MemoryAccess::kWrite); + auto use0 = BufferUse(slice0, BufferUse::kRead); + auto use1 = BufferUse(slice1, BufferUse::kWrite); CommandBufferCmdSequence commands; - commands.Emplace(s0, BufferUsageVector{use0}); - commands.Emplace(s1, BufferUsageVector{use1}); + commands.Emplace(s0, BufferUseVector{use0}); + commands.Emplace(s1, BufferUseVector{use1}); ASSERT_EQ(commands.barriers().size(), 2); EXPECT_EQ(commands.barriers().at(0), false); @@ -350,8 +350,7 @@ TEST(CommandBufferCmdTest, LaunchCmd) { BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); auto args = {slice_a, slice_a, slice_b}; // b = a + a - auto args_access = {MemoryAccess::kRead, MemoryAccess::kRead, - MemoryAccess::kWrite}; + auto args_access = {BufferUse::kRead, MemoryAccess::kRead, BufferUse::kWrite}; // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence commands; @@ -422,9 +421,9 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) { BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation alloc1(/*index=*/1, /*size=*/1024, /*color=*/0); - CommandBufferCmd::BufferUsageVector buffers = { - {BufferAllocation::Slice(&alloc0, 0, 1024), MemoryAccess::kRead}, - {BufferAllocation::Slice(&alloc1, 0, 1024), MemoryAccess::kWrite}}; + CommandBufferCmd::BufferUseVector buffers = { + {BufferAllocation::Slice(&alloc0, 0, 1024), BufferUse::kRead}, + {BufferAllocation::Slice(&alloc1, 0, 1024), BufferUse::kWrite}}; TracedCommandBuffer traced_cmd_buffer(&traced_cmd, buffers, /*capacity=*/trace_cache_size); @@ -512,9 +511,9 @@ static void BM_GetOrTraceCommandBuffer(benchmark::State& state) { BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation alloc1(/*index=*/1, /*size=*/1024, /*color=*/0); - CommandBufferCmd::BufferUsageVector buffers = { - {BufferAllocation::Slice(&alloc0, 0, 1024), MemoryAccess::kRead}, - {BufferAllocation::Slice(&alloc1, 0, 1024), MemoryAccess::kWrite}}; + CommandBufferCmd::BufferUseVector buffers = { + {BufferAllocation::Slice(&alloc0, 0, 1024), BufferUse::kRead}, + {BufferAllocation::Slice(&alloc1, 0, 1024), BufferUse::kWrite}}; se::DeviceMemoryBase mem0(reinterpret_cast(0x01234567)); se::DeviceMemoryBase mem1(reinterpret_cast(0x12345670)); diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc index ebc2be4506a49d..d60a9a02660ebf 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/types/span.h" +#include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/launch_dimensions.h" @@ -68,7 +69,7 @@ limitations under the License. namespace xla::gpu { -using MemoryAccess = CommandBufferCmd::MemoryAccess; +using MemoryAccess = BufferUse::MemoryAccess; using KernelArgsPacking = se::MultiKernelLoaderSpec::KernelArgsPacking; namespace { @@ -799,7 +800,7 @@ TEST(CommandBufferThunkTest, DynamicSliceFusionCmd) { BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, lhs_length); std::vector lhs_offsets = { - DynamicSliceThunk::Offset(2UL), DynamicSliceThunk::Offset(0UL)}; + DynamicSliceThunk::Offset(2l), DynamicSliceThunk::Offset(0l)}; std::vector> arguments = { std::optional(slice_lhs), diff --git a/third_party/xla/xla/service/gpu/runtime/conditional_thunk.cc b/third_party/xla/xla/service/gpu/runtime/conditional_thunk.cc index 88c7273744cd17..f299a24717add1 100644 --- a/third_party/xla/xla/service/gpu/runtime/conditional_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/conditional_thunk.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include @@ -116,7 +115,7 @@ absl::Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) { [](bool* pred) { return *pred ? 0 : 1; }}, branch_index_or_pred); - std::string_view branch_kind = + absl::string_view branch_kind = std::visit(VariantVisitor{[](int32_t*) { return "index"; }, [](bool*) { return "pred"; }}, branch_index_or_pred); diff --git a/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.cc b/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.cc index d0ec3a65283710..6c561cbfe340e0 100644 --- a/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.cc @@ -176,7 +176,7 @@ absl::Status DynamicSliceThunk::ExecuteOnStream(const ExecuteParams& params) { *slice.offsets, src_shape.dimensions(), dst_shape.dimensions()))) { auto [offset, src_dim, dst_dim] = values; - if (uint64_t* const_offset = std::get_if(&offset)) { + if (int64_t* const_offset = std::get_if(&offset)) { // Forward slice offsets that are known constant values VLOG(2) << " - arg " << argument_idx << "[" << offset_idx << "]: constant offset = " << *const_offset; diff --git a/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.h b/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.h index 6adc4a62f72d9d..29e17f1bc6aa51 100644 --- a/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.h @@ -48,7 +48,7 @@ class DynamicSliceThunk : public Thunk { // Dynamic slice offset can be either: (1) a statically known constant value // or (2) a truly dynamic offset that is computed on device and have to be // transferred to host. - using Offset = std::variant; + using Offset = std::variant; DynamicSliceThunk( ThunkInfo thunk_info, std::unique_ptr embedded_thunk, diff --git a/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc b/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc index a26de45ddaa853..54cb63e8fee4d9 100644 --- a/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc @@ -119,7 +119,7 @@ absl::Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) { se::StreamExecutor* executor = params.stream->parent(); LaunchDimensions launch_dimensions; std::optional cluster_dim; - const se::Kernel* kernel = nullptr; + se::Kernel* kernel = nullptr; TF_ASSIGN_OR_RETURN( se::Stream * stream, @@ -198,7 +198,7 @@ absl::Status CustomKernelThunk::Initialize(const InitializeParams& params) { absl::Status CustomKernelThunk::ExecuteOnStream(const ExecuteParams& params) { se::StreamExecutor* executor = params.stream->parent(); - const se::Kernel* kernel = [&] { + se::Kernel* kernel = [&] { absl::MutexLock lock(&mutex_); return kernel_cache_[executor].get(); }(); @@ -222,12 +222,12 @@ absl::Status CustomKernelThunk::ExecuteOnStream(const ExecuteParams& params) { custom_kernel_.shared_memory_bytes()); if (auto cluster = custom_kernel_.cluster_dims(); cluster.has_value()) { - return params.stream->Launch(custom_kernel_.thread_dims(), - custom_kernel_.block_dims(), *cluster, *kernel, - args); + return kernel->Launch(custom_kernel_.thread_dims(), + custom_kernel_.block_dims(), *cluster, params.stream, + args); } else { - return params.stream->Launch(custom_kernel_.thread_dims(), - custom_kernel_.block_dims(), *kernel, args); + return kernel->Launch(custom_kernel_.thread_dims(), + custom_kernel_.block_dims(), params.stream, args); } } diff --git a/third_party/xla/xla/service/gpu/runtime/kernel_thunk.h b/third_party/xla/xla/service/gpu/runtime/kernel_thunk.h index d26e5cab3a182f..caab6242a764d8 100644 --- a/third_party/xla/xla/service/gpu/runtime/kernel_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/kernel_thunk.h @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/base/thread_annotations.h" @@ -141,7 +140,7 @@ class CustomKernelThunk : public Thunk { return args_; } - std::string_view custom_kernel_name() const { return custom_kernel_.name(); } + absl::string_view custom_kernel_name() const { return custom_kernel_.name(); } const std::vector& written() const { return written_; } diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc index 9741a5db1a31f7..b6395226e6e786 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc @@ -23,13 +23,16 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "absl/container/node_hash_map.h" #include "absl/status/status.h" +#include "absl/strings/str_format.h" #include "absl/strings/substitute.h" #include "absl/synchronization/mutex.h" #include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" @@ -40,9 +43,11 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" namespace xla { namespace gpu { @@ -219,63 +224,87 @@ absl::Status RunAllToAll(GpuCollectives* collectives, bool has_split_dimension, std::vector& buffers, se::Stream& stream, Communicator* comm) { int device_ordinal = stream.parent()->device_ordinal(); - VLOG(3) << "Performing all-to-all from device ordinal: " << device_ordinal; + VLOG(3) << "Performing all-to-all from device ordinal: " << device_ordinal + << ", has_split_dimension: " << has_split_dimension; TF_RETURN_IF_ERROR( MaybeRegisterBuffers(collectives, stream.parent(), buffers, comm)); TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm->NumRanks()); - TF_RETURN_IF_ERROR(collectives->GroupStart()); + PrimitiveType element_type = buffers[0].element_type; + int32_t element_count = buffers[0].element_count; + + // All buffers must have the same element type and count. + bool all_buffers_match = absl::c_all_of(buffers, [&](const auto& buffer) { + return buffer.element_type == element_type && + buffer.element_count == element_count; + }); + + if (!all_buffers_match) { + return InvalidArgument( + "All buffers must have the same element type and count"); + } // AllToAll can operate in two modes. Either it specifies a split dimension, // in which case inputs are split and outputs concatenated in that dimension // (here, we only support dimension 0), or it takes a list of inputs // and produces a tuple of outputs. - if (has_split_dimension) { - for (DeviceBufferPair& buffer : buffers) { - TF_RET_CHECK(buffer.element_count % num_ranks == 0) - << "Buffer was not an exact multiple of the number of participants."; + absl::InlinedVector send_buffers; + absl::InlinedVector recv_buffers; - size_t chunk_elements = buffer.element_count / num_ranks; + if (has_split_dimension) { + TF_RET_CHECK(element_count % num_ranks == 0) + << "Buffer element count must be an exact multiple of the number of " + "participants"; + size_t chunk_element_count = element_count / num_ranks; + for (const DeviceBufferPair& buffer : buffers) { for (int peer = 0; peer < num_ranks; ++peer) { - se::DeviceMemoryBase send_slice = - collectives->Slice(buffer.source_buffer, buffer.element_type, - peer * chunk_elements, chunk_elements); - - se::DeviceMemoryBase recv_slice = - collectives->Slice(buffer.destination_buffer, buffer.element_type, - peer * chunk_elements, chunk_elements); - - TF_RETURN_IF_ERROR(comm->Send(send_slice, buffer.element_type, - chunk_elements, peer, - GpuCollectives::On(stream))); - - TF_RETURN_IF_ERROR(comm->Recv(recv_slice, buffer.element_type, - chunk_elements, peer, - GpuCollectives::On(stream))); + send_buffers.push_back(collectives->Slice( + buffer.source_buffer, element_type, peer * chunk_element_count, + chunk_element_count)); + recv_buffers.push_back(collectives->Slice( + buffer.destination_buffer, element_type, peer * chunk_element_count, + chunk_element_count)); } } - } else { - TF_RET_CHECK(buffers.size() == num_ranks) - << "Number of inputs didn't match the number of participants."; - for (size_t i = 0; i < buffers.size(); ++i) { - DeviceBufferPair& buffer = buffers[i]; + return comm->AllToAll(send_buffers, recv_buffers, element_type, + chunk_element_count, GpuCollectives::On(stream)); - TF_RETURN_IF_ERROR(comm->Send(buffer.source_buffer, buffer.element_type, - buffer.element_count, i, - GpuCollectives::On(stream))); - - TF_RETURN_IF_ERROR(comm->Recv(buffer.destination_buffer, - buffer.element_type, buffer.element_count, - i, GpuCollectives::On(stream))); + } else { + for (const DeviceBufferPair& buffer : buffers) { + send_buffers.push_back(buffer.source_buffer); + recv_buffers.push_back(buffer.destination_buffer); } + + return comm->AllToAll(send_buffers, recv_buffers, element_type, + element_count, GpuCollectives::On(stream)); } +} + +static absl::Status SendPtrToPeer(void* ptr, RankId peer, Communicator* comm, + se::Stream& stream) { + VLOG(3) << absl::StreamFormat( + "RecvPtrFromPeer on device #%d; peer=%d; comm=%p; stream=%p", + stream.parent()->device_ordinal(), peer.value(), comm, &stream); + + return comm->Send(se::DeviceMemoryBase(ptr, sizeof(void*)), U64, 1, peer, + GpuCollectives::On(stream)); +} + +static absl::Status RecvPtrFromPeer(void* ptr, RankId peer, Communicator* comm, + se::Stream& stream) { + VLOG(3) << absl::StreamFormat( + "RecvPtrFromPeer on device #%d; peer=%d; comm=%p; stream=%p", + stream.parent()->device_ordinal(), peer.value(), comm, &stream); - return collectives->GroupEnd(); + return comm->Recv(se::DeviceMemoryBase(ptr, sizeof(void*)), U64, 1, peer, + GpuCollectives::On(stream)); } +// TODO(b/380457503): Memcpy AllToAll implementation must be moved to +// NcclCommunicator implementation. absl::Status RunMemCpyAllToAll( GpuCollectives* collectives, bool has_split_dimension, std::vector& buffers, se::Stream& stream, @@ -299,19 +328,19 @@ absl::Status RunMemCpyAllToAll( TF_RET_CHECK(buffer.element_count % num_ranks == 0) << "Buffer was not an exact multiple of the number of participants."; - size_t chunk_elements = buffer.element_count / num_ranks; + size_t chunk_element_count = buffer.element_count / num_ranks; TF_RETURN_IF_ERROR(collectives->GroupStart()); for (int peer = 0; peer < num_ranks; ++peer) { se::DeviceMemoryBase recv_slice = collectives->Slice(buffer.destination_buffer, buffer.element_type, - peer * chunk_elements, chunk_elements); + peer * chunk_element_count, chunk_element_count); send_pointer_map[peer] = (uint64_t)recv_slice.opaque(); - TF_RETURN_IF_ERROR(comm->SendPtrToPeer(&send_pointer_map[peer], peer, - GpuCollectives::On(stream))); - TF_RETURN_IF_ERROR(comm->RecvPtrFromPeer( - &receive_pointer_map[peer], peer, GpuCollectives::On(stream))); + TF_RETURN_IF_ERROR( + SendPtrToPeer(&send_pointer_map[peer], RankId(peer), comm, stream)); + TF_RETURN_IF_ERROR(RecvPtrFromPeer(&receive_pointer_map[peer], + RankId(peer), comm, stream)); } TF_RETURN_IF_ERROR(collectives->GroupEnd()); TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); @@ -319,7 +348,7 @@ absl::Status RunMemCpyAllToAll( for (int peer = 0; peer < num_ranks; ++peer) { se::DeviceMemoryBase send_slice = collectives->Slice(buffer.source_buffer, buffer.element_type, - peer * chunk_elements, chunk_elements); + peer * chunk_element_count, chunk_element_count); se::DeviceMemoryBase dst_addr = se::DeviceMemoryBase((void*)receive_pointer_map[peer]); TF_RETURN_IF_ERROR( @@ -335,10 +364,10 @@ absl::Status RunMemCpyAllToAll( send_pointer_map[peer] = (uint64_t)buffers[peer].destination_buffer.opaque(); - TF_RETURN_IF_ERROR(comm->SendPtrToPeer(&send_pointer_map[peer], peer, - GpuCollectives::On(stream))); - TF_RETURN_IF_ERROR(comm->RecvPtrFromPeer(&receive_pointer_map[peer], peer, - GpuCollectives::On(stream))); + TF_RETURN_IF_ERROR( + SendPtrToPeer(&send_pointer_map[peer], RankId(peer), comm, stream)); + TF_RETURN_IF_ERROR(RecvPtrFromPeer(&receive_pointer_map[peer], + RankId(peer), comm, stream)); } TF_RETURN_IF_ERROR(collectives->GroupEnd()); TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.cc index 8b292e3617fa41..5ea0c6d7cca866 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" @@ -77,8 +78,8 @@ absl::Status RunCollectiveBroadcast(std::vector& buffers, TF_RETURN_IF_ERROR(comm->Broadcast( // Always use rank 0 since we always broadcast from the first id in // replica_groups - src_addr, dest_addr, buffer.element_type, buffer.element_count, 0, - GpuCollectives::On(stream))); + src_addr, dest_addr, buffer.element_type, buffer.element_count, + RankId(0), GpuCollectives::On(stream))); } return collectives->GroupEnd(); } diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc index b2a046321efe33..fc5ce2264bd96e 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc @@ -22,25 +22,36 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" +#include "xla/service/computation_placer.h" #include "xla/service/global_device_id.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/nccl_p2p_thunk_common.h" #include "xla/service/gpu/runtime/thunk.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" namespace xla { namespace gpu { namespace { + absl::StatusOr GetCurrentId( Thunk::CollectiveExecuteParams* collective_params, const NcclP2PConfig& config) { @@ -161,13 +172,49 @@ absl::Status NcclCollectivePermuteStartThunk::Initialize( if (p2p_memcpy_enabled_) { TF_ASSIGN_OR_RETURN(const int64_t current_id, GetCurrentId(params.collective_params, config_)); + absl::MutexLock lock(&barrier_mutex_); + if (barrier_flags_.find(current_id) == barrier_flags_.end()) { + if (!params.stream->parent()->HostMemoryRegister( + &barrier_flags_[current_id], sizeof(uint8_t))) { + LOG(ERROR) << "Registering barrier flag failed."; + } + } + + TF_ASSIGN_OR_RETURN( + std::vector device_buffers, + ConvertToDeviceBuffers(params.buffer_allocations, {buffer_}, + config_.config.operand_element_type)); + TF_RET_CHECK(device_buffers.size() == 1) << "Expected one buffer pair."; + DeviceBufferPair& buffer = device_buffers[0]; + const NcclP2PConfig::SourceTargetMapEntry source_target = + NcclP2PConfig::GetSourceTarget(config_.id_to_source_target, current_id); + + const std::optional source_id = source_target.source; + se::DeviceMemoryBase dest_addr = buffer.destination_buffer; TF_RETURN_IF_ERROR(recv_ptr_map_.InitializeId(current_id)); + + if (source_id) { + TF_RETURN_IF_ERROR( + recv_ptr_map_.PutRecvPtr(current_id, dest_addr.opaque())); + } } return absl::OkStatus(); } +absl::Status NcclCollectivePermuteStartThunk::Cleanup( + const CleanupParams& params) { + TF_ASSIGN_OR_RETURN(const int64_t current_id, + GetCurrentId(params.collective_params, config_)); + + absl::MutexLock lock(&barrier_mutex_); + if (!params.executor->HostMemoryUnregister(&barrier_flags_[current_id])) { + LOG(ERROR) << "Unregistering barrier flag failed."; + } + return absl::OkStatus(); +} + absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective( const ExecuteParams& params, se::Stream& stream, CommunicatorHandle comm_handle) { @@ -190,6 +237,14 @@ absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective( p2p_memcpy_enabled_; TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params)); + if (use_memcpy) { + se::DeviceMemoryBase sync_var_address = + se::DeviceMemoryBase((void*)(&barrier_flags_[current_id])); + TF_RETURN_IF_ERROR(comm_handle.comm->AllReduce( + sync_var_address, sync_var_address, PrimitiveType::U8, 1, + ReductionKind::MIN, GpuCollectives::On(stream))); + } + return ::xla::gpu::RunCollectivePermute( collectives, source_target, device_buffers[0], stream, comm_handle.comm, device_string, current_id, use_memcpy, recv_ptr_map_); @@ -231,8 +286,8 @@ absl::Status RunCollectivePermute( TF_RETURN_IF_ERROR( MaybeRegisterBuffers(collectives, stream.parent(), {buffer}, comm)); - const std::optional source_id = source_target.source; - const std::optional target_id = source_target.target; + std::optional source_id = source_target.source; + std::optional target_id = source_target.target; se::DeviceMemoryBase src_addr = buffer.source_buffer; se::DeviceMemoryBase dest_addr = buffer.destination_buffer; @@ -241,38 +296,15 @@ absl::Status RunCollectivePermute( device_string, current_id, source_id.value_or(-1), target_id.value_or(-1)); - // If all peers are local, only get/send device pointer values and invoke - // memcpy. - if (use_memcpy) { - // If sending to another peer, get the pointer value of the src addr. - // Only change the pointer value when it's different from stored one. - if (source_id) { - TF_RETURN_IF_ERROR( - recv_ptr_map.PutRecvPtr(current_id, dest_addr.opaque())); - } - } else { - // GroupStart/End API is needed only if we will issue both send & recv - // calls. - const bool is_nccl_group_needed = (target_id && source_id); - if (is_nccl_group_needed) { - TF_RETURN_IF_ERROR(collectives->GroupStart()); - } - // Send source buffer to target peer if needed. - if (target_id) { - TF_RETURN_IF_ERROR(comm->Send(src_addr, buffer.element_type, - buffer.element_count, *target_id, - GpuCollectives::On(stream))); - } + if (!use_memcpy) { + std::optional source_rank; + std::vector target_ranks; + if (source_id) source_rank = RankId(*source_id); + if (target_id) target_ranks.push_back(RankId(*target_id)); - // Receive data from the source peer to the destination buffer. - if (source_id) { - TF_RETURN_IF_ERROR(comm->Recv(dest_addr, buffer.element_type, - buffer.element_count, *source_id, - GpuCollectives::On(stream))); - } - if (is_nccl_group_needed) { - TF_RETURN_IF_ERROR(collectives->GroupEnd()); - } + TF_RETURN_IF_ERROR(comm->CollectivePermute( + src_addr, dest_addr, buffer.element_type, buffer.element_count, + source_rank, target_ranks, GpuCollectives::On(stream))); } if (!source_id) { @@ -282,12 +314,9 @@ absl::Status RunCollectivePermute( device_string); TF_RETURN_IF_ERROR(stream.MemZero(&dest_addr, dest_addr.size())); } + if (use_memcpy && target_id) { TF_ASSIGN_OR_RETURN(auto recv_ptr, recv_ptr_map.GetRecvPtr(*target_id)); - if (recv_ptr.IsUnavailable()) { - // TODO make BlockUntilReady support AsyncValueRef directly. - BlockUntilReady(recv_ptr.GetAsyncValue()); - } VLOG(3) << "Using memcpy, received target pointer: " << recv_ptr.get() << " current_id " << current_id << " target_id: " << *target_id; diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_permute_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_collective_permute_thunk.h index bcc124b3dafcd1..8753df53eb6562 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_permute_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_permute_thunk.h @@ -52,9 +52,7 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk { absl::Status InitializeId(int64_t current_id) { absl::MutexLock lock(&mutex_); - if (recv_ptrs_.find(current_id) == recv_ptrs_.end()) { - recv_ptrs_[current_id] = tsl::MakeUnconstructedAsyncValueRef(); - } + recv_ptrs_[current_id] = tsl::MakeUnconstructedAsyncValueRef(); return absl::OkStatus(); } @@ -102,6 +100,7 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk { int64_t partition_count, const Buffer& buffer, bool p2p_memcpy_enabled); absl::Status Initialize(const InitializeParams& params) override; + absl::Status Cleanup(const CleanupParams& params) override; static const char* GetHloOpName() { return "collective-permute-start"; } @@ -115,6 +114,8 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk { const NcclP2PConfig config_; const Buffer buffer_; RecvPtrMap recv_ptr_map_; + absl::Mutex barrier_mutex_; + std::unordered_map barrier_flags_; bool p2p_memcpy_enabled_ = false; int64_t device_count_; }; diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc index 1c839dcb18c9bf..47211e6f437c91 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc @@ -479,10 +479,10 @@ absl::Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) { "first call to collective operation %d; run_id=%d", config().op_id, params.collective_params->run_id.ToInt()); - RendezvousSingle(first_call_rendezvous_flag_, rendezvous_name, - rendezvous_key, num_local_participants, - /*warn_stuck_timeout=*/absl::Seconds(20), - /*terminate_timeout=*/absl::Seconds(40)); + Rendezvous(first_call_rendezvous_flag_, rendezvous_name, rendezvous_key, + num_local_participants, + /*warn_stuck_timeout=*/absl::Seconds(20), + /*terminate_timeout=*/absl::Seconds(40)); } return absl::OkStatus(); diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h index 66c831779607e7..5b5ba1fcf26995 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h @@ -143,6 +143,7 @@ class NcclCollectiveThunk : public Thunk { private: friend class NcclCollectiveThunk; friend class NcclCollectiveDoneThunk; + friend class NcclGroupThunk; absl::Status Initialize(se::StreamExecutor* executor); absl::StatusOr GetEvent(se::StreamExecutor* executor); @@ -209,7 +210,7 @@ class NcclCollectiveThunk : public Thunk { // // TODO(ezhulenev): Try to move this flag to NCCL clique as we need to make // sure that all NCCL resources are allocated just once. - RendezvousSingleFlag first_call_rendezvous_flag_; + RendezvousFlag first_call_rendezvous_flag_; }; //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_group_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_group_thunk.cc index 7e3cdca6120f86..a2bc58cb0dde0a 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_group_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_group_thunk.cc @@ -15,16 +15,20 @@ limitations under the License. #include "xla/service/gpu/runtime/nccl_group_thunk.h" +#include #include #include #include #include "absl/status/status.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/thunk.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/stream.h" #include "xla/util.h" -#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace xla { @@ -32,8 +36,11 @@ namespace gpu { NcclGroupThunk::NcclGroupThunk(const HloInstruction* instruction, Thunk::Kind kind, - std::vector> thunks) - : Thunk(kind, ThunkInfo::WithProfileAnnotation(instruction)) { + std::vector> thunks, + AsyncStreamKind stream_kind) + : Thunk(kind, ThunkInfo::WithProfileAnnotation(instruction)), + stream_kind_(stream_kind), + async_events_(new NcclCollectiveThunk::AsyncEvents()) { for (auto& thunk : thunks) { thunks_.emplace_back(std::move(thunk)); } @@ -46,6 +53,9 @@ absl::Status NcclGroupThunk::Prepare(const PrepareParams& params, return absl::OkStatus(); } absl::Status NcclGroupThunk::Initialize(const InitializeParams& params) { + if (async_events_) { + TF_RETURN_IF_ERROR(async_events_->Initialize(params.executor)); + } for (const std::unique_ptr& thunk : thunks_) { TF_RETURN_IF_ERROR(thunk->Initialize(params)); } @@ -55,12 +65,22 @@ absl::Status NcclGroupThunk::Initialize(const InitializeParams& params) { absl::Status NcclGroupThunk::ExecuteOnStream( const Thunk::ExecuteParams& params) { TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params)); - + int64_t async_stream_idx = static_cast(stream_kind_); + // Async streams are already assigned in gpu_executable.cc::ExecuteThunks. + // async_streams is therefore guaranteed to be non-null and to have enough + // elements to index by the AsyncStreamKind enum. + se::Stream* async_stream = + params.collective_params->async_streams.at(async_stream_idx); + TF_RETURN_IF_ERROR(async_stream->WaitFor(params.stream)); TF_RETURN_IF_ERROR(collectives->GroupStart()); for (const std::unique_ptr& thunk : thunks_) { TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(params)); } TF_RETURN_IF_ERROR(collectives->GroupEnd()); + TF_ASSIGN_OR_RETURN(se::Event * event, + async_events_->GetEvent(params.stream->parent())); + TF_RETURN_IF_ERROR(async_stream->RecordEvent(event)); + return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_group_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_group_thunk.h index d70a85b2c4cf67..9e40ad778f7ac3 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_group_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_group_thunk.h @@ -16,12 +16,13 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_GROUP_THUNK_H_ #define XLA_SERVICE_GPU_RUNTIME_NCCL_GROUP_THUNK_H_ -#include #include -#include +#include #include "absl/status/status.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/thunk.h" namespace xla { @@ -34,14 +35,20 @@ namespace gpu { class NcclGroupThunk : public Thunk { public: NcclGroupThunk(const HloInstruction* instruction, Thunk::Kind kind, - std::vector> thunks); + std::vector> thunks, + AsyncStreamKind stream_kind); absl::Status Prepare(const PrepareParams& params, ResourceRequests& resource_requests) override; absl::Status ExecuteOnStream(const Thunk::ExecuteParams& params) override; absl::Status Initialize(const InitializeParams& params) override; + std::shared_ptr async_events() const { + return async_events_; + } private: ThunkSequence thunks_; + AsyncStreamKind stream_kind_; + std::shared_ptr async_events_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_ragged_all_to_all_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_ragged_all_to_all_thunk.cc new file mode 100644 index 00000000000000..848541c5d99f88 --- /dev/null +++ b/third_party/xla/xla/service/gpu/runtime/nccl_ragged_all_to_all_thunk.cc @@ -0,0 +1,324 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/runtime/nccl_ragged_all_to_all_thunk.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" +#include "xla/backends/gpu/collectives/gpu_collectives.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/memory_allocation.h" +#include "xla/stream_executor/stream.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { +namespace { + +// RaggedAllToAll has 4 operands with ragged tensor metadata: input_offsets, +// send_sizes, output_offsets, and recv_sizes. +constexpr int64_t kNumRaggedMetadataOperands = 4; + +NcclRaggedAllToAllConfig GetNcclRaggedAllToAllConfig( + const HloRaggedAllToAllInstruction* instr) { + NcclRaggedAllToAllConfig config; + config.config = GetNcclCollectiveConfig(instr, std::nullopt); + config.num_ragged_rows = instr->operand(2)->shape().dimensions(0); + config.ragged_row_element_size = + ShapeUtil::ElementsIn(instr->shape()) / instr->shape().dimensions(0); + return config; +} + +// A wrapper around an raw data buffer that indexes values based on the +// PrimitiveType that is stored in the buffer. +class IntegerOperandData { + public: + IntegerOperandData(PrimitiveType element_type, void* data) + : element_type_(element_type), data_(data) {} + + int64_t get(int i) const { + switch (element_type_) { + case PrimitiveType::S32: + case PrimitiveType::U32: + return reinterpret_cast(data_)[i]; + case PrimitiveType::S64: + case PrimitiveType::U64: + return reinterpret_cast(data_)[i]; + default: + LOG(FATAL) << "Unsupported element type: " << element_type_; + } + } + + int64_t operator[](int i) const { return get(i); } + + private: + PrimitiveType element_type_; + void* data_; +}; + +// Loads the offsets and sizes of the input and output ragged tensors from +// device memory. +// +// The parameter `ragged_metadata_allocs` is a vector of pointers to the buffers +// in the host memory allocated by StreamExecutor to copy data from the device +// memory. +absl::StatusOr> LoadRaggedTensorMetadata( + se::Stream& stream, const std::vector& buffers, + const std::vector& ragged_metadata_allocs) { + std::vector indices; + for (int i = 0; i < kNumRaggedMetadataOperands; ++i) { + TF_RETURN_IF_ERROR(stream.Memcpy(ragged_metadata_allocs[i], + buffers[i + 2].source_buffer, + buffers[i + 2].source_buffer.size())); + indices.push_back(IntegerOperandData(buffers[i + 2].element_type, + ragged_metadata_allocs[i])); + } + + // Wait for the copies to complete. + if (absl::Status blocked = stream.BlockHostUntilDone(); !blocked.ok()) { + return absl::InternalError(absl::StrFormat( + "Failed to complete all kernels launched on stream %p: %s", &stream, + blocked.message())); + } + + return indices; +} + +} // namespace + +NcclRaggedAllToAllStartThunk::NcclRaggedAllToAllStartThunk( + ThunkInfo thunk_info, const HloRaggedAllToAllInstruction* instr, + std::vector buffers, bool p2p_memcpy_enabled) + : NcclCollectiveThunk(Thunk::kNcclAllToAllStart, thunk_info, + IsSyncCollective(instr)), + config_(GetNcclRaggedAllToAllConfig(instr)), + buffers_(std::move(buffers)) { + CHECK_EQ(config_.config.operand_count, buffers_.size()); +} + +/*static*/ absl::Status NcclRaggedAllToAllStartThunk::CheckImplementable( + const HloRaggedAllToAllInstruction* instr, int64_t replica_count, + int64_t partition_count) { + auto status = [&instr]() -> absl::Status { + for (HloInstruction* operand : instr->operands()) { + Shape shape = operand->shape(); + TF_RETURN_IF_ERROR(IsValidOperand(shape, Thunk::kNcclRaggedAllToAll)); + } + return absl::OkStatus(); + }; + return AddOpDescription( + status(), instr, replica_count, partition_count); +} + +/*static*/ CollectiveOpGroupMode NcclRaggedAllToAllStartThunk::GetGroupMode( + const HloRaggedAllToAllInstruction* instr) { + return GetNcclRaggedAllToAllConfig(instr).config.group_mode; +} + +absl::Status NcclRaggedAllToAllStartThunk::Initialize( + const InitializeParams& params) { + TF_RETURN_IF_ERROR(NcclCollectiveThunk::Initialize(params)); + + // Allocate temp buffers in the host memory to load the sizes and offsets of + // ragged tensors from device memory. + absl::MutexLock lock(&mutex_); + if (!host_buffer_allocs_.contains(params.executor)) { + std::vector> allocs; + for (int64_t i = 0; i < kNumRaggedMetadataOperands; ++i) { + TF_ASSIGN_OR_RETURN(std::unique_ptr alloc, + params.executor->HostMemoryAllocate( + config_.num_ragged_rows * sizeof(int64_t))); + allocs.push_back(std::move(alloc)); + } + host_buffer_allocs_.emplace(params.executor, std::move(allocs)); + } + + if (!device_buffer_allocs_.contains(params.executor)) { + se::DeviceMemoryBase output_offsets_device_buffer = + params.executor->Allocate(config_.num_ragged_rows * sizeof(int64_t)); + + if (output_offsets_device_buffer.is_null()) { + return absl::InternalError("Failed to allocate output offsets buffer."); + } + + device_buffer_allocs_.emplace(params.executor, + output_offsets_device_buffer); + } + + return absl::OkStatus(); +} + +absl::Status NcclRaggedAllToAllStartThunk::Cleanup( + const CleanupParams& params) { + absl::MutexLock lock(&mutex_); + + if (device_buffer_allocs_.contains(params.executor)) { + se::DeviceMemoryBase alloc = + device_buffer_allocs_.extract(params.executor).mapped(); + params.executor->Deallocate(&alloc); + } + + return absl::OkStatus(); +} + +absl::Status NcclRaggedAllToAllStartThunk::RunNcclCollective( + const ExecuteParams& params, se::Stream& stream, + CommunicatorHandle comm_handle) { + TF_ASSIGN_OR_RETURN( + std::vector device_buffers, + ConvertToDeviceBuffers(params, buffers_, + config_.config.operand_element_type)); + + TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params)); + + // Get buffer allocs to load sizes and offsets of ragged tensors from device + // memory. + std::vector ragged_metadata_allocs(4); + se::DeviceMemoryBase output_offsets_device_buffer; + { + absl::MutexLock lock(&mutex_); + auto it = host_buffer_allocs_.find(stream.parent()); + CHECK(it != host_buffer_allocs_.end()); + + for (int64_t i = 0; i < kNumRaggedMetadataOperands; ++i) { + ragged_metadata_allocs[i] = + reinterpret_cast(it->second[i]->opaque()); + } + + auto jt = device_buffer_allocs_.find(stream.parent()); + CHECK(jt != device_buffer_allocs_.end()); + output_offsets_device_buffer = jt->second; + } + + return xla::gpu::RunRaggedAllToAll( + collectives, config_.ragged_row_element_size, device_buffers, stream, + comm_handle.comm, ragged_metadata_allocs, output_offsets_device_buffer); +} + +AsyncStreamKind NcclRaggedAllToAllStartThunk::GetAsyncStreamKind() const { + return AsyncStreamKind::kCollective; +} + +// Runs AllToAll on a buffer that contains ragged tensor metadata. +absl::Status RunAllToAllOnIndexBuffer( + GpuCollectives* collectives, const se::DeviceMemoryBase& source_buffer, + const se::DeviceMemoryBase& destination_buffer, PrimitiveType element_type, + se::Stream& stream, Communicator* comm) { + TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm->NumRanks()); + + TF_RETURN_IF_ERROR(collectives->GroupStart()); + for (int peer = 0; peer < num_ranks; ++peer) { + se::DeviceMemoryBase send_slice = collectives->Slice( + source_buffer, element_type, /*offset=*/peer, /*count=*/1); + se::DeviceMemoryBase recv_slice = collectives->Slice( + destination_buffer, element_type, /*offset=*/peer, /*count=*/1); + + TF_RETURN_IF_ERROR(comm->Send(send_slice, element_type, /*count=*/1, + RankId(peer), GpuCollectives::On(stream))); + + TF_RETURN_IF_ERROR(comm->Recv(recv_slice, element_type, /*count=*/1, + RankId(peer), GpuCollectives::On(stream))); + } + + TF_RETURN_IF_ERROR(collectives->GroupEnd()); + return stream.BlockHostUntilDone(); +} + +absl::Status RunRaggedAllToAll( + GpuCollectives* collectives, int64_t ragged_row_element_size, + const std::vector& original_buffers, se::Stream& stream, + Communicator* comm, const std::vector& ragged_metadata_allocs, + const se::DeviceMemoryBase& output_offsets_device_buffer) { + int device_ordinal = stream.parent()->device_ordinal(); + VLOG(3) << "Performing ragged-all-to-all from device ordinal: " + << device_ordinal; + TF_RETURN_IF_ERROR(MaybeRegisterBuffers(collectives, stream.parent(), + original_buffers, comm)); + + TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm->NumRanks()); + + std::vector buffers = original_buffers; + + // `output_offsets` of the RaggedAllToAll instruction are sharded in a way, + // that `output_offset[i]` is an offset in the i-th peer output buffer. To + // make it work for NCCL model with send/recv, we need to know offsets in the + // local output buffer. To get the correct offsets we perform an AllToAll on + // the output_offsets buffer. + DeviceBufferPair& output_offsets_buffer_pair = buffers[4]; + TF_RETURN_IF_ERROR(RunAllToAllOnIndexBuffer( + collectives, output_offsets_buffer_pair.source_buffer, + output_offsets_device_buffer, output_offsets_buffer_pair.element_type, + stream, comm)); + output_offsets_buffer_pair.source_buffer = output_offsets_device_buffer; + + TF_ASSIGN_OR_RETURN( + std::vector ragged_metadata, + LoadRaggedTensorMetadata(stream, buffers, ragged_metadata_allocs)); + + const IntegerOperandData& input_offsets = ragged_metadata[0]; + const IntegerOperandData& send_sizes = ragged_metadata[1]; + const IntegerOperandData& output_offsets = ragged_metadata[2]; + const IntegerOperandData& recv_sizes = ragged_metadata[3]; + + TF_RETURN_IF_ERROR(collectives->GroupStart()); + + const DeviceBufferPair& data_buffer = buffers[0]; + for (int peer = 0; peer < num_ranks; ++peer) { + se::DeviceMemoryBase send_slice = + collectives->Slice(data_buffer.source_buffer, data_buffer.element_type, + input_offsets[peer] * ragged_row_element_size, + send_sizes[peer] * ragged_row_element_size); + + se::DeviceMemoryBase recv_slice = collectives->Slice( + data_buffer.destination_buffer, data_buffer.element_type, + output_offsets[peer] * ragged_row_element_size, + recv_sizes[peer] * ragged_row_element_size); + + TF_RETURN_IF_ERROR(comm->Send(send_slice, data_buffer.element_type, + send_sizes[peer] * ragged_row_element_size, + RankId(peer), GpuCollectives::On(stream))); + + TF_RETURN_IF_ERROR(comm->Recv(recv_slice, data_buffer.element_type, + recv_sizes[peer] * ragged_row_element_size, + RankId(peer), GpuCollectives::On(stream))); + } + + return collectives->GroupEnd(); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_ragged_all_to_all_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_ragged_all_to_all_thunk.h new file mode 100644 index 00000000000000..d085aab44d2945 --- /dev/null +++ b/third_party/xla/xla/service/gpu/runtime/nccl_ragged_all_to_all_thunk.h @@ -0,0 +1,102 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_RUNTIME_NCCL_RAGGED_ALL_TO_ALL_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME_NCCL_RAGGED_ALL_TO_ALL_THUNK_H_ + +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/backends/gpu/collectives/gpu_clique_key.h" +#include "xla/backends/gpu/collectives/gpu_collectives.h" +#include "xla/core/collectives/communicator.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/memory_allocation.h" +#include "xla/stream_executor/stream.h" + +namespace xla { +namespace gpu { + +struct NcclRaggedAllToAllConfig { + NcclCollectiveConfig config; + int64_t num_ragged_rows = 1; + int64_t ragged_row_element_size = 1; +}; + +// Thunk that performs a NCCL-based Ragged-All-to-All among CUDA GPU-based +// replicas. +class NcclRaggedAllToAllStartThunk : public NcclCollectiveThunk { + public: + NcclRaggedAllToAllStartThunk(ThunkInfo thunk_info, + const HloRaggedAllToAllInstruction* instr, + std::vector buffers, + bool p2p_memcpy_enabled); + + // Returns whether the given instruction can be lowered to a nccl + // ragged-all-to-all call. + static absl::Status CheckImplementable( + const HloRaggedAllToAllInstruction* instr, int64_t replica_count, + int64_t partition_count); + + absl::Status Initialize(const InitializeParams& params) override; + + absl::Status Cleanup(const CleanupParams& params) override; + + static const char* GetHloOpName() { return "ragged-all-to-all-start"; } + + static CollectiveOpGroupMode GetGroupMode( + const HloRaggedAllToAllInstruction* instr); + + const NcclCollectiveConfig& config() const override { return config_.config; } + absl::Span buffers() const { return buffers_; } + + protected: + absl::Status RunNcclCollective(const ExecuteParams& params, + se::Stream& stream, + CommunicatorHandle comm_handle) override; + + AsyncStreamKind GetAsyncStreamKind() const override; + + private: + const NcclRaggedAllToAllConfig config_; + const std::vector buffers_; + + absl::Mutex mutex_; + absl::flat_hash_map>> + host_buffer_allocs_ ABSL_GUARDED_BY(mutex_); + + absl::flat_hash_map + device_buffer_allocs_ ABSL_GUARDED_BY(mutex_); +}; + +absl::Status RunRaggedAllToAll( + GpuCollectives* collectives, int64_t ragged_row_element_size, + const std::vector& buffers, se::Stream& stream, + Communicator* comm, const std::vector& ragged_metadata_allocs, + const se::DeviceMemoryBase& output_offsets_device_buffer); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_RUNTIME_NCCL_RAGGED_ALL_TO_ALL_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_recv_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_recv_thunk.cc index 58286c02039ace..b5dfd81dd05bfc 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_recv_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_recv_thunk.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" +#include "xla/core/collectives/rank_id.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" @@ -33,9 +34,10 @@ limitations under the License. #include "xla/service/gpu/runtime/nccl_p2p_thunk_common.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -131,8 +133,8 @@ absl::Status NcclRecvThunk::RunNcclCollective(const ExecuteParams& params, } if (should_run) { TF_RETURN_IF_ERROR(comm_handle.comm->Recv( - dest_addr, buffer.element_type, buffer.element_count, *source_id, - GpuCollectives::On(stream))); + dest_addr, buffer.element_type, buffer.element_count, + RankId(*source_id), GpuCollectives::On(stream))); } } else { diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_send_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_send_thunk.cc index 7a86bd2ce69fff..8692b9dd0cf712 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_send_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_send_thunk.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" +#include "xla/core/collectives/rank_id.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" @@ -34,9 +35,10 @@ limitations under the License. #include "xla/service/gpu/runtime/nccl_p2p_thunk_common.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -132,8 +134,8 @@ absl::Status NcclSendThunk::RunNcclCollective(const ExecuteParams& params, if (should_run) { TF_RETURN_IF_ERROR(comm_handle.comm->Send( - src_addr, buffer.element_type, buffer.element_count, *target_id, - GpuCollectives::On(stream))); + src_addr, buffer.element_type, buffer.element_count, + RankId(*target_id), GpuCollectives::On(stream))); } } diff --git a/third_party/xla/xla/service/gpu/runtime/send_recv_thunk.cc b/third_party/xla/xla/service/gpu/runtime/send_recv_thunk.cc index 46a7ebb3bf8fb8..b968e34e72d75f 100644 --- a/third_party/xla/xla/service/gpu/runtime/send_recv_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/send_recv_thunk.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/container/flat_hash_map.h" @@ -50,7 +49,7 @@ using tsl::profiler::TraceMeEncode; // For sharded buffers we should execute Send/Recv operations only on devices // with maximal sharding, and do nothing on every other device. static absl::StatusOr ShouldSkip( - std::string_view operation, const Thunk::ExecuteParams& params, + absl::string_view operation, const Thunk::ExecuteParams& params, const std::optional& device_constraint) { if (!device_constraint.has_value()) return false; diff --git a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc index b7f051d1d119ed..c759339a430032 100644 --- a/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/sequential_thunk.cc @@ -75,6 +75,8 @@ absl::Status SequentialThunk::Initialize(const InitializeParams& params) { } absl::Status SequentialThunk::ExecuteOnStream(const ExecuteParams& params) { + std::optional seq_annotation = + GetKernelAnnotation(profile_annotation()); for (const std::unique_ptr& thunk : thunks_) { std::optional annotation = GetKernelAnnotation(thunk->profile_annotation()); diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.cc b/third_party/xla/xla/service/gpu/runtime/thunk.cc index bb698234cf7c34..39d72a9d3443c8 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/thunk.cc @@ -31,7 +31,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/gpu/collectives/gpu_clique_key.h" -#include "xla/backends/gpu/collectives/gpu_clique_locking.h" +#include "xla/backends/gpu/collectives/gpu_cliques.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" @@ -53,8 +53,10 @@ namespace gpu { // Thunk::CollectiveCliques //===----------------------------------------------------------------------===// -Thunk::CollectiveCliques::CollectiveCliques(AcquiredCliquesMap cliques_map) - : cliques_map_(std::move(cliques_map)) {} +Thunk::CollectiveCliques::CollectiveCliques(AcquiredCliquesMap cliques_map, + int32_t num_transient_cliques) + : cliques_map_(std::move(cliques_map)), + num_transient_cliques_(num_transient_cliques) {} absl::StatusOr Thunk::CollectiveCliques::GetComm( const GpuCliqueKey& clique_key, RankId rank) const { @@ -281,6 +283,9 @@ Thunk::ExecuteParams::ExecuteParams( CASE(kNcclAllToAllDone); CASE(kNcclSend); CASE(kNcclSendDone); + CASE(kNcclRaggedAllToAll); + CASE(kNcclRaggedAllToAllStart); + CASE(kNcclRaggedAllToAllDone); CASE(kNcclRecv); CASE(kNcclRecvDone); CASE(kFft); diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.h b/third_party/xla/xla/service/gpu/runtime/thunk.h index 2ed926660d18d6..90aae04f7d1770 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/thunk.h @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/container/flat_hash_map.h" @@ -33,7 +32,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/gpu/collectives/gpu_clique_key.h" -#include "xla/backends/gpu/collectives/gpu_clique_locking.h" +#include "xla/backends/gpu/collectives/gpu_cliques.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" @@ -155,6 +154,9 @@ class Thunk { kNcclAllToAll, kNcclAllToAllStart, kNcclAllToAllDone, + kNcclRaggedAllToAll, + kNcclRaggedAllToAllStart, + kNcclRaggedAllToAllDone, kNcclSend, kNcclSendDone, kNcclRecv, @@ -179,7 +181,7 @@ class Thunk { // clear what else should become a part of "executable source", we likely // need to keep some information about available symbols and signatures. struct ExecutableSource { - std::string_view text; // PTX for NVIDIA backend + absl::string_view text; // PTX for NVIDIA backend absl::Span binary; // CUBIN for NVIDIA backends BinaryMap dnn_compiled_graphs; }; @@ -216,7 +218,8 @@ class Thunk { class CollectiveCliques { public: CollectiveCliques() = default; - explicit CollectiveCliques(AcquiredCliquesMap cliques_map); + CollectiveCliques(AcquiredCliquesMap cliques_map, + int32_t num_transient_cliques); absl::StatusOr GetComm(const GpuCliqueKey& clique_key, RankId rank) const; @@ -231,8 +234,16 @@ class Thunk { bool empty() const { return cliques_map_.empty(); } + bool num_transient_cliques() const { return num_transient_cliques_; } + private: AcquiredCliquesMap cliques_map_; + + // The number of acquired non-persistent clique. We need to keep track of + // newly created communicators to insert rendezvous after first + // initialization, because otherwise we observe deadlocks with NCCL + // collectives backends. + int32_t num_transient_cliques_ = 0; }; //===--------------------------------------------------------------------===// @@ -441,7 +452,7 @@ class Thunk { virtual std::string ToString(int indent) const { return ""; } Kind kind() const { return kind_; } - std::string_view profile_annotation() const { return profile_annotation_; } + absl::string_view profile_annotation() const { return profile_annotation_; } // Prepares thunk for execution. // diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc index f9ceeba4a6424b..0349d81ffe6eae 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/layout.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/triton_fusion_analysis.h" @@ -36,7 +37,6 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.cc b/third_party/xla/xla/service/gpu/stream_executor_util.cc index 961b7bcf6a81e6..72942a7b30344a 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.cc +++ b/third_party/xla/xla/service/gpu/stream_executor_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/stream_executor_util.h" +#include #include #include #include @@ -23,7 +24,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -386,7 +386,7 @@ absl::StatusOr> CreateKernel( return kernel; } -absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, +absl::Status ExecuteKernelOnStream(se::Kernel& kernel, absl::Span args, const LaunchDimensions& dims, se::Stream* stream) { @@ -394,11 +394,11 @@ absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, std::unique_ptr kernel_args, se::PackKernelArgs(args, kernel.metadata())); - return stream->Launch(dims.thread_counts_per_block(), dims.block_counts(), - kernel, *kernel_args); + return kernel.Launch(dims.thread_counts_per_block(), dims.block_counts(), + stream, *kernel_args); } -absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, +absl::Status ExecuteKernelOnStream(se::Kernel& kernel, absl::Span args, const LaunchDimensions& dims, const se::ClusterDim& cluster_dim, @@ -407,8 +407,8 @@ absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, std::unique_ptr kernel_args, se::PackKernelArgs(args, kernel.metadata())); - return stream->Launch(dims.thread_counts_per_block(), dims.block_counts(), - cluster_dim, kernel, *kernel_args); + return kernel.Launch(dims.thread_counts_per_block(), dims.block_counts(), + cluster_dim, stream, *kernel_args); } // Unimplemented for integers yet. @@ -509,10 +509,10 @@ static void InitializeTypedBuffer(se::Stream* stream, constexpr int threads_per_block = 256; constexpr int blocks_per_grid = (host_buffer_bytes + threads_per_block - 1) / threads_per_block; - TF_CHECK_OK(stream->ThenLaunch(se::ThreadDim(threads_per_block, 1, 1), - se::BlockDim(blocks_per_grid, 1, 1), *kernel, - buffer, host_buffer_bytes, - static_cast(buffer.size()))); + TF_CHECK_OK(kernel->Launch(se::ThreadDim(threads_per_block, 1, 1), + se::BlockDim(blocks_per_grid, 1, 1), stream, + buffer, host_buffer_bytes, + static_cast(buffer.size()))); } void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type, @@ -641,7 +641,7 @@ std::vector KeepNonFailures( } absl::Status AllAlgorithmsFailedInternalError( - std::optional instr_str, + std::optional instr_str, absl::Span profile_results) { std::ostringstream msg; if (instr_str.has_value()) { @@ -659,7 +659,7 @@ absl::Status AllAlgorithmsFailedInternalError( } absl::Status NoAlgorithmSuppliedInternalError( - std::optional instr_str) { + std::optional instr_str) { std::ostringstream msg; if (instr_str.has_value()) { msg << "There are no algorithm candidates for computing: \n " @@ -703,7 +703,7 @@ absl::Span TopResultsWithinMeasurementError( absl::StatusOr PickBestResult( absl::Span profile_results, - std::optional instr_str, + std::optional instr_str, HloModuleConfig hlo_module_config) { if (profile_results.empty()) { return NoAlgorithmSuppliedInternalError(instr_str); diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.h b/third_party/xla/xla/service/gpu/stream_executor_util.h index d0338595f9f17d..87a91c0bd10fbb 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.h +++ b/third_party/xla/xla/service/gpu/stream_executor_util.h @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/status/status.h" @@ -105,13 +104,13 @@ absl::StatusOr> CreateKernel( uint32_t shared_mem_bytes = 0); // Runs loaded kernel on the stream with the provided arguments. -absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, +absl::Status ExecuteKernelOnStream(se::Kernel& kernel, absl::Span args, const LaunchDimensions& dims, se::Stream* stream); // Runs loaded kernel on the stream with the provided arguments. -absl::Status ExecuteKernelOnStream(const se::Kernel& kernel, +absl::Status ExecuteKernelOnStream(se::Kernel& kernel, absl::Span args, const LaunchDimensions& dims, const se::ClusterDim& cluster_dim, @@ -142,7 +141,7 @@ absl::StatusOr GetDNNDataTypeFromPrimitiveType( // If deterministic output is requested, returns first (not failing) result. absl::StatusOr PickBestResult( absl::Span profile_results, - std::optional instr_str, + std::optional instr_str, HloModuleConfig hlo_module_config); // Returns whether determinism is required. diff --git a/third_party/xla/xla/service/gpu/target_util.cc b/third_party/xla/xla/service/gpu/target_util.cc index c86e9d01a0d938..82294c57304efc 100644 --- a/third_party/xla/xla/service/gpu/target_util.cc +++ b/third_party/xla/xla/service/gpu/target_util.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -393,11 +394,12 @@ llvm::CallInst* EmitDeviceFunctionCall( llvm::Triple target_triple = llvm::Triple(module->getTargetTriple()); for (PrimitiveType input_type : input_types) { ir_input_types.push_back( - llvm_ir::PrimitiveTypeToIrType(input_type, module)); + llvm_ir::PrimitiveTypeToIrType(input_type, b->getContext())); } llvm::FunctionType* callee_type = llvm::FunctionType::get( - llvm_ir::PrimitiveTypeToIrType(output_type, module), // Return type. - ir_input_types, // Parameter types. + llvm_ir::PrimitiveTypeToIrType(output_type, + b->getContext()), // Return type. + ir_input_types, // Parameter types. false); // No variadic arguments. // Declares the callee if it is not declared already. diff --git a/third_party/xla/xla/service/gpu/target_util_test.cc b/third_party/xla/xla/service/gpu/target_util_test.cc index a486c405612fa4..862f4f262defce 100644 --- a/third_party/xla/xla/service/gpu/target_util_test.cc +++ b/third_party/xla/xla/service/gpu/target_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/target_util.h" +#include #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" @@ -23,6 +24,7 @@ limitations under the License. #include "llvm/IR/Verifier.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TargetParser/Triple.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/test.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index 163e476f246981..5b799e730bff23 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -54,14 +54,14 @@ cc_library( deps = [ "//xla:debug_options_flags", "//xla:shape_util", + "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:executable", "//xla/service:gpu_plugin", "//xla/service:hlo_module_config", "//xla/service/gpu:gpu_executable", "//xla/stream_executor:platform_manager", - "//xla/tests:filecheck", "//xla/tests:llvm_irgen_test_base", - "//xla/tests:verified_hlo_module", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], @@ -214,6 +214,9 @@ xla_test( deps = [ ":gpu_codegen_test", "//xla:error_spec", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", ], @@ -249,8 +252,8 @@ xla_test( "//xla:literal_util", "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", @@ -266,7 +269,7 @@ xla_test( deps = [ ":gpu_codegen_test", "//xla:error_spec", - "//xla/tests:verified_hlo_module", + "//xla/hlo/testlib:verified_hlo_module", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], @@ -282,7 +285,7 @@ xla_test( "//xla:literal", "//xla:literal_util", "//xla/hlo/ir:hlo", - "//xla/tests:verified_hlo_module", + "//xla/hlo/testlib:verified_hlo_module", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -298,7 +301,7 @@ xla_test( deps = [ ":gpu_codegen_test", "//xla:error_spec", - "//xla/tests:verified_hlo_module", + "//xla/hlo/testlib:verified_hlo_module", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], @@ -351,7 +354,7 @@ xla_test( ":gpu_codegen_test", "//xla:shape_util", "//xla/hlo/ir:hlo", - "//xla/tests:verified_hlo_module", + "//xla/hlo/testlib:verified_hlo_module", "@local_tsl//tsl/platform:test_main", ], ) @@ -406,7 +409,6 @@ xla_test( "//xla:error_spec", "//xla/service:hlo_module_config", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -557,6 +559,7 @@ lit_test_suite( "calling_convention.hlo", "dot_bf16.hlo", "kernel_reuse.hlo", + "offload_scan_output.hlo", "pad_to_static.hlo", "rng_get_and_update_state.hlo", "single_instruction.hlo", @@ -609,7 +612,7 @@ lit_test_suite( # name = "xla-opt", # srcs = ["xla-opt.cc"], # deps = [ -# "//xla/service/gpu/fusions/transforms:passes", +# "//xla/backends/gpu/codegen/transforms:passes", # "//xla/service/gpu/fusions/triton:xla_triton", # "//xla/service/gpu/fusions/triton:xla_triton_passes", # "@llvm-project//mlir:AllExtensions", @@ -883,7 +886,11 @@ xla_test( srcs = ["nop_custom_call_test.cc"], backends = ["gpu"], deps = [ + "//xla:literal", + "//xla:literal_util", "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", + "//xla/tsl/platform:test", "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/service/gpu/tests/fp8_to_llvm_hopper.mlir b/third_party/xla/xla/service/gpu/tests/fp8_to_llvm_hopper.mlir index b9228a4c56efb7..3ecb4e2bb1a1a1 100644 --- a/third_party/xla/xla/service/gpu/tests/fp8_to_llvm_hopper.mlir +++ b/third_party/xla/xla/service/gpu/tests/fp8_to_llvm_hopper.mlir @@ -5,9 +5,9 @@ // When this test fails, change the mapping in ir_emitter_triton.cc. // See b/345700241. #mma = #ttg.nvidia_mma<{ - versionMajor = 2, - versionMinor = 0, - warpsPerCTA = [1, 1], + versionMajor = 2, + versionMinor = 0, + warpsPerCTA = [1, 1], instrShape = [16, 8] }> diff --git a/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.cc index 2f31cb10b58aa4..934a7a6bf6c883 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.cc @@ -23,12 +23,12 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/executable.h" #include "xla/service/gpu/gpu_executable.h" #include "xla/service/hlo_module_config.h" #include "xla/shape_util.h" -#include "xla/tests/filecheck.h" -#include "xla/tests/verified_hlo_module.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.h b/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.h index a6269783536b72..d77a4463055fa5 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.h +++ b/third_party/xla/xla/service/gpu/tests/gpu_codegen_test.h @@ -20,9 +20,9 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/stream_executor/platform_manager.h" #include "xla/tests/llvm_irgen_test_base.h" -#include "xla/tests/verified_hlo_module.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/tests/gpu_compilation_parallelism_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_compilation_parallelism_test.cc index 716797f1ba36b4..61b1bbf59f696c 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_compilation_parallelism_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_compilation_parallelism_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include "xla/error_spec.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/tests/gpu_copy_alone_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_copy_alone_test.cc index 65e538a6a61d94..413411f27f4a68 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_copy_alone_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_copy_alone_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include "xla/error_spec.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/tests/gpu_copy_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_copy_test.cc index 8ef34f0ba63363..6a0149915efa5a 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_copy_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_copy_test.cc @@ -20,10 +20,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/gpu/tests/gpu_ftz_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_ftz_test.cc index f0338549b37d8b..b3ba6a26c8b021 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_ftz_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_ftz_test.cc @@ -18,10 +18,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/verified_hlo_module.h" // Check that the ftz (flush denormals to zero) flag is reflected in PTX as // expected. diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc index e8d8a04f1a93ec..abdb9f471d1ce1 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -1263,6 +1263,136 @@ class FlashAttentionBMMScaleSlidingWindowMaskSoftmaxBMM } }; +class FlashAttentionBMMScaleSegmentMaskSoftmaxBMM + : public MultiHeadedAttentionTest { + protected: + const std::string // NOLINT + GetModuleFlash_Attention_Training_Sequence_Packing_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit_impl, entry_computation_layout={(bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0})->(bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + + ENTRY main.22 { + Arg_0.1 = bf16[2,512,2,64]{3,2,1,0} parameter(0) + Arg_1.2 = bf16[2,512,2,64]{3,2,1,0} parameter(1) + Arg_2.3 = bf16[2,512,2,64]{3,2,1,0} parameter(2) + constant.5 = s32[] constant(256) + broadcast.6 = s32[4]{0} broadcast(constant.5), dimensions={} + constant.7 = s32[5]{0} constant({0, 32768, 65536, 98304, 131072}) + custom-call.8 = (bf16[2,2,512,64]{3,1,2,0}, f32[4,2,512]{2,1,0}, u8[0]{0}) custom-call(Arg_0.1, Arg_1.2, Arg_2.3, broadcast.6, broadcast.6, /*index=5*/constant.7, constant.7), custom_call_target="__cudnn$fmhaSoftmax", operand_layout_constraints={bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, s32[4]{0}, s32[4]{0}, s32[5]{0}, s32[5]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 0.1, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["2", "2", "512", "512"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "dropout_rate": 0, "seed": 42, "sliding_window_length": 0, "max_seg_per_batch": 2}} + get-tuple-element.11 = u8[0]{0} get-tuple-element(custom-call.8), index=2 + get-tuple-element.10 = f32[4,2,512]{2,1,0} get-tuple-element(custom-call.8), index=1 + Arg_3.4 = bf16[2,512,2,64]{3,2,1,0} parameter(3) + get-tuple-element.9 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.8), index=0 + transpose.12 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.9), dimensions={0,2,1,3} + custom-call.13 = (bf16[2,2,512,64]{3,1,2,0}, bf16[2,2,512,64]{3,1,2,0}, bf16[2,2,512,64]{3,1,2,0}, u8[0]{0}) custom-call(Arg_0.1, Arg_1.2, Arg_2.3, get-tuple-element.10, Arg_3.4, /*index=5*/transpose.12, broadcast.6, broadcast.6, constant.7, constant.7), custom_call_target="__cudnn$fmhaSoftmaxBackward", operand_layout_constraints={bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, f32[4,2,512]{2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, s32[4]{0}, s32[4]{0}, s32[5]{0}, s32[5]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 0.1, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["2", "2", "512", "512"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm1_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "dropout_rate": 0, "seed": 42, "sliding_window_length": 0, "max_seg_per_batch": 2}} + get-tuple-element.17 = u8[0]{0} get-tuple-element(custom-call.13), index=3 + get-tuple-element.14 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.13), index=0 + transpose.18 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.14), dimensions={0,2,1,3} + get-tuple-element.15 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.13), index=1 + transpose.19 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.15), dimensions={0,2,1,3} + get-tuple-element.16 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.13), index=2 + transpose.20 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.16), dimensions={0,2,1,3} + ROOT tuple.21 = (bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}) tuple(transpose.12, transpose.18, transpose.19, transpose.20) + } // main.22 + )"; + return hlo_text; + } + + const std::string // NOLINT + GetModuleFlash_Attention_Training_BMM1_SegmentMask_Softmax_BMM2_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit_ref, entry_computation_layout={(bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0})->(bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + + _where.9 { + Arg_0.10 = pred[512]{0} parameter(0) + Arg_1.11 = s32[512]{0} parameter(1) + Arg_2.12 = s32[512]{0} parameter(2) + ROOT select.13 = s32[512]{0} select(Arg_0.10, Arg_1.11, Arg_2.12) + } + + floor_divide.14 { + Arg_0.15 = s32[512]{0} parameter(0) + sign.23 = s32[512]{0} sign(Arg_0.15) + Arg_1.16 = s32[] parameter(1) + sign.24 = s32[] sign(Arg_1.16) + broadcast.25 = s32[512]{0} broadcast(sign.24), dimensions={} + compare.26 = pred[512]{0} compare(sign.23, broadcast.25), direction=NE + broadcast.27 = s32[512]{0} broadcast(Arg_1.16), dimensions={} + remainder.28 = s32[512]{0} remainder(Arg_0.15, broadcast.27) + constant.19 = s32[] constant(0) + broadcast.20 = s32[512]{0} broadcast(constant.19), dimensions={} + compare.29 = pred[512]{0} compare(remainder.28, broadcast.20), direction=NE + and.30 = pred[512]{0} and(compare.26, compare.29) + broadcast.21 = s32[512]{0} broadcast(Arg_1.16), dimensions={} + divide.22 = s32[512]{0} divide(Arg_0.15, broadcast.21) + constant.17 = s32[] constant(1) + broadcast.18 = s32[512]{0} broadcast(constant.17), dimensions={} + subtract.31 = s32[512]{0} subtract(divide.22, broadcast.18) + ROOT call.32 = s32[512]{0} call(and.30, subtract.31, divide.22), to_apply=_where.9 + } // floor_divide.14 + + ENTRY main.61 { + Arg_0.1 = bf16[2,512,2,64]{3,2,1,0} parameter(0) + Arg_1.2 = bf16[2,512,2,64]{3,2,1,0} parameter(1) + Arg_2.3 = bf16[2,512,2,64]{3,2,1,0} parameter(2) + iota.8 = s32[512]{0} iota(), iota_dimension=0 + constant.7 = s32[] constant(256) + call.33 = s32[512]{0} call(iota.8, constant.7), to_apply=floor_divide.14 + broadcast.34 = s32[2,512]{1,0} broadcast(call.33), dimensions={1} + reshape.35 = s32[2,512,1]{2,1,0} reshape(broadcast.34) + broadcast.37 = s32[2,512,1]{2,1,0} broadcast(reshape.35), dimensions={0,1,2} + reshape.38 = s32[2,512]{1,0} reshape(broadcast.37) + broadcast.39 = s32[2,512,512]{2,1,0} broadcast(reshape.38), dimensions={0,1} + reshape.36 = s32[2,1,512]{2,1,0} reshape(broadcast.34) + broadcast.40 = s32[2,1,512]{2,1,0} broadcast(reshape.36), dimensions={0,1,2} + reshape.41 = s32[2,512]{1,0} reshape(broadcast.40) + broadcast.42 = s32[2,512,512]{2,1,0} broadcast(reshape.41), dimensions={0,2} + compare.43 = pred[2,512,512]{2,1,0} compare(broadcast.39, broadcast.42), direction=NE + convert.44 = bf16[2,512,512]{2,1,0} convert(compare.43) + reshape.45 = bf16[2,1,512,512]{3,2,1,0} reshape(convert.44) + constant.5 = bf16[] constant(-2.199e+12) + broadcast.6 = bf16[2,1,512,512]{3,2,1,0} broadcast(constant.5), dimensions={} + multiply.46 = bf16[2,1,512,512]{3,2,1,0} multiply(reshape.45, broadcast.6) + custom-call.47 = (bf16[2,2,512,64]{3,1,2,0}, f32[2,2,512]{2,1,0}, u8[0]{0}) custom-call(Arg_0.1, Arg_1.2, Arg_2.3, multiply.46), custom_call_target="__cudnn$fmhaScaleBiasSoftmax", operand_layout_constraints={bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,1,512,512]{3,2,1,0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 0.1, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["2", "2", "512", "512"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "dropout_rate": 0, "seed": 42, "sliding_window_length": 0, "max_seg_per_batch": 1}} + get-tuple-element.50 = u8[0]{0} get-tuple-element(custom-call.47), index=2 + get-tuple-element.49 = f32[2,2,512]{2,1,0} get-tuple-element(custom-call.47), index=1 + Arg_3.4 = bf16[2,512,2,64]{3,2,1,0} parameter(3) + get-tuple-element.48 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.47), index=0 + transpose.51 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.48), dimensions={0,2,1,3} + custom-call.52 = (bf16[2,2,512,64]{3,1,2,0}, bf16[2,2,512,64]{3,1,2,0}, bf16[2,2,512,64]{3,1,2,0}, u8[0]{0}) custom-call(Arg_0.1, Arg_1.2, Arg_2.3, get-tuple-element.49, Arg_3.4, /*index=5*/multiply.46, transpose.51), custom_call_target="__cudnn$fmhaScaleBiasSoftmaxBackward", operand_layout_constraints={bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, f32[2,2,512]{2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,1,512,512]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 0.1, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["2", "2", "512", "512"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm1_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "dropout_rate": 0, "seed": 42, "sliding_window_length": 0, "max_seg_per_batch": 1}} + get-tuple-element.56 = u8[0]{0} get-tuple-element(custom-call.52), index=3 + get-tuple-element.53 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.52), index=0 + transpose.57 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.53), dimensions={0,2,1,3} + get-tuple-element.54 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.52), index=1 + transpose.58 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.54), dimensions={0,2,1,3} + get-tuple-element.55 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.52), index=2 + transpose.59 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.55), dimensions={0,2,1,3} + ROOT tuple.60 = (bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}) tuple(transpose.51, transpose.57, transpose.58, transpose.59) + } // main.61 + )"; + return hlo_text; + } + + template + void TestImpl_Flash_Attention_Training_BMM1_SegmentMask_Softmax_BMM2() { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < + se::dnn::VersionInfo(9, 6, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.6.0."; + } + XlaBuilder builder(TestName()); + // Cudnn sequence packing packs multiple batches(segments) into one batch + // using offsets and seqlen tensors to indicate where each segment begins + std::string hlo_string = + GetModuleFlash_Attention_Training_Sequence_Packing_HloString_BF16(); // NOLINT + // Reference implementation is regular attention with segment mask + std::string hlo_string_ref = + GetModuleFlash_Attention_Training_BMM1_SegmentMask_Softmax_BMM2_HloString_BF16(); // NOLINT + EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref, + ErrorSpec{1e-3, 1e-3})); + } +}; + class FlashAttentionBMMScaleSoftmaxBMMF8 : public MultiHeadedAttentionTest {}; class FlashAttentionBMMScaleSoftmaxDropoutBMM @@ -1378,6 +1508,13 @@ XLA_TEST_F(FlashAttentionBMMScaleSlidingWindowMaskSoftmaxBMM, bfloat16>(); // NOLINT } +// BMM1 - Scale - SegmentMask - Softmax - BMM2 +XLA_TEST_F(FlashAttentionBMMScaleSegmentMaskSoftmaxBMM, + Flash_Attention_Training_BMM1_SegmentMask_Softmax_BMM2_BF16) { + TestImpl_Flash_Attention_Training_BMM1_SegmentMask_Softmax_BMM2< + bfloat16>(); // NOLINT +} + absl::string_view GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonRef() { static constexpr absl::string_view hlo_text = R"( @@ -1471,8 +1608,7 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, XlaBuilder builder(TestName()); std::string ref_bnth = R"( custom-call.4.0 = ( - bf16[4,4,16,16]{3,1,2,0}, - u8[0]{0} + bf16[4,4,16,16]{3,1,2,0} ) custom-call( convert.19, convert.31, @@ -1546,8 +1682,7 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, custom-call.21.0 = ( f8e4m3fn[4,4,16,16]{3,1,2,0}, f32[1,1,1,1]{3,2,1,0}, - f32[1,1,1,1]{3,2,1,0}, - u8[16]{0} + f32[1,1,1,1]{3,2,1,0} ) custom-call( convert.18, convert.30, @@ -1652,8 +1787,7 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, std::string ref_btnh = R"( custom-call.4.0 = ( - bf16[4,16,4,16]{3,2,1,0}, - u8[0]{0} + bf16[4,16,4,16]{3,2,1,0} ) custom-call( convert.19, convert.31, @@ -1726,8 +1860,7 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, custom-call.21.0 = ( f8e4m3fn[4,16,4,16]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, - f32[1,1,1,1]{3,2,1,0}, - u8[16]{0} + f32[1,1,1,1]{3,2,1,0} ) custom-call( convert.18, convert.30, diff --git a/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index 755c2b0374dfbe..31ad49c7984515 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/service/hlo_module_config.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/test.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/tests/int4_to_packed_int4.mlir b/third_party/xla/xla/service/gpu/tests/int4_to_packed_int4.mlir new file mode 100644 index 00000000000000..29cdd45524d57c --- /dev/null +++ b/third_party/xla/xla/service/gpu/tests/int4_to_packed_int4.mlir @@ -0,0 +1,110 @@ +// RUN: xla-opt --int4-to-packed-int4-rewrite %s --mlir-print-ir-after-all + +module { + tt.func @gemm_fusion_dot_2_impl(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32> + %0 = tt.get_program_id x : i32 + %c16_i32 = arith.constant 16 : i32 + %1 = arith.divsi %0, %c16_i32 : i32 + %c8_i32 = arith.constant 8 : i32 + %2 = arith.muli %1, %c8_i32 : i32 + %c1_i32 = arith.constant 1 : i32 + %3 = arith.subi %c1_i32, %2 : i32 + %4 = arith.cmpi slt, %3, %c8_i32 : i32 + %5 = arith.select %4, %3, %c8_i32 : i32 + %6 = arith.remsi %0, %5 : i32 + %7 = arith.addi %2, %6 : i32 + %c16_i32_0 = arith.constant 16 : i32 + %8 = arith.remsi %0, %c16_i32_0 : i32 + %9 = arith.divsi %8, %5 : i32 + %c128_i32 = arith.constant 128 : i32 + %10 = arith.muli %7, %c128_i32 : i32 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %11 = arith.addi %10, %c0_i32 : i32 + %c128_i64 = arith.constant 128 : i64 + %c0_i32_1 = arith.constant 0 : i32 + %c128_i64_2 = arith.constant 128 : i64 + %c0_i32_3 = arith.constant 0 : i32 + %c128_i64_4 = arith.constant 128 : i64 + %c0_i32_5 = arith.constant 0 : i32 + %12 = arith.addi %c0_i32_3, %c0_i32_5 : i32 + %c64_i64 = arith.constant 64 : i64 + %c0_i32_6 = arith.constant 0 : i32 + %c64_i64_7 = arith.constant 64 : i64 + %c8192_i32 = arith.constant 8192 : i32 + %13 = tt.get_program_id y : i32 + %c0_i32_8 = arith.constant 0 : i32 + %14 = arith.addi %c0_i32_8, %13 : i32 + %15 = arith.muli %14, %c8192_i32 : i32 + %16 = tt.addptr %arg0, %15 : !tt.ptr, i32 + %17 = tt.make_tensor_ptr %16, [%c128_i64_2, %c64_i64_7], [%c1_i64, %c128_i64_4], [%c0_i32_1, %c0_i32_6] {order = array} : > + %18 = tt.advance %17, [%10, %c0_i32_3] : > + %c0_i32_9 = arith.constant 0 : i32 + %c256_i64 = arith.constant 256 : i64 + %c0_i32_10 = arith.constant 0 : i32 + %19 = arith.addi %c0_i32_9, %c0_i32_10 : i32 + %c64_i64_11 = arith.constant 64 : i64 + %c0_i32_12 = arith.constant 0 : i32 + %c64_i64_13 = arith.constant 64 : i64 + %c128_i32_14 = arith.constant 128 : i32 + %20 = arith.muli %9, %c128_i32_14 : i32 + %c1_i64_15 = arith.constant 1 : i64 + %c0_i32_16 = arith.constant 0 : i32 + %21 = arith.addi %20, %c0_i32_16 : i32 + %c256_i64_17 = arith.constant 256 : i64 + %c0_i32_18 = arith.constant 0 : i32 + %c256_i64_19 = arith.constant 256 : i64 + %c16384_i32 = arith.constant 16384 : i32 + %22 = tt.get_program_id y : i32 + %c0_i32_20 = arith.constant 0 : i32 + %23 = arith.addi %c0_i32_20, %22 : i32 + %24 = arith.muli %23, %c16384_i32 : i32 + %25 = tt.addptr %arg1, %24 : !tt.ptr, i32 + %26 = tt.make_tensor_ptr %25, [%c64_i64_13, %c256_i64_19], [%c256_i64, %c1_i64_15], [%c0_i32_12, %c0_i32_18] {order = array} : > + %27 = tt.advance %26, [%c0_i32_9, %20] : > + %c0_i32_21 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c32_i32 = arith.constant 32 : i32 + %28:3 = scf.for %arg3 = %c0_i32_21 to %c64_i32 step %c32_i32 iter_args(%arg4 = %18, %arg5 = %27, %arg6 = %cst) -> (!tt.ptr>, !tt.ptr>, tensor<128x128xf32>) : i32 { + %39 = tt.load %arg4 : !tt.ptr> + %c0_i32_35 = arith.constant 0 : i32 + %c32_i32_36 = arith.constant 32 : i32 + %40 = tt.advance %arg4, [%c0_i32_35, %c32_i32_36] : > + %41 = tt.load %arg5 : !tt.ptr> + %c32_i32_37 = arith.constant 32 : i32 + %c0_i32_38 = arith.constant 0 : i32 + %42 = tt.advance %arg5, [%c32_i32_37, %c0_i32_38] : > + %43 = arith.extsi %39 : tensor<128x32xi4> to tensor<128x32xi8> + %44 = arith.sitofp %43 : tensor<128x32xi8> to tensor<128x32xf32> + %45 = tt.dot %44, %41, %arg6 : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32> + scf.yield %40, %42, %45 : !tt.ptr>, !tt.ptr>, tensor<128x128xf32> + } + %c128_i32_22 = arith.constant 128 : i32 + %29 = arith.muli %7, %c128_i32_22 : i32 + %c256_i64_23 = arith.constant 256 : i64 + %c0_i32_24 = arith.constant 0 : i32 + %30 = arith.addi %29, %c0_i32_24 : i32 + %c128_i64_25 = arith.constant 128 : i64 + %c0_i32_26 = arith.constant 0 : i32 + %c128_i64_27 = arith.constant 128 : i64 + %c128_i32_28 = arith.constant 128 : i32 + %31 = arith.muli %9, %c128_i32_28 : i32 + %c1_i64_29 = arith.constant 1 : i64 + %c0_i32_30 = arith.constant 0 : i32 + %32 = arith.addi %31, %c0_i32_30 : i32 + %c256_i64_31 = arith.constant 256 : i64 + %c0_i32_32 = arith.constant 0 : i32 + %c256_i64_33 = arith.constant 256 : i64 + %c32768_i32 = arith.constant 32768 : i32 + %33 = tt.get_program_id y : i32 + %c0_i32_34 = arith.constant 0 : i32 + %34 = arith.addi %c0_i32_34, %33 : i32 + %35 = arith.muli %34, %c32768_i32 : i32 + %36 = tt.addptr %arg2, %35 : !tt.ptr, i32 + %37 = tt.make_tensor_ptr %36, [%c128_i64_27, %c256_i64_33], [%c256_i64_23, %c1_i64_29], [%c0_i32_26, %c0_i32_32] {order = array} : > + %38 = tt.advance %37, [%29, %31] : > + tt.store %38, %28#2 : !tt.ptr> + tt.return + } +} diff --git a/third_party/xla/xla/service/gpu/tests/int4_to_packed_int4_small.mlir b/third_party/xla/xla/service/gpu/tests/int4_to_packed_int4_small.mlir new file mode 100644 index 00000000000000..a7323a4afaed8b --- /dev/null +++ b/third_party/xla/xla/service/gpu/tests/int4_to_packed_int4_small.mlir @@ -0,0 +1,12 @@ +// RUN: xla-opt --int4-to-packed-int4-rewrite %s + +module { + tt.func @dot_test(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<16x16xi8> { + %c0 = arith.constant 0 : i32 + %c16 = arith.constant 16: i64 + %0 = tt.make_tensor_ptr %arg0, [%c16, %c16], [%c16, %c16], [%c0, %c0] {order = array} : > + %1 = tt.load %0 : !tt.ptr> + %2 = arith.extsi %1 : tensor<16x16xi4> to tensor<16x16xi8> + tt.return %2 : tensor<16x16xi8> + } +} diff --git a/third_party/xla/xla/service/gpu/tests/mixed_precision_dot.mlir b/third_party/xla/xla/service/gpu/tests/mixed_precision_dot.mlir new file mode 100644 index 00000000000000..1116ae9391c35e --- /dev/null +++ b/third_party/xla/xla/service/gpu/tests/mixed_precision_dot.mlir @@ -0,0 +1,12 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s + +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=4}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @f16_to_f8_dot_operand(%f16_inp: tensor<32x32xf16, #dot_operand>) { + // CHECK-LABEL: @f16_to_f8_dot_operand + + %f8 = tt.fp_to_fp %f16_inp, rounding = rtne : tensor<32x32xf16, #dot_operand> -> tensor<32x32xf8E5M2, #dot_operand> + tt.return + } +} diff --git a/third_party/xla/xla/service/gpu/tests/nop_custom_call_test.cc b/third_party/xla/xla/service/gpu/tests/nop_custom_call_test.cc index d979d18aa8ac9d..06df6792eb3e9a 100644 --- a/third_party/xla/xla/service/gpu/tests/nop_custom_call_test.cc +++ b/third_party/xla/xla/service/gpu/tests/nop_custom_call_test.cc @@ -13,9 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/tests/offload_scan_output.hlo b/third_party/xla/xla/service/gpu/tests/offload_scan_output.hlo new file mode 100644 index 00000000000000..ab954aa43ea91b --- /dev/null +++ b/third_party/xla/xla/service/gpu/tests/offload_scan_output.hlo @@ -0,0 +1,59 @@ +// RUN: hlo-opt %s --platform=gpu --stage=hlo --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/a100_pcie_80.txtpb --split-input-file | FileCheck --check-prefixes=CHECK %s + +HloModule jit_f, entry_computation_layout={()->(f32[4]{0:S(5)}, f32[4]{0})}, allow_spmd_sharding_propagation_to_output={true,true} + +// # Simplified from the following Python script. +// +// import jax +// import jax.numpy as jnp +// +// p = jax.sharding.SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host") +// +// @jax.jit +// def f(): +// def g(_1, _2): +// return None, (jax.device_put(jnp.array(1.0), p), jnp.array(2.0)) +// return jax.lax.scan(g, None, length = 4)[1] +// +// print(f()[0].sharding) # doesn't crash + +// Verify that the optimized code allocates one pinned-host buffer. +// CHECK: f32[4]{0:S(5)} custom-call(), custom_call_target="AllocateBuffer" +// CHECK-NOT: custom-call(), custom_call_target="AllocateBuffer" + +body { + body-arg.tuple = (s32[], f32[4]{0}, f32[4]{0}) parameter(0) + index.s32 = s32[] get-tuple-element(body-arg.tuple), index=0 + one.s32 = s32[] constant(1) + add.32 = s32[] add(index.s32, one.s32) + pinned-host-buffer = f32[4]{0} get-tuple-element(body-arg.tuple), index=1 + one.f32 = f32[] constant(1) + custom-call.9 = f32[] custom-call(one.f32), custom_call_target="annotate_device_placement", + custom_call_has_side_effect=true, + frontend_attributes={_xla_buffer_placement="pinned_host"} + reshape.22 = f32[1]{0} reshape(custom-call.9) + new-pinned-host-buffer = f32[4]{0} dynamic-update-slice(pinned-host-buffer, reshape.22, index.s32) + device-buffer = f32[4]{0} get-tuple-element(body-arg.tuple), index=2 + two.f32 = f32[] constant(2) + reshape.27 = f32[1]{0} reshape(two.f32) + new-device-buffer = f32[4]{0} dynamic-update-slice(device-buffer, reshape.27, index.s32) + ROOT new-body-arg.tuple = (s32[], f32[4]{0}, f32[4]{0}) tuple(add.32, new-pinned-host-buffer, new-device-buffer) +} // body + +cond { + cond-arg.tuple = (s32[], f32[4]{0}, f32[4]{0}) parameter(0) + cond-index.s32 = s32[] get-tuple-element(cond-arg.tuple), index=0 + four.s32 = s32[] constant(4) + ROOT cond-result = pred[] compare(cond-index.s32, four.s32), direction=LT +} // cond + +ENTRY main { + zero.s32 = s32[] constant(0) + zero.f32 = f32[] constant(0) + empty-buffer = f32[4]{0} broadcast(zero.f32), dimensions={} + while.tuple = (s32[], f32[4]{0}, f32[4]{0}) tuple(zero.s32, empty-buffer, empty-buffer) + while = (s32[], f32[4]{0}, f32[4]{0}) while(while.tuple), condition=cond, body=body + output-pinned-host-buffer = f32[4]{0} get-tuple-element(while), index=1 + output-device-buffer = f32[4]{0} get-tuple-element(while), index=2 + ROOT result.tuple = (f32[4]{0}, f32[4]{0}) tuple(output-pinned-host-buffer, output-device-buffer) +} // main diff --git a/third_party/xla/xla/service/gpu/tests/parallel_reduction_test.cc b/third_party/xla/xla/service/gpu/tests/parallel_reduction_test.cc index 74af9b4093e1d6..000467cad297a3 100644 --- a/third_party/xla/xla/service/gpu/tests/parallel_reduction_test.cc +++ b/third_party/xla/xla/service/gpu/tests/parallel_reduction_test.cc @@ -23,12 +23,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/literal_util.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/tests/sparse_add_encoding.mlir b/third_party/xla/xla/service/gpu/tests/sparse_add_encoding.mlir index 4d575a07687e6d..75608281b3fe27 100644 --- a/third_party/xla/xla/service/gpu/tests/sparse_add_encoding.mlir +++ b/third_party/xla/xla/service/gpu/tests/sparse_add_encoding.mlir @@ -5,20 +5,20 @@ // Note: 'canonicalize' folds redundant (back-and-forth) convert_layout ops. -// CHECK-DAG: #[[BLOCKED4x4:.*]] = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK-DAG: #[[BLOCKED4x4:.*]] = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> // CHECK-DAG: #[[BLOCKED1x1:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> module { // CHECK: @sparse_dot tt.func @sparse_dot() { // CHECK-NEXT: %[[A:.*]] = arith.constant dense<1.000000e+00> - // CHECK-SAME: : tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[BLOCKED4x4]]}>> + // CHECK-SAME: : tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[BLOCKED4x4]]}>> // CHECK-NEXT: %[[B:.*]] = arith.constant dense<2.000000e+00> - // CHECK-SAME: : tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED4x4]]}>> + // CHECK-SAME: : tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED4x4]]}>> // CHECK-NEXT: %[[C:.*]] = arith.constant dense<0.000000e+00> - // CHECK-SAME: : tensor<64x64xf32, #[[BLOCKED4x4]]> + // CHECK-SAME: : tensor<64x64xf32, #[[BLOCKED4x4]]> // CHECK-NEXT: %[[META:.*]] = arith.constant dense<13107> - // CHECK-SAME: : tensor<64x4xi16, #triton_xla.sparse_dot_meta<{parent = #[[BLOCKED4x4]]}>> + // CHECK-SAME: : tensor<64x4xi16, #triton_xla.sparse_dot_meta<{parent = #[[BLOCKED4x4]]}>> %a = arith.constant dense<1.00e+00> : tensor<64x32xf16> %b = arith.constant dense<2.00e+00> : tensor<64x64xf16> %c = arith.constant dense<0.00e+00> : tensor<64x64xf32> @@ -40,7 +40,7 @@ module { // A use with side effects so we don't DCE the whole function. tt.print "" { hex = false, isSigned = array} : %d : tensor<64x64xf32> - // CHECK-NEXT: tt.return + // CHECK-NEXT: tt.return tt.return } } diff --git a/third_party/xla/xla/service/gpu/tests/sparse_local_load_to_llvm.mlir b/third_party/xla/xla/service/gpu/tests/sparse_local_load_to_llvm.mlir index 47da295e9728dc..cdff37628da49e 100644 --- a/third_party/xla/xla/service/gpu/tests/sparse_local_load_to_llvm.mlir +++ b/third_party/xla/xla/service/gpu/tests/sparse_local_load_to_llvm.mlir @@ -8,8 +8,8 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { // CHECK-LABEL: sparse_local_load_ampere - tt.func @sparse_local_load_ampere(%A_alloc: !ttg.memdesc<32x32xf16, #shared, #ttg.shared_memory>, - %B_alloc: !ttg.memdesc<64x32xf16, #shared, #ttg.shared_memory>, + tt.func @sparse_local_load_ampere(%A_alloc: !ttg.memdesc<32x32xf16, #shared, #ttg.shared_memory>, + %B_alloc: !ttg.memdesc<64x32xf16, #shared, #ttg.shared_memory>, %meta_alloc: !ttg.memdesc<32x4xi16, #shared, #ttg.shared_memory>) { // A_dot and B_dot local loads shouldn not match with -sparse-local-load-to-llvm // CHECK-COUNT-2: ttg.local_load diff --git a/third_party/xla/xla/service/gpu/tests/swap_conv_operands_test.cc b/third_party/xla/xla/service/gpu/tests/swap_conv_operands_test.cc index 2885c8af11ff33..48dc1db38cd5dc 100644 --- a/third_party/xla/xla/service/gpu/tests/swap_conv_operands_test.cc +++ b/third_party/xla/xla/service/gpu/tests/swap_conv_operands_test.cc @@ -13,8 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + +#include "absl/status/statusor.h" #include "xla/error_spec.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/stream_executor/device_description.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" // TODO(b/210165681): The tests in this file are fragile to HLO op names. @@ -24,11 +30,20 @@ namespace gpu { namespace { -class SwapConvOperandsTest : public GpuCodegenTest {}; +class SwapConvOperandsTest : public GpuCodegenTest { + public: + absl::StatusOr GpuComputeCapability() { + TF_ASSIGN_OR_RETURN( + std::unique_ptr device_description, + GetTestPlatform()->DescriptionForDevice(0)); + + return device_description->gpu_compute_capability(); + } +}; // Here, we swap the operands of a convolution to avoid the performance penalty // associated with convolutions with large padding. This tests that the operands -// are swapped in this case, and that the emitted convolution is sucessfully +// are swapped in this case, and that the emitted convolution is successfully // lowered to a cuDNN custom-call. TEST_F(SwapConvOperandsTest, LargePadding) { const char* hlo_text = R"( @@ -42,10 +57,22 @@ ENTRY swap_conv { } )"; - MatchOptimizedHloWithShapes(hlo_text, - R"( + TF_ASSERT_OK_AND_ASSIGN(se::GpuComputeCapability gpu_compute_capability, + GpuComputeCapability()); + + if (std::get_if(&gpu_compute_capability) + ->IsAtLeastHopper()) { + MatchOptimizedHloWithShapes(hlo_text, + R"( +// CHECK: [[cudnn_conv_1_0:%[^ ]+]] = (f32[1,32,32,128]{3,2,1,0}, u8[{{.*}}]{0}) custom-call(f32[1,30,30,512]{3,2,1,0} {{[^ ]+}}, f32[128,3,3,512]{3,2,1,0} {{[^ ]+}}), window={size=3x3 pad=2_2x2_2 rhs_reversal=1x1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForward" + )"); + } else { + MatchOptimizedHloWithShapes(hlo_text, + R"( // CHECK: [[cudnn_conv_1_0:%[^ ]+]] = (f32[1,128,32,32]{3,2,1,0}, u8[{{.*}}]{0}) custom-call(f32[1,512,30,30]{3,2,1,0} [[fusion_1_1:%[^ ]+]], f32[128,512,3,3]{3,2,1,0} [[transpose_1_2:%[^ ]+]]), window={size=3x3 pad=2_2x2_2 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward" - )"); + )"); + } + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); } @@ -62,10 +89,21 @@ ENTRY swap_conv { } )"; - MatchOptimizedHloWithShapes(hlo_text, - R"( + TF_ASSERT_OK_AND_ASSIGN(se::GpuComputeCapability gpu_compute_capability, + GpuComputeCapability()); + + if (std::get_if(&gpu_compute_capability) + ->IsAtLeastHopper()) { + MatchOptimizedHloWithShapes(hlo_text, + R"( +// CHECK: [[cudnn_conv_1_0:%[^ ]+]] = (f32[1,32,32,128]{3,2,1,0}, u8[{{[0-9]*}}]{0}) custom-call(f32[1,30,30,512]{3,2,1,0} {{[^ ]+}}, f32[128,3,3,512]{3,2,1,0} {{[^ ]+}}), window={size=3x3 pad=2_2x2_2 rhs_reversal=1x1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForward" + )"); + } else { + MatchOptimizedHloWithShapes(hlo_text, + R"( // CHECK: [[cudnn_conv_1_0:%[^ ]+]] = (f32[1,128,32,32]{3,2,1,0}, u8[{{[0-9]*}}]{0}) custom-call(f32[1,512,30,30]{3,2,1,0} [[fusion_1_1:%[^ ]+]], f32[128,512,3,3]{3,2,1,0} [[transpose_1_2:%[^ ]+]]), window={size=3x3 pad=2_2x2_2 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward" - )"); + )"); + } EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); } diff --git a/third_party/xla/xla/service/gpu/tests/xla-opt.cc b/third_party/xla/xla/service/gpu/tests/xla-opt.cc index ba6cede789f3bf..7bfda500c22806 100644 --- a/third_party/xla/xla/service/gpu/tests/xla-opt.cc +++ b/third_party/xla/xla/service/gpu/tests/xla-opt.cc @@ -15,7 +15,7 @@ limitations under the License. #include "mlir/InitAllExtensions.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" -#include "xla/service/gpu/fusions/transforms/passes.h" +#include "xla/backends/gpu/codegen/transforms/passes.h" #include "xla/service/gpu/fusions/triton/xla_triton_ops.h" #include "xla/service/gpu/fusions/triton/xla_triton_passes.h" #include "third_party/triton/bin/RegisterTritonDialects.h" diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 8eed41a577f3f3..4a94e9adaf9461 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -13,6 +13,7 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ + "//xla/hlo/tools/hlo_opt:__subpackages__", "//xla/service/gpu:__subpackages__", "//xla/tools/hlo_opt:__subpackages__", ], @@ -75,7 +76,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", "//xla/service:hlo_creation_utils", "//xla/service:pattern_matcher", "//xla/service/gpu:matmul_utils", @@ -97,7 +98,7 @@ xla_cc_test( deps = [ ":algebraic_simplifier", "//xla/hlo/ir:hlo", - "//xla/hlo/transforms:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/stream_executor:device_description", @@ -354,10 +355,12 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/testlib:verified_hlo_module", + "//xla/hlo/utils:hlo_query", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], ) @@ -511,13 +514,13 @@ xla_cc_test( deps = [ ":collective_permute_cycle_decomposer", "//xla/hlo/ir:hlo", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", + "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", + "@com_google_googletest//:gtest_main", ], ) @@ -527,6 +530,7 @@ cc_library( hdrs = ["collective_select_folder.h"], deps = [ "//xla:comparison_util", + "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:collective_ops_utils", @@ -629,12 +633,12 @@ xla_test( ":command_buffer_scheduling", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:gpu_executable", "//xla/stream_executor:device_description", - "//xla/tests:filecheck", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:status", @@ -941,10 +945,12 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:algebraic_simplifier", - "//xla/hlo/transforms:convert_mover", - "//xla/hlo/transforms:hlo_constant_folding", - "//xla/hlo/transforms:reshape_mover", + "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:convert_mover", + "//xla/hlo/transforms/simplifiers:hlo_constant_folding", + "//xla/hlo/transforms/simplifiers:reshape_mover", "//xla/service:hlo_module_config", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", @@ -955,9 +961,7 @@ xla_test( "//xla/stream_executor:device_description", "//xla/stream_executor:dnn", "//xla/stream_executor:semantic_version", - "//xla/tests:filecheck", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", @@ -1026,9 +1030,9 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", - "//xla/hlo/transforms:algebraic_simplifier", - "//xla/hlo/transforms:hlo_dce", "//xla/hlo/transforms:reshape_decomposer", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/service:computation_layout", "//xla/service:hlo_cse", "//xla/service:hlo_module_config", @@ -1161,10 +1165,10 @@ xla_test( deps = [ ":cudnn_norm_rewriter", "//xla:error_spec", + "//xla/hlo/testlib:filecheck", "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", - "//xla/tests:filecheck", "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", ] + if_cuda_is_configured([ @@ -1250,9 +1254,9 @@ xla_cc_test( "//xla:util", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:algebraic_simplifier", - "//xla/hlo/transforms:reshape_mover", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:reshape_mover", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/service:call_inliner", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", @@ -1559,7 +1563,7 @@ cc_library( "//xla/hlo/ir:hlo_instruction_utils", "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:flatten_call_graph", + "//xla/hlo/transforms/simplifiers:flatten_call_graph", "//xla/hlo/utils:hlo_query", "//xla/service:collective_ops_utils", "@com_google_absl//absl/algorithm:container", @@ -1587,7 +1591,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", "//xla/hlo/testlib:filecheck", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/hlo/utils:hlo_query", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -1608,11 +1612,13 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla/ffi:ffi_api", + "//xla/hlo/analysis:while_loop_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", "//xla/hlo/utils:hlo_traversal", + "//xla/service:call_graph", "//xla/service:custom_call_target_registry", - "//xla/service:pattern_matcher", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu:gpu_constants", @@ -2083,7 +2089,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:sub_byte_normalization", + "//xla/hlo/transforms/simplifiers:sub_byte_normalization", "//xla/service:hlo_creation_utils", "//xla/service/gpu:gpu_fusible", "//xla/stream_executor:device_description", @@ -2116,7 +2122,7 @@ xla_test( "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/service:hlo_cost_analysis", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", @@ -2178,6 +2184,7 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:filecheck", "//xla/service:computation_layout", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", @@ -2286,7 +2293,7 @@ cc_library( deps = [ "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/utils:hlo_query", "//xla/service:call_graph", "//xla/service:instruction_fusion", @@ -2324,6 +2331,7 @@ xla_cc_test( ":nest_gemm_fusion", "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", "//xla/service:hlo_cost_analysis", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", @@ -2331,7 +2339,6 @@ xla_cc_test( "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:matmul_utils", - "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -2425,7 +2432,6 @@ cc_library( "@com_google_absl//absl/time", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -2855,6 +2861,7 @@ cc_library( "//xla/service/gpu/model:triton_emitter_constraints", "//xla/stream_executor:device_description", "//xla/tools:hlo_decomposer_lib", + "//xla/tsl/platform:errors", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -3099,7 +3106,7 @@ xla_cc_test( deps = [ ":topk_splitter", "//xla/hlo/ir:hlo", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/service:pattern_matcher", "//xla/service:topk_rewriter", "//xla/tests:hlo_test_base", @@ -3334,8 +3341,8 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:algebraic_simplifier", - "//xla/hlo/transforms:hlo_constant_folding", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:hlo_constant_folding", "//xla/hlo/utils:hlo_query", "//xla/service:hlo_creation_utils", "//xla/service:pattern_matcher", diff --git a/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier_test.cc b/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier_test.cc index b2dfccaca8ed03..3f25b26af7adde 100644 --- a/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier_test.cc @@ -74,7 +74,7 @@ TEST_F(AllGatherDynamicSliceSimplifierTest, AllPartitions) { dimensions={0}, channel_id=1, use_global_device_ids=true %pid = u32[] partition-id() %pid_s32 = s32[] convert(%pid) - %slice_size = s32[] constant(32) + %slice_size = s32[] constant(32) %offset = s32[] multiply(%pid_s32, %slice_size) %zero = s32[] constant(0) ROOT %ds = f32[32,8,128]{2,1,0} dynamic-slice(%ag, %offset, %zero, %zero), @@ -94,7 +94,7 @@ TEST_F(AllGatherDynamicSliceSimplifierTest, AllPartitions) { TEST_F(AllGatherDynamicSliceSimplifierTest, AllReplicasWithReshape) { absl::string_view hlo_string = R"( HloModule AllGather - + ENTRY %AllGather { %param = f32[32,8,128]{2,1,0} parameter(0) %ag = f32[256,8,128]{2,1,0} all-gather(%param), replica_groups={{0,1,2,3,4,5,6,7}}, diff --git a/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc index a941cb6681cedd..18acf38c4d1986 100644 --- a/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc +++ b/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc @@ -65,9 +65,10 @@ absl::StatusOr AsyncWrapper::Run( continue; } - // Otherwise, follow any `calls` to discover other instructions that can - // potentially be made async. - if (HloPredicateIsOp(instruction)) { + // Otherwise, follow anything other than `fusion`s to discover other + // instructions that can potentially be made async. + if (HloPredicateIsOp(instruction)) { std::copy(instruction->called_computations().begin(), instruction->called_computations().end(), std::back_inserter(computations)); diff --git a/third_party/xla/xla/service/gpu/transforms/async_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/async_wrapper_test.cc index 345fac37bd4707..183832154238e1 100644 --- a/third_party/xla/xla/service/gpu/transforms/async_wrapper_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/async_wrapper_test.cc @@ -25,11 +25,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/hlo/utils/hlo_query.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" namespace xla::gpu { namespace { @@ -81,5 +83,115 @@ TEST_F(AsyncWrapperTest, BasicFusion) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +TEST_F(AsyncWrapperTest, OpWithinWhileShouldWrapInAsync) { + const char* hlo = R"( + HloModule m + + body { + param = (f32[1], s32[]) parameter(0) + p0 = f32[1] get-tuple-element(param), index=0 + agg1 = f32[1] custom-call(p0), custom_call_target="foo" + agg2 = f32[1] custom-call(p0), custom_call_target="bar" + done = f32[1] add(agg1, agg2) + iter = s32[] get-tuple-element(param), index=1 + c1 = s32[] constant(1) + add = s32[] add(iter, c1) + ROOT tuple = (f32[1], s32[]) tuple(done, add) + } + + condition { + param.1 = (f32[1], s32[]) parameter(0) + iter.1 = s32[] get-tuple-element(param.1), index=1 + c4 = s32[] constant(4) + ROOT compare = pred[] compare(iter.1, c4), direction=LT + } + + ENTRY main { + c0 = s32[] constant(0) + p0.1 = f32[1] parameter(0) + agg3 = f32[1] custom-call(p0.1), custom_call_target="baz" + tuple = (f32[1], s32[]) tuple(agg3, c0) + while = (f32[1], s32[]) while(tuple), body=body, condition=condition + ROOT done.1 = f32[1] get-tuple-element(while), index=0 + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + + AsyncWrapper wrapper(HloPredicateIsOp); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + wrapper.Run(module.get(), /*execution_threads=*/{})); + EXPECT_TRUE(changed); + EXPECT_EQ(CountAsyncInstructions(module->entry_computation()), 2); + HloInstruction* while_op = hlo_query::FindInstruction( + module->entry_computation(), HloOpcode::kWhile); + ASSERT_NE(while_op, nullptr); + EXPECT_EQ(CountAsyncInstructions(while_op->while_body()), 4); +} + +TEST_F(AsyncWrapperTest, OpWithinConditionalShouldWrapInAsync) { + const char* hlo = R"( + HloModule m + + true_computation { + p0.1 = f32[] parameter(0) + ROOT res.1 = f32[] custom-call(p0.1), custom_call_target="foo" + } + + false_computation { + p0.2 = f32[] parameter(0) + ROOT res.2 = f32[] custom-call(p0.2), custom_call_target="foo" + } + + ENTRY main { + p0 = f32[] parameter(0) + c0 = f32[] constant(0) + compare = pred[] compare(p0, c0), direction=GE + ROOT done = f32[] conditional(compare, p0, p0), true_computation=true_computation, false_computation=false_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + + AsyncWrapper wrapper(HloPredicateIsOp); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + wrapper.Run(module.get(), /*execution_threads=*/{})); + EXPECT_TRUE(changed); + EXPECT_EQ(CountAsyncInstructions(module->entry_computation()), 0); + HloInstruction* conditional_op = hlo_query::FindInstruction( + module->entry_computation(), HloOpcode::kConditional); + ASSERT_NE(conditional_op, nullptr); + EXPECT_EQ(CountAsyncInstructions(conditional_op->true_computation()), 2); + EXPECT_EQ(CountAsyncInstructions(conditional_op->false_computation()), 2); +} + +TEST_F(AsyncWrapperTest, OpWithinFusionShouldNotWrapInAsync) { + const char* hlo = R"( + foo { + p0 = f32[1] parameter(0) + ROOT custom-call = f32[1] custom-call(p0), custom_call_target="bar" + } + ENTRY main { + c0 = s32[] constant(0) + p0.1 = f32[1] parameter(0) + agg.1 = f32[1] fusion(p0.1), kind=kLoop, calls=foo + agg.2 = f32[1] custom-call(agg.1), custom_call_target="bar" + ROOT done.1 = (f32[1], f32[1]) tuple(agg.1, agg.2) + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + + AsyncWrapper wrapper(HloPredicateIsOp); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + wrapper.Run(module.get(), /*execution_threads=*/{})); + EXPECT_TRUE(changed); + EXPECT_EQ(CountAsyncInstructions(module->entry_computation()), 2); + + HloInstruction* fusion = hlo_query::FindInstruction( + module->entry_computation(), HloOpcode::kFusion); + EXPECT_EQ(CountAsyncInstructions(fusion->fused_instructions_computation()), + 0); +} + } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc index ab4b5466dda500..082e4dc3af1087 100644 --- a/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc @@ -25,16 +25,16 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/tests/filecheck.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" namespace xla { namespace { using ::testing::HasSubstr; -using CollectivePermuteCycleDecomposerTest = HloTestBase; +using CollectivePermuteCycleDecomposerTest = HloHardwareIndependentTestBase; using Decomposer = CollectivePermuteCycleDecomposer; HloPrintOptions PrintOptions() { diff --git a/third_party/xla/xla/service/gpu/transforms/collective_select_folder.cc b/third_party/xla/xla/service/gpu/transforms/collective_select_folder.cc index 1d850d4aa516a3..5b6c4c008ee894 100644 --- a/third_party/xla/xla/service/gpu/transforms/collective_select_folder.cc +++ b/third_party/xla/xla/service/gpu/transforms/collective_select_folder.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/collective_ops_utils.h" +#include "xla/shape_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" @@ -51,12 +52,20 @@ struct FoldableSelect { HloInstruction* false_operand; }; +const HloInstruction* FindInnerScalarOp(const HloInstruction* inst) { + while (inst->opcode() == HloOpcode::kConvert || + inst->opcode() == HloOpcode::kBroadcast) { + inst = inst->operand(0); + } + return inst; +} + // Matches foldable select ops that we can analyse and returns handy references // to %constant, %true_operand, %false_operand of the op. Matches, e.g., // // ``` // select( -// broadcast(compare(partition-id(), constant)), +// broadcast(compare(convert(partition-id()), constant)), // true_operand, // false_operand) // ``` @@ -65,7 +74,7 @@ struct FoldableSelect { // // ``` // select( -// compare(partition-id(), constant), +// compare(replica-id(), constant), // true_operand, // false_operand) // ``` @@ -74,21 +83,22 @@ std::optional MatchFoldableSelect(HloInstruction* select) { return std::nullopt; } - // Match select predicate (may be broadcasted). - const HloInstruction* predicate_candidate = select->operand(0); - if (HloPredicateIsOp(predicate_candidate)) - predicate_candidate = predicate_candidate->operand(0); + // Match select predicate. + const HloInstruction* predicate_candidate = + FindInnerScalarOp(select->operand(0)); const HloCompareInstruction* compare = DynCast(predicate_candidate); - if (compare == nullptr) return std::nullopt; + if (compare == nullptr) { + return std::nullopt; + } if (compare->direction() != Comparison::Direction::kEq && compare->direction() != Comparison::Direction::kNe) { return std::nullopt; } // Find replica-id or partition-id op and constant op, swap if needed. - const HloInstruction* id_op = compare->operand(0); - const HloInstruction* constant_op = compare->operand(1); + const HloInstruction* id_op = FindInnerScalarOp(compare->operand(0)); + const HloInstruction* constant_op = FindInnerScalarOp(compare->operand(1)); if (HloPredicateIsNotOp(constant_op)) { std::swap(id_op, constant_op); } @@ -104,35 +114,41 @@ std::optional MatchFoldableSelect(HloInstruction* select) { } // Match constant. - if (HloPredicateIsNotOp(constant_op)) + if (HloPredicateIsNotOp(constant_op) || + !ShapeUtil::IsScalar(constant_op->shape())) { return std::nullopt; + } std::optional constant_id = constant_op->literal().GetFirstInteger(); - if (!constant_id.has_value()) return std::nullopt; + if (!constant_id.has_value()) { + return std::nullopt; + } return FoldableSelect{compare->direction(), *constant_id, collective_mode, select->mutable_operand(1), select->mutable_operand(2)}; } +bool SelectPredicateEval(const FoldableSelect& select_match, + const SourceTargetPair& pair) { + int64_t src_id = pair.first; + return select_match.cmp_direction == Comparison::Direction::kEq + ? src_id == select_match.constant_id + : src_id != select_match.constant_id; +}; + std::optional StaticallyEvaluatePredicateForAllSourceIDs( - FoldableSelect select_match, SourceTargetPairs pairs) { + const FoldableSelect& select_match, const SourceTargetPairs& pairs) { // If there are no pairs, the predicate is undefined. if (pairs.empty()) return std::nullopt; // Evaluate the select predicate for the first source target pair. CHECK(select_match.cmp_direction == Comparison::Direction::kEq || select_match.cmp_direction == Comparison::Direction::kNe); - auto select_predicate_eval = [&select_match](const SourceTargetPair& pair) { - int64_t src_id = pair.first; - return select_match.cmp_direction == Comparison::Direction::kEq - ? src_id == select_match.constant_id - : src_id != select_match.constant_id; - }; - bool result_candidate = select_predicate_eval(pairs.front()); + bool result_candidate = SelectPredicateEval(select_match, pairs.front()); // Check that the result is the same for all source target pairs. If not, // we have a contradiction and cannot statically evaluate the predicate. We // return std::nullopt in this case. if (!absl::c_all_of(pairs, [&](const SourceTargetPair& it) -> bool { - return result_candidate == select_predicate_eval(it); + return result_candidate == SelectPredicateEval(select_match, it); })) { return std::nullopt; } diff --git a/third_party/xla/xla/service/gpu/transforms/collective_select_folder_test.cc b/third_party/xla/xla/service/gpu/transforms/collective_select_folder_test.cc index 12faa97377d1bf..42ecc87717cffa 100644 --- a/third_party/xla/xla/service/gpu/transforms/collective_select_folder_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/collective_select_folder_test.cc @@ -41,7 +41,7 @@ using ::testing::HasSubstr; class CollectiveSelectFolderTest : public HloTestBase { public: - absl::Status ExpectNoTranform(std::string_view hlo_template) { + absl::Status ExpectNoTranform(absl::string_view hlo_template) { return RunAndCheckHloRewrite(hlo_template, CollectiveSelectFolder(), /*expect_change=*/false) .status(); @@ -49,8 +49,8 @@ class CollectiveSelectFolderTest : public HloTestBase { }; void VerifyDirectDataFeedSPMD(HloModule* module, - std::string_view expected_fwd_operand, - std::string_view expected_bwd_operand) { + absl::string_view expected_fwd_operand, + absl::string_view expected_bwd_operand) { auto root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSelect); EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kCollectivePermute); @@ -423,7 +423,7 @@ TEST_F(CollectiveSelectFolderTest, } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, RunAndCheckHloRewrite(kHlo, CollectiveSelectFolder(), /*expect_change=*/true)); const absl::string_view kExpected = R"( @@ -449,5 +449,40 @@ TEST_F(CollectiveSelectFolderTest, EXPECT_TRUE(filecheck_result); } +TEST_F(CollectiveSelectFolderTest, DtypeConvertedPartitionId) { + const absl::string_view kHlo = R"( + HloModule test + + ENTRY computation { + param = (f32[1,1,28672,2048]{3,2,1,0}, f32[1,1,28672,2048]{3,2,1,0}) + parameter(0) + get-tuple-element-a = f32[1,1,28672,2048]{3,2,1,0} + get-tuple-element(param), index=0 + get-tuple-element-b = f32[1,1,28672,2048]{3,2,1,0} + get-tuple-element(param), index=1 + partition-id.1 = u32[] partition-id() + convert = s32[] convert(partition-id.1) + constant.148 = s32[] constant(3) + compare.83 = pred[] compare(convert, constant.148), direction=EQ + select.33 = f32[1,1,28672,2048]{3,2,1,0} select(compare.83, + get-tuple-element-a, get-tuple-element-b) + ROOT cp-a = f32[1,1,28672,2048]{3,2,1,0} collective-permute(select.33), + channel_id=1, source_target_pairs={{3,0}} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + RunAndCheckHloRewrite(kHlo, CollectiveSelectFolder(), + /*expect_change=*/true)); + const absl::string_view kExpected = R"( + // CHECK: %[[PARAM:.*]] = {{.*}} parameter(0) + // CHECK: %[[DATA_A:.*]] = {{.*}} get-tuple-element({{.*}} %[[PARAM]]), index=0 + // CHECK: ROOT %[[DATA_A_:.*]] = {{.*}} collective-permute({{.*}} %[[DATA_A]]) + )"; + TF_ASSERT_OK_AND_ASSIGN(bool filecheck_result, + RunFileCheck(module->ToString(), kExpected)); + EXPECT_TRUE(filecheck_result); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/collective_send_recv_combiner.h b/third_party/xla/xla/service/gpu/transforms/collective_send_recv_combiner.h index de8a88517035d2..8b1de0602f7c2a 100644 --- a/third_party/xla/xla/service/gpu/transforms/collective_send_recv_combiner.h +++ b/third_party/xla/xla/service/gpu/transforms/collective_send_recv_combiner.h @@ -26,11 +26,23 @@ limitations under the License. namespace xla { // CollectiveSendRecvCombiner is a pass that scans for all send/recv pairs -// which are part of the same computation, and transforms them into wrapped -// single-op computations that are executed asynchronously. This pass also +// which are part of the same computation, and transforms them into a wrapped +// multi-op computation that can be executed asynchronously. This pass also // replaces the corresponding send-done and recv-done instructions with -// async-done functions. This pass is primarily used for pipelining send/recv -// and send-done/recv-done instructions across while loop iteration boundaries. +// async-done functions. This pass shouldn't be applied to send/recv +// instructions that are called in a while loop, since it will force all +// send/recv instructions in the same group to finish executing before +// computation can continue.Partial grouping of send/recv instructions in the +// same NCCL group will lead to deadlocks and is therefore discouraged. In +// practice this means that there exists at least one send or recv instruction +// in the same NCCL group that doesn't have a matching send/recv. An example of +// partial grouping with deadlock written in HLO pseudocode: +// wrapped_send_recv {send1, recv1, recv2} +// async_start = async_start(inputs), calls=wrapped_send_recv +// loop_input = gte(async_done(async_start)) +// while_loop_output = while(loop_input) +// send2_data = gte(while_loop_output) +// output_token = send2(send2_data) class CollectiveSendRecvCombiner : public HloModulePass { public: absl::string_view name() const override { diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc index 195e218b4c7d64..89744b057eb791 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc @@ -24,12 +24,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/gpu_executable.h" #include "xla/stream_executor/device_description.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -1091,7 +1091,7 @@ TEST_F(CommandBufferSchedulingTest, AsyncFusion) { TEST_F(CommandBufferSchedulingTest, AsyncAlltoAll) { const char* hlo = R"( HloModule m, is_scheduled=true - + async_computation.1 { param.1 = f32[4,8,128]{2,1,0} parameter(0) ROOT all-to-all.1 = f32[4,8,128]{2,1,0} all-to-all(param.1), channel_id=1, dimensions={1} @@ -1099,7 +1099,7 @@ TEST_F(CommandBufferSchedulingTest, AsyncAlltoAll) { ENTRY main { param.0 = f32[4,8,128]{2,1,0} parameter(0) - all-to-all-start = ((f32[4,8,128]{2,1,0}), f32[4,8,128]{2,1,0}) async-start(param.0), calls=async_computation.1 + all-to-all-start = ((f32[4,8,128]{2,1,0}), f32[4,8,128]{2,1,0}) async-start(param.0), calls=async_computation.1 ROOT all-to-all-done = f32[4,8,128]{2,1,0} async-done(all-to-all-start) })"; diff --git a/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc index d411fd064dd5e0..567d66ac7a0b0a 100644 --- a/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include #include diff --git a/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc index d01ffd1829b7f8..8039ceca20c825 100644 --- a/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc @@ -759,7 +759,7 @@ TEST_F(ConvRewriterTest, TestInvalidTypes) { })"); // Test complex types - for (std::string_view type : {"c64", "c128"}) { + for (absl::string_view type : {"c64", "c128"}) { const std::string module_with_type = absl::StrReplaceAll(module_str, {{"TYPE", type}}); TF_ASSERT_OK_AND_ASSIGN(auto m, diff --git a/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync_test.cc b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync_test.cc index d38ab70864ac4c..8fb271138f1dde 100644 --- a/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync_test.cc @@ -15,8 +15,6 @@ limitations under the License. #include "xla/service/gpu/transforms/convert_async_collectives_to_sync.h" -#include - #include #include #include "absl/status/status.h" @@ -50,7 +48,7 @@ class GpuConvertAsyncCollectivesToSyncTest : public HloTestBase { } // Returns true if the instruction with the given name is synchronous. - bool IsSync(HloModule *module, std::string_view name) { + bool IsSync(HloModule *module, absl::string_view name) { const HloInstruction *inst = FindInstruction(module, name); if (inst == nullptr) { return false; diff --git a/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc b/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc index 1b34fb13a72903..23706a4dbcf149 100644 --- a/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc @@ -75,7 +75,7 @@ absl::StatusOr CopyFusion::DoCopyFusion(HloComputation* computation) { } HloInstruction* root = fused_computation->root_instruction(); if (IsReductionFromOrToContiguousDimensions(*root, device_description_) || - root->opcode() == HloOpcode::kScatter || + HloPredicateIsOp(root) || (hlo->IsMultiOutputFusion() && absl::c_all_of(root->operands(), HloPredicateIsOp))) { diff --git a/third_party/xla/xla/service/gpu/transforms/cublas_gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cublas_gemm_rewriter_test.cc index fc4ccdd304cc5f..0c4910cc2e5318 100644 --- a/third_party/xla/xla/service/gpu/transforms/cublas_gemm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cublas_gemm_rewriter_test.cc @@ -2166,12 +2166,11 @@ ENTRY test { } TEST_F(CublasLtGemmRewriteTest, VectorBiasThenApproxGeluActivation) { -#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60000 - auto rocm_switch = false; // GELU is only available from ROCM 6.0 -#else - auto rocm_switch = true; -#endif - if (!IsCuda() && rocm_switch) { + auto runtime_version = GetRuntimeVersion(); + bool rocm_gelu_available = + IsRocm() && + (runtime_version >= stream_executor::SemanticVersion(6, 0, 0)); + if (IsRocm() && !rocm_gelu_available) { GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM"; } const char* hlo_text = R"( @@ -2234,7 +2233,7 @@ ENTRY test { } TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationWithAux) { - if (!IsCuda()) { + if (IsRocm()) { GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM"; } const char* hlo_text = R"( @@ -2294,7 +2293,7 @@ ENTRY test { } TEST_F(CublasLtGemmRewriteTest, VectorBiasThenApproxGeluActivationWithAux) { - if (!IsCuda()) { + if (IsRocm()) { GTEST_SKIP() << "TODO: Unsupported blas-lt epilogue on ROCM"; } const char* hlo_text = R"( @@ -2982,7 +2981,7 @@ ENTRY test { } TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF64) { - if (!IsCuda()) { + if (IsRocm()) { GTEST_SKIP() << "TODO: Unsupported blas-lt F64 datatype on ROCM"; } const char* hlo_text = R"( @@ -3170,7 +3169,7 @@ ENTRY main { // Test gemm matrix bias add fusion with mix type and out of place update(C != // D) TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeOutOfPlace) { - if (!IsCuda()) { + if (IsRocm()) { GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM"; } std::vector> @@ -3215,7 +3214,7 @@ ENTRY test { // Test batch gemm matrix bias add fusion with mix type and out of place // update(C != D) TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeOutOfPlaceBatched) { - if (!IsCuda()) { + if (IsRocm()) { GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM"; } std::vector> @@ -3259,7 +3258,7 @@ ENTRY test { // Test gemm matrix bias add fusion with mix type and in place update(C = D) TEST_F(CublasLtGemmRewriteTest, MatrixBiasMixTypeInPlace) { - if (!IsCuda()) { + if (IsRocm()) { GTEST_SKIP() << "TODO: Unsupported mixed datatypes on ROCM"; } std::vector> diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc index 0dc92c47d2cb55..67f33164fa2638 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc @@ -149,12 +149,14 @@ absl::StatusOr HloCustomCallToCuDnnGraph( GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); const int sliding_window_length = config.sliding_window_length(); + const int max_seg_per_batch = config.max_seg_per_batch(); TF_ASSIGN_OR_RETURN( se::gpu::CudnnGraph graph, se::gpu::GetCudnnFlashAttentionOperationGraph( dnn_support, lhs_bmm1, rhs_bmm1, rhs_bmm2, output, bias, activation, static_cast(config.fmha_scale()), dropout_rate > 0.0, - dropout_rate, dnn_mask_type, sliding_window_length)); + dropout_rate, dnn_mask_type, sliding_window_length, + max_seg_per_batch)); return graph; } else if (IsFwdCustomCallTofMHAF8(*custom_call)) { TF_ASSIGN_OR_RETURN( @@ -230,12 +232,19 @@ absl::StatusOr HloCustomCallToCuDnnGraph( // Unused fwd_output_shape ++input_index; + const int max_seg_per_batch = config.max_seg_per_batch(); if (config.mask_type() == xla::gpu::CudnnfMHABackendConfig::PADDING || config.mask_type() == - xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL) { + xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL || + max_seg_per_batch > 1) { // skip q_seqlen and kv_seqlen input_index += 2; } + + if (max_seg_per_batch > 1) { + // skip q_offsets and kv_offsets + input_index += 2; + } TF_RET_CHECK(input_index == custom_call->operand_count()); int output_index = 0; @@ -312,7 +321,8 @@ absl::StatusOr HloCustomCallToCuDnnGraph( bmm2_grad_gemm1_lhs, bmm2_grad_gemm2_rhs, d_output, d_bmm1_lhs, d_bmm1_rhs, d_bmm2_rhs, bias, dropout_rate, config.seed(), config.fmha_scale(), dropout_rate > 0.0, bias != std::nullopt, - dnn_mask_type, force_deterministic, sliding_window_length)); + dnn_mask_type, force_deterministic, sliding_window_length, + max_seg_per_batch)); return graph; } else { TF_ASSIGN_OR_RETURN( diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc index c4ebb27d62ab71..9808f50a7f1d88 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc @@ -35,6 +35,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_fix.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/hlo/transforms/simplifiers/convert_mover.h" #include "xla/hlo/transforms/simplifiers/hlo_constant_folding.h" @@ -50,9 +52,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/semantic_version.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" @@ -2197,7 +2197,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseToS8IfMultipleUsers) { TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS32ToF32) { MAYBE_SKIP_TEST("I8"); - const std::string_view module_str = R"( + const absl::string_view module_str = R"( HloModule Test ENTRY test_entry { @@ -2224,7 +2224,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS32ToF32) { TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS8ToF32) { MAYBE_SKIP_TEST("I8"); - const std::string_view module_str = R"( + const absl::string_view module_str = R"( HloModule Test ENTRY test_entry { @@ -2251,7 +2251,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS8ToF32) { TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingF32ToS8) { MAYBE_SKIP_TEST("I8"); - const std::string_view module_str = R"( + const absl::string_view module_str = R"( HloModule Test ENTRY test_entry { @@ -2277,7 +2277,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingF32ToS8) { } TEST_F(CudnnFusedConvRewriterHloTest, DontRemoveConvertDuetoMultpleUser) { - const std::string_view module_str = R"( + const absl::string_view module_str = R"( HloModule Test ENTRY test_entry { diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc index ce90055036f413..649058153d6cf8 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc @@ -176,13 +176,6 @@ inline std::optional GetComputeDataType( return compute_dtype; } -int FusionLevel(const HloInstruction& hlo) { - return hlo.GetModule() - ->config() - .debug_options() - .xla_gpu_cudnn_gemm_fusion_level(); -}; - // Extracts dimensions and strides from HLO tensors in the format expected by // cuDNN. class GemmDimensionAdapter { @@ -277,9 +270,6 @@ class GemmDimensionAdapter { if (spec->size() == 1) { // The dimension is not split, nothing to do. } else if (spec->size() == 2) { - if (FusionLevel(hlo) < 3) { - return std::nullopt; - } if (!dims.lhs_batch_dimensions().empty()) { VLOG(8) << "Noncontracting dimension split is not compatible with " "batch dimensions."; @@ -498,8 +488,7 @@ absl::StatusOr> HloFusionToCuDnnGraph( return std::nullopt; } continue; - } else if (FusionLevel(fusion) >= 2 && - HloPredicateIsOp(hlo)) { + } else if (HloPredicateIsOp(hlo)) { if (const auto const_tensor = HandleConstantHloToCudnnGraph(*hlo, graph); const_tensor.has_value()) { hlo_to_cudnn[hlo] = const_tensor.value(); @@ -508,9 +497,8 @@ absl::StatusOr> HloFusionToCuDnnGraph( } } else if (HloPredicateIsOp(hlo) || - (FusionLevel(fusion) >= 2 && - (HloPredicateIsOp( - hlo)))) { + ((HloPredicateIsOp( + hlo)))) { // All these are accounted for separately as transformations of strides. hlo_to_cudnn[hlo] = operand(0); } else if (hlo->IsElementwise()) { diff --git a/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.cc b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.cc index b1e0b98c319340..7bfdb137e47c12 100644 --- a/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.cc @@ -88,7 +88,7 @@ absl::StatusOr DotDimensionSorter::Run( for (const HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { for (HloInstruction* instr : computation->instructions()) { - if (instr->opcode() != HloOpcode::kDot) { + if (HloPredicateIsNotOp(instr)) { continue; } // TODO(b/265688934): should non-default layouts be expected here at all? diff --git a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc index 7d217aac5674ee..c46c3f53f6a84b 100644 --- a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc +++ b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc @@ -118,9 +118,9 @@ absl::Status SetSendRecvValidationForPeeledInstr(HloInstruction* new_instr, TF_RET_CHECK( new_instr->opcode() == old_instr->opcode() && "cloned instruction and original instruction have different opcodes"); - if (!HloPredicateIsOp(old_instr)) { + if (HloPredicateIsNotOp(old_instr)) { return absl::OkStatus(); } @@ -188,9 +188,9 @@ absl::Status SetSendRecvValidation(HloInstruction* cp1, HloInstruction* cp2, TF_RET_CHECK( cp2->opcode() == cp1->opcode() && "cloned instruction and original instruction have different opcodes"); - if (!HloPredicateIsOp(cp1)) { + if (HloPredicateIsNotOp(cp1)) { return absl::OkStatus(); } const auto& attribute_map = cp2->frontend_attributes().map(); diff --git a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc index 8c492755fcb04c..fd36117c90f8da 100644 --- a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc @@ -33,18 +33,20 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/ffi/ffi_api.h" +#include "xla/hlo/analysis/while_loop_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/utils/hlo_query.h" #include "xla/hlo/utils/hlo_traversal.h" +#include "xla/service/call_graph.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/gpu_constants.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -57,8 +59,6 @@ namespace gpu { namespace { -namespace m = ::xla::match; - // A dataflow path flowing from a definition to a user. using DefUseDataflowPath = absl::InlinedVector; @@ -151,70 +151,71 @@ bool IsAlignedSlice(const HloInstruction* slice) { return true; } -// Pattern matches the following IR (generated by `jax.lax.scan`) to check if -// the offset is a loop iteration number: - -// clang-format off -// param = (s32[], s32[], s32[16]{0}, s32[16]{0}) parameter(0) -// // the index in `gte` has to be the loop iteration index -// gte = s32[] get-tuple-element(param), index=0 -// c0 = s32[] constant(0) -// compare = pred[] compare(gte, c0), direction=LT -// c_trip_count = s32[] constant(16) -// add = s32[] add(gte, c_trip_count) -// select = s32[] select(compare, add, gte) -// clang-format on - -bool IsLoopIterationNumber(const HloInstruction& offset) { - const HloComputation* parent = offset.parent(); - if (!parent->IsWhileBodyComputation()) return false; - - // Scan loops trip count must be known at compile time as it iterates over the - // leading dimension of the statically shaped input. - const HloInstruction* while_instr = parent->WhileCallInstruction(); - auto config = while_instr->backend_config(); - if (!config.ok() || !config->has_known_trip_count()) return false; - int32_t trip_count = config->known_trip_count().n(); - - // First lets check the offset computation pattern - if (!Match(&offset, m::Select(m::Lt(m::GetTupleElement(m::Parameter(0)), - m::ConstantScalar(0)), - m::Add(m::GetTupleElement(m::Parameter(0)), - m::ConstantScalar(trip_count)), - m::GetTupleElement(m::Parameter())))) { +// Returns true if the `consumer` only depends on the `producer` and no other +// instructions. This is a recursive function checking all paths from the +// `consumer` to the parameters of the computation and if there is any path +// without `producer`, then it returns false. +bool IsOnlyDependentOn(const HloInstruction* consumer, + HloInstruction* producer) { + if (consumer == producer || + HloPredicateIsOp(consumer)) { + return true; + } + if (consumer->operand_count() == 0) { return false; } - - // Next, we check that the parameter used in offset computation is the loop - // induction variable - int64_t param_idx = offset.operand(2)->tuple_index(); - const HloInstruction* root = offset.parent()->root_instruction(); - if (HloPredicateIsNotOp(root)) { + return absl::c_all_of(consumer->operands(), + [producer](const HloInstruction* operand) { + return IsOnlyDependentOn(operand, producer); + }); +}; + +// Returns true if the value is a function of the induction variable within a +// while loop. +bool IsValueFunctionOfLoopInductionVariable(const HloInstruction& value, + CallGraph* call_graph) { + std::vector callers = + call_graph->GetComputationCallers(value.parent()); + if (callers.size() != 1) { + VLOG(2) << "Computation has multiple callers: " + << absl::StrJoin(callers, ",", + [](std::string* out, const HloInstruction* instr) { + out->append(instr->name()); + }); return false; } - // Check the update operation - const HloInstruction* updated_var = - offset.parent()->root_instruction()->operand(param_idx); - if (!Match(updated_var, m::Add(m::GetTupleElement(m::Parameter(0), param_idx), - m::ConstantScalar(1)))) { + HloInstruction* while_op = callers[0]; + if (HloPredicateIsNotOp(while_op)) { + VLOG(2) << "Computation caller is not while, it is " + << while_op->ToString(); return false; } - // Check that the condition considers this. - const HloInstruction* condition_root = - while_instr->while_condition()->root_instruction(); - if (!Match(condition_root, - m::Lt(m::GetTupleElement(m::Parameter(0), param_idx), - m::ConstantScalar(trip_count)))) { + HloComputation* while_body = while_op->while_body(); + std::optional loop_induction_variable_tuple_idx = + GetLoopInductionVarTupleIdx(while_op); + if (!loop_induction_variable_tuple_idx.has_value()) { + VLOG(2) << "Induction variable tuple index is nullopt"; return false; } - // Check init - const HloInstruction* init_loop_iter = - while_instr->operand(0)->operand(param_idx); - if (!Match(init_loop_iter, m::ConstantScalar(0))) { + // The verifier makes sure that there is exactly one parameter. So, it is okay + // to directly access the parameter here. The function + // `GetLoopInductionVarTupleIdx` above makes sure that the parameter is a + // tuple. + HloInstruction* indvar = hlo_query::GetUniqueGteInstruction( + while_body->parameter_instruction(0), *loop_induction_variable_tuple_idx); + if (!indvar) { + VLOG(2) << "Unable to find unique GTE for while induction variable idx: " + << *loop_induction_variable_tuple_idx + << ", while op: " << while_op->ToString(); return false; } + const HloInstruction* update = while_body->root_instruction()->operand( + *loop_induction_variable_tuple_idx); - return true; + // The `update` instruction and `value` should only depend on the induction + // variable. + return IsOnlyDependentOn(/*consumer=*/update, /*producer=*/indvar) && + IsOnlyDependentOn(/*consumer=*/&value, /*producer=*/indvar); } // This returns true for the constants that are handled in the dynamic slice @@ -237,15 +238,17 @@ bool IsHandledConstantForDynamicSliceFusion(const HloInstruction& offset) { // This checks whether a dynamic index operation has all offsets that are either // constant or loop iteration offsets. -bool HasConstantOrLoopIterationOffsets( - const HloDynamicIndexInstruction& instr) { - return llvm::all_of(instr.index_operands(), [](const HloInstruction* offset) { - return IsLoopIterationNumber(*offset) || - IsHandledConstantForDynamicSliceFusion(*offset); - }); +bool HasConstantOrLoopIterationOffsets(const HloDynamicIndexInstruction& instr, + CallGraph* call_graph) { + return absl::c_all_of( + instr.index_operands(), [call_graph](const HloInstruction* offset) { + return IsValueFunctionOfLoopInductionVariable(*offset, call_graph) || + IsHandledConstantForDynamicSliceFusion(*offset); + }); } -UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { +UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr, + CallGraph* call_graph) { UseDefDataflowPaths sliced_operand_paths; // This set is used to avoid duplicates in the matched results. It contains @@ -296,10 +299,10 @@ UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { auto dynamic_index_operation = DynCast(maybe_slice_instr.value()); bool valid_slice_found = - slice_found && - ((dynamic_index_operation && - HasConstantOrLoopIterationOffsets(*dynamic_index_operation)) || - (*maybe_slice_instr)->opcode() == HloOpcode::kSlice); + slice_found && ((dynamic_index_operation && + HasConstantOrLoopIterationOffsets( + *dynamic_index_operation, call_graph)) || + (*maybe_slice_instr)->opcode() == HloOpcode::kSlice); if (valid_slice_found || processed_instrs.contains(maybe_slice_instr.value())) { // Even in the case of stopping at a match that has been processed, we @@ -321,7 +324,8 @@ UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { // vector. // Each entry contains the sliced paths for that user, i.e. the sequence of ops // following the dataflow from the user itself to the DUS (included). -DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr) { +DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr, + CallGraph* call_graph) { DefUseDataflowPaths sliced_user_paths; // This set is used to avoid duplicates in the matched results. It contains // the matched instructions that we have seen so far. @@ -352,7 +356,7 @@ DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr) { DynCast(maybe_dus_instr.value()); bool valid_dus_found = dus_found && dynamic_index_operation && - HasConstantOrLoopIterationOffsets(*dynamic_index_operation); + HasConstantOrLoopIterationOffsets(*dynamic_index_operation, call_graph); if (valid_dus_found || processed_instrs.contains(maybe_dus_instr.value())) { // Even in the case of stopping at a match that has been processed, we // still need to add instructions encountered in the sliced user path @@ -520,6 +524,7 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( matches_kv; std::vector matches; + std::unique_ptr call_graph = CallGraph::Build(module); // Collect all potential custom call matches in the non-fusion computations. for (HloComputation* computation : module->computations()) { if (computation->IsFusionComputation()) continue; @@ -527,9 +532,11 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( if ((HloPredicateIsOp(instr) && instr->shape().IsArray()) || IsLegacyCublasMatmul(*instr) || IsCustomCall(instr, platform_name_)) { - UseDefDataflowPaths sliced_operand_paths = GetSlicedOperandPaths(instr); + UseDefDataflowPaths sliced_operand_paths = + GetSlicedOperandPaths(instr, call_graph.get()); bool has_sliced_operand_paths = sliced_operand_paths.size() > 1; - DefUseDataflowPaths sliced_user_paths = GetSlicedUserPaths(instr); + DefUseDataflowPaths sliced_user_paths = + GetSlicedUserPaths(instr, call_graph.get()); bool has_sliced_user_paths = absl::c_any_of( sliced_user_paths, [&](auto& sliced_user_path) { return !sliced_user_path.empty(); }); diff --git a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc index c71f74444dcfaf..622fe832785c27 100644 --- a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc @@ -1988,7 +1988,7 @@ TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLaxScan) { HloModule lax_scan // This is the HLO generated for the following: - // + // // inp = jax.random.uniform(jax.random.key(128), (128, 128, 128)) // init = jnp.identity(128) // ans = jax.lax.scan(lambda carry, x : (init, x@carry), init, inp) @@ -2143,4 +2143,57 @@ TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDynamicSlice) { RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); } +TEST_F(DynamicSliceFusionRewriterTest, + OffsetAsFunctionOfInductionVariableShouldFuse) { + const char* hlo = R"( + HloModule test, replica_count=2 + add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = s32[] add(a, b) + } + body { + param.1 = (s32[], s32[32,32], s32[32,32]) parameter(0) + iter.1 = s32[] get-tuple-element(param.1), index=0 + src = s32[32,32] get-tuple-element(param.1), index=1 + dest = s32[32,32] get-tuple-element(param.1), index=2 + + // offset as a function of only the loop induction variable. + add.1 = s32[] add(iter.1, iter.1) + c3 = s32[] constant(3) + multiply.1 = s32[] multiply(add.1, c3) + c16 = s32[] constant(16) + offset.1 = s32[] subtract(multiply.1, c16) + + c0 = s32[] constant(0) + rs = s32[16,32] reduce-scatter(src), dimensions={0}, replica_groups={{0,1}}, to_apply=add + dus = s32[32,32] dynamic-update-slice(dest, rs, offset.1, c0) + c1 = s32[] constant(1) + add.2 = s32[] add(iter.1, c1) + ROOT tuple = tuple(add.2, src, dus) + } + condition { + param.2 = (s32[], s32[32,32], s32[32,32]) parameter(0) + iter.2 = s32[] get-tuple-element(param.2), index=0 + c16 = s32[] constant(16) + ROOT compare = pred[] compare(iter.2, c16), direction=LT + } + ENTRY main { + src = s32[32,32] parameter(0) + dest = s32[32,32] parameter(1) + c0 = s32[] constant(0) + tuple = (s32[], s32[32,32], s32[32,32]) tuple(c0, src, dest) + ROOT while = (s32[], s32[32,32], s32[32,32]) while(tuple), body=body, condition=condition + } + )"; + RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), R"( + // CHECK: dynamic-slice-fusion + // CHECK: %[[rs:.+]] = {{.+}} reduce-scatter({{.+}}) + // CHECK: ROOT %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}}) + // CHECK: body + // CHECK: %[[fusion:.+]] = {{.+}} fusion({{.+}}), kind=kCustom, calls=%dynamic-slice-fusion, + // CHECK-SAME: "fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation" + )"); +} + } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc index d574fc106282ad..d78dc65be97720 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc @@ -47,7 +47,7 @@ namespace { using ::tsl::testing::IsOkAndHolds; bool HasTritonBlockLevelFusionConfig(const HloInstruction* fusion) { - return fusion->opcode() == HloOpcode::kFusion && + return HloPredicateIsOp(fusion) && fusion->has_backend_config() && fusion->backend_config().ok() && fusion->backend_config() diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc index 32bb18fb2e77b9..0ac2e0fb167605 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc @@ -637,11 +637,11 @@ class Decision { static Decision Allow() { return {FusionDecision::Allow(), true}; }; - static Decision Deny(std::string_view value) { + static Decision Deny(absl::string_view value) { return {FusionDecision::Forbid(value), false}; } - static Decision NotProfitable(std::string_view value) { + static Decision NotProfitable(absl::string_view value) { return {FusionDecision::Forbid(value), true}; } diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc index 509cc8d76b320b..d2e60f6b547403 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc @@ -213,8 +213,8 @@ ENTRY e { p1 = f32[101,16] parameter(1) d = f32[16,7] dot(p1, s0), lhs_contracting_dims={0}, rhs_contracting_dims={1} - s1 = f32[3,33] slice(p0), slice={[10:13], [20:53]} - ROOT t = tuple(d, s1) + sout1 = f32[3,33] slice(p0), slice={[10:13], [20:53]} + ROOT t = tuple(d, sout1) })")); const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0}; @@ -247,9 +247,9 @@ ENTRY e { slice={[0:1], [0:1], [0:256], [0:256]} r0 = f32[256,256] reshape(s0) p1 = f16[2,2,256,256] parameter(1) - s1 = f16[1,1,256,256] slice(p1), + sout1 = f16[1,1,256,256] slice(p1), slice={[0:1], [0:1], [0:256], [0:256]} - r1 = f16[256,256] reshape(s1) + r1 = f16[256,256] reshape(sout1) ROOT d = f32[256,256] dot(r0, r1), lhs_contracting_dims={1}, rhs_contracting_dims={0} })")); diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc index ef034658c5059c..3accf17dbbe0cc 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc @@ -210,21 +210,21 @@ std::optional FindF8SubgraphRecursive( // The initial operand index is meaningless. Arbitrarily use -1. return InstrPath{{instr, -1}}; } - if (instr->operand_count() == 1 || instr->opcode() == HloOpcode::kDivide || - instr->opcode() == HloOpcode::kDynamicSlice || - instr->opcode() == HloOpcode::kPad) { + if (instr->operand_count() == 1 || + HloPredicateIsOp(instr)) { std::optional subgraph = FindF8SubgraphRecursive(instr->mutable_operand(0), visited_instrs); if (subgraph) { subgraph->emplace_back(std::make_pair(instr, 0)); } return subgraph; - } else if (instr->opcode() == HloOpcode::kMultiply || - instr->opcode() == HloOpcode::kSelect) { + } else if (HloPredicateIsOp( + instr)) { for (int k = 0; k < 2; ++k) { // Iterate over operands 0 and 1 for multiply and operands 1 and 2 for // select. - int operand_idx = k + (instr->opcode() == HloOpcode::kSelect); + int operand_idx = k + (HloPredicateIsOp(instr)); std::optional subgraph = FindF8SubgraphRecursive( instr->mutable_operand(operand_idx), visited_instrs); if (subgraph) { @@ -650,7 +650,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { bool supported_by_cublaslt, GemmIsSupportedByCublasLt(*instr, gemm_backend_config)); std::optional a, b; - if (supported_by_cublaslt && instr->opcode() == HloOpcode::kDot && + if (supported_by_cublaslt && HloPredicateIsOp(instr) && (a = MatchFp8Param( const_cast(instr->operand(0)))) && (b = MatchFp8Param( @@ -873,9 +873,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // Do not fuse broadcast unless we can fuse its input, as it will cause // broadcast materialization. - auto is_not_broadcast = [](const HloInstruction *instr) { - return instr->opcode() != HloOpcode::kBroadcast; - }; + auto is_not_broadcast = HloPredicateIsNotOp; // add(bitcast(gemm(a, b)), bias) -> // bitcast(add(gemm(a, b), bitcast(bias))) -> @@ -1013,7 +1011,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { .WithOneUser()))) { return F8ConvertD( instr, existing_gemm, d_scale, clamp_lower, clamp_upper, - /*mult_scale=*/(binary && binary->opcode() == HloOpcode::kMultiply)); + /*mult_scale=*/ + (binary && HloPredicateIsOp(binary))); } return absl::OkStatus(); } @@ -1223,13 +1222,13 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { for (std::pair op : x_ops) { std::vector operands = {x}; // Insert the additional operands of dynamic-slice ops. - if (op.first->opcode() == HloOpcode::kDynamicSlice) { + if (HloPredicateIsOp(op.first)) { for (int i = 1; i < op.first->operand_count(); ++i) { operands.emplace_back(op.first->mutable_operand(i)); } } // Convert the second operand of pad ops. - if (op.first->opcode() == HloOpcode::kPad) { + if (HloPredicateIsOp(op.first)) { HloInstruction *convert = instr->AddInstruction(HloInstruction::CreateConvert( ShapeUtil::ChangeElementType(op.first->operand(1)->shape(), @@ -1238,7 +1237,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { operands.push_back(convert); } // Convert and insert the additional operands of select ops. - if (op.first->opcode() == HloOpcode::kSelect) { + if (HloPredicateIsOp(op.first)) { // The first operand is the predicate. operands.emplace(operands.begin(), op.first->mutable_operand(0)); // Convert the remaining operand. @@ -1367,8 +1366,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // If necessary, invert the scaling factor of D and convert to F32. TF_ASSIGN_OR_RETURN( - d_scale, - InvertAndConvertScalar(d_scale, instr->opcode() == HloOpcode::kDivide)); + d_scale, InvertAndConvertScalar( + d_scale, HloPredicateIsOp(instr))); TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith(2, d_scale)); TF_RETURN_IF_ERROR(ReplaceInstruction(instr, existing_gemm)); @@ -1430,7 +1429,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { maybe_reduce = gemm_users[i]; } - if (maybe_reduce->opcode() == HloOpcode::kReduce && + if (HloPredicateIsOp(maybe_reduce) && maybe_reduce->operands().size() == 2 && maybe_reduce->operand(1)->opcode() == HloOpcode::kConstant && ShapeUtil::IsScalar(maybe_reduce->operand(1)->shape())) { @@ -1438,7 +1437,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloComputation *reduce_comp = reduce->to_apply(); HloInstruction *reduce_comp_root = reduce_comp->root_instruction(); if (reduce->operand(1)->literal().GetAsDouble({}) <= 0. && - reduce_comp_root->opcode() == HloOpcode::kMaximum && + HloPredicateIsOp(reduce_comp_root) && reduce_comp_root->operand(0)->opcode() == HloOpcode::kParameter && reduce_comp_root->operand(1)->opcode() == HloOpcode::kParameter) { reduce_damax = reduce; @@ -1571,7 +1570,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return false; } - if (bias->opcode() != HloOpcode::kParameter) { + if (HloPredicateIsNotOp(bias)) { // Not a parameter; can overwrite. return true; } diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test_lib.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test_lib.cc index a24b51daaa6e26..c892c29e93ac6c 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test_lib.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test_lib.cc @@ -31,6 +31,11 @@ const auto& GemmRewriteTestBase::device_desc() const { return backend().default_stream_executor()->GetDeviceDescription(); } +stream_executor::SemanticVersion GemmRewriteTestBase::GetRuntimeVersion() + const { + return device_desc().runtime_version(); +} + const stream_executor::GpuComputeCapability& GemmRewriteTestBase::Capability() const { return device_desc().gpu_compute_capability(); diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test_lib.h b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test_lib.h index c31b2e0fad6ecb..44d92d9cccca88 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test_lib.h +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test_lib.h @@ -32,7 +32,7 @@ class GemmRewriteTestBase : public GpuCodegenTest { const stream_executor::GpuComputeCapability& Capability() const; stream_executor::SemanticVersion GetToolkitVersion() const; - + stream_executor::SemanticVersion GetRuntimeVersion() const; bool IsCuda() const; bool IsRocm() const; diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc index dd99da9bcddf35..a6cb22add0cd10 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc @@ -39,6 +39,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/transforms/simplifiers/sub_byte_normalization.h" +#include "xla/layout.h" #include "xla/layout_util.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/hlo_creation_utils.h" @@ -70,6 +71,16 @@ PrimitiveType GetUniqueOutputTypeOfFusible(const HloInstruction& fusible) { return first_output_type; } +bool IsShapeDefaultMemorySpace(const Shape& shape) { + bool are_all_subshapes_default_space = true; + ShapeUtil::ForEachSubshape( + shape, [&](const Shape& subshape, const ShapeIndex& /*index*/) { + are_all_subshapes_default_space &= + LayoutUtil::MemorySpace(subshape) == Layout::kDefaultMemorySpace; + }); + return are_all_subshapes_default_space; +} + class HorizontalLoopFusionImpl { public: explicit HorizontalLoopFusionImpl( @@ -155,16 +166,15 @@ bool IsConcatenationInputFusion(const HloInstruction& instr) { } bool IsDynamicUpdateSliceFusion(const HloInstruction* instr) { - if (instr->opcode() != HloOpcode::kFusion) { + if (HloPredicateIsNotOp(instr)) { return false; } auto root = instr->fused_expression_root(); - if (root->opcode() == HloOpcode::kTuple) { - return absl::c_any_of(root->operands(), [&](const HloInstruction* operand) { - return operand->opcode() == HloOpcode::kDynamicUpdateSlice; - }); + if (HloPredicateIsOp(root)) { + return absl::c_any_of(root->operands(), + HloPredicateIsOp); } - return root->opcode() == HloOpcode::kDynamicUpdateSlice; + return HloPredicateIsOp(root); } bool IsFusibleCandidate(const HloInstruction& instr, @@ -180,6 +190,13 @@ bool IsFusibleCandidate(const HloInstruction& instr, return false; } + // Only consider instructions with default memory space operands and outputs + // to be fusable. + if (!IsShapeDefaultMemorySpace(instr.shape())) return false; + for (auto operand : instr.operands()) { + if (!IsShapeDefaultMemorySpace(operand->shape())) return false; + } + // Require no further check for element-wise instructions. if (instr.IsElementwise() && instr.operand_count() > 0) { return true; @@ -300,14 +317,13 @@ void HorizontalLoopFusionImpl::FusionCandidates::Initialize( << " rejects may-not-be profitable fusion instr" << instr->ToString(); continue; - } else if ((sliced_input_fusion_ || IsDynamicUpdateSliceFusion(instr)) && + } else if (IsDynamicUpdateSliceFusion(instr) && AnyOperandIsSharedAmongFusions(instr, fusible_candidates)) { - // Don't fuse fusions with at least one shared operand because we cannot - // i/o alias the produced horizontal fusion due to the concat insertion - // (or run into aliasing problems with DynamicUpdateSlice fusions). + // Don't fuse DUS fusions with shared operands because we cannot + // i/o alias the produced horizontal fusion due to the concat insertion. VLOG(2) << "sliced_input_fusion=" << sliced_input_fusion_ - << " rejects the fusion instr because it shares parameter with" - << " other fusion candidates, instr: " << instr->ToString(); + << " rejects the DUS fusion because it shares an operand with" + << " other fusion candidates: " << instr->ToString(); continue; } else { // Encapsulate it into a fusion computation for unified representation diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc index f79c4a60b59ffd..85974ab568524e 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc @@ -60,7 +60,7 @@ auto MakeDeviceDescription() { class HorizontalLoopFusionTest : public HloTestBase { public: static bool IsFusion(const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kFusion; + return HloPredicateIsOp(instr); } const se::DeviceDescription device_description_{MakeDeviceDescription()}; }; @@ -189,6 +189,23 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) { HorizontalLoopFusion{device_description_}.Run(module.get()).value()); } +TEST_F(HorizontalLoopFusionTest, NegativeTestForDifferentMemorySpace) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule NegativeTestForIncompatibleSpaces + ENTRY main { + arg0 = f32[1]{0} parameter(0) + arg1 = f32[1]{0:S(5)} parameter(1) + cp1 = f32[1]{0} copy(arg0) + cp2 = f32[1]{0:S(5)} copy(arg1) + ROOT tuple_out = (f32[1]{0}, f32[1]{0:S(5)}) tuple(cp1, cp2) + } +)") + .value(); + + EXPECT_FALSE( + HorizontalLoopFusion{device_description_}.Run(module.get()).value()); +} + TEST_F(HorizontalLoopFusionTest, FusingIntoKLoopAndKInputTogether) { auto module = ParseAndReturnVerifiedModule(R"( HloModule FusingIntoKLoopAndKInputTogether @@ -279,7 +296,7 @@ TEST_F(HorizontalLoopFusionTest, FusingIntoKLoopAndKInputTogether) { int input_fusion_count = 0; int loop_fusion_count = 0; for (auto inst : module->entry_computation()->MakeInstructionPostOrder()) { - if (inst->opcode() == HloOpcode::kFusion) { + if (HloPredicateIsOp(inst)) { input_fusion_count += (inst->fusion_kind() == HloInstruction::FusionKind::kInput) ? 1 : 0; loop_fusion_count += @@ -648,7 +665,7 @@ TEST_F(HorizontalLoopFusionTest, GmockMatch(m::Tuple(m::Multiply(), m::Add()))); } -TEST_F(HorizontalLoopFusionTest, ForbidSharedParametersWhenUsingConcatenation) { +TEST_F(HorizontalLoopFusionTest, AllowSharedOperandsWhenUsingConcatenation) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( f { p = f16[] parameter(0) @@ -668,9 +685,7 @@ e { // As fusions f and g have different output shapes, the horizontal fusion // algorithm would only consider merging them using concatenation/slicing. - // The horizontal fusion is not supposed to happen in this - // example though because f and g share an input parameter. - EXPECT_FALSE( + EXPECT_TRUE( HorizontalLoopFusion{device_description_}.Run(module.get()).value()); } diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc index 4e7234cedb52ae..5bbd64145fbc19 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc @@ -103,8 +103,9 @@ HeuristicLayoutAssignment(const HloInstruction* instr, instr->convolution_dimension_numbers(); Shape input_shape = instr->operand(0)->shape(); PrimitiveType input_ty = instr->operand(0)->shape().element_type(); + int num_spatial_dimensions = dnums.input_spatial_dimensions_size(); if (primitive_util::IsIntegralType(input_ty)) { - if (input_ty == S8 && dnums.input_spatial_dimensions_size() == 2 && + if (input_ty == S8 && num_spatial_dimensions == 2 && input_shape.dimensions_size() == 5) { VLOG(2) << "Using NCHW_VECT_C for int8_t conv " << instr->ToString(); return kAllNCHW_VECT_C; @@ -131,6 +132,31 @@ HeuristicLayoutAssignment(const HloInstruction* instr, return kAllNHWC; } + // Despite the specialized logic below for Volta, we expect GPUs with Tensor + // Cores work best using NHWC layouts for cuDNN convolutions---as per + // https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#tensor-layout. + if (auto* cc = std::get_if(&gpu_version)) { + // TODO(b/383560056): investigate chips below Hopper as well. + if (cc->IsAtLeast(se::CudaComputeCapability::HOPPER)) { + // With that said, cuDNN's documentation states that NHWC is not supported + // for float64, so we use NCHW instead. + if (input_ty == F64) { + VLOG(2) << "Using NCHW for F64 conv " << instr->ToString() << " on " + << cc->ToString(); + return kAllNCHW; + // TODO(b/383560056): find the right filter for 3D convolutions. 3D + // convolutions also have a much smaller surface of support. We filter + // them out completely as well for now. + } else if (num_spatial_dimensions > 2) { + VLOG(2) << "Using NHWC for " << num_spatial_dimensions << "D conv " + << instr->ToString() << " on " << cc->ToString(); + return kAllNCHW; + } else { + return kAllNHWC; + } + } + } + const auto* rocm_compute_capability = std::get_if(&gpu_version); if (rocm_compute_capability && input_ty == F16) return kAllNHWC; diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc index 058d47e509c0cc..e38dd8e3b7548c 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/service/computation_layout.h" @@ -533,6 +534,76 @@ ENTRY entry { << ". Output: " << output_layout; } +TEST_F(LayoutAssignmentTest, CuDNNConvolutionHasNHWCLayoutPostHopper) { + const char* hlo = R"( +ENTRY entry { + p0 = f32[1,64,64,16]{3,2,1,0} parameter(0) + p1 = f32[3,16,3,32]{3,2,1,0} parameter(1) + ROOT conv = (f32[1,64,64,32]{3,2,1,0}, u8[0]{0}) custom-call(p0, p1), + window={size=3x3 pad=1_1x1_1}, dim_labels=b10f_o10i->b10f, + custom_call_target="__cudnn$convForwardGraph" +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo)); + ComputationLayout computation_layout( + hlo_module->entry_computation()->ComputeProgramShape()); + + GpuLayoutAssignment layout_assignment( + &computation_layout, se::CudaComputeCapability::Hopper(), GetDnnVersion(), + GetDeviceDescription()); + + EXPECT_THAT(layout_assignment.Run(hlo_module.get()), IsOkAndHolds(true)); + + // We start from b10f_o10i->b10f, meaning that the inputs start out as + // NWHC_OWHI->NWHC. Layout assignment should yield layouts of the form + // {3,1,2,0} (transpose the middle dimensions) for both inputs and for the + // output, therefore, in order to get to the desired NHWC_OHWI->NHWC layout. + EXPECT_THAT( + RunFileCheck(hlo_module->ToString(HloPrintOptions::ShortParsable()), R"( +// CHECK-DAG: [[P0:[^ ]+]] = {{.*}} parameter(0) +// CHECK-DAG: [[P1:[^ ]+]] = {{.*}} parameter(1) +// CHECK-DAG: [[COPY_P0:[^ ]+]] = {{.*}}{3,1,2,0} copy([[P0]]) +// CHECK-DAG: [[COPY_P1:[^ ]+]] = {{.*}}{3,1,2,0} copy([[P1]]) +// CHECK: [[CONV:[^ ]+]] = {{.*}}{3,1,2,0}, {{.*}} custom-call([[COPY_P0]], [[COPY_P1]]) +)"), + IsOkAndHolds(true)); +} + +TEST_F(LayoutAssignmentTest, F64CuDNNConvolutionHasNCHWLayoutPostHopper) { + const char* hlo = R"( +ENTRY entry { + p0 = f64[2,64,64,16]{3,2,1,0} parameter(0) + p1 = f64[6,16,3,32]{3,2,1,0} parameter(1) + ROOT conv = (f64[2,64,64,32]{3,2,1,0}, u8[0]{0}) custom-call(p0, p1), + window={size=3x3 pad=1_1x1_1}, dim_labels=b10f_o10i->b10f, + custom_call_target="__cudnn$convForwardGraph" +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo)); + ComputationLayout computation_layout( + hlo_module->entry_computation()->ComputeProgramShape()); + + GpuLayoutAssignment layout_assignment( + &computation_layout, se::CudaComputeCapability::Hopper(), GetDnnVersion(), + GetDeviceDescription()); + + EXPECT_THAT(layout_assignment.Run(hlo_module.get()), IsOkAndHolds(true)); + + // We start from b10f_o10i->b10f, meaning that the inputs start out as + // NWHC_OWHI->NWHC. Layout assignment should yield layouts of the form + // {1,2,3,0} for both inputs and for the output, therefore, in order to get to + // the desired NCHW_OIHW->NCHW layout. + EXPECT_THAT( + RunFileCheck(hlo_module->ToString(HloPrintOptions::ShortParsable()), R"( +// CHECK-DAG: [[P0:[^ ]+]] = {{.*}} parameter(0) +// CHECK-DAG: [[P1:[^ ]+]] = {{.*}} parameter(1) +// CHECK-DAG: [[COPY_P0:[^ ]+]] = {{.*}}{1,2,3,0} copy([[P0]]) +// CHECK-DAG: [[COPY_P1:[^ ]+]] = {{.*}}{1,2,3,0} copy([[P1]]) +// CHECK: [[CONV:[^ ]+]] = {{.*}}{1,2,3,0}, {{.*}} custom-call([[COPY_P0]], [[COPY_P1]]) +)"), + IsOkAndHolds(true)); +} + TEST_F(LayoutAssignmentTest, ConvCuDNNF8) { if (!GetCudaComputeCapability().IsAtLeast( se::CudaComputeCapability::HOPPER)) { diff --git a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc index 4d46e105f48c12..88906d6361c767 100644 --- a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc @@ -142,7 +142,7 @@ int FusionPriority(const HloInstruction* instr) { if (instr->IsMultiOutputFusion()) { return 2; } - if (instr->opcode() == HloOpcode::kFusion) { + if (HloPredicateIsOp(instr)) { return 1; } return 0; @@ -170,7 +170,7 @@ FusionDecision OperandReachableFromProducer( // map, it has been created by fusion in this pass. Simply move // on to its operand, which is in the reachability map. if (!reachability.IsPresent(operand) && - operand->opcode() == HloOpcode::kGetTupleElement) { + HloPredicateIsOp(operand)) { operand = operand->operand(0); } CHECK(reachability.IsPresent(operand) && reachability.IsPresent(&producer)) @@ -274,9 +274,8 @@ bool IsSiblingFusionCandidate(const HloInstruction* instr, // If this is the case, we bail out because the transformation assumes // the users are get-tuple-element. return (!instr->IsMultiOutputFusion() || - absl::c_all_of(instr->users(), [&](const HloInstruction* user) { - return user->opcode() == HloOpcode::kGetTupleElement; - })); + absl::c_all_of(instr->users(), + HloPredicateIsOp)); } FusionDecision CanFuseSiblings(const HloInstruction& sibling_consumer_1, @@ -386,7 +385,7 @@ bool MultiOutputFusion::FuseSiblings(HloInstruction* parent, "| inside multi-output fusion"), /*producer=*/fused); - if (fused->opcode() == HloOpcode::kFusion) { + if (HloPredicateIsOp(fused)) { remaining->MergeFusionInstructionIntoMultiOutput(fused); if (fused->IsInputFusion()) { remaining->set_fusion_kind(HloInstruction::FusionKind::kInput); @@ -427,7 +426,7 @@ absl::StatusOr MultiOutputFusion::DoMultiOutputFusion() { auto* producer = *it; // Never multi-output fuse constants. To the extent that we want to fuse // constants, that should be handled by the regular fusion pass. - if (producer->opcode() == HloOpcode::kConstant) { + if (HloPredicateIsOp(producer)) { VLOG(3) << producer->name() << " is a constant."; continue; } @@ -462,7 +461,7 @@ absl::StatusOr MultiOutputFusion::DoMultiOutputFusion() { TF_RETURN_IF_ERROR(cost_analysis.RemoveInstruction(consumer_for_fusion)); HloInstruction* input_fusion; - if (consumer_for_fusion->opcode() == HloOpcode::kFusion) { + if (HloPredicateIsOp(consumer_for_fusion)) { input_fusion = consumer_for_fusion; VLOG(2) << "Fuse producer " << producer->name() << " into its consumer " << consumer_for_fusion->name(); @@ -484,7 +483,7 @@ absl::StatusOr MultiOutputFusion::DoMultiOutputFusion() { "| inside multi-output fusion"), /*producer=*/producer); - if (producer->opcode() == HloOpcode::kFusion) { + if (HloPredicateIsOp(producer)) { input_fusion->MergeFusionInstructionIntoMultiOutput(producer); } else { input_fusion->FuseInstructionIntoMultiOutput(producer); diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc index e0151fa7ac9628..cc82c1ed8971b8 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc @@ -86,7 +86,7 @@ absl::Status FuseInstructionsForConsumer( continue; } - if (instruction->opcode() == HloOpcode::kParameter) { + if (HloPredicateIsOp(instruction)) { add_parameter(instruction); continue; } @@ -328,9 +328,8 @@ absl::Status MakeNestedFusionFromGemmFusion(HloFusionInstruction* fusion, } size_t GetDotCount(HloComputation* computation) { - return absl::c_count_if(computation->instructions(), [](HloInstruction* hlo) { - return hlo->opcode() == HloOpcode::kDot; - }); + return absl::c_count_if(computation->instructions(), + HloPredicateIsOp); } // Returns the set of instructions that are reachable from 'instruction' using @@ -426,7 +425,7 @@ absl::Status HoistBitcastUpwardsToCallers( Shape shape = bitcast->shape(); for (HloInstruction* instruction : producers) { *instruction->mutable_shape() = shape; - if (instruction->opcode() != HloOpcode::kParameter) { + if (HloPredicateIsNotOp(instruction)) { continue; } // For parameters, we need to bitcast the caller's operand. @@ -490,7 +489,7 @@ absl::Status TryHoistBitcastsInComputationToCallers(HloInstruction* dot, CallGraph* call_graph) { auto callers = call_graph->GetComputationCallers(dot->parent()); for (HloInstruction* instruction : GetProducerSet(dot)) { - if (instruction->opcode() != HloOpcode::kBitcast) { + if (HloPredicateIsNotOp(instruction)) { continue; } VLOG(2) << "Hoisting bitcast upwards " << instruction->ToString(); @@ -500,7 +499,7 @@ absl::Status TryHoistBitcastsInComputationToCallers(HloInstruction* dot, } } for (HloInstruction* instruction : GetConsumerSet(dot)) { - if (instruction->opcode() != HloOpcode::kBitcast) { + if (HloPredicateIsNotOp(instruction)) { continue; } VLOG(2) << "Hoisting bitcast downwards " << instruction->ToString(); diff --git a/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.cc index 5702daa8f531b2..378935dc6a81d6 100644 --- a/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.cc @@ -100,7 +100,7 @@ HloInstruction* FindUniqueGTEUserWithIndex(const HloInstruction* op, HloInstruction* gte = nullptr; for (auto user : op->users()) { - if (user->opcode() != HloOpcode::kGetTupleElement) { + if (HloPredicateIsNotOp(user)) { continue; } if (user->tuple_index() == idx) { @@ -119,7 +119,7 @@ bool HasGTEUserWithIndex(const HloInstruction* op, int64_t idx) { CHECK(op->shape().IsTuple()); for (auto user : op->users()) { - if (user->opcode() != HloOpcode::kGetTupleElement) { + if (HloPredicateIsNotOp(user)) { continue; } if (user->tuple_index() == idx) { @@ -139,12 +139,12 @@ bool HasGTEUserWithIndex(const HloInstruction* op, int64_t idx) { // TODO(bixia): investigate the possible of implementing // m::TrivialTuple(m::RecvDone(&instr)) as suggested by code review. HloInstruction* MaySkipTrivialTuple(HloInstruction* op) { - if (op->opcode() != HloOpcode::kTuple) { + if (HloPredicateIsNotOp(op)) { return op; } HloInstruction* hidden_op = nullptr; for (auto opnd : op->mutable_operands()) { - if (opnd->opcode() != HloOpcode::kGetTupleElement) { + if (HloPredicateIsNotOp(opnd)) { return op; } if (hidden_op == nullptr) { @@ -182,10 +182,9 @@ FindConsecutiveAndBalanceBlockOfSendDoneRecvDone( // tuple, find such block. for (int64_t i = 0; i < while_init->operand_count(); ++i) { const HloInstruction* op = while_init->operand(i); - if ((op->opcode() == HloOpcode::kRecvDone || - op->opcode() == HloOpcode::kSendDone) && + if ((HloPredicateIsOp(op)) && op->frontend_attributes().map().count(kSendRecvPipelineAttr) > 0) { - if (op->opcode() == HloOpcode::kRecvDone) { + if (HloPredicateIsOp(op)) { difference++; } else { difference--; @@ -212,8 +211,7 @@ FindConsecutiveAndBalanceBlockOfSendDoneRecvDone( for (int64_t i = pipelined_p2p_info.opnd_end; i < while_init->operand_count(); ++i) { const HloInstruction* op = while_init->operand(i); - if (op->opcode() == HloOpcode::kRecvDone || - op->opcode() == HloOpcode::kSendDone) { + if (HloPredicateIsOp(op)) { VLOG(10) << "SendDone/RecvDone outside the consecutive block"; return std::nullopt; break; @@ -258,7 +256,7 @@ std::optional FindPipelinedP2P( const HloInstruction* while_op) { VLOG(10) << "while_op: " << while_op->ToString(); const HloInstruction* while_init = while_op->while_init(); - if (while_init->opcode() != HloOpcode::kTuple || + if (HloPredicateIsNotOp(while_init) || while_init->user_count() != 1) { return std::nullopt; } @@ -287,7 +285,7 @@ std::optional FindPipelinedP2P( for (int64_t i = pipelined_p2p_info->opnd_start; i < pipelined_p2p_info->opnd_end; ++i) { const HloInstruction* op = while_init->operand(i); - if (op->opcode() == HloOpcode::kRecvDone) { + if (HloPredicateIsOp(op)) { if (!FindUniqueGTEUserWithIndex(while_op, i)) { VLOG(10) << "While result get-tuple-element user with index " << i << " not unique"; @@ -300,7 +298,7 @@ std::optional FindPipelinedP2P( return std::nullopt; } } else { - CHECK(op->opcode() == HloOpcode::kSendDone); + CHECK(HloPredicateIsOp(op)); if (HasGTEUserWithIndex(while_op, i) || HasGTEUserWithIndex(while_body->parameter_instruction(0), i)) { VLOG(10) << "SendDone with index " << i << " has unexpected users"; @@ -375,7 +373,7 @@ absl::Status RemoveDoneOpsAndUpdateSequence( return absl::OkStatus(); }; for (auto op : ops) { - if (op->opcode() == HloOpcode::kTuple) { + if (HloPredicateIsOp(op)) { InstructionVector to_remove; HloInstruction* tuple_op = op; op = MaySkipTrivialTuple(tuple_op); @@ -460,7 +458,7 @@ absl::Status RewritePipelinedP2PWhileBody( for (int64_t i = opnd_start; i < opnd_end; ++i) { const HloInstruction* op = root->operand(i); op = MaySkipTrivialTuple(op); - if (op->opcode() == HloOpcode::kRecvDone) { + if (HloPredicateIsOp(op)) { HloInstruction* gte = FindUniqueGTEUserWithIndex(param, i); CHECK(gte != nullptr); recv_dones.push_back(gte); @@ -473,7 +471,7 @@ absl::Status RewritePipelinedP2PWhileBody( new_recv_dones.push_back(recv_done); continue; } - CHECK(op->opcode() == HloOpcode::kSendDone); + CHECK(HloPredicateIsOp(op)); // Create the new SendDone using the new while-op result. HloInstruction* send = computation->AddInstruction( HloInstruction::CreateGetTupleElement(param, i)); @@ -575,7 +573,7 @@ absl::Status TransformLoop( for (int64_t i = opnd_start; i < opnd_end; ++i) { HloInstruction* op = while_init->mutable_operand(i); done_ops.push_back(op); - if (op->opcode() == HloOpcode::kRecvDone) { + if (HloPredicateIsOp(op)) { HloInstruction* gte = FindUniqueGTEUserWithIndex(while_op, i); CHECK(gte != nullptr); recv_dones.push_back(gte); @@ -590,7 +588,7 @@ absl::Status TransformLoop( CopyInstructionInfo(op, recv_done); continue; } - CHECK(op->opcode() == HloOpcode::kSendDone); + CHECK(HloPredicateIsOp(op)); // Create the new SendDone using the new while-op result. HloInstruction* send = computation->AddInstruction( HloInstruction::CreateGetTupleElement(new_while_op, i)); @@ -654,7 +652,7 @@ absl::StatusOr ProcessComputation( collective_in_computation[computation] = true; } - if (hlo->opcode() != HloOpcode::kWhile) { + if (HloPredicateIsNotOp(hlo)) { idx++; continue; } diff --git a/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc index deb99d9b6a1bdb..400901ec6a65fd 100644 --- a/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc @@ -710,7 +710,7 @@ TEST_F(PipelinedP2pRewriterTest, NoCrashOnDynamicSliceFusion) { ENTRY %main (data.1: s32[8,32]) -> s32[2,32] { %data.1 = s32[8,32]{1,0} parameter(0) - ROOT %address-computation.1 = s32[2,32]{1,0} fusion(s32[8,32]{1,0} %data.1), kind=kCustom, calls=%dynamic-slice-fusion, + ROOT %address-computation.1 = s32[2,32]{1,0} fusion(s32[8,32]{1,0} %data.1), kind=kCustom, calls=%dynamic-slice-fusion, backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"}}} })"; diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc index 994bf5a8524498..fde1bc29c08e4d 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc @@ -35,6 +35,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/blocking_counter.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "llvm/ADT/STLExtras.h" @@ -65,7 +66,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/blocking_counter.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" @@ -176,10 +176,10 @@ class PriorityFusionQueue { std::vector instructions; for (auto* instruction : computation->MakeInstructionPostOrder()) { TF_CHECK_OK(UpdatePerformanceModelCache(instruction)); - if (instruction->opcode() == HloOpcode::kParameter || + if (HloPredicateIsOp(instruction) || instruction->user_count() == 0 || !instruction->IsFusible() || - instruction->opcode() == HloOpcode::kTuple || - instruction->opcode() == HloOpcode::kGetTupleElement) { + HloPredicateIsOp( + instruction)) { continue; } instructions.push_back(instruction); @@ -226,7 +226,7 @@ class PriorityFusionQueue { fn(); } }; - tsl::BlockingCounter counter(instructions.size()); + absl::BlockingCounter counter(instructions.size()); std::vector priorities(instructions.size()); for (size_t i = 0; i < instructions.size(); ++i) { @@ -255,7 +255,7 @@ class PriorityFusionQueue { current_consumers_ = current_producer_->users(); - if (current_producer_->opcode() == HloOpcode::kBitcast) { + if (HloPredicateIsOp(current_producer_)) { // We don't check if bitcasts can be fused with all consumers, so we // have to do it here. llvm::erase_if(current_consumers_, [&](HloInstruction* consumer) { @@ -423,8 +423,8 @@ class PriorityFusionQueue { // Collect the instructions whose priorities need to be updated. for (HloInstruction* operand : fusion->operands()) { if (operand == original_producer || - operand->opcode() == HloOpcode::kConstant || - operand->opcode() == HloOpcode::kGetTupleElement) { + HloPredicateIsOp( + operand)) { continue; } // Need to consider only instructions that are fusible, e.g., rng with @@ -476,13 +476,13 @@ class PriorityFusionQueue { // users. Priority CalculateProducerPriority(HloInstruction* producer) { // Bitcasts should always be fused first, since they are no-ops. - if (producer->opcode() == HloOpcode::kBitcast) { + if (HloPredicateIsOp(producer)) { return absl::InfiniteDuration(); } // We always fuse constants, but the cost model doesn't handle them very // well: fusing constants changes costs significantly. Also, there's no // point recomputing priorities. Therefore, we fuse all of them at the end. - if (producer->opcode() == HloOpcode::kConstant) { + if (HloPredicateIsOp(producer)) { return -absl::InfiniteDuration(); } @@ -678,7 +678,7 @@ class PriorityFusionQueue { return can_fuse_triton; } - if (consumer->opcode() == HloOpcode::kBitcast) { + if (HloPredicateIsOp(consumer)) { return FusionDecision::Forbid( "not fusing into a single bitcast as consumer"); } @@ -784,7 +784,7 @@ class PriorityFusionQueue { bool has_non_bitcast_user = false; for (const auto& user : producer->users()) { - if (user->opcode() == HloOpcode::kBitcast) { + if (HloPredicateIsOp(user)) { continue; } has_non_bitcast_user = true; @@ -896,8 +896,8 @@ class PriorityFusionQueue { // // This function matches the emitter logic. bool IsSmallConstant(const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kConstant && instr->shape().IsArray() && - ShapeUtil::ElementsIn(instr->shape()) <= 1; + return HloPredicateIsOp(instr) && + instr->shape().IsArray() && ShapeUtil::ElementsIn(instr->shape()) <= 1; } bool PriorityFusion::ConsumeFuel(HloInstruction* producer, @@ -1003,7 +1003,7 @@ absl::StatusOr PriorityFusion::Run( for (auto* consumer : fusion_queue->current_consumers()) { // Don't fuse into single bitcasts. We ignore them in the check // CanFuseWithAllNonBitcastUsers(), so we need to check it here. - if (consumer->opcode() == HloOpcode::kBitcast) { + if (HloPredicateIsOp(consumer)) { continue; } if (!ConsumeFuel(producer, consumer)) continue; @@ -1117,7 +1117,7 @@ HloInstruction* PriorityFusion::Fuse(HloInstruction* producer, auto kind = ChooseKind(producer, consumer); HloInstruction* fusion_instruction = consumer; - if (fusion_instruction->opcode() != HloOpcode::kFusion) { + if (HloPredicateIsNotOp(fusion_instruction)) { fusion_instruction = computation->AddInstruction( HloInstruction::CreateFusion(consumer->shape(), kind, consumer)); TF_CHECK_OK(computation->ReplaceInstruction(consumer, fusion_instruction)); @@ -1129,7 +1129,7 @@ HloInstruction* PriorityFusion::Fuse(HloInstruction* producer, computation->execution_thread(), /*skip_async_execution_thread_overwrite=*/false); - if (producer->opcode() == HloOpcode::kFusion) { + if (HloPredicateIsOp(producer)) { fusion_instruction->MergeFusionInstruction(producer); } else { fusion_instruction->FuseInstruction(producer); diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc index f790320fb5cf39..79ccd4d06b45d1 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc @@ -521,8 +521,8 @@ TEST_F(PriorityFusionTest, DontFuseIntoFirstOperandOfScatter) { ENTRY FuseIntoScatter { p0 = s32[3,3] parameter(0) operand = s32[3,3] add(p0, p0) - p1 = s32[2] parameter(1) - indices = s32[2] add(p1, p1) + p1 = s32[2,1] parameter(1) + indices = s32[2,1] add(p1, p1) p2 = s32[2,3] parameter(2) updates = s32[2,3] add(p2, p2) scatter = s32[3,3] scatter(operand, indices, updates), diff --git a/third_party/xla/xla/service/gpu/transforms/ragged_all_to_all_decomposer_test.cc b/third_party/xla/xla/service/gpu/transforms/ragged_all_to_all_decomposer_test.cc index c7160d87b3ffd6..be1ddb782e3f66 100644 --- a/third_party/xla/xla/service/gpu/transforms/ragged_all_to_all_decomposer_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/ragged_all_to_all_decomposer_test.cc @@ -56,7 +56,7 @@ ENTRY main { send_sizes = s32[2] parameter(3) output_offsets = s32[2] parameter(4) recv_sizes = s32[2] parameter(5) - ROOT ra2a = bf16[16] ragged-all-to-all(input, output, input_offsets, + ROOT ra2a = bf16[16] ragged-all-to-all(input, output, input_offsets, send_sizes, output_offsets, recv_sizes), replica_groups={{0,1}} } )")); diff --git a/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc index e17547857ac05d..c08f3794408130 100644 --- a/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc @@ -50,7 +50,7 @@ absl::StatusOr ReduceScatterCreator::Run( module->MakeNonfusionComputations(execution_threads)) { for (HloInstruction *instruction : computation->MakeInstructionPostOrder()) { - if (instruction->opcode() != HloOpcode::kAllReduce) { + if (HloPredicateIsNotOp(instruction)) { continue; } auto *ar = Cast(instruction); diff --git a/third_party/xla/xla/service/gpu/transforms/rename_fusions.cc b/third_party/xla/xla/service/gpu/transforms/rename_fusions.cc index 29f3edf968fb3c..ac396b3fd5915f 100644 --- a/third_party/xla/xla/service/gpu/transforms/rename_fusions.cc +++ b/third_party/xla/xla/service/gpu/transforms/rename_fusions.cc @@ -78,7 +78,7 @@ absl::StatusOr RenameFusions::Run( const absl::flat_hash_set& execution_threads) { for (HloComputation* computation : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kFusion || + if (HloPredicateIsNotOp(instruction) || instruction->fusion_kind() == HloInstruction::FusionKind::kCustom) { continue; } diff --git a/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.cc b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.cc index d7962130a2eeb8..3177ce781bb928 100644 --- a/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.cc +++ b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.cc @@ -41,7 +41,7 @@ absl::StatusOr AnnotateSchedulingInstructionNames( // We skip constants as we might have to sanitize them in order to satisfy // LLVM backend. I.e. we allow `GpuSanitizeConstantNames` pass to run post // scheduling. - if (inst->opcode() == HloOpcode::kConstant) { + if (HloPredicateIsOp(inst)) { continue; } inst->set_metadata_scheduling_name(inst->name()); diff --git a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc index 02ea89849a3d87..ebb3cf9eff3de6 100644 --- a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/service/gpu/transforms/softmax_rewriter_triton.h" -#include #include #include #include @@ -60,9 +59,9 @@ limitations under the License. #include "xla/service/hlo_cost_analysis.h" #include "xla/service/instruction_fusion.h" #include "xla/shape.h" -#include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" #include "xla/tools/hlo_decomposer.h" +#include "xla/tsl/platform/errors.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -81,45 +80,6 @@ bool HasDefaultLayout(const Shape& shape) { LayoutUtil::IsMonotonicWithDim0Major(shape.layout()); } -// Returns true if a trivially connected producer of 'consumer' with opcode -// 'opcode' exists. If such an instruction is found, the value of 'producer' is -// set to it. The definition of "trivial" operations is as given in -// 'IsTriviallyFusible'. -bool TrivialEdge(HloInstruction** producer, HloInstruction* consumer, - HloOpcode opcode, const se::GpuComputeCapability& gpu_version); - -bool BitcastIsTilingNoop(HloInstruction* bitcast, - const se::GpuComputeCapability& gpu_version) { - CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast); - - if (ShapeUtil::IsEffectiveScalar(bitcast->shape())) { - return true; - } - - // In the Softmax rewriter for now, tiling is derived from a hero reduction - // operation, which should be reducing its input on the last axis. Therefore, - // a bitcast is always a no-op with regards to a tile if - // (1) it does not change the size of the reduction dimension of its input - // (the last one); if its input is already reduced, then (1) is true - // by default - // (2) the layout of its output is ordered in the same way as the layout of - // its input. This is a fuzzy definition, but since we assume fusible - // ops to always have a default layout, we can just check if both the - // bitcast and its input have a default layout - auto last_dimension = [](const HloInstruction* instr) { - return instr->shape().dimensions().back(); - }; - - HloInstruction* reduce = nullptr; - TrivialEdge(&reduce, bitcast->mutable_operand(0), HloOpcode::kReduce, - gpu_version); - - return (HasDefaultLayout(bitcast->shape()) && - HasDefaultLayout(bitcast->operand(0)->shape()) && - (reduce != nullptr || - last_dimension(bitcast->operand(0)) == last_dimension(bitcast))); -} - inline bool HasOneUse(const HloInstruction* instr) { return instr->user_count() == 1; } @@ -152,8 +112,7 @@ bool IsTriviallyFusible(HloInstruction* instr, return false; } - if (HloPredicateIsOp(instr) && - BitcastIsTilingNoop(instr, gpu_version)) { + if (HloPredicateIsOp(instr)) { return true; } @@ -188,6 +147,10 @@ bool IsTriviallyFusible(HloInstruction* instr, return false; } +// Returns true if a trivially connected producer of 'consumer' with opcode +// 'opcode' exists. If such an instruction is found, the value of 'producer' is +// set to it. The definition of "trivial" operations is as given in +// 'IsTriviallyFusible'. bool TrivialEdge(HloInstruction** producer, HloInstruction* consumer, HloOpcode opcode, const se::GpuComputeCapability& gpu_version) { @@ -227,36 +190,16 @@ bool IsTriviallyConnectedProducerOf( return false; } -// Finds the first non-fusible producer of a diamond. This instruction is either -// 1. the direct producer of the diamond, if that producer is used more than -// twice and/or is not otherwise trivially fusible -// 2. the first parent instruction of the producer of the diamond such that -// that instruction is used more than once, and/or is not trivially -// fusible. -HloInstruction* FindFirstNonFusibleDiamondProducer( - HloInstruction* diamond_producer, - const se::GpuComputeCapability& gpu_version) { - if (IsTriviallyFusible(diamond_producer, gpu_version, - /*num_allowed_users=*/2)) { - diamond_producer = ChooseOperandForFusionProcessing(diamond_producer); - while (IsTriviallyFusible(diamond_producer, gpu_version)) { - diamond_producer = ChooseOperandForFusionProcessing(diamond_producer); - } - } - - return diamond_producer; -} - -// Creates a fusion corresponding to the input diamond chain. The resulting +// Creates a fusion corresponding to the input diamond. The resulting // fusion instruction is added to the module, but is not yet inserted into the // graph as a replacement of the original instructions. // // TODO(b/347956491): this awkward abstraction is needed to work around // limitations of HloFusionAdaptor, which underpins the implementation of // SymbolicTileAnalysis. We need to come up with a better solution. -absl::StatusOr MakeFusionForDiamondChain( - const DiamondChainDescriptor& diamond_chain) { - auto [root, producer] = diamond_chain; +absl::StatusOr MakeFusionForDiamond( + const DiamondDescriptor& diamond) { + auto [root, producer] = diamond; std::string suggested_name = "triton_softmax"; HloComputation::Builder builder(absl::StrCat(suggested_name, "_computation")); @@ -299,20 +242,20 @@ absl::StatusOr MakeFusionForDiamondChain( root->GetModule()->AddComputationAndUnifyNamesAndIds(builder.Build(), /*is_entry=*/false); - HloInstruction* softmax_fusion = + HloInstruction* normalization_fusion = root->parent()->AddInstruction(HloInstruction::CreateFusion( root->shape(), HloInstruction::FusionKind::kCustom, parameters, computation)); - softmax_fusion->GetModule()->SetAndUniquifyInstrName(softmax_fusion, - "triton_softmax"); + normalization_fusion->GetModule()->SetAndUniquifyInstrName( + normalization_fusion, "triton_softmax"); TF_ASSIGN_OR_RETURN(auto gpu_config, - softmax_fusion->backend_config()); + normalization_fusion->backend_config()); FusionBackendConfig& backend_config = *gpu_config.mutable_fusion_backend_config(); backend_config.set_kind(std::string(kTritonFusionKind)); - TF_RETURN_IF_ERROR(softmax_fusion->set_backend_config(gpu_config)); - return xla::Cast(softmax_fusion); + TF_RETURN_IF_ERROR(normalization_fusion->set_backend_config(gpu_config)); + return xla::Cast(normalization_fusion); } // Runs an HLO pipeline to convert the `module` to the stage as it would look @@ -346,8 +289,8 @@ absl::Status RunFusionPipeline( // Returns a run time estimate for instructions in the `fusion` if they were // fused without SoftmaxRewriterTriton. // -// This can help us understand how effective are ReductionSplitter and -// PriorityFusion for this fusion. +// This can help us understand how effective `ReductionSplitter` and +// `PriorityFusion` are for this fusion. // // In the bigger module, the instructions in the normalization diamond will be // fused with other instructions around it, so it's not an exact estimate, but @@ -365,7 +308,7 @@ EstimateOptimizedHloRunTimeWithoutSoftMaxRewriterTriton( TF_RETURN_IF_ERROR( RunFusionPipeline(new_module.get(), device_info, shape_size)); - VLOG(10) << "priority fusion module: " << new_module->ToString(); + VLOG(3) << "priority fusion module: " << new_module->ToString(); HloComputation* entry_computation = new_module->entry_computation(); GpuHloCostAnalysis::Options cost_analysis_options{ @@ -399,12 +342,12 @@ EstimateOptimizedHloRunTimeWithoutSoftMaxRewriterTriton( // returns a `FusionDecision` to indicate that the function should not happen. absl::StatusOr DecideIfShouldFuseAndMaybeSetBlockLevelParameters( - HloFusionInstruction* softmax_fusion, + HloFusionInstruction* normalization_fusion, GpuPerformanceModelWithIndexingAnalysis& indexing_performance_model, const se::DeviceDescription& device_info, const HloCostAnalysis::ShapeSizeFunction& shape_size, bool use_cost_model_to_evaluate_fusions) { - auto fusion_adaptor = HloFusionAdaptor::ForInstruction(softmax_fusion); + auto fusion_adaptor = HloFusionAdaptor::ForInstruction(normalization_fusion); TF_ASSIGN_OR_RETURN( TiledRunTimeDataOrError tiled_runtime_data_or, @@ -422,11 +365,11 @@ DecideIfShouldFuseAndMaybeSetBlockLevelParameters( if (use_cost_model_to_evaluate_fusions) { TF_ASSIGN_OR_RETURN(absl::Duration run_time_without_softmax_rewriter, EstimateOptimizedHloRunTimeWithoutSoftMaxRewriterTriton( - softmax_fusion, device_info, shape_size)); + normalization_fusion, device_info, shape_size)); - VLOG(5) << "run time estimate if normalization diamond fused together: " + VLOG(2) << "run time estimate if normalization diamond fused together: " << tiled_runtime_data.runtime_data.exec_time; - VLOG(5) + VLOG(2) << "run time estimate if normalization diamond is not fused together: " << run_time_without_softmax_rewriter; @@ -439,73 +382,73 @@ DecideIfShouldFuseAndMaybeSetBlockLevelParameters( } TF_ASSIGN_OR_RETURN(auto backend_config, - softmax_fusion->backend_config()); + normalization_fusion->backend_config()); *backend_config.mutable_fusion_backend_config() ->mutable_block_level_fusion_config() = tiled_runtime_data.block_level_parameters.ToBlockLevelFusionConfig(); - TF_RETURN_IF_ERROR(softmax_fusion->set_backend_config(backend_config)); - VLOG(5) << "Fusing with backend config: " << backend_config.DebugString(); + TF_RETURN_IF_ERROR(normalization_fusion->set_backend_config(backend_config)); + VLOG(2) << "Fusing with backend config: " << backend_config.DebugString(); return FusionDecision::Allow(); } -absl::StatusOr MaybeFuseDiamondChainImpl( - const DiamondChainDescriptor& diamond_chain, +absl::StatusOr MaybeFuseDiamondImpl( + const DiamondDescriptor& diamond, GpuPerformanceModelWithIndexingAnalysis& indexing_performance_model, const se::DeviceDescription& device_info, const HloCostAnalysis::ShapeSizeFunction& shape_size, bool use_cost_model_to_evaluate_fusions) { - TF_ASSIGN_OR_RETURN(HloFusionInstruction * softmax_fusion, - MakeFusionForDiamondChain(diamond_chain)); - HloInstruction* root = diamond_chain.root; + TF_ASSIGN_OR_RETURN(HloFusionInstruction * normalization_fusion, + MakeFusionForDiamond(diamond)); + HloInstruction* root = diamond.root; - VLOG(5) << "MaybeFuseDiamondChainImpl: " << softmax_fusion->ToString(); + VLOG(2) << "MaybeFuseDiamondImpl: " << normalization_fusion->ToString(); TF_ASSIGN_OR_RETURN( FusionDecision fusion_decision, DecideIfShouldFuseAndMaybeSetBlockLevelParameters( - softmax_fusion, indexing_performance_model, device_info, shape_size, - use_cost_model_to_evaluate_fusions)); + normalization_fusion, indexing_performance_model, device_info, + shape_size, use_cost_model_to_evaluate_fusions)); if (!fusion_decision.CanFuse()) { - VLOG(5) << "Not fusing: " << fusion_decision.Explain(); - softmax_fusion->DetachFromOperandsAndUsers(); - TF_RETURN_IF_ERROR( - softmax_fusion->parent()->RemoveInstruction(softmax_fusion)); + VLOG(2) << "Not fusing: " << fusion_decision.Explain(); + normalization_fusion->DetachFromOperandsAndUsers(); + TF_RETURN_IF_ERROR(normalization_fusion->parent()->RemoveInstruction( + normalization_fusion)); return false; } if (root->IsRoot()) { - root->parent()->set_root_instruction(softmax_fusion); + root->parent()->set_root_instruction(normalization_fusion); TF_RETURN_IF_ERROR( root->parent()->RemoveInstructionAndUnusedOperands(root)); } else { TF_RETURN_IF_ERROR( - root->parent()->ReplaceInstruction(root, softmax_fusion)); + root->parent()->ReplaceInstruction(root, normalization_fusion)); } return true; } -// Returns `true` if the diamond chain passed as a parameter can be tiled -// correctly using `SymbolicTileAnalysis`. -absl::StatusOr CanSymbolicTileAnalysisTileDiamondChain( - const DiamondChainDescriptor& diamond_chain, +// Returns `true` if the diamond passed as a parameter can be tiled correctly +// using `SymbolicTileAnalysis`. +absl::StatusOr CanSymbolicTileAnalysisTileDiamond( + const DiamondDescriptor& diamond, const se::DeviceDescription& device_info) { - TF_ASSIGN_OR_RETURN(HloFusionInstruction * softmax_fusion, - MakeFusionForDiamondChain(diamond_chain)); + TF_ASSIGN_OR_RETURN(HloFusionInstruction * normalization_fusion, + MakeFusionForDiamond(diamond)); mlir::MLIRContext context; SymbolicTileAnalysisOrError symbolic_tile_analysis_or_error = SymbolicTileAnalysis::AnalyzeComputation( - *softmax_fusion->called_computation(), &context, + *normalization_fusion->called_computation(), &context, TritonEmitterConstraints::GetBuilder(device_info)); bool can_tile = std::holds_alternative( symbolic_tile_analysis_or_error); - TF_RETURN_IF_ERROR(diamond_chain.root->GetModule()->RemoveEmbeddedComputation( - softmax_fusion->called_computation())); + TF_RETURN_IF_ERROR(diamond.root->GetModule()->RemoveEmbeddedComputation( + normalization_fusion->called_computation())); TF_RETURN_IF_ERROR( - diamond_chain.root->parent()->RemoveInstruction(softmax_fusion)); + diamond.root->parent()->RemoveInstruction(normalization_fusion)); return can_tile; } @@ -624,24 +567,30 @@ DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamondImpl( return FusionDecision::Forbid("Unsupported root-producer connection."); } - VLOG(5) << "Matched Softmax diamond with: "; - VLOG(5) << "root: " << instr->ToString(); - VLOG(5) << "producer: " << producer->ToString(); - VLOG(5) << "broadcast: " << broadcast->ToString(); - VLOG(5) << "reduce: " << reduce->ToString(); + VLOG(2) << "Matched Softmax diamond with: "; + VLOG(2) << "root: " << instr->ToString(); + VLOG(2) << "producer: " << producer->ToString(); + VLOG(2) << "broadcast: " << broadcast->ToString(); + VLOG(2) << "reduce: " << reduce->ToString(); return producer; } -// Returns a vector containing all the single diamonds in the parameter module. -// The diamonds are returned in def-before-use order, and grouped by -// computation. -absl::StatusOr> FindAllFusibleDiamonds( +} // anonymous namespace + +DiamondMatchingDecision +SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond( + HloInstruction* instr) const { + return MatchesTritonCompatibleClosedReductionDiamondImpl( + instr, device_info_.gpu_compute_capability()); +} + +absl::StatusOr> +SoftmaxRewriterTriton::FindAllFusibleNormalizationDiamonds( HloModule& module, - const absl::flat_hash_set& execution_threads, - const se::DeviceDescription& device_info) { - const se::GpuComputeCapability& cc = device_info.gpu_compute_capability(); - std::vector matched_diamonds; + const absl::flat_hash_set& execution_threads) const { + const se::GpuComputeCapability& cc = device_info_.gpu_compute_capability(); + std::vector matched_diamonds; for (HloComputation* comp : module.MakeNonfusionComputations(execution_threads)) { @@ -652,24 +601,24 @@ absl::StatusOr> FindAllFusibleDiamonds( auto producer = MatchesTritonCompatibleClosedReductionDiamondImpl(instr, cc); if (std::holds_alternative(producer)) { - DiamondChainDescriptor diamond_chain{ + DiamondDescriptor diamond{ /*root=*/instr, /*producer=*/std::get(producer)}; - // We filter out the diamond chains that cannot be tiled correctly using + // We filter out the diamonds that cannot be tiled correctly using // `SymbolicTileAnalysis`. - TF_ASSIGN_OR_RETURN(bool can_tile_diamond_chain, - CanSymbolicTileAnalysisTileDiamondChain( - diamond_chain, device_info)); - if (can_tile_diamond_chain) { - matched_diamonds.push_back(diamond_chain); + TF_ASSIGN_OR_RETURN( + bool can_tile_diamond, + CanSymbolicTileAnalysisTileDiamond(diamond, device_info_)); + if (can_tile_diamond) { + matched_diamonds.push_back(diamond); } else { - VLOG(5) << "Cannot tile the diamond pattern described by " + VLOG(2) << "Cannot tile the diamond pattern described by " << "instructions " << instr->ToString() << " and " << std::get(producer)->ToString() << "."; continue; } } else { - VLOG(5) << "Cannot match the diamond pattern for instruction " + VLOG(2) << "Cannot match the diamond pattern for instruction " << instr->ToString() << ". Reason: " << std::get(producer).Explain(); } @@ -679,154 +628,14 @@ absl::StatusOr> FindAllFusibleDiamonds( return matched_diamonds; } -// Returns the size of the reduction dimension of the input diamond. -int64_t GetReductionDimensionSizeForDiamond( - const DiamondChainDescriptor& diamond_chain) { - HloInstruction* diamond_root = diamond_chain.root; - HloInstruction* instr = diamond_root->mutable_operand(1); - while (HloPredicateIsNotOp(instr)) { - instr = ChooseOperandForFusionProcessing(instr); - } - - int operand_rank = instr->operand(0)->shape().rank(); - CHECK_EQ(instr->dimensions().size(), 1); - CHECK_EQ(instr->dimensions(0), operand_rank - 1); - return instr->operand(0)->shape().dimensions(operand_rank - 1); -} - -// Returns a pointer to the last user of `instr` that is trivially fusible. -HloInstruction* GetLastTriviallyFusibleUser( - HloInstruction* instr, const se::GpuComputeCapability& cc) { - while (HasOneUse(instr) && !instr->IsRoot() && - IsTriviallyFusible(instr->users().front(), cc)) { - instr = instr->users().front(); - } - - // We do not care about the number of users for the last instruction of the - // fusion, so attempt to fuse one more instruction with this relaxed - // restriction. - if (HasOneUse(instr) && !instr->IsRoot() && - IsTriviallyFusible( - instr->users().front(), cc, - /*num_allowed_users=*/instr->users().front()->user_count())) { - instr = instr->users().front(); - } - return instr; -} - -} // anonymous namespace - -DiamondMatchingDecision -SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond( - HloInstruction* instr) const { - return MatchesTritonCompatibleClosedReductionDiamondImpl( - instr, device_info_.gpu_compute_capability()); -} - -absl::StatusOr> -SoftmaxRewriterTriton::FindAllFusibleDiamondChains( - HloModule& module, - const absl::flat_hash_set& execution_threads) const { - TF_ASSIGN_OR_RETURN( - std::vector matched_diamonds, - FindAllFusibleDiamonds(module, execution_threads, device_info_)); - - if (matched_diamonds.empty()) { - return std::vector(); - } - - // If we matched several diamonds, it may be possible for some of them to be - // fused together. This is the case if the following conditions hold: - // 1. The path between the root of diamond n towards the producer of - // diamond n+1 is composed only of trivially fusible operations. In that - // case, the first non-trivially fusible producer of diamond n+1 must be - // exactly the root of diamond n. - // 2. The root of diamond n/first non-fusible producer of diamond n+1 must - // have - // a. exactly one user if it is not exactly the producer of diamond - // n+1; - // b/ exactly two users otherwise. - // 3. The axis being reduced must have the same length in all the diamonds - // being fused together. - // - // Crucially, this approach relies on a diamond root never being considered a - // trivially fusible operation. - std::vector diamond_chains; - diamond_chains.reserve(matched_diamonds.size()); - - const se::GpuComputeCapability& cc = device_info_.gpu_compute_capability(); - HloInstruction* current_fusion_producer = - FindFirstNonFusibleDiamondProducer(matched_diamonds.front().producer, cc); - int current_reduce_dimension_size = - GetReductionDimensionSizeForDiamond(matched_diamonds.front()); - - for (int diamond_idx = 1; diamond_idx < matched_diamonds.size(); - ++diamond_idx) { - HloInstruction* diamond_producer = matched_diamonds[diamond_idx].producer; - HloInstruction* previous_diamond_root = - matched_diamonds[diamond_idx - 1].root; - - HloInstruction* first_non_fusible_diamond_producer = - FindFirstNonFusibleDiamondProducer(diamond_producer, cc); - - int diamond_reduce_dimension_size = - GetReductionDimensionSizeForDiamond(matched_diamonds[diamond_idx]); - - if (first_non_fusible_diamond_producer == previous_diamond_root && // 1 - ((first_non_fusible_diamond_producer != diamond_producer && - HasOneUse(first_non_fusible_diamond_producer)) || // 2.a - (first_non_fusible_diamond_producer == diamond_producer && - first_non_fusible_diamond_producer->user_count() == 2)) && // 2.b - diamond_reduce_dimension_size == current_reduce_dimension_size) { // 3 - continue; - } - - // The "last trivially fusible user" chain of diamond chain n should never - // intersect with the "first non fusible diamond producer" chain of diamond - // chain n+1: if these chains intersected, then all the intermediate ops - // between the diamond chains could be trivially fused, and both diamond - // chains could be fused into a single diamond chain. Note that this only - // holds insofar as we do not allow fusing in bitcasts that modify the last - // dimension of the input array. It is however possible for the last - // trivially fusible user of diamond chain n to be the first non fusible - // diamond producer of diamond chain n+1. - diamond_chains.push_back(DiamondChainDescriptor{ - GetLastTriviallyFusibleUser(previous_diamond_root, cc), - current_fusion_producer, - }); - - current_fusion_producer = first_non_fusible_diamond_producer; - current_reduce_dimension_size = diamond_reduce_dimension_size; - } - - // The last diamond chain is still open; close it. - diamond_chains.push_back(DiamondChainDescriptor{ - GetLastTriviallyFusibleUser(matched_diamonds.back().root, cc), - current_fusion_producer}); - - // We filter out the diamond chains that cannot be tiled correctly using - // `SymbolicTileAnalysis`. - std::vector filtered_diamond_chains; - for (const DiamondChainDescriptor& diamond_chain : diamond_chains) { - TF_ASSIGN_OR_RETURN( - bool can_tile_diamond_chain, - CanSymbolicTileAnalysisTileDiamondChain(diamond_chain, device_info_)); - if (can_tile_diamond_chain) { - filtered_diamond_chains.push_back(diamond_chain); - } - } - return filtered_diamond_chains; -} - -absl::StatusOr SoftmaxRewriterTriton::MaybeFuseDiamondChain( - const DiamondChainDescriptor& diamond_chain) { +absl::StatusOr SoftmaxRewriterTriton::MaybeFuseNormalizationDiamond( + const DiamondDescriptor& diamond) { HloFusionAnalysisCache fusion_analysis_cache(device_info_); GpuPerformanceModelWithIndexingAnalysis indexing_performance_model( &device_info_, &fusion_analysis_cache, shape_size_, &mlir_context_); - return MaybeFuseDiamondChainImpl(diamond_chain, indexing_performance_model, - device_info_, shape_size_, - use_cost_model_to_evaluate_fusions_); + return MaybeFuseDiamondImpl(diamond, indexing_performance_model, device_info_, + shape_size_, use_cost_model_to_evaluate_fusions_); } absl::StatusOr SoftmaxRewriterTriton::Run( @@ -835,16 +644,17 @@ absl::StatusOr SoftmaxRewriterTriton::Run( TF_RETURN_IF_ERROR(EnsureTritonSupportsComputeCapability( device_info_.gpu_compute_capability())); - TF_ASSIGN_OR_RETURN(std::vector diamond_chains, - FindAllFusibleDiamondChains(*module, execution_threads)); + TF_ASSIGN_OR_RETURN( + std::vector diamonds, + FindAllFusibleNormalizationDiamonds(*module, execution_threads)); bool changed = false; - // The diamond chains must be emitted in reverse order, to make sure that - // producer instructions are emitted correctly when the root of - // diamond chain n is exactly the producer of diamond chain n+1. - for (auto diamond_chain = diamond_chains.rbegin(); - diamond_chain != diamond_chains.rend(); ++diamond_chain) { - TF_ASSIGN_OR_RETURN(bool fused, MaybeFuseDiamondChain(*diamond_chain)); + // The diamonds must be emitted in reverse order, to make sure that producer + // instructions are emitted correctly when the root of diamond n is exactly + // the producer of diamond n+1. + for (auto diamond = diamonds.rbegin(); diamond != diamonds.rend(); + ++diamond) { + TF_ASSIGN_OR_RETURN(bool fused, MaybeFuseNormalizationDiamond(*diamond)); changed |= fused; } return changed; diff --git a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.h b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.h index 22b26304cfc3ba..8f904cf800d5fd 100644 --- a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.h +++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.h @@ -22,13 +22,10 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "absl/time/time.h" #include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/service/gpu/model/gpu_indexing_performance_model.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/instruction_fusion.h" #include "xla/stream_executor/device_description.h" @@ -36,7 +33,7 @@ limitations under the License. namespace xla { namespace gpu { -struct DiamondChainDescriptor { +struct DiamondDescriptor { HloInstruction* root = nullptr; HloInstruction* producer = nullptr; }; @@ -66,21 +63,22 @@ class SoftmaxRewriterTriton : public HloModulePass { HloModule* module, const absl::flat_hash_set& execution_threads) override; - // Finds and returns all the fusible diamond chains in the module. The + // Finds and returns all the fusible normalization diamonds in the module. The // resulting vector is sorted according to a post-order matching (i.e. within // the same computation, producer diamonds appear before consumer diamonds). - absl::StatusOr> - FindAllFusibleDiamondChains( + absl::StatusOr> + FindAllFusibleNormalizationDiamonds( HloModule& module, const absl::flat_hash_set& execution_threads) const; - // Constructs a Softmax fusion containing all the instructions between the - // root and the producer of a diamond chain. The producer is excluded from the + // Constructs a normalization fusion containing all the instructions between + // the root and the producer of a diamond. The producer is excluded from the // fusion. - // Returns `true` if the diamond chain was successfully fused. Otherwise, + // + // Returns `true` if the diamond was successfully fused. Otherwise, // returns `false` if, for example, the resulting fusion cannot be tiled. - absl::StatusOr MaybeFuseDiamondChain( - const DiamondChainDescriptor& diamond_chain); + absl::StatusOr MaybeFuseNormalizationDiamond( + const DiamondDescriptor& diamond_chain); // Return the producer of the following pattern: // diff --git a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc index 08f124ebd1882c..a1a80bb826f544 100644 --- a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc @@ -47,7 +47,7 @@ namespace m = ::xla::match; using ::testing::HasSubstr; bool HasBlockLevelFusionConfig(const HloInstruction* fusion) { - return fusion->opcode() == HloOpcode::kFusion && + return HloPredicateIsOp(fusion) && fusion->has_backend_config() && fusion->backend_config().ok() && fusion->backend_config() @@ -64,7 +64,7 @@ class SoftmaxRewriterTritonTest HloCostAnalysis::DefaultShapeSize}; }; -TEST_F(SoftmaxRewriterTritonTest, CanFuseExactSoftmaxF32) { +TEST_F(SoftmaxRewriterTritonTest, CanFuseSingleNormalizationF32) { const std::string hlo_string = R"( HloModule softmax max_computation { @@ -73,23 +73,17 @@ max_computation { ROOT maximum = f32[] maximum(arg_0, arg_1) } add_computation { - arg_0.1 = f32[] parameter(0) - arg_1.1 = f32[] parameter(1) - ROOT add = f32[] add(arg_0.1, arg_1.1) + arg_0 = f32[] parameter(0) + arg_1 = f32[] parameter(1) + ROOT add = f32[] add(arg_0, arg_1) } ENTRY main { param_0 = f32[127,125]{1,0} parameter(0) constant_neg_inf = f32[] constant(-inf) reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} - subtract = f32[127,125]{1,0} subtract(param_0, broadcast) - exponential = f32[127,125]{1,0} exponential(subtract) - constant_zero = f32[] constant(0) - second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation - second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0} - ROOT divide = f32[127,125]{1,0} divide(exponential, second_broadcast) -} -)"; + ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast) +})"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); @@ -103,7 +97,7 @@ ENTRY main { } TEST_F(SoftmaxRewriterTritonTest, - CanFuseSoftmaxLikeComputationWithNonF32DataType) { + CanFuseSignleNormalizationWithNonF32DataType) { const std::string hlo_string = R"( HloModule softmax max_computation { @@ -112,25 +106,17 @@ max_computation { ROOT maximum = f16[] maximum(arg_0, arg_1) } add_computation { - arg_0.1 = f16[] parameter(0) - arg_1.1 = f16[] parameter(1) - ROOT add = f16[] add(arg_0.1, arg_1.1) + arg_0 = f16[] parameter(0) + arg_1 = f16[] parameter(1) + ROOT add = f16[] add(arg_0, arg_1) } ENTRY main { param_0 = f16[127,125]{1,0} parameter(0) constant_neg_inf = f16[] constant(-inf) reduce = f16[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation broadcast = f16[127,125]{1,0} broadcast(reduce), dimensions={0} - subtract = f16[127,125]{1,0} subtract(param_0, broadcast) - exp = f16[127,125]{1,0} exponential(subtract) - constant_zero = f16[] constant(0) - second_reduce = f16[127]{0} reduce(exp, constant_zero), dimensions={1}, to_apply=add_computation - second_broadcast = f16[127,125]{1,0} broadcast(second_reduce), dimensions={0} - // Replace divide with multiply, because Triton doesn't support f16 - // divisions. - ROOT multiply = f16[127,125]{1,0} multiply(exp, second_broadcast) -} -)"; + ROOT subtract = f16[127,125]{1,0} subtract(param_0, broadcast) +})"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); @@ -345,107 +331,6 @@ ENTRY main { EXPECT_FALSE(fusion_rewriter_.Run(module.get()).value()); } -TEST_F(SoftmaxRewriterTritonTest, - CanFuseSoftmaxWithIntermediateUnaryElementwise) { - const std::string hlo_string = R"( -HloModule softmax -max_computation { - arg_0 = f32[] parameter(0) - arg_1 = f32[] parameter(1) - ROOT maximum = f32[] maximum(arg_0, arg_1) -} -add_computation { - arg_0.1 = f32[] parameter(0) - arg_1.1 = f32[] parameter(1) - ROOT add = f32[] add(arg_0.1, arg_1.1) -} -ENTRY main { - param_0 = f32[127,125]{1,0} parameter(0) - constant_neg_inf = f32[] constant(-inf) - reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation - broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} - subtract = f32[127,125]{1,0} subtract(param_0, broadcast) - abs = f32[127,125]{1,0} abs(subtract) - exponential = f32[127,125]{1,0} exponential(abs) - constant_zero = f32[] constant(0) - second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation - second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0} - ROOT divide = f32[127,125]{1,0} divide(exponential, second_broadcast) -} -)"; - auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - - EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch( - m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); -} - -TEST_F(SoftmaxRewriterTritonTest, - CanFuseTwoDiamondsWithSecondDiamondProducerEqualToFirstDiamondRoot) { - const std::string hlo_string = R"( -HloModule softmax -max_computation { - arg_0 = f32[] parameter(0) - arg_1 = f32[] parameter(1) - ROOT maximum = f32[] maximum(arg_0, arg_1) -} -add_computation { - arg_0.1 = f32[] parameter(0) - arg_1.1 = f32[] parameter(1) - ROOT add = f32[] add(arg_0.1, arg_1.1) -} -ENTRY main { - param_0 = f32[127,125]{1,0} parameter(0) - constant_neg_inf = f32[] constant(-inf) - reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation - broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} - subtract = f32[127,125]{1,0} subtract(param_0, broadcast) - constant_zero = f32[] constant(0) - second_reduce = f32[127]{0} reduce(subtract, constant_zero), dimensions={1}, to_apply=add_computation - second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0} - ROOT divide = f32[127,125]{1,0} divide(subtract, second_broadcast) -} -)"; - auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - - EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch( - m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); -} - -TEST_F(SoftmaxRewriterTritonTest, - CanFuseDiamondWithTrailingUnaryElementwiseAtTheRoot) { - const std::string hlo_string = R"( -HloModule softmax -max_computation { - arg_0 = f32[] parameter(0) - arg_1 = f32[] parameter(1) - ROOT maximum = f32[] maximum(arg_0, arg_1) -} -ENTRY main { - param_0 = f32[127,125]{1,0} parameter(0) - constant_neg_inf = f32[] constant(-inf) - reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation - broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} - subtract = f32[127,125]{1,0} subtract(param_0, broadcast) - ROOT abs = f32[127,125]{1,0} abs(subtract) -} -)"; - auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch( - m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); -} - TEST_F(SoftmaxRewriterTritonTest, CanFuseDiamondWithUnaryElementwisePrefix) { const std::string hlo_string = R"( HloModule softmax @@ -599,153 +484,6 @@ ENTRY main { EXPECT_FALSE(fusion_rewriter_.Run(module.get()).value()); } -TEST_F(SoftmaxRewriterTritonTest, - CanNotFuseTwoDiamondsWithDifferentReductionAxisSizeTogether) { - const std::string hlo_string = R"( -HloModule softmax -max_computation { - arg_0 = f32[] parameter(0) - arg_1 = f32[] parameter(1) - ROOT maximum = f32[] maximum(arg_0, arg_1) -} -add_computation { - arg_0.1 = f32[] parameter(0) - arg_1.1 = f32[] parameter(1) - ROOT add = f32[] add(arg_0.1, arg_1.1) -} -ENTRY main { - param_0 = f32[127,625]{1,0} parameter(0) - constant_neg_inf = f32[] constant(-inf) - reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation - broadcast = f32[127,625]{1,0} broadcast(reduce), dimensions={0} - subtract = f32[127,625]{1,0} subtract(param_0, broadcast) - bitcasted_subtract = f32[127,5,125] bitcast(subtract) - exponential = f32[127,5,125] exponential(bitcasted_subtract) - constant_zero = f32[] constant(0) - second_reduce = f32[127,5] reduce(exponential, constant_zero), dimensions={2}, to_apply=add_computation - second_broadcast = f32[127,5,125] broadcast(second_reduce), dimensions={0,1} - ROOT divide = f32[127,5,125] divide(exponential, second_broadcast) -} -)"; - auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - - EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch( - m::Fusion(m::Bitcast(m::Fusion(m::Parameter()) - .WithPredicate(HasBlockLevelFusionConfig))) - .WithPredicate(HasBlockLevelFusionConfig))); -} - -TEST_F(SoftmaxRewriterTritonTest, - CanNotFuseTwoDiamondsWithExtraUsageForFirstDiamondRoot) { - const std::string hlo_string = R"( -HloModule softmax -max_computation { - arg_0 = f32[] parameter(0) - arg_1 = f32[] parameter(1) - ROOT maximum = f32[] maximum(arg_0, arg_1) -} -add_computation { - arg_0.1 = f32[] parameter(0) - arg_1.1 = f32[] parameter(1) - ROOT add = f32[] add(arg_0.1, arg_1.1) -} -ENTRY main { - param_0 = f32[127,125]{1,0} parameter(0) - constant_neg_inf = f32[] constant(-inf) - reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation - broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} - subtract = f32[127,125]{1,0} subtract(param_0, broadcast) - exponential = f32[127,125]{1,0} exponential(subtract) - constant_zero = f32[] constant(0) - second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation - second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0} - divide = f32[127,125]{1,0} divide(exponential, second_broadcast) - ROOT tuple = (f32[127,125]{1,0}, f32[127,125]{1,0}) tuple(divide, subtract) -} -)"; - auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - - EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch(m::Tuple( - m::Fusion(m::Fusion()).WithPredicate(HasBlockLevelFusionConfig), - m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)))); -} - -TEST_F(SoftmaxRewriterTritonTest, - CanNotFuseTwoDiamondsWithExtraUsageForSecondDiamondProducer) { - const std::string hlo_string = R"( -HloModule softmax -max_computation { - arg_0 = f32[] parameter(0) - arg_1 = f32[] parameter(1) - ROOT maximum = f32[] maximum(arg_0, arg_1) -} -add_computation { - arg_0.1 = f32[] parameter(0) - arg_1.1 = f32[] parameter(1) - ROOT add = f32[] add(arg_0.1, arg_1.1) -} -ENTRY main { - param_0 = f32[127,125]{1,0} parameter(0) - constant_neg_inf = f32[] constant(-inf) - reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation - broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} - subtract = f32[127,125]{1,0} subtract(param_0, broadcast) - exponential = f32[127,125]{1,0} exponential(subtract) - constant_zero = f32[] constant(0) - second_reduce = f32[127]{0} reduce(exponential, constant_zero), dimensions={1}, to_apply=add_computation - second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0} - divide = f32[127,125]{1,0} divide(exponential, second_broadcast) - ROOT tuple = (f32[127,125]{1,0}, f32[127,125]{1,0}) tuple(divide, exponential) -} -)"; - auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - - EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch(m::Tuple( - m::Fusion(m::Fusion()).WithPredicate(HasBlockLevelFusionConfig), - m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig)))); -} - -TEST_F(SoftmaxRewriterTritonTest, - CanFuseSoftmaxDiamondWithTritonIncompatibleProducer) { - const std::string hlo_string = R"( -HloModule softmax -max_computation { - arg_0 = f32[] parameter(0) - arg_1 = f32[] parameter(1) - ROOT maximum = f32[] maximum(arg_0, arg_1) -} - -ENTRY main { - param_0 = f16[127,125]{1,0} parameter(0) - round-nearest-even = f16[127,125] round-nearest-even(param_0) - convert = f32[127,125] convert(round-nearest-even) - constant_neg_inf = f32[] constant(-inf) - reduce = f32[127]{0} reduce(convert, constant_neg_inf), dimensions={1}, to_apply=max_computation - broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} - ROOT subtract = f32[127,125]{1,0} subtract(convert, broadcast) -})"; - auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - - EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::RoundNearestEven(m::Parameter())) - .WithPredicate(HasBlockLevelFusionConfig))); -} - TEST_F(SoftmaxRewriterTritonTest, CanNotFuseSoftmaxDiamondWithNonFusibleBitcastBetweenReduceAndProducer) { const std::string hlo_string = R"( @@ -771,8 +509,7 @@ ENTRY main { EXPECT_FALSE(fusion_rewriter_.Run(module.get()).value()); } -TEST_F(SoftmaxRewriterTritonTest, - CanFuseSoftmaxDiamondWithBitcastProducerFollowedByBitcastsOnEachUse) { +TEST_F(SoftmaxRewriterTritonTest, CanFuseSoftmaxDiamondWithBitcastsOnEachUse) { const std::string hlo_string = R"( HloModule softmax @@ -783,10 +520,9 @@ max_computation { } ENTRY main { - param_0 = f32[1,1,127,125]{3,2,1,0} parameter(0) - bitcast_parent = f32[127,125]{1,0} bitcast(param_0) - bitcast_0 = f32[127,125]{1,0} bitcast(bitcast_parent) - bitcast_1 = f32[127,125]{1,0} bitcast(bitcast_parent) + param_0 = f32[127,125]{1,0} parameter(0) + bitcast_0 = f32[127,125]{1,0} bitcast(param_0) + bitcast_1 = f32[127,125]{1,0} bitcast(param_0) constant_neg_inf = f32[] constant(-inf) reduce = f32[127]{0} reduce(bitcast_0, constant_neg_inf), dimensions={1}, to_apply=max_computation broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} @@ -858,32 +594,6 @@ ENTRY main { .ok()); } -TEST_F(SoftmaxRewriterTritonTest, - CanFuseBinaryElementwiseProducerIntoDiamondWhenBothOperandsAreTheSame) { - const std::string hlo_string = R"( -HloModule fusible_diamond -max_computation { - arg_0 = f32[] parameter(0) - arg_1 = f32[] parameter(1) - ROOT maximum = f32[] maximum(arg_0, arg_1) -} -ENTRY main { - param_0 = f32[127,125]{1,0} parameter(0) - multiply = f32[127,125]{1,0} multiply(param_0, param_0) - constant_neg_inf = f32[] constant(-inf) - reduce = f32[127]{0} reduce(multiply, constant_neg_inf), dimensions={1}, to_apply=max_computation - broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} - ROOT subtract = f32[127,125]{1,0} subtract(multiply, broadcast) -})"; - auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch( - m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); -} - TEST_F( SoftmaxRewriterTritonTest, CanFuseIntermediateBinaryElementwiseWithinDiamondWhenBothOperandsAreTheSame) { // NOLINT(whitespace/line_length) @@ -912,74 +622,6 @@ ENTRY main { m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } -TEST_F(SoftmaxRewriterTritonTest, - CanFuseBinaryElementwiseWhenBothOperandsAreTheSameBetweenDiamonds) { - const std::string hlo_string = R"( -HloModule fusible_diamonds -max_computation { - arg_0 = f32[] parameter(0) - arg_1 = f32[] parameter(1) - ROOT maximum = f32[] maximum(arg_0, arg_1) -} -add_computation { - arg_0.1 = f32[] parameter(0) - arg_1.1 = f32[] parameter(1) - ROOT add = f32[] add(arg_0.1, arg_1.1) -} -ENTRY main { - param_0 = f32[127,125]{1,0} parameter(0) - constant_neg_inf = f32[] constant(-inf) - reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation - broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} - subtract = f32[127,125]{1,0} subtract(param_0, broadcast) - multiply = f32[127,125]{1,0} multiply(subtract, subtract) - constant_zero = f32[] constant(0) - second_reduce = f32[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation - second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0} - ROOT subtract_second = f32[127,125]{1,0} subtract(multiply, second_broadcast) -} -)"; - auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch( - m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); -} - -TEST_F(SoftmaxRewriterTritonTest, - CanFuseBinaryElementwiseConsumerWhereBothOperandsAreTheSameIntoDiamond) { - const std::string hlo_string = R"( -HloModule fusible_diamond -max_computation { - arg_0 = f32[] parameter(0) - arg_1 = f32[] parameter(1) - ROOT maximum = f32[] maximum(arg_0, arg_1) -} -add_computation { - arg_0.1 = f32[] parameter(0) - arg_1.1 = f32[] parameter(1) - ROOT add = f32[] add(arg_0.1, arg_1.1) -} -ENTRY main { - param_0 = f32[127,125]{1,0} parameter(0) - constant_neg_inf = f32[] constant(-inf) - reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation - broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} - subtract = f32[127,125]{1,0} subtract(param_0, broadcast) - ROOT multiply = f32[127,125]{1,0} multiply(subtract, subtract) -} -)"; - auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch( - m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); -} - TEST_F( SoftmaxRewriterTritonTest, DoesNotFuseIntermediateBinaryElementwiseWithBothSplatOperandsIntoDiamond) { @@ -1070,74 +712,6 @@ ENTRY main.30 { m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } -TEST_F( - SoftmaxRewriterTritonTest, - CanFuseAndEmitBinaryElementwiseWhereTheFirstOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length) - const std::string hlo_string = R"( -HloModule fusible_diamonds -add_computation { - arg_0.1 = f32[] parameter(0) - arg_1.1 = f32[] parameter(1) - ROOT add = f32[] add(arg_0.1, arg_1.1) -} -ENTRY main { - param_0 = f32[127,125]{1,0} parameter(0) - constant_neg_inf = f32[] constant(-inf) - reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation - broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} - subtract = f32[127,125]{1,0} subtract(param_0, broadcast) - constant = f32[] constant(0.333333343) - broadcast_splat = f32[127,125]{1,0} broadcast(constant), dimensions={} - multiply = f32[127,125]{1,0} multiply(broadcast_splat, subtract) - constant_zero = f32[] constant(0) - second_reduce = f32[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation - second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0} - ROOT second_subtract = f32[127,125]{1,0} subtract(multiply, second_broadcast) -} -)"; - auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch( - m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); -} - -TEST_F( - SoftmaxRewriterTritonTest, - CanFuseAndEmitBinaryElementwiseWhereTheSecondOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length) - const std::string hlo_string = R"( -HloModule fusible_diamonds -add_computation { - arg_0.1 = f32[] parameter(0) - arg_1.1 = f32[] parameter(1) - ROOT add = f32[] add(arg_0.1, arg_1.1) -} -ENTRY main { - param_0 = f32[127,125]{1,0} parameter(0) - constant_neg_inf = f32[] constant(-inf) - reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation - broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} - subtract = f32[127,125]{1,0} subtract(param_0, broadcast) - constant = f32[] constant(0.333333343) - broadcast_splat = f32[127,125]{1,0} broadcast(constant), dimensions={} - multiply = f32[127,125]{1,0} multiply(subtract, broadcast_splat) - constant_zero = f32[] constant(0) - second_reduce = f32[127]{0} reduce(multiply, constant_zero), dimensions={1}, to_apply=add_computation - second_broadcast = f32[127,125]{1,0} broadcast(second_reduce), dimensions={0} - ROOT second_subtract = f32[127,125]{1,0} subtract(multiply, second_broadcast) -} -)"; - auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch( - m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); -} - TEST_F( SoftmaxRewriterTritonTest, CanFuseBinaryElementwiseWhereTheFirstOperandIsASplatConstantWithinDiamond) { @@ -1168,33 +742,6 @@ ENTRY main { m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); } -TEST_F(SoftmaxRewriterTritonTest, - CanFuseBinaryElementwiseConsumerWhereTheFirstOperandIsASplatConstant) { - const std::string hlo_string = R"( -HloModule fusible_diamond -add_computation { - arg_0.1 = f32[] parameter(0) - arg_1.1 = f32[] parameter(1) - ROOT add = f32[] add(arg_0.1, arg_1.1) -} -ENTRY main { - param_0 = f32[127,125]{1,0} parameter(0) - constant_neg_inf = f32[] constant(-inf) - reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=add_computation - broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} - subtract = f32[127,125]{1,0} subtract(param_0, broadcast) - constant = f32[] constant(0.333333343) - broadcast_splat = f32[127,125]{1,0} broadcast(constant), dimensions={} - ROOT multiply = f32[127,125]{1,0} multiply(broadcast_splat, subtract) -})"; - auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); - EXPECT_TRUE(verifier().Run(module.get()).status().ok()); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch( - m::Fusion(m::Parameter()).WithPredicate(HasBlockLevelFusionConfig))); -} TEST_F(SoftmaxRewriterTritonTest, CanFuseBinaryElementwiseOperationWhereOneOperandIsASharedSplatProducer) { @@ -1570,10 +1117,8 @@ ENTRY main { reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation add = f32[127]{0} add(broadcast_from_scalar, reduce) broadcast = f32[127,125]{1,0} broadcast(add), dimensions={0} - subtract = f32[127,125]{1,0} subtract(param_0, broadcast) - ROOT abs = f32[127,125]{1,0} abs(subtract) -} -)"; + ROOT subtract = f32[127,125]{1,0} subtract(param_0, broadcast) +})"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc index d1172eaaf893e8..ad46e3847a36da 100644 --- a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc @@ -44,7 +44,7 @@ namespace { bool IsOnlyRootNonDefaultStream(HloComputation* computation) { HloInstruction* root = computation->root_instruction(); auto root_gpu_config = root->backend_config(); - if (!root_gpu_config.ok() || root->opcode() == HloOpcode::kTuple) { + if (!root_gpu_config.ok() || HloPredicateIsOp(root)) { return false; } int64_t root_stream_id = root_gpu_config->operation_queue_id(); @@ -155,7 +155,7 @@ absl::StatusOr AnnotateStreamAttributesForUsers( } std::vector all_consumers; for (auto user : instr->users()) { - if (user->opcode() == HloOpcode::kGetTupleElement) { + if (HloPredicateIsOp(user)) { user = user->users()[0]; } all_consumers.push_back(user); @@ -194,13 +194,12 @@ absl::StatusOr StreamAttributeAnnotator::Run( // For fusion instruction, only annotate // when the root of fusion is a single instruction // running on non-default stream. - if (instr->opcode() == HloOpcode::kFusion) { + if (HloPredicateIsOp(instr)) { TF_ASSIGN_OR_RETURN(bool comp_result, AnnotateStreamAttributesForInstruction( instr, instr_gpu_config.value())); changed |= comp_result; - } else if (instr->opcode() == HloOpcode::kCopyStart && - module->has_schedule()) { + } else if (instr->opcode() == HloOpcode::kCopyStart) { TF_ASSIGN_OR_RETURN(bool comp_result, AnnotateStreamAttributesForCopyStart( instr, channel_id, instr_gpu_config.value())); @@ -208,8 +207,7 @@ absl::StatusOr StreamAttributeAnnotator::Run( continue; } else if (comp->IsAsyncComputation() && (instr->opcode() == HloOpcode::kDynamicSlice || - instr->opcode() == HloOpcode::kDynamicUpdateSlice) && - module->has_schedule()) { + instr->opcode() == HloOpcode::kDynamicUpdateSlice)) { TF_ASSIGN_OR_RETURN(bool comp_result, WrapIntoFusionAndAnnotateStreamAttributes( instr, channel_id, instr_gpu_config.value(), diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc index 5e0e1f50d2ccc5..247286b99c211d 100644 --- a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc @@ -185,7 +185,7 @@ TEST_F(StreamAttributeAnnotatorTest, FusionIsAnnotated) { TEST_F(StreamAttributeAnnotatorTest, CopyStartIsAnnotated) { constexpr absl::string_view kHloString = R"( - HloModule offloading, is_scheduled=true + HloModule offloading ENTRY %main (param_0: f32[1024], param_1: f32[1024]) -> f32[1024] { %param_1 = f32[1024]{0} parameter(1) %param_0 = f32[1024]{0} parameter(0) @@ -250,7 +250,7 @@ TEST_F(StreamAttributeAnnotatorTest, DynamicUpdateSliceWrappedAndAnnotated) { TF_ASSERT_OK_AND_ASSIGN( bool changed, - StreamAttributeAnnotator(device_description()).Run(module.get())); + StreamAttributeAnnotator{device_description()}.Run(module.get())); EXPECT_TRUE(changed); // Check that the dynamic-update-slice instruction is wrapped in a fusion @@ -314,7 +314,7 @@ TEST_F(StreamAttributeAnnotatorTest, DynamicSliceWrappedAndAnnotated) { EXPECT_TRUE(module->has_schedule()); TF_ASSERT_OK_AND_ASSIGN( bool changed, - StreamAttributeAnnotator(device_description()).Run(module.get())); + StreamAttributeAnnotator{device_description()}.Run(module.get())); EXPECT_TRUE(changed); // Check that the dynamic-slice instruction is wrapped in a fusion diff --git a/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc b/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc index c8a3a54639d85b..f21681f39336a3 100644 --- a/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -56,7 +55,7 @@ using ::testing::Values; // - batch_size // - dtype using ParameterizedInterface = - ::testing::WithParamInterface>; + ::testing::WithParamInterface>; class TopkTest : public HloTestBase, public ParameterizedInterface { public: @@ -74,7 +73,7 @@ class TopkTest : public HloTestBase, public ParameterizedInterface { protected: absl::StatusOr> TopkHlo(int n, int k, int batch_size, - std::string_view dtype) { + absl::string_view dtype) { return ParseAndReturnVerifiedModule(absl::Substitute( R"( %compare { diff --git a/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc b/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc index 41ba13500c4182..385e06077b9c2c 100644 --- a/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc +++ b/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc @@ -75,7 +75,7 @@ class TopkSplitterVisitor : public DfsHloRewriteVisitor { if (n % kRequiredAlignment != 0) { return absl::OkStatus(); } - if (n < split_threshold_) return absl::OkStatus(); + if (n <= split_threshold_) return absl::OkStatus(); int new_batch = std::min(absl::bit_floor(n / split_threshold_), kMaximumBatchSize); int new_n = n / new_batch; diff --git a/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc index 0814b0ef71b726..ee69e380ce8ebd 100644 --- a/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc @@ -42,6 +42,7 @@ namespace xla { namespace gpu { namespace { +using ::tsl::testing::IsOk; using ::tsl::testing::IsOkAndHolds; using TopkSplitterTest = HloTestBase; @@ -204,6 +205,26 @@ ENTRY cluster { EXPECT_TRUE(RunAndCompare(std::move(module), std::nullopt, round_trip)); } +TEST_F(TopkSplitterTest, HandlesDimensionsEqualToThresholdCorrectly) { + // This test was added since initially TopkSplitter was going into an + // infinite loop when the split threshold was equal to the dimension of the + // input. + const std::string hlo_string = absl::Substitute(R"( +HloModule module +$0 +ENTRY cluster { + %arg.1 = f32[1,1024] parameter(0) + ROOT %topk.1 = (f32[1,5], s32[1,5]) custom-call(%arg.1), custom_call_target= "TopK", to_apply=%compare +})", + kComparator); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_THAT(RunHloPass(TopKSplitter(1024), module.get()), IsOk()); + // We expect idempotency - No change on the second run. + EXPECT_THAT(RunHloPass(TopKSplitter(1024), module.get()), + IsOkAndHolds(false)); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc index f25e5ee407fa5f..01f6c891a48cc6 100644 --- a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc +++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc @@ -64,7 +64,7 @@ using ProfilingOutput = AutotunerCompileUtil::ProfilingOutput; // Triton fusion. Otherwise, returns nullptr. absl::StatusOr AsTritonFusion( const HloInstruction* hlo) { - if (hlo->opcode() != HloOpcode::kFusion) { + if (HloPredicateIsNotOp(hlo)) { return nullptr; } const HloFusionInstruction* fusion = Cast(hlo); @@ -90,7 +90,7 @@ absl::StatusOr> NewHloModuleWithoutTritonFromFusion( new_module->mutable_config().set_debug_options(debug_opts); new_module->mutable_config() .mutable_debug_options() - .clear_xla_gpu_experimental_enable_triton_softmax_priority_fusion(); + .add_xla_disable_hlo_passes("triton-softmax-rewriter"); TreeReductionRewriter tree_reduction_rewriter(gpu_device_info); TF_RETURN_IF_ERROR(tree_reduction_rewriter.Run(new_module.get()).status()); @@ -138,15 +138,12 @@ absl::StatusOr CompileAndRunFusion( fusion, config, debug_opts, RedzoneBuffers::kAllInputs)); TF_ASSIGN_OR_RETURN(auto stream, config.GetStream()); - TF_ASSIGN_OR_RETURN(std::optional profiling_output, + TF_ASSIGN_OR_RETURN(ProfilingOutput profiling_output, util.ProfileExecutable(executable.get(), stream, rz_buffers.input_buffers(), rz_buffers.input_shapes())); - if (!profiling_output.has_value()) { - return Internal("No output after a successful verification run."); - } - return std::move(profiling_output->output); + return std::move(profiling_output).output; } absl::Status CompareBuffers(const ScopedShapedBuffer& current, @@ -246,9 +243,8 @@ absl::StatusOr TritonFusionNumericsVerifier::Run( debug_options.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning( false); - TF_ASSIGN_OR_RETURN(std::optional opt_compile_util, + TF_ASSIGN_OR_RETURN(AutotunerCompileUtil compile_util, AutotunerCompileUtil::Create(config_, debug_options)); - TF_RET_CHECK(opt_compile_util.has_value()); TF_RETURN_IF_ERROR(triton_fusion_numerics_pass_internal::ForAllTritonFusions( *module, execution_threads, [&](const HloFusionInstruction& fusion) { @@ -258,8 +254,8 @@ absl::StatusOr TritonFusionNumericsVerifier::Run( ++cache_hits_; return it->second; } - auto result = VerifyTritonFusion(*opt_compile_util, fusion, config_, - debug_options); + auto result = + VerifyTritonFusion(compile_util, fusion, config_, debug_options); fusion_result_cache_[key] = result; return result; })); diff --git a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc index f615f4b1dc4aac..73f85084737c7b 100644 --- a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc @@ -46,8 +46,6 @@ class TritonFusionNumericsVerifierTest public: DebugOptions GetDebugOptionsForTest() const override { auto options = HloTestBase::GetDebugOptionsForTest(); - options.set_xla_gpu_experimental_enable_triton_softmax_priority_fusion( - true); options.set_xla_gpu_verify_triton_fusion_numerics(true); return options; } @@ -84,11 +82,10 @@ class TritonFusionNumericsVerifierTest } AutotunerCompileUtil CreateAutotunerCompileUtil(AutotuneConfig& config) { - auto opt_compile_util_or = + auto compile_util_or = AutotunerCompileUtil::Create(config, GetDebugOptionsForTest()); - TF_EXPECT_OK(opt_compile_util_or); - EXPECT_TRUE(opt_compile_util_or->has_value()); - return std::move(opt_compile_util_or->value()); + TF_EXPECT_OK(compile_util_or); + return std::move(compile_util_or).value(); } }; diff --git a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc index 1a62c6011208d8..df86c0901e3bc2 100644 --- a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc +++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc @@ -86,11 +86,9 @@ absl::StatusOr ShiftDequantizationF8(HloComputation* while_body) { HloInstruction* operand = param_tuple->mutable_operand(k); // Capture bitcast, broadcast, copy, reshape and transpose ops between // dequantization and the loop. - while (operand->opcode() == HloOpcode::kBitcast || - operand->opcode() == HloOpcode::kBroadcast || - operand->opcode() == HloOpcode::kCopy || - operand->opcode() == HloOpcode::kReshape || - operand->opcode() == HloOpcode::kTranspose) { + while (HloPredicateIsOp(operand)) { unaries[k].push_back(operand); operand = operand->mutable_operand(0); } @@ -358,7 +356,7 @@ bool FindDusSliceForCachedActivation(HloInstruction* inst, HloInstruction** slice_indices, bool is_first_slice) { // We are only interested in DUS in the loop body. - if (inst->opcode() != HloOpcode::kDynamicUpdateSlice) { + if (HloPredicateIsNotOp(inst)) { return false; } // Check that the first operand of DUS is a: @@ -425,7 +423,7 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( // collective-permute HloInstruction* first_cp_output; for (HloInstruction* gte_user : input_gte->users()) { - if (gte_user->opcode() == HloOpcode::kCollectivePermute) { + if (HloPredicateIsOp(gte_user)) { first_cp_output = gte_user; break; } @@ -690,7 +688,7 @@ absl::Status PostProcessUnrolledLoop(HloInstruction* loop, int64_t stream_id) { SetForceDelayForInstruction(matched_cp, /*force_delay=*/true)); } - if (inst->opcode() == HloOpcode::kDot) { + if (HloPredicateIsOp(inst)) { // Dispatch the dot to additional compute stream. TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(inst, stream_id)); ++stream_id; @@ -746,7 +744,7 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { allowed_intermediate_ops.insert(allowed_intermediate_ops.end(), std::begin(curr->operands()), std::end(curr->operands())); - } else if (curr->opcode() == HloOpcode::kAllToAll && + } else if (HloPredicateIsOp(curr) && curr->user_count() == 1) { matched_a2a = DynCast(curr); allowed_intermediate_ops.pop_back(); @@ -767,7 +765,7 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { int64_t split_dimension = *matched_a2a->split_dimension(); for (int64_t i = allowed_intermediate_ops.size() - 1; i >= 0; i--) { HloInstruction* current_op = allowed_intermediate_ops[i]; - if (current_op->opcode() == HloOpcode::kReshape) { + if (HloPredicateIsOp(current_op)) { std::vector> unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape( current_op->operand(0)->shape(), current_op->shape()); @@ -786,7 +784,7 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { } // Assign the new split dim. split_dimension = it->second; - } else if (current_op->opcode() == HloOpcode::kTranspose) { + } else if (HloPredicateIsOp(current_op)) { const auto& transpose_dims = current_op->dimensions(); for (int64_t j = 0; j < transpose_dims.size(); j++) { if ((int64_t)transpose_dims[j] == split_dimension) { @@ -961,6 +959,12 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { // Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms // to minimize communication overhead. To do this, the original input will // be sliced into replica_group size and perform all-to-all+gemm. + if (!dot->GetModule() + ->config() + .debug_options() + .xla_gpu_experimental_enable_alltoall_windowed_einsum()) { + return absl::OkStatus(); + } HloInstruction* lhs; HloInstruction* rhs; std::vector replica_groups; @@ -1120,7 +1124,8 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { allowed_intermediate_ops.insert(allowed_intermediate_ops.end(), std::begin(curr->operands()), std::end(curr->operands())); - } else if (curr->opcode() == HloOpcode::kDot && curr->user_count() == 1) { + } else if (HloPredicateIsOp(curr) && + curr->user_count() == 1) { matched_dot = curr; allowed_intermediate_ops.pop_back(); break; @@ -1136,7 +1141,7 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { int64_t split_dimension = *a2a->split_dimension(); for (int64_t i = 0; i < allowed_intermediate_ops.size(); i++) { HloInstruction* current_op = allowed_intermediate_ops[i]; - if (current_op->opcode() == HloOpcode::kReshape) { + if (HloPredicateIsOp(current_op)) { std::vector> unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape( current_op->operand(0)->shape(), current_op->shape()); @@ -1155,7 +1160,7 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { } // Assign the new split dim. split_dimension = it->first; - } else if (current_op->opcode() == HloOpcode::kTranspose) { + } else if (HloPredicateIsOp(current_op)) { const auto& transpose_dims = current_op->dimensions(); split_dimension = transpose_dims[split_dimension]; } @@ -1184,6 +1189,12 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { absl::Status HandleAllToAll(HloInstruction* inst) override { CHECK_EQ(inst->opcode(), HloOpcode::kAllToAll); HloComputation* comp = inst->parent(); + if (!inst->GetModule() + ->config() + .debug_options() + .xla_gpu_experimental_enable_alltoall_windowed_einsum()) { + return absl::OkStatus(); + } // Rewrites a gemm+alltoall into multiple independent partial gemm+a2as // to minimize communication overhead. std::vector replica_groups; diff --git a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc index 48b09bea966122..12b44f5029c643 100644 --- a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc @@ -387,6 +387,9 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2, WindowedEinsumHandler gpu_handler; bool changed; + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_experimental_enable_alltoall_windowed_einsum(true); TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched, RunFileCheck(module->ToString(), kExpected)); @@ -459,6 +462,9 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1, WindowedEinsumHandler gpu_handler; bool changed; + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_experimental_enable_alltoall_windowed_einsum(true); TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched, RunFileCheck(module->ToString(), kExpected)); @@ -541,6 +547,9 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2, WindowedEinsumHandler gpu_handler; bool changed; + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_experimental_enable_alltoall_windowed_einsum(true); TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); EXPECT_TRUE(changed); TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched, @@ -625,6 +634,9 @@ CHECK: ROOT {{.*}} = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(bf16[1,4,1,204 WindowedEinsumHandler gpu_handler; bool changed; + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_experimental_enable_alltoall_windowed_einsum(true); TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); EXPECT_TRUE(changed); TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched, @@ -825,11 +837,11 @@ ENTRY main.9_spmd { constant.20 = u32[] constant(0) scale_lhs = f32[] parameter(3) scale_lhs_bcast = f32[2,2048,24576]{2,1,0} broadcast(scale_lhs), dimensions={} - lhs_bf16 = f32[2,2048,24576]{2,1,0} convert(param.8) + lhs_bf16 = f32[2,2048,24576]{2,1,0} convert(param.8) lhs_scaled = f32[2,2048,24576]{2,1,0} multiply(lhs_bf16, scale_lhs_bcast) scale_rhs = f32[] parameter(4) scale_rhs_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs), dimensions={} - rhs_bf16 = f32[24576,24576]{1,0} convert(param.6) + rhs_bf16 = f32[24576,24576]{1,0} convert(param.6) rhs_scaled = f32[24576,24576]{1,0} multiply(rhs_bf16, scale_rhs_bcast) tuple.3 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) tuple(lhs_scaled, rhs_scaled, param.7, param.7, constant.20) while.1 = (f32[2,2048,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[]) while(tuple.3), condition=windowed_dot_general_cond_rs, body=windowed_dot_general_body_rs diff --git a/third_party/xla/xla/service/gpu/triton_call.cc b/third_party/xla/xla/service/gpu/triton_call.cc index 5ca36c74e34c96..515145630ce4d4 100644 --- a/third_party/xla/xla/service/gpu/triton_call.cc +++ b/third_party/xla/xla/service/gpu/triton_call.cc @@ -16,9 +16,9 @@ limitations under the License. #include "xla/service/gpu/triton_call.h" #include -#include #include +#include "absl/strings/string_view.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/MLIRContext.h" @@ -27,7 +27,7 @@ limitations under the License. namespace xla::gpu { -TritonCall TritonCall::Parse(std::string_view backend_config, +TritonCall TritonCall::Parse(absl::string_view backend_config, mlir::MLIRContext* mlir_context) { // TODO(slebedev): Plumb through num_ctas and enable_wrap_specialization. auto attrs = mlir::cast( diff --git a/third_party/xla/xla/service/gpu/triton_call.h b/third_party/xla/xla/service/gpu/triton_call.h index 853f45e01c3417..d931bc93505a6e 100644 --- a/third_party/xla/xla/service/gpu/triton_call.h +++ b/third_party/xla/xla/service/gpu/triton_call.h @@ -18,8 +18,8 @@ limitations under the License. #include #include -#include +#include "absl/strings/string_view.h" #include "mlir/IR/MLIRContext.h" namespace xla::gpu { @@ -34,7 +34,7 @@ struct TritonCall { int32_t grid_z; // Parse the metadata of a __gpu$xla.gpu.triton call. - static TritonCall Parse(std::string_view backend_config, + static TritonCall Parse(absl::string_view backend_config, mlir::MLIRContext* mlir_context); }; diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc index e19395f8c0e062..ab0d25d0542501 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc @@ -255,7 +255,10 @@ absl::Status TritonFusionAnalysis::ExecuteForDotFusion( // Currently supported is one fusion output and one path from dot to it. // Propagate dimension order from dot to root. while (!output->IsRoot()) { - TF_RET_CHECK(output->user_count() == 1); + if (output->user_count() != 1) { + return absl::FailedPreconditionError( + absl::StrCat("Expected one user for ", output->ToString())); + } const HloInstruction* input = output; // Tuple with a custom call can be added at root to allocate a workspace // buffer. These do not need to participate in propagation of dimensions. @@ -271,14 +274,21 @@ absl::Status TritonFusionAnalysis::ExecuteForDotFusion( return FailedPrecondition("Failed to propagate tiling with error: %s", decision.Explain()); } - TF_RET_CHECK( - context.CombineDimOrdersAndReqs(std::get(result))); + if (!context.CombineDimOrdersAndReqs(std::get(result))) { + return absl::InternalError( + "Failed to combine dim orders and requirements."); + } } - TF_RET_CHECK( + + bool spec_was_inserted = iter_specs_[Scope::OUTPUT] .insert( {output, context.dim_orders().at(output).ToTensorIterationSpec()}) - .second); + .second; + if (!spec_was_inserted) { + return absl::InternalError( + "Failed to insert output spec for the output fusion."); + } parameters_[Scope::OUTPUT] = {}; if (output != &dot) { // Propagate back to parameters of the output fusion. diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc index ac32fd717a2016..36b11d64b851a3 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc @@ -23,10 +23,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/gpu/transforms/gemm_fusion.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/xfeed_queue.h b/third_party/xla/xla/service/gpu/xfeed_queue.h index 18f63a934a17ce..737bc921a2e3e3 100644 --- a/third_party/xla/xla/service/gpu/xfeed_queue.h +++ b/third_party/xla/xla/service/gpu/xfeed_queue.h @@ -42,7 +42,6 @@ class XfeedQueue { void EnqueueDestination(BufferType buffers) { absl::MutexLock l(&mu_); enqueued_buffers_.push_back(std::move(buffers)); - enqueue_cv_.Signal(); EnqueueHook(); } @@ -57,10 +56,8 @@ class XfeedQueue { bool became_empty; BufferType current_buffer; { - absl::MutexLock l(&mu_); - while (enqueued_buffers_.empty()) { - enqueue_cv_.Wait(&mu_); - } + absl::MutexLock l(&mu_, + absl::Condition(this, &XfeedQueue::IsBufferEnqueued)); current_buffer = std::move(enqueued_buffers_.front()); enqueued_buffers_.pop_front(); DequeueHook(); @@ -94,8 +91,10 @@ class XfeedQueue { std::deque enqueued_buffers_ ABSL_GUARDED_BY(mu_); private: - // Condition variable that is signaled every time a buffer is enqueued. - absl::CondVar enqueue_cv_; + // Returns true if there is a buffer in the queue. + bool IsBufferEnqueued() const ABSL_SHARED_LOCKS_REQUIRED(mu_) { + return !enqueued_buffers_.empty(); + } // List of callbacks which will be called when 'enqueued_buffers_' becomes // empty. @@ -122,14 +121,9 @@ class BlockingXfeedQueue : public XfeedQueue { : max_pending_xfeeds_(max_pending_xfeeds) {} void BlockUntilEnqueueSlotAvailable() { - absl::MutexLock l{&this->mu_}; - while (pending_buffers_ + this->enqueued_buffers_.size() >= - max_pending_xfeeds_) { - VLOG(2) << "Capacity " - << (pending_buffers_ + this->enqueued_buffers_.size()) - << " >= max capacity " << max_pending_xfeeds_; - dequeue_cv_.Wait(&this->mu_); - } + absl::MutexLock l{ + &this->mu_, + absl::Condition(this, &BlockingXfeedQueue::IsEnqueueSlotAvailable)}; pending_buffers_++; } @@ -139,15 +133,18 @@ class BlockingXfeedQueue : public XfeedQueue { pending_buffers_--; } - void DequeueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override { - dequeue_cv_.Signal(); - } + void DequeueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override {} private: const int max_pending_xfeeds_; - // Condition variable that is signaled every time a buffer is dequeued. - absl::CondVar dequeue_cv_; + bool IsEnqueueSlotAvailable() const ABSL_SHARED_LOCKS_REQUIRED(this->mu_) { + VLOG(2) << "Capacity " + << (pending_buffers_ + this->enqueued_buffers_.size()) + << " >= max capacity " << max_pending_xfeeds_; + return pending_buffers_ + this->enqueued_buffers_.size() < + max_pending_xfeeds_; + } // Keeps track of the number of buffers reserved but not added to // enqueued_buffers_. diff --git a/third_party/xla/xla/service/graphcycles/graphcycles.cc b/third_party/xla/xla/service/graphcycles/graphcycles.cc index c4648bb1cc91fd..15329e981bd3e3 100644 --- a/third_party/xla/xla/service/graphcycles/graphcycles.cc +++ b/third_party/xla/xla/service/graphcycles/graphcycles.cc @@ -125,7 +125,7 @@ int32_t GraphCycles::NewNode() { Node n; n.visited = false; n.rank = rep_->nodes_.size(); - rep_->nodes_.emplace_back(n); + rep_->nodes_.push_back(n); rep_->node_io_.emplace_back(); rep_->node_data_.push_back(nullptr); return n.rank; diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator.cc b/third_party/xla/xla/service/heap_simulator/heap_simulator.cc index 9ceb861e0fce2d..e3cf85615a8331 100644 --- a/third_party/xla/xla/service/heap_simulator/heap_simulator.cc +++ b/third_party/xla/xla/service/heap_simulator/heap_simulator.cc @@ -944,6 +944,36 @@ bool BufferIntervalTree::Remove(int64_t start, int64_t end, return true; } +int BufferIntervalTree::NumChunksOverlappingInTime(int64_t start, + int64_t end) const { + int result = 0; + if (root_ == nullptr) { + return result; + } + std::vector visiting_stack; + visiting_stack.push_back(root_); + while (!visiting_stack.empty()) { + const BufferIntervalTreeNode* top = visiting_stack.back(); + visiting_stack.pop_back(); + if (start > top->subtree_end) { + continue; + } + if (top->left != nullptr) { + visiting_stack.push_back(top->left); + } + if (top->start <= end && top->end >= start) { + ++result; + } + if (end < top->start) { + continue; + } + if (top->right != nullptr) { + visiting_stack.push_back(top->right); + } + } + return result; +} + std::vector BufferIntervalTree::ChunksOverlappingInTime( int64_t start, int64_t end) const { std::vector result; @@ -2304,7 +2334,7 @@ GlobalDecreasingSizeBestFitHeap::Finish() { VLOG(1) << "result heap_size: " << result_.heap_size; Result result; result.heap_size = result_.heap_size; - result.heap_results.emplace_back(result_); + result.heap_results.push_back(result_); return result; } diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator.h b/third_party/xla/xla/service/heap_simulator/heap_simulator.h index 7328f87722b600..d81b29b52ad451 100644 --- a/third_party/xla/xla/service/heap_simulator/heap_simulator.h +++ b/third_party/xla/xla/service/heap_simulator/heap_simulator.h @@ -363,6 +363,10 @@ class BufferIntervalTree { // Remove the interval from the tree. Returns true if the chunk is removed. bool Remove(int64_t start, int64_t end, const Chunk& chunk); + // Returns the number of allocated chunks that overlap with the given time + // interval. + int NumChunksOverlappingInTime(int64_t start, int64_t end) const; + // Returns vector of allocated chunks that overlap with the given time // interval. std::vector ChunksOverlappingInTime(int64_t start, int64_t end) const; diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc b/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc index d27dbd14d81cce..612e7b060d886d 100644 --- a/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc +++ b/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc @@ -1862,10 +1862,16 @@ TEST_F(IntervalTreeTest, InsertAndRemoveTwoLevelsLeft) { BufferIntervalTree tree; tree.Add(20, 36, chunk); tree.Add(1, 45, chunk); + EXPECT_EQ(tree.NumChunksOverlappingInTime(10, 25), 2); + EXPECT_EQ(tree.NumChunksOverlappingInTime(5, 15), 1); EXPECT_TRUE(tree.Remove(1, 45, chunk)); + EXPECT_EQ(tree.NumChunksOverlappingInTime(10, 25), 1); + EXPECT_EQ(tree.NumChunksOverlappingInTime(5, 15), 0); EXPECT_EQ(tree.GetRoot()->subtree_end, 36); EXPECT_TRUE(tree.Remove(20, 36, chunk)); ASSERT_EQ(tree.GetRoot(), nullptr); + EXPECT_EQ(tree.NumChunksOverlappingInTime(10, 25), 0); + EXPECT_EQ(tree.NumChunksOverlappingInTime(5, 15), 0); } TEST_F(IntervalTreeTest, InsertAndRemoveTwoLevelsRight) { diff --git a/third_party/xla/xla/service/hlo.proto b/third_party/xla/xla/service/hlo.proto index 37283ede9d8b77..4858f4153feff0 100644 --- a/third_party/xla/xla/service/hlo.proto +++ b/third_party/xla/xla/service/hlo.proto @@ -590,6 +590,7 @@ message HloModuleProto { LAYOUT = 3; DOT = 4; FLAGNET = 5; + SHARDING = 6; } // The type of profile generation strategy used to generate the profile. enum ProfileGenerationStrategy { diff --git a/third_party/xla/xla/service/hlo_computation_test.cc b/third_party/xla/xla/service/hlo_computation_test.cc index d46996dd3accd5..a4dcc36d979cc2 100644 --- a/third_party/xla/xla/service/hlo_computation_test.cc +++ b/third_party/xla/xla/service/hlo_computation_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -56,7 +55,7 @@ using ::testing::UnorderedElementsAre; class HloComputationTest : public HloTestBase { protected: - HloComputationTest() {} + HloComputationTest() = default; // Create a computation which takes a scalar and returns its negation. std::unique_ptr CreateNegateComputation() { @@ -849,7 +848,7 @@ ENTRY entry { } TEST_F(HloComputationTest, ComparisonWithCustomComparator) { - std::string_view mod_txt = R"( + absl::string_view mod_txt = R"( HloModule Module region_X { Arg_0.5 = s32[] parameter(0) @@ -890,7 +889,7 @@ TEST_F(HloComputationTest, ComparisonWithCustomComparator) { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(mod_txt)); - absl::flat_hash_map replace_map; + absl::flat_hash_map replace_map; replace_map["region_X"] = "region_A"; replace_map["region_Y"] = "region_B"; auto compare_func = [&replace_map](const HloComputation* a, @@ -974,7 +973,7 @@ TEST_F(HloComputationTest, CompositeCall) { } TEST_F(HloComputationTest, CloneComputationWithAsyncInstructions) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main comp.0 { diff --git a/third_party/xla/xla/service/hlo_creation_utils_test.cc b/third_party/xla/xla/service/hlo_creation_utils_test.cc index 252345fbbbc5ff..debabe09c3c51e 100644 --- a/third_party/xla/xla/service/hlo_creation_utils_test.cc +++ b/third_party/xla/xla/service/hlo_creation_utils_test.cc @@ -15,19 +15,29 @@ limitations under the License. #include "xla/service/hlo_creation_utils.h" +#include #include +#include +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "xla/array2d.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/hlo_domain_test.cc b/third_party/xla/xla/service/hlo_domain_test.cc index c80155b75659c6..11acf73bf6cfff 100644 --- a/third_party/xla/xla/service/hlo_domain_test.cc +++ b/third_party/xla/xla/service/hlo_domain_test.cc @@ -372,7 +372,7 @@ ENTRY entry { sharding={{maximal device=-1},{maximal device=-1}} b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=-1} c = f32[4] add(b_element, b_element), sharding={maximal device=-1} - d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, + d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, sharding={{maximal device=-1},{maximal device=-1},{maximal device=-1}} ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=-1} } diff --git a/third_party/xla/xla/service/hlo_module_test.cc b/third_party/xla/xla/service/hlo_module_test.cc index 339feeb8fd2d4e..960f107c9117b9 100644 --- a/third_party/xla/xla/service/hlo_module_test.cc +++ b/third_party/xla/xla/service/hlo_module_test.cc @@ -24,25 +24,37 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/comparison_util.h" +#include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_original_value.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/buffer_value.h" #include "xla/service/computation_placer.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/test_compilation_environment.pb.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" +#include "tsl/platform/casts.h" +#include "tsl/platform/protobuf.h" namespace xla { diff --git a/third_party/xla/xla/service/hlo_runner.h b/third_party/xla/xla/service/hlo_runner.h index f2387c04bca6d4..4314fce1655930 100644 --- a/third_party/xla/xla/service/hlo_runner.h +++ b/third_party/xla/xla/service/hlo_runner.h @@ -195,6 +195,8 @@ class HloRunner : public HloRunnerInterface { return backend().compiler()->ShapeSizeBytesFunction(); } + int device_count() const override { return backend().device_count(); } + private: absl::StatusOr ExecuteWithExecutionInputs( Executable* executable, std::vector arguments, diff --git a/third_party/xla/xla/service/hlo_runner_interface.cc b/third_party/xla/xla/service/hlo_runner_interface.cc index f3f3303851952a..510ccdba6e4453 100644 --- a/third_party/xla/xla/service/hlo_runner_interface.cc +++ b/third_party/xla/xla/service/hlo_runner_interface.cc @@ -15,7 +15,23 @@ limitations under the License. #include "xla/service/hlo_runner_interface.h" +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/literal.h" +#include "xla/service/executable.h" +#include "xla/service/hlo_module_config.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -61,14 +77,15 @@ HloRunnerInterface::ReadModuleFromBinaryProtoFile( } /*static*/ absl::StatusOr> -HloRunnerInterface::ReadModuleFromHloTextFile( - const std::string& filename, const DebugOptions& debug_options) { +HloRunnerInterface::ReadModuleFromHloTextFile(const std::string& filename, + const DebugOptions& debug_options, + const HloParserOptions& options) { std::string hlo_string; TF_RETURN_IF_ERROR( tsl::ReadFileToString(tsl::Env::Default(), filename, &hlo_string)); HloModuleConfig config; config.set_debug_options(debug_options); - return ParseAndReturnUnverifiedModule(hlo_string, config); + return ParseAndReturnUnverifiedModule(hlo_string, config, options); } /*static*/ absl::StatusOr> diff --git a/third_party/xla/xla/service/hlo_runner_interface.h b/third_party/xla/xla/service/hlo_runner_interface.h index ab6ab7f121b13b..23f29591a37b1e 100644 --- a/third_party/xla/xla/service/hlo_runner_interface.h +++ b/third_party/xla/xla/service/hlo_runner_interface.h @@ -18,21 +18,22 @@ limitations under the License. #include #include -#include #include -#include #include #include #include +#include "absl/log/log.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/literal.h" #include "xla/service/computation_placer.h" #include "xla/service/executable.h" -#include "xla/status_macros.h" -#include "xla/types.h" +#include "xla/shape.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -112,7 +113,8 @@ class HloRunnerInterface { // Reads the hlo text dump file in HloModule::ToString format, creates and // returns the HloModule. static absl::StatusOr> ReadModuleFromHloTextFile( - const std::string& filename, const DebugOptions& debug_options); + const std::string& filename, const DebugOptions& debug_options, + const HloParserOptions& options = HloParserOptions()); // Creates an executable object given an HLO module. If run_hlo_passes is // true, the HLO passes will be run as part of compilation. @@ -226,6 +228,10 @@ class HloRunnerInterface { // This function is used e.g. to create a VerifiedHloModule. It returns an // integer representing the size of the shape in bytes as opposed to a Shape. virtual DeviceShapeSizeFn device_shape_size_fn() const = 0; + + // Returns the number of devices which are known. Not all of these devices may + // be usable by XLA. + virtual int device_count() const = 0; }; } // namespace xla diff --git a/third_party/xla/xla/service/hlo_runner_pjrt.cc b/third_party/xla/xla/service/hlo_runner_pjrt.cc index 6aec116ffca612..dce3bc9e1ca5be 100644 --- a/third_party/xla/xla/service/hlo_runner_pjrt.cc +++ b/third_party/xla/xla/service/hlo_runner_pjrt.cc @@ -23,25 +23,37 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/log/die_if_null.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/layout.h" +#include "xla/literal.h" #include "xla/pjrt/host_memory_spaces.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/service/computation_layout.h" +#include "xla/service/computation_placer.h" #include "xla/service/executable.h" -#include "xla/service/hlo_module_util.h" +#include "xla/service/hlo_runner_interface.h" +#include "xla/service/service_executable_run_options.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/threadpool.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" +#include "tsl/platform/casts.h" namespace xla { @@ -109,6 +121,43 @@ absl::StatusOr GenerateExecuteOptions(const HloModule& module) { return execute_options; } +inline PjRtGlobalDeviceId DeviceIdForInvocation( + const DeviceAssignment& device_assignment, const int64_t i) { + const int64_t computation_count = device_assignment.computation_count(); + return PjRtGlobalDeviceId( + device_assignment(i / computation_count, i % computation_count)); +} + +absl::StatusOr GetStaticDeviceAssignmentOrComputeDefault( + const HloModule& module, PjRtClient& client) { + if (module.config().has_static_device_assignment()) { + return module.config().static_device_assignment(); + } + return client.GetDefaultDeviceAssignment(module.config().replica_count(), + module.config().num_partitions()); +} + +std::vector BufferVecToPointerVec( + const absl::Span> buffer) { + std::vector argument_ptrs; + argument_ptrs.resize(buffer.size()); + for (int i = 0; i < buffer.size(); ++i) { + argument_ptrs[i] = buffer[i].get(); + } + + return argument_ptrs; +} + +std::vector> BufferMatToPointerMat( + const absl::Span>> buffer) { + std::vector> argument_ptrs; + argument_ptrs.reserve(buffer.size()); + for (int i = 0; i < buffer.size(); ++i) { + argument_ptrs.push_back(BufferVecToPointerVec(buffer[i])); + } + return argument_ptrs; +} + } // namespace // TODO(b/245550554): Remove the use of PjRtWrappedExecutable. @@ -156,9 +205,8 @@ HloRunnerPjRt::~HloRunnerPjRt() = default; absl::StatusOr HloRunnerPjRt::GenerateDefaultCompileOptions( HloModule* module, bool run_hlo_passes) { TF_ASSIGN_OR_RETURN( - auto device_assignment, - pjrt_client_->GetDefaultDeviceAssignment( - module->config().replica_count(), module->config().num_partitions())); + const DeviceAssignment device_assignment, + GetStaticDeviceAssignmentOrComputeDefault(*module, *pjrt_client_)); CompileOptions compile_options; @@ -189,6 +237,9 @@ absl::StatusOr HloRunnerPjRt::GenerateDefaultCompileOptions( compile_options.executable_build_options.set_result_layout( module->entry_computation_layout().result_shape()); + compile_options.executable_build_options.set_use_spmd_partitioning( + module->config().use_spmd_partitioning()); + return compile_options; } @@ -280,36 +331,12 @@ absl::StatusOr HloRunnerPjRt::Execute( ExecutionProfile* profile) { // TODO (b/245550554) : Remove UpdateEntryComputationLayout from runner. UpdateEntryComputationLayout(module.get()); - TF_ASSIGN_OR_RETURN(auto compile_options, GenerateDefaultCompileOptions( - module.get(), run_hlo_passes)); - TF_ASSIGN_OR_RETURN(auto executable, CreateExecutable(std::move(module), run_hlo_passes)); return ExecuteWithExecutable(executable.get(), arguments, {}); } -std::vector HloRunnerPjRt::BufferVecToPointerVec( - const std::vector>& buffer) { - std::vector argument_ptrs; - argument_ptrs.resize(buffer.size()); - for (int i = 0; i < buffer.size(); ++i) { - argument_ptrs[i] = buffer[i].get(); - } - - return argument_ptrs; -} - -std::vector> HloRunnerPjRt::BufferMatToPointerMat( - std::vector>>& buffer) { - std::vector> argument_ptrs; - argument_ptrs.reserve(buffer.size()); - for (int i = 0; i < buffer.size(); ++i) { - argument_ptrs.push_back(BufferVecToPointerVec(buffer[i])); - } - return argument_ptrs; -} - absl::StatusOr> HloRunnerPjRt::CreateExecutable(HloModule* module, CompileOptions compile_options) { @@ -417,7 +444,7 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( const HloRunnerInterface::ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment, ExecutionProfile* profile) { return ExecuteReplicatedImpl( - [&](absl::Span>& argument_buffer_slices) + [&](absl::Span> argument_buffer_slices) -> absl::StatusOr>> { PjRtWrappedExecutable* wrapped_executable = static_cast(executable); @@ -448,70 +475,193 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( std::function argument_provider, const HloRunnerInterface::ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment) { - return Unimplemented("Unimplemeneted ExecuteReplicated"); + TF_RET_CHECK(device_assignment->computation_count() == 1) + << "Only single-computation execution is supported."; + return ExecuteReplicatedImpl( + [&](absl::Span> argument_buffer_slices) + -> absl::StatusOr>> { + TF_RET_CHECK(options.use_threads); + + // The underlying data is modified concurrently. We don't need to + // protect access as each replica writes only to its own slot. + std::vector>>> + per_replica_results(options.num_replicas); + absl::c_fill(per_replica_results, + absl::InternalError("No result for replica.")); + + { + // NB: `pool` is joined on destruction. + tsl::thread::ThreadPool pool(tsl::Env::Default(), "replicas", + options.num_replicas); + for (int64_t i = 0; i < options.num_replicas; ++i) { + for (const PjRtBuffer* const buffer : argument_buffer_slices[i]) { + TF_RET_CHECK(buffer != nullptr); + } + PjRtWrappedExecutable* executable = + tensorflow::down_cast( + executable_provider(i)); + if (executable == nullptr) { + return absl::InternalError( + absl::StrFormat("Failed to cast executable for replica %d " + "to PjRtWrappedExecutable.", + i)); + } + TF_ASSIGN_OR_RETURN( + PjRtDevice * device_ptr, + pjrt_client_->LookupDevice( + DeviceIdForInvocation(*device_assignment, i))); + pool.Schedule([&per_replica_results, i, executable, + args = argument_buffer_slices[i], device_ptr]() { + per_replica_results[i] = + executable->GetPjRtLoadedExecutable()->ExecuteSharded( + args, device_ptr, {}); + }); + } + } + // Aggregate results. + std::vector> results; + for (int64_t i = 0; i < options.num_replicas; ++i) { + absl::StatusOr>>& + replica_result = per_replica_results[i]; + if (!replica_result.ok()) { + return replica_result.status(); + } + if (replica_result->size() != 1) { + return absl::InternalError(absl::StrFormat( + "Expected a single result for replica %d, got %d results.", i, + replica_result->size())); + } + results.push_back(std::move(std::move(replica_result)->front())); + } + return results; + }, + argument_count_provider, argument_provider, options, device_assignment); } absl::StatusOr> HloRunnerPjRt::ExecuteReplicatedImpl( std::function>>( - absl::Span>&)> + absl::Span>)> execution_helper, std::function argument_count_provider, std::function argument_provider, const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment) { - absl::Span devices = pjrt_client_->devices(); + TF_RET_CHECK(options.infeed_values.empty() || + options.infeed_values.size() == options.num_replicas); + std::vector replica_devices(options.num_replicas, nullptr); std::vector>> argument_buffer_slices; - argument_buffer_slices.reserve(pjrt_client_->addressable_device_count()); - + argument_buffer_slices.reserve(options.num_replicas); for (int64_t i = 0; i < options.num_replicas; ++i) { - PjRtDevice* device_ptr = devices[i]; + // Amortize device lookup. + TF_ASSIGN_OR_RETURN(PjRtDevice* const device_ptr, + pjrt_client_->LookupDevice( + DeviceIdForInvocation(*device_assignment, i))); + replica_devices[i] = device_ptr; // Transfer literals to device. const int64_t argument_count = argument_count_provider(i); - std::vector> replica_buffers; replica_buffers.reserve(argument_count); - for (int64_t arg_index = 0; arg_index < argument_count; arg_index++) { const Literal* const argument = argument_provider(i, arg_index); TF_RET_CHECK(argument != nullptr); - TF_ASSIGN_OR_RETURN(auto assignment, pjrt_client_->BufferFromHostLiteral( - *argument, device_ptr)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr assignment, + use_parameter_layout_on_device_ + ? pjrt_client_->BufferFromHostLiteral(*argument, device_ptr, + &argument->shape().layout()) + : pjrt_client_->BufferFromHostLiteral(*argument, device_ptr)); replica_buffers.push_back(std::move(assignment)); } - argument_buffer_slices.push_back(std::move(replica_buffers)); } - TF_RET_CHECK(options.infeed_values.empty() || - options.infeed_values.size() == options.num_replicas); - - if (!options.infeed_values.empty()) { - // TODO(b/245550554): Infeed/Outfeed + // Handle infeed and outfeed. + const bool has_infeed = !options.infeed_values.empty(); + const bool has_outfeed = ShapeUtil::IsInitialized(options.outfeed_shape); + std::unique_ptr pool = nullptr; + absl::Mutex infeed_outfeed_status_mu; + absl::Status infeed_outfeed_status = absl::OkStatus(); + if (has_infeed || has_outfeed) { + // One infeed per infeed value and one outfeed per replica. + const int64_t num_threads = + options.infeed_values.size() + (has_outfeed ? options.num_replicas : 0); + pool = std::make_unique( + tsl::Env::Default(), "infeed_outfeed", num_threads); } - - if (ShapeUtil::IsInitialized(options.outfeed_shape)) { - // TODO(b/245550554): Infeed/Outfeed + if (has_infeed) { + for (int64_t i = 0; i < options.num_replicas; ++i) { + pool->Schedule( + [device = replica_devices[i], + &infeed_literal = *ABSL_DIE_IF_NULL(options.infeed_values[i]), + infeed_steps = options.infeed_steps, &infeed_outfeed_status_mu, + &infeed_outfeed_status]() { + VLOG(1) << "Starting infeed on device " << device->ToString(); + absl::Status per_feed_status = absl::OkStatus(); + for (int64_t step = 1; infeed_steps < 0 || step <= infeed_steps; + ++step) { + per_feed_status.Update(device->TransferToInfeed(infeed_literal)); + if (step % 100 == 0) { + VLOG(1) << "Infeed step " << step; + } + } + absl::MutexLock lock(&infeed_outfeed_status_mu); + infeed_outfeed_status.Update(per_feed_status); + }); + } + } + if (has_outfeed) { + if (options.outfeed_values != nullptr) { + options.outfeed_values->resize(options.num_replicas); + } + for (int64_t i = 0; i < options.num_replicas; ++i) { + pool->Schedule([i, device = replica_devices[i], + outfeed_values = options.outfeed_values, + outfeed_shape = options.outfeed_shape, + infeed_steps = options.infeed_steps, + &infeed_outfeed_status_mu, &infeed_outfeed_status]() { + VLOG(1) << "Starting outfeed on device " << device->ToString(); + absl::Status per_feed_status = absl::OkStatus(); + for (int64_t step = 1; infeed_steps < 0 || step <= infeed_steps; + ++step) { + Literal literal(outfeed_shape); + per_feed_status.Update(device->TransferFromOutfeed(&literal)); + if (outfeed_values != nullptr) { + outfeed_values->at(i) = std::move(literal); + } + if (step % 100 == 0) { + VLOG(1) << "Outfeed step " << step; + } + } + absl::MutexLock lock(&infeed_outfeed_status_mu); + infeed_outfeed_status.Update(per_feed_status); + }); + } } - auto mat = BufferMatToPointerMat(argument_buffer_slices); - - auto span = absl::Span>(mat); - - TF_ASSIGN_OR_RETURN(auto results, execution_helper(span)); - std::vector exec_results; - exec_results.reserve(options.num_replicas); + VLOG(1) << "Replicated execution started"; + TF_ASSIGN_OR_RETURN( + const std::vector> result_buffers, + execution_helper(BufferMatToPointerMat(argument_buffer_slices))); + VLOG(1) << "Replicated execution terminated"; + // Get the result from execution. + std::vector result_literals; + result_literals.reserve(options.num_replicas); for (int64_t i = 0; i < options.num_replicas; ++i) { TF_ASSIGN_OR_RETURN(Literal literal, - TransferLiteralFromDevice(*results[i])); - - exec_results.push_back(std::move(literal)); + TransferLiteralFromDevice(*result_buffers[i])); + result_literals.push_back(std::move(literal)); } - return std::move(exec_results); + // Join infeed and outfeed threads, if they exist. The thread pool's threads + // are joined on destruction. No-op otherwise. + pool = nullptr; + TF_RETURN_IF_ERROR(infeed_outfeed_status); + + return std::move(result_literals); } absl::string_view HloRunnerPjRt::Name() const { return "HloRunnerPjRt"; } diff --git a/third_party/xla/xla/service/hlo_runner_pjrt.h b/third_party/xla/xla/service/hlo_runner_pjrt.h index dc4ec3921b4a6e..0d7c92beb00789 100644 --- a/third_party/xla/xla/service/hlo_runner_pjrt.h +++ b/third_party/xla/xla/service/hlo_runner_pjrt.h @@ -25,7 +25,13 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" +#include "xla/literal.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/service/computation_layout.h" +#include "xla/service/computation_placer.h" +#include "xla/service/executable.h" #include "xla/service/hlo_module_util.h" #include "xla/service/hlo_runner_interface.h" #include "xla/xla_data.pb.h" @@ -117,29 +123,25 @@ class HloRunnerPjRt : public HloRunnerInterface { return device_shape_size_fn_; } - private: - std::unique_ptr pjrt_client_; - DeviceShapeRepresentationFn device_shape_representation_fn_; - DeviceShapeSizeFn device_shape_size_fn_; - bool use_parameter_layout_on_device_ = false; - - std::vector BufferVecToPointerVec( - const std::vector>& buffer); - - std::vector> BufferMatToPointerMat( - std::vector>>& buffer); + int device_count() const override { return pjrt_client_->device_count(); } + private: absl::StatusOr GenerateDefaultCompileOptions( HloModule* module, bool run_hlo_passes); absl::StatusOr> ExecuteReplicatedImpl( std::function>>( - absl::Span>&)> + absl::Span>)> execution_helper, std::function argument_count_provider, std::function argument_provider, const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment); + + std::unique_ptr pjrt_client_; + DeviceShapeRepresentationFn device_shape_representation_fn_; + DeviceShapeSizeFn device_shape_size_fn_; + bool use_parameter_layout_on_device_ = false; }; } // namespace xla diff --git a/third_party/xla/xla/service/hlo_schedule_test.cc b/third_party/xla/xla/service/hlo_schedule_test.cc index d18c8527893c81..fd89bcc5b23fc5 100644 --- a/third_party/xla/xla/service/hlo_schedule_test.cc +++ b/third_party/xla/xla/service/hlo_schedule_test.cc @@ -22,19 +22,20 @@ limitations under the License. #include #include "absl/algorithm/container.h" #include "absl/log/log.h" -#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" +#include "xla/literal_util.h" +#include "xla/service/buffer_value.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/hlo_unstacker_test.cc b/third_party/xla/xla/service/hlo_unstacker_test.cc index 9878e0805f6669..3a552c5542cfe7 100644 --- a/third_party/xla/xla/service/hlo_unstacker_test.cc +++ b/third_party/xla/xla/service/hlo_unstacker_test.cc @@ -52,7 +52,7 @@ TEST_F(UnstackerTest, UnstackDSFusionPattern) { %param_0.51117 = s8[3,128,128] parameter(0) p1 = s32[] parameter(1) %constant.85694 = s32[] constant(0) - %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040) } @@ -63,8 +63,8 @@ TEST_F(UnstackerTest, UnstackDSFusionPattern) { p1 = s8[3,128,128] get-tuple-element(wide_p), index=2 one = s32[] constant(1) inc = s32[] add(i, one) - %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice - conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf + %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf ROOT out = (s32[], bf16[8,128], s8[3,128,128]) tuple(inc, conv, p1) } @@ -80,7 +80,7 @@ TEST_F(UnstackerTest, UnstackDSFusionPattern) { p1 = bf16[8,128] parameter(1) init = s32[] constant(0) while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0) - while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body + while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body while_use = s8[3,128,128] get-tuple-element(while.out), index=2 ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 } @@ -106,7 +106,7 @@ TEST_F(UnstackerTest, NotUnstackDSFusionPattern) { %param_0.51117 = s8[3,128,128] parameter(0) p1 = s32[] parameter(1) %constant.85694 = s32[] constant(0) - %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040) } @@ -142,7 +142,7 @@ TEST_F(UnstackerTest, NotUnstackDSFusionPattern) { p1 = bf16[8,128] parameter(1) init = s32[] constant(0) while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0) - while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body + while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body while_use = s8[3,128,128] get-tuple-element(while.out), index=2 ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 } @@ -161,7 +161,7 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternMultipleLoopRootUse) { %param_0.51117 = s8[3,128,128] parameter(0) p1 = s32[] parameter(1) %constant.85694 = s32[] constant(0) - %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040) } @@ -172,8 +172,8 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternMultipleLoopRootUse) { p2 = s8[3,128,128] get-tuple-element(wide_p), index=3 one = s32[] constant(1) inc = s32[] add(i, one) - %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p2, i), kind=kLoop, calls=%fused_computation.slice - conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf + %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p2, i), kind=kLoop, calls=%fused_computation.slice + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf ROOT out = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) tuple(inc, conv, p2, p2) } @@ -191,7 +191,7 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternMultipleLoopRootUse) { zero = s8[] constant(0) buffer = s8[3,128,128] broadcast(zero), dimensions={} while.input = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) tuple(init, p1, p0, buffer) - while.out = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body + while.out = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 } )"; @@ -216,7 +216,7 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternWithUnusedOperand) { %param_0.51117 = s8[3,128,128] parameter(0) p1 = s32[] parameter(1) %constant.85694 = s32[] constant(0) - %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040) } @@ -227,8 +227,8 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternWithUnusedOperand) { p1 = s8[3,128,128] get-tuple-element(wide_p), index=2 one = s32[] constant(1) inc = s32[] add(i, one) - %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice - conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf + %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf ROOT out = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) tuple(inc, conv, p1, p1) } @@ -246,7 +246,7 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternWithUnusedOperand) { zero = s8[] constant(0) buffer = s8[3,128,128] broadcast(zero), dimensions={} while.input = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) tuple(init, p1, p0, buffer) - while.out = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body + while.out = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body while_use = s8[3,128,128] get-tuple-element(while.out), index=2 ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 } @@ -290,8 +290,8 @@ TEST_F(UnstackerTest, UnstackReduceFusionPattern) { p1 = s8[3,128,128] get-tuple-element(wide_p), index=2 one = s32[] constant(1) inc = s32[] add(i, one) - %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.1096.clone - conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf + %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.1096.clone + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf ROOT out = (s32[], bf16[8,128], s8[3,128,128]) tuple(inc, conv, p1) } @@ -307,8 +307,8 @@ TEST_F(UnstackerTest, UnstackReduceFusionPattern) { p1 = bf16[8,128] parameter(1) init = s32[] constant(0) while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0) - while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body - while_use = s8[3,128,128] get-tuple-element(while.out), index=2 + while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body + while_use = s8[3,128,128] get-tuple-element(while.out), index=2 ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 } )"; @@ -328,7 +328,7 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternNoBitcast) { %param_0.51117 = s8[3,128,128] parameter(0) p1 = s32[] parameter(1) %constant.85694 = s32[] constant(0) - ROOT %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + ROOT %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} } %while.body (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> (s32[], bf16[8,128], s8[3,128,128]) { @@ -340,7 +340,7 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternNoBitcast) { inc = s32[] add(i, one) %fusion.67830 = s8[1,128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice bitcast.102 = s8[128,128] bitcast(s8[1,128,128] %fusion.67830) - conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] bitcast.102), dim_labels=bf_io->bf + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] bitcast.102), dim_labels=bf_io->bf ROOT out = (s32[], bf16[8,128], s8[3,128,128]) tuple(inc, conv, p1) } @@ -356,8 +356,8 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternNoBitcast) { p1 = bf16[8,128] parameter(1) init = s32[] constant(0) while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0) - while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body - while_use = s8[3,128,128] get-tuple-element(while.out), index=2 + while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body + while_use = s8[3,128,128] get-tuple-element(while.out), index=2 ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 } )"; @@ -382,7 +382,7 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternNoBitcastKeepFused) { %param_0.51117 = s8[3,128,128] parameter(0) p1 = s32[] parameter(1) %constant.85694 = s32[] constant(0) - ROOT %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + ROOT %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} } %while.body (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> (s32[], bf16[8,128], s8[3,128,128]) { @@ -394,7 +394,7 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternNoBitcastKeepFused) { inc = s32[] add(i, one) %fusion.67830 = s8[1,128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice bitcast.102 = s8[128,128] bitcast(s8[1,128,128] %fusion.67830) - conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] bitcast.102), dim_labels=bf_io->bf + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] bitcast.102), dim_labels=bf_io->bf ROOT out = (s32[], bf16[8,128], s8[3,128,128]) tuple(inc, conv, p1) } @@ -410,8 +410,8 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternNoBitcastKeepFused) { p1 = bf16[8,128] parameter(1) init = s32[] constant(0) while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0) - while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body - while_use = s8[3,128,128] get-tuple-element(while.out), index=2 + while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body + while_use = s8[3,128,128] get-tuple-element(while.out), index=2 ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 } )"; @@ -438,7 +438,7 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternKeepFused) { %param_0.51117 = s8[3,128,128] parameter(0) p1 = s32[] parameter(1) %constant.85694 = s32[] constant(0) - %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} ROOT out = s8[128,128] bitcast(%dynamic-slice.22040) } @@ -450,7 +450,7 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternKeepFused) { one = s32[] constant(1) inc = s32[] add(i, one) %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice - conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf ROOT out = (s32[], bf16[8,128], s8[3,128,128]) tuple(inc, conv, p1) } @@ -466,8 +466,8 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternKeepFused) { p1 = bf16[8,128] parameter(1) init = s32[] constant(0) while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0) - while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body - while_use = s8[3,128,128] get-tuple-element(while.out), index=2 + while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body + while_use = s8[3,128,128] get-tuple-element(while.out), index=2 ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 } )"; @@ -662,7 +662,7 @@ TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithMultipleIndex) { %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[4,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040) } - + %fused_computation.slice.2 (param_0.51117: s8[4,128,128], p1: s32[]) -> s8[128,128] { %param_0.51117 = s8[4,128,128] parameter(0) p1 = s32[] parameter(1) @@ -678,7 +678,7 @@ TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithMultipleIndex) { %fusion.67830 = s8[128,128] fusion(s8[4,128,128] %param_1.30691, p2), kind=kLoop, calls=%fused_computation.slice.1 ROOT %convolution.3447 = bf16[8,128] convolution(bf16[8,128] %param_0.34523, s8[128,128] %fusion.67830), dim_labels=bf_io->bf } - + %fused_computation.inner.2 (param_0.34523: bf16[8,128], param_1.30691: s8[4,128,128], p2: s32[]) -> bf16[8,128] { %param_0.34523 = bf16[8,128] parameter(0) %param_1.30691 = s8[4,128,128] parameter(1) @@ -799,7 +799,7 @@ TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithSameUnstackingComps) { %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040) } - + %fused_computation.slice.2 (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] { %param_0.51117 = s8[3,128,128] parameter(0) p1 = s32[] parameter(1) @@ -815,7 +815,7 @@ TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithSameUnstackingComps) { %fusion.67830 = s8[128,128] fusion(s8[3,128,128] %param_1.30691, p2), kind=kLoop, calls=%fused_computation.slice.1 ROOT %convolution.3447 = bf16[8,128] convolution(bf16[8,128] %param_0.34523, s8[128,128] %fusion.67830), dim_labels=bf_io->bf } - + %fused_computation.inner.2 (param_0.34523: bf16[8,128], param_1.30691: s8[3,128,128], p2: s32[]) -> bf16[8,128] { %param_0.34523 = bf16[8,128] parameter(0) %param_1.30691 = s8[3,128,128] parameter(1) @@ -875,7 +875,7 @@ TEST_F(UnstackerTest, %constant.85694 = s32[] constant(0) ROOT %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} } - + %fused_computation.slice.2 (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] { %param_0.51117 = s8[3,128,128] parameter(0) p1 = s32[] parameter(1) @@ -1214,7 +1214,7 @@ TEST_F(UnstackerTest, UnstackDSAndDUSPatternNestedLoop) { offset = s32[] parameter(1) zero = s32[] constant(0) %dynamic-slice.22040 = bf16[1,1,8,257,128] - dynamic-slice(bf16[4,1,8,257,128] %param_0.51117, offset, zero, zero, zero, zero), dynamic_slice_sizes={1,1,8,257,128} + dynamic-slice(bf16[4,1,8,257,128] %param_0.51117, offset, zero, zero, zero, zero), dynamic_slice_sizes={1,1,8,257,128} ROOT %bitcast.31250 = bf16[1,8,257,128] bitcast(%dynamic-slice.22040) } @@ -1222,19 +1222,19 @@ TEST_F(UnstackerTest, UnstackDSAndDUSPatternNestedLoop) { %param_0.51117 = bf16[4,1,8,257,128] parameter(0) offset = s32[] parameter(1) zero = s32[] constant(0) - %dynamic-slice.22040 = bf16[1,1,8,257,128] dynamic-slice(bf16[4,1,8,257,128] %param_0.51117, offset, zero, zero, zero, zero), dynamic_slice_sizes={1,1,8,257,128} + %dynamic-slice.22040 = bf16[1,1,8,257,128] dynamic-slice(bf16[4,1,8,257,128] %param_0.51117, offset, zero, zero, zero, zero), dynamic_slice_sizes={1,1,8,257,128} ROOT %bitcast.31250 = bf16[1,8,257,128] bitcast(%dynamic-slice.22040) } inner.body { - loop_var.1 = (s32[], bf16[4,1,8,257,128], bf16[4,1,8,257,128]) parameter(0) - get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 - get-tuple-element.2 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=1 - get-tuple-element.3 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=2 - sliced = bf16[1,8,257,128] fusion(get-tuple-element.2, get-tuple-element.1), kind=kLoop, calls=%fused_computation.slice - sliced.2 = bf16[1,8,257,128] fusion(get-tuple-element.3, get-tuple-element.1), kind=kLoop,calls=%fused_computation.slice.2 - temp = bf16[1,8,257,128] add(sliced, sliced.2) - one = s32[] constant(1) idx = s32[] add(get-tuple-element.1, one) + loop_var.1 = (s32[], bf16[4,1,8,257,128], bf16[4,1,8,257,128]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + get-tuple-element.2 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=1 + get-tuple-element.3 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=2 + sliced = bf16[1,8,257,128] fusion(get-tuple-element.2, get-tuple-element.1), kind=kLoop, calls=%fused_computation.slice + sliced.2 = bf16[1,8,257,128] fusion(get-tuple-element.3, get-tuple-element.1), kind=kLoop,calls=%fused_computation.slice.2 + temp = bf16[1,8,257,128] add(sliced, sliced.2) + one = s32[] constant(1) idx = s32[] add(get-tuple-element.1, one) ROOT out = tuple(idx, get-tuple-element.2, get-tuple-element.3) } inner.condition { @@ -1245,7 +1245,7 @@ TEST_F(UnstackerTest, UnstackDSAndDUSPatternNestedLoop) { } outer.body { - loop_var.1 = (s32[], bf16[4,1,8,257,128], bf16[4,1,8,257,128]) parameter(0) + loop_var.1 = (s32[], bf16[4,1,8,257,128], bf16[4,1,8,257,128]) parameter(0) get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 get-tuple-element.2 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=1 get-tuple-element.3 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=2 @@ -1306,12 +1306,12 @@ TEST_F(UnstackerTest, UnstackDSAndDUSPatternLoopFeedingLoop) { %param_0.51117 = bf16[4,1,8,257,128] parameter(0) offset = s32[] parameter(1) zero = s32[] constant(0) - %dynamic-slice.22040 = bf16[1,1,8,257,128] dynamic-slice(bf16[4,1,8,257,128] %param_0.51117, offset, zero, zero, zero, zero), dynamic_slice_sizes={1,1,8,257,128} + %dynamic-slice.22040 = bf16[1,1,8,257,128] dynamic-slice(bf16[4,1,8,257,128] %param_0.51117, offset, zero, zero, zero, zero), dynamic_slice_sizes={1,1,8,257,128} ROOT %bitcast.31250 = bf16[1,8,257,128] bitcast(%dynamic-slice.22040) } first.body { - loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0) + loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0) get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),index=0 get-tuple-element.2 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=1 constant = bf16[1,8,257,128] constant({...}) @@ -1322,14 +1322,14 @@ TEST_F(UnstackerTest, UnstackDSAndDUSPatternLoopFeedingLoop) { ROOT out = tuple(idx, get-tuple-element.2) } first.condition { - loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0) + loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0) get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 - constant.2 = s32[] constant(4) + constant.2 = s32[] constant(4) ROOT less-than = pred[] compare(get-tuple-element.1, constant.2), direction=LT } - + next.body { - loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0) + loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0) get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),index=0 get-tuple-element.2 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=1 constant = bf16[1,8,257,128] constant({...}) @@ -1341,7 +1341,7 @@ TEST_F(UnstackerTest, UnstackDSAndDUSPatternLoopFeedingLoop) { next.condition { loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0) get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 - constant.2 = s32[] constant(4) + constant.2 = s32[] constant(4) ROOT less-than = pred[] compare(get-tuple-element.1, constant.2), direction=LT } @@ -1444,13 +1444,13 @@ TEST_F(UnstackerTest, UnstackDUSFusionWithPadPatternLoopFeedingLoop) { TEST_F(UnstackerTest, UnstackDUSFusionWithAddPattern) { std::string hlo_string = R"( HloModule SimpleLoop - + add.2771.reduce_sub_computation { lhs.44 = bf16[] parameter(0) rhs.44 = bf16[] parameter(1) ROOT add.3079 = bf16[] add(lhs.44, rhs.44) } - + fused_computation.75.clone { param_0.31658 = bf16[2,4096]{1,0:T(8,128)(2,1)} parameter(0) param_1.26202 = s32[]{:T(128)} parameter(1) diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc index 88823f1dd9e5c1..9e84f287beb874 100644 --- a/third_party/xla/xla/service/hlo_verifier.cc +++ b/third_party/xla/xla/service/hlo_verifier.cc @@ -2483,6 +2483,27 @@ absl::Status VerifyLayoutConstrainedAllReduce(const HloModule& module) { return absl::OkStatus(); } +// Verifies that leaf nodes in an original value contain values. +absl::Status VerifyOriginalValue(const HloModule& module) { + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + if (auto original_value = instruction->original_value()) { + // An original value is expected to have intermediate nodes that are + // always nullopt and leaves with actual values. + for (const auto& leaf : original_value->leaves()) { + if (!leaf.second.has_value()) { + return Internal( + "Leaf nodes in an original value is expected to contain values." + " Instruction: %s.", + instruction->ToString()); + } + } + } + } + } + return absl::OkStatus(); +} + // Checks various invariants of channel instructions (send/recv and // collectives). absl::Status VerifyChannels(const HloModule& module, @@ -3117,6 +3138,7 @@ absl::StatusOr HloVerifier::Run( TF_RETURN_IF_ERROR(module->buffer_donor_config().Verify(*module)); TF_RETURN_IF_ERROR(VerifyLayoutConstrainedAllReduce(*module)); + TF_RETURN_IF_ERROR(VerifyOriginalValue(*module)); return false; }(); if (status_or_changed.ok()) { diff --git a/third_party/xla/xla/service/hlo_verifier_test.cc b/third_party/xla/xla/service/hlo_verifier_test.cc index 419156664e7f46..6e2207726caeb2 100644 --- a/third_party/xla/xla/service/hlo_verifier_test.cc +++ b/third_party/xla/xla/service/hlo_verifier_test.cc @@ -3635,5 +3635,19 @@ TEST_F(HloVerifierTest, UnaryOpWithResultAccuracy) { EXPECT_TRUE(status.ok()) << status; } +TEST_F(HloVerifierTest, EmptyLeafInOriginalValue) { + const std::string hlo_string = R"( +HloModule module +ENTRY %entry_computation { + ROOT op = ((f32[], f32[3]{0}), f32[2,3]) parameter(0), origin={(({}, {"v2"}), {"v3"})} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + EXPECT_FALSE(status.ok()); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/host_memory_offload_annotations.h b/third_party/xla/xla/service/host_memory_offload_annotations.h index a0b7e3decaea38..e230fdc8b60764 100644 --- a/third_party/xla/xla/service/host_memory_offload_annotations.h +++ b/third_party/xla/xla/service/host_memory_offload_annotations.h @@ -26,10 +26,15 @@ inline const absl::string_view kDevicePlacement = "annotate_device_placement"; inline const absl::string_view kMemoryTargetPinnedHost = "pinned_host"; inline const absl::string_view kMemoryTargetUnpinnedHost = "unpinned_host"; inline const absl::string_view kMemoryTargetDevice = "device"; +inline const absl::string_view kMemoryTargetDeviceSram = "device_sram"; +inline const absl::string_view kMemoryTargetPinnedDevice = "pinned_device"; // Internal annotations: inline const absl::string_view kMoveToHostCustomCallTarget = "MoveToHost"; inline const absl::string_view kMoveToDeviceCustomCallTarget = "MoveToDevice"; +inline const absl::string_view kPinToDeviceCustomCallTarget = "PinToDevice"; +inline const absl::string_view kPinToDeviceSramCustomCallTarget = + "PinToDeviceSram"; } // namespace host_memory_offload_annotations } // namespace xla diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.cc b/third_party/xla/xla/service/latency_hiding_scheduler.cc index 7f2d7d2f187892..d199e1f046daa0 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/debug_options_flags.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/analysis/hlo_reachability.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -57,8 +58,6 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/util.h" #include "xla/xla.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" namespace xla { @@ -315,10 +314,11 @@ ResourcesVector AsyncTracker::GetResourcesFromInstructionImpl( ResourceUsageType::kResourceRelease) : std::make_pair(ResourceTypeToIndex(ResourceType::kSendRecv), ResourceUsageType::kResourceRelease)}; - case HloOpcode::kRecvDone: + case HloOpcode::kRecvDone: { + const HloSendRecvInstruction* recv = + DynCast(hlo.operand(0)); return ResourcesVector{ - static_cast(hlo.operand(0)) - ->is_host_transfer() + (recv != nullptr && recv->is_host_transfer()) ? std::make_pair( config_.force_send_recv_to_use_same_resource ? ResourceTypeToIndex(ResourceType::kSendHost) @@ -326,14 +326,17 @@ ResourcesVector AsyncTracker::GetResourcesFromInstructionImpl( ResourceUsageType::kResourceOccupy) : std::make_pair(ResourceTypeToIndex(ResourceType::kSendRecv), ResourceUsageType::kResourceOccupy)}; - case HloOpcode::kSendDone: + } + case HloOpcode::kSendDone: { + const HloSendRecvInstruction* send = + DynCast(hlo.operand(0)); return ResourcesVector{ - static_cast(hlo.operand(0)) - ->is_host_transfer() + (send != nullptr && send->is_host_transfer()) ? std::make_pair(ResourceTypeToIndex(ResourceType::kSendHost), ResourceUsageType::kResourceOccupy) : std::make_pair(ResourceTypeToIndex(ResourceType::kSendRecv), ResourceUsageType::kResourceOccupy)}; + } default: return ResourcesVector{}; } @@ -381,19 +384,17 @@ AsyncTracker::RecursivelyComputeResourceMap( int64_t AsyncTracker::GetNumResourcesPerInstruction( int64_t resource_type, const HloInstruction& instr) const { - // For instructions not calling a computation then return 1 if the instruction - // has opcode equal to 'async_done' + // For instructions not calling a computation, or async start/done + // instructions, we directly check the resources from the instruction. if (instr.called_computations().empty() || instr.opcode() == HloOpcode::kAsyncStart || instr.opcode() == HloOpcode::kAsyncDone) { - return absl::c_any_of(GetResourcesFromInstruction(instr), - [resource_type](const ResourcePair& resource) { - return resource.second == - ResourceUsageType::kResourceOccupy && - (resource_type == resource.first); - }) - ? 1 - : 0; + return absl::c_count_if(GetResourcesFromInstruction(instr), + [resource_type](const ResourcePair& resource) { + return resource.second == + ResourceUsageType::kResourceOccupy && + (resource_type == resource.first); + }); } int64_t num_resources = 0; for (const HloComputation* computation : instr.called_computations()) { @@ -1320,29 +1321,6 @@ DefaultSchedulerCore::FindAndExtractBestNodeAvailable( } absl::InlinedVector, 2> skipped_nodes_and_reasons; - if (!scheduling_instruction_crosses_overlap_limit_) { - scheduling_instruction_crosses_overlap_limit_ = - [](const SchedulingState& sched_state, const HloGraphNode* node) { - for (const auto& [resource, limit] : - sched_state.max_concurrent_resource) { - // No resources in flight of this kind. Continue. - auto it = sched_state.resource_occupiers_in_flight.find(resource); - if (it == sched_state.resource_occupiers_in_flight.end() || - it->second.empty()) { - continue; - } - // Number of instances of 'resource' needed if this instruction was - // to be scheduled. - const int64_t num_resources_needed = - sched_state.async_tracker->GetNumResourcesPerInstruction( - resource, node->GetInstr()); - if (limit < num_resources_needed) { - return true; - } - } - return false; - }; - } VLOG(2) << "Current time: " << sched_state.current_time; ReadySetLt ready_lt{&sched_state, target_scheduling_rule_, early_target_scheduling_rule_}; @@ -1913,13 +1891,18 @@ absl::StatusOr DefaultSchedulerCore::ScheduleNode( ++sched_state->scheduled_count; for (auto& resource : n->GetResources()) { if (resource.second == ResourceUsageType::kResourceRelease) { - sched_state->resource_occupiers_in_flight.at(resource.first) - .erase(&n->GetInstr()); + // Some recv-dones exist without a corresponding recv op in the same + // computation. In this case, we cannot find the corresponding start op + // and thus cannot erase the start op from the map. + if (sched_state->resource_occupiers_in_flight.contains(resource.first)) { + sched_state->resource_occupiers_in_flight.at(resource.first) + .erase(&n->GetInstr()); + } } else if (resource.second == ResourceUsageType::kResourceOccupy) { - // For async collective done ops, save their corresponding start ops to - // the map - if (async_tracker_->IsSupportedAsyncDone(n->GetInstr())) { - CHECK(async_tracker_->IsSupportedAsyncStart(*n->GetInstr().operand(0))); + // For supported async collective done ops, save their corresponding start + // ops in the map + if (async_tracker_->IsSupportedAsyncDone(n->GetInstr()) && + async_tracker_->IsSupportedAsyncStart(*n->GetInstr().operand(0))) { sched_state->resource_occupiers_in_flight[resource.first].insert( n->GetInstr().operand(0)); } else { @@ -2280,6 +2263,29 @@ absl::Status DefaultSchedulerCore::InitializeScheduler( if (VLOG_IS_ON(2)) { annotation_tracker_->PrintAnnotationSets(2); } + if (!scheduling_instruction_crosses_overlap_limit_) { + scheduling_instruction_crosses_overlap_limit_ = + [](const SchedulingState& sched_state, const HloGraphNode* node) { + for (const auto& [resource, limit] : + sched_state.max_concurrent_resource) { + // No resources in flight of this kind. Continue. + auto it = sched_state.resource_occupiers_in_flight.find(resource); + if (it == sched_state.resource_occupiers_in_flight.end() || + it->second.empty()) { + continue; + } + // Number of instances of 'resource' needed if this instruction was + // to be scheduled. + const int64_t num_resources_needed = + sched_state.async_tracker->GetNumResourcesPerInstruction( + resource, node->GetInstr()); + if (limit < num_resources_needed) { + return true; + } + } + return false; + }; + } return absl::OkStatus(); } @@ -2298,6 +2304,17 @@ absl::Status DefaultSchedulerCore::SchedulingStep( return absl::OkStatus(); } +bool DefaultSchedulerCore::SchedulingAnnotationCrossesOverlapLimit( + const SchedulingState& sched_state, int64_t annotation) { + for (const HloInstruction* instr : + annotation_tracker_->GetInstructions(annotation)) { + if (scheduling_instruction_crosses_overlap_limit_( + sched_state, &sched_state.sched_graph.GetNode(instr))) { + return true; + } + } + return false; +} absl::StatusOr> DefaultSchedulerCore::ScheduleComputation(const HloComputation* computation) { const HloSchedule& module_schedule = computation->parent()->schedule(); @@ -2364,16 +2381,30 @@ DefaultSchedulerCore::ScheduleComputation(const HloComputation* computation) { return absl::StrJoin(sched_state.ready_set, "\n", LogFormatter()); }()); if (!sched_state.ready_annotations.empty()) { - // TODO (sacer): If more than one annotations are ready, decide which one - // to schedule next with a heuristic. - int64_t annotation = sched_state.ready_annotations.back(); - sched_state.ready_annotations.pop_back(); - VLOG(2) << "------- BEGIN ANNOTATION: " << annotation << " -------"; - sched_state.ongoing_annotation = annotation; - TF_RETURN_IF_ERROR(ScheduleAnnotation(annotation, &sched_state)); - VLOG(2) << "------- END ANNOTATION: " << annotation << " --------"; - sched_state.ongoing_annotation = -1; - continue; + // Pick the first ready annotation whose scheduling will not cross the + // overlap limit. If there is no such annotation, continue with scheduling + // non-annotated ops. + int64_t annotation_index = -1; + for (int64_t i = 0; i < sched_state.ready_annotations.size(); ++i) { + if (SchedulingAnnotationCrossesOverlapLimit( + sched_state, sched_state.ready_annotations[i])) { + continue; + } + annotation_index = i; + break; + } + if (annotation_index != -1) { + std::swap(sched_state.ready_annotations[annotation_index], + sched_state.ready_annotations.back()); + int64_t annotation = sched_state.ready_annotations.back(); + sched_state.ready_annotations.pop_back(); + VLOG(2) << "------- BEGIN ANNOTATION: " << annotation << " -------"; + sched_state.ongoing_annotation = annotation; + TF_RETURN_IF_ERROR(ScheduleAnnotation(annotation, &sched_state)); + VLOG(2) << "------- END ANNOTATION: " << annotation << " --------"; + sched_state.ongoing_annotation = -1; + continue; + } } TF_RETURN_IF_ERROR(SchedulingStep(&sched_state)); } diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.h b/third_party/xla/xla/service/latency_hiding_scheduler.h index 1733c8b2fe8f9e..e1dffa0851a156 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.h +++ b/third_party/xla/xla/service/latency_hiding_scheduler.h @@ -43,12 +43,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/ir/ptrvec.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/map_util.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" #include "xla/shape_util.h" +#include "xla/side_effect_util.h" #include "xla/status_macros.h" #include "xla/xla.pb.h" @@ -359,8 +361,8 @@ class AnnotationTracker { } std::optional GetAnnotation(const HloInstruction* instr) const { const auto& attrs = instr->frontend_attributes().map(); - if (attrs.contains("_scheduling_group_id")) { - return std::stoi(attrs.at("_scheduling_group_id")); + if (attrs.contains(kXlaSchedulingGroupIdAttr)) { + return std::stoi(attrs.at(kXlaSchedulingGroupIdAttr)); } return std::nullopt; } @@ -376,10 +378,13 @@ class AnnotationTracker { annotations_[annotation].begin(), annotations_[annotation].end()); for (const HloInstruction* instr : annotations_.at(annotation)) { bool has_annotated_user = false; - for (HloInstruction* user : instr->users()) { - if (seen_instructions.contains(user)) { - has_annotated_user = true; - break; + for (const PtrVec& users : + {instr->users(), instr->control_successors()}) { + for (HloInstruction* user : users) { + if (seen_instructions.contains(user)) { + has_annotated_user = true; + break; + } } } if (!has_annotated_user) { @@ -1050,6 +1055,8 @@ class DefaultSchedulerCore : public SchedulerCore { this->config_.memory_limit = new_limit; } int64_t GetRerunTimes() override { return config_.rerun; } + bool SchedulingAnnotationCrossesOverlapLimit( + const SchedulingState& sched_state, int64_t annotation); protected: virtual void LogInstruction(const HloInstruction* instr) const; diff --git a/third_party/xla/xla/service/latency_hiding_scheduler_test.cc b/third_party/xla/xla/service/latency_hiding_scheduler_test.cc index 32495806b71581..a5508f7553a3fc 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler_test.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler_test.cc @@ -152,10 +152,15 @@ absl::StatusOr RunScheduler( /*convert_collective_permute=*/HloPredicateTrue}; TF_ASSIGN_OR_RETURN(bool value, AsyncCollectiveCreator(std::move(config)).Run(module)); - TF_ASSIGN_OR_RETURN(value, LegalizeSchedulingAnnotations().Run(module)); + TF_ASSIGN_OR_RETURN(value, LegalizeSchedulingAnnotations( + LegalizeSchedulingAnnotations::Config()) + .Run(module)); HloCostAnalysis::ShapeSizeFunction shape_size_bytes = [&shape_size_bytes](const Shape& shape) -> int64_t { int64_t shape_size = 0; + if (shape.IsToken()) { + return 0; + } if (shape.IsTuple()) { for (auto& sub_shape : shape.tuple_shapes()) { shape_size += shape_size_bytes(sub_shape); @@ -2555,9 +2560,9 @@ ENTRY entry { cp3d = f32[128,2048,2048]{2,1,0} collective-permute-done(cp3s) slice = f32[16,64,256]{2,1,0} slice(f32[512,2048,2048]{2,1,0} cp1d), slice={[0:16], [0:64], [0:256]} c0 = f32[16,256,256]{2,1,0} convolution(p0, slice), - window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb + window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb c1 = f32[16,256,256]{2,1,0} convolution(p0, slice), - window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb + window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb ROOT tuple.2 = (f32[16,256,256]{2,1,0}, f32[16,256,256]{2,1,0}, f32[128,2048,2048]{2,1,0}, f32[128,2048,2048]{2,1,0}) tuple(c0, c1, cp2d, cp3d) } )"; @@ -3083,22 +3088,22 @@ while_body { param = (bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, pred[]) parameter(0) gte0 = bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)} get-tuple-element(param), index=0 gte1 = bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)} get-tuple-element(param), index=1 - %add.0 = bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)} add(gte0, gte1) + add.0 = bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)} add(gte0, gte1) gte2 = pred[] get-tuple-element(param), index=2 - ROOT tuple = (bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, pred[]) tuple(%add.0, gte1, gte2) + ROOT tuple = (bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, pred[]) tuple(add.0, gte1, gte2) } ENTRY %entry { - %p0 = bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)} parameter(0) - %p1 = bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)} parameter(1) - %after-all = token[] after-all() - %send = (bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, u32[], token[]) send(bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)} %p0, token[] %after-all), channel_id=1246, is_host_transfer=true, frontend_attributes={_xla_host_transfer_handler_name="xla_megascale_runtime",_xla_host_transfer_rendezvous="collective-permute.145_0",_xla_megascale_target="{{200000->100000},{200001->100001},{200002->100002},{200003->100003},{200004->100004},{200005->100005},{200006->100006},{200007->100007},{200008->100008},{200009->100009},{200010->100010},{200011->100011},{200012->100012},{200013->100013},{200014->100014},{200015->100015},{200016->100016},{200017->100017},{200018->100018},{200019->100019},{200020->100020},{200021->100021},{200022->100022},{200023->100023},{200024->100024},{200025->100025},{200026->100026},{200027->100027},{200028->100028},{200029->100029},{200030->100030},{200031->100031},{200032->100032},{200033->100033},{200034->100034},{200035->100035},{200036->100036},{200037->100037},{200038->100038},{200039->100039},{200040->100040},{200041->100041},{200042->100042},{200043->100043},{200044->100044},{200045->100045},{200046->100046},{200047->100047},{200048->100048},{200049->100049},{200050->100050},{200051->100051},{200052->100052},{200053->100053},{200054->100054},{200055->100055},{200056->100056},{200057->100057},{200058->100058},{200059->100059},{200060->100060},{200061->100061},{200062->100062},{200063->100063},{200064->100064},{200065->100065},{200066->100066},{200067->100067},{200068->100068},{200069->100069},{200070->100070},{200071->100071},{200072->100072},{200073->100073},{200074->100074},{200075->100075},{200076->100076},{200077->100077},{200078->100078},{200079->100079},{200080->100080},{200081->100081},{200082->100082},{200083->100083},{200084->100084},{200085->100085},{200086->100086},{200087->100087},{200088->100088},{200089->100089},{200090->100090},{200091->100091},{200092->100092},{200093->100093},{200094->100094},{200095->100095},{200096->100096},{200097->100097},{200098->100098},{200099->100099},{200100->100100},{200101->100101},{200102->100102},{200103->100103},{200104->100104},{200105->100105},{200106->100106},{200107->100107},{200108->100108},{200109->100109},{200110->100110},{200111->100111},{200112->100112},{200113->100113},{200114->100114},{200115->100115},{200116->100116},{200117->100117},{200118->100118},{200119->100119},{200120->100120},{200121->100121},{200122->100122},{200123->100123},{200124->100124},{200125->100125},{200126->100126},{200127->100127}}",_xla_megascale_transfer_type="ONE_TO_ONE"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[],"customized_send_recv_config":{"dcn_collective_permute_send":{"non_source_slice_ids":[0]}}} - %recv = (bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, u32[], token[]) recv(token[] %after-all), channel_id=1247, is_host_transfer=true, frontend_attributes={_xla_host_transfer_handler_name="xla_megascale_runtime",_xla_host_transfer_rendezvous="collective-permute.145_0",_xla_megascale_target="{{200000->100000},{200001->100001},{200002->100002},{200003->100003},{200004->100004},{200005->100005},{200006->100006},{200007->100007},{200008->100008},{200009->100009},{200010->100010},{200011->100011},{200012->100012},{200013->100013},{200014->100014},{200015->100015},{200016->100016},{200017->100017},{200018->100018},{200019->100019},{200020->100020},{200021->100021},{200022->100022},{200023->100023},{200024->100024},{200025->100025},{200026->100026},{200027->100027},{200028->100028},{200029->100029},{200030->100030},{200031->100031},{200032->100032},{200033->100033},{200034->100034},{200035->100035},{200036->100036},{200037->100037},{200038->100038},{200039->100039},{200040->100040},{200041->100041},{200042->100042},{200043->100043},{200044->100044},{200045->100045},{200046->100046},{200047->100047},{200048->100048},{200049->100049},{200050->100050},{200051->100051},{200052->100052},{200053->100053},{200054->100054},{200055->100055},{200056->100056},{200057->100057},{200058->100058},{200059->100059},{200060->100060},{200061->100061},{200062->100062},{200063->100063},{200064->100064},{200065->100065},{200066->100066},{200067->100067},{200068->100068},{200069->100069},{200070->100070},{200071->100071},{200072->100072},{200073->100073},{200074->100074},{200075->100075},{200076->100076},{200077->100077},{200078->100078},{200079->100079},{200080->100080},{200081->100081},{200082->100082},{200083->100083},{200084->100084},{200085->100085},{200086->100086},{200087->100087},{200088->100088},{200089->100089},{200090->100090},{200091->100091},{200092->100092},{200093->100093},{200094->100094},{200095->100095},{200096->100096},{200097->100097},{200098->100098},{200099->100099},{200100->100100},{200101->100101},{200102->100102},{200103->100103},{200104->100104},{200105->100105},{200106->100106},{200107->100107},{200108->100108},{200109->100109},{200110->100110},{200111->100111},{200112->100112},{200113->100113},{200114->100114},{200115->100115},{200116->100116},{200117->100117},{200118->100118},{200119->100119},{200120->100120},{200121->100121},{200122->100122},{200123->100123},{200124->100124},{200125->100125},{200126->100126},{200127->100127}}",_xla_megascale_transfer_type="ONE_TO_ONE"}, control-predecessors={%send}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[],"customized_send_recv_config":{"dcn_collective_permute_recv":{"non_target_slice_ids":[1]}}} - %recv-done = (bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, token[]) recv-done((bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, u32[], token[]) %recv), channel_id=1247, is_host_transfer=true, backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[],"customized_send_recv_config":{"dcn_collective_permute_recv":{"non_target_slice_ids":[1]}}} - %get-tuple-element = bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)} get-tuple-element((bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, token[]) %recv-done), index=0 - %send-done = token[] send-done((bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, u32[], token[]) %send), channel_id=1246, is_host_transfer=true, control-predecessors={%recv-done}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[],"customized_send_recv_config":{"dcn_collective_permute_send":{"non_source_slice_ids":[0]}}} - %p2 = pred[] parameter(2) - tuple = (bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, pred[]) tuple(%get-tuple-element, %p1, %p2) + p0 = bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)} parameter(0) + p1 = bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)} parameter(1) + after-all = token[] after-all() + send = (bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, u32[], token[]) send(p0, after-all), channel_id=1246 + recv = (bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, u32[], token[]) recv(after-all), channel_id=1247 + recv-done = (bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, token[]) recv-done(recv), channel_id=1247 + get-tuple-element = bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)} get-tuple-element(recv-done), index=0 + send-done = token[] send-done(send), channel_id=1246, control-predecessors={recv-done} + p2 = pred[] parameter(2) + tuple = (bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, pred[]) tuple(get-tuple-element, p1, p2) while = (bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)}, pred[]) while(tuple), condition=while_cond, body=while_body ROOT gte0 = bf16[1,1,4096,1344]{2,3,1,0:T(8,128)(2,1)} get-tuple-element(while), index=0 } @@ -3632,7 +3637,7 @@ ENTRY entry { cp1d = f32[128,2048,2048]{2,1,0} collective-permute-done(cp1s), frontend_attributes={_scheduling_group_id="1"} f0 = f32[16,256,256]{2,1,0} fusion(p0, p0), kind=kOutput, calls=fused_computation, frontend_attributes={_scheduling_group_id="0"} f1 = f32[1,256,256]{2,1,0} fusion(f0, f0), kind=kOutput, calls=fused_computation.1, frontend_attributes={_scheduling_group_id="1"} - ROOT tuple = (f32[128,2048,2048]{2,1,0}, f32[1,256,256]{2,1,0}) tuple(cp1d, f1) + ROOT tuple = (f32[128,2048,2048]{2,1,0}, f32[1,256,256]{2,1,0}) tuple(cp1d, f1) } )"; @@ -3680,7 +3685,7 @@ ENTRY entry { p1 = f32[128,2048,2048]{2,1,0} parameter(1) cp0s = (f32[128,2048,2048]{2,1,0}, f32[128,2048,2048]{2,1,0}, u32[], u32[]) collective-permute-start(p1), source_target_pairs={{1,0},{0,3},{3,2}}, frontend_attributes={_scheduling_group_id="0"} cp0d = f32[128,2048,2048]{2,1,0} collective-permute-done(cp0s), frontend_attributes={_scheduling_group_id="0"} - ROOT f0 = f32[16,256,256]{2,1,0} fusion(p0, p0), kind=kOutput, calls=fused_computation, frontend_attributes={_scheduling_group_id="0"} + ROOT f0 = f32[16,256,256]{2,1,0} fusion(p0, p0), kind=kOutput, calls=fused_computation, frontend_attributes={_scheduling_group_id="0"} } )"; @@ -3757,4 +3762,107 @@ ENTRY entry { GetIndex(new_instruction_sequence, "cpd")); } +TEST_F(LatencyHidingSchedulerTest, OutOfOrderStartAndDone) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +while_condition { + tuple = ((f32[16,16], u32[], token[]), f32[16,16], u32[]) parameter(0) + i = get-tuple-element(tuple), index=2 + n = u32[] constant(2) + ROOT predicate = pred[] compare(i, n), direction=LT +} + +while_body { + tuple = ((f32[16,16], u32[], token[]), f32[16,16], u32[]) parameter(0) + gte = get-tuple-element(tuple), index=0 + param = get-tuple-element(tuple), index=1 + i = get-tuple-element(tuple), index=2 + dot = f32[16,16] dot(param, param), lhs_contracting_dims={0}, rhs_contracting_dims={1} + recv_done = (f32[16], token[]) recv-done(gte), frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + after_all = token[] after-all() + recv = (f32[16,16], u32[], token[]) recv(after_all), frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, control-predecessors={recv_done} + c1 = u32[] constant(1) + add = add(i, c1) + ROOT tuple_ = ((f32[16,16], u32[], token[]), f32[16,16], u32[]) tuple(recv, dot, add) +} + +ENTRY main { + param0 = f32[16,16] parameter(0) + after_all0 = token[] after-all() + recv0 = (f32[16,16], u32[], token[]) recv(after_all0), frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + c0 = u32[] constant(0) + tuple = ((f32[16,16], u32[], token[]), f32[16,16], u32[]) tuple(recv0, param0, c0) + while = ((f32[16,16], u32[], token[]), f32[16,16], u32[]) while(tuple), body=while_body, condition=while_condition + gte0 = (f32[16,16], u32[], token[]) get-tuple-element(while), index=0 + ROOT recv_done0 = (f32[16], token[]) recv-done(gte0), frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string)); + HloSchedule& module_schedule = hlo_module->schedule(); + EXPECT_TRUE(hlo_module->has_entry_computation()); + auto sched_config = GetDefaultSchedConfig(); + sched_config.schedule_send_recvs = true; + sched_config.send_recv_host_overlap_limit = 2; + EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config, + std::make_unique()) + .ok()); + EXPECT_TRUE(hlo_module->has_entry_computation()); + + std::vector new_instruction_sequence = + module_schedule.sequence(hlo_module->entry_computation()).instructions(); + if (VLOG_IS_ON(1)) { + for (auto* new_i : new_instruction_sequence) { + VLOG(1) << new_i->ToString(); + } + } +} + +TEST_F(LatencyHidingSchedulerTest, SchedulingAnnotationCrossesOverlapLimit) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p0 = f32[16,64,256]{2,1,0} parameter(0) + p1 = f32[128,2048,2048]{2,1,0} parameter(1) + cp1s = (f32[128,2048,2048]{2,1,0}, f32[128,2048,2048]{2,1,0}, u32[], u32[]) collective-permute-start(p1), source_target_pairs={{1,0},{0,3},{3,2}}, frontend_attributes={_scheduling_group_id="0"} + cp1d = f32[128,2048,2048]{2,1,0} collective-permute-done(cp1s), frontend_attributes={_scheduling_group_id="0"} + cp2s = (f32[128,2048,2048]{2,1,0}, f32[128,2048,2048]{2,1,0}, u32[], u32[]) collective-permute-start(p1), source_target_pairs={{1,0},{0,3},{3,2}} + cp2d = f32[128,2048,2048]{2,1,0} collective-permute-done(cp2s) + slice = f32[16,64,256]{2,1,0} slice(cp1d), slice={[0:16], [0:64], [0:256]} + c1 = f32[16,256,256]{2,1,0} convolution(p0, p0), + window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb, frontend_attributes={_scheduling_group_id="0"} + c2 = f32[16,256,256]{2,1,0} convolution(p0, slice), + window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb + ROOT tuple.2 = (f32[16,256,256]{2,1,0}, f32[16,256,256]{2,1,0}, f32[128,2048,2048]{2,1,0}) tuple(c1, c2, cp2d) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string)); + HloSchedule& module_schedule = hlo_module->schedule(); + EXPECT_TRUE(hlo_module->has_entry_computation()); + auto sched_config = GetDefaultSchedConfig(); + sched_config.collective_permute_overlap_limit = 1; + EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config, + std::make_unique()) + .ok()); + EXPECT_TRUE(hlo_module->has_entry_computation()); + + std::vector new_instruction_sequence = + module_schedule.sequence(hlo_module->entry_computation()).instructions(); + if (VLOG_IS_ON(1)) { + for (auto* new_i : new_instruction_sequence) { + VLOG(1) << new_i->ToString(); + } + } + + // With the overlap limit of 1 on collective permutes, we cannot schedule the + // scheduling group with annotation 0 right after it becomes ready, because + // cp2's overlap would be open at that moment. cp1 can be scheduled only after + // cp2 is closed (in the reverse order). + EXPECT_LT(GetIndex(new_instruction_sequence, "cp1d"), + GetIndex(new_instruction_sequence, "cp2s")); +} + } // namespace xla diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index eef57904b2f296..58af55e7faa6d3 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -657,6 +657,20 @@ absl::Status PropagateParameterLayoutToUsers(const HloInstruction* instruction, return absl::OkStatus(); } +absl::Status ResetMemorySpaceInLayout(ShapeLayout& mutable_shape_layout) { + Shape shape = mutable_shape_layout.shape(); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( + &shape, [](Shape* subshape, const ShapeIndex& shape_index) { + if (subshape->has_layout() && subshape->IsArray()) { + subshape->mutable_layout()->set_memory_space( + Layout::kDefaultMemorySpace); + } + return absl::OkStatus(); + })); + TF_RETURN_IF_ERROR(mutable_shape_layout.CopyLayoutFromShape(shape)); + return absl::OkStatus(); +} + } // namespace absl::Status LayoutAssignment::AddMandatoryConstraints( @@ -693,27 +707,18 @@ absl::Status LayoutAssignment::AddMandatoryConstraints( entry_computation_layout_->AnyLayoutSet()) || (conditional_mismatch_.count(constraints->computation()) == 0 && constraints->computation_constraint().parameter_layout_is_set())) { - const ShapeLayout& parameter_layout = + ShapeLayout parameter_layout = constraints->computation_layout().parameter_layout( instruction->parameter_number()); // Allow some paramter/result layouts to be unset in the entry // computation. if (parameter_layout.AnyLayoutIsSet()) { + // Clear out memory space in layout. Host offloader will do the + // analysis later. + TF_RETURN_IF_ERROR(ResetMemorySpaceInLayout(parameter_layout)); // Parameter layouts must match the respective layout in // ComputationLayout, if there is one. Shape param_shape = parameter_layout.shape(); - // Clear out memory space in layout. Host offloader will do the - // analysis later. - TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( - ¶m_shape, [](Shape* subshape, const ShapeIndex& index) { - if (!subshape->has_layout() || !subshape->IsArray()) { - return absl::OkStatus(); - } - subshape->mutable_layout()->set_memory_space( - Layout::kDefaultMemorySpace); - return absl::OkStatus(); - })); - TF_RETURN_IF_ERROR(SetInstructionLayout(param_shape, instruction)); if (reverse_computation_order_) { TF_RETURN_IF_ERROR(PropagateParameterLayoutToUsers( @@ -771,10 +776,6 @@ absl::Status LayoutAssignment::AddMandatoryConstraints( get_channel_constraints(instruction) ->LayoutShapeForChannel(buffer_shape, channel_id); TF_RETURN_IF_ERROR(SetInstructionLayout(new_buffer_shape, instruction)); - } else if (instruction->preserve_layout()) { - TF_RETURN_IF_ERROR(SetInstructionLayout(instruction->shape(), instruction, - /*mandatory=*/true, /*dfs=*/true, - /*allow_alias=*/true)); } } @@ -2033,16 +2034,7 @@ absl::Status LayoutAssignment::PropagateResultConstraint( // Clear out memory space in layout for entry computation root. Host offloader // will do the analysis later and add back the memory space for host outputs. if (constraints->computation()->IsEntryComputation()) { - Shape result_shape = result_layout.shape(); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( - &result_shape, [](Shape* subshape, const ShapeIndex& shape_index) { - if (subshape->has_layout() && subshape->IsArray()) { - subshape->mutable_layout()->set_memory_space( - Layout::kDefaultMemorySpace); - } - return absl::OkStatus(); - })); - TF_RETURN_IF_ERROR(result_layout.CopyLayoutFromShape(result_shape)); + TF_RETURN_IF_ERROR(ResetMemorySpaceInLayout(result_layout)); } // Propagate the use constraint of the root instruction up to the logical @@ -2232,25 +2224,29 @@ absl::Status LayoutAssignment::AssignLayouts(LayoutConstraints& constraints) { // layout constraint. if (constraints.ResultLayout() != nullptr && constraints.ResultLayout()->LayoutIsSet()) { + ShapeLayout result_layout = *constraints.ResultLayout(); + // Clear out memory space in layout. Host offloader will do the + // analysis later. + TF_RETURN_IF_ERROR(ResetMemorySpaceInLayout(result_layout)); // Layout assignment at this point only does minor-to-major assignment so // tiling info should be ignored here for comparison. VLOG(5) << "Computation result layout needs root copying\n"; - if (!constraints.ResultLayout()->MatchesLayoutInShape( + if (!result_layout.MatchesLayoutInShape( computation->root_instruction()->shape(), /*minor_to_major_only=*/true)) { TF_ASSIGN_OR_RETURN( HloInstruction * new_root, - CreateCopyWithNewLayout(constraints.ResultLayout()->shape(), + CreateCopyWithNewLayout(result_layout.shape(), computation->root_instruction())); computation->set_root_instruction(new_root); } else { // Copy the tiling info/tail_padding_alignment_in_elements specified in // result layout. - auto copy_tiling = [&constraints](xla::Shape* subshape, - const xla::ShapeIndex& index) { + auto copy_tiling = [&result_layout](xla::Shape* subshape, + const xla::ShapeIndex& index) { if (subshape->IsArray()) { - const Shape& result_shape = ShapeUtil::GetSubshape( - constraints.ResultLayout()->shape(), index); + const Shape& result_shape = + ShapeUtil::GetSubshape(result_layout.shape(), index); if (result_shape.layout().tiles_size() != 0) { subshape->mutable_layout()->mutable_tiles()->assign( result_shape.layout().tiles().begin(), @@ -2418,8 +2414,7 @@ absl::Status LayoutAssignment::ClearComputationLayouts( // Some instructions carry mandatory layouts in their shape. if (instruction->opcode() != HloOpcode::kInfeed && !IsLayoutConstrainedCustomCall(instruction) && - !IsLayoutConstrainedCollective(instruction) && - !instruction->preserve_layout()) { + !IsLayoutConstrainedCollective(instruction)) { LayoutUtil::ClearLayout(instruction->mutable_shape()); } } diff --git a/third_party/xla/xla/service/layout_assignment_test.cc b/third_party/xla/xla/service/layout_assignment_test.cc index 3cd4a872bff55d..ee547b10f3fbf5 100644 --- a/third_party/xla/xla/service/layout_assignment_test.cc +++ b/third_party/xla/xla/service/layout_assignment_test.cc @@ -1367,6 +1367,59 @@ ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}}); } +TEST_F(LayoutAssignmentTest, MemorySpaceRemoved) { + const char* module_str = R"( +HloModule MixedHostDeviceResult + +ENTRY %MixedHostDeviceResult { + %p0 = f32[4,4] parameter(0) + %d = f32[4,4]{1,0} custom-call(%p0), custom_call_target="MoveToDevice", metadata={} + ROOT %tuple = (f32[4,4], f32[4,4]) tuple(%p0, %d) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = m->entry_computation_layout(); + + // Set the parameter to be in host memory. + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithDenseLayout( + F32, {4, 4}, {1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, + Layout::kHostMemorySpace)); + // Set one result component to be in host memory, the other one on device. + // Also make sure to request incompatible result layout so that the layout + // assignment pass has to copy the layout from the entry computation layout. + *computation_layout.mutable_result_layout() = + ShapeLayout(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithDenseLayout( + F32, {4, 4}, {1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, + /*element_size_in_bits=*/0, Layout::kHostMemorySpace), + ShapeUtil::MakeShapeWithDenseLayout( + F32, {4, 4}, {0, 1}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, + /*element_size_in_bits=*/0, Layout::kDefaultMemorySpace)})); + AssignLayouts(m.get(), &computation_layout); + + // Verify that the memory space did not leak from the entry computation layout + // to the parameter or to the result. + Shape result_shape = m->entry_computation()->root_instruction()->shape(); + EXPECT_EQ( + ShapeUtil::GetTupleElementShape(result_shape, 0).layout().memory_space(), + Layout::kDefaultMemorySpace); + EXPECT_EQ( + ShapeUtil::GetTupleElementShape(result_shape, 1).layout().memory_space(), + Layout::kDefaultMemorySpace); + + const HloInstruction* parameter = FindInstruction(m.get(), "p0"); + EXPECT_EQ(parameter->shape().layout().memory_space(), + Layout::kDefaultMemorySpace); + + ExpectTupleLayoutIs(result_shape, {{1, 0}, {0, 1}}); +} + absl::Status AssignLayoutsToComputation( HloModule* m, ChannelLayoutConstraints* channel_constraints = nullptr) { if (!m->entry_computation_layout().result_layout().LayoutIsSet()) { @@ -1673,33 +1726,6 @@ TEST_F(LayoutAssignmentTest, PropagateOperandLayout2) { ExpectLayoutIs(reshape_3->shape(), {3, 1, 2, 0}); } -// Test the ability to preset layout for instruction. -TEST_F(LayoutAssignmentTest, PreserveInstructionLayout) { - const char* module_str = R"( - HloModule TensorFlowGather, entry_computation_layout={(f32[32,650]{1,0},s32[16,1,18]{0,1,2})->(f32[16,1,18,32]{3,1,2,0})} - - ENTRY %main { - %operand = f32[32,650]{1,0} parameter(0) - %transpose = f32[650,32]{0,1} transpose(f32[32,650]{1,0} %operand), dimensions={1,0} - %indices = s32[16,1,18]{0,1,2} parameter(1) - %reshape.1 = s32[288,1]{1,0} reshape(s32[16,1,18]{0,1,2} %indices) - %gather.1 = f32[288,1,32]{2,1,0} gather(f32[650,32]{0,1} %transpose, s32[288,1]{1,0} %reshape.1), offset_dims={1,2}, collapsed_slice_dims={}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,32} - %reshape.3 = f32[16,1,18,32]{3,2,1,0} reshape(f32[288,1,32]{2,1,0} %gather.1), metadata={preserve_layout=true} - ROOT %tuple.1 = (f32[16,1,18,32]{3,1,2,0}) tuple(reshape.3) - } )"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, - ParseAndReturnVerifiedModule(module_str)); - - LayoutAssignment layout_assignment(m->mutable_entry_computation_layout(), - nullptr); - EXPECT_IS_OK(layout_assignment.Run(m.get()).status()); - const HloInstruction* reshape_1 = FindInstruction(m.get(), "reshape.1"); - ExpectLayoutIs(reshape_1->shape(), {1, 0}); - const HloInstruction* reshape_3 = FindInstruction(m.get(), "reshape.3"); - ExpectLayoutIs(reshape_3->shape(), {3, 2, 1, 0}); -} - // Different instructions should not share buffers when assigning layout. TEST_F(LayoutAssignmentTest, BreakBufferAliasAcrossInstructions) { const char* module_str = R"( @@ -1714,7 +1740,7 @@ called_computation { ENTRY main { init = f32[256,8] parameter(0) - ROOT start = f32[256,8]{1,0} custom-call(init), custom_call_target="baz", to_apply=called_computation, custom_call_has_side_effect=true, output_to_operand_aliasing={{}: (0, {})}, metadata={preserve_layout=true} + ROOT start = f32[256,8]{1,0} custom-call(init), custom_call_target="baz", to_apply=called_computation, custom_call_has_side_effect=true, output_to_operand_aliasing={{}: (0, {})}, metadata={} } )"; @@ -1768,7 +1794,7 @@ TEST_F(LayoutAssignmentTest, TupleEntryParameterLayoutNoResultConstraint) { ENTRY %main { p = (f32[32,650],s32[16,1,18]) parameter(0) - operand = f32[32,650] get-tuple-element(p), index=0 + operand = f32[32,650] get-tuple-element(p), index=0 reshape = f32[208,100] reshape(operand) indices = s32[16,1,18] get-tuple-element(p), index=1 reshape_indices = s32[2,144] reshape(indices) @@ -1802,7 +1828,7 @@ TEST_F(LayoutAssignmentTest, ENTRY %main { p = (f32[32,650],s32[16,1,18]) parameter(0) - operand = f32[32,650] get-tuple-element(p), index=0 + operand = f32[32,650] get-tuple-element(p), index=0 reshape = f32[208,100] reshape(operand) indices = s32[16,1,18] get-tuple-element(p), index=1 reshape_indices = s32[2,144] reshape(indices) diff --git a/third_party/xla/xla/service/legalize_scheduling_annotations.cc b/third_party/xla/xla/service/legalize_scheduling_annotations.cc index 3f863c5796812b..4cb57a7fcafd9a 100644 --- a/third_party/xla/xla/service/legalize_scheduling_annotations.cc +++ b/third_party/xla/xla/service/legalize_scheduling_annotations.cc @@ -32,6 +32,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/ptrvec.h" +#include "xla/side_effect_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -41,19 +43,40 @@ absl::StatusOr ExtractAnnotation( const ::google::protobuf::Map& attrs, absl::string_view instr_name) { int64_t annotation_id; - if (!absl::SimpleAtoi(attrs.at("_scheduling_group_id"), &annotation_id)) { + if (!absl::SimpleAtoi(attrs.at(kXlaSchedulingGroupIdAttr), &annotation_id)) { return absl::InvalidArgumentError(absl::StrCat( "Instruction has a non-integer scheduling annotation, inst: ", - instr_name, ", annotation: ", attrs.at("_scheduling_group_id"))); + instr_name, ", annotation: ", attrs.at(kXlaSchedulingGroupIdAttr))); } if (annotation_id < 0) { return absl::InvalidArgumentError(absl::StrCat( "Instruction has a negative scheduling annotation, inst: ", instr_name, - ", annotation: ", attrs.at("_scheduling_group_id"))); + ", annotation: ", attrs.at(kXlaSchedulingGroupIdAttr))); } return annotation_id; } +void DropSchedulingAnnotation(HloInstruction* instr) { + VLOG(2) << "Dropping annotation from " << instr->name(); + FrontendAttributes frontend_attributes = instr->frontend_attributes(); + frontend_attributes.mutable_map()->erase("_scheduling_group_id"); + instr->set_frontend_attributes(frontend_attributes); +} + +bool IsSupportedAsyncOp(HloInstruction* instr) { + return HloPredicateIsOp< + HloOpcode::kAllGatherDone, HloOpcode::kAllGatherStart, + HloOpcode::kAllReduceDone, HloOpcode::kAllReduceStart, + HloOpcode::kCollectivePermuteDone, HloOpcode::kCollectivePermuteStart, + HloOpcode::kAsyncDone, HloOpcode::kAsyncStart, HloOpcode::kSendDone, + HloOpcode::kSend, HloOpcode::kRecvDone, HloOpcode::kRecv>(instr); +} + +bool LegalizeSchedulingAnnotations::KeepSchedulingAnnotation( + HloInstruction* instr) { + return IsSupportedAsyncOp(instr) || config_.keep_sync_annotation(instr); +} + absl::StatusOr LegalizeSchedulingAnnotations::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { @@ -61,16 +84,28 @@ absl::StatusOr LegalizeSchedulingAnnotations::Run( absl::flat_hash_map annotation_to_computation; absl::flat_hash_map> annotation_to_instructions; + // Filter the annotated ops (using config) to keep the annotations only in the + // desired sync ops. Annotations in all async ops are kept. + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instr : computation->instructions()) { + if (!instr->frontend_attributes().map().contains( + "_scheduling_group_id") || + KeepSchedulingAnnotation(instr)) { + continue; + } + DropSchedulingAnnotation(instr); + } + } // Find the annotated instructions and save relevant information. for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { for (HloInstruction* instr : computation->instructions()) { const auto& attrs = instr->frontend_attributes().map(); - if (!attrs.contains("_scheduling_group_id")) { + if (!attrs.contains(kXlaSchedulingGroupIdAttr)) { continue; } VLOG(1) << "Annotated instruction: " << instr->name() << " " - << attrs.at("_scheduling_group_id"); + << attrs.at(kXlaSchedulingGroupIdAttr); TF_ASSIGN_OR_RETURN(int64_t annotation_id, ExtractAnnotation(attrs, instr->name())); if (annotation_to_computation.contains(annotation_id) && @@ -93,13 +128,14 @@ absl::StatusOr LegalizeSchedulingAnnotations::Run( // there are some fused instructions with different annotations. for (HloComputation* computation : module->computations(execution_threads)) { if (!computation->IsFusionComputation() || + !config_.keep_sync_annotation(computation->FusionInstruction()) || annotation.contains(computation->FusionInstruction())) { continue; } int64_t seen_annotation = -1; for (HloInstruction* instr : computation->instructions()) { const auto& attrs = instr->frontend_attributes().map(); - if (!attrs.contains("_scheduling_group_id")) { + if (!attrs.contains(kXlaSchedulingGroupIdAttr)) { continue; } TF_ASSIGN_OR_RETURN(int64_t annotation_id, @@ -123,13 +159,14 @@ absl::StatusOr LegalizeSchedulingAnnotations::Run( FrontendAttributes frontend_attributes = computation->FusionInstruction()->frontend_attributes(); frontend_attributes.mutable_map()->insert( - {"_scheduling_group_id", std::to_string(seen_annotation)}); + {kXlaSchedulingGroupIdAttr, std::to_string(seen_annotation)}); computation->FusionInstruction()->set_frontend_attributes( frontend_attributes); } if (annotation_to_computation.empty()) { return false; } + absl::flat_hash_map parent; for (const auto& [id, annotated_instructions] : annotation_to_instructions) { // First find the frontier nodes that are not annotated with id but use an // annotated instruction with id. @@ -147,13 +184,17 @@ absl::StatusOr LegalizeSchedulingAnnotations::Run( "Done instruction's operand is not annotated with the same id: ", instr->operand(0)->name(), ", annotation: ", id)); } - for (HloInstruction* user : instr->users()) { - if (!visited.contains(user) && - (!annotation.contains(user) || annotation[user] != id)) { - stack.push_back(user); - visited.insert(user); - VLOG(2) << "Annotation group: " << id - << ", frontier using a root: " << user->name(); + for (const PtrVec& users : + {instr->users(), instr->control_successors()}) { + for (HloInstruction* user : users) { + if (!visited.contains(user) && + (!annotation.contains(user) || annotation[user] != id)) { + stack.push_back(user); + parent[user] = instr; + visited.insert(user); + VLOG(2) << "Annotation group: " << id + << ", frontier using a root: " << user->name(); + } } } } @@ -165,20 +206,31 @@ absl::StatusOr LegalizeSchedulingAnnotations::Run( while (!stack.empty()) { HloInstruction* instr = stack.back(); stack.pop_back(); - for (HloInstruction* user : instr->users()) { - if (annotation.contains(user) && annotation[user] == id) { - return absl::UnimplementedError( - absl::StrCat("Support for annotation groups with gaps doesn't " - "exist yet, annotation: ", - id, ", instr: ", user->name(), - " has the same annotation in its operand tree but " - "has gaps on the way from that operand to itself.")); - } - if (visited.contains(user)) { - continue; + for (const PtrVec& users : + {instr->users(), instr->control_successors()}) { + for (HloInstruction* user : users) { + if (annotation.contains(user) && annotation[user] == id) { + LOG(INFO) << "PATH: " << user->name(); + HloInstruction* current = instr; + LOG(INFO) << "PATH: " << current->name(); + while (parent.contains(current)) { + current = parent[current]; + LOG(INFO) << "PATH: " << current->name(); + } + return absl::UnimplementedError(absl::StrCat( + "Support for annotation groups with gaps doesn't " + "exist yet, annotation: ", + id, ", instr: ", user->name(), + " has the same annotation in its operand tree but " + "has gaps on the way from that operand to itself.")); + } + if (visited.contains(user)) { + continue; + } + stack.push_back(user); + parent[user] = instr; + visited.insert(user); } - stack.push_back(user); - visited.insert(user); } } } diff --git a/third_party/xla/xla/service/legalize_scheduling_annotations.h b/third_party/xla/xla/service/legalize_scheduling_annotations.h index e83301745c526f..49b02271110b86 100644 --- a/third_party/xla/xla/service/legalize_scheduling_annotations.h +++ b/third_party/xla/xla/service/legalize_scheduling_annotations.h @@ -16,11 +16,14 @@ limitations under the License. #ifndef XLA_SERVICE_LEGALIZE_SCHEDULING_ANNOTATIONS_H_ #define XLA_SERVICE_LEGALIZE_SCHEDULING_ANNOTATIONS_H_ +#include + #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/util.h" namespace xla { @@ -28,7 +31,12 @@ namespace xla { // LatencyHidingScheduler). class LegalizeSchedulingAnnotations : public HloModulePass { public: - LegalizeSchedulingAnnotations() = default; + struct Config { + HloPredicate keep_sync_annotation = HloPredicateTrue; + }; + + explicit LegalizeSchedulingAnnotations(Config config) + : config_(std::move(config)) {} absl::string_view name() const override { return "legalize-scheduling-annotations"; } @@ -36,6 +44,10 @@ class LegalizeSchedulingAnnotations : public HloModulePass { absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; + + private: + bool KeepSchedulingAnnotation(HloInstruction* instr); + Config config_; }; } // namespace xla diff --git a/third_party/xla/xla/service/legalize_scheduling_annotations_test.cc b/third_party/xla/xla/service/legalize_scheduling_annotations_test.cc index 41ca53294fd841..b724fac21307fd 100644 --- a/third_party/xla/xla/service/legalize_scheduling_annotations_test.cc +++ b/third_party/xla/xla/service/legalize_scheduling_annotations_test.cc @@ -20,11 +20,14 @@ limitations under the License. #include #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/test_helpers.h" +#include "xla/side_effect_util.h" #include "xla/test_helpers.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -46,9 +49,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, NonIntegerAnnotation) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, ParseAndReturnVerifiedModule(hlo_string)); - + LegalizeSchedulingAnnotations::Config config; EXPECT_IS_NOT_OK( - LegalizeSchedulingAnnotations().Run(hlo_module.get()).status()); + LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status()); } TEST_F(LegalizeSchedulingAnnotationsTest, MultipleAnnotations) { @@ -68,9 +71,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, MultipleAnnotations) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, ParseAndReturnVerifiedModule(hlo_string)); - + LegalizeSchedulingAnnotations::Config config; EXPECT_IS_NOT_OK( - LegalizeSchedulingAnnotations().Run(hlo_module.get()).status()); + LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status()); } TEST_F(LegalizeSchedulingAnnotationsTest, NegativeAnnotation) { @@ -88,9 +91,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, NegativeAnnotation) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, ParseAndReturnVerifiedModule(hlo_string)); - + LegalizeSchedulingAnnotations::Config config; EXPECT_IS_NOT_OK( - LegalizeSchedulingAnnotations().Run(hlo_module.get()).status()); + LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status()); } TEST_F(LegalizeSchedulingAnnotationsTest, CrossComputationAnnotation) { @@ -128,9 +131,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, CrossComputationAnnotation) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, ParseAndReturnVerifiedModule(hlo_string)); - + LegalizeSchedulingAnnotations::Config config; EXPECT_IS_NOT_OK( - LegalizeSchedulingAnnotations().Run(hlo_module.get()).status()); + LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status()); } TEST_F(LegalizeSchedulingAnnotationsTest, AnnotationWithGaps) { @@ -152,9 +155,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, AnnotationWithGaps) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, ParseAndReturnVerifiedModule(hlo_string)); - + LegalizeSchedulingAnnotations::Config config; EXPECT_IS_NOT_OK( - LegalizeSchedulingAnnotations().Run(hlo_module.get()).status()); + LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status()); } TEST_F(LegalizeSchedulingAnnotationsTest, AnnotationWithGaps2) { @@ -176,9 +179,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, AnnotationWithGaps2) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, ParseAndReturnVerifiedModule(hlo_string)); - + LegalizeSchedulingAnnotations::Config config; EXPECT_IS_NOT_OK( - LegalizeSchedulingAnnotations().Run(hlo_module.get()).status()); + LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status()); } TEST_F(LegalizeSchedulingAnnotationsTest, MissingAnnotationInStart) { @@ -196,9 +199,9 @@ TEST_F(LegalizeSchedulingAnnotationsTest, MissingAnnotationInStart) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, ParseAndReturnVerifiedModule(hlo_string)); - + LegalizeSchedulingAnnotations::Config config; EXPECT_IS_NOT_OK( - LegalizeSchedulingAnnotations().Run(hlo_module.get()).status()); + LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status()); } TEST_F(LegalizeSchedulingAnnotationsTest, MoveFusedOpAnnotationToCaller) { @@ -219,13 +222,14 @@ TEST_F(LegalizeSchedulingAnnotationsTest, MoveFusedOpAnnotationToCaller) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, ParseAndReturnVerifiedModule(hlo_string)); - - EXPECT_IS_OK(LegalizeSchedulingAnnotations().Run(hlo_module.get()).status()); + LegalizeSchedulingAnnotations::Config config; + EXPECT_IS_OK( + LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status()); HloInstruction* fusion = hlo_module->entry_computation()->root_instruction(); const auto& attrs = fusion->frontend_attributes().map(); - EXPECT_TRUE(attrs.contains("_scheduling_group_id")); - EXPECT_EQ(attrs.at("_scheduling_group_id"), "1"); + EXPECT_TRUE(attrs.contains(kXlaSchedulingGroupIdAttr)); + EXPECT_EQ(attrs.at(kXlaSchedulingGroupIdAttr), "1"); } TEST_F(LegalizeSchedulingAnnotationsTest, FusedOpsWithDifferentAnnotationIds) { @@ -247,10 +251,74 @@ TEST_F(LegalizeSchedulingAnnotationsTest, FusedOpsWithDifferentAnnotationIds) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, ParseAndReturnVerifiedModule(hlo_string)); - + LegalizeSchedulingAnnotations::Config config; EXPECT_IS_NOT_OK( - LegalizeSchedulingAnnotations().Run(hlo_module.get()).status()); + LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status()); +} + +TEST_F(LegalizeSchedulingAnnotationsTest, DropAnnotationFromBitcast) { + constexpr absl::string_view hlo_string = R"( + HloModule test + ENTRY entry { + p0 = f32[256,1024]{1,0} parameter(0) + p1 = f32[16,64,256]{2,1,0} parameter(1) + ags0 = (f32[256,1024]{1,0}, f32[1024,1024]{1,0}) all-gather-start(p0), replica_groups={{0,1,2,3}}, dimensions={0}, frontend_attributes={_scheduling_group_id="0"} + bitcast = f32[16,64,256]{2,1,0} bitcast(p1), frontend_attributes={_scheduling_group_id="0"} + agd0 = f32[1024,1024]{1,0} all-gather-done(ags0), frontend_attributes={_scheduling_group_id="0"} + ROOT tuple = (f32[16,64,256]{2,1,0}, f32[1024,1024]{1,0}) tuple(bitcast, agd0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); + LegalizeSchedulingAnnotations::Config config; + config.keep_sync_annotation = [](const HloInstruction* instr) { + return instr->opcode() != HloOpcode::kBitcast; + }; + EXPECT_IS_OK( + LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status()); + HloInstruction* bitcast = + hlo_module->entry_computation()->root_instruction()->mutable_operand(0); + EXPECT_FALSE( + bitcast->frontend_attributes().map().contains(kXlaSchedulingGroupIdAttr)); } +TEST_F(LegalizeSchedulingAnnotationsTest, OpsWithControlDependencies) { + constexpr absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + +ENTRY entry { + p0 = f32[16,64,256]{2,1,0} parameter(0) + p2 = f32[512,2048,2048]{2,1,0} parameter(2) + after-all = token[] after-all() + send = (f32[512,2048,2048]{2,1,0}, u32[], token[]) send(p2, after-all), channel_id=1 + send-done = token[] send-done(send), channel_id=1 + recv = (f32[512,2048,2048]{2,1,0}, u32[], token[]) recv(after-all), channel_id=2 + recv-done = (f32[512,2048,2048]{2,1,0}, token[]) recv-done(recv), channel_id=2, control-predecessors={send-done} + get-tuple-element = f32[512,2048,2048]{2,1,0} get-tuple-element(recv-done), index=0 + slice = f32[16,64,256]{2,1,0} slice(get-tuple-element), slice={[0:16], [0:64], [0:256]} + c0 = f32[16,256,256]{2,1,0} convolution(p0, slice), window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb + c1 = f32[16,256,256]{2,1,0} convolution(p0, slice), window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb, frontend_attributes={_scheduling_group_id="0"} + p1 = f32[128,2048,2048]{2,1,0} parameter(1) + after-all.1 = token[] after-all() + send.1 = (f32[128,2048,2048]{2,1,0}, u32[], token[]) send(p1, after-all.1), channel_id=3, frontend_attributes={_scheduling_group_id="0"} + send-done.1 = token[] send-done(send.1), channel_id=3, frontend_attributes={_scheduling_group_id="0"} + recv.1 = (f32[128,2048,2048]{2,1,0}, u32[], token[]) recv(after-all.1), channel_id=4, frontend_attributes={_scheduling_group_id="0"} + recv-done.1 = (f32[128,2048,2048]{2,1,0}, token[]) recv-done(recv.1), channel_id=4, frontend_attributes={_scheduling_group_id="0"}, control-predecessors={send-done.1} + get-tuple-element.1 = f32[128,2048,2048]{2,1,0} get-tuple-element(recv-done.1), index=0 + after-all.2 = token[] after-all() + send.2 = (f32[128,2048,2048]{2,1,0}, u32[], token[]) send(get-tuple-element.1, after-all.2), channel_id=5 + send-done.2 = token[] send-done(send.2), channel_id=5 + recv.2 = (f32[128,2048,2048]{2,1,0}, u32[], token[]) recv(after-all.2), channel_id=6 + recv-done.2 = (f32[128,2048,2048]{2,1,0}, token[]) recv-done(recv.2), channel_id=6, control-predecessors={send-done.2} + get-tuple-element.2 = f32[128,2048,2048]{2,1,0} get-tuple-element(recv-done.2), index=0 + ROOT tuple.2 = (f32[16,256,256]{2,1,0}, f32[16,256,256]{2,1,0}, f32[128,2048,2048]{2,1,0}, f32[128,2048,2048]{2,1,0}) tuple(c0, c1, get-tuple-element.1, get-tuple-element.2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); + LegalizeSchedulingAnnotations::Config config; + EXPECT_IS_OK( + LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status()); +} } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/llvm_ir/BUILD b/third_party/xla/xla/service/llvm_ir/BUILD index ca4a500cae0409..3f4b9ae56b59db 100644 --- a/third_party/xla/xla/service/llvm_ir/BUILD +++ b/third_party/xla/xla/service/llvm_ir/BUILD @@ -340,7 +340,7 @@ xla_cc_test( "//xla:shape_util", "//xla:test", "//xla:xla_data_proto_cc", - "//xla/tests:filecheck", + "//xla/hlo/testlib:filecheck", "//xla/tests:xla_internal_test_main", "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", diff --git a/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.cc b/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.cc index 7cebcb28eb770d..bb7cbb8d0f115e 100644 --- a/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.cc @@ -109,7 +109,8 @@ FusedIrEmitter::IndexedGenerator FusedIrEmitter::HandleConstant( /*isExternallyInitialized=*/false); global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global); - llvm::Type* shape_type = llvm_ir::ShapeToIrType(constant.shape(), module); + llvm::Type* shape_type = + llvm_ir::ShapeToIrType(constant.shape(), module->getContext()); IrArray array(global, shape_type, constant.shape()); return [&, b, array = std::move(array)](const IrArray::Index& index) { @@ -123,7 +124,8 @@ absl::StatusOr FusedIrEmitter::HandleTuple( element_ir_types.reserve(tuple.operand_count()); for (const HloInstruction* operand : tuple.operands()) { element_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType( - operand->shape().element_type(), elemental_emitter_.module())); + operand->shape().element_type(), + elemental_emitter_.module()->getContext())); } llvm::IRBuilderBase* b = elemental_emitter_.b(); diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.cc b/third_party/xla/xla/service/llvm_ir/ir_array.cc index 8a05c7c55e75ae..a1d87039fba093 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.cc +++ b/third_party/xla/xla/service/llvm_ir/ir_array.cc @@ -567,8 +567,8 @@ llvm::Value* IrArray::EmitLinearArrayElementAddress( const IrArray::Index& index, llvm::IRBuilderBase* b, absl::string_view name, llvm::Value** bit_offset) const { CHECK(index.LinearValidOnShape(shape_)); - llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); - llvm::Type* type = PrimitiveTypeToIrType(shape_.element_type(), module); + llvm::Type* type = + PrimitiveTypeToIrType(shape_.element_type(), b->getContext()); if (!primitive_util::IsSubByteNonPredType(shape_.element_type())) { auto linear_index = llvm::dyn_cast(index.linear()); if (linear_index && (linear_index->getOpcode() == llvm::Instruction::Add)) { @@ -671,8 +671,7 @@ IrArray IrArray::CastToShape(const Shape& new_shape, llvm::IRBuilderBase* b) const { if (shape_ == new_shape) return *this; - llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); - llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module); + llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, b->getContext()); IrArray new_irarray(base_ptr_, new_ir_type, new_shape); new_irarray.metadata_ = metadata_; return new_irarray; diff --git a/third_party/xla/xla/service/llvm_ir/ir_array_test.cc b/third_party/xla/xla/service/llvm_ir/ir_array_test.cc index 63ca0d8fa30d79..74289ece2a214a 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array_test.cc +++ b/third_party/xla/xla/service/llvm_ir/ir_array_test.cc @@ -25,11 +25,11 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/tests/filecheck.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -92,7 +92,7 @@ TEST_F(IrArrayTest, EmitArrayElementAddress) { llvm::Argument* array_index = function->getArg(1); Shape shape = ShapeUtil::MakeShape(F32, {3, 5}); - llvm::Type* type = llvm_ir::ShapeToIrType(shape, &module_); + llvm::Type* type = llvm_ir::ShapeToIrType(shape, module_.getContext()); IrArray ir_array(array_ptr, type, shape); IrArray::Index index(array_index, shape, &builder_); @@ -116,7 +116,7 @@ TEST_F(IrArrayTest, EmitArrayElementAddressNonLinear) { llvm::Argument* array_index = function->getArg(1); Shape shape = ShapeUtil::MakeShape(F32, {3, 5}); - llvm::Type* type = llvm_ir::ShapeToIrType(shape, &module_); + llvm::Type* type = llvm_ir::ShapeToIrType(shape, module_.getContext()); IrArray ir_array(array_ptr, type, shape); IrArray::Index index(array_index, shape, &builder_); @@ -144,7 +144,7 @@ TEST_F(IrArrayTest, EmitArrayElementAddressInt4) { llvm::Argument* array_index = function->getArg(1); Shape shape = ShapeUtil::MakeShape(S4, {3, 5}); - llvm::Type* type = llvm_ir::ShapeToIrType(shape, &module_); + llvm::Type* type = llvm_ir::ShapeToIrType(shape, module_.getContext()); IrArray ir_array(array_ptr, type, shape); IrArray::Index index(array_index, shape, &builder_); @@ -177,7 +177,7 @@ TEST_F(IrArrayTest, EmitArrayElementAddressInt4NonLinear) { llvm::Argument* array_index1 = function->getArg(2); Shape shape = ShapeUtil::MakeShape(S4, {3, 5}); - llvm::Type* type = llvm_ir::ShapeToIrType(shape, &module_); + llvm::Type* type = llvm_ir::ShapeToIrType(shape, module_.getContext()); IrArray ir_array(array_ptr, type, shape); IrArray::Index index({array_index0, array_index1}, shape, @@ -212,7 +212,7 @@ TEST_F(IrArrayTest, EmitReadArrayElementInt4) { llvm::Argument* array_index = function->getArg(1); Shape shape = ShapeUtil::MakeShape(S4, {3, 5}); - llvm::Type* type = llvm_ir::ShapeToIrType(shape, &module_); + llvm::Type* type = llvm_ir::ShapeToIrType(shape, module_.getContext()); IrArray ir_array(array_ptr, type, shape); IrArray::Index index(array_index, shape, &builder_); @@ -249,7 +249,7 @@ TEST_F(IrArrayTest, EmitWriteArrayElementInt4) { llvm::Argument* val_to_write = function->getArg(2); Shape shape = ShapeUtil::MakeShape(S4, {3, 5}); - llvm::Type* type = llvm_ir::ShapeToIrType(shape, &module_); + llvm::Type* type = llvm_ir::ShapeToIrType(shape, module_.getContext()); IrArray ir_array(array_ptr, type, shape); IrArray::Index index(array_index, shape, &builder_); diff --git a/third_party/xla/xla/service/llvm_ir/llvm_command_line_options.cc b/third_party/xla/xla/service/llvm_ir/llvm_command_line_options.cc index 88937eff4b6d79..ef24817fcd6802 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_command_line_options.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_command_line_options.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include "absl/algorithm/container.h" @@ -25,6 +24,7 @@ limitations under the License. #include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "llvm/Support/CommandLine.h" #include "tsl/platform/logging.h" @@ -56,12 +56,12 @@ LLVMCommandLineOptionsLock::LLVMCommandLineOptionsLock( std::vector fake_argv(client_options.size() + GetGlobalOptions().size() + 1); fake_argv[0] = "xla"; - for (std::string_view client_option : client_options) { + for (absl::string_view client_option : client_options) { VLOG(1) << absl::StrFormat("XLA LLVM arg[%d]: %s", idx, client_option); fake_argv[idx] = client_option.data(); ++idx; } - for (std::string_view global_option : GetGlobalOptions()) { + for (absl::string_view global_option : GetGlobalOptions()) { VLOG(1) << absl::StrFormat("XLA LLVM arg[%d]: %s", idx, global_option); fake_argv[idx] = global_option.data(); ++idx; diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.cc b/third_party/xla/xla/service/llvm_ir/llvm_util.cc index 229b7f87b7d2c1..ff7c4e84a19b00 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.cc @@ -48,6 +48,7 @@ limitations under the License. #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Type.h" #include "llvm/Support/Alignment.h" @@ -183,21 +184,21 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Type* element_type, } llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, - llvm::Module* module) { + llvm::LLVMContext& context) { switch (element_type) { case S2: case U2: - return llvm::Type::getIntNTy(module->getContext(), 2); + return llvm::Type::getIntNTy(context, 2); case S4: case U4: - return llvm::Type::getIntNTy(module->getContext(), 4); + return llvm::Type::getIntNTy(context, 4); case PRED: case S8: case U8: - return llvm::Type::getInt8Ty(module->getContext()); + return llvm::Type::getInt8Ty(context); case S16: case U16: - return llvm::Type::getInt16Ty(module->getContext()); + return llvm::Type::getInt16Ty(context); case F8E5M2: case F8E5M2FNUZ: case F8E4M3: @@ -206,24 +207,23 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, case F8E4M3FNUZ: case F8E3M4: // We represent F8 as an int since there is no LLVM F8 dtype. - return llvm::Type::getInt8Ty(module->getContext()); + return llvm::Type::getInt8Ty(context); case BF16: - return llvm::Type::getBFloatTy(module->getContext()); + return llvm::Type::getBFloatTy(context); case F16: - return llvm::Type::getHalfTy(module->getContext()); + return llvm::Type::getHalfTy(context); case S32: case U32: - return llvm::Type::getInt32Ty(module->getContext()); + return llvm::Type::getInt32Ty(context); case S64: case U64: - return llvm::Type::getInt64Ty(module->getContext()); + return llvm::Type::getInt64Ty(context); case F32: - return llvm::Type::getFloatTy(module->getContext()); + return llvm::Type::getFloatTy(context); case F64: - return llvm::Type::getDoubleTy(module->getContext()); + return llvm::Type::getDoubleTy(context); case C64: { - auto cplx_t = - llvm::StructType::getTypeByName(module->getContext(), "complex64"); + auto cplx_t = llvm::StructType::getTypeByName(context, "complex64"); if (cplx_t == nullptr) { // C++ standard dictates the memory layout of std::complex is contiguous // real followed by imaginary. C++11 section 26.4 [complex.numbers]: @@ -233,31 +233,28 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, // z, and reinterpret_cast(z)[1] shall designate the // imaginary part of z. return llvm::StructType::create( - {llvm::Type::getFloatTy(module->getContext()), - llvm::Type::getFloatTy(module->getContext())}, + {llvm::Type::getFloatTy(context), llvm::Type::getFloatTy(context)}, "complex64", /*isPacked=*/true); } return cplx_t; } case C128: { - auto cplx_t = - llvm::StructType::getTypeByName(module->getContext(), "complex128"); + auto cplx_t = llvm::StructType::getTypeByName(context, "complex128"); if (cplx_t == nullptr) { - return llvm::StructType::create( - {llvm::Type::getDoubleTy(module->getContext()), - llvm::Type::getDoubleTy(module->getContext())}, - "complex128", /*isPacked=*/true); + return llvm::StructType::create({llvm::Type::getDoubleTy(context), + llvm::Type::getDoubleTy(context)}, + "complex128", /*isPacked=*/true); } return cplx_t; } // A Tuple contains an array of pointers. Use i8*. case TUPLE: // An Opaque is like a void*, use i8*. case OPAQUE_TYPE: - return llvm::PointerType::getUnqual(module->getContext()); + return llvm::PointerType::getUnqual(context); case TOKEN: // Tokens do not have a physical representation, but the compiler needs // some placeholder type, so use int8_t*. - return llvm::PointerType::getUnqual(module->getContext()); + return llvm::PointerType::getUnqual(context); default: LOG(FATAL) << "unsupported type " << element_type; } @@ -278,8 +275,9 @@ int GetSizeInBits(llvm::Type* type) { return bits; } -llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { - llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module); +llvm::Type* ShapeToIrType(const Shape& shape, llvm::LLVMContext& context) { + llvm::Type* result_type = + PrimitiveTypeToIrType(shape.element_type(), context); if (shape.IsTuple()) { // A tuple buffer is an array of pointers. result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size()); @@ -471,8 +469,8 @@ llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, } // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1 // arrays. So we extend it to i8 so that it's addressable. - return b->CreateZExt(comparison_result, llvm_ir::PrimitiveTypeToIrType( - PRED, ModuleFromIRBuilder(b))); + return b->CreateZExt(comparison_result, + llvm_ir::PrimitiveTypeToIrType(PRED, b->getContext())); } // Internal helper that is called from emitted code to log an int64_t value with @@ -689,8 +687,8 @@ void DumpIrIfEnabled(const HloModule& hlo_module, // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously // dumped from the same process in such cases. std::string suffix = - absl::StrCat("ir-", optimized ? "with" : "no", "-opt", - filename_suffix.empty() ? "" : ".", filename_suffix); + absl::StrCat(filename_suffix, filename_suffix.empty() ? "" : ".", "ir-", + optimized ? "with" : "no", "-opt"); DumpToFileInDirOrStdout(hlo_module, "", absl::StrCat(suffix, ".ll"), DumpToString(&llvm_module)); } diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.h b/third_party/xla/xla/service/llvm_ir/llvm_util.h index e5f1ea13000876..88c1287d2f236d 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.h +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.h @@ -31,6 +31,7 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "mlir/IR/BuiltinOps.h" @@ -130,14 +131,14 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Type* element_type, // Returns the LLVM type which represents the given XLA primitive type. llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, - llvm::Module* module); + llvm::LLVMContext& context); // Returns the type size in bits. If "type" is a struct, it must be packed. int GetSizeInBits(llvm::Type* type); // Returns the LLVM type which represents the given XLA shape. For example, // if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]]. -llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module); +llvm::Type* ShapeToIrType(const Shape& shape, llvm::LLVMContext& context); // Returns a value that represents a pointer to a global string constant that // encodes the shape as a serialized protobuf. diff --git a/third_party/xla/xla/service/llvm_ir/sort_util.cc b/third_party/xla/xla/service/llvm_ir/sort_util.cc index 726973612458a3..1be41989c7b666 100644 --- a/third_party/xla/xla/service/llvm_ir/sort_util.cc +++ b/third_party/xla/xla/service/llvm_ir/sort_util.cc @@ -131,8 +131,8 @@ absl::Status EmitCompareLoopBody( values_to_compare_types.push_back( element_address_pointee_type(i, current_keys_index)); } - llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); - llvm::Type* pred_type = llvm_ir::PrimitiveTypeToIrType(PRED, module); + llvm::Type* pred_type = + llvm_ir::PrimitiveTypeToIrType(PRED, b->getContext()); llvm::Value* compare_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( pred_type, "compare_return_buffer", b); TF_RETURN_IF_ERROR( @@ -366,7 +366,7 @@ absl::Status EmitSortInPlace( for (int64_t i = 0; i < values_arrays.size(); ++i) { llvm::Type* tile_type = llvm::ArrayType::get( llvm_ir::PrimitiveTypeToIrType( - values_arrays[i].GetShape().element_type(), module), + values_arrays[i].GetShape().element_type(), b->getContext()), std::max(tile_size, static_cast(64))); param_shmem_buffers[i] = llvm_ir::AllocateSharedMemoryTile( module, tile_type, absl::StrCat(name, "_tile_param_", i)); diff --git a/third_party/xla/xla/service/llvm_ir/tuple_ops.cc b/third_party/xla/xla/service/llvm_ir/tuple_ops.cc index bb9088c409cdee..65d47114e07113 100644 --- a/third_party/xla/xla/service/llvm_ir/tuple_ops.cc +++ b/third_party/xla/xla/service/llvm_ir/tuple_ops.cc @@ -45,10 +45,9 @@ static llvm::Module* getModuleFromBuilder(llvm::IRBuilderBase* b) { void EmitTuple(const IrArray& tuple, absl::Span operands, llvm::IRBuilderBase* b) { - llvm::Module* module = getModuleFromBuilder(b); for (size_t i = 0; i < operands.size(); ++i) { - auto* cast = - b->CreatePointerCast(operands[i], PrimitiveTypeToIrType(TUPLE, module)); + auto* cast = b->CreatePointerCast( + operands[i], PrimitiveTypeToIrType(TUPLE, b->getContext())); auto* store = b->CreateStore( cast, b->CreateInBoundsGEP(tuple.GetBasePointeeType(), tuple.GetBasePointer(), @@ -69,8 +68,6 @@ void EmitTuple(const IrArray& tuple, absl::Span buffers, std::vector EmitTupleAllocasAtFunctionEntry( const Shape& tuple_shape, llvm::IRBuilderBase* b) { - llvm::Module* module = b->GetInsertBlock()->getModule(); - llvm::IRBuilderBase::InsertPointGuard guard(*b); llvm::Function* function = b->GetInsertBlock()->getParent(); b->SetInsertPoint(&function->getEntryBlock(), @@ -82,8 +79,8 @@ std::vector EmitTupleAllocasAtFunctionEntry( for (int i = 0; i < tuple_size; i++) { const Shape& element_shape = tuple_shape.tuple_shapes(i); CHECK(ShapeUtil::IsScalar(element_shape)); - llvm::Type* type = - llvm_ir::PrimitiveTypeToIrType(element_shape.element_type(), module); + llvm::Type* type = llvm_ir::PrimitiveTypeToIrType( + element_shape.element_type(), b->getContext()); llvm::AllocaInst* alloca = b->CreateAlloca( type, /*ArraySize=*/nullptr, AsStringRef(absl::StrCat("tuple_element_", i))); diff --git a/third_party/xla/xla/service/lockable_test.cc b/third_party/xla/xla/service/lockable_test.cc index 9118fb9e7276bf..67bf41cef0617b 100644 --- a/third_party/xla/xla/service/lockable_test.cc +++ b/third_party/xla/xla/service/lockable_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include #include "absl/synchronization/blocking_counter.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/map_inliner.cc b/third_party/xla/xla/service/map_inliner.cc index deb7b6755f6ce0..7f96c1e8aa80a4 100644 --- a/third_party/xla/xla/service/map_inliner.cc +++ b/third_party/xla/xla/service/map_inliner.cc @@ -19,7 +19,9 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" diff --git a/third_party/xla/xla/service/map_inliner_test.cc b/third_party/xla/xla/service/map_inliner_test.cc index de1511e2a6ff43..c9387108a19fae 100644 --- a/third_party/xla/xla/service/map_inliner_test.cc +++ b/third_party/xla/xla/service/map_inliner_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" diff --git a/third_party/xla/xla/service/mapped_ptr_container_sorter_test.cc b/third_party/xla/xla/service/mapped_ptr_container_sorter_test.cc index ca738619aa8ab8..bb1b55ccdd646b 100644 --- a/third_party/xla/xla/service/mapped_ptr_container_sorter_test.cc +++ b/third_party/xla/xla/service/mapped_ptr_container_sorter_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include #include "absl/functional/bind_front.h" #include "absl/log/log.h" #include "xla/test.h" diff --git a/third_party/xla/xla/service/memory_space_assignment/BUILD b/third_party/xla/xla/service/memory_space_assignment/BUILD index c8b5e507061f51..f6dbfc0995f093 100644 --- a/third_party/xla/xla/service/memory_space_assignment/BUILD +++ b/third_party/xla/xla/service/memory_space_assignment/BUILD @@ -84,11 +84,13 @@ xla_cc_test( ":cost_analysis", ":memory_space_assignment", ":memory_space_assignment_proto_cc", + ":memory_space_assignment_test_base", ":options", ":prefetch_interval_picker", ":repacking", ":slice", ":testing_utils", + ":utils", "//xla:comparison_util", "//xla:literal_util", "//xla:shape_util", @@ -97,19 +99,21 @@ xla_cc_test( "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", - "//xla/hlo/transforms:instruction_hoister", + "//xla/hlo/testlib:verified_hlo_module", "//xla/hlo/utils:hlo_live_range", "//xla/hlo/utils:hlo_matchers", - "//xla/service:buffer_value", "//xla/service:hlo_buffer", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_value", "//xla/service/heap_simulator", "//xla/service/heap_simulator:allocation_block", - "//xla/tests:hlo_test_base", "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", + "//xla/tsl/platform:status_matchers", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -120,12 +124,8 @@ xla_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", ], ) @@ -160,6 +160,34 @@ cc_library( ], ) +cc_library( + name = "memory_space_assignment_test_base", + testonly = True, + hdrs = ["memory_space_assignment_test_base.h"], + deps = [ + ":buffer_interval_comparator", + ":cost_analysis", + ":memory_space_assignment", + ":options", + ":prefetch_interval_picker", + "//xla:shape_util", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/transforms/simplifiers:instruction_hoister", + "//xla/hlo/utils:hlo_live_range", + "//xla/service:buffer_value", + "//xla/service:hlo_buffer", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_value", + "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + cc_library( name = "utils", srcs = ["utils.cc"], @@ -279,7 +307,7 @@ cc_library( cc_library( name = "options", - srcs = [], + srcs = ["options.cc"], hdrs = ["options.h"], deps = [ ":allocation_value", @@ -294,9 +322,10 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:buffer_value", "//xla/service:hlo_value", - "//xla/service/heap_simulator", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", ], @@ -490,12 +519,12 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/hlo/utils:hlo_live_range", "//xla/service:buffer_value", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_value", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -538,6 +567,7 @@ cc_library( "//xla/hlo/utils:hlo_live_range", "//xla/service:buffer_value", "//xla/service:call_graph", + "//xla/service:computation_layout", "//xla/service:hlo_buffer", "//xla/service:hlo_proto_cc", "//xla/service:hlo_value", diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc index 5209907a67624d..6d2a5eb49c83a4 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc @@ -29,7 +29,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -57,6 +56,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/buffer_value.h" #include "xla/service/call_graph.h" +#include "xla/service/computation_layout.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo_buffer.h" @@ -77,7 +77,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -244,9 +243,6 @@ bool IsCrossProgramPrefetchCandidate(const HloValue& value, return value.defining_instruction()->parent() == value.defining_instruction()->GetModule()->entry_computation() && value.defining_instruction()->opcode() == HloOpcode::kParameter && - (!value.shape().has_layout() || - value.shape().layout().memory_space() != - options.alternate_memory_space) && value.index().size() <= 1 && value.shape().IsArray() && !uses.empty() && options.size_fn(value) <= options.max_size_in_bytes && absl::c_all_of(uses, [&](const HloUse& use) { @@ -270,34 +266,78 @@ bool IsCrossProgramPrefetchCandidate(const HloValue& value, }); } -struct CrossProgramPrefetchBufferSortValues { - int64_t latest_use = 0; - int64_t use_size = 0; +bool IsUserAnnotatedCrossProgramPrefetch(const HloValue& value, + const Options& options) { + const HloInstruction* defining_instruction = value.defining_instruction(); + if (defining_instruction->parent() != + defining_instruction->GetModule()->entry_computation() || + defining_instruction->opcode() != HloOpcode::kParameter) { + return false; + } + const ComputationLayout& entry_computation_layout = + defining_instruction->GetModule()->entry_computation_layout(); + if (defining_instruction->parameter_number() >= + entry_computation_layout.parameter_count()) { + return false; + } + const Shape& shape = + entry_computation_layout + .parameter_layout(defining_instruction->parameter_number()) + .shape(); + return shape.has_layout() && + shape.layout().memory_space() == options.alternate_memory_space; +} + +MsaBufferInterval CreateMsaBufferInterval(const HloBuffer& buffer, + const HloValue* value, + const HloLiveRange& hlo_live_range, + const Options& options) { + MsaBufferInterval interval; + interval.buffer = value; + interval.size = options.size_fn(*value); + interval.start = 0; + interval.end = hlo_live_range.schedule_end_time(); + interval.colocations = {++buffer.values().begin(), buffer.values().end()}; + interval.need_allocation = true; + return interval; +} + +struct CrossProgramPrefetches { + std::vector prefetches; + std::vector candidates; }; -std::vector FindCrossProgramPrefetchCandidates( +CrossProgramPrefetches FindCrossProgramPrefetches( const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range, const Options& options) { - std::vector candidates; + CrossProgramPrefetches cross_program_prefetches; for (const HloBuffer& buffer : alias_analysis.buffers()) { CHECK_GE(buffer.values().size(), 1); const HloValue* value = buffer.values().at(0); - MsaBufferInterval interval; - interval.buffer = value; - interval.size = options.size_fn(*value); - interval.start = 0; - interval.end = hlo_live_range.schedule_end_time(); - interval.need_allocation = true; - interval.colocations = {++buffer.values().begin(), buffer.values().end()}; - if (IsCrossProgramPrefetchCandidate(*value, alias_analysis, options)) { - candidates.emplace_back(interval); + MsaBufferInterval buffer_interval = + CreateMsaBufferInterval(buffer, value, hlo_live_range, options); + if (IsUserAnnotatedCrossProgramPrefetch(*value, options)) { + cross_program_prefetches.prefetches.push_back(buffer_interval); + } else if (IsCrossProgramPrefetchCandidate(*value, alias_analysis, + options)) { + cross_program_prefetches.candidates.push_back(buffer_interval); } else if (MemorySpaceAssignmentUtils:: DoesCrossProgramPrefetchBufferMatchAnyFilter( - options.msa_sort_order_overrides, interval)) { - candidates.emplace_back(interval); + options.msa_sort_order_overrides, buffer_interval)) { + cross_program_prefetches.candidates.push_back(buffer_interval); } } + for (auto& prefetch : cross_program_prefetches.prefetches) { + VLOG(3) << "User annotated cross-program prefetch: " + << prefetch.buffer->ToString(); + } + + for (auto& prefetch : cross_program_prefetches.prefetches) { + VLOG(3) << "User annotated cross-program prefetch: " + << prefetch.buffer->ToString(); + } + DefaultCrossProgramPrefetchBufferIntervalComparator default_comparator( hlo_live_range, options.msa_sort_order_overrides); BufferIntervalComparator* comparator = @@ -305,16 +345,18 @@ std::vector FindCrossProgramPrefetchCandidates( options.buffer_interval_comparator ? options.buffer_interval_comparator : &default_comparator); - absl::c_sort(candidates, comparator->GetComparisonFunctor()); + absl::c_sort(cross_program_prefetches.candidates, + comparator->GetComparisonFunctor()); - VLOG(3) << "Cross-program prefetch candidates: " << candidates.size() + VLOG(3) << "Cross-program prefetch candidates: " + << cross_program_prefetches.candidates.size() << ". Sorting criteria: " << comparator->DescribeComparisonCriteria(); - for (auto& candidate : candidates) { + for (auto& candidate : cross_program_prefetches.candidates) { VLOG(3) << "Cross-program prefetch candidate. Sorting criteria: " << comparator->CriteriaToString(candidate) << ". Candidate: " << candidate.buffer->ToString(); } - return candidates; + return cross_program_prefetches; } } // namespace @@ -1591,6 +1633,46 @@ bool MsaAlgorithm::RepackAllocationsIncludeConvertedSyncMemOp() { return false; } +namespace { + +// Fixes the AllocationSequence after post-allocation transformation: +// 1. Remove the allocations with to_be_removed instructions as the defining +// positions. +// 2. Update the vector of uses for all allocations according to the +// update_use_map. +// Note that to_be_removed instructions will later be removed from the module +// during SimplifyGraph() call in memory_space_assignment.cc +void FixAllocationSequenceAfterPostAllocationTransformation( + AllocationSequence* allocations, + const PostAllocationTransformationUpdate& transformation_info) { + VLOG(3) << "Fixing AllocationSequence after post-allocation transformation"; + + // (1) + allocations->erase( + std::remove_if( + allocations->begin(), allocations->end(), + [transformation_info](const std::unique_ptr& allocation) { + return std::find(transformation_info.to_be_removed.begin(), + transformation_info.to_be_removed.end(), + allocation->defining_position().instruction) != + transformation_info.to_be_removed.end(); + }), + allocations->end()); + + // (2) + for (auto& allocation : *allocations) { + for (const HloUse& use : allocation->uses()) { + auto new_use_it = transformation_info.update_use_map.find(use); + if (new_use_it != transformation_info.update_use_map.end()) { + allocation->RemoveUse(use); + allocation->AddUse(new_use_it->second); + } + } + } +} + +} // namespace + absl::StatusOr> MsaAlgorithm::Finish() { // Note: Memory Space Assignment creates a HeapSimulator and passes an // MsaAlgorithm object to it. buffer_intervals_ is populated by calling the @@ -1642,11 +1724,27 @@ absl::StatusOr> MsaAlgorithm::Finish() { } VLOG(1) << "Memory pressure = " << memory_pressure_; + CrossProgramPrefetches cross_program_prefetches = + FindCrossProgramPrefetches(alias_analysis_, hlo_live_range_, options_); + // Crash if cross program prefetch is disabled and user has requested + // cross program prefetch. + CHECK(options_.enable_cross_program_prefetch || + cross_program_prefetches.prefetches.empty()) + << "Cross program prefetch is disabled but user has requested cross " + "program prefetch."; + // Crash if number of user requested cross program prefetches is greater than + // the maximum number of cross program prefetches allowed. + CHECK(cross_program_prefetches.prefetches.size() <= + options().max_cross_program_prefetches) + << "Number of user requested cross program prefetches is greater than " + "the maximum number of cross program prefetches allowed."; + // Allocate user requested cross program prefetches first. + for (auto& prefetch : cross_program_prefetches.prefetches) { + HloModule* module = prefetch.buffer->instruction()->GetModule(); + AllocateCrossProgramPrefetchBuffer(module, prefetch); + } if (options_.enable_cross_program_prefetch) { - std::vector prefetch_candidates = - FindCrossProgramPrefetchCandidates(alias_analysis_, hlo_live_range_, - options_); - for (auto& prefetch_candidate : prefetch_candidates) { + for (auto& prefetch_candidate : cross_program_prefetches.candidates) { HloModule* module = prefetch_candidate.buffer->instruction()->GetModule(); if (0 <= options().max_cross_program_prefetches && options().max_cross_program_prefetches <= @@ -1848,6 +1946,10 @@ absl::StatusOr> MsaAlgorithm::Finish() { if (VLOG_IS_ON(3)) { VLOG(3) << "Sync copy replacement summary: "; + VLOG(3) << "\tnumber of successful async conversion: " + << successful_async_conversion_set_.size(); + VLOG(3) << "\tnumber of failed async conversion: " + << failed_async_conversions_.size(); for (const HloInstruction* inst : successful_async_conversion_set_) { VLOG(3) << "Successful copy replacement: " << inst->ToString(); } @@ -1857,6 +1959,53 @@ absl::StatusOr> MsaAlgorithm::Finish() { } } + // Run post allocation transformation and fix the allocation sequence if + // needed. + if (options_.post_allocation_transformation_fn) { + PostAllocationTransformationUpdate all_changes; + VLOG(3) << "Running post allocation transformation on module"; + for (HloComputation* comp : alias_analysis_.dataflow_analysis() + .module() + .MakeNonfusionComputations()) { + for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { + // If the operand is in alternate memory, we don't run the + // post-allocation transformation. + auto operand_it = operands_in_alternate_memory_map_.find(instr); + if (operand_it != operands_in_alternate_memory_map_.end()) { + continue; + } + + // If the instruction is a successful async conversion, we don't run the + // post-allocation transformation. + if (successful_async_conversion_set_.contains(instr)) { + continue; + } + + // If any of the operands of the instruction has an in-place user, we + // don't run the post-allocation transformation. + for (HloInstruction* operand : instr->operands()) { + for (HloInstruction* user : operand->users()) { + if (HloDataflowAnalysis::IsInPlaceOperation(user->opcode())) { + continue; + } + } + } + + TF_ASSIGN_OR_RETURN(PostAllocationTransformationUpdate changes, + options_.post_allocation_transformation_fn(instr)); + all_changes.to_be_removed.insert(all_changes.to_be_removed.end(), + changes.to_be_removed.begin(), + changes.to_be_removed.end()); + all_changes.update_use_map.insert(changes.update_use_map.begin(), + changes.update_use_map.end()); + } + } + VLOG(3) << "Post allocation transformation info: \n" + << all_changes.ToString(); + FixAllocationSequenceAfterPostAllocationTransformation(allocations_, + all_changes); + } + HeapSimulator::Result result; result.heap_size = result_.heap_size; result.heap_results.emplace_back(std::move(result_)); @@ -2982,8 +3131,8 @@ bool AsynchronousCopyResource::ConsumeResource( // that was freed when removing the copy. float old_resource = std::max(0.0f, initial_resources_[time] - delay_[time]); - if (delay_change_map && !delay_change_map->contains(time)) { - (*delay_change_map)[time] = delay_[time]; + if (delay_change_map) { + delay_change_map->emplace(time, delay_[time]); } delay_[time] = std::max(0.0f, resource - resource_to_free); float new_resource = @@ -3178,7 +3327,7 @@ std::string AsynchronousCopyResource::Dump( std::vector col_sizes; std::vector> rows; rows.push_back({"time", "initial", "delay", "avail", "overlapping copies"}); - for (std::string_view col : rows.front()) { + for (absl::string_view col : rows.front()) { col_sizes.push_back(col.size()); } for (int i = 0; i < time_dump_data.size(); ++i) { @@ -3239,6 +3388,25 @@ void MsaAlgorithm::CreateOrAddToAliasedOffset(const Allocation& allocation, return nullptr; } +namespace { + +void SetDefaultMemorySpace(const HloValue* value, const Options& options) { + for (auto& position : value->positions()) { + Shape* shape = ShapeUtil::GetMutableSubshape( + position.instruction->mutable_shape(), position.index); + if (!shape->has_layout() || + shape->layout().memory_space() != options.alternate_memory_space) { + continue; + } + shape->mutable_layout()->set_memory_space(options.default_memory_space); + } + HloModule* module = value->defining_instruction()->GetModule(); + module->mutable_config().SetComputationLayoutIfExists( + module->entry_computation()->ComputeProgramShape()); +} + +} // namespace + void MsaAlgorithm::AllocateCrossProgramPrefetchBuffer( HloModule* module, const MsaBufferInterval& prefetch_candidate) { Chunk chunk_candidate = FindChunkCandidate(prefetch_candidate); @@ -3250,6 +3418,7 @@ void MsaAlgorithm::AllocateCrossProgramPrefetchBuffer( const HloValue* buffer = prefetch_candidate.buffer; int64_t parameter = buffer->instruction()->parameter_number(); int cross_program_prefetch_index = module->CrossProgramPrefetches().size(); + SetDefaultMemorySpace(buffer, options_); module->AddCrossProgramPrefetch(parameter, buffer->index()); AllocationSequence allocations; @@ -3285,15 +3454,10 @@ void MsaAlgorithm::AllocateCrossProgramPrefetchBuffer( } int64_t end_of_program_prefetch_end_time = instruction_schedule.size(); - int64_t end_of_program_prefetch_latest_start_time = - options_.prefetch_interval_picker->LatestPrefetchStartTime( - buffer->defining_position().shape(), last_use_time, - end_of_program_prefetch_end_time, nullptr); int64_t end_of_program_inclusive_prefetch_start_time = options_.prefetch_interval_picker->PreferredPrefetchStartTime( buffer->defining_position().shape(), last_use_time, - end_of_program_prefetch_latest_start_time, - end_of_program_prefetch_end_time); + end_of_program_prefetch_end_time, end_of_program_prefetch_end_time); VLOG(2) << "last use time = " << last_use_time << ", end-of-program inclusive prefetch start time = " << end_of_program_inclusive_prefetch_start_time; @@ -4669,19 +4833,15 @@ bool MsaAlgorithm::ViolatesMaximumOutstandingAsyncCopies( // Count the prefetches/evictions in the interval tree for the given interval. if (is_prefetch) { - int64_t num_prefetches = - prefetch_interval_tree_ - .ChunksOverlappingInTime(inclusive_start_time, end_time) - .size() + - num_additional_copies; + int64_t num_prefetches = prefetch_interval_tree_.NumChunksOverlappingInTime( + inclusive_start_time, end_time) + + num_additional_copies; return num_prefetches >= options_.max_outstanding_prefetches + extra_async_copy_limit; } else { - int64_t num_evictions = - eviction_interval_tree_ - .ChunksOverlappingInTime(inclusive_start_time, end_time) - .size() + - num_additional_copies; + int64_t num_evictions = eviction_interval_tree_.NumChunksOverlappingInTime( + inclusive_start_time, end_time) + + num_additional_copies; return num_evictions >= options_.max_outstanding_evictions + extra_async_copy_limit; } diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation.cc b/third_party/xla/xla/service/memory_space_assignment/allocation.cc index 438fa9778eb38a..aa287b15a52e32 100644 --- a/third_party/xla/xla/service/memory_space_assignment/allocation.cc +++ b/third_party/xla/xla/service/memory_space_assignment/allocation.cc @@ -128,6 +128,12 @@ bool Allocation::is_in_default_mem() const { return memory_space_ == MemorySpace::kDefault; } +void Allocation::RemoveUse(HloUse use) { + uses_.erase(std::remove_if(uses_.begin(), uses_.end(), + [=](const auto& u) { return u == use; }), + uses_.end()); +} + void Allocation::AddUse(HloUse use) { HloInstruction* operand = use.instruction->mutable_operand(use.operand_number); @@ -919,19 +925,28 @@ absl::Status WindowPrefetchedAllocation::InsertWindowPrefetchInstruction( HloInstruction* producing_instruction, HloInstruction* use_instruction, HloComputation* computation) { // Derive the shape for window buffer. - Shape shape = ShapeUtil::MakeShape(U8, {options_.bytes}); + Shape buffer_shape = ShapeUtil::MakeShape(U8, {options_.bytes}); Layout layout = LayoutUtil::MakeLayout({0}); layout.set_memory_space(options_.alternate_memory_space); - *shape.mutable_layout() = layout; - - // Insert async WindowPrefetch instructions as operands to the fusion. - HloInstruction* prefetch = + *buffer_shape.mutable_layout() = layout; + // Sync flag shape + Shape sflag_shape = ShapeUtil::MakeShape(S32, {}); + // Output shape of the WindowPrefetch op. + Shape output_shape = ShapeUtil::MakeTupleShape({buffer_shape, sflag_shape}); + + // Insert WindowPrefetch op. + HloInstruction* custom_call = computation->AddInstruction(HloInstruction::CreateCustomCall( - shape, {producing_instruction}, "WindowPrefetch")); - TF_ASSIGN_OR_RETURN(prefetch_instruction_, - computation->CreateAsyncInstructions(prefetch, {})); - use_instruction->AppendOperand(prefetch_instruction_); - + output_shape, {producing_instruction}, "WindowPrefetch")); + HloInstruction* get_buffer = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(buffer_shape, custom_call, 0)); + HloInstruction* get_sflag = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(sflag_shape, custom_call, 1)); + use_instruction->AppendOperand(get_buffer); + use_instruction->AppendOperand(get_sflag); + + // The buffer's defining position is the get_tuple_element instruction. + prefetch_instruction_ = get_buffer; return absl::OkStatus(); } @@ -939,6 +954,7 @@ absl::Status WindowPrefetchedAllocation::Process() { HloInstruction* producing_instruction = AddGetTupleElements(); HloComputation* computation = producing_instruction->parent(); HloInstruction* use_instruction = use_.instruction; + int64_t use_operand = use_instruction->operand_count(); CHECK_EQ(use_instruction->opcode(), HloOpcode::kFusion); TF_RETURN_IF_ERROR(InsertWindowPrefetchInstruction( @@ -946,7 +962,6 @@ absl::Status WindowPrefetchedAllocation::Process() { // Notify the backend that an operand has been appended as a window prefetch // buffer. - int64_t use_operand = use_instruction->operand_count() - 1; options_.notify_operand_appended_fn(use_instruction, options_.uid, use_operand); diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation.h b/third_party/xla/xla/service/memory_space_assignment/allocation.h index 81ac4199c5b86f..0e1d688a9ace92 100644 --- a/third_party/xla/xla/service/memory_space_assignment/allocation.h +++ b/third_party/xla/xla/service/memory_space_assignment/allocation.h @@ -127,6 +127,7 @@ class Allocation { bool has_no_uses() const { return uses_.empty(); } // Adds a use to this allocation. void AddUse(HloUse use); + void RemoveUse(HloUse use); // Replaces all uses of the allocation with the copy_complete instruction. absl::Status UpdateUses(HloComputation* computation, HloInstruction* producing_instruction); @@ -238,8 +239,6 @@ class PinnedAllocation final : public Allocation { // before `copy_done_schedule_before_time`. class CopyAllocation final : public Allocation { public: - // TODO(b/307342076): Reorder scheduling times to be - // copy_start_schedule_after_time, copy_done_schedule_before_time, end_time CopyAllocation( Allocation& prev_allocation, MemorySpace memory_space, std::optional chunk, diff --git a/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker.cc b/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker.cc index 177aed23bc4e0b..365993b1bc5969 100644 --- a/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker.cc +++ b/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker.cc @@ -507,7 +507,7 @@ class BestFitRepacker Result result; result.heap_size = result_.heap_size; - result.heap_results.emplace_back(result_); + result.heap_results.push_back(result_); return result; } diff --git a/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker.h b/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker.h index 8fd0f7c1550dc8..e22daba33991f6 100644 --- a/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker.h +++ b/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker.h @@ -17,6 +17,7 @@ limitations under the License. #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_BEST_FIT_REPACKER_H_ #include +#include #include "absl/status/statusor.h" #include "absl/types/span.h" diff --git a/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker_test.cc b/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker_test.cc index 3003bd69e617e8..2b47d1223f800b 100644 --- a/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/best_fit_repacker_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "xla/service/memory_space_assignment/best_fit_repacker.h" #include +#include +#include #include "absl/container/flat_hash_map.h" #include "absl/types/span.h" diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc index 29b9e453d7fad6..0b62def3fc7ea3 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc @@ -39,6 +39,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/buffer_value.h" #include "xla/service/hlo_cost_analysis.h" @@ -54,7 +55,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc index 5216d08860d66f..d6f0e41ad6340f 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc @@ -24,7 +24,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -286,11 +285,10 @@ void TransformAllocationSequenceToSpill(AllocationSequence& allocations, } // namespace absl::StatusOr -MemorySpaceAssignment::CalculateAsyncCopyStats() const { +MemorySpaceAssignment::CalculateAsyncCopyStats( + const HloDataflowAnalysis& dataflow_analysis) const { AsyncCopyStats stats; int64_t current_copies = 0; - TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow_analysis, - HloDataflowAnalysis::Run(*module_)); for (const HloComputation* computation : module_->MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->instructions()) { @@ -305,7 +303,7 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const { HloOpcode::kSlice)) { current_copies--; int64_t size = - options_.size_fn(dataflow_analysis->GetUniqueValueAt(instruction)); + options_.size_fn(dataflow_analysis.GetUniqueValueAt(instruction)); if (instruction->shape().layout().memory_space() == options_.alternate_memory_space) { ++stats.num_prefetches; @@ -388,11 +386,13 @@ MemorySpaceAssignment::RunMemorySpaceAssignment( if (options_.cost_analysis) { runtime_simulator.emplace(options_.cost_analysis, options_.alternate_memory_space); - float estimated_time = - runtime_simulator->SimulateElapsedTimeWithoutAsyncCopyLikes( - hlo_live_range, allocations_); - VLOG(1) << "Estimated elapsed time without async copies (sec): " - << estimated_time; + if (VLOG_IS_ON(1)) { + float estimated_time = + runtime_simulator->SimulateElapsedTimeWithoutAsyncCopyLikes( + hlo_live_range, allocations_); + LOG(INFO) << "Estimated elapsed time without async copies (sec): " + << estimated_time; + } } TF_RETURN_IF_ERROR(Process(hlo_live_range)); @@ -409,35 +409,34 @@ MemorySpaceAssignment::RunMemorySpaceAssignment( ScheduleAsynchronousCopies(); TF_RETURN_IF_ERROR(SimplifyGraph()); TF_RETURN_IF_ERROR(FixSchedule()); - TF_RETURN_IF_ERROR(ExportAndColorBuffers()); + TF_ASSIGN_OR_RETURN(auto alias, HloAliasAnalysis::Run(module_)); + TF_RETURN_IF_ERROR(ExportAndColorBuffers(*alias)); std::vector alt_mem_bytes_occupied; // alt_mem_bytes_occupied is used for logging in the RuntimeSimulator below. // We only populate it in VerifyAndExportHeapSimulatorTrace if the // RuntimeSimulator is present. TF_RETURN_IF_ERROR(VerifyAndExportHeapSimulatorTrace( + *alias, runtime_simulator.has_value() ? &alt_mem_bytes_occupied : nullptr)); - if (runtime_simulator.has_value()) { - float estimated_time = runtime_simulator->SimulateElapsedTime( - module_, allocations_, &alt_mem_bytes_occupied); - VLOG(1) << "Estimated elapsed time with async copies (sec): " - << estimated_time; - } if (VLOG_IS_ON(3)) { LOG(INFO) << "Module after memory space assignment: "; XLA_LOG_LINES(INFO, module_->ToString()); } TF_CHECK_OK(module_->schedule().Verify()); - TF_ASSIGN_OR_RETURN(AsyncCopyStats stats, CalculateAsyncCopyStats()); - VLOG(1) << "Maximum number of outstanding async copies/slices: " - << stats.max_outstanding_async_copies; - VLOG(1) << "Number of prefetches: " << stats.num_prefetches - << ", in bytes: " << stats.prefetch_bytes; - VLOG(1) << "Number of sliced prefetches: " << stats.num_sliced_prefetches - << ", consuming number of slices: " - << stats.num_sliced_prefetch_slices; - VLOG(1) << "Number of evictions: " << stats.num_evictions - << ", in bytes: " << stats.eviction_bytes; + if (VLOG_IS_ON(1)) { + TF_ASSIGN_OR_RETURN(AsyncCopyStats stats, + CalculateAsyncCopyStats(alias->dataflow_analysis())); + LOG(INFO) << "Maximum number of outstanding async copies/slices: " + << stats.max_outstanding_async_copies; + LOG(INFO) << "Number of prefetches: " << stats.num_prefetches + << ", in bytes: " << stats.prefetch_bytes; + LOG(INFO) << "Number of sliced prefetches: " << stats.num_sliced_prefetches + << ", consuming number of slices: " + << stats.num_sliced_prefetch_slices; + LOG(INFO) << "Number of evictions: " << stats.num_evictions + << ", in bytes: " << stats.eviction_bytes; + } return std::move(preset_assignments_); } @@ -539,15 +538,15 @@ absl::Status MemorySpaceAssignment::Process( return absl::OkStatus(); } -absl::Status MemorySpaceAssignment::ExportAndColorBuffers() { +absl::Status MemorySpaceAssignment::ExportAndColorBuffers( + const HloAliasAnalysis& alias_analysis) { VLOG(1) << "Exporting buffers..."; - TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_)); absl::flat_hash_map seen_buffer_offsets; VLOG(3) << "Exported alternate memory allocations:"; for (const auto& position_and_chunk : alternate_memory_assignments_) { const HloPosition& defining_position = position_and_chunk.first; const HeapSimulator::Chunk& chunk = position_and_chunk.second; - const HloBuffer& buffer = alias_analysis->GetUniqueBufferAt( + const HloBuffer& buffer = alias_analysis.GetUniqueBufferAt( defining_position.instruction, defining_position.index); auto seen_buffer_offset_it = seen_buffer_offsets.find(buffer.id()); if (seen_buffer_offset_it != seen_buffer_offsets.end()) { @@ -589,7 +588,7 @@ absl::Status MemorySpaceAssignment::ExportAndColorBuffers() { for (const auto& defining_position_and_chunk : preset_assignments_->chunks()) { const HloPosition& defining_position = defining_position_and_chunk.first; - for (auto& buffer : alias_analysis->ComputeBuffersAt( + for (auto& buffer : alias_analysis.ComputeBuffersAt( defining_position.instruction, defining_position.index)) { for (auto& value : buffer->values()) { for (auto& position : value->positions()) { @@ -1049,12 +1048,11 @@ absl::Status MemorySpaceAssignment::FixSchedule() { } absl::Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace( + const HloAliasAnalysis& alias_analysis, std::vector* alt_mem_bytes_occupied) { VLOG(1) << "Verifying..."; - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module_)); TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_live_range, - HloLiveRange::Run(module_->schedule(), *alias_analysis, + HloLiveRange::Run(module_->schedule(), alias_analysis, module_->entry_computation())); BufferIntervalTree interval_tree; @@ -1120,7 +1118,7 @@ absl::Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace( const HloPosition& position = position_and_chunk.first; const HeapSimulator::Chunk& chunk = position_and_chunk.second; const HloBuffer& buffer = - alias_analysis->GetUniqueBufferAt(position.instruction, position.index); + alias_analysis.GetUniqueBufferAt(position.instruction, position.index); CHECK(!seen_buffers.contains(buffer.id())) << "Multiple preset assignments for the same buffer: " << buffer.ToString() << ", pos: " << position.ToString() diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h index e2ff35441e4d51..d2bcccc161684f 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h @@ -190,6 +190,7 @@ Useful logging and error messages #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/buffer_value.h" @@ -305,7 +306,8 @@ class MemorySpaceAssignment { const HloAliasAnalysis& alias_analysis, const Options& options); // Calculates asynchronous copy statistics. - absl::StatusOr CalculateAsyncCopyStats() const; + absl::StatusOr CalculateAsyncCopyStats( + const HloDataflowAnalysis& dataflow_analysis) const; // Verify that allocations_ are free of overlapping Allocations in time and // space. This is a post-processing step called after all allocations have @@ -318,6 +320,7 @@ class MemorySpaceAssignment { // If alt_mem_bytes_occupied is not null, it will be populated with the number // of bytes occupied in the alternate memory space at each instruction time. absl::Status VerifyAndExportHeapSimulatorTrace( + const HloAliasAnalysis& alias_analysis, std::vector* alt_mem_bytes_occupied = nullptr); protected: @@ -372,7 +375,7 @@ class MemorySpaceAssignment { // Export the alternate memory assignments to the PresetAssignments and color // the HLO graph with the determined memory spaces. - absl::Status ExportAndColorBuffers(); + absl::Status ExportAndColorBuffers(const HloAliasAnalysis& alias_analysis); // Schedules asynchronous copies and ensures that the CopyStarts and their // corresponding CopyDones follow the same order. diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc index d243e9c8afea22..badefa8a0a3951 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -25,9 +25,7 @@ limitations under the License. #include #include #include -#include #include -#include #include #include #include @@ -55,12 +53,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" -#include "xla/hlo/transforms/simplifiers/instruction_hoister.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/layout_util.h" #include "xla/literal_util.h" -#include "xla/service/buffer_value.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo_buffer.h" @@ -72,24 +69,25 @@ limitations under the License. #include "xla/service/memory_space_assignment/buffer_interval_comparator.h" #include "xla/service/memory_space_assignment/cost_analysis.h" #include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/service/memory_space_assignment/memory_space_assignment_test_base.h" #include "xla/service/memory_space_assignment/options.h" #include "xla/service/memory_space_assignment/prefetch_interval_picker.h" #include "xla/service/memory_space_assignment/repacking.h" #include "xla/service/memory_space_assignment/slice.h" #include "xla/service/memory_space_assignment/testing_utils.h" +#include "xla/service/memory_space_assignment/utils.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/status_matchers.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep -#include "tsl/platform/status.h" -#include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" namespace xla { namespace memory_space_assignment { @@ -101,404 +99,10 @@ using ::testing::_; using ::testing::Return; using ::testing::UnorderedElementsAre; -constexpr float kDefaultMemBandwidth = 100; -constexpr float kAlternateMemBandwidth = 1000; constexpr float kBytesPerSecond = 100; -constexpr float kFlopsPerSecond = 1000; -constexpr float kTranscendentalsPerSecond = 10; const auto& ShapeSize = HloCostAnalysis::DefaultShapeSize; -int64_t SizeFunction(const BufferValue& value) { - return ShapeSize(value.shape()); -} - -int64_t ReservedScopedMemoryFn( - const HloInstruction* instruction, - const absl::flat_hash_set>& - operands_in_alternate_memory, - const absl::flat_hash_set& outputs_in_alternate_memory) { - return 0; -} - -class TestBufferIntervalComparator : public BufferIntervalComparator { - public: - explicit TestBufferIntervalComparator(MsaBufferIntervalCompare compare_method) - : BufferIntervalComparator(), compare_method_(compare_method) {} - - ~TestBufferIntervalComparator() override = default; - - std::string DescribeComparisonCriteria() const override { - return "internal to test"; - } - std::string CriteriaToString( - const MsaBufferInterval& buffer_interval) override { - return "internal to test"; - } - bool LessThan(const MsaBufferInterval& lhs, - const MsaBufferInterval& rhs) override { - return compare_method_(lhs, rhs); - } - - private: - MsaBufferIntervalCompare compare_method_; -}; - -class MemorySpaceAssignmentTestBase : public HloTestBase { - protected: - // We use the following two memory space values to describe the default (slow - // and large) and alternate (fast and small) memory spaces. - const int64_t kDefaultMemorySpace = 0; - const int64_t kAlternateMemorySpace = 1; - - HloCostAnalysis::Options DefaultHloCostAnalysisOptions() { - HloCostAnalysis::Options options; - options.set_flops_per_second(kFlopsPerSecond); - options.set_bytes_per_second(kBytesPerSecond); - options.set_transcendentals_per_second(kTranscendentalsPerSecond); - - return options; - } - - Options DefaultMemorySpaceOptions() { - Options options; - options.max_size_in_bytes = 128; - options.alignment_in_bytes = 8; - options.verify = false; - options.alternate_memory_space = kAlternateMemorySpace; - options.max_outstanding_prefetches = -1; - options.max_outstanding_evictions = -1; - - return options; - } - - CostAnalysisOptions DefaultCostAnalysisOptions() { - CostAnalysisOptions options; - options.default_mem_bandwidth_bytes_per_second = kDefaultMemBandwidth; - options.alternate_mem_bandwidth_bytes_per_second = kAlternateMemBandwidth; - return options; - } - - Options UpdateMaxAsyncCopies(Options options, int64_t max_async_copies) { - options.max_outstanding_prefetches = max_async_copies; - options.max_outstanding_evictions = max_async_copies; - - return options; - } - - std::unique_ptr AssignMemorySpaceUsingCostAnalysis( - HloModule* module, - std::optional memory_space_options_override = std::nullopt, - std::optional cost_analysis_options_override = - std::nullopt, - std::optional hlo_cost_options_override = - std::nullopt, - std::optional optional_msa_sort_order_overrides = - std::nullopt) { - HloCostAnalysis::Options hlo_cost_options = DefaultHloCostAnalysisOptions(); - if (hlo_cost_options_override) { - hlo_cost_options = *hlo_cost_options_override; - } - - HloCostAnalysis hlo_cost_analysis(hlo_cost_options); - for (HloComputation* computation : module->MakeNonfusionComputations()) { - TF_CHECK_OK(computation->Accept(&hlo_cost_analysis)); - } - auto alias_analysis = HloAliasAnalysis::Run(module).value(); - - Options memory_space_options = DefaultMemorySpaceOptions(); - if (memory_space_options_override) { - memory_space_options = *memory_space_options_override; - } - CostAnalysisOptions cost_analysis_options = DefaultCostAnalysisOptions(); - if (cost_analysis_options_override) { - cost_analysis_options = *cost_analysis_options_override; - } - HloCostAnalysisCosts hlo_cost_analysis_costs(hlo_cost_analysis); - - auto cost_analysis = CostAnalysis::Create(hlo_cost_analysis_costs, - cost_analysis_options, *module) - .value(); - memory_space_options.cost_analysis = cost_analysis.get(); - CostAnalysisPrefetchIntervalPicker prefetch_interval_picker( - CostAnalysisPrefetchIntervalPicker( - *cost_analysis, /*min_overlap_to_async_copy_ratio=*/0.8, - /*preferred_overlap_to_async_copy_ratio=*/1.5, - /*max_overlap_to_mem_size_async_copy_ratio=*/10.0, - /*mem_size_bytes=*/memory_space_options.max_size_in_bytes)); - MsaSortOrderOverrides msa_sort_order_overrides; - if (optional_msa_sort_order_overrides.has_value()) { - msa_sort_order_overrides = optional_msa_sort_order_overrides.value(); - } - MemoryBoundednessBufferIntervalComparator comparator( - *cost_analysis, &cache_, msa_sort_order_overrides); - return AssignMemorySpace( - module, memory_space_options, - [&comparator](const MsaBufferInterval& lhs, - const MsaBufferInterval& rhs) { - return comparator.LessThan(lhs, rhs); - }, - &prefetch_interval_picker); - } - - std::unique_ptr AssignMemorySpace( - HloModule* module, std::optional options_override = std::nullopt, - int64_t max_prefetch_interval = 10, int64_t min_prefetch_interval = 2) { - InstructionHoister instruction_hoister; - TF_CHECK_OK(instruction_hoister.Run(module).status()); - InstructionCountPrefetchIntervalPicker prefetch_interval_picker( - min_prefetch_interval, max_prefetch_interval); - return AssignMemorySpace(module, options_override, - /*buffer_interval_compare=*/{}, - &prefetch_interval_picker); - } - - std::unique_ptr AssignMemorySpace( - HloModule* module, std::optional options_override, - std::optional buffer_interval_compare, - PrefetchIntervalPicker* prefetch_interval_picker) { - auto status_or = AssignMemorySpaceAndReturnStatus(module, options_override, - buffer_interval_compare, - prefetch_interval_picker); - TF_EXPECT_OK(status_or.status()); - return std::move(status_or.value()); - } - - absl::StatusOr> - AssignMemorySpaceAndReturnStatus( - HloModule* module, std::optional options_override, - std::optional buffer_interval_compare, - PrefetchIntervalPicker* prefetch_interval_picker) { - auto size_fn = [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); - }; - - auto is_allowed_in_alternate_mem = [](const HloValue& value) { - // Check if the value belongs to the entry computation. - HloInstruction* instruction = value.instruction(); - HloComputation* computation = instruction->parent(); - bool in_entry_computation = - (computation == computation->parent()->entry_computation()); - if (in_entry_computation && - instruction->opcode() == HloOpcode::kParameter) { - return false; - } - return true; - }; - - // Only check parameters in default memory if the original module didn't - // have the parameters in alternate memory. - bool check_parameters_in_default_memory = true; - for (const HloInstruction* parameter : - module->entry_computation()->parameter_instructions()) { - ShapeUtil::ForEachSubshape( - parameter->shape(), - [&](const Shape& subshape, const ShapeIndex& /*index*/) { - if (subshape.has_layout() && - subshape.layout().memory_space() == kAlternateMemorySpace) { - check_parameters_in_default_memory = false; - } - }); - } - - Options options = DefaultMemorySpaceOptions(); - if (options_override) { - options = *options_override; - } - std::unique_ptr test_comparator; - if (buffer_interval_compare.has_value()) { - test_comparator = std::make_unique( - *buffer_interval_compare); - options.buffer_interval_comparator = test_comparator.get(); - } - options.prefetch_interval_picker = prefetch_interval_picker; - options.size_fn = size_fn; - if (options.is_allowed_in_alternate_mem_fn == nullptr) { - options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem; - } - - TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module)); - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_live_range, - HloLiveRange::Run(module->schedule(), *alias_analysis, - module->entry_computation())); - - TF_ASSIGN_OR_RETURN(std::unique_ptr preset_assignments, - MemorySpaceAssignment::Run(module, *hlo_live_range, - *alias_analysis, options)); - if (check_parameters_in_default_memory) { - CheckParametersInDefaultMemory(module); - } - CheckRootInDefaultMemory(module); - CheckPresetAssignments(preset_assignments.get()); - return preset_assignments; - } - - void CheckPresetAssignments(const PresetAssignments* preset_assignments) { - // Ensure that the exported preset assignments point to layouts in the - // alternate memory. Also ensure that the positions are unique. Note that - // we're using a std::set instead of absl::flat_hash_set because we can make - // use of HloPosition's comparator logic instead of providing a hasher. - std::set positions_in_preset_assignments; - for (auto& position_and_chunk : preset_assignments->chunks()) { - HloPosition position = position_and_chunk.first; - EXPECT_EQ(positions_in_preset_assignments.find(position), - positions_in_preset_assignments.end()); - positions_in_preset_assignments.insert(position); - const Shape& subshape = - ShapeUtil::GetSubshape(position.instruction->shape(), position.index); - EXPECT_EQ(subshape.layout().memory_space(), kAlternateMemorySpace) - << "Exported position is not in alternate mem: " - << position.ToString(); - } - } - - void CheckParametersInDefaultMemory(const HloModule* module) { - // Check that all the entry parameter subshapes are placed in default - // memory. - const HloComputation* entry_computation = module->entry_computation(); - for (const HloInstruction* parameter : - entry_computation->parameter_instructions()) { - ShapeUtil::ForEachSubshape( - parameter->shape(), - [&](const Shape& subshape, const ShapeIndex& /*index*/) { - if (subshape.has_layout()) { - EXPECT_NE(subshape.layout().memory_space(), kAlternateMemorySpace) - << "Parameter not in default memory: " - << parameter->ToString(); - } - }); - } - } - - void CheckRootInDefaultMemory(const HloModule* module) { - const HloInstruction* root = - module->entry_computation()->root_instruction(); - if (root->shape().IsArray()) { - EXPECT_EQ(root->shape().layout().memory_space(), kDefaultMemorySpace); - } - } - - struct OutstandingAsyncCopies { - int64_t max_copies; - int64_t max_prefetches; - int64_t max_evictions; - }; - - /*static*/ OutstandingAsyncCopies CountMaximumOutstandingAsyncCopies( - const HloModule& module) { - OutstandingAsyncCopies copies{0, 0, 0}; - int64_t current_copies = 0; - int64_t current_prefetches = 0; - int64_t current_evictions = 0; - for (HloInstruction* instruction : module.schedule() - .sequence(module.entry_computation()) - .instructions()) { - if (instruction->opcode() == HloOpcode::kCopyStart) { - current_copies++; - if (ShapeUtil::GetSubshape(instruction->shape(), {0}) - .layout() - .memory_space() == kAlternateMemorySpace) { - current_prefetches++; - } else { - current_evictions++; - } - } else if (instruction->opcode() == HloOpcode::kCopyDone) { - current_copies--; - if (instruction->shape().layout().memory_space() == - kAlternateMemorySpace) { - current_prefetches--; - } else { - current_evictions--; - } - } - copies.max_copies = std::max(copies.max_copies, current_copies); - copies.max_prefetches = - std::max(copies.max_prefetches, current_prefetches); - copies.max_prefetches = std::max(copies.max_evictions, current_evictions); - } - return copies; - } - - int64_t GetAlternateMemoryOffset(const PresetAssignments& preset_assignments, - const HloInstruction* instruction, - const ShapeIndex& index = {}) const { - // Returns the offset of the assignment, -1 if it's not in the alternate - // memory. - const HloModule* module = instruction->GetModule(); - auto alias_analysis = HloAliasAnalysis::Run(module).value(); - HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(instruction, index); - for (auto& pos_and_chunk : preset_assignments.chunks()) { - for (auto& value : buffer.values()) { - if (pos_and_chunk.first == value->defining_position()) { - return pos_and_chunk.second.offset; - } - } - } - return -1; - } - - std::unique_ptr CreateEvictAndPrefetchModule() { - HloComputation::Builder builder(TestName()); - Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); - HloInstruction* p0 = - builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); - HloInstruction* p1 = - builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); - HloInstruction* tanh = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); - // tanh should be placed in the alternate memory since there isn't much - // contention in the beginning. However, tanh has another consumer at the - // end. So it should be kicked out to default memory and prefetched back in. - // The graph below is meant to increase the contention to force - // eviction/prefetch behavior. - HloInstruction* a = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh)); - HloInstruction* b = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1)); - HloInstruction* c = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1)); - HloInstruction* d = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1)); - HloInstruction* e = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b)); - HloInstruction* f = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c)); - HloInstruction* g = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d)); - HloInstruction* h = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c)); - HloInstruction* i = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d)); - HloInstruction* j = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d)); - HloInstruction* k = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f)); - HloInstruction* l = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h)); - HloInstruction* m = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j)); - HloInstruction* n = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l)); - HloInstruction* o = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m)); - // tanh is being used at the root instruction, and this should be - // prefetched. - HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh)); - - auto module = CreateNewVerifiedModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); - - HloSchedule schedule(module.get()); - schedule.set_sequence(computation, {p0, p1, tanh, a, b, c, d, e, f, g, h, i, - j, k, l, m, n, o, add}); - TF_CHECK_OK(module->set_schedule(schedule)); - return module; - } - - CostAnalysis::Cache cache_; -}; - using MemorySpaceAssignmentTest = MemorySpaceAssignmentTestBase; TEST_F(MemorySpaceAssignmentTest, ParameterOnly) { @@ -634,6 +238,53 @@ TEST_F(MemorySpaceAssignmentTest, NegateChain) { EXPECT_THAT(sequence.instructions()[10], op::CopyDone()); } +TEST_F(MemorySpaceAssignmentTest, PinnedDefaultMemorySpace) { + absl::string_view hlo_string = R"( + HloModule NegateChain, is_scheduled=true, entry_computation_layout={(f32[2,3]{1,0}, f32[2,3]{1,0:S(2)})->f32[2,3]{1,0}} + + ENTRY %NegateChain (p0: f32[2,3], p1: f32[2,3]) -> f32[2,3] { + %p0 = f32[2,3]{1,0} parameter(0) + %p1 = f32[2,3]{1,0:S(2)} parameter(1) + %negate = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0} %p0) + %negate.1 = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0:S(2)} %negate) + %negate.2 = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0:S(2)} %negate.1) + %negate.3 = f32[2,3]{1,0} negate(f32[2,3]{1,0:S(2)} %negate.2) + %negate.4 = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0} %negate.3) + %negate.5 = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0:S(2)} %negate.4) + %negate.6 = f32[2,3]{1,0:S(2)} negate(f32[2,3]{1,0:S(2)} %negate.5) + ROOT %add = f32[2,3]{1,0} add(f32[2,3]{1,0:S(2)} %negate.6, f32[2,3]{1,0:S(2)} %p1) + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + XLA_VLOG_LINES(1, module->ToString()); + HloInstruction* p0 = FindInstruction(module.get(), "p0"); + HloInstruction* p1 = FindInstruction(module.get(), "p1"); + HloInstruction* negate = FindInstruction(module.get(), "negate"); + HloInstruction* negate_1 = FindInstruction(module.get(), "negate.1"); + HloInstruction* negate_2 = FindInstruction(module.get(), "negate.2"); + HloInstruction* negate_3 = FindInstruction(module.get(), "negate.3"); + HloInstruction* negate_4 = FindInstruction(module.get(), "negate.4"); + HloInstruction* negate_5 = FindInstruction(module.get(), "negate.5"); + HloInstruction* negate_6 = FindInstruction(module.get(), "negate.6"); + HloInstruction* add = FindInstruction(module.get(), "add"); + std::vector pinned_hbm_instructions = { + p1, negate, negate_1, negate_2, negate_4, negate_5, negate_6}; + for (const HloInstruction* instruction : pinned_hbm_instructions) { + EXPECT_EQ(instruction->shape().layout().memory_space(), + kPinnedDefaultMemorySpace); + } + // Check p0 and add are in the default memory space. + EXPECT_EQ(p0->shape().layout().memory_space(), kDefaultMemorySpace); + EXPECT_EQ(add->shape().layout().memory_space(), kDefaultMemorySpace); + // Check negate_3 is in pinned to alternate memory space. + EXPECT_EQ(negate_3->shape().layout().memory_space(), kAlternateMemorySpace); + // Check that p1 is only used once at the add instruction. ie, the there is no + // copy/prefetch. + CHECK_EQ(p1->users().size(), 1); + EXPECT_EQ(p1->users()[0], add); +} + // A simple case where the synchronous copy is actually redundant, because its // operand ends up getting prefetched and the its output is only used once, so // we remove the sync copy. @@ -5890,6 +5541,11 @@ TEST_F(MemorySpaceAssignmentTest, /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, kAlternateMemorySpace); + Shape shape_in_default_mem = ShapeUtil::MakeShapeWithDenseLayout( + F32, {2, 3}, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, + /*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0, + kDefaultMemorySpace); // p0 is in the default memory space. HloInstruction* p0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); @@ -5930,13 +5586,14 @@ TEST_F(MemorySpaceAssignmentTest, options.is_allowed_in_alternate_mem_fn = [](const HloValue& value) { return true; }; + XLA_VLOG_LINES(3, module->ToString()); std::unique_ptr preset_assignments = AssignMemorySpace(module.get(), options); - + XLA_VLOG_LINES(3, module->ToString()); // Ensure that p1 is in the alternate memory and add, which has p1 as an // operand, has a direct dependency to p1 (no CopyStart/CopyDone). - EXPECT_THAT(p1, op::ShapeWithLayout(shape_in_alternate_mem)); - EXPECT_THAT(add, op::Add(op::Negate(), op::Parameter(1))); + EXPECT_THAT(p1, op::ShapeWithLayout(shape_in_default_mem)); + EXPECT_THAT(add, op::Add(op::Negate(), op::CopyDone())); // Make sure add is still in the alternate memory space. EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem)); @@ -5945,6 +5602,7 @@ TEST_F(MemorySpaceAssignmentTest, // alternate memory space are left to BufferAssignment to be allocated. for (const auto& position_and_chunk : preset_assignments->chunks()) { const HloPosition& position = position_and_chunk.first; + XLA_VLOG_LINES(3, position.instruction->ToString()); EXPECT_NE(position.instruction, p1); EXPECT_NE(position.instruction, add); } @@ -8767,7 +8425,7 @@ ENTRY main { Options options = DefaultMemorySpaceOptions(); options.position_requires_contiguous_allocation_fn = [](const HloPosition& position) { - std::string_view inst_name = position.instruction->name(); + absl::string_view inst_name = position.instruction->name(); if (inst_name == "fusion1" || (inst_name == "fusion2" && position.index != ShapeIndex({0}))) { return true; @@ -8873,7 +8531,7 @@ ENTRY main { Options options = DefaultMemorySpaceOptions(); options.position_requires_contiguous_allocation_fn = [](const HloPosition& position) { - std::string_view inst_name = position.instruction->name(); + absl::string_view inst_name = position.instruction->name(); if (inst_name == "fusion1" || (inst_name == "fusion2" && position.index != ShapeIndex({0}))) { return true; @@ -8995,7 +8653,7 @@ ENTRY main { Options options = DefaultMemorySpaceOptions(); options.position_requires_contiguous_allocation_fn = [](const HloPosition& position) { - std::string_view inst_name = position.instruction->name(); + absl::string_view inst_name = position.instruction->name(); if (inst_name == "fusion1" || (inst_name == "fusion2" && position.index != ShapeIndex({0})) || (inst_name == "fusion3" && position.index != ShapeIndex({0}))) { @@ -9110,23 +8768,15 @@ entry { AssignMemorySpace(module.get(), options, /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/0); const HloInstruction* fusion = FindInstruction(module.get(), "fusion"); - // The fusion instruction should have 5 operands: the 3 original operands - // plus 2 window prefetch buffers. - EXPECT_EQ(fusion->operand_count(), 5); - - // The 2 added operands are async calls to WindowPrefetch. - for (int i = 3; i < 5; i++) { - const HloInstruction* async_done = fusion->operand(i); - EXPECT_EQ(async_done->opcode(), HloOpcode::kAsyncDone); - EXPECT_EQ(async_done->operand_count(), 1); - EXPECT_TRUE(async_done->async_wrapped_instruction()->IsCustomCall( - "WindowPrefetch")); + // The fusion instruction should have 7 operands: the 3 original operands + // plus 2 window prefetch buffers, plus 2 sync flags. + EXPECT_EQ(fusion->operand_count(), 7); - const HloInstruction* async_start = async_done->operand(0); - EXPECT_EQ(async_start->opcode(), HloOpcode::kAsyncStart); - EXPECT_EQ(async_start->operand_count(), 1); - EXPECT_TRUE(async_start->async_wrapped_instruction()->IsCustomCall( - "WindowPrefetch")); + // The added operands are GetTupleElements of WindowPrefetch custom calls. + for (int i = 3; i < 7; i++) { + EXPECT_EQ(fusion->operand(i)->opcode(), HloOpcode::kGetTupleElement); + const HloInstruction* window_prefetch = fusion->operand(i)->operand(0); + EXPECT_TRUE(window_prefetch->IsCustomCall("WindowPrefetch")); } VLOG(2) << "module: " << module->ToString(); @@ -10225,7 +9875,7 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchPinnedTest) { AssignMemorySpace(module.get(), options); auto cross_program_prefetches = module->CrossProgramPrefetches(); - EXPECT_EQ(cross_program_prefetches.size(), 0); + EXPECT_GT(cross_program_prefetches.size(), 0); } TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchPinnedTupleTest) { @@ -10272,7 +9922,7 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchPinnedTupleTest) { AssignMemorySpace(module.get(), options); auto cross_program_prefetches = module->CrossProgramPrefetches(); - EXPECT_EQ(cross_program_prefetches.size(), 0); + EXPECT_GT(cross_program_prefetches.size(), 0); } TEST_F(MemorySpaceAssignmentTest, CrossProgramRootDupMayAlias) { @@ -10526,8 +10176,10 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchNoReuse) { } TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchWithOverrideNoReuse) { - // This test is for checking if the cross-program-prefetched buffer is freed - // after its last use and there is an end-of-program prefetch. + // This test is same as above, but with an override to cross-program prefetch + // parameter0 as opposed to p0 and limiting the max alternate memory + // size to 256 bytes so that both p0 and p1 cannot be assigned to alternate + // memory and priority is given to p0. absl::string_view hlo_string = R"( HloModule cross_program_prefetch, is_scheduled=true @@ -10615,6 +10267,203 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchWithOverrideNoReuse) { EXPECT_TRUE(has_zero_offset_allocations); } +TEST_F(MemorySpaceAssignmentTest, UserAnnotatedCrossProgramPrefetchNoReuse) { + // This test is same as above, but with user directive to cross-program + // prefetch parameter0 as opposed to p0 and limiting the max alternate memory + // size to 256 bytes so that both p0 and p1 cannot be assigned to alternate + // memory and priority is given to p0. + absl::string_view hlo_string = R"( + HloModule cross_program_prefetch, is_scheduled=true, entry_computation_layout={(f32[8,8]{1,0:S(1)}, f32[8,2]{1,0})->f32[8,2]{1,0}} + + ENTRY CrossProgramPrefetch { + p0 = f32[8,8]{1,0:S(1)} parameter(0) + p1 = f32[8,2]{1,0} parameter(1) + dot = f32[8,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + negate.1 = f32[8,2]{1,0} negate(dot) + negate.2 = f32[8,2]{1,0} negate(negate.1) + negate.3 = f32[8,2]{1,0} negate(negate.2) + negate.4 = f32[8,2]{1,0} negate(negate.3) + negate.5 = f32[8,2]{1,0} negate(negate.4) + negate.6 = f32[8,2]{1,0} negate(negate.5) + negate.7 = f32[8,2]{1,0} negate(negate.6) + negate.8 = f32[8,2]{1,0} negate(negate.7) + ROOT negate.9 = f32[8,2]{1,0} negate(negate.8) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 256; + auto preset_assignments = AssignMemorySpace(module.get(), options, + /*max_prefetch_interval=*/5, + /*min_prefetch_interval=*/2); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 1); + EXPECT_EQ(cross_program_prefetches[0].parameter, 0); + EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({})); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr dataflow_analysis, + HloDataflowAnalysis::Run(*module)); + LOG(ERROR) << "module: " << module->ToString(); + const HloValue& cross_program_prefetched_value = + dataflow_analysis->GetValueDefinedAt( + module->entry_computation()->parameter_instruction(0), {}); + // Expect that there are two prefetches that use this value, one is the + // cross-program prefetch, the other is the end-of-program prefetch. + auto is_cross_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + use.instruction->cross_program_prefetch_index().has_value(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(), + is_cross_program_prefetch), + 1); + auto is_end_of_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + !use.instruction->cross_program_prefetch_index().has_value(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(), + is_end_of_program_prefetch), + 1); + // Also verify that the copy-done for the end-of-program prefetch is the last + // instruction in schedule. + const HloInstruction* last_instruction = + module->schedule() + .sequence(module->entry_computation()) + .instructions()[module->entry_computation()->instruction_count() - 1]; + EXPECT_THAT(last_instruction, op::CopyDone()); + EXPECT_NE(last_instruction, module->entry_computation()->root_instruction()); + // Cross program prefetch would use offset 0 because that's the first + // assignment. Since we are freeing the cross-program prefetch buffer, we + // would also expect to see some of the intermediate computations (one of the + // negate ops) to also get 0 offset allocations. + bool has_zero_offset_allocations = false; + for (auto pos_and_chunk : preset_assignments->chunks()) { + if (pos_and_chunk.first.instruction->opcode() == HloOpcode::kNegate && + pos_and_chunk.second.offset == 0) { + has_zero_offset_allocations = true; + } + } + EXPECT_TRUE(has_zero_offset_allocations); + XLA_VLOG_LINES(3, module->ToString()); + bool found = false; + for (auto* c : module->computations()) { + for (auto* instr : c->instructions()) { + if (instr->name() == "p0") { + found = true; + EXPECT_EQ(instr->shape().layout().memory_space(), 0); + EXPECT_EQ(module->entry_computation_layout() + .parameter_layout(0) + .shape() + .layout() + .memory_space(), + 0); + } + } + } + EXPECT_TRUE(found); +} + +TEST_F(MemorySpaceAssignmentTest, + UserAnnotatedCrossProgramPrefetchWithoutPropagationToParameterNoReuse) { + // This test is same as above, but the S(1) memory space specified in the + // layout to cross-program prefetch p0 is only present in the entry + // computation layout and has not been propagated to the parameter + // instruction. This still works as the previous test. + absl::string_view hlo_string = R"( + HloModule cross_program_prefetch, is_scheduled=true, entry_computation_layout={(f32[8,8]{1,0:S(1)}, f32[8,2]{1,0})->f32[8,2]{1,0}} + + ENTRY CrossProgramPrefetch { + p0 = f32[8,8]{1,0} parameter(0) + p1 = f32[8,2]{1,0} parameter(1) + dot = f32[8,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + negate.1 = f32[8,2]{1,0} negate(dot) + negate.2 = f32[8,2]{1,0} negate(negate.1) + negate.3 = f32[8,2]{1,0} negate(negate.2) + negate.4 = f32[8,2]{1,0} negate(negate.3) + negate.5 = f32[8,2]{1,0} negate(negate.4) + negate.6 = f32[8,2]{1,0} negate(negate.5) + negate.7 = f32[8,2]{1,0} negate(negate.6) + negate.8 = f32[8,2]{1,0} negate(negate.7) + ROOT negate.9 = f32[8,2]{1,0} negate(negate.8) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 256; + auto preset_assignments = AssignMemorySpace(module.get(), options, + /*max_prefetch_interval=*/5, + /*min_prefetch_interval=*/2); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 1); + EXPECT_EQ(cross_program_prefetches[0].parameter, 0); + EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({})); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr dataflow_analysis, + HloDataflowAnalysis::Run(*module)); + LOG(ERROR) << "module: " << module->ToString(); + const HloValue& cross_program_prefetched_value = + dataflow_analysis->GetValueDefinedAt( + module->entry_computation()->parameter_instruction(0), {}); + // Expect that there are two prefetches that use this value, one is the + // cross-program prefetch, the other is the end-of-program prefetch. + auto is_cross_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + use.instruction->cross_program_prefetch_index().has_value(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(), + is_cross_program_prefetch), + 1); + auto is_end_of_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + !use.instruction->cross_program_prefetch_index().has_value(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(), + is_end_of_program_prefetch), + 1); + // Also verify that the copy-done for the end-of-program prefetch is the last + // instruction in schedule. + const HloInstruction* last_instruction = + module->schedule() + .sequence(module->entry_computation()) + .instructions()[module->entry_computation()->instruction_count() - 1]; + EXPECT_THAT(last_instruction, op::CopyDone()); + EXPECT_NE(last_instruction, module->entry_computation()->root_instruction()); + // Cross program prefetch would use offset 0 because that's the first + // assignment. Since we are freeing the cross-program prefetch buffer, we + // would also expect to see some of the intermediate computations (one of the + // negate ops) to also get 0 offset allocations. + bool has_zero_offset_allocations = false; + for (auto pos_and_chunk : preset_assignments->chunks()) { + if (pos_and_chunk.first.instruction->opcode() == HloOpcode::kNegate && + pos_and_chunk.second.offset == 0) { + has_zero_offset_allocations = true; + } + } + EXPECT_TRUE(has_zero_offset_allocations); + XLA_VLOG_LINES(3, module->ToString()); + bool found = false; + for (auto* c : module->computations()) { + for (auto* instr : c->instructions()) { + if (instr->name() == "p0") { + found = true; + EXPECT_EQ(instr->shape().layout().memory_space(), 0); + EXPECT_EQ(module->entry_computation_layout() + .parameter_layout(0) + .shape() + .layout() + .memory_space(), + 0); + } + } + } + EXPECT_TRUE(found); +} + TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchTupleNoReuse) { // This test is for checking if the cross-program-prefetched buffer is freed // after its last use and there is an end-of-program prefetch. @@ -10692,6 +10541,75 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchTupleNoReuse) { EXPECT_TRUE(has_zero_offset_allocations); } +TEST_F(MemorySpaceAssignmentTest, + CrossProgramPrefetchEndOfProgramPrefetchAndWhile) { + absl::string_view hlo_string = R"( + HloModule cross_program_prefetch, is_scheduled=true + + while_condition { + param1 = (f32[8,2]{1,0}, f32[8,2]{1,0}) parameter(0) + ROOT cond = pred[] constant(true) + } + + while_body { + param2 = (f32[8,2]{1,0}, f32[8,2]{1,0}) parameter(0) + gte2 = f32[8,2]{1,0} get-tuple-element(param2), index=0 + gte3 = f32[8,2]{1,0} get-tuple-element(param2), index=1 + add = f32[8,2]{1,0} add(gte2, gte3) + negate.2 = f32[8,2]{1,0} negate(add) + negate.3 = f32[8,2]{1,0} negate(negate.2) + negate.4 = f32[8,2]{1,0} negate(negate.3) + negate.5 = f32[8,2]{1,0} negate(negate.4) + negate.6 = f32[8,2]{1,0} negate(negate.5) + negate.7 = f32[8,2]{1,0} negate(negate.6) + negate.8 = f32[8,2]{1,0} negate(negate.7) + ROOT tuple2 = (f32[8,2]{1,0}, f32[8,2]{1,0}) tuple(negate.8, gte3) + } + + ENTRY CrossProgramPrefetch { + p0 = f32[8,8]{1,0} parameter(0) + p1 = f32[8,2]{1,0} parameter(1) + dot = f32[8,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + negate.1 = f32[8,2]{1,0} negate(dot) + tuple = (f32[8,2]{1,0}, f32[8,2]{1,0}) tuple(negate.1, dot) + while = (f32[8,2]{1,0}, f32[8,2]{1,0}) while(tuple), condition=while_condition, body=while_body + ROOT gte0 = f32[8,2]{1,0} get-tuple-element(while), index=0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto preset_assignments = AssignMemorySpaceUsingCostAnalysis(module.get()); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 1); + EXPECT_EQ(cross_program_prefetches[0].parameter, 1); + EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({})); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr dataflow_analysis, + HloDataflowAnalysis::Run(*module)); + LOG(ERROR) << "module: " << module->ToString(); + const HloValue& cross_program_prefetched_value = + dataflow_analysis->GetValueDefinedAt( + module->entry_computation()->parameter_instruction(1), {}); + // Expect that there are two prefetches that use this value, one is the + // cross-program prefetch, the other is the end-of-program prefetch. + auto is_cross_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + use.instruction->cross_program_prefetch_index().has_value(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(), + is_cross_program_prefetch), + 1); + auto is_end_of_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + !use.instruction->cross_program_prefetch_index().has_value(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(), + is_end_of_program_prefetch), + 1); +} + TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchReuse) { // This tests the scenario that the cross-program-prefetched buffer is used // again close to the end of the computation. In this case, it is better not @@ -10894,7 +10812,7 @@ ENTRY entry { // - Test: prefetch p1, after p0 is unallocated from alternate memory (after // instruction c). TEST_F(MemorySpaceAssignmentTest, CopyResourceIntegration) { - std::string_view hlo_string = R"( + absl::string_view hlo_string = R"( HloModule module, is_scheduled=true ENTRY main { @@ -10983,7 +10901,7 @@ ENTRY main { // Check the schedule const std::vector& schedule = module->schedule().sequence(module->entry_computation()).instructions(); - auto find_schedule_index = [&schedule](std::string_view name) -> int { + auto find_schedule_index = [&schedule](absl::string_view name) -> int { for (int i = 0; i < schedule.size(); ++i) { if (schedule[i]->name() == name) { return i; @@ -11275,7 +11193,7 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { static bool MatchMemorySpace(const HloInstruction* instruction, int64_t expected_memory_space, - std::string_view error_message_identifier, + absl::string_view error_message_identifier, ::testing::MatchResultListener* listener) { if (!instruction->shape().has_layout()) { *listener << " contains " << error_message_identifier << " named " @@ -11432,7 +11350,7 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { // Returns the index of the first instruction with the given name. static absl::StatusOr FindScheduleIndexOfInstruction( - const std::vector& schedule, std::string_view name, + const std::vector& schedule, absl::string_view name, InstructionClass c) { for (int i = 0; i < schedule.size(); ++i) { if (schedule[i]->name() == name) { @@ -11448,7 +11366,7 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { // Returns a scheduled instruction with the specified name or null. static const HloInstruction* FindNamedScheduledInstruction( - const HloModule& module, std::string_view name) { + const HloModule& module, absl::string_view name) { for (const HloInstruction* i : module.entry_computation()->instructions()) { if (i->name() == name) { return i; @@ -11703,8 +11621,8 @@ class SlicedPrefetchTest : public MemorySpaceAssignmentTestBase { // - concat_bitcast comes after all slice dones AND static absl::Status CheckSchedule( const HloModule& module, const HloInstruction* concat_bitcast, - std::string_view slices_start_after_instruction_name, - std::string_view slices_done_before_instruction_name, + absl::string_view slices_start_after_instruction_name, + absl::string_view slices_done_before_instruction_name, bool expect_slices_started_at_different_times) { CHECK(concat_bitcast->IsCustomCall(kConcatBitcastCustomCall)); @@ -12011,7 +11929,7 @@ ENTRY main { p1_copy1 = f32[8,8] copy(p1) p1_copy2 = f32[8,8] copy(p1) - + r1 = f32[8,8] add(c, p1_copy1) r2 = f32[8,8] add(c, p1_copy2) @@ -12687,8 +12605,8 @@ ENTRY main { // A lambda for generating HLO with 2 while loops called back to back. The // first while loop will execute while_computation1 and the second while loop // will execute while_computation2. - auto gen_hlo = [&](std::string_view while_computation1, - std::string_view while_computation2) { + auto gen_hlo = [&](absl::string_view while_computation1, + absl::string_view while_computation2) { return absl::StrReplaceAll( module_text, { @@ -12729,7 +12647,7 @@ ENTRY main { // Define a lambda for running MSA on the specified HLO, with the // configuration above. auto run_msa = - [&](std::string_view hlo_text) -> absl::StatusOr { + [&](absl::string_view hlo_text) -> absl::StatusOr { ModuleAndAssignments module_and_assignments; TF_ASSIGN_OR_RETURN(module_and_assignments.module, ParseAndReturnVerifiedModule(hlo_text)); diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test_base.h b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test_base.h new file mode 100644 index 00000000000000..c81035e25dc954 --- /dev/null +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test_base.h @@ -0,0 +1,448 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_SPACE_ASSIGNMENT_TEST_BASE_H_ +#define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_SPACE_ASSIGNMENT_TEST_BASE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/transforms/simplifiers/instruction_hoister.h" +#include "xla/hlo/utils/hlo_live_range.h" +#include "xla/service/buffer_value.h" +#include "xla/service/hlo_buffer.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/buffer_interval_comparator.h" +#include "xla/service/memory_space_assignment/cost_analysis.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.h" +#include "xla/service/memory_space_assignment/options.h" +#include "xla/service/memory_space_assignment/prefetch_interval_picker.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace memory_space_assignment { + +constexpr int64_t kPointerSize = 8; +constexpr float kDefaultMemBandwidth = 100; +constexpr float kAlternateMemBandwidth = 1000; +constexpr float kBytesPerSecond = 100; +constexpr float kFlopsPerSecond = 1000; +constexpr float kTranscendentalsPerSecond = 10; + +class TestBufferIntervalComparator : public BufferIntervalComparator { + public: + explicit TestBufferIntervalComparator(MsaBufferIntervalCompare compare_method) + : compare_method_(std::move(compare_method)) {} + + ~TestBufferIntervalComparator() override = default; + + std::string DescribeComparisonCriteria() const override { + return "internal to test"; + } + std::string CriteriaToString( + const MsaBufferInterval& buffer_interval) override { + return "internal to test"; + } + bool LessThan(const MsaBufferInterval& lhs, + const MsaBufferInterval& rhs) override { + return compare_method_(lhs, rhs); + } + + private: + MsaBufferIntervalCompare compare_method_; +}; + +class MemorySpaceAssignmentTestBase : public HloTestBase { + protected: + // We use the following two memory space values to describe the default (slow + // and large) and alternate (fast and small) memory spaces. + const int64_t kDefaultMemorySpace = 0; + const int64_t kAlternateMemorySpace = 1; + const int64_t kPinnedDefaultMemorySpace = 2; + + static HloCostAnalysis::Options DefaultHloCostAnalysisOptions() { + HloCostAnalysis::Options options; + options.set_flops_per_second(kFlopsPerSecond); + options.set_bytes_per_second(kBytesPerSecond); + options.set_transcendentals_per_second(kTranscendentalsPerSecond); + + return options; + } + + Options DefaultMemorySpaceOptions() const { + Options options; + options.max_size_in_bytes = 128; + options.alignment_in_bytes = 8; + options.verify = false; + options.alternate_memory_space = kAlternateMemorySpace; + options.max_outstanding_prefetches = -1; + options.max_outstanding_evictions = -1; + + return options; + } + + static CostAnalysisOptions DefaultCostAnalysisOptions() { + CostAnalysisOptions options; + options.default_mem_bandwidth_bytes_per_second = kDefaultMemBandwidth; + options.alternate_mem_bandwidth_bytes_per_second = kAlternateMemBandwidth; + return options; + } + + static Options UpdateMaxAsyncCopies(Options options, + int64_t max_async_copies) { + options.max_outstanding_prefetches = max_async_copies; + options.max_outstanding_evictions = max_async_copies; + + return options; + } + + std::unique_ptr AssignMemorySpaceUsingCostAnalysis( + HloModule* module, + std::optional memory_space_options_override = std::nullopt, + std::optional cost_analysis_options_override = + std::nullopt, + std::optional hlo_cost_options_override = + std::nullopt, + std::optional optional_msa_sort_order_overrides = + std::nullopt) { + HloCostAnalysis::Options hlo_cost_options = DefaultHloCostAnalysisOptions(); + if (hlo_cost_options_override) { + hlo_cost_options = *hlo_cost_options_override; + } + + HloCostAnalysis hlo_cost_analysis(hlo_cost_options); + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_CHECK_OK(computation->Accept(&hlo_cost_analysis)); + } + TF_CHECK_OK(HloAliasAnalysis::Run(module).status()); + + Options memory_space_options = DefaultMemorySpaceOptions(); + if (memory_space_options_override) { + memory_space_options = *memory_space_options_override; + } + CostAnalysisOptions cost_analysis_options = DefaultCostAnalysisOptions(); + if (cost_analysis_options_override) { + cost_analysis_options = *cost_analysis_options_override; + } + HloCostAnalysisCosts hlo_cost_analysis_costs(hlo_cost_analysis); + + auto status_or_cost_analysis = CostAnalysis::Create( + hlo_cost_analysis_costs, cost_analysis_options, *module); + TF_CHECK_OK(status_or_cost_analysis.status()); + auto cost_analysis = std::move(status_or_cost_analysis.value()); + + memory_space_options.cost_analysis = cost_analysis.get(); + CostAnalysisPrefetchIntervalPicker prefetch_interval_picker( + CostAnalysisPrefetchIntervalPicker( + *cost_analysis, /*min_overlap_to_async_copy_ratio=*/0.8f, + /*preferred_overlap_to_async_copy_ratio=*/1.5, + /*max_overlap_to_mem_size_async_copy_ratio=*/10.0, + /*mem_size_bytes=*/memory_space_options.max_size_in_bytes)); + MsaSortOrderOverrides msa_sort_order_overrides; + if (optional_msa_sort_order_overrides.has_value()) { + msa_sort_order_overrides = optional_msa_sort_order_overrides.value(); + } + MemoryBoundednessBufferIntervalComparator comparator( + *cost_analysis, &cache_, msa_sort_order_overrides); + return AssignMemorySpace( + module, memory_space_options, + [&comparator](const MsaBufferInterval& lhs, + const MsaBufferInterval& rhs) { + return comparator.LessThan(lhs, rhs); + }, + &prefetch_interval_picker); + } + + std::unique_ptr AssignMemorySpace( + HloModule* module, std::optional options_override = std::nullopt, + int64_t max_prefetch_interval = 10, int64_t min_prefetch_interval = 2) { + InstructionHoister instruction_hoister; + TF_CHECK_OK(instruction_hoister.Run(module).status()); + InstructionCountPrefetchIntervalPicker prefetch_interval_picker( + min_prefetch_interval, max_prefetch_interval); + return AssignMemorySpace(module, std::move(options_override), + /*buffer_interval_compare=*/{}, + &prefetch_interval_picker); + } + + std::unique_ptr AssignMemorySpace( + HloModule* module, std::optional options_override, + std::optional buffer_interval_compare, + PrefetchIntervalPicker* prefetch_interval_picker) { + auto status_or = AssignMemorySpaceAndReturnStatus( + module, std::move(options_override), std::move(buffer_interval_compare), + prefetch_interval_picker); + TF_EXPECT_OK(status_or.status()); + return std::move(status_or.value()); + } + + absl::StatusOr> + AssignMemorySpaceAndReturnStatus( + HloModule* module, std::optional options_override, + std::optional buffer_interval_compare, + PrefetchIntervalPicker* prefetch_interval_picker) { + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + + auto is_allowed_in_alternate_mem = [](const HloValue& value) { + // Check if the value belongs to the entry computation. + HloInstruction* instruction = value.instruction(); + HloComputation* computation = instruction->parent(); + bool in_entry_computation = + (computation == computation->parent()->entry_computation()); + + return (!in_entry_computation || + instruction->opcode() != HloOpcode::kParameter); + }; + + // Only check parameters in default memory if the original module didn't + // have the parameters in alternate memory. + bool check_parameters_in_default_memory = true; + for (const HloInstruction* parameter : + module->entry_computation()->parameter_instructions()) { + ShapeUtil::ForEachSubshape( + parameter->shape(), + [&](const Shape& subshape, const ShapeIndex& /*index*/) { + if (subshape.has_layout() && + subshape.layout().memory_space() == kAlternateMemorySpace) { + check_parameters_in_default_memory = false; + } + }); + } + + Options options = DefaultMemorySpaceOptions(); + if (options_override) { + options = *options_override; + } + std::unique_ptr test_comparator; + if (buffer_interval_compare.has_value()) { + test_comparator = std::make_unique( + *buffer_interval_compare); + options.buffer_interval_comparator = test_comparator.get(); + } + options.prefetch_interval_picker = prefetch_interval_picker; + options.size_fn = size_fn; + if (options.is_allowed_in_alternate_mem_fn == nullptr) { + options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem; + } + + TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module)); + TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_live_range, + HloLiveRange::Run(module->schedule(), *alias_analysis, + module->entry_computation())); + + TF_ASSIGN_OR_RETURN(std::unique_ptr preset_assignments, + MemorySpaceAssignment::Run(module, *hlo_live_range, + *alias_analysis, options)); + if (check_parameters_in_default_memory) { + CheckParametersInDefaultMemory(module); + } + CheckRootInDefaultMemory(module); + CheckPresetAssignments(preset_assignments.get()); + return preset_assignments; + } + + void CheckPresetAssignments(const PresetAssignments* preset_assignments) { + // Ensure that the exported preset assignments point to layouts in the + // alternate memory. Also ensure that the positions are unique. Note that + // we're using a std::set instead of absl::flat_hash_set because we can make + // use of HloPosition's comparator logic instead of providing a hasher. + std::set positions_in_preset_assignments; + for (auto& position_and_chunk : preset_assignments->chunks()) { + HloPosition position = position_and_chunk.first; + EXPECT_EQ(positions_in_preset_assignments.find(position), + positions_in_preset_assignments.end()); + positions_in_preset_assignments.insert(position); + const Shape& subshape = + ShapeUtil::GetSubshape(position.instruction->shape(), position.index); + EXPECT_EQ(subshape.layout().memory_space(), kAlternateMemorySpace) + << "Exported position is not in alternate mem: " + << position.ToString(); + } + } + + void CheckParametersInDefaultMemory(const HloModule* module) { + // Check that all the entry parameter subshapes are placed in default + // memory. + const HloComputation* entry_computation = module->entry_computation(); + for (const HloInstruction* parameter : + entry_computation->parameter_instructions()) { + ShapeUtil::ForEachSubshape( + parameter->shape(), + [&](const Shape& subshape, const ShapeIndex& /*index*/) { + if (subshape.has_layout()) { + EXPECT_NE(subshape.layout().memory_space(), kAlternateMemorySpace) + << "Parameter not in default memory: " + << parameter->ToString(); + } + }); + } + } + + void CheckRootInDefaultMemory(const HloModule* module) { + const HloInstruction* root = + module->entry_computation()->root_instruction(); + if (root->shape().IsArray()) { + EXPECT_EQ(root->shape().layout().memory_space(), kDefaultMemorySpace); + } + } + + struct OutstandingAsyncCopies { + int64_t max_copies; + int64_t max_prefetches; + int64_t max_evictions; + }; + + /*static*/ OutstandingAsyncCopies CountMaximumOutstandingAsyncCopies( + const HloModule& module) const { + OutstandingAsyncCopies copies{0, 0, 0}; + int64_t current_copies = 0; + int64_t current_prefetches = 0; + int64_t current_evictions = 0; + for (HloInstruction* instruction : module.schedule() + .sequence(module.entry_computation()) + .instructions()) { + if (instruction->opcode() == HloOpcode::kCopyStart) { + current_copies++; + if (ShapeUtil::GetSubshape(instruction->shape(), {0}) + .layout() + .memory_space() == kAlternateMemorySpace) { + current_prefetches++; + } else { + current_evictions++; + } + } else if (instruction->opcode() == HloOpcode::kCopyDone) { + current_copies--; + if (instruction->shape().layout().memory_space() == + kAlternateMemorySpace) { + current_prefetches--; + } else { + current_evictions--; + } + } + copies.max_copies = std::max(copies.max_copies, current_copies); + copies.max_prefetches = + std::max(copies.max_prefetches, current_prefetches); + copies.max_prefetches = std::max(copies.max_evictions, current_evictions); + } + return copies; + } + + static int64_t GetAlternateMemoryOffset( + const PresetAssignments& preset_assignments, + const HloInstruction* instruction, const ShapeIndex& index = {}) { + // Returns the offset of the assignment, -1 if it's not in the alternate + // memory. + const HloModule* module = instruction->GetModule(); + auto status_or_alias_analysis = HloAliasAnalysis::Run(module); + TF_CHECK_OK(status_or_alias_analysis.status()); + auto alias_analysis = std::move(status_or_alias_analysis.value()); + HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(instruction, index); + for (auto& pos_and_chunk : preset_assignments.chunks()) { + for (auto& value : buffer.values()) { + if (pos_and_chunk.first == value->defining_position()) { + return pos_and_chunk.second.offset; + } + } + } + return -1; + } + + std::unique_ptr CreateEvictAndPrefetchModule() { + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + HloInstruction* p0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + HloInstruction* p1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); + HloInstruction* tanh = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); + // tanh should be placed in the alternate memory since there isn't much + // contention in the beginning. However, tanh has another consumer at the + // end. So it should be kicked out to default memory and prefetched back in. + // The graph below is meant to increase the contention to force + // eviction/prefetch behavior. + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh)); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1)); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1)); + HloInstruction* d = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1)); + HloInstruction* e = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b)); + HloInstruction* f = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c)); + HloInstruction* g = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d)); + HloInstruction* h = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c)); + HloInstruction* i = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d)); + HloInstruction* j = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d)); + HloInstruction* k = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f)); + HloInstruction* l = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h)); + HloInstruction* m = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j)); + HloInstruction* n = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l)); + HloInstruction* o = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m)); + // tanh is being used at the root instruction, and this should be + // prefetched. + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh)); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {p0, p1, tanh, a, b, c, d, e, f, g, h, i, + j, k, l, m, n, o, add}); + TF_CHECK_OK(module->set_schedule(schedule)); + return module; + } + + CostAnalysis::Cache cache_; +}; + +} // namespace memory_space_assignment +} // namespace xla + +#endif // XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_SPACE_ASSIGNMENT_TEST_BASE_H_ diff --git a/third_party/xla/xla/service/memory_space_assignment/options.cc b/third_party/xla/xla/service/memory_space_assignment/options.cc new file mode 100644 index 00000000000000..31953cccbe800f --- /dev/null +++ b/third_party/xla/xla/service/memory_space_assignment/options.cc @@ -0,0 +1,43 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/memory_space_assignment/options.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" + +namespace xla { +namespace memory_space_assignment { + +std::string PostAllocationTransformationUpdate::ToString() const { + return absl::StrCat("to_be_removed: ", + absl::StrJoin(to_be_removed, ", ", + [](std::string* out, const auto& entry) { + absl::StrAppend(out, entry->name()); + }), + "\n", "update_use_map: ", + absl::StrJoin(update_use_map, ", ", + [](std::string* out, const auto& entry) { + absl::StrAppend( + out, "<", entry.first.ToString(), + " -> ", entry.second.ToString(), ">"); + }), + "\n"); +} + +} // namespace memory_space_assignment +} // namespace xla diff --git a/third_party/xla/xla/service/memory_space_assignment/options.h b/third_party/xla/xla/service/memory_space_assignment/options.h index 2148784c9d266c..96de950050ba08 100644 --- a/third_party/xla/xla/service/memory_space_assignment/options.h +++ b/third_party/xla/xla/service/memory_space_assignment/options.h @@ -24,8 +24,11 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -64,8 +67,26 @@ using WindowPrefetchNotifyOperandAppendedFunction = using IsAsyncSliceImplementedFunction = std::function; +// MSA allows for custom post-allocation transformations. When a post-allocation +// transformation is performed on an instruction, this result is returned. It +// tells MSA: +// 1. A list of instructions that MSA should delete. +// 2. A list of HloUses that the transformation replaced. +// +// This information is then processed via +// FixAllocationSequenceAfterPostAllocationTransformation call. +struct PostAllocationTransformationUpdate { + std::vector to_be_removed; + absl::flat_hash_map update_use_map; + + std::string ToString() const; +}; + // The different options to be passed to the Run() API. struct Options { + // The backend-specific integer value that describes the default memory. + int64_t default_memory_space = 0; + // Backend-specific integer value that describes the alternate memory. int64_t alternate_memory_space = 0; @@ -145,6 +166,28 @@ struct Options { std::function allocation_request_modifier_testing_fn = nullptr; + // Applies post-allocation transformations to the given instruction. This + // function is called after the allocations are found in the MsaAlgorithm. It + // is called on each instruction I that meets the following conditions: + // 1. I is called from a non-fusion computation + // 2. I's operands are not in alternate memory + // 3. I is not successfully converted to async instruction. + // 4. I's operands don't have in-place users, e.g., a dynamic-update-slice. + // + // The transformation function is allowed to do the following: + // 1. Mark instructions for removal. + // 2. Modify existing instructions. + // + // This transformation is NOT allowed to: + // 1. Directly remove instructions (or nullify them). + // 2. Add new instructions. + // + // Note that it is up to the transformation function to ensure that the + // changes to the module preserves the semantics of the original program. + std::function( + HloInstruction*)> + post_allocation_transformation_fn; + // If true, we will try to reduce scoped allocation buffer size for all // instructions if their operand/output has been allocated in alternate // memory. diff --git a/third_party/xla/xla/service/memory_space_assignment/simulator_test.cc b/third_party/xla/xla/service/memory_space_assignment/simulator_test.cc index 93d20fb5dba397..1e5e97a8504798 100644 --- a/third_party/xla/xla/service/memory_space_assignment/simulator_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/simulator_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -37,7 +36,6 @@ limitations under the License. #include "xla/service/memory_space_assignment/allocation.h" #include "xla/service/memory_space_assignment/cost_analysis.h" #include "xla/shape.h" -#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" @@ -106,7 +104,8 @@ class MemorySpaceAssignmentSimulatorTest : public HloTestBase { cost_analysis_.get(), kAlternateMemorySpace); return absl::OkStatus(); } - absl::flat_hash_map instruction_map_; + absl::flat_hash_map + instruction_map_; std::unique_ptr hlo_cost_analysis_; std::unique_ptr hlo_cost_analysis_costs_; diff --git a/third_party/xla/xla/service/memory_space_assignment/slice.h b/third_party/xla/xla/service/memory_space_assignment/slice.h index da3fab681d3f8b..f0caa04e92ee41 100644 --- a/third_party/xla/xla/service/memory_space_assignment/slice.h +++ b/third_party/xla/xla/service/memory_space_assignment/slice.h @@ -38,6 +38,7 @@ limitations under the License. #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_SLICE_H_ #include +#include #include #include #include diff --git a/third_party/xla/xla/service/memory_space_assignment/utils.cc b/third_party/xla/xla/service/memory_space_assignment/utils.cc index b4b37ff0677bac..43f04c263f27ee 100644 --- a/third_party/xla/xla/service/memory_space_assignment/utils.cc +++ b/third_party/xla/xla/service/memory_space_assignment/utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "re2/re2.h" @@ -30,6 +31,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" #include "xla/shape_util.h" #include "xla/util.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/pattern_matcher_gmock.h b/third_party/xla/xla/service/pattern_matcher_gmock.h index eeb7b1caabb4e1..f8bea2cff482a7 100644 --- a/third_party/xla/xla/service/pattern_matcher_gmock.h +++ b/third_party/xla/xla/service/pattern_matcher_gmock.h @@ -16,94 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_ #define XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_ -#include - -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/layout.h" -#include "xla/service/pattern_matcher.h" -#include "xla/shape.h" -#include "xla/test.h" -#include "tsl/platform/test.h" - -namespace xla { - -namespace pattern_matcher_gmock_detail { -template -class GmockMatcher { - public: - explicit GmockMatcher(Pattern p) : pattern_(std::move(p)) {} - - // In service of better error messages, list out the overloads explicitly - // rather than just using a template. gMock's polymorphism plus - // pattern_matcher yields some pretty gnarly stuff. - bool MatchAndExplain(const Layout& l, - ::testing::MatchResultListener* listener) const { - return MatchAndExplainImpl(&l, listener); - } - bool MatchAndExplain(const Layout* l, - ::testing::MatchResultListener* listener) const { - return MatchAndExplainImpl(l, listener); - } - bool MatchAndExplain(Layout* l, - ::testing::MatchResultListener* listener) const { - return MatchAndExplainImpl(l, listener); - } - - bool MatchAndExplain(const Shape& s, - ::testing::MatchResultListener* listener) const { - return MatchAndExplainImpl(&s, listener); - } - bool MatchAndExplain(const Shape* s, - ::testing::MatchResultListener* listener) const { - return MatchAndExplainImpl(s, listener); - } - bool MatchAndExplain(Shape* s, - ::testing::MatchResultListener* listener) const { - return MatchAndExplainImpl(s, listener); - } - - bool MatchAndExplain(const HloInstruction& instr, - ::testing::MatchResultListener* listener) const { - return MatchAndExplainImpl(&instr, listener); - } - bool MatchAndExplain(const HloInstruction* instr, - ::testing::MatchResultListener* listener) const { - return MatchAndExplainImpl(instr, listener); - } - bool MatchAndExplain(HloInstruction* instr, - ::testing::MatchResultListener* listener) const { - return MatchAndExplainImpl(instr, listener); - } - - void DescribeTo(std::ostream* os) const { pattern_.DescribeTo(os); } - - void DescribeNegationTo(std::ostream* os) const { - *os << "is NOT: "; - DescribeTo(os); - } - - private: - template - bool MatchAndExplainImpl(T* t, - ::testing::MatchResultListener* listener) const { - MatchOption options{/*.capture=*/true, /*.single_user_only=*/false, - /*.explain_os=*/listener->stream()}; - return Match(t, pattern_, options); - } - - Pattern pattern_; -}; -} // namespace pattern_matcher_gmock_detail - -template -::testing::PolymorphicMatcher< - pattern_matcher_gmock_detail::GmockMatcher> -GmockMatch(Pattern&& p) { - return ::testing::MakePolymorphicMatcher( - pattern_matcher_gmock_detail::GmockMatcher( - std::forward(p))); -} - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/testlib/pattern_matcher_gmock.h" #endif // XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_ diff --git a/third_party/xla/xla/service/platform_util.cc b/third_party/xla/xla/service/platform_util.cc index b0101ed9e73124..34bdf4808e70b6 100644 --- a/third_party/xla/xla/service/platform_util.cc +++ b/third_party/xla/xla/service/platform_util.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/status/statusor.h" @@ -51,7 +50,7 @@ constexpr char kInterpreter[] = "interpreter"; namespace { -std::string CanonicalPlatformName(std::string_view platform_name) { +std::string CanonicalPlatformName(absl::string_view platform_name) { std::string lowercase_platform_name = absl::AsciiStrToLower(platform_name); // "cpu" and "host" mean the same thing. if (lowercase_platform_name == "cpu") { @@ -89,7 +88,7 @@ absl::StatusOr> GetSupportedPlatforms() { } // namespace absl::StatusOr PlatformUtil::CanonicalPlatformName( - std::string_view platform_name) { + absl::string_view platform_name) { return xla::CanonicalPlatformName(platform_name); } @@ -131,7 +130,7 @@ absl::StatusOr PlatformUtil::GetDefaultPlatform() { } /*static*/ absl::StatusOr PlatformUtil::GetPlatform( - std::string_view platform_name) { + absl::string_view platform_name) { TF_ASSIGN_OR_RETURN(se::Platform * platform, se::PlatformManager::PlatformWithName( xla::CanonicalPlatformName(platform_name))); diff --git a/third_party/xla/xla/service/platform_util.h b/third_party/xla/xla/service/platform_util.h index 7b0ee854e9dc65..1162ebfeb282b8 100644 --- a/third_party/xla/xla/service/platform_util.h +++ b/third_party/xla/xla/service/platform_util.h @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/status/statusor.h" @@ -38,7 +37,7 @@ class PlatformUtil { // there are multiple implementations. For example, GPU platform may be // cuda(Nvidia) or rocm(AMD) static absl::StatusOr CanonicalPlatformName( - std::string_view platform_name); + absl::string_view platform_name); // Returns the platforms present on the system and supported by XLA. // @@ -56,7 +55,7 @@ class PlatformUtil { // Returns the platform according to the given name. Returns error if there is // no such platform. static absl::StatusOr GetPlatform( - std::string_view platform_name); + absl::string_view platform_name); // Returns a vector of StreamExecutors for the given platform. // If populated, only the devices in allowed_devices will have diff --git a/third_party/xla/xla/service/rendezvous.cc b/third_party/xla/xla/service/rendezvous.cc index b4be7d39e1c815..e9c88dbca2e5a6 100644 --- a/third_party/xla/xla/service/rendezvous.cc +++ b/third_party/xla/xla/service/rendezvous.cc @@ -19,12 +19,11 @@ limitations under the License. #include #include #include -#include #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/profiler/lib/traceme.h" namespace xla { @@ -67,7 +66,7 @@ static bool WaitForReadyWithTimeout(RendezvousStateSynchronization& state, } void AwaitAndLogIfStuck(RendezvousStateSynchronization& state, int32_t id, - std::string_view name, + absl::string_view name, absl::Duration warn_stuck_timeout, absl::Duration terminate_timeout) { // Wait for `warn_stuck_timeout` for the rendezvous to be ready. @@ -138,13 +137,12 @@ inline constexpr int32_t kPending = 0; inline constexpr int32_t kCompleted = std::numeric_limits::max(); } // namespace -RendezvousSingleFlag::RendezvousSingleFlag() : state_(kPending) {} +RendezvousFlag::RendezvousFlag() : state_(kPending) {} -RendezvousSingleFlag::InFlightRendezvous::InFlightRendezvous( - RendezvousSingleFlag* flag) +RendezvousFlag::InFlightRendezvous::InFlightRendezvous(RendezvousFlag* flag) : flag_(flag) {} -RendezvousSingleFlag::InFlightRendezvous::~InFlightRendezvous() { +RendezvousFlag::InFlightRendezvous::~InFlightRendezvous() { if (flag_ == nullptr) return; // Reload state and use CAS to decide if we are the one who @@ -163,11 +161,11 @@ RendezvousSingleFlag::InFlightRendezvous::~InFlightRendezvous() { } } -RendezvousSingleFlag::InFlightRendezvous::operator bool() const { +RendezvousFlag::InFlightRendezvous::operator bool() const { return flag_ != nullptr; } -RendezvousSingleFlag::InFlightRendezvous RendezvousSingleFlag::TryJoin() { +RendezvousFlag::InFlightRendezvous RendezvousFlag::TryJoin() { // If `state_` is `kCompleted` it means that we have at least one completed // rendezvous for this flag and can skip it. if (state_.load() == kCompleted) return InFlightRendezvous(nullptr); @@ -185,8 +183,6 @@ RendezvousSingleFlag::InFlightRendezvous RendezvousSingleFlag::TryJoin() { return InFlightRendezvous(this); } -bool RendezvousSingleFlag::IsCompleted() const { - return state_.load() == kCompleted; -} +bool RendezvousFlag::IsCompleted() const { return state_.load() == kCompleted; } } // namespace xla diff --git a/third_party/xla/xla/service/rendezvous.h b/third_party/xla/xla/service/rendezvous.h index a1b6585d07c655..ffd4c431003726 100644 --- a/third_party/xla/xla/service/rendezvous.h +++ b/third_party/xla/xla/service/rendezvous.h @@ -21,19 +21,19 @@ limitations under the License. #include #include #include -#include #include #include #include #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" -#include "absl/synchronization/notification.h" #include "absl/time/time.h" #include "absl/types/span.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/profiler/lib/traceme.h" namespace xla { @@ -68,38 +68,45 @@ struct RendezvousResult> { static Type Empty() { return {std::shared_ptr()}; } }; +template <> +struct RendezvousResult { + using Type = absl::Status; + + static Type Wrap(absl::Status result) { return result; } + static Type Empty() { return absl::OkStatus(); } +}; + template using RendezvousResultType = typename RendezvousResult::Type; // The group of threads identifies itself with a key that must be unique to -// the the group. When all threads have arrived at the rendezvous, one thread +// the group. When all threads have arrived at the rendezvous, one thread // executes the given function with the values supplied by each thread, and // all threads receive the result. Rendezvous must have a human readable name to // make easy to debug stuck and timed out attempts. template -RendezvousResultType RendezvousSingle( - std::string_view name, const K& key, const V& value, size_t num_threads, +RendezvousResultType Rendezvous( + absl::string_view name, const K& key, const V& value, size_t num_threads, Fn fn, absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), absl::Duration terminate_timeout = absl::InfiniteDuration()); // A rendezvous for a group of threads that do not have any value arguments. template -RendezvousResultType RendezvousSingle( - std::string_view name, const K& key, size_t num_threads, Fn fn, +RendezvousResultType Rendezvous( + absl::string_view name, const K& key, size_t num_threads, Fn fn, absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), absl::Duration terminate_timeout = absl::InfiniteDuration()); // A rendezvous for a group of threads that do not have any computation to run // and simply acts as a barrier for a group of thread. template -void RendezvousSingle( - std::string_view name, const K& key, size_t num_threads, - absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), - absl::Duration terminate_timeout = absl::InfiniteDuration()); +void Rendezvous(absl::string_view name, const K& key, size_t num_threads, + absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), + absl::Duration terminate_timeout = absl::InfiniteDuration()); -// An `std::once_flag`-like primitive for executing RendezvousSingle operations. +// An `std::once_flag`-like primitive for executing Rendezvous operations. // -// RendezvousSingleFlag guarantees that all or none participants in a rendezvous +// RendezvousFlag guarantees that all or none participants in a rendezvous // join the rendezvous process and once rendezvous is completed flag marked as // `completed` and all further rendezvous using this flag will be skipped. It // has a weaker than exactly-once guarantee and multiple racing rendezvous can @@ -111,17 +118,17 @@ void RendezvousSingle( // and prefer simpler implementation with weaker guarantees. // // See: https://en.cppreference.com/w/cpp/thread/once_flag -class RendezvousSingleFlag { +class RendezvousFlag { public: - RendezvousSingleFlag(); + RendezvousFlag(); - RendezvousSingleFlag(const RendezvousSingleFlag&) = delete; - RendezvousSingleFlag& operator=(const RendezvousSingleFlag&) = delete; + RendezvousFlag(const RendezvousFlag&) = delete; + RendezvousFlag& operator=(const RendezvousFlag&) = delete; // RAII wrapper to exit from in-flight rendezvous when destructed. class InFlightRendezvous { public: - explicit InFlightRendezvous(RendezvousSingleFlag* flag); + explicit InFlightRendezvous(RendezvousFlag* flag); ~InFlightRendezvous(); InFlightRendezvous(const InFlightRendezvous&) = delete; @@ -130,7 +137,7 @@ class RendezvousSingleFlag { operator bool() const; // NOLINT private: - RendezvousSingleFlag* flag_; + RendezvousFlag* flag_; }; // Returns InFlightRendezvous convertible to `true` if the caller should join @@ -151,8 +158,8 @@ class RendezvousSingleFlag { // rendezvous. If rendezvous will not be executed it will return empty shared // pointer result. template -RendezvousResultType RendezvousSingle( - RendezvousSingleFlag& flag, std::string_view name, const K& key, +RendezvousResultType Rendezvous( + RendezvousFlag& flag, absl::string_view name, const K& key, size_t num_threads, Fn fn, absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), absl::Duration terminate_timeout = absl::InfiniteDuration()); @@ -161,11 +168,10 @@ RendezvousResultType RendezvousSingle( // not in `completed` state and will switch it to `completed` after finishing a // rendezvous. template -void RendezvousSingle( - RendezvousSingleFlag& flag, std::string_view name, const K& key, - size_t num_threads, - absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), - absl::Duration terminate_timeout = absl::InfiniteDuration()); +void Rendezvous(RendezvousFlag& flag, absl::string_view name, const K& key, + size_t num_threads, + absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), + absl::Duration terminate_timeout = absl::InfiniteDuration()); //===----------------------------------------------------------------------===// // Internal implementation details. @@ -199,7 +205,7 @@ struct RendezvousState : public RendezvousStateSynchronization { explicit RendezvousState(size_t n_threads) : RendezvousStateSynchronization(n_threads), values(n_threads, nullptr), - result(nullptr) {} + result(RendezvousResult::Empty()) {} std::vector values; RendezvousResultType result; @@ -273,7 +279,7 @@ class RendezvousMap { }; void AwaitAndLogIfStuck(RendezvousStateSynchronization& state, int32_t id, - std::string_view name, + absl::string_view name, absl::Duration warn_stuck_timeout, absl::Duration terminate_timeout); } // namespace internal @@ -283,11 +289,10 @@ void AwaitAndLogIfStuck(RendezvousStateSynchronization& state, int32_t id, //===----------------------------------------------------------------------===// template -RendezvousResultType RendezvousSingle(std::string_view name, const K& key, - const V& value, size_t num_threads, - Fn fn, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { +RendezvousResultType Rendezvous(absl::string_view name, const K& key, + const V& value, size_t num_threads, Fn fn, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { // Check that `fn` is callable with a span of values and returns `R`. static_assert(std::is_invocable_r_v>, "invalid rendezvous function signature"); @@ -311,7 +316,7 @@ RendezvousResultType RendezvousSingle(std::string_view name, const K& key, tsl::profiler::TraceMe trace([&] { return tsl::profiler::TraceMeEncode( - "RendezvousSingle", + "Rendezvous", {{"num_threads", num_threads}, {"name", name}, {"id", id}}); }); @@ -347,46 +352,44 @@ RendezvousResultType RendezvousSingle(std::string_view name, const K& key, } template -RendezvousResultType RendezvousSingle(std::string_view name, const K& key, - size_t num_threads, Fn fn, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { - return RendezvousSingle( +RendezvousResultType Rendezvous(absl::string_view name, const K& key, + size_t num_threads, Fn fn, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { + return Rendezvous( name, key, std::nullopt, num_threads, [fn](auto) { return fn(); }, warn_stuck_timeout, terminate_timeout); } template -void RendezvousSingle(std::string_view name, const K& key, size_t num_threads, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { - RendezvousSingle( +void Rendezvous(absl::string_view name, const K& key, size_t num_threads, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { + Rendezvous( name, key, std::nullopt, num_threads, [](auto) { return std::nullopt; }, warn_stuck_timeout, terminate_timeout); } template -RendezvousResultType RendezvousSingle(RendezvousSingleFlag& flag, - std::string_view name, const K& key, - size_t num_threads, Fn fn, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { +RendezvousResultType Rendezvous(RendezvousFlag& flag, absl::string_view name, + const K& key, size_t num_threads, Fn fn, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { if (auto in_flight_rendezvous = flag.TryJoin()) { - return RendezvousSingle(name, key, num_threads, std::move(fn), - warn_stuck_timeout, terminate_timeout); + return Rendezvous(name, key, num_threads, std::move(fn), + warn_stuck_timeout, terminate_timeout); } else { return RendezvousResult::Empty(); } } template -void RendezvousSingle(RendezvousSingleFlag& flag, std::string_view name, - const K& key, size_t num_threads, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { +void Rendezvous(RendezvousFlag& flag, absl::string_view name, const K& key, + size_t num_threads, absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { if (auto in_flight_rendezvous = flag.TryJoin()) { - RendezvousSingle(name, key, num_threads, warn_stuck_timeout, - terminate_timeout); + Rendezvous(name, key, num_threads, warn_stuck_timeout, + terminate_timeout); } } diff --git a/third_party/xla/xla/service/rendezvous_test.cc b/third_party/xla/xla/service/rendezvous_test.cc index 867d24971f078b..c47550a63de17a 100644 --- a/third_party/xla/xla/service/rendezvous_test.cc +++ b/third_party/xla/xla/service/rendezvous_test.cc @@ -41,8 +41,7 @@ tsl::thread::ThreadPool CreateThreadPool(int32_t size) { } TEST(RendezvousTest, OneParticipant) { - auto result = - RendezvousSingle("rendezvous_test", 0, 1, [] { return 42; }); + auto result = Rendezvous("rendezvous_test", 0, 1, [] { return 42; }); ASSERT_EQ(*result, 42); } @@ -53,7 +52,7 @@ TEST(RendezvousTest, TwoParticipants) { auto task = [&](int32_t id) { return [&, id] { results[id] = - RendezvousSingle("rendezvous_test", 0, 2, [] { return 42; }); + Rendezvous("rendezvous_test", 0, 2, [] { return 42; }); counter.DecrementCount(); }; }; @@ -81,7 +80,7 @@ TEST(RendezvousTest, TwoParticipantsWithValues) { auto task = [&](int32_t id) { return [&, id] { results[id] = - RendezvousSingle("rendezvous_test", 0, id, 2, accumulate); + Rendezvous("rendezvous_test", 0, id, 2, accumulate); counter.DecrementCount(); }; }; @@ -103,7 +102,7 @@ TEST(RendezvousTest, RepeatRendezvous) { absl::BlockingCounter counter(2); auto task = [&] { - RendezvousSingle("rendezvous_test", i, 2, [] { return 42; }); + Rendezvous("rendezvous_test", i, 2, [] { return 42; }); counter.DecrementCount(); }; @@ -119,8 +118,8 @@ TEST(RendezvousTest, ReturningStatusOr) { auto task = [&](int32_t id) { return [&, id] { - results[id] = RendezvousSingle>( - "rendezvous_test", 0, 2, [] { return 42; }); + results[id] = Rendezvous>("rendezvous_test", 0, 2, + [] { return 42; }); counter.DecrementCount(); }; }; @@ -135,8 +134,8 @@ TEST(RendezvousTest, ReturningStatusOr) { ASSERT_EQ(**results[1], 42); } -TEST(RendezvousTest, RendezvousSingleFlag) { - RendezvousSingleFlag flag; +TEST(RendezvousTest, RendezvousFlag) { + RendezvousFlag flag; auto thread_pool = CreateThreadPool(2); int32_t num_executed = 0; @@ -146,7 +145,7 @@ TEST(RendezvousTest, RendezvousSingleFlag) { auto task = [&](absl::BlockingCounter& counter) { return [&] { - RendezvousSingle( + Rendezvous( flag, "rendezvous_test", 0, 2, [&] { return ++num_executed; }, Timeout(), Terminate()); counter.DecrementCount(); @@ -169,8 +168,8 @@ TEST(RendezvousTest, RendezvousSingleFlag) { ASSERT_EQ(num_executed, 1); } -TEST(RendezvousTest, RendezvousSingleFlagRace) { - RendezvousSingleFlag flag; +TEST(RendezvousTest, RendezvousFlagRace) { + RendezvousFlag flag; static constexpr int32_t kNumRendezvous = 16; static constexpr int32_t kNumThreads = 8; @@ -179,8 +178,8 @@ TEST(RendezvousTest, RendezvousSingleFlagRace) { auto task = [&](int32_t key) { return [&, key] { - RendezvousSingle(flag, "key: " + std::to_string(key), key, kNumThreads, - Timeout(), Terminate()); + Rendezvous(flag, "key: " + std::to_string(key), key, kNumThreads, + Timeout(), Terminate()); }; }; @@ -191,8 +190,8 @@ TEST(RendezvousTest, RendezvousSingleFlagRace) { } } -TEST(RendezvousTest, RendezvousSingleFlagRaceWithBarriers) { - RendezvousSingleFlag flag; +TEST(RendezvousTest, RendezvousFlagRaceWithBarriers) { + RendezvousFlag flag; static constexpr int32_t kNumRendezvous = 16; static constexpr int32_t kNumThreads = 8; @@ -209,8 +208,8 @@ TEST(RendezvousTest, RendezvousSingleFlagRaceWithBarriers) { return [&, key] { participants_ready.DecrementCount(); participants_notification.WaitForNotification(); - RendezvousSingle(flag, "key: " + std::to_string(key), key, kNumThreads, - Timeout(), Terminate()); + Rendezvous(flag, "key: " + std::to_string(key), key, kNumThreads, + Timeout(), Terminate()); participants_done.DecrementCount(); }; }; @@ -238,8 +237,8 @@ static void BM_Rendezvous(benchmark::State& state) { absl::BlockingCounter counter(num_threads); for (int64_t i = 0; i < num_threads; ++i) { thread_pool.Schedule([&] { - RendezvousSingle("rendezvous_test", 0, num_threads, - [] { return 42; }); + Rendezvous("rendezvous_test", 0, num_threads, + [] { return 42; }); counter.DecrementCount(); }); } @@ -256,8 +255,8 @@ static void BM_RendezvousWithValues(benchmark::State& state) { for (int64_t i = 0; i < num_threads; ++i) { thread_pool.Schedule([&] { int32_t value = i; - RendezvousSingle("rendezvous_test", 0, value, num_threads, - [](auto) { return 42; }); + Rendezvous("rendezvous_test", 0, value, num_threads, + [](auto) { return 42; }); counter.DecrementCount(); }); } diff --git a/third_party/xla/xla/service/scan_loop_accumulator_input_unification_test.cc b/third_party/xla/xla/service/scan_loop_accumulator_input_unification_test.cc index a8a1911663eb1f..902d3ef3b4a936 100644 --- a/third_party/xla/xla/service/scan_loop_accumulator_input_unification_test.cc +++ b/third_party/xla/xla/service/scan_loop_accumulator_input_unification_test.cc @@ -25,9 +25,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/copy_insertion.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/statusor.h" namespace xla { @@ -55,14 +55,14 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, UnifyAccumulatorInput) { get-tuple-element.47 = s32[] get-tuple-element(wide.arg_tuple.8), index=1 get-tuple-element.48 = s32[8] get-tuple-element(wide.arg_tuple.8), index=2 get-tuple-element.54 = s32[8] get-tuple-element(wide.arg_tuple.8), index=3 - + dynamic-slice.0 = s32[1] dynamic-slice(get-tuple-element.54, get-tuple-element.46), dynamic_slice_sizes={1} reshape.2 = s32[] reshape(dynamic-slice.0) add.1 = s32[] add(get-tuple-element.47, reshape.2) reshape.3 = s32[1] reshape(add.1) dynamic-update-slice.0 = s32[8] dynamic-update-slice(get-tuple-element.48, reshape.3, get-tuple-element.46) - + const = s32[] constant(1) add.0 = s32[] add(get-tuple-element.46, const) ROOT tuple.10 = (s32[], s32[], s32[8], s32[8]) tuple(add.0, add.1, dynamic-update-slice.0, get-tuple-element.54) @@ -92,7 +92,7 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, UnifyAccumulatorInput) { add.0 = s32[] add(get-tuple-element.46, const) ROOT out = (s32[], s32[], s32[8]) tuple(add.0, get-tuple-element.47, get-tuple-element.40) } - + outer_cond { constant.5 = s32[] constant(8) wide.arg_tuple.30 = (s32[], s32[], s32[8]) parameter(0) @@ -108,7 +108,7 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, UnifyAccumulatorInput) { while = (s32[], s32[], s32[8]) while(tuple.8), condition=outer_cond, body=outer_body ROOT get-tuple-element.40 = s32[8] get-tuple-element(while), index=2 } // main.43 - + )"; auto module = ParseAndReturnVerifiedModule(kModule).value(); @@ -144,21 +144,21 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, UnifyAccumulatorInput2) { get-tuple-element.54 = s32[8] get-tuple-element(wide.arg_tuple.8), index=3 get-tuple-element.55 = s32[8] get-tuple-element(wide.arg_tuple.8), index=4 get-tuple-element.56 = s32[8] get-tuple-element(wide.arg_tuple.8), index=5 - + dynamic-slice.0 = s32[1] dynamic-slice(get-tuple-element.54, get-tuple-element.46), dynamic_slice_sizes={1} reshape.2 = s32[] reshape(dynamic-slice.0) add.1 = s32[] add(get-tuple-element.47, reshape.2) reshape.3 = s32[1] reshape(add.1) dynamic-update-slice.0 = s32[8] dynamic-update-slice(get-tuple-element.48, reshape.3, get-tuple-element.46) - + dynamic-slice.1 = s32[1] dynamic-slice(get-tuple-element.56, get-tuple-element.46), dynamic_slice_sizes={1} reshape.4 = s32[] reshape(dynamic-slice.1) add.2 = s32[] multiply(get-tuple-element.47, reshape.4) reshape.5 = s32[1] reshape(add.2) dynamic-update-slice.1 = s32[8] dynamic-update-slice(get-tuple-element.55, reshape.5, get-tuple-element.46) - + const = s32[] constant(1) add.0 = s32[] add(get-tuple-element.46, const) ROOT tuple.10 = (s32[], s32[], s32[8], s32[8], s32[8], s32[8]) tuple(add.0, add.1, dynamic-update-slice.0, get-tuple-element.54, dynamic-update-slice.1, get-tuple-element.56) @@ -186,12 +186,12 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, UnifyAccumulatorInput2) { while = (s32[], s32[], s32[8], s32[8], s32[8], s32[8]) while(tuple.8), condition=wide.region_1.29, body=wide.region_0.7 get-tuple-element.40 = s32[8] get-tuple-element(while), index=2 get-tuple-element.41 = s32[8] get-tuple-element(while), index=4 - + const = s32[] constant(1) add.0 = s32[] add(get-tuple-element.46, const) ROOT out = (s32[], s32[], s32[8], s32[8]) tuple(add.0, get-tuple-element.47, get-tuple-element.40, get-tuple-element.41) } - + outer_cond { constant.5 = s32[] constant(8) wide.arg_tuple.30 = (s32[], s32[], s32[8], s32[8]) parameter(0) @@ -210,7 +210,7 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, UnifyAccumulatorInput2) { get-tuple-element.41 = s32[8] get-tuple-element(while), index=3 ROOT out = (s32[8],s32[8]) tuple(get-tuple-element.40, get-tuple-element.41) } // main.43 - + )"; auto module = ParseAndReturnVerifiedModule(kModule).value(); @@ -246,14 +246,14 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, AccumulatorAllocateOutside) { get-tuple-element.47 = s32[] get-tuple-element(wide.arg_tuple.8), index=1 get-tuple-element.48 = s32[8] get-tuple-element(wide.arg_tuple.8), index=2 get-tuple-element.54 = s32[8] get-tuple-element(wide.arg_tuple.8), index=3 - + dynamic-slice.0 = s32[1] dynamic-slice(get-tuple-element.54, get-tuple-element.46), dynamic_slice_sizes={1} reshape.2 = s32[] reshape(dynamic-slice.0) add.1 = s32[] add(get-tuple-element.47, reshape.2) reshape.3 = s32[1] reshape(add.1) dynamic-update-slice.0 = s32[8] dynamic-update-slice(get-tuple-element.48, reshape.3, get-tuple-element.46) - + const = s32[] constant(1) add.0 = s32[] add(get-tuple-element.46, const) ROOT tuple.10 = (s32[], s32[], s32[8], s32[8]) tuple(add.0, add.1, dynamic-update-slice.0, get-tuple-element.54) @@ -282,7 +282,7 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, AccumulatorAllocateOutside) { add.0 = s32[] add(get-tuple-element.46, const) ROOT out = (s32[], s32[], s32[8], s32[8]) tuple(add.0, get-tuple-element.47, get-tuple-element.48, get-tuple-element.40) } - + outer_cond { constant.5 = s32[] constant(8) wide.arg_tuple.30 = (s32[], s32[], s32[8], s32[8]) parameter(0) @@ -299,7 +299,7 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, AccumulatorAllocateOutside) { while = (s32[], s32[], s32[8], s32[8]) while(tuple.8), condition=outer_cond, body=outer_body ROOT get-tuple-element.40 = s32[8] get-tuple-element(while), index=3 } // main.43 - + )"; auto module = ParseAndReturnVerifiedModule(kModule).value(); @@ -321,11 +321,11 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, InputDifferentShape) { get-tuple-element.47 = s32[] get-tuple-element(wide.arg_tuple.8), index=1 get-tuple-element.48 = s32[8] get-tuple-element(wide.arg_tuple.8), index=2 get-tuple-element.54 = s32[8,10] get-tuple-element(wide.arg_tuple.8), index=3 - + zero = s32[] constant(0) dynamic-slice.0 = s32[1,10] dynamic-slice(get-tuple-element.54, get-tuple-element.46, zero), dynamic_slice_sizes={1,10} reshape.2 = s32[10] reshape(dynamic-slice.0) - + dynamic-slice.1 = s32[1] dynamic-slice(reshape.2, get-tuple-element.46), dynamic_slice_sizes={1} reshape.3 = s32[] reshape(dynamic-slice.1) @@ -333,7 +333,7 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, InputDifferentShape) { reshape.4 = s32[1] reshape(add.1) dynamic-update-slice.0 = s32[8] dynamic-update-slice(get-tuple-element.48, reshape.4, get-tuple-element.46) - + const = s32[] constant(1) add.0 = s32[] add(get-tuple-element.46, const) ROOT tuple.10 = (s32[], s32[], s32[8], s32[8,10]) tuple(add.0, add.1, dynamic-update-slice.0, get-tuple-element.54) @@ -351,13 +351,13 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, InputDifferentShape) { init = s32[] constant(0) array = s32[8,10] parameter(0) broadcast.5 = s32[8] broadcast(constant.3), dimensions={} - + tuple.8 = (s32[], s32[], s32[8], s32[8,10]) tuple(constant.3, init, broadcast.5, array) while = (s32[], s32[], s32[8], s32[8,10]) while(tuple.8), condition=wide.region_1.29, body=wide.region_0.7 get-tuple-element.39 = s32[] get-tuple-element(while), index=1 ROOT get-tuple-element.40 = s32[8] get-tuple-element(while), index=2 } // main.43 - + )"; auto module = ParseAndReturnVerifiedModule(kModule).value(); @@ -383,24 +383,24 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, MultipleUsersInput) { get-tuple-element.55 = s32[8] get-tuple-element(wide.arg_tuple.8), index=4 // input get-tuple-element.56 = s32[8] get-tuple-element(wide.arg_tuple.8), index=5 - + // this is here only to have another user for gte.54 mult = s32[8] multiply(get-tuple-element.54, get-tuple-element.54) - + dynamic-slice.0 = s32[1] dynamic-slice(get-tuple-element.54, get-tuple-element.46), dynamic_slice_sizes={1} reshape.2 = s32[] reshape(dynamic-slice.0) add.1 = s32[] add(get-tuple-element.47, reshape.2) reshape.3 = s32[1] reshape(add.1) dynamic-update-slice.0 = s32[8] dynamic-update-slice(get-tuple-element.48, reshape.3, get-tuple-element.46) - + dynamic-slice.1 = s32[1] dynamic-slice(get-tuple-element.56, get-tuple-element.46), dynamic_slice_sizes={1} reshape.4 = s32[] reshape(dynamic-slice.1) add.2 = s32[] multiply(get-tuple-element.47, reshape.4) reshape.5 = s32[1] reshape(add.2) dynamic-update-slice.1 = s32[8] dynamic-update-slice(get-tuple-element.55, reshape.5, get-tuple-element.46) - + const = s32[] constant(1) add.0 = s32[] add(get-tuple-element.46, const) ROOT tuple.10 = (s32[], s32[], s32[8], s32[8], s32[8], s32[8]) tuple(add.0, add.1, dynamic-update-slice.0, get-tuple-element.54, dynamic-update-slice.1, get-tuple-element.56) @@ -412,14 +412,14 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, MultipleUsersInput) { get-tuple-element.16 = s32[] get-tuple-element(wide.arg_tuple.30), index=0 ROOT compare.0 = pred[] compare(get-tuple-element.16, constant.5), direction=LT } - + outer_body { wide.arg_tuple.8 = (s32[], s32[], s32[8], s32[8]) parameter(0) get-tuple-element.46 = s32[] get-tuple-element(wide.arg_tuple.8), index=0 get-tuple-element.47 = s32[] get-tuple-element(wide.arg_tuple.8), index=1 get-tuple-element.54 = s32[8] get-tuple-element(wide.arg_tuple.8), index=2 get-tuple-element.56 = s32[8] get-tuple-element(wide.arg_tuple.8), index=3 - + constant.3 = s32[] constant(0) broadcast = s32[8] broadcast(constant.3), dimensions={} broadcast2 = s32[8] broadcast(constant.3), dimensions={} @@ -433,7 +433,7 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, MultipleUsersInput) { add.0 = s32[] add(get-tuple-element.46, const) ROOT out = (s32[], s32[], s32[8], s32[8]) tuple(add.0, get-tuple-element.47, get-tuple-element.40, get-tuple-element.41) } - + outer_cond { constant.5 = s32[] constant(8) wide.arg_tuple.30 = (s32[], s32[], s32[8], s32[8]) parameter(0) @@ -452,7 +452,7 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, MultipleUsersInput) { get-tuple-element.41 = s32[8] get-tuple-element(while), index=3 ROOT out = (s32[8],s32[8]) tuple(get-tuple-element.40, get-tuple-element.41) } // main.43 - + )"; auto module = ParseAndReturnVerifiedModule(kModule).value(); @@ -494,7 +494,7 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, reshape.3 = s32[] reshape(dynamic-slice.1) add.1 = s32[] add(reshape.3, reshape.2) add.2 = s32[] add(add.1, get-tuple-element.47) - + reshape.4 = s32[1] reshape(add.2) dynamic-update-slice.0 = s32[8] dynamic-update-slice(get-tuple-element.48, reshape.4, get-tuple-element.46) const = s32[] constant(1) @@ -508,26 +508,26 @@ TEST_F(ScanLoopAccumulatorInputUnificationTest, get-tuple-element.16 = s32[] get-tuple-element(wide.arg_tuple.30), index=0 ROOT compare.0 = pred[] compare(get-tuple-element.16, constant.5), direction=LT } - + outer_body { wide.arg_tuple.8 = (s32[], s32[], s32[8], s32[10]) parameter(0) get-tuple-element.46 = s32[] get-tuple-element(wide.arg_tuple.8), index=0 get-tuple-element.47 = s32[] get-tuple-element(wide.arg_tuple.8), index=1 get-tuple-element.48 = s32[8] get-tuple-element(wide.arg_tuple.8), index=2 get-tuple-element.55 = s32[10] get-tuple-element(wide.arg_tuple.8), index=3 - + constant.3 = s32[] constant(0) broadcast = s32[8] broadcast(constant.3), dimensions={} - + tuple.8 = (s32[], s32[], s32[8], s32[8], s32[10]) tuple(constant.3, get-tuple-element.47, broadcast, get-tuple-element.48, get-tuple-element.55) while = (s32[], s32[], s32[8], s32[8], s32[10]) while(tuple.8), condition=wide.region_1.29, body=wide.region_0.7 get-tuple-element.40 = s32[8] get-tuple-element(while), index=2 - + const = s32[] constant(1) add.0 = s32[] add(get-tuple-element.46, const) ROOT out = (s32[], s32[], s32[8], s32[10]) tuple(add.0, get-tuple-element.47, get-tuple-element.40, get-tuple-element.55) } - + outer_cond { constant.5 = s32[] constant(8) wide.arg_tuple.30 = (s32[], s32[], s32[8], s32[10]) parameter(0) diff --git a/third_party/xla/xla/service/scatter_determinism_expander_test.cc b/third_party/xla/xla/service/scatter_determinism_expander_test.cc index 27ed15b8220980..81078b0da54499 100644 --- a/third_party/xla/xla/service/scatter_determinism_expander_test.cc +++ b/third_party/xla/xla/service/scatter_determinism_expander_test.cc @@ -596,14 +596,14 @@ TEST_F(ScatterDeterminismExpanderTest, } ENTRY scatter_add_computation { - operand = f32[3, 3, 3] constant({{{0, 0, 0}, - {0, 0, 0}, - {0, 0, 0}}, - {{0, 0, 0}, - {0, 0, 0}, - {0, 0, 0}}, - {{0, 0, 0}, - {0, 0, 0}, + operand = f32[3, 3, 3] constant({{{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, {0, 0, 0}}}) indices = s32[4, 2] constant({{0, 0}, {0, 1}, {1, 1}, {1, 2}}) updates = f32[4, 2] constant({{1, 2}, {4, 7}, {10, 13}, {21, 27}}) @@ -646,14 +646,14 @@ TEST_F(ScatterDeterminismExpanderTest, } ENTRY scatter_add_computation { - operand = f32[3, 3, 3] constant({{{0, 0, 0}, - {0, 0, 0}, - {0, 0, 0}}, - {{0, 0, 0}, - {0, 0, 0}, - {0, 0, 0}}, - {{0, 0, 0}, - {0, 0, 0}, + operand = f32[3, 3, 3] constant({{{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, {0, 0, 0}}}) indices = s32[4, 2] constant({{0, 0}, {0, 1}, {1, 1}, {1, 2}}) updates = f32[4, 2] constant({{1, 2}, {4, 7}, {10, 13}, {21, 27}}) @@ -696,14 +696,14 @@ TEST_F(ScatterDeterminismExpanderTest, } ENTRY scatter_add_computation { - operand = f32[3, 3, 3] constant({{{0, 0, 0}, - {0, 0, 0}, - {0, 0, 0}}, - {{0, 0, 0}, - {0, 0, 0}, - {0, 0, 0}}, - {{0, 0, 0}, - {0, 0, 0}, + operand = f32[3, 3, 3] constant({{{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, {0, 0, 0}}}) indices = s32[2, 2] constant({{0, 0}, {1, 1}}) updates = f32[2, 2, 2] constant({{{1, 2}, {4, 7}}, {{10, 13}, {21, 27}}}) @@ -746,14 +746,14 @@ TEST_F(ScatterDeterminismExpanderTest, } ENTRY scatter_add_computation { - operand = f32[3, 3, 3] constant({{{0, 0, 0}, - {0, 0, 0}, - {0, 0, 0}}, - {{0, 0, 0}, - {0, 0, 0}, - {0, 0, 0}}, - {{0, 0, 0}, - {0, 0, 0}, + operand = f32[3, 3, 3] constant({{{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}}, + {{0, 0, 0}, + {0, 0, 0}, {0, 0, 0}}}) indices = s32[2, 2] constant({{0, 0}, {1, 1}}) updates = f32[2, 2, 2] constant({{{1, 2}, {4, 7}}, {{10, 13}, {21, 27}}}) @@ -893,27 +893,27 @@ TEST_F(ScatterDeterminismExpanderTest, ScalarScatterAddReproducibilityTest) { ENTRY scatter_add_computation { operand = f32[3] constant({0, 0, 0}) - indices = s32[100,1] constant({{0}, {3}, {0}, {1}, {0}, {3}, {1}, {2}, {1}, {2}, {2}, {2}, {0}, {2}, {1}, + indices = s32[100,1] constant({{0}, {3}, {0}, {1}, {0}, {3}, {1}, {2}, {1}, {2}, {2}, {2}, {0}, {2}, {1}, {0}, {1}, {1}, {2}, {0}, {2}, {1}, {2}, {1}, {2}, {2}, {3}, {2}, {2}, {0}, {3}, {0}, {3}, {2}, {0}, {3}, {3}, {3}, {3}, {3}, {2}, {3}, {3}, {0}, {0}, {3}, {3}, {3}, {2}, {3}, {2}, {3}, {0}, {0}, {2}, {0}, {1}, {3}, {1}, {3}, {2}, {2}, {2}, {1}, {0}, {3}, {1}, {1}, {1}, {1}, {1}, {2}, {2}, {3}, {0}, {2}, {2}, {0}, {2}, {1}, {0}, {2}, {2}, {2}, {0}, {2}, {0}, {1}, {3}, {0}, {2}, {3}, {3}, {2}, {0}, {3}, {3}, {2}, {3}, {2}}) - updates = f32[100] constant({0.02379167, 0.8527204, 0.8132185, 0.5140263, 0.17172801, 0.8026866, 0.5124631, - 0.34838438, 0.50526905, 0.3370521, 0.10868239, 0.10520637, 0.83827364, 0.78986526, - 0.34059846, 0.8349273, 0.24575627, 0.21387374, 0.02423227, 0.5617423, 0.28066766, - 0.94366455, 0.61214995, 0.7383388, 0.52419806, 0.65466726, 0.41012764, 0.24028647, - 0.74443066, 0.03544927, 0.851014, 0.02434528, 0.47239733, 0.72706807, 0.35055435, - 0.6274171, 0.61077535, 0.06525731, 0.8091929, 0.21307838, 0.6465323, 0.3245015, - 0.5538883, 0.8849807, 0.9591211, 0.83856845, 0.48919427, 0.11810577, 0.16933143, - 0.83657074, 0.587505, 0.6867087, 0.95522237, 0.5797727, 0.28024232, 0.34749162, - 0.5199702, 0.9811766, 0.5645981, 0.2446456, 0.68722725, 0.9616587, 0.480047, - 0.88953114, 0.7083205, 0.948612, 0.67764974, 0.44131804, 0.36789334, 0.95148766, - 0.30909216, 0.70908046, 0.8749926, 0.60973287, 0.60751855, 0.22647333, 0.5363518, - 0.96195626, 0.08158326, 0.5266887, 0.85922587, 0.648262, 0.4657668, 0.31623375, - 0.43507564, 0.48351157, 0.41285944, 0.73501325, 0.15267539, 0.67055714, 0.08459568, - 0.04527426, 0.21078384, 0.4654404, 0.7363906, 0.23245859, 0.22119188, 0.99092937, + updates = f32[100] constant({0.02379167, 0.8527204, 0.8132185, 0.5140263, 0.17172801, 0.8026866, 0.5124631, + 0.34838438, 0.50526905, 0.3370521, 0.10868239, 0.10520637, 0.83827364, 0.78986526, + 0.34059846, 0.8349273, 0.24575627, 0.21387374, 0.02423227, 0.5617423, 0.28066766, + 0.94366455, 0.61214995, 0.7383388, 0.52419806, 0.65466726, 0.41012764, 0.24028647, + 0.74443066, 0.03544927, 0.851014, 0.02434528, 0.47239733, 0.72706807, 0.35055435, + 0.6274171, 0.61077535, 0.06525731, 0.8091929, 0.21307838, 0.6465323, 0.3245015, + 0.5538883, 0.8849807, 0.9591211, 0.83856845, 0.48919427, 0.11810577, 0.16933143, + 0.83657074, 0.587505, 0.6867087, 0.95522237, 0.5797727, 0.28024232, 0.34749162, + 0.5199702, 0.9811766, 0.5645981, 0.2446456, 0.68722725, 0.9616587, 0.480047, + 0.88953114, 0.7083205, 0.948612, 0.67764974, 0.44131804, 0.36789334, 0.95148766, + 0.30909216, 0.70908046, 0.8749926, 0.60973287, 0.60751855, 0.22647333, 0.5363518, + 0.96195626, 0.08158326, 0.5266887, 0.85922587, 0.648262, 0.4657668, 0.31623375, + 0.43507564, 0.48351157, 0.41285944, 0.73501325, 0.15267539, 0.67055714, 0.08459568, + 0.04527426, 0.21078384, 0.4654404, 0.7363906, 0.23245859, 0.22119188, 0.99092937, 0.878675, 0.4102913}) ROOT scatter.48 = f32[3] scatter(operand, indices, updates), update_window_dims={}, inserted_window_dims={0}, @@ -965,14 +965,14 @@ TEST_F(ScatterDeterminismExpanderTest, NonScalarScatterAddReproducibilityTest) { ENTRY scatter_add_computation { operand = f32[3, 3] constant({{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}) - indices = s32[50, 2] constant({{0, 0}, {0, 1}, {1, 1}, {2, 2}, {0, 1}, {1, 0}, {2, 1}, {1, 2}, {0, 2}, {2, 0}, + indices = s32[50, 2] constant({{0, 0}, {0, 1}, {1, 1}, {2, 2}, {0, 1}, {1, 0}, {2, 1}, {1, 2}, {0, 2}, {2, 0}, {1, 1}, {2, 2}, {0, 0}, {0, 1}, {2, 1}, {1, 2}, {2, 0}, {0, 2}, {1, 0}, {1, 1}, {1, 2}, {2, 1}, {0, 0}, {1, 1}, {0, 2}, {2, 0}, {1, 0}, {2, 2}, {1, 2}, {0, 1}, {2, 1}, {1, 0}, {0, 2}, {2, 0}, {0, 1}, {2, 1}, {1, 1}, {1, 0}, {2, 2}, {0, 0}, {0, 1}, {1, 2}, {2, 0}, {1, 1}, {0, 2}, {2, 1}, {1, 2}, {2, 1}, {1, 1}, {0, 2}}) - updates = f32[50, 2] constant({{0.02379167, 0.8527204}, {0.8132185, 0.5140263}, {0.17172801, 0.8026866}, - {0.5124631, 0.34838438}, {0.50526905, 0.3370521}, {0.10868239, 0.10520637}, - {0.83827364, 0.78986526}, {0.34059846, 0.8349273}, {0.24575627, 0.21387374}, + updates = f32[50, 2] constant({{0.02379167, 0.8527204}, {0.8132185, 0.5140263}, {0.17172801, 0.8026866}, + {0.5124631, 0.34838438}, {0.50526905, 0.3370521}, {0.10868239, 0.10520637}, + {0.83827364, 0.78986526}, {0.34059846, 0.8349273}, {0.24575627, 0.21387374}, {0.02423227, 0.5617423}, {0.28066766, 0.94366455}, {0.61214995, 0.7383388}, {0.52419806, 0.65466726}, {0.41012764, 0.24028647}, {0.74443066, 0.03544927}, {0.851014, 0.02434528}, {0.47239733, 0.72706807}, {0.35055435, 0.6274171}, diff --git a/third_party/xla/xla/service/select_and_scatter_expander_test.cc b/third_party/xla/xla/service/select_and_scatter_expander_test.cc index 0daf6a7fa586a2..001dea8281766a 100644 --- a/third_party/xla/xla/service/select_and_scatter_expander_test.cc +++ b/third_party/xla/xla/service/select_and_scatter_expander_test.cc @@ -31,13 +31,13 @@ constexpr absl::string_view kModuleStr = %rhs = f32[] parameter(1) ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE, type=TOTALORDER } - + %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] { %lhs.1 = f32[] parameter(0) %rhs.1 = f32[] parameter(1) ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1) } - + ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] { %constant = f32[4,5,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } }) %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } }) diff --git a/third_party/xla/xla/service/service_executable_run_options.h b/third_party/xla/xla/service/service_executable_run_options.h index f59dedde20999e..0cb9c0a28b4770 100644 --- a/third_party/xla/xla/service/service_executable_run_options.h +++ b/third_party/xla/xla/service/service_executable_run_options.h @@ -91,7 +91,6 @@ class ServiceExecutableRunOptions { private: ExecutableRunOptions run_options_; StreamBorrower stream_borrower_; - int64_t local_device_count_; }; } // namespace xla diff --git a/third_party/xla/xla/service/shape_inference_test.cc b/third_party/xla/xla/service/shape_inference_test.cc index 8826f0b6e3bddf..996d1d66191546 100644 --- a/third_party/xla/xla/service/shape_inference_test.cc +++ b/third_party/xla/xla/service/shape_inference_test.cc @@ -129,7 +129,7 @@ struct BinaryOpTestCase { std::string rhs; absl::Span broadcast_dimensions; std::string expected; - std::optional error_message; + std::optional error_message; }; // Subclass for testing unbounded dynamic logical ops diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index 66515af52c8903..dbe082daee0f44 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -397,6 +397,7 @@ bool SupportSpatialPartitioning( case HloOpcode::kReduce: case HloOpcode::kRngBitGenerator: case HloOpcode::kAllReduce: + case HloOpcode::kCollectivePermute: case HloOpcode::kReduceScatter: return true; case HloOpcode::kParameter: @@ -2459,7 +2460,7 @@ bool ShardingPropagation::InferShardingFromOperands( const int64_t sort_dim = sort->sort_dimension(); if (!operand->sharding().IsTileMaximal() && operand->sharding().tile_assignment().dim(sort_dim) != 1 && - !hlo_sharding_util::GetFirstMergeableDimForSortOperand( + !hlo_sharding_util::GetFirstTargetDimToMoveShardingTiles( operand->shape(), operand->sharding(), sort_dim) .has_value()) { // In case of a sort operand sharded along the sort dimension, the diff --git a/third_party/xla/xla/service/sharding_propagation_test.cc b/third_party/xla/xla/service/sharding_propagation_test.cc index 96d3d5d659c89e..1d4eb22cf5f8e4 100644 --- a/third_party/xla/xla/service/sharding_propagation_test.cc +++ b/third_party/xla/xla/service/sharding_propagation_test.cc @@ -2828,7 +2828,7 @@ HloModule module %count = u32[] get-tuple-element(%param), index=0 %after-all = token[] after-all() %recv = (f32[], u32[], token[]) recv(%after-all), channel_id=1, - sharding={{maximal device=1 metadata={op_name="a"}}, + sharding={{maximal device=1 metadata={op_name="a"}}, {maximal device=1}, {maximal device=1}} %recv-done = (f32[], token[]) recv-done(%recv), channel_id=1 %data = f32[] get-tuple-element(%recv-done), index=0 @@ -2889,7 +2889,7 @@ HloModule module sharding={maximal device=0 metadata={op_name="a"}} %after-all = token[] after-all() %recv = (f32[], u32[], token[]) recv(%after-all), channel_id=1, - sharding={{maximal device=1 metadata={op_name="b"}}, + sharding={{maximal device=1 metadata={op_name="b"}}, {maximal device=1}, {maximal device=1}} %recv-done = (f32[], token[]) recv-done(%recv), channel_id=1 %data = f32[] get-tuple-element(%recv-done), index=0 @@ -2934,7 +2934,7 @@ HloModule module %count = u32[] get-tuple-element(%param), index=0 %after-all = token[] after-all() %recv = (f32[], u32[], token[]) recv(%after-all), channel_id=1, - sharding={{maximal device=1 metadata={op_name="a"}}, + sharding={{maximal device=1 metadata={op_name="a"}}, {maximal device=1}, {maximal device=1}} %recv-done = (f32[], token[]) recv-done(%recv), channel_id=1 %data = f32[] get-tuple-element(%recv-done), index=0, @@ -2980,7 +2980,7 @@ HloModule module %count = u32[] get-tuple-element(%param), index=0 %after-all = token[] after-all() %recv = (f32[], u32[], token[]) recv(%after-all), channel_id=1, - sharding={{maximal device=1 metadata={op_name="a"}}, + sharding={{maximal device=1 metadata={op_name="a"}}, {maximal device=1}, {maximal device=1}} %recv-done = (f32[], token[]) recv-done(%recv), channel_id=1 %data = f32[] get-tuple-element(%recv-done), index=0 diff --git a/third_party/xla/xla/service/sharding_remover.cc b/third_party/xla/xla/service/sharding_remover.cc index ea26ab13bf9194..042e9f137ef1f0 100644 --- a/third_party/xla/xla/service/sharding_remover.cc +++ b/third_party/xla/xla/service/sharding_remover.cc @@ -28,6 +28,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/spmd/shard_barrier_partitioner.h" #include "xla/service/spmd/shardy/constants.h" #include "tsl/platform/errors.h" @@ -41,9 +42,13 @@ absl::StatusOr ShardingRemover::Run( bool changed = false; const absl::flat_hash_set to_remove_sharding_ops = { - "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape", + "Sharding", + "SPMDShardToFullShape", + "SPMDFullToShardShape", sdy::kShardingGroupCustomCallTargetName, - sdy::kFuncResultShardingTargetName}; + sdy::kFuncResultShardingTargetName, + spmd::kShardBarrierFrom, + spmd::kShardBarrierTo}; for (HloComputation* computation : module->computations(execution_threads)) { auto instructions = computation->MakeInstructionPostOrder(); @@ -74,7 +79,9 @@ absl::StatusOr ShardingRemover::Run( // with a copy instead, so that it can be DCE-ed in later passes. if (instruction->custom_call_target() == "Sharding" || instruction->custom_call_target() == - sdy::kFuncResultShardingTargetName) { + sdy::kFuncResultShardingTargetName || + instruction->custom_call_target() == spmd::kShardBarrierFrom || + instruction->custom_call_target() == spmd::kShardBarrierTo) { auto copy = computation->AddInstruction( HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kCopy, instruction->mutable_operand(0))); diff --git a/third_party/xla/xla/service/space_to_batch_converter_test.cc b/third_party/xla/xla/service/space_to_batch_converter_test.cc index a88d157314c7aa..6473c65dccf73b 100644 --- a/third_party/xla/xla/service/space_to_batch_converter_test.cc +++ b/third_party/xla/xla/service/space_to_batch_converter_test.cc @@ -33,12 +33,12 @@ namespace op = testing::opcode_matchers; TEST_F(SpaceToBatchConverterTest, SimpleBatch1) { std::string hlo_string = R"( - + HloModule module ENTRY computation { %p0 = bf16[1,258,258,32] parameter(0) %p1 = bf16[3,3,32,32] parameter(1) - ROOT %convolution = bf16[1,256,256,32] convolution(%p0, %p1), window={size=3x3}, + ROOT %convolution = bf16[1,256,256,32] convolution(%p0, %p1), window={size=3x3}, dim_labels=b01f_01io->b01f } @@ -68,12 +68,12 @@ ENTRY computation { TEST_F(SpaceToBatchConverterTest, SimpleBatch1ConvXpose) { std::string hlo_string = R"( - + HloModule module ENTRY computation { %p0 = bf16[1,258,258,32] parameter(0) %p1 = bf16[3,3,32,32] parameter(1) - %convolution = bf16[1,256,256,32] convolution(%p0, %p1), window={size=3x3}, + %convolution = bf16[1,256,256,32] convolution(%p0, %p1), window={size=3x3}, dim_labels=b01f_01io->b01f ROOT tr = bf16[1,256,256,32] transpose(%convolution), dimensions={0,2,1,3} } @@ -101,7 +101,7 @@ ENTRY computation { TEST_F(SpaceToBatchConverterTest, SimpleBatch1WithReduceWindow) { std::string hlo_string = R"( - HloModule module + HloModule module adder (lhs: bf16[], rhs: bf16[]) -> bf16[] { lhs = bf16[] parameter(0) rhs = bf16[] parameter(1) @@ -159,8 +159,8 @@ TEST_F(SpaceToBatchConverterTest, UnpropagatableOp) { ENTRY comp { %reduce-window = bf16[1,76,76,64]{3,2,1,0} parameter(0) %convert.13 = bf16[3,3,64,64]{3,2,1,0} parameter(1) - %convolution.1 = bf16[64,76,76,1]{0,2,1,3} convolution( - %reduce-window, %convert.13), window={size=3x3 pad=1_1x1_1}, + %convolution.1 = bf16[64,76,76,1]{0,2,1,3} convolution( + %reduce-window, %convert.13), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->f01b ROOT custom-call.5079 = bf16[64,152,152,1]{0,2,1,3} custom-call(%convolution.1), custom_call_target="ResizeNearest" @@ -181,8 +181,8 @@ TEST_F(SpaceToBatchConverterTest, Batch1WithStrideAndPad) { ENTRY computation { %p0 = bf16[1,224,224,3]{3,2,1,0} parameter(0) %p1 = bf16[7,7,3,64]{3,2,1,0} parameter(1) - - ROOT %convolution.3 = bf16[1,112,112,64]{3,2,1,0} convolution(%p0, %p1), + + ROOT %convolution.3 = bf16[1,112,112,64]{3,2,1,0} convolution(%p0, %p1), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f } )"; @@ -211,7 +211,7 @@ TEST_F(SpaceToBatchConverterTest, Batch1WithStrideAndPad) { TEST_F(SpaceToBatchConverterTest, Batch1WithBaseDilation) { std::string hlo_string = R"( - + HloModule module ENTRY computation { %p2 = bf16[1,28,28,128]{3,0,2,1} parameter(0) @@ -326,7 +326,7 @@ TEST_F(SpaceToBatchConverterTest, DoNotPropagateOnTupleReduce) { %select.2727 = f32[] select(pred[] %compare.2725, f32[] %minimum.2726, f32[] %select.2724) ROOT %tuple.4 = (f32[], f32[]) tuple(f32[] %select.2723, f32[] %select.2727) } - + ENTRY computation { %p0 = bf16[7,320,800,3]{3,2,1,0} parameter(0) %p1 = bf16[3,3,3,32]{3,2,1,0} parameter(1) @@ -359,15 +359,15 @@ TEST_F(SpaceToBatchConverterTest, ReduceDegenerateDim) { %Arg_1.39 = f32[] parameter(1) ROOT %add.40 = f32[] add(f32[] %Arg_0.38, f32[] %Arg_1.39) } - + ENTRY computation { %p0 = f32[2,1,84,84,3]{4,3,2,1,0} parameter(0) %p1 = f32[3,3,3,3,32]{4,3,2,1,0} parameter(1) %constant.10559 = f32[] constant(0) - %convolution.98 = f32[2,1,84,84,32]{4,3,2,1,0} convolution(%p0, %p1), + %convolution.98 = f32[2,1,84,84,32]{4,3,2,1,0} convolution(%p0, %p1), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f - - ROOT %reduce.2606 = f32[2,84,84]{2,1,0} reduce(f32[2,1,84,84,32]{4,3,2,1,0} + + ROOT %reduce.2606 = f32[2,84,84]{2,1,0} reduce(f32[2,1,84,84,32]{4,3,2,1,0} %convolution.98, f32[] %constant.10559), dimensions={1,4}, to_apply=%region_42.4982 } )"; diff --git a/third_party/xla/xla/service/spmd/BUILD b/third_party/xla/xla/service/spmd/BUILD index 1110028dc23cd5..87b6dc150cc03e 100644 --- a/third_party/xla/xla/service/spmd/BUILD +++ b/third_party/xla/xla/service/spmd/BUILD @@ -55,9 +55,9 @@ cc_library( "//xla/hlo/parser:hlo_lexer", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:flatten_call_graph", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:flatten_call_graph", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/hlo/utils:hlo_query", "//xla/hlo/utils:hlo_sharding_util", "//xla/service:call_graph", @@ -112,6 +112,8 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -120,8 +122,6 @@ xla_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/spmd/canonicalize_all_gather_for_cse.cc b/third_party/xla/xla/service/spmd/canonicalize_all_gather_for_cse.cc index f7639aa633f5fe..745eed9ebb73a3 100644 --- a/third_party/xla/xla/service/spmd/canonicalize_all_gather_for_cse.cc +++ b/third_party/xla/xla/service/spmd/canonicalize_all_gather_for_cse.cc @@ -15,6 +15,10 @@ limitations under the License. #include "xla/service/spmd/canonicalize_all_gather_for_cse.h" +#include +#include +#include + #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" diff --git a/third_party/xla/xla/service/spmd/canonicalize_all_gather_for_cse.h b/third_party/xla/xla/service/spmd/canonicalize_all_gather_for_cse.h index 8b322c20611084..113ffa17ee27d6 100644 --- a/third_party/xla/xla/service/spmd/canonicalize_all_gather_for_cse.h +++ b/third_party/xla/xla/service/spmd/canonicalize_all_gather_for_cse.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_SERVICE_SPMD_CANONICALIZE_ALL_GATHER_FOR_CSE_H_ #define XLA_SERVICE_SPMD_CANONICALIZE_ALL_GATHER_FOR_CSE_H_ +#include + #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" diff --git a/third_party/xla/xla/service/spmd/canonicalize_all_gather_for_cse_test.cc b/third_party/xla/xla/service/spmd/canonicalize_all_gather_for_cse_test.cc index 85b8ed4fce539c..593d6f7d36c32c 100644 --- a/third_party/xla/xla/service/spmd/canonicalize_all_gather_for_cse_test.cc +++ b/third_party/xla/xla/service/spmd/canonicalize_all_gather_for_cse_test.cc @@ -15,6 +15,10 @@ limitations under the License. #include "xla/service/spmd/canonicalize_all_gather_for_cse.h" +#include +#include +#include + #include #include #include "absl/status/status.h" diff --git a/third_party/xla/xla/service/spmd/convolution_handler.cc b/third_party/xla/xla/service/spmd/convolution_handler.cc index a084c2ec98fae6..aaf27dcd30c194 100644 --- a/third_party/xla/xla/service/spmd/convolution_handler.cc +++ b/third_party/xla/xla/service/spmd/convolution_handler.cc @@ -793,7 +793,7 @@ absl::StatusOr PartitionConvolutionTiledOutput( lhs = lhs.Reshard(target_operand_sharding); // Replicate the RHS. - rhs = rhs.Reshard(HloSharding::Replicate()); + rhs = rhs.Replicate(); // Convolution window config does not include batch and feature dimensions, // whereas ReshardAsWindowedInput() expects the same number of window diff --git a/third_party/xla/xla/service/spmd/convolution_handler.h b/third_party/xla/xla/service/spmd/convolution_handler.h index 0799b0d53202e8..6df55c85e9fcc4 100644 --- a/third_party/xla/xla/service/spmd/convolution_handler.h +++ b/third_party/xla/xla/service/spmd/convolution_handler.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_SERVICE_SPMD_CONVOLUTION_HANDLER_H_ #define XLA_SERVICE_SPMD_CONVOLUTION_HANDLER_H_ +#include + #include "absl/functional/function_ref.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_computation.h" diff --git a/third_party/xla/xla/service/spmd/custom_call_handler.h b/third_party/xla/xla/service/spmd/custom_call_handler.h index ff3737279d43bb..cf54c5e272c012 100644 --- a/third_party/xla/xla/service/spmd/custom_call_handler.h +++ b/third_party/xla/xla/service/spmd/custom_call_handler.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_SPMD_CUSTOM_CALL_HANDLER_H_ #define XLA_SERVICE_SPMD_CUSTOM_CALL_HANDLER_H_ +#include #include #include "xla/hlo/ir/hlo_instruction.h" diff --git a/third_party/xla/xla/service/spmd/dot_handler.cc b/third_party/xla/xla/service/spmd/dot_handler.cc index 2b7547b2ea35bc..ef619b7719e7ec 100644 --- a/third_party/xla/xla/service/spmd/dot_handler.cc +++ b/third_party/xla/xla/service/spmd/dot_handler.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -579,13 +580,13 @@ std::optional GetWindowedEinsumConfiguration( ? PartitionedHlo(partitioned_lhs->hlo(), partitioned_lhs->base_shape(), partitioned_lhs->state()) - .Reshard(HloSharding::Replicate()) + .Replicate() : *partitioned_lhs; auto new_rhs = rhs_needs_ag ? PartitionedHlo(partitioned_rhs->hlo(), partitioned_rhs->base_shape(), partitioned_rhs->state()) - .Reshard(HloSharding::Replicate()) + .Replicate() : *partitioned_rhs; dot = (*create_sharded_dot)(new_lhs.hlo(), new_rhs.hlo(), b, conv_window) .value(); @@ -2016,16 +2017,14 @@ absl::StatusOr PartitionBaseCase( if (lhs_non_contracting_partitions == num_partitions && output_lhs_non_contracting_partitions == num_partitions && lhs_sharding_transposed_to_match_output == output_sharding) { - auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo(); - return create_sharded_dot(lhs.hlo(), rhs_replicated, b, conv_window); + return create_sharded_dot(lhs.hlo(), rhs.Replicate().hlo(), b, conv_window); } // RHS and output have the same partitioned non-contracting dimensions. if (rhs_non_contracting_partitions == num_partitions && output_rhs_non_contracting_partitions == num_partitions && rhs_sharding_transposed_to_match_output == output_sharding) { - auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo(); - return create_sharded_dot(lhs_replicated, rhs.hlo(), b, conv_window); + return create_sharded_dot(lhs.Replicate().hlo(), rhs.hlo(), b, conv_window); } if (may_reshard_without_detecting_match) { @@ -2042,13 +2041,13 @@ absl::StatusOr PartitionBaseCase( if (output_lhs_non_contracting_partitions == num_partitions) { auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); - auto replicated_rhs = rhs.Reshard(HloSharding::Replicate()); + auto replicated_rhs = rhs.Replicate(); return create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), b, conv_window); } // Output is partitioned along RHS non-contracting dimensions. if (output_rhs_non_contracting_partitions == num_partitions) { - auto replicated_lhs = lhs.Reshard(HloSharding::Replicate()); + auto replicated_lhs = lhs.Replicate(); auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); return create_sharded_dot(replicated_lhs.hlo(), resharded_rhs.hlo(), b, diff --git a/third_party/xla/xla/service/spmd/fft_handler.cc b/third_party/xla/xla/service/spmd/fft_handler.cc index 8b70f8d4b58e2d..7bff2e341d5da5 100644 --- a/third_party/xla/xla/service/spmd/fft_handler.cc +++ b/third_party/xla/xla/service/spmd/fft_handler.cc @@ -15,10 +15,12 @@ limitations under the License. #include -#include +#include #include #include +#include #include +#include #include #include "absl/log/check.h" diff --git a/third_party/xla/xla/service/spmd/gather_scatter_handler.cc b/third_party/xla/xla/service/spmd/gather_scatter_handler.cc index ecf06378a266cc..57f13ca7d1c5fb 100644 --- a/third_party/xla/xla/service/spmd/gather_scatter_handler.cc +++ b/third_party/xla/xla/service/spmd/gather_scatter_handler.cc @@ -193,6 +193,44 @@ std::vector GatherOutputDimsByPriority( return priority_dims_for_output; } +PartitionedHlo ClampGatherIndices(const PartitionedHlo& indices, + const Shape& operand_base_shape, + absl::Span start_index_map, + int64_t index_vector_dim, SpmdBuilder* b) { + const PrimitiveType indices_type = indices.hlo()->shape().element_type(); + + HloInstruction* max_indices; + if (index_vector_dim < indices.rank()) { + std::vector max_indices_values; + max_indices_values.reserve(start_index_map.size()); + for (int64_t operand_dim : start_index_map) { + max_indices_values.push_back(operand_base_shape.dimensions(operand_dim) - + 1); + } + max_indices = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(max_indices_values))); + max_indices = b->AddInstruction(HloInstruction::CreateBroadcast( + indices.hlo()->shape(), max_indices, {index_vector_dim})); + } else { + CHECK_EQ(start_index_map.size(), 1); + max_indices = CreateR0WithType( + indices_type, operand_base_shape.dimensions(start_index_map[0]) - 1, b); + max_indices = b->AddInstruction(HloInstruction::CreateBroadcast( + indices.hlo()->shape(), max_indices, {})); + } + + HloInstruction* constant_zero = CreateR0WithType(indices_type, 0, b); + HloInstruction* min_indices = + b->AddInstruction(HloInstruction::CreateBroadcast(indices.hlo()->shape(), + constant_zero, {})); + + HloInstruction* clamped_indices = b->AddInstruction( + HloInstruction::CreateTernary(indices.hlo()->shape(), HloOpcode::kClamp, + min_indices, indices.hlo(), max_indices)); + clamped_indices->set_sharding(indices.sharding()); + return PartitionedHlo(clamped_indices, indices.base_shape(), indices.state()); +} + // Returns the min and max for the indices in a scatter/gather which has the // operand partitioned on trivial slice dimensions (slice size 1). std::pair @@ -451,11 +489,9 @@ absl::StatusOr PartitionGatherTrivialSlicedOperandDimensions( SpmdBuilder* b = visitor->builder(); const GatherDimensionNumbers& dnums = gather->gather_dimension_numbers(); - std::vector start_index_map(dnums.start_index_map().begin(), - dnums.start_index_map().end()); if (std::optional> trivial_slice_dims = GatherScatterOperandPartitionedOnTrivialSliceDims( - operand, start_index_map, slice_sizes)) { + operand, dnums.start_index_map(), slice_sizes)) { const HloSharding original_operand_sharding = operand.sharding(); const int64_t num_groups = operand.sharding().NumTiles(*trivial_slice_dims); const int64_t num_tiles = operand.sharding().TotalNumTiles(); @@ -504,6 +540,9 @@ absl::StatusOr PartitionGatherTrivialSlicedOperandDimensions( // Reshard indices to its intended sharding before clamping and adjusting. indices = indices.Reshard(hlo_sharding_util::UngroupSharding(indices_grouped)); + indices = ClampGatherIndices(indices, operand.base_shape(), + dnums.start_index_map(), + dnums.index_vector_dim(), b); // Now the operand is partitioned in trivial slice dimensions, and the // indices are replicated. We execute a gather on partitioned operand, // with full number of indices, where out-of-bounds indices are clamped, @@ -514,8 +553,9 @@ absl::StatusOr PartitionGatherTrivialSlicedOperandDimensions( HloInstruction* indices_max; std::tie(indices_min, indices_max) = IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( - operand, indices, operand.state().partition_id, start_index_map, - *trivial_slice_dims, dnums.index_vector_dim(), b); + operand, indices, operand.state().partition_id, + dnums.start_index_map(), *trivial_slice_dims, + dnums.index_vector_dim(), b); // Clamp the indices. auto adjusted_indices = b->AddInstruction( HloInstruction::CreateTernary(indices.hlo()->shape(), HloOpcode::kClamp, diff --git a/third_party/xla/xla/service/spmd/schedule_aware_collective_ops_cse.cc b/third_party/xla/xla/service/spmd/schedule_aware_collective_ops_cse.cc index 7c16628c70f21b..a4b9b5c6ee991f 100644 --- a/third_party/xla/xla/service/spmd/schedule_aware_collective_ops_cse.cc +++ b/third_party/xla/xla/service/spmd/schedule_aware_collective_ops_cse.cc @@ -15,6 +15,10 @@ limitations under the License. #include "xla/service/spmd/schedule_aware_collective_ops_cse.h" +#include +#include +#include + #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" diff --git a/third_party/xla/xla/service/spmd/schedule_aware_collective_ops_cse.h b/third_party/xla/xla/service/spmd/schedule_aware_collective_ops_cse.h index 8eb52bbdbcdfa0..b23216be99f837 100644 --- a/third_party/xla/xla/service/spmd/schedule_aware_collective_ops_cse.h +++ b/third_party/xla/xla/service/spmd/schedule_aware_collective_ops_cse.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_SERVICE_SPMD_SCHEDULE_AWARE_COLLECTIVE_OPS_CSE_H_ #define XLA_SERVICE_SPMD_SCHEDULE_AWARE_COLLECTIVE_OPS_CSE_H_ +#include + #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" diff --git a/third_party/xla/xla/service/spmd/schedule_aware_collective_ops_cse_test.cc b/third_party/xla/xla/service/spmd/schedule_aware_collective_ops_cse_test.cc index c7f6b546851e9d..e39b802c935f65 100644 --- a/third_party/xla/xla/service/spmd/schedule_aware_collective_ops_cse_test.cc +++ b/third_party/xla/xla/service/spmd/schedule_aware_collective_ops_cse_test.cc @@ -15,6 +15,10 @@ limitations under the License. #include "xla/service/spmd/schedule_aware_collective_ops_cse.h" +#include +#include +#include + #include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" diff --git a/third_party/xla/xla/service/spmd/shardy/BUILD b/third_party/xla/xla/service/spmd/shardy/BUILD index 27f784d2573d8a..a941cd9e21ddf0 100644 --- a/third_party/xla/xla/service/spmd/shardy/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/BUILD @@ -35,8 +35,9 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/hlo/transforms/simplifiers:tuple_simplifier", + "//xla/hlo/translate:stablehlo", "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//xla/hlo/utils:hlo_sharding_util", @@ -55,6 +56,8 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", @@ -83,8 +86,10 @@ cc_library( "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@shardy//shardy/dialect/sdy/ir:dialect", "@shardy//shardy/dialect/sdy/ir:register", + "@stablehlo//:stablehlo_ops", ], ) @@ -101,9 +106,9 @@ xla_cc_test( ":constants", ":shardy_xla_pass", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "@com_google_absl//absl/log", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", @@ -116,6 +121,7 @@ xla_cc_binary( deps = [ "//xla/mlir_hlo", "//xla/mlir_hlo:mhlo_passes", + "//xla/service/spmd/shardy/mhlo_round_trip:export_callback_custom_calls", "//xla/service/spmd/shardy/mhlo_round_trip:export_ops", "//xla/service/spmd/shardy/mhlo_round_trip:export_shardings", "//xla/service/spmd/shardy/mhlo_round_trip:mhlo_export", @@ -129,6 +135,7 @@ xla_cc_binary( "//xla/service/spmd/shardy/round_trip_common:open_while_free_vars_sharding", "//xla/service/spmd/shardy/sdy_round_trip:export_ops", "//xla/service/spmd/shardy/sdy_round_trip:export_shardy_attrs", + "//xla/service/spmd/shardy/sdy_round_trip:import_callback_custom_calls", "//xla/service/spmd/shardy/sdy_round_trip:import_shardy_attrs", "//xla/service/spmd/shardy/sdy_round_trip:pipelines", "//xla/service/spmd/shardy/sdy_round_trip:remove_size_one_axes", @@ -137,12 +144,11 @@ xla_cc_binary( "//xla/service/spmd/shardy/sdy_round_trip/test_utils:mhlo_to_hlo_to_mhlo", "//xla/service/spmd/shardy/sdy_round_trip/test_utils:testing_pipeline", "@llvm-project//mlir:AllPassesAndDialects", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", "@llvm-project//mlir:MlirOptLib", "@shardy//shardy/dialect/sdy/ir:dialect", + "@shardy//shardy/dialect/sdy/ir:register", "@shardy//shardy/dialect/sdy/transforms:passes", - "@stablehlo//:stablehlo_ops", ], ) diff --git a/third_party/xla/xla/service/spmd/shardy/constants.h b/third_party/xla/xla/service/spmd/shardy/constants.h index 220a43e1b48cc9..4ebd8d3690d066 100644 --- a/third_party/xla/xla/service/spmd/shardy/constants.h +++ b/third_party/xla/xla/service/spmd/shardy/constants.h @@ -21,6 +21,9 @@ limitations under the License. namespace xla { namespace sdy { +// The attribute name for attributes in MHLO ops. +inline constexpr llvm::StringRef kMhloAttributesAttr = "mhlo.attributes"; + // The attribute name for xla::HloSharding. inline constexpr llvm::StringRef kXlaShardingAttr = "mhlo.sharding"; @@ -35,6 +38,14 @@ inline constexpr llvm::StringRef kSPMDFullToShardShapeCallTargetName = inline constexpr llvm::StringRef kSPMDShardToFullShapeCallTargetName = "SPMDShardToFullShape"; +// The target name of the Python CPU callback custom call. +inline constexpr llvm::StringRef kPythonCpuCallbackCustomCallTargetName = + "xla_python_cpu_callback"; + +// The target name of the Python GPU callback custom call. +inline constexpr llvm::StringRef kPythonGpuCallbackCustomCallTargetName = + "xla_python_gpu_callback"; + // The attribute name for backend config. inline constexpr llvm::StringRef kXlaBackendConfigAttr = "backend_config"; diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/BUILD b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/BUILD index f427b97624b143..8e4337496dd5e2 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/BUILD @@ -13,6 +13,7 @@ package_group( name = "friends", packages = [ "//learning/deepmind/partir/...", + "//learning/deepmind/partir/compiler/mpmd/export/...", "//third_party/openxla/shardy/tools/...", "//xla/service/spmd/shardy/...", ], @@ -37,6 +38,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) @@ -54,6 +56,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) @@ -70,13 +73,29 @@ cc_library( "//xla/service/spmd/shardy:constants", "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", + ], +) + +cc_library( + name = "export_callback_custom_calls", + srcs = ["export_callback_custom_calls.cc"], + hdrs = ["export_callback_custom_calls.h"], + deps = [ + "//xla/service/spmd/shardy:constants", + "//xla/service/spmd/shardy:utils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) @@ -85,6 +104,7 @@ cc_library( srcs = ["mhlo_export.cc"], hdrs = ["mhlo_export.h"], deps = [ + ":export_callback_custom_calls", ":export_ops", ":export_shardings", ":shard_map_export", @@ -147,7 +167,6 @@ cc_library( hdrs = ["shard_map_import.h"], deps = [ "//xla:xla_data_proto_cc", - "//xla/mlir_hlo", "//xla/mlir_hlo:mhlo_passes", "//xla/service/spmd/shardy:constants", "@com_google_absl//absl/algorithm:container", @@ -163,5 +182,6 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.cc new file mode 100644 index 00000000000000..1a02da265ee971 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.cc @@ -0,0 +1,120 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h" + +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "xla/service/spmd/shardy/constants.h" +#include "xla/service/spmd/shardy/utils.h" + +namespace xla { +namespace sdy { + +namespace { + +using ::mlir::ModuleOp; +using ::mlir::OperationPass; +using ::mlir::PassWrapper; +using ::mlir::StringRef; + +using ::mlir::stablehlo::CustomCallOp; + +// Attempts to replace the `CustomCallOp` with a tuple version of it, and a +// `GetTupleElementOp` that gets the first element of the tuple. +// +// This only happens if the op has a single result and the result type is not +// a tuple. +void replaceCallbackWithTupleVersion(CustomCallOp customCall, + mlir::IRRewriter& rewriter) { + if (customCall.getNumResults() != 1 || + mlir::isa(customCall->getResultTypes().front())) { + return; + } + CustomCallOp tupleCustomCall = cloneCustomCallWithNewResultTypes( + customCall, + mlir::TupleType::get(customCall->getContext(), + {customCall->getResultTypes()}), + rewriter); + auto getTupleElement = rewriter.create( + customCall.getLoc(), customCall->getResultTypes().front(), + tupleCustomCall.getResult(0), rewriter.getI32IntegerAttr(0)); + getTupleElement->setAttr(kXlaShardingAttr, + customCall->getAttr(kXlaShardingAttr)); + rewriter.replaceOp(customCall, getTupleElement); +} + +class MhloRoundTripExportCallbackCustomCallsPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + MhloRoundTripExportCallbackCustomCallsPass) + + void runOnOperation() final { + getOperation().walk([&](CustomCallOp customCall) { + if (!isPythonCallbackCustomCall(customCall)) { + return; + } + mlir::IRRewriter rewriter(customCall); + if (!customCall->use_empty()) { + replaceCallbackWithTupleVersion(customCall, rewriter); + return; + } + CustomCallOp newCustomCall = cloneCustomCallWithNewResultTypes( + customCall, mlir::TypeRange(), rewriter); + newCustomCall.setResultLayoutsAttr(rewriter.getArrayAttr({})); + rewriter.eraseOp(customCall); + return; + }); + } + + StringRef getArgument() const override { + return "xla-sdy-mhlo-round-trip-export-callback-custom-calls"; + } + + StringRef getDescription() const override { + return "Converts the `CustomCallOp`s for host callbacks in XLA into the " + "pattern that the XLA compiler recognizes."; + } + + void getDependentDialects(mlir::DialectRegistry& registry) const final { + registry.insert(); + } +}; + +} // namespace + +std::unique_ptr createMhloRoundTripExportCallbackCustomCallsPass() { + return std::make_unique(); +} + +void registerMhloRoundTripExportCallbackCustomCallsPass() { + mlir::registerPass(createMhloRoundTripExportCallbackCustomCallsPass); +} + +} // namespace sdy +} // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h new file mode 100644 index 00000000000000..b67955f7a80212 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h @@ -0,0 +1,42 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_EXPORT_CALLBACK_CUSTOM_CALLS_H_ +#define XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_EXPORT_CALLBACK_CUSTOM_CALLS_H_ + +#include + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" + +namespace xla { +namespace sdy { + +// Creates a pass that converts the `CustomCallOp`s for host callbacks in XLA +// into the pattern that the XLA compiler recognizes. +// +// The rest of the XLA pipeline expects host callback custom calls to either be +// a tuple with a get_tuple_element or no results (which we changed due to +// shardy shardings expecting at least one result, and needing to attach a +// maximal sharding to the callbacks). +std::unique_ptr createMhloRoundTripExportCallbackCustomCallsPass(); + +// Registers the xla-sdy-mhlo-round-trip-export-callback-custom-calls pass. +void registerMhloRoundTripExportCallbackCustomCallsPass(); + +} // namespace sdy +} // namespace xla + +#endif // XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_EXPORT_CALLBACK_CUSTOM_CALLS_H_ diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc index fbc7beca1bf085..bc93d37128c31a 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -45,6 +46,7 @@ limitations under the License. #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/sharding_op_util.h" @@ -54,6 +56,7 @@ namespace sdy { namespace { +namespace stablehlo = ::mlir::stablehlo; namespace mhlo = ::mlir::mhlo; using ::mlir::ConversionPatternRewriter; @@ -73,7 +76,7 @@ using ::mlir::sdy::ShardingConstraintOp; using ::mlir::sdy::TensorShardingAttr; using ::mlir::sdy::TensorShardingPerValueAttr; -// Converts `sdy::ConstantOp` to `mhlo::ConstantOp`. +// Converts `sdy::ConstantOp` to `stablehlo::ConstantOp`. class ConstantPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -82,7 +85,7 @@ class ConstantPattern : public OpConversionPattern { ConversionPatternRewriter& rewriter) const override { // We use the generic op builder so that unregistered attributes will be // added to the new op. - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op->getResultTypes(), adaptor.getOperands(), op->getAttrs()); return success(); } @@ -134,7 +137,7 @@ class ExportOpsPass // ShardingConstraintOp should be replaced by ReshardOp before this pass. // Hence, we add ShardingConstraintOp as an illegal op. target.addIllegalOp(); - target.addLegalOp(); + target.addLegalOp(); mlir::RewritePatternSet patterns(&context); // After converting `sdy.constant` into `mhlo.constant`, the constants // should not be deduped via folding. Fortunately, folding only happens in diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc index fc65c24cc623e8..bd5834c8249333 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc @@ -56,6 +56,7 @@ limitations under the License. #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/array.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" @@ -70,6 +71,7 @@ namespace sdy { namespace { using ::mlir::ArrayRef; +using ::mlir::DictionaryAttr; using ::mlir::LogicalResult; using ::mlir::ModuleOp; using ::mlir::OpBuilder; @@ -84,6 +86,8 @@ using ::mlir::success; using ::mlir::SymbolTable; using ::mlir::func::FuncOp; +using ::mlir::stablehlo::CustomCallOp; + using ::mlir::sdy::AxisRefAttr; using ::mlir::sdy::DimensionShardingAttr; using ::mlir::sdy::kShardingAttr; @@ -195,6 +199,7 @@ class ExportMhloShardingsPass void runOnOperation() final { ModuleOp moduleOp = getOperation(); + mlir::SymbolTableCollection symbolTableCollection; SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(moduleOp); @@ -206,6 +211,29 @@ class ExportMhloShardingsPass } } + moduleOp.walk([&](CustomCallOp customCall) { + // StableHLO doesn't have an equivalent of `erf` and `topk` ops. + // If they have a sharding annotation, we need to move it into + // `mhlo.attributes`, which StableHLO->MHLO conversion would lift back up. + StringRef callTargetName = customCall.getCallTargetName(); + if (callTargetName != "mhlo.erf" && callTargetName != "mhlo.topk") { + return; + } + // TODO(bartchr): refactor `addFrontendAttribute` to take a key for the + // dictionary attribute. Then can re-use the logic instead of duplicating + // it here for `kMhloAttributesAttr`. + if (auto sdySharding = + customCall->getAttrOfType(kXlaShardingAttr)) { + customCall->removeAttr(kXlaShardingAttr); + SmallVector newAttributes( + customCall->getAttrOfType(kMhloAttributesAttr) + .getValue()); + newAttributes.push_back( + builder.getNamedAttr(kXlaShardingAttr, sdySharding)); + customCall->setAttr(kMhloAttributesAttr, + builder.getDictionaryAttr(newAttributes)); + } + }); // Remove all mesh symbols for (MeshOp meshOp : llvm::make_early_inc_range(moduleOp.getOps())) { diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc index 36aee9a64f266b..232e8c4d09da2c 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/LLVM.h" +#include "xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h" #include "xla/service/spmd/shardy/mhlo_round_trip/export_ops.h" #include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h" #include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.h" @@ -36,6 +37,7 @@ void addMhloExportPipeline(mlir::OpPassManager& pm) { pm.addPass(createMhloRoundTripShardMapExportPass()); pm.addPass(createExportNamedComputationsPass()); pm.addPass(createExportMhloShardingsPass()); + pm.addPass(createMhloRoundTripExportCallbackCustomCallsPass()); } void registerMhloExportPipeline() { diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc index 1f0cff4c61a75c..f377c5b465872e 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc @@ -413,7 +413,7 @@ TensorShardingAttr convertToSdySharding( // device. if (hloSharding.HasUniqueDevice()) { return TensorShardingAttr::getFullyClosed( - ctx, rank, + ctx, /*rank=*/0, deviceIdToMaximalMeshName.lookup(hloSharding.GetUniqueDevice())); } CHECK(!hloSharding.IsTuple()); @@ -658,8 +658,8 @@ void addMhloImportPipeline(mlir::OpPassManager& pm, void registerMhloImportPipeline() { mlir::PassPipelineRegistration<> importPipeline( "xla-sdy-mhlo-import-pipeline", - "Run passes to import an mhlo module with `mhlo.shardings` into the SDY " - "(Shardy) dialect.", + "Run passes to import a StableHLO module with `mhlo.shardings` into the " + "SDY (Shardy) dialect.", std::bind(addMhloImportPipeline, std::placeholders::_1, ArrayRef(), ArrayRef())); } diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.cc index e70720f4e8aa1d..73f48c698f9939 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.cc +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.cc @@ -51,6 +51,7 @@ limitations under the License. #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -73,7 +74,7 @@ using ::mlir::StringAttr; using ::mlir::StringRef; using ::mlir::Value; using ::mlir::mhlo::CopyOp; -using ::mlir::mhlo::CustomCallOp; +using ::mlir::stablehlo::CustomCallOp; namespace sdy = ::mlir::sdy; using sdy::kShardingAttr; diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.cc index a8098832a71d5a..d12f194e023f46 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.cc +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.cc @@ -53,7 +53,7 @@ limitations under the License. #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/xla_data.pb.h" @@ -73,7 +73,7 @@ using ::mlir::StringRef; using ::mlir::Value; using ::mlir::func::CallOp; using ::mlir::func::FuncOp; -using ::mlir::mhlo::CustomCallOp; +using ::mlir::stablehlo::CustomCallOp; namespace sdy = ::mlir::sdy; using sdy::AxisRefAttr; diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD b/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD index af119242aa3437..b3ab4176a0be73 100644 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD @@ -19,7 +19,6 @@ cc_library( hdrs = ["import_sdy_custom_calls.h"], deps = [ "//xla:sharding_op_util", - "//xla/mlir_hlo", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", "@com_google_absl//absl/log:check", @@ -29,6 +28,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) @@ -55,6 +55,7 @@ cc_library( "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -86,7 +87,6 @@ cc_library( srcs = ["open_while_free_vars_sharding.cc"], hdrs = ["open_while_free_vars_sharding.h"], deps = [ - "//xla/mlir_hlo", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -94,6 +94,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) @@ -109,6 +110,7 @@ cc_library( "//xla/mlir_hlo:mhlo_passes", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.cc index 57a50d928d3bde..b2c0e517e7430d 100644 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.cc +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include #include #include "absl/log/check.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -72,7 +72,7 @@ class BackendFuncCallPattern : public OpConversionPattern { FuncOp func = symbolTable.lookup(adaptor.getCallee()); CHECK(func) << "Failed to lookup function: " - << std::string_view(adaptor.getCallee()); + << absl::string_view(adaptor.getCallee()); mlir::SmallVector namedCompAttrs; llvm::copy_if(callOp->getDiscardableAttrs(), std::back_inserter(namedCompAttrs), diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_constants.h b/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_constants.h index 3de4603894bb9b..a83869ca3e93b0 100644 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_constants.h +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_constants.h @@ -23,8 +23,8 @@ limitations under the License. namespace xla { namespace sdy { -// Creates a pass that converts an `mhlo.constant` (which is foldable) into an -// `sdy.constant` (which isn't foldable). +// Creates a pass that converts a `stablehlo.constant` (which is foldable) into +// an `sdy.constant` (which isn't foldable). std::unique_ptr createImportConstantsPass(); // Register the xla-sdy-import-constants pass. diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc index 8172a217e30a91..4a36c2ba3b1583 100644 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc @@ -35,7 +35,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/service/spmd/shardy/utils.h" #include "xla/sharding_op_util.h" @@ -47,11 +47,11 @@ namespace { using ::mlir::IntegerAttr; using ::mlir::StringRef; -using ::mlir::mhlo::CustomCallOp; using ::mlir::sdy::ShardingConstraintOp; using ::mlir::sdy::ShardingGroupOp; using ::mlir::sdy::TensorShardingAttr; -using ::mlir::mhlo::CustomCallOpAdaptor; +using ::mlir::stablehlo::CustomCallOp; +using ::mlir::stablehlo::CustomCallOpAdaptor; mlir::LogicalResult rewriteShardingCustomCall( CustomCallOp op, CustomCallOpAdaptor adaptor, diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc index 603b270eefa46f..6fe201ccb4fb4d 100644 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc @@ -28,7 +28,7 @@ limitations under the License. #include "mlir/Transforms/RegionUtils.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "stablehlo/dialect/StablehloOps.h" namespace xla { namespace sdy { @@ -49,7 +49,7 @@ class OpenWhileFreeVarsShardingPass FuncOp funcOp = getOperation(); mlir::IRRewriter rewriter(funcOp); - funcOp.walk([&](mlir::mhlo::WhileOp op) { + funcOp.walk([&](mlir::stablehlo::WhileOp op) { llvm::SetVector freeVars; mlir::getUsedValuesDefinedAbove(op->getRegions(), freeVars); rewriter.setInsertionPoint(op); diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc index 68592c1918a3e3..e8970270353550 100644 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc @@ -17,6 +17,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.h" @@ -31,24 +32,34 @@ using ::mlir::func::FuncOp; void addCommonPreImportPasses(mlir::OpPassManager& pm) { pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); // TODO(b/333505182): remove when partitioning is done in SDY. // We call prepare-for-export pass before SDY propagation, so that all IR // changes happen before shardings are added to operations, to ensure the // correct shardings are added and that they are not lost by this pass. pm.addNestedPass(mlir::mhlo::createPrepareForExportPass()); - // We import `mhlo.constant` ops to `sdy.constant` ops so that constants // aren't folded in greedy pattern rewriters, which would lift them outside of // nested regions (this undoes `WhileLoopConstantSinking` HLO pass). - // Therefore, this pass needs to be applied after any mhlo pass that expects - // `mhlo.constant`, and before any pass that has a greedy pattern rewriter. + // Therefore, this pass needs to be applied after any MHLO pass that + // expects `mhlo.constant`, and before any pass that has a greedy pattern + // rewriter. pm.addNestedPass(createImportConstantsPass()); - pm.addNestedPass(mlir::mhlo::createFlattenTuplePass()); // We need to canonicalize redundant mhlo::GetTupleElementOp and // mhlo::GetTupleOp. We also need to canonicalize mhlo::WhileOp before // `createOpenWhileFreeVarsShardingPass`. - pm.addPass(mlir::createCanonicalizerPass()); + mlir::GreedyRewriteConfig config; + config.useTopDownTraversal = true; + config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; + config.fold = false; + config.cseConstants = false; + // TODO(tomnatan): consider only enabling the specific passes we need. + pm.addPass(mlir::createCanonicalizerPass(config)); + // Shardy is currently operating on stablehlo, since this is what JAX + // emits. Long term shardy will be fully dialect agnostic, and both mhlo + // and stablehlo can register their ops for sdy propagation. + pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); } void addCommonPostImportPasses(mlir::OpPassManager& pm) { diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc b/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc index 7f2dff488a7f00..f994526846cb80 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc @@ -14,15 +14,15 @@ limitations under the License. ==============================================================================*/ #include "mlir/Dialect/Func/Extensions/AllExtensions.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/register.h" #include "shardy/dialect/sdy/transforms/passes.h" -#include "stablehlo/dialect/StablehloOps.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h" #include "xla/service/spmd/shardy/mhlo_round_trip/export_ops.h" #include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h" #include "xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.h" @@ -36,6 +36,7 @@ limitations under the License. #include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.h" +#include "xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h" #include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h" #include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" #include "xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h" @@ -49,8 +50,8 @@ int main(int argc, char** argv) { mlir::mhlo::registerAllMhloPasses(); mlir::DialectRegistry dialects; - dialects.insert(); + mlir::sdy::registerAllDialects(dialects); + dialects.insert(); mlir::func::registerAllExtensions(dialects); // Register all SDY passes and pipelines. @@ -66,12 +67,14 @@ int main(int argc, char** argv) { xla::sdy::registerMhloExportPipeline(); xla::sdy::registerMhloExportShardingsPass(); + xla::sdy::registerMhloRoundTripExportCallbackCustomCallsPass(); xla::sdy::registerMhloRoundTripShardMapExportPass(); xla::sdy::registerExportNamedComputationsPass(); xla::sdy::registerExportOpsPass(); xla::sdy::registerSdyRoundTripMhloToHloToMhloPass(); xla::sdy::registerSdyRoundTripExportShardyAttrsPass(); + xla::sdy::registerSdyRoundTripImportCallbackCustomCallsPass(); xla::sdy::registerSdyRoundTripImportShardyAttrsPass(); xla::sdy::registerSdyRoundTripRemoveSizeOneAxesPass(); xla::sdy::registerSdyRoundTripExportOpsPass(); diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD index 25c3928386d1df..3d5f950c31d92c 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD @@ -9,6 +9,7 @@ package( package_group( name = "friends", packages = [ + "//learning/deepmind/partir/compiler/mpmd/...", "//learning/deepmind/partir/compiler/shardonnay/...", "//third_party/openxla/shardy/tools/...", "//xla/...", @@ -20,7 +21,6 @@ cc_library( srcs = ["export_shardy_attrs.cc"], hdrs = ["export_shardy_attrs.h"], deps = [ - "//xla/mlir_hlo", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", "@llvm-project//llvm:Support", @@ -30,6 +30,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) @@ -38,7 +39,6 @@ cc_library( srcs = ["export_ops.cc"], hdrs = ["export_ops.h"], deps = [ - "//xla/mlir_hlo", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", "@llvm-project//llvm:Support", @@ -47,6 +47,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) @@ -55,8 +56,6 @@ cc_library( srcs = ["import_shardy_attrs.cc"], hdrs = ["import_shardy_attrs.h"], deps = [ - "//xla/mlir_hlo", - "//xla/mlir_hlo:mhlo_passes", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", "@llvm-project//llvm:Support", @@ -68,6 +67,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) @@ -94,7 +94,6 @@ cc_library( srcs = ["shard_map_import.cc"], hdrs = ["shard_map_import.h"], deps = [ - "//xla/mlir_hlo", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", "@com_google_absl//absl/log:check", @@ -106,6 +105,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) @@ -116,6 +116,7 @@ cc_library( deps = [ "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -125,6 +126,22 @@ cc_library( ], ) +cc_library( + name = "import_callback_custom_calls", + srcs = ["import_callback_custom_calls.cc"], + hdrs = ["import_callback_custom_calls.h"], + deps = [ + "//xla/service/spmd/shardy:utils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", + ], +) + cc_library( name = "pipelines", srcs = ["pipelines.cc"], @@ -132,6 +149,7 @@ cc_library( deps = [ ":export_ops", ":export_shardy_attrs", + ":import_callback_custom_calls", ":import_shardy_attrs", ":remove_size_one_axes", ":shard_map_export", @@ -142,6 +160,5 @@ cc_library( "//xla/service/spmd/shardy/round_trip_common:pipeline_passes", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", ], ) diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc index 67c4bc63b86802..50f31670e7b40c 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc @@ -40,11 +40,11 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/service/spmd/shardy/utils.h" -namespace mhlo = ::mlir::mhlo; +namespace stablehlo = ::mlir::stablehlo; namespace xla { namespace sdy { @@ -67,7 +67,7 @@ using ::mlir::sdy::ShardingGroupOp; using ::mlir::sdy::TensorShardingAttr; using ::mlir::sdy::TensorShardingPerValueAttr; -// Converts `sdy::ConstantOp` to `mhlo::ConstantOp`. +// Converts `sdy::ConstantOp` to `stablehlo::ConstantOp`. class ConstantPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -76,7 +76,7 @@ class ConstantPattern : public OpConversionPattern { ConversionPatternRewriter& rewriter) const override { // We use the generic op builder so that unregistered attributes will be // added to the new op. - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op->getResultTypes(), adaptor.getOperands(), op->getAttrs()); return success(); } @@ -93,7 +93,7 @@ class ShardingConstraintPattern ConversionPatternRewriter& rewriter) const override { TensorShardingAttr sharding = op.getSharding(); - auto customCallOp = rewriter.replaceOpWithNewOp( + auto customCallOp = rewriter.replaceOpWithNewOp( op, op.getType(), adaptor.getInput()); customCallOp.setCallTargetName(kShardingCustomCallTargetName); @@ -117,11 +117,11 @@ class ShardingGroupPattern : public OpConversionPattern { LogicalResult matchAndRewrite( ShardingGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - auto customCallOp = rewriter.replaceOpWithNewOp( + auto customCallOp = rewriter.replaceOpWithNewOp( op, op->getResultTypes(), adaptor.getInput()); customCallOp.setCallTargetName(kShardingGroupCustomCallTargetName); - addFrontendAttribute(customCallOp, kShardingGroupIdAttr, + setFrontendAttribute(customCallOp, kShardingGroupIdAttr, op.getGroupIdAttr()); customCallOp.setHasSideEffectAttr(rewriter.getBoolAttr(true)); return success(); @@ -137,7 +137,7 @@ class SdyRoundTripExportOpsPass mlir::MLIRContext& context = getContext(); mlir::ConversionTarget target(context); target.addIllegalOp(); - target.addLegalOp(); + target.addLegalOp(); mlir::RewritePatternSet patterns(&context); patterns .add( diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.cc index 8474d3efb0e6e2..c6de645bf60fcf 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.cc @@ -43,7 +43,7 @@ limitations under the License. #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/service/spmd/shardy/utils.h" @@ -66,7 +66,7 @@ using ::mlir::StringRef; using ::mlir::Value; using ::mlir::func::FuncOp; -using ::mlir::mhlo::CustomCallOp; +using ::mlir::stablehlo::CustomCallOp; using ::mlir::sdy::kShardingAttr; using ::mlir::sdy::kShardingRuleAttr; @@ -79,7 +79,7 @@ using ::mlir::sdy::TensorShardingPerValueAttr; // the `op`. void saveOpShardingPerValueAttr( Operation* op, TensorShardingPerValueAttr shardingPerValueAttr) { - addFrontendAttribute(op, kShardingRoundTripAttr, shardingPerValueAttr); + setFrontendAttribute(op, kShardingRoundTripAttr, shardingPerValueAttr); } // Converts the shardings from `kShardingAttr` into @@ -88,7 +88,7 @@ LogicalResult exportFunc(FuncOp funcOp, OpBuilder& builder) { for (int64_t argNum = 0; argNum < funcOp.getNumArguments(); ++argNum) { if (auto oldSharding = funcOp.getArgAttrOfType( argNum, kShardingAttr)) { - addFrontendAttribute(funcOp, kShardingRoundTripAttr, oldSharding, argNum); + setFrontendAttribute(funcOp, kShardingRoundTripAttr, oldSharding, argNum); } } @@ -126,7 +126,7 @@ LogicalResult exportFunc(FuncOp funcOp, OpBuilder& builder) { } if (auto oldShardingRule = op->getAttrOfType(kShardingRuleAttr)) { - addFrontendAttribute(op, kShardingRuleRoundTripAttr, oldShardingRule); + setFrontendAttribute(op, kShardingRuleRoundTripAttr, oldShardingRule); op->removeAttr(kShardingRuleAttr); } }); @@ -159,7 +159,7 @@ class SdyRoundTripExportShardyAttrsPass mhloMeshes.emplace_back(meshOp.getSymNameAttr(), meshOp.getMeshAttr()); } if (!mhloMeshes.empty()) { - addFrontendAttribute(moduleOp, kMeshesRoundTripAttr, + setFrontendAttribute(moduleOp, kMeshesRoundTripAttr, DictionaryAttr::get(context, mhloMeshes)); } } @@ -177,7 +177,7 @@ class SdyRoundTripExportShardyAttrsPass } void getDependentDialects(mlir::DialectRegistry& registry) const final { - registry.insert(); + registry.insert(); } }; diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.cc new file mode 100644 index 00000000000000..0fa3f44d8204af --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.cc @@ -0,0 +1,91 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h" + +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "xla/service/spmd/shardy/utils.h" + +namespace xla { +namespace sdy { + +namespace { + +using ::mlir::ModuleOp; +using ::mlir::StringRef; +using ::mlir::stablehlo::CustomCallOp; + +class SdyRoundTripImportCallbackCustomCallsPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + SdyRoundTripImportCallbackCustomCallsPass) + + void runOnOperation() final { + getOperation().walk([&](CustomCallOp op) { + if (op->getNumResults() != 0 || !isPythonCallbackCustomCall(op)) { + return; + } + mlir::IRRewriter rewriter(op); + // Shardy needs at least one op result to have a sharding annotation. + // Since the callback has no results, and we need to say the callbacks + // have a maximal sharding, we add a dummy result and set the result + // layout to the 0th operand layout. + CustomCallOp newCustomCall = cloneCustomCallWithNewResultTypes( + op, op->getOperand(0).getType(), rewriter); + newCustomCall.setResultLayoutsAttr(rewriter.getArrayAttr( + {op.getOperandLayoutsAttr().getValue().front()})); + rewriter.eraseOp(op); + }); + } + + StringRef getArgument() const override { + return "xla-sdy-round-trip-import-callback-custom-calls"; + } + + StringRef getDescription() const override { + return "Modifies the return types of XLA host callback custom calls to be " + "compatible with SDY"; + } + + void getDependentDialects(mlir::DialectRegistry& registry) const final { + registry.insert(); + } +}; + +} // namespace + +std::unique_ptr createSdyRoundTripImportCallbackCustomCallsPass() { + return std::make_unique(); +} + +void registerSdyRoundTripImportCallbackCustomCallsPass() { + mlir::registerPass(createSdyRoundTripImportCallbackCustomCallsPass); +} + +} // namespace sdy +} // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h new file mode 100644 index 00000000000000..ce81f5ead47191 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h @@ -0,0 +1,41 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_CALLBACK_CUSTOM_CALLS_H_ +#define XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_CALLBACK_CUSTOM_CALLS_H_ + +#include + +#include "mlir/Pass/Pass.h" + +namespace xla { +namespace sdy { + +// Creates the pass to modify the return types of XLA host callback custom calls +// to be compatible with SDY. +// +// Shardy shardings require an op to have at least one result, and the XLA host +// callback custom calls are not guaranteed to return a value. +// To allow the custom calls to have a maximal sharding, we change the return +// type to return a dummy value. +std::unique_ptr createSdyRoundTripImportCallbackCustomCallsPass(); + +// Registers the xla-sdy-round-trip-import-callback-custom-calls pass. +void registerSdyRoundTripImportCallbackCustomCallsPass(); + +} // namespace sdy +} // namespace xla + +#endif // XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_CALLBACK_CUSTOM_CALLS_H_ diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc index 26f3539163b15f..b69302532b419d 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc @@ -45,8 +45,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/service/spmd/shardy/utils.h" @@ -66,8 +65,6 @@ using ::mlir::StringRef; using ::mlir::SymbolTable; using ::mlir::func::FuncOp; -using ::mlir::mhlo::CustomCallOp; - using ::mlir::sdy::kShardingAttr; using ::mlir::sdy::kShardingRuleAttr; using ::mlir::sdy::MeshAttr; @@ -75,6 +72,8 @@ using ::mlir::sdy::OpShardingRuleAttr; using ::mlir::sdy::TensorShardingAttr; using ::mlir::sdy::TensorShardingPerValueAttr; +namespace stablehlo = ::mlir::stablehlo; + // Builds the shardy attributes coming from Shardy previously. This means // the module was exported from Shardy and we are now round-tripping back. // This should happen after the meshes were created from the `ModuleOp` attrs @@ -109,13 +108,19 @@ void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter) { if (!dictAttr) { return; } + // `SendOp` and `RecvOp` can have a sharding when doing TPU callbacks + // through JAX. + if (mlir::isa(op)) { + op->setAttr(kShardingAttr, parseStringAttr( + dictAttr, kShardingRoundTripAttr)); + } // NOTE: we are only setting the sharding on known custom-calls. For any // other op that has a `kShardingRoundTripAttr` we discard it. XLA sometimes // creates new instructions, copying over the operand's frontend attrs, // which may mean the shapes are wrong when the new instruction is a reshape // for example. This does mean we can't fully round-trip b/w HLO and MLIR // after SDY propagation. - if (auto customCallOp = mlir::dyn_cast(op)) { + if (auto customCallOp = mlir::dyn_cast(op)) { StringRef targetName = customCallOp.getCallTargetName(); if (targetName == kFuncResultShardingTargetName) { // This is a temporary CustomCallOp that holds the sharding from a @@ -140,7 +145,8 @@ void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter) { } if (targetName == kShardingCustomCallTargetName || targetName == kSPMDFullToShardShapeCallTargetName || - targetName == kSPMDShardToFullShapeCallTargetName) { + targetName == kSPMDShardToFullShapeCallTargetName || + isPythonCallbackCustomCall(customCallOp)) { customCallOp->setAttr(kShardingAttr, parseStringAttr( dictAttr, kShardingRoundTripAttr)); diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc index 32e15074c843a1..0f92d457152cf4 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/service/spmd/shardy/round_trip_common/pipeline_passes.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.h" +#include "xla/service/spmd/shardy/sdy_round_trip/import_callback_custom_calls.h" #include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h" #include "xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h" #include "xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h" @@ -49,6 +50,7 @@ void addSdyRoundTripExportPipeline(mlir::OpPassManager& pm) { void addSdyRoundTripImportPipeline(mlir::OpPassManager& pm) { addCommonPreImportPasses(pm); + pm.addPass(createSdyRoundTripImportCallbackCustomCallsPass()); pm.addPass(createSdyRoundTripImportShardyAttrsPass()); pm.addPass(createSdyRoundTripShardMapImportPass()); pm.addPass(createSdyRoundTripRemoveSizeOneAxesPass()); diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.cc index 06a383f1fefafd..bee62bff1a3602 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include #include -#include #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -53,7 +53,6 @@ using ::mlir::StringRef; using ::mlir::SymbolTable; using ::mlir::sdy::AxisRefAttr; using ::mlir::sdy::DimensionShardingAttr; -using ::mlir::sdy::getMeshAttr; using ::mlir::sdy::ManualAxesAttr; using ::mlir::sdy::ManualComputationOp; using ::mlir::sdy::MeshAttr; @@ -76,7 +75,7 @@ MeshAttr removeSizeOneAxes(MeshAttr mesh) { TensorShardingAttr removeSizeOneAxes(TensorShardingAttr sharding, const SymbolTable& symbolTable) { MeshAttr mesh = sharding.getMesh(symbolTable); - CHECK(mesh) << "unknown mesh: " << std::string_view(sharding.getMeshName()); + CHECK(mesh) << "unknown mesh: " << absl::string_view(sharding.getMeshName()); auto isNotSizeOne = [&](AxisRefAttr axis) { return axis.getSize(mesh) != 1; }; diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.cc index 16d9397ed16ee7..dda4aec8eed052 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.cc @@ -98,11 +98,11 @@ class SdyRoundTripShardMapExportPass auto callOp = rewriter.create(loc, localResultTypes, funcName, operands); - addFrontendAttribute(callOp, kInShardings, + setFrontendAttribute(callOp, kInShardings, manualComputation.getInShardings()); - addFrontendAttribute(callOp, kOutShardings, + setFrontendAttribute(callOp, kOutShardings, manualComputation.getOutShardings()); - addFrontendAttribute(callOp, kManualAxes, + setFrontendAttribute(callOp, kManualAxes, manualComputation.getManualAxesAttr()); mlir::ResultRange results = manualComputation->getResults(); diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc index c4e75e44cee0ad..a645b25a551a4e 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc @@ -44,7 +44,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/service/spmd/shardy/utils.h" @@ -60,7 +60,7 @@ using ::mlir::StringRef; using ::mlir::SymbolTable; using ::mlir::func::CallOp; using ::mlir::func::FuncOp; -using ::mlir::mhlo::CustomCallOp; +using ::mlir::stablehlo::CustomCallOp; namespace sdy = ::mlir::sdy; diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD index 7969e578c6d884..62f6470ae00a5e 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD @@ -20,26 +20,21 @@ cc_library( srcs = ["mhlo_to_hlo_to_mhlo.cc"], hdrs = ["mhlo_to_hlo_to_mhlo.h"], deps = [ - "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/hlo/translate:stablehlo", "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", - "//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", - "//xla/mlir_hlo", "//xla/mlir_hlo:mhlo_passes", "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:errors", - "@shardy//shardy/dialect/sdy/ir:dialect", - "@stablehlo//:stablehlo_ops", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc index da7bda8f60e3b9..adcb9251dcca9d 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc @@ -18,32 +18,26 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/Quant/IR/Quant.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/OwningOpRef.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/TypeID.h" -#include "shardy/dialect/sdy/ir/dialect.h" -#include "stablehlo/dialect/StablehloOps.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" -#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/hlo/translate/stablehlo.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" -#include "xla/shape.h" -#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace sdy { @@ -53,35 +47,22 @@ namespace { using ::mlir::ModuleOp; using ::mlir::StringRef; -// Converts an MHLO module to an HLO module. +// Converts a StableHLO module to an HLO module. absl::StatusOr> toHlo(ModuleOp module) { - absl::StatusOr> hloModule; - xla::HloProto hloProto; - TF_RETURN_IF_ERROR(ConvertMlirHloToHlo(module, &hloProto, - /*use_tuple_args=*/false, - /*return_tuple=*/false)); - xla::HloModuleConfig moduleConfig; - xla::ProgramShape expectedProgramShape( - hloProto.hlo_module().host_program_shape()); - moduleConfig.SetDefaultComputationLayout(expectedProgramShape); - moduleConfig.set_use_spmd_partitioning(true); - return xla::HloModule::CreateFromProto(hloProto.hlo_module(), moduleConfig); + TF_ASSIGN_OR_RETURN(std::unique_ptr hloModule, + xla::ConvertStablehloToHlo(module)); + hloModule->mutable_config().set_use_spmd_partitioning(true); + return hloModule; } -// Converts an HLO module to an MHLO module. -absl::Status toMhlo(std::unique_ptr hloModule, ModuleOp module) { - // Delete the functions, which can be more than one due to preserving - // the shmap_body functions. - mlir::SymbolTableCollection symbolTableCollection; - mlir::SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(module); - for (mlir::Operation& op : - llvm::make_early_inc_range(module.getBodyRegion().getOps())) { - symbolTable.erase(&op); - } - TF_RETURN_IF_ERROR( - xla::ConvertHloToMlirHlo(module, hloModule.get(), - /*import_all_computations=*/false, - /*flatten_computation_args_result=*/true)); +// Converts an HLO module to a StableHLO module. +absl::Status toStablehlo(std::unique_ptr hloModule, + ModuleOp& module) { + TF_ASSIGN_OR_RETURN( + mlir::OwningOpRef newModule, + xla::ConvertHloToStablehlo(*module->getContext(), hloModule.get())); + // Erase the old body region and replace it with the new one. + module.getBodyRegion().takeBody(newModule.get().getBodyRegion()); return absl::OkStatus(); } @@ -94,18 +75,18 @@ class SdyRoundTripMhloToHloToMhloPass private: void runOnOperation() final { ModuleOp module = getOperation(); - // 1. MHLO -> HLO + // 1. StableHLO -> HLO absl::StatusOr> hloModule = toHlo(module); if (!hloModule.ok()) { - module.emitError(absl::StrCat("Failed to convert to HLO from MHLO: ", + module.emitError(absl::StrCat("Failed to convert to HLO from StableHLO: ", hloModule.status().message())); return signalPassFailure(); } - // 2. HLO -> MHLO - if (absl::Status status = toMhlo(std::move(*hloModule), module); + // 2. HLO -> StableHLO + if (absl::Status status = toStablehlo(std::move(*hloModule), module); !status.ok()) { - module.emitError(absl::StrCat("Failed to convert to MHLO from HLO: ", + module.emitError(absl::StrCat("Failed to convert to StableHLO from HLO: ", status.message())); return signalPassFailure(); } @@ -116,13 +97,11 @@ class SdyRoundTripMhloToHloToMhloPass } StringRef getDescription() const override { - return "Round trips from MHLO -> HLO -> MHLO."; + return "Round trips from MHLO -> StableHLO -> MHLO."; } void getDependentDialects(mlir::DialectRegistry& registry) const final { - registry.insert(); + xla::RegisterMlirToHloDependentDialects(registry); } }; diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc index ae6fe74dc3b7ab..3d6f4ac4f1692a 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc @@ -20,11 +20,9 @@ limitations under the License. #include #include #include -#include #include #include -#include "mhlo/transforms/passes.h" #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -32,6 +30,8 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "mlir/IR/BuiltinAttributes.h" @@ -51,6 +51,7 @@ limitations under the License. #include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" +#include "xla/hlo/translate/stablehlo.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/layout.h" #include "xla/map_util.h" @@ -78,6 +79,15 @@ namespace sdy { namespace { +std::string uniqueModuleName(const HloModule& module) { + std::string result; + absl::StrAppendFormat(&result, "module_%04d", module.unique_id()); + if (!module.name().empty()) { + absl::StrAppend(&result, ".", module.name()); + } + return result; +} + // Creates a vector of HloComputation, which is used to replace the old // computations in the HloModule. It is adapted from CreateAndSanitizeFromProto // in internal xla/tests/fuzzing/hlo_fuzzer_utils.cc. @@ -298,17 +308,12 @@ absl::StatusOr ShardyXLA::Run( const absl::flat_hash_set& executionThreads) { LOG(INFO) << "Using Shardy for XLA SPMD propagation."; - // HLO -> MLIR MHLO + // HLO -> StableHLO auto mlirContext = std::make_unique(); loadAllRequiredDialects(mlirContext.get()); - mlir::OwningOpRef mlirModule = - xla::llvm_ir::CreateMlirModuleOp( - mlir::UnknownLoc::get(mlirContext.get())); - TF_RETURN_IF_ERROR( - ConvertHloToMlirHlo(*mlirModule, hloModule, - /*import_all_computations=*/false, - /*flatten_computation_args_result=*/true)); - + TF_ASSIGN_OR_RETURN( + mlir::OwningOpRef mlirModule, + xla::ConvertHloToStablehlo(*mlirContext.get(), hloModule)); std::string shardyDir = hloModule->config().debug_options().xla_dump_to(); if (shardyDir == "sponge") { @@ -324,8 +329,7 @@ absl::StatusOr ShardyXLA::Run( if (!shardyDir.empty()) { shardyDir = - tsl::io::JoinPath(shardyDir, "shardy", - std::string_view(mlirModule->getName().value_or(""))); + tsl::io::JoinPath(shardyDir, "shardy", uniqueModuleName(*hloModule)); LOG(INFO) << "Using Shardy output directory: " << shardyDir; } @@ -382,17 +386,12 @@ absl::StatusOr ShardyXLA::Run( useTupleArgs); if (runSdyShardingPropagation) { - // Shardy is currently operating on stablehlo, since this is what JAX - // emits. Long term shardy will be fully dialect agnostic, and both mhlo - // and stablehlo can register their ops for sdy propagation. - pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); // NOTE: if we are using auto-spmd, we will use conservative propagation // since the TOAST cost model cannot account for split axes or padding. mlir::sdy::PropagationOptions options; options.dumpDirectory = shardyDir; options.conservativePropagation = hloModule->use_auto_spmd_partitioning(); mlir::sdy::addPropagationPipeline(pm, options); - pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); } addMhloExportPipeline(pm); pm.addPass(mlir::sdy::createSaveModuleOpPass(shardyDir, @@ -400,10 +399,10 @@ absl::StatusOr ShardyXLA::Run( tsl::StatusScopedDiagnosticHandler diagnosticHandler(mlirContext.get()); TF_RETURN_IF_ERROR(diagnosticHandler.consumeStatus(pm.run(*mlirModule))); - // MLIR MHLO -> HLO + // StableHlo -> HLO HloProto hloProto; - TF_RETURN_IF_ERROR(ConvertMlirHloToHlo(*mlirModule, &hloProto, useTupleArgs, - /*return_tuple=*/false)); + TF_RETURN_IF_ERROR(ConvertStablehloWithManyArgsToHloProto( + *mlirModule, &hloProto, useTupleArgs)); TF_RETURN_IF_ERROR( createFromProtoAndReplaceComputations(hloModule, hloProto.hlo_module())); diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc index 6cb846048cff7b..1dea14f81ece92 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc @@ -22,10 +22,10 @@ limitations under the License. #include "absl/log/log.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/statusor.h" namespace op = xla::testing::opcode_matchers; @@ -560,9 +560,8 @@ TEST_F(ShardyXLATest, WhileWithFreeVariables) { op::Sharding("{devices=[2,1,2]<=[4] last_tile_dim_replicate}")); // Verify the sharding of the while, and specifically that the sharding of the // result that corresponds to parameter(1) is further sharded. - EXPECT_THAT(whileInst, - op::Sharding("{{devices=[2,2]<=[4]}, {replicated}, {replicated}, " - "{devices=[2,2]<=[4]}, {replicated}}")); + EXPECT_THAT(whileInst, op::Sharding("{{devices=[2,2]<=[4]}, {replicated}, " + "{devices=[2,2]<=[4]}}")); } TEST_F(ShardyXLATest, ShardMap) { diff --git a/third_party/xla/xla/service/spmd/shardy/test/import_backend_func_calls.mlir b/third_party/xla/xla/service/spmd/shardy/test/import_backend_func_calls.mlir index 35c4d62e8d099d..9ab41e20ce0a19 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/import_backend_func_calls.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/import_backend_func_calls.mlir @@ -5,41 +5,41 @@ sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]> // CHECK-LABEL: func @no_out_shardings func.func @no_out_shardings(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { // CHECK-NEXT: %[[NC:.*]] = sdy.named_computation<"foo">(%arg0) (%arg1: tensor<8x2xi32>) { - // CHECK-NEXT: %[[MULT:.*]] = mhlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> + // CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> // CHECK-NEXT: sdy.return %[[MULT]] : tensor<8x2xi32> // CHECK-NEXT: } {mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}, // CHECK-SAME: random_attr = "random_value"} // CHECK-SAME: (tensor<8x2xi32>) -> tensor<8x2xi32> - // CHECK-NEXT: %[[MOVE_TO_HOST:.*]] = mhlo.custom_call @MoveToHost(%[[NC]]) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + // CHECK-NEXT: %[[MOVE_TO_HOST:.*]] = stablehlo.custom_call @MoveToHost(%[[NC]]) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> // CHECK-NEXT: return %[[MOVE_TO_HOST]] : tensor<8x2xi32> %0 = call @foo(%arg0) {random_attr = "random_value", mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}} : (tensor<8x2xi32>) -> tensor<8x2xi32> - %1 = mhlo.custom_call @MoveToHost(%0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + %1 = stablehlo.custom_call @MoveToHost(%0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> return %1 : tensor<8x2xi32> } func.func private @foo(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { - %0 = mhlo.multiply %arg0, %arg0 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> + %0 = stablehlo.multiply %arg0, %arg0 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> return %0 : tensor<8x2xi32> } // CHECK-LABEL: func @out_shardings func.func @out_shardings(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { // CHECK-NEXT: %[[NC:.*]] = sdy.named_computation<"bar">(%arg0) out_shardings=[<@mesh, [{"x"}, {"y"}]>] (%arg1: tensor<8x2xi32>) { - // CHECK-NEXT: %[[MULT:.*]] = mhlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> + // CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> // CHECK-NEXT: sdy.return %[[MULT]] : tensor<8x2xi32> // CHECK-NEXT: } {mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}, // CHECK-SAME: random_attr = "random_value"} // CHECK-SAME: (tensor<8x2xi32>) -> tensor<8x2xi32> - // CHECK-NEXT: %[[MOVE_TO_HOST:.*]] = mhlo.custom_call @MoveToHost(%[[NC]]) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + // CHECK-NEXT: %[[MOVE_TO_HOST:.*]] = stablehlo.custom_call @MoveToHost(%[[NC]]) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> // CHECK-NEXT: return %[[MOVE_TO_HOST]] : tensor<8x2xi32> %0 = call @bar(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>, random_attr = "random_value", mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}} : (tensor<8x2xi32>) -> tensor<8x2xi32> - %1 = mhlo.custom_call @MoveToHost(%0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + %1 = stablehlo.custom_call @MoveToHost(%0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> return %1 : tensor<8x2xi32> } // NOTE: we ignore any arg/result shardings on the function. func.func private @bar(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}) { - %0 = mhlo.multiply %arg0, %arg0 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> + %0 = stablehlo.multiply %arg0, %arg0 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> return %0 : tensor<8x2xi32> } @@ -53,6 +53,6 @@ func.func @no_backend_config(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.shardin } func.func private @baz(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { - %0 = mhlo.multiply %arg0, %arg0 : tensor<8x2xi32> + %0 = stablehlo.multiply %arg0, %arg0 : tensor<8x2xi32> return %0 : tensor<8x2xi32> } diff --git a/third_party/xla/xla/service/spmd/shardy/test/import_shardings.mlir b/third_party/xla/xla/service/spmd/shardy/test/import_shardings.mlir index 9cc62dd41959b7..a8236ade495588 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/import_shardings.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/import_shardings.mlir @@ -10,8 +10,8 @@ func.func @non_trivial_common_mesh(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,8]<=[8,4]T(1,0)}"}, %arg1: tensor<8x8xf32> {mhlo.sharding = "{devices=[1,2,16]<=[32] last_tile_dim_replicate}"}, %arg2: tensor<8x16xf32> {mhlo.sharding = "{devices=[4,4,2]<=[2,16]T(1,0) last_tile_dim_replicate}"}) -> tensor<8x16xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> + %1 = "stablehlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -24,10 +24,10 @@ func.func @multiple_shardings(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices= %arg1: tensor<8x8xf32> {mhlo.sharding = "{devices=[1,8,4]<=[2,4,4]T(0,2,1) last_tile_dim_replicate}"}, %arg2: tensor<8x16xf32> {mhlo.sharding = "{devices=[1,4,8]<=[2,4,4]T(1,0,2) last_tile_dim_replicate}"}) -> (tensor<8x16xf32> {mhlo.sharding = "{devices=[8,4]<=[32]}"}) { - // CHECK-NEXT: mhlo.add + // CHECK-NEXT: stablehlo.add // CHECK-SAME{LITERAL}: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"axis_1", "axis_0"}, {}]>]>} - %0 = mhlo.add %arg0, %arg1 {mhlo.sharding = "{devices=[8,1,4]<=[2,4,4]T(1,0,2) last_tile_dim_replicate}"} : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 {mhlo.sharding = "{devices=[8,1,4]<=[2,4,4]T(1,0,2) last_tile_dim_replicate}"} : tensor<8x8xf32> + %1 = "stablehlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -41,7 +41,7 @@ func.func @multiple_shardings(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices= // CHECK-SAME: -> tensor<32x16xf32> { func.func @single_axis(%arg0: tensor<32x8xf32> {mhlo.sharding = "{devices=[16,1]<=[16]}"}, %arg1: tensor<8x16xf32>) -> tensor<32x16xf32> { - %0 = "mhlo.dot" (%arg0, %arg1) : (tensor<32x8xf32>, tensor<8x16xf32>) -> tensor<32x16xf32> + %0 = "stablehlo.dot" (%arg0, %arg1) : (tensor<32x8xf32>, tensor<8x16xf32>) -> tensor<32x16xf32> return %0 : tensor<32x16xf32> } @@ -51,16 +51,16 @@ func.func @single_axis(%arg0: tensor<32x8xf32> {mhlo.sharding = "{devices=[16,1] // CHECK-LABEL: func @multi_result_op func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>) { - %0 = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: mhlo.reduce + %0 = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: stablehlo.reduce // CHECK-SAME{LITERAL}: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"axis_1"}]>, <@mesh, [{"axis_1"}, {}]>]>} - %1:2 = mhlo.reduce(%arg0 init: %0), (%arg1 init: %0) across dimensions = [1] + %1:2 = stablehlo.reduce(%arg0 init: %0), (%arg1 init: %0) across dimensions = [1] {mhlo.sharding = "{{devices=[1,4,8]<=[8,4]T(1,0) last_tile_dim_replicate}, {devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}}"} : (tensor<4x64x8xf32>, tensor<4x64x8xf32>, tensor, tensor) -> (tensor<4x8xf32>, tensor<4x8xf32>) reducer(%arg2: tensor, %arg4: tensor) (%arg3: tensor, %arg5: tensor) { - %2 = mhlo.add %arg2, %arg4 : tensor - %3 = mhlo.add %arg3, %arg5 : tensor - mhlo.return %2, %3 : tensor, tensor + %2 = stablehlo.add %arg2, %arg4 : tensor + %3 = stablehlo.add %arg3, %arg5 : tensor + stablehlo.return %2, %3 : tensor, tensor } return %1#0, %1#1 : tensor<4x8xf32>, tensor<4x8xf32> } @@ -77,8 +77,8 @@ func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) func.func @fully_replicated(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}, %arg1: tensor<8x8xf32> {mhlo.sharding = "{replicated}"}, %arg2: tensor<8x16xf32>) -> tensor<8x16xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> + %1 = "stablehlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -92,7 +92,7 @@ func.func @fully_replicated(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4 // CHECK-SAME: -> tensor<6x35xf32> { func.func @prime_number(%arg0: tensor<6x35xf32> {mhlo.sharding = "{devices=[6,35]<=[7,10,3]T(2,1,0)}"}, %arg1: tensor<6x35xf32> {mhlo.sharding = "{replicated}"}) -> tensor<6x35xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<6x35xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<6x35xf32> return %0 : tensor<6x35xf32> } @@ -106,7 +106,7 @@ func.func @prime_number(%arg0: tensor<6x35xf32> {mhlo.sharding = "{devices=[6,35 // CHECK-SAME: -> tensor<231x550x42x42xf32> { func.func @prime_number_2(%arg0: tensor<231x550x42x42xf32> {mhlo.sharding = "{devices=[33,10,1,7]<=[2,3,5,7,11]T(1,4,2,0,3)}"}, %arg1: tensor<231x550x42x42xf32> {mhlo.sharding = "{devices=[7,55,6,1]<=[2,3,5,7,11]T(3,2,4,1,0)}"}) -> tensor<231x550x42x42xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<231x550x42x42xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<231x550x42x42xf32> return %0 : tensor<231x550x42x42xf32> } @@ -120,7 +120,7 @@ func.func @prime_number_2(%arg0: tensor<231x550x42x42xf32> {mhlo.sharding = "{de // CHECK-SAME: -> tensor<8x8xf32> { func.func @unknown_sharding(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}, %arg1: tensor<8x8xf32> {mhlo.sharding = "{unknown}"}) -> tensor<8x8xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> return %0 : tensor<8x8xf32> } @@ -130,10 +130,10 @@ func.func @unknown_sharding(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4 // CHECK-LABEL: sdy.mesh @maximal_mesh_0 = <[], device_ids=[0]> // CHECK-LABEL: func @one_maximal_mesh( -// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, [{}, {}]>} +// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, []>} func.func @one_maximal_mesh(%arg0: tensor<8x8xf32> {mhlo.sharding = "{maximal device=0}"}, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> return %0 : tensor<8x8xf32> } @@ -143,11 +143,11 @@ func.func @one_maximal_mesh(%arg0: tensor<8x8xf32> {mhlo.sharding = "{maximal de // CHECK-LABEL: sdy.mesh @maximal_mesh_4 = <[], device_ids=[4]> // CHECK-LABEL: func @two_maximal_shardings_should_be_sorted( -// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_4, [{}, {}]>}, -// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, [{}, {}]>}) +// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_4, []>}, +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, []>}) func.func @two_maximal_shardings_should_be_sorted(%arg0: tensor<8x8xf32> {mhlo.sharding = "{maximal device=4}"}, %arg1: tensor<8x8xf32> {mhlo.sharding = "{maximal device=0}"}) -> tensor<8x8xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> return %0 : tensor<8x8xf32> } @@ -155,11 +155,11 @@ func.func @two_maximal_shardings_should_be_sorted(%arg0: tensor<8x8xf32> {mhlo.s // CHECK-COUNT-1: sdy.mesh @maximal_mesh_0 = <[], device_ids=[0]> // CHECK-LABEL: func @duplicate_maximal_sharding_should_be_deduped( -// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, [{}, {}]>}, -// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, [{}, {}]>}) +// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, []>}, +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, []>}) func.func @duplicate_maximal_sharding_should_be_deduped(%arg0: tensor<8x8xf32> {mhlo.sharding = "{maximal device=0}"}, %arg1: tensor<8x8xf32> {mhlo.sharding = "{maximal device=0}"}) -> tensor<8x8xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> return %0 : tensor<8x8xf32> } @@ -170,12 +170,12 @@ func.func @duplicate_maximal_sharding_should_be_deduped(%arg0: tensor<8x8xf32> { // CHECK-LABEL: func @two_meshes( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_1"}, {}]>}, -// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, [{}, {}]>}, %arg2: tensor<8x16xf32>) +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, []>}, %arg2: tensor<8x16xf32>) func.func @two_meshes(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}, %arg1: tensor<8x8xf32> {mhlo.sharding = "{maximal device=0}"}, %arg2: tensor<8x16xf32>) -> tensor<8x16xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> + %1 = "stablehlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -189,11 +189,11 @@ func.func @two_meshes(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,1,8]< // CHECK-SAME: -> tensor<8x8xf32> { func.func @maximal_sharding_on_op(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> { -// CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg0, %arg1 -// CHECK-SAME{LITERAL}: {sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_4, [{}, {}]>]>} -// CHECK-NEXT: %[[MULTIPLY:.*]] = mhlo.multiply %[[ADD]], %[[ADD]] -// CHECK-SAME{LITERAL}: {sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_0, [{}, {}]>]>} - %0 = mhlo.add %arg0, %arg1 {mhlo.sharding = "{maximal device=4}"} : tensor<8x8xf32> - %1 = mhlo.multiply %0, %0 {mhlo.sharding = "{maximal device=0}"} : tensor<8x8xf32> +// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 +// CHECK-SAME{LITERAL}: {sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_4, []>]>} +// CHECK-NEXT: %[[MULTIPLY:.*]] = stablehlo.multiply %[[ADD]], %[[ADD]] +// CHECK-SAME{LITERAL}: {sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_0, []>]>} + %0 = stablehlo.add %arg0, %arg1 {mhlo.sharding = "{maximal device=4}"} : tensor<8x8xf32> + %1 = stablehlo.multiply %0, %0 {mhlo.sharding = "{maximal device=0}"} : tensor<8x8xf32> return %1 : tensor<8x8xf32> } diff --git a/third_party/xla/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir index 1a5f443f4ec472..ca9d1d5d00647f 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir @@ -21,8 +21,8 @@ sdy.mesh @empty_mesh_1 = <[]> func.func @non_trivial_common_mesh(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"axis_2"}, {"axis_0", "axis_1"}]>}, %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_0"}]>}, %arg2: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"axis_1"}, {"axis_2"}]>}) -> tensor<8x16xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> + %1 = stablehlo.dot %0, %arg2 : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -35,10 +35,10 @@ func.func @multiple_shardings(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.shardi %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_0", "axis_2"}]>}, %arg2: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_1"}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"axis_0", "axis_1"}, {"axis_2"}]>}) { -// CHECK-NEXT: mhlo.add +// CHECK-NEXT: stablehlo.add // CHECK-SAME{LITERAL}: {mhlo.sharding = "{devices=[8,1,4]<=[2,4,4]T(1,0,2) last_tile_dim_replicate}"} - %0 = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> + %1 = stablehlo.dot %0, %arg2 : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -48,22 +48,22 @@ func.func @multiple_shardings(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.shardi // CHECK-SAME: -> tensor<32x16xf32> { func.func @single_axis(%arg0: tensor<32x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"axis_0"}, {}]>}, %arg1: tensor<8x16xf32>) -> tensor<32x16xf32> { - %0 = "mhlo.dot" (%arg0, %arg1) : (tensor<32x8xf32>, tensor<8x16xf32>) -> tensor<32x16xf32> + %0 = stablehlo.dot %arg0, %arg1 : (tensor<32x8xf32>, tensor<8x16xf32>) -> tensor<32x16xf32> return %0 : tensor<32x16xf32> } // CHECK-LABEL: func @multi_result_op func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>) { - %0 = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: mhlo.reduce + %0 = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: stablehlo.reduce // CHECK-SAME{LITERAL}: {mhlo.sharding = "{{devices=[1,4,8]<=[8,4]T(1,0) last_tile_dim_replicate}, {devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}}"} - %1:2 = mhlo.reduce(%arg0 init: %0), (%arg1 init: %0) across dimensions = [1] + %1:2 = stablehlo.reduce(%arg0 init: %0), (%arg1 init: %0) across dimensions = [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{}, {"y"}]>, <@mesh_2, [{"y"}, {}]>]>} : (tensor<4x64x8xf32>, tensor<4x64x8xf32>, tensor, tensor) -> (tensor<4x8xf32>, tensor<4x8xf32>) reducer(%arg2: tensor, %arg4: tensor) (%arg3: tensor, %arg5: tensor) { - %2 = mhlo.add %arg2, %arg4 : tensor - %3 = mhlo.add %arg3, %arg5 : tensor - mhlo.return %2, %3 : tensor, tensor + %2 = stablehlo.add %arg2, %arg4 : tensor + %3 = stablehlo.add %arg3, %arg5 : tensor + stablehlo.return %2, %3 : tensor, tensor } return %1#0, %1#1 : tensor<4x8xf32>, tensor<4x8xf32> } @@ -76,8 +76,8 @@ func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) func.func @fully_replicated(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"y"}, {}]>}, %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{}, {}]>}, %arg2: tensor<8x16xf32>) -> tensor<8x16xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> + %1 = stablehlo.dot %0, %arg2 : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -87,16 +87,16 @@ func.func @fully_replicated(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding // CHECK-SAME: -> tensor<8x16xf32> { func.func @split_axes(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"y"}, {"x":(2)2}]>}, %arg1: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x":(1)2}, {"x":(2)4}]>}) -> tensor<8x16xf32> { -// CHECK-NEXT: "mhlo.dot" +// CHECK-NEXT: stablehlo.dot // CHECK-SAME{LITERAL}: {mhlo.sharding = "{devices=[4,1,8]<=[2,2,2,4]T(0,2,1,3) last_tile_dim_replicate}"} - %1 = "mhlo.dot" (%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x":(1)2, "x":(4)2}, {}]>]>} : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %1 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x":(1)2, "x":(4)2}, {}]>]>} : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } // CHECK-LABEL: func @split_constants func.func @split_constants() -> (tensor<8x8xf32>, tensor<8x8xf32>) { - // CHECK-NEXT: %[[CONST_0:.*]] = mhlo.constant {mhlo.sharding = "{devices=[8,1,4]<=[32] last_tile_dim_replicate}"} dense<1.000000e+00> - // CHECK-NEXT: %[[CONST_1:.*]] = mhlo.constant {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"} dense<1.000000e+00> + // CHECK-NEXT: %[[CONST_0:.*]] = stablehlo.constant {mhlo.sharding = "{devices=[8,1,4]<=[32] last_tile_dim_replicate}"} dense<1.000000e+00> + // CHECK-NEXT: %[[CONST_1:.*]] = stablehlo.constant {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"} dense<1.000000e+00> // CHECK-NEXT: return %[[CONST_0]], %[[CONST_1]] %0 = sdy.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x"}, {}]>]>} dense<1.000000e+00> : tensor<8x8xf32> %1 = sdy.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"y"}, {}]>]>} dense<1.000000e+00> : tensor<8x8xf32> @@ -129,42 +129,42 @@ func.func @reshard_fully_open_partially_open(%arg0: tensor<8x8xf32>) -> tensor<8 // CHECK-SAME: %arg1: tensor<16x32xf32> {mhlo.sharding = "{devices=[2,1,8]<=[2,2,4]T(1,0,2) last_tile_dim_replicate}"}) // CHECK-SAME: -> (tensor<8x32xf32> {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"}) { func.func @sharding_in_manual_computation_body(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_3, [{"a", ?}, {"b", ?}]>}, %arg1: tensor<16x32xf32> {sdy.sharding = #sdy.sharding<@mesh_3, [{"b", ?}, {?}]>}) -> (tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh_3, [{"a"}, {}]>}) { -// CHECK-NEXT: %0 = mhlo.copy %arg0 {mhlo.sharding = "{devices=[2,2,4]<=[2,2,4]T(1,0,2) last_tile_dim_replicate}"} : tensor<8x16xf32> -// CHECK-NEXT: %1 = mhlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<8x16xf32>) -> tensor<4x8xf32> -// CHECK-NEXT: %2 = mhlo.copy %arg1 {mhlo.sharding = "{devices=[2,1,8]<=[2,2,4]T(1,0,2) last_tile_dim_replicate}"} : tensor<16x32xf32> -// CHECK-NEXT: %3 = mhlo.custom_call @SPMDFullToShardShape(%2) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<16x32xf32>) -> tensor<8x32xf32> -// CHECK-NEXT: %4 = mhlo.copy %1 {mhlo.sharding = "{devices=[1,2,4,2]<=[8,2]T(1,0) last_tile_dims={manual, replicated}}"} : tensor<4x8xf32> -// CHECK-NEXT: %5 = mhlo.add %4, %4 {mhlo.sharding = "{devices=[2,1,4,2]<=[4,2,2]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<4x8xf32> -// CHECK-NEXT: %6 = "mhlo.dot"(%5, %3) {mhlo.sharding = "{devices=[2,2,4]<=[4,4]T(1,0) last_tile_dims={manual}}"} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> -// CHECK-NEXT: %7 = mhlo.sine %6 {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : tensor<4x32xf32> -// CHECK-NEXT: %8 = mhlo.copy %7 {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : tensor<4x32xf32> -// CHECK-NEXT: %9 = mhlo.custom_call @SPMDShardToFullShape(%8) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<4x32xf32>) -> tensor<8x32xf32> -// CHECK-NEXT: return %9 : tensor<8x32xf32> +// CHECK-NEXT: %[[COPY_0:.*]] = mhlo.copy %arg0 {mhlo.sharding = "{devices=[2,2,4]<=[2,2,4]T(1,0,2) last_tile_dim_replicate}"} : tensor<8x16xf32> +// CHECK-NEXT: %[[FULL_TO_SHARD_0:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_0]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<8x16xf32>) -> tensor<4x8xf32> +// CHECK-NEXT: %[[COPY_1:.*]] = mhlo.copy %arg1 {mhlo.sharding = "{devices=[2,1,8]<=[2,2,4]T(1,0,2) last_tile_dim_replicate}"} : tensor<16x32xf32> +// CHECK-NEXT: %[[FULL_TO_SHARD_1:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_1]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<16x32xf32>) -> tensor<8x32xf32> +// CHECK-NEXT: %[[RESHARD:.*]] = mhlo.copy %[[FULL_TO_SHARD_0]] {mhlo.sharding = "{devices=[1,2,4,2]<=[8,2]T(1,0) last_tile_dims={manual, replicated}}"} : tensor<4x8xf32> +// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[RESHARD]], %[[RESHARD]] {mhlo.sharding = "{devices=[2,1,4,2]<=[4,2,2]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<4x8xf32> +// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[ADD]], %[[FULL_TO_SHARD_1]] {mhlo.sharding = "{devices=[2,2,4]<=[4,4]T(1,0) last_tile_dims={manual}}"} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> +// CHECK-NEXT: %[[SINE:.*]] = stablehlo.sine %[[DOT]] {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : tensor<4x32xf32> +// CHECK-NEXT: %[[COPY_2:.*]] = mhlo.copy %[[SINE]] {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : tensor<4x32xf32> +// CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_2]]) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<4x32xf32>) -> tensor<8x32xf32> +// CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<8x32xf32> %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh_3, [{"b"}, {"a"}]>, <@mesh_3, [{"b"}, {}], replicated={"a"}>] out_shardings=[<@mesh_3, [{"a"}, {}], replicated={"b"}>] manual_axes={"a", "b"} (%arg2: tensor<4x8xf32>, %arg3: tensor<8x32xf32>) { %1 = sdy.reshard %arg2 <@mesh_3, [{}, {"d"}]> : tensor<4x8xf32> - %2 = mhlo.add %1, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_3, [{"c"}, {}]>]>} : tensor<4x8xf32> - %3 = "mhlo.dot"(%2, %arg3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_3, [{"c"}, {"d"}]>]>} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> - %4 = mhlo.sine %3 : tensor<4x32xf32> + %2 = stablehlo.add %1, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_3, [{"c"}, {}]>]>} : tensor<4x8xf32> + %3 = stablehlo.dot %2, %arg3 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_3, [{"c"}, {"d"}]>]>} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> + %4 = stablehlo.sine %3 : tensor<4x32xf32> sdy.return %4 : tensor<4x32xf32> } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> return %0 : tensor<8x32xf32> } // CHECK-LABEL: func @mesh_with_device_id_should_be_converted_to_maximal_sharding(%arg0: tensor<8x8xf32> {mhlo.sharding = "{maximal device=0}"}, %arg1: tensor<8x8xf32>) -func.func @mesh_with_device_id_should_be_converted_to_maximal_sharding(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, [{}, {}]>}, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK: %[[ADD:.*]] = mhlo.add %arg0, %arg1 - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> - // CHECK: %[[ADD_WITH_SHARDING:.*]] = mhlo.add %[[ADD]], %[[ADD]] {mhlo.sharding = "{maximal device=1}"} - %1 = mhlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_1, [{}, {}]>]>} : tensor<8x8xf32> +func.func @mesh_with_device_id_should_be_converted_to_maximal_sharding(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, []>}, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> { + // CHECK: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> + // CHECK: %[[ADD_WITH_SHARDING:.*]] = stablehlo.add %[[ADD]], %[[ADD]] {mhlo.sharding = "{maximal device=1}"} + %1 = stablehlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_1, []>]>} : tensor<8x8xf32> return %1 : tensor<8x8xf32> } // CHECK-LABEL: func @mesh_empty_should_be_converted_to_replicated_sharding(%arg0: tensor<8x8xf32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<8x8xf32>) func.func @mesh_empty_should_be_converted_to_replicated_sharding(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@empty_mesh_0, [{}, {}]>}, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK: %[[ADD:.*]] = mhlo.add %arg0, %arg1 - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> - // CHECK: %[[ADD_WITH_SHARDING:.*]] = mhlo.add %[[ADD]], %[[ADD]] {mhlo.sharding = "{replicated}"} - %1 = mhlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@empty_mesh_1, [{}, {}]>]>} : tensor<8x8xf32> + // CHECK: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> + // CHECK: %[[ADD_WITH_SHARDING:.*]] = stablehlo.add %[[ADD]], %[[ADD]] {mhlo.sharding = "{replicated}"} + %1 = stablehlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@empty_mesh_1, [{}, {}]>]>} : tensor<8x8xf32> return %1 : tensor<8x8xf32> } @@ -176,10 +176,10 @@ func.func @mesh_empty_should_be_converted_to_replicated_sharding(%arg0: tensor<8 func.func @multiple_shardings_with_device_list(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_4, [{"axis_2"}, {"axis_0", "axis_1"}]>}, %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_4, [{}, {"axis_0", "axis_2"}]>}, %arg2: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_4, [{}, {"axis_1"}]>}) -> tensor<8x16xf32> { - // CHECK-NEXT: mhlo.add + // CHECK-NEXT: stablehlo.add // CHECK-SAME{LITERAL}: {mhlo.sharding = "{devices=[4,1,2]0,2,1,3,4,6,5,7 last_tile_dim_replicate}"} - %0 = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_4, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_4, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> + %1 = stablehlo.dot %0, %arg2 : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -190,10 +190,10 @@ func.func @named_sharding_in_manual_computation( %arg0: tensor<32x2xi32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", "y"}, {}]>}) -> (tensor<32x2xi32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", "y"}, {}]>}) { // CHECK-NEXT: %[[COPY_0:.*]] = mhlo.copy %arg0 {mhlo.sharding = "{devices=[32,1]<=[32]}"} : tensor<32x2xi32> - // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = mhlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dims={manual}}"} : (tensor<32x2xi32>) -> tensor<4x2xi32> + // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dims={manual}}"} : (tensor<32x2xi32>) -> tensor<4x2xi32> // CHECK-NEXT: %[[FOO:.*]] = call @foo(%[[FULL_TO_SHARD]]) {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"} : (tensor<4x2xi32>) -> tensor<4x2xi32> // CHECK-NEXT: %[[COPY_1:.*]] = mhlo.copy %[[FOO]] {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dims={manual}}"} : tensor<4x2xi32> - // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_1]]) {mhlo.sharding = "{devices=[32,1]<=[32]}"} : (tensor<4x2xi32>) -> tensor<32x2xi32> + // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_1]]) {mhlo.sharding = "{devices=[32,1]<=[32]}"} : (tensor<4x2xi32>) -> tensor<32x2xi32> // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<32x2xi32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_2, [{"x", "y"}, {}]>] out_shardings=[<@mesh_2, [{"x", "y"}, {}]>] manual_axes={"x"} (%arg1: tensor<4x2xi32>) { %1 = sdy.named_computation<"foo">(%arg1) in_shardings=[<@mesh_2, [{"y"}, {}]>] out_shardings=[<@mesh_2, [{"y"}, {}]>] (%arg2: tensor<4x2xi32>) { @@ -210,23 +210,110 @@ func.func @free_axis_inside_in_out_shardings_manual_computation( %arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_5, [{"i"}, {}]>}) -> (tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_5, [{"i", ?}, {?}]>}) { // CHECK-NEXT: %[[COPY_OPERAND:.*]] = mhlo.copy %arg0 {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dim_replicate}"} : tensor<4x8xf32> - // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND]]) {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dims={manual}}"} : (tensor<4x8xf32>) -> tensor<4x8xf32> - // CHECK-NEXT: %[[MULT:.*]] = mhlo.multiply %[[FULL_TO_SHARD]], %[[FULL_TO_SHARD]] {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dims={manual}}"} : tensor<4x8xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND]]) {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dims={manual}}"} : (tensor<4x8xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %[[FULL_TO_SHARD]], %[[FULL_TO_SHARD]] {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dims={manual}}"} : tensor<4x8xf32> // CHECK-NEXT: %[[COPY:.*]] = mhlo.copy %[[MULT]] {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dims={manual}}"} : tensor<4x8xf32> // CHECK-NEXT: %[[COPY_RESULT:.*]] = mhlo.copy %[[COPY]] {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dims={manual}}"} : tensor<4x8xf32> - // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT]]) {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dim_replicate}"} : (tensor<4x8xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT]]) {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dim_replicate}"} : (tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<4x8xf32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_5, [{"i", ?}, {?}], replicated={"j"}>] out_shardings=[<@mesh_5, [{"i", ?}, {?}], replicated={"j"}>] manual_axes={"j"} (%arg1: tensor<4x8xf32>) { - %1 = mhlo.multiply %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_5, [{"i"}, {}]>]>} : tensor<4x8xf32> + %1 = stablehlo.multiply %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_5, [{"i"}, {}]>]>} : tensor<4x8xf32> %2 = sdy.reshard %1 <@mesh_5, [{"i"}, {}]> : tensor<4x8xf32> sdy.return %2 : tensor<4x8xf32> } : (tensor<4x8xf32>) -> tensor<4x8xf32> return %0 : tensor<4x8xf32> } +// CHECK-LABEL: func @custom_call_erf_topk +func.func @custom_call_erf_topk( + %arg0: tensor<16x8xf32> {sdy.sharding = #sdy.sharding<@mesh_5, [{"i"}, {}]>} + ) -> (tensor<16x2xf32> {sdy.sharding = #sdy.sharding<@mesh_5, [{"i", ?}, {?}]>}) { + // CHECK-NEXT: %[[ERF:.*]] = stablehlo.custom_call @mhlo.erf(%arg0) {mhlo.attributes = {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dim_replicate}", mhlo.version = 1 : i64}} : (tensor<16x8xf32>) -> tensor<16x8xf32> + // CHECK-NEXT: stablehlo.custom_call @mhlo.topk(%[[ERF]]) + // CHECK-SAME{LITERAL}: {mhlo.attributes = {k = 2 : i64, largest = true, mhlo.sharding = "{{devices=[2,1,2]<=[4] last_tile_dim_replicate}, {devices=[2,1,2]<=[4] last_tile_dim_replicate}}"}, mhlo.version = 1 : i64} : (tensor<16x8xf32>) -> (tensor<16x2xf32>, tensor<16x2xi32>) + %0 = stablehlo.custom_call @mhlo.erf(%arg0) { + mhlo.attributes = {mhlo.version = 1 : i64}, + sdy.sharding = #sdy.sharding_per_value<[<@mesh_5, [{"i", ?}, {?}]>]> + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + %1:2 = stablehlo.custom_call @mhlo.topk(%0) { + mhlo.attributes = {k = 2 : i64, largest = true}, + mhlo.version = 1 : i64, + sdy.sharding = #sdy.sharding_per_value<[<@mesh_5, [{"i", ?}, {?}]>, <@mesh_5, [{"i", ?}, {?}]>]> + } : (tensor<16x8xf32>) -> (tensor<16x2xf32>, tensor<16x2xi32>) + return %1#0 : tensor<16x2xf32> +} + +// CHECK-LABEL: @callback_transform_to_tuple +func.func @callback_transform_to_tuple(%arg0: tensor<2xf64> {sdy.sharding = #sdy.sharding<@mesh_5, [{"i"}]>}) -> (tensor<2xf64> {sdy.sharding = #sdy.sharding<@mesh_5, [{"i"}]>}) { + // CHECK-NEXT: %[[C:.*]] = stablehlo.constant + // CHECK-NEXT: %[[CALLBACK:.*]] = stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) {{{.*}} : (tensor, tensor<2xf64>) -> tuple> + // CHECK-NEXT: %[[GET_TUPLE:.*]] = stablehlo.get_tuple_element %[[CALLBACK]][0] {mhlo.sharding = "{replicated}"} : (tuple>) -> tensor<2xf64> + // CHECK-NEXT: return %[[GET_TUPLE]] : tensor<2xf64> + %1 = stablehlo.constant dense<56560393354880> : tensor + %2 = stablehlo.custom_call @xla_python_cpu_callback(%1, %arg0) {api_version = 2 : i32, backend_config = "56560393354880", operand_layouts = [dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>], sdy.sharding = #sdy.sharding_per_value<[<@empty_mesh_0, [{}]>]>, xla_shape = "(f64[2]{0})"} : (tensor, tensor<2xf64>) -> tensor<2xf64> + return %2 : tensor<2xf64> +} + +// CHECK-LABEL: @callback_no_result +func.func private @callback_no_result(%arg0: tensor) { + // CHECK-NEXT: %[[C:.*]] = stablehlo.constant + // CHECK-NEXT: stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) { + // CHECK-SAME: api_version = 2 : i32, backend_config = "56238273106176", + // CHECK-SAME: has_side_effect = true, mhlo.sharding = "{maximal device=0}", + // CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], + // CHECK-SAME: result_layouts = [] + // CHECK-SAME: } : (tensor, tensor) -> () + %c = stablehlo.constant dense<56238273106176> : tensor + %0 = stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0) {api_version = 2 : i32, backend_config = "56238273106176", has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], result_layouts = [], sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_0, []>]>} : (tensor, tensor) -> tuple<> + return +} + +// CHECK-LABEL: @callback_result_unused +func.func private @callback_result_unused(%arg0: tensor) { + // CHECK-NEXT: %[[C:.*]] = stablehlo.constant + // CHECK-NEXT: stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) { + // CHECK-SAME: api_version = 2 : i32, backend_config = "56238273106176", + // CHECK-SAME: has_side_effect = true, mhlo.sharding = "{maximal device=0}", + // CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], + // CHECK-SAME: result_layouts = [] + // CHECK-SAME: } : (tensor, tensor) -> () + %c = stablehlo.constant dense<56238273106176> : tensor + %0 = stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0) {api_version = 2 : i32, backend_config = "56238273106176", has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], result_layouts = [dense<> : tensor<0xindex>], sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_0, []>]>} : (tensor, tensor) -> tensor + return +} + +// CHECK-LABEL: @callback_tuple_result_token_used +func.func public @callback_tuple_result_token_used(%arg0: !stablehlo.token, %arg1: tensor<2xi64>) -> !stablehlo.token { + %c = stablehlo.constant dense<56238119409280> : tensor + // CHECK-NEXT: %[[C:.*]] = stablehlo.constant + // CHECK-NEXT: %[[CALLBACK:.*]] = stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0, %arg1) { + // CHECK-SAME: api_version = 2 : i32, backend_config = "56238119409280", + // CHECK-SAME: has_side_effect = true, mhlo.sharding = "{maximal device=0}", + // CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>], + // CHECK-SAME: result_layouts = [dense<> : tensor<0xindex>] + // CHECK-SAME: } : (tensor, !stablehlo.token, tensor<2xi64>) -> tuple + // CHECK-NEXT: %[[TOKEN:.*]] = stablehlo.get_tuple_element %[[CALLBACK]][0] : (tuple) -> !stablehlo.token + // CHECK-NEXT: return %[[TOKEN]] : !stablehlo.token + %0 = stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0, %arg1) {api_version = 2 : i32, backend_config = "56238119409280", has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<> : tensor<0xindex>], sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_0, []>]>} : (tensor, !stablehlo.token, tensor<2xi64>) -> tuple + %1 = stablehlo.get_tuple_element %0[0] : (tuple) -> !stablehlo.token + return %1 : !stablehlo.token +} + +// CHECK-LABEL: @callback_no_tuple_result_used +func.func @callback_no_tuple_result_used(%arg0: tensor<2xf64>) -> tensor<2xf64> { + // CHECK-NEXT: %[[C:.*]] = stablehlo.constant + // CHECK-NEXT: %[[CALLBACK:.*]] = stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) {{{.*}} : (tensor, tensor<2xf64>) -> tuple> + // CHECK-NEXT: %[[GET_TUPLE:.*]] = stablehlo.get_tuple_element %[[CALLBACK]][0] {mhlo.sharding = "{replicated}"} : (tuple>) -> tensor<2xf64> + // CHECK-NEXT: return %[[GET_TUPLE]] : tensor<2xf64> + %c = stablehlo.constant dense<18990036333952> : tensor + %0 = stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0) {api_version = 2 : i32, backend_config = "18990036333952", operand_layouts = [dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>], sdy.sharding = #sdy.sharding_per_value<[<@empty_mesh_0, [{?}]>]>, xla_shape = "(f64[2]{0})"} : (tensor, tensor<2xf64>) -> tensor<2xf64> + return %0 : tensor<2xf64> +} + + // CHECK-LABEL: func private @foo // CHECK-SAME: %arg0: tensor<4x2xi32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"} // CHECK-SAME: -> (tensor<4x2xi32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}) { diff --git a/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir index 55ccddd9645d5e..7bdc2c28273723 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir @@ -32,7 +32,7 @@ func.func @manual(%arg0: tensor<8x8xf32> {mhlo.sharding = "{replicated}"}, // CHECK-SAME: in_shardings=[<@mesh, [{"axis_0", "axis_1"}, {}]>, <@mesh, [{"axis_0"}, {}]>] // CHECK-SAME: out_shardings=[<@mesh, [{"axis_0", "axis_1"}, {}]>] // CHECK-SAME: manual_axes={"axis_0", "axis_1"} (%arg2: tensor<1x8xf32>, %arg3: tensor<1x8xf32>) { - // CHECK-LABEL: mhlo.add + // CHECK-LABEL: stablehlo.add // CHECK-LABEL: sdy.return %0 = mhlo.custom_call @Sharding(%arg0) {mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<8x8xf32>) -> tensor<8x8xf32> %1 = mhlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{manual}"} : (tensor<8x8xf32>) -> tensor<1x8xf32> @@ -63,14 +63,14 @@ func.func @while_with_free_variables( // CHECK-NEXT: %[[C1:.*]] = sdy.constant dense<1> // CHECK-NEXT: %[[C32:.*]] = sdy.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, []>]>} dense<32> // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]> - // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: %[[WHILE:.*]]:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) // CHECK-NEXT: cond { - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] - // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: %[[COND:.*]] = stablehlo.compare LT, %iterArg_0, %[[C32]] + // CHECK-NEXT: stablehlo.return %[[COND]] // CHECK-NEXT: } do { - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]] - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]] - // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]] + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %iterArg_0, %[[C1]] + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %iterArg, %[[SC]] + // CHECK-NEXT: stablehlo.return %[[ADD_1]], %[[ADD_0]] // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE]]#0 %0 = mhlo.constant dense<0> : tensor @@ -93,16 +93,16 @@ func.func @while_with_free_variables( // CHECK-LABEL: func @while_with_sinked_constants func.func @while_with_sinked_constants(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> { // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0> - // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: %[[WHILE:.*]]:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) // CHECK-NEXT: cond { // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32> - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] - // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: %[[COND:.*]] = stablehlo.compare LT, %iterArg_0, %[[C32]] + // CHECK-NEXT: stablehlo.return %[[COND]] // CHECK-NEXT: } do { // CHECK-NEXT: %[[C1:.*]] = sdy.constant dense<1> - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]] - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %iterArg - // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]] + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %iterArg_0, %[[C1]] + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %iterArg, %iterArg + // CHECK-NEXT: stablehlo.return %[[ADD_1]], %[[ADD_0]] // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE]]#0 %0 = mhlo.constant dense<0> : tensor @@ -124,7 +124,7 @@ func.func @while_with_sinked_constants(%arg0: tensor<32x96xf32>) -> tensor<32x96 // CHECK-LABEL: func @custom_call_with_tuple_operand_result func.func @custom_call_with_tuple_operand_result(%arg0: tensor<8x8xf32>, %arg1: tensor<4x8xf32>, %arg2: tensor<8x16xf32>) -> tensor<8x8xf32> { - // CHECK-NEXT: %[[FOO:.*]]:3 = mhlo.custom_call @foo(%arg0, %arg1, %arg2) : + // CHECK-NEXT: %[[FOO:.*]]:3 = stablehlo.custom_call @foo(%arg0, %arg1, %arg2) : // CHECK-SAME: (tensor<8x8xf32>, tensor<4x8xf32>, tensor<8x16xf32>) // CHECK-SAME: -> (tensor<8x8xf32>, tensor<4x8xf32>, tensor<8x16xf32>) // CHECK-NEXT: return %[[FOO]]#0 @@ -133,3 +133,13 @@ func.func @custom_call_with_tuple_operand_result(%arg0: tensor<8x8xf32>, %arg1: %2 = mhlo.get_tuple_element %1[0] : (!tuple) -> tensor<8x8xf32> return %2 : tensor<8x8xf32> } + +// ----- + +// CHECK-LABEL: func @import_sharding_group_with_unused_result +// CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { +func.func @import_sharding_group_with_unused_result(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + // CHECK sdy.sharding_group %arg0 group_id = 21: tensor<8x8xf32> + %0 = mhlo.custom_call @local_xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> tuple<> + return %arg0 : tensor<8x8xf32> +} diff --git a/third_party/xla/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_export.mlir b/third_party/xla/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_export.mlir index 859d067123e635..9e094e6eb7e344 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_export.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_export.mlir @@ -6,22 +6,22 @@ sdy.mesh @mesh_1 = <["a"=2, "b"=2, "c"=2, "d"=2]> // CHECK-LABEL: func @single_manual_comp func.func @single_manual_comp(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"a", ?}, {"b", ?}]>}, %arg1: tensor<16x32xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"b", ?}, {?}]>}) -> (tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"a"}, {}]>}) { // CHECK-NEXT: %0 = mhlo.copy %arg0 {mhlo.sharding = "{devices=[4,2]<=[8]}"} : tensor<8x16xf32> - // CHECK-NEXT: %1 = mhlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{manual}"} : (tensor<8x16xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{manual}"} : (tensor<8x16xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %2 = mhlo.copy %arg1 {mhlo.sharding = "{devices=[2,1,4]<=[4,2]T(1,0) last_tile_dim_replicate}"} : tensor<16x32xf32> - // CHECK-NEXT: %3 = mhlo.custom_call @SPMDFullToShardShape(%2) {mhlo.sharding = "{manual}"} : (tensor<16x32xf32>) -> tensor<8x32xf32> - // CHECK-NEXT: %4 = mhlo.add %1, %1 {mhlo.sharding = "{manual}"} : tensor<2x8xf32> - // CHECK-NEXT: %5 = "mhlo.dot"(%4, %3) {mhlo.sharding = "{manual}"} : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> - // CHECK-NEXT: %6 = "mhlo.all_reduce"(%5) + // CHECK-NEXT: %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) {mhlo.sharding = "{manual}"} : (tensor<16x32xf32>) -> tensor<8x32xf32> + // CHECK-NEXT: %4 = stablehlo.add %1, %1 {mhlo.sharding = "{manual}"} : tensor<2x8xf32> + // CHECK-NEXT: %5 = stablehlo.dot %4, %3 {mhlo.sharding = "{manual}"} : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> + // CHECK-NEXT: %6 = "stablehlo.all_reduce"(%5) // CHECK: %7 = mhlo.copy %6 {mhlo.sharding = "{manual}"} : tensor<2x32xf32> - // CHECK-NEXT: %8 = mhlo.custom_call @SPMDShardToFullShape(%7) {mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : (tensor<2x32xf32>) -> tensor<8x32xf32> + // CHECK-NEXT: %8 = stablehlo.custom_call @SPMDShardToFullShape(%7) {mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : (tensor<2x32xf32>) -> tensor<8x32xf32> // CHECK-NEXT: return %8 : tensor<8x32xf32> %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh_0, [{"a"}, {"b"}]>, <@mesh_0, [{"b"}, {}], replicated={"a"}>] out_shardings=[<@mesh_0, [{"a"}, {}], replicated={"b"}>] manual_axes={"a", "b"} (%arg2: tensor<2x8xf32>, %arg3: tensor<8x32xf32>) { - %1 = mhlo.add %arg2, %arg2 : tensor<2x8xf32> - %2 = "mhlo.dot"(%1, %arg3) : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> - %3 = "mhlo.all_reduce"(%2) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xi64>, use_global_device_ids}> ({ + %1 = stablehlo.add %arg2, %arg2 : tensor<2x8xf32> + %2 = stablehlo.dot %1, %arg3 : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> + %3 = "stablehlo.all_reduce"(%2) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xi64>, use_global_device_ids}> ({ ^bb0(%arg4: tensor, %arg5: tensor): - %4 = mhlo.add %arg4, %arg5 : tensor - mhlo.return %4 : tensor + %4 = stablehlo.add %arg4, %arg5 : tensor + stablehlo.return %4 : tensor }) : (tensor<2x32xf32>) -> tensor<2x32xf32> sdy.return %3 : tensor<2x32xf32> } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> @@ -32,13 +32,13 @@ func.func @single_manual_comp(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.shard func.func @manual_comp_using_another(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"a"}, {}]>}) -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"b"}]>}) { // CHECK-NEXT: %0 = mhlo.copy %arg0 {mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : tensor<8x8xf32> - // CHECK-NEXT: %1 = mhlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<8x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<8x8xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %2 = mhlo.copy %1 {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> - // CHECK-NEXT: %3 = mhlo.custom_call @SPMDShardToFullShape(%2) {mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : (tensor<2x8xf32>) -> tensor<8x8xf32> + // CHECK-NEXT: %3 = stablehlo.custom_call @SPMDShardToFullShape(%2) {mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : (tensor<2x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: %4 = mhlo.copy %3 {mhlo.sharding = "{devices=[1,2,4]<=[4,2]T(1,0) last_tile_dim_replicate}"} : tensor<8x8xf32> - // CHECK-NEXT: %5 = mhlo.custom_call @SPMDFullToShardShape(%4) {mhlo.sharding = "{devices=[1,1,2,4]<=[4,2]T(1,0) last_tile_dims={manual, replicated}}"} : (tensor<8x8xf32>) -> tensor<8x4xf32> + // CHECK-NEXT: %5 = stablehlo.custom_call @SPMDFullToShardShape(%4) {mhlo.sharding = "{devices=[1,1,2,4]<=[4,2]T(1,0) last_tile_dims={manual, replicated}}"} : (tensor<8x8xf32>) -> tensor<8x4xf32> // CHECK-NEXT: %6 = mhlo.copy %5 {mhlo.sharding = "{devices=[1,1,2,4]<=[4,2]T(1,0) last_tile_dims={manual, replicated}}"} : tensor<8x4xf32> - // CHECK-NEXT: %7 = mhlo.custom_call @SPMDShardToFullShape(%6) {mhlo.sharding = "{devices=[1,2,4]<=[4,2]T(1,0) last_tile_dim_replicate}"} : (tensor<8x4xf32>) -> tensor<8x8xf32> + // CHECK-NEXT: %7 = stablehlo.custom_call @SPMDShardToFullShape(%6) {mhlo.sharding = "{devices=[1,2,4]<=[4,2]T(1,0) last_tile_dim_replicate}"} : (tensor<8x4xf32>) -> tensor<8x8xf32> // CHECK-NEXT: return %7 : tensor<8x8xf32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<2x8xf32>) { sdy.return %arg1 : tensor<2x8xf32> @@ -53,17 +53,17 @@ func.func @manual_comp_using_another(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy // CHECK-LABEL: func @sharding_in_manual_computation_body func.func @sharding_in_manual_computation_body(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a", ?}, {"b", ?}]>}, %arg1: tensor<16x32xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"b", ?}, {?}]>}) -> (tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a"}, {}]>}) { // CHECK-NEXT: %0 = mhlo.copy %arg0 {mhlo.sharding = "{devices=[2,2,4]<=[16] last_tile_dim_replicate}"} : tensor<8x16xf32> - // CHECK-NEXT: %1 = mhlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<8x16xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<8x16xf32>) -> tensor<4x8xf32> // CHECK-NEXT: %2 = mhlo.copy %arg1 {mhlo.sharding = "{devices=[2,1,8]<=[2,2,4]T(1,0,2) last_tile_dim_replicate}"} : tensor<16x32xf32> - // CHECK-NEXT: %3 = mhlo.custom_call @SPMDFullToShardShape(%2) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<16x32xf32>) -> tensor<8x32xf32> - // CHECK-NEXT: %4 = mhlo.add %1, %1 {mhlo.sharding = "{devices=[2,1,4,2]<=[4,2,2]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<4x8xf32> - // CHECK-NEXT: %5 = "mhlo.dot"(%4, %3) {mhlo.sharding = "{devices=[2,2,4]<=[4,2,2]T(2,1,0) last_tile_dims={manual}}"} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> + // CHECK-NEXT: %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<16x32xf32>) -> tensor<8x32xf32> + // CHECK-NEXT: %4 = stablehlo.add %1, %1 {mhlo.sharding = "{devices=[2,1,4,2]<=[4,2,2]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<4x8xf32> + // CHECK-NEXT: %5 = stablehlo.dot %4, %3 {mhlo.sharding = "{devices=[2,2,4]<=[4,2,2]T(2,1,0) last_tile_dims={manual}}"} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> // CHECK-NEXT: %6 = mhlo.copy %5 {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : tensor<4x32xf32> - // CHECK-NEXT: %7 = mhlo.custom_call @SPMDShardToFullShape(%6) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<4x32xf32>) -> tensor<8x32xf32> + // CHECK-NEXT: %7 = stablehlo.custom_call @SPMDShardToFullShape(%6) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<4x32xf32>) -> tensor<8x32xf32> // CHECK-NEXT: return %7 : tensor<8x32xf32> %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh_1, [{"a"}, {"b"}]>, <@mesh_1, [{"b"}, {}], replicated={"a"}>] out_shardings=[<@mesh_1, [{"a"}, {}], replicated={"b"}>] manual_axes={"a", "b"} (%arg2: tensor<4x8xf32>, %arg3: tensor<8x32xf32>) { - %1 = mhlo.add %arg2, %arg2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"c"}, {}]>]>} : tensor<4x8xf32> - %2 = "mhlo.dot"(%1, %arg3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"d"}, {"c"}]>]>} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> + %1 = stablehlo.add %arg2, %arg2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"c"}, {}]>]>} : tensor<4x8xf32> + %2 = stablehlo.dot %1, %arg3 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"d"}, {"c"}]>]>} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> sdy.return %2 : tensor<4x32xf32> } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> return %0 : tensor<8x32xf32> @@ -71,14 +71,14 @@ func.func @sharding_in_manual_computation_body(%arg0: tensor<8x16xf32> {sdy.shar // CHECK-LABEL: func @call_op_with_no_operands_or_results func.func @call_op_with_no_operands_or_results() { - // CHECK-LABEL: %0 = mhlo.constant + // CHECK-LABEL: %cst = stablehlo.constant // CHECK-NOT: sdy.sharding // CHECK-NOT: mhlo.sharding - // CHECK-NEXT: %1 = mhlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a"}, {}]>]>} : tensor<2x2xf32> + // CHECK-NEXT: %0 = stablehlo.add %cst, %cst {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a"}, {}]>]>} : tensor<2x2xf32> // CHECK-NEXT: return sdy.manual_computation() in_shardings=[] out_shardings=[] manual_axes={} () { - %0 = mhlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> - %1 = mhlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a"}, {}]>]>} : tensor<2x2xf32> + %0 = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> + %1 = stablehlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a"}, {}]>]>} : tensor<2x2xf32> sdy.return } : () -> () return @@ -87,18 +87,18 @@ func.func @call_op_with_no_operands_or_results() { // CHECK-LABEL: func @nested_shmaps func.func @nested_shmaps(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a"}, {"b"}]>}) -> (tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a", ?}, {?}]>}) { // CHECK-NEXT: %[[COPY_OPERAND_OUTER:.*]] = mhlo.copy %arg0 {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : tensor<4x8xf32> - // CHECK-NEXT: %[[FULL_TO_SHARD_OUTER:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_OUTER]]) {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<4x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD_OUTER:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_OUTER]]) {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<4x8xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[COPY_OPERAND_INNER:.*]] = mhlo.copy %[[FULL_TO_SHARD_OUTER]] {mhlo.sharding = "{devices=[1,2,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> - // CHECK-NEXT: %[[FULL_TO_SHARD_INNER:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_INNER]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x8xf32>) -> tensor<2x4xf32> - // CHECK-NEXT: %[[MULT:.*]] = mhlo.multiply %[[FULL_TO_SHARD_INNER]], %[[FULL_TO_SHARD_INNER]] {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD_INNER:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_INNER]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x8xf32>) -> tensor<2x4xf32> + // CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %[[FULL_TO_SHARD_INNER]], %[[FULL_TO_SHARD_INNER]] {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> // CHECK-NEXT: %[[COPY_RESULT_INNER:.*]] = mhlo.copy %[[MULT]] {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> - // CHECK-NEXT: %[[SHARD_TO_FULL_INNER:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_INNER]]) {mhlo.sharding = "{devices=[1,2,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x4xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL_INNER:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_INNER]]) {mhlo.sharding = "{devices=[1,2,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x4xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[COPY_RESULT_OUTER:.*]] = mhlo.copy %[[SHARD_TO_FULL_INNER]] {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> - // CHECK-NEXT: %[[SHARD_TO_FULL_OUTER:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_OUTER]]) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<2x8xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL_OUTER:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_OUTER]]) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<2x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[SHARD_TO_FULL_OUTER]] : tensor<4x8xf32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_1, [{"a"}, {}]>] out_shardings=[<@mesh_1, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<2x8xf32>) { %1 = sdy.manual_computation(%arg1) in_shardings=[<@mesh_1, [{}, {"b"}]>] out_shardings=[<@mesh_1, [{}, {"b"}]>] manual_axes={"b"} (%arg2: tensor<2x4xf32>) { - %2 = mhlo.multiply %arg2, %arg2 : tensor<2x4xf32> + %2 = stablehlo.multiply %arg2, %arg2 : tensor<2x4xf32> sdy.return %2 : tensor<2x4xf32> } : (tensor<2x8xf32>) -> tensor<2x8xf32> sdy.return %1 : tensor<2x8xf32> @@ -109,26 +109,26 @@ func.func @nested_shmaps(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@m // CHECK-LABEL: func @nested_shmaps_extra_op func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a"}, {"b"}]>}) -> (tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a", ?}, {?}]>}) { // CHECK-NEXT: %[[COPY_OPERAND_OUTER:.*]] = mhlo.copy %arg0 {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : tensor<4x8xf32> - // CHECK-NEXT: %[[FULL_TO_SHARD_OUTER:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_OUTER]]) {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<4x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD_OUTER:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_OUTER]]) {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<4x8xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[COPY_OPERAND_INNER:.*]] = mhlo.copy %[[FULL_TO_SHARD_OUTER]] {mhlo.sharding = "{devices=[1,2,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> - // CHECK-NEXT: %[[FULL_TO_SHARD_INNER:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_INNER]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x8xf32>) -> tensor<2x4xf32> - // CHECK-NEXT: %[[MULT:.*]] = mhlo.multiply %[[FULL_TO_SHARD_INNER]], %[[FULL_TO_SHARD_INNER]] {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[MULT]], %[[MULT]] {mhlo.sharding = "{devices=[2,1,4,2]<=[2,2,2,2]T(2,1,0,3) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> - // CHECK-NEXT: %[[SUB:.*]] = mhlo.subtract %[[ADD]], %[[ADD]] {mhlo.sharding = "{devices=[4,1,4]<=[2,2,4]T(2,1,0) last_tile_dims={manual}}"} : tensor<2x4xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD_INNER:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_INNER]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x8xf32>) -> tensor<2x4xf32> + // CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %[[FULL_TO_SHARD_INNER]], %[[FULL_TO_SHARD_INNER]] {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[MULT]], %[[MULT]] {mhlo.sharding = "{devices=[2,1,4,2]<=[2,2,2,2]T(2,1,0,3) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> + // CHECK-NEXT: %[[SUB:.*]] = stablehlo.subtract %[[ADD]], %[[ADD]] {mhlo.sharding = "{devices=[4,1,4]<=[2,2,4]T(2,1,0) last_tile_dims={manual}}"} : tensor<2x4xf32> // CHECK-NEXT: %[[COPY_RESULT_INNER:.*]] = mhlo.copy %[[SUB]] {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> - // CHECK-NEXT: %[[SHARD_TO_FULL_INNER:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_INNER]]) {mhlo.sharding = "{devices=[1,2,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x4xf32>) -> tensor<2x8xf32> - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[SHARD_TO_FULL_INNER]], %[[SHARD_TO_FULL_INNER]] {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL_INNER:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_INNER]]) {mhlo.sharding = "{devices=[1,2,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x4xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[SHARD_TO_FULL_INNER]], %[[SHARD_TO_FULL_INNER]] {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> // CHECK-NEXT: %[[COPY_RESULT_OUTER:.*]] = mhlo.copy %[[ADD]] {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> - // CHECK-NEXT: %[[SHARD_TO_FULL_OUTER:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_OUTER]]) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<2x8xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL_OUTER:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_OUTER]]) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<2x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[SHARD_TO_FULL_OUTER]] : tensor<4x8xf32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_1, [{"a"}, {}]>] out_shardings=[<@mesh_1, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<2x8xf32>) { %1 = sdy.manual_computation(%arg1) in_shardings=[<@mesh_1, [{}, {"b"}]>] out_shardings=[<@mesh_1, [{}, {"b"}]>] manual_axes={"b"} (%arg2: tensor<2x4xf32>) { - %2 = mhlo.multiply %arg2, %arg2 : tensor<2x4xf32> - %3 = mhlo.add %2, %2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"c"}, {}]>]>} : tensor<2x4xf32> - %4 = mhlo.subtract %3, %3 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"c", "d"}, {}]>]>} : tensor<2x4xf32> + %2 = stablehlo.multiply %arg2, %arg2 : tensor<2x4xf32> + %3 = stablehlo.add %2, %2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"c"}, {}]>]>} : tensor<2x4xf32> + %4 = stablehlo.subtract %3, %3 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"c", "d"}, {}]>]>} : tensor<2x4xf32> sdy.return %4 : tensor<2x4xf32> } : (tensor<2x8xf32>) -> tensor<2x8xf32> - %5 = mhlo.add %1, %1 : tensor<2x8xf32> + %5 = stablehlo.add %1, %1 : tensor<2x8xf32> sdy.return %5 : tensor<2x8xf32> } : (tensor<4x8xf32>) -> tensor<4x8xf32> return %0 : tensor<4x8xf32> @@ -137,22 +137,22 @@ func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sh // CHECK-LABEL: func @multiple_manual_computation_uses func.func @multiple_manual_computation_uses(%arg0: tensor<2x4x8xi32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {}, {"a"}]>}, %arg1: tensor<32x16x8xi32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {}, {"a"}]>}) -> (tensor<131x4x8xi32> {sdy.sharding = #sdy.sharding<@mesh_0, [{?}, {?}, {"a"}]>}) { // CHECK-NEXT: %[[COPY_OPERAND_0:.*]] = mhlo.copy %arg0 {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : tensor<2x4x8xi32> - // CHECK-NEXT: %[[FULL_TO_SHARD_0:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_0]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<2x4x8xi32>) -> tensor<2x4x2xi32> + // CHECK-NEXT: %[[FULL_TO_SHARD_0:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_0]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<2x4x8xi32>) -> tensor<2x4x2xi32> // CHECK-NEXT: %[[CUSTOM_CALL:.*]] = stablehlo.custom_call @sdy_testonly(%[[FULL_TO_SHARD_0]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<2x4x2xi32>) -> tensor<3x4x2xi32> // CHECK-NEXT: %[[COPY_RESULT_0:.*]] = mhlo.copy %[[CUSTOM_CALL]] {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : tensor<3x4x2xi32> - // CHECK-NEXT: %[[SHARD_TO_FULL_0:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_0]]) {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : (tensor<3x4x2xi32>) -> tensor<3x4x8xi32> + // CHECK-NEXT: %[[SHARD_TO_FULL_0:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_0]]) {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : (tensor<3x4x2xi32>) -> tensor<3x4x8xi32> // CHECK-NEXT: %[[COPY_OPERAND_1:.*]] = mhlo.copy %arg1 {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : tensor<32x16x8xi32> - // CHECK-NEXT: %[[FULL_TO_SHARD_1:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_1]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<32x16x8xi32>) -> tensor<32x16x2xi32> + // CHECK-NEXT: %[[FULL_TO_SHARD_1:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_1]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<32x16x8xi32>) -> tensor<32x16x2xi32> // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %[[FULL_TO_SHARD_1]] {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<32x16x2xi32>) -> tensor<128x4x2xi32> // CHECK-NEXT: %[[COPY_RESULT_1:.*]] = mhlo.copy %[[RESHAPE]] {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : tensor<128x4x2xi32> - // CHECK-NEXT: %[[SHARD_TO_FULL_1:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_1]]) {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : (tensor<128x4x2xi32>) -> tensor<128x4x8xi32> + // CHECK-NEXT: %[[SHARD_TO_FULL_1:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_1]]) {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : (tensor<128x4x2xi32>) -> tensor<128x4x8xi32> // CHECK-NEXT: %[[COPY_OPERAND_2:.*]] = mhlo.copy %[[SHARD_TO_FULL_0]] {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : tensor<3x4x8xi32> - // CHECK-NEXT: %[[FULL_TO_SHARD_2:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_2]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<3x4x8xi32>) -> tensor<3x4x2xi32> + // CHECK-NEXT: %[[FULL_TO_SHARD_2:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_2]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<3x4x8xi32>) -> tensor<3x4x2xi32> // CHECK-NEXT: %[[COPY_OPERAND_3:.*]] = mhlo.copy %[[SHARD_TO_FULL_1]] {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : tensor<128x4x8xi32> - // CHECK-NEXT: %[[FULL_TO_SHARD_3:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_3]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<128x4x8xi32>) -> tensor<128x4x2xi32> + // CHECK-NEXT: %[[FULL_TO_SHARD_3:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_3]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<128x4x8xi32>) -> tensor<128x4x2xi32> // CHECK-NEXT: %[[CONCAT:.*]] = stablehlo.concatenate %[[FULL_TO_SHARD_3]], %[[FULL_TO_SHARD_2]], dim = 0 {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<128x4x2xi32>, tensor<3x4x2xi32>) -> tensor<131x4x2xi32> // CHECK-NEXT: %[[COPY_RESULT_2:.*]] = mhlo.copy %[[CONCAT]] {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : tensor<131x4x2xi32> - // CHECK-NEXT: %[[SHARD_TO_FULL_2:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_2]]) {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : (tensor<131x4x2xi32>) -> tensor<131x4x8xi32> + // CHECK-NEXT: %[[SHARD_TO_FULL_2:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_2]]) {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : (tensor<131x4x2xi32>) -> tensor<131x4x8xi32> // CHECK-NEXT: return %[[SHARD_TO_FULL_2]] : tensor<131x4x8xi32> %1 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_0, [{}, {}, {"a"}]>] out_shardings=[<@mesh_0, [{}, {}, {"a"}]>] manual_axes={"a"} (%arg2: tensor<2x4x2xi32>) { %4 = stablehlo.custom_call @sdy_testonly(%arg2) : (tensor<2x4x2xi32>) -> tensor<3x4x2xi32> diff --git a/third_party/xla/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import.mlir b/third_party/xla/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import.mlir index 12641b0d746476..a62c58cc7a9e96 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import.mlir @@ -24,11 +24,11 @@ func.func public @call_op_with_one_operand_and_no_results(%arg0: tensor<4xf32>) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{}], replicated={"a"}>] out_shardings=[] manual_axes={"a"} (%arg1: tensor<4xf32>) { // CHECK-NEXT: sdy.return // CHECK-NEXT: } : (tensor<4xf32>) -> () - // CHECK-NEXT: %0 = mhlo.add %arg0, %arg0 : tensor<4xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}], replicated={"a"}>]>} : (tensor<4xf32>) -> tensor<4xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0 = stablehlo.add %arg0, %arg0 : tensor<4xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}], replicated={"a"}>]>} : (tensor<4xf32>) -> tensor<4xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<4xf32>) -> tensor<4xf32> call @shmap_body_one_argument_empty_body(%1) : (tensor<4xf32>) -> () - %2 = mhlo.add %arg0, %arg0 : tensor<4xf32> + %2 = stablehlo.add %arg0, %arg0 : tensor<4xf32> return %2 : tensor<4xf32> } // CHECK-NOT: func.func private @shmap_body_one_argument_empty_body @@ -40,18 +40,18 @@ func.func private @shmap_body_one_argument_empty_body(%arg0: tensor<4xf32>) -> ( func.func public @call_op_with_no_operands_and_one_result() -> tensor<4xf32> { // CHECK: %0 = sdy.manual_computation() // CHECK-SAME{LITERAL}: in_shardings=[] out_shardings=[<@mesh_0, [{}], replicated={"a"}>] manual_axes={"a"} () { - // CHECK-LABEL: %1 = mhlo.constant - // CHECK-NEXT: sdy.return %1 : tensor<4xf32> + // CHECK-LABEL: %cst = stablehlo.constant + // CHECK-NEXT: sdy.return %cst : tensor<4xf32> // CHECK-NEXT: } : () -> tensor<4xf32> // CHECK-NEXT: return %0 : tensor<4xf32> %0 = call @shmap_body_no_arg() : () -> (tensor<4xf32>) - %1 = mhlo.custom_call @Sharding(%0) : (tensor<4xf32>) -> tensor<4xf32> - %2 = mhlo.custom_call @SPMDShardToFullShape(%1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}], replicated={"a"}>]>} : (tensor<4xf32>) -> tensor<4xf32> + %1 = stablehlo.custom_call @Sharding(%0) : (tensor<4xf32>) -> tensor<4xf32> + %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}], replicated={"a"}>]>} : (tensor<4xf32>) -> tensor<4xf32> return %2 : tensor<4xf32> } // CHECK-NOT: func.func private @shmap_body_no_arg() func.func private @shmap_body_no_arg() -> tensor<4xf32> { - %0 = mhlo.constant dense <[0.0, 1.0, 2.0, 3.0]> : tensor<4xf32> + %0 = stablehlo.constant dense <[0.0, 1.0, 2.0, 3.0]> : tensor<4xf32> return %0 : tensor<4xf32> } @@ -59,20 +59,20 @@ func.func private @shmap_body_no_arg() -> tensor<4xf32> { func.func public @call_op_with_shamp_body_in_middle(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { // CHECK: %0 = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<4x32xf32>) { - // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<4x32xf32> + // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<4x32xf32> // CHECK-NEXT: sdy.return %1 : tensor<4x32xf32> // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: return %0 : tensor<16x32xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> %2 = call @prefix_shmap_body_suffix(%1) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> return %4 : tensor<16x32xf32> } // CHECK-NOT: func.func private @shmap_body func.func private @prefix_shmap_body_suffix(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<4x32xf32> return %0 : tensor<4x32xf32> } @@ -80,20 +80,20 @@ func.func private @prefix_shmap_body_suffix(%arg0: tensor<4x32xf32>) -> (tensor< func.func public @shard_map_single_sharded_input_output_dim_0(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { // CHECK: %0 = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<4x32xf32>) { - // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<4x32xf32> + // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<4x32xf32> // CHECK-NEXT: sdy.return %1 : tensor<4x32xf32> // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: return %0 : tensor<16x32xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> %2 = call @shmap_body(%1) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> return %4 : tensor<16x32xf32> } // CHECK-NOT: func.func private @shmap_body func.func private @shmap_body(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<4x32xf32> return %0 : tensor<4x32xf32> } @@ -101,20 +101,20 @@ func.func private @shmap_body(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>) { func.func public @shard_map_single_sharded_input_output_dim_1(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { // CHECK: %0 = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{}, {"a"}]>] out_shardings=[<@mesh_1, [{}, {"a"}]>] manual_axes={"a"} (%arg1: tensor<16x8xf32>) { - // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<16x8xf32> + // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<16x8xf32> // CHECK-NEXT: sdy.return %1 : tensor<16x8xf32> // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: return %0 : tensor<16x32xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a"}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x8xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a"}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x8xf32> %2 = call @shmap_body_0(%1) : (tensor<16x8xf32>) -> tensor<16x8xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<16x8xf32>) -> tensor<16x8xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a"}]>]>} : (tensor<16x8xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<16x8xf32>) -> tensor<16x8xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a"}]>]>} : (tensor<16x8xf32>) -> tensor<16x32xf32> return %4 : tensor<16x32xf32> } // CHECK-NOT: func.func private @shmap_body_0 func.func private @shmap_body_0(%arg0: tensor<16x8xf32>) -> (tensor<16x8xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x8xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x8xf32> return %0 : tensor<16x8xf32> } @@ -122,20 +122,20 @@ func.func private @shmap_body_0(%arg0: tensor<16x8xf32>) -> (tensor<16x8xf32>) { func.func public @shard_map_single_replicated_input_sharded_output(%arg0: tensor<16x32xf32>) -> tensor<16x256xf32> { // CHECK: %0 = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{}, {}], replicated={"a", "b"}>] out_shardings=[<@mesh_1, [{}, {"a", "b"}]>] manual_axes={"a", "b"} (%arg1: tensor<16x32xf32>) { - // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<16x32xf32> + // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<16x32xf32> // CHECK-NEXT: sdy.return %1 : tensor<16x32xf32> // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x256xf32> // CHECK-NEXT: return %0 : tensor<16x256xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={"a", "b"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={"a", "b"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> %2 = call @shmap_body_1(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a", "b"}]>]>} : (tensor<16x32xf32>) -> tensor<16x256xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a", "b"}]>]>} : (tensor<16x32xf32>) -> tensor<16x256xf32> return %4 : tensor<16x256xf32> } // CHECK-NOT func.func private @shmap_body_1 func.func private @shmap_body_1(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -143,51 +143,51 @@ func.func private @shmap_body_1(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) func.func public @shard_map_contracting_dim_matmul_all_reduce(%arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>) -> tensor<8x32xf32> { // CHECK: %0 = sdy.manual_computation(%arg0, %arg1) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{"a"}, {"b"}]>, <@mesh_1, [{"b"}, {}], replicated={"a"}>] out_shardings=[<@mesh_1, [{"a"}, {}], replicated={"b"}>] manual_axes={"a", "b"} (%arg2: tensor<2x8xf32>, %arg3: tensor<8x32xf32>) { - // CHECK-NEXT: %1 = "mhlo.dot_general"(%arg2, %arg3) <{dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]}> : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> - // CHECK-NEXT: %2 = "mhlo.all_reduce"(%1) <{ - // CHECK-SAME{LITERAL}: channel_handle = #mhlo.channel_handle, replica_groups = dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xi64>, use_global_device_ids + // CHECK-NEXT: %1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> + // CHECK-NEXT: %2 = "stablehlo.all_reduce"(%1) <{ + // CHECK-SAME{LITERAL}: channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xi64>, use_global_device_ids // CHECK-SAME: }> ({ // CHECK-NEXT: ^bb0(%arg4: tensor, %arg5: tensor): - // CHECK-NEXT: %3 = mhlo.add %arg4, %arg5 : tensor - // CHECK-NEXT: mhlo.return %3 : tensor + // CHECK-NEXT: %3 = stablehlo.add %arg4, %arg5 : tensor + // CHECK-NEXT: stablehlo.return %3 : tensor // CHECK-NEXT: }) : (tensor<2x32xf32>) -> tensor<2x32xf32> // CHECK-NEXT: sdy.return %2 : tensor<2x32xf32> // CHECK-NEXT: } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> // CHECK-NEXT: return %0 : tensor<8x32xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a"}, {"b"}]>]>} : (tensor<8x16xf32>) -> tensor<8x16xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<8x16xf32>) -> tensor<2x8xf32> - %2 = mhlo.custom_call @Sharding(%arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b"}, {}], replicated={"a"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @SPMDFullToShardShape(%2) : (tensor<16x32xf32>) -> tensor<8x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a"}, {"b"}]>]>} : (tensor<8x16xf32>) -> tensor<8x16xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<8x16xf32>) -> tensor<2x8xf32> + %2 = stablehlo.custom_call @Sharding(%arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b"}, {}], replicated={"a"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) : (tensor<16x32xf32>) -> tensor<8x32xf32> %4 = call @shmap_body_2(%1, %3) : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> - %5 = mhlo.custom_call @Sharding(%4) : (tensor<2x32xf32>) -> tensor<2x32xf32> - %6 = mhlo.custom_call @SPMDShardToFullShape(%5) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a"}, {}], replicated={"b"}>]>}: (tensor<2x32xf32>) -> tensor<8x32xf32> + %5 = stablehlo.custom_call @Sharding(%4) : (tensor<2x32xf32>) -> tensor<2x32xf32> + %6 = stablehlo.custom_call @SPMDShardToFullShape(%5) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a"}, {}], replicated={"b"}>]>}: (tensor<2x32xf32>) -> tensor<8x32xf32> return %6 : tensor<8x32xf32> } // CHECK-NOT: func.func private @shmap_body_2 func.func private @shmap_body_2(%arg0: tensor<2x8xf32>, %arg1: tensor<8x32xf32>) -> (tensor<2x32xf32>) { - %0 = "mhlo.dot_general"(%arg0, %arg1) <{dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]}> : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> - %1 = "mhlo.all_reduce"(%0) ({ + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> + %1 = "stablehlo.all_reduce"(%0) ({ ^bb0(%arg2: tensor, %arg3: tensor): - %2 = mhlo.add %arg2, %arg3 : tensor - mhlo.return %2 : tensor - }) {channel_handle = #mhlo.channel_handle, replica_groups = dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xi64>, use_global_device_ids} : (tensor<2x32xf32>) -> tensor<2x32xf32> + %2 = stablehlo.add %arg2, %arg3 : tensor + stablehlo.return %2 : tensor + }) {channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xi64>, use_global_device_ids} : (tensor<2x32xf32>) -> tensor<2x32xf32> return %1 : tensor<2x32xf32> } // CHECK-LABEL: func.func public @shard_map_wrong_callee_name func.func public @shard_map_wrong_callee_name(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> // CHECK: call @shmap_head // CHECK-NOT: sdy.manual_computation %2 = call @shmap_head(%1) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> return %4 : tensor<16x32xf32> } // CHECK-LABEL: func.func private @shmap_head func.func private @shmap_head(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<4x32xf32> return %0 : tensor<4x32xf32> } @@ -197,16 +197,16 @@ func.func public @shard_map_multiple_results(%arg0: tensor<16x32xf32>) -> tensor // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{}, {}], replicated={"a", "b"}>] out_shardings=[<@mesh_1, [{"a", "b"}, {}]>, <@mesh_1, [{"b", "a"}, {}]>] manual_axes={"a", "b"} (%arg1: tensor<16x32xf32>) { // CHECK-NEXT: sdy.return %arg1, %arg1 : tensor<16x32xf32>, tensor<16x32xf32> // CHECK-NEXT: } : (tensor<16x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>) - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[SHARD_MAP]]#0, %[[SHARD_MAP]]#1 : tensor<128x32xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[SHARD_MAP]]#0, %[[SHARD_MAP]]#1 : tensor<128x32xf32> // CHECK-NEXT: return %[[ADD]] : tensor<128x32xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={"a", "b"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={"a", "b"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> %2:2 = call @shmap_body_4(%1) : (tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) - %3 = mhlo.custom_call @Sharding(%2#0) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a", "b"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32> - %5 = mhlo.custom_call @Sharding(%2#1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %6 = mhlo.custom_call @SPMDShardToFullShape(%5) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b", "a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32> - %7 = mhlo.add %4, %6 : tensor<128x32xf32> + %3 = stablehlo.custom_call @Sharding(%2#0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a", "b"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32> + %5 = stablehlo.custom_call @Sharding(%2#1) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %6 = stablehlo.custom_call @SPMDShardToFullShape(%5) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b", "a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32> + %7 = stablehlo.add %4, %6 : tensor<128x32xf32> return %7 : tensor<128x32xf32> } // CHECK-NOT: func.func private @shmap_body_4 @@ -218,46 +218,46 @@ func.func private @shmap_body_4(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, func.func public @shard_map_multiple_call_ops(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>, tensor<16x32xf32>) { // CHECK-NEXT: %[[SHARD_MAP_0:.*]] = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<4x32xf32>) { - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %arg1, %arg1 + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %arg1, %arg1 // CHECK-NEXT: sdy.return %[[ADD_0]] // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: %[[SHARD_MAP_1:.*]] = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{}, {"a"}]>] out_shardings=[<@mesh_1, [{}, {"a"}]>] manual_axes={"a"} (%arg1: tensor<16x8xf32>) { - // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %arg1, %arg1 + // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %arg1, %arg1 // CHECK-NEXT: sdy.return %[[MUL]] // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: %[[SHARD_MAP_2:.*]] = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<4x32xf32>) { - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %arg1, %arg1 + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %arg1, %arg1 // CHECK-NEXT: sdy.return %[[ADD_1]] // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: return %[[SHARD_MAP_0]], %[[SHARD_MAP_1]], %[[SHARD_MAP_2]] - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> %2 = call @shmap_body_5(%1) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> - %5 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a"}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %6 = mhlo.custom_call @SPMDFullToShardShape(%5) : (tensor<16x32xf32>) -> tensor<16x8xf32> + %5 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a"}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %6 = stablehlo.custom_call @SPMDFullToShardShape(%5) : (tensor<16x32xf32>) -> tensor<16x8xf32> %7 = call @shmap_body_6(%6) : (tensor<16x8xf32>) -> tensor<16x8xf32> - %8 = mhlo.custom_call @Sharding(%7) : (tensor<16x8xf32>) -> tensor<16x8xf32> - %9 = mhlo.custom_call @SPMDShardToFullShape(%8) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a"}]>]>} : (tensor<16x8xf32>) -> tensor<16x32xf32> + %8 = stablehlo.custom_call @Sharding(%7) : (tensor<16x8xf32>) -> tensor<16x8xf32> + %9 = stablehlo.custom_call @SPMDShardToFullShape(%8) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a"}]>]>} : (tensor<16x8xf32>) -> tensor<16x32xf32> %10 = call @shmap_body_5(%1) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %11 = mhlo.custom_call @Sharding(%10) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %12 = mhlo.custom_call @SPMDShardToFullShape(%11) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %11 = stablehlo.custom_call @Sharding(%10) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %12 = stablehlo.custom_call @SPMDShardToFullShape(%11) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> return %4, %9, %12 : tensor<16x32xf32>, tensor<16x32xf32>, tensor<16x32xf32> } // CHECK-NOT: func.func private @shmap_body func.func private @shmap_body_5(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<4x32xf32> return %0 : tensor<4x32xf32> } // CHECK-NOT: func.func private @shmap_body func.func private @shmap_body_6(%arg0: tensor<16x8xf32>) -> (tensor<16x8xf32>) { - %0 = mhlo.multiply %arg0, %arg0 : tensor<16x8xf32> + %0 = stablehlo.multiply %arg0, %arg0 : tensor<16x8xf32> return %0 : tensor<16x8xf32> } @@ -265,42 +265,42 @@ func.func private @shmap_body_6(%arg0: tensor<16x8xf32>) -> (tensor<16x8xf32>) { func.func public @sharding_with_missing_manual_axes(%arg0: tensor<16x16xf32>) -> tensor<32x4xf32> { // CHECK: %0 = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_2, [{"b"}, {"a"}]>] out_shardings=[<@mesh_2, [{"a"}, {}], replicated={"c"}>] manual_axes={"a", "b", "c"} (%arg1: tensor<8x4xf32>) { - // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<8x4xf32> + // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<8x4xf32> // CHECK-NEXT: sdy.return %1 : tensor<8x4xf32> // CHECK-NEXT: } : (tensor<16x16xf32>) -> tensor<32x4xf32> // CHECK-NEXT: return %0 : tensor<32x4xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"b"}, {"a"}]>]>} : (tensor<16x16xf32>) -> tensor<16x16xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x16xf32>) -> tensor<8x4xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"b"}, {"a"}]>]>} : (tensor<16x16xf32>) -> tensor<16x16xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x16xf32>) -> tensor<8x4xf32> %2 = call @shmap_body_7(%1) : (tensor<8x4xf32>) -> tensor<8x4xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<8x4xf32>) -> tensor<8x4xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"a"}, {}], replicated={"c"}>]>} : (tensor<8x4xf32>) -> tensor<32x4xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<8x4xf32>) -> tensor<8x4xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"a"}, {}], replicated={"c"}>]>} : (tensor<8x4xf32>) -> tensor<32x4xf32> return %4 : tensor<32x4xf32> } // CHECK-NOT: func.func private @shmap_body_5 func.func private @shmap_body_7(%arg0: tensor<8x4xf32>) -> (tensor<8x4xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<8x4xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<8x4xf32> return %0 : tensor<8x4xf32> } // CHECK-LABEL: func.func public @shard_map_sharding_custom_call_other_uses func.func public @shard_map_sharding_custom_call_other_uses(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) { - // CHECk-NEXT: %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} + // CHECk-NEXT: %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} // CHECK: %1 = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<4x32xf32>) { - // CHECK-NEXT: %2 = mhlo.add %arg1, %arg1 : tensor<4x32xf32> + // CHECK-NEXT: %2 = stablehlo.add %arg1, %arg1 : tensor<4x32xf32> // CHECK-NEXT: sdy.return %2 : tensor<4x32xf32> // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: return %1, %0 - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> %2 = call @shmap_body_8(%1) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> return %4, %0 : tensor<16x32xf32>, tensor<16x32xf32> } // CHECK-NOT: func.func private @shmap_body func.func private @shmap_body_8(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<4x32xf32> return %0 : tensor<4x32xf32> } @@ -308,22 +308,22 @@ func.func private @shmap_body_8(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>) { func.func public @shard_map_unused_results(%arg0: tensor<16x32xf32>) -> tensor<128x32xf32> { // CHECK: %[[SHARD_MAP:.*]] = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{}, {}], replicated={"a", "b"}>] out_shardings=[<@mesh_1, [{"b", "a"}, {}]>] manual_axes={"a", "b"} (%arg1: tensor<16x32xf32>) { - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg1, %arg1 - // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %[[ADD]], %[[ADD]] + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg1, %arg1 + // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %[[ADD]], %[[ADD]] // CHECK-NEXT: sdy.return %[[ADD]] // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<128x32xf32> // CHECK-NEXT: return %[[SHARD_MAP]] - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={"a", "b"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={"a", "b"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> %2:3 = call @shmap_body_9(%1) : (tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>, tensor<16x32xf32>) - %3 = mhlo.custom_call @Sharding(%2#1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b", "a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32> + %3 = stablehlo.custom_call @Sharding(%2#1) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b", "a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32> return %4 : tensor<128x32xf32> } // CHECK-NOT: func.func private @shmap_body_9 func.func private @shmap_body_9(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>, tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> - %1 = mhlo.multiply %0, %0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> + %1 = stablehlo.multiply %0, %0 : tensor<16x32xf32> return %0, %0, %1 : tensor<16x32xf32>, tensor<16x32xf32>, tensor<16x32xf32> } @@ -331,32 +331,32 @@ func.func private @shmap_body_9(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, func.func public @shard_map_multiple_call_ops_unused_result_in_one(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>, tensor<4x128xf32>) { // CHECK-NEXT: %[[SHARD_MAP_0:.*]] = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<4x32xf32>) { - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %arg1, %arg1 + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %arg1, %arg1 // CHECK-NEXT: sdy.return %[[ADD_0]] // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: %[[SHARD_MAP_1:.*]]:2 = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>, <@mesh_0, [{}, {"a"}]>] manual_axes={"a"} (%arg1: tensor<4x32xf32>) { - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %arg1, %arg1 + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %arg1, %arg1 // CHECK-NEXT: sdy.return %[[ADD_1]], %[[ADD_1]] // CHECK-NEXT: } : (tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<4x128xf32>) // CHECK-NEXT: return %[[SHARD_MAP_0]], %[[SHARD_MAP_1]]#0, %[[SHARD_MAP_1]]#1 - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> %2:2 = call @shmap_body_10(%1) : (tensor<4x32xf32>) -> (tensor<4x32xf32>, tensor<4x32xf32>) - %3 = mhlo.custom_call @Sharding(%2#0) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2#0) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> %5:2 = call @shmap_body_10(%1) : (tensor<4x32xf32>) -> (tensor<4x32xf32>, tensor<4x32xf32>) - %6 = mhlo.custom_call @Sharding(%5#0) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %7 = mhlo.custom_call @SPMDShardToFullShape(%6) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> - %8 = mhlo.custom_call @Sharding(%5#1) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %9 = mhlo.custom_call @SPMDShardToFullShape(%8) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {"a"}]>]>} : (tensor<4x32xf32>) -> tensor<4x128xf32> + %6 = stablehlo.custom_call @Sharding(%5#0) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %7 = stablehlo.custom_call @SPMDShardToFullShape(%6) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %8 = stablehlo.custom_call @Sharding(%5#1) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %9 = stablehlo.custom_call @SPMDShardToFullShape(%8) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {"a"}]>]>} : (tensor<4x32xf32>) -> tensor<4x128xf32> return %4, %7, %9 : tensor<16x32xf32>, tensor<16x32xf32>, tensor<4x128xf32> } // CHECK-NOT: func.func private @shmap_body func.func private @shmap_body_10(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>, tensor<4x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<4x32xf32> return %0, %0 : tensor<4x32xf32>, tensor<4x32xf32> } @@ -364,19 +364,19 @@ func.func private @shmap_body_10(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>, func.func public @shard_map_duplicate_operand(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { // CHECK: %0 = sdy.manual_computation(%arg0, %arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>, <@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<4x32xf32>, %arg2: tensor<4x32xf32>) { - // CHECK-NEXT: %1 = mhlo.add %arg1, %arg2 : tensor<4x32xf32> + // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg2 : tensor<4x32xf32> // CHECK-NEXT: sdy.return %1 : tensor<4x32xf32> // CHECK-NEXT: } : (tensor<16x32xf32>, tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: return %0 : tensor<16x32xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> %2 = call @shmap_body_11(%1, %1) : (tensor<4x32xf32>, tensor<4x32xf32>) -> tensor<4x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> return %4 : tensor<16x32xf32> } // CHECK-NOT: func.func private @shmap_body func.func private @shmap_body_11(%arg0: tensor<4x32xf32>, %arg1: tensor<4x32xf32>) -> (tensor<4x32xf32>) { - %0 = mhlo.add %arg0, %arg1 : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<4x32xf32> return %0 : tensor<4x32xf32> } diff --git a/third_party/xla/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import_failure.mlir b/third_party/xla/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import_failure.mlir index 51b1a4e49f7a9e..e41b9b7fe3e0a1 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import_failure.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import_failure.mlir @@ -4,16 +4,16 @@ sdy.mesh @mesh_1 = <["a"=4, "b"=2]> sdy.mesh @mesh_2 = <["a"=4, "b"=2, "c"=3]> func.func public @multiple_meshes(%arg0: tensor<16x16xf32>) -> tensor<32x4xf32> { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b"}, {"a"}]>]>} : (tensor<16x16xf32>) -> tensor<16x16xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x16xf32>) -> tensor<8x4xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b"}, {"a"}]>]>} : (tensor<16x16xf32>) -> tensor<16x16xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x16xf32>) -> tensor<8x4xf32> // expected-error @+1 {{Multiple meshes in a single manual computation.}} %2 = call @shmap_body_0(%1) : (tensor<8x4xf32>) -> tensor<8x4xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<8x4xf32>) -> tensor<8x4xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"a"}, {}], replicated={"c"}>]>} : (tensor<8x4xf32>) -> tensor<32x4xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<8x4xf32>) -> tensor<8x4xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"a"}, {}], replicated={"c"}>]>} : (tensor<8x4xf32>) -> tensor<32x4xf32> return %4 : tensor<32x4xf32> } func.func private @shmap_body_0(%arg0: tensor<8x4xf32>) -> (tensor<8x4xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<8x4xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<8x4xf32> return %0 : tensor<8x4xf32> } @@ -24,12 +24,12 @@ sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { // expected-error @+1 {{expecting CustomCallOp as operand}} %0 = call @shmap_body_1(%arg0) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @Sharding(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %2 = mhlo.custom_call @SPMDShardToFullShape(%1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @Sharding(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> return %2 : tensor<16x32xf32> } func.func private @shmap_body_1(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -38,15 +38,15 @@ func.func private @shmap_body_1(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting SPMDFullToShardShape custom call as operand}} %1 = call @shmap_body_1(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %2 = mhlo.custom_call @Sharding(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @SPMDShardToFullShape(%2) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %2 = stablehlo.custom_call @Sharding(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @SPMDShardToFullShape(%2) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> return %3 : tensor<16x32xf32> } func.func private @shmap_body_1(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -55,15 +55,15 @@ func.func private @shmap_body_1(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting CustomCallOp as operand of SPMDFullToShardShape}} %1 = call @shmap_body(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %2 = mhlo.custom_call @Sharding(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @SPMDShardToFullShape(%2) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %2 = stablehlo.custom_call @Sharding(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @SPMDShardToFullShape(%2) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> return %3 : tensor<16x32xf32> } func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -72,16 +72,16 @@ func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @SPMDFullToShardShape(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting Sharding CustomCallOp as operand of SPMDFullToShardShape}} %2 = call @shmap_body(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> return %4 : tensor<16x32xf32> } func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -90,16 +90,16 @@ func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting each result of shmap_body to have one or no uses}} %2 = call @shmap_body(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> - mhlo.custom_call @SPMDShardToFullShape(%2) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> + stablehlo.custom_call @SPMDShardToFullShape(%2) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> return %3 : tensor<16x32xf32> } func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -108,16 +108,16 @@ func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting Sharding CustomCallOp user of the result to have one use}} %2 = call @shmap_body(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> return %4, %3 : tensor<16x32xf32>, tensor<16x32xf32> } func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -126,14 +126,14 @@ func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting CustomCallOp as the use of the result of the CallOp}} %2 = call @shmap_body(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> return %2 : tensor<16x32xf32> } func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -142,16 +142,16 @@ func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting Sharding CustomCallOp as the use of the result of the CallOp}} %2 = call @shmap_body(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @SPMDShardToFullShape(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @SPMDShardToFullShape(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> return %4 : tensor<16x32xf32> } func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -160,15 +160,15 @@ func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting CustomCallOp as the use of Sharding CustomCallOp}} %2 = call @shmap_body(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> return %3 : tensor<16x32xf32> } func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -177,15 +177,15 @@ func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting SPMDShardToFullShape CustomCallOp as the use of Sharding CustomCallOp}} %2 = call @shmap_body(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %4 = mhlo.custom_call @Sharding(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %4 = stablehlo.custom_call @Sharding(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> return %4 : tensor<16x32xf32> } func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } diff --git a/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir b/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir index ca2fd01b7b28d8..fe13f45d4e09a4 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir @@ -9,38 +9,38 @@ func.func @while_with_free_variables( %arg1: tensor<32x96xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a"}, {}]>}, %arg2: tensor<32x96xf32>) -> (tensor<32x96xf32>, tensor<32x96xf32>) { - // CHECK-NEXT: %[[C0:.*]] = mhlo.constant dense<0> - // CHECK-NEXT: %[[C1:.*]] = mhlo.constant dense<1> - // CHECK-NEXT: %[[C32:.*]] = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh2, [{}, {"b"}]>]>} + // CHECK-NEXT: %[[C0:.*]] = stablehlo.constant dense<0> + // CHECK-NEXT: %[[C1:.*]] = stablehlo.constant dense<1> + // CHECK-NEXT: %[[C32:.*]] = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh2, [{}, {"b"}]>]>} // CHECK-NEXT: %[[SC_0:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]> // CHECK-NEXT: %[[SC_1:.*]] = sdy.sharding_constraint %[[ADD_0]] <@mesh2, [{?}, {?}]> - // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: %[[WHILE:.*]]:2 = stablehlo.while(%iterArg = %arg0, %iterArg_2 = %[[C0]]) // CHECK-NEXT: cond { - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] - // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: %[[COND:.*]] = stablehlo.compare LT, %iterArg_2, %[[C32]] + // CHECK-NEXT: stablehlo.return %[[COND]] // CHECK-NEXT: } do { - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg_0, %[[C1]] - // CHECK-NEXT: %[[ADD_2:.*]] = mhlo.add %iterArg, %[[SC_0]] - // CHECK-NEXT: %[[ADD_3:.*]] = mhlo.add %[[ADD_2]], %arg2 - // CHECK-NEXT: %[[ADD_4:.*]] = mhlo.add %[[ADD_3]], %[[SC_1]] - // CHECK-NEXT: mhlo.return %[[ADD_4]], %[[ADD_1]] + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %iterArg_2, %[[C1]] + // CHECK-NEXT: %[[ADD_2:.*]] = stablehlo.add %iterArg, %[[SC_0]] + // CHECK-NEXT: %[[ADD_3:.*]] = stablehlo.add %[[ADD_2]], %arg2 + // CHECK-NEXT: %[[ADD_4:.*]] = stablehlo.add %[[ADD_3]], %[[SC_1]] + // CHECK-NEXT: stablehlo.return %[[ADD_4]], %[[ADD_1]] // CHECK-NEXT: } // CHECK-NEXT: return %[[ADD_0]], %[[WHILE]]#0 - %0 = mhlo.constant dense<0> : tensor - %1 = mhlo.constant dense<1> : tensor - %2 = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> : tensor - %3 = mhlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh2, [{}, {"b"}]>]>} : tensor<32x96xf32> - %4:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + %0 = stablehlo.constant dense<0> : tensor + %1 = stablehlo.constant dense<1> : tensor + %2 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> : tensor + %3 = stablehlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh2, [{}, {"b"}]>]>} : tensor<32x96xf32> + %4:2 = stablehlo.while(%iterArg = %arg0, %iterArg_2 = %0) : tensor<32x96xf32>, tensor cond { - %5 = mhlo.compare LT, %iterArg_0, %2 : (tensor, tensor) -> tensor - mhlo.return %5 : tensor + %5 = stablehlo.compare LT, %iterArg_2, %2 : (tensor, tensor) -> tensor + stablehlo.return %5 : tensor } do { - %5 = mhlo.add %iterArg_0, %1 : tensor - %6 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> - %7 = mhlo.add %6, %arg2 : tensor<32x96xf32> - %8 = mhlo.add %7, %3 : tensor<32x96xf32> - mhlo.return %8, %5 : tensor<32x96xf32>, tensor + %5 = stablehlo.add %iterArg_2, %1 : tensor + %6 = stablehlo.add %iterArg, %arg1 : tensor<32x96xf32> + %7 = stablehlo.add %6, %arg2 : tensor<32x96xf32> + %8 = stablehlo.add %7, %3 : tensor<32x96xf32> + stablehlo.return %8, %5 : tensor<32x96xf32>, tensor } return %3, %4#0 : tensor<32x96xf32>, tensor<32x96xf32> } @@ -50,44 +50,44 @@ func.func @free_var_used_in_multiple_while_ops( %arg0: tensor<32x96xf32>, %arg1: tensor<32x96xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a"}, {}]>}) -> tensor<32x96xf32> { - // CHECK-NEXT: %[[C0:.*]] = mhlo.constant dense<0> - // CHECK-NEXT: %[[C32:.*]] = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> + // CHECK-NEXT: %[[C0:.*]] = stablehlo.constant dense<0> + // CHECK-NEXT: %[[C32:.*]] = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> // CHECK-NEXT: %[[SC_0:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]> - // CHECK-NEXT: %[[WHILE_0:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: %[[WHILE_0:.*]]:2 = stablehlo.while(%iterArg = %arg0, %iterArg_1 = %[[C0]]) // CHECK-NEXT: cond { - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] - // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: %[[COND:.*]] = stablehlo.compare LT, %iterArg_1, %[[C32]] + // CHECK-NEXT: stablehlo.return %[[COND]] // CHECK-NEXT: } do { - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg, %[[SC_0]] - // CHECK-NEXT: mhlo.return %[[ADD_0]], %iterArg_0 + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %iterArg, %[[SC_0]] + // CHECK-NEXT: stablehlo.return %[[ADD_0]], %iterArg_1 // CHECK-NEXT: } // CHECK-NEXT: %[[SC_1:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]> - // CHECK-NEXT: %[[WHILE_1:.*]]:2 = mhlo.while(%iterArg = %[[WHILE_0]]#0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: %[[WHILE_1:.*]]:2 = stablehlo.while(%iterArg = %[[WHILE_0]]#0, %iterArg_1 = %[[C0]]) // CHECK-NEXT: cond { - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] - // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: %[[COND:.*]] = stablehlo.compare LT, %iterArg_1, %[[C32]] + // CHECK-NEXT: stablehlo.return %[[COND]] // CHECK-NEXT: } do { - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC_1]] - // CHECK-NEXT: mhlo.return %[[ADD_1]], %iterArg_0 + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %iterArg, %[[SC_1]] + // CHECK-NEXT: stablehlo.return %[[ADD_1]], %iterArg_1 // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE_1]]#0 - %0 = mhlo.constant dense<0> : tensor - %1 = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> : tensor - %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + %0 = stablehlo.constant dense<0> : tensor + %1 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> : tensor + %2:2 = stablehlo.while(%iterArg = %arg0, %iterArg_1 = %0) : tensor<32x96xf32>, tensor cond { - %4 = mhlo.compare LT, %iterArg_0, %1 : (tensor, tensor) -> tensor - mhlo.return %4 : tensor + %4 = stablehlo.compare LT, %iterArg_1, %1 : (tensor, tensor) -> tensor + stablehlo.return %4 : tensor } do { - %4 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> - mhlo.return %4, %iterArg_0 : tensor<32x96xf32>, tensor + %4 = stablehlo.add %iterArg, %arg1 : tensor<32x96xf32> + stablehlo.return %4, %iterArg_1 : tensor<32x96xf32>, tensor } - %3:2 = mhlo.while(%iterArg = %2#0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + %3:2 = stablehlo.while(%iterArg = %2#0, %iterArg_1 = %0) : tensor<32x96xf32>, tensor cond { - %4 = mhlo.compare LT, %iterArg_0, %1 : (tensor, tensor) -> tensor - mhlo.return %4 : tensor + %4 = stablehlo.compare LT, %iterArg_1, %1 : (tensor, tensor) -> tensor + stablehlo.return %4 : tensor } do { - %4 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> - mhlo.return %4, %iterArg_0 : tensor<32x96xf32>, tensor + %4 = stablehlo.add %iterArg, %arg1 : tensor<32x96xf32> + stablehlo.return %4, %iterArg_1 : tensor<32x96xf32>, tensor } return %3#0 : tensor<32x96xf32> } diff --git a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir index f8a760d6c5794f..cf0dc80b83006d 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir @@ -13,8 +13,8 @@ // CHECK-SAME: %arg0: tensor<8x16xf32>) func.func @main( %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<8x16xf32> - %1 = mhlo.add %0, %0 : tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> + %1 = stablehlo.add %0, %0 : tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -33,8 +33,8 @@ func.func @main( // CHECK: %arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b"}p4]>}) %arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b"}p4]>} ) -> (tensor<8x16xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<8x16xf32> - %1 = mhlo.add %0, %0 : tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> + %1 = stablehlo.add %0, %0 : tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -57,8 +57,8 @@ func.func @main( %arg0: tensor<8x16xf32> // CHECK-SAME: -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b"}p4]>}) { ) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b"}p4]>}) { - // CHECK: mhlo.add %arg0, %arg0 : tensor<8x16xf32> - %0 = mhlo.add %arg0, %arg0 : tensor<8x16xf32> + // CHECK: stablehlo.add %arg0, %arg0 : tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> return %0 : tensor<8x16xf32> } @@ -123,10 +123,10 @@ sdy.mesh @mesh = <["a"=2, "b"=2, "c"=2]> func.func @main( %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}p4]>}, %arg1: tensor<8x8xf32>, %arg2: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK: %[[ADD:.*]] = mhlo.add %arg0, %arg1 : tensor<8x8xf32> + // CHECK: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> // CHECK-NEXT: %[[WSC:.*]] = sdy.sharding_constraint %0 <@mesh, [{}, {"c", ?}p1]> : tensor<8x8xf32> // CHECK-NEXT: return %[[WSC]] : tensor<8x8xf32> - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> %1 = sdy.sharding_constraint %0 <@mesh, [{}, {"c", ?}p1]> : tensor<8x8xf32> return %1 : tensor<8x8xf32> } @@ -168,10 +168,10 @@ sdy.mesh @mesh_2 = <["x"=8, "y"=4]> func.func @main( // CHECK: %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x"}, {"y"}p1]>}) { %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x"}, {"y"}p1]>}) { - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg0, %arg0 : tensor<8x16xf32> - %0 = mhlo.add %arg0, %arg0 : tensor<8x16xf32> - // CHECK-NEXT: %[[CUSTOM_CALL:.*]]:2 = mhlo.custom_call @sdy_testonly(%arg0) {backend_config = "", xla_shape = "(f32[8,16]{1,0}, f32[8,16]{1,0})"} : (tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) - %1:2 = mhlo.custom_call @sdy_testonly(%arg0) : (tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> + // CHECK-NEXT: %[[CUSTOM_CALL:.*]]:2 = stablehlo.custom_call @sdy_testonly(%arg0) {backend_config = "", xla_shape = "(f32[8,16]{1,0}, f32[8,16]{1,0})"} : (tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) + %1:2 = stablehlo.custom_call @sdy_testonly(%arg0) : (tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) // CHECK-NEXT: return %[[ADD]], %[[CUSTOM_CALL]]#0, %[[CUSTOM_CALL]]#1 return %0, %1#0, %1#1 : tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32> } @@ -186,33 +186,33 @@ sdy.mesh @mesh = <["x"=2]> // CHECK-LABEL: func @main func.func @main( %arg0: tensor<32x96xf32>, - %arg1: tensor<32x96xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{}, {}]>"}}) + %arg1: tensor<32x96xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}) -> tensor<32x96xf32> { // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0> - // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32> // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]> - // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: %[[WHILE:.*]]:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) // CHECK-NEXT: cond { - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] - // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32> + // CHECK-NEXT: %[[COND:.*]] = stablehlo.compare LT, %iterArg_0, %[[C32]] + // CHECK-NEXT: stablehlo.return %[[COND]] // CHECK-NEXT: } do { // CHECK-DAG: %[[C1:.*]] = sdy.constant dense<1> - // CHECK-DAG: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]] - // CHECK-DAG: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]] - // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]] + // CHECK-DAG: %[[ADD_0:.*]] = stablehlo.add %iterArg_0, %[[C1]] + // CHECK-DAG: %[[ADD_1:.*]] = stablehlo.add %iterArg, %[[SC]] + // CHECK-NEXT: stablehlo.return %[[ADD_1]], %[[ADD_0]] // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE]]#0 %0 = sdy.constant dense<0> : tensor %1 = sdy.constant dense<32> : tensor - %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + %2:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor cond { - %3 = mhlo.compare LT, %iterArg_0, %1 : (tensor, tensor) -> tensor - mhlo.return %3 : tensor + %3 = stablehlo.compare LT, %iterArg_0, %1 : (tensor, tensor) -> tensor + stablehlo.return %3 : tensor } do { %3 = sdy.constant dense<1> : tensor - %4 = mhlo.add %iterArg_0, %3 : tensor - %5 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> - mhlo.return %5, %4 : tensor<32x96xf32>, tensor + %4 = stablehlo.add %iterArg_0, %3 : tensor + %5 = stablehlo.add %iterArg, %arg1 : tensor<32x96xf32> + stablehlo.return %5, %4 : tensor<32x96xf32>, tensor } return %2#0 : tensor<32x96xf32> } @@ -236,22 +236,22 @@ func.func @main(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>) { func.func @main(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { // CHECK: %[[NC:.*]]:2 = sdy.named_computation<"g.2.2">(%arg0) (%arg1: tensor<8x2xi32>) { - // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %arg1, %arg1 : tensor<8x2xi32> + // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %arg1, %arg1 : tensor<8x2xi32> // CHECK-NEXT: sdy.return %[[MUL]], %[[MUL]] : tensor<8x2xi32>, tensor<8x2xi32> // CHECK-NEXT: } {mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}} : (tensor<8x2xi32>) -> (tensor<8x2xi32>, tensor<8x2xi32>) - // CHECK-NEXT: %[[HOST:.*]] = mhlo.custom_call @MoveToHost(%[[NC]]#0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + // CHECK-NEXT: %[[HOST:.*]] = stablehlo.custom_call @MoveToHost(%[[NC]]#0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> // CHECK-NEXT: return %[[HOST]] : tensor<8x2xi32> %0:2 = call @g.2(%arg0) {mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}, mhlo.sharding = "{{maximal device=0}, {replicated}}"} : (tensor<8x2xi32>) -> (tensor<8x2xi32>, tensor<8x2xi32>) - %1 = mhlo.custom_call @MoveToHost(%0#0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + %1 = stablehlo.custom_call @MoveToHost(%0#0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> return %1 : tensor<8x2xi32> } // CHECK-NOT: g.2 func.func private @g.2(%arg0: tensor<8x2xi32>) -> (tensor<8x2xi32>, tensor<8x2xi32>) { - %0 = mhlo.multiply %arg0, %arg0 : tensor<8x2xi32> + %0 = stablehlo.multiply %arg0, %arg0 : tensor<8x2xi32> return %0, %0 : tensor<8x2xi32>, tensor<8x2xi32> } -// TODO(b/335481977): Add more tests for MHLO ops. So far tested all SDY +// TODO(b/335481977): Add more tests for StableHLO ops. So far tested all SDY // compiler APIs other than shard as/like (doesn't exist yet). See // round_trip_pipeline_manual_computation.mlir for ManualComputationOp tests. diff --git a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline_manual_computation.mlir b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline_manual_computation.mlir index 90754f8e9bf0a2..54ec035eed6aa8 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline_manual_computation.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline_manual_computation.mlir @@ -16,12 +16,12 @@ func.func @main(%arg0: tensor<16x32xf32>) -> tensor<128x32xf32> { // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{}, {}], replicated={"a", "b"}>] out_shardings=[<@mesh_1, [{"a", "b"}, {}]>, <@mesh_1, [{"b", "a"}, {}]>] manual_axes={"a", "b"} (%arg1: tensor<16x32xf32>) { // CHECK-NEXT: sdy.return %arg1, %arg1 : tensor<16x32xf32>, tensor<16x32xf32> // CHECK-NEXT: } : (tensor<16x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>) - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[SHARD_MAP]]#0, %[[SHARD_MAP]]#1 : tensor<128x32xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[SHARD_MAP]]#0, %[[SHARD_MAP]]#1 : tensor<128x32xf32> // CHECK-NEXT: return %[[ADD]] : tensor<128x32xf32> - %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<16x32xf32>) -> tensor<16x32xf32> %1:2 = call @local_xla.sdy.manual_computation_body(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={\\\22a\\\22, \\\22b\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22, \\\22b\\\22}, {}]>, <@mesh_1, [{\\\22b\\\22, \\\22a\\\22}, {}]>]>"}} : (tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) - %2:2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1#0, %1#1) : (tensor<16x32xf32>, tensor<16x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>) - %3 = mhlo.add %2#0, %2#1 : tensor<128x32xf32> + %2:2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1#0, %1#1) : (tensor<16x32xf32>, tensor<16x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>) + %3 = stablehlo.add %2#0, %2#1 : tensor<128x32xf32> return %3 : tensor<128x32xf32> } // CHECK-NOT: func.func private @local_xla.sdy.manual_computation_body diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_inline_round_trip.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_inline_round_trip.mlir index d0ed401a2a4299..17b6681d2b5c77 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_inline_round_trip.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_inline_round_trip.mlir @@ -13,19 +13,19 @@ sdy.mesh @mesh = <["a"=2, "b"=2, "c"=2]> // CHECK-SAME: -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}, {}]>}) func.func @main(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}, {}]>}) { - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %arg0, %arg0 : tensor<8x16xf32> - // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %[[ADD_0]], %[[ADD_0]] : tensor<8x16xf32> - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %[[MUL]], %[[MUL]] : tensor<8x16xf32> + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> + // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %[[ADD_0]], %[[ADD_0]] : tensor<8x16xf32> + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %[[MUL]], %[[MUL]] : tensor<8x16xf32> // CHECK-NEXT: return %[[ADD_1]] : tensor<8x16xf32> - %0 = mhlo.add %arg0, %arg0 : tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> %1 = func.call @nested_func(%0) : (tensor<8x16xf32>) -> (tensor<8x16xf32>) - %2 = mhlo.add %1, %1 : tensor<8x16xf32> + %2 = stablehlo.add %1, %1 : tensor<8x16xf32> return %2 : tensor<8x16xf32> } // CHECK-NOT: func @nested_func func.func @nested_func(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"b"}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"b"}]>}) { - %0 = mhlo.multiply %arg0, %arg0 : tensor<8x16xf32> + %0 = stablehlo.multiply %arg0, %arg0 : tensor<8x16xf32> return %0 : tensor<8x16xf32> } diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir index 977de9208630fb..3004b8738a7528 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir @@ -20,25 +20,25 @@ sdy.mesh @mesh_2 = <["x"=8, "y"=4]> func.func @multiple_shardings(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"axis_2"}, {"axis_0", "axis_1"}]>}, %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_0", "axis_2"}]>}, %arg2: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_1"}]>}) -> tensor<8x16xf32> { -// CHECK-NEXT: mhlo.add +// CHECK-NEXT: stablehlo.add // CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22axis_1\\\22, \\\22axis_0\\\22}, {}]>]>"}, mhlo.sharding = - %0 = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> + %1 = stablehlo.dot %0, %arg2 : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } // CHECK-LABEL: func @multi_result_op func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>) { - %0 = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: mhlo.reduce + %0 = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: stablehlo.reduce // CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{}, {\\\22y\\\22}]>, <@mesh_2, [{\\\22y\\\22}, {}]>]>"}, mhlo.sharding = - %1:2 = mhlo.reduce(%arg0 init: %0), (%arg1 init: %0) across dimensions = [1] + %1:2 = stablehlo.reduce(%arg0 init: %0), (%arg1 init: %0) across dimensions = [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{}, {"y"}]>, <@mesh_2, [{"y"}, {}]>]>} : (tensor<4x64x8xf32>, tensor<4x64x8xf32>, tensor, tensor) -> (tensor<4x8xf32>, tensor<4x8xf32>) reducer(%arg2: tensor, %arg4: tensor) (%arg3: tensor, %arg5: tensor) { - %2 = mhlo.add %arg2, %arg4 : tensor - %3 = mhlo.add %arg3, %arg5 : tensor - mhlo.return %2, %3 : tensor, tensor + %2 = stablehlo.add %arg2, %arg4 : tensor + %3 = stablehlo.add %arg3, %arg5 : tensor + stablehlo.return %2, %3 : tensor, tensor } return %1#0, %1#1 : tensor<4x8xf32>, tensor<4x8xf32> } @@ -49,9 +49,9 @@ func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) // CHECK-SAME: -> tensor<8x16xf32> { func.func @split_axes(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"y"}, {"x":(2)2}]>}, %arg1: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x":(1)2}, {"x":(2)4}]>}) -> tensor<8x16xf32> { -// CHECK-NEXT: "mhlo.dot" +// CHECK-NEXT: stablehlo.dot // CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22:(1)2, \\\22x\\\22:(4)2}, {}]>]>"}, mhlo.sharding = - %1 = "mhlo.dot" (%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x":(1)2, "x":(4)2}, {}]>]>} : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %1 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x":(1)2, "x":(4)2}, {}]>]>} : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -60,7 +60,7 @@ func.func @func_result_sharding_returning_func_arg( // CHECK: %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32> {mhlo.sharding = %arg0: tensor<8x16xf32> ) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}) { - // CHECK: %[[CUSTOM_CALL:.*]] = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, ?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK: %[[CUSTOM_CALL:.*]] = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, ?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK-NEXT: return %[[CUSTOM_CALL]] : tensor<8x16xf32> return %arg0 : tensor<8x16xf32> } @@ -75,22 +75,22 @@ func.func @func_result_sharding_returning_op_value(%arg0: tensor<8x16xf32>) tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x"}, {"y"}p1]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{}, {}]>}) { - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg0, %arg0 : tensor<8x16xf32> - // CHECK-NEXT: %[[TEST_ONLY:.*]]:2 = mhlo.custom_call @sdy_testonly(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, \\\22y\\\22}, {}]>, <@mesh_2, [{\\\22y\\\22, \\\22x\\\22}, {}]>]>"}, mhlo.sharding = - // CHECK-NEXT: %[[ADD_RESULT_SHARDING_0:.*]] = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%[[ADD]]) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, ?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_0:.*]] = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_1:.*]] = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#1) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22}, {\\\22y\\\22}p1]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[ADD_RESULT_SHARDING_1:.*]] = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%[[ADD]]) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{}, {}]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> + // CHECK-NEXT: %[[TEST_ONLY:.*]]:2 = stablehlo.custom_call @sdy_testonly(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, \\\22y\\\22}, {}]>, <@mesh_2, [{\\\22y\\\22, \\\22x\\\22}, {}]>]>"}, mhlo.sharding = + // CHECK-NEXT: %[[ADD_RESULT_SHARDING_0:.*]] = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%[[ADD]]) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, ?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_0:.*]] = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_1:.*]] = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#1) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22}, {\\\22y\\\22}p1]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[ADD_RESULT_SHARDING_1:.*]] = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%[[ADD]]) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{}, {}]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK-NEXT: return %[[ADD_RESULT_SHARDING_0]], %[[TEST_ONLY_RES_SHARDING_0]], %[[TEST_ONLY_RES_SHARDING_1]], %[[ADD_RESULT_SHARDING_1]] - %0 = mhlo.add %arg0, %arg0 : tensor<8x16xf32> - %1:2 = mhlo.custom_call @sdy_testonly(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x","y"}, {}]>, <@mesh_2, [{"y","x"}, {}]>]>} : (tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) + %0 = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> + %1:2 = stablehlo.custom_call @sdy_testonly(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x","y"}, {}]>, <@mesh_2, [{"y","x"}, {}]>]>} : (tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) return %0, %1#0, %1#1, %0 : tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32> } // CHECK-LABEL: func @sharding_constraint // CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { func.func @sharding_constraint(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK: mhlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, ?}, {?}]>]>"}, mhlo.sharding = + // CHECK: stablehlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, ?}, {?}]>]>"}, mhlo.sharding = %0 = sdy.sharding_constraint %arg0 <@mesh_2, [{"x", ?}, {?}]> : tensor<8x8xf32> return %0 : tensor<8x8xf32> } @@ -98,14 +98,14 @@ func.func @sharding_constraint(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { // CHECK-LABEL: func @export_sharding_group // CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { func.func @export_sharding_group(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK: mhlo.custom_call @local_xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "12 : i64"}} + // CHECK: stablehlo.custom_call @local_xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "12 : i64"}} sdy.sharding_group %arg0 group_id = 12: tensor<8x8xf32> return %arg0 : tensor<8x8xf32> } // CHECK-LABEL: func @constant func.func @constant() -> tensor { - // CHECK-NEXT: %[[CONST:.*]] = mhlo.constant dense<0> + // CHECK-NEXT: %[[CONST:.*]] = stablehlo.constant dense<0> // CHECK-NEXT: return %[[CONST]] %0 = sdy.constant dense<0> : tensor return %0 : tensor @@ -118,11 +118,11 @@ func.func @constant() -> tensor { // CHECK-SAME: -> (tensor<32xi32> {mhlo.sharding = "{maximal device=5}"}) { func.func @inlined_mesh( %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding, [{"a"}]>} -) -> (tensor<32xi32> {sdy.sharding = #sdy.sharding, [{}]>}) { - // CHECK-NEXT: %[[SHARDING:.*]] = mhlo.custom_call @Sharding(%arg0) +) -> (tensor<32xi32> {sdy.sharding = #sdy.sharding, []>}) { + // CHECK-NEXT: %[[SHARDING:.*]] = stablehlo.custom_call @Sharding(%arg0) // CHECK-SAME: mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{\\\22c\\\22}]>]>"}, mhlo.sharding = "{devices=[4]<=[4]}"} - // CHECK-NEXT: %[[RESULT_SHARDING:.*]] = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%[[SHARDING]]) - // CHECK-SAME: mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{}]>]>"} + // CHECK-NEXT: %[[RESULT_SHARDING:.*]] = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%[[SHARDING]]) + // CHECK-SAME: mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, []>]>"} // CHECK-NEXT: return %[[RESULT_SHARDING]] %0 = sdy.sharding_constraint %arg0 , [{"c"}]> : tensor<32xi32> return %0 : tensor<32xi32> @@ -160,10 +160,10 @@ func.func @non_sdy_module(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,8 %arg1: tensor<8x8xf32> {mhlo.sharding = "{devices=[1,2,16]<=[32] last_tile_dim_replicate}"}, %arg2: tensor<8x16xf32> {mhlo.sharding = "{devices=[4,4,2]<=[2,16]T(1,0) last_tile_dim_replicate}"}) -> (tensor<8x16xf32> {mhlo.sharding = "{devices=[8,4]<=[32]}"}) { - // CHECK-NEXT: mhlo.add %arg0, %arg1 {mhlo.sharding = "{devices=[4,8]<=[8,4]T(1,0)}"} + // CHECK-NEXT: stablehlo.add %arg0, %arg1 {mhlo.sharding = "{devices=[4,8]<=[8,4]T(1,0)}"} // CHECK-NOT: xla.sdy.sharding // CHECK-NOT: xla.sdy.sharding_rule - %0 = mhlo.add %arg0, %arg1 {mhlo.sharding = "{devices=[4,8]<=[8,4]T(1,0)}"} : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 {mhlo.sharding = "{devices=[4,8]<=[8,4]T(1,0)}"} : tensor<8x8xf32> + %1 = stablehlo.dot %0, %arg2 : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir index 9c8e27a4871429..d2cecee843abf0 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir @@ -1,4 +1,4 @@ -// RUN: sdy_opt %s --split-input-file -xla-sdy-round-trip-import-pipeline 2>&1 | FileCheck %s +// RUN: sdy_opt %s --split-input-file -xla-sdy-import-constants -xla-sdy-round-trip-import-pipeline 2>&1 | FileCheck %s // CHECK-LABEL: module @multiple_func_result_shardings module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {xla.sdy.meshes = @@ -25,11 +25,11 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x %arg1: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22a\\\22}p1]>"}}, %arg2: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22c\\\22}p0]>"}} ) -> (tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>) { - %0 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %1 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %2 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p1]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %3 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %4 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%arg2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %0 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %1 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %2 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p1]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %3 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %4 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> return %0, %1, %2, %3, %1, %4 : tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32> } @@ -39,16 +39,16 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x // CHECK-SAME: ) -> ( // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>}, // CHECK-SAME: tensor<32xi32>) { - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg0, %arg1 + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 // CHECK-NEXT: return %arg0, %[[ADD]] // CHECK-NEXT: } func.func @func_result_shardings_used_by_other_ops( %arg0: tensor<32xi32>, %arg1: tensor<32xi32> ) -> (tensor<32xi32>, tensor<32xi32>) { - %0 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %1 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %2 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %3 = mhlo.add %1, %2 : tensor<32xi32> + %0 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %1 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %2 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %3 = stablehlo.add %1, %2 : tensor<32xi32> return %1, %3 : tensor<32xi32>, tensor<32xi32> } @@ -61,27 +61,27 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x // CHECK-NEXT: %[[C1:.*]] = sdy.constant dense<1> // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32> // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]> - // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: %[[WHILE:.*]]:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) // CHECK-NEXT: cond { - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] - // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: %[[COND:.*]] = stablehlo.compare LT, %iterArg_0, %[[C32]] + // CHECK-NEXT: stablehlo.return %[[COND]] // CHECK-NEXT: } do { - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]] - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]] - // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]] + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %iterArg_0, %[[C1]] + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %iterArg, %[[SC]] + // CHECK-NEXT: stablehlo.return %[[ADD_1]], %[[ADD_0]] // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE]]#0 %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<1> : tensor %2 = mhlo.constant dense<32> : tensor - %3:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + %3:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor cond { - %4 = mhlo.compare LT, %iterArg_0, %2 : (tensor, tensor) -> tensor - mhlo.return %4 : tensor + %4 = stablehlo.compare LT, %iterArg_0, %2 : (tensor, tensor) -> tensor + stablehlo.return %4 : tensor } do { - %4 = mhlo.add %iterArg_0, %1 : tensor - %5 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> - mhlo.return %5, %4 : tensor<32x96xf32>, tensor + %4 = stablehlo.add %iterArg_0, %1 : tensor + %5 = stablehlo.add %iterArg, %arg1 : tensor<32x96xf32> + stablehlo.return %5, %4 : tensor<32x96xf32>, tensor } return %3#0 : tensor<32x96xf32> } @@ -89,29 +89,29 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x // CHECK-LABEL: func @while_with_sinked_constants func.func @while_with_sinked_constants(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> { // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0> - // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: %[[WHILE:.*]]:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) // CHECK-NEXT: cond { // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32> - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] - // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: %[[COND:.*]] = stablehlo.compare LT, %iterArg_0, %[[C32]] + // CHECK-NEXT: stablehlo.return %[[COND]] // CHECK-NEXT: } do { // CHECK-NEXT: %[[C1:.*]] = sdy.constant dense<1> - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]] - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %iterArg - // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]] + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %iterArg_0, %[[C1]] + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %iterArg, %iterArg + // CHECK-NEXT: stablehlo.return %[[ADD_1]], %[[ADD_0]] // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE]]#0 %0 = mhlo.constant dense<0> : tensor - %1:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + %1:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor cond { %2 = mhlo.constant dense<32> : tensor - %3 = mhlo.compare LT, %iterArg_0, %2 : (tensor, tensor) -> tensor - mhlo.return %3 : tensor + %3 = stablehlo.compare LT, %iterArg_0, %2 : (tensor, tensor) -> tensor + stablehlo.return %3 : tensor } do { %2 = mhlo.constant dense<1> : tensor - %3 = mhlo.add %iterArg_0, %2 : tensor - %4 = mhlo.add %iterArg, %iterArg : tensor<32x96xf32> - mhlo.return %4, %3 : tensor<32x96xf32>, tensor + %3 = stablehlo.add %iterArg_0, %2 : tensor + %4 = stablehlo.add %iterArg, %iterArg : tensor<32x96xf32> + stablehlo.return %4, %3 : tensor<32x96xf32>, tensor } return %1#0 : tensor<32x96xf32> } @@ -122,27 +122,27 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x func.func @discard_shardings_on_unknown_ops( %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22a\\\22}p0]>"}} ) -> tensor<32xi32> { - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg0, %arg0 : tensor<32xi32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 : tensor<32xi32> // CHECK-NEXT: %[[SHARDING:.*]] = sdy.sharding_constraint %[[ADD]] <@mesh, [{"a"}p2]> : tensor<32xi32> - // CHECK-NEXT: %[[UNKNOWN:.*]] = mhlo.custom_call @UnknownCustomCall(%[[SHARDING]]) : (tensor<32xi32>) -> tensor<32xi32> + // CHECK-NEXT: %[[UNKNOWN:.*]] = stablehlo.custom_call @UnknownCustomCall(%[[SHARDING]]) : (tensor<32xi32>) -> tensor<32xi32> // CHECK-NEXT: return %[[UNKNOWN]] - %0 = mhlo.add %arg0, %arg0 {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p1]>]>"}} : tensor<32xi32> - %1 = mhlo.custom_call @Sharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %2 = mhlo.custom_call @UnknownCustomCall(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %3 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p4]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %0 = stablehlo.add %arg0, %arg0 {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p1]>]>"}} : tensor<32xi32> + %1 = stablehlo.custom_call @Sharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %2 = stablehlo.custom_call @UnknownCustomCall(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %3 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p4]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> return %3 : tensor<32xi32> } // CHECK-LABEL: func @inlined_mesh( // CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding, [{"a"}]>}) - // CHECK-SAME: -> (tensor<32xi32> {sdy.sharding = #sdy.sharding, [{}]>}) { + // CHECK-SAME: -> (tensor<32xi32> {sdy.sharding = #sdy.sharding, []>}) { func.func @inlined_mesh( %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding, [{\\\22a\\\22}]>"}} ) -> tensor<32xi32> { // CHECK-NEXT: %[[SHARDING:.*]] = sdy.sharding_constraint %arg0 , [{"c"}]> : tensor<32xi32> // CHECK-NEXT: return %[[SHARDING]] - %0 = mhlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{\\\22c\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %1 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{\\\22c\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %1 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, []>]>"}} : (tensor<32xi32>) -> tensor<32xi32> return %1 : tensor<32xi32> } @@ -159,16 +159,16 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x %arg2: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh2, [{\\\22c\\\22, \\\22b\\\22, ?}p0]>"}} ) -> (tensor<32xi32>, tensor<32xi32>) { // CHECK-NEXT: %[[SC1:.*]] = sdy.sharding_constraint %arg0 <@mesh2, [{"b", ?}]> - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[SC1]], %[[SC1]] + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[SC1]], %[[SC1]] // CHECK-NOT: sdy.sharding // CHECK-NEXT: %[[SC2:.*]] = sdy.sharding_constraint %arg1 <@mesh2, [{}]> // CHECK-NEXT: return %[[ADD]], %[[SC2]] // CHECK-NEXT: } - %0 = mhlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22, \\\22b\\\22, ?}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %1 = mhlo.add %0, %0 : tensor<32xi32> - %2 = mhlo.custom_call @Sharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22c\\\22, \\\22a\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %3 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %4 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22b\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22, \\\22b\\\22, ?}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %1 = stablehlo.add %0, %0 : tensor<32xi32> + %2 = stablehlo.custom_call @Sharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22c\\\22, \\\22a\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %3 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %4 = stablehlo.custom_call @local_xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22b\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> return %3, %4 : tensor<32xi32>, tensor<32xi32> } @@ -180,19 +180,19 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x // CHECK-SAME{LITERAL}: out_shardings=[<@mesh2, [{}, {"b"}]>] // CHECK-SAME{LITERAL}: manual_axes={"b"} // CHECK-SAME: (%arg2: tensor<16x8xf32>, %arg3: tensor<16x8xf32>) { - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg2, %arg3 + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg2, %arg3 // CHECK-NEXT: sdy.return %[[ADD]] // CHECK-NEXT: } : (tensor<16x32xf32>, tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: return %[[MAN_COMP]] - %0:2 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xf32>) -> (tensor<16x8xf32>, tensor<16x8xf32>) + %0:2 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xf32>) -> (tensor<16x8xf32>, tensor<16x8xf32>) %1 = call @local_xla.sdy.manual_computation_body(%0#0, %0#1) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh2, [{}, {\\\22b\\\22}], replicated={\\\22a\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh2, [{}, {\\\22b\\\22, \\\22a\\\22}]>]>"}} : (tensor<16x8xf32>, tensor<16x8xf32>) -> tensor<16x8xf32> - %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<16x8xf32>) -> tensor<16x32xf32> + %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<16x8xf32>) -> tensor<16x32xf32> return %2 : tensor<16x32xf32> } // CHECK-NOT: func @local_xla.sdy.manual_computation_body( func.func @local_xla.sdy.manual_computation_body(%arg0: tensor<16x8xf32>, %arg1: tensor<16x8xf32>) -> tensor<16x8xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<16x8xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<16x8xf32> return %0 : tensor<16x8xf32> } } @@ -238,16 +238,21 @@ module @no_meshes_attr_module { // CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { func.func @import_sharding_group(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { // CHECK sdy.sharding_group %arg0 group_id = 21: tensor<8x8xf32> - mhlo.custom_call @local_xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> () + stablehlo.custom_call @local_xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> () return %arg0 : tensor<8x8xf32> } // ----- -// CHECK-LABEL: func @import_sharding_group_with_unused_result -// CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { -func.func @import_sharding_group_with_unused_result(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK sdy.sharding_group %arg0 group_id = 21: tensor<8x8xf32> - %0 = mhlo.custom_call @local_xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> tuple<> - return %arg0 : tensor<8x8xf32> +func.func @callback_no_result(%arg0: tensor) { + // CHECK: %[[C:.*]] = sdy.constant + // CHECK-NEXT: stablehlo.custom_call @xla_python_cpu_callback(%[[C]], %arg0) { + // CHECK-SAME: api_version = 2 : i32, backend_config = "56238273106176", + // CHECK-SAME: has_side_effect = true, + // CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], + // CHECK-SAME: result_layouts = [dense<> : tensor<0xindex>] + // CHECK-SAME: } : (tensor, tensor) -> tensor + %c = stablehlo.constant dense<56238273106176> : tensor + stablehlo.custom_call @xla_python_cpu_callback(%c, %arg0) {api_version = 2 : i32, backend_config = "56238273106176", has_side_effect = true, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], result_layouts = []} : (tensor, tensor) -> () + return } diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir index 33e2b3a6c64757..ea59b214151d35 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir @@ -20,9 +20,9 @@ func.func @single_manual_comp(%arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>) // CHECK-NEXT: sdy.return %[[REDUCE]] : tensor<2x32xf32> // CHECK-NEXT: } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> // CHECK-NEXT: return %[[MAN_COMP]] : tensor<8x32xf32> - %0:2 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x32xf32>) -> (tensor<2x8xf32>, tensor<8x32xf32>) + %0:2 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x32xf32>) -> (tensor<2x8xf32>, tensor<8x32xf32>) %1 = call @local_xla.sdy.manual_computation_body(%0#0, %0#1) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh_0, [{\\\22b\\\22}, {}], replicated={\\\22a\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> - %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x32xf32>) -> tensor<8x32xf32> + %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x32xf32>) -> tensor<8x32xf32> return %2 : tensor<8x32xf32> } @@ -36,9 +36,9 @@ func.func @single_manual_comp_name_is_not_prefix_nor_suffix(%arg0: tensor<8x8xf3 // CHECK-NEXT: sdy.return %arg1 : tensor<2x8xf32> // CHECK-NEXT: } : (tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: return %[[MAN_COMP]] : tensor<8x8xf32> - %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> + %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> %1 = call @my_model.___call__.fwd.xla.sdy.manual_computation_body_14.1234(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<8x8xf32> + %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<8x8xf32> return %2 : tensor<8x8xf32> } @@ -60,20 +60,20 @@ func.func @manual_comp_using_another(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: sdy.return %arg1 : tensor<8x4xf32> // CHECK-NEXT: } : (tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: return %[[MAN_COMP_1]] : tensor<8x8xf32> - %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> + %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> %1 = call @local_xla.sdy.manual_computation_body_0(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<8x8xf32> - %3 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%2) : (tensor<8x8xf32>) -> tensor<8x4xf32> + %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<8x8xf32> + %3 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%2) : (tensor<8x8xf32>) -> tensor<8x4xf32> %4 = call @local_xla.sdy.manual_computation_body_1(%3) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>"}} : (tensor<8x4xf32>) -> tensor<8x4xf32> - %5 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%4) : (tensor<8x4xf32>) -> tensor<8x8xf32> + %5 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%4) : (tensor<8x4xf32>) -> tensor<8x8xf32> return %5 : tensor<8x8xf32> } // CHECK-NOT: func @local_xla.sdy.manual_computation_body_3( func.func @local_xla.sdy.manual_computation_body_3(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { - %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> + %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> %1 = call @local_xla.sdy.manual_computation_body_2(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x4xf32>) -> tensor<2x4xf32> - %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> + %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> return %2 : tensor<2x8xf32> } @@ -101,9 +101,9 @@ func.func @nested_shmaps(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: sdy.return %[[MAN_COMP_1]] : tensor<2x8xf32> // CHECK-NEXT: } : (tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[MAN_COMP_0]] : tensor<4x8xf32> - %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> %1 = call @local_xla.sdy.manual_computation_body_3(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> + %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> return %2 : tensor<4x8xf32> } @@ -126,9 +126,9 @@ func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: sdy.return %[[ADD]] : tensor<2x8xf32> // CHECK-NEXT: } : (tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[MAN_COMP_0]] : tensor<4x8xf32> - %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> %1 = call @local_xla.sdy.manual_computation_body_5(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> + %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> return %2 : tensor<4x8xf32> } @@ -144,7 +144,7 @@ func.func @manual_computation_no_inputs() -> tensor<4xi64> { // CHECK-NEXT: } : () -> tensor<4xi64> // CHECK-NEXT: return %[[SHMAP]] : tensor<4xi64> %0 = call @local_xla.sdy.manual_computation_body_6() {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>"}} : () -> tensor<2xi64> - %1 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%0) : (tensor<2xi64>) -> tensor<4xi64> + %1 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%0) : (tensor<2xi64>) -> tensor<4xi64> return %1 : tensor<4xi64> } @@ -155,11 +155,11 @@ func.func @manual_computation_no_outputs(%arg0: tensor<4xi64>) { // CHECK-SAME{LITERAL}: out_shardings=[] // CHECK-SAME{LITERAL}: manual_axes={"b"} // CHECK-SAME{LITERAL}: (%arg1: tensor<2xi64>) { - // CHECK-NEXT: mhlo.custom_call @sdy_testonly(%arg1) : (tensor<2xi64>) -> () + // CHECK-NEXT: stablehlo.custom_call @sdy_testonly(%arg1) : (tensor<2xi64>) -> () // CHECK-NEXT: sdy.return // CHECK-NEXT: } : (tensor<4xi64>) -> () // CHECK-NEXT: return - %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4xi64>) -> tensor<2xi64> + %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4xi64>) -> tensor<2xi64> call @local_xla.sdy.manual_computation_body_7(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[]>"}} : (tensor<2xi64>) -> () return } @@ -198,9 +198,9 @@ func.func @local_xla.sdy.manual_computation_body_4(%arg0: tensor<2x4xf32>) -> te // CHECK-NOT: func @local_xla.sdy.manual_computation_body_5( func.func @local_xla.sdy.manual_computation_body_5(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { - %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> + %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> %1 = call @local_xla.sdy.manual_computation_body_4(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x4xf32>) -> tensor<2x4xf32> - %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> + %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> %3 = stablehlo.add %2, %2 : tensor<2x8xf32> return %3 : tensor<2x8xf32> } @@ -213,6 +213,6 @@ func.func @local_xla.sdy.manual_computation_body_6() -> tensor<2xi64> { // CHECK-NOT: func @local_xla.sdy.manual_computation_body_7( func.func @local_xla.sdy.manual_computation_body_7(%arg0: tensor<2xi64>) { - mhlo.custom_call @sdy_testonly(%arg0) : (tensor<2xi64>) -> () + stablehlo.custom_call @sdy_testonly(%arg0) : (tensor<2xi64>) -> () return } diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir index 9f2a3a5740924d..ba5f28da7a7484 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir @@ -3,14 +3,14 @@ sdy.mesh @mesh = <["a"=2]> func.func @using_same_body_func(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) + %0 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) %1 = call @local_xla.sdy.manual_computation_body(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>) -> (tensor<2x8xf32>) - %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) - %3 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%2) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) + %2 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) + %3 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%2) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) // expected-error @+2 {{'func.call' op expected a unique FuncOp per @local_xla.sdy.manual_computation_body call}} // expected-error @+1 {{failed to legalize operation 'func.call'}} %4 = call @local_xla.sdy.manual_computation_body(%3) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>) -> (tensor<2x8xf32>) - %5 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%4) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) + %5 = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%4) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) return %5 : tensor<8x8xf32> } diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_sharding_group_import_failure.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_sharding_group_import_failure.mlir index f30c0150ce0264..b884fc45eb841a 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_sharding_group_import_failure.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_sharding_group_import_failure.mlir @@ -1,18 +1,18 @@ // RUN: sdy_opt %s -xla-sdy-import-sdy-custom-calls -split-input-file -verify-diagnostics func.func @sharding_group_import_failure_if_no_group_id(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> { - // expected-error @+2 {{failed to legalize operation 'mhlo.custom_call' that was explicitly marked illegal}} + // expected-error @+2 {{failed to legalize operation 'stablehlo.custom_call' that was explicitly marked illegal}} // expected-error @+1 {{expected CustomCallOp with a sharding group id.}} - mhlo.custom_call @local_xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {}} : (tensor<16x16xf32>) -> () + stablehlo.custom_call @local_xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {}} : (tensor<16x16xf32>) -> () return %arg0 : tensor<16x16xf32> } // ----- func.func @sharding_group_import_with_used_result(%arg0: tensor<8x8xf32>) -> tuple> { - // expected-error @+2 {{failed to legalize operation 'mhlo.custom_call' that was explicitly marked illegal}} + // expected-error @+2 {{failed to legalize operation 'stablehlo.custom_call' that was explicitly marked illegal}} // expected-error @+1 {{xla.sdy.ShardingGroup CustomCallOp should have no uses.}} - %0 = mhlo.custom_call @local_xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> tuple<> - %1 = "mhlo.tuple"(%0) : (tuple<>) -> tuple> + %0 = stablehlo.custom_call @local_xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> tuple<> + %1 = "stablehlo.tuple"(%0) : (tuple<>) -> tuple> return %1 : tuple> } diff --git a/third_party/xla/xla/service/spmd/shardy/utils.cc b/third_party/xla/xla/service/spmd/shardy/utils.cc index 604ed05b306ec3..62eecad007b040 100644 --- a/third_party/xla/xla/service/spmd/shardy/utils.cc +++ b/third_party/xla/xla/service/spmd/shardy/utils.cc @@ -30,9 +30,12 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeRange.h" #include "mlir/Support/LLVM.h" #include "shardy/dialect/sdy/ir/register.h" #include "shardy/dialect/sdy/ir/utils.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/spmd/shardy/constants.h" @@ -50,6 +53,7 @@ using ::mlir::StringRef; using xla::sdy::kFrontendAttributesAttr; using ::mlir::func::FuncOp; +using ::mlir::stablehlo::CustomCallOp; DictionaryAttr getFrontendAttrs(Operation* op) { return op->getAttrOfType(kFrontendAttributesAttr); @@ -62,14 +66,18 @@ DictionaryAttr getFuncArgFrontendAttrs(FuncOp funcOp, unsigned int index) { namespace { -mlir::StringAttr getStringAttribute(Attribute attr, mlir::OpBuilder& builder) { +mlir::StringAttr getStringAttribute(Attribute attr, mlir::OpBuilder& builder, + bool escapeAttr) { std::string value; if (auto stringAttr = mlir::dyn_cast(attr)) { + if (!escapeAttr) { + return stringAttr; + } value = stringAttr.getValue().str(); } else { value = mlir::sdy::attributeToString(attr); } - return builder.getStringAttr(absl::CEscape(value)); + return builder.getStringAttr(escapeAttr ? absl::CEscape(value) : value); } SmallVector getExistingFrontendAttributes( @@ -80,17 +88,28 @@ SmallVector getExistingFrontendAttributes( } for (NamedAttribute entry : frontendAttributes) { if (entry.getName() != excludedAttribute) { - dictEntries.emplace_back(entry); + dictEntries.push_back(entry); } } return dictEntries; } -void addFrontendAttribute(SmallVector& existingAttributes, - StringRef name, Attribute value) { +void setFrontendAttribute(SmallVector& existingAttributes, + StringRef name, Attribute value, bool escapeAttr) { mlir::OpBuilder builder(value.getContext()); - existingAttributes.emplace_back(NamedAttribute( - builder.getStringAttr(name), getStringAttribute(value, builder))); + StringAttr stringValue = getStringAttribute(value, builder, escapeAttr); + for (auto* it = existingAttributes.begin(); it != existingAttributes.end(); + ++it) { + if (it->getName() == name) { + if (it->getValue() == stringValue) { + return; + } + existingAttributes.erase(it); + break; + } + } + existingAttributes.emplace_back( + NamedAttribute(builder.getStringAttr(name), stringValue)); } void removeFrontendAttribute( @@ -119,19 +138,20 @@ void setFuncArgFrontendAttrs(FuncOp funcOp, unsigned int index, } // namespace -void addFrontendAttribute(Operation* op, StringRef name, Attribute value) { +void setFrontendAttribute(Operation* op, StringRef name, Attribute value, + bool escapeAttr) { SmallVector existingAttributes = getExistingFrontendAttributes(getFrontendAttrs(op), ""); - addFrontendAttribute(existingAttributes, name, value); + setFrontendAttribute(existingAttributes, name, value, escapeAttr); setFrontendAttrs(op, existingAttributes); } -void addFrontendAttribute(FuncOp funcOp, StringRef name, Attribute value, - int64_t argNum) { +void setFrontendAttribute(FuncOp funcOp, StringRef name, Attribute value, + int64_t argNum, bool escapeAttr) { SmallVector existingAttributes = getExistingFrontendAttributes(getFuncArgFrontendAttrs(funcOp, argNum), ""); - addFrontendAttribute(existingAttributes, name, value); + setFrontendAttribute(existingAttributes, name, value, escapeAttr); setFuncArgFrontendAttrs(funcOp, argNum, existingAttributes); } @@ -169,5 +189,25 @@ void loadAllRequiredDialects(mlir::MLIRContext* context) { context->loadAllAvailableDialects(); } +CustomCallOp cloneCustomCallWithNewResultTypes(CustomCallOp op, + mlir::TypeRange resultTypes, + mlir::IRRewriter& rewriter) { + auto customCallOp = rewriter.create( + op.getLoc(), resultTypes, op.getOperands(), op.getCallTargetNameAttr(), + op.getHasSideEffectAttr(), op.getBackendConfigAttr(), + op.getApiVersionAttr(), op.getCalledComputations(), + op.getOperandLayoutsAttr(), op.getResultLayoutsAttr(), + op.getOutputOperandAliases()); + customCallOp->setDiscardableAttrs(mlir::DictionaryAttr::get( + op->getContext(), llvm::to_vector(op->getDiscardableAttrs()))); + return customCallOp; +}; + +bool isPythonCallbackCustomCall(mlir::stablehlo::CustomCallOp op) { + mlir::StringRef targetName = op.getCallTargetName(); + return targetName == kPythonCpuCallbackCustomCallTargetName || + targetName == kPythonGpuCallbackCustomCallTargetName; +} + } // namespace sdy } // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/utils.h b/third_party/xla/xla/service/spmd/shardy/utils.h index 552de063ce2e4a..7975a55599d648 100644 --- a/third_party/xla/xla/service/spmd/shardy/utils.h +++ b/third_party/xla/xla/service/spmd/shardy/utils.h @@ -28,7 +28,10 @@ limitations under the License. #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeRange.h" #include "mlir/Support/LLVM.h" +#include "stablehlo/dialect/StablehloOps.h" namespace xla { namespace sdy { @@ -42,15 +45,18 @@ mlir::DictionaryAttr getFrontendAttrs(mlir::Operation* op); mlir::DictionaryAttr getFuncArgFrontendAttrs(mlir::func::FuncOp funcOp, unsigned int index); -// Add `name` into the frontend attributes of `op` with value `value`. Note that -// `value` will be turned into a `StringAttr`. -void addFrontendAttribute(mlir::Operation* op, mlir::StringRef name, - mlir::Attribute value); +// Adds `name` into the frontend attributes of `op` with value `value`. If +// `name` already exists, it will be overwritten. Note that `value` will be +// turned into a `StringAttr`. +void setFrontendAttribute(mlir::Operation* op, mlir::StringRef name, + mlir::Attribute value, bool escapeAttr = true); -// Add `name` into the argument at `argNum`'s frontend attributes of `funcOp` -// with value `value`. Note that `value` will be turned into a `StringAttr`. -void addFrontendAttribute(mlir::func::FuncOp funcOp, mlir::StringRef name, - mlir::Attribute value, int64_t argNum); +// Adds `name` into the argument at `argNum`'s frontend attributes of `funcOp` +// with value `value`. If `name` already exists, it will be overwritten. Note +// that `value` will be turned into a `StringAttr`. +void setFrontendAttribute(mlir::func::FuncOp funcOp, mlir::StringRef name, + mlir::Attribute value, int64_t argNum, + bool escapeAttr = true); // Remove `attributeName` from the frontend attributes of `op`. void removeFrontendAttribute(mlir::Operation* op, @@ -98,6 +104,15 @@ std::optional tryGetFrontendAttr(mlir::Operation* op, return std::nullopt; } +// Builds a new `stablehlo.custom_call` with the same operands and attributes +// as `op` but with new `resultTypes`. +mlir::stablehlo::CustomCallOp cloneCustomCallWithNewResultTypes( + mlir::stablehlo::CustomCallOp op, mlir::TypeRange resultTypes, + mlir::IRRewriter& rewriter); + +// Whether `op` is a Python callback custom call. +bool isPythonCallbackCustomCall(mlir::stablehlo::CustomCallOp op); + } // namespace sdy } // namespace xla diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc index 844c1855175fed..6034c828804ba0 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -431,7 +432,7 @@ PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target, // propagated to constant.) if (hlo()->opcode() == HloOpcode::kConstant && !sharding().IsManual() && target.IsManual()) { - PartitionedHlo pconstant = this->Reshard(HloSharding::Replicate()); + PartitionedHlo pconstant = this->Replicate(); pconstant.hlo()->set_sharding(target); return pconstant; } @@ -1443,8 +1444,8 @@ PartitionedHlo::ReshardToPartialReplicateWithAllGather( int64_t replicate_factor = temp_sharding.tile_assignment().dim(dim) / target.tile_assignment().dim(dim); if (replicate_factor > 1) { - replicate_dims.emplace_back(dim); - replicate_factors.emplace_back(replicate_factor); + replicate_dims.push_back(dim); + replicate_factors.push_back(replicate_factor); } } @@ -1590,131 +1591,86 @@ PartitionedHlo PartitionedHlo::Broadcast() const { return PartitionedHlo(result, base_shape_, state_); } +namespace { + +HloSharding GetAllToAllSharding(const HloSharding& source_sharding, + absl::Span source_dims, + absl::Span target_dims) { + CHECK_EQ(source_dims.size(), target_dims.size()); + TileAssignment result = source_sharding.tile_assignment(); + + for (int64_t i = 0; i < source_dims.size(); ++i) { + const int64_t source_dim = source_dims[i]; + const int64_t target_dim = target_dims[i]; + CHECK_NE(source_dim, target_dim); + CHECK_EQ(result.dim(source_dim) % result.dim(target_dim), 0); + + std::vector shape_1_dims; + shape_1_dims.reserve(result.num_dimensions() + 2); + int64_t added_source_dim; + int64_t added_target_dim; + for (int64_t i = 0; i < result.num_dimensions(); ++i) { + if (i == source_dim) { + shape_1_dims.push_back(result.dim(target_dim)); + shape_1_dims.push_back(result.dim(source_dim) / result.dim(target_dim)); + added_source_dim = shape_1_dims.size() - 1; + } else if (i == target_dim) { + shape_1_dims.push_back(result.dim(i)); + shape_1_dims.push_back(1); + added_target_dim = shape_1_dims.size() - 1; + } else { + shape_1_dims.push_back(result.dim(i)); + } + } + + std::vector permutation(shape_1_dims.size()); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[added_source_dim], permutation[added_target_dim]); + std::vector shape_2_dims(result.dimensions().begin(), + result.dimensions().end()); + std::swap(shape_2_dims[source_dim], shape_2_dims[target_dim]); + result = result.Reshape(shape_1_dims) + .Transpose(permutation) + .Reshape(shape_2_dims); + } + + return source_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(result) + : HloSharding::Subgroup(result, source_sharding.subgroup_types()); +} + +} // namespace + PartitionedHlo PartitionedHlo::ReshardWithAllToAll( const HloSharding& target, - absl::Span> source_target_dims) const { + absl::Span> source_target_dims, + bool try_multiple_source_target_dims) const { + if (target == sharding()) { + return *this; + } + VLOG(5) << "Source: " << sharding().ToString(); + VLOG(5) << "Target: " << target.ToString(); if (source_target_dims.empty()) { - if (target == sharding()) { - return *this; - } // If the device order is different in the target, fix the order with // ReshardWithCollectivePermute. return ReshardWithCollectivePermute(target); } - VLOG(5) << "Source: " << sharding().ToString(); - VLOG(5) << "Target: " << target.ToString(); + if (try_multiple_source_target_dims) { + return TryMultipleSourceTargetDims(target, source_target_dims); + } + // Swap one pair of dimensions. - int64_t source_dim = source_target_dims[0].first; - int64_t target_dim = source_target_dims[0].second; + const int64_t source_dim = source_target_dims[0].first; + const int64_t target_dim = source_target_dims[0].second; + VLOG(5) << "Source dim: " << source_dim; + VLOG(5) << "Target dim: " << target_dim; + CHECK_NE(source_dim, target_dim); const int64_t group_size = sharding().tile_assignment().dim(source_dim) / sharding().tile_assignment().dim(target_dim); - VLOG(5) << "Group size: " << group_size; - auto temp_target_tile = [&] { - auto& original_tile_assignment = sharding().tile_assignment(); - std::vector reshape_tile_dims( - original_tile_assignment.num_dimensions() + 2); - int64_t i = 0; - int64_t added_source_dim = -1; - int64_t added_target_dim = -1; - for (int64_t j = 0; j < original_tile_assignment.num_dimensions(); ++j) { - if (source_dim == j) { - reshape_tile_dims[i] = original_tile_assignment.dim(j) / group_size; - reshape_tile_dims[++i] = group_size; - added_source_dim = i; - } else if (target_dim == j) { - reshape_tile_dims[i] = original_tile_assignment.dim(j); - reshape_tile_dims[++i] = 1; - added_target_dim = i; - } else { - reshape_tile_dims[i] = original_tile_assignment.dim(j); - } - ++i; - } - VLOG(5) << "Added target: " << added_target_dim; - VLOG(5) << "Added source: " << added_source_dim; - std::vector xpose_dims(reshape_tile_dims.size()); - std::iota(xpose_dims.begin(), xpose_dims.end(), 0); - xpose_dims[added_source_dim] = added_target_dim; - xpose_dims[added_target_dim] = added_source_dim; - auto temp_target_tile = - hlo_sharding_util::TransposeSharding( - HloSharding::Tile( - original_tile_assignment.Reshape(reshape_tile_dims)), - xpose_dims) - .tile_assignment(); - VLOG(5) << "Transposed target: " << temp_target_tile.ToString(); - std::vector temp_target_tile_dims( - sharding().tile_assignment().dimensions().begin(), - sharding().tile_assignment().dimensions().end()); - temp_target_tile_dims[source_dim] = - sharding().tile_assignment().dim(target_dim); - temp_target_tile_dims[target_dim] = - sharding().tile_assignment().dim(source_dim); - return temp_target_tile.Reshape(temp_target_tile_dims); - }(); - auto temp_target = target.ReplicateOnLastTileDim() - ? HloSharding::PartialTile(temp_target_tile) - : HloSharding::Tile(temp_target_tile); - VLOG(5) << "Temp target sharding: " << temp_target.ToString(); - auto padded_shape = hlo_->shape(); - auto padded_base_shape = base_shape_; - auto current_base_padded_shape = base_shape_; - padded_base_shape.set_dimensions( - target_dim, RoundUpTo(base_shape_.dimensions(target_dim), - temp_target.tile_assignment().dim(target_dim))); - current_base_padded_shape.set_dimensions( - target_dim, hlo_->shape().dimensions(target_dim) * - sharding().tile_assignment().dim(target_dim)); - - auto padded_source_base_shape = base_shape_; - auto current_source_base_padded_shape = base_shape_; - padded_source_base_shape.set_dimensions( - source_dim, RoundUpTo(base_shape_.dimensions(source_dim), - temp_target.tile_assignment().dim(source_dim))); - current_source_base_padded_shape.set_dimensions( - source_dim, hlo_->shape().dimensions(source_dim) * - sharding().tile_assignment().dim(source_dim)); - - VLOG(5) << "Target dim: " << target_dim; - VLOG(5) << "Source dim: " << source_dim; - VLOG(5) << "Original sharded shape: " << hlo_->shape(); - VLOG(5) << "Base shape: " << base_shape_.ToString(); - VLOG(5) << "Padded base shape: " << padded_base_shape.ToString(); - VLOG(5) << "Current padded shape: " << current_base_padded_shape.ToString(); - VLOG(5) << "Padded source base shape: " - << padded_source_base_shape.ToString(); - VLOG(5) << "Current source padded shape: " - << current_source_base_padded_shape.ToString(); - VLOG(5) << "Dimension padded target_dim: " - << hlo_->shape().dimensions(target_dim) * - sharding().tile_assignment().dim(target_dim); - CHECK_GE(padded_base_shape.rank(), current_base_padded_shape.rank()); - CHECK_LE(padded_source_base_shape.rank(), - current_source_base_padded_shape.rank()); - - PaddingConfig pc; - for (int64_t i = 0; i < hlo_->shape().rank(); ++i) { - auto* pd = pc.add_dimensions(); - pd->set_edge_padding_low(0); - pd->set_edge_padding_high(padded_base_shape.dimensions(i) - - current_base_padded_shape.dimensions(i)); - pd->set_interior_padding(0); - } - PartitionedHlo p_hlo = *this; - VLOG(5) << "Before reshard: " << p_hlo.hlo_->ToString(); - HloInstruction* zero = CreateZero( - ShapeUtil::MakeShape(hlo_->shape().element_type(), {}), state_.b); - HloSharding sharding_copy = sharding(); - auto padded_phlo = - ReshardDataForPad(zero, pc, p_hlo, sharding_copy, state_.b); - CHECK(padded_phlo.has_value()); - VLOG(5) << "Resharded: " << padded_phlo->sharded_input->ToString(); - VLOG(5) << "Padded Window: " << padded_phlo->shard_window.DebugString(); - HloInstruction* padded_hlo = - PadDataFromWindowReshard(*padded_phlo, zero, state_.b); - VLOG(5) << "Padded data: " << padded_hlo->ToString(); + const HloSharding temp_target = + GetAllToAllSharding(sharding(), {source_dim}, {target_dim}); // The order of ids in the group must follow the temp_target sharding. std::vector> groups( @@ -1734,27 +1690,44 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll( groups[group_id].push_back(device); }); - HloInstruction* result = nullptr; - - // Split along the split dimension (target_dim) of the all-to-all - // output. - std::vector dimensions; - const int64_t rank = base_shape_.rank(); - dimensions.reserve(rank + 1); - for (int64_t i = 0; i < rank; ++i) { + PaddingConfig pc; + for (int64_t i = 0; i < hlo_->shape().rank(); ++i) { + auto* pd = pc.add_dimensions(); + pd->set_edge_padding_low(0); if (i == target_dim) { - dimensions.push_back(group_size); - dimensions.push_back(padded_hlo->shape().dimensions(i) / group_size); + pd->set_edge_padding_high( + RoundUpTo(base_shape_.dimensions(i), + temp_target.tile_assignment().dim(i)) - + hlo_->shape().dimensions(i) * sharding().tile_assignment().dim(i)); } else { - dimensions.push_back(padded_hlo->shape().dimensions(i)); + pd->set_edge_padding_high(0); } + pd->set_interior_padding(0); } - VLOG(5) << "Target ata shape: " - << ShapeUtil::MakeShape(base_shape_.element_type(), dimensions) - .ToString(); + PartitionedHlo p_hlo = *this; + VLOG(5) << "Before reshard: " << p_hlo.hlo_->ToString(); + HloInstruction* zero = CreateZero( + ShapeUtil::MakeShape(hlo_->shape().element_type(), {}), state_.b); + HloSharding sharding_copy = sharding(); + auto padded_phlo = + ReshardDataForPad(zero, pc, p_hlo, sharding_copy, state_.b); + CHECK(padded_phlo.has_value()); + VLOG(5) << "Resharded: " << padded_phlo->sharded_input->ToString(); + VLOG(5) << "Padded Window: " << padded_phlo->shard_window.DebugString(); + HloInstruction* padded_hlo = + PadDataFromWindowReshard(*padded_phlo, zero, state_.b); + VLOG(5) << "Padded data: " << padded_hlo->ToString(); + + // Split along the split dimension (target_dim) of the all-to-all output. + std::vector target_ata_dims(padded_hlo->shape().dimensions().begin(), + padded_hlo->shape().dimensions().end()); + target_ata_dims.insert(target_ata_dims.begin() + target_dim, group_size); + target_ata_dims[target_dim + 1] /= group_size; auto reshape = state_.b->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(base_shape_.element_type(), dimensions), + ShapeUtil::MakeShape(base_shape_.element_type(), target_ata_dims), padded_hlo)); + VLOG(5) << "Target ata shape: " << reshape->shape().ToString(); + // After the reshape, it is guaranteed to have at least 3 dimensions. auto all_to_all = state_.collective_ops_creator.create_cross_partition_all_to_all( @@ -1783,27 +1756,212 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll( auto new_shape = ShapeInference::InferAllToAllShape( padded_hlo->shape(), target_dim, source_dim, group_size) .value(); - result = state_.b->AddInstruction( + HloInstruction* result = state_.b->AddInstruction( HloInstruction::CreateReshape(new_shape, transpose)); + CHECK_EQ(result->shape().rank(), base_shape_.rank()); result->set_sharding(temp_target); + + auto padded_source_base_shape = base_shape_; + auto current_source_base_padded_shape = base_shape_; + padded_source_base_shape.set_dimensions( + source_dim, RoundUpTo(base_shape_.dimensions(source_dim), + temp_target.tile_assignment().dim(source_dim))); + current_source_base_padded_shape.set_dimensions( + source_dim, hlo_->shape().dimensions(source_dim) * + sharding().tile_assignment().dim(source_dim)); + + VLOG(5) << "Original sharded shape: " << hlo_->shape(); + VLOG(5) << "Base shape: " << base_shape_.ToString(); + VLOG(5) << "Padded source base shape: " + << padded_source_base_shape.ToString(); + VLOG(5) << "Current source padded shape: " + << current_source_base_padded_shape.ToString(); + std::vector strides(result->shape().rank(), 1); std::vector starts(result->shape().rank(), 0); - std::vector limits(result->shape().rank()); - for (int64_t i = 0; i < result->shape().rank(); ++i) { - limits[i] = padded_source_base_shape.dimensions(i); - } auto sliced_phlo = ReshardDataForSlicing( - strides, starts, limits, + strides, starts, padded_source_base_shape.dimensions(), PartitionedHlo(result, current_source_base_padded_shape, state_), temp_target, state_.b); CHECK(sliced_phlo.has_value()); result = SliceDataFromWindowReshard(*sliced_phlo, strides, base_shape_, temp_target, state_.b); result->set_sharding(temp_target); - auto remaining_source_target_dims = source_target_dims; - remaining_source_target_dims.remove_prefix(1); return PartitionedHlo(result, base_shape_, state_) - .ReshardWithAllToAll(target, remaining_source_target_dims); + .ReshardWithAllToAll( + target, source_target_dims.last(source_target_dims.size() - 1)); +} + +PartitionedHlo PartitionedHlo::TryMultipleSourceTargetDims( + const HloSharding& target, + absl::Span> source_target_dims) const { + std::vector eligible_source_dims; + std::vector eligible_target_dims; + std::vector group_sizes; + std::vector> ineligible_source_target_dims; + absl::flat_hash_set seen_dims; + + std::vector> sorted_pairs_by_target_dim( + source_target_dims.begin(), source_target_dims.end()); + absl::c_stable_sort( + sorted_pairs_by_target_dim, + [](const std::pair& a, + const std::pair& b) { return a.second < b.second; }); + for (const auto& [source_dim, target_dim] : sorted_pairs_by_target_dim) { + CHECK_NE(source_dim, target_dim); + bool dims_already_seen = + seen_dims.contains(source_dim) || seen_dims.contains(target_dim); + bool source_dim_divisible = + base_shape_.dimensions(source_dim) % + sharding().tile_assignment().dim(source_dim) == + 0; + bool target_dim_divisible = base_shape_.dimensions(target_dim) % + target.tile_assignment().dim(target_dim) == + 0; + if (!dims_already_seen && source_dim_divisible && target_dim_divisible) { + eligible_source_dims.push_back(source_dim); + eligible_target_dims.push_back(target_dim); + group_sizes.push_back(sharding().tile_assignment().dim(source_dim) / + sharding().tile_assignment().dim(target_dim)); + seen_dims.insert(source_dim); + seen_dims.insert(target_dim); + } else { + ineligible_source_target_dims.push_back({source_dim, target_dim}); + } + } + + const int64_t num_eligible_dims = eligible_source_dims.size(); + if (num_eligible_dims < 2) { + return ReshardWithAllToAll(target, source_target_dims, false); + } + + // We go through 3 steps with the following example: + // base shape: (32,32,32,32) + // old sharding: [1,4,2,1], local shape (32,8,16,32) + // new sharding: [2,1,1,4], local shape (16,32,32,8) + // source_target_dims sorted by target_dims: {{2, 0}, {1, 3}} + + // Step 1. Merge sharding axes to a single dimension + // 1. reshape_0 (32,8,16,32) -> shape_0 (2,16,8,16,4,8) + // 2. transpose_0 (2,16,8,16,4,8) -> (2,4,16,8,16,8) with permutation_0 + // (0,4,1,2,3,5) + // 3. reshape_1 (2,4,16,8,16,8) -> (8,16,8,16,8) + std::vector shape_0_dims; + shape_0_dims.reserve(hlo_->shape().rank() + num_eligible_dims); + std::vector permutation_0; + for (int64_t i = 0; i < hlo_->shape().rank(); ++i) { + auto it = absl::c_find(eligible_target_dims, i); + if (it != eligible_target_dims.end()) { + int64_t group_size = + group_sizes[std::distance(eligible_target_dims.begin(), it)]; + permutation_0.push_back(shape_0_dims.size()); + shape_0_dims.push_back(group_size); + shape_0_dims.push_back(hlo_->shape().dimensions(i) / group_size); + } else { + shape_0_dims.push_back(hlo_->shape().dimensions(i)); + } + } + HloInstruction* reshape_0 = + state_.b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(base_shape_.element_type(), shape_0_dims), + hlo_)); + + for (int64_t i = 0; i < shape_0_dims.size(); ++i) { + if (!absl::c_linear_search(permutation_0, i)) { + permutation_0.push_back(i); + } + } + HloInstruction* transpose_0 = + state_.b->AddInstruction(HloInstruction::CreateTranspose( + ShapeInference::InferTransposeShape(reshape_0->shape(), permutation_0) + .value(), + reshape_0, permutation_0)); + + absl::Span transpose_shape_dims = + transpose_0->shape().dimensions(); + std::vector shape_1_dims; + shape_1_dims.reserve(1 + base_shape_.rank()); + shape_1_dims.push_back( + std::accumulate(transpose_shape_dims.begin(), + transpose_shape_dims.begin() + num_eligible_dims, 1, + std::multiplies())); + std::copy(transpose_shape_dims.begin() + num_eligible_dims, + transpose_shape_dims.end(), std::back_inserter(shape_1_dims)); + HloInstruction* reshape_1 = + state_.b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(base_shape_.element_type(), shape_1_dims), + transpose_0)); + + // // Step 2. Apply the all-to-all + // all-to-all on (8,16,8,16,8) with split_dimension = 0 + int64_t total_group_size = std::accumulate( + group_sizes.begin(), group_sizes.end(), 1, std::multiplies()); + const HloSharding temp_target = GetAllToAllSharding( + sharding(), eligible_source_dims, eligible_target_dims); + std::vector> groups( + temp_target.tile_assignment().num_elements() / total_group_size); + temp_target.tile_assignment().Each( + [&](absl::Span indices, int64_t device) { + int64_t group_id = 0; + for (int64_t dim = 0; dim < indices.size(); ++dim) { + auto it = absl::c_find(eligible_target_dims, dim); + if (it != eligible_target_dims.end()) { + int64_t group_size = + group_sizes[std::distance(eligible_target_dims.begin(), it)]; + group_id *= temp_target.tile_assignment().dim(dim) / group_size; + group_id += indices[dim] / group_size; + } else { + group_id *= temp_target.tile_assignment().dim(dim); + group_id += indices[dim]; + } + } + groups[group_id].push_back(device); + }); + HloInstruction* all_to_all = + state_.collective_ops_creator.create_cross_partition_all_to_all( + state_.b, {reshape_1}, groups, (*state_.next_channel_id)++, 0); + + // Step 3. Split sharding axes to multiple dimensions + // 1. reshape_2 (8,16,8,16,8) -> (2,4,16,8,16,8) + // 2. transpose_1 (2,4,16,8,16,8) -> (16,4,8,2,16,8) with permutation_1 + // (2,1,3,0,4,5) + // 3. reshape_3 (16,4,8,2,16,8) -> shape_3 (16,32,32,8) + HloInstruction* reshape_2 = state_.b->AddInstruction( + HloInstruction::CreateReshape(transpose_0->shape(), all_to_all)); + + std::vector permutation_1(base_shape_.rank()); + std::iota(permutation_1.begin(), permutation_1.end(), num_eligible_dims); + for (int64_t i = 0; i < num_eligible_dims; ++i) { + auto it = absl::c_find(permutation_1, + eligible_source_dims[i] + num_eligible_dims); + CHECK(it != permutation_1.end()); + permutation_1.insert(it, i); + } + HloInstruction* transpose_1 = + state_.b->AddInstruction(HloInstruction::CreateTranspose( + ShapeInference::InferTransposeShape(reshape_2->shape(), permutation_1) + .value(), + reshape_2, permutation_1)); + + std::vector shape_3_dims; + shape_3_dims.reserve(base_shape_.rank()); + for (int64_t i = 0; i < permutation_1.size(); ++i) { + if (permutation_1[i] < num_eligible_dims) { + shape_3_dims.push_back(transpose_1->shape().dimensions(i) * + transpose_1->shape().dimensions(i + 1)); + i++; + } else { + shape_3_dims.push_back(transpose_1->shape().dimensions(i)); + } + } + HloInstruction* reshape_3 = + state_.b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(base_shape_.element_type(), shape_3_dims), + transpose_1)); + reshape_3->set_sharding(temp_target); + + return PartitionedHlo(reshape_3, base_shape_, state_) + .ReshardWithAllToAll(target, ineligible_source_target_dims, false); } namespace { @@ -2600,103 +2758,65 @@ absl::Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) { return absl::OkStatus(); } -absl::Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) { +absl::Status SpmdPartitioningVisitor::HandleCollectivePermute( + HloInstruction* hlo) { + if (hlo->channel_id()) { + return HandleElementwise(hlo); + } + return DefaultAction(hlo); +} + +absl::Status SpmdPartitioningVisitor::HandleElementwiseWithDimsToReplicate( + HloInstruction* hlo, absl::Span dims_to_replicate) { const HloSharding& sharding = hlo->sharding(); if (sharding.IsTileMaximal()) { return DefaultAction(hlo); } - const Shape shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); - const int64_t dimension = hlo->concatenate_dimension(); - if (sharding.tile_assignment().dim(dimension) == 1) { - std::vector new_operands; - for (HloInstruction* operand : hlo->operands()) { - new_operands.push_back( - GetPartitionedHlo(operand).Reshard(sharding).hlo()); - } - SetPartitionedHlo(hlo, [&] { - return b_.AddInstruction( - hlo->CloneWithNewOperands(shard_shape, new_operands)); - }); - return absl::OkStatus(); - } + // 1. Replicate the final sharding along `dims_to_replicate` to get + // temp_sharding. + const HloSharding temp_sharding = + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + sharding, dims_to_replicate); - // If the concatenate dimension is along one of the partitioned dimensions, - // allocate the full output shape, each partition updates its owned region, - // all-reduce across partitions, and then slice its output region. - - // temp_output_shape is the output shape where the concatenate dimension - // is changed to the full (and padded to shard count) dimension size. - auto temp_output_shape = MakePartitionedShape(hlo->shape(), sharding); - auto last_operand_padded_shape = - MakePartitionedShape(hlo->operands().back()->shape(), sharding); - // If the last operand has more padding than the temp_output padding, needs to - // add extra padding to avoid dynamic update slice out of bound. - int last_operand_padding = - last_operand_padded_shape.dimensions(dimension) * - sharding.tile_assignment().dim(dimension) - - hlo->operands().back()->shape().dimensions(dimension); - int temp_output_padding = temp_output_shape.dimensions(dimension) * - sharding.tile_assignment().dim(dimension) - - hlo->shape().dimensions(dimension); - int padding_for_last_operand = - last_operand_padding < temp_output_padding - ? 0 - : last_operand_padding - temp_output_padding; - temp_output_shape.set_dimensions( - dimension, temp_output_shape.dimensions(dimension) * - sharding.tile_assignment().dim(dimension) + - padding_for_last_operand); - auto temp_output = CreateZero(temp_output_shape, &b_); - - // Offset of each operand along the concatenate dimension. - int64_t offset = 0; - auto state = MakePartitioningState(); + // 2. Reshard the operands to temp_sharding. + std::vector new_operands; + new_operands.reserve(hlo->operands().size()); for (HloInstruction* operand : hlo->operands()) { - auto spmd_operand = - GetPartitionedHlo(operand).Reshard(sharding).PadWithZero().hlo(); - std::vector start_indices( - hlo->shape().rank(), b_.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(S32)))); - start_indices[dimension] = - MultiplyAddDivideOffsetCalculation( - spmd_operand->shape().dimensions(dimension), offset, 1) - .Calculate(MakeTiledPartitionOrdinals(sharding, state.partition_id, - &b_)[dimension], - &b_); - temp_output = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - temp_output_shape, temp_output, spmd_operand, start_indices)); - offset += operand->shape().dimensions(dimension); - } - std::vector non_concat_dims; - non_concat_dims.reserve(hlo->shape().rank() - 1); - for (int64_t i = 0; i < hlo->shape().rank(); ++i) { - if (i != dimension) { - non_concat_dims.push_back(i); - } - } - auto grouped = - hlo_sharding_util::GroupShardingOnDims(sharding, non_concat_dims); - auto per_group_partitioner_state = - CreatePerGroupPartitioningState(state, grouped.device_groups, &b_); - auto all_reduce = per_group_partitioner_state.collective_ops_creator - .create_cross_partition_all_reduce( - &b_, temp_output, - MakeBinaryAdd(hlo->shape().element_type(), module_), - {}, NewChannel()); - SetPartitionedHlo(hlo, [&] { - auto start_indices = MakeTiledPartitionOrdinals( - grouped.sharding, per_group_partitioner_state.partition_id, &b_); - start_indices[dimension] = MultiplyAddDivideOffsetCalculation( - shard_shape.dimensions(dimension), 0, 1) - .Calculate(start_indices[dimension], &b_); - return b_.AddInstruction(HloInstruction::CreateDynamicSlice( - shard_shape, all_reduce, start_indices, shard_shape.dimensions())); - }); + new_operands.push_back( + GetPartitionedHlo(operand).Reshard(temp_sharding).hlo()); + } + + // 3. Apply the operation to get result in temp_sharding. + auto result_in_temp_sharding = b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), temp_sharding), new_operands)); + result_in_temp_sharding->set_sharding(temp_sharding); + // 4. Reshard the result from temp_sharding to the final sharding. + SetPartitionedHlo(hlo, PartitionedHlo(result_in_temp_sharding, hlo->shape(), + MakePartitioningState()) + .Reshard(sharding)); return absl::OkStatus(); } +absl::Status SpmdPartitioningVisitor::HandleCholesky(HloInstruction* hlo) { + CHECK_GE(hlo->shape().rank(), 2); + return HandleElementwiseWithDimsToReplicate( + hlo, {hlo->shape().rank() - 2, hlo->shape().rank() - 1}); +} + +absl::Status SpmdPartitioningVisitor::HandleTriangularSolve( + HloInstruction* hlo) { + CHECK_GE(hlo->shape().rank(), 2); + return HandleElementwiseWithDimsToReplicate( + hlo, {hlo->shape().rank() - 2, hlo->shape().rank() - 1}); +} + +absl::Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) { + return HandleElementwiseWithDimsToReplicate(hlo, + {hlo->concatenate_dimension()}); +} + absl::Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) { const HloSharding& sharding = hlo->sharding(); if (sharding.IsTileMaximal()) { @@ -2818,8 +2938,7 @@ absl::Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) { slice_input, ShapeUtil::MakeShape(element_type, replicated_dimensions), MakePartitioningState()); // Reshard value to be replicated. - auto replicated_slice_input = - partitioned_slice_input.Reshard(HloSharding::Replicate()).hlo(); + auto replicated_slice_input = partitioned_slice_input.Replicate().hlo(); // Slice top K index from the first parttioned sort. auto slice_index = SliceFirstK(index_gte, &b_, sort_dim, k.value()); @@ -2828,8 +2947,7 @@ absl::Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) { slice_index, ShapeUtil::MakeShape(index_type, replicated_dimensions), MakePartitioningState()); // Reshard value to be replicated. - auto replicated_slice_index = - partitioned_slice_index.Reshard(HloSharding::Replicate()).hlo(); + auto replicated_slice_index = partitioned_slice_index.Replicate().hlo(); // Creates replicated sort to do TopK, the input is value and index pairs // from all the partitions. @@ -2875,7 +2993,7 @@ absl::Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) { std::vector new_shardings; std::optional new_output_sharding; if (std::optional picked_dim = - hlo_sharding_util::GetFirstMergeableDimForSortOperand( + hlo_sharding_util::GetFirstTargetDimToMoveShardingTiles( subshape, cur_sharding, sort_dim)) { // We can move the sharding tiles from the sort dimension to the picked // dimension. @@ -3401,6 +3519,48 @@ absl::Status SpmdPartitioningVisitor::HandleAllReduce(HloInstruction* hlo) { return DefaultAction(hlo); } +absl::Status SpmdPartitioningVisitor::HandleBitcastConvert( + HloInstruction* hlo) { + const Shape& input_shape = hlo->operand(0)->shape(); + const Shape& output_shape = hlo->shape(); + if (input_shape.rank() == output_shape.rank()) { + return HandleElementwise(hlo); + } + + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + PartitionedHlo& operand = GetPartitionedHlo(hlo->operand(0)); + HloSharding temp_input_sharding = HloSharding::Replicate(); + HloSharding temp_output_sharding = HloSharding::Replicate(); + if (input_shape.rank() > output_shape.rank()) { + CHECK_EQ(input_shape.rank(), output_shape.rank() + 1); + std::vector extra_dim = {output_shape.rank()}; + temp_input_sharding = + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + operand.sharding(), extra_dim); + temp_output_sharding = hlo_sharding_util::RemoveShapeDimensions( + temp_input_sharding, extra_dim); + } else { + CHECK_EQ(input_shape.rank() + 1, output_shape.rank()); + std::vector extra_dim = {input_shape.rank()}; + temp_output_sharding = + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + hlo->sharding(), extra_dim); + temp_input_sharding = hlo_sharding_util::RemoveShapeDimensions( + temp_output_sharding, extra_dim); + } + Shape temp_output_shape = + MakePartitionedShape(output_shape, temp_output_sharding); + HloInstruction* temp_output = b_.AddInstruction(hlo->CloneWithNewOperands( + temp_output_shape, {operand.Reshard(temp_input_sharding).hlo()})); + temp_output->set_sharding(temp_output_sharding); + SetPartitionedHlo( + hlo, PartitionedHlo(temp_output, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding())); + return absl::OkStatus(); +} + absl::Status SpmdPartitioningVisitor::HandleBroadcast(HloInstruction* hlo) { if (hlo->sharding().IsTileMaximal()) { return DefaultAction(hlo); @@ -3471,9 +3631,7 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicSlice(HloInstruction* hlo) { continue; } // Replicate the indices.; - new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1)) - .Reshard(HloSharding::Replicate()) - .hlo(); + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1)).Replicate().hlo(); } SetPartitionedHlo(hlo, [&]() { auto partitioned_shape = @@ -3528,9 +3686,7 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice( std::vector new_indices(hlo->shape().rank()); for (int64_t i = 0; i < new_indices.size(); ++i) { // Replicate the indices. - new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)) - .Reshard(HloSharding::Replicate()) - .hlo(); + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)).Replicate().hlo(); } auto dus = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( base.hlo()->shape(), base.hlo(), operand.hlo(), new_indices)); @@ -3559,9 +3715,7 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice( continue; } // Replicate the indices. - new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)) - .Reshard(HloSharding::Replicate()) - .hlo(); + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)).Replicate().hlo(); } // Get partitioned input. @@ -3679,9 +3833,7 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice( continue; } // Replicate the indices. - new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)) - .Reshard(HloSharding::Replicate()) - .hlo(); + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)).Replicate().hlo(); } SetPartitionedHlo(hlo, [&]() { auto partitioned_shape = @@ -3849,9 +4001,7 @@ absl::Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) { return DefaultAction(hlo); } auto lhs = GetPartitionedHlo(hlo->operand(0)); - auto replicated_rhs = GetPartitionedHlo(hlo->operand(1)) - .Reshard(HloSharding::Replicate()) - .hlo(); + auto replicated_rhs = GetPartitionedHlo(hlo->operand(1)).Replicate().hlo(); auto reshard_operand = ReshardDataForPad( replicated_rhs, hlo->padding_config(), lhs, hlo->sharding(), &b_); if (!reshard_operand.has_value()) { @@ -3930,7 +4080,7 @@ absl::Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { for (int64_t operand_id = 0; operand_id < input_count; ++operand_id) { inits.push_back(GetPartitionedHlo(hlo->operand(operand_id + input_count)) - .Reshard(HloSharding::Replicate()) + .Replicate() .hlo()); inputs.push_back(GetPartitionedHlo(hlo->operand(operand_id))); if (operand_id > 0) { @@ -4115,9 +4265,7 @@ absl::Status SpmdPartitioningVisitor::HandleConditional(HloInstruction* hlo) { .Reshard(hlo_sharding_util::UngroupSharding(grouped_sharding)) .hlo(); } else { - cond = GetPartitionedHlo(hlo->operand(0)) - .Reshard(HloSharding::Replicate()) - .hlo(); + cond = GetPartitionedHlo(hlo->operand(0)).Replicate().hlo(); } } return b_.AddInstruction(HloInstruction::CreateConditional( @@ -4343,7 +4491,7 @@ absl::Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) { // Run on a single device (0) and distribute the data to all other cores. auto clone = clone_from_original(HloSharding::AssignDevice(0)); return PartitionedHlo(clone, hlo->shape(), MakePartitioningState()) - .Reshard(HloSharding::Replicate()) + .Replicate() .hlo(); }); return absl::OkStatus(); @@ -4354,9 +4502,8 @@ absl::Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) { std::vector new_operands; new_operands.reserve(hlo->operand_count()); for (int64_t i = 0; i < hlo->operand_count(); ++i) { - new_operands.push_back(GetPartitionedHlo(hlo->operand(i)) - .Reshard(HloSharding::Replicate()) - .hlo()); + new_operands.push_back( + GetPartitionedHlo(hlo->operand(i)).Replicate().hlo()); } if (!hlo->sharding().ReplicateOnLastTileDim()) { @@ -4403,8 +4550,8 @@ absl::Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) { for (const HloInstruction* input_array : input_arrays) { PartitionedHlo& operand = GetPartitionedHlo(input_array); // Replicate init - PartitionedHlo replicated_init = GetPartitionedHlo(init_values[input_idx]) - .Reshard(HloSharding::Replicate()); + PartitionedHlo replicated_init = + GetPartitionedHlo(init_values[input_idx]).Replicate(); const HloSharding& sharding = hlo->sharding().IsTuple() ? hlo->sharding().tuple_elements()[input_idx] @@ -4506,8 +4653,7 @@ absl::Status SpmdPartitioningVisitor::HandleSelectAndScatter( : LiteralUtil::CreateR0(float_pad_value))); // Replicate init - auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2)) - .Reshard(HloSharding::Replicate()); + auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2)).Replicate(); auto state = MakePartitioningState(); auto partition_ordinals = diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.h b/third_party/xla/xla/service/spmd/spmd_partitioner.h index 0279fc0fc7d4a5..30da16cf8b4cff 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.h +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.h @@ -543,6 +543,13 @@ class PartitionedHlo { // Helper function to reshard the tensor using AllToAll (instead of the // default of Replicate followed by Slice). PartitionedHlo ReshardWithAllToAll( + const HloSharding& target, + absl::Span> source_target_dims, + bool try_multiple_source_target_dims = true) const; + + // Called by ReshardWithAllToAll if try_multiple_source_target_dims is true. + // Try to handle multiple source and target dims in a single AllToAll. + PartitionedHlo TryMultipleSourceTargetDims( const HloSharding& target, absl::Span> source_target_dims) const; @@ -582,10 +589,17 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { SpmdPartitioningVisitor(const SpmdPartitioningVisitor& src); absl::Status DefaultAction(HloInstruction* hlo) override; + absl::Status HandleAllReduce(HloInstruction* hlo) override; + absl::Status HandleBitcastConvert(HloInstruction* hlo) override; absl::Status HandleBroadcast(HloInstruction* hlo) override; absl::Status HandleCall(HloInstruction* hlo) override; + absl::Status HandleCholesky(HloInstruction* hlo) override; + absl::Status HandleCollectivePermute(HloInstruction* hlo) override; + absl::Status HandleConcatenate(HloInstruction* hlo) override; + absl::Status HandleConditional(HloInstruction* hlo) override; absl::Status HandleConstant(HloInstruction* hlo) override; + absl::Status HandleConvolution(HloInstruction* hlo) override; absl::Status HandleCustomCall(HloInstruction* hlo) override; absl::Status HandleDot(HloInstruction* hlo) override; absl::Status HandleDynamicSlice(HloInstruction* hlo) override; @@ -594,27 +608,25 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { absl::Status HandleGather(HloInstruction* hlo) override; absl::Status HandleGetTupleElement(HloInstruction* hlo) override; absl::Status HandleInfeed(HloInstruction* hlo) override; + absl::Status HandleIota(HloInstruction* hlo) override; absl::Status HandleOptimizationBarrier(HloInstruction* hlo) override; absl::Status HandleOutfeed(HloInstruction* hlo) override; absl::Status HandlePad(HloInstruction* hlo) override; absl::Status HandleParameter(HloInstruction* hlo) override; + absl::Status HandlePartitionId(HloInstruction* hlo) override; absl::Status HandleReduce(HloInstruction* hlo) override; - absl::Status HandleReverse(HloInstruction* hlo) override; - absl::Status HandleWhile(HloInstruction* hlo) override; - absl::Status HandleConditional(HloInstruction* hlo) override; absl::Status HandleReduceWindow(HloInstruction* hlo) override; - absl::Status HandleSelectAndScatter(HloInstruction* hlo) override; - absl::Status HandleTuple(HloInstruction* hlo) override; + absl::Status HandleReshape(HloInstruction* hlo) override; + absl::Status HandleReverse(HloInstruction* hlo) override; absl::Status HandleRng(HloInstruction* hlo) override; - absl::Status HandleConvolution(HloInstruction* hlo) override; - absl::Status HandleConcatenate(HloInstruction* hlo) override; absl::Status HandleScatter(HloInstruction* hlo) override; + absl::Status HandleSelectAndScatter(HloInstruction* hlo) override; absl::Status HandleSlice(HloInstruction* hlo) override; absl::Status HandleSort(HloInstruction* hlo) override; absl::Status HandleTranspose(HloInstruction* hlo) override; - absl::Status HandleReshape(HloInstruction* hlo) override; - absl::Status HandleIota(HloInstruction* hlo) override; - absl::Status HandlePartitionId(HloInstruction* hlo) override; + absl::Status HandleTriangularSolve(HloInstruction* hlo) override; + absl::Status HandleTuple(HloInstruction* hlo) override; + absl::Status HandleWhile(HloInstruction* hlo) override; // Implementation of dot partitioning given DotGeneralDimsMapping. absl::Status HandleDotHelper( @@ -628,6 +640,11 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { // Common handle for elementwise HLOs. absl::Status HandleElementwise(HloInstruction* hlo); + // All dimensions in the hlo are element-wise except that we replicate + // `dims_to_replicate`. + absl::Status HandleElementwiseWithDimsToReplicate( + HloInstruction* hlo, absl::Span dims_to_replicate); + // Common handle for HLOs that runs on a single device. absl::Status HandleSingleDevice(const HloInstruction* hlo); diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc index 59b7cce5432c8c..dba36a3340dea8 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc @@ -51,10 +51,10 @@ limitations under the License. #include "xla/shape.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" namespace xla { namespace spmd { @@ -147,7 +147,8 @@ class SpmdPartitioningTest } } - int64_t NumOfInstructions(HloComputation* computation, HloOpcode opcode) { + int64_t NumOfInstructions(const HloComputation* computation, + HloOpcode opcode) { int64_t count = 0; for (const HloInstruction* inst : computation->instructions()) { if (inst->opcode() == opcode) { @@ -397,6 +398,68 @@ ENTRY entry { op::Shape("s32[8,1]"))); } +TEST_P(SpmdPartitioningTest, MultipleSourceTargetDimsInOneAllToAll1) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %param= s32[64,64,64,64] parameter(0), sharding={devices=[1,4,2,1]<=[8]} + ROOT %copy = s32[64,64,64,64] copy(%param), sharding={devices=[2,1,1,4]<=[4,2]T(1,0)} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + const HloComputation* entry = module->entry_computation(); + EXPECT_EQ(NumOfInstructions(entry, HloOpcode::kAllToAll), 1); + EXPECT_EQ(NumOfInstructions(entry, HloOpcode::kCollectivePermute), 0); + + auto* all_to_all = FindInstruction(module.get(), "all-to-all"); + EXPECT_THAT(all_to_all, op::Shape("s32[8,32,16,32,16]")); + EXPECT_EQ(all_to_all->replica_groups().size(), 1); + EXPECT_EQ(all_to_all->replica_groups()[0].replica_ids_size(), 8); +} + +TEST_P(SpmdPartitioningTest, MultipleSourceTargetDimsInOneAllToAll2) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %param= f32[64,64,64,64,64,64] parameter(0), sharding={devices=[2,2,2,1,1,1]<=[8]} + ROOT %copy = f32[64,64,64,64,64,64] copy(%param), sharding={devices=[1,1,1,2,2,2]<=[2,2,2]T(1,0,2)} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + + const HloComputation* entry = module->entry_computation(); + EXPECT_EQ(NumOfInstructions(entry, HloOpcode::kAllToAll), 1); + EXPECT_EQ(NumOfInstructions(entry, HloOpcode::kCollectivePermute), 1); + + auto* all_to_all = FindInstruction(module.get(), "all-to-all"); + EXPECT_THAT(all_to_all, op::Shape("f32[8,32,32,32,32,32,32]")); + EXPECT_EQ(all_to_all->replica_groups().size(), 1); + EXPECT_EQ(all_to_all->replica_groups()[0].replica_ids_size(), 8); +} + +TEST_P(SpmdPartitioningTest, MultipleSourceTargetDimsInOneAllToAll3) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %param= f32[64,64,64,64] parameter(0), sharding={devices=[2,4,8,1]<=[64]} + ROOT %copy = f32[64,64,64,64] copy(%param), sharding={devices=[4,2,1,8]<=[2,2,2,8]T(0,2,1,3)} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/64)); + + const HloComputation* entry = module->entry_computation(); + EXPECT_EQ(NumOfInstructions(entry, HloOpcode::kAllToAll), 1); + EXPECT_EQ(NumOfInstructions(entry, HloOpcode::kCollectivePermute), 0); + + auto* all_to_all = FindInstruction(module.get(), "all-to-all"); + EXPECT_THAT(all_to_all, op::Shape("f32[16,16,16,8,8]")); + EXPECT_EQ(all_to_all->replica_groups().size(), 4); + EXPECT_EQ(all_to_all->replica_groups()[0].replica_ids_size(), 16); +} + TEST_P(SpmdPartitioningTest, TiledToTiledUneven) { absl::string_view hlo_string = R"( HloModule module @@ -2249,11 +2312,9 @@ TEST_P(SpmdPartitioningTest, ConcatenateAlongPartitionedDimension) { HloModule module ENTRY entry { - %param0 = f32[14,257] parameter(0) - %param0.copy = f32[14,257] copy(%param0), sharding={devices=[1,2]0,1} - %param1 = f32[14,116] parameter(1) - %param1.copy = f32[14,116] copy(%param1), sharding={devices=[1,2]0,1} - ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy), + %param0 = f32[14,257] parameter(0), sharding={devices=[1,2]0,1} + %param1 = f32[14,116] parameter(1), sharding={devices=[1,2]0,1} + ROOT %concatenate = f32[14,373] concatenate(%param0, %param1), dimensions={1}, sharding={devices=[1,2]0,1} })"; @@ -2261,27 +2322,28 @@ ENTRY entry { PartitionComputation(hlo_string, /*num_devices=*/2)); VLOG(1) << module->ToString(); - const auto root = module->entry_computation()->root_instruction(); - auto param0 = - AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), - op::Constant(), op::Reshape())), - op::Shape("f32[14,129]")); + auto param0 = AllOf(op::Parameter(0), op::Shape("f32[14,129]")); auto param0_adjusted = AllOf(op::Select(op::Compare(op::Add(), op::Broadcast(op::Constant())), param0, op::Broadcast(op::Constant())), op::Shape("f32[14,129]")); - auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), - op::Reshape())), - op::Shape("f32[14,58]")); - EXPECT_THAT(root, AllOf(op::DynamicSlice( - AllOf(op::AllReduce(op::DynamicUpdateSlice( - op::DynamicUpdateSlice( - op::Broadcast(), param0_adjusted, - op::Constant(), op::Multiply()), - param1, op::Constant(), op::Add())), - op::Shape("f32[14,374]")), - op::Constant(), op::Multiply()), - op::Shape("f32[14,187]"))); + auto param0_replicated = AllOf(op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(), param0_adjusted, _, _)), + op::Shape("f32[14,257]")); + + auto param1 = AllOf(op::Parameter(1), op::Shape("f32[14,58]")); + auto param1_replicated = AllOf( + op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(), param1, _, _)), + op::Shape("f32[14,116]")); + + auto concatenate = + AllOf(op::Concatenate(param0_replicated, param1_replicated), + op::Shape("f32[14,373]")); + + const auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, AllOf(op::DynamicSlice(op::Pad(concatenate, op::Constant()), _, _), + op::Shape("f32[14,187]"))); } TEST_P(SpmdPartitioningTest, ConcatenateAlongBothDimensions) { @@ -2299,22 +2361,59 @@ ENTRY entry { PartitionComputation(hlo_string, /*num_devices=*/4)); VLOG(1) << module->ToString(); - const auto root = module->entry_computation()->root_instruction(); auto param0 = AllOf(op::Parameter(0), op::Shape("f32[7,129]")); auto param0_adjusted = AllOf(op::Select(op::Compare(op::Add(), op::Broadcast(op::Constant())), param0, op::Broadcast(op::Constant())), op::Shape("f32[7,129]")); + auto param0_replicated = AllOf(op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(), param0_adjusted, _, _)), + op::Shape("f32[7,257]")); auto param1 = AllOf(op::Parameter(1), op::Shape("f32[7,58]")); - EXPECT_THAT(root, AllOf(op::DynamicSlice( - AllOf(op::AllReduce(op::DynamicUpdateSlice( - op::DynamicUpdateSlice( - op::Broadcast(), param0_adjusted, - op::Constant(), op::Multiply()), - param1, op::Constant(), op::Add())), - op::Shape("f32[7,374]")), - op::Constant(), op::Multiply()), - op::Shape("f32[7,187]"))); + auto param1_replicated = AllOf( + op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(), param1, _, _)), + op::Shape("f32[7,116]")); + + auto concatenate = + AllOf(op::Concatenate(param0_replicated, param1_replicated), + op::Shape("f32[7,373]")); + + const auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, AllOf(op::DynamicSlice(op::Pad(concatenate, op::Constant()), _, _), + op::Shape("f32[7,187]"))); +} + +TEST_P(SpmdPartitioningTest, DoNotPartitionConcatenate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[256] parameter(0), sharding={devices=[4]<=[4]} + %param1 = s32[] parameter(1), sharding={replicated} + %concatenate = f32[512] concatenate(%param0, %param0), dimensions={0}, sharding={devices=[4]<=[4]} + ROOT %dynamic-slice = f32[256] dynamic-slice(%concatenate, %param1), dynamic_slice_sizes={256}, sharding={devices=[4]<=[4]} +})"; + // In this test target, we do not need to partition the concatenate to satisfy + // the sharding={devices=[4]<=[4]} since the root instruction, the only user + // of the concatenate, requires the concatenate to be replicated. + // + // This pattern is generated by jax.numpy.roll with dynamic shift. + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + + auto param0_replicated = AllOf(op::AllReduce( + op::DynamicUpdateSlice(op::Broadcast(), op::Parameter(0), _))); + auto concatenate_replicated = + AllOf(op::Concatenate(param0_replicated, param0_replicated), + op::Shape("f32[512]")); + auto root_replicated = + AllOf(op::DynamicSlice(concatenate_replicated, op::Parameter(1)), + op::Shape("f32[256]")); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + AllOf(op::DynamicSlice(root_replicated, _), op::Shape("f32[64]"))); } TEST_P(SpmdPartitioningTest, PadAlongNonPartitionedDimension) { @@ -2834,7 +2933,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, /*num_devices=*/8)); - LOG(ERROR) << module->ToString(); + auto custom_call = FindInstruction(module.get(), "custom-call.1"); EXPECT_EQ(custom_call->operand(0)->shape().dimensions(1), 32128); auto sort = FindInstruction(module.get(), "sort"); @@ -7926,10 +8025,13 @@ ENTRY entry { auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")); auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())), op::Shape("s32[2,3]")); - auto clamp = op::Clamp(min, op::Parameter(1), max); + auto clamped_indices = + op::Clamp(op::Broadcast(op::Constant()), op::Parameter(1), + op::Broadcast(op::Constant())); + auto clamp = op::Clamp(min, clamped_indices, max); auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min)); auto mask = - op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max)); + op::Or(op::Lt(clamped_indices, min), op::Gt(clamped_indices, max)); auto masked = op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather); HloInstruction* root = module->entry_computation()->root_instruction(); @@ -7952,15 +8054,18 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, /*num_devices=*/4)); VLOG(1) << module->ToString(); + auto clamped_indices = + op::Clamp(op::Broadcast(op::Constant()), op::Parameter(1), + op::Broadcast(op::Constant())); auto offset = op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())); auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")); auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())), op::Shape("s32[2,3]")); - auto clamp = op::Clamp(min, op::Parameter(1), max); + auto clamp = op::Clamp(min, clamped_indices, max); auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min)); auto mask = - op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max)); + op::Or(op::Lt(clamped_indices, min), op::Gt(clamped_indices, max)); auto masked = op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather); HloInstruction* root = module->entry_computation()->root_instruction(); @@ -8692,22 +8797,19 @@ TEST_P(SpmdPartitioningTest, TiledReversePassthrough) { HloModule module ENTRY entry { - constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}), - sharding={devices=[2,1]0,1} - ROOT reverse = f32[3,3]{1,0} reverse(constant), dimensions={1}, + p0 = f32[3,3] parameter(0), sharding={devices=[2,1]0,1} + ROOT reverse = f32[3,3] reverse(p0), dimensions={1}, sharding={devices=[2,1]0,1} })"; TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, /*num_devices=*/2)); VLOG(1) << module->ToString(); - HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]{1,0}"), - op::Reverse(op::DynamicSlice( - op::Pad(op::Constant(), op::Constant()), - op::Reshape(), op::Constant())))); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + AllOf(op::Shape("f32[2,3]"), op::Reverse(op::Parameter(0)))); } -TEST_P(SpmdPartitioningTest, TiledReversePassthroughViaReversedSharding) { +TEST_P(SpmdPartitioningTest, TiledReverseViaReversedSharding) { absl::string_view hlo_string = R"( HloModule module @@ -11919,11 +12021,10 @@ ENTRY entry { VLOG(1) << module->ToString(); HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::AllReduce(op::Select(_, _, op::Gather(_, _)))); - EXPECT_THAT(root->operand(0)->operand(2)->operand(1), - op::Subtract(op::Clamp(_, op::Parameter(1), _), _)); + EXPECT_THAT( + root->operand(0)->operand(2)->operand(1), + op::Subtract(op::Clamp(_, op::Clamp(_, op::Parameter(1), _), _), _)); - auto clamp = FindInstruction(module.get(), HloOpcode::kClamp); - EXPECT_THAT(clamp->operand(1), op::Parameter(1)); auto dynamic_slice = FindInstruction(module.get(), HloOpcode::kDynamicSlice); EXPECT_THAT(dynamic_slice->operand(1), op::PartitionId()); auto collective_permute = @@ -11955,8 +12056,9 @@ ENTRY entry { _, op::AllReduce(op::Select(_, _, op::Gather(op::AllReduce(_), _))), _, _, _))); auto gather = FindInstruction(module.get(), HloOpcode::kGather); - EXPECT_THAT(gather->operand(1), - op::Subtract(op::Clamp(_, op::Parameter(1), _), _)); + EXPECT_THAT( + gather->operand(1), + op::Subtract(op::Clamp(_, op::Clamp(_, op::Parameter(1), _), _), _)); auto collective_permute = FindInstruction(module.get(), HloOpcode::kCollectivePermute); EXPECT_NE(collective_permute, nullptr); @@ -15326,6 +15428,143 @@ ENTRY entry { op::Shape("f32[1]"))); } +TEST_P(SpmdPartitioningTest, BitcastConvertSameRank) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + p0 = s32[4] parameter(0), sharding={devices=[2]<=[2]} + ROOT result = f32[4] bitcast-convert(p0), sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + + auto param0 = AllOf(op::Parameter(0), op::Shape("s32[2]")); + auto param0_replicated = AllOf(op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), param0, _)), + op::Shape("s32[4]")); + auto result = + AllOf(op::BitcastConvert(param0_replicated), op::Shape("f32[4]")); + EXPECT_THAT(module->entry_computation()->root_instruction(), result); +} + +TEST_P(SpmdPartitioningTest, BitcastConvertInputRankGreaterThanOutputRank) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + p0 = s32[4,2] parameter(0), sharding={devices=[2,2]<=[4]} + ROOT result = f64[4] bitcast-convert(p0), sharding={devices=[2,2]<=[2,2]T(1,0) last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + + auto param0 = AllOf(op::Parameter(0), op::Shape("s32[2,1]")); + auto param0_reshard = AllOf(op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), param0, _, _)), + op::Shape("s32[2,2]")); + auto result = AllOf(op::BitcastConvert(param0_reshard), op::Shape("f64[2]")); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::CollectivePermute(result)); +} + +TEST_P(SpmdPartitioningTest, BitcastConvertInputRankSmallerThanOutputRank) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + p0 = s64[4] parameter(0), sharding={devices=[2,2]<=[2,2]T(1,0) last_tile_dim_replicate} + ROOT result = f32[4,2] bitcast-convert(p0), sharding={devices=[2,2]<=[4]} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + + auto param0 = AllOf(op::Parameter(0), op::Shape("s64[2]")); + auto param0_reshard = + AllOf(op::CollectivePermute(param0), op::Shape("s64[2]")); + auto result = + AllOf(op::BitcastConvert(param0_reshard), op::Shape("f32[2,2]")); + EXPECT_THAT(module->entry_computation()->root_instruction(), + AllOf(op::DynamicSlice(result, _, _), op::Shape("f32[2,1]"))); +} + +TEST_P(SpmdPartitioningTest, Cholesky) { + absl::string_view hlo_string = R"( +ENTRY entry { + %p0 = f32[32,32,32] parameter(0), sharding={devices=[2,2,2]<=[8]} + ROOT %cholesky = f32[32,32,32] cholesky(p0), lower=true, sharding={devices=[2,2,2]<=[8]} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + + auto param0 = AllOf(op::Parameter(0), op::Shape("f32[16,16,16]")); + auto param0_reshard = + AllOf(op::Shape("f32[16,32,32]"), + op::AllReduce(op::AllReduce( + op::DynamicUpdateSlice(op::Broadcast(), param0, _, _, _)))); + auto cholesky = + AllOf(op::Cholesky(param0_reshard), op::Shape("f32[16,32,32]")); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + AllOf(op::DynamicSlice(cholesky, _, _, _), op::Shape("f32[16,16,16]"))); +} + +TEST_P(SpmdPartitioningTest, TriangularSolve) { + absl::string_view hlo_string = R"( +ENTRY main { + a = f32[10,32,32] parameter(0), sharding={devices=[2,2,2]<=[8]} + b = f32[10,32,48] parameter(1), sharding={devices=[2,2,2]<=[8]} + ROOT triangular-solve = f32[10,32,48] triangular-solve(a, b), left_side=true, unit_diagonal=true, lower=true, transpose_a=NO_TRANSPOSE, sharding={devices=[2,2,2]<=[8]} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + + auto param0 = AllOf(op::Parameter(0), op::Shape("f32[5,16,16]")); + auto param0_reshard = + AllOf(op::Shape("f32[5,32,32]"), + op::AllReduce(op::AllReduce( + op::DynamicUpdateSlice(op::Broadcast(), param0, _, _, _)))); + auto param1 = AllOf(op::Parameter(1), op::Shape("f32[5,16,24]")); + auto param1_reshard = + AllOf(op::Shape("f32[5,32,48]"), + op::AllReduce(op::AllReduce( + op::DynamicUpdateSlice(op::Broadcast(), param1, _, _, _)))); + + auto ts = AllOf(op::TriangularSolve(param0_reshard, param1_reshard), + op::Shape("f32[5,32,48]")); + EXPECT_THAT(module->entry_computation()->root_instruction(), + AllOf(op::DynamicSlice(ts, _, _, _), op::Shape("f32[5,16,24]"))); +} + +TEST_P(SpmdPartitioningTest, PartitionCollectivePermute) { + absl::string_view hlo_string = R"( +HloModule jit_f, entry_computation_layout={(s32[8]{0})->s32[8]{0}}, allow_spmd_sharding_propagation_to_output={true}, num_partitions=8 + +ENTRY main.12 { + Arg_0.1 = s32[8]{0} parameter(0), sharding={devices=[8]<=[8]}, metadata={op_name="x"} + copy.2 = s32[8]{0} copy(Arg_0.1), sharding={devices=[4,2]<=[8] last_tile_dim_replicate} + custom-call.3 = s32[2]{0} custom-call(copy.2), custom_call_target="SPMDFullToShardShape", sharding={devices=[1,4,2]<=[8] last_tile_dims={manual, replicated}}, backend_config="unspecified_dims=[0]" + copy.1 = s32[2]{0} copy(custom-call.3), sharding={devices=[2,4]<=[4,2]T(1,0) last_tile_dims={manual}} + multiply.0 = s32[2]{0} multiply(copy.1, copy.1), sharding={devices=[2,4]<=[4,2]T(1,0) last_tile_dims={manual}} + collective-permute.0 = s32[2]{0} collective-permute(multiply.0), channel_id=1, source_target_pairs={{0,6},{2,0},{4,2},{6,4},{1,7},{3,1},{5,3},{7,5}}, sharding={devices=[2,4]<=[4,2]T(1,0) last_tile_dims={manual}} + ROOT custom-call.11 = s32[8]{0} custom-call(collective-permute.0), custom_call_target="SPMDShardToFullShape", sharding={devices=[8]<=[8]}, backend_config="unspecified_dims=[0]" +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + // Check the collective permute instruction is partitioned. + auto cp = FindInstruction(module.get(), HloOpcode::kCollectivePermute); + EXPECT_NE(cp, nullptr); + EXPECT_THAT(cp, op::Shape("s32[1]{0}")); +} + } // namespace } // namespace spmd } // namespace xla diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc index aa92be45a8ca40..b4eb1f1092cda7 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc @@ -398,7 +398,7 @@ std::optional PartialReplicateReshardCompatibleSharding( std::vector perm; perm.reserve(rank + expand_tile_sizes.size()); for (int64_t dim = 0; dim < rank; dim++) { - perm.emplace_back(dim); + perm.push_back(dim); if (expand_tile_dims_indices[dim] > -1) { perm.emplace_back(expand_tile_dims_indices[dim] + rank); } @@ -530,7 +530,7 @@ std::optional PadFromPartialReplicateShape( // If src sharding at this dimension is not partitioned, simply pad to // the desired shape. if (src_shard_count == 1) { - expand_dims_without_halo_exchange.emplace_back(dim); + expand_dims_without_halo_exchange.push_back(dim); continue; } diff --git a/third_party/xla/xla/service/triangular_solve_expander_test.cc b/third_party/xla/xla/service/triangular_solve_expander_test.cc index fa382b24d0d9db..1a2ba8c71ece6e 100644 --- a/third_party/xla/xla/service/triangular_solve_expander_test.cc +++ b/third_party/xla/xla/service/triangular_solve_expander_test.cc @@ -15,15 +15,20 @@ limitations under the License. #include "xla/service/triangular_solve_expander.h" +#include #include +#include #include +#include "xla/array2d.h" +#include "xla/error_spec.h" #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/reference_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/tuple_util_test.cc b/third_party/xla/xla/service/tuple_util_test.cc index e2a7176bc12b44..6e91ad17f7e12d 100644 --- a/third_party/xla/xla/service/tuple_util_test.cc +++ b/third_party/xla/xla/service/tuple_util_test.cc @@ -26,7 +26,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/value_range.cc b/third_party/xla/xla/service/value_range.cc index 850db808d73928..d4edd39db8edd7 100644 --- a/third_party/xla/xla/service/value_range.cc +++ b/third_party/xla/xla/service/value_range.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/log/log.h" @@ -54,7 +55,8 @@ std::string Range::ToString() const { return min_.ToString(); } return absl::StrCat( - "min: ", min_.ToString(), " max: ", max_.ToString(), + "min: ", min_.ToString(), + " max: ", IsBounded() ? max_.value().ToString() : "Unknown", " step: ", IsStepKnown() ? step_.value().ToString() : "Unknown"); } @@ -69,17 +71,27 @@ std::optional FindStepForBinaryOp(const Range& lhs, if (rhs.IsSingleValue()) { return lhs.step(); } - if (lhs.step().eq(rhs.step())) { + if (lhs.step()->eq(rhs.step().value())) { return lhs.step(); } return std::nullopt; } +// Helper function that updates the known_ranges map and returns the range. +Range RecordAndReturnRange( + const Range& range, const HloInstruction* instr, + absl::flat_hash_map& known_ranges) { + known_ranges[instr] = range; + VLOG(5) << "Computed range for: " << instr->name() << " -> " + << range.ToString(); + return range; +} + // Identify the value ranges of a scalar HLO with a integer type. It returns // a range of values that the instruction can have. Range RecursivelyIdentifyRange( const HloInstruction* instr, - const absl::flat_hash_map& predefined_ranges, + absl::flat_hash_map& known_ranges, const HloAliasAnalysis* alias_analysis) { // Non scalar or non-integer HLO. Abort. if ((!instr->shape().IsInteger() && instr->shape().element_type() != PRED) || @@ -87,32 +99,48 @@ Range RecursivelyIdentifyRange( return Range{}; } VLOG(5) << "Computing Range for " << instr->ToString(); - auto it = predefined_ranges.find(instr); - if (it != predefined_ranges.end()) { - VLOG(5) << "Found range! " << it->second.max().GetSignedValue() << " " - << it->second.min().GetSignedValue(); + auto it = known_ranges.find(instr); + if (it != known_ranges.end()) { + VLOG(5) << "Found range: " << it->second.ToString(); return it->second; } else if (alias_analysis != nullptr) { auto value_set = alias_analysis->dataflow_analysis().GetFlattenedValueSet(instr); for (const auto& value : value_set.TakeValues()) { for (const HloPosition& position : value->positions()) { - auto it = predefined_ranges.find(position.instruction); - if (it != predefined_ranges.end()) { - VLOG(5) << "Found range in defining instruction! " - << it->second.max().GetSignedValue() << " " - << it->second.min().GetSignedValue(); + auto it = known_ranges.find(position.instruction); + if (it != known_ranges.end()) { + VLOG(5) << "Found range in defining instruction: " + << it->second.ToString(); return it->second; } } } } switch (instr->opcode()) { + case HloOpcode::kGetTupleElement: { + if (alias_analysis != nullptr) { + auto value_set = + alias_analysis->dataflow_analysis().GetFlattenedValueSet(instr); + std::vector values = value_set.TakeValues(); + if (values.size() != 1) { + VLOG(5) << "Ambiguous value set"; + return Range{}; + } + HloInstruction* defining_instruction = + values.at(0)->defining_instruction(); + if (defining_instruction != nullptr) { + return RecursivelyIdentifyRange(defining_instruction, known_ranges, + alias_analysis); + } + } + return Range{}; + } case HloOpcode::kCompare: { VLOG(5) << "Handling Compare"; - Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges, + Range lhs = RecursivelyIdentifyRange(instr->operand(0), known_ranges, alias_analysis); - Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges, + Range rhs = RecursivelyIdentifyRange(instr->operand(1), known_ranges, alias_analysis); VLOG(5) << "Returned Rhs: " << rhs.ToString() << " Lhs: " << lhs.ToString(); @@ -120,37 +148,37 @@ Range RecursivelyIdentifyRange( if (instr->comparison_direction() != ComparisonDirection::kLt) { return Range{}; } - if (lhs.max().lt(rhs.min())) { - return Range{ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), - ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), - /*is_linear=*/true}; + if (lhs.IsBounded() && lhs.max()->lt(rhs.min())) { + return RecordAndReturnRange( + Range{ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), + ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), + /*is_linear=*/true}, + instr, known_ranges); } - if (!lhs.min().lt(rhs.max())) { - return Range{ - ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false), - ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false), - /*is_linear=*/true}; + if (rhs.IsBounded() && !lhs.min().lt(rhs.max().value())) { + return RecordAndReturnRange( + Range{ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false), + ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false), + /*is_linear=*/true}, + instr, known_ranges); } - VLOG(5) << "Compare failed"; - VLOG(5) << "rhs max " << rhs.max().GetSignedValue() << " rhs min " - << rhs.min().GetSignedValue() << " lhs max " - << lhs.max().GetSignedValue() << " lhs min " - << lhs.min().GetSignedValue(); return Range{}; } case HloOpcode::kConstant: { if (instr->shape().element_type() == PRED && instr->shape().dimensions_size() == 0) { if (instr->literal().IsAll(true)) { - return Range{ - ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), - ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), - /*is_linear=*/true}; + return RecordAndReturnRange( + Range{ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), + ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), + /*is_linear=*/true}, + instr, known_ranges); } - return Range{ - ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false), - ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false), - /*is_linear=*/true}; + return RecordAndReturnRange( + Range{ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false), + ConstantValue::GetZero(/*bitwidth=*/1, /*is_signed=*/false), + /*is_linear=*/true}, + instr, known_ranges); } if (!instr->shape().IsInteger()) { return Range{}; @@ -162,25 +190,29 @@ Range RecursivelyIdentifyRange( primitive_util::IsSignedIntegralType(instr->shape().element_type()); if (is_signed) { const int64_t value = *instr->literal().GetFirstInteger(); - return Range{ConstantValue::GetSigned(value, bitwidth), - ConstantValue::GetSigned(value, bitwidth), - ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), - /*is_linear=*/true}; + return RecordAndReturnRange( + Range{ConstantValue::GetSigned(value, bitwidth), + ConstantValue::GetSigned(value, bitwidth), + ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), + /*is_linear=*/true}, + instr, known_ranges); } const uint64_t value = *instr->literal().GetFirstInteger(); - return Range{ConstantValue::GetUnsigned(value, bitwidth), - ConstantValue::GetUnsigned(value, bitwidth), - ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), - /*is_linear=*/true}; + return RecordAndReturnRange( + Range{ConstantValue::GetUnsigned(value, bitwidth), + ConstantValue::GetUnsigned(value, bitwidth), + ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), + /*is_linear=*/true}, + instr, known_ranges); } case HloOpcode::kAdd: { if (!instr->shape().IsInteger()) { return Range{}; } VLOG(5) << "Handling Add"; - Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges, + Range lhs = RecursivelyIdentifyRange(instr->operand(0), known_ranges, alias_analysis); - Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges, + Range rhs = RecursivelyIdentifyRange(instr->operand(1), known_ranges, alias_analysis); VLOG(5) << "Returned Rhs: " << rhs.ToString() << " Lhs: " << lhs.ToString(); @@ -188,22 +220,29 @@ Range RecursivelyIdentifyRange( return Range{}; } ConstantValue min = lhs.min().add(rhs.min()); - ConstantValue max = lhs.max().add(rhs.max()); - if (max.lt(min)) { - VLOG(5) << "Add wrapped"; - return Range{}; + std::optional step = FindStepForBinaryOp(lhs, rhs); + if (lhs.IsBounded() && rhs.IsBounded()) { + ConstantValue max = lhs.max()->add(rhs.max().value()); + if (max.lt(min)) { + VLOG(5) << "Add wrapped"; + return Range{}; + } + return RecordAndReturnRange( + Range{min, max, step, lhs.IsLinear() && rhs.IsLinear()}, instr, + known_ranges); } - return Range{min, max, FindStepForBinaryOp(lhs, rhs), - lhs.IsLinear() && rhs.IsLinear()}; + return RecordAndReturnRange( + Range{min, std::nullopt, step, lhs.IsLinear() && rhs.IsLinear()}, + instr, known_ranges); } case HloOpcode::kMultiply: { if (!instr->shape().IsInteger()) { return Range{}; } VLOG(5) << "Handling Multiply"; - Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges, + Range lhs = RecursivelyIdentifyRange(instr->operand(0), known_ranges, alias_analysis); - Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges, + Range rhs = RecursivelyIdentifyRange(instr->operand(1), known_ranges, alias_analysis); VLOG(5) << "Returned Rhs: " << rhs.ToString() << " Lhs: " << lhs.ToString(); @@ -215,51 +254,88 @@ Range RecursivelyIdentifyRange( return Range{}; } ConstantValue single_value = lhs.IsSingleValue() ? lhs.min() : rhs.min(); - ConstantValue min = lhs.IsSingleValue() ? rhs.min().mul(single_value) - : lhs.min().mul(single_value); - ConstantValue max = lhs.IsSingleValue() ? rhs.max().mul(single_value) - : lhs.max().mul(single_value); - return Range{min, max, single_value, lhs.IsLinear() && rhs.IsLinear()}; + Range operand_range = lhs.IsSingleValue() ? rhs : lhs; + // When multiplying with a constant, min, max, and step are all + // multiplied by the single value. + ConstantValue min = operand_range.min().mul(single_value); + if (operand_range.IsBounded()) { + ConstantValue max = operand_range.max()->mul(single_value); + if (!operand_range.IsStepKnown()) { + return RecordAndReturnRange(Range{min, max, operand_range.IsLinear()}, + instr, known_ranges); + } + ConstantValue step = operand_range.step()->mul(single_value); + return RecordAndReturnRange( + Range{min, max, step, operand_range.IsLinear()}, instr, + known_ranges); + } + if (!operand_range.IsStepKnown()) { + return RecordAndReturnRange( + Range{min, std::nullopt, operand_range.IsLinear()}, instr, + known_ranges); + } + ConstantValue step = operand_range.step()->mul(single_value); + return RecordAndReturnRange( + Range{min, std::nullopt, step, operand_range.IsLinear()}, instr, + known_ranges); } case HloOpcode::kSelect: { VLOG(5) << "Handling Select: " << instr->ToString(); const HloInstruction* cmp = instr->operand(0); Range cmp_range = - RecursivelyIdentifyRange(cmp, predefined_ranges, alias_analysis); + RecursivelyIdentifyRange(cmp, known_ranges, alias_analysis); // Support only when the select has a constant value as condition. if (cmp_range.IsEmpty() || !cmp_range.IsSingleValue()) { VLOG(5) << "Select failed"; return Range{}; } if (cmp_range.GetSingleSignedValue() == 0) { - return RecursivelyIdentifyRange(instr->operand(2), predefined_ranges, - alias_analysis); + return RecordAndReturnRange( + RecursivelyIdentifyRange(instr->operand(2), known_ranges, + alias_analysis), + instr, known_ranges); } - return RecursivelyIdentifyRange(instr->operand(1), predefined_ranges, - alias_analysis); + return RecordAndReturnRange( + RecursivelyIdentifyRange(instr->operand(1), known_ranges, + alias_analysis), + instr, known_ranges); } case HloOpcode::kSubtract: { if (!instr->shape().IsInteger()) { return Range{}; } VLOG(5) << "Handling Subtract"; - Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges, + Range lhs = RecursivelyIdentifyRange(instr->operand(0), known_ranges, alias_analysis); - Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges, + Range rhs = RecursivelyIdentifyRange(instr->operand(1), known_ranges, alias_analysis); VLOG(5) << "Returned Rhs: " << rhs.ToString() << " Lhs: " << lhs.ToString(); if (lhs.IsEmpty() || rhs.IsEmpty()) { return Range{}; } - ConstantValue min = lhs.min().sub(rhs.max()); - ConstantValue max = lhs.max().sub(rhs.min()); - if (max.lt(min)) { - VLOG(5) << "Subtract wrapped"; + if (lhs.IsBounded() && rhs.IsBounded()) { + ConstantValue min = lhs.min().sub(rhs.max().value()); + ConstantValue max = lhs.max()->sub(rhs.min()); + if (max.lt(min)) { + VLOG(5) << "Subtract wrapped"; + return Range{}; + } + return RecordAndReturnRange( + Range{min, max, FindStepForBinaryOp(lhs, rhs), + lhs.IsLinear() && rhs.IsLinear()}, + instr, known_ranges); + } else if (lhs.IsBounded()) { // bounded - unbounded -> Empty range + VLOG(5) << "Subtract unbounded from bounded is not represntable with a " + "range"; return Range{}; + } else { // unbounded - bounded -> Unbounded range + ConstantValue min = lhs.min().sub(rhs.max().value()); + return RecordAndReturnRange( + Range{min, std::nullopt, FindStepForBinaryOp(lhs, rhs), + lhs.IsLinear() && rhs.IsLinear()}, + instr, known_ranges); } - return Range{min, max, FindStepForBinaryOp(lhs, rhs), - lhs.IsLinear() && rhs.IsLinear()}; } default: break; diff --git a/third_party/xla/xla/service/value_range.h b/third_party/xla/xla/service/value_range.h index b46b9bbcfa22fa..eb06d3b488ffd1 100644 --- a/third_party/xla/xla/service/value_range.h +++ b/third_party/xla/xla/service/value_range.h @@ -26,7 +26,10 @@ limitations under the License. namespace xla { -// Class keeping track of the range of an HLO value. +// Class keeping track of the range of an HLO value. A range is typically +// defined by a minimum value, a maximum value, and a step value. The step and +// maximum values are optional. If the maximum value is missing, the range is +// unbounded. The default step value is nullopt. class Range { public: Range() @@ -35,13 +38,14 @@ class Range { step_(ConstantValue::GetZero(/*bitwidth=*/64, /*is_signed=*/false)), empty_(true), is_linear_(false) {} - Range(const ConstantValue& min, const ConstantValue& max, bool is_linear) + Range(const ConstantValue& min, std::optional max, + bool is_linear) : min_(min), max_(max), step_(std::nullopt), empty_(false), is_linear_(is_linear) {} - Range(const ConstantValue& min, const ConstantValue& max, + Range(const ConstantValue& min, std::optional max, std::optional step, bool is_linear) : min_(min), max_(max), @@ -51,13 +55,15 @@ class Range { // Minimum value of the range. const ConstantValue& min() const { return min_; } // Maximum value of the range. - const ConstantValue& max() const { return max_; } + const std::optional& max() const { return max_; } // Step value of the range. - const ConstantValue& step() const { return step_.value(); } - // Returns if the range is empty (no value in set). + const std::optional& step() const { return step_; } + // Returns if the range has min and max values (it can be a single value). bool IsEmpty() const { return empty_; } // Only one value in set. This means the range is a constant. - bool IsSingleValue() const { return !IsEmpty() && min_ == max_; } + bool IsSingleValue() const { + return !IsEmpty() && max_.has_value() && min_ == max_; + } // This is a way to track in some way recurring values that change in a // monotonic way. This true means that the variables driving the range change // in a monotonic way and that the way they are composed together is linear @@ -65,6 +71,8 @@ class Range { // loop recursion. bool IsLinear() const { return is_linear_; } bool IsStepKnown() const { return step_.has_value(); } + // If this range is a bounded range with known max value. + bool IsBounded() const { return max_.has_value(); } // If this range represents a single value return that signed value. std::optional GetSingleSignedValue() const; // If this range represents a single value return that unsigned value. @@ -81,20 +89,20 @@ class Range { private: ConstantValue min_; - ConstantValue max_; + std::optional max_; std::optional step_; bool empty_; bool is_linear_; }; -// Constructs a Range object from a HloInstruction. Gets a "predefined_ranges" +// Constructs a Range object from a HloInstruction. Gets a "known_ranges" // object as input that returns known ranges for some variables for which we // already know the range. The final range is composed from operations over // these predetermined ranges. // The input HLO needs to be of scalar type and integer. Range RecursivelyIdentifyRange( const HloInstruction* instr, - const absl::flat_hash_map& predefined_ranges, + absl::flat_hash_map& known_ranges, const HloAliasAnalysis* alias_analysis = nullptr); } // namespace xla diff --git a/third_party/xla/xla/service/value_range_test.cc b/third_party/xla/xla/service/value_range_test.cc index 05a64ae3a6d9bf..ff389b92b11c57 100644 --- a/third_party/xla/xla/service/value_range_test.cc +++ b/third_party/xla/xla/service/value_range_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/value_range.h" +#include #include #include @@ -22,6 +23,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/string_view.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/service/constant_value.h" @@ -59,8 +61,8 @@ TEST_F(ValueRangeTest, AddedValue) { EXPECT_FALSE(range.IsSingleValue()); EXPECT_TRUE(range.IsLinear()); EXPECT_EQ(range.min().GetSignedValue(), 124); - EXPECT_EQ(range.max().GetSignedValue(), 129); - EXPECT_EQ(range.step().GetSignedValue(), 1); + EXPECT_EQ(range.max()->GetSignedValue(), 124 + 5); + EXPECT_EQ(range.step()->GetSignedValue(), 1); } TEST_F(ValueRangeTest, MultiplyValue) { @@ -78,18 +80,64 @@ TEST_F(ValueRangeTest, MultiplyValue) { const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* p0 = root->operand(0); absl::flat_hash_map fs; - fs.insert( - std::make_pair(p0, Range{ConstantValue::GetZero(32, /*is_signed=*/true), - ConstantValue::GetSigned(5, 32), - ConstantValue::GetOne(32, /*is_signed=*/false), - /*is_linear=*/true})); + // p0 has range min = 0, max = 32, step = 2. + fs.insert(std::make_pair( + p0, Range{/*min=*/ConstantValue::GetSigned(0, /*bitwidth=*/32), + /*max=*/ConstantValue::GetSigned(32, /*bitwidth=*/32), + /*step=*/ConstantValue::GetUnsigned(2, /*bitwidth=*/32), + /*is_linear=*/true})); auto range = RecursivelyIdentifyRange(root, fs); EXPECT_FALSE(range.IsEmpty()); EXPECT_FALSE(range.IsSingleValue()); EXPECT_TRUE(range.IsLinear()); EXPECT_EQ(range.min().GetSignedValue(), 0); - EXPECT_EQ(range.max().GetSignedValue(), 5120); - EXPECT_EQ(range.step().GetSignedValue(), 1024); + EXPECT_EQ(range.max()->GetSignedValue(), 32 * 1024); + EXPECT_EQ(range.step()->GetSignedValue(), 2 * 1024); +} + +TEST_F(ValueRangeTest, MultiplyValuePassedToLoop) { + constexpr absl::string_view hlo_string = R"( + HloModule module + body.comp { + p0 = (s32[], s32[]) parameter(0) + gte = s32[] get-tuple-element(p0), index=0 + ROOT tuple = (s32[], s32[]) tuple(gte, gte) + } + cond.comp { + p0 = (s32[], s32[]) parameter(0) + ROOT out = pred[] constant(true) + } + ENTRY entry { + c0 = s32[] constant(1024) + p0 = s32[] parameter(0) + %mul = s32[] multiply(p0, c0) + tuple = (s32[], s32[]) tuple(%mul, %mul) + ROOT out = (s32[], s32[]) while(tuple), condition=cond.comp, + body=body.comp + } + )"; + auto module = + ParseAndReturnUnverifiedModule(hlo_string, HloModuleConfig{}).value(); + TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis, + HloAliasAnalysis::Run(module.get())); + const HloInstruction* p0 = + module->entry_computation()->parameter_instruction(0); + absl::flat_hash_map fs; + // p0 has range min = 0, max = 32, step = 2. + fs.insert(std::make_pair( + p0, Range{/*min=*/ConstantValue::GetSigned(0, /*bitwidth=*/32), + /*max=*/ConstantValue::GetSigned(32, /*bitwidth=*/32), + /*step=*/ConstantValue::GetUnsigned(2, /*bitwidth=*/32), + /*is_linear=*/true})); + HloComputation* body = module->GetComputationWithName("body.comp"); + HloInstruction* gte = body->GetInstructionWithName("gte"); + auto range = RecursivelyIdentifyRange(gte, fs, alias_analysis.get()); + EXPECT_FALSE(range.IsEmpty()); + EXPECT_FALSE(range.IsSingleValue()); + EXPECT_TRUE(range.IsLinear()); + EXPECT_EQ(range.min().GetSignedValue(), 0); + EXPECT_EQ(range.max()->GetSignedValue(), 32 * 1024); + EXPECT_EQ(range.step()->GetSignedValue(), 2 * 1024); } TEST_F(ValueRangeTest, ConstantValuePred) { @@ -104,14 +152,15 @@ TEST_F(ValueRangeTest, ConstantValuePred) { auto module = ParseAndReturnUnverifiedModule(hlo_string, HloModuleConfig{}).value(); const HloInstruction* tuple = module->entry_computation()->root_instruction(); - auto false_range = RecursivelyIdentifyRange(tuple->operand(0), {}); + absl::flat_hash_map known_ranges; + auto false_range = RecursivelyIdentifyRange(tuple->operand(0), known_ranges); VLOG(3) << "false_range: " << false_range.ToString(); EXPECT_FALSE(false_range.IsEmpty()); EXPECT_TRUE(false_range.IsSingleValue()); EXPECT_TRUE(false_range.IsLinear()); EXPECT_EQ(false_range.min().GetUnsignedValue(), 0); - auto true_range = RecursivelyIdentifyRange(tuple->operand(1), {}); + auto true_range = RecursivelyIdentifyRange(tuple->operand(1), known_ranges); VLOG(3) << "true_range: " << true_range.ToString(); EXPECT_FALSE(true_range.IsEmpty()); EXPECT_TRUE(true_range.IsSingleValue()); @@ -137,7 +186,8 @@ TEST_F(ValueRangeTest, ConstantValueWithConditional) { ENTRY entry { p0 = s32[] parameter(0) branch_index = s32[] parameter(1) - ROOT conditional.1 = (s32[], s32[]) conditional(branch_index, p0, p0), branch_computations={region1, region2} + ROOT conditional.1 = (s32[], s32[]) conditional(branch_index, p0, p0), + branch_computations={region1, region2} } )"; auto module = @@ -151,27 +201,28 @@ TEST_F(ValueRangeTest, ConstantValueWithConditional) { const HloInstruction* p0 = module->entry_computation()->parameter_instruction(0); absl::flat_hash_map fs; - fs.insert( - std::make_pair(p0, Range{ConstantValue::GetZero(32, /*is_signed=*/true), - ConstantValue::GetSigned(5, 32), - ConstantValue::GetOne(32, /*is_signed=*/false), - /*is_linear=*/true})); + // p0 has range min = 0, max = 32, step = 2. + fs.insert(std::make_pair( + p0, Range{/*min=*/ConstantValue::GetSigned(0, /*bitwidth=*/32), + /*max=*/ConstantValue::GetSigned(32, /*bitwidth=*/32), + /*step=*/ConstantValue::GetUnsigned(2, /*bitwidth=*/32), + /*is_linear=*/true})); auto add_range = RecursivelyIdentifyRange(add, fs, alias_analysis.get()); EXPECT_FALSE(add_range.IsEmpty()); EXPECT_FALSE(add_range.IsSingleValue()); EXPECT_TRUE(add_range.IsLinear()); EXPECT_EQ(add_range.min().GetSignedValue(), 1024); - EXPECT_EQ(add_range.max().GetSignedValue(), 1029); - EXPECT_EQ(add_range.step().GetSignedValue(), 1); + EXPECT_EQ(add_range.max()->GetSignedValue(), 1024 + 32); + EXPECT_EQ(add_range.step()->GetSignedValue(), 2); auto mult_range = RecursivelyIdentifyRange(mult, fs, alias_analysis.get()); EXPECT_FALSE(mult_range.IsEmpty()); EXPECT_FALSE(mult_range.IsSingleValue()); EXPECT_TRUE(mult_range.IsLinear()); EXPECT_EQ(mult_range.min().GetSignedValue(), 0); - EXPECT_EQ(mult_range.max().GetSignedValue(), 5120); - EXPECT_EQ(mult_range.step().GetSignedValue(), 1024); + EXPECT_EQ(mult_range.max()->GetSignedValue(), 32 * 1024); + EXPECT_EQ(mult_range.step()->GetSignedValue(), 2 * 1024); } TEST_F(ValueRangeTest, SelectValueWithCompareInConditional) { @@ -181,28 +232,29 @@ TEST_F(ValueRangeTest, SelectValueWithCompareInConditional) { region1_param = s32[] parameter(0) region1_c0 = s32[] constant(1024) %add = s32[] add(region1_param, region1_c0) - - compare_const = s32[] constant(1030) // this valueis bigger than the max of add + + compare_const = s32[] constant(1030) compare1 = pred[] compare(%add, compare_const), direction=LT select1 = s32[] select(compare1, region1_param, %add) - + ROOT out = (s32[], s32[]) tuple(%add, %add) } region2 { region2_param = s32[] parameter(0) region2_c0 = s32[] constant(1024) %mult = s32[] multiply(region2_param, region2_c0) - - compare_const = s32[] constant(5121) // this valueis bigger than the max of mult + + compare_const = s32[] constant(5121) compare2 = pred[] compare(%mult, compare_const), direction=LT select2 = s32[] select(compare2, region2_param, %mult) - + ROOT out = (s32[], s32[]) tuple(%mult, %mult) } ENTRY entry { p0 = s32[] parameter(0) branch_index = s32[] parameter(1) - ROOT conditional.1 = (s32[], s32[]) conditional(branch_index, p0, p0), branch_computations={region1, region2} + ROOT conditional.1 = (s32[], s32[]) conditional(branch_index, p0, p0), + branch_computations={region1, region2} } )"; auto module = @@ -216,11 +268,12 @@ TEST_F(ValueRangeTest, SelectValueWithCompareInConditional) { const HloInstruction* p0 = module->entry_computation()->parameter_instruction(0); absl::flat_hash_map fs; - fs.insert( - std::make_pair(p0, Range{ConstantValue::GetZero(32, /*is_signed=*/true), - ConstantValue::GetSigned(5, 32), - ConstantValue::GetOne(32, /*is_signed=*/false), - /*is_linear=*/true})); + // p0 has range min = 0, max = 32, step = 2. + fs.insert(std::make_pair( + p0, Range{/*min=*/ConstantValue::GetSigned(0, /*bitwidth=*/32), + /*max=*/ConstantValue::GetSigned(32, /*bitwidth=*/32), + /*step=*/ConstantValue::GetUnsigned(2, /*bitwidth=*/32), + /*is_linear=*/true})); auto select1_range = RecursivelyIdentifyRange(select1, fs, alias_analysis.get()); @@ -254,7 +307,7 @@ ENTRY entry { EXPECT_FALSE(range.IsSingleValue()); EXPECT_TRUE(range.IsLinear()); EXPECT_EQ(range.min().GetUnsignedValue(), 32768); - EXPECT_EQ(range.max().GetUnsignedValue(), 32773); + EXPECT_EQ(range.max()->GetUnsignedValue(), 32773); } TEST_F(ValueRangeTest, SubtractValue) { @@ -280,7 +333,7 @@ ENTRY entry { EXPECT_FALSE(range.IsSingleValue()); EXPECT_TRUE(range.IsLinear()); EXPECT_EQ(range.min().GetSignedValue(), -124); - EXPECT_EQ(range.max().GetSignedValue(), -119); + EXPECT_EQ(range.max()->GetSignedValue(), -119); } TEST_F(ValueRangeTest, SelectValue) { @@ -308,7 +361,7 @@ ENTRY entry { EXPECT_FALSE(range.IsEmpty()); EXPECT_FALSE(range.IsSingleValue()); EXPECT_TRUE(range.IsLinear()); - EXPECT_EQ(range.max().GetSignedValue(), -119); + EXPECT_EQ(range.max()->GetSignedValue(), -119); EXPECT_EQ(range.min().GetSignedValue(), -124); } @@ -337,10 +390,47 @@ ENTRY entry { EXPECT_FALSE(range.IsEmpty()); EXPECT_FALSE(range.IsSingleValue()); EXPECT_TRUE(range.IsLinear()); - EXPECT_EQ(range.max().GetSignedValue(), 129); + EXPECT_EQ(range.max()->GetSignedValue(), 129); EXPECT_EQ(range.min().GetSignedValue(), 124); } +TEST_F(ValueRangeTest, SelectBoundedFromUnboundedRange) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + p0 = s32[] parameter(0) + p1 = s32[] parameter(1) + ROOT %s = s32[] subtract(p0, p1) +} +)"; + auto module = + ParseAndReturnUnverifiedModule(hlo_string, HloModuleConfig{}).value(); + const HloInstruction* root = module->entry_computation()->root_instruction(); + const HloInstruction* p0 = + module->entry_computation()->parameter_instruction(0); + const HloInstruction* p1 = + module->entry_computation()->parameter_instruction(1); + absl::flat_hash_map fs; + // p0 has range min = 1, max = Unknown, step = 2 + fs.insert(std::make_pair( + p0, Range{/*min=*/ConstantValue::GetSigned(1, 32), + /*max=*/std::nullopt, + /*step=*/ConstantValue::GetUnsigned(2, /*bitwidth=*/32), + /*is_linear=*/true})); + // p1 has range min = 0, max = 10, step = 2 + fs.insert(std::make_pair( + p1, Range{/*min=*/ConstantValue::GetZero(32, /*is_signed=*/true), + /*max=*/ConstantValue::GetSigned(10, 32), + /*step=*/ConstantValue::GetUnsigned(2, /*bitwidth=*/32), + /*is_linear=*/true})); + auto range = RecursivelyIdentifyRange(root, fs); + EXPECT_FALSE(range.IsSingleValue()); + EXPECT_TRUE(range.IsLinear()); + EXPECT_FALSE(range.IsBounded()); + EXPECT_EQ(range.min().GetSignedValue(), 1 - 10); +} + TEST_F(ValueRangeTest, AddSubtractValue) { constexpr absl::string_view hlo_string = R"( HloModule module @@ -368,7 +458,7 @@ ENTRY entry { EXPECT_FALSE(range.IsSingleValue()); EXPECT_TRUE(range.IsLinear()); EXPECT_EQ(range.min().GetSignedValue(), 112); - EXPECT_EQ(range.max().GetSignedValue(), 117); + EXPECT_EQ(range.max()->GetSignedValue(), 117); } TEST_F(ValueRangeTest, SubtractWrapAroundValue) { @@ -386,10 +476,10 @@ ENTRY entry { const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* p0 = root->operand(0); absl::flat_hash_map fs; - fs.insert( - std::make_pair(p0, Range{ConstantValue::GetSigned(-32768, 16), - ConstantValue::GetZero(16, /*is_signed=*/true), - /*is_linear=*/true})); + fs.insert(std::make_pair(p0, Range{ConstantValue::GetSigned(-32768, 16), + ConstantValue::GetZero(16, + /*is_signed=*/true), + /*is_linear=*/true})); auto range = RecursivelyIdentifyRange(root, fs); EXPECT_TRUE(range.IsEmpty()); EXPECT_FALSE(range.IsSingleValue()); @@ -411,10 +501,10 @@ ENTRY entry { const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* p0 = root->operand(0); absl::flat_hash_map fs; - fs.insert( - std::make_pair(p0, Range{ConstantValue::GetZero(16, /*is_signed=*/true), - ConstantValue::GetSigned(32760, 16), - /*is_linear=*/true})); + fs.insert(std::make_pair(p0, Range{ConstantValue::GetZero(16, + /*is_signed=*/true), + ConstantValue::GetSigned(32760, 16), + /*is_linear=*/true})); auto range = RecursivelyIdentifyRange(root, fs); EXPECT_TRUE(range.IsEmpty()); EXPECT_FALSE(range.IsSingleValue()); diff --git a/third_party/xla/xla/service/while_loop_pipeline_unroller.cc b/third_party/xla/xla/service/while_loop_pipeline_unroller.cc index 97dca76ba65c50..8f242ab227f869 100644 --- a/third_party/xla/xla/service/while_loop_pipeline_unroller.cc +++ b/third_party/xla/xla/service/while_loop_pipeline_unroller.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -111,7 +110,7 @@ int64_t WhileLoopPipelineUnroller::ComputeWhileLoopPipelineDepth( absl::StatusOr WhileLoopPipelineUnroller::Run( HloModule* module, - const absl::flat_hash_set& execution_threads) { + const absl::flat_hash_set& execution_threads) { std::vector> while_instructions; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { diff --git a/third_party/xla/xla/service/while_loop_pipeline_unroller.h b/third_party/xla/xla/service/while_loop_pipeline_unroller.h index 4e5318f8f90385..f259fe3b83617e 100644 --- a/third_party/xla/xla/service/while_loop_pipeline_unroller.h +++ b/third_party/xla/xla/service/while_loop_pipeline_unroller.h @@ -17,7 +17,6 @@ limitations under the License. #define XLA_SERVICE_WHILE_LOOP_PIPELINE_UNROLLER_H_ #include -#include #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -40,14 +39,14 @@ namespace xla { // drastically increase compile times due to linearly increasing graph size. class WhileLoopPipelineUnroller : public HloModulePass { public: - std::string_view name() const override { + absl::string_view name() const override { return "while_loop_pipeline_unroller"; } using HloPassInterface::Run; absl::StatusOr Run( HloModule* module, - const absl::flat_hash_set& execution_threads) override; + const absl::flat_hash_set& execution_threads) override; // The pipeline depth of a while loop is the number of loop iterations that // pipelined loop inputs live throughout. This is used to determine how many diff --git a/third_party/xla/xla/service/while_loop_pipeline_unroller_test.cc b/third_party/xla/xla/service/while_loop_pipeline_unroller_test.cc index f8618a304514c6..82793a2e52b28f 100644 --- a/third_party/xla/xla/service/while_loop_pipeline_unroller_test.cc +++ b/third_party/xla/xla/service/while_loop_pipeline_unroller_test.cc @@ -16,10 +16,10 @@ limitations under the License. #include "xla/service/while_loop_pipeline_unroller.h" #include -#include #include #include "absl/container/inlined_vector.h" +#include "absl/strings/string_view.h" #include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -49,7 +49,7 @@ class WhileLoopPipelineUnrollerTest : public HloTestBase { }; TEST_F(WhileLoopPipelineUnrollerTest, PipelinedLoop) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main body { @@ -100,7 +100,7 @@ ENTRY main { } TEST_F(WhileLoopPipelineUnrollerTest, PipelinedLoopWithInfeed) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main body { diff --git a/third_party/xla/xla/service/while_loop_simplifier_test.cc b/third_party/xla/xla/service/while_loop_simplifier_test.cc index 785c21ec941b07..a478453e4aa881 100644 --- a/third_party/xla/xla/service/while_loop_simplifier_test.cc +++ b/third_party/xla/xla/service/while_loop_simplifier_test.cc @@ -1070,12 +1070,12 @@ TEST_F(WhileLoopSimplifierTest, RemoveTrivialCompare) { HloModule RemoveTrivialCompare RemoveTrivialCompare.body { loop_var = (pred[], s32[]) parameter(0) - + get-tuple-element.2 = s32[] get-tuple-element((pred[], s32[]) loop_var), index=1 - + cons = s32[] constant({{LOOP_CONSTANT}}) comp = pred[] compare(get-tuple-element.2, cons), direction={{DIRECTION}} - + constant.1 = s32[] constant(1) add = s32[] add(s32[] get-tuple-element.2, s32[] constant.1) ROOT tuple = (pred[], s32[]) tuple(comp, @@ -1144,12 +1144,12 @@ TEST_F(WhileLoopSimplifierTest, NotRemoveCompare) { HloModule RemoveTrivialCompare RemoveTrivialCompare.body { loop_var = (pred[], s32[]) parameter(0) - + get-tuple-element.2 = s32[] get-tuple-element((pred[], s32[]) loop_var), index=1 - + five = s32[] constant(5) comp = pred[] compare(get-tuple-element.2, five), direction=LT - + constant.1 = s32[] constant(1) add = s32[] add(s32[] get-tuple-element.2, s32[] constant.1) ROOT tuple = (pred[], s32[]) tuple(comp, diff --git a/third_party/xla/xla/service/while_loop_unroller.cc b/third_party/xla/xla/service/while_loop_unroller.cc index d2731d22f61575..8464f9babca796 100644 --- a/third_party/xla/xla/service/while_loop_unroller.cc +++ b/third_party/xla/xla/service/while_loop_unroller.cc @@ -19,12 +19,12 @@ limitations under the License. #include #include #include -#include #include #include #include "absl/algorithm/algorithm.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -47,9 +47,11 @@ limitations under the License. #include "xla/overflow_util.h" #include "xla/service/call_inliner.h" #include "xla/service/collective_ops_utils.h" +#include "xla/service/constant_value.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_cse.h" #include "xla/service/pattern_matcher.h" +#include "xla/service/value_range.h" #include "xla/service/while_loop_constant_sinking.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -67,7 +69,7 @@ using hlo_query::ContainsInstrWithOpcode; // Helper function to create a condition for a single iteration while loop in // the form of 'i <= init_value' where i is the induction variable. std::unique_ptr MakeTrivialLoopCondition( - HloInstruction* while_op, std::string_view name, int64_t induction_idx, + HloInstruction* while_op, absl::string_view name, int64_t induction_idx, int64_t init_value) { auto condition_builder = HloComputation::Builder(name); @@ -285,7 +287,7 @@ absl::StatusOr UnrollInternal(HloInstruction* while_op, computation->AddInstruction(HloInstruction::CreateCall( while_op->shape(), call_operands, unrolled_body)); call_operands.clear(); - call_operands.emplace_back(unrolled_body_call_op); + call_operands.push_back(unrolled_body_call_op); } TF_RETURN_IF_ERROR( computation->ReplaceInstruction(while_op, unrolled_body_call_op)); @@ -327,7 +329,7 @@ absl::StatusOr UnrollInternalWrappedAndReturnReplacement( absl::StrCat(while_op->name(), "-unrolled-body-call-", i)); call_operands.clear(); - call_operands.emplace_back(unrolled_body_call_op); + call_operands.push_back(unrolled_body_call_op); } HloComputation* new_body = module->AddEmbeddedComputation(body_builder.Build(unrolled_body_call_op)); @@ -361,21 +363,38 @@ absl::StatusOr UnrollInternalWrapped(HloInstruction* while_op, }; // namespace -// Recursively checks if the given instruction points to the induction var of -// the given loop config. -bool IsLoopInductionVar(const HloInstruction* instr, - const WhileLoopConfig& config) { - if (!instr->parent()->IsFusionComputation()) { - return Match(instr, match::GetTupleElement(match::Parameter(), - config.induction_var_idx)); - } else { +// Recursively checks if the given instruction inside a while loop body can be +// expressed as a value range, possibly depending on the loop induction variable +// of that while loop. +std::optional IdentifyRangeAsFunctionOfInductionVar( + const HloInstruction* instr, const WhileLoopConfig& config) { + if (instr->parent()->IsFusionComputation()) { if (!Match(instr, match::Parameter())) { - return false; + return std::nullopt; } HloInstruction* caller_fusion = instr->parent()->FusionInstruction(); - return IsLoopInductionVar(caller_fusion->operand(instr->parameter_number()), - config); + return IdentifyRangeAsFunctionOfInductionVar( + caller_fusion->operand(instr->parameter_number()), config); } + + std::optional loop_range = MatchTrivialLoopRange(config.while_instr); + if (loop_range == std::nullopt) { + return std::nullopt; + } + + const HloComputation* while_body = config.while_instr->while_body(); + absl::flat_hash_map predefined_ranges; + HloInstruction* while_body_input_tuple = while_body->parameter_instruction(0); + for (HloInstruction* user : while_body_input_tuple->users()) { + if (Match(user, match::GetTupleElement(match::Parameter(0), + config.induction_var_idx))) { + predefined_ranges[user] = loop_range.value(); + } + } + + Range instr_range = + RecursivelyIdentifyRange(instr, predefined_ranges, nullptr); + return instr_range; } // Recursively checks if the given instruction is effectively static by checking @@ -466,12 +485,16 @@ std::optional MatchShapeCoveringDynamicIndexInstruction( return std::nullopt; } // Based on the instruction type, start indices start from index 1 or 2 of the - // operands. + // operands and the slice shape is either the shape of instr (i.e. its output + // shape) or the shape of its operand at index 1. int64_t start_indices_offset; + const Shape* slice_shape; if (instr->opcode() == HloOpcode::kDynamicSlice) { start_indices_offset = 1; + slice_shape = &instr->shape(); } else if (instr->opcode() == HloOpcode::kDynamicUpdateSlice) { start_indices_offset = 2; + slice_shape = &instr->operand(1)->shape(); } else { return std::nullopt; } @@ -481,7 +504,8 @@ std::optional MatchShapeCoveringDynamicIndexInstruction( return std::nullopt; } - int64_t dynamic_index = -1; + std::optional dynamic_index; + std::optional dynamic_index_range; for (int64_t start_index = start_indices_offset; start_index < instr->operand_count(); ++start_index) { const HloInstruction* index = instr->operand(start_index); @@ -496,46 +520,80 @@ std::optional MatchShapeCoveringDynamicIndexInstruction( continue; } - // Check that the instruction's dynamic index points to the loop induction - // variable. - if (IsLoopInductionVar(index, config)) { + // Try to compute a Range for this interval based on the loop induction + // variable's Range. + std::optional index_range = + IdentifyRangeAsFunctionOfInductionVar(index, config); + if (index_range != std::nullopt && !index_range->IsSingleValue()) { // In order to cover the whole shape only a single non-constant index is // allowed. - if (dynamic_index != -1) { + if (dynamic_index != std::nullopt) { VLOG(3) << "Multiple non-constant indices."; return std::nullopt; } dynamic_index = start_index - start_indices_offset; + dynamic_index_range = index_range; + continue; } + + VLOG(3) << "Index is neither constant nor a function of loop induction " + "var."; + return std::nullopt; } - if (dynamic_index == -1) { + if (dynamic_index == std::nullopt) { VLOG(3) << "No dynamic index found."; return std::nullopt; } - if (operand->shape().dimensions(dynamic_index) != config.trip_count) { - VLOG(3) << "The dynamic_index dimension size of the operand must be equal " - "to the loop trip count."; + const ConstantValue& min_index_touched = dynamic_index_range->min(); + const ConstantValue operand_first_index = ConstantValue::GetZero( + min_index_touched.GetBitwidth(), min_index_touched.IsSigned()); + if (min_index_touched.gt(operand_first_index)) { + VLOG(3) << "The dynamic_index must cover index zero, but it begins at " + << min_index_touched.ToString(); return std::nullopt; } - if (opcode == HloOpcode::kDynamicSlice) { - const Shape& result_shape = instr->shape(); - if (result_shape.dimensions(dynamic_index) != 1) { - VLOG(3) << "The slice size on the dynamic_index dimension must be 1."; - return std::nullopt; - } + const ConstantValue slice_size = + ConstantValue::Get(slice_shape->dimensions(dynamic_index.value()), + dynamic_index_range->max()->GetBitwidth(), + dynamic_index_range->max()->IsSigned()); + const ConstantValue max_index_touched_plus_one = + dynamic_index_range->max()->add(slice_size); + const Shape& operand_shape = operand->shape(); + const ConstantValue operand_last_index_plus_one = + ConstantValue::Get(operand_shape.dimensions(dynamic_index.value()), + dynamic_index_range->max()->GetBitwidth(), + dynamic_index_range->max()->IsSigned()); + if (max_index_touched_plus_one.lt(operand_last_index_plus_one)) { + const ConstantValue constant_one = + ConstantValue::GetOne(dynamic_index_range->max()->GetBitwidth(), + dynamic_index_range->max()->IsSigned()); + VLOG(3) << "The dynamic_index must cover index " + << operand_last_index_plus_one.sub(constant_one).ToString() + << " but the last value it takes on is " + << dynamic_index_range->max()->ToString() + << " and the slice size is " << slice_size.ToString() + << " so it only reaches " + << max_index_touched_plus_one.sub(constant_one).ToString(); + return std::nullopt; + } - const Shape& operand_shape = operand->shape(); - CHECK_EQ(result_shape.dimensions_size(), operand_shape.dimensions_size()); - for (int64_t i = 0; i < result_shape.dimensions_size(); ++i) { - if (i != dynamic_index && - result_shape.dimensions(i) != operand_shape.dimensions(i)) { - VLOG(3) << "The slice sizes must match the operand-shape on " - "non-dynamic-index dimensions."; - return std::nullopt; - } + if (dynamic_index_range->step()->gt(slice_size)) { + VLOG(3) << "The dynamic_index has a step size of " + << dynamic_index_range->step()->ToString() + << " but the slice size is " << slice_size.ToString(); + return std::nullopt; + } + + CHECK_EQ(slice_shape->dimensions_size(), operand_shape.dimensions_size()); + for (int64_t i = 0; i < slice_shape->dimensions_size(); ++i) { + if (i != dynamic_index && + slice_shape->dimensions(i) != operand_shape.dimensions(i)) { + VLOG(3) << "The slice sizes must match the operand-shape on " + "non-dynamic-index dimensions."; + return std::nullopt; } } diff --git a/third_party/xla/xla/service/while_loop_unroller.h b/third_party/xla/xla/service/while_loop_unroller.h index 619c11697435bc..e3c96dc42cdc67 100644 --- a/third_party/xla/xla/service/while_loop_unroller.h +++ b/third_party/xla/xla/service/while_loop_unroller.h @@ -62,14 +62,15 @@ struct UnrollResult { // Check if `instr` is a dynamic index instruction, i.e., dynamic-slice or // dynamic-update-slice with the given input that operates on the entire // shape of the instruction. To satisfy this: -// 1. All start indices must be constant zero except only a single dimension. -// 2. The start index of that dimension should be equal to the enclosing loop -// induction variable. -// 3. The size of that dimension must match the loop trip count. -// 4. For dynamic-slice, the slice size for the induction variable dimension is -// 1, and the size of all other dimensions is the same as the shape of the -// input. -// If so, it returns the dynamic index. +// 1. All start indices must be constant zero except for a single dimension, +// hereafter referred to as the dynamic dimension. +// 2. The slice sizes of all nondynamic dimensions is the same as their size in +// the input shape. +// 3. The start index of the dynamic dimension should be equal to the enclosing +// loop induction variable times the dynamic dimension's slice size. +// 4. The size of the dynamic dimension must be at most the loop trip count +// times the slice size. +// If so, it returns the index of the dynamic dimension. std::optional MatchShapeCoveringDynamicIndexInstruction( const HloInstruction* instr, const HloInstruction* input, HloOpcode opcode, const WhileLoopConfig& config); diff --git a/third_party/xla/xla/service/while_loop_unroller_test.cc b/third_party/xla/xla/service/while_loop_unroller_test.cc index 952b6f5240a95f..dc659de2f4108b 100644 --- a/third_party/xla/xla/service/while_loop_unroller_test.cc +++ b/third_party/xla/xla/service/while_loop_unroller_test.cc @@ -31,10 +31,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/literal.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/statusor.h" namespace xla { @@ -52,6 +52,14 @@ class WhileLoopUnrollerTest : public HloTestBase { MakeModuleWithWhileFeedingAnotherWhile(int num_iters); [[nodiscard]] std::unique_ptr MakeModuleWithSimpleLoopAllReduce(int num_iters); + // These two methods make a module with a while loop over + // (i = `start`; i < `stop`; i += `step`) whose iterations perform a + // dynamic slice (or dynamic update slice) at position i with slice size + // `slice_size` on a tensor whose dimension has size `dim_size`. + [[nodiscard]] std::unique_ptr MakeModuleWithDS( + int start, int stop, int step, int slice_size, int dim_size); + [[nodiscard]] std::unique_ptr MakeModuleWithDUS( + int start, int stop, int step, int slice_size, int dim_size); public: void UnrollAndCompare(std::unique_ptr module, @@ -311,6 +319,81 @@ WhileLoopUnrollerTest::MakeModuleWithSimpleLoopAllReduce(int num_iters) { return ParseAndReturnVerifiedModule(hlo_string).value(); } +std::unique_ptr WhileLoopUnrollerTest::MakeModuleWithDS( + int start, int stop, int step, int slice_size, int dim_size) { + std::string hlo_string_template = R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[]{:T(128)}, s32[{{DIM_SIZE}},10]{1,0}) parameter(0) + get-tuple-element.1 = s32[]{:T(128)} get-tuple-element(loop_var.1), index=0 + constant.1 = s32[]{:T(128)} constant({{STEP}}) + idx = s32[]{:T(128)} add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[{{DIM_SIZE}},10]{1,0} get-tuple-element(loop_var.1), index=1 + zero = s32[] constant(0) + slice = s32[{{SLICE_SIZE}},10] dynamic-slice(get-tuple-element.2, get-tuple-element.1, zero), dynamic_slice_sizes={{{SLICE_SIZE}},10} + output = s32[{{DIM_SIZE}},10]{1,0} add(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[]{:T(128)}, s32[{{DIM_SIZE}},10]{1,0}) tuple(idx, output) + } + SimpleLoop.condition { + loop_var.2 = (s32[]{:T(128)}, s32[{{DIM_SIZE}},10]{1,0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[]{:T(128)} constant({{STOP}}) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + ENTRY SimpleLoop { + constant.3 = s32[]{:T(128)} constant({{START}}) + constant.4 = s32[{{DIM_SIZE}},10]{1,0} constant({...}) + tuple.1 = (s32[]{:T(128)}, s32[{{DIM_SIZE}},10]{1,0}) tuple(constant.3, constant.4) + ROOT while = (s32[]{:T(128)}, s32[{{DIM_SIZE}},10]{1,0}) while(tuple.1), condition= SimpleLoop.condition, body=SimpleLoop.body + } + )"; + std::string hlo_string = absl::StrReplaceAll( + hlo_string_template, {{"{{START}}", absl::StrCat(start)}, + {"{{STOP}}", absl::StrCat(stop)}, + {"{{STEP}}", absl::StrCat(step)}, + {"{{SLICE_SIZE}}", absl::StrCat(slice_size)}, + {"{{DIM_SIZE}}", absl::StrCat(dim_size)}}); + return ParseAndReturnVerifiedModule(hlo_string).value(); +} + +std::unique_ptr WhileLoopUnrollerTest::MakeModuleWithDUS( + int start, int stop, int step, int slice_size, int dim_size) { + std::string hlo_string_template = R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[]{:T(128)}, s32[{{DIM_SIZE}},10]{1,0}) parameter(0) + get-tuple-element.1 = s32[]{:T(128)} get-tuple-element(loop_var.1), index=0 + constant.1 = s32[]{:T(128)} constant({{STEP}}) + idx = s32[]{:T(128)} add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[{{DIM_SIZE}},10]{1,0} get-tuple-element(loop_var.1), index=1 + zero = s32[] constant(0) + broadcast = s32[{{SLICE_SIZE}},10] broadcast(zero) + slice = s32[{{DIM_SIZE}},10] dynamic-update-slice(get-tuple-element.2, broadcast, get-tuple-element.1, zero) + output = s32[{{DIM_SIZE}},10]{1,0} add(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[]{:T(128)}, s32[{{DIM_SIZE}},10]{1,0}) tuple(idx, output) + } + SimpleLoop.condition { + loop_var.2 = (s32[]{:T(128)}, s32[{{DIM_SIZE}},10]{1,0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[]{:T(128)} constant({{STOP}}) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + ENTRY SimpleLoop { + constant.3 = s32[]{:T(128)} constant({{START}}) + constant.4 = s32[{{DIM_SIZE}},10]{1,0} constant({...}) + tuple.1 = (s32[]{:T(128)}, s32[{{DIM_SIZE}},10]{1,0}) tuple(constant.3, constant.4) + ROOT while = (s32[]{:T(128)}, s32[{{DIM_SIZE}},10]{1,0}) while(tuple.1), condition= SimpleLoop.condition, body=SimpleLoop.body + } + )"; + std::string hlo_string = absl::StrReplaceAll( + hlo_string_template, {{"{{START}}", absl::StrCat(start)}, + {"{{STOP}}", absl::StrCat(stop)}, + {"{{STEP}}", absl::StrCat(step)}, + {"{{SLICE_SIZE}}", absl::StrCat(slice_size)}, + {"{{DIM_SIZE}}", absl::StrCat(dim_size)}}); + return ParseAndReturnVerifiedModule(hlo_string).value(); +} + TEST_F(WhileLoopUnrollerTest, SimpleLoopUnroll) { UnrollAndCompare(MakeModuleWithSimpleLoop(/*num_iters=*/5), {}, -1, false); UnrollAndCompare(MakeModuleWithSimpleLoop(/*num_iters=*/5), {}, -1, true); @@ -945,37 +1028,8 @@ TEST_F(WhileLoopUnrollerTest, LoopWithCollective2) { } TEST_F(WhileLoopUnrollerTest, MatchShapeCoveringDS) { - std::string hlo_string_template = R"( - HloModule SimpleLoop - SimpleLoop.body { - loop_var.1 = (s32[]{:T(128)}, s32[3,10]{1,0}) parameter(0) - get-tuple-element.1 = s32[]{:T(128)} get-tuple-element(loop_var.1), index=0 - constant.1 = s32[]{:T(128)} constant(1) - idx = s32[]{:T(128)} add(get-tuple-element.1, constant.1) - get-tuple-element.2 = s32[3,10]{1,0} get-tuple-element(loop_var.1), index=1 - zero = s32[] constant(0) - slice = s32[1,10] dynamic-slice(get-tuple-element.2, get-tuple-element.1, zero), dynamic_slice_sizes={1,10} - output = s32[3,10]{1,0} add(get-tuple-element.2, get-tuple-element.2) - ROOT tuple = (s32[]{:T(128)}, s32[3,10]{1,0}) tuple(idx, output) - } - SimpleLoop.condition { - loop_var.2 = (s32[]{:T(128)}, s32[3,10]{1,0}) parameter(0) - get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 - constant.2 = s32[]{:T(128)} constant({{LOOP_BOUND}}) - ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT - } - ENTRY SimpleLoop { - constant.3 = s32[]{:T(128)} constant(0) - constant.4 = s32[3,10]{1,0} constant({...}) - tuple.1 = (s32[]{:T(128)}, s32[3,10]{1,0}) tuple(constant.3, constant.4) - ROOT while = (s32[]{:T(128)}, s32[3,10]{1,0}) while(tuple.1), condition= - SimpleLoop.condition, body=SimpleLoop.body - } - )"; - - std::string hlo_string = absl::StrReplaceAll( - hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(3)}}); - auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + auto module = MakeModuleWithDS(/*start=*/0, /*stop=*/3, /*step=*/1, + /*slice_size=*/1, /*dim_size=*/3); HloInstruction* loop = module->entry_computation()->root_instruction(); auto config = WhileLoopUnroller::IsLoopUnrollable(loop); EXPECT_TRUE(config.has_value()); @@ -1088,6 +1142,83 @@ TEST_F(WhileLoopUnrollerTest, MatchShapeCoveringDSNested) { .has_value()); } +TEST_F(WhileLoopUnrollerTest, MatchShapeCoveringDSIncrementByTwo) { + // In this version of the test, our dimension of interest gets incremented by + // two at a time so that it takes on values {0, 2, 4}. The DS has slice size + // two, so indeed all index values {0, 1, 2, 3, 4, 5} are retrieved by the DS. + auto module = MakeModuleWithDS(/*start=*/0, /*stop=*/6, /*step=*/2, + /*slice_size=*/2, /*dim_size=*/6); + HloInstruction* loop = module->entry_computation()->root_instruction(); + auto config = WhileLoopUnroller::IsLoopUnrollable(loop); + EXPECT_TRUE(config.has_value()); + HloComputation* body = module->GetComputationWithName("SimpleLoop.body"); + HloInstruction* input = body->GetInstructionWithName("get-tuple-element.2"); + HloInstruction* instr = body->GetInstructionWithName("slice"); + EXPECT_TRUE(MatchShapeCoveringDynamicIndexInstruction( + instr, input, HloOpcode::kDynamicSlice, config.value()) + .has_value()); +} + +TEST_F(WhileLoopUnrollerTest, MatchShapeCoveringDSIncrementByTwoMismatch) { + // In this version of the test, our dimension of interest gets incremented by + // two at a time so that it takes on values {0, 2, 4}. The DS has slice size + // two, so only index values {0, 1, 2, 3, 4, 5} are retrieved by the DS and + // index value 6 is not. + auto module = MakeModuleWithDS(/*start=*/0, /*stop=*/6, /*step=*/2, + /*slice_size=*/2, /*dim_size=*/7); + HloInstruction* loop = module->entry_computation()->root_instruction(); + auto config = WhileLoopUnroller::IsLoopUnrollable(loop); + EXPECT_TRUE(config.has_value()); + HloComputation* body = module->GetComputationWithName("SimpleLoop.body"); + HloInstruction* input = body->GetInstructionWithName("get-tuple-element.2"); + HloInstruction* instr = body->GetInstructionWithName("slice"); + EXPECT_FALSE(MatchShapeCoveringDynamicIndexInstruction( + instr, input, HloOpcode::kDynamicSlice, config.value()) + .has_value()); +} + +TEST_F(WhileLoopUnrollerTest, MatchShapeCoveringDUS) { + auto module = MakeModuleWithDUS(/*start=*/0, /*stop=*/3, /*step=*/1, + /*slice_size=*/1, /*dim_size=*/3); + HloInstruction* loop = module->entry_computation()->root_instruction(); + auto config = WhileLoopUnroller::IsLoopUnrollable(loop); + EXPECT_TRUE(config.has_value()); + HloComputation* body = module->GetComputationWithName("SimpleLoop.body"); + HloInstruction* input = body->GetInstructionWithName("get-tuple-element.2"); + HloInstruction* instr = body->GetInstructionWithName("slice"); + EXPECT_TRUE(MatchShapeCoveringDynamicIndexInstruction( + instr, input, HloOpcode::kDynamicUpdateSlice, config.value()) + .has_value()); +} + +TEST_F(WhileLoopUnrollerTest, MatchShapeCoveringDUSIncrementByTwo) { + auto module = MakeModuleWithDUS(/*start=*/0, /*stop=*/6, /*step=*/2, + /*slice_size=*/2, /*dim_size=*/6); + HloInstruction* loop = module->entry_computation()->root_instruction(); + auto config = WhileLoopUnroller::IsLoopUnrollable(loop); + EXPECT_TRUE(config.has_value()); + HloComputation* body = module->GetComputationWithName("SimpleLoop.body"); + HloInstruction* input = body->GetInstructionWithName("get-tuple-element.2"); + HloInstruction* instr = body->GetInstructionWithName("slice"); + EXPECT_TRUE(MatchShapeCoveringDynamicIndexInstruction( + instr, input, HloOpcode::kDynamicUpdateSlice, config.value()) + .has_value()); +} + +TEST_F(WhileLoopUnrollerTest, MatchShapeCoveringDUSIncrementByTwoMismatch) { + auto module = MakeModuleWithDUS(/*start=*/0, /*stop=*/6, /*step=*/2, + /*slice_size=*/2, /*dim_size=*/7); + HloInstruction* loop = module->entry_computation()->root_instruction(); + auto config = WhileLoopUnroller::IsLoopUnrollable(loop); + EXPECT_TRUE(config.has_value()); + HloComputation* body = module->GetComputationWithName("SimpleLoop.body"); + HloInstruction* input = body->GetInstructionWithName("get-tuple-element.2"); + HloInstruction* instr = body->GetInstructionWithName("slice"); + EXPECT_FALSE(MatchShapeCoveringDynamicIndexInstruction( + instr, input, HloOpcode::kDynamicUpdateSlice, config.value()) + .has_value()); +} + // Unroller pass must remove all the DynamicGte custom-calls. TEST_F(WhileLoopUnrollerTest, UnrollLoopWithDynamicGte) { std::string hlo_string = R"( diff --git a/third_party/xla/xla/service/while_util_test.cc b/third_party/xla/xla/service/while_util_test.cc index f8e597ecc43932..e2162a841d599e 100644 --- a/third_party/xla/xla/service/while_util_test.cc +++ b/third_party/xla/xla/service/while_util_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/service/while_util.h" #include -#include #include #include @@ -224,7 +223,7 @@ ENTRY main { } TEST_F(WhileUtilTest, TryIncrementNonCounterTripCount) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main body { @@ -260,7 +259,7 @@ ENTRY main { } TEST_F(WhileUtilTest, TryIncrementNonConstantTripCount) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main body { @@ -297,7 +296,7 @@ ENTRY main { } TEST_F(WhileUtilTest, TryIncrementSideEffecting) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main body { @@ -334,7 +333,7 @@ ENTRY main { } TEST_F(WhileUtilTest, IncrementTripCountLt) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main body { @@ -372,7 +371,7 @@ ENTRY main { } TEST_F(WhileUtilTest, IncrementTripCountGt) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main body { diff --git a/third_party/xla/xla/service/xla_debug_info_manager.cc b/third_party/xla/xla/service/xla_debug_info_manager.cc index b6d5e5ff90d135..82bf0e89224d9b 100644 --- a/third_party/xla/xla/service/xla_debug_info_manager.cc +++ b/third_party/xla/xla/service/xla_debug_info_manager.cc @@ -79,7 +79,7 @@ void XlaDebugInfoManager::StopTracing( modules_to_serialize.emplace_back(std::move(m)); modules_.erase(cur_it); } else { - modules_to_serialize.emplace_back(m); + modules_to_serialize.push_back(m); } } } diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index 09e2f63305db32..11d4ab4b9e5d2d 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/types/span.h" #include "xla/layout.h" #include "xla/layout_util.h" diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index 5b7fd2d89487b6..75c8c0f8256271 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" #include "absl/types/span.h" #include "xla/layout.h" #include "xla/primitive_util.h" @@ -75,7 +76,6 @@ class Shape { // Returns the rank (number of dimensions) of the given shape. Shape must be // an array. int64_t rank() const { - DCHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString(); return dimensions_.size(); } @@ -152,19 +152,11 @@ class Shape { return absl::MakeSpan(dynamic_dimensions_); } - // Add dimension_upper_bound(). - // Removes the given dimension from the shape. Layout, if it exists, is // adjusted to match the modified shape. void DeleteDimension(int64_t dim_to_delete); void DeleteDimensions(absl::Span sorted_dims_to_delete); - // The following methods mirror the protobuf generated code interface for the - // message ShapeProto. This enabled easy migration of this data structure - // from a proto to a proper C++ class. - // TODO(b/29771030): Replace or augment these methods with a more ergonomic - // interface. - // Methods for accessing the primitive type. PrimitiveType element_type() const { return element_type_; } void set_element_type(PrimitiveType value) { element_type_ = value; } diff --git a/third_party/xla/xla/shape_layout.cc b/third_party/xla/xla/shape_layout.cc index 7a3516b5fb7cec..057a523b731eff 100644 --- a/third_party/xla/xla/shape_layout.cc +++ b/third_party/xla/xla/shape_layout.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/shape_layout.h" +#include "absl/log/check.h" #include "absl/status/status.h" #include "xla/layout.h" #include "xla/layout_util.h" diff --git a/third_party/xla/xla/shape_test.cc b/third_party/xla/xla/shape_test.cc index 242bf12601435a..78f3fda40cb12d 100644 --- a/third_party/xla/xla/shape_test.cc +++ b/third_party/xla/xla/shape_test.cc @@ -15,10 +15,11 @@ limitations under the License. #include "xla/shape.h" +#include #include "absl/hash/hash_testing.h" +#include "xla/hlo/testlib/test.h" #include "xla/layout.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/xla_data.pb.h" #include "tsl/platform/test_benchmark.h" diff --git a/third_party/xla/xla/shape_tree.cc b/third_party/xla/xla/shape_tree.cc index bc83698a02851d..9fb17e2ecb6a3a 100644 --- a/third_party/xla/xla/shape_tree.cc +++ b/third_party/xla/xla/shape_tree.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "tsl/platform/logging.h" // IWYU pragma: keep diff --git a/third_party/xla/xla/shape_tree.h b/third_party/xla/xla/shape_tree.h index fd4448e0265089..9ea53dd4aeb79d 100644 --- a/third_party/xla/xla/shape_tree.h +++ b/third_party/xla/xla/shape_tree.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" diff --git a/third_party/xla/xla/shape_tree_test.cc b/third_party/xla/xla/shape_tree_test.cc index 5e29d719eb27dc..f810c1e895bd5c 100644 --- a/third_party/xla/xla/shape_tree_test.cc +++ b/third_party/xla/xla/shape_tree_test.cc @@ -20,9 +20,10 @@ limitations under the License. #include #include +#include +#include "xla/hlo/testlib/test.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/xla_data.pb.h" #include "tsl/platform/test_benchmark.h" diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index d537221747a46a..b8604459a77b1f 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -33,11 +33,13 @@ limitations under the License. #include "absl/base/optimization.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/blocking_counter.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -69,6 +71,15 @@ namespace { constexpr int64_t kAnnotationPrintInterval = 5; +inline absl::Status ShapeError(const Shape& shape, absl::string_view message) { + return absl::InvalidArgumentError(absl::StrFormat( + "Shape Error: %s Shape(%s): %s", message, + PrimitiveType_IsValid(shape.element_type()) + ? primitive_util::LowercasePrimitiveTypeName(shape.element_type()) + : absl::StrCat(static_cast(shape.element_type())), + shape.DebugString())); +} + template void PrintShape(Printer* printer, const Shape& shape) { if constexpr (kPrintLayout) { @@ -97,18 +108,6 @@ void PrintTupleShapes(Printer* printer, absl::Span tuple_shapes) { printer->Append(")"); } -} // namespace - -std::string ShapeIndex::ToString() const { - return StrCat("{", absl::StrJoin(*this, ","), "}"); -} - -std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) { - out << shape_index.ToString(); - return out; -} - -namespace { // Constructs and returns the new shape with the given minor_to_major order in // its Layout. absl::StatusOr MakeShapeWithLayoutInternal( @@ -171,6 +170,15 @@ Shape MakeTupleShapeImpl(absl::Span shapes) { } // namespace +std::string ShapeIndex::ToString() const { + return StrCat("{", absl::StrJoin(*this, ","), "}"); +} + +std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) { + out << shape_index.ToString(); + return out; +} + /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { bool equal = Shape::Equal()(lhs, rhs); @@ -623,7 +631,6 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { } /* static */ int64_t ShapeUtil::TupleElementCount(const Shape& shape) { - CHECK(shape.IsTuple()) << HumanString(shape); return shape.tuple_shapes_size(); } @@ -791,8 +798,6 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs, const Shape& rhs) { - CHECK(lhs.IsArray()); - CHECK(rhs.IsArray()); if (!SameRank(lhs, rhs)) return false; for (int i = 0; i < lhs.rank(); ++i) { if (!lhs.is_unbounded_dynamic_dimension(i) && @@ -806,8 +811,6 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { } /* static */ bool ShapeUtil::SameRank(const Shape& lhs, const Shape& rhs) { - CHECK(lhs.IsArray()); - CHECK(rhs.IsArray()); return lhs.rank() == rhs.rank(); } @@ -930,8 +933,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { return absl::OkStatus(); } if (!subshape.IsArray()) { - return InvalidArgument("Shape cannot be serialiized: %s", - shape.ToString()); + return ShapeError(shape, "Shape cannot be serialiized."); } if (subshape.is_dynamic()) { size += sizeof(DynamicSizeType) * subshape.rank(); @@ -954,46 +956,28 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { return size; } -/* static */ absl::Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( - const Shape& shape) { - if (shape.element_type() == PRIMITIVE_TYPE_INVALID || - !PrimitiveType_IsValid(shape.element_type())) { - return InvalidArgument("shape has invalid element type: %s", - shape.ShortDebugString()); - } - if (shape.element_type() == TUPLE) { - if (shape.dimensions_size() != 0) { - return InvalidArgument("tuples must not have dimensions specified"); - } - for (auto& element_shape : shape.tuple_shapes()) { - TF_RETURN_IF_ERROR( - ValidateShapeWithOptionalLayoutInternal(element_shape)); - } +namespace { + +// Validates the shape size is sane. This makes sure it's safe to do +// calculations in int64_t without overflowing. +absl::Status ValidateShapeSize(const Shape& shape) { + if (!shape.IsArray()) { return absl::OkStatus(); } - // Non-tuple shape. - if (shape.tuple_shapes_size() > 0) { - return InvalidArgument("non-tuple shape has tuple_shapes field"); - } + auto [extent_product, extent_overflow] = + ShapeUtil::ExtentProduct(shape); + auto [dense_shape_size, byte_width_overflow] = OverflowSafeMultiply( + extent_product, ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())); - // Tokens and opaques should not have layout or dimensions. - if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE_TYPE) { - if (shape.dimensions_size() != 0) { - return InvalidArgument( - "shape has %s element type, but has dimensions field: %s", - primitive_util::LowercasePrimitiveTypeName(shape.element_type()), - shape.ShortDebugString()); - } - if (shape.has_layout()) { - return InvalidArgument( - "shape has %s element type, but has layout field: %s", - primitive_util::LowercasePrimitiveTypeName(shape.element_type()), - shape.ShortDebugString()); - } - return absl::OkStatus(); + if (extent_overflow || byte_width_overflow) { + return InvalidArgument("Shape %s size may overflow int64_t.", + ShapeUtil::HumanString(shape)); } + return absl::OkStatus(); +} +absl::Status ValidateDimensions(const Shape& shape) { bool any_overflows = false; int64_t product = 1; for (int64_t i = 0; i < shape.rank(); ++i) { @@ -1002,54 +986,69 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { continue; } if (dimension < 0) { - return InvalidArgument( - "shape's dimensions must not be < 0; dimension at index %d was %d", i, - dimension); + return ShapeError( + shape, + absl::StrFormat("Negative dimension at index %d: %d.", i, dimension)); } bool overflow; std::tie(product, overflow) = OverflowSafeMultiply(product, dimension); any_overflows |= overflow; } if (any_overflows) { - return InvalidArgument("shape's dimensions overflow: %s", - shape.ShortDebugString()); + return ShapeError(shape, "Dimensions overflow."); } - - TF_RETURN_IF_ERROR(ValidateShapeSize(shape)); return absl::OkStatus(); } -/* static */ absl::Status ShapeUtil::ValidateShapeSize(const Shape& shape) { - VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape); - - if (!shape.IsArray()) { +// Validates all of the non-layout properties of the shape -- this is a helper +// used by both the layout-optional and layout-required public method. +absl::Status ValidateNonLayoutProperties(const Shape& shape) { + if (shape.element_type() == PRIMITIVE_TYPE_INVALID || + !PrimitiveType_IsValid(shape.element_type())) { + return ShapeError(shape, "Invalid element type."); + } + if (shape.element_type() == TUPLE) { + if (shape.dimensions_size() != 0) { + return ShapeError(shape, "This type cannot have dimensions."); + } + for (auto& element_shape : shape.tuple_shapes()) { + TF_RETURN_IF_ERROR(ValidateNonLayoutProperties(element_shape)); + } return absl::OkStatus(); } - auto [extent_product, extent_overflow] = - ExtentProduct(shape); - auto [dense_shape_size, byte_width_overflow] = OverflowSafeMultiply( - extent_product, ByteSizeOfPrimitiveType(shape.element_type())); + // Non-tuple shape. + if (shape.tuple_shapes_size() > 0) { + return ShapeError(shape, "Non-tuple type contains tuple_shapes."); + } - if (extent_overflow || byte_width_overflow) { - return InvalidArgument("Shape %s size may overflow int64_t.", - ShapeUtil::HumanString(shape)); + // Tokens and opaques should not have layout or dimensions. + if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE_TYPE) { + if (shape.dimensions_size() != 0) { + return ShapeError(shape, "This type cannot have dimensions."); + } + if (shape.has_layout()) { + return ShapeError(shape, "This type cannot have a layout."); + } + return absl::OkStatus(); } - VLOG(3) << "Shape size is valid: " << dense_shape_size; + TF_RETURN_IF_ERROR(ValidateDimensions(shape)); + TF_RETURN_IF_ERROR(ValidateShapeSize(shape)); return absl::OkStatus(); } +} // namespace /* static */ absl::Status ShapeUtil::ValidateShapeWithOptionalLayout( const Shape& shape) { - TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape)); + TF_RETURN_IF_ERROR(ValidateNonLayoutProperties(shape)); return LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); } /* static */ absl::Status ShapeUtil::ValidateShape(const Shape& shape) { - TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape)); + TF_RETURN_IF_ERROR(ValidateNonLayoutProperties(shape)); return LayoutUtil::ValidateLayoutInShape(shape); } diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index 76a02174841650..0fcc75aa7aa507 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -140,7 +140,7 @@ class ShapeUtil { return product; } - // Returns the number of elements are contained within the provided shape; + // Returns the number of elements contained within the provided shape; // e.g. for rank 0 (scalars) the result is always 1. // Precondition: shape.IsArray() static inline int64_t ElementsIn(const Shape& shape) { @@ -1057,15 +1057,6 @@ class ShapeUtil { static bool FillNewShape(PrimitiveType element_type, absl::Span dimensions, Shape* shape); - // Validates the shape size is sane. This makes sure it's safe to do - // calculations in int64_t without overflowing. - static absl::Status ValidateShapeSize(const Shape& shape); - - // Validates all of the non-layout properties of the shape -- this is a helper - // used by both the layout-optional and layout-required public method. - static absl::Status ValidateShapeWithOptionalLayoutInternal( - const Shape& shape); - // Helper for ForEachSubshape which visits the subshapes of the given shape in // DFS pre-order starting with the index. template diff --git a/third_party/xla/xla/shape_util_test.cc b/third_party/xla/xla/shape_util_test.cc index 71a0c2cf5ff69c..9c58b488a3be66 100644 --- a/third_party/xla/xla/shape_util_test.cc +++ b/third_party/xla/xla/shape_util_test.cc @@ -23,15 +23,16 @@ limitations under the License. #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/hlo/testlib/test.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/shape.h" -#include "xla/test.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" @@ -1222,6 +1223,25 @@ TEST(ShapeUtilTest, B_251055887) { EXPECT_FALSE(ShapeUtil::ValidateShape(shape).ok()); } +TEST(ShapeUtilTest, B_385192799) { + // This case failed the fuzzer; see b/385192799. + ShapeProto proto; + + { + EXPECT_TRUE(tsl::protobuf::TextFormat::ParseFromString( + R"pb(element_type: 2000)pb", &proto)); + Shape shape(proto); + EXPECT_FALSE(ShapeUtil::ValidateShape(shape).ok()); + } + + { + EXPECT_TRUE(tsl::protobuf::TextFormat::ParseFromString( + R"pb(element_type: -1)pb", &proto)); + Shape shape(proto); + EXPECT_FALSE(ShapeUtil::ValidateShape(shape).ok()); + } +} + TEST(ShapeUtilTest, Int4ShapeSize) { Shape int4_shape = ShapeUtil::MakeShape(S4, {64, 128}); int4_shape.mutable_layout()->set_element_size_in_bits(4); diff --git a/third_party/xla/xla/side_effect_util.cc b/third_party/xla/xla/side_effect_util.cc index 602d76b66a4880..5c64d9a99e5f1c 100644 --- a/third_party/xla/xla/side_effect_util.cc +++ b/third_party/xla/xla/side_effect_util.cc @@ -73,4 +73,6 @@ const char kXlaCollectiveMatmulNone[] = "none"; const char kXlaMultiRecvCountAttr[] = "_xla_multi_recv_count"; +const char kXlaSchedulingGroupIdAttr[] = "_scheduling_group_id"; + } // namespace xla diff --git a/third_party/xla/xla/side_effect_util.h b/third_party/xla/xla/side_effect_util.h index 281a007b4cd8bc..d8c3c118004f59 100644 --- a/third_party/xla/xla/side_effect_util.h +++ b/third_party/xla/xla/side_effect_util.h @@ -82,6 +82,9 @@ extern const char kXlaCollectiveMatmulNone[]; // XLA frontend attribute for specifying the number of sends this recv should // match. extern const char kXlaMultiRecvCountAttr[]; + +// XLA frontend attribute for specifying the scheduling group id annotations. +extern const char kXlaSchedulingGroupIdAttr[]; } // namespace xla #endif // XLA_SIDE_EFFECT_UTIL_H_ diff --git a/third_party/xla/xla/status_macros.cc b/third_party/xla/xla/status_macros.cc index 5d24514b621f1a..449da54cee817f 100644 --- a/third_party/xla/xla/status_macros.cc +++ b/third_party/xla/xla/status_macros.cc @@ -20,6 +20,8 @@ limitations under the License. #include "absl/base/attributes.h" #include "absl/base/log_severity.h" #include "absl/base/optimization.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/status_macros_test.cc b/third_party/xla/xla/status_macros_test.cc index 5f54b5961e433e..474d1015137915 100644 --- a/third_party/xla/xla/status_macros_test.cc +++ b/third_party/xla/xla/status_macros_test.cc @@ -18,10 +18,11 @@ limitations under the License. #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "xla/test.h" -#include "xla/test_helpers.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index eecbc19aa3580b..72b762d3865aa4 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -3,7 +3,6 @@ load("//xla:xla.bzl", "xla_cc_test") load("//xla/stream_executor:build_defs.bzl", "stream_executor_build_defs_bzl_deps", "stream_executor_friends", "stream_executor_internal") load("//xla/tsl:tsl.bzl", "if_google", "if_oss", "internal_visibility") load("//xla/tsl/platform:build_config.bzl", "tf_proto_library") -load("//xla/tsl/platform:build_config_root.bzl", "if_static") load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( @@ -36,80 +35,6 @@ bzl_library( ] + stream_executor_build_defs_bzl_deps(), ) -#===--------------------------------------------------------------------------------------------===# -# StreamExecutor public API -#===--------------------------------------------------------------------------------------------===# - -# StreamExecutor itself is a small abstrtaction layer on top of platform-specific API -# implementations (e.g. see `stream_executor/cuda` folder for CUDA-specific details), and should -# not contribute a lot to binary size or compilation time. - -# TODO(klucke) Remove this target once the final user of this target is changed to use "stream" instead. -cc_library( - name = "stream_executor", - hdrs = [ - "stream.h", - ], - deps = [ - ":activate_context", - ":allocator_stats", - ":blas", - ":command_buffer", - ":data_type", - ":device_description", - ":device_description_proto_cc", - ":device_memory", - ":device_memory_allocator", - ":dnn", - ":event", - ":event_based_timer", - ":fft", - ":host_memory_allocation", # build_cleaner: keep - ":host_or_device_scalar", - ":kernel", - ":kernel_spec", - ":launch_dim", - ":memory_allocation", - ":module_spec", - ":numeric_options", - ":platform", - ":semantic_version", - ":stream_common", - ":stream_executor_common", - ":stream_executor_h", - "//xla/tsl/framework:device_id", - "//xla/tsl/framework:device_type", - "//xla/tsl/lib/gtl:int_type", - "//xla/tsl/protobuf:dnn_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/meta:type_traits", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", - ] + if_static([ - ":stream_executor_impl", - ]) + if_google([ - "@com_google_protobuf//:wrappers_cc_proto", # indirectly-used by dnn.h - ]), -) - #===--------------------------------------------------------------------------------------------===# # StreamExecutor public libraries #===--------------------------------------------------------------------------------------------===# @@ -145,8 +70,8 @@ cc_library( name = "device_memory", hdrs = ["device_memory.h"], deps = [ + "//xla/tsl/platform:logging", "@com_google_absl//absl/base:core_headers", - "@local_tsl//tsl/platform:logging", ], ) @@ -474,9 +399,11 @@ cc_library( ":allocator_stats", ":blas", ":command_buffer", + ":device_description", ":device_memory", ":dnn", ":event", + ":event_based_timer", ":fft", ":kernel", ":kernel_spec", @@ -507,7 +434,6 @@ cc_library( ":device_memory", ":event", ":event_based_timer", - ":kernel", ":launch_dim", ":platform", "@com_google_absl//absl/functional:any_invocable", @@ -567,6 +493,7 @@ cc_library( ":device_memory", ":kernel_spec", ":launch_dim", + ":stream", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/status", @@ -659,10 +586,10 @@ cc_library( ":device_memory", ":kernel", ":launch_dim", - ":platform", "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", ], diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index cbb470713b60e5..bb56f0f0c3ca4b 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -19,12 +19,12 @@ limitations under the License. #include #include #include -#include #include #include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/bit_pattern.h" #include "xla/stream_executor/device_memory.h" @@ -52,7 +52,7 @@ class CommandBuffer { // Execution scope enables fine-grained synchronization scopes inside // commands buffers. Implementation is very backend-specific and for CUDA/ROCM // backends it's implemented as DAG edges. By default all commands launched in - // the `kDefaulExecutionScope` execution scope. + // the `kDefaultExecutionScope` execution scope. // // Example #1: independent execution scopes and independent barriers // @@ -114,7 +114,7 @@ class CommandBuffer { // semantics as stream wait operation. // TSL_LIB_GTL_DEFINE_INT_TYPE(ExecutionScopeId, uint64_t); - static constexpr auto kDefaulExecutionScope = ExecutionScopeId(0); + static constexpr auto kDefaultExecutionScope = ExecutionScopeId(0); // Builder constructs nested command buffers owned by a parent command buffer. // @@ -159,7 +159,7 @@ class CommandBuffer { // enum class Mode { kPrimary, kNested }; - friend std::string_view ModeToString(Mode mode) { + friend absl::string_view ModeToString(Mode mode) { switch (mode) { case CommandBuffer::Mode::kPrimary: return "primary"; @@ -188,7 +188,7 @@ class CommandBuffer { ExecutionScopeId to_execution_scope_id) = 0; // Adds an execution barrier to the default execution scope. - absl::Status Barrier() { return Barrier(kDefaulExecutionScope); } + absl::Status Barrier() { return Barrier(kDefaultExecutionScope); } // Adds a kernel launch command. virtual absl::Status Launch(ExecutionScopeId execution_scope_id, @@ -198,7 +198,7 @@ class CommandBuffer { // Adds a kernel launch command to the default execution scope. absl::Status Launch(const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel, const KernelArgs& args) { - return Launch(kDefaulExecutionScope, threads, blocks, kernel, args); + return Launch(kDefaultExecutionScope, threads, blocks, kernel, args); } // Type-safe wrapper for launching typed kernels. Notice that the order of @@ -214,7 +214,7 @@ class CommandBuffer { absl::Status Launch(const TypedKernel& kernel, const ThreadDim& threads, const BlockDim& blocks, Args... args) { - return Launch(kernel, kDefaulExecutionScope, threads, blocks, args...); + return Launch(kernel, kDefaultExecutionScope, threads, blocks, args...); } // Adds a nested command buffer. @@ -223,7 +223,7 @@ class CommandBuffer { // Adds a nested command buffer to the default execution scope. absl::Status AddNestedCommandBuffer(const CommandBuffer& nested) { - return AddNestedCommandBuffer(kDefaulExecutionScope, nested); + return AddNestedCommandBuffer(kDefaultExecutionScope, nested); } // Adds a device-to-device memory copy. @@ -236,7 +236,7 @@ class CommandBuffer { absl::Status MemcpyDeviceToDevice(DeviceMemoryBase* dst, const DeviceMemoryBase& src, uint64_t size) { - return MemcpyDeviceToDevice(kDefaulExecutionScope, dst, src, size); + return MemcpyDeviceToDevice(kDefaultExecutionScope, dst, src, size); } // Adds a memset command. @@ -247,7 +247,7 @@ class CommandBuffer { // Adds a memset command to the default execution scope. absl::Status Memset(DeviceMemoryBase* dst, BitPattern bit_pattern, size_t num_elements) { - return Memset(kDefaulExecutionScope, dst, bit_pattern, num_elements); + return Memset(kDefaultExecutionScope, dst, bit_pattern, num_elements); } //--------------------------------------------------------------------------// @@ -261,7 +261,7 @@ class CommandBuffer { // Adds a conditional If operation to default execution scope. absl::Status If(DeviceMemory pred, Builder then_builder) { - return If(kDefaulExecutionScope, pred, then_builder); + return If(kDefaultExecutionScope, pred, then_builder); } // Adds a conditional operation that will execute a command buffer constructed @@ -274,7 +274,7 @@ class CommandBuffer { // Adds a conditional IfElse operation to default execution scope. absl::Status IfElse(DeviceMemory pred, Builder then_builder, Builder else_builder) { - return IfElse(kDefaulExecutionScope, pred, then_builder, else_builder); + return IfElse(kDefaultExecutionScope, pred, then_builder, else_builder); } // Adds a conditional operation that will execute a command buffer constructed @@ -289,7 +289,7 @@ class CommandBuffer { // Adds a conditional Case operation to default execution scope. absl::Status Case(DeviceMemory index, std::vector branches) { - return Case(kDefaulExecutionScope, index, branches); + return Case(kDefaultExecutionScope, index, branches); } // Adds a conditional operation that will execute a command buffer constructed @@ -304,7 +304,7 @@ class CommandBuffer { // Adds a conditional For operation to default execution scope. absl::Status For(int32_t num_iteration, DeviceMemory loop_counter, Builder body_builder) { - return For(kDefaulExecutionScope, num_iteration, loop_counter, + return For(kDefaultExecutionScope, num_iteration, loop_counter, body_builder); } @@ -332,7 +332,7 @@ class CommandBuffer { // Adds a conditional While operation to default execution scope. absl::Status While(DeviceMemory pred, ExecutionScopeBuilder cond_builder, Builder body_builder) { - return While(kDefaulExecutionScope, pred, cond_builder, body_builder); + return While(kDefaultExecutionScope, pred, cond_builder, body_builder); } // Submits the command buffer for execution. diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index f15e8a84eaf77b..9acd61252f381a 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -570,13 +570,17 @@ cuda_only_cc_library( "//xla/stream_executor:activate_context", "//xla/stream_executor:kernel", "//xla/stream_executor:launch_dim", + "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", ], ) @@ -614,6 +618,7 @@ cc_library( deps = [ "//xla/stream_executor:kernel_spec", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ], ) @@ -712,6 +717,7 @@ xla_cc_test( deps = [ ":ptx_compiler_helpers", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test", @@ -981,9 +987,12 @@ cc_library( "//xla/stream_executor:device_description", "//xla/stream_executor/gpu:gpu_asm_opts", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:path", ], @@ -1522,6 +1531,7 @@ cc_library( ":compilation_provider", "//xla/stream_executor:device_description", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", ], @@ -1792,6 +1802,7 @@ xla_cc_test( ":mock_compilation_provider", "//xla/stream_executor:device_description", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/stream_executor/cuda/assemble_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/assemble_compilation_provider.cc index c2d88551aa1736..2214b6d1e467e4 100644 --- a/third_party/xla/xla/stream_executor/cuda/assemble_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/assemble_compilation_provider.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include @@ -137,7 +136,7 @@ AssembleCompilationProvider(const xla::DebugOptions& debug_options) { TF_RETURN_IF_ERROR(CheckIncompatibleFlagSettings(debug_options)); std::string decision_log; - const auto append_to_decision_log = [&](std::string_view decision) { + const auto append_to_decision_log = [&](absl::string_view decision) { VLOG(4) << decision; absl::StrAppend(&decision_log, " - ", decision, "\n"); }; diff --git a/third_party/xla/xla/stream_executor/cuda/caching_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/caching_compilation_provider.cc index e84e3ca97f42dc..23872d260ca3c1 100644 --- a/third_party/xla/xla/stream_executor/cuda/caching_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/caching_compilation_provider.cc @@ -16,13 +16,13 @@ limitations under the License. #include "xla/stream_executor/cuda/caching_compilation_provider.h" #include -#include #include #include #include #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/stream_executor/cuda/compilation_options.h" @@ -45,7 +45,7 @@ bool CachingCompilationProvider::SupportsCompileAndLink() const { } absl::StatusOr CachingCompilationProvider::Compile( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const { CacheKey cache_key{cc, std::string{ptx}, options}; { @@ -78,7 +78,7 @@ absl::StatusOr CachingCompilationProvider::Compile( absl::StatusOr CachingCompilationProvider::CompileToRelocatableModule( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const { CacheKey cache_key{cc, std::string{ptx}, options}; { diff --git a/third_party/xla/xla/stream_executor/cuda/caching_compilation_provider.h b/third_party/xla/xla/stream_executor/cuda/caching_compilation_provider.h index cdde48c99340ac..264b0384d99d46 100644 --- a/third_party/xla/xla/stream_executor/cuda/caching_compilation_provider.h +++ b/third_party/xla/xla/stream_executor/cuda/caching_compilation_provider.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -52,10 +51,10 @@ class CachingCompilationProvider : public CompilationProvider { bool SupportsCompileAndLink() const override; absl::StatusOr Compile( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const override; absl::StatusOr CompileToRelocatableModule( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const override; absl::StatusOr CompileAndLink( const CudaComputeCapability& cc, diff --git a/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.cc b/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.cc index e41c6f0954f005..e1aad01d026bc9 100644 --- a/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.cc +++ b/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.cc @@ -15,9 +15,8 @@ limitations under the License. #include "xla/stream_executor/cuda/command_buffer_kernels.h" -#include - #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/stream_executor/kernel_spec.h" namespace stream_executor { @@ -48,7 +47,7 @@ namespace { // } // // Easiest way to get PTX from C++ is to use https://godbolt.org. -inline constexpr std::string_view kSetIfConditionKernel = R"( +inline constexpr absl::string_view kSetIfConditionKernel = R"( .version 4.0 .target sm_50 .address_size 64 @@ -130,7 +129,7 @@ inline constexpr std::string_view kSetIfConditionKernel = R"( // } // // Easiest way to get PTX from C++ is to use https://godbolt.org. -inline constexpr std::string_view kSetIfElseConditionKernel = R"( +inline constexpr absl::string_view kSetIfElseConditionKernel = R"( .version 4.0 .target sm_50 .address_size 64 @@ -277,7 +276,7 @@ inline constexpr std::string_view kSetIfElseConditionKernel = R"( // // Easiest way to get PTX from C++ is to use https://godbolt.org. // May have to include these compiler options: -arch sm_50 -inline constexpr std::string_view kSetCaseConditionKernel = R"( +inline constexpr absl::string_view kSetCaseConditionKernel = R"( .version 4.0 .target sm_50 .address_size 64 @@ -635,7 +634,7 @@ inline constexpr std::string_view kSetCaseConditionKernel = R"( // } // // Easiest way to get PTX from C++ is to use https://godbolt.org. -inline constexpr std::string_view kSetForConditionKernel = R"( +inline constexpr absl::string_view kSetForConditionKernel = R"( .version 4.0 .target sm_50 .address_size 64 @@ -711,7 +710,7 @@ inline constexpr std::string_view kSetForConditionKernel = R"( })"; // While condition kernel is the same as an `If` with a single branch. -inline constexpr std::string_view kSetWhileConditionKernel = R"( +inline constexpr absl::string_view kSetWhileConditionKernel = R"( .version 4.0 .target sm_50 .address_size 64 @@ -783,7 +782,7 @@ inline constexpr std::string_view kSetWhileConditionKernel = R"( // __global__ void noop() {} // // Easiest way to get PTX from C++ is to use https://godbolt.org. -inline constexpr std::string_view kNoOpKernel = R"( +inline constexpr absl::string_view kNoOpKernel = R"( .version 4.0 .target sm_50 .address_size 64 diff --git a/third_party/xla/xla/stream_executor/cuda/compilation_provider.h b/third_party/xla/xla/stream_executor/cuda/compilation_provider.h index c12e3d35f72775..38efab1e14ab8e 100644 --- a/third_party/xla/xla/stream_executor/cuda/compilation_provider.h +++ b/third_party/xla/xla/stream_executor/cuda/compilation_provider.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include @@ -95,7 +94,7 @@ class CompilationProvider { // Compiles a single PTX module into a CUDA program. This method is supported // by all compilation providers. virtual absl::StatusOr Compile( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const = 0; // Compiles the given PTX string into relocatable CUBIN for the given @@ -103,7 +102,7 @@ class CompilationProvider { // providers. `SupportsCompileToRelocatableModule` can be used to check if // this method is supported. virtual absl::StatusOr CompileToRelocatableModule( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const = 0; // Returns true if 'CompileToRelocatableModule' can be used. diff --git a/third_party/xla/xla/stream_executor/cuda/compilation_provider_test.cc b/third_party/xla/xla/stream_executor/cuda/compilation_provider_test.cc index ca15b11d216b2d..3571c33dcfe80f 100644 --- a/third_party/xla/xla/stream_executor/cuda/compilation_provider_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/compilation_provider_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include @@ -78,7 +77,7 @@ void CompilationProviderTest::SetUp() { } absl::StatusOr> -CompilationProviderTest::CreateCompilationProvider(std::string_view name) { +CompilationProviderTest::CreateCompilationProvider(absl::string_view name) { if (name == kSubprocessCompilationProviderName) { TF_ASSIGN_OR_RETURN(auto ptxas, FindCudaExecutable("ptxas", "/does/not/exist")); diff --git a/third_party/xla/xla/stream_executor/cuda/compilation_provider_test.h b/third_party/xla/xla/stream_executor/cuda/compilation_provider_test.h index 1b3c4f8a75e068..118d2c8389fe2e 100644 --- a/third_party/xla/xla/stream_executor/cuda/compilation_provider_test.h +++ b/third_party/xla/xla/stream_executor/cuda/compilation_provider_test.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include "absl/status/statusor.h" @@ -26,18 +25,18 @@ limitations under the License. namespace stream_executor::cuda { -inline constexpr std::string_view kSubprocessCompilationProviderName = +inline constexpr absl::string_view kSubprocessCompilationProviderName = "subprocess"; -inline constexpr std::string_view kNvJitLinkCompilationProviderName = +inline constexpr absl::string_view kNvJitLinkCompilationProviderName = "nvjitlink"; -inline constexpr std::string_view kNvptxcompilerCompilationProviderName = +inline constexpr absl::string_view kNvptxcompilerCompilationProviderName = "nvptxcompiler"; -inline constexpr std::string_view kDriverCompilationProviderName = "driver"; +inline constexpr absl::string_view kDriverCompilationProviderName = "driver"; class CompilationProviderTest - : public testing::TestWithParam { + : public testing::TestWithParam { absl::StatusOr> - CreateCompilationProvider(std::string_view name); + CreateCompilationProvider(absl::string_view name); void SetUp() override; std::unique_ptr compilation_provider_; @@ -51,7 +50,7 @@ class CompilationProviderTest // Prints the test parameter name as is. Needed for gtest instantiation. struct CompilationProviderTestParamNamePrinter { std::string operator()( - const ::testing::TestParamInfo& name) const { + const ::testing::TestParamInfo& name) const { return std::string(name.param); } }; diff --git a/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.cc index 6ec968b714b853..c9e665aa514600 100644 --- a/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include @@ -78,14 +77,14 @@ CompositeCompilationProvider::Create( } absl::StatusOr CompositeCompilationProvider::Compile( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const { return providers_.front()->Compile(cc, ptx, options); } absl::StatusOr CompositeCompilationProvider::CompileToRelocatableModule( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const { if (!relocatable_compilation_provider_) { return absl::UnavailableError( diff --git a/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.h b/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.h index 5ec987ec39f6af..131d80d30b3aef 100644 --- a/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.h +++ b/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include "absl/status/statusor.h" @@ -50,10 +49,10 @@ class CompositeCompilationProvider : public CompilationProvider { bool SupportsCompileAndLink() const override; absl::StatusOr Compile( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const override; absl::StatusOr CompileToRelocatableModule( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const override; absl::StatusOr CompileAndLink( const CudaComputeCapability& cc, diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc index 1a5b09593de253..6fa559998b76f2 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc @@ -18,11 +18,19 @@ limitations under the License. #include #include #include +#include +#include #include +#include "absl/base/const_init.h" +#include "absl/base/optimization.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" #include "xla/stream_executor/cuda/cubin_or_ptx_image.h" #include "xla/stream_executor/cuda/ptx_compiler.h" #include "xla/stream_executor/cuda/ptx_compiler_support.h" @@ -52,4 +60,37 @@ absl::StatusOr> CompileGpuAsm( return CompileGpuAsmUsingPtxAs(cc, ptx, options, cancel_if_reg_spill); } +absl::StatusOr> CompileGpuAsmOrGetCached( + const CudaComputeCapability& cc, const std::string& ptx, + GpuAsmOpts compilation_options) { + using PtxCacheKey = std::tuple; + using PtxCompilerResult = absl::StatusOr>; + static absl::Mutex ptx_cache_mutex(absl::kConstInit); + static auto& ptx_cache ABSL_GUARDED_BY(ptx_cache_mutex) = + *new absl::flat_hash_map(); + + absl::MutexLock lock(&ptx_cache_mutex); + PtxCacheKey cache_key{cc, ptx, compilation_options.ToTuple()}; + auto it = ptx_cache.find(cache_key); + if (it == ptx_cache.end()) { + PtxCompilerResult compiled = CompileGpuAsm(cc, ptx, compilation_options); + it = ptx_cache.emplace(cache_key, std::move(compiled)).first; + } + + CHECK(it != ptx_cache.end()); + + // Failed compilation attempts are cached. + // Use separate status check and ValueOrDie invocation on ptx_cache + // entry to avoid value moving introduced by TF_ASSIGN_OR_RETURN. + + if (ABSL_PREDICT_FALSE(!it->second.ok())) { + return it->second.status(); + } + + const std::vector& compiled = it->second.value(); + return absl::MakeSpan(compiled); +} + + } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.h b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.h index 52bba651def65a..caf2af501526e8 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.h @@ -45,6 +45,14 @@ inline absl::StatusOr> CompileGpuAsm( std::string(ptx_contents), options, cancel_if_reg_spill); } +// Same as CompileGpuAsm, but caches the result, and returns unowned view of +// the compiled binary. +// +// A copy of the string provided in ptx will be made. +absl::StatusOr> CompileGpuAsmOrGetCached( + const CudaComputeCapability& cc, const std::string& ptx_contents, + GpuAsmOpts compilation_options); + // Bundles the GPU machine code (cubins) and PTX if requested and returns the // resulting binary (i.e. a fatbin) as a byte array. absl::StatusOr> BundleGpuAsm( diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc index 4ddb5348dc75bc..ca7b9b345dd6a5 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc @@ -612,7 +612,7 @@ absl::Status CudaCommandBuffer::PrepareFinalization() { } TF_ASSIGN_OR_RETURN(NoOpKernel * noop, GetNoOpKernel()); - TF_RETURN_IF_ERROR(CommandBuffer::Launch(*noop, kDefaulExecutionScope, + TF_RETURN_IF_ERROR(CommandBuffer::Launch(*noop, kDefaultExecutionScope, ThreadDim(), BlockDim())); return absl::OkStatus(); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index 57448f9c01319c..e27af9a3ae53be 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -4965,6 +4965,10 @@ static absl::StatusOr RebuildExecutionPlan( } // namespace +void FixDimsForRaggedOffset(std::vector& dims, int max_reg_per_batch) { + dims[0] *= max_reg_per_batch; +} + absl::StatusOr GetCudnnFlashAttentionOperationGraph( dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_descriptor, @@ -4974,7 +4978,8 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( const std::optional bias_descriptor, const std::optional stats_descriptor, double scale, const bool use_dropout, const std::optional dropout_rate, - const dnn::FMHAMaskKind mask_type, const int sliding_window_length) { + const dnn::FMHAMaskKind mask_type, const int sliding_window_length, + const int max_seg_per_batch) { using cudnn_frontend::graph::Tensor_attributes; #if CUDNN_VERSION >= 90000 @@ -5007,23 +5012,34 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; + std::vector q_dims = q_descriptor.GetCudnnCompatibleDimensions(true); + std::vector k_dims = k_descriptor.GetCudnnCompatibleDimensions(true); + std::vector v_dims = + v_descriptor.GetCudnnCompatibleDimensions(false); + + if (max_seg_per_batch > 1) { + FixDimsForRaggedOffset(q_dims, max_seg_per_batch); + FixDimsForRaggedOffset(k_dims, max_seg_per_batch); + FixDimsForRaggedOffset(v_dims, max_seg_per_batch); + } + std::shared_ptr q_tensor = graph.tensor(Tensor_attributes() .set_name("Q") - .set_dim(q_descriptor.GetCudnnCompatibleDimensions(true)) + .set_dim(q_dims) .set_stride(q_descriptor.GetCudnnCompatibleStrides(true)) .set_uid(next_uid())); std::shared_ptr k_tensor = graph.tensor(Tensor_attributes() .set_name("K") - .set_dim(k_descriptor.GetCudnnCompatibleDimensions(true)) + .set_dim(k_dims) .set_stride(k_descriptor.GetCudnnCompatibleStrides(true)) .set_uid(next_uid())); std::shared_ptr v_tensor = graph.tensor( Tensor_attributes() .set_name("V") - .set_dim(v_descriptor.GetCudnnCompatibleDimensions(false)) + .set_dim(v_dims) .set_stride(v_descriptor.GetCudnnCompatibleStrides(false)) .set_uid(next_uid())); @@ -5049,9 +5065,9 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( // Setting actual seqlen bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - if (is_padding) { - auto q_dim = q_descriptor.GetCudnnCompatibleDimensions(true); - auto b = q_dim[0]; + if (is_padding || max_seg_per_batch > 1) { + // Get batch size + auto b = q_dims[0]; auto seq_q_tensor = graph.tensor(Tensor_attributes() .set_name("seq_q") @@ -5070,6 +5086,30 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( sdpa_options.set_seq_len_q(seq_q_tensor); sdpa_options.set_seq_len_kv(seq_kv_tensor); } + + std::shared_ptr offset_q; + if (max_seg_per_batch > 1) { + // Get batch size + auto b = q_dims[0]; + offset_q = + graph.tensor(Tensor_attributes() + .set_name("offset_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::INT32)); + auto offset_kv = + graph.tensor(Tensor_attributes() + .set_name("offset_kv") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::INT32)); + q_tensor->set_ragged_offset(offset_q); + k_tensor->set_ragged_offset(offset_kv); + v_tensor->set_ragged_offset(offset_kv); + } + // Setting seed and offset std::shared_ptr seed_tensor; std::shared_ptr offset_tensor; @@ -5100,10 +5140,16 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( auto [o_tensor, stats_tensor] = graph.sdpa(q_tensor, k_tensor, v_tensor, sdpa_options); + auto o_dims = o_descriptor.dimensions(); + + if (max_seg_per_batch > 1) { + FixDimsForRaggedOffset(o_dims, max_seg_per_batch); + o_tensor->set_ragged_offset(offset_q); + } // Set output attributes. o_tensor->set_name("O") .set_output(true) - .set_dim(o_descriptor.dimensions()) + .set_dim(o_dims) .set_stride(o_descriptor.GetLogicalStrides()) .set_uid(next_uid()); if (stats_descriptor.has_value()) { @@ -5488,7 +5534,8 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( const std::optional bias_descriptor, std::optional dropout_rate, std::optional seed, double scale, bool use_dropout, bool use_bias, dnn::FMHAMaskKind mask_type, - bool force_deterministic, const int sliding_window_length) { + bool force_deterministic, const int sliding_window_length, + const int max_seg_per_batch) { #if CUDNN_VERSION >= 90000 if (VLOG_IS_ON(4)) { VLOG(4) << "\n bmm1_grad_gemm1_rhs(q): " << q_desc.ToString() @@ -5514,19 +5561,38 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT) .set_io_data_type(ioDataType); - auto p_dims = p_desc.GetCudnnCompatibleDimensions(false); - auto p_strides = p_desc.GetCudnnCompatibleStrides(false); - std::vector p_reduction_dims(p_dims.begin(), p_dims.end() - 1); - p_reduction_dims.push_back(1); - + // Get dims and strides + std::vector q_dims = q_desc.GetCudnnCompatibleDimensions(false); + std::vector k_dims = k_desc.GetCudnnCompatibleDimensions(false); + std::vector v_dims = v_desc.GetCudnnCompatibleDimensions(true); + std::vector p_dims = p_desc.GetCudnnCompatibleDimensions(false); + std::vector p_strides = p_desc.GetCudnnCompatibleStrides(false); + std::vector do_dims = do_desc.GetCudnnCompatibleDimensions(false); + std::vector dq_dims = dq_desc.dimensions(); + std::vector dk_dims = dk_desc.dimensions(); + std::vector dv_dims = dv_desc.dimensions(); + std::vector stats_dims(p_dims.begin(), p_dims.end() - 1); + stats_dims.push_back(1); // Divide every stride by the last dim value. - std::vector p_reduction_strides; - p_reduction_strides.reserve(p_strides.size()); + std::vector stats_strides; + stats_strides.reserve(p_strides.size()); int64_t p_reduced_dim_len = p_dims.back(); for (auto stride : p_strides) { - p_reduction_strides.push_back(stride / p_reduced_dim_len); + stats_strides.push_back(stride / p_reduced_dim_len); + } + stats_strides[3] = 1; + + if (max_seg_per_batch > 1) { + FixDimsForRaggedOffset(q_dims, max_seg_per_batch); + FixDimsForRaggedOffset(k_dims, max_seg_per_batch); + FixDimsForRaggedOffset(v_dims, max_seg_per_batch); + FixDimsForRaggedOffset(p_dims, max_seg_per_batch); + FixDimsForRaggedOffset(do_dims, max_seg_per_batch); + FixDimsForRaggedOffset(dq_dims, max_seg_per_batch); + FixDimsForRaggedOffset(dk_dims, max_seg_per_batch); + FixDimsForRaggedOffset(dv_dims, max_seg_per_batch); + FixDimsForRaggedOffset(stats_dims, max_seg_per_batch); } - p_reduction_strides[3] = 1; bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL || mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; auto sdpa_backward_options = @@ -5541,52 +5607,51 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( std::shared_ptr q = graph.tensor(Tensor_attributes() .set_name("Q") - .set_dim(q_desc.GetCudnnCompatibleDimensions(false)) + .set_dim(q_dims) .set_stride(q_desc.GetCudnnCompatibleStrides(false)) .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr k = graph.tensor(Tensor_attributes() .set_name("K") - .set_dim(k_desc.GetCudnnCompatibleDimensions(false)) + .set_dim(k_dims) .set_stride(k_desc.GetCudnnCompatibleStrides(false)) .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr v = graph.tensor(Tensor_attributes() .set_name("V") - .set_dim(v_desc.GetCudnnCompatibleDimensions(true)) + .set_dim(v_dims) .set_stride(v_desc.GetCudnnCompatibleStrides(true)) .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr stats = graph.tensor(Tensor_attributes() .set_name("stats") - .set_dim(p_reduction_dims) - .set_stride(p_reduction_strides) + .set_dim(stats_dims) + .set_stride(stats_strides) .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::FLOAT)); std::shared_ptr dO = graph.tensor(Tensor_attributes() .set_name("dO") - .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) + .set_dim(do_dims) .set_stride(do_desc.GetCudnnCompatibleStrides(false)) .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr d_bias_tensor; if (use_bias) { DCHECK(bias_descriptor != std::nullopt); - auto bias_dim = bias_descriptor->dimensions(); - auto q_dim = q_desc.GetCudnnCompatibleDimensions(false); - auto b = bias_dim[0]; - auto n = bias_dim[1]; - auto q_n = q_dim[1]; - auto bias_tensor = - graph.tensor(Tensor_attributes() - .set_name("bias") - .set_dim(bias_descriptor->dimensions()) - .set_stride(bias_descriptor->GetLogicalStrides()) - .set_uid(next_uid())); + auto bias_dims = bias_descriptor->dimensions(); + auto bias_strides = bias_descriptor->GetLogicalStrides(); + auto b = bias_dims[0]; + auto n = bias_dims[1]; + auto q_n = q_dims[1]; + auto bias_tensor = graph.tensor(Tensor_attributes() + .set_name("bias") + .set_dim(bias_dims) + .set_stride(bias_strides) + .set_uid(next_uid())); sdpa_backward_options.set_bias(bias_tensor); // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] are not supported for @@ -5604,7 +5669,7 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( std::shared_ptr o = graph.tensor(Tensor_attributes() .set_name("O") - .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) + .set_dim(do_dims) .set_stride(do_desc.GetCudnnCompatibleStrides(false)) .set_uid(next_uid()) .set_data_type(ioDataType)); @@ -5612,9 +5677,10 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( // Setting actual seqlen bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - if (is_padding) { - auto q_dim = q_desc.GetCudnnCompatibleDimensions(false); - auto b = q_dim[0]; + + if (is_padding || max_seg_per_batch > 1) { + // Get batch size + auto b = q_dims[0]; auto seq_q_tensor = graph.tensor(Tensor_attributes() .set_name("seq_q") @@ -5633,6 +5699,31 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( sdpa_backward_options.set_seq_len_q(seq_q_tensor); sdpa_backward_options.set_seq_len_kv(seq_kv_tensor); } + + std::shared_ptr offset_q, offset_kv; + if (max_seg_per_batch > 1) { + // Get batch size + auto b = q_dims[0]; + offset_q = + graph.tensor(Tensor_attributes() + .set_name("offset_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::INT32)); + offset_kv = + graph.tensor(Tensor_attributes() + .set_name("offset_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::INT32)); + q->set_ragged_offset(offset_q); + k->set_ragged_offset(offset_kv); + v->set_ragged_offset(offset_kv); + o->set_ragged_offset(offset_q); + dO->set_ragged_offset(offset_q); + } // Setting seed and offset std::shared_ptr seed_tensor; std::shared_ptr offset_tensor; @@ -5668,20 +5759,25 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( auto [dQ, dK, dV] = graph.sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options); + if (max_seg_per_batch > 1) { + dQ->set_ragged_offset(offset_q); + dK->set_ragged_offset(offset_kv); + dV->set_ragged_offset(offset_kv); + } dQ->set_output(true) - .set_dim(dq_desc.dimensions()) + .set_dim(dq_dims) .set_stride(dq_desc.GetLogicalStrides()) .set_uid(next_uid()) .set_name("dQ") .set_data_type(ioDataType); dK->set_output(true) - .set_dim(dk_desc.dimensions()) + .set_dim(dk_dims) .set_stride(dk_desc.GetLogicalStrides()) .set_uid(next_uid()) .set_name("dK") .set_data_type(ioDataType); dV->set_output(true) - .set_dim(dv_desc.dimensions()) + .set_dim(dv_dims) .set_stride(dv_desc.GetLogicalStrides()) .set_uid(next_uid()) .set_name("dV") @@ -8474,9 +8570,9 @@ absl::Status CudnnGraph::Execute(Stream& stream, const CudnnSupport& dnn_support = static_cast(*stream.parent()->AsDnn()); - RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.execute( - dnn_support.cudnn_->GetHandle(stream.parent(), &stream).handle(), - tensor_to_ptr_map, workspace.opaque())); + auto cudnn = dnn_support.cudnn_->GetHandle(stream.parent(), &stream); + RETURN_IF_CUDNN_FRONTEND_ERROR( + graph_.execute(cudnn.handle(), tensor_to_ptr_map, workspace.opaque())); return absl::OkStatus(); } diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h index 16a08231263500..9d46794e2329b8 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h @@ -68,7 +68,7 @@ class CudnnGraph : public dnn::DnnGraph { int64_t local_device_ordinal) const override; const cudnn_frontend::graph::Graph& Graph() const { return graph_; } void InitDropoutState(int64_t local_device_count, int64_t seed, - int64_t increment) { + int64_t increment) override { dropout_rng_seed_ = seed; current_dropout_rng_offset_ = std::vector(local_device_count, 0); dropout_rng_offset_increment_ = increment; @@ -707,7 +707,8 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( const std::optional bias_descriptor, const std::optional stats_descriptor, double scale, const bool use_dropout, const std::optional dropout_rate, - const dnn::FMHAMaskKind mask_type, const int sliding_window_length); + const dnn::FMHAMaskKind mask_type, const int sliding_window_length, + const int max_seg_per_batch); absl::StatusOr GetCudnnFlashAttentionF8OperationGraph( dnn::DnnSupport& dnn_support, @@ -730,7 +731,7 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( std::optional dropout_rate, std::optional seed, double scale, bool use_dropout, bool use_bias, const dnn::FMHAMaskKind mask_type, bool force_deterministic, - const int sliding_window_length); + const int sliding_window_length, const int max_seg_per_batch); absl::StatusOr GetCudnnFlashAttentionBackwardF8OperationGraph( dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_desc, diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 250d0c4390e6a2..2c4b7bd5022c84 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -97,7 +97,7 @@ bool ShouldLaunchDelayKernel() { // Only launch the delay kernel if CUDA_LAUNCH_BLOCKING is not set to 1. static bool value = [] { const char* blocking = std::getenv("CUDA_LAUNCH_BLOCKING"); - return !blocking || std::string_view{blocking} != "1"; + return !blocking || absl::string_view{blocking} != "1"; }(); return value; } diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc index 66d01bda9713a2..4c68fda5099683 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc @@ -18,8 +18,11 @@ limitations under the License. #include #include #include +#include +#include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -28,7 +31,9 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" -#include "tsl/platform/errors.h" +#include "xla/stream_executor/stream.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" namespace stream_executor { namespace gpu { @@ -75,5 +80,57 @@ absl::StatusOr CudaKernel::GetKernelMetadata() { return kernel_metadata; } +absl::Status CudaKernel::Launch(const ThreadDim& thread_dims, + const BlockDim& block_dims, + const std::optional& cluster_dims, + Stream* stream, const KernelArgs& args) { + CUfunction function = gpu_function(); + + // Launch kernels with packed arguments. + auto launch = [this, stream, &cluster_dims, &thread_dims, &block_dims, + function](const KernelArgsPackedArrayBase& packed) { + int32_t expected_number_of_arguments = + Arity() + (packed.number_of_shared_bytes() > 0); + + CHECK_EQ(expected_number_of_arguments, packed.number_of_arguments()) + << "Kernel " << name() << " has " << packed.number_of_arguments() + << " arguments, but expected " << expected_number_of_arguments + << "; arity=" << Arity() + << "; number_of_shared_bytes=" << packed.number_of_shared_bytes(); + + void** params = const_cast(packed.argument_addresses().data()); + + if (cluster_dims.has_value()) { + return stream->LaunchKernel(thread_dims, block_dims, cluster_dims, + function, name(), params, + packed.number_of_shared_bytes()); + } else { + return stream->LaunchKernel(thread_dims, block_dims, std::nullopt, + function, name(), params, + packed.number_of_shared_bytes()); + } + }; + + // If arguments are already packed we can just launch the kernel. + if (auto* packed = DynCast(&args)) { + return launch(*packed); + } + + // For device memory array we rely on a custom kernel arguments packing. + if (auto* device_mem = DynCast(&args)) { + auto& pack = args_packing(); + if (!pack) { + return absl::InternalError( + "Kernel is missing a custom arguments packing function for device " + "memory arguments array"); + } + + TF_ASSIGN_OR_RETURN(auto packed, pack(*this, *device_mem)); + return launch(*packed); + } + + return absl::InternalError("Unsupported kernel arguments type"); +} + } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.h b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.h index 55bc34f229072e..c2e0b990d999a6 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.h @@ -60,6 +60,10 @@ class CudaKernel : public Kernel { absl::StatusOr GetKernelMetadata(); private: + absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, + const std::optional &cluster_dims, + Stream *stream, const KernelArgs &args) override; + StreamExecutor* executor_ = nullptr; CUfunction gpu_function_ = nullptr; // wrapped CUDA kernel handle diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_stream.cc b/third_party/xla/xla/stream_executor/cuda/cuda_stream.cc index 469c19a8b60b58..fa1a7d2ad0865a 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_stream.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_stream.cc @@ -335,13 +335,12 @@ absl::Status CudaStream::DoHostCallbackWithStatus( } namespace { -absl::Status LaunchKernel(StreamExecutor* executor, - absl::string_view kernel_name, CUfunction function, - unsigned int grid_dim_x, unsigned int grid_dim_y, - unsigned int grid_dim_z, unsigned int block_dim_x, - unsigned int block_dim_y, unsigned int block_dim_z, - unsigned int shared_mem_bytes, CUstream stream, - void** kernel_params, void** extra) { +absl::Status LaunchCudaKernel( + StreamExecutor* executor, absl::string_view kernel_name, + CUfunction function, unsigned int grid_dim_x, unsigned int grid_dim_y, + unsigned int grid_dim_z, unsigned int block_dim_x, unsigned int block_dim_y, + unsigned int block_dim_z, unsigned int shared_mem_bytes, CUstream stream, + void** kernel_params, void** extra) { std::unique_ptr activation = executor->Activate(); VLOG(2) << "launching kernel: " << kernel_name << "; gdx: " << grid_dim_x << " gdy: " << grid_dim_y << " gdz: " << grid_dim_z @@ -371,16 +370,14 @@ absl::Status LaunchKernel(StreamExecutor* executor, "; shared memory size: ", shared_mem_bytes)); } -absl::Status LaunchKernel(StreamExecutor* executor, - absl::string_view kernel_name, CUfunction function, - unsigned int cluster_dim_x, - unsigned int cluster_dim_y, - unsigned int cluster_dim_z, unsigned int grid_dim_x, - unsigned int grid_dim_y, unsigned int grid_dim_z, - unsigned int block_dim_x, unsigned int block_dim_y, - unsigned int block_dim_z, - unsigned int shared_mem_bytes, CUstream stream, - void** kernel_params, void** extra) { +absl::Status LaunchCudaKernel( + StreamExecutor* executor, absl::string_view kernel_name, + CUfunction function, unsigned int cluster_dim_x, unsigned int cluster_dim_y, + unsigned int cluster_dim_z, unsigned int grid_dim_x, + unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x, + unsigned int block_dim_y, unsigned int block_dim_z, + unsigned int shared_mem_bytes, CUstream stream, void** kernel_params, + void** extra) { std::unique_ptr activation = executor->Activate(); VLOG(2) << "launching kernel: " << kernel_name << "; cdx: " << cluster_dim_x << " cdy: " << cluster_dim_y << " cdz: " << cluster_dim_z @@ -433,62 +430,24 @@ absl::Status LaunchKernel(StreamExecutor* executor, } // namespace -absl::Status CudaStream::Launch(const ThreadDim& thread_dims, - const BlockDim& block_dims, - const std::optional& cluster_dims, - const Kernel& kernel, const KernelArgs& args) { - const CudaKernel* gpu_kernel = static_cast(&kernel); - CUfunction function = gpu_kernel->gpu_function(); - - // Launch kernels with packed arguments. - auto launch = [this, &kernel, &cluster_dims, &thread_dims, &block_dims, - &function](const KernelArgsPackedArrayBase& packed) { - int32_t expected_number_of_arguments = - kernel.Arity() + (packed.number_of_shared_bytes() > 0); - - CHECK_EQ(expected_number_of_arguments, packed.number_of_arguments()) - << "Kernel " << kernel.name() << " has " << packed.number_of_arguments() - << " arguments, but expected " << expected_number_of_arguments - << "; arity=" << kernel.Arity() - << "; number_of_shared_bytes=" << packed.number_of_shared_bytes(); - - void** params = const_cast(packed.argument_addresses().data()); - - if (cluster_dims.has_value()) { - return LaunchKernel( - executor_, kernel.name(), function, cluster_dims->x, cluster_dims->y, - cluster_dims->z, block_dims.x, block_dims.y, block_dims.z, - thread_dims.x, thread_dims.y, thread_dims.z, - packed.number_of_shared_bytes(), stream_handle_, params, - /*extra=*/nullptr); - } else { - return LaunchKernel( - executor_, kernel.name(), function, block_dims.x, block_dims.y, - block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, - packed.number_of_shared_bytes(), stream_handle_, params, - /*extra=*/nullptr); - } - }; - - // If arguments are already packed we can just launch the kernel. - if (auto* packed = DynCast(&args)) { - return launch(*packed); - } - - // For device memory array we rely on a custom kernel arguments packing. - if (auto* device_mem = DynCast(&args)) { - auto& pack = kernel.args_packing(); - if (!pack) { - return absl::InternalError( - "Kernel is missing a custom arguments packing function for device " - "memory arguments array"); - } - - TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem)); - return launch(*packed); +absl::Status CudaStream::LaunchKernel( + const ThreadDim& thread_dims, const BlockDim& block_dims, + const std::optional& cluster_dims, void* function, + absl::string_view name, void** args, int64_t shmem_bytes) { + if (cluster_dims.has_value()) { + return LaunchCudaKernel(executor_, name, static_cast(function), + cluster_dims->x, cluster_dims->y, cluster_dims->z, + block_dims.x, block_dims.y, block_dims.z, + thread_dims.x, thread_dims.y, thread_dims.z, + shmem_bytes, stream_handle_, args, + /*extra=*/nullptr); + } else { + return LaunchCudaKernel(executor_, name, static_cast(function), + block_dims.x, block_dims.y, block_dims.z, + thread_dims.x, thread_dims.y, thread_dims.z, + shmem_bytes, stream_handle_, args, + /*extra=*/nullptr); } - - return absl::InternalError("Unsupported kernel arguments type"); } void CudaStream::SetName(std::string name) { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_stream.h b/third_party/xla/xla/stream_executor/cuda/cuda_stream.h index 7d8be77df9366c..a692e0e2508ca2 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_stream.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_stream.h @@ -89,9 +89,11 @@ class CudaStream : public StreamCommon { absl::Status RecordCompletedEvent(); - absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, - const std::optional& cluster_dims, - const Kernel& kernel, const KernelArgs& args) override; + absl::Status LaunchKernel(const ThreadDim& thread_dims, + const BlockDim& block_dims, + const std::optional& cluster_dims, + void* function, absl::string_view name, void** args, + int64_t shmem_bytes) override; StreamExecutor* executor_; CudaEvent completed_event_; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_stream_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_stream_test.cc index f678af5dd8f071..82818e8cfa23ec 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_stream_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_stream_test.cc @@ -219,7 +219,7 @@ TEST_F(CudaStreamTest, LaunchKernel) { EXPECT_THAT(stream->Memset32(&a, 1, kByteLength), IsOk()); EXPECT_THAT(stream->Memset32(&b, 2, kByteLength), IsOk()); EXPECT_THAT(stream->MemZero(&c, kByteLength), IsOk()); - EXPECT_THAT(stream->ThenLaunch(ThreadDim(), BlockDim(kLength), add, a, b, c), + EXPECT_THAT(add.Launch(ThreadDim(), BlockDim(kLength), stream.get(), a, b, c), IsOk()); EXPECT_THAT(stream->BlockHostUntilDone(), IsOk()); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_timer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_timer_test.cc index 021ce4f7d2cdd7..426066bf1dd151 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_timer_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_timer_test.cc @@ -66,8 +66,7 @@ class CudaTimerTest : public ::testing::TestWithParam { ASSERT_THAT(stream->Memset32(&b, 2, byte_length), IsOk()); ASSERT_THAT(stream->MemZero(&c, byte_length), IsOk()); - ASSERT_THAT(stream->ThenLaunch(ThreadDim(), BlockDim(4), add, a, b, c), - IsOk()); + ASSERT_THAT(add.Launch(ThreadDim(), BlockDim(4), stream, a, b, c), IsOk()); } StreamExecutor* executor_; diff --git a/third_party/xla/xla/stream_executor/cuda/defer_relocatable_compilation_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/defer_relocatable_compilation_compilation_provider.cc index 903a7d3dbef79c..bfada1b648147d 100644 --- a/third_party/xla/xla/stream_executor/cuda/defer_relocatable_compilation_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/defer_relocatable_compilation_compilation_provider.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -60,7 +59,7 @@ constexpr const uint8_t kPtxPrefix[] = {'P', 'T', 'X', ':', ' '}; absl::StatusOr DeferRelocatableCompilationCompilationProvider::CompileToRelocatableModule( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const { if (ptx.empty()) return RelocatableModule{}; @@ -103,7 +102,7 @@ DeferRelocatableCompilationCompilationProvider::CompileAndLink( absl::StatusOr DeferRelocatableCompilationCompilationProvider::Compile( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const { return delegate_->Compile(cc, ptx, options); } diff --git a/third_party/xla/xla/stream_executor/cuda/defer_relocatable_compilation_compilation_provider.h b/third_party/xla/xla/stream_executor/cuda/defer_relocatable_compilation_compilation_provider.h index 8a4702021c7e1a..4451ea7255fc86 100644 --- a/third_party/xla/xla/stream_executor/cuda/defer_relocatable_compilation_compilation_provider.h +++ b/third_party/xla/xla/stream_executor/cuda/defer_relocatable_compilation_compilation_provider.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include "absl/status/statusor.h" #include "absl/strings/str_format.h" @@ -57,7 +56,7 @@ class DeferRelocatableCompilationCompilationProvider } absl::StatusOr CompileToRelocatableModule( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const override; absl::StatusOr CompileAndLink( @@ -66,7 +65,7 @@ class DeferRelocatableCompilationCompilationProvider const CompilationOptions& options) const override; absl::StatusOr Compile( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const override; private: diff --git a/third_party/xla/xla/stream_executor/cuda/defer_relocatable_compilation_compilation_provider_test.cc b/third_party/xla/xla/stream_executor/cuda/defer_relocatable_compilation_compilation_provider_test.cc index 52774c84475566..a82102dfc53fd9 100644 --- a/third_party/xla/xla/stream_executor/cuda/defer_relocatable_compilation_compilation_provider_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/defer_relocatable_compilation_compilation_provider_test.cc @@ -17,12 +17,12 @@ limitations under the License. #include #include -#include #include #include #include #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "xla/stream_executor/cuda/compilation_options.h" #include "xla/stream_executor/cuda/compilation_provider.h" #include "xla/stream_executor/cuda/mock_compilation_provider.h" @@ -67,8 +67,8 @@ TEST(DeferRelocatableCompilationCompilationProviderTest, StatusIs(absl::StatusCode::kInvalidArgument)); } -constexpr std::string_view kSomePtxString = "some ptx string"; -constexpr std::string_view kSomeOtherPtxString = "some other ptx string"; +constexpr absl::string_view kSomePtxString = "some ptx string"; +constexpr absl::string_view kSomeOtherPtxString = "some other ptx string"; constexpr CudaComputeCapability kDefaultComputeCapability{10, 0}; constexpr CompilationOptions kDefaultCompilationOptions{}; diff --git a/third_party/xla/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc b/third_party/xla/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc index e0c5138278e676..93dfa1053d8a68 100644 --- a/third_party/xla/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc +++ b/third_party/xla/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc @@ -66,9 +66,9 @@ absl::StatusOr LaunchDelayKernel(Stream* stream) { // Launch a delay kernel into this stream, which will spin until // GetElapsedDuration() is called, the timer is destroyed, or the timeout // in the kernel is reached. - TF_RETURN_IF_ERROR(stream->ThenLaunch(ThreadDim(1, 1, 1), BlockDim(1, 1, 1), - kernel, semaphore.device(), - GpuSemaphoreState::kRelease)); + TF_RETURN_IF_ERROR(kernel.Launch(ThreadDim(1, 1, 1), BlockDim(1, 1, 1), + stream, semaphore.device(), + GpuSemaphoreState::kRelease)); return semaphore; } diff --git a/third_party/xla/xla/stream_executor/cuda/driver_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/driver_compilation_provider.cc index 1996080f7ae7ff..469fa3351221b6 100644 --- a/third_party/xla/xla/stream_executor/cuda/driver_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/driver_compilation_provider.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -31,6 +30,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/activate_context.h" @@ -48,14 +48,14 @@ limitations under the License. namespace stream_executor::cuda { absl::StatusOr DriverCompilationProvider::Compile( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const { return CompileAndLink(cc, {Ptx{std::string{ptx}}}, options); } absl::StatusOr DriverCompilationProvider::CompileToRelocatableModule( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const { return absl::UnavailableError( "Compilation to relocatable module is not " @@ -165,7 +165,7 @@ absl::StatusOr DriverCompilationProvider::CompileAndLink( CHECK(info_log_buffer_size() <= kInfoLogBufferSize); info_log_buffer.resize(info_log_buffer_size()); - std::string_view extension = (cc.major == 9 && cc.minor == 0) ? "a" : ""; + absl::string_view extension = (cc.major == 9 && cc.minor == 0) ? "a" : ""; std::string architecture = absl::StrCat("sm_", cc.major, cc.minor, extension); if (result != CUDA_SUCCESS) { diff --git a/third_party/xla/xla/stream_executor/cuda/driver_compilation_provider.h b/third_party/xla/xla/stream_executor/cuda/driver_compilation_provider.h index fac829bb06916b..e73db347c69e1b 100644 --- a/third_party/xla/xla/stream_executor/cuda/driver_compilation_provider.h +++ b/third_party/xla/xla/stream_executor/cuda/driver_compilation_provider.h @@ -17,7 +17,6 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_CUDA_DRIVER_COMPILATION_PROVIDER_H_ #include -#include #include "absl/status/statusor.h" #include "absl/types/span.h" @@ -37,11 +36,11 @@ class DriverCompilationProvider : public CompilationProvider { bool SupportsCompileAndLink() const override { return true; } absl::StatusOr Compile( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const override; absl::StatusOr CompileToRelocatableModule( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const override; absl::StatusOr CompileAndLink( diff --git a/third_party/xla/xla/stream_executor/cuda/dummy_cuda_binary.cc b/third_party/xla/xla/stream_executor/cuda/dummy_cuda_binary.cc index 807f30b6a24c43..737fb78e7229b6 100644 --- a/third_party/xla/xla/stream_executor/cuda/dummy_cuda_binary.cc +++ b/third_party/xla/xla/stream_executor/cuda/dummy_cuda_binary.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include -#include #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" int main(int argc, char** argv) { if (argc == 1) { @@ -27,25 +27,25 @@ int main(int argc, char** argv) { return -1; } - const auto process_was_called_as = [&](std::string_view binary_name) { + const auto process_was_called_as = [&](absl::string_view binary_name) { return argv[0] == binary_name || absl::EndsWith(argv[0], absl::StrCat("/", binary_name)); }; if (process_was_called_as("ptxas") && - argv[1] == std::string_view{"--version"}) { + argv[1] == absl::string_view{"--version"}) { std::cout << "ptxas dummy V111.2.3\n"; return 0; } if (process_was_called_as("nvlink") && - argv[1] == std::string_view{"--version"}) { + argv[1] == absl::string_view{"--version"}) { std::cout << "nvlink dummy V444.5.6\n"; return 0; } if (process_was_called_as("fatbinary") && - argv[1] == std::string_view{"--version"}) { + argv[1] == absl::string_view{"--version"}) { std::cout << "fatbinary dummy V777.8.9\n"; return 0; } diff --git a/third_party/xla/xla/stream_executor/cuda/mock_compilation_provider.h b/third_party/xla/xla/stream_executor/cuda/mock_compilation_provider.h index cca253d2f94032..23a69fe5c74fdd 100644 --- a/third_party/xla/xla/stream_executor/cuda/mock_compilation_provider.h +++ b/third_party/xla/xla/stream_executor/cuda/mock_compilation_provider.h @@ -17,10 +17,10 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_CUDA_MOCK_COMPILATION_PROVIDER_H_ #include -#include #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/cuda/compilation_options.h" #include "xla/stream_executor/cuda/compilation_provider.h" @@ -34,11 +34,11 @@ class MockCompilationProvider : public CompilationProvider { MOCK_METHOD(bool, SupportsCompileAndLink, (), (const, override)); MOCK_METHOD(std::string, name, (), (const, override)); MOCK_METHOD(absl::StatusOr, Compile, - (const CudaComputeCapability& cc, std::string_view ptx, + (const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options), (const, override)); MOCK_METHOD(absl::StatusOr, CompileToRelocatableModule, - (const CudaComputeCapability& cc, std::string_view ptx, + (const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options), (const, override)); MOCK_METHOD(absl::StatusOr, CompileAndLink, diff --git a/third_party/xla/xla/stream_executor/cuda/nvjitlink_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/nvjitlink_compilation_provider.cc index e80550b5319c15..6dd1f1e215b694 100644 --- a/third_party/xla/xla/stream_executor/cuda/nvjitlink_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/nvjitlink_compilation_provider.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -36,14 +35,14 @@ namespace stream_executor::cuda { absl::StatusOr stream_executor::cuda::NvJitLinkCompilationProvider::Compile( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const { return CompileAndLink(cc, {Ptx{std::string{ptx}}}, options); } absl::StatusOr stream_executor::cuda::NvJitLinkCompilationProvider::CompileToRelocatableModule( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const { return absl::UnavailableError( "Compilation to relocatable module is not supported."); diff --git a/third_party/xla/xla/stream_executor/cuda/nvjitlink_compilation_provider.h b/third_party/xla/xla/stream_executor/cuda/nvjitlink_compilation_provider.h index b8a11711d9784c..b680e0882a1729 100644 --- a/third_party/xla/xla/stream_executor/cuda/nvjitlink_compilation_provider.h +++ b/third_party/xla/xla/stream_executor/cuda/nvjitlink_compilation_provider.h @@ -17,7 +17,6 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_CUDA_NVJITLINK_COMPILATION_PROVIDER_H_ #include -#include #include "absl/status/statusor.h" #include "absl/types/span.h" @@ -35,11 +34,11 @@ class NvJitLinkCompilationProvider : public CompilationProvider { bool SupportsCompileAndLink() const override { return true; } absl::StatusOr Compile( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const override; absl::StatusOr CompileToRelocatableModule( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const override; absl::StatusOr CompileAndLink( diff --git a/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc b/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc index 160f8bfcc50efd..f3342266f85115 100644 --- a/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc +++ b/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/algorithm/container.h" @@ -42,7 +41,7 @@ limitations under the License. namespace stream_executor { -static std::string_view ToString(nvJitLinkResult status) { +static absl::string_view ToString(nvJitLinkResult status) { switch (status) { case NVJITLINK_SUCCESS: return "SUCCESS"; @@ -65,7 +64,8 @@ static std::string_view ToString(nvJitLinkResult status) { } } -static absl::Status ToStatus(nvJitLinkResult status, std::string_view message) { +static absl::Status ToStatus(nvJitLinkResult status, + absl::string_view message) { return absl::UnknownError(absl::StrCat(ToString(status), ": ", message)); } @@ -131,14 +131,15 @@ absl::StatusOr> CompileAndLinkUsingLibNvJitLink( return std::vector(); } - TF_ASSIGN_OR_RETURN((auto [major, minor]), GetNvJitLinkVersion()); - WarnIfBadPtxasVersion("nvJitLink", cc, {major, minor, 0}); + TF_ASSIGN_OR_RETURN(NvJitLinkVersion version, GetNvJitLinkVersion()); + auto [version_major, version_minor] = version; + WarnIfBadPtxasVersion("nvJitLink", cc, {version_major, version_minor, 0}); std::vector cli_args; // On Hopper, default to sm_90a so that all instructions can be used. But // only sm_90 is forward compatible, so don't use sm_90a with newer hardware: // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility - std::string_view extension = (cc.major == 9 && cc.minor == 0) ? "a" : ""; + absl::string_view extension = (cc.major == 9 && cc.minor == 0) ? "a" : ""; std::string architecture = absl::StrCat("sm_", cc.major, cc.minor, extension); cli_args.emplace_back(absl::StrCat("-arch=", architecture)); diff --git a/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.cc index 3a4f05f6821f21..3cebebf368077d 100644 --- a/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include @@ -34,7 +33,7 @@ limitations under the License. namespace stream_executor::cuda { absl::StatusOr> NvptxcompilerCompilationProvider::CompileHelper( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options, bool compile_to_relocatable_module) const { GpuAsmOpts asm_opts{}; @@ -55,7 +54,7 @@ NvptxcompilerCompilationProvider::CompileHelper( } absl::StatusOr NvptxcompilerCompilationProvider::Compile( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const { TF_ASSIGN_OR_RETURN(auto cubin, CompileHelper(cc, ptx, options, @@ -65,7 +64,7 @@ absl::StatusOr NvptxcompilerCompilationProvider::Compile( absl::StatusOr NvptxcompilerCompilationProvider::CompileToRelocatableModule( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const { TF_ASSIGN_OR_RETURN(auto cubin, CompileHelper(cc, ptx, options, diff --git a/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.h b/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.h index 2ab45107abcc0d..5ffdee124c19fe 100644 --- a/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.h +++ b/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include "absl/status/statusor.h" @@ -42,11 +41,11 @@ class NvptxcompilerCompilationProvider : public CompilationProvider { } absl::StatusOr Compile( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const override; absl::StatusOr CompileToRelocatableModule( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const override; absl::StatusOr CompileAndLink( @@ -56,7 +55,7 @@ class NvptxcompilerCompilationProvider : public CompilationProvider { private: absl::StatusOr> CompileHelper( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options, bool compile_to_relocatable_module) const; }; diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.cc b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.cc index 596fb58521a5a5..96a7d003e604b2 100644 --- a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.cc +++ b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.cc @@ -15,23 +15,40 @@ limitations under the License. #include "xla/stream_executor/cuda/ptx_compiler_helpers.h" -#include - +#include "absl/base/call_once.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/semantic_version.h" namespace stream_executor { +namespace { + +static constexpr absl::string_view kPtxasErrorPayloadKey = "ptxas_log"; + +} // namespace + +absl::Status PtxRegisterAllocationError(absl::string_view message) { + absl::Status status = absl::ResourceExhaustedError(message); + status.SetPayload(kPtxasErrorPayloadKey, absl::Cord()); + return status; +} + +bool IsPtxRegisterAllocationError(absl::Status status) { + return status.GetPayload(kPtxasErrorPayloadKey).has_value(); +} -bool IsPtxRegisterAllocationError(std::string_view str) { +bool IsPtxRegisterAllocationError(absl::string_view str) { return absl::StrContains(str, "ptxas fatal") && (absl::StrContains(str, "Register allocation failed") || absl::StrContains(str, "Insufficient registers")); } -absl::Status CreateErrorFromPTXASLog(std::string_view log, - std::string_view architecture, +absl::Status CreateErrorFromPTXASLog(absl::string_view log, + absl::string_view architecture, bool cancel_if_reg_spill) { // It happens when the loaded version of nvjitlink is too old for // the current GPU. Example error message associated with this error @@ -43,7 +60,7 @@ absl::Status CreateErrorFromPTXASLog(std::string_view log, "Loaded PTX assembler is too old for %s.", architecture)); } if (IsPtxRegisterAllocationError(log)) { - return absl::ResourceExhaustedError(log); + return PtxRegisterAllocationError(log); } if (absl::StrContains(log, "warning")) { LOG(INFO) << log; @@ -58,7 +75,7 @@ absl::Status CreateErrorFromPTXASLog(std::string_view log, // Warns if the ptxas version should be upgraded. // Only prints the warning upon the first invocation. -void WarnIfBadPtxasVersion(std::string_view method, +void WarnIfBadPtxasVersion(absl::string_view method, const CudaComputeCapability& cc, SemanticVersion compiler_version) { static absl::once_flag run_once; diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.h b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.h index d061eee6184fd9..10b13b215760dc 100644 --- a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.h +++ b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.h @@ -14,27 +14,33 @@ limitations under the License. ==============================================================================*/ #ifndef XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_HELPERS_H_ #define XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_HELPERS_H_ -#include #include "absl/status/status.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/semantic_version.h" namespace stream_executor { + +// Creates a status with a payload indicating a register allocation error. +absl::Status PtxRegisterAllocationError(absl::string_view message); + // Checks whether ptxas log contains errors related to register allocation. -bool IsPtxRegisterAllocationError(std::string_view); +bool IsPtxRegisterAllocationError(absl::string_view); + +// Checks whether the status is a register allocation error. +bool IsPtxRegisterAllocationError(absl::Status status); // Identifies errors in the ptxas log and creates an error status. // `architecture` is the name of the GPU architecture, e.g. "sm_80" and is only // used for error message generation. If `cancel_if_reg_spill` is true, then a // register spill warning will be treated as an error, otherwise it will be // ignored. -absl::Status CreateErrorFromPTXASLog(std::string_view log, - std::string_view architecture, +absl::Status CreateErrorFromPTXASLog(absl::string_view log, + absl::string_view architecture, bool cancel_if_reg_spill); // Warns if the ptxas version should be upgraded. -void WarnIfBadPtxasVersion(std::string_view method, +void WarnIfBadPtxasVersion(absl::string_view method, const CudaComputeCapability& cc, SemanticVersion compiler_version); } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers_test.cc b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers_test.cc index 55f21fa49c4d9f..a9d40f42693d64 100644 --- a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers_test.cc @@ -15,11 +15,10 @@ limitations under the License. #include "xla/stream_executor/cuda/ptx_compiler_helpers.h" -#include - #include #include #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/test.h" @@ -29,7 +28,7 @@ using ::tsl::testing::IsOk; using ::tsl::testing::StatusIs; // When the compilation succeeds, then the error log is empty. -constexpr std::string_view kPtxasLogSuccessfulCompilation = R"( +constexpr absl::string_view kPtxasLogSuccessfulCompilation = R"( ptxas info : 0 bytes gmem ptxas info : Compiling entry function 'input_concatenate_fusion' for 'sm_80' ptxas info : Function properties for input_concatenate_fusion @@ -37,21 +36,21 @@ ptxas info : Function properties for input_concatenate_fusion ptxas info : Used 10 registers, 368 bytes cmem[0] )"; -constexpr std::string_view kPtxasLogTooOldError = R"( +constexpr absl::string_view kPtxasLogTooOldError = R"( // Something in the log before the error. ptxas fatal : Value 'sm_80' is not defined for option 'gpu-name' ptxas fatal : Ptx assembly aborted due to errors // Something in the log after the error. )"; -constexpr std::string_view kPtxasLogRegisterAllocationError = R"( +constexpr absl::string_view kPtxasLogRegisterAllocationError = R"( // Something in the log before the error. ptxas fatal : (C7600) Register allocation failed with register count of '64'. Compile the program with a higher register target ptxas fatal : Ptx assembly aborted due to errors // Something in the log after the error. )"; -constexpr std::string_view kPtxasLogRegisterSpillWarning = R"( +constexpr absl::string_view kPtxasLogRegisterSpillWarning = R"( // Something in the log before the warning. ptxas warning : Registers are spilled to local memory in function '__kernel', 8 bytes spill stores, 8 bytes spill loads // Something in the log after the warning. @@ -62,7 +61,7 @@ TEST(PtxCompilerHelpersTest, IsPtxRegisterAllocationError) { EXPECT_FALSE(IsPtxRegisterAllocationError(kPtxasLogRegisterSpillWarning)); } -constexpr std::string_view kDefaultArchitecture = "sm_80"; +constexpr absl::string_view kDefaultArchitecture = "sm_80"; TEST(PtxCompilerHelpersTest, CreateErrorFromPTXASLogNoError) { EXPECT_THAT(CreateErrorFromPTXASLog(kPtxasLogSuccessfulCompilation, @@ -102,5 +101,13 @@ TEST(PtxCompilerHelpersTest, IsOk()); } +TEST(PtxCompilerHelpersTest, IsPtxRegisterAllocationErrorStatus) { + EXPECT_TRUE(IsPtxRegisterAllocationError( + PtxRegisterAllocationError("Register allocation failed"))); + EXPECT_FALSE( + IsPtxRegisterAllocationError(absl::ResourceExhaustedError("OOM"))); + EXPECT_FALSE(IsPtxRegisterAllocationError(absl::OkStatus())); +} + } // namespace } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_impl.cc b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_impl.cc index e48d73ca1c729b..daaaf7f891c7c8 100644 --- a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_impl.cc +++ b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_impl.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/algorithm/container.h" @@ -32,6 +31,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/nvPTXCompiler.h" #include "xla/stream_executor/cuda/ptx_compiler.h" @@ -44,7 +44,7 @@ limitations under the License. namespace stream_executor { -static std::string_view ToString(nvPTXCompileResult status) { +static absl::string_view ToString(nvPTXCompileResult status) { switch (status) { case NVPTXCOMPILE_SUCCESS: return "SUCCESS"; @@ -97,7 +97,7 @@ absl::StatusOr> CompileGpuAsmUsingLibNvPtxCompiler( // On Hopper, default to sm_90a so that all instructions can be used. But // only sm_90 is forward compatible, so don't use sm_90a with newer hardware: // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility - std::string_view extension = (cc.major == 9 && cc.minor == 0) ? "a" : ""; + absl::string_view extension = (cc.major == 9 && cc.minor == 0) ? "a" : ""; std::string architecture = absl::StrCat("sm_", cc.major, cc.minor, extension); options.extra_flags.emplace_back(absl::StrCat("-arch=", architecture)); @@ -141,7 +141,7 @@ absl::StatusOr> CompileGpuAsmUsingLibNvPtxCompiler( "Linked libnvptxcompiler is too old for %s.", architecture)); } if (IsPtxRegisterAllocationError(error_log)) { - return absl::ResourceExhaustedError(error_log); + return PtxRegisterAllocationError(error_log); } return absl::InternalError( diff --git a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation.cc b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation.cc index f8b5ee85d142ae..895c21e7aaeefb 100644 --- a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation.cc +++ b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -60,7 +59,7 @@ limitations under the License. namespace stream_executor { static absl::StatusOr GetToolVersionString( - std::string_view binary_path) { + absl::string_view binary_path) { // If binary_path doesn't exist, then tsl::SubProcess will log a bunch of // error messages that have confused users in the past. Therefore we first // check whether the binary_path exists and error out early if not. @@ -91,7 +90,7 @@ static absl::StatusOr GetToolVersionString( } static absl::StatusOr GetToolVersionImpl( - std::string_view tool_path) { + absl::string_view tool_path) { absl::StatusOr tool_version = GetToolVersionString(tool_path); if (!tool_version.ok()) { return absl::FailedPreconditionError( @@ -100,7 +99,7 @@ static absl::StatusOr GetToolVersionImpl( } static constexpr LazyRE2 kVersionRegex = {R"(\bV(\d+)\.(\d+)\.(\d+)\b)"}; SemanticVersion version{0, 0, 0}; - std::string_view vmaj_str, vmin_str, vdot_str; + absl::string_view vmaj_str, vmin_str, vdot_str; if (!RE2::PartialMatch(tool_version.value(), *kVersionRegex, &vmaj_str, &vmin_str, &vdot_str) || !absl::SimpleAtoi(vmaj_str, &version.major()) || @@ -113,7 +112,7 @@ static absl::StatusOr GetToolVersionImpl( return version; } -absl::StatusOr GetToolVersion(std::string_view tool_path) { +absl::StatusOr GetToolVersion(absl::string_view tool_path) { // This is only implementing a static cache. `GetToolVersionImpl` has the // actual business logic. static absl::Mutex mutex(absl::kConstInit); @@ -132,7 +131,7 @@ absl::StatusOr GetToolVersion(std::string_view tool_path) { } absl::StatusOr FindCudaExecutable( - std::string_view binary_name, std::string_view preferred_cuda_dir, + absl::string_view binary_name, absl::string_view preferred_cuda_dir, SemanticVersion minimum_version, absl::Span excluded_versions) { std::string binary_filename = std::string{binary_name}; @@ -146,14 +145,14 @@ absl::StatusOr FindCudaExecutable( // #2 - Check generic CUDA locations if that is preferred over the PATH if (!tsl::PreferPtxasFromPath()) { - for (std::string_view path : tsl::CandidateCudaRoots()) { + for (absl::string_view path : tsl::CandidateCudaRoots()) { candidates.emplace_back(tsl::io::JoinPath(path, "bin", binary_filename)); } } // #3 - Check the PATH environment variable if (const auto* path_env_ptr = std::getenv("PATH")) { - std::string_view path_env{path_env_ptr ? path_env_ptr : ""}; + absl::string_view path_env{path_env_ptr ? path_env_ptr : ""}; #if defined(PLATFORM_WINDOWS) constexpr char kSearchPathSeparator = ';'; @@ -161,7 +160,7 @@ absl::StatusOr FindCudaExecutable( constexpr char kSearchPathSeparator = ':'; #endif - for (std::string_view path : + for (absl::string_view path : absl::StrSplit(path_env, kSearchPathSeparator)) { candidates.emplace_back(tsl::io::JoinPath(path, binary_filename)); } @@ -169,7 +168,7 @@ absl::StatusOr FindCudaExecutable( // #4 - Check generic CUDA locations if we didn't do that already in #2 if (tsl::PreferPtxasFromPath()) { - for (std::string_view path : tsl::CandidateCudaRoots()) { + for (absl::string_view path : tsl::CandidateCudaRoots()) { candidates.emplace_back(tsl::io::JoinPath(path, "bin", binary_filename)); } } @@ -206,7 +205,7 @@ absl::StatusOr FindCudaExecutable( } absl::StatusOr FindCudaExecutable( - std::string_view binary_name, std::string_view preferred_cuda_dir) { + absl::string_view binary_name, absl::string_view preferred_cuda_dir) { static constexpr SemanticVersion kNoMinimumVersion{0, 0, 0}; static constexpr absl::Span kNoExcludedVersions{}; return FindCudaExecutable(binary_name, preferred_cuda_dir, kNoMinimumVersion, @@ -214,10 +213,10 @@ absl::StatusOr FindCudaExecutable( } absl::StatusOr FindPtxAsExecutable( - std::string_view preferred_cuda_dir) { + absl::string_view preferred_cuda_dir) { static constexpr SemanticVersion kMinimumSupportedPtxAsVersion{11, 8, 0}; static constexpr SemanticVersion kBuggyPtxAsVersions[] = {{12, 3, 103}}; - static constexpr std::string_view kPtxAsBinaryName = "ptxas"; + static constexpr absl::string_view kPtxAsBinaryName = "ptxas"; return FindCudaExecutable(kPtxAsBinaryName, preferred_cuda_dir, kMinimumSupportedPtxAsVersion, kBuggyPtxAsVersions); @@ -252,7 +251,7 @@ static void AppendArgsFromOptions(GpuAsmOpts options, } absl::StatusOr> CompileGpuAsmUsingPtxAs( - const CudaComputeCapability& cc, std::string_view ptx, GpuAsmOpts options, + const CudaComputeCapability& cc, absl::string_view ptx, GpuAsmOpts options, bool cancel_if_reg_spill) { TF_ASSIGN_OR_RETURN(std::string ptxas_path, FindPtxAsExecutable(options.preferred_cuda_dir)); @@ -261,8 +260,8 @@ absl::StatusOr> CompileGpuAsmUsingPtxAs( } absl::StatusOr> CompileGpuAsmUsingPtxAs( - std::string_view ptxas_path, const CudaComputeCapability& cc, - std::string_view ptx, GpuAsmOpts options, bool cancel_if_reg_spill) { + absl::string_view ptxas_path, const CudaComputeCapability& cc, + absl::string_view ptx, GpuAsmOpts options, bool cancel_if_reg_spill) { TF_ASSIGN_OR_RETURN(auto version, GetToolVersion(ptxas_path)); WarnIfBadPtxasVersion("ptxas", cc, version); @@ -334,7 +333,7 @@ absl::StatusOr> CompileGpuAsmUsingPtxAs( } if (IsPtxRegisterAllocationError(stderr_output)) { LOG(INFO) << stderr_output; - return absl::ResourceExhaustedError(stderr_output); + return PtxRegisterAllocationError(stderr_output); } return absl::InternalError( @@ -364,7 +363,7 @@ absl::StatusOr> CompileGpuAsmUsingPtxAs( } absl::StatusOr GetAsmCompilerVersion( - std::string_view preferred_cuda_dir) { + absl::string_view preferred_cuda_dir) { TF_ASSIGN_OR_RETURN(std::string ptxas_path, FindPtxAsExecutable(preferred_cuda_dir)); return GetToolVersion(ptxas_path); @@ -454,17 +453,17 @@ absl::StatusOr> BundleGpuAsmUsingFatbin( } absl::StatusOr FindNvlinkExecutable( - std::string_view preferred_cuda_dir) { + absl::string_view preferred_cuda_dir) { static constexpr SemanticVersion kMinimumNvlinkVersion{11, 8, 0}; static constexpr absl::Span kNoExcludedVersions{}; - static constexpr std::string_view kNvLinkBinaryName = "nvlink"; + static constexpr absl::string_view kNvLinkBinaryName = "nvlink"; return FindCudaExecutable(kNvLinkBinaryName, preferred_cuda_dir, kMinimumNvlinkVersion, kNoExcludedVersions); } absl::StatusOr GetNvLinkVersion( - std::string_view preferred_cuda_dir) { + absl::string_view preferred_cuda_dir) { // Make sure nvlink exists and is executable. TF_ASSIGN_OR_RETURN(std::string bin_path, FindNvlinkExecutable(preferred_cuda_dir)); @@ -474,7 +473,7 @@ absl::StatusOr GetNvLinkVersion( absl::StatusOr> LinkUsingNvlink( stream_executor::CudaComputeCapability cc, - std::string_view preferred_cuda_dir, + absl::string_view preferred_cuda_dir, absl::Span> images) { TF_ASSIGN_OR_RETURN(std::string bin_path, FindNvlinkExecutable(preferred_cuda_dir)); @@ -483,7 +482,7 @@ absl::StatusOr> LinkUsingNvlink( } absl::StatusOr> LinkUsingNvlink( - std::string_view nvlink_path, stream_executor::CudaComputeCapability cc, + absl::string_view nvlink_path, stream_executor::CudaComputeCapability cc, absl::Span> images) { LOG_FIRST_N(INFO, 1) << "Using nvlink for parallel linking"; @@ -516,7 +515,7 @@ absl::StatusOr> LinkUsingNvlink( }; std::vector args; args.push_back(std::string{nvlink_path}); - std::string_view extension = (cc.major == 9 && cc.minor == 0) ? "a" : ""; + absl::string_view extension = (cc.major == 9 && cc.minor == 0) ? "a" : ""; args.push_back(absl::StrCat("-arch=sm_", cc.major, cc.minor, extension)); for (int i = 0; i < images.size(); i++) { args.push_back(temp_files[i]); diff --git a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation.h b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation.h index 6bb374b9d6af67..f052da91069ca8 100644 --- a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation.h +++ b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include "absl/status/statusor.h" @@ -36,14 +35,14 @@ namespace stream_executor { // 'options' is used to query for the CUDA location in case it is // customized in a passed flag, and for controlling ptxas optimizations. absl::StatusOr> CompileGpuAsmUsingPtxAs( - const CudaComputeCapability& cc, std::string_view ptx_contents, + const CudaComputeCapability& cc, absl::string_view ptx_contents, GpuAsmOpts options, bool cancel_if_reg_spill = false); // Like the above, but uses the ptxas_binary from `ptxas_path` instead of // using `FindCudaExecutable` to find it. absl::StatusOr> CompileGpuAsmUsingPtxAs( - std::string_view ptxas_path, const CudaComputeCapability& cc, - std::string_view ptx_contents, GpuAsmOpts options, + absl::string_view ptxas_path, const CudaComputeCapability& cc, + absl::string_view ptx_contents, GpuAsmOpts options, bool cancel_if_reg_spill = false); // Finds the CUDA executable with the given binary_name @@ -53,35 +52,35 @@ absl::StatusOr> CompileGpuAsmUsingPtxAs( // A binary is only considered if it is of at least `minimum_version` and not // in `excluded_versions`. absl::StatusOr FindCudaExecutable( - std::string_view binary_name, std::string_view preferred_cuda_dir, + absl::string_view binary_name, absl::string_view preferred_cuda_dir, SemanticVersion minimum_version, absl::Span excluded_versions); // Same as above, but with no version constraints. absl::StatusOr FindCudaExecutable( - std::string_view binary_name, std::string_view preferred_cuda_dir); + absl::string_view binary_name, absl::string_view preferred_cuda_dir); // Returns the path to the first found ptxas binary that fulfills our version // requirements. absl::StatusOr FindPtxAsExecutable( - std::string_view preferred_cuda_dir); + absl::string_view preferred_cuda_dir); // Returns the path to the first found nvlink binary that fulfills our version // requirements. absl::StatusOr FindNvlinkExecutable( - std::string_view preferred_cuda_dir); + absl::string_view preferred_cuda_dir); // Runs tool --version and parses its version string. All the usual CUDA // tools are supported. -absl::StatusOr GetToolVersion(std::string_view tool_path); +absl::StatusOr GetToolVersion(absl::string_view tool_path); // On NVIDIA GPUs, returns the version of the ptxas command line tool. absl::StatusOr GetAsmCompilerVersion( - std::string_view preferred_cuda_dir); + absl::string_view preferred_cuda_dir); // On NVIDIA GPUs, returns the version of the nvlink command line tool. absl::StatusOr GetNvLinkVersion( - std::string_view preferred_cuda_dir); + absl::string_view preferred_cuda_dir); // Bundles the GPU machine code (cubins) and PTX if requested and returns the // resulting binary (i.e. a fatbin) as a byte array. @@ -91,13 +90,13 @@ absl::StatusOr> BundleGpuAsmUsingFatbin( // Links the given CUBIN `images` using nvlink. absl::StatusOr> LinkUsingNvlink( stream_executor::CudaComputeCapability cc, - std::string_view preferred_cuda_dir, + absl::string_view preferred_cuda_dir, absl::Span> images); // The same as above, but uses the nvlink_path instead of // `FindCudaExecutable` to find the nvlink binary. absl::StatusOr> LinkUsingNvlink( - std::string_view nvlink_path, stream_executor::CudaComputeCapability cc, + absl::string_view nvlink_path, stream_executor::CudaComputeCapability cc, absl::Span> images); } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.cc index 2e46f85d025a4d..52ac2f2cecaa0e 100644 --- a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -38,7 +37,7 @@ namespace stream_executor::cuda { absl::StatusOr> SubprocessCompilationProvider::CompileHelper( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options, bool compile_to_relocatable_module) const { GpuAsmOpts asm_opts{}; @@ -59,7 +58,7 @@ SubprocessCompilationProvider::CompileHelper( } absl::StatusOr SubprocessCompilationProvider::Compile( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const { TF_ASSIGN_OR_RETURN(auto cubin, CompileHelper(cc, ptx, options, @@ -69,7 +68,7 @@ absl::StatusOr SubprocessCompilationProvider::Compile( absl::StatusOr SubprocessCompilationProvider::CompileToRelocatableModule( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const { TF_ASSIGN_OR_RETURN(auto cubin, CompileHelper(cc, ptx, options, diff --git a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.h b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.h index dc3b6c156d4f3a..2960b3c657476f 100644 --- a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.h +++ b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include @@ -40,11 +39,11 @@ class SubprocessCompilationProvider : public CompilationProvider { path_to_nvlink_(std::move(path_to_nvlink)) {} absl::StatusOr Compile( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const override; absl::StatusOr CompileToRelocatableModule( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const override; absl::StatusOr CompileAndLink( @@ -59,7 +58,7 @@ class SubprocessCompilationProvider : public CompilationProvider { private: absl::StatusOr> CompileHelper( - const CudaComputeCapability& cc, std::string_view ptx, + const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options, bool compile_to_relocatable_module) const; diff --git a/third_party/xla/xla/stream_executor/device_description.h b/third_party/xla/xla/stream_executor/device_description.h index f6bb2e4a41ad4e..f02102753f4650 100644 --- a/third_party/xla/xla/stream_executor/device_description.h +++ b/third_party/xla/xla/stream_executor/device_description.h @@ -59,7 +59,7 @@ struct CudaComputeCapability { this->minor = minor; } // cuda arch format "major.minor", example: "8.6". - explicit CudaComputeCapability(const std::string &cuda_arch_name) { + explicit CudaComputeCapability(std::string cuda_arch_name) { std::vector split = absl::StrSplit(cuda_arch_name, '.'); assert(split.size() == 2); this->major = std::stoi(split[0]); @@ -236,6 +236,8 @@ class RocmComputeCapability { bool has_fp8_support() const { return gfx9_mi300(); } + std::string ToString() const { return gcn_arch_name(); } + RocmComputeCapabilityProto ToProto() const { RocmComputeCapabilityProto proto; proto.set_gcn_arch_name(gcn_arch_name_); diff --git a/third_party/xla/xla/stream_executor/device_memory.h b/third_party/xla/xla/stream_executor/device_memory.h index 43b645b4c345df..d599faadf7562f 100644 --- a/third_party/xla/xla/stream_executor/device_memory.h +++ b/third_party/xla/xla/stream_executor/device_memory.h @@ -31,7 +31,7 @@ limitations under the License. #include #include "absl/base/attributes.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace stream_executor { @@ -83,8 +83,7 @@ class DeviceMemoryBase { // Warning: note that the pointer returned is not necessarily directly to // device virtual address space, but is platform-dependent. - void *opaque() { return opaque_; } - const void *opaque() const { return opaque_; } + void *opaque() const { return opaque_; } // Returns the payload of this memory region. uint64_t payload() const { return payload_; } @@ -129,7 +128,7 @@ class DeviceMemoryBase { // that represents one or more integers in Device memory. // // Thread-compatible. -template +template class DeviceMemory final : public DeviceMemoryBase { public: // Default constructor instantiates a null-pointed, zero-sized memory region. @@ -144,29 +143,25 @@ class DeviceMemory final : public DeviceMemoryBase { SetPayload(other.payload()); } - // Returns the number of elements of type ElemT that constitute this + // Returns the number of elements of type T that constitute this // allocation. - uint64_t ElementCount() const { return size() / sizeof(ElemT); } + uint64_t ElementCount() const { return size() / sizeof(T); } // Returns pointer to the allocated data - ElemT *base() { return reinterpret_cast(opaque()); } - const ElemT *base() const { - return reinterpret_cast(opaque()); - } + T *base() const { return reinterpret_cast(opaque()); } // Creates a typed area of DeviceMemory with a given opaque pointer and the // quantity of bytes in the allocation. This function is broken out to // distinguish bytes from an element count. - static DeviceMemory MakeFromByteSize(void *opaque, uint64_t bytes) { - return DeviceMemory(opaque, bytes); + static DeviceMemory MakeFromByteSize(void *opaque, uint64_t bytes) { + return DeviceMemory(opaque, bytes); } // Creates a memory region (slice) inside another allocated memory region. - // Offset and size are specified in terms of ElemT elements. - DeviceMemory GetSlice(uint64_t element_offset, - uint64_t element_count) { - return DeviceMemory(GetByteSlice(sizeof(ElemT) * element_offset, - sizeof(ElemT) * element_count)); + // Offset and size are specified in terms of T elements. + DeviceMemory GetSlice(uint64_t element_offset, uint64_t element_count) { + return DeviceMemory( + GetByteSlice(sizeof(T) * element_offset, sizeof(T) * element_count)); } protected: diff --git a/third_party/xla/xla/stream_executor/fft.h b/third_party/xla/xla/stream_executor/fft.h index 937ae639eed9f0..3349beb7146261 100644 --- a/third_party/xla/xla/stream_executor/fft.h +++ b/third_party/xla/xla/stream_executor/fft.h @@ -37,8 +37,7 @@ limitations under the License. // TF_CHECK_OK(stream.BlockHostUntilDone()); // // By using stream operations in this manner the user can easily intermix custom -// kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned FFT -// routines. +// kernel launches with these pre-canned FFT routines. #ifndef XLA_STREAM_EXECUTOR_FFT_H_ #define XLA_STREAM_EXECUTOR_FFT_H_ diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 590296e7fcaca9..541585ed1bdcb2 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -366,12 +366,23 @@ cc_library( "manual", ], deps = [ + ":gpu_asm_opts", "//xla/stream_executor:device_memory", "//xla/stream_executor:kernel", - "//xla/stream_executor:kernel_spec", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:typed_kernel_factory", + "//xla/stream_executor/cuda:cuda_asm_compiler", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/platform:statusor", ], ) @@ -386,10 +397,8 @@ gpu_kernel_library( ":gpu_asm_opts", "//xla/stream_executor:device_memory", "//xla/stream_executor:kernel", - "//xla/stream_executor:kernel_spec", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:typed_kernel_factory", - "@com_google_absl//absl/base", "@com_google_absl//absl/status:statusor", "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform:statusor", @@ -403,12 +412,11 @@ gpu_only_cc_library( "redzone_allocator_kernel.h", ], hdrs = ["redzone_allocator.h"], - visibility = internal_visibility([ - "//xla/service/gpu:__subpackages__", - "//xla/stream_executor:__subpackages__", - "//tensorflow/core/kernels:__subpackages__", - ]), + visibility = internal_visibility([":friends"]), deps = [ + ":gpu_asm_opts", + "//xla:shape_util", + "//xla/service/gpu:stream_executor_util", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:device_memory_handle", @@ -555,6 +563,7 @@ gpu_kernel_library( testonly = 1, srcs = ["gpu_test_kernels.cu.cc"], hdrs = ["gpu_test_kernels.h"], + linkstatic = True, tags = ["gpu"], deps = [ "//xla/stream_executor:kernel_spec", @@ -565,46 +574,25 @@ gpu_kernel_library( ]), ) -genrule( - name = "gpu_test_kernels_fatbin_extractor", - testonly = True, - srcs = [":gpu_test_kernels"], - outs = ["gpu_test_kernels.fatbin"], - cmd = """ - STATIC_LIBRARY="" - for src in $(SRCS); do - if [[ $$src == *.a ]]; then - STATIC_LIBRARY=$$src - break - fi - done - - if [[ -z $$STATIC_LIBRARY ]]; then - echo "No static library found in $(SRCS)" >&2 - exit 1 - fi - - $(OBJCOPY) "--dump-section=.nv_fatbin=$@" "$$STATIC_LIBRARY" || true - - if [ ! -f "$@" ]; then - # binutils' objcopy doesn't return a non-zero exit code if the - # section was not found, so we need to check for the file's existence instead. - $(OBJCOPY) "--dump-section=.hip_fatbin=$@" "$$STATIC_LIBRARY" - fi - """, - tags = ["gpu"], - toolchains = ["@bazel_tools//tools/cpp:current_cc_toolchain"], -) - cc_library( name = "gpu_test_kernels_fatbin", testonly = True, srcs = ["gpu_test_kernels_fatbin.cc"], hdrs = ["gpu_test_kernels_fatbin.h"], - data = [":gpu_test_kernels_fatbin_extractor"], + data = [ + ":gpu_test_kernels", + ], tags = ["gpu"], deps = [ + ":gpu_init_impl", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:test", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Object", + "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", @@ -612,6 +600,19 @@ cc_library( ], ) +xla_test( + name = "gpu_test_kernels_fatbin_test", + srcs = ["gpu_test_kernels_fatbin_test.cc"], + backends = ["gpu"], + data = [":gpu_test_kernels_fatbin"], + deps = [ + ":gpu_test_kernels_fatbin", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + xla_test( name = "gpu_kernel_test", srcs = ["gpu_kernel_test.cc"], @@ -639,6 +640,7 @@ xla_test( "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", ], + data = [":gpu_test_kernels_fatbin"], ) xla_test( @@ -773,10 +775,10 @@ xla_test( local_defines = if_cuda_is_configured([ 'GPU_SPEC_FILE_NAMES=(std::string[]){\\"a100_pcie_80\\", \\"a100_sxm_40\\", \ \\"a100_sxm_80\\", \\"a6000\\", \\"h100_pcie\\", \\"h100_sxm\\", \\"p100\\", \\"v100\\"}', - 'PLATFORM_NAME=\\"CUDA\\"' + 'PLATFORM_NAME=\\"CUDA\\"', ]) + if_rocm_is_configured([ 'GPU_SPEC_FILE_NAMES=(std::string[]){\\"mi200\\"}', - 'PLATFORM_NAME=\\"ROCM\\"' + 'PLATFORM_NAME=\\"ROCM\\"', ]), deps = [ "//xla/service:platform_util", @@ -787,6 +789,8 @@ xla_test( "//xla/stream_executor:stream_executor_h", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index a889ffa095a625..70dabe4c9cb699 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -32,12 +31,12 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/bit_pattern.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" @@ -62,7 +61,7 @@ using GraphConditionalHandle = GpuCommandBuffer::GraphConditionalHandle; using GraphConditionalHandles = absl::Span; namespace { -std::string_view to_string(State state) { +absl::string_view to_string(State state) { switch (state) { case State::kCreate: return "create"; @@ -106,7 +105,7 @@ static std::atomic alive_execs(0); GpuCommandBuffer::GpuCommandBuffer(Mode mode, StreamExecutor* parent) : mode_(mode), parent_(parent) { - execution_scopes_.try_emplace(kDefaulExecutionScope); + execution_scopes_.try_emplace(kDefaultExecutionScope); } GpuCommandBuffer::Dependencies GpuCommandBuffer::GetBarrier( @@ -119,7 +118,7 @@ GpuCommandBuffer::Dependencies GpuCommandBuffer::GetBarrier( absl::Status GpuCommandBuffer::DisableBarriersExecution( GpuCommandBuffer& root_command_buffer) { - ExecutionScope& execution_scope = execution_scopes_[kDefaulExecutionScope]; + ExecutionScope& execution_scope = execution_scopes_[kDefaultExecutionScope]; for (GpuGraphBarrierInfo& barrier : execution_scope.barriers) { if (barrier.is_barrier_node) { @@ -670,8 +669,8 @@ absl::Status GpuCommandBuffer::For(ExecutionScopeId execution_scope_id, TF_RETURN_IF_ERROR(body->Barrier()); // Decide if we want to continue loop iteration. - return body->LaunchSetForConditionKernel(kDefaulExecutionScope, conditional, - loop_counter, num_iteration); + return body->LaunchSetForConditionKernel( + kDefaultExecutionScope, conditional, loop_counter, num_iteration); }; std::array builders = {std::move(body)}; @@ -695,9 +694,9 @@ absl::Status GpuCommandBuffer::While(ExecutionScopeId execution_scope_id, auto body = [&](GpuCommandBuffer* body, GraphConditionalHandle conditional) { TF_RETURN_IF_ERROR(body_builder(body)); TF_RETURN_IF_ERROR(body->Barrier()); - TF_RETURN_IF_ERROR(cond_builder(kDefaulExecutionScope, body)); + TF_RETURN_IF_ERROR(cond_builder(kDefaultExecutionScope, body)); TF_RETURN_IF_ERROR(body->Barrier()); - return body->LaunchSetWhileConditionKernel(kDefaulExecutionScope, + return body->LaunchSetWhileConditionKernel(kDefaultExecutionScope, conditional, pred); }; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 9c580a1986f6cd..886713d2277bd9 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -155,11 +155,11 @@ class GpuCommandBuffer : public CommandBuffer { absl::Span barriers(ExecutionScopeId id) const; absl::Span nodes() const { - return nodes(kDefaulExecutionScope); + return nodes(kDefaultExecutionScope); } absl::Span barriers() const { - return barriers(kDefaulExecutionScope); + return barriers(kDefaultExecutionScope); } // Returns the list of dependencies for a given node. `node` must be a node diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc index 188513c78c9090..afebe70913a135 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc @@ -198,9 +198,9 @@ TEST(CudaCommandBufferTest, TraceSingleKernel) { TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer, TraceCommandBufferFactory::Create( executor, [&](Stream* stream) { - return stream->Launch( + return add->Launch( ThreadDim(), BlockDim(4), - *add, args); + stream, args); }, primary)); @@ -1663,7 +1663,7 @@ static void BM_TraceCommandBuffer(benchmark::State& state) { for (auto s : state) { auto launch_kernels = [&](Stream* stream) { for (int i = 1; i < state.range(0); ++i) { - CHECK_OK(stream->ThenLaunch(ThreadDim(), BlockDim(4), add, b, b, b)); + CHECK_OK(add.Launch(ThreadDim(), BlockDim(4), stream, b, b, b)); } return absl::OkStatus(); }; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator_test.cc index aa98c212aaf4dc..e4719453343483 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include "absl/log/check.h" #include "absl/strings/ascii.h" diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_device_info_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_device_info_test.cc index d77ba42f2d497f..990e26a6d8897e 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_device_info_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_device_info_test.cc @@ -13,7 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc index c0fc79b3248db7..a99a15ae62b21c 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include #include -#include #include #include @@ -74,7 +73,7 @@ class GpuKernelTest : public ::testing::Test { // Launch kernel. ASSERT_TRUE( - stream->ThenLaunch(ThreadDim(), BlockDim(4), add, a, b, c).ok()); + add.Launch(ThreadDim(), BlockDim(4), stream.get(), a, b, c).ok()); // Copy data back to host. std::vector dst(4, 42); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.h b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.h index e12b054e40e4c0..f64d1015f2ad4e 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.h @@ -16,8 +16,6 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_TEST_KERNELS_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_TEST_KERNELS_H_ -#include - #include "xla/stream_executor/kernel_spec.h" namespace stream_executor::gpu { @@ -38,7 +36,7 @@ namespace internal { // } // // Easiest way to get PTX from C++ is to use https://godbolt.org. -inline constexpr std::string_view kAddI32KernelPtx = R"( +inline constexpr absl::string_view kAddI32KernelPtx = R"( .version 4.0 .target sm_50 .address_size 64 diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin.cc b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin.cc index d4fd7f68e83418..08d40a2794015c 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin.cc @@ -17,9 +17,23 @@ limitations under the License. #include #include +#include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Object/Archive.h" +#include "llvm/Object/ELFObjectFile.h" +#include "llvm/Object/ObjectFile.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/MemoryBuffer.h" +#include "xla/stream_executor/gpu/gpu_init.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" @@ -27,13 +41,85 @@ limitations under the License. namespace stream_executor::gpu { -absl::StatusOr> GetGpuTestKernelsFatbin() { +namespace { + +// Reads an archive file, searches for a section that starts with +// 'fatbin_section_prefix' and returns the contents of that section as a vector +// of bytes. +absl::StatusOr> GetFatbinFromArchive( + llvm::StringRef archive_path, llvm::StringRef fatbin_section_prefix) { tsl::Env* env = tsl::Env::Default(); - std::string file_path = - tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "stream_executor", "gpu", - "gpu_test_kernels.fatbin"); + std::string file_contents; - TF_RETURN_IF_ERROR(tsl::ReadFileToString(env, file_path, &file_contents)); - return std::vector(file_contents.begin(), file_contents.end()); + TF_RETURN_IF_ERROR( + tsl::ReadFileToString(env, std::string(archive_path), &file_contents)); + + const auto buffer = llvm::MemoryBuffer::getMemBuffer( + llvm::StringRef(file_contents), + /*BufferName=*/"", /*RequiresNullTerminator=*/false); + + auto archive_ptr = llvm::object::Archive::create(buffer->getMemBufferRef()); + + if (!archive_ptr) { + return absl::InternalError(llvm::toString(archive_ptr.takeError())); + } + + const llvm::object::Archive* archive = archive_ptr.get().get(); + + llvm::Error archive_error = llvm::Error::success(); + for (const auto& child : archive->children(archive_error)) { + if (archive_error) { + return absl::InternalError(llvm::toString(std::move(archive_error))); + } + + auto binary = child.getAsBinary(); + if (!binary) { + continue; + } + + auto executable_elf_object_file_ptr = + llvm::dyn_cast(binary.get()); + if (!executable_elf_object_file_ptr) { + continue; + } + + const auto executable_elf_object_file = + executable_elf_object_file_ptr.get(); + + for (const auto& section : executable_elf_object_file->sections()) { + if (absl::StartsWith(section.getName().get().str(), + fatbin_section_prefix)) { + const std::string fatbin_contents = section.getContents().get().str(); + return std::vector(fatbin_contents.begin(), + fatbin_contents.end()); + } + } + } + + return absl::InternalError("Fatbin section not found in generated archive."); +} + +} // namespace + +absl::StatusOr> GetGpuTestKernelsFatbin() { + const std::string platform_name = GpuPlatformName(); + std::string archive_filename; + std::string fatbin_prefix; + + if (platform_name == "CUDA") { + archive_filename = "libgpu_test_kernels_cuda.a"; + fatbin_prefix = ".nv_fatbin"; + } else if (platform_name == "ROCM") { + archive_filename = "libgpu_test_kernels_rocm.a"; + fatbin_prefix = ".hip_fatbin"; + } else { + return absl::InternalError("Unsupported GPU platform: " + platform_name); + } + + std::string file_path = + tsl::io::JoinPath("external", "local_xla", "xla", "stream_executor", + "gpu", archive_filename); + + return GetFatbinFromArchive(file_path, fatbin_prefix); } } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin_test.cc new file mode 100644 index 00000000000000..5295288e17cc2f --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin_test.cc @@ -0,0 +1,35 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/gpu/gpu_test_kernels_fatbin.h" + +#include +#include + +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { + +TEST(GpuTestKernelsFatbinTest, GetGpuTestKernelsFatbin) { + std::vector fatbin; + + TF_ASSERT_OK_AND_ASSIGN(fatbin, GetGpuTestKernelsFatbin()); + EXPECT_FALSE(fatbin.empty()); +} + +} // namespace +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc index 34b23d714591dd..408e76996a04cd 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc @@ -30,8 +30,12 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "xla/service/gpu/stream_executor_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_handle.h" +#include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/gpu/redzone_allocator_kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream.h" @@ -159,10 +163,11 @@ static absl::StatusOr CheckRedzoneHost( // Run the redzone checker on the provided buffer redzone. // // Increment out_param if mismatch occurs. -static absl::Status RunRedzoneChecker( - Stream* stream, const DeviceMemory& redzone, - uint8_t redzone_pattern, const DeviceMemory& out_param, - const ComparisonKernel& comparison_kernel) { +static absl::Status RunRedzoneChecker(Stream* stream, + const DeviceMemory& redzone, + uint8_t redzone_pattern, + const DeviceMemory& out_param, + ComparisonKernel& comparison_kernel) { StreamExecutor* executor = stream->parent(); if (redzone.size() == 0) { @@ -175,9 +180,9 @@ static absl::Status RunRedzoneChecker( int64_t block_count = tsl::MathUtil::CeilOfRatio(num_elements, threads_per_block); - TF_RETURN_IF_ERROR(stream->ThenLaunch( - ThreadDim(threads_per_block), BlockDim(block_count), comparison_kernel, - redzone, redzone_pattern, redzone.size(), out_param)); + TF_RETURN_IF_ERROR(comparison_kernel.Launch( + ThreadDim(threads_per_block), BlockDim(block_count), stream, redzone, + redzone_pattern, redzone.size(), out_param)); return absl::OkStatus(); } @@ -202,7 +207,7 @@ static absl::Status ReinitializeRedzone(Stream* stream, static absl::StatusOr CheckRedzonesForBuffer( Stream* stream, DeviceMemoryBase memory, const DeviceMemory& out_param, - const ComparisonKernel& comparison_kernel, int64_t user_allocation_size, + ComparisonKernel& comparison_kernel, int64_t user_allocation_size, uint64_t redzone_size, uint8_t redzone_pattern) { int64_t rhs_slop = RoundUpToNearest(user_allocation_size, kRhsRedzoneAlign) - @@ -250,10 +255,22 @@ static absl::StatusOr CheckRedzonesForBuffer( return RedzoneCheckStatus::OK(); } +absl::StatusOr RedzoneAllocator::CreateBuffer( + const xla::Shape& shape, bool initialize_buffers, int64_t& rng_state) { + TF_ASSIGN_OR_RETURN(stream_executor::DeviceMemoryBase buffer, + AllocateBytes(xla::ShapeUtil::ByteSizeOf(shape))); + if (initialize_buffers) { + xla::gpu::InitializeBuffer(stream(), shape.element_type(), &rng_state, + buffer); + } + return buffer; +} + absl::StatusOr RedzoneAllocator::CheckRedzones() const { StreamExecutor* executor = stream_->parent(); - TF_ASSIGN_OR_RETURN(ComparisonKernel kernel, GetComparisonKernel(executor)); + TF_ASSIGN_OR_RETURN(ComparisonKernel * kernel, + GetComparisonKernel(stream_->parent(), GpuAsmOpts())); stream_executor::DeviceMemoryHandle out_param( executor, executor->AllocateScalar()); @@ -265,7 +282,7 @@ absl::StatusOr RedzoneAllocator::CheckRedzones() const { RedzoneCheckStatus redzone_status, CheckRedzonesForBuffer(stream_, *buf_and_size.first, DeviceMemory(out_param.memory()), - kernel, buf_and_size.second, redzone_size_, + *kernel, buf_and_size.second, redzone_size_, redzone_pattern_)); if (!redzone_status.ok()) { return redzone_status; diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.h b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.h index 2fa11f1d174447..dba6fe2edd5af6 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.h +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/shape.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/scratch_allocator.h" @@ -103,6 +104,12 @@ class RedzoneAllocator : public ScratchAllocator { Stream* stream() const { return stream_; } + // Create a buffer for a given operation using redzone checker, initialize + // based on a given rng state. + absl::StatusOr CreateBuffer(const xla::Shape& shape, + bool initialize_buffers, + int64_t& rng_state); + private: const int device_ordinal_; Stream* stream_; diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel.h b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel.h index 578ddd92e46438..7f1a3c3420ae0c 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel.h +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/stream_executor.h" @@ -33,7 +34,8 @@ using ComparisonKernel = TypedKernel, uint8_t, uint64_t, // buffer_address // + buffer_length]` that is not equal to `redzone_pattern`, // `*mismatch_count_ptr` gets incremented by 1. -absl::StatusOr GetComparisonKernel(StreamExecutor* executor); +absl::StatusOr GetComparisonKernel(StreamExecutor* executor, + GpuAsmOpts gpu_asm_opts); } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc index 2cde896383c71a..3c50e648503677 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc @@ -14,15 +14,56 @@ limitations under the License. ==============================================================================*/ #include +#include +#include +#include "absl/base/call_once.h" +#include "absl/base/const_init.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/node_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/stream_executor/cuda/cuda_asm_compiler.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/gpu/redzone_allocator_kernel.h" -#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/kernel.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/typed_kernel_factory.h" +#include "tsl/platform/statusor.h" namespace stream_executor { +// Maintains a cache of pointers to loaded kernels +template +static absl::StatusOr*> LoadKernelOrGetPtr( + StreamExecutor* executor, absl::string_view kernel_name, + absl::string_view ptx, absl::Span cubin_data) { + using KernelPtrCacheKey = + std::tuple; + + static absl::Mutex kernel_ptr_cache_mutex(absl::kConstInit); + static auto& kernel_ptr_cache ABSL_GUARDED_BY(kernel_ptr_cache_mutex) = + *new absl::node_hash_map>(); + KernelPtrCacheKey kernel_ptr_cache_key{executor, kernel_name, ptx}; + absl::MutexLock lock(&kernel_ptr_cache_mutex); + + auto it = kernel_ptr_cache.find(kernel_ptr_cache_key); + if (it == kernel_ptr_cache.end()) { + TF_ASSIGN_OR_RETURN(TypedKernel loaded, + (TypedKernelFactory::Create( + executor, kernel_name, ptx, cubin_data))); + it = + kernel_ptr_cache.emplace(kernel_ptr_cache_key, std::move(loaded)).first; + } + + CHECK(it != kernel_ptr_cache.end()); + return &it->second; +} + // PTX blob for the function which checks that every byte in // input_buffer (length is buffer_length) is equal to redzone_pattern. // @@ -39,7 +80,7 @@ namespace stream_executor { // } // // Code must compile for the oldest GPU XLA may be compiled for. -static const char* kRedzoneCheckerPtx = R"( +static const char* redzone_checker_ptx = R"( .version 4.2 .target sm_30 .address_size 64 @@ -79,11 +120,27 @@ static const char* kRedzoneCheckerPtx = R"( } )"; -absl::StatusOr GetComparisonKernel(StreamExecutor* executor) { - MultiKernelLoaderSpec spec(/*arity=*/4); - spec.AddCudaPtxInMemory(kRedzoneCheckerPtx, "redzone_checker"); +absl::StatusOr GetComparisonKernel(StreamExecutor* executor, + GpuAsmOpts gpu_asm_opts) { + absl::Span compiled_ptx = {}; + absl::StatusOr> compiled_ptx_or = + CompileGpuAsmOrGetCached( + executor->GetDeviceDescription().cuda_compute_capability(), + redzone_checker_ptx, gpu_asm_opts); + if (compiled_ptx_or.ok()) { + compiled_ptx = compiled_ptx_or.value(); + } else { + static absl::once_flag ptxas_not_found_logged; + absl::call_once(ptxas_not_found_logged, [&]() { + LOG(WARNING) << compiled_ptx_or.status() + << "\nRelying on driver to perform ptx compilation. " + << "\nModify $PATH to customize ptxas location." + << "\nThis message will be only logged once."; + }); + } - return TypedKernelFactory, uint8_t, uint64_t, - DeviceMemory>::Create(executor, spec); + return LoadKernelOrGetPtr, uint8_t, uint64_t, + DeviceMemory>( + executor, "redzone_checker", redzone_checker_ptx, compiled_ptx); } } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc index 87a254a34d9a75..2e701f2c0ddbb0 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc @@ -15,19 +15,19 @@ limitations under the License. #include -#include "absl/base/casts.h" #include "absl/status/statusor.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/redzone_allocator_kernel.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/typed_kernel_factory.h" +#include "tsl/platform/statusor.h" namespace { -__global__ void redzone_checker(uint8_t* input_buffer, uint8_t redzone_pattern, - uint64_t buffer_length, - uint32_t* out_mismatched_ptr) { +__global__ void redzone_checker_kernel(uint8_t* input_buffer, + uint8_t redzone_pattern, + uint64_t buffer_length, + uint32_t* out_mismatched_ptr) { uint64_t idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx >= buffer_length) return; if (input_buffer[idx] != redzone_pattern) atomicAdd(out_mismatched_ptr, 1); @@ -36,12 +36,16 @@ __global__ void redzone_checker(uint8_t* input_buffer, uint8_t redzone_pattern, namespace stream_executor { -absl::StatusOr GetComparisonKernel(StreamExecutor* executor) { - MultiKernelLoaderSpec spec(/*arity=*/4); - spec.AddInProcessSymbol(absl::bit_cast(&redzone_checker), - "redzone_checker"); - return TypedKernelFactory, uint8_t, uint64_t, - DeviceMemory>::Create(executor, spec); +absl::StatusOr GetComparisonKernel( + StreamExecutor* executor, GpuAsmOpts /*gpu_asm_opts*/) { + static auto kernel = TypedKernelFactory< + DeviceMemory, uint8_t, uint64_t, + DeviceMemory>::Create(executor, "redzone_checker", + reinterpret_cast( + redzone_checker_kernel)); + + if (!kernel.ok()) return kernel.status(); + return &kernel.value(); } } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/host/BUILD b/third_party/xla/xla/stream_executor/host/BUILD index dcad498362ece2..a5740d0dc2a484 100644 --- a/third_party/xla/xla/stream_executor/host/BUILD +++ b/third_party/xla/xla/stream_executor/host/BUILD @@ -118,17 +118,18 @@ cc_library( "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:launch_dim", + "//xla/stream_executor:stream", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:env", + "//xla/tsl/platform:logging", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:logging", ], ) diff --git a/third_party/xla/xla/stream_executor/host/host_kernel.cc b/third_party/xla/xla/stream_executor/host/host_kernel.cc index cb64e7e9ff5329..6ed9ba7a2e0a13 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel.cc +++ b/third_party/xla/xla/stream_executor/host/host_kernel.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/base/optimization.h" @@ -27,10 +28,12 @@ limitations under the License. #include "absl/types/span.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/host/host_kernel_c_api.h" +#include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/threadpool.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/threadpool.h" namespace stream_executor::host { @@ -138,6 +141,26 @@ absl::Status HostKernel::Launch( return absl::OkStatus(); } +absl::Status HostKernel::Launch(const ThreadDim& thread_dims, + const BlockDim& block_dims, + const std::optional& cluster_dims, + Stream* stream, const KernelArgs& args) { + if (cluster_dims.has_value()) { + if (cluster_dims->x != 1 || cluster_dims->y != 1 || cluster_dims->z != 1) { + return absl::UnimplementedError("Not implemented for Host"); + } + } + const KernelArgsDeviceMemoryArray* device_mem = + DynCast(&args); + + if (device_mem != nullptr) { + return Launch(thread_dims, device_mem->device_memory_args()); + } + return absl::UnimplementedError( + "Host kernel implements Launch method only for DeviceMemoryArray " + "arguments."); +} + tsl::AsyncValueRef HostKernel::Launch( const ThreadDim& thread_dims, absl::Span buffers, TaskRunner task_runner) const { diff --git a/third_party/xla/xla/stream_executor/host/host_kernel.h b/third_party/xla/xla/stream_executor/host/host_kernel.h index b8eaf62c1646e4..fe62b9071934d1 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel.h +++ b/third_party/xla/xla/stream_executor/host/host_kernel.h @@ -89,6 +89,9 @@ class HostKernel : public Kernel { absl::Span buffers) const; absl::Status Launch(const ThreadDim& thread_dims, absl::Span args) const; + absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, + const std::optional& cluster_dims, + Stream* stream, const KernelArgs& args) override; // Launches the kernel by iterating over all threads in `thread_dims` and // calling `task_runner` to run individual task (implementation might decide diff --git a/third_party/xla/xla/stream_executor/host/host_kernel_test.cc b/third_party/xla/xla/stream_executor/host/host_kernel_test.cc index 4e766fc92158d5..aabcbc185aeb16 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel_test.cc +++ b/third_party/xla/xla/stream_executor/host/host_kernel_test.cc @@ -160,7 +160,8 @@ TEST(HostKernelTest, Addition3D) { TF_ASSERT_OK_AND_ASSIGN(auto add, executor->LoadKernel(spec)); const KernelArgsDeviceMemoryArray kargs{args, /*shared_memory_bytes=*/0}; - TF_ASSERT_OK(stream->Launch(ThreadDim(2, 2, 3), BlockDim(1), *add, kargs)); + TF_ASSERT_OK( + add->Launch(ThreadDim(2, 2, 3), BlockDim(1), stream.get(), kargs)); std::vector expected = {11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33}; @@ -186,7 +187,7 @@ TEST(HostKernelTest, JitAddition) { TF_ASSERT_OK_AND_ASSIGN(auto add, executor->LoadKernel(spec)); const KernelArgsDeviceMemoryArray kargs{args, /*shared_memory_bytes=*/0}; - TF_ASSERT_OK(stream->Launch(ThreadDim(4), BlockDim(1), *add, kargs)); + TF_ASSERT_OK(add->Launch(ThreadDim(4), BlockDim(1), stream.get(), kargs)); std::vector expected = {6, 8, 10, 12}; EXPECT_EQ(out, expected); diff --git a/third_party/xla/xla/stream_executor/host/host_stream.cc b/third_party/xla/xla/stream_executor/host/host_stream.cc index 1cbf01298ce213..ee812daad8d97a 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.cc +++ b/third_party/xla/xla/stream_executor/host/host_stream.cc @@ -196,27 +196,5 @@ absl::Status HostStream::BlockUntilDone() { return status; } -absl::Status HostStream::Launch(const ThreadDim& thread_dims, - const BlockDim& block_dims, - const std::optional& cluster_dims, - const Kernel& kernel, const KernelArgs& args) { - if (cluster_dims.has_value()) { - if (cluster_dims->x != 1 || cluster_dims->y != 1 || cluster_dims->z != 1) { - return absl::UnimplementedError("Not implemented for Host"); - } - } - const HostKernel* host_kernel = AsHostKernel(&kernel); - - const KernelArgsDeviceMemoryArray* device_mem = - DynCast(&args); - - if (device_mem != nullptr) { - return host_kernel->Launch(thread_dims, device_mem->device_memory_args()); - } - return absl::UnimplementedError( - "Host kernel implements Launch method only for DeviceMemoryArray " - "arguments."); -} - } // namespace host } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/host/host_stream.h b/third_party/xla/xla/stream_executor/host/host_stream.h index dc6760f8f629ca..4644052803c5ef 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.h +++ b/third_party/xla/xla/stream_executor/host/host_stream.h @@ -72,9 +72,6 @@ class HostStream : public StreamCommon { uint64_t size) override; absl::Status DoHostCallbackWithStatus( absl::AnyInvocable callback) override; - absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, - const std::optional& cluster_dims, - const Kernel& kernel, const KernelArgs& args) override; private: bool WorkAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); diff --git a/third_party/xla/xla/stream_executor/host/jit_host_kernel_function.cc b/third_party/xla/xla/stream_executor/host/jit_host_kernel_function.cc index 89c5e7cc8cbdfd..3291af042faad1 100644 --- a/third_party/xla/xla/stream_executor/host/jit_host_kernel_function.cc +++ b/third_party/xla/xla/stream_executor/host/jit_host_kernel_function.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -38,6 +37,7 @@ limitations under the License. #include "llvm/ExecutionEngine/ObjectCache.h" #include "llvm/ExecutionEngine/Orc/CompileUtils.h" #include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/CoreContainers.h" #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" @@ -177,7 +177,7 @@ class ExecutionEngine { static absl::StatusOr> CreateFromModule( std::unique_ptr ctx, std::unique_ptr module, Options options, - absl::Span exported); + absl::Span exported); // Returns a pointer to the exported function. absl::Span exported() const { return exported_; } @@ -233,10 +233,10 @@ static std::string ToString(const llvm::Error &err) { } absl::StatusOr> -ExecutionEngine::CreateFromModule(std::unique_ptr ctx, - std::unique_ptr module, - Options options, - absl::Span exported) { +ExecutionEngine::CreateFromModule( + std::unique_ptr ctx, + std::unique_ptr module, Options options, + absl::Span exported) { auto engine = std::unique_ptr(new ExecutionEngine( options.enable_gdb_listener, options.enable_perf_listener)); @@ -324,7 +324,7 @@ ExecutionEngine::CreateFromModule(std::unique_ptr ctx, llvm::DataLayout data_layout = (*jit)->getDataLayout(); // Resolve all exported functions to function pointers. - for (std::string_view name : exported) { + for (absl::string_view name : exported) { // Trigger compilation by looking up the exported function. // TODO(tsilytskyi): // - Do we need to mangle function name? @@ -418,7 +418,7 @@ JitHostKernelFunction::CreateFromLlvmIr(absl::string_view name, engine_options.target_machine = std::move(target_machine.get()); engine_options.make_optimizing_transformer = MakeOptimizingTransformerForJit; - std::vector exported = {entry}; + std::vector exported = {entry}; // Compile input module to the native function. TF_ASSIGN_OR_RETURN(auto engine, @@ -439,9 +439,9 @@ static void RegisterJitKernelFunctionLoader() { if (!spec.has_llvm_host_kernel()) return std::nullopt; const LlvmHostKernel &llvm_host_kernel = spec.llvm_host_kernel(); - std::string_view name = llvm_host_kernel.kernel_name(); - std::string_view entry = llvm_host_kernel.entrypoint(); - std::string_view ir = llvm_host_kernel.ir(); + absl::string_view name = llvm_host_kernel.kernel_name(); + absl::string_view entry = llvm_host_kernel.entrypoint(); + absl::string_view ir = llvm_host_kernel.ir(); absl::Span options = llvm_host_kernel.options(); return JitHostKernelFunction::CreateFromLlvmIr(name, entry, ir, diff --git a/third_party/xla/xla/stream_executor/integrations/BUILD b/third_party/xla/xla/stream_executor/integrations/BUILD index 3170776a291e7a..6e48ab51834387 100644 --- a/third_party/xla/xla/stream_executor/integrations/BUILD +++ b/third_party/xla/xla/stream_executor/integrations/BUILD @@ -46,6 +46,7 @@ cc_library( deps = [ "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:memory_allocation", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", @@ -56,6 +57,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:logging", @@ -93,9 +95,13 @@ xla_cc_test( "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/stream_executor/integrations/device_mem_allocator.h b/third_party/xla/xla/stream_executor/integrations/device_mem_allocator.h index 736b62e051314a..4e941aeca37a55 100644 --- a/third_party/xla/xla/stream_executor/integrations/device_mem_allocator.h +++ b/third_party/xla/xla/stream_executor/integrations/device_mem_allocator.h @@ -31,12 +31,19 @@ class DeviceMemAllocator : public tsl::SubAllocator { // 'platform_device_id' refers to the ID of the device within // the process and must reference a valid ID in the process. // Note: stream_exec cannot be null. - explicit DeviceMemAllocator(StreamExecutor* stream_exec, - tsl::PlatformDeviceId device_id, - MemoryType memory_type, - const std::vector& alloc_visitors, - const std::vector& free_visitors) - : SubAllocator(alloc_visitors, free_visitors), + DeviceMemAllocator(StreamExecutor* stream_exec, + tsl::PlatformDeviceId device_id, MemoryType memory_type, + const std::vector& alloc_visitors) + : SubAllocator(alloc_visitors, {}), + stream_exec_(stream_exec), + device_id_(device_id), + memory_type_(memory_type) { + CHECK(stream_exec_ != nullptr); + } + + DeviceMemAllocator(StreamExecutor* stream_exec, + tsl::PlatformDeviceId device_id, MemoryType memory_type) + : SubAllocator({}, {}), stream_exec_(stream_exec), device_id_(device_id), memory_type_(memory_type) { diff --git a/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.cc b/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.cc index fcdfccb763d444..bf09d1a3bad4e8 100644 --- a/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.cc +++ b/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.cc @@ -15,10 +15,14 @@ limitations under the License. #include "xla/stream_executor/integrations/tf_allocator_adapter.h" +#include + #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/platform.h" @@ -49,8 +53,7 @@ absl::StatusOr TfAllocatorAdapter::Allocate( data = wrapped_->AllocateRaw(tsl::Allocator::kAllocatorAlignment, size, attrs); if (data == nullptr) { - return absl::ResourceExhaustedError(absl::StrCat( - "Out of memory while trying to allocate ", size, " bytes.")); + return MemoryAllocationError(size); } } return OwningDeviceMemory(DeviceMemoryBase(data, size), device_ordinal, this); @@ -81,4 +84,18 @@ absl::StatusOr TfAllocatorAdapter::GetAllocator( return wrapped_; } +static constexpr absl::string_view kMemoryAllocationErrorPayloadKey = + "tf-allocator-allocation-error"; + +absl::Status MemoryAllocationError(uint64_t size) { + absl::Status status = absl::ResourceExhaustedError( + absl::StrCat("Out of memory while trying to allocate ", size, " bytes.")); + status.SetPayload(kMemoryAllocationErrorPayloadKey, absl::Cord()); + return status; +} + +bool IsMemoryAllocationError(absl::Status status) { + return status.GetPayload(kMemoryAllocationErrorPayloadKey).has_value(); +} + } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.h b/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.h index 1d8b96b5e09da9..d712027bfb37aa 100644 --- a/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.h +++ b/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_INTEGRATIONS_TF_ALLOCATOR_ADAPTER_H_ #define XLA_STREAM_EXECUTOR_INTEGRATIONS_TF_ALLOCATOR_ADAPTER_H_ +#include #include +#include #include #include #include @@ -29,6 +31,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" @@ -196,6 +199,13 @@ class MultiDeviceAdapter : public DeviceMemoryAllocator { std::vector> tf_allocators_; }; +// Creates a status with a payload indicating an error while allocating `size` +// bytes of memory. +absl::Status MemoryAllocationError(uint64_t size); + +// Checks whether the status is a memory allocation error. +bool IsMemoryAllocationError(absl::Status status); + } // namespace stream_executor #endif // XLA_STREAM_EXECUTOR_INTEGRATIONS_TF_ALLOCATOR_ADAPTER_H_ diff --git a/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter_test.cc b/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter_test.cc index 0969b97e866afb..6e845e9c3beabc 100644 --- a/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter_test.cc +++ b/third_party/xla/xla/stream_executor/integrations/tf_allocator_adapter_test.cc @@ -21,20 +21,22 @@ limitations under the License. #include #include +#include #include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_set.h" #include "absl/log/check.h" +#include "absl/status/status.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/framework/allocator.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" -namespace se = stream_executor; +namespace stream_executor { +namespace { // Each allocation will have an incrementing address. class TestAllocator : public tsl::Allocator { @@ -73,11 +75,11 @@ class TestAllocator : public tsl::Allocator { TEST(MultiDeviceAdapter, UsesCorrectAllocator) { TF_ASSERT_OK_AND_ASSIGN(auto* platform, xla::PlatformUtil::GetDefaultPlatform()); - TF_ASSERT_OK_AND_ASSIGN(std::vector executors, + TF_ASSERT_OK_AND_ASSIGN(std::vector executors, xla::PlatformUtil::GetStreamExecutors(platform)) TF_ASSERT_OK_AND_ASSIGN(auto stream, executors[0]->CreateStream()); - std::vector infos; + std::vector infos; infos.emplace_back(std::make_unique(0x1000), stream.get(), /*memory_space=*/0, /*device_ordinal=*/0); infos.emplace_back(std::make_unique(0x2000), stream.get(), @@ -86,27 +88,27 @@ TEST(MultiDeviceAdapter, UsesCorrectAllocator) { /*memory_space=*/1, /*device_ordinal=*/0); infos.emplace_back(std::make_unique(0x4000), stream.get(), /*memory_space=*/1, /*device_ordinal=*/1); - std::unique_ptr allocator = - std::make_unique(platform, std::move(infos)); + std::unique_ptr allocator = + std::make_unique(platform, std::move(infos)); TF_ASSERT_OK_AND_ASSIGN( - se::OwningDeviceMemory buff0, + OwningDeviceMemory buff0, allocator->Allocate(/*device_ordinal=*/0, 4, false, /*memory_space=*/0)); CHECK_EQ(reinterpret_cast(buff0->opaque()), 0x1001); TF_ASSERT_OK_AND_ASSIGN( - se::OwningDeviceMemory buff1, + OwningDeviceMemory buff1, allocator->Allocate(/*device_ordinal=*/0, 4, false, /*memory_space=*/0)); CHECK_EQ(reinterpret_cast(buff1->opaque()), 0x1002); TF_ASSERT_OK_AND_ASSIGN( - se::OwningDeviceMemory buff2, + OwningDeviceMemory buff2, allocator->Allocate(/*device_ordinal=*/0, 4, false, /*memory_space=*/1)); CHECK_EQ(reinterpret_cast(buff2->opaque()), 0x3001); TF_ASSERT_OK_AND_ASSIGN( - se::OwningDeviceMemory buff3, + OwningDeviceMemory buff3, allocator->Allocate(/*device_ordinal=*/1, 4, false, /*memory_space=*/0)); CHECK_EQ(reinterpret_cast(buff3->opaque()), 0x2001); TF_ASSERT_OK_AND_ASSIGN( - se::OwningDeviceMemory buff4, + OwningDeviceMemory buff4, allocator->Allocate(/*device_ordinal=*/1, 4, false, /*memory_space=*/1)); CHECK_EQ(reinterpret_cast(buff4->opaque()), 0x4001); } @@ -114,31 +116,30 @@ TEST(MultiDeviceAdapter, UsesCorrectAllocator) { TEST(MultiDeviceAdapter, DeallocationWithDifferentAllocator) { TF_ASSERT_OK_AND_ASSIGN(auto* platform, xla::PlatformUtil::GetDefaultPlatform()); - TF_ASSERT_OK_AND_ASSIGN(std::vector executors, + TF_ASSERT_OK_AND_ASSIGN(std::vector executors, xla::PlatformUtil::GetStreamExecutors(platform)); TF_ASSERT_OK_AND_ASSIGN(auto stream, executors[0]->CreateStream()); std::shared_ptr> allocations = std::make_shared>(); - std::vector info_allocator; + std::vector info_allocator; info_allocator.emplace_back( std::make_unique(0x1000, allocations), stream.get(), /*memory_space=*/0, /*device_ordinal=*/0); - std::unique_ptr allocator = - std::make_unique(platform, - std::move(info_allocator)); + std::unique_ptr allocator = + std::make_unique(platform, std::move(info_allocator)); - std::vector info_deallocator; + std::vector info_deallocator; info_deallocator.emplace_back( std::make_unique(0x1000, allocations), stream.get(), /*memory_space=*/0, /*device_ordinal=*/0); - std::unique_ptr deallocator = - std::make_unique(platform, - std::move(info_deallocator)); + std::unique_ptr deallocator = + std::make_unique(platform, + std::move(info_deallocator)); TF_ASSERT_OK_AND_ASSIGN( - se::OwningDeviceMemory buff0, + OwningDeviceMemory buff0, allocator->Allocate(/*device_ordinal=*/0, 4, false, /*memory_space=*/0)); CHECK_EQ(allocations->size(), 1); CHECK_EQ(reinterpret_cast(buff0->opaque()), 0x1001); @@ -150,3 +151,12 @@ TEST(MultiDeviceAdapter, DeallocationWithDifferentAllocator) { // destruction. allocations->insert(buff0->opaque()); } + +TEST(MemoryAllocationError, IsMemoryAllocationError) { + EXPECT_TRUE(IsMemoryAllocationError(MemoryAllocationError(100))); + EXPECT_FALSE(IsMemoryAllocationError(absl::OkStatus())); + EXPECT_FALSE(IsMemoryAllocationError(absl::InternalError(""))); +} + +} // namespace +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/kernel.h b/third_party/xla/xla/stream_executor/kernel.h index 6076717d430598..7ce0877cf6332b 100644 --- a/third_party/xla/xla/stream_executor/kernel.h +++ b/third_party/xla/xla/stream_executor/kernel.h @@ -90,6 +90,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream.h" #include "tsl/platform/logging.h" namespace stream_executor { @@ -228,16 +229,48 @@ class Kernel { args_packing_ = std::move(args_packing); } - std::string_view name() const { return name_; } + absl::string_view name() const { return name_; } void set_name(absl::string_view name); + // Launches a data parallel kernel with the given thread/block + // dimensionality and already-packed args/sizes to pass to the underlying + // platform driver. + absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, + Stream *stream, const KernelArgs &args); + + // Launches a data parallel kernel with the given thread/block + // dimensionality and already-packed args/sizes to pass to the underlying + // platform driver. + absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, + const ClusterDim &cluster_dims, Stream *stream, + const KernelArgs &args); + private: + // Helper method to launch a kernel with optional cluster dimensions. + virtual absl::Status Launch(const ThreadDim &thread_dims, + const BlockDim &block_dims, + const std::optional &cluster_dims, + Stream *stream, const KernelArgs &args) = 0; + std::string name_; KernelMetadata metadata_; KernelArgsPacking args_packing_; }; +inline absl::Status Kernel::Launch(const ThreadDim &thread_dims, + const BlockDim &block_dims, Stream *stream, + const KernelArgs &args) { + return Launch(thread_dims, block_dims, std::nullopt, stream, args); +} +inline absl::Status Kernel::Launch(const ThreadDim &thread_dims, + const BlockDim &block_dims, + const ClusterDim &cluster_dims, + Stream *stream, const KernelArgs &args) { + return Launch(thread_dims, block_dims, std::make_optional(cluster_dims), + stream, args); +} + //===----------------------------------------------------------------------===// // Typed kernel //===----------------------------------------------------------------------===// @@ -263,6 +296,39 @@ class TypedKernel { // Type of factory used to create a TypedKernel. using FactoryType = TypedKernelFactory; + // Launches a kernel with the given (variadic) parameters for the invocation + // onto the specified stream. These arguments can be things + // like DeviceMemory or primitive types such as int. What arguments you may + // pass to a given kernel are noted as the template parameters to the + // TypedKernel type that the compiler generates. + // + // Template parameters: + // Params... The type list of formal parameters that the typed kernel + // expects, which is matched against Args... + // Args... The deduced type list for passed actual arguments + // + // Implementation: A compile-time compatibility check is performed that has + // some leniency versus an exact parameter pack match -- for example, + // `const DeviceMemory` is considered "pack compatible" with a + // `const DeviceMemory&` formal parameter; in part, because we don't have + // perfect forwarding support without rvalue references. It also attempts to + // spit out helpful static_assert error traces with information as to the + // argument number and types that were mismatched. + template + inline absl::Status Launch(ThreadDim thread_dims, BlockDim block_dims, + Stream *stream, Args... args) { + auto kernel_args = PackKernelArgs(*this, args...); + return kernel_->Launch(thread_dims, block_dims, stream, *kernel_args); + } + + template + inline absl::Status Launch(ThreadDim thread_dims, BlockDim block_dims, + int32_t shmem_bytes, Stream *stream, + Args... args) { + auto kernel_args = PackKernelArgs(shmem_bytes, args...); + return kernel_->Launch(thread_dims, block_dims, stream, *kernel_args); + } + private: friend class TypedKernelFactory; explicit TypedKernel(std::unique_ptr kernel) diff --git a/third_party/xla/xla/stream_executor/mock_stream.h b/third_party/xla/xla/stream_executor/mock_stream.h index 41d06aa4f6e607..2aa6e8064a453f 100644 --- a/third_party/xla/xla/stream_executor/mock_stream.h +++ b/third_party/xla/xla/stream_executor/mock_stream.h @@ -25,11 +25,11 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/event_based_timer.h" -#include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" @@ -75,10 +75,10 @@ class MockStream : public Stream { (const, override)); MOCK_METHOD((std::variant), priority, (), (const, override)); - MOCK_METHOD(absl::Status, Launch, + MOCK_METHOD(absl::Status, LaunchKernel, (const ThreadDim &thread_dims, const BlockDim &block_dims, - const std::optional &cluster_dims, const Kernel &k, - const KernelArgs &args), + const std::optional &cluster_dims, void *function, + absl::string_view name, void **args, int64_t shmem_bytes), (override)); MOCK_METHOD(const std::string &, GetName, (), (const, override)); MOCK_METHOD(void, SetName, (std::string name), (override)); diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index 5f2ecac1090216..b2ac827f6166c8 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -244,12 +244,16 @@ cc_library( "//xla/stream_executor:activate_context", "//xla/stream_executor:kernel", "//xla/stream_executor:launch_dim", + "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_config_rocm//rocm:rocm_headers", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", ], ) @@ -820,15 +824,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "rocm_rpath", - linkopts = select({ - "//conditions:default": [ - "-Wl,-rpath,../local_config_rocm/rocm/rocm/lib", - ], - }), -) - cc_library( name = "stream_executor_rocm", tags = [ @@ -837,12 +832,12 @@ cc_library( ], deps = [ ":rocm_platform_id", - ":rocm_rpath", "//xla/stream_executor:dnn", "//xla/stream_executor:platform_manager", "//xla/stream_executor:scratch_allocator", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/host:host_platform_id", + "@local_config_rocm//rocm:rocm_rpath", ] + if_static( [":all_runtime"], ), @@ -1022,6 +1017,7 @@ cc_test( deps = [ ":rocm_status", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform:status_matchers", diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_kernel.cc b/third_party/xla/xla/stream_executor/rocm/rocm_kernel.cc index a75b62927ba1c2..e345e97b3f9a58 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_kernel.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_kernel.cc @@ -18,8 +18,11 @@ limitations under the License. #include #include #include +#include +#include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "xla/stream_executor/activate_context.h" @@ -27,7 +30,9 @@ limitations under the License. #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/rocm/rocm_driver_wrapper.h" #include "xla/stream_executor/rocm/rocm_status.h" -#include "tsl/platform/errors.h" +#include "xla/stream_executor/stream.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" namespace stream_executor { namespace gpu { @@ -73,5 +78,57 @@ absl::StatusOr RocmKernel::GetKernelMetadata() { return kernel_metadata; } +absl::Status RocmKernel::Launch(const ThreadDim& thread_dims, + const BlockDim& block_dims, + const std::optional& cluster_dims, + Stream* stream, const KernelArgs& args) { + hipFunction_t function = gpu_function(); + + // Launch kernels with packed arguments. + auto launch = [this, &cluster_dims, &thread_dims, &block_dims, &function, + stream](const KernelArgsPackedArrayBase& packed) { + int32_t expected_number_of_arguments = + Arity() + (packed.number_of_shared_bytes() > 0); + + CHECK_EQ(expected_number_of_arguments, packed.number_of_arguments()) + << "Kernel " << name() << " has " << packed.number_of_arguments() + << " arguments, but expected " << expected_number_of_arguments + << "; arity=" << Arity() + << "; number_of_shared_bytes=" << packed.number_of_shared_bytes(); + + void** params = const_cast(packed.argument_addresses().data()); + + if (cluster_dims.has_value()) { + return stream->LaunchKernel(thread_dims, block_dims, cluster_dims, + function, name(), params, + packed.number_of_shared_bytes()); + } else { + return stream->LaunchKernel(thread_dims, block_dims, std::nullopt, + function, name(), params, + packed.number_of_shared_bytes()); + } + }; + + // If arguments are already packed we can just launch the kernel. + if (auto* packed = DynCast(&args)) { + return launch(*packed); + } + + // For device memory array we rely on a custom kernel arguments packing. + if (auto* device_mem = DynCast(&args)) { + auto& pack = args_packing(); + if (!pack) { + return absl::InternalError( + "Kernel is missing a custom arguments packing function for device " + "memory arguments array"); + } + + TF_ASSIGN_OR_RETURN(auto packed, pack(*this, *device_mem)); + return launch(*packed); + } + + return absl::InternalError("Unsupported kernel arguments type"); +} + } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_kernel.h b/third_party/xla/xla/stream_executor/rocm/rocm_kernel.h index 7fe8542ae2e69e..a252666ea7d1ad 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_kernel.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_kernel.h @@ -60,6 +60,10 @@ class RocmKernel : public Kernel { absl::StatusOr GetKernelMetadata(); private: + absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, + const std::optional &cluster_dims, + Stream *stream, const KernelArgs &args) override; + StreamExecutor* executor_ = nullptr; hipFunction_t rocm_function_ = nullptr; // wrapped HIP kernel handle diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_status_test.cc b/third_party/xla/xla/stream_executor/rocm/rocm_status_test.cc index 0f5e46f33a557e..1ad8fb99a37596 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_status_test.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_status_test.cc @@ -15,11 +15,10 @@ limitations under the License. #include "xla/stream_executor/rocm/rocm_status.h" -#include - #include #include #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "rocm/include/hip/hip_runtime.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/test.h" @@ -42,7 +41,7 @@ TEST(RocmStatusTest, ToStatusReturnsExpectedStatusCodes) { } TEST(RocmStatusTest, ToStatusIncludesDetailMessage) { - constexpr std::string_view kMyMessage = "Some arbitrary message"; + constexpr absl::string_view kMyMessage = "Some arbitrary message"; EXPECT_THAT(ToStatus(hipErrorNotInitialized, kMyMessage), StatusIs(absl::StatusCode::kInternal, HasSubstr(kMyMessage))); } diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_stream.cc b/third_party/xla/xla/stream_executor/rocm/rocm_stream.cc index dff3a877227fc5..c7ab3c462ca32c 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_stream.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_stream.cc @@ -326,13 +326,12 @@ absl::Status RocmStream::DoHostCallbackWithStatus( } namespace { -absl::Status LaunchKernel(StreamExecutor* executor, - absl::string_view kernel_name, hipFunction_t function, - unsigned int grid_dim_x, unsigned int grid_dim_y, - unsigned int grid_dim_z, unsigned int block_dim_x, - unsigned int block_dim_y, unsigned int block_dim_z, - unsigned int shared_mem_bytes, hipStream_t stream, - void** kernel_params, void** extra) { +absl::Status LaunchRocmKernel( + StreamExecutor* executor, absl::string_view kernel_name, + hipFunction_t function, unsigned int grid_dim_x, unsigned int grid_dim_y, + unsigned int grid_dim_z, unsigned int block_dim_x, unsigned int block_dim_y, + unsigned int block_dim_z, unsigned int shared_mem_bytes, hipStream_t stream, + void** kernel_params, void** extra) { std::unique_ptr activation = executor->Activate(); VLOG(2) << "launching kernel: " << kernel_name << "; gdx: " << grid_dim_x << " gdy: " << grid_dim_y << " gdz: " << grid_dim_z @@ -366,21 +365,20 @@ absl::Status LaunchKernel(StreamExecutor* executor, return absl::OkStatus(); } -absl::Status LaunchKernel(StreamExecutor* executor, - absl::string_view kernel_name, hipFunction_t function, - unsigned int cluster_dim_x, - unsigned int cluster_dim_y, - unsigned int cluster_dim_z, unsigned int grid_dim_x, - unsigned int grid_dim_y, unsigned int grid_dim_z, - unsigned int block_dim_x, unsigned int block_dim_y, - unsigned int block_dim_z, - unsigned int shared_mem_bytes, hipStream_t stream, - void** kernel_params, void** extra) { +absl::Status LaunchRocmKernel( + StreamExecutor* executor, absl::string_view kernel_name, + hipFunction_t function, unsigned int cluster_dim_x, + unsigned int cluster_dim_y, unsigned int cluster_dim_z, + unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, + unsigned int block_dim_x, unsigned int block_dim_y, + unsigned int block_dim_z, unsigned int shared_mem_bytes, hipStream_t stream, + void** kernel_params, void** extra) { if (cluster_dim_x != 1 || cluster_dim_y != 1 || cluster_dim_z != 1) return absl::UnimplementedError("Not implemented for ROCm"); - return LaunchKernel(executor, kernel_name, function, grid_dim_x, grid_dim_y, - grid_dim_z, block_dim_x, block_dim_y, block_dim_z, - shared_mem_bytes, stream, kernel_params, extra); + return LaunchRocmKernel(executor, kernel_name, function, grid_dim_x, + grid_dim_y, grid_dim_z, block_dim_x, block_dim_y, + block_dim_z, shared_mem_bytes, stream, kernel_params, + extra); } } // namespace @@ -389,62 +387,24 @@ absl::Status RocmStream::BlockHostUntilDone() { return SynchronizeStream(executor_, stream_handle_); } -absl::Status RocmStream::Launch(const ThreadDim& thread_dims, - const BlockDim& block_dims, - const std::optional& cluster_dims, - const Kernel& kernel, const KernelArgs& args) { - const RocmKernel* gpu_kernel = static_cast(&kernel); - hipFunction_t function = gpu_kernel->gpu_function(); - - // Launch kernels with packed arguments. - auto launch = [this, &kernel, &cluster_dims, &thread_dims, &block_dims, - &function](const KernelArgsPackedArrayBase& packed) { - int32_t expected_number_of_arguments = - kernel.Arity() + (packed.number_of_shared_bytes() > 0); - - CHECK_EQ(expected_number_of_arguments, packed.number_of_arguments()) - << "Kernel " << kernel.name() << " has " << packed.number_of_arguments() - << " arguments, but expected " << expected_number_of_arguments - << "; arity=" << kernel.Arity() - << "; number_of_shared_bytes=" << packed.number_of_shared_bytes(); - - void** params = const_cast(packed.argument_addresses().data()); - - if (cluster_dims.has_value()) { - return LaunchKernel( - executor_, kernel.name(), function, cluster_dims->x, cluster_dims->y, - cluster_dims->z, block_dims.x, block_dims.y, block_dims.z, - thread_dims.x, thread_dims.y, thread_dims.z, - packed.number_of_shared_bytes(), stream_handle_, params, - /*extra=*/nullptr); - } else { - return LaunchKernel( - executor_, kernel.name(), function, block_dims.x, block_dims.y, - block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, - packed.number_of_shared_bytes(), stream_handle_, params, - /*extra=*/nullptr); - } - }; - - // If arguments are already packed we can just launch the kernel. - if (auto* packed = DynCast(&args)) { - return launch(*packed); - } - - // For device memory array we rely on a custom kernel arguments packing. - if (auto* device_mem = DynCast(&args)) { - auto& pack = kernel.args_packing(); - if (!pack) { - return absl::InternalError( - "Kernel is missing a custom arguments packing function for device " - "memory arguments array"); - } - - TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem)); - return launch(*packed); +absl::Status RocmStream::LaunchKernel( + const ThreadDim& thread_dims, const BlockDim& block_dims, + const std::optional& cluster_dims, void* function, + absl::string_view name, void** args, int64_t shmem_bytes) { + if (cluster_dims.has_value()) { + return LaunchRocmKernel( + executor_, name, static_cast(function), cluster_dims->x, + cluster_dims->y, cluster_dims->z, block_dims.x, block_dims.y, + block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, shmem_bytes, + stream_handle_, args, + /*extra=*/nullptr); + } else { + return LaunchRocmKernel( + executor_, name, static_cast(function), block_dims.x, + block_dims.y, block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, + shmem_bytes, stream_handle_, args, + /*extra=*/nullptr); } - - return absl::InternalError("Unsupported kernel arguments type"); } } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_stream.h b/third_party/xla/xla/stream_executor/rocm/rocm_stream.h index 693335daa187bf..977d27f3b7e131 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_stream.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_stream.h @@ -84,9 +84,11 @@ class RocmStream : public StreamCommon { absl::Status RecordCompletedEvent(); - absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, - const std::optional& cluster_dims, - const Kernel& kernel, const KernelArgs& args) override; + absl::Status LaunchKernel(const ThreadDim& thread_dims, + const BlockDim& block_dims, + const std::optional& cluster_dims, + void* function, absl::string_view name, void** args, + int64_t shmem_bytes) override; StreamExecutor* executor_; RocmEvent completed_event_; diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_stream_test.cc b/third_party/xla/xla/stream_executor/rocm/rocm_stream_test.cc index d8a312b0032864..8687230dd730a3 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_stream_test.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_stream_test.cc @@ -50,7 +50,6 @@ namespace { using ::testing::Each; using ::testing::ElementsAre; using ::testing::ElementsAreArray; -using ::testing::IsEmpty; using ::tsl::testing::IsOk; class RocmStreamTest : public ::testing::Test { @@ -219,7 +218,7 @@ TEST_F(RocmStreamTest, LaunchKernel) { EXPECT_THAT(stream->Memset32(&a, 1, kByteLength), IsOk()); EXPECT_THAT(stream->Memset32(&b, 2, kByteLength), IsOk()); EXPECT_THAT(stream->MemZero(&c, kByteLength), IsOk()); - EXPECT_THAT(stream->ThenLaunch(ThreadDim(), BlockDim(kLength), add, a, b, c), + EXPECT_THAT(add.Launch(ThreadDim(), BlockDim(kLength), stream.get(), a, b, c), IsOk()); EXPECT_THAT(stream->BlockHostUntilDone(), IsOk()); diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_timer_test.cc b/third_party/xla/xla/stream_executor/rocm/rocm_timer_test.cc index 958c5dfa53316f..cbe8b38c6c9dff 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_timer_test.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_timer_test.cc @@ -64,10 +64,7 @@ class RocmTimerTest : public ::testing::Test { ASSERT_THAT(stream->Memset32(&a, 1, byte_length), IsOk()); ASSERT_THAT(stream->Memset32(&b, 2, byte_length), IsOk()); - ASSERT_THAT(stream->MemZero(&c, byte_length), IsOk()); - - ASSERT_THAT(stream->ThenLaunch(ThreadDim(), BlockDim(4), add, a, b, c), - IsOk()); + ASSERT_THAT(add.Launch(ThreadDim(), BlockDim(4), stream, a, b, c), IsOk()); } RocmExecutor* executor_; diff --git a/third_party/xla/xla/stream_executor/stream.h b/third_party/xla/xla/stream_executor/stream.h index 220cbf761c24fd..c3f01994de2ebc 100644 --- a/third_party/xla/xla/stream_executor/stream.h +++ b/third_party/xla/xla/stream_executor/stream.h @@ -38,7 +38,6 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/event_based_timer.h" -#include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" @@ -106,34 +105,6 @@ class Stream { // TODO(b/112196569): The semantics of failed sub-streams is error-prone. virtual void ReturnSubStream(Stream *sub_stream) = 0; - // Entrains onto the stream of operations: a kernel launch with the given - // (variadic) parameters for the invocation. These arguments can be things - // like DeviceMemory or primitive types such as int. What arguments you may - // pass to a given kernel are noted as the template parameters to the - // TypedKernel type that the compiler generates. - // - // Template parameters: - // Params... The type list of formal parameters that the typed kernel - // expects, which is matched against Args... - // Args... The deduced type list for passed actual arguments - // - // Implementation: A compile-time compatibility check is performed that has - // some leniency versus an exact parameter pack match -- for example, - // `const DeviceMemory` is considered "pack compatible" with a - // `const DeviceMemory&` formal parameter; in part, because we don't have - // perfect forwarding support without rvalue references. It also attempts to - // spit out helpful static_assert error traces with information as to the - // argument number and types that were mismatched. - template - absl::Status ThenLaunch(ThreadDim thread_dims, BlockDim block_dims, - const TypedKernel &kernel, Args... args); - - // Same as above, with an explicit argument for shared memory size in bytes. - template - absl::Status ThenLaunch(ThreadDim thread_dims, BlockDim block_dims, - int32_t shmem_bytes, - const TypedKernel &kernel, Args... args); - // Create a dependency for this stream's next work on the other stream // completing. Does not take ownership of other, and other must not be // null. @@ -269,24 +240,6 @@ class Stream { // Gets priority for a stream. virtual std::variant priority() const = 0; - // Launches a data parallel kernel with the given thread/block - // dimensionality and already-packed args/sizes to pass to the underlying - // platform driver. - absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, - const Kernel &kernel, const KernelArgs &args) { - return Launch(thread_dims, block_dims, std::nullopt, kernel, args); - } - - // Launches a data parallel kernel with the given thread/block - // dimensionality and already-packed args/sizes to pass to the underlying - // platform driver. - absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, - const ClusterDim &cluster_dims, const Kernel &kernel, - const KernelArgs &args) { - return Launch(thread_dims, block_dims, std::make_optional(cluster_dims), - kernel, args); - } - // Get/set a name for a stream, which can be shown in profiling tools virtual const std::string &GetName() const = 0; virtual void SetName(std::string name) = 0; @@ -306,34 +259,15 @@ class Stream { "This stream does not support EventBasedTimers."); } - private: // Helper method to launch a kernel with optional cluster dimensions. - virtual absl::Status Launch(const ThreadDim &thread_dims, - const BlockDim &block_dims, - const std::optional &cluster_dims, - const Kernel &kernel, const KernelArgs &args) { + virtual absl::Status LaunchKernel( + const ThreadDim &thread_dims, const BlockDim &block_dims, + const std::optional &cluster_dims, void *function, + absl::string_view name, void **args, int64_t shmem_bytes) { return absl::UnimplementedError("Not implemented"); } }; -template -inline absl::Status Stream::ThenLaunch(ThreadDim thread_dims, - BlockDim block_dims, - const TypedKernel &kernel, - Args... args) { - auto kernel_args = PackKernelArgs(kernel, args...); - return Launch(thread_dims, block_dims, *kernel, *kernel_args); -} - -template -inline absl::Status Stream::ThenLaunch(ThreadDim thread_dims, - BlockDim block_dims, int32_t shmem_bytes, - const TypedKernel &kernel, - Args... args) { - auto kernel_args = PackKernelArgs(shmem_bytes, args...); - return Launch(thread_dims, block_dims, *kernel, *kernel_args); -} - } // namespace stream_executor #endif // XLA_STREAM_EXECUTOR_STREAM_H_ diff --git a/third_party/xla/xla/stream_executor/stream_executor.h b/third_party/xla/xla/stream_executor/stream_executor.h index 2ebd361fa16756..719207ba016314 100644 --- a/third_party/xla/xla/stream_executor/stream_executor.h +++ b/third_party/xla/xla/stream_executor/stream_executor.h @@ -31,9 +31,11 @@ limitations under the License. #include "xla/stream_executor/allocator_stats.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" +#include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/fft.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" @@ -47,19 +49,6 @@ namespace stream_executor { // Identifies the memory space where an allocation resides. enum class MemoryType { kDevice = 0, kUnified, kCollective, kHost = 5 }; -inline std::string MemoryTypeString(MemoryType memory_type) { - switch (memory_type) { - case MemoryType::kDevice: - return "device"; - case MemoryType::kUnified: - return "unified"; - case MemoryType::kCollective: - return "collective"; - case MemoryType::kHost: - return "host"; - } -} - /// The StreamExecutor is a single-device abstraction for: // // * Loading/launching data-parallel-kernels diff --git a/third_party/xla/xla/stream_executor/typed_kernel_factory.h b/third_party/xla/xla/stream_executor/typed_kernel_factory.h index 5e81c35c7f5374..5a0b5133e3f992 100644 --- a/third_party/xla/xla/stream_executor/typed_kernel_factory.h +++ b/third_party/xla/xla/stream_executor/typed_kernel_factory.h @@ -44,7 +44,7 @@ class TypedKernelFactory { return TypedKernel(std::move(kernel)); } - // Creates a kernel which can be launched with `stream.ThenLaunch(...)` from a + // Creates a kernel which can be launched on a stream from a // PTX (and optional CUBIN), such that the types of the arguments provided for // launch would have to match types of the arguments provided at creation // time. The canonical storage for both ptx and cubin_data should outlive the @@ -63,7 +63,7 @@ class TypedKernelFactory { return Create(executor, loader_spec); } - // Creates a kernel which can be launched with `stream.ThenLaunch(...)` from + // Creates a kernel which can be launched on a stream from // an in-process symbol pointer. static absl::StatusOr> Create( StreamExecutor *executor, absl::string_view kernel_name, void *symbol) { diff --git a/third_party/xla/xla/test.h b/third_party/xla/xla/test.h index 5117b8fd41a1c6..8ce11ab8a7a374 100644 --- a/third_party/xla/xla/test.h +++ b/third_party/xla/xla/test.h @@ -34,16 +34,7 @@ limitations under the License. // Note that while the use of gmock matchers is allowed in the xla project, the // use of mocks is disallowed in the whole tensorflow project! -#include "tsl/platform/platform.h" - -#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) -#include // IWYU pragma: export -#else -#include -#include // IWYU pragma: export -#include // IWYU pragma: export -#endif - -#include "tsl/platform/test.h" // IWYU pragma: export +// The current header will be deprecated in favour of the following. +#include "xla/hlo/testlib/test.h" #endif // XLA_TEST_H_ diff --git a/third_party/xla/xla/test_helpers.h b/third_party/xla/xla/test_helpers.h index bc0a054626b497..77336bd5aa53cc 100644 --- a/third_party/xla/xla/test_helpers.h +++ b/third_party/xla/xla/test_helpers.h @@ -16,53 +16,7 @@ limitations under the License. #ifndef XLA_TEST_HELPERS_H_ #define XLA_TEST_HELPERS_H_ -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "tsl/platform/test.h" - -// This module contains a minimal subset of gmock functionality just -// sufficient to execute the currently existing tests. - -namespace xla { -template -class Array2D; -class Literal; - -namespace testing { - -namespace internal_status { -// TODO(b/340953531) Eliminate this function. -inline const absl::Status& GetStatus(const absl::Status& status) { - return status; -} - -template -inline const absl::Status& GetStatus(const absl::StatusOr& status) { - return status.status(); -} -} // namespace internal_status - -} // namespace testing -} // namespace xla - -// The following macros are similar to macros in gmock, but deliberately named -// differently in order to avoid conflicts in files which include both. - -// Macros for testing the results of functions that return absl::Status or -// absl::StatusOr (for any type T). -#define EXPECT_IS_OK(expression) \ - EXPECT_EQ(::absl::OkStatus(), \ - xla::testing::internal_status::GetStatus(expression)) -#define EXPECT_IS_NOT_OK(expression) \ - EXPECT_NE(::absl::OkStatus(), \ - xla::testing::internal_status::GetStatus(expression)) -#undef ASSERT_IS_OK -#define ASSERT_IS_OK(expression) \ - ASSERT_EQ(::absl::OkStatus(), \ - xla::testing::internal_status::GetStatus(expression)) -#undef ASSERT_IS_NOT_OK -#define ASSERT_IS_NOT_OK(expression) \ - ASSERT_NE(::absl::OkStatus(), \ - xla::testing::internal_status::GetStatus(expression)) +// The current header will be deprecated in favour of the following. +#include "xla/hlo/testlib/test_helpers.h" #endif // XLA_TEST_HELPERS_H_ diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 0b938d25e4f5eb..509e145f2f4421 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -12,10 +12,6 @@ load("//xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_ load("//xla/tsl:tsl.bzl", "if_google", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "filegroup") load("//xla/tsl/platform:rules_cc.bzl", "cc_library") -load( - "//xla/tsl/platform/default:cuda_build_defs.bzl", - "if_cuda_is_configured", -) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -96,10 +92,10 @@ cc_library( "//xla:literal", "//xla:literal_comparison", "//xla:literal_util", - "//xla:test", - "//xla:test_helpers", "//xla:types", "//xla:xla_data_proto_cc", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -163,10 +159,8 @@ cc_library( ], deps = [ ":pjrt_client_registry", - "//xla/pjrt:interpreter_device", "//xla/pjrt:pjrt_client", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:status", + "//xla/pjrt/interpreter:interpreter_client", ], ) @@ -176,49 +170,35 @@ cc_library( srcs = ["hlo_test_base.cc"], hdrs = ["hlo_test_base.h"], deps = [ - ":filecheck", ":hlo_runner_agnostic_test_base", - ":literal_test_util", ":pjrt_client_registry", - ":test_utils", - ":verified_hlo_module", - "//xla:debug_options_flags", "//xla:error_spec", "//xla:literal", - "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_module_group", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/testlib:filecheck", "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/hlo/utils:hlo_query", "//xla/pjrt:pjrt_client", "//xla/service:backend", "//xla/service:computation_placer_hdr", "//xla/service:executable", - "//xla/service:hlo_module_config", - "//xla/service:hlo_module_util", "//xla/service:hlo_runner", "//xla/service:hlo_runner_interface", "//xla/service:hlo_runner_pjrt", - "//xla/service:hlo_verifier", "//xla/service:interpreter_plugin", # reference backend "//xla/service:platform_util", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", @@ -232,39 +212,29 @@ cc_library( srcs = ["hlo_runner_agnostic_test_base.cc"], hdrs = ["hlo_runner_agnostic_test_base.h"], deps = [ - ":filecheck", ":literal_test_util", ":test_utils", - ":verified_hlo_module", - "//xla:debug_options_flags", "//xla:error_spec", "//xla:literal", - "//xla:literal_util", - "//xla:shape_layout", "//xla:shape_util", - "//xla:test_helpers", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_module_group", - "//xla/hlo/pass:hlo_pass", "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:test_helpers", "//xla/hlo/testlib:verified_hlo_module", - "//xla/hlo/utils:hlo_query", - "//xla/service:backend", - "//xla/service:computation_layout", "//xla/service:computation_placer_hdr", "//xla/service:executable", "//xla/service:hlo_module_config", "//xla/service:hlo_module_util", - "//xla/service:hlo_runner", "//xla/service:hlo_runner_interface", "//xla/service:hlo_verifier", "//xla/service:interpreter_plugin", # reference backend - "//xla/service:platform_util", - "//xla/stream_executor:device_memory_allocator", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -274,10 +244,43 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:protobuf", + ], +) + +cc_library( + name = "hlo_runner_agnostic_reference_mixin", + testonly = True, + srcs = ["hlo_runner_agnostic_reference_mixin.cc"], + hdrs = ["hlo_runner_agnostic_reference_mixin.h"], + deps = [ + ":hlo_runner_agnostic_test_base", + ":literal_test_util", + ":test_utils", + "//xla:error_spec", + "//xla:literal", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/service:hlo_runner_interface", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "hlo_pjrt_interpreter_reference_mixin", + testonly = True, + hdrs = ["hlo_pjrt_interpreter_reference_mixin.h"], + deps = [ + ":hlo_runner_agnostic_reference_mixin", + "//xla/pjrt/interpreter:interpreter_client", + "//xla/service:hlo_runner_pjrt", ], ) @@ -292,15 +295,12 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/pjrt:pjrt_client", - "//xla/service:hlo_runner", + "//xla/pjrt/interpreter:interpreter_client", "//xla/service:hlo_runner_interface", "//xla/service:hlo_runner_pjrt", "//xla/service:interpreter_plugin", # reference backend - "//xla/service:platform_util", - "//xla/stream_executor:platform", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:logging", ], ) @@ -348,7 +348,6 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:status_macros", - "//xla:test_helpers", "//xla:types", "//xla:xla_data_proto_cc", "//xla/client:client_library", @@ -356,6 +355,7 @@ cc_library( "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", + "//xla/hlo/testlib:test_helpers", "//xla/service:interpreter_plugin", # reference backend "//xla/service:platform_util", "//xla/stream_executor:stream_executor_h", @@ -377,7 +377,7 @@ cc_library( hdrs = ["llvm_irgen_test_base.h"], deps = [ ":codegen_test_base", - ":filecheck", + "//xla/hlo/testlib:filecheck", "//xla/service:llvm_compiler", "//xla/service/llvm_ir:llvm_util", "//xla/tsl/lib/core:status_test_util", @@ -416,16 +416,16 @@ cc_library( hdrs = ["local_client_test_base.h"], deps = [ ":client_library_test_base", - ":verified_hlo_module", "//xla:shape_util", "//xla:status_macros", - "//xla:test_helpers", "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:client_library", "//xla/client:local_client", "//xla/hlo/builder:xla_computation", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test_helpers", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:computation_placer", "//xla/service:hlo_module_config", "//xla/service:local_service", @@ -448,16 +448,15 @@ cc_library( xla_test( name = "bad_rng_shape_validation_test", srcs = ["bad_rng_shape_validation_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":xla_internal_test_main", - "//xla:test", - "//xla:types", + "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", + "//xla/hlo/testlib:test", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:logging", ], @@ -479,17 +478,17 @@ xla_test( }, {}, ), - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":literal_test_util", - ":verified_hlo_module", ":xla_internal_test_main", "//xla:literal", "//xla:status_macros", "//xla/client:client_library", "//xla/client:local_client", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:backend", "//xla/service:executable", "//xla/stream_executor:stream_executor_memory_allocator", @@ -504,7 +503,7 @@ xla_test( "conv_depthwise_test.cc", ], shard_count = 50, - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":conv_depthwise_common", @@ -513,10 +512,10 @@ xla_test( ":xla_internal_test_main", "//xla:execution_options_util", "//xla:status_macros", - "//xla:test", "//xla/hlo/builder:xla_computation", + "//xla/hlo/testlib:test", "//xla/hlo/transforms:despecializer", - "//xla/hlo/transforms:float_normalization", + "//xla/hlo/transforms/simplifiers:float_normalization", ], ) @@ -525,7 +524,7 @@ xla_test( timeout = "long", srcs = ["conv_depthwise_backprop_filter_test.cc"], shard_count = 40, - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -533,10 +532,10 @@ xla_test( ":xla_internal_test_main", "//xla:execution_options_util", "//xla:status_macros", - "//xla:test", "//xla/hlo/builder:xla_computation", + "//xla/hlo/testlib:test", "//xla/hlo/transforms:despecializer", - "//xla/hlo/transforms:float_normalization", + "//xla/hlo/transforms/simplifiers:float_normalization", ], ) @@ -549,7 +548,7 @@ xla_test( "cpu", ], shard_count = 50, - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -558,10 +557,10 @@ xla_test( ":xla_internal_test_main", "//xla:execution_options_util", "//xla:status_macros", - "//xla:test", "//xla/hlo/builder:xla_computation", + "//xla/hlo/testlib:test", "//xla/hlo/transforms:despecializer", - "//xla/hlo/transforms:float_normalization", + "//xla/hlo/transforms/simplifiers:float_normalization", "@com_google_absl//absl/algorithm:container", ], ) @@ -569,19 +568,19 @@ xla_test( xla_test( name = "check_execution_arity_test", srcs = ["check_execution_arity_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":test_macros_header", ":xla_internal_test_main", "//xla:literal", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "@com_google_absl//absl/status:statusor", ], ) @@ -589,15 +588,15 @@ xla_test( xla_test( name = "query_inferred_shape_test", srcs = ["query_inferred_shape_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":xla_internal_test_main", "//xla:shape_util", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", + "//xla/hlo/testlib:test_helpers", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:test", ], @@ -607,7 +606,7 @@ xla_test( name = "while_test", srcs = ["while_test.cc"], # placeholder for extra args for while_test - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -635,7 +634,7 @@ xla_test( xla_test( name = "axpy_simple_test", srcs = ["axpy_simple_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -650,7 +649,7 @@ xla_test( xla_test( name = "map_test", srcs = ["map_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -661,14 +660,14 @@ xla_test( "//xla:array2d", "//xla:literal", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", "//xla/hlo/builder/lib:arithmetic", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/status:statusor", ], @@ -681,7 +680,7 @@ xla_test( shard_count = 30, tags = [ "optonly", - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_base", @@ -706,7 +705,7 @@ xla_test( xla_test( name = "pred_test", srcs = ["pred_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":xla_internal_test_main", @@ -722,7 +721,7 @@ xla_test( xla_test( name = "select_test", srcs = ["select_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -740,7 +739,7 @@ xla_test( name = "conditional_test", srcs = ["conditional_test.cc"], shard_count = 2, - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -754,7 +753,7 @@ xla_test( xla_test( name = "unary_op_test", srcs = ["unary_op_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -780,7 +779,8 @@ xla_test( "gpu", ], tags = ["test_xla_cpu_thunks", - "cuda-only",], #(TODO)(rocm): weekly sync 24-11-05 + "cuda-only", + "test_xla_cpu_no_thunks",], #(TODO)(rocm): weekly sync 24-11-05 deps = [ ":client_library_test_base", ":literal_test_util", @@ -799,7 +799,7 @@ xla_test( name = "scalar_computations_test", srcs = ["scalar_computations_test.cc"], shard_count = 32, - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -808,12 +808,12 @@ xla_test( "//xla:literal", "//xla:literal_util", "//xla:status_macros", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", + "//xla/hlo/testlib:test_helpers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -824,17 +824,17 @@ xla_test( xla_test( name = "deallocation_test", srcs = ["deallocation_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":test_macros_header", ":xla_internal_test_main", - "//xla:test", - "//xla:test_helpers", "//xla/client:global_data", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", ], @@ -843,20 +843,20 @@ xla_test( xla_test( name = "deconstruct_tuple_test", srcs = ["deconstruct_tuple_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":test_macros_header", ":xla_internal_test_main", "//xla:literal", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:test", @@ -866,14 +866,10 @@ xla_test( xla_test( name = "array_elementwise_ops_test", srcs = ["array_elementwise_ops_test.cc"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), shard_count = 25, - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", - ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", "//xla:array2d", @@ -883,20 +879,16 @@ xla_test( "//xla:fp_util", "//xla:literal", "//xla:shape_util", - "//xla:test", "//xla:types", - "//xla/client:global_data", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", + "//xla/hlo/testlib:test", + "//xla/stream_executor:device_description", "@com_google_absl//absl/base", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:ml_dtypes", - "@ml_dtypes//:float8", - ] + if_rocm_is_configured([ - # keep sorted - "@local_config_rocm//rocm:rocm_headers", - ]), + ], ) cc_library( @@ -910,17 +902,17 @@ cc_library( ":test_macros_header", "//xla:execution_options_util", "//xla:status_macros", - "//xla:test", "//xla/hlo/builder:xla_computation", + "//xla/hlo/testlib:test", "//xla/hlo/transforms:despecializer", - "//xla/hlo/transforms:float_normalization", + "//xla/hlo/transforms/simplifiers:float_normalization", ], ) xla_test( name = "reduce_precision_test", srcs = ["reduce_precision_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -929,11 +921,11 @@ xla_test( "//xla:array2d", "//xla:literal", "//xla:shape_util", - "//xla:test", "//xla:types", "//xla/client:global_data", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", + "//xla/hlo/testlib:test", "@com_google_absl//absl/base", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -943,7 +935,7 @@ xla_test( xla_test( name = "fft_test", srcs = ["fft_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":test_macros_header", @@ -966,7 +958,7 @@ xla_test( }, tags = [ "optonly", - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", ], deps = [ ":hlo_test_base", @@ -984,7 +976,7 @@ xla_test( shard_count = 20, tags = [ "optonly", - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_base", @@ -1000,7 +992,6 @@ xla_test( "//xla:literal_util", "//xla:reference_util", "//xla:shape_util", - "//xla:test_helpers", "//xla:types", "//xla/client:client_library", "//xla/client:executable_build_options", @@ -1009,12 +1000,15 @@ xla_test( "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/builder/lib:matrix", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test_helpers", "//xla/service", "//xla/service:platform_util", "//xla/service:shaped_buffer", "//xla/stream_executor:device_description", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1042,7 +1036,7 @@ xla_test( "optonly", # TODO(b/151340488): Timed out on 2020-03-12. "nozapfhahn", - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_base", @@ -1060,7 +1054,10 @@ xla_test( "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/builder/lib:matrix", "//xla/hlo/parser:hlo_parser", + "//xla/service:platform_util", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test", @@ -1088,7 +1085,7 @@ xla_test( tags = [ "nozapfhahn", "optonly", - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_base", @@ -1106,7 +1103,10 @@ xla_test( "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/builder/lib:matrix", "//xla/hlo/parser:hlo_parser", + "//xla/service:platform_util", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test", @@ -1122,28 +1122,31 @@ xla_test( srcs = ["gather_operation_test.cc"], shard_count = 20, tags = [ - "test_hlo_pjrt_runner", - "test_xla_cpu_thunks", + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_base", - ":hlo_test_base", + ":hlo_pjrt_test_base", ":test_macros_header", ":xla_internal_test_main", "//xla:array", "//xla:execution_options_util", + "//xla:literal", "//xla:literal_util", - "//xla:status_macros", - "//xla:test", "//xla/hlo/builder:xla_builder", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "//xla/service", + "//xla/service:hlo_module_config", + "@local_tsl//tsl/platform:statusor", ], ) xla_test( name = "scatter_test", srcs = ["scatter_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], # TODO(b/245550554): enable Pjrt runner for scatter test once it's fixed. deps = [ ":client_library_test_base", @@ -1155,8 +1158,8 @@ xla_test( "//xla:literal", "//xla:shape_util", "//xla:status_macros", - "//xla:test", "//xla:types", + "//xla/hlo/testlib:test", "@com_google_absl//absl/strings", ], ) @@ -1175,6 +1178,7 @@ xla_test( "optonly", "test_xla_cpu_thunks", "cuda-only", #TODO(rocm): weekly sync 24-10-01 + "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_base", @@ -1192,7 +1196,10 @@ xla_test( "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/builder/lib:matrix", "//xla/hlo/parser:hlo_parser", + "//xla/service:platform_util", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test", @@ -1206,7 +1213,7 @@ xla_test( xla_test( name = "transpose_test", srcs = ["transpose_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -1226,7 +1233,7 @@ xla_test( xla_test( name = "constants_test", srcs = ["constants_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -1258,6 +1265,7 @@ CONVOLUTION_TEST_DEPS = [ "//xla:shape_util", "@com_google_absl//absl/status:statusor", "//xla:util", + "//xla:window_util", "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", @@ -1281,7 +1289,7 @@ xla_test( "optonly", # Timed out on 2020-07-18 "nozapfhahn", - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", ], deps = CONVOLUTION_TEST_DEPS + [ "//xla:error_spec", @@ -1328,7 +1336,7 @@ xla_test( tags = [ "cuda-only", "optonly", - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", ], deps = CONVOLUTION_TEST_DEPS + [ "//xla:array3d", @@ -1438,7 +1446,7 @@ xla_test( "cpu": ["nomsan"], }, shard_count = 30, - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -1461,7 +1469,7 @@ xla_test( timeout = "long", srcs = ["convolution_dimension_numbers_test.cc"], shard_count = 20, - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -1469,10 +1477,10 @@ xla_test( ":xla_internal_test_main", "//xla:array4d", "//xla:reference_util", - "//xla:test", "//xla/client:local_client", "//xla/hlo/builder:padding", "//xla/hlo/builder:xla_builder", + "//xla/hlo/testlib:test", "@com_google_absl//absl/status:statusor", ], ) @@ -1507,7 +1515,7 @@ xla_test( "interpreter", ], shard_count = 40, - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -1520,8 +1528,6 @@ xla_test( "//xla:literal", "//xla:reference_util", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", @@ -1531,6 +1537,8 @@ xla_test( "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/builder/lib:math", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/tsl/lib/math:math_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -1543,7 +1551,7 @@ xla_test( name = "bfloat16_test", srcs = ["bfloat16_test.cc"], shard_count = 40, - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -1556,13 +1564,13 @@ xla_test( "//xla:literal", "//xla:reference_util", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", "//xla:util", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:test", ], @@ -1571,12 +1579,12 @@ xla_test( xla_test( name = "float8_test", srcs = ["float8_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":xla_internal_test_main", - "//xla:test", "//xla/hlo/builder:xla_builder", + "//xla/hlo/testlib:test", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:ml_dtypes", ], @@ -1589,16 +1597,16 @@ xla_test( "cpu", "gpu", ], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":test_macros_header", ":test_utils", ":xla_internal_test_main", "//xla:literal", - "//xla:test", - "//xla:test_helpers", "//xla/hlo/builder:xla_builder", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "@com_google_absl//absl/status:statusor", ], ) @@ -1611,14 +1619,14 @@ xla_test( "gpu", "interpreter", ], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", ":xla_internal_test_main", - "//xla:test", "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", ], @@ -1629,7 +1637,7 @@ xla_test( timeout = "long", srcs = ["slice_test.cc"], shard_count = 40, - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -1650,7 +1658,7 @@ xla_test( xla_test( name = "multidimensional_slice_test", srcs = ["multidimensional_slice_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -1669,7 +1677,7 @@ xla_test( timeout = "moderate", srcs = ["dynamic_ops_test.cc"], shard_count = 4, - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -1678,10 +1686,10 @@ xla_test( ":xla_internal_test_main", "//xla:array2d", "//xla:reference_util", - "//xla:test_helpers", "//xla/client:client_library", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", + "//xla/hlo/testlib:test_helpers", "//xla/service:computation_placer", "//xla/service:local_service", "//xla/service:platform_util", @@ -1698,7 +1706,7 @@ xla_test( xla_test( name = "tuple_test", srcs = ["tuple_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -1708,12 +1716,12 @@ xla_test( "//xla:array2d", "//xla:literal_util", "//xla:shape_util", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:test_helpers", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:test", @@ -1723,7 +1731,7 @@ xla_test( xla_test( name = "vector_ops_reduce_test", srcs = ["vector_ops_reduce_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -1745,7 +1753,7 @@ xla_test( shard_count = 31, tags = [ "optonly", - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_base", @@ -1783,7 +1791,7 @@ xla_test( "cpu", "gpu", ], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":test_macros_header", @@ -1837,7 +1845,7 @@ xla_test( shard_count = 40, tags = [ "optonly", - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", ], xla_test_library_deps = [":reduce_window_test_library"], deps = [ @@ -1855,7 +1863,7 @@ xla_test( "no_mac", # b/194731834 "nozapfhahn", "optonly", - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_base", @@ -1879,17 +1887,18 @@ xla_test( name = "copy_test", srcs = ["copy_test.cc"], tags = [ - "test_hlo_pjrt_runner", - "test_xla_cpu_thunks", + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_base", - ":hlo_test_base", + ":hlo_pjrt_test_base", ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", "//xla:array3d", "//xla:array4d", + "//xla:error_spec", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -1912,7 +1921,7 @@ xla_test( "cpu", "interpreter", ], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":test_macros_header", @@ -1934,7 +1943,7 @@ xla_test( xla_test( name = "sort_test", srcs = ["sort_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":test_macros_header", @@ -1950,7 +1959,7 @@ xla_test( xla_test( name = "topk_test", srcs = ["topk_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":test_macros_header", @@ -1965,7 +1974,7 @@ xla_test( name = "runtime_topk_test", srcs = ["runtime_topk_test.cc"], backends = ["cpu"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":literal_test_util", @@ -1982,7 +1991,7 @@ xla_test( xla_test( name = "token_hlo_test", srcs = ["token_hlo_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":literal_test_util", @@ -2001,7 +2010,7 @@ xla_test( xla_test( name = "call_test", srcs = ["call_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -2010,10 +2019,10 @@ xla_test( "//xla:literal", "//xla:literal_util", "//xla:shape_util", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", + "//xla/hlo/testlib:test_helpers", "@local_tsl//tsl/platform:test", ], ) @@ -2022,7 +2031,6 @@ xla_test( name = "custom_call_test", srcs = ["custom_call_test.cc"], backends = ["cpu"], - tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -2068,7 +2076,7 @@ xla_test( xla_test( name = "binop_scaling_test", srcs = ["binop_scaling_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -2086,7 +2094,7 @@ xla_test( xla_test( name = "broadcast_simple_test", srcs = ["broadcast_simple_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -2096,9 +2104,9 @@ xla_test( "//xla:array4d", "//xla:literal", "//xla:literal_util", - "//xla:test", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", + "//xla/hlo/testlib:test", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], @@ -2107,7 +2115,7 @@ xla_test( xla_test( name = "pad_test", srcs = ["pad_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":test_macros_header", @@ -2131,7 +2139,7 @@ xla_test( xla_test( name = "fmax_fmin_test", srcs = ["fmax_fmin_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":test_macros_header", @@ -2145,7 +2153,7 @@ xla_test( xla_test( name = "log_test", srcs = ["log_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -2161,7 +2169,7 @@ xla_test( name = "matrix_ops_simple_test", timeout = "long", srcs = ["matrix_ops_simple_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -2172,19 +2180,18 @@ xla_test( "//xla:literal", "//xla:reference_util", "//xla:shape_util", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", + "//xla/hlo/testlib:test_helpers", + "//xla/stream_executor:device_description", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", - ] + if_rocm_is_configured([ - # keep sorted - "@local_config_rocm//rocm:rocm_headers", - ]), + ], ) xla_test( @@ -2197,7 +2204,7 @@ xla_test( "no_mac", "noasan", "nosan", - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_base", @@ -2205,11 +2212,11 @@ xla_test( ":xla_internal_test_main", "//xla:literal", "//xla:shape_util", - "//xla:test", "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", + "//xla/hlo/testlib:test", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:protobuf", @@ -2221,7 +2228,7 @@ xla_test( name = "rng_test", srcs = ["rng_test.cc"], backends = ["cpu"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", "//xla:literal", @@ -2241,7 +2248,7 @@ xla_test( name = "reshape_test", srcs = ["reshape_test.cc"], shard_count = 30, - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -2256,12 +2263,12 @@ xla_test( "//xla:literal_util", "//xla:reference_util", "//xla:shape_util", - "//xla:test", "//xla:types", "//xla:xla_data_proto_cc", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", + "//xla/hlo/testlib:test", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", @@ -2278,13 +2285,13 @@ xla_test( "gpu": ["notsan"], # TODO(b/345034145): Fix tsan error. }, disabled_backends = ["interpreter"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:literal", "//xla:literal_util", - "//xla:test", + "//xla/hlo/testlib:test", "@local_tsl//tsl/platform:statusor", ], ) @@ -2292,7 +2299,7 @@ xla_test( xla_test( name = "reverse_test", srcs = ["reverse_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":test_macros_header", @@ -2315,7 +2322,7 @@ xla_test( name = "stochastic_convert_test", srcs = ["stochastic_convert_test.cc"], backends = ["cpu"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", "//xla:error_spec", @@ -2333,7 +2340,7 @@ xla_test( xla_test( name = "vector_ops_simple_test", srcs = ["vector_ops_simple_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -2341,13 +2348,13 @@ xla_test( ":xla_internal_test_main", "//xla:array4d", "//xla:shape_util", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", "//xla/hlo/builder/lib:arithmetic", + "//xla/hlo/testlib:test_helpers", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:test", ], @@ -2356,7 +2363,7 @@ xla_test( xla_test( name = "concat_test", srcs = ["concat_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -2367,11 +2374,11 @@ xla_test( "//xla:array3d", "//xla:literal_util", "//xla:reference_util", - "//xla:test", - "//xla:test_helpers", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:test", ], @@ -2383,6 +2390,7 @@ xla_test( tags = [ "test_xla_cpu_thunks", "cuda-only", #TODO(rocm): weekly sync 24-10-01 + "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_base", @@ -2408,18 +2416,15 @@ xla_test( "interpreter", ], tags = [ - "test_hlo_pjrt_runner", - "test_xla_cpu_thunks", + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", ], deps = [ - ":hlo_test_base", + ":hlo_pjrt_test_base", ":test_macros_header", ":xla_internal_test_main", - "//xla:literal", "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", + "//xla/hlo/testlib:test", ], ) @@ -2440,25 +2445,28 @@ xla_test( "gpu", "cpu", ], - tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", ":literal_test_util", ":test_macros_header", ":test_utils", - ":verified_hlo_module", ":xla_internal_test_main", "//xla:literal", "//xla:literal_util", "//xla:shape_util", + "//xla:types", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:computation_placer", - "//xla/service:executable", "//xla/service:hlo_module_config", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:blocking_counter", - "@local_tsl//tsl/platform:env", + "@ml_dtypes//:float8", ], ) @@ -2479,19 +2487,14 @@ xla_test( ":hlo_test_base", ":literal_test_util", ":test_macros_header", - ":test_utils", - ":verified_hlo_module", ":xla_internal_test_main", "//xla:error_spec", "//xla:literal", "//xla:literal_util", - "//xla:shape_util", - "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:computation_placer", - "//xla/service:executable", "//xla/service:hlo_module_config", "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", @@ -2513,17 +2516,15 @@ xla_test( ], deps = [ ":hlo_test_base", - ":literal_test_util", ":xla_internal_test_main", "//xla:literal", - "//xla:literal_util", "//xla/hlo/testlib:verified_hlo_module", "//xla/service:hlo_module_config", - "@com_google_absl//absl/log", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", ], ) @@ -2557,21 +2558,20 @@ xla_test( "//xla/service:hlo_module_config", "//xla/service/gpu:backend_configs_cc", "//xla/stream_executor:device_description", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", ], ) xla_test( name = "collective_pipeliner_execution_test", srcs = ["collective_pipeliner_execution_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":xla_internal_test_main", @@ -2580,7 +2580,7 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/service:collective_pipeliner", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -2598,23 +2598,30 @@ xla_test( ], }, backends = ["gpu"], + tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ - ":hlo_test_base", - ":test_macros_header", + ":hlo_pjrt_test_base", + ":literal_test_util", ":xla_internal_test_main", "//xla:literal", "//xla:literal_util", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test_helpers", + "//xla/service:computation_placer_hdr", + "//xla/service:hlo_runner_interface", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "@com_google_absl//absl/strings:string_view", ], ) xla_test( name = "bitcast_convert_test", srcs = ["bitcast_convert_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -2642,15 +2649,15 @@ xla_test( ":xla_internal_test_main", "//xla:literal", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", ], ) xla_test( name = "floor_ceil_test", srcs = ["floor_ceil_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":test_macros_header", @@ -2671,15 +2678,15 @@ xla_test( "cpu", "gpu", ], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":xla_internal_test_main", "//xla:literal", "//xla:literal_util", "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", + "//xla/hlo/testlib:test", + "//xla/hlo/testlib:test_helpers", "//xla/service:hlo_proto_cc", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", @@ -2693,7 +2700,7 @@ xla_test( xla_test( name = "value_inference_test", srcs = ["value_inference_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":literal_test_util", ":test_macros_header", @@ -2702,7 +2709,6 @@ xla_test( "//xla:literal", "//xla:shape_util", "//xla:status_macros", - "//xla:test", "//xla:xla_data_proto_cc", "//xla/client:client_library", "//xla/client:global_data", @@ -2711,6 +2717,7 @@ xla_test( "//xla/hlo/builder:xla_computation", "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/builder/lib:prng", + "//xla/hlo/testlib:test", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -2723,7 +2730,7 @@ xla_test( xla_test( name = "compute_constant_test", srcs = ["compute_constant_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":literal_test_util", ":test_macros_header", @@ -2732,12 +2739,12 @@ xla_test( "//xla:literal", "//xla:shape_util", "//xla:status_macros", - "//xla:test", "//xla:xla_data_proto_cc", "//xla/client:client_library", "//xla/client:global_data", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", + "//xla/hlo/testlib:test", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -2747,7 +2754,7 @@ xla_test( xla_test( name = "client_test", srcs = ["client_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -2756,12 +2763,12 @@ xla_test( ":xla_internal_test_main", "//xla:shape_util", "//xla:status_macros", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder:xla_computation", + "//xla/hlo/testlib:test_helpers", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:test", ], @@ -2770,7 +2777,7 @@ xla_test( xla_test( name = "replay_test", srcs = ["replay_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -2794,15 +2801,17 @@ xla_test( name = "broadcast_test", srcs = ["broadcast_test.cc"], tags = [ - "test_hlo_pjrt_runner", - "test_xla_cpu_thunks", + "test_migrated_to_hlo_runner_pjrt", + "test_xla_cpu_no_thunks", ], deps = [ - ":hlo_test_base", + ":hlo_pjrt_test_base", ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", - "//xla:literal", + "//xla:array3d", + "//xla:array4d", + "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -2817,13 +2826,13 @@ xla_test( "cpu", "gpu", ], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", "//xla:literal_util", - "//xla:test_helpers", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", + "//xla/hlo/testlib:test_helpers", "//xla/service:backend", "//xla/service:llvm_compiler", "//xla/stream_executor:device_description", @@ -2840,7 +2849,7 @@ xla_test( xla_test( name = "round_trip_packed_literal_test", srcs = ["round_trip_packed_literal_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -2868,7 +2877,7 @@ xla_test( "gpu", "interpreter", ], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -2927,7 +2936,7 @@ xla_cc_test( linkstatic = 1, tags = [ "not_run:arm", # b/341355246 - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", ], deps = [ "//xla:executable_run_options", @@ -2941,7 +2950,7 @@ xla_cc_test( xla_test( name = "local_client_allocation_test", srcs = ["local_client_allocation_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":literal_test_util", ":local_client_test_base", @@ -2965,7 +2974,7 @@ xla_test( shard_count = 30, tags = [ "optonly", - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", ], deps = [ ":literal_test_util", @@ -2975,12 +2984,12 @@ xla_test( ":xla_internal_test_main", "//xla:literal", "//xla:shape_util", - "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/client:client_library", "//xla/client:local_client", "//xla/hlo/builder:sharding_builder", "//xla/hlo/builder:xla_builder", + "//xla/hlo/testlib:test_helpers", "//xla/service:platform_util", "//xla/service:shaped_buffer", "//xla/service:transfer_manager", @@ -3004,7 +3013,7 @@ xla_test( # Outfeed ops are not supported on the interpreter backend. "interpreter", ], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":local_client_test_base", ":test_macros_header", @@ -3018,12 +3027,12 @@ xla_cc_test( srcs = [ "hlo_metadata_test.cc", ], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":local_client_test_base", - "//xla:test_helpers", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", + "//xla/hlo/testlib:test_helpers", "//xla/service:cpu_plugin", "//xla/service:local_service", "@local_tsl//tsl/platform:test_main", @@ -3033,7 +3042,7 @@ xla_cc_test( xla_test( name = "round_trip_transfer_test", srcs = ["round_trip_transfer_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -3051,7 +3060,7 @@ xla_test( xla_test( name = "reshape_motion_test", srcs = ["reshape_motion_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -3063,10 +3072,10 @@ xla_test( "//xla:reference_util", "//xla:shape_util", "//xla:status_macros", - "//xla:test_helpers", "//xla/client:global_data", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", + "//xla/hlo/testlib:test_helpers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:test", @@ -3076,7 +3085,7 @@ xla_test( xla_test( name = "deep_graph_test", srcs = ["deep_graph_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":client_library_test_base", ":xla_internal_test_main", @@ -3087,11 +3096,11 @@ xla_test( xla_cc_test( name = "literal_test_util_test", srcs = ["literal_test_util_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":literal_test_util", "//xla:literal", - "//xla:test_helpers", + "//xla/hlo/testlib:test_helpers", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", @@ -3105,7 +3114,7 @@ xla_test( name = "transfer_manager_test", srcs = ["transfer_manager_test.cc"], shard_count = 50, - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":literal_test_util", ":local_client_test_base", @@ -3136,14 +3145,14 @@ xla_test( "cpu", "gpu", ], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", - "//xla:test", "//xla:types", + "//xla/hlo/testlib:test", ], ) @@ -3156,7 +3165,7 @@ xla_test( deps = [ ":hlo_test_base", ":xla_internal_test_main", # fixdeps: keep - "//xla:test", + "//xla/hlo/testlib:test", "//xla/service:cpu_plugin", # reference backend "//xla/service:platform_util", "@local_tsl//tsl/platform:path", @@ -3169,7 +3178,7 @@ xla_test( srcs = ["test_utils_test.cc"], # There is nothing backend specific in this test, so just pick an arbitrary backend. backends = ["cpu"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":local_client_test_base", ":test_macros_header", @@ -3193,7 +3202,7 @@ xla_test( }, shard_count = 50, tags = [ - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_base", @@ -3215,7 +3224,7 @@ xla_cc_test( name = "multiple_devices_on_host_test", srcs = ["multiple_devices_on_host_test.cc"], args = ["--xla_force_host_platform_device_count=4"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":xla_internal_test_main", # fixdeps: keep "//xla:shape_util", @@ -3236,28 +3245,28 @@ xla_test( tags = [ # Disabled in OSS until nvidia publicly releases a fixed ptxas. "no_oss", - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", ], deps = [ ":hlo_test_base", ":test_macros_header", ":xla_internal_test_main", # fixdeps: keep "//xla:debug_options_flags", - "//xla:test", + "//xla/hlo/testlib:test", ], ) xla_test( name = "get_dimension_size_test", srcs = ["get_dimension_size_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:literal", "//xla:literal_util", - "//xla:test", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:statusor", ], @@ -3269,14 +3278,14 @@ xla_test( backend_tags = { "gpu": ["notsan"], # TODO(b/345034145): Fix tsan error. }, - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:literal", "//xla:literal_util", - "//xla:test", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:statusor", ], @@ -3289,7 +3298,7 @@ xla_test( shard_count = 3, tags = [ "optonly", - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_base", @@ -3299,12 +3308,12 @@ xla_test( "//xla:array", "//xla:array2d", "//xla:literal", - "//xla:test", "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder/lib:math", "//xla/hlo/builder/lib:matrix", + "//xla/hlo/testlib:test", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -3317,7 +3326,7 @@ xla_test( shard_count = 10, tags = [ "optonly", - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", ], deps = [ ":client_library_test_base", @@ -3326,11 +3335,11 @@ xla_test( ":xla_internal_test_main", "//xla:array2d", "//xla:literal", - "//xla:test", "//xla:types", "//xla/hlo/builder:xla_builder", "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/builder/lib:matrix", + "//xla/hlo/testlib:test", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", ], @@ -3339,26 +3348,26 @@ xla_test( xla_test( name = "constant_reduction_function_test", srcs = ["constant_reduction_function_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", - "//xla:test", "//xla:types", + "//xla/hlo/testlib:test", ], ) xla_cc_test( name = "tile_assignment_test", srcs = ["tile_assignment_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":xla_internal_test_main", "//xla:array3d", - "//xla:test", "//xla/hlo/ir:tile_assignment", + "//xla/hlo/testlib:test", "@com_google_absl//absl/hash", ], ) @@ -3366,15 +3375,15 @@ xla_cc_test( xla_test( name = "numerics_test", srcs = ["numerics_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":test_macros_header", ":xla_internal_test_main", "//xla:literal_util", - "//xla:test", "//xla:types", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -3387,7 +3396,7 @@ xla_test( backend_tags = { "gpu": ["notsan"], }, - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":literal_test_util", @@ -3395,7 +3404,7 @@ xla_test( "//xla:literal", "//xla:literal_util", "//xla:shape_util", - "//xla:test", + "//xla/hlo/testlib:test", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -3409,14 +3418,14 @@ xla_test( xla_test( name = "batch_norm_grad_test", srcs = ["batch_norm_grad_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:literal", "//xla:literal_util", - "//xla:test", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:statusor", ], @@ -3425,14 +3434,14 @@ xla_test( xla_test( name = "batch_norm_training_test", srcs = ["batch_norm_training_test.cc"], - tags = ["test_xla_cpu_thunks"], + tags = ["test_xla_cpu_no_thunks"], deps = [ ":hlo_test_base", ":xla_internal_test_main", # fixdeps: keep "//xla:literal", "//xla:literal_util", - "//xla:test", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:test", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:statusor", ], @@ -3503,3 +3512,22 @@ xla_test( "@local_tsl//tsl/platform:path", ], ) + +xla_test( + name = "xnn_fusion_test", + srcs = ["xnn_fusion_test.cc"], + backends = ["cpu"], + deps = [ + ":hlo_test_base", + "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla/tsl/platform:test", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) diff --git a/third_party/xla/xla/tests/all_reduce_test.cc b/third_party/xla/xla/tests/all_reduce_test.cc index 714ee5fc0c3a94..d4ce1b89b63ab4 100644 --- a/third_party/xla/xla/tests/all_reduce_test.cc +++ b/third_party/xla/xla/tests/all_reduce_test.cc @@ -16,17 +16,15 @@ limitations under the License. #include #include -#include "xla/literal.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal_util.h" -#include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" namespace xla { namespace { -class TrivialAllReduceTest : public HloTestBase {}; +using TrivialAllReduceTest = HloPjRtTestBase; // Currently the CPU and GPU backends only support AllReduce with one // replica. But we can at least check this. diff --git a/third_party/xla/xla/tests/array_elementwise_ops_test.cc b/third_party/xla/xla/tests/array_elementwise_ops_test.cc index c12ce79a06e8fa..fde51c9d99b16d 100644 --- a/third_party/xla/xla/tests/array_elementwise_ops_test.cc +++ b/third_party/xla/xla/tests/array_elementwise_ops_test.cc @@ -25,34 +25,29 @@ limitations under the License. #include #include #include +#include #include #include "absl/base/casts.h" #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "ml_dtypes/include/float8.h" #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" -#include "xla/client/global_data.h" #include "xla/client/local_client.h" #include "xla/comparison_util.h" #include "xla/fp_util.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/testlib/test.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/primitive_util.h" -#include "xla/test.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/types.h" #include "tsl/platform/ml_dtypes.h" -#if TENSORFLOW_USE_ROCM -#include "rocm/rocm_config.h" -#endif - namespace xla { namespace { @@ -1752,11 +1747,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { } XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { -#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION == 50700 - GTEST_SKIP() - << "This test fails on rocm-5.7.0 platform due to a compiler bug"; -#endif - + auto device_description = + client_->backend().default_stream_executor()->GetDeviceDescription(); + bool is_rocm = std::holds_alternative( + device_description.gpu_compute_capability()); + if (is_rocm && device_description.runtime_version() == + stream_executor::SemanticVersion(5, 7, 0)) { + GTEST_SKIP() + << "This test fails on rocm-5.7.0 platform due to a compiler bug"; + } SetFastMathDisabled(true); XlaBuilder builder(TestName()); auto eps = std::numeric_limits::epsilon(); diff --git a/third_party/xla/xla/tests/bad_rng_shape_validation_test.cc b/third_party/xla/xla/tests/bad_rng_shape_validation_test.cc index cb077b05dda71d..c4a8efbc7509e0 100644 --- a/third_party/xla/xla/tests/bad_rng_shape_validation_test.cc +++ b/third_party/xla/xla/tests/bad_rng_shape_validation_test.cc @@ -16,15 +16,12 @@ limitations under the License. // Tests that passing a bad shape to RNG's output parameter causes a validation // failure rather than causing a crash. -#include - #include "absl/status/statusor.h" -#include "xla/client/local_client.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" +#include "xla/shape.h" #include "xla/tests/client_library_test_base.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" @@ -37,29 +34,19 @@ TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) { XlaBuilder builder(TestName()); auto zero = ConstantR0(&builder, 0.0); auto one = ConstantR0(&builder, 1.0); - Shape default_constructed; - RngUniform(zero, one, default_constructed); - - absl::StatusOr computation = builder.Build(); - EXPECT_FALSE(computation.ok()); - LOG(INFO) << "status received: " << computation.status(); - EXPECT_THAT(computation.status().message(), - ::testing::HasSubstr("shape has invalid")); + RngUniform(zero, one, Shape()); + EXPECT_FALSE(builder.Build().ok()); } TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { XlaBuilder builder(TestName()); auto zero = ConstantR0(&builder, 0.0); auto one = ConstantR0(&builder, 1.0); - Shape sans_layout; - sans_layout.set_element_type(F32); - sans_layout.add_dimensions(1); - - RngUniform(zero, one, sans_layout); - - absl::StatusOr computation = builder.Build(); - ASSERT_TRUE(computation.ok()); - LOG(INFO) << computation.status(); + Shape shape; + shape.set_element_type(F32); + shape.add_dimensions(1); + RngUniform(zero, one, shape); + EXPECT_TRUE(builder.Build().ok()); } } // namespace diff --git a/third_party/xla/xla/tests/batch_norm_grad_test.cc b/third_party/xla/xla/tests/batch_norm_grad_test.cc index 0bff1da41b90fd..74512febada3c5 100644 --- a/third_party/xla/xla/tests/batch_norm_grad_test.cc +++ b/third_party/xla/xla/tests/batch_norm_grad_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include "absl/status/status.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/tests/batch_norm_training_test.cc b/third_party/xla/xla/tests/batch_norm_training_test.cc index 581a47090cfa01..77386432c733b6 100644 --- a/third_party/xla/xla/tests/batch_norm_training_test.cc +++ b/third_party/xla/xla/tests/batch_norm_training_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include "absl/status/status.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/tests/batch_normalization_test.cc b/third_party/xla/xla/tests/batch_normalization_test.cc index 3b6aebc95cb05d..8569a2b48e651e 100644 --- a/third_party/xla/xla/tests/batch_normalization_test.cc +++ b/third_party/xla/xla/tests/batch_normalization_test.cc @@ -29,11 +29,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" #include "xla/reference_util.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/tests/bfloat16_test.cc b/third_party/xla/xla/tests/bfloat16_test.cc index 22085485fde573..2eb9dea66b8596 100644 --- a/third_party/xla/xla/tests/bfloat16_test.cc +++ b/third_party/xla/xla/tests/bfloat16_test.cc @@ -26,11 +26,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" #include "xla/reference_util.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/tests/broadcast_simple_test.cc b/third_party/xla/xla/tests/broadcast_simple_test.cc index 2876714ab94e02..0f7f5656dc75ee 100644 --- a/third_party/xla/xla/tests/broadcast_simple_test.cc +++ b/third_party/xla/xla/tests/broadcast_simple_test.cc @@ -23,9 +23,9 @@ limitations under the License. #include "xla/array4d.h" #include "xla/client/local_client.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/broadcast_test.cc b/third_party/xla/xla/tests/broadcast_test.cc index c46104f8443195..a45b9309008bfa 100644 --- a/third_party/xla/xla/tests/broadcast_test.cc +++ b/third_party/xla/xla/tests/broadcast_test.cc @@ -16,12 +16,13 @@ limitations under the License. #include #include +#include "xla/array3d.h" +#include "xla/array4d.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/xla_data.pb.h" @@ -30,7 +31,7 @@ limitations under the License. namespace xla { namespace { -class BroadcastTest : public HloTestBase {}; +using BroadcastTest = HloPjRtTestBase; XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { // Test degenerate case of broadcasting a scalar into a scalar. @@ -46,7 +47,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0(42.0), result, - error_spec_)); + kDefaultErrorSpec)); } XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { @@ -63,7 +64,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { EXPECT_TRUE(LiteralTestUtil::Near( LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), result, - error_spec_)); + kDefaultErrorSpec)); } XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { @@ -86,11 +87,11 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { EXPECT_TRUE(LiteralTestUtil::Near( LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), - LiteralSlice(result, {0}), error_spec_)); + LiteralSlice(result, {0}), kDefaultErrorSpec)); EXPECT_TRUE(LiteralTestUtil::Near( LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), - LiteralSlice(result, {1}), error_spec_)); + LiteralSlice(result, {1}), kDefaultErrorSpec)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { @@ -107,7 +108,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { EXPECT_TRUE(LiteralTestUtil::Near( LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), result, - error_spec_)); + kDefaultErrorSpec)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { @@ -126,7 +127,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { EXPECT_TRUE(LiteralTestUtil::Near( LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), result, - error_spec_)); + kDefaultErrorSpec)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { @@ -144,7 +145,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { EXPECT_TRUE(LiteralTestUtil::Near( LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), - result, error_spec_)); + result, kDefaultErrorSpec)); } TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { @@ -165,8 +166,9 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { Array2D pz({{1, 2}, {1, 2}}); expected.FillWithPZ(pz); - EXPECT_TRUE(LiteralTestUtil::Near( - LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(expected), + result, kDefaultErrorSpec)); } TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { @@ -195,8 +197,9 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { } expected.FillWithYX(yx); - EXPECT_TRUE(LiteralTestUtil::Near( - LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(expected), + result, kDefaultErrorSpec)); } XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { @@ -218,7 +221,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array), - result, error_spec_)); + result, kDefaultErrorSpec)); } TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { @@ -237,8 +240,9 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { Array4D expected(64, 64, 3, 3); expected.Fill(1.0f); - EXPECT_TRUE(LiteralTestUtil::Near( - LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(expected), + result, kDefaultErrorSpec)); } TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { @@ -259,8 +263,9 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { Array4D expected(3, 3, 2, 2); expected.FillWithYX(to_broadcast); - EXPECT_TRUE(LiteralTestUtil::Near( - LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(expected), + result, kDefaultErrorSpec)); } TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { @@ -290,8 +295,9 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near( - LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(expected), + result, kDefaultErrorSpec)); } } // namespace diff --git a/third_party/xla/xla/tests/buffer_donation_test.cc b/third_party/xla/xla/tests/buffer_donation_test.cc index 666ebb6dd411c0..dc5176ec69d214 100644 --- a/third_party/xla/xla/tests/buffer_donation_test.cc +++ b/third_party/xla/xla/tests/buffer_donation_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/literal.h" #include "xla/service/backend.h" #include "xla/service/executable.h" @@ -29,7 +30,6 @@ limitations under the License. #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" namespace xla { diff --git a/third_party/xla/xla/tests/build_defs.bzl b/third_party/xla/xla/tests/build_defs.bzl index 827db91a79a66f..2ea2c38888cb08 100644 --- a/third_party/xla/xla/tests/build_defs.bzl +++ b/third_party/xla/xla/tests/build_defs.bzl @@ -22,6 +22,7 @@ NVIDIA_GPU_DEFAULT_BACKENDS = [ "gpu_any", "gpu_a100", "gpu_h100", + "gpu_b100", ] AMD_GPU_DEFAULT_BACKENDS = ["gpu_amd_any"] @@ -285,6 +286,10 @@ def xla_test( "//xla/service:cpu_plugin", "//xla/tests:test_macros_cpu", ] + + # TODO: b/382779188 - Remove this when all tests are migrated to PjRt. + if "test_migrated_to_hlo_runner_pjrt" in tags: + backend_deps.append("//xla/tests:pjrt_cpu_client_registry") elif backend in NVIDIA_GPU_BACKENDS + AMD_GPU_DEFAULT_BACKENDS: backend_deps += [ "//xla/service:gpu_plugin", @@ -295,11 +300,19 @@ def xla_test( if backend in AMD_GPU_DEFAULT_BACKENDS: this_backend_tags.append("gpu") this_backend_copts.append("-DXLA_TEST_BACKEND_GPU=1") + + # TODO: b/382779188 - Remove this when all tests are migrated to PjRt. + if "test_migrated_to_hlo_runner_pjrt" in tags: + backend_deps.append("//xla/tests:pjrt_gpu_client_registry") elif backend == "interpreter": backend_deps += [ "//xla/service:interpreter_plugin", "//xla/tests:test_macros_interpreter", ] + + # TODO: b/382779188 - Remove this when all tests are migrated to PjRt. + if "test_migrated_to_hlo_runner_pjrt" in tags: + backend_deps.append("//xla/tests:pjrt_interpreter_client_registry") elif backend in plugins: backend_deps += plugins[backend]["deps"] this_backend_copts += plugins[backend]["copts"] diff --git a/third_party/xla/xla/tests/call_test.cc b/third_party/xla/xla/tests/call_test.cc index 36aae1aed51de1..4fdfc73db84296 100644 --- a/third_party/xla/xla/tests/call_test.cc +++ b/third_party/xla/xla/tests/call_test.cc @@ -18,10 +18,10 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" -#include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/check_execution_arity_test.cc b/third_party/xla/xla/tests/check_execution_arity_test.cc index fd0f5bd9bf75e0..5ec29108109446 100644 --- a/third_party/xla/xla/tests/check_execution_arity_test.cc +++ b/third_party/xla/xla/tests/check_execution_arity_test.cc @@ -20,10 +20,10 @@ limitations under the License. #include "xla/client/global_data.h" #include "xla/client/local_client.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/tests/cholesky_test.cc b/third_party/xla/xla/tests/cholesky_test.cc index c52ea4b9ea849c..26d10e3dc773b5 100644 --- a/third_party/xla/xla/tests/cholesky_test.cc +++ b/third_party/xla/xla/tests/cholesky_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/lib/matrix.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal.h" -#include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/client_library_test_base.cc b/third_party/xla/xla/tests/client_library_test_base.cc index 01944740eab9ec..1882948c9595a7 100644 --- a/third_party/xla/xla/tests/client_library_test_base.cc +++ b/third_party/xla/xla/tests/client_library_test_base.cc @@ -26,11 +26,11 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/execution_options_util.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal_util.h" #include "xla/service/platform_util.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/test_helpers.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/tests/client_test.cc b/third_party/xla/xla/tests/client_test.cc index 59eafe57b12141..77f2345c7a1ab0 100644 --- a/third_party/xla/xla/tests/client_test.cc +++ b/third_party/xla/xla/tests/client_test.cc @@ -21,9 +21,9 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc index 77cd2a7ee82357..0ca1c15c778e9d 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc +++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" -#include "absl/types/span.h" #include "xla/array.h" #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -43,9 +42,9 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/types.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" namespace xla { namespace { @@ -162,7 +161,7 @@ class AsyncCollectiveOps : public CollectiveOpsTestE2E, {DebugOptions::NOOP, DebugOptions::ALLREDUCE, DebugOptions::ALLGATHER, DebugOptions::REDUCESCATTER, DebugOptions::COLLECTIVEBROADCAST, DebugOptions::ALLTOALL, - DebugOptions::COLLECTIVEPERMUTE}) { + DebugOptions::COLLECTIVEPERMUTE, DebugOptions::RAGGEDALLTOALL}) { debug_options.add_xla_gpu_disable_async_collectives(option); } } @@ -209,7 +208,10 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllReduce) { )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } const bool enable_async_all_reduce = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -246,7 +248,10 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllGather) { } )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } const bool enable_async_all_gather = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, @@ -288,7 +293,10 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllGatherMixedTypes) { } )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } const bool enable_async_all_gather = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, @@ -326,7 +334,10 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncCollectiveBroadcast) { } )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } const bool enable_async_collective_broadcast = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -359,7 +370,10 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncCollectivePermute) { } )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } const bool enable_async_collective_permute = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -403,7 +417,10 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncReduceScatter) { )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } const bool enable_async_reduce_scatter = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -437,7 +454,10 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllToAllWithSplitDim) { } )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } const bool enable_async_all_to_all = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -522,7 +542,10 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllToAllWithoutSplitDim) { } )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } const bool enable_async_all_to_all = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -575,7 +598,10 @@ TEST_P(AsyncCollectiveOps, MatmulReplicated) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -695,7 +721,10 @@ TEST_F(CollectiveOpsTestE2E, WhileLoopReduceScatterCodeMotion) { )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } DebugOptions debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_while_loop_reduce_scatter_code_motion(true); @@ -750,7 +779,10 @@ TEST_F(CollectiveOpsTestE2E, NoAllToAllDecomposition) { } )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -779,16 +811,23 @@ TEST_F(CollectiveOpsTestE2E, NoAllToAllDecomposition) { class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E { public: void CollectiveOpsCompareWindowedNonWindowed( - absl::string_view hlo_text, bool disable_dot_merger = false) { + absl::string_view hlo_text, bool disable_dot_merger = false, + bool enable_a2a_rewrite = false) { const int64_t kNumReplicas = 1; const int64_t kNumPartitions = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); + if (test_runner().device_count() < kNumReplicas * kNumPartitions) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions + << " devices (" << test_runner().device_count() + << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); auto opts = GetDebugOptionsForTest(); opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0); opts.set_xla_gpu_multi_streamed_windowed_einsum(true); + opts.set_xla_gpu_experimental_enable_alltoall_windowed_einsum( + enable_a2a_rewrite); opts.set_xla_gpu_graph_min_graph_size(200); opts.set_xla_gpu_enable_triton_gemm(false); if (disable_dot_merger) { @@ -1062,7 +1101,9 @@ ENTRY main.9_spmd { } )"; - CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr); + CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, + /*disable_dot_merger=*/false, + /*enable_a2a_rewrite=*/true); } TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, @@ -1078,7 +1119,9 @@ ENTRY main.9_spmd { } )"; - CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr); + CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, + /*disable_dot_merger=*/false, + /*enable_a2a_rewrite=*/true); } TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, @@ -1099,7 +1142,9 @@ ENTRY main.9_spmd { } )"; - CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr); + CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, + /*disable_dot_merger=*/false, + /*enable_a2a_rewrite=*/true); } TEST_F(CollectiveOpsTestE2E, CollectivePipelinerF8) { @@ -1157,7 +1202,10 @@ ENTRY entry { )"; const int64_t kNumReplicas = 1; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -1174,7 +1222,11 @@ class CollectiveOpsTestE2EPipelinedNonPipelined : public CollectiveOpsTestE2E { void CollectiveOpsComparePipelinedNonPipelined(absl::string_view hlo_string) { const int64_t kNumReplicas = 1; const int64_t kNumPartitions = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); + if (test_runner().device_count() < kNumReplicas * kNumPartitions) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions + << " devices (" << test_runner().device_count() + << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(kNumReplicas, kNumPartitions); @@ -1368,7 +1420,11 @@ ENTRY entry { const int64_t kNumReplicas = 1; const int64_t kNumPartitions = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); + if (test_runner().device_count() < kNumReplicas * kNumPartitions) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions + << " devices (" << test_runner().device_count() + << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -1420,7 +1476,11 @@ ENTRY entry { const int64_t kNumReplicas = 1; const int64_t kNumPartitions = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); + if (test_runner().device_count() < kNumReplicas * kNumPartitions) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions + << " devices (" << test_runner().device_count() + << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -1516,7 +1576,10 @@ ENTRY entry { )"; const int64_t kNumReplicas = 1; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } const int64_t kNumPartitions = 4; HloModuleConfig config = @@ -1536,7 +1599,7 @@ ENTRY entry { EXPECT_TRUE(executable->has_module()); } -class RaggedAllToAllTestE2E : public CollectiveOpsTestE2E { +class RaggedAllToAllTest : public AsyncCollectiveOps { public: // Creates random test data for a ragged-all-to-all. // @@ -1557,8 +1620,9 @@ class RaggedAllToAllTestE2E : public CollectiveOpsTestE2E { // `input_sizes` is a 2D array of shape [num_replicas, num_replicas]. // `input_sizes[i, j]` is the number of elements in the j-th ragged row of the // i-th replica input. + template void CreateRandomTestData(HloModule* module, - const Array& input_sizes) { + const Array& input_sizes) { auto ragged_all_to_all = FindInstruction(module, HloOpcode::kRaggedAllToAll); EXPECT_THAT(ragged_all_to_all, NotNull()); @@ -1575,12 +1639,12 @@ class RaggedAllToAllTestE2E : public CollectiveOpsTestE2E { std::vector> output_data(num_replicas, Array(ragged_tensor_sizes)); - Array output_sizes = input_sizes; + Array output_sizes = input_sizes; output_sizes.TransposeDimensions({1, 0}); // Computes ragged tensor offsets based on the sizes of the ragged rows. - auto get_offsets = [&](const Array& sizes) { - Array offsets(sizes.dimensions()); + auto get_offsets = [&](const Array& sizes) { + Array offsets(sizes.dimensions()); for (int i = 0; i < num_replicas; ++i) { for (int j = 1; j < num_replicas; ++j) { offsets(i, j) = offsets(i, j - 1) + sizes(i, j - 1); @@ -1589,8 +1653,9 @@ class RaggedAllToAllTestE2E : public CollectiveOpsTestE2E { return offsets; }; - Array input_offsets = get_offsets(input_sizes); - Array output_offsets = get_offsets(output_sizes); + Array input_offsets = get_offsets(input_sizes); + Array output_offsets = get_offsets(output_sizes); + output_offsets.TransposeDimensions({1, 0}); std::vector chunk_sizes{ragged_tensor_sizes.begin(), ragged_tensor_sizes.end()}; @@ -1610,18 +1675,19 @@ class RaggedAllToAllTestE2E : public CollectiveOpsTestE2E { start_indices[0] = input_offsets(i, j); input_data[i].UpdateSlice(chunk_data, start_indices); - start_indices[0] = output_offsets(j, i); + start_indices[0] = output_offsets(i, j); output_data[j].UpdateSlice(chunk_data, start_indices); } } - auto get_row = [&](int64_t row_id, const Array& data) { - Array row = data.Slice({row_id, 0}, {row_id + 1, num_replicas}); + auto get_row = [&](int64_t row_id, const Array& data) { + Array row = + data.Slice({row_id, 0}, {row_id + 1, num_replicas}); row.Reshape({num_replicas}); return row; }; - // Create literals concert array to literals. + // Create literals from array data. for (int replica_id = 0; replica_id < num_replicas; ++replica_id) { inputs_.push_back(LiteralUtil::CreateFromArray(input_data[replica_id])); input_offsets_.push_back( @@ -1667,7 +1733,7 @@ class RaggedAllToAllTestE2E : public CollectiveOpsTestE2E { Literal output_init_; }; -TEST_F(RaggedAllToAllTestE2E, RaggedAllToAll_2GPUs) { +XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_2GPUs) { absl::string_view kModuleReplicatedStr = R"( HloModule module, num_partitions=1 @@ -1684,7 +1750,11 @@ TEST_F(RaggedAllToAllTestE2E, RaggedAllToAll_2GPUs) { const int64_t kNumReplicas = 2; const int64_t kNumPartitions = 1; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); + if (test_runner().device_count() < kNumReplicas * kNumPartitions) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions + << " devices (" << test_runner().device_count() + << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas * kNumPartitions); @@ -1692,8 +1762,9 @@ TEST_F(RaggedAllToAllTestE2E, RaggedAllToAll_2GPUs) { TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config)); - CreateRandomTestData(module.get(), /*input_sizes=*/{/*replica_0=*/{1, 1}, - /*replica_1=*/{3, 1}}); + CreateRandomTestData( + module.get(), /*input_sizes=*/{/*replica_0=*/{1, 1}, + /*replica_1=*/{3, 1}}); TF_ASSERT_OK_AND_ASSIGN( std::vector results, @@ -1706,17 +1777,17 @@ TEST_F(RaggedAllToAllTestE2E, RaggedAllToAll_2GPUs) { EXPECT_TRUE(LiteralTestUtil::Equal(expected_outputs_[1], results[1])); } -TEST_F(RaggedAllToAllTestE2E, RaggedAllToAll_2GPUs_MultiDimData) { +XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_2GPUs_MultiDimData) { absl::string_view kModuleReplicatedStr = R"( HloModule module, num_partitions=1 ENTRY entry { input = f32[16, 5, 32] parameter(0) output = f32[16, 5, 32] parameter(1) - input_offsets = s32[2] parameter(2) - send_sizes = s32[2] parameter(3) - output_offsets = s32[2] parameter(4) - recv_sizes = s32[2] parameter(5) + input_offsets = s64[2] parameter(2) + send_sizes = s64[2] parameter(3) + output_offsets = s64[2] parameter(4) + recv_sizes = s64[2] parameter(5) ROOT ra2a = f32[16, 5, 32] ragged-all-to-all(input, output, input_offsets, send_sizes, output_offsets, recv_sizes), replica_groups={{0,1}} @@ -1724,7 +1795,11 @@ TEST_F(RaggedAllToAllTestE2E, RaggedAllToAll_2GPUs_MultiDimData) { const int64_t kNumReplicas = 2; const int64_t kNumPartitions = 1; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); + if (test_runner().device_count() < kNumReplicas * kNumPartitions) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions + << " devices (" << test_runner().device_count() + << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas * kNumPartitions); @@ -1736,8 +1811,9 @@ TEST_F(RaggedAllToAllTestE2E, RaggedAllToAll_2GPUs_MultiDimData) { FindInstruction(module.get(), HloOpcode::kRaggedAllToAll); EXPECT_THAT(ragged_all_to_all, NotNull()); - CreateRandomTestData(module.get(), /*input_sizes=*/{/*replica_0=*/{4, 7}, - /*replica_1=*/{2, 5}}); + CreateRandomTestData( + module.get(), /*input_sizes=*/{/*replica_0=*/{4, 7}, + /*replica_1=*/{2, 5}}); TF_ASSERT_OK_AND_ASSIGN( std::vector results, @@ -1751,7 +1827,67 @@ TEST_F(RaggedAllToAllTestE2E, RaggedAllToAll_2GPUs_MultiDimData) { EXPECT_TRUE(LiteralTestUtil::Equal(expected_outputs_[1], results[1])); } -TEST_F(RaggedAllToAllTestE2E, RaggedAllToAll_8GPUs) { +XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_Degenerate_2GPUs) { + absl::string_view kModuleReplicatedStr = R"( + HloModule module + + ENTRY entry { + input = f32[4] parameter(0) + output = f32[4] parameter(1) + input_offsets = s32[1] parameter(2) + send_sizes = s32[1] parameter(3) + output_offsets = s32[1] parameter(4) + recv_sizes = s32[1] parameter(5) + ROOT ra2a = f32[4] ragged-all-to-all(input, output, input_offsets, + send_sizes, output_offsets, recv_sizes), replica_groups={{0},{1}} + })"; + + const int64_t kNumReplicas = 2; + const int64_t kNumPartitions = 1; + if (test_runner().device_count() < kNumReplicas * kNumPartitions) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions + << " devices (" << test_runner().device_count() + << " available)"; + } + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas * kNumPartitions); + + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config)); + + inputs_.push_back(LiteralUtil::CreateR1({1, 0, 0, 0})); + inputs_.push_back(LiteralUtil::CreateR1({2, 3, 4, 0})); + + input_sizes_.push_back(LiteralUtil::CreateR1({1})); + input_sizes_.push_back(LiteralUtil::CreateR1({3})); + + output_sizes_.push_back(LiteralUtil::CreateR1({1})); + output_sizes_.push_back(LiteralUtil::CreateR1({3})); + + input_offsets_.push_back(LiteralUtil::CreateR1({0})); + input_offsets_.push_back(LiteralUtil::CreateR1({0})); + + output_offsets_.push_back(LiteralUtil::CreateR1({2})); + output_offsets_.push_back(LiteralUtil::CreateR1({1})); + + output_init_ = LiteralUtil::CreateR1({-1, -1, -1, -1}); + + expected_outputs_.push_back(LiteralUtil::CreateR1({-1, -1, 1, -1})); + expected_outputs_.push_back(LiteralUtil::CreateR1({-1, 2, 3, 4})); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + HloTestBase::ExecuteReplicated(std::move(module), GetInputLiteralPtrs(), + /*num_replicas=*/kNumReplicas, + /*run_hlo_passes=*/true, + /*device_assignment=*/nullptr)); + ASSERT_EQ(results.size(), kNumReplicas); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_outputs_[0], results[0])); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_outputs_[1], results[1])); +} + +XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_8GPUs) { absl::string_view kModuleReplicatedStr = R"( HloModule module, num_partitions=1 @@ -1769,7 +1905,11 @@ TEST_F(RaggedAllToAllTestE2E, RaggedAllToAll_8GPUs) { const int64_t kNumReplicas = 8; const int64_t kNumPartitions = 1; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); + if (test_runner().device_count() < kNumReplicas * kNumPartitions) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions + << " devices (" << test_runner().device_count() + << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas * kNumPartitions); @@ -1780,7 +1920,7 @@ TEST_F(RaggedAllToAllTestE2E, RaggedAllToAll_8GPUs) { Array input_sizes({kNumReplicas, kNumReplicas}); input_sizes.FillRandomUniform(0, 10); - CreateRandomTestData(module.get(), input_sizes); + CreateRandomTestData(module.get(), input_sizes); TF_ASSERT_OK_AND_ASSIGN( std::vector results, @@ -1795,5 +1935,130 @@ TEST_F(RaggedAllToAllTestE2E, RaggedAllToAll_8GPUs) { } } +INSTANTIATE_TEST_SUITE_P(RaggedAllToAllTest, RaggedAllToAllTest, + ::testing::Bool()); + +TEST_F(CollectiveOpsTestE2E, MemcpyP2pWhileLoopCorrectness) { + absl::string_view hlo_string = R"( +HloModule MemcpyP2pWhileLoopCorrectness, entry_computation_layout={(bf16[128,96]{1,0})->(bf16[32,384]{1,0}, bf16[32,384]{1,0})}, allow_spmd_sharding_propagation_to_output={true,true}, num_partitions=4 + +None.4 { + Arg_1.6 = bf16[32,96]{1,0} parameter(1) + Arg_0.5 = bf16[32,96]{1,0} parameter(0) + collective-permute.9 = bf16[32,96]{1,0} collective-permute(Arg_0.5), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}} + constant.7 = bf16[] constant(2) + broadcast.8 = bf16[32,96]{1,0} broadcast(constant.7), dimensions={} + multiply.10 = bf16[32,96]{1,0} multiply(Arg_0.5, broadcast.8) + ROOT tuple.11 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(collective-permute.9, multiply.10) +} // None.4 + +region_0.12 { + arg_tuple.13 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) parameter(0) + get-tuple-element.14 = s32[] get-tuple-element(arg_tuple.13), index=0 + constant.17 = s32[] constant(1) + add.21 = s32[] add(get-tuple-element.14, constant.17) + get-tuple-element.15 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.13), index=1 + get-tuple-element.16 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.13), index=2 + call.18 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) call(get-tuple-element.15, get-tuple-element.16), to_apply=None.4 + get-tuple-element.19 = bf16[32,96]{1,0} get-tuple-element(call.18), index=0 + get-tuple-element.20 = bf16[32,96]{1,0} get-tuple-element(call.18), index=1 + ROOT tuple.22 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(add.21, get-tuple-element.19, get-tuple-element.20) +} // region_0.12 + +region_1.23 { + arg_tuple.24 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) parameter(0) + get-tuple-element.26 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.24), index=1 + get-tuple-element.27 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.24), index=2 + get-tuple-element.25 = s32[] get-tuple-element(arg_tuple.24), index=0 + constant.28 = s32[] constant(3) + ROOT compare.29 = pred[] compare(get-tuple-element.25, constant.28), direction=LT +} // region_1.23 + +shmap_body.30 { + constant.32 = s32[] constant(0) + Arg_0.31 = bf16[32,96]{1,0} parameter(0) + constant.33 = bf16[] constant(0) + broadcast.34 = bf16[32,96]{1,0} broadcast(constant.33), dimensions={} + tuple.35 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(constant.32, Arg_0.31, broadcast.34) + while.36 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) while(tuple.35), condition=region_1.23, body=region_0.12 + get-tuple-element.37 = s32[] get-tuple-element(while.36), index=0 + get-tuple-element.38 = bf16[32,96]{1,0} get-tuple-element(while.36), index=1 + get-tuple-element.39 = bf16[32,96]{1,0} get-tuple-element(while.36), index=2 + ROOT tuple.40 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(get-tuple-element.38, get-tuple-element.39) +} // shmap_body.30 + +ENTRY main.49 { + Arg_0.1 = bf16[128,96]{1,0} parameter(0), sharding={devices=[4,1]<=[4]} + custom-call.2 = bf16[128,96]{1,0} custom-call(Arg_0.1), custom_call_target="Sharding", sharding={devices=[4,1]<=[4]} + custom-call.3 = bf16[32,96]{1,0} custom-call(custom-call.2), custom_call_target="SPMDFullToShardShape", sharding={manual} + call.41 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) call(custom-call.3), to_apply=shmap_body.30 + get-tuple-element.42 = bf16[32,96]{1,0} get-tuple-element(call.41), index=0 + custom-call.44 = bf16[32,96]{1,0} custom-call(get-tuple-element.42), custom_call_target="Sharding", sharding={manual} + custom-call.45 = bf16[32,384]{1,0} custom-call(custom-call.44), custom_call_target="SPMDShardToFullShape", sharding={devices=[1,4]<=[4]} + get-tuple-element.43 = bf16[32,96]{1,0} get-tuple-element(call.41), index=1 + custom-call.46 = bf16[32,96]{1,0} custom-call(get-tuple-element.43), custom_call_target="Sharding", sharding={manual} + custom-call.47 = bf16[32,384]{1,0} custom-call(custom-call.46), custom_call_target="SPMDShardToFullShape", sharding={devices=[1,4]<=[4]} + ROOT tuple.48 = (bf16[32,384]{1,0}, bf16[32,384]{1,0}) tuple(custom-call.45, custom-call.47) +} // main.49 +)"; + + const int64_t kNumReplicas = 1; + const int64_t kNumPartitions = 4; + if (test_runner().device_count() < kNumReplicas * kNumPartitions) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions + << " devices (" << test_runner().device_count() + << " available)"; + } + + HloModuleConfig config = GetModuleConfigForTest(kNumReplicas, kNumPartitions); + auto opts = GetDebugOptionsForTest(); + opts.set_xla_gpu_use_memcpy_local_p2p(true); + config.set_debug_options(opts); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string, config)); + auto fake_arguments = xla::MakeFakeArguments(module.get()).value(); + std::vector fake_ptrs(fake_arguments.size()); + for (int i = 0; i < fake_arguments.size(); ++i) { + fake_ptrs[i] = &fake_arguments[i]; + } + + DeviceAssignment assn(/*replica_count=*/kNumReplicas, + /*computation_count=*/kNumPartitions); + for (int64_t i = 0; i < kNumPartitions; ++i) { + assn(0, i) = i; + } + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + HloTestBase::ExecuteReplicated( + std::move(module), fake_ptrs, kNumPartitions, &assn, + /*run_hlo_passes=*/true, /*use-threads=*/true)); + ASSERT_EQ(results.size(), kNumPartitions); + + HloModuleConfig ref_config = + GetModuleConfigForTest(kNumReplicas, kNumPartitions); + auto ref_opts = GetDebugOptionsForTest(); + ref_opts.set_xla_gpu_use_memcpy_local_p2p(false); + ref_config.set_debug_options(ref_opts); + TF_ASSERT_OK_AND_ASSIGN(auto ref_module, + ParseAndReturnVerifiedModule(hlo_string, ref_config)); + auto fake_ref_arguments = xla::MakeFakeArguments(ref_module.get()).value(); + std::vector ref_fake_ptrs(fake_ref_arguments.size()); + for (int i = 0; i < fake_ref_arguments.size(); ++i) { + ref_fake_ptrs[i] = &fake_ref_arguments[i]; + } + + TF_ASSERT_OK_AND_ASSIGN( + std::vector ref_results, + HloTestBase::ExecuteReplicated( + std::move(ref_module), ref_fake_ptrs, kNumPartitions, &assn, + /*run_hlo_passes=*/true, /*use-threads=*/true)); + ASSERT_EQ(ref_results.size(), kNumPartitions); + ErrorSpec error_spec{1e-5, 1e-5}; + // Expect same results with and without pipelining of collectives. + for (int i = 0; i < kNumPartitions; ++i) { + EXPECT_TRUE(LiteralTestUtil::Near(ref_results[i], results[i], error_spec)); + } +} } // namespace } // namespace xla diff --git a/third_party/xla/xla/tests/collective_ops_test.cc b/third_party/xla/xla/tests/collective_ops_test.cc index ef59426121609b..eee86a396d43f8 100644 --- a/third_party/xla/xla/tests/collective_ops_test.cc +++ b/third_party/xla/xla/tests/collective_ops_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include #include #include #include @@ -24,21 +23,25 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "absl/types/span.h" +#include "ml_dtypes/include/float8.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/service/computation_placer.h" -#include "xla/service/executable.h" #include "xla/service/hlo_module_config.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/threadpool.h" +#include "xla/types.h" #include "tsl/platform/blocking_counter.h" -#include "tsl/platform/env.h" -#include "tsl/platform/threadpool.h" namespace xla { namespace { @@ -221,6 +224,19 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceSingleOutput_float32) { /*expected_value=*/LiteralUtil::CreateR1({2})); } +XLA_TEST_F(CollectiveOpsTest, + AllReduceTwoReplicasOneOperand_float8_e4m3b11fnuz) { + TestAllOpsForReduce(); +} + +XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int4) { + TestAllOpsForReduce(); +} + +XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint4) { + TestAllOpsForReduce(); +} + XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int8) { TestAllOpsForReduce(); } @@ -486,7 +502,10 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_ThreeReplicaGroups) { // Test a prime number so it's not all powers of 2. const int64_t kNumElems = 137; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -533,7 +552,10 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_Degenerate) { } )"; static constexpr int kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -672,7 +694,10 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectiveBroadcast_TwoGPUs)) { } )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -710,7 +735,10 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectiveBroadcast_Simple)) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -746,7 +774,10 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_TwoGPUs) { } )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -778,7 +809,10 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -815,7 +849,10 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Degenerate) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -851,7 +888,10 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_NotDegenerate) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -888,7 +928,10 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Rotate) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -926,7 +969,10 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncCollectivePermute)) { )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -968,7 +1014,10 @@ XLA_TEST_F(CollectiveOpsTest, AllToAll_EmptyReplicaGroups) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -1014,7 +1063,10 @@ XLA_TEST_F(CollectiveOpsTest, AllToAll_OrderedReplicaGroups) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -1054,7 +1106,10 @@ XLA_TEST_F(CollectiveOpsTest, AllToAll_TwoReplicaGroups) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -1086,7 +1141,10 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_SplitDimension)) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -1435,7 +1493,6 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(ReduceScatterReassociate)) { kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); - const ErrorSpec es{1e-5, 1e-5}; LiteralTestUtil::ExpectR1Equal({26, 30, 34, 38}, results[0]); LiteralTestUtil::ExpectR1Equal({42, 46, 50, 54}, results[1]); } @@ -1486,7 +1543,6 @@ XLA_TEST_F(CollectiveOpsTest, kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); - const ErrorSpec es{1e-5, 1e-5}; LiteralTestUtil::ExpectR1Equal({26, 30, 34, 38}, results[0]); LiteralTestUtil::ExpectR1Equal({42, 46, 50, 54}, results[1]); } @@ -1630,7 +1686,7 @@ XLA_TEST_F(CollectiveOpsTest, results[0]); } -XLA_TEST_F(CollectiveOpsTest, AllGather_16BitInt) { +XLA_TEST_F(CollectiveOpsTest, AllGather16BitInt) { const char* const kModuleStr = R"( HloModule test ENTRY test_computation { @@ -1660,7 +1716,40 @@ XLA_TEST_F(CollectiveOpsTest, AllGather_16BitInt) { } } -XLA_TEST_F(CollectiveOpsTest, AllToAll_16BitInt) { +XLA_TEST_F(CollectiveOpsTest, AllGather4BitInt) { + // Test with all-gather inputs having an odd number of elements to ensure that + // the 4 bits of padding are handled correctly. + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + id32 = u32[] replica-id() + id = u4[] convert(id32) + id2 = u4[1, 3] broadcast(id), dimensions={} + a0 = u4[1, 3] constant({{3, 5, 7}}) + a1 = u4[1, 3] add(id2, a0) + allgather = u4[2, 3] all-gather(a1), dimensions={0} + ROOT out = u4[6] reshape(allgather) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + for (const Literal& result : results) { + LiteralTestUtil::ExpectR1Equal( + {u4{3}, u4{5}, u4{7}, u4{4}, u4{6}, u4{8}}, result); + } +} + +XLA_TEST_F(CollectiveOpsTest, AllToAll16BitInt) { const char* const kModuleStr = R"( HloModule test ENTRY test_computation { @@ -1688,7 +1777,35 @@ XLA_TEST_F(CollectiveOpsTest, AllToAll_16BitInt) { LiteralTestUtil::ExpectR1Equal({15, 16}, results[1]); } -XLA_TEST_F(CollectiveOpsTest, CollectivePermute_16BitInt) { +XLA_TEST_F(CollectiveOpsTest, AllToAll4BitInt) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + id32 = u32[] replica-id() + id = u4[] convert(id32) + id2 = u4[2] broadcast(id), dimensions={} + a0 = u4[2] constant({5, 7}) + a1 = u4[2] add(id2, a0) + ROOT a2a = u4[2] all-to-all(a1), dimensions={0} + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({u4{5}, u4{6}}, results[0]); + LiteralTestUtil::ExpectR1Equal({u4{7}, u4{8}}, results[1]); +} + +XLA_TEST_F(CollectiveOpsTest, CollectivePermute16BitInt) { const char* const kModuleStr = R"( HloModule test ENTRY test_computation { @@ -1716,7 +1833,37 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_16BitInt) { LiteralTestUtil::ExpectR1Equal({10, 15}, results[1]); } -XLA_TEST_F(CollectiveOpsTest, AllReduce_16BitInt) { +XLA_TEST_F(CollectiveOpsTest, CollectivePermute4BitInt) { + // Test with collective-permute inputs having an odd number of elements to + // ensure that the 4 bits of padding are handled correctly. + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + id32 = u32[] replica-id() + id = u4[] convert(id32) + id2 = u4[3] broadcast(id), dimensions={} + a0 = u4[3] constant({3, 5, 7}) + a1 = u4[3] add(id2, a0) + ROOT cp = u4[3] collective-permute(a1), source_target_pairs={{0,1}, {1,0}} + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({u4{4}, u4{6}, u4{8}}, results[0]); + LiteralTestUtil::ExpectR1Equal({u4{3}, u4{5}, u4{7}}, results[1]); +} + +XLA_TEST_F(CollectiveOpsTest, AllReduce16BitInt) { const char* const kModuleStr = R"( HloModule test @@ -1752,7 +1899,45 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_16BitInt) { } } -XLA_TEST_F(CollectiveOpsTest, ReduceScatter_16BitInt) { +XLA_TEST_F(CollectiveOpsTest, AllReduce4BitInt) { + // Test with all-reduce inputs having an odd number of elements to ensure that + // the 4 bits of padding are handled correctly. + const char* const kModuleStr = R"( + HloModule test + + sum { + a = u4[] parameter(0) + b = u4[] parameter(1) + ROOT add.2 = u4[] add(a, b) + } + + ENTRY test_computation { + id32 = u32[] replica-id() + id = u4[] convert(id32) + id2 = u4[3] broadcast(id), dimensions={} + a0 = u4[3] constant({3, 5, 7}) + a1 = u4[3] add(id2, a0) + ROOT cp = u4[3] all-reduce(a1), replica_groups={}, to_apply=sum + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + for (const Literal& result : results) { + LiteralTestUtil::ExpectR1Equal({u4{7}, u4{11}, u4{15}}, result); + } +} + +XLA_TEST_F(CollectiveOpsTest, ReduceScatter16BitInt) { const char* const kModuleStr = R"( HloModule test @@ -1787,6 +1972,41 @@ XLA_TEST_F(CollectiveOpsTest, ReduceScatter_16BitInt) { LiteralTestUtil::ExpectR1Equal({31}, results[1]); } +XLA_TEST_F(CollectiveOpsTest, ReduceScatter4BitInt) { + const char* const kModuleStr = R"( + HloModule test + + sum { + a = u4[] parameter(0) + b = u4[] parameter(1) + ROOT add.2 = u4[] add(a, b) + } + + ENTRY test_computation { + id32 = u32[] replica-id() + id = u4[] convert(id32) + id2 = u4[2] broadcast(id), dimensions={} + a0 = u4[2] constant({5, 7}) + a1 = u4[2] add(id2, a0) + ROOT cp = u4[1]reduce-scatter(a1), dimensions={0}, replica_groups={}, to_apply=sum + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({u4{11}}, results[0]); + LiteralTestUtil::ExpectR1Equal({u4{15}}, results[1]); +} + XLA_TEST_F(CollectiveOpsTest, AllReduceBFloat16Min) { const char* const kModuleStr = R"( HloModule test @@ -1991,7 +2211,10 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_Simple)) { )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -2071,7 +2294,10 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_TwoConcurrentChains)) { })"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -2150,7 +2376,10 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_ValidationAttr1)) { })"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -2251,7 +2480,10 @@ body { })"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -2292,7 +2524,10 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecvCrossReplica)) { )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest( /*replica_count=*/kNumReplicas, /*num_partitions=*/1); @@ -2335,7 +2570,11 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecvCrossPartition)) { const int64_t kNumReplicas = 1; const int64_t kNumPartitions = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); + if (test_runner().device_count() < kNumReplicas * kNumPartitions) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions + << " devices (" << test_runner().device_count() + << " available)"; + } // Create device assignment running across partitions. DeviceAssignment device_assignment(/*replica_count=*/kNumReplicas, diff --git a/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc b/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc index 1a9997fe21dbfb..f10ab3d181da33 100644 --- a/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc +++ b/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/error_spec.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/computation_placer.h" @@ -31,7 +32,6 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/statusor.h" namespace xla { @@ -103,7 +103,10 @@ XLA_TEST_P(CollectivePipelineParallelismTest, const int64_t kNumReplicas = 4; const int64_t kNumPartitions = 1; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } // Parse HLO module. HloModuleConfig config = GetModuleConfigForTest( @@ -296,7 +299,10 @@ XLA_TEST_P(CollectivePipelineParallelismTest, NaiveBFSMicrobatch4Replica4) { const int64_t kNumReplicas = 4; const int64_t kNumPartitions = 1; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } // Parse HLO module. HloModuleConfig config = GetModuleConfigForTest( @@ -416,7 +422,10 @@ XLA_TEST_P(CollectivePipelineParallelismTest, NaiveBFSMicrobatch5Replica4) { const int64_t kNumReplicas = 4; const int64_t kNumPartitions = 1; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } // Parse HLO module. HloModuleConfig config = GetModuleConfigForTest( @@ -536,7 +545,10 @@ XLA_TEST_P(CollectivePipelineParallelismTest, const int64_t kNumReplicas = 4; const int64_t kNumPartitions = 1; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } // Parse HLO module. HloModuleConfig config = GetModuleConfigForTest( @@ -674,7 +686,10 @@ XLA_TEST_P(CollectivePipelineParallelismTest, const int64_t kNumReplicas = 4; const int64_t kNumPartitions = 1; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } // Parse HLO module. HloModuleConfig config = GetModuleConfigForTest( @@ -814,7 +829,10 @@ XLA_TEST_P(CollectivePipelineParallelismTest, const int64_t kNumReplicas = 4; const int64_t kNumPartitions = 1; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } // Parse HLO module. HloModuleConfig config = GetModuleConfigForTest( @@ -907,7 +925,11 @@ XLA_TEST_P(CollectivePipelineParallelismTest, SendRecvLoop) { const int64_t kNumReplicas = 1; const int64_t kNumPartitions = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); + if (test_runner().device_count() < kNumReplicas * kNumPartitions) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions + << " devices (" << test_runner().device_count() + << " available)"; + } // Parse HLO module. HloModuleConfig config = GetModuleConfigForTest( @@ -992,7 +1014,11 @@ XLA_TEST_P(CollectivePipelineParallelismTest, SendRecvLoop2Devices) { const int64_t kNumReplicas = 1; const int64_t kNumPartitions = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); + if (test_runner().device_count() < kNumReplicas * kNumPartitions) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions + << " devices (" << test_runner().device_count() + << " available)"; + } // Parse HLO module. HloModuleConfig config = GetModuleConfigForTest( @@ -1088,7 +1114,11 @@ XLA_TEST_P(CollectivePipelineParallelismTest, const int64_t kNumReplicas = 1; const int64_t kNumPartitions = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); + if (test_runner().device_count() < kNumReplicas * kNumPartitions) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions + << " devices (" << test_runner().device_count() + << " available)"; + } // Parse HLO module. HloModuleConfig config = GetModuleConfigForTest( @@ -1186,7 +1216,11 @@ XLA_TEST_P(CollectivePipelineParallelismTest, const int64_t kNumReplicas = 1; const int64_t kNumPartitions = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); + if (test_runner().device_count() < kNumReplicas * kNumPartitions) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions + << " devices (" << test_runner().device_count() + << " available)"; + } // Parse HLO module. HloModuleConfig config = GetModuleConfigForTest( @@ -1223,6 +1257,424 @@ XLA_TEST_P(CollectivePipelineParallelismTest, LiteralTestUtil::ExpectR2Equal({{0, 0}, {0, 0}}, results[1]); } +// This is the partially pipelined version of +// NaiveBFSMicrobatch5CircularRepeat2Replica4 and should yield the same results. +// TODO(b/383868854): replace this with GPU pipeliner implementation. +XLA_TEST_P(CollectivePipelineParallelismTest, + NaiveBFSMb5Cr2Replica4SendRecvPartiallyPipelined) { + constexpr char kMoreComputationsStr[] = R"( + while_condition { + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[], + (f32[16], token[]), (f32[16], token[])) parameter(0) + i = u32[] get-tuple-element(tuple), index=5 + n = u32[] constant(13) + ROOT predicate = pred[] compare(i, n), direction=LT + } + + while_body { + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[], + (f32[16], token[]), (f32[16], token[])) parameter(0) + weights = f32[16,16] get-tuple-element(tuple), index=0 + input = f32[5,16] get-tuple-element(tuple), index=1 + output = f32[5,16] get-tuple-element(tuple), index=2 + buffer = f32[5,16] get-tuple-element(tuple), index=3 + prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=4 + i = u32[] get-tuple-element(tuple), index=5 + + prev_iter_fwd_recv_done = (f32[16], token[]) + get-tuple-element(tuple), index=6 + prev_iter_bwd_recv_done = (f32[16], token[]) + get-tuple-element(tuple), index=7 + prev_stage_slice_fwd = f32[16] get-tuple-element(prev_iter_fwd_recv_done), + index=0 + prev_stage_slice_bwd = f32[16] get-tuple-element(prev_iter_bwd_recv_done), + index=0 + + c0 = u32[] constant(0) + c1 = u32[] constant(1) + c2 = u32[] constant(2) + c3 = u32[] constant(3) + c4 = u32[] constant(4) + c5 = u32[] constant(5) + + // Read from buffers. + input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb5 + buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer_mb5 + + // Shift data to the next stage in the pipeline. + // Directly depends on the updated buffer of the previous iteration and, + // therefore, depends on the previous iteration's compute. + is_output_replica = pred[] call(), to_apply=is_output_replica + next_stage_slice = select(is_output_replica, buffer_slice, + prev_iteration_compute_res) + + // Shift data to the next stage in the pipeline. + after_all_fwd = token[] after-all() + fwd_send = (f32[16], u32[], token[]) send(next_stage_slice, after_all_fwd), + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + + // Select compute argument from previous stage or from input and perform + // compute. + is_read_input = pred[] call(i), to_apply=is_read_input_mb5 + compute_arg_bwd = f32[16] select(is_read_input, input_slice, prev_stage_slice_bwd) + compute_res_bwd = f32[16] dot(weights, compute_arg_bwd), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + is_device_zero = pred[] call(), to_apply=is_input_replica + compute_arg_fwd = f32[16] select(is_device_zero, + prev_stage_slice_bwd, prev_stage_slice_fwd) + compute_res_fwd = f32[16] dot(weights, compute_arg_fwd), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + + // Update buffers. + compute_res = f32[16] select(is_device_zero, compute_res_bwd, compute_res_fwd) + output_ = f32[5,16] call(output, compute_res, c1, i), + to_apply=update_buffer_mb5 + buffer_ = f32[5,16] call(buffer, prev_iteration_compute_res, c4, i), + to_apply=update_buffer_mb5 + + fwd_recv = (f32[16], u32[], token[]) recv(after_all_fwd), + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + fwd_recv_done = (f32[16], token[]) recv-done(fwd_recv), + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, + control-predecessors={fwd_send} + + after_all_bwd = token[] after-all() + bwd_send = (f32[16], u32[], token[]) send(next_stage_slice, after_all_bwd), + frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}} + bwd_recv = (f32[16], u32[], token[]) recv(after_all_bwd), + frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}} + bwd_recv_done = (f32[16], token[]) recv-done(bwd_recv), + frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}}, + control-predecessors={bwd_send} + + i_ = add(i, c1) + + ROOT tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[], + (f32[16], token[]), (f32[16], token[])) tuple(weights, input, output_, + buffer_, compute_res, i_, fwd_recv_done, bwd_recv_done) + fwd_send_done = token[] send-done(fwd_send) + bwd_send_done = token[] send-done(bwd_send) + } + + ENTRY main { + weights = f32[16,16] parameter(0) + input = f32[5,16] parameter(1) + + cf0 = f32[] constant(0) + output = f32[5,16] broadcast(cf0), dimensions={} + buffer = f32[5,16] broadcast(cf0), dimensions={} + prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={} + c0 = u32[] constant(0) + input_slice = f32[16] call(input, c0, c0), to_apply=read_buffer_mb5 + + after_all_fwd = token[] after-all() + fwd_recv = (f32[16], u32[], token[]) recv(after_all_fwd), + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + fwd_recv_done = (f32[16], token[]) recv-done(fwd_recv), + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + + after_all_bwd = token[] after-all() + bwd_recv = (f32[16], u32[], token[]) recv(after_all_bwd), + frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}} + bwd_recv_done = (f32[16], token[]) recv-done(bwd_recv), + frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}} + bwd_send = (f32[16], u32[], token[]) send(input_slice, after_all_bwd), + frontend_attributes={_xla_send_recv_source_target_pairs={{3,0}}} + bwd_send_done = token[] send-done(bwd_send) + + + // Iterate through pipeline stages. + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[], + (f32[16], token[]), (f32[16], token[])) tuple(weights, input, output, + buffer, prev_iteration_compute_res, c0, fwd_recv_done, bwd_recv_done) + tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[], + (f32[16], token[]), (f32[16], token[])) while(tuple), + condition=while_condition, body=while_body + + + // unroll while loop results + weights_ = f32[16,16] get-tuple-element(tuple_), index=0 + input_ = f32[5,16] get-tuple-element(tuple_), index=1 + output_ = f32[5,16] get-tuple-element(tuple_), index=2 + buffer_ = f32[5,16] get-tuple-element(tuple_), index=3 + prev_iteration_compute_res_ = f32[16] get-tuple-element(tuple_), index=4 + i_ = u32[] get-tuple-element(tuple_), index=5 + prev_stage_fwd_recv_done_ = (f32[16], token[]) get-tuple-element(tuple_), index=6 + prev_stage_bwd_recv_done_ = (f32[16], token[]) get-tuple-element(tuple_), index=7 + prev_stage_slice_fwd_ = f32[16] get-tuple-element(prev_stage_fwd_recv_done_), index=0 + prev_stage_slice_bwd_ = f32[16] get-tuple-element(prev_stage_bwd_recv_done_), index=0 + + c0_ = u32[] constant(0) + c1_ = u32[] constant(1) + c2_ = u32[] constant(2) + c3_ = u32[] constant(3) + c4_ = u32[] constant(4) + c5_ = u32[] constant(5) + + // Read from buffers. + input_slice_ = f32[16] call(input, c0_, i_), to_apply=read_buffer_mb5 + buffer_slice_ = f32[16] call(buffer, c3_, i_), to_apply=read_buffer_mb5 + + // Shift data to the next stage in the pipeline. + // Directly depends on the updated buffer of the previous iteration and, + // therefore, depends on the previous iteration's compute. + is_output_replica_ = pred[] call(), to_apply=is_output_replica + next_stage_slice_ = select(is_output_replica_, buffer_slice_, + prev_iteration_compute_res_) + + fwd_send = (f32[16], u32[], token[]) send(next_stage_slice_, after_all_fwd), + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + fwd_send_done = token[] send-done(fwd_send) + + + // Select compute argument from previous stage or from input and perform + // compute. + is_read_input_ = pred[] call(i_), to_apply=is_read_input_mb5 + compute_arg_bwd_ = f32[16] select(is_read_input_, input_slice_, prev_stage_slice_bwd_) + compute_res_bwd_ = f32[16] dot(weights_, compute_arg_bwd_), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + is_device_zero_ = pred[] call(), to_apply=is_input_replica + compute_arg_fwd_ = f32[16] select(is_device_zero_, prev_stage_slice_bwd_, prev_stage_slice_fwd_) + compute_res_fwd_ = f32[16] dot(weights_, compute_arg_fwd_), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + + // Update buffers. + compute_res_ = f32[16] select(is_device_zero_, compute_res_bwd_, compute_res_fwd_) + ROOT output__ = f32[5,16] call(output_, compute_res_, c1_, i_), + to_apply=update_buffer_mb5 + + } + )"; + + const int64_t kNumReplicas = 4; + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations( + /*name=*/"test", kMoreComputationsStr), + config)); + + const int64_t kInputSize = 16; + Literal weights_r0 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 1.0); + Literal weights_r1 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 2.0); + Literal weights_r2 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 3.0); + Literal weights_r3 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 4.0); + + const int64_t kMicrobatches = 5; + Literal real_input = + LiteralUtil::CreateFingerprintMatixR2(kMicrobatches, kInputSize); + Literal fake_input = LiteralUtil::CreateFull( + {kMicrobatches, kInputSize}, /*value=*/0.0); + + const float kExpectedFactor = 1.0 * 2.0 * 3.0 * 4.0 * 1.0 * 2.0 * 3.0 * 4.0; + Literal expected_output = LiteralUtil::CreateFingerprintMatixR2( + kMicrobatches, kInputSize, /*scale=*/kExpectedFactor); + std::vector> args = {{&weights_r0, &real_input}, + {&weights_r1, &fake_input}, + {&weights_r2, &fake_input}, + {&weights_r3, &fake_input}}; + // TODO(rosiezou): enable send/recv combiner pass. + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), args, kNumReplicas, + /*run_hlo_passes=*/true)); + EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected_output, results[3], + ErrorSpec{1e-5, 1e-5})); +} + +// This is the async-grouped version of +// NaiveBFSMicrobatch5CircularRepeat2Replica4 and should yield the same results. +// TODO(b/383868854): replace this with GPU pipeliner implementation. +XLA_TEST_P(CollectivePipelineParallelismTest, + NaiveBFSMb5Cr2Replica4SendRecvAsyncGroup) { + constexpr char kMoreComputationsStr[] = R"( + + wrapped_send_recv_1 { + fwd_send_data = f32[16] parameter(0) + fwd_send_after_all = token[] parameter(1) + fwd_send = (f32[16], u32[], token[]) send(fwd_send_data, fwd_send_after_all), + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + + fwd_recv_after_all = token[] parameter(2) + fwd_recv = (f32[16], u32[], token[]) recv(fwd_recv_after_all), frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + + bwd_send_data = f32[16] parameter(3) + bwd_send_after_all = token[] parameter(4) + bwd_send = (f32[16], u32[], token[]) send(bwd_send_data, bwd_send_after_all), frontend_attributes={ + _xla_send_recv_source_target_pairs={{3,0}}} + + bwd_recv_after_all = token[] parameter(5) + bwd_recv = (f32[16], u32[], token[]) recv(bwd_recv_after_all), frontend_attributes={ + _xla_send_recv_source_target_pairs={{3,0}}} + + ROOT out = ((f32[16], u32[], token[]),(f32[16], u32[], token[]), + (f32[16], u32[], token[]),(f32[16], u32[], token[])) tuple(fwd_send, + fwd_recv, bwd_send, bwd_recv) + + } + + while_condition { + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + parameter(0) + i = u32[] get-tuple-element(tuple), index=5 + n = u32[] constant(13) + ROOT predicate = pred[] compare(i, n), direction=LT + } + + while_body { + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + parameter(0) + weights = f32[16,16] get-tuple-element(tuple), index=0 + input = f32[5,16] get-tuple-element(tuple), index=1 + output = f32[5,16] get-tuple-element(tuple), index=2 + buffer = f32[5,16] get-tuple-element(tuple), index=3 + prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=4 + i = u32[] get-tuple-element(tuple), index=5 + + c0 = u32[] constant(0) + c1 = u32[] constant(1) + c2 = u32[] constant(2) + c3 = u32[] constant(3) + c4 = u32[] constant(4) + c5 = u32[] constant(5) + + // Read from buffers. + input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb5 + buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer_mb5 + + // Shift data to the next stage in the pipeline. + // Directly depends on the updated buffer of the previous iteration and, + // therefore, depends on the previous iteration's compute. + is_output_replica = pred[] call(), to_apply=is_output_replica + next_stage_slice = select(is_output_replica, buffer_slice, + prev_iteration_compute_res) + + + // Shift data to the next stage in the pipeline. + after_all_fwd = token[] after-all() + after_all_bwd = token[] after-all() + + async_comp_start = (( f32[16], token[], token[], f32[16], token[], token[]), + ((f32[16], u32[], token[]), (f32[16], u32[], token[]), (f32[16], u32[], token[]), + (f32[16], u32[], token[])), s32[]) async-start(next_stage_slice, + after_all_fwd, after_all_fwd, next_stage_slice, + after_all_bwd, after_all_bwd), calls=wrapped_send_recv_1 + + async_comp_done = ((f32[16], u32[], token[]), (f32[16], u32[], token[]), + (f32[16], u32[], token[]), (f32[16], u32[], token[])) async-done(async_comp_start) + unpack_fwd_recv = (f32[16], u32[], token[]) get-tuple-element(async_comp_done), index=1 + fwd_recv_data = f32[16] get-tuple-element(unpack_fwd_recv), index=0 + fwd_recv_token = token[] get-tuple-element(unpack_fwd_recv), index=2 + fwd_recv_done = (f32[16], token[]) tuple(fwd_recv_data, fwd_recv_token), + control-predecessors={async_comp_start} + + unpack_bwd_recv = (f32[16], u32[], token[]) get-tuple-element(async_comp_done), index=3 + bwd_recv_data = f32[16] get-tuple-element(unpack_bwd_recv), index=0 + bwd_recv_token = token[] get-tuple-element(unpack_bwd_recv), index=2 + bwd_recv_done = (f32[16], token[]) tuple(bwd_recv_data, bwd_recv_token), + control-predecessors={async_comp_start} + prev_stage_slice_fwd = f32[16] get-tuple-element(fwd_recv_done), index=0 + prev_stage_slice_bwd = f32[16] get-tuple-element(bwd_recv_done), index=0 + + + // Select compute argument from previous stage or from input and perform + // compute. + is_read_input = pred[] call(i), to_apply=is_read_input_mb5 + compute_arg_bwd = f32[16] select(is_read_input, input_slice, prev_stage_slice_bwd) + compute_res_bwd = f32[16] dot(weights, compute_arg_bwd), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + is_device_zero = pred[] call(), to_apply=is_input_replica + compute_arg_fwd = f32[16] select(is_device_zero, prev_stage_slice_bwd, prev_stage_slice_fwd) + compute_res_fwd = f32[16] dot(weights, compute_arg_fwd), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + + // Update buffers. + compute_res = f32[16] select(is_device_zero, compute_res_bwd, compute_res_fwd) + output_ = f32[5,16] call(output, compute_res, c2, i), + to_apply=update_buffer_mb5 + buffer_ = f32[5,16] call(buffer, compute_res, c0, i), + to_apply=update_buffer_mb5 + + i_ = add(i, c1) + + ROOT tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + tuple(weights, input, output_, buffer_, compute_res, i_) + + unpack-send-done1 = (f32[16], u32[], token[]) get-tuple-element(async_comp_done), index=0 + send-done1 = token[] get-tuple-element(unpack-send-done1), index=2 + unpack-send-done2 = (f32[16], u32[], token[]) get-tuple-element(async_comp_done), index=2 + send-done2 = token[] get-tuple-element(unpack-send-done2), index=2 + } + + ENTRY main { + weights = f32[16,16] parameter(0) + input = f32[5,16] parameter(1) + + cf0 = f32[] constant(0) + output = f32[5,16] broadcast(cf0), dimensions={} + buffer = f32[5,16] broadcast(cf0), dimensions={} + prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={} + c0 = u32[] constant(0) + + // Iterate through pipeline stages. + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + tuple(weights, input, output, buffer, prev_iteration_compute_res, c0) + tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + while(tuple), condition=while_condition, body=while_body + + ROOT output_ = f32[5,16] get-tuple-element(tuple_), index=2 + } + )"; + + const int64_t kNumReplicas = 4; + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations( + /*name=*/"test", kMoreComputationsStr), + config)); + + const int64_t kInputSize = 16; + Literal weights_r0 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 1.0); + Literal weights_r1 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 2.0); + Literal weights_r2 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 3.0); + Literal weights_r3 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 4.0); + + const int64_t kMicrobatches = 5; + Literal real_input = + LiteralUtil::CreateFingerprintMatixR2(kMicrobatches, kInputSize); + Literal fake_input = + LiteralUtil::CreateFull({kMicrobatches, kInputSize}, 0.0); + + const float kExpectedFactor = 1.0 * 2.0 * 3.0 * 4.0 * 1.0 * 2.0 * 3.0 * 4.0; + Literal expected_output = LiteralUtil::CreateFingerprintMatixR2( + kMicrobatches, kInputSize, /*scale=*/kExpectedFactor); + std::vector> args = {{&weights_r0, &real_input}, + {&weights_r1, &fake_input}, + {&weights_r2, &fake_input}, + {&weights_r3, &fake_input}}; + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), args, kNumReplicas, + /*run_hlo_passes=*/true)); + EXPECT_TRUE(LiteralTestUtil::NearOrEqual( + expected_output, results[3], + ErrorSpec{/*abs_error=*/1e-5, /*rel_error=*/1e-5})); +} + INSTANTIATE_TEST_SUITE_P(CollectivePipelineParallelismTestWithAndWithoutOpts, CollectivePipelineParallelismTest, ::testing::Bool(), ::testing::PrintToStringParamName()); diff --git a/third_party/xla/xla/tests/compute_constant_test.cc b/third_party/xla/xla/tests/compute_constant_test.cc index 6524e47ffa5486..7d3065bb0d1fb5 100644 --- a/third_party/xla/xla/tests/compute_constant_test.cc +++ b/third_party/xla/xla/tests/compute_constant_test.cc @@ -23,11 +23,11 @@ limitations under the License. #include "xla/client/global_data.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/test.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/test.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" diff --git a/third_party/xla/xla/tests/concat_test.cc b/third_party/xla/xla/tests/concat_test.cc index 6f831d8f29c998..7301d7bf8c9c05 100644 --- a/third_party/xla/xla/tests/concat_test.cc +++ b/third_party/xla/xla/tests/concat_test.cc @@ -22,10 +22,10 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal_util.h" #include "xla/reference_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/tests/concatenate_test.cc b/third_party/xla/xla/tests/concatenate_test.cc index 02a28684beaf22..460087a3e16c4b 100644 --- a/third_party/xla/xla/tests/concatenate_test.cc +++ b/third_party/xla/xla/tests/concatenate_test.cc @@ -23,10 +23,10 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "tsl/platform/status.h" diff --git a/third_party/xla/xla/tests/constant_reduction_function_test.cc b/third_party/xla/xla/tests/constant_reduction_function_test.cc index 57c603023610cd..4c2529ca46f33e 100644 --- a/third_party/xla/xla/tests/constant_reduction_function_test.cc +++ b/third_party/xla/xla/tests/constant_reduction_function_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/conv_depthwise_backprop_filter_test.cc b/third_party/xla/xla/tests/conv_depthwise_backprop_filter_test.cc index d53458d3ae9af9..c0770876733e16 100644 --- a/third_party/xla/xla/tests/conv_depthwise_backprop_filter_test.cc +++ b/third_party/xla/xla/tests/conv_depthwise_backprop_filter_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include "xla/execution_options_util.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/transforms/despecializer.h" #include "xla/hlo/transforms/simplifiers/float_normalization.h" #include "xla/status_macros.h" -#include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/conv_depthwise_common.cc b/third_party/xla/xla/tests/conv_depthwise_common.cc index 5c4bb5d1fcef45..09cd38576322fa 100644 --- a/third_party/xla/xla/tests/conv_depthwise_common.cc +++ b/third_party/xla/xla/tests/conv_depthwise_common.cc @@ -19,10 +19,10 @@ limitations under the License. #include "xla/execution_options_util.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/transforms/despecializer.h" #include "xla/hlo/transforms/simplifiers/float_normalization.h" #include "xla/status_macros.h" -#include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/conv_depthwise_common.h b/third_party/xla/xla/tests/conv_depthwise_common.h index 350858498111f4..010dde84898815 100644 --- a/third_party/xla/xla/tests/conv_depthwise_common.h +++ b/third_party/xla/xla/tests/conv_depthwise_common.h @@ -20,10 +20,10 @@ limitations under the License. #include "xla/execution_options_util.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/transforms/despecializer.h" #include "xla/hlo/transforms/simplifiers/float_normalization.h" #include "xla/status_macros.h" -#include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/conv_depthwise_test.cc b/third_party/xla/xla/tests/conv_depthwise_test.cc index 05d2e6c446ee4a..b5dc09522591e5 100644 --- a/third_party/xla/xla/tests/conv_depthwise_test.cc +++ b/third_party/xla/xla/tests/conv_depthwise_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include "xla/execution_options_util.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/transforms/despecializer.h" #include "xla/hlo/transforms/simplifiers/float_normalization.h" #include "xla/status_macros.h" -#include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/conv_depthwise_common.h" #include "xla/tests/hlo_test_base.h" diff --git a/third_party/xla/xla/tests/convolution_dimension_numbers_test.cc b/third_party/xla/xla/tests/convolution_dimension_numbers_test.cc index 557f4046ca4e82..833a9266afb3bf 100644 --- a/third_party/xla/xla/tests/convolution_dimension_numbers_test.cc +++ b/third_party/xla/xla/tests/convolution_dimension_numbers_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/hlo/builder/padding.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/testlib/test.h" #include "xla/reference_util.h" -#include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/convolution_test.cc b/third_party/xla/xla/tests/convolution_test.cc index d337322911b60b..48bef5d0bb0884 100644 --- a/third_party/xla/xla/tests/convolution_test.cc +++ b/third_party/xla/xla/tests/convolution_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/window_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/test.h" @@ -2069,27 +2070,23 @@ class Transposed2DConvHloTest } } - std::string GetPaddingString(int kernel_x, int kernel_y) { - return absl::StrCat(GetPaddingValue(kernel_x, /*low=*/true), "_", - GetPaddingValue(kernel_x, /*low=*/false), "x", - GetPaddingValue(kernel_y, /*low=*/true), "_", - GetPaddingValue(kernel_y, /*low=*/false)); - } + auto GetWindow() { + Window window; - std::string GetWindowString() { - const auto padding_string = GetPaddingString(kernel_x_, kernel_y_); - const auto window_size_string = absl::StrCat(kernel_x_, "x", kernel_y_); - const auto lhs_dilation_string = - absl::StrCat(lhs_dilation_x_, "x", lhs_dilation_y_); + auto add_dim = [&](int size, int lhs_dilation) { + auto dim = window.add_dimensions(); + dim->set_size(size); + dim->set_stride(1); + dim->set_padding_low(GetPaddingValue(size, /*low=*/true)); + dim->set_padding_high(GetPaddingValue(size, /*low=*/false)); + dim->set_window_dilation(1); + dim->set_base_dilation(lhs_dilation); + }; - return absl::StrCat("{size=", window_size_string, " pad=", padding_string, - " lhs_dilate=", lhs_dilation_string, "}"); - } + add_dim(kernel_x_, lhs_dilation_x_); + add_dim(kernel_y_, lhs_dilation_y_); - int GetOutputShape(int input_size, int kernel_size, int lhs_dilation) { - return lhs_dilation * (input_size - 1) + kernel_size - - (kernel_size - GetPaddingValue(kernel_size, /*low=*/true) - 1) - - (kernel_size - GetPaddingValue(kernel_size, /*low=*/false) - 1); + return window; } public: @@ -2107,27 +2104,23 @@ class Transposed2DConvHloTest }; XLA_TEST_P(Transposed2DConvHloTest, Simple) { - const auto window = GetWindowString(); - const auto input_shape = - absl::StrCat(batch_, ",", input_channels_, ",", input_x_, ",", input_y_); - const auto kernel_shape = absl::StrCat(output_channels_, ",", input_channels_, - ",", kernel_x_, ",", kernel_y_); - const auto output_shape = - absl::StrCat(batch_, ",", output_channels_, ",", - GetOutputShape(input_x_, kernel_x_, lhs_dilation_x_), ",", - GetOutputShape(input_y_, kernel_y_, lhs_dilation_y_)); + ShapeUtil::MakeShape(F32, {batch_, input_channels_, input_x_, input_y_}); + const auto kernel_shape = ShapeUtil::MakeShape( + F32, {output_channels_, input_channels_, kernel_x_, kernel_y_}); + + const auto window = GetWindow(); // clang-format off const std::string hlo = absl::StrCat(R"( HloModule TestModule ENTRY TestComputation { - input.1 = f32[)", input_shape, R"(]{3,2,1,0} parameter(0) - filter.2 = f32[)", kernel_shape, R"(]{3,2,1,0} parameter(1) - ROOT conv.3 = f32[)", output_shape, R"(]{3,2,1,0} convolution( - input.1, filter.2), - window=)", window, R"(, dim_labels=bf01_oi01->bf01 + input.1 = )", input_shape.ToString(), R"( parameter(0) + filter.2 = )", kernel_shape.ToString(), R"( parameter(1) + ROOT conv.3 = convolution(input.1, filter.2), + window={)", window_util::ToString(window), R"(}, + dim_labels=bf01_oi01->bf01 } )"); // clang-format on diff --git a/third_party/xla/xla/tests/copy_test.cc b/third_party/xla/xla/tests/copy_test.cc index 36b7e0815a844f..5ba971c9d10ac7 100644 --- a/third_party/xla/xla/tests/copy_test.cc +++ b/third_party/xla/xla/tests/copy_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/array3d.h" #include "xla/array4d.h" +#include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -36,7 +37,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/stream_executor/platform.h" #include "xla/tests/client_library_test_base.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/xla_data.pb.h" @@ -45,7 +46,7 @@ limitations under the License. namespace xla { namespace { -class CopyOpTest : public HloTestBase { +class CopyOpTest : public HloPjRtTestBase { protected: CopyOpTest() : platform_(*PlatformUtil::GetDefaultPlatform()) {} @@ -89,7 +90,7 @@ class CopyOpTest : public HloTestBase { se::Platform* platform() const { return platform_; } private: - se::Platform* platform_; + se::Platform* platform_ = nullptr; }; XLA_TEST_F(CopyOpTest, CopyR0Bool) { @@ -190,7 +191,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {&literal}); - LiteralTestUtil::ExpectR0Near(42.0f, result, error_spec_); + LiteralTestUtil::ExpectR0Near(42.0f, result, ErrorSpec{0.0001}); } XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { @@ -211,7 +212,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, result, - error_spec_); + ErrorSpec{0.0001}); } XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { @@ -240,7 +241,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { // The result of the computation has the default layout, which is the inverse // of the layout of the source literal. LiteralTestUtil::ExpectR2Near({{1.0, 3.0}, {2.0, 4.0}}, result, - error_spec_); + ErrorSpec{0.0001}); } void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { diff --git a/third_party/xla/xla/tests/custom_call_test.cc b/third_party/xla/xla/tests/custom_call_test.cc index 3f264f1996fc63..ff88a0de868cf8 100644 --- a/third_party/xla/xla/tests/custom_call_test.cc +++ b/third_party/xla/xla/tests/custom_call_test.cc @@ -409,6 +409,18 @@ XLA_FFI_DEFINE_HANDLER(kAlwaysFail, AlwaysFail, XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$always_fail", PLATFORM, kAlwaysFail); +static absl::Status Tokens(ffi::Token, ffi::Result, + ffi::Result) { + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER( + kTokens, Tokens, + ffi::Ffi::Bind().Arg().Ret().Ret()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$tokens", PLATFORM, + kTokens); + static absl::Status FfiR0F32Add2(R0F32Buffer in, R0F32ResultBuffer out) { auto in_data = in.typed_data(); auto out_data = out->typed_data(); @@ -843,6 +855,24 @@ XLA_TEST_F(FfiCustomCallTest, FfiReportsSuccess) { EXPECT_EQ(status, absl::OkStatus()); } +XLA_TEST_F(FfiCustomCallTest, Tokens) { + auto module = CreateNewVerifiedModule(); + auto builder = HloComputation::Builder(TestName()); + + std::vector ret = {ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeTokenShape()}; + + auto* token = builder.AddInstruction(HloInstruction::CreateToken()); + builder.AddInstruction(HloInstruction::CreateCustomCall( + ShapeUtil::MakeTupleShape(ret), {token}, "__xla_test$$tokens", "", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + module->AddEntryComputation(builder.Build()); + + auto status = Execute(std::move(module), {}).status(); + EXPECT_EQ(status, absl::OkStatus()); +} + XLA_TEST_F(FfiCustomCallTest, FfiUnknownTarget) { auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); diff --git a/third_party/xla/xla/tests/deallocation_test.cc b/third_party/xla/xla/tests/deallocation_test.cc index 213e3f05ed9931..b901d7f5f7ebb3 100644 --- a/third_party/xla/xla/tests/deallocation_test.cc +++ b/third_party/xla/xla/tests/deallocation_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" -#include "xla/test.h" -#include "xla/test_helpers.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/deconstruct_tuple_test.cc b/third_party/xla/xla/tests/deconstruct_tuple_test.cc index e5579e7abc4e20..da6752ffdb1b4f 100644 --- a/third_party/xla/xla/tests/deconstruct_tuple_test.cc +++ b/third_party/xla/xla/tests/deconstruct_tuple_test.cc @@ -23,10 +23,10 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/tests/dot_operation_test.cc b/third_party/xla/xla/tests/dot_operation_test.cc index 674ada04d96c30..866e8693ece841 100644 --- a/third_party/xla/xla/tests/dot_operation_test.cc +++ b/third_party/xla/xla/tests/dot_operation_test.cc @@ -29,14 +29,15 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/reference_util.h" +#include "xla/service/platform_util.h" #include "xla/shape_util.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" #include "tsl/platform/ml_dtypes.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" #if TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" diff --git a/third_party/xla/xla/tests/dynamic_ops_test.cc b/third_party/xla/xla/tests/dynamic_ops_test.cc index ab27dbe99072fe..747d073fc5dd83 100644 --- a/third_party/xla/xla/tests/dynamic_ops_test.cc +++ b/third_party/xla/xla/tests/dynamic_ops_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "xla/client/client_library.h" #include "xla/client/local_client.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/reference_util.h" #include "xla/service/local_service.h" #include "xla/service/platform_util.h" @@ -28,7 +29,6 @@ limitations under the License. #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" -#include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/tests/dynamic_reshape_test.cc b/third_party/xla/xla/tests/dynamic_reshape_test.cc index b4584a785c7f58..6b0f534c66851e 100644 --- a/third_party/xla/xla/tests/dynamic_reshape_test.cc +++ b/third_party/xla/xla/tests/dynamic_reshape_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include +#include "xla/hlo/testlib/test.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/tests/exhaustive/BUILD b/third_party/xla/xla/tests/exhaustive/BUILD index 91eb51e18ac10b..93411097695227 100644 --- a/third_party/xla/xla/tests/exhaustive/BUILD +++ b/third_party/xla/xla/tests/exhaustive/BUILD @@ -141,7 +141,7 @@ exhaustive_xla_test( shard_count = 50, tags = [ "optonly", - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", # This is a big test that we skip for capacity reasons in OSS testing. "no_oss", ], @@ -183,7 +183,7 @@ xla_test( shard_count = 50, tags = [ "optonly", - "test_xla_cpu_thunks", + "test_xla_cpu_no_thunks", # This is a big test that we skip for capacity reasons in OSS testing. "no_oss", ], @@ -250,7 +250,6 @@ exhaustive_xla_test( shard_count = 50, tags = [ "optonly", - "test_xla_cpu_thunks", # This is a big test that we skip for capacity reasons in OSS testing. "no_oss", ], diff --git a/third_party/xla/xla/tests/float8_test.cc b/third_party/xla/xla/tests/float8_test.cc index 648c718d7cd958..71d50ebd6f8676 100644 --- a/third_party/xla/xla/tests/float8_test.cc +++ b/third_party/xla/xla/tests/float8_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "xla/hlo/builder/xla_builder.h" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" #include "tsl/platform/ml_dtypes.h" diff --git a/third_party/xla/xla/tests/gather_operation_test.cc b/third_party/xla/xla/tests/gather_operation_test.cc index 7bf57a8f05138f..6544c85d1e0226 100644 --- a/third_party/xla/xla/tests/gather_operation_test.cc +++ b/third_party/xla/xla/tests/gather_operation_test.cc @@ -13,23 +13,30 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include +#include + #include "xla/array.h" #include "xla/execution_options_util.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/test.h" +#include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/service.h" -#include "xla/status_macros.h" -#include "xla/test.h" #include "xla/tests/client_library_test_base.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/tests/hlo_pjrt_test_base.h" #include "xla/tests/test_macros.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { -using std::nullopt; - -class GatherOperationTest : public HloTestBase { +class GatherOperationTest : public HloPjRtTestBase { protected: void RunTest(const std::string& hlo_text, Literal* operand, Literal* start_indices) { @@ -41,7 +48,7 @@ class GatherOperationTest : public HloTestBase { config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text, config)); - EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt)); + EXPECT_TRUE(RunAndCompare(std::move(module), args, std::nullopt)); } }; diff --git a/third_party/xla/xla/tests/get_dimension_size_test.cc b/third_party/xla/xla/tests/get_dimension_size_test.cc index 44d88f0608ea20..3c815fd989d17b 100644 --- a/third_party/xla/xla/tests/get_dimension_size_test.cc +++ b/third_party/xla/xla/tests/get_dimension_size_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include "absl/status/status.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/tests/grouped_convolution_test.cc b/third_party/xla/xla/tests/grouped_convolution_test.cc index 7a86547f171aae..01f10b82737450 100644 --- a/third_party/xla/xla/tests/grouped_convolution_test.cc +++ b/third_party/xla/xla/tests/grouped_convolution_test.cc @@ -20,10 +20,10 @@ limitations under the License. #include "absl/algorithm/container.h" #include "xla/execution_options_util.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/test.h" #include "xla/hlo/transforms/despecializer.h" #include "xla/hlo/transforms/simplifiers/float_normalization.h" #include "xla/status_macros.h" -#include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/half_test.cc b/third_party/xla/xla/tests/half_test.cc index 385e3622230775..3da23420fccbd9 100644 --- a/third_party/xla/xla/tests/half_test.cc +++ b/third_party/xla/xla/tests/half_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" diff --git a/third_party/xla/xla/tests/hlo_metadata_test.cc b/third_party/xla/xla/tests/hlo_metadata_test.cc index 30cb1fa0e3b262..ecd7d3de892a47 100644 --- a/third_party/xla/xla/tests/hlo_metadata_test.cc +++ b/third_party/xla/xla/tests/hlo_metadata_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/service/local_service.h" -#include "xla/test_helpers.h" #include "xla/tests/local_client_test_base.h" namespace xla { diff --git a/third_party/xla/xla/tests/hlo_pjrt_interpreter_reference_mixin.h b/third_party/xla/xla/tests/hlo_pjrt_interpreter_reference_mixin.h new file mode 100644 index 00000000000000..cdbd2f3575cfbe --- /dev/null +++ b/third_party/xla/xla/tests/hlo_pjrt_interpreter_reference_mixin.h @@ -0,0 +1,50 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TESTS_HLO_PJRT_INTERPRETER_REFERENCE_MIXIN_H_ +#define XLA_TESTS_HLO_PJRT_INTERPRETER_REFERENCE_MIXIN_H_ + +#include + +#include "xla/pjrt/interpreter/interpreter_client.h" +#include "xla/service/hlo_runner_pjrt.h" +#include "xla/tests/hlo_runner_agnostic_reference_mixin.h" + +namespace xla { + +// A wrapper mixin around HloRunnerAgnosticReferenceMixin which provides a +// default reference backend via HloRunnerPjRt using the PjRt InterpreterClient. +// +// The mixin requires that that the test class is a subclass of +// HloRunnerAgnosticTestBase. +template +class HloPjRtInterpreterReferenceMixin + : public HloRunnerAgnosticReferenceMixin { + protected: + template + explicit HloPjRtInterpreterReferenceMixin(BaseArgs&&... base_args) + : HloRunnerAgnosticReferenceMixin( + std::make_unique( + std::make_unique(), + InterpreterClient::DeviceShapeRepresentation, + InterpreterClient::ShapeSizeBytes, + /*use_parameter_layout_on_device=*/true), + std::forward(base_args)...) {} + ~HloPjRtInterpreterReferenceMixin() override = default; +}; + +} // namespace xla + +#endif // XLA_TESTS_HLO_PJRT_INTERPRETER_REFERENCE_MIXIN_H_ diff --git a/third_party/xla/xla/tests/hlo_pjrt_test_base.cc b/third_party/xla/xla/tests/hlo_pjrt_test_base.cc index 8a7b7064a49759..8eb3002d26dbe7 100644 --- a/third_party/xla/xla/tests/hlo_pjrt_test_base.cc +++ b/third_party/xla/xla/tests/hlo_pjrt_test_base.cc @@ -17,18 +17,14 @@ limitations under the License. #include #include -#include #include #include "absl/log/check.h" #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" +#include "xla/pjrt/interpreter/interpreter_client.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/service/hlo_runner.h" #include "xla/service/hlo_runner_interface.h" #include "xla/service/hlo_runner_pjrt.h" -#include "xla/service/platform_util.h" -#include "xla/stream_executor/platform.h" #include "xla/tests/hlo_runner_agnostic_test_base.h" #include "xla/tests/pjrt_client_registry.h" #include "xla/util.h" @@ -57,21 +53,20 @@ std::unique_ptr GetHloRunnerForTest() { } std::unique_ptr GetHloRunnerForReference() { - absl::StatusOr platform = - PlatformUtil::GetPlatform("interpreter"); - CHECK_OK(platform.status()) - << "Failed to get interpreter platform. " << platform.status(); - return std::make_unique(*platform); + return std::make_unique( + std::make_unique(), + InterpreterClient::DeviceShapeRepresentation, + InterpreterClient::ShapeSizeBytes, + /*use_parameter_layout_on_device=*/true); } } // namespace -HloPjRtTestBase::HloPjRtTestBase( - bool verifier_layout_sensitive, bool allow_mixed_precision_in_hlo_verifier, - HloPredicate instruction_can_change_layout_func) - : HloRunnerAgnosticTestBase( - GetHloRunnerForTest(), GetHloRunnerForReference(), - verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier, - instruction_can_change_layout_func) {} +HloPjRtTestBase::HloPjRtTestBase(HloPjRtTestBaseOptions options) + : HloRunnerAgnosticTestBase(GetHloRunnerForTest(), + GetHloRunnerForReference(), + options.verifier_layout_sensitive, + options.allow_mixed_precision_in_hlo_verifier, + options.instruction_can_change_layout_func) {} } // namespace xla diff --git a/third_party/xla/xla/tests/hlo_pjrt_test_base.h b/third_party/xla/xla/tests/hlo_pjrt_test_base.h index 7253f378fb529a..fe7b95dfba363b 100644 --- a/third_party/xla/xla/tests/hlo_pjrt_test_base.h +++ b/third_party/xla/xla/tests/hlo_pjrt_test_base.h @@ -22,14 +22,17 @@ limitations under the License. namespace xla { +struct HloPjRtTestBaseOptions { + bool verifier_layout_sensitive = false; + bool allow_mixed_precision_in_hlo_verifier = true; + HloPredicate instruction_can_change_layout_func; +}; + class HloPjRtTestBase : public HloRunnerAgnosticTestBase { protected: // This uses the SE interpreter backend for the reference backend and // automatically finds a PjRt backend for the test backend. - explicit HloPjRtTestBase( - bool verifier_layout_sensitive = false, - bool allow_mixed_precision_in_hlo_verifier = true, - HloPredicate instruction_can_change_layout_func = {}); + explicit HloPjRtTestBase(HloPjRtTestBaseOptions options = {}); }; } // namespace xla diff --git a/third_party/xla/xla/tests/hlo_runner_agnostic_reference_mixin.cc b/third_party/xla/xla/tests/hlo_runner_agnostic_reference_mixin.cc new file mode 100644 index 00000000000000..dbf6acc37c59ff --- /dev/null +++ b/third_party/xla/xla/tests/hlo_runner_agnostic_reference_mixin.cc @@ -0,0 +1,47 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tests/hlo_runner_agnostic_reference_mixin.h" + +#include "xla/hlo/ir/hlo_module.h" +#include "xla/shape.h" + +namespace xla { + +ProgramShape GetProgramShapeWithLayout(const HloModule& module) { + ProgramShape program_shape; + const auto* entry = module.entry_computation(); + for (const auto* param : entry->parameter_instructions()) { + *program_shape.add_parameters() = param->shape(); + *program_shape.add_parameter_names() = param->name(); + } + *program_shape.mutable_result() = entry->root_instruction()->shape(); + return program_shape; +} + +bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) { + if (lhs.parameters_size() != rhs.parameters_size()) { + return false; + } + for (int i = 0; i < lhs.parameters_size(); ++i) { + if (!Shape::Equal().IgnoreElementSizeInLayout()(lhs.parameters(i), + rhs.parameters(i))) { + return false; + } + } + return Shape::Equal().IgnoreElementSizeInLayout()(lhs.result(), rhs.result()); +} + +} // namespace xla diff --git a/third_party/xla/xla/tests/hlo_runner_agnostic_reference_mixin.h b/third_party/xla/xla/tests/hlo_runner_agnostic_reference_mixin.h new file mode 100644 index 00000000000000..e661e0509c47ee --- /dev/null +++ b/third_party/xla/xla/tests/hlo_runner_agnostic_reference_mixin.h @@ -0,0 +1,251 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TESTS_HLO_RUNNER_AGNOSTIC_REFERENCE_MIXIN_H_ +#define XLA_TESTS_HLO_RUNNER_AGNOSTIC_REFERENCE_MIXIN_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/literal.h" +#include "xla/service/hlo_runner_interface.h" +#include "xla/shape.h" +#include "xla/tests/hlo_runner_agnostic_test_base.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tests/test_utils.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" + +namespace xla { + +ProgramShape GetProgramShapeWithLayout(const HloModule& module); + +bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs); + +// This class is designed to be used as a mixin for tests that want to run +// against a reference implementation via a runner implementing +// HloRunnerInterface. +// +// The mixin requires that that the test class is a subclass of +// HloRunnerAgnosticTestBase. +template +class HloRunnerAgnosticReferenceMixin : public T { + static_assert( + std::is_base_of_v, + "Mixin must be used with a subclass of HloRunnerAgnosticTestBase."); + + protected: + template + explicit HloRunnerAgnosticReferenceMixin( + absl::Nonnull> reference_runner, + BaseArgs&&... base_args) + : T(std::forward(base_args)...), + reference_runner_(std::move(reference_runner)) {} + ~HloRunnerAgnosticReferenceMixin() override = default; + + // Executes the given hlo module on two backends and compares results. + // + // 'arguments': the input of the hlo module. + // + // 'error': if has value, expects the results to be near (within the error + // bound). Otherwise, expects the results to be equal. + // + // 'reference_preprocessor': the module should be ready to run on the test + // backend, but it might need to be tailored so that it is able to run on the + // reference backend. Note that the program shape of the module must not be + // modified. + ::testing::AssertionResult RunAndCompare( + std::unique_ptr module, absl::Span arguments, + const std::optional& error, + const std::function& reference_preprocessor = nullptr, + const std::function& test_preprocessor = nullptr) { + const absl::StatusOr<::testing::AssertionResult> result = + RunAndCompareInternal(std::move(module), arguments, error, + /*run_hlo_passes=*/true, reference_preprocessor, + test_preprocessor); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return *result; + } + + // Same as above, except that the module will be executed without Hlo + // optimization. + ::testing::AssertionResult RunAndCompareNoHloPasses( + std::unique_ptr module, + const absl::Span arguments, + const std::optional& error, + const std::function& reference_preprocessor = nullptr, + const std::function& test_preprocessor = nullptr) { + const absl::StatusOr<::testing::AssertionResult> result = + RunAndCompareInternal(std::move(module), arguments, error, + /*run_hlo_passes=*/false, reference_preprocessor, + test_preprocessor); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return *result; + } + + // Executes an hlo module with fake inputs and compares the results. + ::testing::AssertionResult RunAndCompare( + std::unique_ptr module, const std::optional& error, + const std::function& reference_preprocessor = nullptr, + const std::function& test_preprocessor = nullptr, + const std::optional args_max_bits_of_precision = std::nullopt) { + const absl::StatusOr> fake_arguments = + MakeFakeArguments(module.get(), /*pseudo_random=*/true, + /*use_large_range=*/false, + /*treat_gte_as_data_formatting=*/false, + args_max_bits_of_precision); + if (!fake_arguments.ok()) { + return ::testing::AssertionFailure() << fake_arguments.status().message(); + } + std::vector fake_argument_ptrs; + absl::c_transform( + *fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + return RunAndCompare(std::move(module), fake_argument_ptrs, error, + reference_preprocessor, test_preprocessor); + } + + // Same as above, except that the module will be executed without Hlo + // optimization. + ::testing::AssertionResult RunAndCompareNoHloPasses( + std::unique_ptr module, const std::optional& error, + const std::function& reference_preprocessor = nullptr, + const std::function& test_preprocessor = nullptr) { + const absl::StatusOr> fake_arguments = + MakeFakeArguments(module.get()); + if (!fake_arguments.ok()) { + return ::testing::AssertionFailure() << fake_arguments.status().message(); + } + std::vector fake_argument_ptrs; + absl::c_transform( + *fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, + error, reference_preprocessor, + test_preprocessor); + } + + // Convenient wrapper for executing and comparing an hlo module with fake + // input. Module can be passed in directly, or parsed from an hlo_string, + // or loaded from a file. + ::testing::AssertionResult RunAndCompare( + const absl::string_view hlo_string, const std::optional& error, + const std::function& reference_preprocessor = nullptr, + const std::function& test_preprocessor = nullptr, + const std::optional args_max_bits_of_precision = std::nullopt) { + absl::StatusOr> module = + this->ParseAndReturnVerifiedModule(hlo_string); + if (!module.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module.status().ToString(); + } + return RunAndCompare(*std::move(module), error, reference_preprocessor, + test_preprocessor, args_max_bits_of_precision); + } + + ::testing::AssertionResult RunAndCompareNoHloPasses( + const absl::string_view hlo_string, const std::optional& error, + const std::function& reference_preprocessor = nullptr, + const std::function& test_preprocessor = nullptr) { + absl::StatusOr> module = + this->ParseAndReturnVerifiedModule(hlo_string); + if (!module.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module.status().ToString(); + } + return RunAndCompareNoHloPasses(*std::move(module), error, + reference_preprocessor, test_preprocessor); + } + + HloRunnerInterface& reference_runner() const { return *reference_runner_; } + + private: + // Given the test module, makes a reference module that is ready to run on the + // reference platform. This assumes that the given module is ready to run on + // the test platform. + absl::StatusOr> MakeReferenceModule( + const HloModule& test_module, + const std::function& reference_preprocessor = nullptr) { + std::unique_ptr reference_module = test_module.Clone(); + const ProgramShape program_shape = GetProgramShapeWithLayout(test_module); + + if (reference_preprocessor != nullptr) { + reference_preprocessor(reference_module.get()); + if (!ProgramShapesEqual(program_shape, + GetProgramShapeWithLayout(*reference_module))) { + return absl::InvalidArgumentError( + "reference preprocessor must not modify the program shape"); + } + } + TF_RETURN_IF_ERROR(this->verifier().Run(reference_module.get()).status()); + return std::move(reference_module); + } + + // Runs the module on two platforms with or without running hlo passes and + // compares the results. Returns whether the results are near or equal. If any + // error happens before the results are computed, returns the error status. + absl::StatusOr<::testing::AssertionResult> RunAndCompareInternal( + std::unique_ptr module, absl::Span arguments, + const std::optional& error, bool run_hlo_passes, + const std::function& reference_preprocessor = nullptr, + const std::function& test_preprocessor = nullptr) { + TF_RETURN_IF_ERROR(this->verifier().Run(module.get()).status()); + TF_ASSIGN_OR_RETURN(std::unique_ptr reference_module, + MakeReferenceModule(*module, reference_preprocessor)); + TF_RETURN_IF_ERROR(this->PreprocessModuleForTestRunner(module.get())); + if (test_preprocessor != nullptr) { + test_preprocessor(module.get()); + } + // Execute on two backends. + TF_ASSIGN_OR_RETURN(const Literal test, + this->test_runner().Execute(std::move(module), + arguments, run_hlo_passes)); + TF_ASSIGN_OR_RETURN(const Literal reference, + reference_runner_->Execute(std::move(reference_module), + arguments, run_hlo_passes)); + if (reference.IsAll(0)) { + LOG(WARNING) << "Reference value is only zeros."; + } + + return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test, + error); + } + + std::unique_ptr reference_runner_; +}; + +} // namespace xla + +#endif // XLA_TESTS_HLO_RUNNER_AGNOSTIC_REFERENCE_MIXIN_H_ diff --git a/third_party/xla/xla/tests/hlo_runner_agnostic_test_base.cc b/third_party/xla/xla/tests/hlo_runner_agnostic_test_base.cc index 40f47428a72c79..b781a0eebd37d0 100644 --- a/third_party/xla/xla/tests/hlo_runner_agnostic_test_base.cc +++ b/third_party/xla/xla/tests/hlo_runner_agnostic_test_base.cc @@ -30,19 +30,13 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" -#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" -#include "xla/debug_options_flags.h" #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module_group.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/verified_hlo_module.h" -#include "xla/hlo/utils/hlo_query.h" #include "xla/literal.h" #include "xla/service/computation_placer.h" #include "xla/service/executable.h" @@ -51,15 +45,14 @@ limitations under the License. #include "xla/service/hlo_runner_interface.h" #include "xla/service/hlo_verifier.h" #include "xla/shape.h" -#include "xla/tests/filecheck.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_utils.h" -#include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "tsl/platform/protobuf.h" namespace xla { diff --git a/third_party/xla/xla/tests/hlo_runner_agnostic_test_base.h b/third_party/xla/xla/tests/hlo_runner_agnostic_test_base.h index 4aaa14da2c0b6c..3bb8c5b787a917 100644 --- a/third_party/xla/xla/tests/hlo_runner_agnostic_test_base.h +++ b/third_party/xla/xla/tests/hlo_runner_agnostic_test_base.h @@ -24,7 +24,6 @@ limitations under the License. #include #include -#include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -35,32 +34,17 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_module_group.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/hlo/testlib/verified_hlo_module.h" -#include "xla/layout.h" #include "xla/literal.h" -#include "xla/literal_util.h" -#include "xla/service/backend.h" -#include "xla/service/computation_layout.h" #include "xla/service/computation_placer.h" #include "xla/service/executable.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_runner.h" #include "xla/service/hlo_runner_interface.h" -#include "xla/service/hlo_verifier.h" -#include "xla/service/platform_util.h" -#include "xla/shape_layout.h" -#include "xla/shape_util.h" -#include "xla/stream_executor/device_memory_allocator.h" -#include "xla/test_helpers.h" -#include "xla/tests/literal_test_util.h" -#include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/platform/test.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { @@ -95,6 +79,9 @@ namespace xla { // other implementations. We plan to incrementally migrate tests to this class // and away from HloTestBase. class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { + public: + static constexpr ErrorSpec kDefaultErrorSpec{0.0001}; + protected: explicit HloRunnerAgnosticTestBase( absl::Nonnull> test_runner, @@ -187,7 +174,7 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { // backend, but it might need to be tailored so that it is able to run on the // reference backend. Note that the program shape of the module must not be // modified. - [[nodiscard]] ::testing::AssertionResult RunAndCompare( + ::testing::AssertionResult RunAndCompare( std::unique_ptr module, absl::Span arguments, const std::optional& error, const std::function& reference_preprocessor = nullptr, @@ -195,14 +182,14 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { // Same as above, except that the module will be executed without Hlo // optimization. - [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( + ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, absl::Span arguments, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr); // Executes an hlo module with fake inputs and compares the results. - [[nodiscard]] ::testing::AssertionResult RunAndCompare( + ::testing::AssertionResult RunAndCompare( std::unique_ptr module, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr, @@ -210,26 +197,26 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { // Same as above, except that the module will be executed without Hlo // optimization. - [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( + ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr); // Executes an hlo module with fake inputs and checks that the execution is // successful. - [[nodiscard]] ::testing::AssertionResult Run( + ::testing::AssertionResult Run( std::unique_ptr module, bool run_hlo_passes, const std::function& test_preprocessor = nullptr); // Convenient wrappers for executing and comparing an hlo module with fake // input. Module can be passed in directly, or parsed from an hlo_string, // or loaded from a file. - [[nodiscard]] ::testing::AssertionResult RunAndCompare( + ::testing::AssertionResult RunAndCompare( absl::string_view hlo_string, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr, std::optional args_max_bits_of_precision = std::nullopt); - [[nodiscard]] ::testing::AssertionResult Run( + ::testing::AssertionResult Run( absl::string_view hlo_string, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr, const tsl::protobuf::Message* backend_config = nullptr, @@ -297,19 +284,19 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { const std::optional& error, bool run_hlo_passes = true); // Executes an hlo module with fake inputs on multiple replicas. - [[nodiscard]] ::testing::AssertionResult RunReplicated( + ::testing::AssertionResult RunReplicated( absl::string_view hlo_string, bool run_hlo_passes = true, int64_t num_replicas = 1, const tsl::protobuf::Message* backend_config = nullptr); // If assert_determinism is true, the assertion will fail unless all runs // produce exactly the same output. - [[nodiscard]] ::testing::AssertionResult RunMultipleTimes( + ::testing::AssertionResult RunMultipleTimes( absl::string_view hlo_string, bool run_hlo_passes, std::vector* profiles, const tsl::protobuf::Message* backend_config = nullptr, bool assert_determinism = false); - [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( + ::testing::AssertionResult RunAndCompareNoHloPasses( absl::string_view hlo_string, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr); diff --git a/third_party/xla/xla/tests/hlo_test_base.cc b/third_party/xla/xla/tests/hlo_test_base.cc index f66c3208ed22a9..896de2cb58aa10 100644 --- a/third_party/xla/xla/tests/hlo_test_base.cc +++ b/third_party/xla/xla/tests/hlo_test_base.cc @@ -15,58 +15,32 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" -#include -#include #include -#include #include #include #include -#include #include -#include -#include "absl/algorithm/container.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/debug_options_flags.h" #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module_group.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" -#include "xla/hlo/utils/hlo_query.h" -#include "xla/literal.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/pjrt/pjrt_client.h" #include "xla/service/backend.h" -#include "xla/service/computation_placer.h" -#include "xla/service/executable.h" -#include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_module_util.h" #include "xla/service/hlo_runner.h" #include "xla/service/hlo_runner_interface.h" #include "xla/service/hlo_runner_pjrt.h" -#include "xla/service/hlo_verifier.h" #include "xla/service/platform_util.h" -#include "xla/shape.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_runner_agnostic_test_base.h" -#include "xla/tests/literal_test_util.h" #include "xla/tests/pjrt_client_registry.h" -#include "xla/tests/test_utils.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/tests/hlo_test_base.h b/third_party/xla/xla/tests/hlo_test_base.h index 6b9bda3327377b..00e755a010e2c3 100644 --- a/third_party/xla/xla/tests/hlo_test_base.h +++ b/third_party/xla/xla/tests/hlo_test_base.h @@ -206,13 +206,6 @@ class HloTestBase : public HloRunnerAgnosticTestBase { std::unique_ptr allocator_; }; -#define SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(x) \ - int64_t num_devices = backend().device_count(); \ - if (num_devices < x) { \ - GTEST_SKIP() << "Test requires at least " << x << " devices (" \ - << num_devices << " available)"; \ - } - } // namespace xla #endif // XLA_TESTS_HLO_TEST_BASE_H_ diff --git a/third_party/xla/xla/tests/int4_test.cc b/third_party/xla/xla/tests/int4_test.cc index 264a68e2d0479d..dc925069a462d3 100644 --- a/third_party/xla/xla/tests/int4_test.cc +++ b/third_party/xla/xla/tests/int4_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "absl/strings/substitute.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/literal_test_util.h b/third_party/xla/xla/tests/literal_test_util.h index 01b2aa6433c30a..d5d3090288000f 100644 --- a/third_party/xla/xla/tests/literal_test_util.h +++ b/third_party/xla/xla/tests/literal_test_util.h @@ -28,10 +28,10 @@ limitations under the License. #include "xla/array3d.h" #include "xla/array4d.h" #include "xla/error_spec.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/tests/literal_test_util_test.cc b/third_party/xla/xla/tests/literal_test_util_test.cc index 4912a37255d9d6..7c6b201fb9a260 100644 --- a/third_party/xla/xla/tests/literal_test_util_test.cc +++ b/third_party/xla/xla/tests/literal_test_util_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include "absl/strings/str_join.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" -#include "xla/test_helpers.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" #include "tsl/platform/path.h" diff --git a/third_party/xla/xla/tests/llvm_compiler_test.cc b/third_party/xla/xla/tests/llvm_compiler_test.cc index 94e37c64664948..b099b6271319b5 100644 --- a/third_party/xla/xla/tests/llvm_compiler_test.cc +++ b/third_party/xla/xla/tests/llvm_compiler_test.cc @@ -25,11 +25,11 @@ limitations under the License. #include "llvm/IR/Module.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal_util.h" #include "xla/service/backend.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/casts.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/tests/llvm_irgen_test_base.cc b/third_party/xla/xla/tests/llvm_irgen_test_base.cc index db3d06c69f62dd..ca879bf88098f8 100644 --- a/third_party/xla/xla/tests/llvm_irgen_test_base.cc +++ b/third_party/xla/xla/tests/llvm_irgen_test_base.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include "absl/status/status.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/service/llvm_ir/llvm_util.h" -#include "xla/tests/filecheck.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/tests/local_client_execute_test.cc b/third_party/xla/xla/tests/local_client_execute_test.cc index 22c469aa992863..8d4c0f5345fb60 100644 --- a/third_party/xla/xla/tests/local_client_execute_test.cc +++ b/third_party/xla/xla/tests/local_client_execute_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/hlo/builder/sharding_builder.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/service/platform_util.h" @@ -34,7 +35,6 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" -#include "xla/test_helpers.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/local_client_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/local_client_test_base.cc b/third_party/xla/xla/tests/local_client_test_base.cc index 0f4750132889ba..aaebced1f3b9b6 100644 --- a/third_party/xla/xla/tests/local_client_test_base.cc +++ b/third_party/xla/xla/tests/local_client_test_base.cc @@ -25,12 +25,12 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/map_util.h" #include "xla/service/hlo_module_config.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" -#include "xla/test_helpers.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/tests/local_client_test_base.h b/third_party/xla/xla/tests/local_client_test_base.h index dfe45beb735b89..df1facab14590f 100644 --- a/third_party/xla/xla/tests/local_client_test_base.h +++ b/third_party/xla/xla/tests/local_client_test_base.h @@ -27,6 +27,7 @@ limitations under the License. #include "xla/client/client_library.h" #include "xla/client/local_client.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/hlo_module_config.h" #include "xla/service/local_service.h" #include "xla/service/platform_util.h" @@ -36,7 +37,6 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/client_library_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/tests/map_test.cc b/third_party/xla/xla/tests/map_test.cc index 6d654a74a06656..fcb381bf7e0691 100644 --- a/third_party/xla/xla/tests/map_test.cc +++ b/third_party/xla/xla/tests/map_test.cc @@ -23,11 +23,11 @@ limitations under the License. #include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/tests/matmul_test.cc b/third_party/xla/xla/tests/matmul_test.cc index 1ed47869346ad0..19c671f138b052 100644 --- a/third_party/xla/xla/tests/matmul_test.cc +++ b/third_party/xla/xla/tests/matmul_test.cc @@ -15,10 +15,10 @@ limitations under the License. #include +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/matrix_ops_simple_test.cc b/third_party/xla/xla/tests/matrix_ops_simple_test.cc index 65bad8ae68fe38..b7e06fc286c961 100644 --- a/third_party/xla/xla/tests/matrix_ops_simple_test.cc +++ b/third_party/xla/xla/tests/matrix_ops_simple_test.cc @@ -18,26 +18,26 @@ limitations under the License. #include #include #include +#include +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#if TENSORFLOW_USE_ROCM -#include "rocm/rocm_config.h" -#endif -#include "absl/status/statusor.h" #include "xla/array2d.h" #include "xla/client/local_client.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" #include "xla/reference_util.h" #include "xla/shape_util.h" -#include "xla/test_helpers.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -182,16 +182,25 @@ class MatOpsDotAddTest : public ClientLibraryTestBase, public ::testing::WithParamInterface> { public: + // Returns true if the test is using a GPU. + bool IsGpu() { + auto stream_executor = client_->platform()->ExecutorForDevice(0).value(); + auto gpu_compute_capability = + stream_executor->GetDeviceDescription().gpu_compute_capability(); + if ((std::holds_alternative( + gpu_compute_capability)) || + std::holds_alternative( + gpu_compute_capability)) { + return true; + } + return false; + } template void TestImpl() { bool row_major = std::get<0>(GetParam()); bool add_lhs = std::get<1>(GetParam()); bool transpose = std::get<2>(GetParam()); -#if GOOGLE_CUDA || TF_HIPBLASLT - bool use_cublaslt = std::get<3>(GetParam()); -#else - bool use_cublaslt = false; -#endif + bool use_cublaslt = IsGpu() ? std::get<3>(GetParam()) : false; execution_options_.mutable_debug_options()->set_xla_gpu_enable_cublaslt( use_cublaslt); Array2D lhs({{1.0f, 2.0f}, {3.0f, 4.0f}}); @@ -287,11 +296,7 @@ class MatOpsDotAddTest void TestImplBiasAddEpilogueFusion() { bool row_major = std::get<0>(GetParam()); bool transpose = std::get<2>(GetParam()); -#if GOOGLE_CUDA || TF_HIPBLASLT - bool use_cublaslt = std::get<3>(GetParam()); -#else - bool use_cublaslt = false; -#endif + bool use_cublaslt = IsGpu() ? std::get<3>(GetParam()) : false; execution_options_.mutable_debug_options()->set_xla_gpu_enable_cublaslt( use_cublaslt); Array2D lhs({{1.0f, 2.0f}, {3.0f, 4.0f}}); @@ -337,11 +342,7 @@ class MatOpsDotAddTest void TestImplReluActivationEpilogueFusion() { bool row_major = std::get<0>(GetParam()); bool transpose = std::get<2>(GetParam()); -#if GOOGLE_CUDA || TF_HIPBLASLT - bool use_cublaslt = std::get<3>(GetParam()); -#else - bool use_cublaslt = false; -#endif + bool use_cublaslt = IsGpu() ? std::get<3>(GetParam()) : false; execution_options_.mutable_debug_options()->set_xla_gpu_enable_cublaslt( use_cublaslt); Array2D lhs({{-1.0f, 2.0f}, {3.0f, 4.0f}}); @@ -382,11 +383,7 @@ class MatOpsDotAddTest void TestImplBiasAddReluActivationEpilogueFusion() { bool row_major = std::get<0>(GetParam()); bool transpose = std::get<2>(GetParam()); -#if GOOGLE_CUDA || TF_HIPBLASLT - bool use_cublaslt = std::get<3>(GetParam()); -#else - bool use_cublaslt = false; -#endif + bool use_cublaslt = IsGpu() ? std::get<3>(GetParam()) : false; execution_options_.mutable_debug_options()->set_xla_gpu_enable_cublaslt( use_cublaslt); Array2D lhs({{-1.0f, 2.0f}, {3.0f, 4.0f}}); diff --git a/third_party/xla/xla/tests/multithreaded_compilation_test.cc b/third_party/xla/xla/tests/multithreaded_compilation_test.cc index 1e5f138389a289..f8708a6fef4f6d 100644 --- a/third_party/xla/xla/tests/multithreaded_compilation_test.cc +++ b/third_party/xla/xla/tests/multithreaded_compilation_test.cc @@ -21,12 +21,12 @@ limitations under the License. #include #include "absl/status/status.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/hlo.pb.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tsl/lib/core/status_test_util.h" diff --git a/third_party/xla/xla/tests/nccl_group_execution_test.cc b/third_party/xla/xla/tests/nccl_group_execution_test.cc index b2d5e6d3317f70..2e08042b5432f7 100644 --- a/third_party/xla/xla/tests/nccl_group_execution_test.cc +++ b/third_party/xla/xla/tests/nccl_group_execution_test.cc @@ -18,8 +18,6 @@ limitations under the License. #include #include -#include -#include "absl/log/log.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/testlib/verified_hlo_module.h" @@ -27,7 +25,9 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { @@ -89,7 +89,10 @@ XLA_TEST_F(NcclGroupExecutionTest, NcclGroupSendRecvNoWhileLoop) { recv-done2 = (f32[], token[]) tuple(recv-done-data2, recv-done-token2), control-predecessors={async-comp-start} data-out2 = f32[] get-tuple-element(recv-done2), index=0 - ROOT out = (f32[], f32[]) tuple(data-out1, data-out2) + c100 = f32[] constant(100) + res1 = f32[] dot(data-out1, c100) + res2 = f32[] dot(data-out2, c100) + ROOT out = (f32[], f32[]) tuple(res1, res2) unpack-send-done1 = (f32[], u32[], token[]) get-tuple-element(async-comp-done), index=0 send-done1 = token[] get-tuple-element(unpack-send-done1), index=2 unpack-send-done2 = (f32[], u32[], token[]) get-tuple-element(async-comp-done), index=1 @@ -98,7 +101,10 @@ XLA_TEST_F(NcclGroupExecutionTest, NcclGroupSendRecvNoWhileLoop) { )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -114,9 +120,9 @@ XLA_TEST_F(NcclGroupExecutionTest, NcclGroupSendRecvNoWhileLoop) { // TODO (rosiezou): remove the string comparison once a tuple comparison // function is available in LiteralTestUtil. EXPECT_EQ(results[0].ToStringWithoutShapeOneline(), "( 0, 0 )"); - EXPECT_EQ(results[1].ToStringWithoutShapeOneline(), "( 10, 0 )"); - EXPECT_EQ(results[2].ToStringWithoutShapeOneline(), "( 10, 0 )"); - EXPECT_EQ(results[3].ToStringWithoutShapeOneline(), "( 0, 20 )"); + EXPECT_EQ(results[1].ToStringWithoutShapeOneline(), "( 1000, 0 )"); + EXPECT_EQ(results[2].ToStringWithoutShapeOneline(), "( 1000, 0 )"); + EXPECT_EQ(results[3].ToStringWithoutShapeOneline(), "( 0, 2000 )"); } } // namespace diff --git a/third_party/xla/xla/tests/numerics_test.cc b/third_party/xla/xla/tests/numerics_test.cc index b1bfcd9ed24d4c..988f9d6990c1ca 100644 --- a/third_party/xla/xla/tests/numerics_test.cc +++ b/third_party/xla/xla/tests/numerics_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "xla/types.h" diff --git a/third_party/xla/xla/tests/pjrt_interpreter_client_registry.cc b/third_party/xla/xla/tests/pjrt_interpreter_client_registry.cc index 9e21b88ea4db75..52389287bab30b 100644 --- a/third_party/xla/xla/tests/pjrt_interpreter_client_registry.cc +++ b/third_party/xla/xla/tests/pjrt_interpreter_client_registry.cc @@ -14,25 +14,23 @@ limitations under the License. ==============================================================================*/ #include -#include -#include "absl/status/statusor.h" -#include "xla/pjrt/interpreter_device.h" +#include "xla/pjrt/interpreter/interpreter_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/tests/pjrt_client_registry.h" -#include "tsl/platform/status.h" namespace xla { namespace { // Register an interpreter PjRt client for tests. -const bool kUnused = (RegisterPjRtClientTestFactory([]() { - absl::StatusOr> client = - GetInterpreterClient(); - TF_CHECK_OK(client.status()); - return *std::move(client); - }), - true); +const bool kUnused = + (RegisterPjRtClientTestFactory( + []() { return std::make_unique(); }, + [](PjRtClient* client) { + return InterpreterClient::DeviceShapeRepresentation; + }, + [](PjRtClient* client) { return InterpreterClient::ShapeSizeBytes; }), + true); } // namespace } // namespace xla diff --git a/third_party/xla/xla/tests/prng_test.cc b/third_party/xla/xla/tests/prng_test.cc index b68f56c4157635..0400c3683cc2f6 100644 --- a/third_party/xla/xla/tests/prng_test.cc +++ b/third_party/xla/xla/tests/prng_test.cc @@ -27,10 +27,10 @@ limitations under the License. #include "unsupported/Eigen/SpecialFunctions" #include "xla/client/local_client.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal.h" #include "xla/primitive_util.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" #include "xla/util.h" diff --git a/third_party/xla/xla/tests/ptxas_bug_120501638.cc b/third_party/xla/xla/tests/ptxas_bug_120501638.cc index 2c1217cf8be918..9cc57edcac2ce2 100644 --- a/third_party/xla/xla/tests/ptxas_bug_120501638.cc +++ b/third_party/xla/xla/tests/ptxas_bug_120501638.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "xla/debug_options_flags.h" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/query_inferred_shape_test.cc b/third_party/xla/xla/tests/query_inferred_shape_test.cc index 871e6266220cc2..a163ba58dd8313 100644 --- a/third_party/xla/xla/tests/query_inferred_shape_test.cc +++ b/third_party/xla/xla/tests/query_inferred_shape_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/client/local_client.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/shape_util.h" -#include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/tests/reduce_precision_test.cc b/third_party/xla/xla/tests/reduce_precision_test.cc index b3614174902dc5..72b535e2908f7a 100644 --- a/third_party/xla/xla/tests/reduce_precision_test.cc +++ b/third_party/xla/xla/tests/reduce_precision_test.cc @@ -27,9 +27,9 @@ limitations under the License. #include "xla/client/global_data.h" #include "xla/client/local_client.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/testlib/test.h" #include "xla/layout_util.h" #include "xla/literal.h" -#include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/replicated_io_feed_test.cc b/third_party/xla/xla/tests/replicated_io_feed_test.cc index 415faa01ff89e7..a6d82d33112c40 100644 --- a/third_party/xla/xla/tests/replicated_io_feed_test.cc +++ b/third_party/xla/xla/tests/replicated_io_feed_test.cc @@ -13,32 +13,37 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/service/computation_placer.h" +#include "xla/service/hlo_runner_interface.h" #include "xla/shape_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tests/hlo_pjrt_test_base.h" +#include "xla/tests/literal_test_util.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" // Tests replicated infeed/outfeed operations. namespace xla { +namespace { -class ReplicatedIOFeedTest : public HloTestBase {}; +class ReplicatedIOFeedTest : public HloPjRtTestBase {}; -static DeviceAssignment MakeDeviceAssn(size_t num_devices) { - DeviceAssignment assn(/*replica_count=*/num_devices, - /*computation_count=*/1); - for (int64_t i = 0; i < num_devices; ++i) { - assn(i, 0) = i; - } - return assn; -} - -XLA_TEST_F(ReplicatedIOFeedTest, InfeedAndOutfeed) { - std::string hlo_text = R"( +TEST_F(ReplicatedIOFeedTest, InfeedAndOutfeed) { + static constexpr int kNumReplicas = 4; + static constexpr absl::string_view kHloText = R"( HloModule infeed ENTRY main { // Read from infeed, add replica_id, and send to outfeed. @@ -50,22 +55,14 @@ XLA_TEST_F(ReplicatedIOFeedTest, InfeedAndOutfeed) { result = u32[] add(infeed.data, replica_id) outfeed = token[] outfeed(result, infeed.token), outfeed_shape=u32[] })"; - - const int kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); - - auto config = GetModuleConfigForTest(); - config.set_replica_count(kNumReplicas); - std::unique_ptr module = - ParseAndReturnVerifiedModule(hlo_text, config).value(); - auto executable = - CreateExecutable(std::move(module), /*run_hlo_passes=*/true).value(); - - auto device_assn = MakeDeviceAssn(kNumReplicas); + if (test_runner().device_count() < kNumReplicas) { + GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices (" + << test_runner().device_count() << " available)"; + } std::vector outfeed_literals; - HloRunner::ReplicatedExecuteOptions opts; + HloRunnerInterface::ReplicatedExecuteOptions opts; opts.num_replicas = kNumReplicas; // Initialize infeed literal = replica_id * 10 @@ -79,9 +76,15 @@ XLA_TEST_F(ReplicatedIOFeedTest, InfeedAndOutfeed) { opts.outfeed_values = &outfeed_literals; opts.use_threads = true; - TF_ASSERT_OK( - ExecuteReplicatedWithHloRunner(executable.get(), opts, &device_assn) - .status()); + DeviceAssignment device_assn(/*replica_count=*/kNumReplicas, + /*computation_count=*/1); + device_assn.FillIota(0); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule( + kHloText, GetModuleConfigForTest(kNumReplicas))); + TF_ASSERT_OK(test_runner() + .ExecuteReplicated(std::move(module), opts, &device_assn) + .status()); // Verify that each infeed and outfeed is routed correctly. Each replica // should produce 10*replica (indeed) + replica (from HLO) @@ -89,4 +92,6 @@ XLA_TEST_F(ReplicatedIOFeedTest, InfeedAndOutfeed) { LiteralTestUtil::ExpectR0Equal(10 * i + i, outfeed_literals[i]); } } + +} // namespace } // namespace xla diff --git a/third_party/xla/xla/tests/reshape_motion_test.cc b/third_party/xla/xla/tests/reshape_motion_test.cc index 2300df5990c635..07dc6473f2e167 100644 --- a/third_party/xla/xla/tests/reshape_motion_test.cc +++ b/third_party/xla/xla/tests/reshape_motion_test.cc @@ -25,12 +25,12 @@ limitations under the License. #include "xla/client/global_data.h" #include "xla/client/local_client.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/reference_util.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/reshape_test.cc b/third_party/xla/xla/tests/reshape_test.cc index 84d51c5f53de49..925b3e2f72b843 100644 --- a/third_party/xla/xla/tests/reshape_test.cc +++ b/third_party/xla/xla/tests/reshape_test.cc @@ -31,12 +31,12 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/test.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/reference_util.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/tests/sample_file_test.cc b/third_party/xla/xla/tests/sample_file_test.cc index 3179fbc95aa59b..9bf2903dc6b294 100644 --- a/third_party/xla/xla/tests/sample_file_test.cc +++ b/third_party/xla/xla/tests/sample_file_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include +#include "xla/hlo/testlib/test.h" #include "xla/service/platform_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/path.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/tests/sample_text_test.cc b/third_party/xla/xla/tests/sample_text_test.cc index 576cdeffea8a84..3b5dc2692149de 100644 --- a/third_party/xla/xla/tests/sample_text_test.cc +++ b/third_party/xla/xla/tests/sample_text_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/scalar_computations_test.cc b/third_party/xla/xla/tests/scalar_computations_test.cc index bc9ab8b7326d3e..b5383efe438b30 100644 --- a/third_party/xla/xla/tests/scalar_computations_test.cc +++ b/third_party/xla/xla/tests/scalar_computations_test.cc @@ -25,10 +25,10 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/status_macros.h" -#include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/scatter_test.cc b/third_party/xla/xla/tests/scatter_test.cc index 0151b863e1a08a..f3d32dfe758e90 100644 --- a/third_party/xla/xla/tests/scatter_test.cc +++ b/third_party/xla/xla/tests/scatter_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include "absl/strings/substitute.h" #include "xla/array2d.h" #include "xla/error_spec.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/set_dimension_size_test.cc b/third_party/xla/xla/tests/set_dimension_size_test.cc index 5e7d3984f5952c..3674e582802647 100644 --- a/third_party/xla/xla/tests/set_dimension_size_test.cc +++ b/third_party/xla/xla/tests/set_dimension_size_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include "absl/status/status.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/tests/tile_assignment_test.cc b/third_party/xla/xla/tests/tile_assignment_test.cc index 0f8368555f3200..9f20e86d6b4529 100644 --- a/third_party/xla/xla/tests/tile_assignment_test.cc +++ b/third_party/xla/xla/tests/tile_assignment_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "absl/hash/hash.h" #include "xla/array3d.h" -#include "xla/test.h" +#include "xla/hlo/testlib/test.h" namespace xla { namespace { diff --git a/third_party/xla/xla/tests/triangular_solve_test.cc b/third_party/xla/xla/tests/triangular_solve_test.cc index 3bbe5ca227c074..a2e6334f69f99c 100644 --- a/third_party/xla/xla/tests/triangular_solve_test.cc +++ b/third_party/xla/xla/tests/triangular_solve_test.cc @@ -24,8 +24,8 @@ limitations under the License. #include "xla/hlo/builder/lib/math.h" #include "xla/hlo/builder/lib/matrix.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/testlib/test.h" #include "xla/literal.h" -#include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/tuple_test.cc b/third_party/xla/xla/tests/tuple_test.cc index 5cc7f7b1bb9d18..2c9e6ed0073cba 100644 --- a/third_party/xla/xla/tests/tuple_test.cc +++ b/third_party/xla/xla/tests/tuple_test.cc @@ -22,9 +22,9 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal_util.h" #include "xla/shape_util.h" -#include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/tests/value_inference_test.cc b/third_party/xla/xla/tests/value_inference_test.cc index 5ac2f038f67180..661ff06b44075f 100644 --- a/third_party/xla/xla/tests/value_inference_test.cc +++ b/third_party/xla/xla/tests/value_inference_test.cc @@ -28,11 +28,11 @@ limitations under the License. #include "xla/hlo/builder/lib/prng.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/test.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/test.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" diff --git a/third_party/xla/xla/tests/vector_ops_simple_test.cc b/third_party/xla/xla/tests/vector_ops_simple_test.cc index eb67f886d6254f..a9705b0c61f283 100644 --- a/third_party/xla/xla/tests/vector_ops_simple_test.cc +++ b/third_party/xla/xla/tests/vector_ops_simple_test.cc @@ -26,8 +26,8 @@ limitations under the License. #include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/shape_util.h" -#include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/xnn_fusion_test.cc b/third_party/xla/xla/tests/xnn_fusion_test.cc new file mode 100644 index 00000000000000..d76873f23a8bb5 --- /dev/null +++ b/third_party/xla/xla/tests/xnn_fusion_test.cc @@ -0,0 +1,49 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/strings/string_view.h" +#include "xla/error_spec.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" + +namespace xla { +namespace { + +using XnnFusionTest = HloTestBase; + +XLA_TEST_F(XnnFusionTest, CorrectComputation) { + constexpr absl::string_view kModuleStr = R"( + HloModule xnn-fusion + + xnn_fusion { + %lhs = f32[4] parameter(0) + %rhs = f32[4] parameter(1) + %add = f32[4] add(%lhs, %rhs) + ROOT %mul = f32[4] multiply(%add, %add) + } + + ENTRY entry { + %p0 = f32[4] parameter(0) + %p1 = f32[4] parameter(1) + ROOT %fusion = f32[4] fusion(%p0, %p1), kind=kCustom, calls=xnn_fusion, + backend_config={"fusion_config": {kind: "__xnn_fusion"}} + })"; + + EXPECT_TRUE(RunAndCompare(kModuleStr, ErrorSpec{0.0})); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/text_literal_reader_test.cc b/third_party/xla/xla/text_literal_reader_test.cc index 11d76f224f4c9a..face01dce4a620 100644 --- a/third_party/xla/xla/text_literal_reader_test.cc +++ b/third_party/xla/xla/text_literal_reader_test.cc @@ -17,9 +17,10 @@ limitations under the License. #include +#include +#include "xla/hlo/testlib/test.h" #include "xla/literal.h" #include "xla/shape_util.h" -#include "xla/test.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" diff --git a/third_party/xla/xla/text_literal_writer_test.cc b/third_party/xla/xla/text_literal_writer_test.cc index 7ba40aff24b2e8..eea2e0eca0dba5 100644 --- a/third_party/xla/xla/text_literal_writer_test.cc +++ b/third_party/xla/xla/text_literal_writer_test.cc @@ -17,9 +17,10 @@ limitations under the License. #include +#include +#include "xla/hlo/testlib/test.h" +#include "xla/hlo/testlib/test_helpers.h" #include "xla/literal_util.h" -#include "xla/test.h" -#include "xla/test_helpers.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index 62e900243b6c5e..cf6526fd309b3d 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -95,6 +95,7 @@ xla_cc_binary( "//xla/service:interpreter_plugin", "//xla/service:local_service", "//xla/tsl/util:command_line_flags", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", @@ -125,6 +126,7 @@ xla_cc_binary( "//xla/service:hlo_proto_cc", "//xla/service:interpreter_plugin", "//xla/service:local_service", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -161,8 +163,8 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla/hlo/ir:hlo", - "//xla/hlo/transforms:algebraic_simplifier", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/service:call_inliner", "//xla/service:compilation_environments", "//xla/service:hlo_module_config", @@ -189,8 +191,15 @@ xla_cc_binary( testonly = True, linkopts = ["-Wl,-rpath,$$ORIGIN/../lit_lib"], deps = [ - "//xla/tools/hlo_opt:opt_main", - ], + "//xla/hlo/tools/hlo_opt:opt_main", + "//xla/tools/hlo_opt:cpu_opt", + ] + if_gpu_is_configured([ + "//xla/tools/hlo_opt:gpu_opt", + ]) + if_cuda_is_configured([ + "//xla/stream_executor:cuda_platform", + ]) + if_rocm_is_configured([ + "//xla/stream_executor:rocm_platform", + ]), ) cc_library( @@ -345,6 +354,7 @@ cc_library( deps = [ ":run_hlo_module_proto_cc", "//xla:debug_options_flags", + "//xla:util", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", @@ -355,6 +365,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_googlesource_code_re2//:re2", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:protobuf", @@ -563,7 +574,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/service:call_graph", "//xla/service:collective_ops_utils", "//xla/service:tuple_util", @@ -677,7 +688,6 @@ tsl_gpu_library( name = "xla_compile_lib", srcs = ["xla_compile_lib.cc"], hdrs = ["xla_compile_lib.h"], - defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), visibility = ["//visibility:public"], deps = [ ":hlo_module_loader", @@ -694,11 +704,15 @@ tsl_gpu_library( "//xla/service:export_hlo", "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", + "//xla/service:platform_util", "//xla/service:symbol_repository", "//xla/service:xla_compile_result_proto_cc_impl", "//xla/service/cpu:cpu_compiler", "//xla/service/cpu:cpu_executable", + "//xla/service/gpu:gpu_symbol_repository", + "//xla/service/gpu/autotuning:autotuner_util", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/cleanup", @@ -721,16 +735,8 @@ tsl_gpu_library( "@stablehlo//:register", ] + if_cuda_is_configured([ "//xla/service/gpu:nvptx_compiler", - "//xla/service/gpu:nvptx_compiler_impl", ]) + if_rocm_is_configured([ "//xla/service/gpu:amdgpu_compiler", - "//xla/service/gpu:amdgpu_compiler_impl", - ]) + if_gpu_is_configured([ - "//xla/service/gpu:executable_proto_cc", - "//xla/service/gpu:gpu_compiler", - "//xla/service/gpu/autotuning:autotuner_util", - "//xla/stream_executor/gpu:gpu_init", - "//xla/service/gpu:gpu_symbol_repository", ]) + if_google(["@com_google_protobuf//:duration_cc_proto"]), ) @@ -809,7 +815,7 @@ xla_test( deps = [ ":hlo_decomposer_lib", "//xla/hlo/ir:hlo", - "//xla/tests:filecheck", + "//xla/hlo/testlib:filecheck", "//xla/tests:hlo_test_base", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings:string_view", diff --git a/third_party/xla/xla/tools/driver.cc b/third_party/xla/xla/tools/driver.cc index 4f4895b57123ae..7f0d9c4507a2a2 100644 --- a/third_party/xla/xla/tools/driver.cc +++ b/third_party/xla/xla/tools/driver.cc @@ -101,12 +101,14 @@ void Log(const std::string& msg) { // Needs to be kept in sync with PrimitiveType in xla_data.proto. enum PrimitiveType { + S1, S2, S4, S8, S16, S32, S64, + U1, U2, U4, U8, @@ -129,19 +131,14 @@ enum PrimitiveType { }; const std::vector& primitive_strings() { - static auto vec = new std::vector({"s2", "s4", - "s8", "s16", - "s32", "s64", - "u2", "u4", - "u8", "u16", - "u32", "u64", - "f16", "bf16", - "f32", "f64", - "c64", "c128", - "f8e5m2", "f8e4m3", - "f8e4m3fn", "f8e4m3b11fnuz", - "f8e5m2fnuz", "f8e4m3fnuz", - "f8e3m4"}); + static auto vec = new std::vector( + {"s1", "s2", "s4", "s8", + "s16", "s32", "s64", "u1", + "u2", "u4", "u8", "u16", + "u32", "u64", "f16", "bf16", + "f32", "f64", "c64", "c128", + "f8e5m2", "f8e4m3", "f8e4m3fn", "f8e4m3b11fnuz", + "f8e5m2fnuz", "f8e4m3fnuz", "f8e3m4"}); return *vec; } @@ -429,6 +426,8 @@ void Fill(void* buffer, const ArrayShape& shape) { case BF16: case C64: case C128: + case S1: + case U1: case S2: case U2: case S4: @@ -487,6 +486,8 @@ void Display(const void* buffer, const ArrayShape& shape) { case BF16: case C64: case C128: + case S1: + case U1: case S2: case U2: case S4: diff --git a/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc b/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc index b6be3188ffa02e..e70fcb935596e7 100644 --- a/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc +++ b/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" diff --git a/third_party/xla/xla/tools/dumped_computation_to_text.cc b/third_party/xla/xla/tools/dumped_computation_to_text.cc index 72e1710e194507..78811d51bfc93f 100644 --- a/third_party/xla/xla/tools/dumped_computation_to_text.cc +++ b/third_party/xla/xla/tools/dumped_computation_to_text.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/client/client_library.h" diff --git a/third_party/xla/xla/tools/hlo_bisect/BUILD b/third_party/xla/xla/tools/hlo_bisect/BUILD index d094aba983de94..53a2c083f163d0 100644 --- a/third_party/xla/xla/tools/hlo_bisect/BUILD +++ b/third_party/xla/xla/tools/hlo_bisect/BUILD @@ -46,7 +46,7 @@ cc_library( "//xla:literal_util", "//xla:util", "//xla/hlo/ir:hlo", - "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms/simplifiers:hlo_dce", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/third_party/xla/xla/tools/hlo_decomposer.cc b/third_party/xla/xla/tools/hlo_decomposer.cc index 577456c24787a2..e083dc798a5c1e 100644 --- a/third_party/xla/xla/tools/hlo_decomposer.cc +++ b/third_party/xla/xla/tools/hlo_decomposer.cc @@ -223,10 +223,10 @@ std::unique_ptr ExtractProducerConsumerIntoNewModule( std::unique_ptr ExtractComputationIntoNewModule( const HloComputation& computation) { - auto new_hlo_module = - std::make_unique("extracted", HloModuleConfig{}, - std::make_unique( - computation.parent()->comp_envs())); + auto new_hlo_module = std::make_unique( + std::string(computation.name()), HloModuleConfig{}, + std::make_unique( + computation.parent()->comp_envs())); HloCloneContext clone_context(new_hlo_module.get()); new_hlo_module->AddEntryComputationWithLayouts( computation.CloneInContext(clone_context)); diff --git a/third_party/xla/xla/tools/hlo_decomposer_test.cc b/third_party/xla/xla/tools/hlo_decomposer_test.cc index d60f94fdd26aa6..c38aa8faa53599 100644 --- a/third_party/xla/xla/tools/hlo_decomposer_test.cc +++ b/third_party/xla/xla/tools/hlo_decomposer_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/tests/filecheck.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" @@ -157,5 +157,20 @@ CHECK-THEN: ROOT %e.1 )"); } +TEST_F(HloDecomposerTest, ExtractComputationIntoNewModule) { + std::unique_ptr module = ParseAndReturnVerifiedModule(R"( +HloModule module + +ENTRY main { + p0 = s8[10,10] parameter(0) + p1 = s8[10,10] parameter(1) + ROOT r = s8[10,10] add(p0, p1) +})") + .value(); + auto new_module = + ExtractComputationIntoNewModule(*module->entry_computation()); + EXPECT_EQ(new_module->name(), module->entry_computation()->name()); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/tools/hlo_extractor_test.cc b/third_party/xla/xla/tools/hlo_extractor_test.cc index 35c4c44953e6d9..8c2ab34524db83 100644 --- a/third_party/xla/xla/tools/hlo_extractor_test.cc +++ b/third_party/xla/xla/tools/hlo_extractor_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include @@ -481,7 +480,7 @@ TEST_F(HloExtractorTest, TestWithCalledComputationsAndFusion) { } TEST_F(HloExtractorTest, TestInvalidModule) { - constexpr std::string_view hlo = R"( + constexpr absl::string_view hlo = R"( HloModule main computation { diff --git a/third_party/xla/xla/tools/hlo_module_loader.cc b/third_party/xla/xla/tools/hlo_module_loader.cc index 3ab573dfa2ac42..db1439a226ba4e 100644 --- a/third_party/xla/xla/tools/hlo_module_loader.cc +++ b/third_party/xla/xla/tools/hlo_module_loader.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -28,6 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "re2/re2.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" @@ -36,8 +36,10 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/tools/run_hlo_module.pb.h" +#include "xla/util.h" #include "xla/xla.pb.h" #include "tsl/platform/env.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/path.h" #include "tsl/platform/protobuf.h" @@ -55,12 +57,12 @@ absl::Status OverrideConfig(const hlo_module_loader_details::Config& ovr_config, } // namespace -std::string StripLogHeaders(std::string_view hlo_string) { +std::string StripLogHeaders(absl::string_view hlo_string) { // I0521 12:04:45.883483 1509 service.cc:186] ... static RE2* matcher = new RE2( "[IWEF]\\d{4} " "\\d{2}:\\d{2}:\\d{2}\\.\\d+\\s+\\d+\\s+[^:]+:\\d+\\]\\s?(.*)"); - std::string_view matches[4]; + absl::string_view matches[4]; std::vector lines = absl::StrSplit(hlo_string, '\n'); for (auto& line : lines) { if (matcher->Match(line, 0, line.size(), RE2::ANCHOR_START, matches, 4)) { @@ -74,7 +76,7 @@ std::string StripLogHeaders(std::string_view hlo_string) { } absl::StatusOr> LoadModuleFromData( - const std::string& data, std::string_view format, + const std::string& data, absl::string_view format, const hlo_module_loader_details::Config& ovr_config, const std::function& config_modifier_hook, BufferAssignmentProto* buffer_assignment_proto, bool fill_missing_layouts) { @@ -150,7 +152,7 @@ absl::StatusOr> LoadModuleFromFile( } absl::StatusOr> -LoadInputFromData(const std::string& data, std::string_view format) { +LoadInputFromData(const std::string& data, absl::string_view format) { HloSnapshot proto; if (format == "pb") { if (!proto.ParseFromString(data) && diff --git a/third_party/xla/xla/tools/hlo_module_loader.h b/third_party/xla/xla/tools/hlo_module_loader.h index 4dc0653cd9729b..a8b7c1e48123f4 100644 --- a/third_party/xla/xla/tools/hlo_module_loader.h +++ b/third_party/xla/xla/tools/hlo_module_loader.h @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -32,7 +31,7 @@ namespace xla { namespace hlo_module_loader_details { struct Config { - Config() {} + Config() = default; int64_t num_replicas = 1; int64_t num_partitions = 1; }; @@ -41,7 +40,7 @@ struct Config { // Given a string composed by multiple lines, strip the log headers, if present // at the beginning of each line. -std::string StripLogHeaders(std::string_view hlo_string); +std::string StripLogHeaders(absl::string_view hlo_string); // Loads an HLO module from a string. // The data can have the followings formats: @@ -58,7 +57,7 @@ std::string StripLogHeaders(std::string_view hlo_string); // and the hlo module format is proto, it loads buffer assignment from the // proto. absl::StatusOr> LoadModuleFromData( - const std::string& data, std::string_view format, + const std::string& data, absl::string_view format, const hlo_module_loader_details::Config& ovr_config = hlo_module_loader_details::Config(), const std::function& config_modifier_hook = {}, @@ -93,7 +92,7 @@ absl::StatusOr> LoadModuleFromFile( // 1) A binary proto (format "pb") // 2) A text proto (format "pbtxt") absl::StatusOr> -LoadInputFromData(const std::string& data, std::string_view format); +LoadInputFromData(const std::string& data, absl::string_view format); // Loads an HLO snapshot from file, only for its inputs // The file must be one of the following: diff --git a/third_party/xla/xla/tools/hlo_opt/BUILD b/third_party/xla/xla/tools/hlo_opt/BUILD index c3d5137ed4f797..e59e76e7e3a087 100644 --- a/third_party/xla/xla/tools/hlo_opt/BUILD +++ b/third_party/xla/xla/tools/hlo_opt/BUILD @@ -26,109 +26,25 @@ package( ) cc_library( - name = "opt_main", - testonly = True, - srcs = ["opt_main.cc"], + name = "compiled_opt_lib", + srcs = ["compiled_opt_lib.cc"], + hdrs = ["compiled_opt_lib.h"], deps = [ - ":cpu_opt", - ":opt_lib", "//xla:debug_options_flags", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", - "//xla/service:hlo_runner", - "//xla/service:platform_util", - "//xla/tools:hlo_module_loader", - "//xla/tools:run_hlo_module_lib", - "//xla/tsl/util:command_line_flags", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - ] + if_gpu_is_configured([ - ":gpu_opt", - ]) + if_cuda_is_configured([ - "//xla/stream_executor:cuda_platform", - ]) + if_rocm_is_configured([ - "//xla/stream_executor:rocm_platform", - ]), -) - -# Includes a macro to register a provider. -cc_library( - name = "opt_lib", - srcs = ["opt_lib.cc"], - hdrs = [ - "opt_lib.h", - "transforms_example_passes.h", - ], - deps = [ - "//xla/hlo/analysis:indexed_array_analysis", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:all_reduce_folder", - "//xla/hlo/transforms:batch_dot_simplification", - "//xla/hlo/transforms:broadcast_canonicalizer", - "//xla/hlo/transforms:cholesky_expander", - "//xla/hlo/transforms:comparison_expander", - "//xla/hlo/transforms:conditional_canonicalizer", - "//xla/hlo/transforms:convert_memory_placement_to_internal_annotations", - "//xla/hlo/transforms:convert_mover", - "//xla/hlo/transforms:convolution_4d_expander", - "//xla/hlo/transforms:convolution_group_converter", - "//xla/hlo/transforms:convolution_pred_expander", - "//xla/hlo/transforms:dot_decomposer", - "//xla/hlo/transforms:dynamic_dimension_simplifier", - "//xla/hlo/transforms:dynamic_index_splitter", - "//xla/hlo/transforms:eigh_expander", - "//xla/hlo/transforms:flatten_call_graph", - "//xla/hlo/transforms:float_normalization", - "//xla/hlo/transforms:gather_simplifier", - "//xla/hlo/transforms:hlo_constant_folding", - "//xla/hlo/transforms:hlo_dce", - "//xla/hlo/transforms:logistic_expander", - "//xla/hlo/transforms:operand_upcaster", - "//xla/hlo/transforms:optimization_barrier_expander", - "//xla/hlo/transforms:optimize_input_output_buffer_alias", - "//xla/hlo/transforms:qr_expander", - "//xla/hlo/transforms:real_imag_expander", - "//xla/hlo/transforms:reduce_decomposer", - "//xla/hlo/transforms:reshape_decomposer", - "//xla/hlo/transforms:reshape_mover", - "//xla/hlo/transforms:result_caster", - "//xla/hlo/transforms:rng_expander", - "//xla/hlo/transforms:simplify_fp_conversions", - "//xla/hlo/transforms:slice_sinker", - "//xla/hlo/transforms:sort_simplifier", - "//xla/hlo/transforms:stable_sort_expander", - "//xla/hlo/transforms:stochastic_convert_decomposer", - "//xla/hlo/transforms:sub_byte_normalization", - "//xla/hlo/transforms:tree_reduction_rewriter", - "//xla/hlo/transforms:tuple_simplifier", - "//xla/hlo/transforms:while_loop_trip_count_annotator", - "//xla/hlo/transforms:zero_sized_hlo_elimination", - "//xla/hlo/transforms/collectives:all_gather_broadcast_reorder", - "//xla/hlo/transforms/collectives:all_reduce_contiguous", - "//xla/hlo/transforms/collectives:collective_quantizer", + "//xla/hlo/tools/hlo_opt:opt_lib", + "//xla/hlo/transforms:bitcast_dtypes_expander", "//xla/service:all_reduce_simplifier", "//xla/service:all_to_all_decomposer", "//xla/service:batched_gather_scatter_normalizer", - "//xla/service:bitcast_dtypes_expander", "//xla/service:call_inliner", + "//xla/service:compiler", "//xla/service:conditional_simplifier", "//xla/service:conditional_to_select", "//xla/service:copy_insertion", - "//xla/service:float_support", + "//xla/service:executable", "//xla/service:gather_expander", - "//xla/service:hlo_graph_dumper", "//xla/service:map_inliner", "//xla/service:platform_util", "//xla/service:reduce_scatter_reassociate", @@ -148,32 +64,6 @@ cc_library( "//xla/service/gpu/transforms:scatter_expander", "//xla/service/gpu/transforms:scatter_slice_simplifier", "//xla/service/spmd/shardy:shardy_xla_pass", - "//xla/stream_executor/platform:initialize", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "compiled_opt_lib", - srcs = ["compiled_opt_lib.cc"], - hdrs = ["compiled_opt_lib.h"], - deps = [ - ":opt_lib", - "//xla:debug_options_flags", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:compiler", - "//xla/service:executable", - "//xla/service:platform_util", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/log:check", @@ -194,8 +84,8 @@ cc_library( "//xla:types", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/hlo/transforms:algebraic_simplifier", - "//xla/hlo/transforms:reduce_window_rewriter", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:reduce_window_rewriter", "//xla/service:buffer_value", "//xla/service:compiler", "//xla/service:dump", @@ -258,7 +148,6 @@ cc_library( srcs = ["cpu_opt.cc"], deps = [ ":compiled_opt_lib", - ":opt_lib", "//xla:debug_options_flags", "//xla:util", "//xla:xla_data_proto_cc", @@ -268,9 +157,9 @@ cc_library( "//xla/backends/cpu/codegen:target_machine_features", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/transforms:algebraic_simplifier", - "//xla/hlo/transforms:reduce_window_rewriter", "//xla/hlo/transforms:rng_bit_generator_expander", + "//xla/hlo/transforms/simplifiers:algebraic_simplifier", + "//xla/hlo/transforms/simplifiers:reduce_window_rewriter", "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "//xla/service:batchnorm_expander", "//xla/service:change_op_data_type", @@ -278,7 +167,6 @@ cc_library( "//xla/service:dynamic_dimension_inference", "//xla/service:dynamic_padder", "//xla/service:executable", - "//xla/service:float_support", "//xla/service:gather_expander", "//xla/service:hlo_execution_profile", "//xla/service:hlo_graph_dumper", @@ -302,6 +190,7 @@ cc_library( "//xla/service/spmd:stateful_rng_spmd_partitioner", "//xla/stream_executor/host:host_platform", "//xla/stream_executor/platform:initialize", + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:MC", @@ -320,6 +209,7 @@ lit_test_suite( [ "tests/cpu_hlo.hlo", "tests/cpu_llvm.hlo", + "tests/cpu_hlo_pass.hlo", "tests/gpu_hlo.hlo", "tests/gpu_hlo_backend.hlo", "tests/gpu_hlo_buffers.hlo", @@ -327,12 +217,9 @@ lit_test_suite( "tests/gpu_hlo_pass.hlo", "tests/gpu_hlo_ptx.hlo", "tests/gpu_hlo_unoptimized_llvm.hlo", - "tests/run_single_pass.hlo", + "tests/gpu_hlo_html.hlo", "tests/list_passes.hlo", - "tests/run_multiple_passes.hlo", "tests/run_pass_with_input.hlo", - "tests/gpu_hlo_html.hlo", - "tests/cpu_hlo_pass.hlo", ], include = [ "tests/*.hlo", diff --git a/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.cc b/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.cc index 279836160128e0..4b907482e1470c 100644 --- a/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.cc +++ b/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.cc @@ -26,9 +26,36 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/transforms/expanders/bitcast_dtypes_expander.h" +#include "xla/service/all_reduce_simplifier.h" +#include "xla/service/all_to_all_decomposer.h" +#include "xla/service/batched_gather_scatter_normalizer.h" +#include "xla/service/call_inliner.h" #include "xla/service/compiler.h" +#include "xla/service/conditional_simplifier.h" +#include "xla/service/conditional_to_select.h" +#include "xla/service/copy_insertion.h" #include "xla/service/executable.h" +#include "xla/service/gather_expander.h" +#include "xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h" +#include "xla/service/gpu/transforms/all_reduce_splitter.h" +#include "xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h" +#include "xla/service/gpu/transforms/scatter_expander.h" +#include "xla/service/gpu/transforms/scatter_slice_simplifier.h" +#include "xla/service/map_inliner.h" #include "xla/service/platform_util.h" +#include "xla/service/reduce_scatter_reassociate.h" +#include "xla/service/scatter_determinism_expander.h" +#include "xla/service/scatter_simplifier.h" +#include "xla/service/select_and_scatter_expander.h" +#include "xla/service/sharding_remover.h" +#include "xla/service/spmd/shardy/shardy_xla_pass.h" +#include "xla/service/topk_rewriter.h" +#include "xla/service/triangular_solve_expander.h" +#include "xla/service/while_loop_all_reduce_code_motion.h" +#include "xla/service/while_loop_constant_sinking.h" +#include "xla/service/while_loop_invariant_code_motion.h" +#include "xla/service/while_loop_simplifier.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/xla.pb.h" @@ -117,4 +144,36 @@ std::set CompiledOptProvider::SupportedStages() { return {"hlo", "html", "hlo-backend"}; } +void CompiledOptProvider::RegisterSharedHardwareSpecificPasses() { + // go/keep-sorted start + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(GatherExpander::kEliminateSimpleGathers); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + RegisterPass(); + // go/keep-sorted end +} + } // namespace xla diff --git a/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.h b/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.h index 9cbe2d61810f80..eaabe294b5533a 100644 --- a/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.h +++ b/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.h @@ -24,17 +24,19 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/tools/hlo_opt/opt_lib.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/stream_executor/platform.h" -#include "xla/tools/hlo_opt/opt_lib.h" namespace xla { // Platform-specific provider of `hlo-opt` functionality. class CompiledOptProvider : public OptProvider { public: - CompiledOptProvider() : OptProvider() {} + CompiledOptProvider() : OptProvider() { + RegisterSharedHardwareSpecificPasses(); + } // Generates textual output for a given stage on a given platform, returns // empty optional if the stage is not supported. @@ -61,6 +63,10 @@ class CompiledOptProvider : public OptProvider { // Gets a compiler associated with the provider. virtual absl::StatusOr GetCompiler(); + + // Registers hardware-specific passes which are shared by + // multiple backends (CPU, GPU, xPU). + void RegisterSharedHardwareSpecificPasses(); }; } // namespace xla diff --git a/third_party/xla/xla/tools/hlo_opt/cpu_opt.cc b/third_party/xla/xla/tools/hlo_opt/cpu_opt.cc index c9b5605a0f0929..cb5c3bebe13ab2 100644 --- a/third_party/xla/xla/tools/hlo_opt/cpu_opt.cc +++ b/third_party/xla/xla/tools/hlo_opt/cpu_opt.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "llvm/MC/TargetRegistry.h" @@ -66,7 +67,6 @@ limitations under the License. #include "xla/service/transpose_folding.h" #include "xla/stream_executor/platform/initialize.h" #include "xla/tools/hlo_opt/compiled_opt_lib.h" -#include "xla/tools/hlo_opt/opt_lib.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/tools/hlo_opt/tests/run_pass_with_input.hlo b/third_party/xla/xla/tools/hlo_opt/tests/run_pass_with_input.hlo index c1fbfc81dd11ca..defd1a6cb2a116 100644 --- a/third_party/xla/xla/tools/hlo_opt/tests/run_pass_with_input.hlo +++ b/third_party/xla/xla/tools/hlo_opt/tests/run_pass_with_input.hlo @@ -1,4 +1,4 @@ -// RUN: hlo-opt %s --passes=gather_expander | FileCheck %s +// RUN: hlo-opt %s --platform=cpu --passes=gather_expander | FileCheck %s HloModule test diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD index 1d130c855b045e..fa76ac509597d5 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD +++ b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD @@ -46,6 +46,7 @@ cc_library( "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options", "//xla/service:cpu_plugin", "//xla/tsl/util:command_line_flags", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -111,10 +112,10 @@ cc_library( "//xla/pjrt/distributed:client", "//xla/pjrt/distributed:key_value_store_interface", "//xla/pjrt/distributed:service", - "//xla/pjrt/gpu:se_gpu_pjrt_client", "//xla/pjrt/plugin/xla_cpu:cpu_client_options", "//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options", + "//xla/pjrt/plugin/xla_gpu:xla_gpu_pjrt_client", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc b/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc index a1c3bb027b5a70..822766ff392ab9 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc @@ -27,12 +27,12 @@ limitations under the License. #include "xla/pjrt/distributed/client.h" #include "xla/pjrt/distributed/distributed.h" #include "xla/pjrt/distributed/service.h" -#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" #include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" #include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h" +#include "xla/pjrt/plugin/xla_gpu/xla_gpu_pjrt_client.h" #include "xla/status_macros.h" #include "xla/xla.pb.h" #include "tsl/platform/status.h" @@ -113,7 +113,7 @@ absl::StatusOr> CreateGpuClient( return absl::InvalidArgumentError( "Node id is expected to be in range [0, num_nodes)"); } - return GetStreamExecutorGpuClient(options); + return xla::GetXlaPjrtGpuClient(options); } absl::StatusOr> CreateMockGpuClient(int num_nodes) { diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc index 01bd7f02c6fc69..740a1f92dc7b55 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -543,7 +542,7 @@ FunctionalHloRunner::LoadAndRun(PjRtClient& client, absl::Status FunctionalHloRunner::LoadAndCompile( PjRtClient& client, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, - const RawCompileOptions& raw_compile_options, std::string_view hlo_file, + const RawCompileOptions& raw_compile_options, absl::string_view hlo_file, InputFormat input_format, int task_id, int num_nodes, std::shared_ptr kv_store, bool use_gpu_count_workaround) { @@ -1308,16 +1307,15 @@ FunctionalHloRunner::CopyArgumentsToDevice( TF_RET_CHECK(!shape.IsTuple()) << "Param tuple without flattened_arguments"; return non_tuple_memory_space(shape); }; - TF_ASSIGN_OR_RETURN(const std::vector>& + TF_ASSIGN_OR_RETURN(const std::vector>& executable_parameter_pjrt_layouts, executable->GetParameterLayouts()); std::vector executable_parameter_layouts; executable_parameter_layouts.reserve( executable_parameter_pjrt_layouts.size()); - for (const std::unique_ptr& pjrt_layout : + for (const std::shared_ptr& pjrt_layout : executable_parameter_pjrt_layouts) { - executable_parameter_layouts.push_back( - xla::GetXlaLayoutUnsafe(pjrt_layout)); + executable_parameter_layouts.push_back(pjrt_layout->xla_layout()); } auto buffer_from_host_literal = [&client, &argument_memory_space, &executable_parameter_layouts]( diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h index 26a7894c28c80b..15ce5de0917f83 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h @@ -16,14 +16,17 @@ limitations under the License. #ifndef XLA_TOOLS_MULTIHOST_HLO_RUNNER_FUNCTIONAL_HLO_RUNNER_H_ #define XLA_TOOLS_MULTIHOST_HLO_RUNNER_FUNCTIONAL_HLO_RUNNER_H_ +#include +#include #include #include #include +#include #include -#include #include #include "absl/container/btree_map.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -273,7 +276,7 @@ class FunctionalHloRunner { static absl::Status LoadAndCompile( PjRtClient& client, const DebugOptions& debug_options, const PreprocessingOptions& preproc_options, - const RawCompileOptions& raw_compile_options, std::string_view hlo_file, + const RawCompileOptions& raw_compile_options, absl::string_view hlo_file, InputFormat input_format, int task_id = 0, int num_nodes = 1, std::shared_ptr kv_store = nullptr, bool use_gpu_count_workaround = true); diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc b/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc index 99c87551e32dd6..9a09f15481c307 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" diff --git a/third_party/xla/xla/tools/xla_compile_lib.cc b/third_party/xla/xla/tools/xla_compile_lib.cc index 20d0a33593368a..2fc800c6e6e541 100644 --- a/third_party/xla/xla/tools/xla_compile_lib.cc +++ b/third_party/xla/xla/tools/xla_compile_lib.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/cleanup/cleanup.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" @@ -45,15 +46,17 @@ limitations under the License. #include "xla/pjrt/mlir_to_hlo.h" #include "xla/service/compiler.h" #include "xla/service/cpu/cpu_compiler.h" -#include "xla/service/cpu/cpu_executable.h" #include "xla/service/executable.h" #include "xla/service/export_hlo.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" +#include "xla/service/gpu/gpu_symbol_repository.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" +#include "xla/service/platform_util.h" #include "xla/service/symbol_repository.h" #include "xla/service/xla_compile_result.pb.h" #include "xla/shape.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tools/hlo_module_loader.h" @@ -67,18 +70,6 @@ limitations under the License. #include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/statusor.h" -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/service/gpu/autotuning/autotuner_util.h" -#include "xla/service/gpu/executable.pb.h" -#include "xla/service/gpu/gpu_symbol_repository.h" -#include "xla/stream_executor/gpu/gpu_init.h" -#endif -#if GOOGLE_CUDA -#include "xla/service/gpu/nvptx_compiler.h" -#elif TENSORFLOW_USE_ROCM -#include "xla/service/gpu/amdgpu_compiler.h" -#endif - namespace xla { static absl::StatusOr AotCompileCpuExecutable( @@ -97,26 +88,27 @@ static absl::StatusOr CompileGpuExecutable( std::unique_ptr hlo_module, std::optional target_config, CompilationResult& result) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + TF_ASSIGN_OR_RETURN(std::string platform_name, + xla::PlatformUtil::CanonicalPlatformName("gpu")); + platform_name = absl::AsciiStrToUpper(platform_name); + TF_ASSIGN_OR_RETURN( + auto platform, + stream_executor::PlatformManager::PlatformWithName(platform_name)); const bool aot = target_config.has_value(); -#if GOOGLE_CUDA - auto gpu_compiler = gpu::NVPTXCompiler(); -#elif TENSORFLOW_USE_ROCM - auto gpu_compiler = gpu::AMDGPUCompiler(); -#endif + TF_ASSIGN_OR_RETURN(auto gpu_compiler, Compiler::GetForPlatform(platform)); auto module_group = std::make_unique(std::move(hlo_module)); if (aot) { - AotCompilationOptions aot_options(gpu_compiler.PlatformId()); + AotCompilationOptions aot_options(platform->id()); aot_options.set_target_config(*target_config); // We need the optimized module, so we call RunHloPasses ourselves above. aot_options.set_run_backend_only(true); TF_ASSIGN_OR_RETURN( std::vector> aot_results, - gpu_compiler.CompileAheadOfTime(std::move(module_group), aot_options)); + gpu_compiler->CompileAheadOfTime(std::move(module_group), aot_options)); TF_ASSIGN_OR_RETURN(std::string compile_result, aot_results[0]->SerializeAsString()); *result.mutable_hlo_module() = @@ -125,10 +117,8 @@ static absl::StatusOr CompileGpuExecutable( } Compiler::CompileOptions compile_options; - TF_RETURN_IF_ERROR(stream_executor::ValidateGPUMachineManager()); - TF_ASSIGN_OR_RETURN( - stream_executor::StreamExecutor * stream_executor, - stream_executor::GPUMachineManager()->ExecutorForDevice(0)); + TF_ASSIGN_OR_RETURN(stream_executor::StreamExecutor * stream_executor, + platform->ExecutorForDevice(0)); auto allocator = std::make_unique( stream_executor); @@ -136,14 +126,10 @@ static absl::StatusOr CompileGpuExecutable( TF_ASSIGN_OR_RETURN( std::vector> executables, - gpu_compiler.Compile(std::move(module_group), {{stream_executor}}, - compile_options)); + gpu_compiler->Compile(std::move(module_group), {{stream_executor}}, + compile_options)); *result.mutable_hlo_module() = executables[0]->module().ToProto(); return executables[0]->module().ToString(); -#else - LOG(ERROR) << "Neither ROCm nor CUDA present; returning empty."; - return ""; -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } absl::StatusOr CompileExecutable( @@ -235,13 +221,11 @@ ReadModuleFromSymbolRepo(absl::string_view symbol_repo, static std::unique_ptr ReadTargetConfigFromModule( HloModuleAndMetadata* mod, BackendType backend) { if (backend == BackendType::kGpu) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (auto* data = static_cast( mod->backend_specific_data.get()); data != nullptr) { return std::move(mod->target_config); } -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } return nullptr; @@ -252,7 +236,6 @@ namespace internal { absl::StatusOr LoadAutotuneDataFromModule(HloModuleAndMetadata* mod, BackendType backend) { if (backend == BackendType::kGpu) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (auto* data = static_cast( mod->backend_specific_data.get()); data != nullptr && data->autotune_results.has_value() && @@ -262,7 +245,6 @@ absl::StatusOr LoadAutotuneDataFromModule(HloModuleAndMetadata* mod, gpu::AutotunerUtil::LoadAutotuneResults(*data->autotune_results)); return true; } -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } return false; } @@ -293,9 +275,7 @@ absl::Status XlaCompileMain(const XlaCompileOptions& options) { TF_ASSIGN_OR_RETURN(hlo_module, LoadModule(options.module_path)); } -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM bool found_autotune = false; -#endif if (absl::string_view optimized_symbol_id = options.repo_options.optimized_symbol_id; @@ -304,10 +284,8 @@ absl::Status XlaCompileMain(const XlaCompileOptions& options) { std::unique_ptr optimized_mod, ReadModuleFromSymbolRepo(symbol_repo, optimized_symbol_id, backend)); -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM TF_ASSIGN_OR_RETURN(found_autotune, internal::LoadAutotuneDataFromModule( optimized_mod.get(), backend)); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } xla::TimerStats stats; @@ -325,7 +303,6 @@ absl::Status XlaCompileMain(const XlaCompileOptions& options) { // Run AOT compilation. std::optional cfg = std::nullopt; if (backend == BackendType::kGpu) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (absl::string_view gpu_target_config_path = options.gpu_options.gpu_target_config_path; !gpu_target_config_path.empty()) { @@ -356,7 +333,6 @@ absl::Status XlaCompileMain(const XlaCompileOptions& options) { cfg = (options.gpu_options.use_attached_device) ? std::nullopt : std::make_optional(*std::move(target_config)); -#endif } auto result = CompileExecutable(std::move(hlo_module), backend, std::move(cfg), compilation_result); diff --git a/third_party/xla/xla/tsl/c/BUILD b/third_party/xla/xla/tsl/c/BUILD index e80786d59c847d..06b4e76c19c652 100644 --- a/third_party/xla/xla/tsl/c/BUILD +++ b/third_party/xla/xla/tsl/c/BUILD @@ -60,7 +60,7 @@ tsl_gpu_library( ], visibility = ["//visibility:public"], deps = [ - "@local_tsl//tsl/platform:status", + "//xla/tsl/platform:status", ], ) @@ -71,8 +71,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":tsl_status_internal", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", ], ) @@ -82,10 +82,10 @@ tsl_cc_test( deps = [ ":tsl_status", ":tsl_status_internal", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", ], ) @@ -103,8 +103,8 @@ tsl_gpu_library( deps = [ ":tsl_status", ":tsl_status_internal", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", ], ) diff --git a/third_party/xla/xla/tsl/c/tsl_status.cc b/third_party/xla/xla/tsl/c/tsl_status.cc index 75b948129f2533..b68908e89ff598 100644 --- a/third_party/xla/xla/tsl/c/tsl_status.cc +++ b/third_party/xla/xla/tsl/c/tsl_status.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include "xla/tsl/c/tsl_status_internal.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" using ::tsl::Status; using ::tsl::error::Code; diff --git a/third_party/xla/xla/tsl/c/tsl_status_helper.cc b/third_party/xla/xla/tsl/c/tsl_status_helper.cc index ca1c8b2dbe322b..a3bb572acb0417 100644 --- a/third_party/xla/xla/tsl/c/tsl_status_helper.cc +++ b/third_party/xla/xla/tsl/c/tsl_status_helper.cc @@ -16,7 +16,7 @@ limitations under the License. #include "xla/tsl/c/tsl_status_helper.h" #include "xla/tsl/c/tsl_status_internal.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/c/tsl_status_helper.h b/third_party/xla/xla/tsl/c/tsl_status_helper.h index 905785dc678386..6199c8724d5453 100644 --- a/third_party/xla/xla/tsl/c/tsl_status_helper.h +++ b/third_party/xla/xla/tsl/c/tsl_status_helper.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "xla/tsl/c/tsl_status.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/c/tsl_status_internal.h b/third_party/xla/xla/tsl/c/tsl_status_internal.h index 132adc62dac66f..a535fac0e65d5f 100644 --- a/third_party/xla/xla/tsl/c/tsl_status_internal.h +++ b/third_party/xla/xla/tsl/c/tsl_status_internal.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_TSL_C_TSL_STATUS_INTERNAL_H_ #define XLA_TSL_C_TSL_STATUS_INTERNAL_H_ -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" // Internal structures used by the status C API. These are likely to change // and should not be depended on. diff --git a/third_party/xla/xla/tsl/c/tsl_status_test.cc b/third_party/xla/xla/tsl/c/tsl_status_test.cc index b4518644f837f2..366b810691fb3a 100644 --- a/third_party/xla/xla/tsl/c/tsl_status_test.cc +++ b/third_party/xla/xla/tsl/c/tsl_status_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include "xla/tsl/c/tsl_status_internal.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/concurrency/BUILD b/third_party/xla/xla/tsl/concurrency/BUILD index ed3c7cb2730bfc..e6cfc40f0d726a 100644 --- a/third_party/xla/xla/tsl/concurrency/BUILD +++ b/third_party/xla/xla/tsl/concurrency/BUILD @@ -23,6 +23,7 @@ cc_library( deps = [ ":concurrent_vector", ":ref_count", + "//xla/tsl/platform:logging", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", @@ -30,7 +31,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", ], ) @@ -39,9 +39,9 @@ tsl_cc_test( srcs = ["async_value_test.cc"], deps = [ ":async_value", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -50,13 +50,13 @@ tsl_cc_test( srcs = ["async_value_ptr_test.cc"], deps = [ ":async_value", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/platform:test_main", ], ) @@ -66,13 +66,13 @@ tsl_cc_test( deps = [ ":async_value", ":ref_count", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/platform:test_main", ], ) @@ -81,9 +81,9 @@ cc_library( hdrs = ["concurrent_vector.h"], compatible_with = get_compatible_with_portable(), deps = [ + "//xla/tsl/platform:logging", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", ], ) @@ -92,10 +92,10 @@ tsl_cc_test( srcs = ["concurrent_vector_test.cc"], deps = [ ":concurrent_vector", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/tsl/concurrency/async_value.cc b/third_party/xla/xla/tsl/concurrency/async_value.cc index fa3f0582e779ef..fda3aa65911843 100644 --- a/third_party/xla/xla/tsl/concurrency/async_value.cc +++ b/third_party/xla/xla/tsl/concurrency/async_value.cc @@ -28,7 +28,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/ref_count.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/concurrency/async_value.h b/third_party/xla/xla/tsl/concurrency/async_value.h index 30e0d8ee11ac90..d04efc88a551b9 100644 --- a/third_party/xla/xla/tsl/concurrency/async_value.h +++ b/third_party/xla/xla/tsl/concurrency/async_value.h @@ -32,7 +32,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/tsl/concurrency/concurrent_vector.h" #include "xla/tsl/concurrency/ref_count.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/concurrency/async_value_ptr_test.cc b/third_party/xla/xla/tsl/concurrency/async_value_ptr_test.cc index 597d5b66f2c1db..7e5322b654f112 100644 --- a/third_party/xla/xla/tsl/concurrency/async_value_ptr_test.cc +++ b/third_party/xla/xla/tsl/concurrency/async_value_ptr_test.cc @@ -25,8 +25,8 @@ limitations under the License. #include "absl/types/span.h" #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/concurrency/async_value_ref.cc b/third_party/xla/xla/tsl/concurrency/async_value_ref.cc index 437ff310267140..d8af644eef53f1 100644 --- a/third_party/xla/xla/tsl/concurrency/async_value_ref.cc +++ b/third_party/xla/xla/tsl/concurrency/async_value_ref.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/ref_count.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/concurrency/async_value_ref.h b/third_party/xla/xla/tsl/concurrency/async_value_ref.h index 20f491345cf698..76f39e55ca4757 100644 --- a/third_party/xla/xla/tsl/concurrency/async_value_ref.h +++ b/third_party/xla/xla/tsl/concurrency/async_value_ref.h @@ -37,7 +37,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/ref_count.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc b/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc index 833b5e2fe543cb..d845c2b3e2e654 100644 --- a/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc +++ b/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc @@ -29,8 +29,8 @@ limitations under the License. #include "absl/types/span.h" #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/ref_count.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/concurrency/async_value_test.cc b/third_party/xla/xla/tsl/concurrency/async_value_test.cc index eb14685f37903f..00d1dc55056834 100644 --- a/third_party/xla/xla/tsl/concurrency/async_value_test.cc +++ b/third_party/xla/xla/tsl/concurrency/async_value_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/concurrency/concurrent_vector.h b/third_party/xla/xla/tsl/concurrency/concurrent_vector.h index b7a033ddaa75a2..aebca0369d2f1b 100644 --- a/third_party/xla/xla/tsl/concurrency/concurrent_vector.h +++ b/third_party/xla/xla/tsl/concurrency/concurrent_vector.h @@ -26,7 +26,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { namespace internal { diff --git a/third_party/xla/xla/tsl/concurrency/concurrent_vector_test.cc b/third_party/xla/xla/tsl/concurrency/concurrent_vector_test.cc index 5106909ce06146..2e1b41c37aff97 100644 --- a/third_party/xla/xla/tsl/concurrency/concurrent_vector_test.cc +++ b/third_party/xla/xla/tsl/concurrency/concurrent_vector_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include "tsl/platform/env.h" -#include "tsl/platform/test.h" -#include "tsl/platform/threadpool.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/threadpool.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/cuda/cublasLt_stub.cc b/third_party/xla/xla/tsl/cuda/cublasLt_stub.cc index db60995d59fa57..728c3affeaf387 100644 --- a/third_party/xla/xla/tsl/cuda/cublasLt_stub.cc +++ b/third_party/xla/xla/tsl/cuda/cublasLt_stub.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "third_party/gpus/cuda/include/cublasLt.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/dso_loader.h" #include "tsl/platform/load_library.h" -#include "tsl/platform/logging.h" // Implements the cuBLASLt API by forwarding to cuBLASLt loaded from the DSO. diff --git a/third_party/xla/xla/tsl/cuda/cublas_stub.cc b/third_party/xla/xla/tsl/cuda/cublas_stub.cc index a4b7fcbb828b68..bbec38bd3e868d 100644 --- a/third_party/xla/xla/tsl/cuda/cublas_stub.cc +++ b/third_party/xla/xla/tsl/cuda/cublas_stub.cc @@ -23,9 +23,9 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/dso_loader.h" #include "tsl/platform/load_library.h" -#include "tsl/platform/logging.h" // Implements the cuBLAS API by forwarding to cuBLAS loaded from the DSO. // Note that it does not implement the v1 interface. diff --git a/third_party/xla/xla/tsl/cuda/cuda_stub.cc b/third_party/xla/xla/tsl/cuda/cuda_stub.cc index e33535c16e33c3..4958b626c2fde0 100644 --- a/third_party/xla/xla/tsl/cuda/cuda_stub.cc +++ b/third_party/xla/xla/tsl/cuda/cuda_stub.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "third_party/gpus/cuda/include/cuda.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/dso_loader.h" #include "tsl/platform/load_library.h" -#include "tsl/platform/logging.h" // Implements the CUDA driver API by forwarding to CUDA loaded from the DSO. diff --git a/third_party/xla/xla/tsl/cuda/cudart_stub.cc b/third_party/xla/xla/tsl/cuda/cudart_stub.cc index 7064a72541eefd..55a6dd88309a39 100644 --- a/third_party/xla/xla/tsl/cuda/cudart_stub.cc +++ b/third_party/xla/xla/tsl/cuda/cudart_stub.cc @@ -20,9 +20,9 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/dso_loader.h" #include "tsl/platform/load_library.h" -#include "tsl/platform/logging.h" namespace { void *GetDsoHandle() { diff --git a/third_party/xla/xla/tsl/cuda/cudnn_stub.cc b/third_party/xla/xla/tsl/cuda/cudnn_stub.cc index 192009c9e8728d..483d391534a887 100644 --- a/third_party/xla/xla/tsl/cuda/cudnn_stub.cc +++ b/third_party/xla/xla/tsl/cuda/cudnn_stub.cc @@ -15,9 +15,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "third_party/gpus/cudnn/cudnn.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/dso_loader.h" #include "tsl/platform/load_library.h" -#include "tsl/platform/logging.h" // Implements the cuDNN API by forwarding to cuDNN loaded from the DSO. diff --git a/third_party/xla/xla/tsl/cuda/cufft_stub.cc b/third_party/xla/xla/tsl/cuda/cufft_stub.cc index ea7b08f8821891..3f890b20b95d73 100644 --- a/third_party/xla/xla/tsl/cuda/cufft_stub.cc +++ b/third_party/xla/xla/tsl/cuda/cufft_stub.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "third_party/gpus/cuda/include/cufft.h" #include "third_party/gpus/cuda/include/cufftXt.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/dso_loader.h" #include "tsl/platform/load_library.h" -#include "tsl/platform/logging.h" // Implements the cuFFT API by forwarding to cuFFT loaded from the DSO. diff --git a/third_party/xla/xla/tsl/cuda/cupti_stub.cc b/third_party/xla/xla/tsl/cuda/cupti_stub.cc index 01d13a8ea7d4f9..c95b38dc249b05 100644 --- a/third_party/xla/xla/tsl/cuda/cupti_stub.cc +++ b/third_party/xla/xla/tsl/cuda/cupti_stub.cc @@ -15,9 +15,9 @@ limitations under the License. #include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/dso_loader.h" #include "tsl/platform/load_library.h" -#include "tsl/platform/logging.h" // Implements the CUPTI API by forwarding to CUPTI loaded from the DSO. diff --git a/third_party/xla/xla/tsl/cuda/cusolver_stub.cc b/third_party/xla/xla/tsl/cuda/cusolver_stub.cc index d76526042582e8..2cd67175b85f4c 100644 --- a/third_party/xla/xla/tsl/cuda/cusolver_stub.cc +++ b/third_party/xla/xla/tsl/cuda/cusolver_stub.cc @@ -15,9 +15,9 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cusolverDn.h" #include "third_party/gpus/cuda/include/cusolverSp.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/dso_loader.h" #include "tsl/platform/load_library.h" -#include "tsl/platform/logging.h" // Implements the cusolver API by forwarding to cusolver loaded from the DSO. diff --git a/third_party/xla/xla/tsl/cuda/cusparse_stub.cc b/third_party/xla/xla/tsl/cuda/cusparse_stub.cc index b8ab1d67354bd3..56730ea90d0a59 100644 --- a/third_party/xla/xla/tsl/cuda/cusparse_stub.cc +++ b/third_party/xla/xla/tsl/cuda/cusparse_stub.cc @@ -18,9 +18,9 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cusparse.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/dso_loader.h" #include "tsl/platform/load_library.h" -#include "tsl/platform/logging.h" // Implements the cusparse API by forwarding to cusparse loaded from the DSO. diff --git a/third_party/xla/xla/tsl/cuda/nccl_stub.cc b/third_party/xla/xla/tsl/cuda/nccl_stub.cc index f3895da2451760..345e5e5a6d6a67 100644 --- a/third_party/xla/xla/tsl/cuda/nccl_stub.cc +++ b/third_party/xla/xla/tsl/cuda/nccl_stub.cc @@ -17,9 +17,9 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/nccl/nccl.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/dso_loader.h" #include "tsl/platform/load_library.h" -#include "tsl/platform/logging.h" // Implements the nccl API by forwarding to nccl loaded from a DSO. diff --git a/third_party/xla/xla/tsl/distributed_runtime/BUILD b/third_party/xla/xla/tsl/distributed_runtime/BUILD index 4ad349e9b7eb1a..e969e9d986b06e 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/BUILD @@ -21,10 +21,10 @@ cc_library( srcs = ["call_options.cc"], hdrs = ["call_options.h"], deps = [ - "@local_tsl//tsl/platform:macros", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", ], ) diff --git a/third_party/xla/xla/tsl/distributed_runtime/call_options.h b/third_party/xla/xla/tsl/distributed_runtime/call_options.h index 99a66d4b42f311..95231e12b584d4 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/call_options.h +++ b/third_party/xla/xla/tsl/distributed_runtime/call_options.h @@ -18,10 +18,10 @@ limitations under the License. #include -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/mutex.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD index 2cd3c95ba96928..5b6032c8629f49 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD @@ -29,12 +29,12 @@ tsl_cc_test( srcs = ["coordination_service_error_util_test.cc"], deps = [ ":coordination_service_error_util", + "//xla/tsl/platform:status", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -43,8 +43,8 @@ cc_library( hdrs = ["coordination_client.h"], deps = [ "//xla/tsl/distributed_runtime:call_options", + "//xla/tsl/platform:status", "//xla/tsl/protobuf:coordination_service_proto_cc", - "@local_tsl//tsl/platform:status", ], ) @@ -53,14 +53,14 @@ cc_library( hdrs = ["coordination_service.h"], deps = [ ":coordination_client", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", "//xla/tsl/protobuf:coordination_config_proto_cc", "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:status", ], ) @@ -75,6 +75,8 @@ tsl_gpu_library( ":coordination_service", ":coordination_service_error_util", "//xla/tsl/distributed_runtime:call_options", + "//xla/tsl/platform:env", + "//xla/tsl/platform:status", "//xla/tsl/protobuf:coordination_config_proto_cc", "//xla/tsl/protobuf:coordination_service_proto_cc", "//xla/tsl/util:device_name_utils", @@ -91,9 +93,7 @@ tsl_gpu_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:random", - "@local_tsl//tsl/platform:status", ], alwayslink = 1, ) @@ -119,6 +119,11 @@ tsl_cc_test( ":test_device_proto_cc", "//xla/tsl/distributed_runtime:call_options", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:status", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "//xla/tsl/platform:types", "//xla/tsl/protobuf:coordination_config_proto_cc", "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/base:core_headers", @@ -128,12 +133,7 @@ tsl_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:random", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/platform:types", ], ) @@ -147,6 +147,8 @@ tsl_gpu_library( "//xla/tsl/distributed_runtime:call_options", "//xla/tsl/framework:cancellation", "//xla/tsl/lib/monitoring:gauge", + "//xla/tsl/platform:env", + "//xla/tsl/platform:status", "//xla/tsl/protobuf:coordination_config_proto_cc", "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/base:core_headers", @@ -161,9 +163,7 @@ tsl_gpu_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:random", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:thread_annotations", ], ) @@ -177,6 +177,11 @@ tsl_cc_test( ":coordination_service_error_util", "//xla/tsl/distributed_runtime:call_options", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:status", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "//xla/tsl/protobuf:coordination_config_proto_cc_impl", "//xla/tsl/protobuf:coordination_service_proto_cc_impl", "@com_google_absl//absl/log", @@ -184,11 +189,6 @@ tsl_cc_test( "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -202,6 +202,7 @@ cc_library( ":coordination_service", ":coordination_service_agent", ":coordination_service_error_util", + "//xla/tsl/platform:status", "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", @@ -210,7 +211,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:thread_annotations", ], ) @@ -227,6 +227,11 @@ tsl_cc_test( "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_client", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_service_impl", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:status", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "//xla/tsl/protobuf:coordination_config_proto_cc_impl", "//xla/tsl/protobuf:coordination_service_proto_cc_impl", "//xla/tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", @@ -236,11 +241,6 @@ tsl_cc_test( "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ] + tsl_grpc_cc_dependencies(), ) @@ -258,6 +258,12 @@ tsl_cc_test( "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_client", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_service_impl", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "//xla/tsl/protobuf:coordination_config_proto_cc_impl", "//xla/tsl/protobuf:coordination_service_proto_cc_impl", "//xla/tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", @@ -267,12 +273,6 @@ tsl_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ] + tsl_grpc_cc_dependencies(), ) diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/client_server_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/client_server_test.cc index ae1346d1c37761..c52872a904dc36 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/client_server_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/client_server_test.cc @@ -44,13 +44,13 @@ limitations under the License. #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/threadpool.h" #include "xla/tsl/protobuf/coordination_config.pb.h" #include "xla/tsl/protobuf/coordination_service.pb.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/test.h" -#include "tsl/platform/threadpool.h" namespace tsl { namespace { @@ -1046,6 +1046,35 @@ TEST_F(ClientServerTest, } } +TEST_F(ClientServerTest, GetAliveTasks_Succeed) { + const int num_nodes = 2; + StartService(num_nodes); + + auto thread_fn = [&](int node_id) -> absl::Status { + auto client = GetClient(node_id); + TF_RETURN_IF_ERROR(client->Connect()); + absl::StatusOr> alive_tasks = + client->GetAliveTasks({GetTask(0), GetTask(1)}); + if (!alive_tasks.ok()) { + return alive_tasks.status(); + } + TF_RETURN_IF_ERROR(client->Shutdown()); + return absl::OkStatus(); + }; + + std::vector statuses(num_nodes); + { + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", + num_nodes); + for (int i = 0; i < num_nodes; ++i) { + thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); }); + } + } + for (int i = 0; i < num_nodes; ++i) { + TF_EXPECT_OK(statuses[i]); + } +} + TEST_F(ClientServerTest, GetKeyValueDir) { StartService(/*num_nodes=*/1); auto client = GetClient(/*node_id=*/0); diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_client.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_client.h index cbdd0f2147a35e..6efd02a736850d 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_client.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_client.h @@ -20,8 +20,8 @@ limitations under the License. #include #include "xla/tsl/distributed_runtime/call_options.h" +#include "xla/tsl/platform/status.h" #include "xla/tsl/protobuf/coordination_service.pb.h" -#include "tsl/platform/status.h" namespace tsl { using tensorflow::BarrierRequest; @@ -30,6 +30,8 @@ using tensorflow::CancelBarrierRequest; using tensorflow::CancelBarrierResponse; using tensorflow::DeleteKeyValueRequest; using tensorflow::DeleteKeyValueResponse; +using tensorflow::GetAliveTasksRequest; +using tensorflow::GetAliveTasksResponse; using tensorflow::GetKeyValueDirRequest; using tensorflow::GetKeyValueDirResponse; using tensorflow::GetKeyValueRequest; @@ -127,6 +129,11 @@ class CoordinationClient { virtual void CancelBarrierAsync(const CancelBarrierRequest* request, CancelBarrierResponse* response, StatusCallback done) = 0; + + virtual void GetAliveTasksAsync(const GetAliveTasksRequest* request, + GetAliveTasksResponse* response, + StatusCallback done) = 0; + virtual void PollForErrorAsync(CallOptions* call_opts, const PollForErrorRequest* request, PollForErrorResponse* response, diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc index 6f6edabd2b786c..227e1ff5a1159f 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -47,12 +48,12 @@ limitations under the License. #include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/status.h" #include "xla/tsl/protobuf/coordination_config.pb.h" #include "xla/tsl/protobuf/coordination_service.pb.h" #include "xla/tsl/util/device_name_utils.h" -#include "tsl/platform/env.h" #include "tsl/platform/random.h" -#include "tsl/platform/status.h" namespace tsl { namespace { @@ -106,6 +107,10 @@ struct CoordinatedTaskEqual { } }; +using CoordinatedTaskSet = + absl::flat_hash_set; + absl::Status MakeShutdownBarrierError(const absl::Status& error) { return MakeCoordinationError(absl::InternalError(absl::StrCat( "Shutdown barrier has failed.\nBarrier result: '", error.ToString()))); @@ -159,6 +164,9 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { BarrierCallback done) override; absl::Status CancelBarrier(std::string barrier_id, int64_t counter, const CoordinatedTask& task) override; + void GetAliveTasksAsync(const tensorflow::CoordinatedTask& requesting_task, + const std::vector& tasks, + GetAliveTasksCallback done) override; void PollForErrorAsync(const CoordinatedTask& task, StatusCallback done) override; @@ -420,6 +428,25 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { bool recoverable_ = false; }; + // AlivenessState tracks the state of pending GetAliveTasks calls. + struct AlivenessState { + // All tasks that can participate in the GetAliveTasks barrier. + CoordinatedTaskSet tasks; + // All tasks currently blocked on the barrier. + CoordinatedTaskSet in_barrier; + // Done callbacks for the tasks blocked on the barrier. + std::vector dones; + }; + + // Returns the set of alive tasks drawn from the provided set of tasks. + CoordinatedTaskSet AliveTasks(const CoordinatedTaskSet& tasks) const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + + // Refreshes the AlivenessStates of all pending GetAliveTasks call, + // potentially finishing some of the pending calls. The AlivenessStates should + // be refreshed, for example, after a task has failed. + void RefreshAliveness() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + std::unique_ptr client_cache_; Env& env_; const uint64_t service_incarnation_ = random::New64(); @@ -462,6 +489,9 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { // use a set. absl::flat_hash_set ongoing_barriers_ ABSL_GUARDED_BY(state_mu_); + // The state of all pending GetAliveTasks calls. + std::vector aliveness_states_ ABSL_GUARDED_BY(state_mu_); + absl::flat_hash_set recoverable_jobs_; ErrorPollingState error_polling_state_ ABSL_GUARDED_BY(state_mu_); @@ -1034,6 +1064,7 @@ absl::Status CoordinationServiceStandaloneImpl::DisconnectTask( task_state->Disconnect( /*grace_period_duration_us=*/heartbeat_timeout_ms_ * 1000); LeaveOngoingBarriers(task, "task disconnected"); + RefreshAliveness(); error_polling_state_.RemoveTask(task, "task has disconnected."); LOG(INFO) << task_name << " has disconnected from coordination service."; return absl::OkStatus(); @@ -1319,8 +1350,9 @@ std::vector CoordinationServiceStandaloneImpl::GetKeyValueDir( for (it = begin; it != kv_store_.end(); ++it) { // Stop once the next key does not have the directory prefix. Since keys are // ordered, none of the other keys would have a matching prefix. - if (std::mismatch(dir.begin(), dir.end(), it->first.begin()).first != - dir.end()) { + if (std::mismatch(dir.begin(), dir.end(), it->first.begin(), + it->first.end()) + .first != dir.end()) { break; } KeyValueEntry kv; @@ -1342,8 +1374,9 @@ absl::Status CoordinationServiceStandaloneImpl::DeleteKeyValue( auto begin = kv_store_.lower_bound(dir); std::map::iterator end; for (end = begin; end != kv_store_.end(); end++) { - if (std::mismatch(dir.begin(), dir.end(), end->first.begin()).first != - dir.end()) + if (std::mismatch(dir.begin(), dir.end(), end->first.begin(), + end->first.end()) + .first != dir.end()) break; } kv_store_.erase(begin, end); @@ -1368,6 +1401,7 @@ void CoordinationServiceStandaloneImpl::SetTaskError( if (task_state->SetError(error)) { LeaveOngoingBarriers( task, absl::StrCat("task is set to ERROR: ", error.ToString())); + RefreshAliveness(); } } @@ -1768,6 +1802,103 @@ void CoordinationServiceStandaloneImpl::PassBarrier( } } +// Returns true if x is a (non-strict) subset of y. +bool TaskSetSubset(const CoordinatedTaskSet& x, const CoordinatedTaskSet& y) { + return std::all_of(x.begin(), x.end(), [&y](const CoordinatedTask& task) { + return y.contains(task); + }); +} + +// Returns true if sets x and y are equal. +// +// Note that the default equality operator (==) on absl::flat_hash_set invokes +// the equal operator on the underlying elements in the sets, but the equal +// operator is not defined on protos. Thus, we have to implement our own +// equality function. +bool TaskSetEqual(const CoordinatedTaskSet& x, const CoordinatedTaskSet& y) { + return x.size() == y.size() && TaskSetSubset(x, y); +} + +CoordinatedTaskSet CoordinationServiceStandaloneImpl::AliveTasks( + const CoordinatedTaskSet& tasks) const { + CoordinatedTaskSet alive_tasks; + for (const CoordinatedTask& task : tasks) { + auto it = cluster_state_.find(GetTaskName(task)); + if (it != cluster_state_.end() && + it->second->GetState() == CoordinatedTaskState::TASKSTATE_CONNECTED) { + // We consider a task alive if it is CONNECTED. + alive_tasks.insert(task); + } + } + return alive_tasks; +} + +void CoordinationServiceStandaloneImpl::RefreshAliveness() { + // Try to finish every pending GetAliveTasks call. + auto it = aliveness_states_.begin(); + while (it != aliveness_states_.end()) { + CoordinatedTaskSet alive_tasks = AliveTasks(it->tasks); + if (TaskSetSubset(alive_tasks, it->in_barrier)) { + // Every alive task is in the barrier, so the barrier is satisfied. Return + // the same set of alive tasks (alive_tasks) to every task in the barrier. + std::vector v{alive_tasks.begin(), alive_tasks.end()}; + for (const GetAliveTasksCallback& done : it->dones) { + done(absl::OkStatus(), v); + } + + // Remove the pending GetAliveTasks call because it is no longer pending. + it = aliveness_states_.erase(it); + } else { + // The pending GetAliveTasks call is still pending. + ++it; + } + } +} + +void CoordinationServiceStandaloneImpl::GetAliveTasksAsync( + const tensorflow::CoordinatedTask& requesting_task, + const std::vector& tasks, + GetAliveTasksCallback done) { + // TODO(mwhittaker): Figure out good timeout semantics and add timeouts. + + // Validate that the requesting task is a member of tasks. + CoordinatedTaskSet task_set{tasks.begin(), tasks.end()}; + if (!task_set.contains(requesting_task)) { + // TODO(mwhittaker): Consider relaxing the requirement that the requesting + // task is one of the specified tasks. + absl::Status err = absl::InvalidArgumentError(absl::StrCat( + "Requesting task ", GetTaskName(requesting_task), + " is not one of the tasks specified in a GetAliveTasks request.")); + done(err, {}); + return; + } + + // Find the corresponding AlivenessState, creating a new one if needed. + absl::MutexLock l(&state_mu_); + auto it = std::find_if(aliveness_states_.begin(), aliveness_states_.end(), + [&task_set](const AlivenessState& state) { + return TaskSetEqual(state.tasks, task_set); + }); + if (it == aliveness_states_.end()) { + aliveness_states_.push_back(AlivenessState{task_set}); + it = std::prev(aliveness_states_.end()); + } + + // Enter the requesting task into the barrier. + it->in_barrier.insert(requesting_task); + it->dones.push_back(std::move(done)); + + // Finish the barrier, if possible. + CoordinatedTaskSet alive_tasks = AliveTasks(task_set); + if (TaskSetSubset(alive_tasks, it->in_barrier)) { + std::vector v{alive_tasks.begin(), alive_tasks.end()}; + for (const GetAliveTasksCallback& done : it->dones) { + done(absl::OkStatus(), v); + } + aliveness_states_.erase(it); + } +} + void CoordinationServiceStandaloneImpl::SendErrorPollingResponse( const absl::Status& error) { CHECK(IsClientPollingForError()) diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h index 1fa2bd0b810627..0dac23f55b762f 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h @@ -30,10 +30,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/time/time.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" #include "xla/tsl/protobuf/coordination_config.pb.h" #include "xla/tsl/protobuf/coordination_service.pb.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/status.h" namespace tsl { class Env; @@ -77,6 +77,8 @@ class CoordinationServiceInterface { using StatusOrValueCallback = std::function&)>; using BarrierCallback = std::function; + using GetAliveTasksCallback = std::function&)>; virtual ~CoordinationServiceInterface() = default; @@ -250,6 +252,40 @@ class CoordinationServiceInterface { std::string barrier_id, int64_t counter, const tensorflow::CoordinatedTask& task) = 0; + // Returns the set of currently alive tasks. More specifically, given a set of + // tasks T, GetAliveTasks(T) returns the subset T of alive tasks. Note that + // `tasks` must include `requesting_task`. + // + // # Barrier Semantics + // + // If multiple tasks call GetAliveTasks concurrently, it's important that they + // all agree on which tasks are alive. Otherwise, the tasks' behavior might + // diverge. For example, imagine a set of tasks trying to run an AllGather, + // but they all disagree on which tasks should be participating in the + // AllGather. This is buggy. + // + // To ensure that every task agrees on which tasks are alive, the + // GetAliveTasks RPC has barrier-like semantics. Consider an invocation + // GetAliveTasks(T) for a set of tasks T. The invocation acts as a barrier, + // waiting for every task in T to call GetAliveTasks(T). Afterwards, + // GetAliveTasks returns the same set of alive tasks A to all the tasks in T. + // This ensures that every task agrees which tasks are alive. + // + // One small correction. GetAliveTasks doesn't act as a barrier for *every* + // task in T. Some tasks in T might have failed, so we should not wait for + // them. Instead, the GetAliveTasks RPC waits only for the returned tasks A. + // + // # An Example + // + // Imagine we have four tasks: A, B, C, and D. Further imagine that task D + // has failed and that every task calls GetAliveTasks([A, B, C, D]). The + // invocation will return tasks [A, B, C]. The GetAliveTasks call acts as a + // barrier across tasks A, B, and C. Task D, which failed, is ignored. + virtual void GetAliveTasksAsync( + const tensorflow::CoordinatedTask& requesting_task, + const std::vector& tasks, + GetAliveTasksCallback done) = 0; + // Gets error from the coordination service. Block until the service // returns an error or the task/service is shutdown. This should never be used // when there is service to client connection (i.e. `CoordinationClientCache` diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc index 342ede2e05183c..dadf760dd9e9b4 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc @@ -48,11 +48,11 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" #include "xla/tsl/framework/cancellation.h" #include "xla/tsl/lib/monitoring/gauge.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/status.h" #include "xla/tsl/protobuf/coordination_config.pb.h" #include "xla/tsl/protobuf/coordination_service.pb.h" -#include "tsl/platform/env.h" #include "tsl/platform/random.h" -#include "tsl/platform/status.h" #include "tsl/platform/thread_annotations.h" // TODO(b/342448688): Expose via config and API instead of flag. @@ -142,6 +142,8 @@ class CoordinationServiceAgentImpl : public CoordinationServiceAgent { absl::Status CancelBarrier(std::string_view barrier_id) override; void CancelBarrierAsync(std::string_view barrier_id, StatusCallback done) override; + absl::StatusOr> GetAliveTasks( + const std::vector& tasks) override; absl::StatusOr GetEnv() override; @@ -1064,6 +1066,39 @@ void CoordinationServiceAgentImpl::CancelBarrierAsync( }); } +absl::StatusOr> +CoordinationServiceAgentImpl::GetAliveTasks( + const std::vector& tasks) { + // Validate the agent. + if (absl::Status s = ValidateRunningAgent(/*allow_disconnected=*/true); + !s.ok()) { + return s; + } + + // Form the request and response. + auto request = std::make_shared(); + auto response = std::make_shared(); + *request->mutable_requesting_task() = task_; + *request->mutable_tasks() = {tasks.begin(), tasks.end()}; + + // Issue the request and wait for it to finish. + absl::Status status; + absl::Notification n; + auto done = [&status, &n](const absl::Status& s) { + status = s; + n.Notify(); + }; + leader_client_->GetAliveTasksAsync(request.get(), response.get(), done); + n.WaitForNotification(); + + // Parse the response. + if (!status.ok()) { + return status; + } + return std::vector( + response->alive_tasks().begin(), response->alive_tasks().end()); +} + // Returns an error if agent is not running. absl::Status CoordinationServiceAgentImpl::ValidateRunningAgent( bool allow_disconnected) { diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.h index 50dd8c86d87c69..2cfef926ae8be7 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.h @@ -28,8 +28,8 @@ limitations under the License. #include "absl/time/time.h" #include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" +#include "xla/tsl/platform/status.h" #include "xla/tsl/protobuf/coordination_service.pb.h" -#include "tsl/platform/status.h" namespace tensorflow { class CoordinationServiceConfig; @@ -272,6 +272,37 @@ class CoordinationServiceAgent { virtual void CancelBarrierAsync(std::string_view barrier_id, StatusCallback done) = 0; + // Returns the set of currently alive tasks. More specifically, given a set of + // tasks T, GetAliveTasks(T) returns the subset T of alive tasks. + // + // # Barrier Semantics + // + // If multiple tasks call GetAliveTasks concurrently, it's important that they + // all agree on which tasks are alive. Otherwise, the tasks' behavior might + // diverge. For example, imagine a set of tasks trying to run an AllGather, + // but they all disagree on which tasks should be participating in the + // AllGather. This is buggy. + // + // To ensure that every task agrees on which tasks are alive, the + // GetAliveTasks RPC has barrier-like semantics. Consider an invocation + // GetAliveTasks(T) for a set of tasks T. The invocation acts as a barrier, + // waiting for every task in T to call GetAliveTasks(T). Afterwards, + // GetAliveTasks returns the same set of alive tasks A to all the tasks in T. + // This ensures that every task agrees which tasks are alive. + // + // One small correction. GetAliveTasks doesn't act as a barrier for *every* + // task in T. Some tasks in T might have failed, so we should not wait for + // them. Instead, the GetAliveTasks RPC waits only for the returned tasks A. + // + // # An Example + // + // Imagine we have four tasks: A, B, C, and D. Further imagine that task D + // has failed and that every task calls GetAliveTasks([A, B, C, D]). The + // invocation will return tasks [A, B, C]. The GetAliveTasks call acts as a + // barrier across tasks A, B, and C. Task D, which failed, is ignored. + virtual absl::StatusOr> + GetAliveTasks(const std::vector& tasks) = 0; + // Get unowned Env* that the agent was initialized with. virtual absl::StatusOr GetEnv() = 0; diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc index 1d27195217c497..658f3c971056bc 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc @@ -31,11 +31,11 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/test.h" #include "xla/tsl/protobuf/coordination_config.pb.h" #include "xla/tsl/protobuf/coordination_service.pb.h" -#include "tsl/platform/env.h" -#include "tsl/platform/status.h" -#include "tsl/platform/test.h" namespace tsl { namespace { @@ -138,6 +138,10 @@ class TestCoordinationClient : public CoordinationClient { (const CancelBarrierRequest*, CancelBarrierResponse*, StatusCallback), (override)); + MOCK_METHOD(void, GetAliveTasksAsync, + (const GetAliveTasksRequest*, GetAliveTasksResponse*, + StatusCallback), + (override)); MOCK_METHOD(void, GetTaskStateAsync, (const GetTaskStateRequest*, GetTaskStateResponse*, StatusCallback), diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc index 9b137d1e417f63..6cea4a579d08e6 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/match.h" +#include "xla/tsl/platform/test.h" #include "xla/tsl/protobuf/coordination_service.pb.h" -#include "tsl/platform/test.h" namespace tsl { namespace { using ::tensorflow::BarrierError; diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc index 737091b1ca7fc3..982bcd5d58a214 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc @@ -32,11 +32,11 @@ limitations under the License. #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/threadpool.h" #include "xla/tsl/protobuf/coordination_config.pb.h" -#include "tsl/platform/env.h" -#include "tsl/platform/status.h" -#include "tsl/platform/test.h" -#include "tsl/platform/threadpool.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc index 58a01ef5d3a296..4d1ac80e61110f 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc @@ -30,9 +30,9 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_service.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" +#include "xla/tsl/platform/status.h" #include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/protobuf.h" -#include "tsl/platform/status.h" namespace tsl { namespace { @@ -305,6 +305,29 @@ void CoordinationServiceRpcHandler::CancelBarrierAsync( request->source_task())); } +void CoordinationServiceRpcHandler::GetAliveTasksAsync( + const tensorflow::GetAliveTasksRequest* request, + tensorflow::GetAliveTasksResponse* response, StatusCallback done) { + absl::ReaderMutexLock l(&mu_); + if (service_ == nullptr) { + done(MakeCoordinationError( + absl::InternalError("Coordination service is not enabled."))); + return; + } + + std::vector tasks = {request->tasks().begin(), + request->tasks().end()}; + service_->GetAliveTasksAsync( + request->requesting_task(), tasks, + [done = std::move(done), response]( + const absl::Status& status, + const std::vector& alive_tasks) { + *response->mutable_alive_tasks() = {alive_tasks.begin(), + alive_tasks.end()}; + done(status); + }); +} + void CoordinationServiceRpcHandler::PollForErrorAsync( const tensorflow::PollForErrorRequest* request, tensorflow::PollForErrorResponse* response, StatusCallback done) { diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h index 2b9ca2ef9f3d2e..b77fb54b559029 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h @@ -20,8 +20,8 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" +#include "xla/tsl/platform/status.h" #include "xla/tsl/protobuf/coordination_service.pb.h" -#include "tsl/platform/status.h" #include "tsl/platform/thread_annotations.h" namespace tsl { @@ -92,6 +92,10 @@ class CoordinationServiceRpcHandler { tensorflow::CancelBarrierResponse* response, StatusCallback done); + void GetAliveTasksAsync(const tensorflow::GetAliveTasksRequest* request, + tensorflow::GetAliveTasksResponse* response, + StatusCallback done); + void PollForErrorAsync(const tensorflow::PollForErrorRequest* request, tensorflow::PollForErrorResponse* response, StatusCallback done); diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc index eb8dc35cac083a..3a1f2d100cf75e 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/blocking_counter.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" #include "absl/time/time.h" @@ -37,13 +38,13 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" #include "xla/tsl/distributed_runtime/coordination/test_device.pb.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/protobuf/coordination_config.pb.h" #include "xla/tsl/protobuf/coordination_service.pb.h" -#include "tsl/platform/env.h" #include "tsl/platform/random.h" -#include "tsl/platform/status.h" -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" namespace tsl { namespace { @@ -53,6 +54,7 @@ using ::testing::EqualsProto; using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::UnorderedElementsAre; +using ::testing::UnorderedPointwise; using ::testing::status::StatusIs; using tensorflow::CoordinatedJob; @@ -111,7 +113,8 @@ class TestCoordinationClient : public CoordinationClient { #define UNIMPLEMENTED(method) \ void method##Async(const method##Request* request, \ method##Response* response, StatusCallback done) \ - override{done(absl::UnimplementedError(#method "Async")); \ + override { \ + done(absl::UnimplementedError(#method "Async")); \ } UNIMPLEMENTED(WaitForAllTasks); @@ -123,6 +126,7 @@ class TestCoordinationClient : public CoordinationClient { UNIMPLEMENTED(GetKeyValueDir); UNIMPLEMENTED(DeleteKeyValue); UNIMPLEMENTED(CancelBarrier); + UNIMPLEMENTED(GetAliveTasks); #undef UNIMPLEMENTED #define UNIMPLEMENTED_WITH_CALL_OPTS(method) \ @@ -203,6 +207,7 @@ class CoordinationBarrierTest : public ::testing::Test { return coord_service_.get(); } CoordinatedTask GetTask(int i) { return tasks_[i]; } + const std::vector& GetTasks() { return tasks_; } // TODO(b/286141652) Refactor this method into a util file. std::string GetTaskName(const CoordinatedTask& task) { @@ -2407,4 +2412,129 @@ TEST_F(CoordinateTwoTasksTest, RegisterWithBarrier_Timeout) { EXPECT_THAT(coord_service_->RegisterTask(task_0_, incarnation_0_), StatusIs(absl::StatusCode::kDeadlineExceeded)); } + +using GetAliveTasksTest = CoordinationBarrierTest; + +TEST_F(GetAliveTasksTest, SuccessfulGetAliveTasks) { + // This test has three tasks successfully call GetAliveTasks. + absl::BlockingCounter finished(3); + auto done = [&](const absl::Status& status, + const std::vector& alive_tasks) { + EXPECT_OK(status); + EXPECT_THAT(alive_tasks, UnorderedPointwise(EqualsProto(), GetTasks())); + finished.DecrementCount(); + }; + GetCoordinationService()->GetAliveTasksAsync(GetTask(0), GetTasks(), done); + GetCoordinationService()->GetAliveTasksAsync(GetTask(1), GetTasks(), done); + GetCoordinationService()->GetAliveTasksAsync(GetTask(2), GetTasks(), done); + finished.Wait(); +} + +TEST_F(GetAliveTasksTest, FailedTaskBeforeCallingGetAliveTasks) { + // This test involves three tasks: 0, 1, and 2. Task 2 is failed. Then, tasks + // 0 and 1 call GetAliveTasks on tasks [0, 1, 2], which should return [0, 1]. + absl::BlockingCounter finished(2); + auto done = [&](const absl::Status& status, + const std::vector& alive_tasks) { + EXPECT_OK(status); + EXPECT_THAT(alive_tasks, + UnorderedPointwise(EqualsProto(), {GetTask(0), GetTask(1)})); + finished.DecrementCount(); + }; + ASSERT_OK(GetCoordinationService()->ReportTaskError( + GetTask(2), absl::InternalError("failed"))); + GetCoordinationService()->GetAliveTasksAsync(GetTask(0), GetTasks(), done); + GetCoordinationService()->GetAliveTasksAsync(GetTask(1), GetTasks(), done); + finished.Wait(); +} + +TEST_F(GetAliveTasksTest, FailedTaskAfterCallingGetAliveTasks) { + // This test involves three tasks: 0, 1, and 2. Tasks 0 and 1 call + // GetAliveTasks on tasks [0, 1, 2]. Then, task 2 is failed, which should + // cause GetAliveTasks to return [0, 1]. + absl::BlockingCounter finished(2); + auto done = [&](const absl::Status& status, + const std::vector& alive_tasks) { + EXPECT_OK(status); + EXPECT_THAT(alive_tasks, + UnorderedPointwise(EqualsProto(), {GetTask(0), GetTask(1)})); + finished.DecrementCount(); + }; + GetCoordinationService()->GetAliveTasksAsync(GetTask(0), GetTasks(), done); + GetCoordinationService()->GetAliveTasksAsync(GetTask(1), GetTasks(), done); + ASSERT_OK(GetCoordinationService()->ReportTaskError( + GetTask(2), absl::InternalError("failed"))); + finished.Wait(); +} + +TEST_F(GetAliveTasksTest, ConcurrentGetAliveTasks) { + // This test involves three tasks: 0, 1, and 2. Tasks 0 and 1 call + // GetAliveTasks on tasks [0, 1], and concurrently tasks 1 and 2 call + // GetAliveTasks on tasks [1, 2]. + + // GetAliveTasks on tasks 0 and 1. + std::vector tasks_01{GetTask(0), GetTask(1)}; + absl::BlockingCounter finished_01(2); + auto done_01 = [&](const absl::Status& status, + const std::vector& alive_tasks) { + EXPECT_OK(status); + EXPECT_THAT(alive_tasks, UnorderedPointwise(EqualsProto(), tasks_01)); + finished_01.DecrementCount(); + }; + + // GetAliveTasks on tasks 1 and 2. + std::vector tasks_12{GetTask(1), GetTask(2)}; + absl::BlockingCounter finished_12(2); + auto done_12 = [&](const absl::Status& status, + const std::vector& alive_tasks) { + EXPECT_OK(status); + EXPECT_THAT(alive_tasks, UnorderedPointwise(EqualsProto(), tasks_12)); + finished_12.DecrementCount(); + }; + + // Run both GetAliveTasks concurrently. + GetCoordinationService()->GetAliveTasksAsync(GetTask(0), tasks_01, done_01); + GetCoordinationService()->GetAliveTasksAsync(GetTask(1), tasks_12, done_12); + GetCoordinationService()->GetAliveTasksAsync(GetTask(1), tasks_01, done_01); + GetCoordinationService()->GetAliveTasksAsync(GetTask(2), tasks_12, done_12); + finished_01.Wait(); + finished_12.Wait(); +} + +TEST_F(GetAliveTasksTest, CallingGetAliveTasksWithoutBeingAMember) { + // This test includes calls to GetAliveTasks where the requesting task is not + // included in the specified set of tasks. This should return an error. + absl::BlockingCounter finished(3); + auto done = [&](const absl::Status& status, + const std::vector&) { + EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument)); + finished.DecrementCount(); + }; + + CoordinationServiceInterface* s = GetCoordinationService(); + s->GetAliveTasksAsync(GetTask(0), {GetTask(1), GetTask(2)}, done); + s->GetAliveTasksAsync(GetTask(1), {GetTask(0), GetTask(2)}, done); + s->GetAliveTasksAsync(GetTask(2), {GetTask(0), GetTask(1)}, done); + finished.Wait(); +} + +TEST_F(GetAliveTasksTest, RedundantGetAliveTasks) { + // This test has three tasks call GetAliveTasks, with the twist that some + // tasks call GetAliveTasks multiple times. + absl::BlockingCounter finished(6); + auto done = [&](const absl::Status& status, + const std::vector& alive_tasks) { + EXPECT_OK(status); + EXPECT_THAT(alive_tasks, UnorderedPointwise(EqualsProto(), GetTasks())); + finished.DecrementCount(); + }; + GetCoordinationService()->GetAliveTasksAsync(GetTask(0), GetTasks(), done); + GetCoordinationService()->GetAliveTasksAsync(GetTask(0), GetTasks(), done); + GetCoordinationService()->GetAliveTasksAsync(GetTask(0), GetTasks(), done); + GetCoordinationService()->GetAliveTasksAsync(GetTask(1), GetTasks(), done); + GetCoordinationService()->GetAliveTasksAsync(GetTask(1), GetTasks(), done); + GetCoordinationService()->GetAliveTasksAsync(GetTask(2), GetTasks(), done); + finished.Wait(); +} + } // namespace tsl diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD b/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD index d17a9146697cec..1c0229dca29387 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD @@ -17,13 +17,13 @@ cc_library( hdrs = ["preemption_notifier.h"], compatible_with = get_compatible_with_portable(), deps = [ + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:mutex", - "@local_tsl//tsl/platform:statusor", ], ) @@ -33,16 +33,16 @@ tsl_cc_test( srcs = ["preemption_notifier_test.cc"], deps = [ ":preemption_notifier", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:mutex", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -55,6 +55,8 @@ cc_library( "//xla/tsl/distributed_runtime:call_options", "//xla/tsl/distributed_runtime/coordination:coordination_service_agent", "//xla/tsl/lib/monitoring:gauge", + "//xla/tsl/platform:env", + "//xla/tsl/platform:statusor", "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", @@ -64,8 +66,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:statusor", ], ) @@ -83,6 +83,10 @@ tsl_cc_test( "//xla/tsl/distributed_runtime/rpc:async_service_interface", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_client", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_service_impl", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "//xla/tsl/protobuf:coordination_config_proto_cc_impl", "//xla/tsl/protobuf:coordination_service_proto_cc_impl", "//xla/tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", @@ -91,9 +95,5 @@ tsl_cc_test( "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ] + tsl_grpc_cc_dependencies(), ) diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_notifier.cc b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_notifier.cc index b1656ef8d59989..e2c6e625d2ee67 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_notifier.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_notifier.cc @@ -23,10 +23,10 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/time/clock.h" #include "absl/time/time.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/statusor.h" #if defined(PLATFORM_GOOGLE) #include "thread/executor.h" #include "thread/signal.h" diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_notifier.h b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_notifier.h index 6cdc16ff0f1733..97479dd06ae61c 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_notifier.h +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_notifier.h @@ -24,9 +24,9 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/time/time.h" -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/statusor.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_notifier_test.cc b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_notifier_test.cc index 837148e6add163..91aa778684a83f 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_notifier_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_notifier_test.cc @@ -22,11 +22,11 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/time/clock.h" #include "absl/time/time.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #if defined(PLATFORM_GOOGLE) #include "thread/executor.h" #include "thread/signal.h" diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc index c6e41a9f030f62..6b70f803573653 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc @@ -37,9 +37,9 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "xla/tsl/distributed_runtime/preemption/preemption_notifier.h" #include "xla/tsl/lib/monitoring/gauge.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/statusor.h" #include "xla/tsl/protobuf/coordination_service.pb.h" -#include "tsl/platform/env.h" -#include "tsl/platform/statusor.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc index 616b8ccd5fcf99..8598a4a56e7ef5 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc @@ -34,10 +34,10 @@ limitations under the License. #include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/threadpool.h" #include "xla/tsl/protobuf/coordination_config.pb.h" -#include "tsl/platform/env.h" -#include "tsl/platform/test.h" -#include "tsl/platform/threadpool.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD b/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD index 20fe4eb5a5f9b4..9971f49cf1362b 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD @@ -36,11 +36,11 @@ cc_library( srcs = ["grpc_util.cc"], hdrs = ["grpc_util.h"], deps = [ + "//xla/tsl/platform:status", "//xla/tsl/protobuf:distributed_runtime_payloads_proto_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:cord", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:stringpiece", "@local_tsl//tsl/platform:stringprintf", ] + tsl_grpc_cc_dependencies(), @@ -56,12 +56,12 @@ tsl_cc_test( deps = [ ":grpc_util", ":test_request_proto_cc_impl", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", "//xla/tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/platform:test_main", ] + tsl_grpc_cc_dependencies(), ) @@ -70,8 +70,8 @@ cc_library( hdrs = ["grpc_channel_common.h"], deps = [ ":grpc_util", + "//xla/tsl/platform:logging", "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:mutex", ], ) @@ -84,19 +84,19 @@ cc_library( ":grpc_channel_common", ":grpc_util", "//xla/tsl/lib/gtl:map_util", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "//xla/tsl/protobuf:rpc_options_proto_cc", "//xla/tsl/util:device_name_utils", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:numbers", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:str_util", "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", ] + tsl_grpc_cc_dependencies(), ) @@ -109,12 +109,12 @@ tsl_cc_test( deps = [ ":grpc_channel", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "//xla/tsl/protobuf:rpc_options_proto_cc_impl", "//xla/tsl/util:device_name_utils", - "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:strcat", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -125,11 +125,11 @@ cc_library( ":grpc_client_cq_tag", ":grpc_util", "//xla/tsl/distributed_runtime:call_options", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", "//xla/tsl/util:env_var", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:strcat", ] + tsl_grpc_cc_dependencies(), ) @@ -139,7 +139,7 @@ cc_library( srcs = [], hdrs = ["grpc_client_cq_tag.h"], deps = [ - "@local_tsl//tsl/platform:macros", + "//xla/tsl/platform:macros", ], ) diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/BUILD b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/BUILD index f50e88e89466e3..d1d1cbd1016763 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/BUILD @@ -21,13 +21,13 @@ cc_library( "//xla/tsl/distributed_runtime/rpc:grpc_client_cq_tag", "//xla/tsl/distributed_runtime/rpc:grpc_state", "//xla/tsl/distributed_runtime/rpc:grpc_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:status", "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status", ] + tsl_grpc_cc_dependencies(), ) @@ -42,12 +42,12 @@ cc_library( "//xla/tsl/distributed_runtime/rpc:async_service_interface", "//xla/tsl/distributed_runtime/rpc:grpc_call", "//xla/tsl/distributed_runtime/rpc:grpc_util", + "//xla/tsl/platform:env", "//xla/tsl/protobuf:coordination_service_cc_grpc_proto", "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:env", ] + tsl_grpc_cc_dependencies(), ) diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc index 8902f0859f0d0e..777f54cb21a933 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc @@ -34,10 +34,10 @@ limitations under the License. #include "xla/tsl/distributed_runtime/rpc/grpc_client_cq_tag.h" #include "xla/tsl/distributed_runtime/rpc/grpc_state.h" #include "xla/tsl/distributed_runtime/rpc/grpc_util.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/status.h" #include "xla/tsl/protobuf/coordination_service.pb.h" -#include "tsl/platform/env.h" #include "tsl/platform/protobuf.h" -#include "tsl/platform/status.h" namespace tsl { namespace { @@ -47,6 +47,8 @@ using tensorflow::CancelBarrierRequest; using tensorflow::CancelBarrierResponse; using tensorflow::DeleteKeyValueRequest; using tensorflow::DeleteKeyValueResponse; +using tensorflow::GetAliveTasksRequest; +using tensorflow::GetAliveTasksResponse; using tensorflow::GetKeyValueDirRequest; using tensorflow::GetKeyValueDirResponse; using tensorflow::GetKeyValueRequest; @@ -271,6 +273,16 @@ class GrpcCoordinationClient : public CoordinationClient { &target_); } + void GetAliveTasksAsync(const GetAliveTasksRequest* request, + GetAliveTasksResponse* response, + StatusCallback done) override { + new RPCState( + &stub_, cq_, "/tensorflow.CoordinationService/GetAliveTasks", *request, + response, std::move(done), /*call_opts=*/nullptr, + /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true, + &target_); + } + void PollForErrorAsync(CallOptions* call_opts, const PollForErrorRequest* request, PollForErrorResponse* response, diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.cc index d3187c291b2d92..460e408629983e 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.cc @@ -16,7 +16,7 @@ limitations under the License. #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h" #include "absl/synchronization/mutex.h" -#include "tsl/platform/threadpool.h" +#include "xla/tsl/platform/threadpool.h" namespace tsl { @@ -57,6 +57,7 @@ void GrpcCoordinationServiceImpl::HandleRPCsLoop() { ENQUEUE_REQUEST(DeleteKeyValue); ENQUEUE_REQUEST(Barrier); ENQUEUE_REQUEST(CancelBarrier); + ENQUEUE_REQUEST(GetAliveTasks); ENQUEUE_REQUEST(PollForError); #undef ENQUEUE_REQUEST diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h index 0fdaafc9f579bb..0550a8565e1e71 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h @@ -30,9 +30,9 @@ limitations under the License. #include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" #include "xla/tsl/distributed_runtime/rpc/grpc_call.h" #include "xla/tsl/distributed_runtime/rpc/grpc_util.h" +#include "xla/tsl/platform/threadpool.h" #include "xla/tsl/protobuf/coordination_service.grpc.pb.h" #include "xla/tsl/protobuf/coordination_service.pb.h" -#include "tsl/platform/threadpool.h" namespace tsl { @@ -98,6 +98,7 @@ class GrpcCoordinationServiceImpl : public AsyncServiceInterface { HANDLER(DeleteKeyValue); HANDLER(Barrier); HANDLER(CancelBarrier); + HANDLER(GetAliveTasks); HANDLER(PollForError); #undef HANDLER diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc index 2ebb8cc7e9499b..6b919bebf19b12 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc @@ -27,18 +27,18 @@ limitations under the License. #include "grpcpp/create_channel.h" #include "xla/tsl/distributed_runtime/rpc/grpc_channel_common.h" #include "xla/tsl/lib/gtl/map_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/protobuf/rpc_options.pb.h" #include "xla/tsl/util/device_name_utils.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" #include "tsl/platform/mutex.h" #include "tsl/platform/numbers.h" -#include "tsl/platform/status.h" #include "tsl/platform/str_util.h" #include "tsl/platform/strcat.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/platform/types.h" namespace tsl { @@ -56,7 +56,7 @@ absl::Status ValidateHostPortPair(const string& host_port) { } uint32 port; auto colon_index = host_port.find_last_of(':'); - if (!strings::safe_strtou32(host_port.substr(colon_index + 1), &port) || + if (!absl::SimpleAtoi(host_port.substr(colon_index + 1), &port) || host_port.substr(0, colon_index).find('/') != string::npos) { return errors::InvalidArgument("Could not interpret \"", host_port, "\" as a host-port pair."); @@ -88,7 +88,7 @@ ::grpc::ChannelArguments* CreateDefaultChannelArguments() { } } else { int64_t value; - if (strings::safe_strto64(name_value[1], &value)) { + if (absl::SimpleAtoi(name_value[1], &value)) { args->SetInt(name_value[0], value); } else { LOG(ERROR) << "Invalid integer value: " << grpc_option; diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_common.h b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_common.h index 61843ec9e20b76..8d37233abbf469 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_common.h +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_common.h @@ -21,7 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "xla/tsl/distributed_runtime/rpc/grpc_util.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/mutex.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc index 2790b0cd65dc44..3efae80a0511c2 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/test.h" #include "xla/tsl/protobuf/rpc_options.pb.h" #include "xla/tsl/util/device_name_utils.h" #include "tsl/platform/strcat.h" -#include "tsl/platform/test.h" namespace tsl { #define IsSameAddrSp DeviceNameUtils::IsSameAddressSpace diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_client_cq_tag.h b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_client_cq_tag.h index 5acb5a5d42245c..eb547c827ff0c8 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_client_cq_tag.h +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_client_cq_tag.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_TSL_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_ #define XLA_TSL_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_ -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_state.h b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_state.h index fca8da1e490bda..d59f2ced10ad81 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_state.h +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_state.h @@ -26,11 +26,11 @@ limitations under the License. #include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/rpc/grpc_client_cq_tag.h" #include "xla/tsl/distributed_runtime/rpc/grpc_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/threadpool.h" #include "xla/tsl/util/env_var.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" #include "tsl/platform/strcat.h" -#include "tsl/platform/threadpool.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util.h b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util.h index d39eb8e0f1be56..4b510b1a02afda 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util.h +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util.h @@ -23,9 +23,9 @@ limitations under the License. #include "absl/strings/cord.h" #include "grpcpp/grpcpp.h" #include "grpcpp/support/byte_buffer.h" +#include "xla/tsl/platform/status.h" #include "xla/tsl/protobuf/distributed_runtime_payloads.pb.h" #include "tsl/platform/protobuf.h" -#include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" #include "tsl/platform/stringprintf.h" diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util_test.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util_test.cc index 99d34350533596..182b6d02343bd9 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util_test.cc @@ -21,9 +21,9 @@ limitations under the License. #include "grpcpp/grpcpp.h" #include "xla/tsl/distributed_runtime/rpc/test_request.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/BUILD b/third_party/xla/xla/tsl/framework/BUILD index fc7213dab4016b..11a2cbccb2cd5a 100644 --- a/third_party/xla/xla/tsl/framework/BUILD +++ b/third_party/xla/xla/tsl/framework/BUILD @@ -121,21 +121,21 @@ cc_library( "//xla/tsl/lib/gtl:inlined_vector", "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:stringprintf", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", + "//xla/tsl/platform:types", ], otherwise = [ "//xla/tsl/lib/gtl:inlined_vector", - "@local_tsl//tsl/platform:logging", + "//xla/tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:strcat", - "@local_tsl//tsl/platform:env", + "//xla/tsl/platform:env", ], ), alwayslink = 1, @@ -164,17 +164,17 @@ cc_library( ":numeric_types", ":type_traits", "//xla/tsl/lib/gtl:inlined_vector", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:stringprintf", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/lib:scoped_memory_debug_annotation", "@local_tsl//tsl/profiler/lib:traceme", ], @@ -196,6 +196,9 @@ cc_library( ":metrics", ":shared_counter", "//xla/tsl/lib/core:bits", + "//xla/tsl/platform:env", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:types", "//xla/tsl/profiler/utils:trace_filter_utils", "//xla/tsl/protobuf:bfc_memory_map_proto_cc", "@com_google_absl//absl/base:core_headers", @@ -204,11 +207,8 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:stacktrace", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/lib:scoped_memory_debug_annotation", "@local_tsl//tsl/profiler/lib:traceme", ], @@ -248,15 +248,15 @@ cc_library( deps = [ ":device_type", "//xla/tsl/lib/gtl:int_type", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:types", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:types", ], ) @@ -269,14 +269,14 @@ cc_library( deps = [ ":device_id_impl", ":device_type", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", + "//xla/tsl/platform:statusor", "//xla/tsl/util:device_name_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:str_util", ], ) @@ -298,7 +298,7 @@ cc_library( ]), deps = [ ":fixedpoint_types", - "@local_tsl//tsl/platform:types", + "//xla/tsl/platform:types", ], ) @@ -308,7 +308,7 @@ cc_library( features = ["parse_headers"], visibility = ["//visibility:public"], deps = [ - "@local_tsl//tsl/platform:types", + "//xla/tsl/platform:types", ], ) @@ -339,7 +339,7 @@ cc_library( ]), deps = [ ":numeric_types", - "@local_tsl//tsl/platform:types", + "//xla/tsl/platform:types", ], ) @@ -366,16 +366,16 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//xla/tsl/lib/gtl:flatmap", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@com_google_absl//absl/memory", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:hash", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:notification", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:stringpiece", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", ], ) @@ -385,10 +385,10 @@ cc_library( hdrs = ["serving_device_selector.h"], visibility = ["//visibility:public"], deps = [ + "//xla/tsl/platform:logging", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", ], ) @@ -414,12 +414,12 @@ tsl_cc_test( srcs = ["cancellation_test.cc"], deps = [ ":cancellation", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:status", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@local_tsl//tsl/platform:notification", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -467,10 +467,10 @@ tsl_cc_test( ":device_id_impl", ":device_id_utils", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:status_matchers", + "//xla/tsl/platform:test_main", "//xla/tsl/protobuf:error_codes_proto_impl_cc", "//xla/tsl/util:device_name_utils", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:test_main", ], ) @@ -479,7 +479,7 @@ tsl_cc_test( srcs = ["real_time_in_memory_metric_test.cc"], deps = [ ":real_time_in_memory_metric", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/tsl/framework/allocator.cc b/third_party/xla/xla/tsl/framework/allocator.cc index 5b0235d57834b8..870b3953c8b677 100644 --- a/third_party/xla/xla/tsl/framework/allocator.cc +++ b/third_party/xla/xla/tsl/framework/allocator.cc @@ -19,11 +19,11 @@ limitations under the License. #include "xla/tsl/framework/allocator_registry.h" #include "xla/tsl/framework/tracking_allocator.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/mem.h" #include "tsl/platform/mutex.h" #include "tsl/platform/strcat.h" #include "tsl/platform/stringprintf.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/allocator.h b/third_party/xla/xla/tsl/framework/allocator.h index c289532c78a75e..a6ab9a67ad06a5 100644 --- a/third_party/xla/xla/tsl/framework/allocator.h +++ b/third_party/xla/xla/tsl/framework/allocator.h @@ -26,10 +26,10 @@ limitations under the License. #include "absl/types/optional.h" #include "xla/tsl/framework/numeric_types.h" #include "xla/tsl/framework/type_traits.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/numa.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/allocator_registry.cc b/third_party/xla/xla/tsl/framework/allocator_registry.cc index c56e777e9ffe9c..365f9c8ec814d6 100644 --- a/third_party/xla/xla/tsl/framework/allocator_registry.cc +++ b/third_party/xla/xla/tsl/framework/allocator_registry.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/allocator_registry.h b/third_party/xla/xla/tsl/framework/allocator_registry.h index 3487c95a40ec81..469072793d39f6 100644 --- a/third_party/xla/xla/tsl/framework/allocator_registry.h +++ b/third_party/xla/xla/tsl/framework/allocator_registry.h @@ -23,7 +23,7 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "xla/tsl/framework/allocator.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" #include "tsl/platform/mutex.h" #include "tsl/platform/numa.h" diff --git a/third_party/xla/xla/tsl/framework/allocator_retry.cc b/third_party/xla/xla/tsl/framework/allocator_retry.cc index 5ba0b4c585b379..8cc1bfc59e0477 100644 --- a/third_party/xla/xla/tsl/framework/allocator_retry.cc +++ b/third_party/xla/xla/tsl/framework/allocator_retry.cc @@ -23,8 +23,8 @@ limitations under the License. #include "absl/time/time.h" #include "absl/types/optional.h" #include "xla/tsl/framework/metrics.h" -#include "tsl/platform/env.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/allocator_retry.h b/third_party/xla/xla/tsl/framework/allocator_retry.h index 32e7840b0fd89b..01e5d1d2613c11 100644 --- a/third_party/xla/xla/tsl/framework/allocator_retry.h +++ b/third_party/xla/xla/tsl/framework/allocator_retry.h @@ -21,7 +21,7 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/bfc_allocator.cc b/third_party/xla/xla/tsl/framework/bfc_allocator.cc index d8a3dd4ed39d94..f4ff011c874039 100644 --- a/third_party/xla/xla/tsl/framework/bfc_allocator.cc +++ b/third_party/xla/xla/tsl/framework/bfc_allocator.cc @@ -35,14 +35,14 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/tsl/framework/allocator.h" #include "xla/tsl/framework/allocator_retry.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/file_system.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/trace_filter_utils.h" #include "xla/tsl/protobuf/bfc_memory_map.pb.h" -#include "tsl/platform/env.h" -#include "tsl/platform/file_system.h" -#include "tsl/platform/logging.h" #include "tsl/platform/numbers.h" #include "tsl/platform/stacktrace.h" -#include "tsl/platform/types.h" #include "tsl/profiler/lib/scoped_memory_debug_annotation.h" #include "tsl/profiler/lib/traceme.h" diff --git a/third_party/xla/xla/tsl/framework/bfc_allocator.h b/third_party/xla/xla/tsl/framework/bfc_allocator.h index 0afd6fdf4cb0c1..a0d6568efab2fc 100644 --- a/third_party/xla/xla/tsl/framework/bfc_allocator.h +++ b/third_party/xla/xla/tsl/framework/bfc_allocator.h @@ -35,9 +35,9 @@ limitations under the License. #include "xla/tsl/framework/allocator_retry.h" #include "xla/tsl/framework/shared_counter.h" #include "xla/tsl/lib/core/bits.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/numbers.h" -#include "tsl/platform/types.h" namespace tensorflow { class MemoryDump; diff --git a/third_party/xla/xla/tsl/framework/cancellation.cc b/third_party/xla/xla/tsl/framework/cancellation.cc index 83d60bcddb96d6..54d4303d48837c 100644 --- a/third_party/xla/xla/tsl/framework/cancellation.cc +++ b/third_party/xla/xla/tsl/framework/cancellation.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/cancellation.h b/third_party/xla/xla/tsl/framework/cancellation.h index 6dd04e269ff5d3..fcfd4c83e956aa 100644 --- a/third_party/xla/xla/tsl/framework/cancellation.h +++ b/third_party/xla/xla/tsl/framework/cancellation.h @@ -20,13 +20,13 @@ limitations under the License. #include #include "xla/tsl/lib/gtl/flatmap.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/hash.h" #include "tsl/platform/mutex.h" #include "tsl/platform/notification.h" -#include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/cancellation_test.cc b/third_party/xla/xla/tsl/framework/cancellation_test.cc index b9648fa8620939..6965d0b0b0270e 100644 --- a/third_party/xla/xla/tsl/framework/cancellation_test.cc +++ b/third_party/xla/xla/tsl/framework/cancellation_test.cc @@ -21,10 +21,10 @@ limitations under the License. #include #include +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/threadpool.h" #include "tsl/platform/notification.h" -#include "tsl/platform/status.h" -#include "tsl/platform/test.h" -#include "tsl/platform/threadpool.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/convolution/BUILD b/third_party/xla/xla/tsl/framework/convolution/BUILD index a6d0dc08608d24..80cab0e4d97eab 100644 --- a/third_party/xla/xla/tsl/framework/convolution/BUILD +++ b/third_party/xla/xla/tsl/framework/convolution/BUILD @@ -97,9 +97,9 @@ tsl_cc_test( ], deps = [ ":eigen_helpers", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/tsl/framework/convolution/eigen_spatial_convolutions_test.cc b/third_party/xla/xla/tsl/framework/convolution/eigen_spatial_convolutions_test.cc index 84c70af927a138..85bb2ca40ba670 100644 --- a/third_party/xla/xla/tsl/framework/convolution/eigen_spatial_convolutions_test.cc +++ b/third_party/xla/xla/tsl/framework/convolution/eigen_spatial_convolutions_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "xla/tsl/framework/convolution/eigen_spatial_convolutions.h" #include "absl/strings/str_cat.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" namespace Eigen { diff --git a/third_party/xla/xla/tsl/framework/cpu_allocator_impl.cc b/third_party/xla/xla/tsl/framework/cpu_allocator_impl.cc index 9c9de966cfb67d..043e17d53e538f 100644 --- a/third_party/xla/xla/tsl/framework/cpu_allocator_impl.cc +++ b/third_party/xla/xla/tsl/framework/cpu_allocator_impl.cc @@ -19,11 +19,11 @@ limitations under the License. #include "xla/tsl/framework/allocator.h" #include "xla/tsl/framework/allocator_registry.h" #include "xla/tsl/framework/tracking_allocator.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/mem.h" #include "tsl/platform/mutex.h" #include "tsl/platform/strcat.h" #include "tsl/platform/stringprintf.h" -#include "tsl/platform/types.h" #include "tsl/profiler/lib/scoped_memory_debug_annotation.h" #include "tsl/profiler/lib/traceme.h" diff --git a/third_party/xla/xla/tsl/framework/device_id_manager.cc b/third_party/xla/xla/tsl/framework/device_id_manager.cc index 46d9ba84b406c8..730718918902c7 100644 --- a/third_party/xla/xla/tsl/framework/device_id_manager.cc +++ b/third_party/xla/xla/tsl/framework/device_id_manager.cc @@ -21,12 +21,12 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "xla/tsl/framework/device_id.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/framework/device_id_manager.h b/third_party/xla/xla/tsl/framework/device_id_manager.h index 7802206d6f3443..3de2413f6d4e4f 100644 --- a/third_party/xla/xla/tsl/framework/device_id_manager.h +++ b/third_party/xla/xla/tsl/framework/device_id_manager.h @@ -20,8 +20,8 @@ limitations under the License. #include "xla/tsl/framework/device_id.h" #include "xla/tsl/framework/device_type.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/device_id_utils.cc b/third_party/xla/xla/tsl/framework/device_id_utils.cc index 343674c8d399d8..6d3b65562198b3 100644 --- a/third_party/xla/xla/tsl/framework/device_id_utils.cc +++ b/third_party/xla/xla/tsl/framework/device_id_utils.cc @@ -28,8 +28,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/tsl/framework/device_id.h" #include "xla/tsl/framework/device_id_manager.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/platform/str_util.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/device_id_utils.h b/third_party/xla/xla/tsl/framework/device_id_utils.h index b4552431cc97d5..871bc69bd1ac00 100644 --- a/third_party/xla/xla/tsl/framework/device_id_utils.h +++ b/third_party/xla/xla/tsl/framework/device_id_utils.h @@ -22,9 +22,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "xla/tsl/framework/device_id.h" #include "xla/tsl/framework/device_type.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" #include "xla/tsl/util/device_name_utils.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/device_id_utils_test.cc b/third_party/xla/xla/tsl/framework/device_id_utils_test.cc index 9d2417e59765b2..2a798594e45eb0 100644 --- a/third_party/xla/xla/tsl/framework/device_id_utils_test.cc +++ b/third_party/xla/xla/tsl/framework/device_id_utils_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "xla/tsl/framework/device_id_manager.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/status_matchers.h" #include "xla/tsl/util/device_name_utils.h" -#include "tsl/platform/status_matchers.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/framework/mlir/BUILD b/third_party/xla/xla/tsl/framework/mlir/BUILD index 2a24ecae52e5db..6695742add8fce 100644 --- a/third_party/xla/xla/tsl/framework/mlir/BUILD +++ b/third_party/xla/xla/tsl/framework/mlir/BUILD @@ -20,10 +20,10 @@ cc_library( "status_scoped_diagnostic_handler.h", ], deps = [ + "//xla/tsl/platform:logging", "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:logging", ], ) diff --git a/third_party/xla/xla/tsl/framework/mlir/status_scoped_diagnostic_handler.cc b/third_party/xla/xla/tsl/framework/mlir/status_scoped_diagnostic_handler.cc index 5d2affac30571d..7ba988ecfb82bc 100644 --- a/third_party/xla/xla/tsl/framework/mlir/status_scoped_diagnostic_handler.cc +++ b/third_party/xla/xla/tsl/framework/mlir/status_scoped_diagnostic_handler.cc @@ -22,7 +22,7 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Support/LogicalResult.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/numeric_types.h b/third_party/xla/xla/tsl/framework/numeric_types.h index bfebc279b305bd..e7e7fcd2f8deba 100644 --- a/third_party/xla/xla/tsl/framework/numeric_types.h +++ b/third_party/xla/xla/tsl/framework/numeric_types.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "xla/tsl/framework/fixedpoint_types.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/real_time_in_memory_metric_test.cc b/third_party/xla/xla/tsl/framework/real_time_in_memory_metric_test.cc index 36c5cbb52771ca..726cc74787ed88 100644 --- a/third_party/xla/xla/tsl/framework/real_time_in_memory_metric_test.cc +++ b/third_party/xla/xla/tsl/framework/real_time_in_memory_metric_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/framework/serving_device_selector.cc b/third_party/xla/xla/tsl/framework/serving_device_selector.cc index a0c4f17ec77cc1..a617c1166f8a3b 100644 --- a/third_party/xla/xla/tsl/framework/serving_device_selector.cc +++ b/third_party/xla/xla/tsl/framework/serving_device_selector.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/container/fixed_array.h" #include "absl/strings/string_view.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/serving_device_selector.h b/third_party/xla/xla/tsl/framework/serving_device_selector.h index 7baa9d338dccf6..2a5f6509e5ef5c 100644 --- a/third_party/xla/xla/tsl/framework/serving_device_selector.h +++ b/third_party/xla/xla/tsl/framework/serving_device_selector.h @@ -22,7 +22,7 @@ limitations under the License. #include "absl/container/fixed_array.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/shared_counter.h b/third_party/xla/xla/tsl/framework/shared_counter.h index 8b3eb27d20afa0..79a376757c15e6 100644 --- a/third_party/xla/xla/tsl/framework/shared_counter.h +++ b/third_party/xla/xla/tsl/framework/shared_counter.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { // A lightweight thread-safe monotone counter for establishing diff --git a/third_party/xla/xla/tsl/framework/test_util/BUILD b/third_party/xla/xla/tsl/framework/test_util/BUILD index ac2c9eff584028..a7a91c7708369b 100644 --- a/third_party/xla/xla/tsl/framework/test_util/BUILD +++ b/third_party/xla/xla/tsl/framework/test_util/BUILD @@ -21,8 +21,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//xla/tsl/framework:serving_device_selector", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/tsl/framework/test_util/mock_serving_device_selector.h b/third_party/xla/xla/tsl/framework/test_util/mock_serving_device_selector.h index 80add74bbd413e..4e876ae389d6bf 100644 --- a/third_party/xla/xla/tsl/framework/test_util/mock_serving_device_selector.h +++ b/third_party/xla/xla/tsl/framework/test_util/mock_serving_device_selector.h @@ -20,7 +20,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/tsl/framework/serving_device_selector.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace test_util { diff --git a/third_party/xla/xla/tsl/framework/tracking_allocator.cc b/third_party/xla/xla/tsl/framework/tracking_allocator.cc index 2ef740e602af8c..29e3b5e4386d76 100644 --- a/third_party/xla/xla/tsl/framework/tracking_allocator.cc +++ b/third_party/xla/xla/tsl/framework/tracking_allocator.cc @@ -15,8 +15,8 @@ limitations under the License. #include "xla/tsl/framework/tracking_allocator.h" -#include "tsl/platform/env.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/tracking_allocator.h b/third_party/xla/xla/tsl/framework/tracking_allocator.h index b0e4288fc99617..a0d5260d5f71fb 100644 --- a/third_party/xla/xla/tsl/framework/tracking_allocator.h +++ b/third_party/xla/xla/tsl/framework/tracking_allocator.h @@ -20,9 +20,9 @@ limitations under the License. #include "xla/tsl/framework/allocator.h" #include "xla/tsl/lib/gtl/inlined_vector.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/mutex.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/type_traits.h b/third_party/xla/xla/tsl/framework/type_traits.h index f7a9bd7a54bc91..5aabbf28b4baa8 100644 --- a/third_party/xla/xla/tsl/framework/type_traits.h +++ b/third_party/xla/xla/tsl/framework/type_traits.h @@ -21,7 +21,7 @@ limitations under the License. #include #include "xla/tsl/framework/numeric_types.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/core/BUILD b/third_party/xla/xla/tsl/lib/core/BUILD index 6c2cc90bd2712e..73d443e41181e9 100644 --- a/third_party/xla/xla/tsl/lib/core/BUILD +++ b/third_party/xla/xla/tsl/lib/core/BUILD @@ -40,8 +40,8 @@ cc_library( hdrs = ["status_test_util.h"], compatible_with = get_compatible_with_portable(), deps = [ - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:test", + "//xla/tsl/platform:status_matchers", + "//xla/tsl/platform:test", ], ) @@ -94,8 +94,8 @@ cc_library( hdrs = ["bitmap.h"], compatible_with = get_compatible_with_portable(), deps = [ + "//xla/tsl/platform:logging", "@com_google_absl//absl/numeric:bits", - "@local_tsl//tsl/platform:logging", ], alwayslink = 1, ) @@ -104,8 +104,8 @@ cc_library( name = "bits", hdrs = ["bits.h"], deps = [ + "//xla/tsl/platform:logging", "@com_google_absl//absl/numeric:bits", - "@local_tsl//tsl/platform:logging", ], ) @@ -115,7 +115,7 @@ tsl_cc_test( srcs = ["bits_test.cc"], deps = [ ":bits", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/tsl/lib/core/bitmap.h b/third_party/xla/xla/tsl/lib/core/bitmap.h index 0766cdd339c733..173c0329aa16eb 100644 --- a/third_party/xla/xla/tsl/lib/core/bitmap.h +++ b/third_party/xla/xla/tsl/lib/core/bitmap.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { namespace core { diff --git a/third_party/xla/xla/tsl/lib/core/bitmap_test.cc b/third_party/xla/xla/tsl/lib/core/bitmap_test.cc index bab7f7e4bc9bf5..447a6e59c12da3 100644 --- a/third_party/xla/xla/tsl/lib/core/bitmap_test.cc +++ b/third_party/xla/xla/tsl/lib/core/bitmap_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "xla/tsl/lib/core/bitmap.h" #include "xla/tsl/lib/random/simple_philox.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace core { diff --git a/third_party/xla/xla/tsl/lib/core/bits.h b/third_party/xla/xla/tsl/lib/core/bits.h index 7db02c2f913084..af4d6d251fe9a5 100644 --- a/third_party/xla/xla/tsl/lib/core/bits.h +++ b/third_party/xla/xla/tsl/lib/core/bits.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "absl/numeric/bits.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/core/bits_test.cc b/third_party/xla/xla/tsl/lib/core/bits_test.cc index 8380214bf29e03..65ad4d338af5ec 100644 --- a/third_party/xla/xla/tsl/lib/core/bits_test.cc +++ b/third_party/xla/xla/tsl/lib/core/bits_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/lib/core/status_test_util.h b/third_party/xla/xla/tsl/lib/core/status_test_util.h index 0c8f5d9d50e4ea..a75eab84985877 100644 --- a/third_party/xla/xla/tsl/lib/core/status_test_util.h +++ b/third_party/xla/xla/tsl/lib/core/status_test_util.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef XLA_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ #define XLA_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ -#include "tsl/platform/status_matchers.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/status_matchers.h" +#include "xla/tsl/platform/test.h" // Macros for testing the results of functions that return tensorflow::Status. #define TF_EXPECT_OK(statement) EXPECT_THAT((statement), ::tsl::testing::IsOk()) diff --git a/third_party/xla/xla/tsl/lib/gtl/BUILD b/third_party/xla/xla/tsl/lib/gtl/BUILD index a4ce425862dbca..6adb13bbd60200 100644 --- a/third_party/xla/xla/tsl/lib/gtl/BUILD +++ b/third_party/xla/xla/tsl/lib/gtl/BUILD @@ -49,9 +49,9 @@ cc_library( hdrs = ["flatmap.h"], deps = [ ":flatrep", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform:hash", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:types", ], ) @@ -59,8 +59,8 @@ cc_library( name = "flatrep", hdrs = ["flatrep.h"], deps = [ + "//xla/tsl/platform:types", "@com_google_absl//absl/base:prefetch", - "@local_tsl//tsl/platform:types", ], ) @@ -69,9 +69,9 @@ cc_library( hdrs = ["flatset.h"], deps = [ ":flatrep", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform:hash", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:types", ], ) @@ -79,10 +79,10 @@ cc_library( name = "inlined_vector", hdrs = ["inlined_vector.h"], deps = [ + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:types", ], ) @@ -90,8 +90,8 @@ cc_library( name = "int_type", hdrs = ["int_type.h"], deps = [ - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:types", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", ], ) @@ -222,10 +222,10 @@ tsl_cc_test( ":int_type", ":iterator_range", ":map_util", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform:hash", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/platform:types", ], ) diff --git a/third_party/xla/xla/tsl/lib/gtl/compactptrset_test.cc b/third_party/xla/xla/tsl/lib/gtl/compactptrset_test.cc index 6f5e52dc085047..45ce9a91e272da 100644 --- a/third_party/xla/xla/tsl/lib/gtl/compactptrset_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/compactptrset_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include "xla/tsl/lib/gtl/compactptrset.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/hash.h" -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" namespace tsl { namespace gtl { diff --git a/third_party/xla/xla/tsl/lib/gtl/flatmap.h b/third_party/xla/xla/tsl/lib/gtl/flatmap.h index e74dbd46531d9a..63ece98a408e80 100644 --- a/third_party/xla/xla/tsl/lib/gtl/flatmap.h +++ b/third_party/xla/xla/tsl/lib/gtl/flatmap.h @@ -24,9 +24,9 @@ limitations under the License. #include #include "xla/tsl/lib/gtl/flatrep.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/hash.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/types.h" namespace tsl { namespace gtl { diff --git a/third_party/xla/xla/tsl/lib/gtl/flatmap_test.cc b/third_party/xla/xla/tsl/lib/gtl/flatmap_test.cc index 231970ccbe45ac..2cf4f517bee6cf 100644 --- a/third_party/xla/xla/tsl/lib/gtl/flatmap_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/flatmap_test.cc @@ -22,9 +22,9 @@ limitations under the License. #include #include +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/hash.h" -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" namespace tsl { namespace gtl { diff --git a/third_party/xla/xla/tsl/lib/gtl/flatrep.h b/third_party/xla/xla/tsl/lib/gtl/flatrep.h index 74ae18fc37c0f8..ed772875452c8a 100644 --- a/third_party/xla/xla/tsl/lib/gtl/flatrep.h +++ b/third_party/xla/xla/tsl/lib/gtl/flatrep.h @@ -21,7 +21,7 @@ limitations under the License. #include #include "absl/base/prefetch.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace gtl { diff --git a/third_party/xla/xla/tsl/lib/gtl/flatset.h b/third_party/xla/xla/tsl/lib/gtl/flatset.h index f272ad1fa7bd1d..c4b44b9bb5a349 100644 --- a/third_party/xla/xla/tsl/lib/gtl/flatset.h +++ b/third_party/xla/xla/tsl/lib/gtl/flatset.h @@ -24,9 +24,9 @@ limitations under the License. #include #include "xla/tsl/lib/gtl/flatrep.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/hash.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/types.h" namespace tsl { namespace gtl { diff --git a/third_party/xla/xla/tsl/lib/gtl/flatset_test.cc b/third_party/xla/xla/tsl/lib/gtl/flatset_test.cc index 8adb9133a76ecb..11cd92f5b4ec3f 100644 --- a/third_party/xla/xla/tsl/lib/gtl/flatset_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/flatset_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include #include +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/hash.h" -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" namespace tsl { namespace gtl { diff --git a/third_party/xla/xla/tsl/lib/gtl/inlined_vector.h b/third_party/xla/xla/tsl/lib/gtl/inlined_vector.h index 6072f87ff6931d..40eb3c9f4b744e 100644 --- a/third_party/xla/xla/tsl/lib/gtl/inlined_vector.h +++ b/third_party/xla/xla/tsl/lib/gtl/inlined_vector.h @@ -22,8 +22,8 @@ limitations under the License. #include "absl/container/inlined_vector.h" // IWYU pragma: export // TODO(kramerb): This is kept only because lots of targets transitively depend // on it. Remove all targets' dependencies. -#include "tsl/platform/macros.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" // TODO: b/323943471 - This macro should eventually be provided by Abseil. #ifndef ABSL_DEPRECATE_AND_INLINE diff --git a/third_party/xla/xla/tsl/lib/gtl/int_type.h b/third_party/xla/xla/tsl/lib/gtl/int_type.h index 2a54fc58fada8f..c0760d45cae7c0 100644 --- a/third_party/xla/xla/tsl/lib/gtl/int_type.h +++ b/third_party/xla/xla/tsl/lib/gtl/int_type.h @@ -159,8 +159,8 @@ limitations under the License. #include // NOLINT #include -#include "tsl/platform/macros.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace gtl { diff --git a/third_party/xla/xla/tsl/lib/gtl/int_type_test.cc b/third_party/xla/xla/tsl/lib/gtl/int_type_test.cc index 6ab47039fe1653..85b011ed5bcb19 100644 --- a/third_party/xla/xla/tsl/lib/gtl/int_type_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/int_type_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/gtl/iterator_range_test.cc b/third_party/xla/xla/tsl/lib/gtl/iterator_range_test.cc index 08028094552ff1..d84db4096f2805 100644 --- a/third_party/xla/xla/tsl/lib/gtl/iterator_range_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/iterator_range_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include -#include "tsl/platform/macros.h" -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace gtl { diff --git a/third_party/xla/xla/tsl/lib/gtl/map_util_test.cc b/third_party/xla/xla/tsl/lib/gtl/map_util_test.cc index ce2a13c9e394e9..92ac1d0e1c5e52 100644 --- a/third_party/xla/xla/tsl/lib/gtl/map_util_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/map_util_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/hash/BUILD b/third_party/xla/xla/tsl/lib/hash/BUILD index 9c554e3cc614ae..8e2089595e2bd6 100644 --- a/third_party/xla/xla/tsl/lib/hash/BUILD +++ b/third_party/xla/xla/tsl/lib/hash/BUILD @@ -34,12 +34,12 @@ cc_library( # -msse4.2 enables the use of crc32c compiler builtins. copts = tsl_copts() + if_linux_x86_64(["-msse4.2"]), deps = [ + "//xla/tsl/platform:types", "@com_google_absl//absl/crc:crc32c", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:cord", - "@local_tsl//tsl/platform:types", ], ) @@ -67,11 +67,11 @@ tsl_cc_test( srcs = ["crc32c_test.cc"], deps = [ ":crc32c", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", + "//xla/tsl/platform:types", "@com_google_absl//absl/strings:cord", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/platform:types", ], ) diff --git a/third_party/xla/xla/tsl/lib/hash/crc32c.cc b/third_party/xla/xla/tsl/lib/hash/crc32c.cc index 8ad835fb1d80f8..37d0ed501ce785 100644 --- a/third_party/xla/xla/tsl/lib/hash/crc32c.cc +++ b/third_party/xla/xla/tsl/lib/hash/crc32c.cc @@ -22,7 +22,7 @@ limitations under the License. #include "absl/strings/cord.h" #include "absl/strings/string_view.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace crc32c { diff --git a/third_party/xla/xla/tsl/lib/hash/crc32c.h b/third_party/xla/xla/tsl/lib/hash/crc32c.h index 29c71eed3f0a99..8d797dacf0572f 100644 --- a/third_party/xla/xla/tsl/lib/hash/crc32c.h +++ b/third_party/xla/xla/tsl/lib/hash/crc32c.h @@ -20,9 +20,9 @@ limitations under the License. #include "absl/crc/crc32c.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/cord.h" #include "tsl/platform/platform.h" -#include "tsl/platform/types.h" namespace tsl { namespace crc32c { diff --git a/third_party/xla/xla/tsl/lib/hash/crc32c_test.cc b/third_party/xla/xla/tsl/lib/hash/crc32c_test.cc index 291121d5043f6f..5082e27ac672e4 100644 --- a/third_party/xla/xla/tsl/lib/hash/crc32c_test.cc +++ b/third_party/xla/xla/tsl/lib/hash/crc32c_test.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include "absl/strings/cord.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace crc32c { diff --git a/third_party/xla/xla/tsl/lib/histogram/BUILD b/third_party/xla/xla/tsl/lib/histogram/BUILD index c182d05c27219d..f089754310dc3f 100644 --- a/third_party/xla/xla/tsl/lib/histogram/BUILD +++ b/third_party/xla/xla/tsl/lib/histogram/BUILD @@ -20,12 +20,12 @@ cc_library( hdrs = ["histogram.h"], visibility = ["//visibility:public"], deps = [ + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "//xla/tsl/protobuf:histogram_proto_cc", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -55,9 +55,9 @@ tsl_cc_test( ], deps = [ ":histogram", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "//xla/tsl/protobuf:histogram_proto_cc", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/tsl/lib/histogram/histogram.cc b/third_party/xla/xla/tsl/lib/histogram/histogram.cc index 35ff514e1fe1dd..e333a419fe05e8 100644 --- a/third_party/xla/xla/tsl/lib/histogram/histogram.cc +++ b/third_party/xla/xla/tsl/lib/histogram/histogram.cc @@ -20,10 +20,10 @@ limitations under the License. #include +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/protobuf/histogram.pb.h" -#include "tsl/platform/logging.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/types.h" namespace tsl { namespace histogram { diff --git a/third_party/xla/xla/tsl/lib/histogram/histogram.h b/third_party/xla/xla/tsl/lib/histogram/histogram.h index 64b0cd188e7222..88fe7be62dafb3 100644 --- a/third_party/xla/xla/tsl/lib/histogram/histogram.h +++ b/third_party/xla/xla/tsl/lib/histogram/histogram.h @@ -19,10 +19,10 @@ limitations under the License. #include #include -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/mutex.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/platform/types.h" namespace tensorflow { class HistogramProto; diff --git a/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc b/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc index 1b2f1827521a17..42268a44b0cce5 100644 --- a/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc +++ b/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test.h" #include "xla/tsl/protobuf/histogram.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test.h" namespace tsl { namespace histogram { diff --git a/third_party/xla/xla/tsl/lib/io/BUILD b/third_party/xla/xla/tsl/lib/io/BUILD index 7422d6eba391d0..e7d4dc4ce02a62 100644 --- a/third_party/xla/xla/tsl/lib/io/BUILD +++ b/third_party/xla/xla/tsl/lib/io/BUILD @@ -43,15 +43,15 @@ cc_library( ":iterator", ":table_options", "//xla/tsl/lib/hash:crc32c", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform:coding", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:raw_coding", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:stringpiece", - "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -63,8 +63,8 @@ cc_library( deps = [ ":inputstream_interface", ":random_inputstream", + "//xla/tsl/platform:env", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:env", ], alwayslink = True, ) @@ -81,13 +81,13 @@ cc_library( srcs = ["inputbuffer.cc"], hdrs = ["inputbuffer.h"], deps = [ + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform:coding", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -97,10 +97,10 @@ cc_library( srcs = ["inputstream_interface.cc"], hdrs = ["inputstream_interface.h"], deps = [ + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform:cord", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -110,7 +110,7 @@ cc_library( srcs = ["iterator.cc"], hdrs = ["iterator.h"], deps = [ - "@local_tsl//tsl/platform:status", + "//xla/tsl/platform:status", "@local_tsl//tsl/platform:stringpiece", ], alwayslink = True, @@ -120,8 +120,8 @@ cc_library( name = "proto_encode_helper", hdrs = ["proto_encode_helper.h"], deps = [ + "//xla/tsl/platform:logging", "@local_tsl//tsl/platform:coding", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:stringpiece", ], @@ -133,8 +133,8 @@ cc_library( hdrs = ["random_inputstream.h"], deps = [ ":inputstream_interface", + "//xla/tsl/platform:env", "@local_tsl//tsl/platform:cord", - "@local_tsl//tsl/platform:env", ], alwayslink = True, ) @@ -153,12 +153,12 @@ cc_library( ":zlib_compression_options", ":zlib_inputstream", "//xla/tsl/lib/hash:crc32c", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:macros", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform:raw_coding", "@local_tsl//tsl/platform:stringpiece", - "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -174,13 +174,13 @@ cc_library( ":zlib_compression_options", ":zlib_outputbuffer", "//xla/tsl/lib/hash:crc32c", + "//xla/tsl/platform:env", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform:coding", "@local_tsl//tsl/platform:cord", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:stringpiece", - "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -235,9 +235,9 @@ cc_library( ":cache", ":iterator", ":table_options", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", "@local_tsl//tsl/platform:coding", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", ], alwayslink = True, ) @@ -253,9 +253,9 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//xla/tsl/lib/hash:crc32c", + "//xla/tsl/platform:env", + "//xla/tsl/platform:status", "@local_tsl//tsl/platform:cord", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:status", ], ) @@ -266,11 +266,11 @@ tsl_cc_test( deps = [ ":buffered_file", "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", ], ) @@ -279,7 +279,7 @@ cc_library( srcs = ["zlib_compression_options.cc"], hdrs = ["zlib_compression_options.h"], deps = [ - "@local_tsl//tsl/platform:types", + "//xla/tsl/platform:types", "@zlib", ], alwayslink = True, @@ -292,12 +292,12 @@ cc_library( deps = [ ":inputstream_interface", ":zlib_compression_options", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:status", + "//xla/tsl/platform:env", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform:strcat", - "@local_tsl//tsl/platform:types", "@zlib", ], alwayslink = True, @@ -309,12 +309,12 @@ cc_library( hdrs = ["zlib_outputbuffer.h"], deps = [ ":zlib_compression_options", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:status", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform:stringpiece", - "@local_tsl//tsl/platform:types", "@zlib", ], alwayslink = True, @@ -446,11 +446,11 @@ tsl_cc_test( ":buffered_inputstream", ":random_inputstream", "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", ], ) @@ -460,10 +460,10 @@ tsl_cc_test( srcs = ["cache_test.cc"], deps = [ ":cache", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@local_tsl//tsl/platform:coding", "@local_tsl//tsl/platform:raw_coding", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -474,16 +474,16 @@ tsl_cc_test( deps = [ ":inputbuffer", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:status", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@local_tsl//tsl/platform:coding", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:str_util", "@local_tsl//tsl/platform:strcat", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -494,9 +494,9 @@ tsl_cc_test( deps = [ ":inputstream_interface", "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", ], ) @@ -507,10 +507,10 @@ tsl_cc_test( deps = [ ":random_inputstream", "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", ], ) @@ -522,14 +522,14 @@ tsl_cc_test( ":record_reader", ":record_writer", "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:status", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@local_tsl//tsl/platform:strcat", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", "@zlib", ], ) @@ -544,13 +544,13 @@ tsl_cc_test( "//xla/tsl/lib/core:status_test_util", "//xla/tsl/lib/hash:crc32c", "//xla/tsl/lib/random:philox", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@local_tsl//tsl/platform:coding", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:str_util", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -563,13 +563,13 @@ tsl_cc_test( ":iterator", ":table", "//xla/tsl/lib/random:philox", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -583,11 +583,11 @@ tsl_cc_test( ":zlib_inputstream", ":zlib_outputbuffer", "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:errors", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@local_tsl//tsl/platform:strcat", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/tsl/lib/io/block.cc b/third_party/xla/xla/tsl/lib/io/block.cc index ae1d40bff71628..2b76696c0d6b01 100644 --- a/third_party/xla/xla/tsl/lib/io/block.cc +++ b/third_party/xla/xla/tsl/lib/io/block.cc @@ -20,9 +20,9 @@ limitations under the License. #include #include "xla/tsl/lib/io/format.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/coding.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" #include "tsl/platform/raw_coding.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/io/block_builder.h b/third_party/xla/xla/tsl/lib/io/block_builder.h index 0defea6d866e0f..aef643d3395738 100644 --- a/third_party/xla/xla/tsl/lib/io/block_builder.h +++ b/third_party/xla/xla/tsl/lib/io/block_builder.h @@ -20,8 +20,8 @@ limitations under the License. #include +#include "xla/tsl/platform/types.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" namespace tsl { namespace table { diff --git a/third_party/xla/xla/tsl/lib/io/buffered_file.h b/third_party/xla/xla/tsl/lib/io/buffered_file.h index 6d173c83d12530..6fc9b994258411 100644 --- a/third_party/xla/xla/tsl/lib/io/buffered_file.h +++ b/third_party/xla/xla/tsl/lib/io/buffered_file.h @@ -22,9 +22,9 @@ limitations under the License. #include #include "xla/tsl/lib/hash/crc32c.h" +#include "xla/tsl/platform/file_system.h" +#include "xla/tsl/platform/status.h" #include "tsl/platform/cord.h" -#include "tsl/platform/file_system.h" -#include "tsl/platform/status.h" namespace tsl { class BufferedWritableFile : public WritableFile { diff --git a/third_party/xla/xla/tsl/lib/io/buffered_file_test.cc b/third_party/xla/xla/tsl/lib/io/buffered_file_test.cc index 2c3fc0fe5070ca..f1faf55ef5353f 100644 --- a/third_party/xla/xla/tsl/lib/io/buffered_file_test.cc +++ b/third_party/xla/xla/tsl/lib/io/buffered_file_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/platform/env.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/buffered_inputstream.h b/third_party/xla/xla/tsl/lib/io/buffered_inputstream.h index 1a187012766ab1..a06c79be944151 100644 --- a/third_party/xla/xla/tsl/lib/io/buffered_inputstream.h +++ b/third_party/xla/xla/tsl/lib/io/buffered_inputstream.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "xla/tsl/lib/io/inputstream_interface.h" -#include "tsl/platform/file_system.h" +#include "xla/tsl/platform/file_system.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/buffered_inputstream_test.cc b/third_party/xla/xla/tsl/lib/io/buffered_inputstream_test.cc index e7ad2c037844bd..3686ab55904bb1 100644 --- a/third_party/xla/xla/tsl/lib/io/buffered_inputstream_test.cc +++ b/third_party/xla/xla/tsl/lib/io/buffered_inputstream_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/lib/io/random_inputstream.h" -#include "tsl/platform/env.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/cache_test.cc b/third_party/xla/xla/tsl/lib/io/cache_test.cc index 3c54c82a11ac25..7e7f10faf5582c 100644 --- a/third_party/xla/xla/tsl/lib/io/cache_test.cc +++ b/third_party/xla/xla/tsl/lib/io/cache_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include +#include "xla/tsl/platform/test.h" #include "tsl/platform/coding.h" #include "tsl/platform/raw_coding.h" -#include "tsl/platform/test.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/io/format.cc b/third_party/xla/xla/tsl/lib/io/format.cc index e02451c08d7e0e..1982d3ff407c51 100644 --- a/third_party/xla/xla/tsl/lib/io/format.cc +++ b/third_party/xla/xla/tsl/lib/io/format.cc @@ -19,9 +19,9 @@ limitations under the License. #include "xla/tsl/lib/hash/crc32c.h" #include "xla/tsl/lib/io/block.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" #include "tsl/platform/coding.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" #include "tsl/platform/raw_coding.h" #include "tsl/platform/snappy.h" diff --git a/third_party/xla/xla/tsl/lib/io/format.h b/third_party/xla/xla/tsl/lib/io/format.h index 3cf5d6312a5f02..408be574f6b059 100644 --- a/third_party/xla/xla/tsl/lib/io/format.h +++ b/third_party/xla/xla/tsl/lib/io/format.h @@ -21,7 +21,7 @@ limitations under the License. #include #include "xla/tsl/lib/io/table_builder.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" #include "tsl/platform/stringpiece.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/io/inputbuffer.cc b/third_party/xla/xla/tsl/lib/io/inputbuffer.cc index 5fdff4943331ed..e7823794df8f76 100644 --- a/third_party/xla/xla/tsl/lib/io/inputbuffer.cc +++ b/third_party/xla/xla/tsl/lib/io/inputbuffer.cc @@ -17,8 +17,8 @@ limitations under the License. #include -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/inputbuffer.h b/third_party/xla/xla/tsl/lib/io/inputbuffer.h index bec656ecd00ef6..1d9db6bf19c5ad 100644 --- a/third_party/xla/xla/tsl/lib/io/inputbuffer.h +++ b/third_party/xla/xla/tsl/lib/io/inputbuffer.h @@ -18,11 +18,11 @@ limitations under the License. #include +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/coding.h" -#include "tsl/platform/env.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/status.h" -#include "tsl/platform/types.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/inputbuffer_test.cc b/third_party/xla/xla/tsl/lib/io/inputbuffer_test.cc index a4d170101ea675..555e664934b256 100644 --- a/third_party/xla/xla/tsl/lib/io/inputbuffer_test.cc +++ b/third_party/xla/xla/tsl/lib/io/inputbuffer_test.cc @@ -18,14 +18,14 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/coding.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/status.h" #include "tsl/platform/str_util.h" #include "tsl/platform/strcat.h" -#include "tsl/platform/test.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/lib/io/inputstream_interface.cc b/third_party/xla/xla/tsl/lib/io/inputstream_interface.cc index 7bf261f6757609..4faaa07bcd9cb2 100644 --- a/third_party/xla/xla/tsl/lib/io/inputstream_interface.cc +++ b/third_party/xla/xla/tsl/lib/io/inputstream_interface.cc @@ -15,7 +15,7 @@ limitations under the License. #include "xla/tsl/lib/io/inputstream_interface.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/inputstream_interface.h b/third_party/xla/xla/tsl/lib/io/inputstream_interface.h index 3ecb5b5af9e8df..bde311a7cb4a23 100644 --- a/third_party/xla/xla/tsl/lib/io/inputstream_interface.h +++ b/third_party/xla/xla/tsl/lib/io/inputstream_interface.h @@ -18,10 +18,10 @@ limitations under the License. #include +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/cord.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/types.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/inputstream_interface_test.cc b/third_party/xla/xla/tsl/lib/io/inputstream_interface_test.cc index 9021440b6e1d84..524345aaaf417b 100644 --- a/third_party/xla/xla/tsl/lib/io/inputstream_interface_test.cc +++ b/third_party/xla/xla/tsl/lib/io/inputstream_interface_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "xla/tsl/lib/io/inputstream_interface.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/iterator.h b/third_party/xla/xla/tsl/lib/io/iterator.h index ba0b1dbc4b76de..23774db476a122 100644 --- a/third_party/xla/xla/tsl/lib/io/iterator.h +++ b/third_party/xla/xla/tsl/lib/io/iterator.h @@ -26,7 +26,7 @@ limitations under the License. #ifndef XLA_TSL_LIB_IO_ITERATOR_H_ #define XLA_TSL_LIB_IO_ITERATOR_H_ -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" #include "tsl/platform/stringpiece.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/io/proto_encode_helper.h b/third_party/xla/xla/tsl/lib/io/proto_encode_helper.h index 33c2411cbc3ca3..a63a4f950f466d 100644 --- a/third_party/xla/xla/tsl/lib/io/proto_encode_helper.h +++ b/third_party/xla/xla/tsl/lib/io/proto_encode_helper.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef XLA_TSL_LIB_IO_PROTO_ENCODE_HELPER_H_ #define XLA_TSL_LIB_IO_PROTO_ENCODE_HELPER_H_ +#include "xla/tsl/platform/logging.h" #include "tsl/platform/coding.h" -#include "tsl/platform/logging.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/stringpiece.h" diff --git a/third_party/xla/xla/tsl/lib/io/random_inputstream.h b/third_party/xla/xla/tsl/lib/io/random_inputstream.h index 99685ab055ac6a..04f5765469c3ac 100644 --- a/third_party/xla/xla/tsl/lib/io/random_inputstream.h +++ b/third_party/xla/xla/tsl/lib/io/random_inputstream.h @@ -17,8 +17,8 @@ limitations under the License. #define XLA_TSL_LIB_IO_RANDOM_INPUTSTREAM_H_ #include "xla/tsl/lib/io/inputstream_interface.h" +#include "xla/tsl/platform/file_system.h" #include "tsl/platform/cord.h" -#include "tsl/platform/file_system.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/random_inputstream_test.cc b/third_party/xla/xla/tsl/lib/io/random_inputstream_test.cc index e2fc82374e47bb..1a50021e8191e7 100644 --- a/third_party/xla/xla/tsl/lib/io/random_inputstream_test.cc +++ b/third_party/xla/xla/tsl/lib/io/random_inputstream_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "xla/tsl/lib/io/random_inputstream.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/platform/env.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/record_reader.cc b/third_party/xla/xla/tsl/lib/io/record_reader.cc index 8332debff876c2..6421a616b2d38d 100644 --- a/third_party/xla/xla/tsl/lib/io/record_reader.cc +++ b/third_party/xla/xla/tsl/lib/io/record_reader.cc @@ -21,8 +21,8 @@ limitations under the License. #include "xla/tsl/lib/io/buffered_inputstream.h" #include "xla/tsl/lib/io/compression.h" #include "xla/tsl/lib/io/random_inputstream.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" #include "tsl/platform/raw_coding.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/io/record_reader.h b/third_party/xla/xla/tsl/lib/io/record_reader.h index 3c18992ec86279..8f144148ca33f5 100644 --- a/third_party/xla/xla/tsl/lib/io/record_reader.h +++ b/third_party/xla/xla/tsl/lib/io/record_reader.h @@ -17,7 +17,7 @@ limitations under the License. #define XLA_TSL_LIB_IO_RECORD_READER_H_ #include "xla/tsl/lib/io/inputstream_interface.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" #include "tsl/platform/stringpiece.h" #if !defined(IS_SLIM_BUILD) #include "xla/tsl/lib/io/snappy/snappy_compression_options.h" @@ -25,8 +25,8 @@ limitations under the License. #include "xla/tsl/lib/io/zlib_compression_options.h" #include "xla/tsl/lib/io/zlib_inputstream.h" #endif // IS_SLIM_BUILD -#include "tsl/platform/macros.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" namespace tsl { class RandomAccessFile; diff --git a/third_party/xla/xla/tsl/lib/io/record_reader_writer_test.cc b/third_party/xla/xla/tsl/lib/io/record_reader_writer_test.cc index e91f1ecaed1b99..2220a3ba0cc63c 100644 --- a/third_party/xla/xla/tsl/lib/io/record_reader_writer_test.cc +++ b/third_party/xla/xla/tsl/lib/io/record_reader_writer_test.cc @@ -24,12 +24,12 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/strcat.h" -#include "tsl/platform/test.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/io/record_writer.cc b/third_party/xla/xla/tsl/lib/io/record_writer.cc index ce6289a014fe04..985e415f632a61 100644 --- a/third_party/xla/xla/tsl/lib/io/record_writer.cc +++ b/third_party/xla/xla/tsl/lib/io/record_writer.cc @@ -17,8 +17,8 @@ limitations under the License. #include "xla/tsl/lib/hash/crc32c.h" #include "xla/tsl/lib/io/compression.h" +#include "xla/tsl/platform/env.h" #include "tsl/platform/coding.h" -#include "tsl/platform/env.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/record_writer.h b/third_party/xla/xla/tsl/lib/io/record_writer.h index 5cb160790b9f1c..ced0bc687a6e28 100644 --- a/third_party/xla/xla/tsl/lib/io/record_writer.h +++ b/third_party/xla/xla/tsl/lib/io/record_writer.h @@ -17,8 +17,8 @@ limitations under the License. #define XLA_TSL_LIB_IO_RECORD_WRITER_H_ #include "xla/tsl/lib/hash/crc32c.h" +#include "xla/tsl/platform/status.h" #include "tsl/platform/coding.h" -#include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" #if !defined(IS_SLIM_BUILD) #include "xla/tsl/lib/io/snappy/snappy_compression_options.h" @@ -26,9 +26,9 @@ limitations under the License. #include "xla/tsl/lib/io/zlib_compression_options.h" #include "xla/tsl/lib/io/zlib_outputbuffer.h" #endif // IS_SLIM_BUILD +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/cord.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/io/recordio_test.cc b/third_party/xla/xla/tsl/lib/io/recordio_test.cc index 02d22ec4931218..9c31aa7eeda825 100644 --- a/third_party/xla/xla/tsl/lib/io/recordio_test.cc +++ b/third_party/xla/xla/tsl/lib/io/recordio_test.cc @@ -18,11 +18,11 @@ limitations under the License. #include "xla/tsl/lib/io/record_reader.h" #include "xla/tsl/lib/io/record_writer.h" #include "xla/tsl/lib/random/simple_philox.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/coding.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" #include "tsl/platform/str_util.h" -#include "tsl/platform/test.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/snappy/BUILD b/third_party/xla/xla/tsl/lib/io/snappy/BUILD index bf19bacf9af44f..ed8ae4d65ff269 100644 --- a/third_party/xla/xla/tsl/lib/io/snappy/BUILD +++ b/third_party/xla/xla/tsl/lib/io/snappy/BUILD @@ -35,12 +35,12 @@ cc_library( hdrs = ["snappy_inputbuffer.h"], deps = [ "//xla/tsl/lib/io:inputstream_interface", + "//xla/tsl/platform:env", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -50,12 +50,12 @@ cc_library( srcs = ["snappy_outputbuffer.cc"], hdrs = ["snappy_outputbuffer.h"], deps = [ + "//xla/tsl/platform:env", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -66,8 +66,8 @@ cc_library( hdrs = ["snappy_inputstream.h"], deps = [ "//xla/tsl/lib/io:inputstream_interface", + "//xla/tsl/platform:errors", "@com_google_absl//absl/memory", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:platform_port", ], alwayslink = True, @@ -77,7 +77,7 @@ cc_library( name = "snappy_compression_options", hdrs = ["snappy_compression_options.h"], deps = [ - "@local_tsl//tsl/platform:types", + "//xla/tsl/platform:types", ], alwayslink = True, ) @@ -93,9 +93,9 @@ tsl_cc_test( "//xla/tsl/lib/core:status_test_util", "//xla/tsl/lib/io:inputbuffer", "//xla/tsl/lib/io:random_inputstream", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/tsl/lib/io/snappy/snappy_compression_options.h b/third_party/xla/xla/tsl/lib/io/snappy/snappy_compression_options.h index 3772a415056cf9..3dbf2ead90fe59 100644 --- a/third_party/xla/xla/tsl/lib/io/snappy/snappy_compression_options.h +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_compression_options.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_TSL_LIB_IO_SNAPPY_SNAPPY_COMPRESSION_OPTIONS_H_ #define XLA_TSL_LIB_IO_SNAPPY_SNAPPY_COMPRESSION_OPTIONS_H_ -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.h b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.h index 969c1e00c2bfe3..8688e368719828 100644 --- a/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.h +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.h @@ -20,11 +20,11 @@ limitations under the License. #include #include "xla/tsl/lib/io/inputstream_interface.h" -#include "tsl/platform/env.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/snappy.h" -#include "tsl/platform/status.h" -#include "tsl/platform/types.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.cc b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.cc index bcbe96e21139e7..980807326e51ae 100644 --- a/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.cc +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" #include "tsl/platform/snappy.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.h b/third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.h index 631014c3b6e189..d48ded2196a454 100644 --- a/third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.h +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.h @@ -19,12 +19,12 @@ limitations under the License. #include #include -#include "tsl/platform/env.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/platform.h" #include "tsl/platform/snappy.h" -#include "tsl/platform/status.h" -#include "tsl/platform/types.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/snappy/snappy_test.cc b/third_party/xla/xla/tsl/lib/io/snappy/snappy_test.cc index f3504e9268a76e..d7eb301c5f8bf3 100644 --- a/third_party/xla/xla/tsl/lib/io/snappy/snappy_test.cc +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "xla/tsl/lib/io/snappy/snappy_inputbuffer.h" #include "xla/tsl/lib/io/snappy/snappy_inputstream.h" #include "xla/tsl/lib/io/snappy/snappy_outputbuffer.h" -#include "tsl/platform/env.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/test.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/io/table.cc b/third_party/xla/xla/tsl/lib/io/table.cc index 5c36b4649859b8..d3030af7aba0c0 100644 --- a/third_party/xla/xla/tsl/lib/io/table.cc +++ b/third_party/xla/xla/tsl/lib/io/table.cc @@ -20,9 +20,9 @@ limitations under the License. #include "xla/tsl/lib/io/format.h" #include "xla/tsl/lib/io/table_options.h" #include "xla/tsl/lib/io/two_level_iterator.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" #include "tsl/platform/coding.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" namespace tsl { namespace table { diff --git a/third_party/xla/xla/tsl/lib/io/table_builder.cc b/third_party/xla/xla/tsl/lib/io/table_builder.cc index b5fcb0c9ed47dc..f7a18b5e9a946b 100644 --- a/third_party/xla/xla/tsl/lib/io/table_builder.cc +++ b/third_party/xla/xla/tsl/lib/io/table_builder.cc @@ -21,9 +21,9 @@ limitations under the License. #include "xla/tsl/lib/io/block_builder.h" #include "xla/tsl/lib/io/format.h" #include "xla/tsl/lib/io/table_options.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" #include "tsl/platform/coding.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" #include "tsl/platform/snappy.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/io/table_builder.h b/third_party/xla/xla/tsl/lib/io/table_builder.h index 059f9ab60546c1..a9ad59a7b89db9 100644 --- a/third_party/xla/xla/tsl/lib/io/table_builder.h +++ b/third_party/xla/xla/tsl/lib/io/table_builder.h @@ -27,7 +27,7 @@ limitations under the License. #include #include "xla/tsl/lib/io/table_options.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" #include "tsl/platform/stringpiece.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/io/table_test.cc b/third_party/xla/xla/tsl/lib/io/table_test.cc index 6671bc816abc17..ead7d32986ad0d 100644 --- a/third_party/xla/xla/tsl/lib/io/table_test.cc +++ b/third_party/xla/xla/tsl/lib/io/table_test.cc @@ -27,10 +27,10 @@ limitations under the License. #include "xla/tsl/lib/io/iterator.h" #include "xla/tsl/lib/io/table_builder.h" #include "xla/tsl/lib/random/simple_philox.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/snappy.h" -#include "tsl/platform/test.h" namespace tsl { namespace table { diff --git a/third_party/xla/xla/tsl/lib/io/zlib_buffers_test.cc b/third_party/xla/xla/tsl/lib/io/zlib_buffers_test.cc index c66d9229e480c9..89c1dcb468202e 100644 --- a/third_party/xla/xla/tsl/lib/io/zlib_buffers_test.cc +++ b/third_party/xla/xla/tsl/lib/io/zlib_buffers_test.cc @@ -18,10 +18,10 @@ limitations under the License. #include "xla/tsl/lib/io/zlib_compression_options.h" #include "xla/tsl/lib/io/zlib_inputstream.h" #include "xla/tsl/lib/io/zlib_outputbuffer.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/strcat.h" -#include "tsl/platform/test.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/zlib_compression_options.h b/third_party/xla/xla/tsl/lib/io/zlib_compression_options.h index 0cae3a2ef54128..b0cb2f05724642 100644 --- a/third_party/xla/xla/tsl/lib/io/zlib_compression_options.h +++ b/third_party/xla/xla/tsl/lib/io/zlib_compression_options.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_TSL_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ #define XLA_TSL_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/zlib_inputstream.cc b/third_party/xla/xla/tsl/lib/io/zlib_inputstream.cc index fda83637279579..b5bfcd5b478e91 100644 --- a/third_party/xla/xla/tsl/lib/io/zlib_inputstream.cc +++ b/third_party/xla/xla/tsl/lib/io/zlib_inputstream.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/strcat.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/io/zlib_inputstream.h b/third_party/xla/xla/tsl/lib/io/zlib_inputstream.h index 16df9508636019..46d78fac0a8681 100644 --- a/third_party/xla/xla/tsl/lib/io/zlib_inputstream.h +++ b/third_party/xla/xla/tsl/lib/io/zlib_inputstream.h @@ -20,10 +20,10 @@ limitations under the License. #include "xla/tsl/lib/io/inputstream_interface.h" #include "xla/tsl/lib/io/zlib_compression_options.h" -#include "tsl/platform/env.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/status.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.cc b/third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.cc index 646e4397898841..483ab8d9691fac 100644 --- a/third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.cc +++ b/third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.cc @@ -15,7 +15,7 @@ limitations under the License. #include "xla/tsl/lib/io/zlib_outputbuffer.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.h b/third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.h index 96b1d1bb9704da..3d7e3024993ee9 100644 --- a/third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.h +++ b/third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.h @@ -21,12 +21,12 @@ limitations under the License. #include #include "xla/tsl/lib/io/zlib_compression_options.h" -#include "tsl/platform/env.h" -#include "tsl/platform/file_system.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/file_system.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" namespace tsl { namespace io { diff --git a/third_party/xla/xla/tsl/lib/math/BUILD b/third_party/xla/xla/tsl/lib/math/BUILD index 137ff9aa961336..f0af1e91a9ddd5 100644 --- a/third_party/xla/xla/tsl/lib/math/BUILD +++ b/third_party/xla/xla/tsl/lib/math/BUILD @@ -29,11 +29,11 @@ tsl_cc_test( ], deps = [ ":math_util", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/platform:types", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", + "//xla/tsl/platform:types", ], ) diff --git a/third_party/xla/xla/tsl/lib/math/math_util_test.cc b/third_party/xla/xla/tsl/lib/math/math_util_test.cc index c60f9796695ceb..b7a91877b1168c 100644 --- a/third_party/xla/xla/tsl/lib/math/math_util_test.cc +++ b/third_party/xla/xla/tsl/lib/math/math_util_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include -#include "tsl/platform/logging.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/lib/monitoring/BUILD b/third_party/xla/xla/tsl/lib/monitoring/BUILD index ee0c361d22a21c..7fe002a48969c5 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/BUILD +++ b/third_party/xla/xla/tsl/lib/monitoring/BUILD @@ -39,13 +39,13 @@ cc_library( deps = [ ":collection_registry", ":metric_def", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", ], ) @@ -57,12 +57,12 @@ cc_library( deps = [ ":collection_registry", ":metric_def", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", ], ) @@ -74,15 +74,15 @@ cc_library( ":collection_registry", ":metric_def", "//xla/tsl/lib/histogram", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "//xla/tsl/protobuf:histogram_proto_cc", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@local_tsl//tsl/platform", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", ], ) @@ -92,7 +92,7 @@ cc_library( "types.h", ], deps = [ - "@local_tsl//tsl/platform:types", + "//xla/tsl/platform:types", ], ) @@ -102,9 +102,9 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":types", + "//xla/tsl/platform:types", "//xla/tsl/protobuf:histogram_proto_cc", "@local_tsl//tsl/platform:stringpiece", - "@local_tsl//tsl/platform:types", ], ) @@ -117,15 +117,15 @@ cc_library( ":collected_metrics", ":metric_def", ":types", + "//xla/tsl/platform:env", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "//xla/tsl/protobuf:histogram_proto_cc", "@local_tsl//tsl/platform", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:stringpiece", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", ], ) @@ -165,14 +165,14 @@ cc_library( ":collection_registry", ":metric_def", ":test_utils", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:types", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:types", ], ) @@ -184,14 +184,14 @@ cc_library( ":collection_registry", ":metric_def", ":types", + "//xla/tsl/platform:env_time", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@com_google_absl//absl/status", "@local_tsl//tsl/platform", - "@local_tsl//tsl/platform:env_time", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", ], ) @@ -203,12 +203,12 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":types", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "//xla/tsl/protobuf:histogram_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", ], ) @@ -218,8 +218,8 @@ cc_library( "timed.h", ], deps = [ - "@local_tsl//tsl/platform:env_time", - "@local_tsl//tsl/platform:types", + "//xla/tsl/platform:env_time", + "//xla/tsl/platform:types", ], ) diff --git a/third_party/xla/xla/tsl/lib/monitoring/cell_reader-inl.cc b/third_party/xla/xla/tsl/lib/monitoring/cell_reader-inl.cc index 6f7f21d4b7732b..69a5536fae9034 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/cell_reader-inl.cc +++ b/third_party/xla/xla/tsl/lib/monitoring/cell_reader-inl.cc @@ -27,9 +27,9 @@ limitations under the License. #include "xla/tsl/lib/monitoring/collection_registry.h" #include "xla/tsl/lib/monitoring/metric_def.h" #include "xla/tsl/lib/monitoring/test_utils.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/cell_reader-inl.h b/third_party/xla/xla/tsl/lib/monitoring/cell_reader-inl.h index e58b1ee9698dad..8eb263ba4c0424 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/cell_reader-inl.h +++ b/third_party/xla/xla/tsl/lib/monitoring/cell_reader-inl.h @@ -27,9 +27,9 @@ limitations under the License. #include "xla/tsl/lib/monitoring/collected_metrics.h" #include "xla/tsl/lib/monitoring/metric_def.h" #include "xla/tsl/lib/monitoring/test_utils.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/collection_registry.cc b/third_party/xla/xla/tsl/lib/monitoring/collection_registry.cc index 90ce825e4a4db7..fbeccc3c617348 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/collection_registry.cc +++ b/third_party/xla/xla/tsl/lib/monitoring/collection_registry.cc @@ -17,16 +17,16 @@ limitations under the License. #include "xla/tsl/lib/monitoring/collected_metrics.h" #include "xla/tsl/lib/monitoring/metric_def.h" -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/mutex.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" // We replace this implementation with a null implementation for mobile // platforms. #ifndef IS_MOBILE_PLATFORM -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/collection_registry.h b/third_party/xla/xla/tsl/lib/monitoring/collection_registry.h index 6c48ea9114c8db..e2d370a27c4862 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/collection_registry.h +++ b/third_party/xla/xla/tsl/lib/monitoring/collection_registry.h @@ -35,7 +35,7 @@ class CollectionRegistryTestAccess; #include #include "xla/tsl/lib/monitoring/metric_def.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" namespace tsl { namespace monitoring { @@ -110,14 +110,14 @@ class CollectionRegistry { #include "xla/tsl/lib/monitoring/collected_metrics.h" #include "xla/tsl/lib/monitoring/metric_def.h" #include "xla/tsl/lib/monitoring/types.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/protobuf/histogram.pb.h" -#include "tsl/platform/env.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" #include "tsl/platform/mutex.h" #include "tsl/platform/stringpiece.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/platform/types.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/counter.h b/third_party/xla/xla/tsl/lib/monitoring/counter.h index e219512e2d6794..72777585afd70c 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/counter.h +++ b/third_party/xla/xla/tsl/lib/monitoring/counter.h @@ -25,9 +25,9 @@ limitations under the License. // platforms. #ifdef IS_MOBILE_PLATFORM -#include "tsl/platform/macros.h" -#include "tsl/platform/status.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace monitoring { @@ -86,10 +86,10 @@ class Counter { #include "xla/tsl/lib/monitoring/collection_registry.h" #include "xla/tsl/lib/monitoring/metric_def.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/status.h" #include "tsl/platform/thread_annotations.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/monitoring/gauge.h b/third_party/xla/xla/tsl/lib/monitoring/gauge.h index eac1ea94249c12..2b1c7f8e1bd2f1 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/gauge.h +++ b/third_party/xla/xla/tsl/lib/monitoring/gauge.h @@ -28,9 +28,9 @@ limitations under the License. #include #include -#include "tsl/platform/macros.h" -#include "tsl/platform/status.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace monitoring { @@ -102,11 +102,11 @@ class Gauge { #include "xla/tsl/lib/monitoring/collection_registry.h" #include "xla/tsl/lib/monitoring/metric_def.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/status.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/platform/types.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/metric_def.h b/third_party/xla/xla/tsl/lib/monitoring/metric_def.h index dcee3f92db4c30..82896f43a7e77e 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/metric_def.h +++ b/third_party/xla/xla/tsl/lib/monitoring/metric_def.h @@ -22,9 +22,9 @@ limitations under the License. #include #include "xla/tsl/lib/monitoring/types.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/protobuf/histogram.pb.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/percentile_sampler.cc b/third_party/xla/xla/tsl/lib/monitoring/percentile_sampler.cc index 46e71d1d30a51a..f298b9f81c0999 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/percentile_sampler.cc +++ b/third_party/xla/xla/tsl/lib/monitoring/percentile_sampler.cc @@ -20,10 +20,10 @@ limitations under the License. #include #include "xla/tsl/lib/monitoring/types.h" -#include "tsl/platform/env_time.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/env_time.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/types.h" // We replace this implementation with a null implementation for mobile // platforms. diff --git a/third_party/xla/xla/tsl/lib/monitoring/percentile_sampler.h b/third_party/xla/xla/tsl/lib/monitoring/percentile_sampler.h index d419eb1934c5c4..5ee2ceea488d66 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/percentile_sampler.h +++ b/third_party/xla/xla/tsl/lib/monitoring/percentile_sampler.h @@ -20,7 +20,7 @@ limitations under the License. // Required for IS_MOBILE_PLATFORM #include "absl/status/status.h" #include "tsl/platform/platform.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" // clang-format on // We replace this implementation with a null implementation for mobile @@ -30,8 +30,8 @@ limitations under the License. #include "xla/tsl/lib/monitoring/collection_registry.h" #include "xla/tsl/lib/monitoring/metric_def.h" #include "xla/tsl/lib/monitoring/types.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" namespace tsl { namespace monitoring { @@ -88,9 +88,9 @@ PercentileSampler* PercentileSampler::New( #include "xla/tsl/lib/monitoring/collection_registry.h" #include "xla/tsl/lib/monitoring/metric_def.h" #include "xla/tsl/lib/monitoring/types.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/status.h" #include "tsl/platform/thread_annotations.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/monitoring/sampler.h b/third_party/xla/xla/tsl/lib/monitoring/sampler.h index 3976e312876cb4..2fdbbd696b54c0 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/sampler.h +++ b/third_party/xla/xla/tsl/lib/monitoring/sampler.h @@ -29,10 +29,10 @@ limitations under the License. #include #include "xla/tsl/lib/monitoring/metric_def.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/protobuf/histogram.pb.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/status.h" -#include "tsl/platform/types.h" namespace tsl { namespace monitoring { @@ -125,10 +125,10 @@ class Sampler { #include "xla/tsl/lib/histogram/histogram.h" #include "xla/tsl/lib/monitoring/collection_registry.h" #include "xla/tsl/lib/monitoring/metric_def.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" #include "xla/tsl/protobuf/histogram.pb.h" -#include "tsl/platform/macros.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/status.h" #include "tsl/platform/thread_annotations.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/monitoring/test_utils.cc b/third_party/xla/xla/tsl/lib/monitoring/test_utils.cc index 3691130880ab24..a519d68f9e5e14 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/test_utils.cc +++ b/third_party/xla/xla/tsl/lib/monitoring/test_utils.cc @@ -21,8 +21,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_join.h" #include "xla/tsl/lib/monitoring/types.h" +#include "xla/tsl/platform/errors.h" #include "xla/tsl/protobuf/histogram.pb.h" -#include "tsl/platform/errors.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/test_utils.h b/third_party/xla/xla/tsl/lib/monitoring/test_utils.h index 85101ebffc6d69..5f083d00a862da 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/test_utils.h +++ b/third_party/xla/xla/tsl/lib/monitoring/test_utils.h @@ -19,8 +19,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/tsl/lib/monitoring/types.h" +#include "xla/tsl/platform/statusor.h" #include "xla/tsl/protobuf/histogram.pb.h" -#include "tsl/platform/statusor.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/timed.h b/third_party/xla/xla/tsl/lib/monitoring/timed.h index 732971aa171a1d..10a76b1883f5af 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/timed.h +++ b/third_party/xla/xla/tsl/lib/monitoring/timed.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_TSL_LIB_MONITORING_TIMED_H_ #define XLA_TSL_LIB_MONITORING_TIMED_H_ -#include "tsl/platform/env_time.h" +#include "xla/tsl/platform/env_time.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/types.h b/third_party/xla/xla/tsl/lib/monitoring/types.h index 7a0358c52bd90f..4618308c8ce3e3 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/types.h +++ b/third_party/xla/xla/tsl/lib/monitoring/types.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/random/BUILD b/third_party/xla/xla/tsl/lib/random/BUILD index 71a3561d4d134a..bceb9dbe18a2bc 100644 --- a/third_party/xla/xla/tsl/lib/random/BUILD +++ b/third_party/xla/xla/tsl/lib/random/BUILD @@ -40,11 +40,11 @@ cc_library( ":exact_uniform_int", ":philox_random", ":random_distributions_utils", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:types", ], alwayslink = 1, ) @@ -73,7 +73,7 @@ cc_library( hdrs = ["philox_random_test_utils.h"], deps = [ ":philox_random", - "@local_tsl//tsl/platform:logging", + "//xla/tsl/platform:logging", "@local_tsl//tsl/platform:random", ], ) @@ -84,9 +84,9 @@ cc_library( hdrs = ["weighted_picker.h"], deps = [ ":philox", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:types", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", ], alwayslink = 1, ) @@ -159,11 +159,11 @@ tsl_cc_test( srcs = ["distribution_sampler_test.cc"], deps = [ ":philox", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/platform:types", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", + "//xla/tsl/platform:types", ], ) @@ -175,10 +175,10 @@ tsl_cc_test( ":philox", ":philox_random", ":philox_random_test_utils", - "@local_tsl//tsl/platform:logging", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@local_tsl//tsl/platform:random", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -191,10 +191,10 @@ tsl_cc_test( ":philox_random", ":philox_random_test_utils", "//xla/tsl/lib/math:math_util", - "@local_tsl//tsl/platform:logging", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@local_tsl//tsl/platform:random", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -204,10 +204,10 @@ tsl_cc_test( srcs = ["simple_philox_test.cc"], deps = [ ":philox", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/platform:types", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "//xla/tsl/platform:types", ], ) @@ -218,11 +218,11 @@ tsl_cc_test( deps = [ ":philox", ":weighted_picker", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/platform:types", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", + "//xla/tsl/platform:types", ], ) diff --git a/third_party/xla/xla/tsl/lib/random/distribution_sampler.h b/third_party/xla/xla/tsl/lib/random/distribution_sampler.h index ababcc6bf23a31..afa0dac4df1644 100644 --- a/third_party/xla/xla/tsl/lib/random/distribution_sampler.h +++ b/third_party/xla/xla/tsl/lib/random/distribution_sampler.h @@ -36,9 +36,9 @@ limitations under the License. #include "absl/types/span.h" #include "xla/tsl/lib/random/simple_philox.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace random { diff --git a/third_party/xla/xla/tsl/lib/random/distribution_sampler_test.cc b/third_party/xla/xla/tsl/lib/random/distribution_sampler_test.cc index 16107ec61c26c0..c94d9ec2de73d7 100644 --- a/third_party/xla/xla/tsl/lib/random/distribution_sampler_test.cc +++ b/third_party/xla/xla/tsl/lib/random/distribution_sampler_test.cc @@ -21,10 +21,10 @@ limitations under the License. #include #include "xla/tsl/lib/random/simple_philox.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace random { diff --git a/third_party/xla/xla/tsl/lib/random/philox_random_test.cc b/third_party/xla/xla/tsl/lib/random/philox_random_test.cc index 7af1f9485754fd..3a4cc70d9f6ba8 100644 --- a/third_party/xla/xla/tsl/lib/random/philox_random_test.cc +++ b/third_party/xla/xla/tsl/lib/random/philox_random_test.cc @@ -24,9 +24,9 @@ limitations under the License. #include "xla/tsl/lib/random/philox_random_test_utils.h" #include "xla/tsl/lib/random/random_distributions.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/random.h" -#include "tsl/platform/test.h" namespace tsl { namespace random { diff --git a/third_party/xla/xla/tsl/lib/random/philox_random_test_utils.h b/third_party/xla/xla/tsl/lib/random/philox_random_test_utils.h index 6bbb1c89596b80..3c76e1553774f3 100644 --- a/third_party/xla/xla/tsl/lib/random/philox_random_test_utils.h +++ b/third_party/xla/xla/tsl/lib/random/philox_random_test_utils.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "xla/tsl/lib/random/philox_random.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/random.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/lib/random/random_distributions.h b/third_party/xla/xla/tsl/lib/random/random_distributions.h index ce231f9f652c27..72ee2ae49aa875 100644 --- a/third_party/xla/xla/tsl/lib/random/random_distributions.h +++ b/third_party/xla/xla/tsl/lib/random/random_distributions.h @@ -23,7 +23,7 @@ limitations under the License. #include "unsupported/Eigen/CXX11/Tensor" #include "xla/tsl/lib/random/philox_random.h" #include "xla/tsl/lib/random/random_distributions_utils.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace random { diff --git a/third_party/xla/xla/tsl/lib/random/random_distributions_test.cc b/third_party/xla/xla/tsl/lib/random/random_distributions_test.cc index b1dab4cd81d6d8..cd31230654e2e7 100644 --- a/third_party/xla/xla/tsl/lib/random/random_distributions_test.cc +++ b/third_party/xla/xla/tsl/lib/random/random_distributions_test.cc @@ -25,9 +25,9 @@ limitations under the License. #include "xla/tsl/lib/math/math_util.h" #include "xla/tsl/lib/random/philox_random.h" #include "xla/tsl/lib/random/philox_random_test_utils.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/random.h" -#include "tsl/platform/test.h" namespace tsl { namespace random { diff --git a/third_party/xla/xla/tsl/lib/random/simple_philox.cc b/third_party/xla/xla/tsl/lib/random/simple_philox.cc index f2c2bbe5820863..8b3481ac7c4f39 100644 --- a/third_party/xla/xla/tsl/lib/random/simple_philox.cc +++ b/third_party/xla/xla/tsl/lib/random/simple_philox.cc @@ -16,7 +16,7 @@ limitations under the License. #include "xla/tsl/lib/random/simple_philox.h" #include "xla/tsl/lib/random/exact_uniform_int.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { namespace random { diff --git a/third_party/xla/xla/tsl/lib/random/simple_philox_test.cc b/third_party/xla/xla/tsl/lib/random/simple_philox_test.cc index 3eded84eb0ee33..7a20dbeccf56c0 100644 --- a/third_party/xla/xla/tsl/lib/random/simple_philox_test.cc +++ b/third_party/xla/xla/tsl/lib/random/simple_philox_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include "tsl/platform/logging.h" -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace random { diff --git a/third_party/xla/xla/tsl/lib/random/weighted_picker.h b/third_party/xla/xla/tsl/lib/random/weighted_picker.h index 27903077df2a73..1300fba858d881 100644 --- a/third_party/xla/xla/tsl/lib/random/weighted_picker.h +++ b/third_party/xla/xla/tsl/lib/random/weighted_picker.h @@ -29,9 +29,9 @@ limitations under the License. #include -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace random { diff --git a/third_party/xla/xla/tsl/lib/random/weighted_picker_test.cc b/third_party/xla/xla/tsl/lib/random/weighted_picker_test.cc index 64e40c05c432a8..c4ae1bb4a1b036 100644 --- a/third_party/xla/xla/tsl/lib/random/weighted_picker_test.cc +++ b/third_party/xla/xla/tsl/lib/random/weighted_picker_test.cc @@ -20,11 +20,11 @@ limitations under the License. #include #include "xla/tsl/lib/random/simple_philox.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace random { diff --git a/third_party/xla/xla/tsl/lib/strings/BUILD b/third_party/xla/xla/tsl/lib/strings/BUILD index 0fd17fd53fb78b..fddf84b0a583da 100644 --- a/third_party/xla/xla/tsl/lib/strings/BUILD +++ b/third_party/xla/xla/tsl/lib/strings/BUILD @@ -14,11 +14,11 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//xla/tsl/lib/gtl:inlined_vector", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:hash", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:protobuf", ], ) diff --git a/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc b/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc index fef78bd1835a00..c952a87bb1cfa9 100644 --- a/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc +++ b/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc @@ -20,9 +20,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/string_view.h" #include "xla/tsl/lib/gtl/inlined_vector.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" #include "tsl/platform/hash.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/platform/BUILD b/third_party/xla/xla/tsl/platform/BUILD index 8cdacfdfbefb90..4a4866296dfbbc 100644 --- a/third_party/xla/xla/tsl/platform/BUILD +++ b/third_party/xla/xla/tsl/platform/BUILD @@ -6,14 +6,21 @@ load( "//xla/tsl:tsl.bzl", "if_not_fuchsia", "internal_visibility", + "tsl_copts", ) load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") load( "//xla/tsl/platform:build_config.bzl", + "tf_logging_deps", "tf_platform_alias", + "tf_platform_deps", "tf_windows_aware_platform_deps", "tsl_cc_test", ) +load( + "//xla/tsl/platform:build_config_root.bzl", + "if_static", +) load( "//xla/tsl/platform:rules_cc.bzl", "cc_library", @@ -30,6 +37,17 @@ package( exports_files( [ "subprocess.h", + "env_time.h", + "env.cc", + "file_system.cc", + "logging.h", + "file_system.h", + "file_system_helper.cc", + "file_system_helper.h", + "test.h", + "threadpool.cc", + "threadpool.h", + "env.h", ], visibility = internal_visibility([ "//tensorflow/core/platform:__subpackages__", @@ -54,6 +72,8 @@ filegroup( name = "test_hdrs", testonly = 1, srcs = [ + "test.h", + "test_benchmark.h", ], compatible_with = get_compatible_with_portable(), visibility = internal_visibility([ @@ -67,6 +87,7 @@ filegroup( name = "android_test_srcs", testonly = 1, srcs = [ + "test.h", ], compatible_with = get_compatible_with_portable(), visibility = internal_visibility([ @@ -86,6 +107,12 @@ filegroup( filegroup( name = "lib_hdrs", srcs = [ + "env.h", + "errors.h", + "file_statistics.h", + "file_system.h", + "file_system_helper.h", + "statusor.h", "subprocess.h", ], compatible_with = get_compatible_with_portable(), @@ -95,6 +122,11 @@ filegroup( filegroup( name = "base_hdrs", srcs = [ + "env_time.h", + "macros.h", + "threadpool.h", + "threadpool_interface.h", + "threadpool_options.h", ], compatible_with = get_compatible_with_portable(), ) @@ -102,6 +134,7 @@ filegroup( filegroup( name = "framework_lite_hdrs", srcs = [ + "macros.h", ], compatible_with = get_compatible_with_portable(), ) @@ -109,7 +142,29 @@ filegroup( # Export source files needed for mobile builds, which do not use granular targets. filegroup( name = "mobile_srcs_no_runtime", - srcs = [], + srcs = [ + "env.cc", + "env.h", + "env_time.h", + "errors.cc", + "errors.h", + "file_statistics.h", + "file_system.cc", + "file_system.h", + "file_system_helper.h", + "macros.h", + "status.cc", + "status.h", + "statusor.h", + "threadpool.cc", + "threadpool.h", + "threadpool_interface.h", + ] + select({ + "//xla/tsl:fuchsia": [], + "//conditions:default": [ + "file_system_helper.cc", + ], + }), compatible_with = get_compatible_with_portable(), ) @@ -145,6 +200,7 @@ filegroup( filegroup( name = "lib_proto_parsing_hdrs", srcs = [ + "macros.h", ], compatible_with = get_compatible_with_portable(), visibility = internal_visibility([ @@ -156,6 +212,8 @@ filegroup( filegroup( name = "lib_internal_public_hdrs", srcs = [ + "status.h", + "statusor.h", ], compatible_with = get_compatible_with_portable(), visibility = internal_visibility([ @@ -168,6 +226,7 @@ filegroup( filegroup( name = "tflite_portable_logging_hdrs", srcs = [ + "macros.h", ], compatible_with = get_compatible_with_portable(), visibility = internal_visibility([ @@ -180,6 +239,7 @@ filegroup( filegroup( name = "jpeg_internal_hdrs", srcs = [ + "macros.h", ], compatible_with = get_compatible_with_portable(), visibility = internal_visibility([ @@ -193,6 +253,7 @@ filegroup( filegroup( name = "gif_internal_hdrs", srcs = [ + "macros.h", ], compatible_with = get_compatible_with_portable(), visibility = internal_visibility([ @@ -206,6 +267,7 @@ filegroup( filegroup( name = "xla_cpu_runtime_srcs", srcs = [ + "macros.h", ], compatible_with = get_compatible_with_portable(), ) @@ -254,11 +316,329 @@ tsl_cc_test( ], tags = ["no_oss"], # TODO(b/327036247): revisit after this moves to XLA deps = [ - ":subprocess", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:subprocess", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:strcat", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) + +cc_library( + name = "env", + textual_hdrs = [ + "env.h", + "file_system.h", + "file_system_helper.h", + "threadpool.h", + ], + deps = tf_windows_aware_platform_deps("env") + if_static([":env_impl"]), +) + +cc_library( + name = "env_impl", + deps = tf_windows_aware_platform_deps("env_impl"), +) + +cc_library( + name = "env_time", + compatible_with = get_compatible_with_portable(), + textual_hdrs = ["env_time.h"], + deps = tf_windows_aware_platform_deps("env_time"), +) + +cc_library( + name = "errors", + srcs = ["errors.cc"], + hdrs = ["errors.h"], + deps = [ + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@local_tsl//tsl/platform:str_util", + "@local_tsl//tsl/platform:strcat", + ], +) + +tsl_cc_test( + name = "errors_test", + size = "small", + srcs = ["errors_test.cc"], + deps = [ + ":errors", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "file_statistics", + hdrs = ["file_statistics.h"], + deps = [ + "//xla/tsl/platform:types", + ], +) + +cc_library( + name = "logging", + compatible_with = get_compatible_with_portable(), + textual_hdrs = ["logging.h"], + visibility = [ + "//visibility:public", + ], + deps = tf_logging_deps(), +) + +tsl_cc_test( + name = "logging_test", + size = "small", + srcs = [ + "logging_test.cc", + ], + deps = [ + ":logging", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "@com_google_absl//absl/base:log_severity", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:stacktrace_handler", + ], +) + +cc_library( + name = "macros", + hdrs = ["macros.h"], + compatible_with = get_compatible_with_portable(), +) + +cc_library( + name = "status", + srcs = ["status.cc"], + hdrs = ["status.h"], + deps = [ + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/types:optional", + "@local_tsl//tsl/platform", + "@local_tsl//tsl/platform:mutex", + "@local_tsl//tsl/platform:stack_frame", + "@local_tsl//tsl/platform:stacktrace", + "@local_tsl//tsl/platform:str_util", + "@local_tsl//tsl/platform:strcat", + "@local_tsl//tsl/platform:stringprintf", + ] + tf_platform_deps("status"), +) + +tsl_cc_test( + name = "status_test", + size = "small", + srcs = ["status_test.cc"], + deps = [ + ":status", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_matchers", + "//xla/tsl/platform:status_to_from_proto", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", + "//xla/tsl/protobuf:status_proto_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:stack_frame", + ], +) + +cc_library( + name = "status_matchers", + testonly = 1, + srcs = ["status_matchers.cc"], + hdrs = ["status_matchers.h"], + deps = [ + "//xla/tsl/platform:status", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", + ], +) + +tsl_cc_test( + name = "status_matchers_test", + size = "small", + srcs = ["status_matchers_test.cc"], + deps = [ + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", + "//xla/tsl/platform:status_matchers", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", + ], +) + +cc_library( + name = "status_to_from_proto", + srcs = [ + "status_to_from_proto.cc", + ], + hdrs = ["status_to_from_proto.h"], + deps = [ + "//xla/tsl/platform:status", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", + "//xla/tsl/protobuf:status_proto_cc", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ] + tf_platform_deps("status"), +) + +cc_library( + name = "statusor", + hdrs = ["statusor.h"], + deps = [ + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform", + ] + tf_platform_deps("statusor"), +) + +tsl_cc_test( + name = "statusor_test", + size = "small", + srcs = ["statusor_test.cc"], + deps = [ + ":statusor", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", + "@com_google_absl//absl/base:config", + ], +) + +cc_library( + name = "test", + testonly = True, + srcs = ["test.cc"], + compatible_with = get_compatible_with_portable(), + textual_hdrs = ["test.h"], + deps = [ + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform", + "@local_tsl//tsl/platform:net", + "@local_tsl//tsl/platform:path", + ], +) + +cc_library( + name = "test_benchmark", + testonly = True, + hdrs = ["test_benchmark.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "@com_google_benchmark//:benchmark", + "@local_tsl//tsl/platform", + ], +) + +cc_library( + name = "test_main", + testonly = 1, + srcs = ["test_main.cc"], + copts = tsl_copts(), + linkopts = select({ + "//xla/tsl:windows": [], + "//conditions:default": ["-lm"], + }), + deps = [ + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform", + "@local_tsl//tsl/platform:stacktrace_handler", + ], + alwayslink = 1, +) + +cc_library( + name = "threadpool_async_executor", + hdrs = ["threadpool_async_executor.h"], + deps = [ + "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:env", + ], +) + +tsl_cc_test( + name = "threadpool_async_executor_test", + srcs = ["threadpool_async_executor_test.cc"], + deps = [ + ":threadpool_async_executor", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "threadpool_interface", + hdrs = ["threadpool_interface.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//xla/tsl/platform:types", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:mutex", + ], +) + +cc_library( + name = "threadpool_options", + hdrs = ["threadpool_options.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//xla/tsl/platform:threadpool_interface", + ], +) + +cc_library( + name = "types", + hdrs = ["types.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "@local_tsl//tsl/platform", + "@local_tsl//tsl/platform:bfloat16", + "@local_tsl//tsl/platform:ml_dtypes", + "@local_tsl//tsl/platform:tstring", + ] + tf_platform_deps("types"), +) diff --git a/third_party/xla/xla/tsl/platform/cloud/BUILD b/third_party/xla/xla/tsl/platform/cloud/BUILD index 46ef36438fcc36..450f5ea89af314 100644 --- a/third_party/xla/xla/tsl/platform/cloud/BUILD +++ b/third_party/xla/xla/tsl/platform/cloud/BUILD @@ -33,10 +33,10 @@ cc_library( hdrs = ["expiring_lru_cache.h"], copts = tsl_copts(), deps = [ - "@local_tsl//tsl/platform:env", + "//xla/tsl/platform:env", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", ], ) @@ -45,13 +45,13 @@ cc_library( hdrs = ["file_block_cache.h"], copts = tsl_copts(), deps = [ - "@local_tsl//tsl/platform:env", + "//xla/tsl/platform:env", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:notification", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:stringpiece", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", ], ) @@ -63,14 +63,14 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":file_block_cache", + "//xla/tsl/platform:env", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@com_google_absl//absl/cleanup", - "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:notification", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:stringpiece", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", ], ) @@ -81,12 +81,12 @@ cc_library( copts = tsl_copts(), deps = [ ":http_request", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:retrying_utils", - "@local_tsl//tsl/platform:status", ], ) @@ -96,7 +96,7 @@ cc_library( hdrs = ["gcs_throttle.h"], copts = tsl_copts(), deps = [ - "@local_tsl//tsl/platform:env", + "//xla/tsl/platform:env", ], ) @@ -118,25 +118,25 @@ cc_library( ":http_request", ":ram_file_block_cache", ":time_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:file_statistics", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@jsoncpp_git//:jsoncpp", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:file_statistics", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:retrying_file_system", "@local_tsl//tsl/platform:retrying_utils", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:str_util", "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:stringprintf", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/lib:traceme", ], alwayslink = 1, @@ -163,25 +163,25 @@ cc_library( ":http_request", ":ram_file_block_cache", ":time_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:file_statistics", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@jsoncpp_git//:jsoncpp", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:file_statistics", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:retrying_file_system", "@local_tsl//tsl/platform:retrying_utils", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:str_util", "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:stringprintf", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/lib:traceme", ], alwayslink = 1, @@ -192,13 +192,13 @@ cc_library( hdrs = ["http_request.h"], copts = tsl_copts(), deps = [ - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:macros", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:stringpiece", - "@local_tsl//tsl/platform:types", ], ) @@ -210,17 +210,17 @@ cc_library( deps = [ ":http_request", "//xla/tsl/lib/gtl:map_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "//xla/tsl/util:env_var", "@curl", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:scanner", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:str_util", "@local_tsl//tsl/platform:stringpiece", - "@local_tsl//tsl/platform:types", ], ) @@ -234,14 +234,14 @@ cc_library( deps = [ ":curl_http_request", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:test", + "//xla/tsl/platform:types", "@curl", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:stringpiece", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:types", ], ) @@ -256,15 +256,15 @@ cc_library( deps = [ ":compute_engine_metadata_client", ":oauth_client", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", "@com_google_absl//absl/strings", "@jsoncpp_git//:jsoncpp", "@local_tsl//tsl/platform:base64", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:retrying_utils", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:thread_annotations", ], ) @@ -281,10 +281,10 @@ cc_library( deps = [ ":curl_http_request", ":http_request", + "//xla/tsl/platform:env", + "//xla/tsl/platform:status", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:retrying_utils", - "@local_tsl//tsl/platform:status", ], ) @@ -300,8 +300,8 @@ cc_library( copts = tsl_copts(), deps = [ ":compute_engine_metadata_client", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", "@local_tsl//tsl/platform:str_util", ], ) @@ -312,9 +312,9 @@ cc_library( hdrs = ["now_seconds_env.h"], copts = tsl_copts(), deps = [ - "@local_tsl//tsl/platform:env", + "//xla/tsl/platform:env", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform:mutex", - "@local_tsl//tsl/platform:types", ], ) @@ -330,12 +330,12 @@ cc_library( deps = [ ":curl_http_request", ":http_request", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", "@boringssl//:crypto", "@jsoncpp_git//:jsoncpp", "@local_tsl//tsl/platform:base64", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", ], ) @@ -349,8 +349,8 @@ cc_library( ], copts = tsl_copts(), deps = [ - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", ], ) @@ -362,9 +362,9 @@ tsl_cc_test( ":expiring_lru_cache", ":now_seconds_env", "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", ], ) @@ -376,12 +376,14 @@ tsl_cc_test( ":now_seconds_env", ":ram_file_block_cache", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@local_tsl//tsl/platform:blocking_counter", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:notification", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -393,14 +395,14 @@ tsl_cc_test( ":gcs_file_system", ":http_request_fake", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "//xla/tsl/profiler/backends/cpu:traceme_recorder_impl", "//xla/tsl/profiler/utils:time_utils_impl", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:str_util", "@local_tsl//tsl/platform:strcat", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -411,10 +413,10 @@ tsl_cc_test( linkopts = if_windows(["-DEFAULTLIB:ws2_32.lib"]), deps = [ ":gcs_dns_cache", - "@local_tsl//tsl/platform:env_impl", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@local_tsl//tsl/platform:str_util", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -426,10 +428,10 @@ tsl_cc_test( deps = [ ":gcs_throttle", "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:env_impl", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@local_tsl//tsl/platform:str_util", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -440,14 +442,14 @@ tsl_cc_test( deps = [ ":curl_http_request", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform", - "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -464,14 +466,14 @@ tsl_cc_test( ":http_request_fake", ":oauth_client", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@boringssl//:crypto", "@local_tsl//tsl/platform:base64", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:scanner", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -489,10 +491,10 @@ tsl_cc_test( ":http_request_fake", ":oauth_client", "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:env_impl", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -503,10 +505,10 @@ tsl_cc_test( deps = [ ":compute_engine_metadata_client", ":http_request_fake", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", ], ) @@ -517,9 +519,9 @@ tsl_cc_test( deps = [ ":compute_engine_zone_provider", ":http_request_fake", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", ], ) @@ -530,7 +532,7 @@ tsl_cc_test( deps = [ ":time_util", "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/tsl/platform/cloud/auth_provider.h b/third_party/xla/xla/tsl/platform/cloud/auth_provider.h index 6b18ed8175089e..5cbc1704baa498 100644 --- a/third_party/xla/xla/tsl/platform/cloud/auth_provider.h +++ b/third_party/xla/xla/tsl/platform/cloud/auth_provider.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/compute_engine_metadata_client.h b/third_party/xla/xla/tsl/platform/cloud/compute_engine_metadata_client.h index c220d0a88c1bda..81863019a247ee 100644 --- a/third_party/xla/xla/tsl/platform/cloud/compute_engine_metadata_client.h +++ b/third_party/xla/xla/tsl/platform/cloud/compute_engine_metadata_client.h @@ -17,8 +17,8 @@ limitations under the License. #define XLA_TSL_PLATFORM_CLOUD_COMPUTE_ENGINE_METADATA_CLIENT_H_ #include "xla/tsl/platform/cloud/http_request.h" +#include "xla/tsl/platform/status.h" #include "tsl/platform/retrying_utils.h" -#include "tsl/platform/status.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/compute_engine_metadata_client_test.cc b/third_party/xla/xla/tsl/platform/cloud/compute_engine_metadata_client_test.cc index 948d177fd84fe0..b89e63cfa0a303 100644 --- a/third_party/xla/xla/tsl/platform/cloud/compute_engine_metadata_client_test.cc +++ b/third_party/xla/xla/tsl/platform/cloud/compute_engine_metadata_client_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "xla/tsl/platform/cloud/compute_engine_metadata_client.h" #include "xla/tsl/platform/cloud/http_request_fake.h" -#include "tsl/platform/env.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/test.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/compute_engine_zone_provider_test.cc b/third_party/xla/xla/tsl/platform/cloud/compute_engine_zone_provider_test.cc index c78a7b19a4a762..e9ecd10f68743a 100644 --- a/third_party/xla/xla/tsl/platform/cloud/compute_engine_zone_provider_test.cc +++ b/third_party/xla/xla/tsl/platform/cloud/compute_engine_zone_provider_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "xla/tsl/platform/cloud/compute_engine_zone_provider.h" #include "xla/tsl/platform/cloud/http_request_fake.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/curl_http_request.cc b/third_party/xla/xla/tsl/platform/cloud/curl_http_request.cc index fb0343332512e2..de26c04012c680 100644 --- a/third_party/xla/xla/tsl/platform/cloud/curl_http_request.cc +++ b/third_party/xla/xla/tsl/platform/cloud/curl_http_request.cc @@ -18,12 +18,12 @@ limitations under the License. #include #include "xla/tsl/lib/gtl/map_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/util/env_var.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/macros.h" #include "tsl/platform/scanner.h" #include "tsl/platform/str_util.h" -#include "tsl/platform/types.h" #define CHECK_CURL_OK(expr) CHECK_EQ(expr, CURLE_OK) diff --git a/third_party/xla/xla/tsl/platform/cloud/curl_http_request.h b/third_party/xla/xla/tsl/platform/cloud/curl_http_request.h index d2ba933227a950..717e59b13e5507 100644 --- a/third_party/xla/xla/tsl/platform/cloud/curl_http_request.h +++ b/third_party/xla/xla/tsl/platform/cloud/curl_http_request.h @@ -23,13 +23,13 @@ limitations under the License. #include #include "xla/tsl/platform/cloud/http_request.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/protobuf.h" -#include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/curl_http_request_test.cc b/third_party/xla/xla/tsl/platform/cloud/curl_http_request_test.cc index d4469b491c27b1..fb13515c7a5446 100644 --- a/third_party/xla/xla/tsl/platform/cloud/curl_http_request_test.cc +++ b/third_party/xla/xla/tsl/platform/cloud/curl_http_request_test.cc @@ -21,10 +21,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/mem.h" #include "tsl/platform/path.h" #include "tsl/platform/platform.h" -#include "tsl/platform/test.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/platform/cloud/expiring_lru_cache.h b/third_party/xla/xla/tsl/platform/cloud/expiring_lru_cache.h index 58f86d1fa2516a..4858cf3b1b9c33 100644 --- a/third_party/xla/xla/tsl/platform/cloud/expiring_lru_cache.h +++ b/third_party/xla/xla/tsl/platform/cloud/expiring_lru_cache.h @@ -21,10 +21,10 @@ limitations under the License. #include #include -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/mutex.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/expiring_lru_cache_test.cc b/third_party/xla/xla/tsl/platform/cloud/expiring_lru_cache_test.cc index 58cb1aebfcf70f..9f107e59c29599 100644 --- a/third_party/xla/xla/tsl/platform/cloud/expiring_lru_cache_test.cc +++ b/third_party/xla/xla/tsl/platform/cloud/expiring_lru_cache_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/cloud/now_seconds_env.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/platform/cloud/file_block_cache.h b/third_party/xla/xla/tsl/platform/cloud/file_block_cache.h index 20543efd881738..07dd253c85d379 100644 --- a/third_party/xla/xla/tsl/platform/cloud/file_block_cache.h +++ b/third_party/xla/xla/tsl/platform/cloud/file_block_cache.h @@ -23,13 +23,13 @@ limitations under the License. #include #include -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/mutex.h" #include "tsl/platform/notification.h" -#include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/gcs_dns_cache.cc b/third_party/xla/xla/tsl/platform/cloud/gcs_dns_cache.cc index dea209a795adf9..5db8af208c1b72 100644 --- a/third_party/xla/xla/tsl/platform/cloud/gcs_dns_cache.cc +++ b/third_party/xla/xla/tsl/platform/cloud/gcs_dns_cache.cc @@ -19,9 +19,9 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" #include "tsl/platform/retrying_utils.h" -#include "tsl/platform/status.h" #ifndef _WIN32 #include #include @@ -197,7 +197,7 @@ void GcsDnsCache::AnnotateRequest(HttpRequest* request) { LOG(ERROR) << "Error converting response to IP address for " << name << ": " << strerror(errno); } else { - output.emplace_back(buf); + output.push_back(buf); VLOG(1) << "... address: " << buf; } } diff --git a/third_party/xla/xla/tsl/platform/cloud/gcs_dns_cache.h b/third_party/xla/xla/tsl/platform/cloud/gcs_dns_cache.h index a29fe502854e42..5753881a9ff67a 100644 --- a/third_party/xla/xla/tsl/platform/cloud/gcs_dns_cache.h +++ b/third_party/xla/xla/tsl/platform/cloud/gcs_dns_cache.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "xla/tsl/platform/cloud/http_request.h" -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" namespace tsl { const int64_t kDefaultRefreshRateSecs = 60; diff --git a/third_party/xla/xla/tsl/platform/cloud/gcs_dns_cache_test.cc b/third_party/xla/xla/tsl/platform/cloud/gcs_dns_cache_test.cc index dc250f0015bbcb..85b3b435d6599b 100644 --- a/third_party/xla/xla/tsl/platform/cloud/gcs_dns_cache_test.cc +++ b/third_party/xla/xla/tsl/platform/cloud/gcs_dns_cache_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "xla/tsl/platform/cloud/gcs_dns_cache.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/str_util.h" -#include "tsl/platform/test.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/gcs_file_system.cc b/third_party/xla/xla/tsl/platform/cloud/gcs_file_system.cc index 923ad2692aeb55..45038d302ffec6 100644 --- a/third_party/xla/xla/tsl/platform/cloud/gcs_file_system.cc +++ b/third_party/xla/xla/tsl/platform/cloud/gcs_file_system.cc @@ -37,7 +37,7 @@ limitations under the License. #include #include -#include "tsl/platform/file_statistics.h" +#include "xla/tsl/platform/file_statistics.h" #include "tsl/platform/strcat.h" #ifdef _WIN32 #include // for _mktemp @@ -49,8 +49,8 @@ limitations under the License. #include "xla/tsl/platform/cloud/google_auth_provider.h" #include "xla/tsl/platform/cloud/ram_file_block_cache.h" #include "xla/tsl/platform/cloud/time_util.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" #include "tsl/platform/mutex.h" #include "tsl/platform/numbers.h" #include "tsl/platform/path.h" diff --git a/third_party/xla/xla/tsl/platform/cloud/gcs_file_system.h b/third_party/xla/xla/tsl/platform/cloud/gcs_file_system.h index d76768d3b1f9a9..811f9828a4f4d6 100644 --- a/third_party/xla/xla/tsl/platform/cloud/gcs_file_system.h +++ b/third_party/xla/xla/tsl/platform/cloud/gcs_file_system.h @@ -30,10 +30,10 @@ limitations under the License. #include "xla/tsl/platform/cloud/gcs_dns_cache.h" #include "xla/tsl/platform/cloud/gcs_throttle.h" #include "xla/tsl/platform/cloud/http_request.h" -#include "tsl/platform/file_system.h" +#include "xla/tsl/platform/file_system.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/retrying_file_system.h" -#include "tsl/platform/status.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/gcs_file_system_test.cc b/third_party/xla/xla/tsl/platform/cloud/gcs_file_system_test.cc index d0f0ec2fc9be8c..414c2f2d51aa63 100644 --- a/third_party/xla/xla/tsl/platform/cloud/gcs_file_system_test.cc +++ b/third_party/xla/xla/tsl/platform/cloud/gcs_file_system_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/cloud/http_request_fake.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/str_util.h" #include "tsl/platform/strcat.h" -#include "tsl/platform/test.h" // Undef DeleteFile macro defined in wndows.h. #ifdef PLATFORM_WINDOWS diff --git a/third_party/xla/xla/tsl/platform/cloud/gcs_throttle.h b/third_party/xla/xla/tsl/platform/cloud/gcs_throttle.h index c86305bc323033..be11261f93f607 100644 --- a/third_party/xla/xla/tsl/platform/cloud/gcs_throttle.h +++ b/third_party/xla/xla/tsl/platform/cloud/gcs_throttle.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_TSL_PLATFORM_CLOUD_GCS_THROTTLE_H_ #define XLA_TSL_PLATFORM_CLOUD_GCS_THROTTLE_H_ -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/gcs_throttle_test.cc b/third_party/xla/xla/tsl/platform/cloud/gcs_throttle_test.cc index dfbd3c6e78e1cb..50e5aab36cab2e 100644 --- a/third_party/xla/xla/tsl/platform/cloud/gcs_throttle_test.cc +++ b/third_party/xla/xla/tsl/platform/cloud/gcs_throttle_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "xla/tsl/platform/cloud/gcs_throttle.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/str_util.h" -#include "tsl/platform/test.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/google_auth_provider.cc b/third_party/xla/xla/tsl/platform/cloud/google_auth_provider.cc index edf220b295c030..d29a70b601ba04 100644 --- a/third_party/xla/xla/tsl/platform/cloud/google_auth_provider.cc +++ b/third_party/xla/xla/tsl/platform/cloud/google_auth_provider.cc @@ -25,9 +25,9 @@ limitations under the License. #include "absl/strings/match.h" #include "json/json.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" #include "tsl/platform/base64.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" #include "tsl/platform/path.h" #include "tsl/platform/retrying_utils.h" diff --git a/third_party/xla/xla/tsl/platform/cloud/google_auth_provider_test.cc b/third_party/xla/xla/tsl/platform/cloud/google_auth_provider_test.cc index cd378144e899cb..3b87cb5aa0fa73 100644 --- a/third_party/xla/xla/tsl/platform/cloud/google_auth_provider_test.cc +++ b/third_party/xla/xla/tsl/platform/cloud/google_auth_provider_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/cloud/http_request_fake.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/path.h" -#include "tsl/platform/test.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/http_request.h b/third_party/xla/xla/tsl/platform/cloud/http_request.h index b9cb805e4bc789..9ca2391b86dd57 100644 --- a/third_party/xla/xla/tsl/platform/cloud/http_request.h +++ b/third_party/xla/xla/tsl/platform/cloud/http_request.h @@ -20,13 +20,13 @@ limitations under the License. #include #include -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/protobuf.h" -#include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/http_request_fake.h b/third_party/xla/xla/tsl/platform/cloud/http_request_fake.h index c166cba3117bc1..0df34865991bb8 100644 --- a/third_party/xla/xla/tsl/platform/cloud/http_request_fake.h +++ b/third_party/xla/xla/tsl/platform/cloud/http_request_fake.h @@ -23,13 +23,13 @@ limitations under the License. #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/cloud/curl_http_request.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/protobuf.h" -#include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/now_seconds_env.h b/third_party/xla/xla/tsl/platform/cloud/now_seconds_env.h index 4f24d7c4094f65..db13a305ec7435 100644 --- a/third_party/xla/xla/tsl/platform/cloud/now_seconds_env.h +++ b/third_party/xla/xla/tsl/platform/cloud/now_seconds_env.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef XLA_TSL_PLATFORM_CLOUD_NOW_SECONDS_ENV_H_ #define XLA_TSL_PLATFORM_CLOUD_NOW_SECONDS_ENV_H_ -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/oauth_client.cc b/third_party/xla/xla/tsl/platform/cloud/oauth_client.cc index e4e16ef7423dfd..3559cf734cfc64 100644 --- a/third_party/xla/xla/tsl/platform/cloud/oauth_client.cc +++ b/third_party/xla/xla/tsl/platform/cloud/oauth_client.cc @@ -28,9 +28,9 @@ limitations under the License. #include #include #include "xla/tsl/platform/cloud/curl_http_request.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" #include "tsl/platform/base64.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/oauth_client.h b/third_party/xla/xla/tsl/platform/cloud/oauth_client.h index 409155acb0dbb0..578914ea0af507 100644 --- a/third_party/xla/xla/tsl/platform/cloud/oauth_client.h +++ b/third_party/xla/xla/tsl/platform/cloud/oauth_client.h @@ -20,8 +20,8 @@ limitations under the License. #include "json/json.h" #include "xla/tsl/platform/cloud/http_request.h" -#include "tsl/platform/env.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/status.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/oauth_client_test.cc b/third_party/xla/xla/tsl/platform/cloud/oauth_client_test.cc index cd91e664910de1..3a0a866bc53d1e 100644 --- a/third_party/xla/xla/tsl/platform/cloud/oauth_client_test.cc +++ b/third_party/xla/xla/tsl/platform/cloud/oauth_client_test.cc @@ -22,11 +22,11 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/cloud/http_request_fake.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/base64.h" -#include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/scanner.h" -#include "tsl/platform/test.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/platform/cloud/ram_file_block_cache.cc b/third_party/xla/xla/tsl/platform/cloud/ram_file_block_cache.cc index 50c7980e8663a0..79576b3e14f81d 100644 --- a/third_party/xla/xla/tsl/platform/cloud/ram_file_block_cache.cc +++ b/third_party/xla/xla/tsl/platform/cloud/ram_file_block_cache.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "absl/cleanup/cleanup.h" -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/ram_file_block_cache.h b/third_party/xla/xla/tsl/platform/cloud/ram_file_block_cache.h index da204d351b57ca..74faa7ac4d6cb8 100644 --- a/third_party/xla/xla/tsl/platform/cloud/ram_file_block_cache.h +++ b/third_party/xla/xla/tsl/platform/cloud/ram_file_block_cache.h @@ -24,13 +24,13 @@ limitations under the License. #include #include "xla/tsl/platform/cloud/file_block_cache.h" -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/mutex.h" #include "tsl/platform/notification.h" -#include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/ram_file_block_cache_test.cc b/third_party/xla/xla/tsl/platform/cloud/ram_file_block_cache_test.cc index 7a6d0dd52c1cd5..b8a72f15a42601 100644 --- a/third_party/xla/xla/tsl/platform/cloud/ram_file_block_cache_test.cc +++ b/third_party/xla/xla/tsl/platform/cloud/ram_file_block_cache_test.cc @@ -17,12 +17,13 @@ limitations under the License. #include +#include "absl/synchronization/blocking_counter.h" +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/cloud/now_seconds_env.h" -#include "tsl/platform/blocking_counter.h" -#include "tsl/platform/env.h" -#include "tsl/platform/notification.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace { @@ -483,11 +484,19 @@ TEST(RamFileBlockCacheTest, ParallelReads) { // concurrently (at which point it will respond with success to all callers), // or 10 seconds have elapsed (at which point it will respond with an error). const int callers = 4; - BlockingCounter counter(callers); - auto fetcher = [&counter](const string& filename, size_t offset, size_t n, - char* buffer, size_t* bytes_transferred) { - counter.DecrementCount(); - if (!counter.WaitFor(std::chrono::seconds(10))) { + absl::BlockingCounter counter(callers); + absl::Notification notification; + auto fetcher = [&counter, ¬ification]( + const string& filename, size_t offset, size_t n, + char* buffer, size_t* bytes_transferred) { + if (counter.DecrementCount()) { + notification.Notify(); + // This call to `Wait()` is not expected to block. Calling `Wait()` here + // allows us to satisfy `BlockingCounter`'s requirement: "When `Wait()` + // returns, it is legal to destroy the `BlockingCounter`.". + counter.Wait(); + } + if (!notification.WaitForNotificationWithTimeout(absl::Seconds(10))) { // This avoids having the test time out, which is harder to debug. return errors::FailedPrecondition("desired concurrency not reached"); } @@ -517,7 +526,7 @@ TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) { // Concurrent reads to the same file blocks should be de-duplicated. const size_t block_size = 16; int num_requests = 0; - Notification notification; + absl::Notification notification; auto fetcher = [&num_requests, ¬ification, block_size]( const string& filename, size_t offset, size_t n, char* buffer, size_t* bytes_transferred) { diff --git a/third_party/xla/xla/tsl/platform/cloud/time_util.cc b/third_party/xla/xla/tsl/platform/cloud/time_util.cc index 3950f387e72c3f..7f9816e6d350f0 100644 --- a/third_party/xla/xla/tsl/platform/cloud/time_util.cc +++ b/third_party/xla/xla/tsl/platform/cloud/time_util.cc @@ -23,7 +23,7 @@ limitations under the License. #ifdef _WIN32 #define timegm _mkgmtime #endif -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/time_util.h b/third_party/xla/xla/tsl/platform/cloud/time_util.h index 0b75a294bd300a..de9653b87acafe 100644 --- a/third_party/xla/xla/tsl/platform/cloud/time_util.h +++ b/third_party/xla/xla/tsl/platform/cloud/time_util.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_TSL_PLATFORM_CLOUD_TIME_UTIL_H_ #define XLA_TSL_PLATFORM_CLOUD_TIME_UTIL_H_ -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/time_util_test.cc b/third_party/xla/xla/tsl/platform/cloud/time_util_test.cc index 9cb6f22dfeb30c..f8a5d04471add4 100644 --- a/third_party/xla/xla/tsl/platform/cloud/time_util_test.cc +++ b/third_party/xla/xla/tsl/platform/cloud/time_util_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "xla/tsl/platform/cloud/time_util.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/cloud/zone_provider.h b/third_party/xla/xla/tsl/platform/cloud/zone_provider.h index c54b2f84a84f12..22a109500b94ad 100644 --- a/third_party/xla/xla/tsl/platform/cloud/zone_provider.h +++ b/third_party/xla/xla/tsl/platform/cloud/zone_provider.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/default/BUILD b/third_party/xla/xla/tsl/platform/default/BUILD index b614d6407825c8..7b6956585532cb 100644 --- a/third_party/xla/xla/tsl/platform/default/BUILD +++ b/third_party/xla/xla/tsl/platform/default/BUILD @@ -1,5 +1,6 @@ # Tensorflow default + linux implementations of tensorflow/core/platform libraries. load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load( "//xla/tsl:tsl.bzl", "if_cuda_tools", @@ -74,12 +75,12 @@ cc_library( "nobuilder", ], deps = [ + "//xla/tsl/platform:env", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:types", "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/platform", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:types", ], ) @@ -98,35 +99,39 @@ cc_library( "nobuilder", ], deps = [ + "//xla/tsl/platform:logging", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", - "@local_config_rocm//rocm:rocm_headers", "@local_config_tensorrt//:tensorrt_headers", "@local_tsl//tsl/platform:load_library", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:path", - ] + if_oss(["@local_config_nccl//:nccl_config"]), + ] + if_oss([ + "@local_config_nccl//:nccl_config", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_config", + "@local_config_rocm//rocm:rocm_headers", + ]), ) cc_library( name = "env", srcs = [ "posix_file_system.cc", - "@local_tsl//tsl/platform:env.cc", - "@local_tsl//tsl/platform:file_system.cc", - "@local_tsl//tsl/platform:file_system_helper.cc", - "@local_tsl//tsl/platform:threadpool.cc", + "//xla/tsl/platform:env.cc", + "//xla/tsl/platform:file_system.cc", + "//xla/tsl/platform:file_system_helper.cc", + "//xla/tsl/platform:threadpool.cc", ], hdrs = [ "posix_file_system.h", - "@local_tsl//tsl/platform:env.h", - "@local_tsl//tsl/platform:file_system.h", - "@local_tsl//tsl/platform:file_system_helper.h", + "//xla/tsl/platform:env.h", + "//xla/tsl/platform:file_system.h", + "//xla/tsl/platform:file_system_helper.h", + "//xla/tsl/platform:threadpool.h", "@local_tsl//tsl/platform:ram_file_system.h", - "@local_tsl//tsl/platform:threadpool.h", ], copts = tsl_copts(), tags = [ @@ -135,9 +140,20 @@ cc_library( "nobuilder", ], deps = [ + "//xla/tsl/platform:env_time", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:file_statistics", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:threadpool_interface", + "//xla/tsl/platform:types", "//xla/tsl/protobuf:error_codes_proto_impl_cc", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@eigen_archive//:eigen3", @@ -146,12 +162,7 @@ cc_library( "@local_tsl//tsl/platform:context", "@local_tsl//tsl/platform:cord", "@local_tsl//tsl/platform:denormal", - "@local_tsl//tsl/platform:env_time", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:file_statistics", "@local_tsl//tsl/platform:load_library", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:platform_port", @@ -159,15 +170,11 @@ cc_library( "@local_tsl//tsl/platform:regexp", "@local_tsl//tsl/platform:scanner", "@local_tsl//tsl/platform:setround", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:str_util", "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:stringpiece", "@local_tsl//tsl/platform:stringprintf", - "@local_tsl//tsl/platform:threadpool_interface", "@local_tsl//tsl/platform:tracing", - "@local_tsl//tsl/platform:types", ], ) @@ -183,9 +190,9 @@ cc_library( ], deps = [ ":env", + "//xla/tsl/platform:logging", "//xla/tsl/protobuf:error_codes_proto_impl_cc", "@local_tsl//tsl/platform:load_library", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:strcat", ], @@ -194,13 +201,13 @@ cc_library( cc_library( name = "env_time", srcs = ["env_time.cc"], - hdrs = ["@local_tsl//tsl/platform:env_time.h"], + hdrs = ["//xla/tsl/platform:env_time.h"], tags = [ "manual", "no_oss", "nobuilder", ], - deps = ["@local_tsl//tsl/platform:types"], + deps = ["//xla/tsl/platform:types"], ) cc_library( @@ -228,8 +235,8 @@ cc_library( "nobuilder", ], deps = [ + "//xla/tsl/platform:logging", "@com_google_absl//absl/log:check", - "@local_tsl//tsl/platform:logging", ] + tsl_grpc_cc_dependencies(), ) @@ -243,13 +250,13 @@ cc_library( "nobuilder", ], deps = [ + "//xla/tsl/platform:errors", + "//xla/tsl/platform:types", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:stringpiece", - "@local_tsl//tsl/platform:types", ], ) @@ -257,6 +264,7 @@ cc_library( name = "load_library", srcs = ["load_library.cc"], hdrs = ["@local_tsl//tsl/platform:load_library.h"], + linkstatic = True, tags = [ "manual", "no_oss", @@ -264,13 +272,15 @@ cc_library( ], deps = [ "@com_google_absl//absl/status", - ], + ] + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_rpath", + ]), ) cc_library( name = "logging", srcs = ["logging.cc"], - hdrs = ["@local_tsl//tsl/platform:logging.h"], + hdrs = ["//xla/tsl/platform:logging.h"], tags = [ "manual", "no_oss", @@ -278,14 +288,14 @@ cc_library( ], textual_hdrs = ["logging.h"], deps = [ + "//xla/tsl/platform:env_time", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "@com_google_absl//absl/base", "@com_google_absl//absl/base:log_severity", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform", - "@local_tsl//tsl/platform:env_time", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", - "@local_tsl//tsl/platform:types", ], ) @@ -308,7 +318,7 @@ cc_library( "nobuilder", ], deps = [ - "@local_tsl//tsl/platform:logging", + "//xla/tsl/platform:logging", "@local_tsl//tsl/platform:strcat", ], alwayslink = True, @@ -342,14 +352,14 @@ cc_library( "nobuilder", ], deps = [ + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "//xla/tsl/platform/profile_utils:profile_utils_cpu_utils", "@com_google_absl//absl/base", "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:byte_order", "@local_tsl//tsl/platform:dynamic_annotations", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:types", "@snappy", ] + select({ # TF Additional NUMA dependencies @@ -386,10 +396,11 @@ cc_library( "nobuilder", ], deps = [ + "//xla/tsl/platform:logging", + "//xla/tsl/platform:types", + "@local_config_rocm//rocm:rocm_config", "@local_config_rocm//rocm:rocm_headers", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:types", ], ) @@ -426,11 +437,11 @@ cc_library( ], textual_hdrs = ["subprocess.h"], deps = [ + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", - "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -449,15 +460,15 @@ cc_library( ], textual_hdrs = ["tracing_impl.h"], deps = [ + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "//xla/tsl/profiler/backends/cpu:threadpool_listener_state", "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:hash", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:str_util", "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:stringpiece", - "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -482,10 +493,10 @@ cc_library( "nobuilder", ], deps = [ - "@com_google_absl//absl/memory", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:mutex", - "@local_tsl//tsl/platform:notification", + "//xla/tsl/platform:env", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:platform_port", ], ) @@ -530,9 +541,9 @@ cc_library( textual_hdrs = ["statusor.h"], visibility = internal_visibility(["//tensorflow:__subpackages__"]), deps = [ + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:status", ], ) diff --git a/third_party/xla/xla/tsl/platform/default/build_config.bzl b/third_party/xla/xla/tsl/platform/default/build_config.bzl index f6b5255441a15a..dd79a03cd8acac 100644 --- a/third_party/xla/xla/tsl/platform/default/build_config.bzl +++ b/third_party/xla/xla/tsl/platform/default/build_config.bzl @@ -1,4 +1,4 @@ -# Platform-specific build configurations. +"""Platform-specific build configurations.""" load("@com_github_grpc_grpc//bazel:generate_cc.bzl", "generate_cc") load("@com_google_protobuf//:protobuf.bzl", "proto_gen") @@ -31,8 +31,16 @@ def well_known_proto_libs(): "@com_google_protobuf//:wrappers_proto", ] -# Appends a suffix to a list of deps. def tf_deps(deps, suffix): + """Appends a suffix to a list of deps. + + Args: + deps: the list of deps which will be suffixed + suffix: the suffix to add + + Returns: + The list of deps with the suffix applied. + """ tf_deps = [] # If the package name is in shorthand form (ie: does not contain a ':'), @@ -44,7 +52,7 @@ def tf_deps(deps, suffix): dep_pieces = dep.split("/") tf_dep += ":" + dep_pieces[len(dep_pieces) - 1] - tf_deps += [tf_dep + suffix] + tf_deps.append(tf_dep + suffix) return tf_deps @@ -259,7 +267,6 @@ def cc_proto_library( ) else: header_only_name = name + "_headers_only" - header_only_deps = tf_deps(protolib_deps, "_cc_headers_only") if make_default_target_header_only: native.alias( @@ -287,8 +294,9 @@ def cc_proto_library( if use_pywrap_rules(): pass else: + header_only_deps = tf_deps(protolib_deps, "_cc_headers_only") native.cc_library( - name = header_only_name, + name = header_only_name, # buildifier: disable=uninitialized deps = [ "@com_google_protobuf//:protobuf_headers", ] + header_only_deps + if_tsl_link_protobuf([impl_name]), @@ -446,17 +454,18 @@ def py_proto_library( # TODO(b/356020232): cleanup non-use_pywrap_rules part and all logic reated to # protobuf header-only targets after migration is done +# buildifier: disable=function-docstring def tf_proto_library_cc( name, srcs = [], - has_services = None, + has_services = None, # @unused protodeps = [], visibility = None, testonly = 0, cc_libs = [], cc_grpc_version = None, use_grpc_namespace = False, - j2objc_api_version = 1, + j2objc_api_version = 1, # @unused js_codegen = "jspb", create_service = False, create_java_proto = False, @@ -470,7 +479,7 @@ def tf_proto_library_cc( testonly = testonly, visibility = visibility, ) - _ignore = (create_service, create_java_proto, create_kotlin_proto) + _ = (create_service, create_java_proto, create_kotlin_proto) # @unused use_grpc_plugin = None if cc_grpc_version: @@ -552,6 +561,7 @@ def tf_proto_library_cc( local_defines = local_defines, ) +# buildifier: disable=function-docstring def tf_proto_library_py( name, srcs = [], @@ -592,9 +602,12 @@ def tf_proto_library_py( deps = deps + py_deps + [clean_dep("@com_google_protobuf//:protobuf_python")], ) -def tf_jspb_proto_library(**kwargs): +def tf_jspb_proto_library(**_kwargs): pass +# buildifier: disable=function-docstring +# buildifier: disable=function-docstring-args +# buildifier: disable=function-docstring-return def tf_proto_library( name, srcs = [], @@ -603,9 +616,9 @@ def tf_proto_library( visibility = None, testonly = 0, cc_libs = [], - cc_grpc_version = None, + cc_grpc_version = None, # @unused use_grpc_namespace = False, - j2objc_api_version = 1, + j2objc_api_version = 1, # @unused js_codegen = "jspb", create_service = False, create_java_proto = False, @@ -621,7 +634,9 @@ def tf_proto_library( # TODO(b/145545130): Add docstring explaining what rules this creates and how # opensource projects importing TF in bazel can use them safely (i.e. w/o ODR or # ABI violations). - _ignore = ( + + # @unused + _ = ( js_codegen, create_service, create_java_proto, @@ -757,7 +772,8 @@ def tf_lib_proto_parsing_deps(): clean_dep("@local_xla//xla/tsl/protobuf:protos_all_cc"), ] -def tf_py_clif_cc(name, visibility = None, **kwargs): +def tf_py_clif_cc(name, visibility = None, **_kwargs): + _ = visibility # @unused pass def tf_pyclif_proto_library( @@ -765,7 +781,8 @@ def tf_pyclif_proto_library( proto_lib, proto_srcfile = "", visibility = None, - **kwargs): + **_kwargs): + _ = (proto_lib, proto_srcfile, visibility) # @unused native.filegroup(name = name) native.filegroup(name = name + "_pb2") @@ -862,7 +879,6 @@ def tf_resource_deps(): def tf_portable_deps_no_runtime(): return [ "@eigen_archive//:eigen3", - "@double_conversion//:double-conversion", "@com_googlesource_code_re2//:re2", "@farmhash_archive//:farmhash", ] diff --git a/third_party/xla/xla/tsl/platform/default/build_config_root.bzl b/third_party/xla/xla/tsl/platform/default/build_config_root.bzl index c26b0681e0328a..5a45456669c3ad 100644 --- a/third_party/xla/xla/tsl/platform/default/build_config_root.bzl +++ b/third_party/xla/xla/tsl/platform/default/build_config_root.bzl @@ -1,6 +1,8 @@ -# Lower-level functionality for build config. -# The functions in this file might be referred by tensorflow.bzl. They have to -# be separate to avoid cyclic references. +"""Lower-level functionality for build config. + +The functions in this file might be referred by tensorflow.bzl. They have to +be separate to avoid cyclic references. +""" load("@local_config_remote_execution//:remote_execution.bzl", "gpu_test_tags") load("@local_tsl//third_party/py/rules_pywrap:pywrap.bzl", "use_pywrap_rules") @@ -46,6 +48,7 @@ def tf_additional_tpu_ops_deps(): # dependency list is used when using the framework_shared_object config # on MacOS platforms. If "macos" is not provided, the "otherwise" list is # used for all framework_shared_object platforms including MacOS. +# buildifier: disable=function-docstring def if_static(extra_deps, otherwise = [], macos = []): if use_pywrap_rules(): return extra_deps @@ -93,6 +96,7 @@ def if_llvm_arm_available(then, otherwise = []): }) def if_llvm_hexagon_available(then, otherwise = []): + _ = then # @unused return otherwise def if_llvm_powerpc_available(then, otherwise = []): diff --git a/third_party/xla/xla/tsl/platform/default/cuda_root_path.cc b/third_party/xla/xla/tsl/platform/default/cuda_root_path.cc index ca6da0e5532eaa..578d8b05c70e68 100644 --- a/third_party/xla/xla/tsl/platform/default/cuda_root_path.cc +++ b/third_party/xla/xla/tsl/platform/default/cuda_root_path.cc @@ -31,9 +31,9 @@ limitations under the License. #if !defined(PLATFORM_GOOGLE) #include "third_party/gpus/cuda/cuda_config.h" -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" #endif -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { @@ -46,7 +46,7 @@ std::vector CandidateCudaRoots() { std::string executable_path = tsl::Env::Default()->GetExecutablePath(); std::string cuda_nvcc_dir = io::JoinPath(executable_path + "." + runfiles_suffix, "cuda_nvcc"); - roots.emplace_back(cuda_nvcc_dir); + roots.push_back(cuda_nvcc_dir); // The CUDA candidate root for python targets. std::string runfiles_dir = tsl::Env::Default()->GetRunfilesDir(); @@ -54,10 +54,11 @@ std::vector CandidateCudaRoots() { cuda_nvcc_dir = io::JoinPath( runfiles_dir.substr(0, runfiles_ind + runfiles_suffix.length()), "cuda_nvcc"); - roots.emplace_back(cuda_nvcc_dir); + roots.push_back(cuda_nvcc_dir); - roots.emplace_back(TF_CUDA_TOOLKIT_PATH); + roots.push_back(TF_CUDA_TOOLKIT_PATH); roots.emplace_back(std::string("/usr/local/cuda")); + roots.emplace_back(std::string("/opt/cuda")); #if defined(PLATFORM_POSIX) && !defined(__APPLE__) Dl_info info; diff --git a/third_party/xla/xla/tsl/platform/default/dlopen_checker.cc b/third_party/xla/xla/tsl/platform/default/dlopen_checker.cc index 8e0bdddc701e5f..763df14caf62d3 100644 --- a/third_party/xla/xla/tsl/platform/default/dlopen_checker.cc +++ b/third_party/xla/xla/tsl/platform/default/dlopen_checker.cc @@ -16,7 +16,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "xla/tsl/platform/default/dso_loader.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { namespace internal { diff --git a/third_party/xla/xla/tsl/platform/default/dlopen_checker_stub.cc b/third_party/xla/xla/tsl/platform/default/dlopen_checker_stub.cc index 504c35f44ffa82..152578731ab5fa 100644 --- a/third_party/xla/xla/tsl/platform/default/dlopen_checker_stub.cc +++ b/third_party/xla/xla/tsl/platform/default/dlopen_checker_stub.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "absl/status/status.h" #include "xla/tsl/platform/default/dso_loader.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { namespace internal { diff --git a/third_party/xla/xla/tsl/platform/default/dso_loader.cc b/third_party/xla/xla/tsl/platform/default/dso_loader.cc index 0d246cb84cd682..5ebe6c9dae7a22 100644 --- a/third_party/xla/xla/tsl/platform/default/dso_loader.cc +++ b/third_party/xla/xla/tsl/platform/default/dso_loader.cc @@ -24,8 +24,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "third_party/gpus/cuda/cuda_config.h" #include "third_party/nccl/nccl_config.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/load_library.h" -#include "tsl/platform/logging.h" #include "tsl/platform/path.h" #include "tsl/platform/platform.h" #include "third_party/tensorrt/tensorrt_config.h" diff --git a/third_party/xla/xla/tsl/platform/default/env.cc b/third_party/xla/xla/tsl/platform/default/env.cc index d60c22c30d9bd8..3022af8b33866f 100644 --- a/third_party/xla/xla/tsl/platform/default/env.cc +++ b/third_party/xla/xla/tsl/platform/default/env.cc @@ -27,6 +27,8 @@ limitations under the License. #include #include +#include + #ifdef __FreeBSD__ #include #endif @@ -36,10 +38,10 @@ limitations under the License. #include #include "xla/tsl/platform/default/posix_file_system.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" #include "xla/tsl/protobuf/error_codes.pb.h" -#include "tsl/platform/env.h" #include "tsl/platform/load_library.h" -#include "tsl/platform/logging.h" #include "tsl/platform/mutex.h" #include "tsl/platform/ram_file_system.h" #include "tsl/platform/strcat.h" @@ -137,8 +139,9 @@ class PosixEnv : public Env { return new PThread(thread_options, name, std::move(fn)); } - int32 GetCurrentThreadId() override { - static thread_local int32 current_thread_id = GetCurrentThreadIdInternal(); + int64_t GetCurrentThreadId() override { + static thread_local int64_t current_thread_id = + GetCurrentThreadIdInternal(); return current_thread_id; } @@ -230,15 +233,15 @@ class PosixEnv : public Env { private: void GetLocalTempDirectories(std::vector* list) override; - int32 GetCurrentThreadIdInternal() { + int64_t GetCurrentThreadIdInternal() { #ifdef __APPLE__ uint64_t tid64; pthread_threadid_np(nullptr, &tid64); - return static_cast(tid64); + return static_cast(tid64); #elif defined(__FreeBSD__) return pthread_getthreadid_np(); #elif defined(__NR_gettid) - return static_cast(syscall(__NR_gettid)); + return static_cast(syscall(__NR_gettid)); #else return std::hash()(std::this_thread::get_id()); #endif diff --git a/third_party/xla/xla/tsl/platform/default/env_time.cc b/third_party/xla/xla/tsl/platform/default/env_time.cc index 6d8b583d527504..cfe7d23d1a2a72 100644 --- a/third_party/xla/xla/tsl/platform/default/env_time.cc +++ b/third_party/xla/xla/tsl/platform/default/env_time.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/env_time.h" +#include "xla/tsl/platform/env_time.h" #include #include diff --git a/third_party/xla/xla/tsl/platform/default/grpc_credentials.cc b/third_party/xla/xla/tsl/platform/default/grpc_credentials.cc index 44850f56e05195..a5c366a4dd0c29 100644 --- a/third_party/xla/xla/tsl/platform/default/grpc_credentials.cc +++ b/third_party/xla/xla/tsl/platform/default/grpc_credentials.cc @@ -19,7 +19,7 @@ #include "absl/log/check.h" #include "grpcpp/security/credentials.h" #include "grpcpp/security/server_credentials.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/default/human_readable_json.cc b/third_party/xla/xla/tsl/platform/default/human_readable_json.cc index 167cdd2b891312..5c3da22fddddc2 100644 --- a/third_party/xla/xla/tsl/platform/default/human_readable_json.cc +++ b/third_party/xla/xla/tsl/platform/default/human_readable_json.cc @@ -20,10 +20,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/strcat.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/default/integral_types.h b/third_party/xla/xla/tsl/platform/default/integral_types.h index 0827b917369eab..0e67cdf9eb047d 100644 --- a/third_party/xla/xla/tsl/platform/default/integral_types.h +++ b/third_party/xla/xla/tsl/platform/default/integral_types.h @@ -18,8 +18,8 @@ limitations under the License. #include -// IWYU pragma: private, include "third_party/tensorflow/tsl/platform/types.h" -// IWYU pragma: friend third_party/tensorflow/tsl/platform/types.h +// IWYU pragma: private, include "xla/tsl/platform/types.h" +// IWYU pragma: friend third_party/tensorflow/compiler/xla/tsl/platform/types.h namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/default/logging.cc b/third_party/xla/xla/tsl/platform/default/logging.cc index 78a6db44efb281..31a3533bc193d9 100644 --- a/third_party/xla/xla/tsl/platform/default/logging.cc +++ b/third_party/xla/xla/tsl/platform/default/logging.cc @@ -20,8 +20,8 @@ limitations under the License. #include "absl/base/internal/sysinfo.h" #include "absl/base/log_severity.h" #include "absl/strings/string_view.h" -#include "tsl/platform/env_time.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/env_time.h" +#include "xla/tsl/platform/macros.h" #include "tsl/platform/mutex.h" #if defined(PLATFORM_POSIX_ANDROID) diff --git a/third_party/xla/xla/tsl/platform/default/logging.h b/third_party/xla/xla/tsl/platform/default/logging.h index c118157347aa20..bc72e301ffbca5 100644 --- a/third_party/xla/xla/tsl/platform/default/logging.h +++ b/third_party/xla/xla/tsl/platform/default/logging.h @@ -22,8 +22,8 @@ limitations under the License. #ifndef XLA_TSL_PLATFORM_DEFAULT_LOGGING_H_ #define XLA_TSL_PLATFORM_DEFAULT_LOGGING_H_ -// IWYU pragma: private, include "third_party/tensorflow/tsl/platform/logging.h" -// IWYU pragma: friend third_party/tensorflow/tsl/platform/logging.h +// IWYU pragma: private, include "xla/tsl/platform/logging.h" +// IWYU pragma: friend third_party/tensorflow/compiler/xla/tsl/platform/logging.h #include #include @@ -34,8 +34,8 @@ limitations under the License. #include "absl/base/log_severity.h" #include "absl/strings/string_view.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" // TODO(mrry): Prevent this Windows.h #define from leaking out of our headers. #undef ERROR diff --git a/third_party/xla/xla/tsl/platform/default/net.cc b/third_party/xla/xla/tsl/platform/default/net.cc index b487e35f4fb618..640f223071b232 100644 --- a/third_party/xla/xla/tsl/platform/default/net.cc +++ b/third_party/xla/xla/tsl/platform/default/net.cc @@ -26,7 +26,7 @@ limitations under the License. #include #include -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/strcat.h" // https://en.wikipedia.org/wiki/Ephemeral_port diff --git a/third_party/xla/xla/tsl/platform/default/platform.bzl b/third_party/xla/xla/tsl/platform/default/platform.bzl index 76bfaa896efa2f..d5db2b948d0f8d 100644 --- a/third_party/xla/xla/tsl/platform/default/platform.bzl +++ b/third_party/xla/xla/tsl/platform/default/platform.bzl @@ -1,3 +1,4 @@ +"""Platform specific paths for various libraries and utilities.""" CUDA_VERSION = "" CUDNN_VERSION = "" @@ -10,6 +11,7 @@ def cuda_sdk_version(): def cudnn_sdk_version(): return CUDNN_VERSION +# buildifier: disable=function-docstring def cuda_library_path(name, version = cuda_sdk_version()): if PLATFORM == "Darwin": if not version: @@ -27,6 +29,7 @@ def cuda_static_library_path(name): else: return "lib64/lib{}_static.a".format(name) +# buildifier: disable=function-docstring def cudnn_library_path(version = cudnn_sdk_version()): if PLATFORM == "Darwin": if not version: @@ -38,6 +41,7 @@ def cudnn_library_path(version = cudnn_sdk_version()): else: return "lib64/libcudnn.so.{}".format(version) +# buildifier: disable=function-docstring def cupti_library_path(version = cuda_sdk_version()): if PLATFORM == "Darwin": if not version: diff --git a/third_party/xla/xla/tsl/platform/default/port.cc b/third_party/xla/xla/tsl/platform/default/port.cc index caf342c730ecb3..06322f61f1fda6 100644 --- a/third_party/xla/xla/tsl/platform/default/port.cc +++ b/third_party/xla/xla/tsl/platform/default/port.cc @@ -14,14 +14,14 @@ limitations under the License. ==============================================================================*/ #include "absl/base/internal/sysinfo.h" +#include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/profile_utils/cpu_utils.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/cpu_info.h" #include "tsl/platform/host_info.h" -#include "tsl/platform/logging.h" #include "tsl/platform/mem.h" #include "tsl/platform/numa.h" #include "tsl/platform/snappy.h" -#include "tsl/platform/types.h" #if defined(__linux__) #include diff --git a/third_party/xla/xla/tsl/platform/default/posix_file_system.cc b/third_party/xla/xla/tsl/platform/default/posix_file_system.cc index 66f2d758d83d44..68ee3b1b7b9697 100644 --- a/third_party/xla/xla/tsl/platform/default/posix_file_system.cc +++ b/third_party/xla/xla/tsl/platform/default/posix_file_system.cc @@ -30,12 +30,12 @@ limitations under the License. #include #include "xla/tsl/platform/default/posix_file_system.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/file_system_helper.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" #include "xla/tsl/protobuf/error_codes.pb.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/file_system_helper.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/status.h" #include "tsl/platform/strcat.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/default/posix_file_system.h b/third_party/xla/xla/tsl/platform/default/posix_file_system.h index e241305d6a12e8..a54ecf04017dcd 100644 --- a/third_party/xla/xla/tsl/platform/default/posix_file_system.h +++ b/third_party/xla/xla/tsl/platform/default/posix_file_system.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_TSL_PLATFORM_DEFAULT_POSIX_FILE_SYSTEM_H_ #define XLA_TSL_PLATFORM_DEFAULT_POSIX_FILE_SYSTEM_H_ -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" #include "tsl/platform/path.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/default/rocm_rocdl_path.cc b/third_party/xla/xla/tsl/platform/default/rocm_rocdl_path.cc index a1934f81e35723..f5cd4c8595f744 100644 --- a/third_party/xla/xla/tsl/platform/default/rocm_rocdl_path.cc +++ b/third_party/xla/xla/tsl/platform/default/rocm_rocdl_path.cc @@ -22,7 +22,7 @@ limitations under the License. #if !defined(PLATFORM_GOOGLE) && TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" #endif -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/default/statusor.h b/third_party/xla/xla/tsl/platform/default/statusor.h index d5ddb2d0668c68..babd52ed96d7b7 100644 --- a/third_party/xla/xla/tsl/platform/default/statusor.h +++ b/third_party/xla/xla/tsl/platform/default/statusor.h @@ -16,8 +16,8 @@ limitations under the License. #define XLA_TSL_PLATFORM_DEFAULT_STATUSOR_H_ #include "absl/status/statusor.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" #define TF_ASSIGN_OR_RETURN(lhs, rexpr) \ TF_ASSIGN_OR_RETURN_IMPL( \ diff --git a/third_party/xla/xla/tsl/platform/default/subprocess.cc b/third_party/xla/xla/tsl/platform/default/subprocess.cc index b3ffe1d441cd65..85cc2e3bcd9534 100644 --- a/third_party/xla/xla/tsl/platform/default/subprocess.cc +++ b/third_party/xla/xla/tsl/platform/default/subprocess.cc @@ -27,7 +27,7 @@ limitations under the License. #include #include -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" // Android versions older than 28 do not have posix_spawn(). #if !defined(__ANDROID_API__) || __ANDROID_API__ >= 28 diff --git a/third_party/xla/xla/tsl/platform/default/subprocess.h b/third_party/xla/xla/tsl/platform/default/subprocess.h index 7366762bb1e102..e7ce0d88f601ac 100644 --- a/third_party/xla/xla/tsl/platform/default/subprocess.h +++ b/third_party/xla/xla/tsl/platform/default/subprocess.h @@ -22,9 +22,9 @@ limitations under the License. #include #include -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/default/unbounded_work_queue.cc b/third_party/xla/xla/tsl/platform/default/unbounded_work_queue.cc index 818d54435439d0..3b11354bbd14b7 100644 --- a/third_party/xla/xla/tsl/platform/default/unbounded_work_queue.cc +++ b/third_party/xla/xla/tsl/platform/default/unbounded_work_queue.cc @@ -15,24 +15,25 @@ limitations under the License. #include "xla/tsl/platform/default/unbounded_work_queue.h" -#include "absl/memory/memory.h" -#include "tsl/platform/env.h" -#include "tsl/platform/mutex.h" +#include + +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/env.h" #include "tsl/platform/numa.h" namespace tsl { -UnboundedWorkQueue::UnboundedWorkQueue(Env* env, const string& thread_name, +UnboundedWorkQueue::UnboundedWorkQueue(Env* env, absl::string_view thread_name, const ThreadOptions& thread_options) : env_(env), thread_name_(thread_name), thread_options_(thread_options) {} UnboundedWorkQueue::~UnboundedWorkQueue() { { - mutex_lock l(work_queue_mu_); + absl::MutexLock l(&work_queue_mu_); // Wake up all `PooledThreadFunc` threads and cause them to terminate before // joining them when `threads_` is cleared. cancelled_ = true; - work_queue_cv_.notify_all(); if (!work_queue_.empty()) { LOG(ERROR) << "UnboundedWorkQueue named \"" << thread_name_ << "\" was " << "deleted with pending work in its queue. This may indicate " @@ -41,7 +42,7 @@ UnboundedWorkQueue::~UnboundedWorkQueue() { } { - mutex_lock l(thread_pool_mu_); + absl::MutexLock l(&thread_pool_mu_); // Clear the list of pooled threads, which will eventually terminate due to // the previous notification. // @@ -55,9 +56,8 @@ UnboundedWorkQueue::~UnboundedWorkQueue() { void UnboundedWorkQueue::Schedule(WorkFunction fn) { // Enqueue a work item for the new thread's function, and wake up a // cached thread to process it. - mutex_lock l(work_queue_mu_); + absl::MutexLock l(&work_queue_mu_); work_queue_.push_back(std::move(fn)); - work_queue_cv_.notify_one(); // NOTE: The queue may be non-empty, so we must account for queued work when // considering how many threads are free. if (work_queue_.size() > num_idle_threads_) { @@ -67,7 +67,7 @@ void UnboundedWorkQueue::Schedule(WorkFunction fn) { Thread* new_thread = env_->StartThread({}, thread_name_, [this]() { PooledThreadFunc(); }); - mutex_lock l(thread_pool_mu_); + absl::MutexLock l(&thread_pool_mu_); thread_pool_.emplace_back(new_thread); } } @@ -81,13 +81,12 @@ void UnboundedWorkQueue::PooledThreadFunc() { while (true) { WorkFunction fn; { - mutex_lock l(work_queue_mu_); + absl::MutexLock l(&work_queue_mu_); ++num_idle_threads_; - while (!cancelled_ && work_queue_.empty()) { - // Wait for a new work function to be submitted, or the cache to be - // destroyed. - work_queue_cv_.wait(l); - } + // Wait for a new work function to be submitted, or the cache to be + // destroyed. + work_queue_mu_.Await( + absl::Condition(this, &UnboundedWorkQueue::HasWorkOrIsCancelled)); if (cancelled_) { return; } diff --git a/third_party/xla/xla/tsl/platform/default/unbounded_work_queue.h b/third_party/xla/xla/tsl/platform/default/unbounded_work_queue.h index 401b2b596d350d..8c3c34b594b7e4 100644 --- a/third_party/xla/xla/tsl/platform/default/unbounded_work_queue.h +++ b/third_party/xla/xla/tsl/platform/default/unbounded_work_queue.h @@ -15,13 +15,17 @@ limitations under the License. #ifndef XLA_TSL_PLATFORM_DEFAULT_UNBOUNDED_WORK_QUEUE_H_ #define XLA_TSL_PLATFORM_DEFAULT_UNBOUNDED_WORK_QUEUE_H_ +#include #include +#include #include +#include #include -#include "tsl/platform/env.h" -#include "tsl/platform/mutex.h" -#include "tsl/platform/notification.h" +#include "absl/base/thread_annotations.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/env.h" namespace tsl { @@ -36,7 +40,7 @@ namespace tsl { // fragmentation that can result from excessive thread creation. class UnboundedWorkQueue { public: - UnboundedWorkQueue(Env* env, const string& thread_name, + UnboundedWorkQueue(Env* env, absl::string_view thread_name, const ThreadOptions& thread_options = {}); ~UnboundedWorkQueue(); @@ -50,17 +54,20 @@ class UnboundedWorkQueue { private: void PooledThreadFunc(); + bool HasWorkOrIsCancelled() const ABSL_SHARED_LOCKS_REQUIRED(work_queue_mu_) { + return !work_queue_.empty() || cancelled_; + } + Env* const env_; // Not owned. - const string thread_name_; + const std::string thread_name_; const ThreadOptions thread_options_; - mutex work_queue_mu_; - condition_variable work_queue_cv_ TF_GUARDED_BY(work_queue_mu_); - size_t num_idle_threads_ TF_GUARDED_BY(work_queue_mu_) = 0; - bool cancelled_ TF_GUARDED_BY(work_queue_mu_) = false; - std::deque work_queue_ TF_GUARDED_BY(work_queue_mu_); - mutex thread_pool_mu_; + absl::Mutex work_queue_mu_; + size_t num_idle_threads_ ABSL_GUARDED_BY(work_queue_mu_) = 0; + bool cancelled_ ABSL_GUARDED_BY(work_queue_mu_) = false; + std::deque work_queue_ ABSL_GUARDED_BY(work_queue_mu_); + absl::Mutex thread_pool_mu_; std::vector> thread_pool_ - TF_GUARDED_BY(thread_pool_mu_); + ABSL_GUARDED_BY(thread_pool_mu_); }; } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/platform/env.cc b/third_party/xla/xla/tsl/platform/env.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/platform/env.cc rename to third_party/xla/xla/tsl/platform/env.cc index 0945a773c78851..088709dda87d9e 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/env.cc +++ b/third_party/xla/xla/tsl/platform/env.cc @@ -13,16 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" #include +#include #include #include #include -#include "tsl/platform/env_time.h" -#include "tsl/platform/errors.h" +#include "absl/strings/str_format.h" +#include "xla/tsl/platform/env_time.h" +#include "xla/tsl/platform/errors.h" #include "tsl/platform/host_info.h" #include "tsl/platform/path.h" #include "tsl/platform/platform.h" @@ -445,12 +447,12 @@ bool Env::LocalTempFilename(string* filename) { } bool Env::CreateUniqueFileName(string* prefix, const string& suffix) { - int32_t tid = GetCurrentThreadId(); + int64_t tid = GetCurrentThreadId(); int32_t pid = GetProcessId(); long long now_microsec = NowMicros(); // NOLINT - *prefix += strings::Printf("%s-%x-%d-%llx", port::Hostname().c_str(), tid, - pid, now_microsec); + absl::StrAppendFormat(prefix, "%s-%x-%d-%llx", port::Hostname(), tid, pid, + now_microsec); if (!suffix.empty()) { *prefix += suffix; diff --git a/third_party/xla/xla/tsl/platform/env.h b/third_party/xla/xla/tsl/platform/env.h new file mode 100644 index 00000000000000..9b302b8090ba2d --- /dev/null +++ b/third_party/xla/xla/tsl/platform/env.h @@ -0,0 +1,737 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TSL_PLATFORM_ENV_H_ +#define XLA_TSL_PLATFORM_ENV_H_ + +#include + +#include +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "xla/tsl/platform/env_time.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/file_system.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" +#include "tsl/platform/mutex.h" +#include "tsl/platform/numa.h" +#include "tsl/platform/platform.h" +#include "tsl/platform/protobuf.h" +#include "tsl/platform/stringpiece.h" + +// Delete leaked Windows definitions. +#ifdef PLATFORM_WINDOWS +#undef CopyFile +#undef DeleteFile +#endif + +namespace tsl { + +class Thread; +struct ThreadOptions; + +/// \brief An interface used by the tensorflow implementation to +/// access operating system functionality like the filesystem etc. +/// +/// Callers may wish to provide a custom Env object to get fine grain +/// control. +/// +/// All Env implementations of file-system modifying functionality are safe +/// for concurrent access from multiple threads without any external +/// synchronization, *however*, Envs and their underlying file systems are +/// global objects, and therefore, if any thread modifies options, the modified +/// options take effect process-wide. The SetOption functions themselves are +/// also *not* thread safe. +class Env { + public: + Env(); + virtual ~Env() = default; + + /// \brief Returns a default environment suitable for the current operating + /// system. + /// + /// Sophisticated users may wish to provide their own Env + /// implementation instead of relying on this default environment. + /// + /// The result of Default() belongs to this library and must never be deleted. + static Env* Default(); + + /// \brief Returns the FileSystem object to handle operations on the file + /// specified by 'fname'. The FileSystem object is used as the implementation + /// for the file system related (non-virtual) functions that follow. + /// Returned FileSystem object is still owned by the Env object and will + // (might) be destroyed when the environment is destroyed. + virtual absl::Status GetFileSystemForFile(const std::string& fname, + FileSystem** result); + + /// \brief Returns the file system schemes registered for this Env. + virtual absl::Status GetRegisteredFileSystemSchemes( + std::vector* schemes); + + /// \brief Register a file system for a scheme. + virtual absl::Status RegisterFileSystem(const std::string& scheme, + FileSystemRegistry::Factory factory); + + /// \brief Register a modular file system for a scheme. + /// + /// Same as `RegisterFileSystem` but for filesystems provided by plugins. + /// + /// TODO(b/139060984): After all filesystems are converted, make this be the + /// canonical registration function. + virtual absl::Status RegisterFileSystem( + const std::string& scheme, std::unique_ptr filesystem); + + absl::Status SetOption(const std::string& scheme, const std::string& key, + const std::string& value); + + absl::Status SetOption(const std::string& scheme, const std::string& key, + const std::vector& values); + + absl::Status SetOption(const std::string& scheme, const std::string& key, + const std::vector& values); + + absl::Status SetOption(const std::string& scheme, const std::string& key, + const std::vector& values); + + /// \brief Flush filesystem caches for all registered filesystems. + absl::Status FlushFileSystemCaches(); + + /// \brief Creates a brand new random access read-only file with the + /// specified name. + + /// On success, stores a pointer to the new file in + /// *result and returns OK. On failure stores NULL in *result and + /// returns non-OK. If the file does not exist, returns a non-OK + /// status. + /// + /// The returned file may be concurrently accessed by multiple threads. + /// + /// The ownership of the returned RandomAccessFile is passed to the caller + /// and the object should be deleted when is not used. The file object + /// shouldn't live longer than the Env object. + absl::Status NewRandomAccessFile(const std::string& fname, + std::unique_ptr* result); + + absl::Status NewRandomAccessFile(const std::string& fname, + TransactionToken* token, + std::unique_ptr* result) { + // We duplicate these methods due to Google internal coding style prevents + // virtual functions with default arguments. See PR #41615. + return absl::OkStatus(); + } + + /// \brief Creates an object that writes to a new file with the specified + /// name. + /// + /// Deletes any existing file with the same name and creates a + /// new file. On success, stores a pointer to the new file in + /// *result and returns OK. On failure stores NULL in *result and + /// returns non-OK. + /// + /// The returned file will only be accessed by one thread at a time. + /// + /// The ownership of the returned WritableFile is passed to the caller + /// and the object should be deleted when is not used. The file object + /// shouldn't live longer than the Env object. + absl::Status NewWritableFile(const std::string& fname, + std::unique_ptr* result); + + absl::Status NewWritableFile(const std::string& fname, + TransactionToken* token, + std::unique_ptr* result) { + return absl::OkStatus(); + } + + /// \brief Creates an object that either appends to an existing file, or + /// writes to a new file (if the file does not exist to begin with). + /// + /// On success, stores a pointer to the new file in *result and + /// returns OK. On failure stores NULL in *result and returns + /// non-OK. + /// + /// The returned file will only be accessed by one thread at a time. + /// + /// The ownership of the returned WritableFile is passed to the caller + /// and the object should be deleted when is not used. The file object + /// shouldn't live longer than the Env object. + absl::Status NewAppendableFile(const std::string& fname, + std::unique_ptr* result); + + absl::Status NewAppendableFile(const std::string& fname, + TransactionToken* token, + std::unique_ptr* result) { + return absl::OkStatus(); + } + /// \brief Creates a readonly region of memory with the file context. + /// + /// On success, it returns a pointer to read-only memory region + /// from the content of file fname. The ownership of the region is passed to + /// the caller. On failure stores nullptr in *result and returns non-OK. + /// + /// The returned memory region can be accessed from many threads in parallel. + /// + /// The ownership of the returned ReadOnlyMemoryRegion is passed to the caller + /// and the object should be deleted when is not used. The memory region + /// object shouldn't live longer than the Env object. + absl::Status NewReadOnlyMemoryRegionFromFile( + const std::string& fname, std::unique_ptr* result); + + absl::Status NewReadOnlyMemoryRegionFromFile( + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) { + return absl::OkStatus(); + } + + /// Returns OK if the named path exists and NOT_FOUND otherwise. + absl::Status FileExists(const std::string& fname); + + absl::Status FileExists(const std::string& fname, TransactionToken* token) { + return absl::OkStatus(); + } + + /// Returns true if all the listed files exist, false otherwise. + /// if status is not null, populate the vector with a detailed status + /// for each file. + bool FilesExist(const std::vector& files, + std::vector* status); + + bool FilesExist(const std::vector& files, TransactionToken* token, + std::vector* status) { + return true; + } + + /// \brief Stores in *result the names of the children of the specified + /// directory. The names are relative to "dir". + /// + /// Original contents of *results are dropped. + absl::Status GetChildren(const std::string& dir, std::vector* result); + + absl::Status GetChildren(const std::string& dir, TransactionToken* token, + std::vector* result) { + return absl::OkStatus(); + } + + /// \brief Returns true if the path matches the given pattern. The wildcards + /// allowed in pattern are described in FileSystem::GetMatchingPaths. + virtual bool MatchPath(const std::string& path, + const std::string& pattern) = 0; + + /// \brief Given a pattern, stores in *results the set of paths that matches + /// that pattern. *results is cleared. + /// + /// More details about `pattern` in FileSystem::GetMatchingPaths. + virtual absl::Status GetMatchingPaths(const std::string& pattern, + std::vector* results); + + absl::Status GetMatchingPaths(const std::string& pattern, + TransactionToken* token, + std::vector* results) { + return absl::OkStatus(); + } + + /// Deletes the named file. + absl::Status DeleteFile(const std::string& fname); + + absl::Status DeleteFile(const std::string& fname, TransactionToken* token) { + return absl::OkStatus(); + } + + /// \brief Deletes the specified directory and all subdirectories and files + /// underneath it. This is accomplished by traversing the directory tree + /// rooted at dirname and deleting entries as they are encountered. + /// + /// If dirname itself is not readable or does not exist, *undeleted_dir_count + /// is set to 1, *undeleted_file_count is set to 0 and an appropriate status + /// (e.g. NOT_FOUND) is returned. + /// + /// If dirname and all its descendants were successfully deleted, TF_OK is + /// returned and both error counters are set to zero. + /// + /// Otherwise, while traversing the tree, undeleted_file_count and + /// undeleted_dir_count are updated if an entry of the corresponding type + /// could not be deleted. The returned error status represents the reason that + /// any one of these entries could not be deleted. + /// + /// REQUIRES: undeleted_files, undeleted_dirs to be not null. + /// + /// Typical return codes: + /// * OK - dirname exists and we were able to delete everything underneath. + /// * NOT_FOUND - dirname doesn't exist + /// * PERMISSION_DENIED - dirname or some descendant is not writable + /// * UNIMPLEMENTED - Some underlying functions (like Delete) are not + /// implemented + absl::Status DeleteRecursively(const std::string& dirname, + int64_t* undeleted_files, + int64_t* undeleted_dirs); + + absl::Status DeleteRecursively(const std::string& dirname, + TransactionToken* token, + int64_t* undeleted_files, + int64_t* undeleted_dirs) { + return absl::OkStatus(); + } + + /// \brief Creates the specified directory and all the necessary + /// subdirectories. Typical return codes. + /// * OK - successfully created the directory and sub directories, even if + /// they were already created. + /// * PERMISSION_DENIED - dirname or some subdirectory is not writable. + absl::Status RecursivelyCreateDir(const std::string& dirname); + + absl::Status RecursivelyCreateDir(const std::string& dirname, + TransactionToken* token) { + return absl::OkStatus(); + } + /// \brief Creates the specified directory. Typical return codes + /// * OK - successfully created the directory. + /// * ALREADY_EXISTS - directory already exists. + /// * PERMISSION_DENIED - dirname is not writable. + absl::Status CreateDir(const std::string& dirname); + + absl::Status CreateDir(const std::string& dirname, TransactionToken* token) { + return absl::OkStatus(); + } + + /// Deletes the specified directory. + absl::Status DeleteDir(const std::string& dirname); + + absl::Status DeleteDir(const std::string& dirname, TransactionToken* token) { + return absl::OkStatus(); + } + + /// Obtains statistics for the given path. + absl::Status Stat(const std::string& fname, FileStatistics* stat); + + absl::Status Stat(const std::string& fname, TransactionToken* token, + FileStatistics* stat) { + return absl::OkStatus(); + } + + /// \brief Returns whether the given path is a directory or not. + /// Typical return codes (not guaranteed exhaustive): + /// * OK - The path exists and is a directory. + /// * FAILED_PRECONDITION - The path exists and is not a directory. + /// * NOT_FOUND - The path entry does not exist. + /// * PERMISSION_DENIED - Insufficient permissions. + /// * UNIMPLEMENTED - The file factory doesn't support directories. + absl::Status IsDirectory(const std::string& fname); + + /// \brief Returns whether the given path is on a file system + /// that has atomic move capabilities. This can be used + /// to determine if there needs to be a temp location to safely write objects. + /// The second boolean argument has_atomic_move contains this information. + /// + /// Returns one of the following status codes (not guaranteed exhaustive): + /// * OK - The path is on a recognized file system, + /// so has_atomic_move holds the above information. + /// * UNIMPLEMENTED - The file system of the path hasn't been implemented in + /// TF + absl::Status HasAtomicMove(const std::string& path, bool* has_atomic_move); + + /// Returns whether the give path is on a file system + /// that has ability to create a new temp file. This can be used + /// to determine if there needs to be a temp location to safely write objects. + /// If this returns false, TensorFlow will write directly to output files + /// instead of creating a temporary file and swapping it in. This may mean + /// that incomplete writes are visible to consumers. + absl::Status CanCreateTempFile(const std::string& fname, + bool* can_create_temp_file); + + /// Stores the size of `fname` in `*file_size`. + absl::Status GetFileSize(const std::string& fname, uint64* file_size); + + absl::Status GetFileSize(const std::string& fname, TransactionToken* token, + uint64* file_size) { + return absl::OkStatus(); + } + + /// \brief Renames file src to target. If target already exists, it will be + /// replaced. + absl::Status RenameFile(const std::string& src, const std::string& target); + + absl::Status RenameFile(const std::string& src, const std::string& target, + TransactionToken* token) { + return absl::OkStatus(); + } + + /// \brief Copy the src to target. + absl::Status CopyFile(const std::string& src, const std::string& target); + + absl::Status CopyFile(const std::string& src, const std::string& target, + TransactionToken* token) { + return absl::OkStatus(); + } + + /// \brief starts a new transaction on the filesystem that handles filename + absl::Status StartTransaction(const std::string& filename, + TransactionToken** token) { + *token = nullptr; + return absl::OkStatus(); + } + + /// \brief Adds `path` to transaction in `token` if token belongs to + /// filesystem that handles the path. + absl::Status AddToTransaction(const std::string& path, + TransactionToken* token) { + return absl::OkStatus(); + } + + /// \brief Get token for `path` or start a new transaction and add `path` to + /// it. + absl::Status GetTokenOrStartTransaction(const std::string& path, + TransactionToken** token) { + *token = nullptr; + return absl::OkStatus(); + } + + /// \brief Returns the transaction for `path` or nullptr in `token` + absl::Status GetTransactionForPath(const std::string& path, + TransactionToken** token) { + *token = nullptr; + return absl::OkStatus(); + } + + /// \brief Finalizes the transaction + absl::Status EndTransaction(TransactionToken* token) { + return absl::OkStatus(); + } + + /// \brief Returns the absolute path of the current executable. It resolves + /// symlinks if there is any. + std::string GetExecutablePath(); + + /// Creates a local unique temporary file name. Returns true if success. + bool LocalTempFilename(std::string* filename); + + /// Creates a local unique file name that starts with |prefix| and ends with + /// |suffix|. Returns true if success. + bool CreateUniqueFileName(std::string* prefix, const std::string& suffix); + + /// \brief Return the runfiles directory if running under bazel. Returns + /// the directory the executable is located in if not running under bazel. + virtual std::string GetRunfilesDir() = 0; + + // TODO(jeff,sanjay): Add back thread/thread-pool support if needed. + // TODO(jeff,sanjay): if needed, tighten spec so relative to epoch, or + // provide a routine to get the absolute time. + + /// \brief Returns the number of nano-seconds since the Unix epoch. + virtual uint64 NowNanos() const { return EnvTime::NowNanos(); } + + /// \brief Returns the number of micro-seconds since the Unix epoch. + virtual uint64 NowMicros() const { return EnvTime::NowMicros(); } + + /// \brief Returns the number of seconds since the Unix epoch. + virtual uint64 NowSeconds() const { return EnvTime::NowSeconds(); } + + /// Sleeps/delays the thread for the prescribed number of micro-seconds. + virtual void SleepForMicroseconds(int64_t micros) = 0; + + /// Returns the process ID of the calling process. + int32 GetProcessId(); + + /// \brief Returns a new thread that is running fn() and is identified + /// (for debugging/performance-analysis) by "name". + /// + /// Caller takes ownership of the result and must delete it eventually + /// (the deletion will block until fn() stops running). + virtual Thread* StartThread( + const ThreadOptions& thread_options, const std::string& name, + absl::AnyInvocable fn) TF_MUST_USE_RESULT = 0; + + // Returns the thread id of calling thread. + // Posix: Returns pthread id which is only guaranteed to be unique within a + // process. + // Windows: Returns thread id which is unique. + virtual int64_t GetCurrentThreadId() = 0; + + // Copies current thread name to "name". Returns true if success. + virtual bool GetCurrentThreadName(std::string* name) = 0; + + // \brief Schedules the given closure on a thread-pool. + // + // NOTE(mrry): This closure may block. + virtual void SchedClosure(absl::AnyInvocable closure) = 0; + + // \brief Schedules the given closure on a thread-pool after the given number + // of microseconds. + // + // NOTE(mrry): This closure must not block. + virtual void SchedClosureAfter(int64_t micros, + absl::AnyInvocable closure) = 0; + + // \brief Load a dynamic library. + // + // Pass "library_filename" to a platform-specific mechanism for dynamically + // loading a library. The rules for determining the exact location of the + // library are platform-specific and are not documented here. + // + // On success, returns a handle to the library in "*handle" and returns + // OK from the function. + // Otherwise returns nullptr in "*handle" and an error status from the + // function. + virtual absl::Status LoadDynamicLibrary(const char* library_filename, + void** handle) = 0; + + // \brief Get a pointer to a symbol from a dynamic library. + // + // "handle" should be a pointer returned from a previous call to LoadLibrary. + // On success, store a pointer to the located symbol in "*symbol" and return + // OK from the function. Otherwise, returns nullptr in "*symbol" and an error + // status from the function. + virtual absl::Status GetSymbolFromLibrary(void* handle, + const char* symbol_name, + void** symbol) = 0; + + // \brief build the name of dynamic library. + // + // "name" should be name of the library. + // "version" should be the version of the library or NULL + // returns the name that LoadLibrary() can use + virtual std::string FormatLibraryFileName(const std::string& name, + const std::string& version) = 0; + + // Returns a possible list of local temporary directories. + virtual void GetLocalTempDirectories(std::vector* list) = 0; + + private: + std::unique_ptr file_system_registry_; + Env(const Env&) = delete; + void operator=(const Env&) = delete; +}; + +/// \brief An implementation of Env that forwards all calls to another Env. +/// +/// May be useful to clients who wish to override just part of the +/// functionality of another Env. +class EnvWrapper : public Env { + public: + /// Initializes an EnvWrapper that delegates all calls to *t + explicit EnvWrapper(Env* t) : target_(t) {} + ~EnvWrapper() override; + + /// Returns the target to which this Env forwards all calls + Env* target() const { return target_; } + + absl::Status GetFileSystemForFile(const std::string& fname, + FileSystem** result) override { + return target_->GetFileSystemForFile(fname, result); + } + + absl::Status GetRegisteredFileSystemSchemes( + std::vector* schemes) override { + return target_->GetRegisteredFileSystemSchemes(schemes); + } + + absl::Status RegisterFileSystem( + const std::string& scheme, FileSystemRegistry::Factory factory) override { + return target_->RegisterFileSystem(scheme, factory); + } + + bool MatchPath(const std::string& path, const std::string& pattern) override { + return target_->MatchPath(path, pattern); + } + + uint64 NowMicros() const override { return target_->NowMicros(); } + void SleepForMicroseconds(int64_t micros) override { + target_->SleepForMicroseconds(micros); + } + Thread* StartThread(const ThreadOptions& thread_options, + const std::string& name, + absl::AnyInvocable fn) override { + return target_->StartThread(thread_options, name, std::move(fn)); + } + int64_t GetCurrentThreadId() override { + return target_->GetCurrentThreadId(); + } + bool GetCurrentThreadName(std::string* name) override { + return target_->GetCurrentThreadName(name); + } + void SchedClosure(absl::AnyInvocable closure) override { + target_->SchedClosure(std::move(closure)); + } + void SchedClosureAfter(int64_t micros, + absl::AnyInvocable closure) override { + target_->SchedClosureAfter(micros, std::move(closure)); + } + absl::Status LoadDynamicLibrary(const char* library_filename, + void** handle) override { + return target_->LoadDynamicLibrary(library_filename, handle); + } + absl::Status GetSymbolFromLibrary(void* handle, const char* symbol_name, + void** symbol) override { + return target_->GetSymbolFromLibrary(handle, symbol_name, symbol); + } + std::string FormatLibraryFileName(const std::string& name, + const std::string& version) override { + return target_->FormatLibraryFileName(name, version); + } + + std::string GetRunfilesDir() override { return target_->GetRunfilesDir(); } + + private: + void GetLocalTempDirectories(std::vector* list) override { + target_->GetLocalTempDirectories(list); + } + + Env* target_; +}; + +/// Represents a thread used to run a TSL function. +class Thread { + public: + Thread() {} + + /// Blocks until the thread of control stops running. + virtual ~Thread(); + + private: + Thread(const Thread&) = delete; + void operator=(const Thread&) = delete; +}; + +/// \brief Cross-platform setenv. +/// +/// Since setenv() is not available on windows, we provide an +/// alternative with platform specific implementations here. +int setenv(const char* name, const char* value, int overwrite); + +/// Cross-platform unsetenv. +int unsetenv(const char* name); + +/// \brief Options to configure a Thread. +/// +/// Note that the options are all hints, and the +/// underlying implementation may choose to ignore it. +struct ThreadOptions { + /// Thread stack size to use (in bytes). + size_t stack_size = 0; // 0: use system default value + /// Guard area size to use near thread stacks to use (in bytes) + size_t guard_size = 0; // 0: use system default value + int numa_node = port::kNUMANoAffinity; +}; + +/// A utility routine: copy contents of `src` in file system `src_fs` +/// to `target` in file system `target_fs`. +absl::Status FileSystemCopyFile(FileSystem* src_fs, const std::string& src, + FileSystem* target_fs, + const std::string& target); + +/// A utility routine: reads contents of named file into `*data` +absl::Status ReadFileToString(Env* env, const std::string& fname, + std::string* data); + +/// A utility routine: write contents of `data` to file named `fname` +/// (overwriting existing contents, if any). +absl::Status WriteStringToFile(Env* env, const std::string& fname, + const absl::string_view& data); + +/// Write binary representation of "proto" to the named file. +absl::Status WriteBinaryProto(Env* env, const std::string& fname, + const protobuf::MessageLite& proto); + +/// Reads contents of named file and parse as binary encoded proto data +/// and store into `*proto`. +absl::Status ReadBinaryProto(Env* env, const std::string& fname, + protobuf::MessageLite* proto); + +/// Write the text representation of "proto" to the named file. +inline absl::Status WriteTextProto(Env* /* env */, + const std::string& /* fname */, + const protobuf::MessageLite& /* proto */) { + return errors::Unimplemented("Can't write text protos with protolite."); +} +absl::Status WriteTextProto(Env* env, const std::string& fname, + const protobuf::Message& proto); + +/// Read contents of named file and parse as text encoded proto data +/// and store into `*proto`. +inline absl::Status ReadTextProto(Env* /* env */, + const std::string& /* fname */, + protobuf::MessageLite* /* proto */) { + return errors::Unimplemented("Can't parse text protos with protolite."); +} +absl::Status ReadTextProto(Env* env, const std::string& fname, + protobuf::Message* proto); + +/// Read contents of named file and parse as either text or binary encoded proto +/// data and store into `*proto`. +absl::Status ReadTextOrBinaryProto(Env* env, const std::string& fname, + protobuf::Message* proto); +absl::Status ReadTextOrBinaryProto(Env* env, const std::string& fname, + protobuf::MessageLite* proto); + +// START_SKIP_DOXYGEN + +// The following approach to register filesystems is deprecated and will be +// replaced with modular filesystem plugins registration. +// TODO(b/139060984): After all filesystems are converted, remove this. +namespace register_file_system { + +template +struct Register { + Register(Env* env, const std::string& scheme, bool try_modular_filesystems) { + // TODO(yongtang): Remove legacy file system registration for hdfs/s3/gcs + // after TF 2.6+. + if (try_modular_filesystems) { + const char* env_value = getenv("TF_USE_MODULAR_FILESYSTEM"); + string load_plugin = env_value ? absl::AsciiStrToLower(env_value) : ""; + if (load_plugin == "true" || load_plugin == "1") { + // We don't register the static filesystem and wait for SIG IO one + LOG(WARNING) << "Using modular file system for '" << scheme << "'." + << " Please switch to tensorflow-io" + << " (https://github.com/tensorflow/io) for file system" + << " support of '" << scheme << "'."; + return; + } + // If the envvar is missing or not "true"/"1", then fall back to legacy + // implementation to be backwards compatible. + } + // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object! + env->RegisterFileSystem(scheme, []() -> FileSystem* { return new Factory; }) + .IgnoreError(); + } +}; + +} // namespace register_file_system + +// END_SKIP_DOXYGEN + +} // namespace tsl + +// Register a FileSystem implementation for a scheme. Files with names that have +// "scheme://" prefixes are routed to use this implementation. +#define REGISTER_FILE_SYSTEM_ENV(env, scheme, factory, modular) \ + REGISTER_FILE_SYSTEM_UNIQ_HELPER(__COUNTER__, env, scheme, factory, modular) +#define REGISTER_FILE_SYSTEM_UNIQ_HELPER(ctr, env, scheme, factory, modular) \ + REGISTER_FILE_SYSTEM_UNIQ(ctr, env, scheme, factory, modular) +#define REGISTER_FILE_SYSTEM_UNIQ(ctr, env, scheme, factory, modular) \ + static ::tsl::register_file_system::Register register_ff##ctr \ + TF_ATTRIBUTE_UNUSED = \ + ::tsl::register_file_system::Register(env, scheme, modular) + +#define REGISTER_FILE_SYSTEM(scheme, factory) \ + REGISTER_FILE_SYSTEM_ENV(::tsl::Env::Default(), scheme, factory, false); + +#define REGISTER_LEGACY_FILE_SYSTEM(scheme, factory) \ + REGISTER_FILE_SYSTEM_ENV(::tsl::Env::Default(), scheme, factory, true); + +#endif // XLA_TSL_PLATFORM_ENV_H_ diff --git a/third_party/xla/xla/tsl/platform/env_time.h b/third_party/xla/xla/tsl/platform/env_time.h new file mode 100644 index 00000000000000..f37e3129f45697 --- /dev/null +++ b/third_party/xla/xla/tsl/platform/env_time.h @@ -0,0 +1,65 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_TSL_PLATFORM_ENV_TIME_H_ +#define XLA_TSL_PLATFORM_ENV_TIME_H_ + +#include + +#include "xla/tsl/platform/types.h" + +namespace tsl { + +/// \brief An interface used by the tsl implementation to +/// access timer related operations. +class EnvTime { + public: + static constexpr uint64 kMicrosToPicos = 1000ULL * 1000ULL; + static constexpr uint64 kMicrosToNanos = 1000ULL; + static constexpr uint64 kMillisToMicros = 1000ULL; + static constexpr uint64 kMillisToNanos = 1000ULL * 1000ULL; + static constexpr uint64 kNanosToPicos = 1000ULL; + static constexpr uint64 kSecondsToMillis = 1000ULL; + static constexpr uint64 kSecondsToMicros = 1000ULL * 1000ULL; + static constexpr uint64 kSecondsToNanos = 1000ULL * 1000ULL * 1000ULL; + + EnvTime() = default; + virtual ~EnvTime() = default; + + /// \brief Returns the number of nano-seconds since the Unix epoch. + static uint64 NowNanos(); + + /// \brief Returns the number of micro-seconds since the Unix epoch. + static uint64 NowMicros() { return NowNanos() / kMicrosToNanos; } + + /// \brief Returns the number of seconds since the Unix epoch. + static uint64 NowSeconds() { return NowNanos() / kSecondsToNanos; } + + /// \brief A version of NowNanos() that may be overridden by a subclass. + virtual uint64 GetOverridableNowNanos() const { return NowNanos(); } + + /// \brief A version of NowMicros() that may be overridden by a subclass. + virtual uint64 GetOverridableNowMicros() const { + return GetOverridableNowNanos() / kMicrosToNanos; + } + + /// \brief A version of NowSeconds() that may be overridden by a subclass. + virtual uint64 GetOverridableNowSeconds() const { + return GetOverridableNowNanos() / kSecondsToNanos; + } +}; + +} // namespace tsl + +#endif // XLA_TSL_PLATFORM_ENV_TIME_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/errors.cc b/third_party/xla/xla/tsl/platform/errors.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/platform/errors.cc rename to third_party/xla/xla/tsl/platform/errors.cc index 6c732a47849113..88aadeb1ac9f95 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/errors.cc +++ b/third_party/xla/xla/tsl/platform/errors.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" #include #include -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" #include "tsl/platform/strcat.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/errors.h b/third_party/xla/xla/tsl/platform/errors.h new file mode 100644 index 00000000000000..a285c1f9041e5d --- /dev/null +++ b/third_party/xla/xla/tsl/platform/errors.h @@ -0,0 +1,646 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TSL_PLATFORM_ERRORS_H_ +#define XLA_TSL_PLATFORM_ERRORS_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_join.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "tsl/platform/str_util.h" +#include "tsl/platform/strcat.h" + +namespace tsl { +namespace error { +// NOLINTBEGIN(misc-unused-using-decls) +// TODO(aminim): figure out the protobuf migration story. +using tensorflow::error::ABORTED; +using tensorflow::error::ALREADY_EXISTS; +using tensorflow::error::CANCELLED; +using tensorflow::error::Code; +using tensorflow::error::DATA_LOSS; +using tensorflow::error::DEADLINE_EXCEEDED; +using tensorflow::error::FAILED_PRECONDITION; +using tensorflow::error::INTERNAL; +using tensorflow::error::INVALID_ARGUMENT; +using tensorflow::error::NOT_FOUND; +using tensorflow::error::OK; +using tensorflow::error::OUT_OF_RANGE; +using tensorflow::error::PERMISSION_DENIED; +using tensorflow::error::RESOURCE_EXHAUSTED; +using tensorflow::error::UNAUTHENTICATED; +using tensorflow::error::UNAVAILABLE; +using tensorflow::error::UNIMPLEMENTED; +using tensorflow::error::UNKNOWN; +// NOLINTEND(misc-unused-using-decls) +} // namespace error + +namespace errors { + +namespace internal { + +// The DECLARE_ERROR macro below only supports types that can be converted +// into StrCat's AlphaNum. For the other types we rely on a slower path +// through std::stringstream. To add support of a new type, it is enough to +// make sure there is an operator<<() for it: +// +// std::ostream& operator<<(std::ostream& os, const MyType& foo) { +// os << foo.ToString(); +// return os; +// } +// Eventually absl::strings will have native support for this and we will be +// able to completely remove PrepareForStrCat(). +template +typename std::enable_if::value, + std::string>::type +PrepareForStrCat(const T& t) { + std::stringstream ss; + ss << t; + return ss.str(); +} +inline const strings::AlphaNum& PrepareForStrCat(const strings::AlphaNum& a) { + return a; +} + +} // namespace internal + +// Maps UNIX errors into a Status. +absl::Status IOError(const string& context, int err_number); + +// Returns all payloads from a Status as a key-value map. +inline std::unordered_map GetPayloads( + const absl::Status& status) { + std::unordered_map payloads; + status.ForEachPayload( + [&payloads](absl::string_view key, const absl::Cord& value) { + payloads[std::string(key)] = std::string(value); + }); + return payloads; +} + +// Inserts all given payloads into the given status. Will overwrite existing +// payloads if they exist with the same key. +inline void InsertPayloads( + absl::Status& status, + const std::unordered_map& payloads) { + for (const auto& payload : payloads) { + status.SetPayload(payload.first, absl::Cord(payload.second)); + } +} + +// Copies all payloads from one Status to another. Will overwrite existing +// payloads in the destination if they exist with the same key. +inline void CopyPayloads(const absl::Status& from, absl::Status& to) { + from.ForEachPayload([&to](absl::string_view key, const absl::Cord& value) { + to.SetPayload(key, value); + }); +} + +#if defined(PLATFORM_GOOGLE) +// Creates a new status with the given code, message and payloads. +inline absl::Status Create( + absl::StatusCode code, absl::string_view message, + const std::unordered_map& payloads, + absl::SourceLocation loc = absl::SourceLocation::current()) { + absl::Status status(code, message, loc); + InsertPayloads(status, payloads); + return status; +} +// Returns a new Status, replacing its message with the given. +inline absl::Status CreateWithUpdatedMessage(const absl::Status& status, + absl::string_view message) { + auto locations = status.GetSourceLocations(); + auto initial_loc = + locations.empty() ? absl::SourceLocation::current() : locations[0]; + absl::Status new_status = Create(static_cast(status.code()), + message, GetPayloads(status), initial_loc); + if (locations.size() > 1) { + for (auto loc : locations.subspan(1)) { + new_status.AddSourceLocation(loc); + } + } + return new_status; +} + +#else +inline ::absl::Status Create( + absl::StatusCode code, ::tsl::StringPiece message, + const std::unordered_map& payloads) { + Status status(code, message); + InsertPayloads(status, payloads); + return status; +} +// Returns a new Status, replacing its message with the given. +inline ::tsl::Status CreateWithUpdatedMessage(const ::tsl::Status& status, + ::tsl::StringPiece message) { + return Create(static_cast(status.code()), message, + GetPayloads(status)); +} +#endif + +// Append some context to an error message. Each time we append +// context put it on a new line, since it is possible for there +// to be several layers of additional context. +template +void AppendToMessage(absl::Status* status, Args... args) { + auto new_status = CreateWithUpdatedMessage( + *status, ::tsl::strings::StrCat(status->message(), "\n\t", args...)); + CopyPayloads(*status, new_status); + *status = std::move(new_status); +} + +// For propagating errors when calling a function. +#define TF_RETURN_IF_ERROR(...) \ + do { \ + ::absl::Status _status = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + MAYBE_ADD_SOURCE_LOCATION(_status) \ + return _status; \ + } \ + } while (0) + +#define TF_RETURN_WITH_CONTEXT_IF_ERROR(expr, ...) \ + do { \ + ::tsl::Status _status = (expr); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + ::tsl::errors::AppendToMessage(&_status, __VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +// Convenience functions for generating and using error status. +// Example usage: +// status.Update(errors::InvalidArgument("The ", foo, " isn't right.")); +// if (errors::IsInvalidArgument(status)) { ... } +// switch (status.code()) { case error::INVALID_ARGUMENT: ... } + +// CANCELLED +template +absl::Status Cancelled(Args... args) { + return absl::Status(absl::StatusCode::kCancelled, + ::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); +} +template +absl::Status CancelledWithPayloads( + const absl::string_view& message, + const std::unordered_map& payloads) { + return errors::Create(absl::StatusCode::kCancelled, message, payloads); +} + +// InvalidArgument +template +absl::Status InvalidArgument(Args... args) { + return absl::Status(absl::StatusCode::kInvalidArgument, + ::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); +} + +#if defined(PLATFORM_GOOGLE) +// Specialized overloads to capture source location for up to three arguments. +template +::absl::Status InvalidArgument( + Arg1 arg1, Arg2 arg2, Arg3 arg3, Arg4 arg4, + absl::SourceLocation loc = absl::SourceLocation::current()) { + return absl::Status( + absl::StatusCode::kInvalidArgument, + ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), + ::tsl::errors::internal::PrepareForStrCat(arg2), + ::tsl::errors::internal::PrepareForStrCat(arg3), + ::tsl::errors::internal::PrepareForStrCat(arg4)), + loc); +} +template +::absl::Status InvalidArgument( + Arg1 arg1, Arg2 arg2, Arg3 arg3, + absl::SourceLocation loc = absl::SourceLocation::current()) { + return absl::Status( + absl::StatusCode::kInvalidArgument, + ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), + ::tsl::errors::internal::PrepareForStrCat(arg2), + ::tsl::errors::internal::PrepareForStrCat(arg3)), + loc); +} +template +::absl::Status InvalidArgument( + Arg1 arg1, Arg2 arg2, + absl::SourceLocation loc = absl::SourceLocation::current()) { + return absl::Status( + absl::StatusCode::kInvalidArgument, + ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), + ::tsl::errors::internal::PrepareForStrCat(arg2)), + loc); +} +template +::absl::Status InvalidArgument( + Arg1 arg1, absl::SourceLocation loc = absl::SourceLocation::current()) { + return absl::Status( + absl::StatusCode::kInvalidArgument, + ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1)), + loc); +} +template +::absl::Status InvalidArgumentWithPayloads( + const absl::string_view& message, + const std::unordered_map& payloads, + absl::SourceLocation loc = absl::SourceLocation::current()) { + return errors::Create(absl::StatusCode::kInvalidArgument, message, payloads, + loc); +} +#else +template +::absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2, Arg3 arg3) { + return ::absl::Status( + absl::StatusCode::kInvalidArgument, + ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), + ::tsl::errors::internal::PrepareForStrCat(arg2), + ::tsl::errors::internal::PrepareForStrCat(arg3))); +} +template +::absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2) { + return ::absl::Status( + absl::StatusCode::kInvalidArgument, + ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), + ::tsl::errors::internal::PrepareForStrCat(arg2))); +} +template +::absl::Status InvalidArgument(Arg1 arg1) { + return ::absl::Status( + absl::StatusCode::kInvalidArgument, + ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1))); +} +template +::absl::Status InvalidArgumentWithPayloads( + const ::tsl::StringPiece& message, + const std::unordered_map& payloads) { + return errors::Create(absl::StatusCode::kInvalidArgument, message, payloads); +} +#endif + +// NotFound +template +absl::Status NotFound(Args... args) { + return absl::Status(absl::StatusCode::kNotFound, + ::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); +} +#if defined(PLATFORM_GOOGLE) +// Specialized overloads to capture source location for up to three arguments. +template +::absl::Status NotFound( + Arg1 arg1, Arg2 arg2, Arg3 arg3, + absl::SourceLocation loc = absl::SourceLocation::current()) { + return absl::Status( + absl::StatusCode::kNotFound, + ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), + ::tsl::errors::internal::PrepareForStrCat(arg2), + ::tsl::errors::internal::PrepareForStrCat(arg3)), + loc); +} +template +::absl::Status NotFound( + Arg1 arg1, Arg2 arg2, + absl::SourceLocation loc = absl::SourceLocation::current()) { + return absl::Status( + absl::StatusCode::kNotFound, + ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), + ::tsl::errors::internal::PrepareForStrCat(arg2)), + loc); +} +template +::absl::Status NotFound( + Arg1 arg1, absl::SourceLocation loc = absl::SourceLocation::current()) { + return absl::Status( + absl::StatusCode::kNotFound, + ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1)), + loc); +} +template +::absl::Status NotFoundWithPayloads( + const absl::string_view& message, + const std::unordered_map& payloads, + absl::SourceLocation loc = absl::SourceLocation::current()) { + return errors::Create(absl::StatusCode::kNotFound, message, payloads, loc); +} +#else +template +::absl::Status NotFound(Arg1 arg1, Arg2 arg2, Arg3 arg3) { + return ::absl::Status( + absl::StatusCode::kNotFound, + ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), + ::tsl::errors::internal::PrepareForStrCat(arg2), + ::tsl::errors::internal::PrepareForStrCat(arg3))); +} +template +::absl::Status NotFound(Arg1 arg1, Arg2 arg2) { + return ::absl::Status( + absl::StatusCode::kNotFound, + ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1), + ::tsl::errors::internal::PrepareForStrCat(arg2))); +} +template +::absl::Status NotFound(Arg1 arg1) { + return ::absl::Status( + absl::StatusCode::kNotFound, + ::tsl::strings::StrCat(::tsl::errors::internal::PrepareForStrCat(arg1))); +} +template +::absl::Status NotFoundWithPayloads( + const ::tsl::StringPiece& message, + const std::unordered_map& payloads) { + return errors::Create(absl::StatusCode::kNotFound, message, payloads); +} +#endif + +// AlreadyExists +template +absl::Status AlreadyExists(Args... args) { + return absl::Status(absl::StatusCode::kAlreadyExists, + ::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); +} +template +absl::Status AlreadyExistsWithPayloads( + const absl::string_view& message, + const std::unordered_map& payloads) { + return errors::Create(absl::StatusCode::kAlreadyExists, message, payloads); +} + +// ResourceExhausted +template +absl::Status ResourceExhausted(Args... args) { + return absl::Status(absl::StatusCode::kResourceExhausted, + ::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); +} +template +absl::Status ResourceExhaustedWithPayloads( + const absl::string_view& message, + const std::unordered_map& payloads) { + return errors::Create(absl::StatusCode::kResourceExhausted, message, + payloads); +} + +// Unavailable +template +absl::Status Unavailable(Args... args) { + return absl::Status(absl::StatusCode::kUnavailable, + ::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); +} +template +absl::Status UnavailableWithPayloads( + const absl::string_view& message, + const std::unordered_map& payloads) { + return errors::Create(absl::StatusCode::kUnavailable, message, payloads); +} + +// FailedPrecondition +template +absl::Status FailedPrecondition(Args... args) { + return absl::Status(absl::StatusCode::kFailedPrecondition, + ::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); +} +template +absl::Status FailedPreconditionWithPayloads( + const absl::string_view& message, + const std::unordered_map& payloads) { + return errors::Create(absl::StatusCode::kFailedPrecondition, message, + payloads); +} + +// OutOfRange +template +absl::Status OutOfRange(Args... args) { + return absl::Status(absl::StatusCode::kOutOfRange, + ::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); +} +template +absl::Status OutOfRangeWithPayloads( + const absl::string_view& message, + const std::unordered_map& payloads) { + return errors::Create(absl::StatusCode::kOutOfRange, message, payloads); +} + +// Unimplemented +template +absl::Status Unimplemented(Args... args) { + return absl::Status(absl::StatusCode::kUnimplemented, + ::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); +} +template +absl::Status UnimplementedWithPayloads( + const absl::string_view& message, + const std::unordered_map& payloads) { + return errors::Create(absl::StatusCode::kUnimplemented, message, payloads); +} + +// Internal +template +absl::Status Internal(Args... args) { + return absl::Status(absl::StatusCode::kInternal, + ::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); +} +template +absl::Status InternalWithPayloads( + const absl::string_view& message, + const std::unordered_map& payloads) { + return errors::Create(absl::StatusCode::kInternal, message, payloads); +} + +// Aborted +template +absl::Status Aborted(Args... args) { + return absl::Status(absl::StatusCode::kAborted, + ::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); +} +template +absl::Status AbortedWithPayloads( + const absl::string_view& message, + const std::unordered_map& payloads) { + return errors::Create(absl::StatusCode::kAborted, message, payloads); +} + +// DeadlineExceeded +template +absl::Status DeadlineExceeded(Args... args) { + return absl::Status(absl::StatusCode::kDeadlineExceeded, + ::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); +} +template +absl::Status DeadlineExceededWithPayloads( + const absl::string_view& message, + const std::unordered_map& payloads) { + return errors::Create(absl::StatusCode::kDeadlineExceeded, message, payloads); +} + +// DataLoss +template +absl::Status DataLoss(Args... args) { + return absl::Status(absl::StatusCode::kDataLoss, + ::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); +} +template +absl::Status DataLossWithPayloads( + const absl::string_view& message, + const std::unordered_map& payloads) { + return errors::Create(absl::StatusCode::kDataLoss, message, payloads); +} + +// Unknown +template +absl::Status Unknown(Args... args) { + return absl::Status(absl::StatusCode::kUnknown, + ::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); +} +template +absl::Status UnknownPayloads( + const absl::string_view& message, + const std::unordered_map& payloads) { + return errors::Create(absl::StatusCode::kUnknown, message, payloads); +} +// PermissionDenied +template +absl::Status PermissionDenied(Args... args) { + return absl::Status(absl::StatusCode::kPermissionDenied, + ::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); +} +template +absl::Status PermissionDeniedWithPayloads( + const absl::string_view& message, + const std::unordered_map& payloads) { + return errors::Create(absl::StatusCode::kPermissionDenied, message, payloads); +} + +// Unauthenticated +template +absl::Status Unauthenticated(Args... args) { + return absl::Status(absl::StatusCode::kUnauthenticated, + ::tsl::strings::StrCat( + ::tsl::errors::internal::PrepareForStrCat(args)...)); +} +template +absl::Status UnauthenticatedWithPayloads( + const absl::string_view& message, + const std::unordered_map& payloads) { + return errors::Create(absl::StatusCode::kUnauthenticated, message, payloads); +} + +bool IsAborted(const absl::Status& status); +bool IsAlreadyExists(const absl::Status& status); +bool IsCancelled(const absl::Status& status); +bool IsDataLoss(const absl::Status& status); +bool IsDeadlineExceeded(const absl::Status& status); +bool IsFailedPrecondition(const absl::Status& status); +bool IsInternal(const absl::Status& status); +bool IsInvalidArgument(const absl::Status& status); +bool IsNotFound(const absl::Status& status); +bool IsOutOfRange(const absl::Status& status); +bool IsPermissionDenied(const absl::Status& status); +bool IsResourceExhausted(const absl::Status& status); +bool IsUnauthenticated(const absl::Status& status); +bool IsUnavailable(const absl::Status& status); +bool IsUnimplemented(const absl::Status& status); +bool IsUnknown(const absl::Status& status); + +// Produces a formatted string pattern from the name which can uniquely identify +// this node upstream to produce an informative error message. The pattern +// followed is: {{node }} +// Note: The pattern below determines the regex _NODEDEF_NAME_RE in the file +// tensorflow/python/client/session.py +// LINT.IfChange +inline std::string FormatNodeNameForError(absl::string_view name) { + return strings::StrCat("{{node ", name, "}}"); +} +// LINT.ThenChange(//tensorflow/python/client/session.py) +template +std::string FormatNodeNamesForError(const T& names) { + return absl::StrJoin( + names, ", ", [](std::string* output, absl::string_view s) { + ::tsl::strings::StrAppend(output, FormatNodeNameForError(s)); + }); +} +// LINT.IfChange +inline std::string FormatColocationNodeForError(absl::string_view name) { + return strings::StrCat("{{colocation_node ", name, "}}"); +} +// LINT.ThenChange(//tensorflow/python/framework/error_interpolation.py) +template >> +std::string FormatColocationNodeForError(const T& names) { + return absl::StrJoin( + names, ", ", [](std::string* output, absl::string_view s) { + ::tsl::strings::StrAppend(output, FormatColocationNodeForError(s)); + }); +} + +inline std::string FormatFunctionForError(absl::string_view name) { + return strings::StrCat("{{function_node ", name, "}}"); +} + +inline absl::Status ReplaceErrorFromNonCommunicationOps( + const absl::Status s, absl::string_view op_name) { + assert(::tsl::errors::IsUnavailable(s)); + return absl::Status( + absl::StatusCode::kInternal, + strings::StrCat( + s.message(), "\nExecuting non-communication op <", op_name, + "> originally returned UnavailableError, and was replaced by " + "InternalError to avoid invoking TF network error handling logic.")); +} + +template +std::string FormatOriginalNodeLocationForError(const T& node_names, + const T& func_names) { + std::vector error_message; + for (int i = 0; i != node_names.size(); ++i) { + if (i != 0) { + error_message.push_back(", "); + } + if (i < func_names.size()) { + error_message.push_back(FormatFunctionForError(func_names[i])); + } + error_message.push_back(FormatNodeNameForError(node_names[i])); + } + return absl::StrJoin(error_message, ""); +} + +// The CanonicalCode() for non-errors. +using ::tsl::error::OK; // NOLINT + +} // namespace errors +} // namespace tsl + +#endif // XLA_TSL_PLATFORM_ERRORS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/errors_test.cc b/third_party/xla/xla/tsl/platform/errors_test.cc similarity index 95% rename from third_party/xla/third_party/tsl/tsl/platform/errors_test.cc rename to third_party/xla/xla/tsl/platform/errors_test.cc index 88a3a5a78f72a5..9058fcce8500f0 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/errors_test.cc +++ b/third_party/xla/xla/tsl/platform/errors_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" #include "absl/status/status.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { @@ -99,8 +99,9 @@ TEST(Status, StackTracePropagation) { ASSERT_EQ(sources.size(), 3); for (int i = 0; i < 3; ++i) { - ASSERT_EQ(sources[i].file_name(), - "third_party/tensorflow/tsl/platform/errors_test.cc"); + ASSERT_EQ( + sources[i].file_name(), + "third_party/tensorflow/compiler/xla/tsl/platform/errors_test.cc"); } } diff --git a/third_party/xla/xla/tsl/platform/file_statistics.h b/third_party/xla/xla/tsl/platform/file_statistics.h new file mode 100644 index 00000000000000..9686f54836c8a8 --- /dev/null +++ b/third_party/xla/xla/tsl/platform/file_statistics.h @@ -0,0 +1,39 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TSL_PLATFORM_FILE_STATISTICS_H_ +#define XLA_TSL_PLATFORM_FILE_STATISTICS_H_ + +#include "xla/tsl/platform/types.h" + +namespace tsl { + +struct FileStatistics { + // The length of the file or -1 if finding file length is not supported. + int64_t length = -1; + // The last modified time in nanoseconds. + int64_t mtime_nsec = 0; + // True if the file is a directory, otherwise false. + bool is_directory = false; + + FileStatistics() {} + FileStatistics(int64_t length, int64_t mtime_nsec, bool is_directory) + : length(length), mtime_nsec(mtime_nsec), is_directory(is_directory) {} + ~FileStatistics() {} +}; + +} // namespace tsl + +#endif // XLA_TSL_PLATFORM_FILE_STATISTICS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/file_system.cc b/third_party/xla/xla/tsl/platform/file_system.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/platform/file_system.cc rename to third_party/xla/xla/tsl/platform/file_system.cc index 453e04b3942e8a..715037913ed788 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/file_system.cc +++ b/third_party/xla/xla/tsl/platform/file_system.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/file_system.h" +#include "xla/tsl/platform/file_system.h" #include @@ -23,7 +23,7 @@ limitations under the License. #include #include -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" #if defined(PLATFORM_POSIX) || defined(IS_MOBILE_PLATFORM) || \ defined(PLATFORM_GOOGLE) @@ -33,8 +33,8 @@ limitations under the License. #endif // defined(PLATFORM_POSIX) || defined(IS_MOBILE_PLATFORM) || \ // defined(PLATFORM_GOOGLE) -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" #include "tsl/platform/platform.h" #include "tsl/platform/scanner.h" #include "tsl/platform/str_util.h" diff --git a/third_party/xla/xla/tsl/platform/file_system.h b/third_party/xla/xla/tsl/platform/file_system.h new file mode 100644 index 00000000000000..ba046fde42c11e --- /dev/null +++ b/third_party/xla/xla/tsl/platform/file_system.h @@ -0,0 +1,936 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TSL_PLATFORM_FILE_SYSTEM_H_ +#define XLA_TSL_PLATFORM_FILE_SYSTEM_H_ + +#include + +#include +#include +#include +#include +#include +#include + +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/file_statistics.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" +#include "tsl/platform/cord.h" +#include "tsl/platform/platform.h" +#include "tsl/platform/stringpiece.h" + +#ifdef PLATFORM_WINDOWS +#undef DeleteFile +#undef CopyFile +#undef TranslateName +#endif + +namespace tsl { + +class FileAcl; +class RandomAccessFile; +class ReadOnlyMemoryRegion; +class WritableFile; + +class FileSystem; +struct TransactionToken { + FileSystem* owner; + void* token; +}; + +/// A generic interface for accessing a file system. Implementations +/// of custom filesystem adapters must implement this interface, +/// RandomAccessFile, WritableFile, and ReadOnlyMemoryRegion classes. +class FileSystem { + public: + /// \brief Creates a brand new random access read-only file with the + /// specified name. + /// + /// On success, stores a pointer to the new file in + /// *result and returns OK. On failure stores NULL in *result and + /// returns non-OK. If the file does not exist, returns a non-OK + /// status. + /// + /// The returned file may be concurrently accessed by multiple threads. + /// + /// The ownership of the returned RandomAccessFile is passed to the caller + /// and the object should be deleted when is not used. + virtual absl::Status NewRandomAccessFile( + const std::string& fname, std::unique_ptr* result) { + return NewRandomAccessFile(fname, nullptr, result); + } + + virtual absl::Status NewRandomAccessFile( + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) { + // We duplicate these methods due to Google internal coding style prevents + // virtual functions with default arguments. See PR #41615. + return absl::OkStatus(); + } + + /// \brief Creates an object that writes to a new file with the specified + /// name. + /// + /// Deletes any existing file with the same name and creates a + /// new file. On success, stores a pointer to the new file in + /// *result and returns OK. On failure stores NULL in *result and + /// returns non-OK. + /// + /// The returned file will only be accessed by one thread at a time. + /// + /// The ownership of the returned WritableFile is passed to the caller + /// and the object should be deleted when is not used. + virtual absl::Status NewWritableFile(const std::string& fname, + std::unique_ptr* result) { + return NewWritableFile(fname, nullptr, result); + } + + virtual absl::Status NewWritableFile(const std::string& fname, + TransactionToken* token, + std::unique_ptr* result) { + return absl::OkStatus(); + } + + /// \brief Creates an object that either appends to an existing file, or + /// writes to a new file (if the file does not exist to begin with). + /// + /// On success, stores a pointer to the new file in *result and + /// returns OK. On failure stores NULL in *result and returns + /// non-OK. + /// + /// The returned file will only be accessed by one thread at a time. + /// + /// The ownership of the returned WritableFile is passed to the caller + /// and the object should be deleted when is not used. + virtual absl::Status NewAppendableFile( + const std::string& fname, std::unique_ptr* result) { + return NewAppendableFile(fname, nullptr, result); + } + + virtual absl::Status NewAppendableFile( + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) { + return absl::OkStatus(); + } + + /// \brief Creates a readonly region of memory with the file context. + /// + /// On success, it returns a pointer to read-only memory region + /// from the content of file fname. The ownership of the region is passed to + /// the caller. On failure stores nullptr in *result and returns non-OK. + /// + /// The returned memory region can be accessed from many threads in parallel. + /// + /// The ownership of the returned ReadOnlyMemoryRegion is passed to the caller + /// and the object should be deleted when is not used. + virtual absl::Status NewReadOnlyMemoryRegionFromFile( + const std::string& fname, std::unique_ptr* result) { + return NewReadOnlyMemoryRegionFromFile(fname, nullptr, result); + } + + virtual absl::Status NewReadOnlyMemoryRegionFromFile( + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) { + return absl::OkStatus(); + } + + /// Returns OK if the named path exists and NOT_FOUND otherwise. + virtual absl::Status FileExists(const std::string& fname) { + return FileExists(fname, nullptr); + } + + virtual absl::Status FileExists(const std::string& fname, + TransactionToken* token) { + return absl::OkStatus(); + } + + /// Returns true if all the listed files exist, false otherwise. + /// if status is not null, populate the vector with a detailed status + /// for each file. + virtual bool FilesExist(const std::vector& files, + std::vector* status) { + return FilesExist(files, nullptr, status); + } + + virtual bool FilesExist(const std::vector& files, + TransactionToken* token, + std::vector* status); + + /// \brief Returns the immediate children in the given directory. + /// + /// The returned paths are relative to 'dir'. + virtual absl::Status GetChildren(const std::string& dir, + std::vector* result) { + return GetChildren(dir, nullptr, result); + } + + virtual absl::Status GetChildren(const std::string& dir, + TransactionToken* token, + std::vector* result) { + return absl::OkStatus(); + } + + /// \brief Given a pattern, stores in *results the set of paths that matches + /// that pattern. *results is cleared. + /// + /// pattern must match all of a name, not just a substring. + /// + /// pattern: { term } + /// term: + /// '*': matches any sequence of non-'/' characters + /// '?': matches a single non-'/' character + /// '[' [ '^' ] { match-list } ']': + /// matches any single character (not) on the list + /// c: matches character c (c != '*', '?', '\\', '[') + /// '\\' c: matches character c + /// character-range: + /// c: matches character c (c != '\\', '-', ']') + /// '\\' c: matches character c + /// lo '-' hi: matches character c for lo <= c <= hi + /// + /// Typical return codes: + /// * OK - no errors + /// * UNIMPLEMENTED - Some underlying functions (like GetChildren) are not + /// implemented + virtual absl::Status GetMatchingPaths(const std::string& pattern, + std::vector* results) { + return GetMatchingPaths(pattern, nullptr, results); + } + + virtual absl::Status GetMatchingPaths(const std::string& pattern, + TransactionToken* token, + std::vector* results) { + return absl::OkStatus(); + } + + /// \brief Checks if the given filename matches the pattern. + /// + /// This function provides the equivalent of posix fnmatch, however it is + /// implemented without fnmatch to ensure that this can be used for cloud + /// filesystems on windows. For windows filesystems, it uses PathMatchSpec. + virtual bool Match(const std::string& filename, const std::string& pattern); + + /// \brief Obtains statistics for the given path. + virtual absl::Status Stat(const std::string& fname, FileStatistics* stat) { + return Stat(fname, nullptr, stat); + } + + virtual absl::Status Stat(const std::string& fname, TransactionToken* token, + FileStatistics* stat) { + return absl::OkStatus(); + } + + /// \brief Deletes the named file. + virtual absl::Status DeleteFile(const std::string& fname) { + return DeleteFile(fname, nullptr); + } + + virtual absl::Status DeleteFile(const std::string& fname, + TransactionToken* token) { + return absl::OkStatus(); + } + + /// \brief Creates the specified directory. + /// Typical return codes: + /// * OK - successfully created the directory. + /// * ALREADY_EXISTS - directory with name dirname already exists. + /// * PERMISSION_DENIED - dirname is not writable. + virtual absl::Status CreateDir(const std::string& dirname) { + return CreateDir(dirname, nullptr); + } + + virtual absl::Status CreateDir(const std::string& dirname, + TransactionToken* token) { + return absl::OkStatus(); + } + + /// \brief Creates the specified directory and all the necessary + /// subdirectories. + /// Typical return codes: + /// * OK - successfully created the directory and sub directories, even if + /// they were already created. + /// * PERMISSION_DENIED - dirname or some subdirectory is not writable. + virtual absl::Status RecursivelyCreateDir(const std::string& dirname) { + return RecursivelyCreateDir(dirname, nullptr); + } + + virtual absl::Status RecursivelyCreateDir(const std::string& dirname, + TransactionToken* token); + + /// \brief Deletes the specified directory. + virtual absl::Status DeleteDir(const std::string& dirname) { + return DeleteDir(dirname, nullptr); + } + + virtual absl::Status DeleteDir(const std::string& dirname, + TransactionToken* token) { + return absl::OkStatus(); + } + + /// \brief Deletes the specified directory and all subdirectories and files + /// underneath it. This is accomplished by traversing the directory tree + /// rooted at dirname and deleting entries as they are encountered. + /// + /// If dirname itself is not readable or does not exist, *undeleted_dir_count + /// is set to 1, *undeleted_file_count is set to 0 and an appropriate status + /// (e.g. NOT_FOUND) is returned. + /// + /// If dirname and all its descendants were successfully deleted, TF_OK is + /// returned and both error counters are set to zero. + /// + /// Otherwise, while traversing the tree, undeleted_file_count and + /// undeleted_dir_count are updated if an entry of the corresponding type + /// could not be deleted. The returned error status represents the reason that + /// any one of these entries could not be deleted. + /// + /// REQUIRES: undeleted_files, undeleted_dirs to be not null. + /// + /// Typical return codes: + /// * OK - dirname exists and we were able to delete everything underneath. + /// * NOT_FOUND - dirname doesn't exist + /// * PERMISSION_DENIED - dirname or some descendant is not writable + /// * UNIMPLEMENTED - Some underlying functions (like Delete) are not + /// implemented + virtual absl::Status DeleteRecursively(const std::string& dirname, + int64_t* undeleted_files, + int64_t* undeleted_dirs) { + return DeleteRecursively(dirname, nullptr, undeleted_files, undeleted_dirs); + } + + virtual absl::Status DeleteRecursively(const std::string& dirname, + TransactionToken* token, + int64_t* undeleted_files, + int64_t* undeleted_dirs); + + /// \brief Stores the size of `fname` in `*file_size`. + virtual absl::Status GetFileSize(const std::string& fname, + uint64* file_size) { + return GetFileSize(fname, nullptr, file_size); + } + + virtual absl::Status GetFileSize(const std::string& fname, + TransactionToken* token, uint64* file_size) { + return absl::OkStatus(); + } + + /// \brief Overwrites the target if it exists. + virtual absl::Status RenameFile(const std::string& src, + const std::string& target) { + return RenameFile(src, target, nullptr); + } + + virtual absl::Status RenameFile(const std::string& src, + const std::string& target, + TransactionToken* token) { + return absl::OkStatus(); + } + + /// \brief Copy the src to target. + virtual absl::Status CopyFile(const std::string& src, + const std::string& target) { + return CopyFile(src, target, nullptr); + } + + virtual absl::Status CopyFile(const std::string& src, + const std::string& target, + TransactionToken* token); + + /// \brief Translate an URI to a filename for the FileSystem implementation. + /// + /// The implementation in this class cleans up the path, removing + /// duplicate /'s, resolving .. and removing trailing '/'. + /// This respects relative vs. absolute paths, but does not + /// invoke any system calls (getcwd(2)) in order to resolve relative + /// paths with respect to the actual working directory. That is, this is + /// purely string manipulation, completely independent of process state. + virtual std::string TranslateName(const std::string& name) const; + + /// \brief Returns whether the given path is a directory or not. + /// + /// Typical return codes (not guaranteed exhaustive): + /// * OK - The path exists and is a directory. + /// * FAILED_PRECONDITION - The path exists and is not a directory. + /// * NOT_FOUND - The path entry does not exist. + /// * PERMISSION_DENIED - Insufficient permissions. + /// * UNIMPLEMENTED - The file factory doesn't support directories. + virtual absl::Status IsDirectory(const std::string& fname) { + return IsDirectory(fname, nullptr); + } + + virtual absl::Status IsDirectory(const std::string& fname, + TransactionToken* token); + + /// \brief Returns whether the given path is on a file system + /// that has atomic move capabilities. This can be used + /// to determine if there needs to be a temp location to safely write objects. + /// The second boolean argument has_atomic_move contains this information. + /// + /// Returns one of the following status codes (not guaranteed exhaustive): + /// * OK - The path is on a recognized file system, + /// so has_atomic_move holds the above information. + /// * UNIMPLEMENTED - The file system of the path hasn't been implemented in + /// TF + virtual absl::Status HasAtomicMove(const std::string& path, + bool* has_atomic_move); + + /// Returns whether the give path is on a file system + /// that has ability to create a new temp file. This can be used + /// to determine if there needs to be a temp location to safely write objects. + /// If the file system cannot create a temp file, it's possibile that + /// uncomplete result may appear in the given file. + virtual absl::Status CanCreateTempFile(const std::string& fname, + bool* can_create_temp_file); + + /// \brief Flushes any cached filesystem objects from memory. + virtual void FlushCaches() { FlushCaches(nullptr); } + + virtual void FlushCaches(TransactionToken* token); + + /// \brief The separator this filesystem uses. + /// + /// This is implemented as a part of the filesystem, because even on windows, + /// a user may need access to filesystems with '/' separators, such as cloud + /// filesystems. + virtual char Separator() const; + + /// \brief Split a path to its basename and dirname. + /// + /// Helper function for Basename and Dirname. + std::pair SplitPath( + absl::string_view uri) const; + + /// \brief returns the final file name in the given path. + /// + /// Returns the part of the path after the final "/". If there is no + /// "/" in the path, the result is the same as the input. + virtual absl::string_view Basename(absl::string_view path) const; + + /// \brief Returns the part of the path before the final "/". + /// + /// If there is a single leading "/" in the path, the result will be the + /// leading "/". If there is no "/" in the path, the result is the empty + /// prefix of the input. + absl::string_view Dirname(absl::string_view path) const; + + /// \brief Returns the part of the basename of path after the final ".". + /// + /// If there is no "." in the basename, the result is empty. + absl::string_view Extension(absl::string_view path) const; + + /// \brief Clean duplicate and trailing, "/"s, and resolve ".." and ".". + /// + /// NOTE: This respects relative vs. absolute paths, but does not + /// invoke any system calls (getcwd(2)) in order to resolve relative + /// paths with respect to the actual working directory. That is, this is + /// purely string manipulation, completely independent of process state. + std::string CleanPath(absl::string_view path) const; + + /// \brief Creates a URI from a scheme, host, and path. + /// + /// If the scheme is empty, we just return the path. + std::string CreateURI(absl::string_view scheme, absl::string_view host, + absl::string_view path) const; + + /// \brief Return true if path is absolute. + bool IsAbsolutePath(absl::string_view path) const; + +#ifndef SWIG // variadic templates + /// \brief Join multiple paths together. + /// + /// This function also removes the unnecessary path separators. + /// For example: + /// + /// Arguments | JoinPath + /// ---------------------------+---------- + /// '/foo', 'bar' | /foo/bar + /// '/foo/', 'bar' | /foo/bar + /// '/foo', '/bar' | /foo/bar + /// + /// Usage: + /// string path = io::JoinPath("/mydir", filename); + /// string path = io::JoinPath(FLAGS_test_srcdir, filename); + /// string path = io::JoinPath("/full", "path", "to", "filename"); + template + std::string JoinPath(const T&... args) { + return JoinPathImpl({args...}); + } +#endif /* SWIG */ + + std::string JoinPathImpl(std::initializer_list paths); + + /// \brief Populates the scheme, host, and path from a URI. + /// + /// scheme, host, and path are guaranteed by this function to point into the + /// contents of uri, even if empty. + /// + /// Corner cases: + /// - If the URI is invalid, scheme and host are set to empty strings and the + /// passed string is assumed to be a path + /// - If the URI omits the path (e.g. file://host), then the path is left + /// empty. + void ParseURI(absl::string_view remaining, absl::string_view* scheme, + absl::string_view* host, absl::string_view* path) const; + + // Transaction related API + + /// \brief Starts a new transaction + virtual absl::Status StartTransaction(TransactionToken** token) { + *token = nullptr; + return absl::OkStatus(); + } + + /// \brief Adds `path` to transaction in `token` + virtual absl::Status AddToTransaction(const std::string& path, + TransactionToken* token) { + return absl::OkStatus(); + } + + /// \brief Ends transaction + virtual absl::Status EndTransaction(TransactionToken* token) { + return absl::OkStatus(); + } + + /// \brief Get token for `path` or start a new transaction and add `path` to + /// it. + virtual absl::Status GetTokenOrStartTransaction(const std::string& path, + TransactionToken** token) { + *token = nullptr; + return absl::OkStatus(); + } + + /// \brief Return transaction for `path` or nullptr in `token` + virtual absl::Status GetTransactionForPath(const std::string& path, + TransactionToken** token) { + *token = nullptr; + return absl::OkStatus(); + } + + /// \brief Decode transaction to human readable string. + virtual std::string DecodeTransaction(const TransactionToken* token); + + /// \brief Set File System Configuration Options + virtual absl::Status SetOption(const string& key, const string& value) { + return errors::Unimplemented("SetOption"); + } + + /// \brief Set File System Configuration Option + virtual absl::Status SetOption(const std::string& name, + const std::vector& values) { + return errors::Unimplemented("SetOption"); + } + + /// \brief Set File System Configuration Option + virtual absl::Status SetOption(const std::string& name, + const std::vector& values) { + return errors::Unimplemented("SetOption"); + } + + /// \brief Set File System Configuration Option + virtual absl::Status SetOption(const std::string& name, + const std::vector& values) { + return errors::Unimplemented("SetOption"); + } + + /// \brief Set File System ACL checker. + /// + /// No checks are enforced if a FileAcl is never set. + virtual absl::Status SetFileAcl(std::shared_ptr file_acl) { + return errors::Unimplemented("SetFileAcl"); + } + + FileSystem() {} + + virtual ~FileSystem() = default; +}; +/// This macro adds forwarding methods from FileSystem class to +/// used class since name hiding will prevent these to be accessed from +/// derived classes and would require all use locations to migrate to +/// Transactional API. This is an interim solution until ModularFileSystem class +/// becomes a singleton. +// TODO(sami): Remove this macro when filesystem plugins migration is complete. +#define TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT \ + using FileSystem::NewRandomAccessFile; \ + using FileSystem::NewWritableFile; \ + using FileSystem::NewAppendableFile; \ + using FileSystem::NewReadOnlyMemoryRegionFromFile; \ + using FileSystem::FileExists; \ + using FileSystem::GetChildren; \ + using FileSystem::GetMatchingPaths; \ + using FileSystem::Stat; \ + using FileSystem::DeleteFile; \ + using FileSystem::RecursivelyCreateDir; \ + using FileSystem::DeleteDir; \ + using FileSystem::DeleteRecursively; \ + using FileSystem::GetFileSize; \ + using FileSystem::RenameFile; \ + using FileSystem::CopyFile; \ + using FileSystem::IsDirectory; \ + using FileSystem::FlushCaches + +/// A Wrapper class for Transactional FileSystem support. +/// This provides means to make use of the transactions with minimal code change +/// Any operations that are done through this interface will be through the +/// transaction created at the time of construction of this instance. +/// See FileSystem documentation for method descriptions. +/// This class simply forwards all calls to wrapped filesystem either with given +/// transaction token or with token used in its construction. This allows doing +/// transactional filesystem access with minimal code change. +class WrappedFileSystem : public FileSystem { + public: + TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT; + + absl::Status NewRandomAccessFile( + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) override { + return fs_->NewRandomAccessFile(fname, (token ? token : token_), result); + } + + absl::Status NewWritableFile(const std::string& fname, + TransactionToken* token, + std::unique_ptr* result) override { + return fs_->NewWritableFile(fname, (token ? token : token_), result); + } + + absl::Status NewAppendableFile( + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) override { + return fs_->NewAppendableFile(fname, (token ? token : token_), result); + } + + absl::Status NewReadOnlyMemoryRegionFromFile( + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) override { + return fs_->NewReadOnlyMemoryRegionFromFile(fname, (token ? token : token_), + result); + } + + absl::Status FileExists(const std::string& fname, + TransactionToken* token) override { + return fs_->FileExists(fname, (token ? token : token_)); + } + + bool FilesExist(const std::vector& files, TransactionToken* token, + std::vector* status) override { + return fs_->FilesExist(files, (token ? token : token_), status); + } + + absl::Status GetChildren(const std::string& dir, TransactionToken* token, + std::vector* result) override { + return fs_->GetChildren(dir, (token ? token : token_), result); + } + + absl::Status GetMatchingPaths(const std::string& pattern, + TransactionToken* token, + std::vector* results) override { + return fs_->GetMatchingPaths(pattern, (token ? token : token_), results); + } + + bool Match(const std::string& filename, const std::string& pattern) override { + return fs_->Match(filename, pattern); + } + + absl::Status Stat(const std::string& fname, TransactionToken* token, + FileStatistics* stat) override { + return fs_->Stat(fname, (token ? token : token_), stat); + } + + absl::Status DeleteFile(const std::string& fname, + TransactionToken* token) override { + return fs_->DeleteFile(fname, (token ? token : token_)); + } + + absl::Status CreateDir(const std::string& dirname, + TransactionToken* token) override { + return fs_->CreateDir(dirname, (token ? token : token_)); + } + + absl::Status RecursivelyCreateDir(const std::string& dirname, + TransactionToken* token) override { + return fs_->RecursivelyCreateDir(dirname, (token ? token : token_)); + } + + absl::Status DeleteDir(const std::string& dirname, + TransactionToken* token) override { + return fs_->DeleteDir(dirname, (token ? token : token_)); + } + + absl::Status DeleteRecursively(const std::string& dirname, + TransactionToken* token, + int64_t* undeleted_files, + int64_t* undeleted_dirs) override { + return fs_->DeleteRecursively(dirname, (token ? token : token_), + undeleted_files, undeleted_dirs); + } + + absl::Status GetFileSize(const std::string& fname, TransactionToken* token, + uint64* file_size) override { + return fs_->GetFileSize(fname, (token ? token : token_), file_size); + } + + absl::Status RenameFile(const std::string& src, const std::string& target, + TransactionToken* token) override { + return fs_->RenameFile(src, target, (token ? token : token_)); + } + + absl::Status CopyFile(const std::string& src, const std::string& target, + TransactionToken* token) override { + return fs_->CopyFile(src, target, (token ? token : token_)); + } + + std::string TranslateName(const std::string& name) const override { + return fs_->TranslateName(name); + } + + absl::Status IsDirectory(const std::string& fname, + TransactionToken* token) override { + return fs_->IsDirectory(fname, (token ? token : token_)); + } + + absl::Status HasAtomicMove(const std::string& path, + bool* has_atomic_move) override { + return fs_->HasAtomicMove(path, has_atomic_move); + } + + void FlushCaches(TransactionToken* token) override { + return fs_->FlushCaches((token ? token : token_)); + } + + char Separator() const override { return fs_->Separator(); } + + absl::string_view Basename(absl::string_view path) const override { + return fs_->Basename(path); + } + + absl::Status StartTransaction(TransactionToken** token) override { + return fs_->StartTransaction(token); + } + + absl::Status AddToTransaction(const std::string& path, + TransactionToken* token) override { + return fs_->AddToTransaction(path, (token ? token : token_)); + } + + absl::Status EndTransaction(TransactionToken* token) override { + return fs_->EndTransaction(token); + } + + absl::Status GetTransactionForPath(const std::string& path, + TransactionToken** token) override { + return fs_->GetTransactionForPath(path, token); + } + + absl::Status GetTokenOrStartTransaction(const std::string& path, + TransactionToken** token) override { + return fs_->GetTokenOrStartTransaction(path, token); + } + + std::string DecodeTransaction(const TransactionToken* token) override { + return fs_->DecodeTransaction((token ? token : token_)); + } + + WrappedFileSystem(FileSystem* file_system, TransactionToken* token) + : fs_(file_system), token_(token) {} + + ~WrappedFileSystem() override = default; + + private: + FileSystem* fs_; + TransactionToken* token_; +}; + +/// A file abstraction for randomly reading the contents of a file. +class RandomAccessFile { + public: + RandomAccessFile() {} + virtual ~RandomAccessFile() = default; + + /// \brief Returns the name of the file. + /// + /// This is an optional operation that may not be implemented by every + /// filesystem. + virtual absl::Status Name(absl::string_view* result) const { + return errors::Unimplemented("This filesystem does not support Name()"); + } + + /// \brief Reads up to `n` bytes from the file starting at `offset`. + /// + /// `scratch[0..n-1]` may be written by this routine. Sets `*result` + /// to the data that was read (including if fewer than `n` bytes were + /// successfully read). May set `*result` to point at data in + /// `scratch[0..n-1]`, so `scratch[0..n-1]` must be live when + /// `*result` is used. + /// + /// On OK returned status: `n` bytes have been stored in `*result`. + /// On non-OK returned status: `[0..n]` bytes have been stored in `*result`. + /// + /// Returns `OUT_OF_RANGE` if fewer than n bytes were stored in `*result` + /// because of EOF. + /// + /// Safe for concurrent use by multiple threads. + virtual absl::Status Read(uint64 offset, size_t n, absl::string_view* result, + char* scratch) const = 0; + +#if defined(TF_CORD_SUPPORT) + /// \brief Read up to `n` bytes from the file starting at `offset`. + virtual absl::Status Read(uint64 offset, size_t n, absl::Cord* cord) const { + return errors::Unimplemented( + "Read(uint64, size_t, absl::Cord*) is not " + "implemented"); + } +#endif + + private: + RandomAccessFile(const RandomAccessFile&) = delete; + void operator=(const RandomAccessFile&) = delete; +}; + +/// \brief A file abstraction for sequential writing. +/// +/// The implementation must provide buffering since callers may append +/// small fragments at a time to the file. +class WritableFile { + public: + WritableFile() {} + virtual ~WritableFile() = default; + + /// \brief Append 'data' to the file. + virtual absl::Status Append(absl::string_view data) = 0; + +#if defined(TF_CORD_SUPPORT) + // \brief Append 'data' to the file. + virtual absl::Status Append(const absl::Cord& cord) { + for (absl::string_view chunk : cord.Chunks()) { + TF_RETURN_IF_ERROR(Append(chunk)); + } + return absl::OkStatus(); + } +#endif + + /// \brief Close the file. + /// + /// Flush() and de-allocate resources associated with this file + /// + /// Typical return codes (not guaranteed to be exhaustive): + /// * OK + /// * Other codes, as returned from Flush() + virtual absl::Status Close() = 0; + + /// \brief Flushes the file and optionally syncs contents to filesystem. + /// + /// This should flush any local buffers whose contents have not been + /// delivered to the filesystem. + /// + /// If the process terminates after a successful flush, the contents + /// may still be persisted, since the underlying filesystem may + /// eventually flush the contents. If the OS or machine crashes + /// after a successful flush, the contents may or may not be + /// persisted, depending on the implementation. + virtual absl::Status Flush() = 0; + + // \brief Returns the name of the file. + /// + /// This is an optional operation that may not be implemented by every + /// filesystem. + virtual absl::Status Name(absl::string_view* result) const { + return errors::Unimplemented("This filesystem does not support Name()"); + } + + /// \brief Syncs contents of file to filesystem. + /// + /// This waits for confirmation from the filesystem that the contents + /// of the file have been persisted to the filesystem; if the OS + /// or machine crashes after a successful Sync, the contents should + /// be properly saved. + virtual absl::Status Sync() = 0; + + /// \brief Retrieves the current write position in the file, or -1 on + /// error. + /// + /// This is an optional operation, subclasses may choose to return + /// errors::Unimplemented. + virtual absl::Status Tell(int64_t* position) { + *position = -1; + return errors::Unimplemented("This filesystem does not support Tell()"); + } + + private: + WritableFile(const WritableFile&) = delete; + void operator=(const WritableFile&) = delete; +}; + +/// \brief A readonly memmapped file abstraction. +/// +/// The implementation must guarantee that all memory is accessible when the +/// object exists, independently from the Env that created it. +class ReadOnlyMemoryRegion { + public: + ReadOnlyMemoryRegion() {} + virtual ~ReadOnlyMemoryRegion() = default; + + /// \brief Returns a pointer to the memory region. + virtual const void* data() = 0; + + /// \brief Returns the length of the memory region in bytes. + virtual uint64 length() = 0; +}; + +/// \brief A registry for file system implementations. +/// +/// Filenames are specified as an URI, which is of the form +/// [scheme://]. +/// File system implementations are registered using the REGISTER_FILE_SYSTEM +/// macro, providing the 'scheme' as the key. +/// +/// There are two `Register` methods: one using `Factory` for legacy filesystems +/// (deprecated mechanism of subclassing `FileSystem` and using +/// `REGISTER_FILE_SYSTEM` macro), and one using `std::unique_ptr` +/// for the new modular approach. +/// +/// Note that the new API expects a pointer to `ModularFileSystem` but this is +/// not checked as there should be exactly one caller to the API and doing the +/// check results in a circular dependency between `BUILD` targets. +/// +/// Plan is to completely remove the filesystem registration from `Env` and +/// incorporate it into `ModularFileSystem` class (which will be renamed to be +/// the only `FileSystem` class and marked as `final`). But this will happen at +/// a later time, after we convert all filesystems to the new API. +/// +/// TODO(b/139060984): After all filesystems are converted, remove old +/// registration and update comment. +class FileSystemRegistry { + public: + typedef std::function Factory; + + virtual ~FileSystemRegistry() = default; + virtual absl::Status Register(const std::string& scheme, Factory factory) = 0; + virtual absl::Status Register(const std::string& scheme, + std::unique_ptr filesystem) = 0; + virtual FileSystem* Lookup(const std::string& scheme) = 0; + virtual absl::Status GetRegisteredFileSystemSchemes( + std::vector* schemes) = 0; +}; + +/// \brief An abstraction for enforcing ACL checks in FileSystem. +class FileAcl { + public: + virtual absl::Status CheckAccess(std::string_view path) = 0; + virtual ~FileAcl() = default; +}; + +} // namespace tsl + +#endif // XLA_TSL_PLATFORM_FILE_SYSTEM_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/file_system_helper.cc b/third_party/xla/xla/tsl/platform/file_system_helper.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/platform/file_system_helper.cc rename to third_party/xla/xla/tsl/platform/file_system_helper.cc index bfbea9808675e2..ffa288b4e25428 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/file_system_helper.cc +++ b/third_party/xla/xla/tsl/platform/file_system_helper.cc @@ -13,22 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/file_system_helper.h" +#include "xla/tsl/platform/file_system_helper.h" #include #include #include +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/file_system.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/threadpool.h" #include "tsl/platform/cpu_info.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/file_system.h" #include "tsl/platform/mutex.h" #include "tsl/platform/path.h" #include "tsl/platform/platform.h" -#include "tsl/platform/status.h" #include "tsl/platform/str_util.h" -#include "tsl/platform/threadpool.h" namespace tsl { namespace internal { diff --git a/third_party/xla/xla/tsl/platform/file_system_helper.h b/third_party/xla/xla/tsl/platform/file_system_helper.h new file mode 100644 index 00000000000000..218b4c887b3a0a --- /dev/null +++ b/third_party/xla/xla/tsl/platform/file_system_helper.h @@ -0,0 +1,64 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TSL_PLATFORM_FILE_SYSTEM_HELPER_H_ +#define XLA_TSL_PLATFORM_FILE_SYSTEM_HELPER_H_ + +#include +#include + +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" + +namespace tsl { + +class FileSystem; +class Env; + +namespace internal { + +// Given a pattern, stores in 'results' the set of paths (in the given file +// system) that match that pattern. +// +// This helper may be used by implementations of FileSystem::GetMatchingPaths() +// in order to provide parallel scanning of subdirectories (except on iOS). +// +// Arguments: +// fs: may not be null and will be used to identify directories and list +// their contents. +// env: may not be null and will be used to check if a match has been found. +// pattern: see FileSystem::GetMatchingPaths() for details. +// results: will be cleared and may not be null. +// +// Returns an error status if any call to 'fs' failed. +absl::Status GetMatchingPaths(FileSystem* fs, Env* env, const string& pattern, + std::vector* results); + +// Given a file path, determines whether the file exists. This helper simplifies +// the use of Env::FileExists. +// +// Arguments: +// env: may not be null. +// fname: the file path to look up +// +// Returns true if the file exists, false if it does not exist, or an error +// Status. +absl::StatusOr FileExists(Env* env, const string& fname); + +} // namespace internal +} // namespace tsl + +#endif // XLA_TSL_PLATFORM_FILE_SYSTEM_HELPER_H_ diff --git a/third_party/xla/xla/tsl/platform/logging.h b/third_party/xla/xla/tsl/platform/logging.h new file mode 100644 index 00000000000000..a50fd04bdaa359 --- /dev/null +++ b/third_party/xla/xla/tsl/platform/logging.h @@ -0,0 +1,29 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TSL_PLATFORM_LOGGING_H_ +#define XLA_TSL_PLATFORM_LOGGING_H_ + +#include "tsl/platform/platform.h" + +#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) || \ + defined(PLATFORM_GOOGLE_IOS) || defined(GOOGLE_LOGGING) || \ + defined(__EMSCRIPTEN__) || defined(PLATFORM_CHROMIUMOS) +#include "xla/tsl/platform/google/logging.h" // IWYU pragma: export +#else +#include "xla/tsl/platform/default/logging.h" // IWYU pragma: export +#endif + +#endif // XLA_TSL_PLATFORM_LOGGING_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/logging_test.cc b/third_party/xla/xla/tsl/platform/logging_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/platform/logging_test.cc rename to third_party/xla/xla/tsl/platform/logging_test.cc index 070696f19f2885..1988174095f0c3 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/logging_test.cc +++ b/third_party/xla/xla/tsl/platform/logging_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include #include @@ -28,10 +28,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/path.h" #include "tsl/platform/stacktrace_handler.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" // Make sure popen and pclose are available on Windows. #ifdef PLATFORM_WINDOWS diff --git a/third_party/xla/xla/tsl/platform/macros.h b/third_party/xla/xla/tsl/platform/macros.h new file mode 100644 index 00000000000000..e635f98f08a34c --- /dev/null +++ b/third_party/xla/xla/tsl/platform/macros.h @@ -0,0 +1,162 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TSL_PLATFORM_MACROS_H_ +#define XLA_TSL_PLATFORM_MACROS_H_ + +// Compiler attributes +#if (defined(__GNUC__) || defined(__APPLE__)) && !defined(SWIG) +// Compiler supports GCC-style attributes +#define TF_ATTRIBUTE_NORETURN __attribute__((noreturn)) +#define TF_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline)) +#define TF_ATTRIBUTE_NOINLINE __attribute__((noinline)) +#define TF_ATTRIBUTE_UNUSED __attribute__((unused)) +#define TF_ATTRIBUTE_COLD __attribute__((cold)) +#define TF_ATTRIBUTE_WEAK __attribute__((weak)) +#define TF_PACKED __attribute__((packed)) +#define TF_MUST_USE_RESULT __attribute__((warn_unused_result)) +#define TF_PRINTF_ATTRIBUTE(string_index, first_to_check) \ + __attribute__((__format__(__printf__, string_index, first_to_check))) +#define TF_SCANF_ATTRIBUTE(string_index, first_to_check) \ + __attribute__((__format__(__scanf__, string_index, first_to_check))) +#elif defined(_MSC_VER) +// Non-GCC equivalents +#define TF_ATTRIBUTE_NORETURN __declspec(noreturn) +#define TF_ATTRIBUTE_ALWAYS_INLINE __forceinline +#define TF_ATTRIBUTE_NOINLINE +#define TF_ATTRIBUTE_UNUSED +#define TF_ATTRIBUTE_COLD +#define TF_ATTRIBUTE_WEAK +#define TF_MUST_USE_RESULT +#define TF_PACKED +#define TF_PRINTF_ATTRIBUTE(string_index, first_to_check) +#define TF_SCANF_ATTRIBUTE(string_index, first_to_check) +#else +// Non-GCC equivalents +#define TF_ATTRIBUTE_NORETURN +#define TF_ATTRIBUTE_ALWAYS_INLINE +#define TF_ATTRIBUTE_NOINLINE +#define TF_ATTRIBUTE_UNUSED +#define TF_ATTRIBUTE_COLD +#define TF_ATTRIBUTE_WEAK +#define TF_MUST_USE_RESULT +#define TF_PACKED +#define TF_PRINTF_ATTRIBUTE(string_index, first_to_check) +#define TF_SCANF_ATTRIBUTE(string_index, first_to_check) +#endif + +// Control visibility outside .so +#if defined(_WIN32) +#ifdef TF_COMPILE_LIBRARY +#define TF_EXPORT __declspec(dllexport) +#else +#define TF_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 + +#ifdef __has_builtin +#define TF_HAS_BUILTIN(x) __has_builtin(x) +#else +#define TF_HAS_BUILTIN(x) 0 +#endif + +// C++11-style attributes (N2761) +#if defined(__has_cpp_attribute) +// Safely checks if an attribute is supported. Equivalent to +// ABSL_HAVE_CPP_ATTRIBUTE. +#define TF_HAS_CPP_ATTRIBUTE(n) __has_cpp_attribute(n) +#else +#define TF_HAS_CPP_ATTRIBUTE(n) 0 +#endif + +// [[clang::annotate("x")]] allows attaching custom strings (e.g. "x") to +// declarations (variables, functions, fields, etc.) for use by tools. They are +// represented in the Clang AST (as AnnotateAttr nodes) and in LLVM IR, but not +// in final output. +#if TF_HAS_CPP_ATTRIBUTE(clang::annotate) +#define TF_ATTRIBUTE_ANNOTATE(str) [[clang::annotate(str)]] +#else +#define TF_ATTRIBUTE_ANNOTATE(str) +#endif + +// A variable declaration annotated with the `TF_CONST_INIT` attribute will +// not compile (on supported platforms) unless the variable has a constant +// initializer. +#if TF_HAS_CPP_ATTRIBUTE(clang::require_constant_initialization) +#define TF_CONST_INIT [[clang::require_constant_initialization]] +#else +#define TF_CONST_INIT +#endif + +// Compilers can be told that a certain branch is not likely to be taken +// (for instance, a CHECK failure), and use that information in static +// analysis. Giving it this information can help it optimize for the +// common case in the absence of better information (ie. +// -fprofile-arcs). +#if TF_HAS_BUILTIN(__builtin_expect) || (defined(__GNUC__) && __GNUC__ >= 3) +#define TF_PREDICT_FALSE(x) (__builtin_expect(x, 0)) +#define TF_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) +#else +#define TF_PREDICT_FALSE(x) (x) +#define TF_PREDICT_TRUE(x) (x) +#endif + +// DEPRECATED: directly use the macro implementation instead. +// A macro to disallow the copy constructor and operator= functions +// This is usually placed in the private: declarations for a class. +#define TF_DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName&) = delete; \ + void operator=(const TypeName&) = delete + +// The TF_ARRAYSIZE(arr) macro returns the # of elements in an array arr. +// +// The expression TF_ARRAYSIZE(a) is a compile-time constant of type +// size_t. +#define TF_ARRAYSIZE(a) \ + ((sizeof(a) / sizeof(*(a))) / \ + static_cast(!(sizeof(a) % sizeof(*(a))))) + +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L || \ + (defined(_MSC_VER) && _MSC_VER >= 1900) +// Define this to 1 if the code is compiled in C++11 mode; leave it +// undefined otherwise. Do NOT define it to 0 -- that causes +// '#ifdef LANG_CXX11' to behave differently from '#if LANG_CXX11'. +#define LANG_CXX11 1 +#endif + +#if defined(__clang__) && defined(LANG_CXX11) && defined(__has_warning) +#if __has_feature(cxx_attributes) && __has_warning("-Wimplicit-fallthrough") +#define TF_FALLTHROUGH_INTENDED [[clang::fallthrough]] // NOLINT +#endif +#endif + +#ifndef TF_FALLTHROUGH_INTENDED +#define TF_FALLTHROUGH_INTENDED \ + do { \ + } while (0) +#endif + +namespace tsl { +namespace internal { +template +void remove_unused_variable_compiler_warning(const T&){}; +} // namespace internal +} // namespace tsl +#define TF_UNUSED_VARIABLE(x) \ + tensorflow::internal::remove_unused_variable_compiler_warning(x) + +#endif // XLA_TSL_PLATFORM_MACROS_H_ diff --git a/third_party/xla/xla/tsl/platform/profile_utils/BUILD b/third_party/xla/xla/tsl/platform/profile_utils/BUILD index f4bf80c0a5c09d..e1712dc13c24ec 100644 --- a/third_party/xla/xla/tsl/platform/profile_utils/BUILD +++ b/third_party/xla/xla/tsl/platform/profile_utils/BUILD @@ -53,10 +53,10 @@ cc_library( ], copts = tsl_copts(), deps = [ + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "@com_google_absl//absl/base", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:types", ], alwayslink = 1, ) diff --git a/third_party/xla/xla/tsl/platform/profile_utils/android_armv7a_cpu_utils_helper.cc b/third_party/xla/xla/tsl/platform/profile_utils/android_armv7a_cpu_utils_helper.cc index 557a54ecc1afc7..00072b2ce91b33 100644 --- a/third_party/xla/xla/tsl/platform/profile_utils/android_armv7a_cpu_utils_helper.cc +++ b/third_party/xla/xla/tsl/platform/profile_utils/android_armv7a_cpu_utils_helper.cc @@ -29,7 +29,7 @@ limitations under the License. #include #include -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/stringprintf.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/profile_utils/android_armv7a_cpu_utils_helper.h b/third_party/xla/xla/tsl/platform/profile_utils/android_armv7a_cpu_utils_helper.h index e4385a9e76ad49..b796ef5b5e6e20 100644 --- a/third_party/xla/xla/tsl/platform/profile_utils/android_armv7a_cpu_utils_helper.h +++ b/third_party/xla/xla/tsl/platform/profile_utils/android_armv7a_cpu_utils_helper.h @@ -18,9 +18,9 @@ limitations under the License. #include +#include "xla/tsl/platform/macros.h" #include "xla/tsl/platform/profile_utils/i_cpu_utils_helper.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" #if defined(__ANDROID__) && (__ANDROID_API__ >= 21) && \ (defined(__ARM_ARCH_7A__) || defined(__aarch64__)) diff --git a/third_party/xla/xla/tsl/platform/profile_utils/clock_cycle_profiler.h b/third_party/xla/xla/tsl/platform/profile_utils/clock_cycle_profiler.h index 7ef8af80ecbb7b..b922cb942902a3 100644 --- a/third_party/xla/xla/tsl/platform/profile_utils/clock_cycle_profiler.h +++ b/third_party/xla/xla/tsl/platform/profile_utils/clock_cycle_profiler.h @@ -18,9 +18,9 @@ limitations under the License. #include +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" #include "xla/tsl/platform/profile_utils/cpu_utils.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/profile_utils/cpu_utils.cc b/third_party/xla/xla/tsl/platform/profile_utils/cpu_utils.cc index 85a7a7b840bf32..394d1f87a341ff 100644 --- a/third_party/xla/xla/tsl/platform/profile_utils/cpu_utils.cc +++ b/third_party/xla/xla/tsl/platform/profile_utils/cpu_utils.cc @@ -28,8 +28,8 @@ limitations under the License. #endif #include "absl/base/call_once.h" +#include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/profile_utils/android_armv7a_cpu_utils_helper.h" -#include "tsl/platform/logging.h" namespace tsl { namespace profile_utils { diff --git a/third_party/xla/xla/tsl/platform/profile_utils/cpu_utils.h b/third_party/xla/xla/tsl/platform/profile_utils/cpu_utils.h index caff59be57eec9..f3d6d42566496b 100644 --- a/third_party/xla/xla/tsl/platform/profile_utils/cpu_utils.h +++ b/third_party/xla/xla/tsl/platform/profile_utils/cpu_utils.h @@ -20,9 +20,9 @@ limitations under the License. #include #include +#include "xla/tsl/platform/macros.h" #include "xla/tsl/platform/profile_utils/i_cpu_utils_helper.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" #if defined(ARMV6) || defined(__ARM_ARCH_7A__) #include diff --git a/third_party/xla/xla/tsl/platform/profile_utils/cpu_utils_test.cc b/third_party/xla/xla/tsl/platform/profile_utils/cpu_utils_test.cc index cc92395c61678c..968846acb40f5a 100644 --- a/third_party/xla/xla/tsl/platform/profile_utils/cpu_utils_test.cc +++ b/third_party/xla/xla/tsl/platform/profile_utils/cpu_utils_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include "xla/tsl/platform/profile_utils/cpu_utils.h" +#include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/profile_utils/clock_cycle_profiler.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace profile_utils { diff --git a/third_party/xla/xla/tsl/platform/profile_utils/i_cpu_utils_helper.h b/third_party/xla/xla/tsl/platform/profile_utils/i_cpu_utils_helper.h index f434c1b17955b8..11d5bf2f4b675f 100644 --- a/third_party/xla/xla/tsl/platform/profile_utils/i_cpu_utils_helper.h +++ b/third_party/xla/xla/tsl/platform/profile_utils/i_cpu_utils_helper.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef XLA_TSL_PLATFORM_PROFILE_UTILS_I_CPU_UTILS_HELPER_H_ #define XLA_TSL_PLATFORM_PROFILE_UTILS_I_CPU_UTILS_HELPER_H_ -#include "tsl/platform/macros.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace profile_utils { diff --git a/third_party/xla/third_party/tsl/tsl/platform/status.cc b/third_party/xla/xla/tsl/platform/status.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/platform/status.cc rename to third_party/xla/xla/tsl/platform/status.cc index f6d4aed1d71984..20d14c089562f5 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status.cc +++ b/third_party/xla/xla/tsl/platform/status.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" #include diff --git a/third_party/xla/xla/tsl/platform/status.h b/third_party/xla/xla/tsl/platform/status.h new file mode 100644 index 00000000000000..0086587b629def --- /dev/null +++ b/third_party/xla/xla/tsl/platform/status.h @@ -0,0 +1,226 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TSL_PLATFORM_STATUS_H_ +#define XLA_TSL_PLATFORM_STATUS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" +#include "xla/tsl/protobuf/error_codes.pb.h" +#include "tsl/platform/platform.h" +#include "tsl/platform/stack_frame.h" + +// Include appropriate platform-dependent parts of status. +#if defined(PLATFORM_GOOGLE) +#include "xla/tsl/platform/google/status.h" // IWYU pragma: export +#else +#include "xla/tsl/platform/default/status.h" // IWYU pragma: export +#endif + +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + +namespace tsl { + +// Since April 2023, tensorflow::Status is an alias to absl::Status. The first +// TF release including this change will be TF 2.14 (the latest release in +// April 2023 is 2.13). +// At the same time `tsl::errors::Code` aliases `absl::StatusCode`. +// +// Here is a set of correspondences: +// - Use `absl::OkStatus()` instead of `tsl::OkStatus()`. +typedef absl::Status Status ABSL_DEPRECATE_AND_INLINE(); + +namespace errors { +typedef absl::StatusCode Code ABSL_DEPRECATE_AND_INLINE(); +} // namespace errors +namespace error { +typedef ::tensorflow::error::Code Code; +} // namespace error +} // namespace tsl + +// Transparent comparison between tensorflow::error::Code protobuf enum and +// absl::Status. +// +// The longer term objective is to delete these when we have done the transition +// to absl::Status. +namespace tensorflow::error { +inline bool operator==(const ::tensorflow::error::Code& c1, + const absl::StatusCode& c2) { + return static_cast(c1) == static_cast(c2); +} + +inline bool operator!=(const ::tensorflow::error::Code& c1, + const absl::StatusCode& c2) { + return static_cast(c1) != static_cast(c2); +} +} // namespace tensorflow::error + +namespace absl { +inline bool operator==(const ::absl::StatusCode& c1, + const ::tensorflow::error::Code& c2) { + return static_cast(c1) == static_cast(c2); +} + +inline bool operator!=(const ::absl::StatusCode& c1, + const ::tensorflow::error::Code& c2) { + return static_cast(c1) != static_cast(c2); +} +} // namespace absl + +namespace tsl { + +// OkStatus() +// +// Returns an OK status, equivalent to a default constructed instance. Prefer +// usage of `OkStatus()` when constructing such an OK status. +ABSL_DEPRECATE_AND_INLINE() inline absl::Status OkStatus() { + return absl::OkStatus(); +}; + +ABSL_DEPRECATE_AND_INLINE() +inline absl::Status FromAbslStatus(const absl::Status& s) { return s; } +ABSL_DEPRECATE_AND_INLINE() +inline absl::Status ToAbslStatus(const ::absl::Status& s) { return s; } + +// Given `Status.message()` does not guarantee to be always backed by a +// null-terminated string, we have this utility function when it's needed for +// the Tensorflow C-API. +// A more robust API would be to get both a `char*` of the beginning of the +// string, plus the size (see e.g. `XlaCustomCallStatusSetFailure`). +// NB: This Windows-only implementation is exists only to avoid a linker error. +// Remove if this is resolved. +#ifdef _WIN32 +const char* NullTerminatedMessage(const absl::Status& status); +#else +ABSL_DEPRECATE_AND_INLINE() +inline const char* NullTerminatedMessage(const absl::Status& status) { + return absl::StatusMessageAsCStr(status); +} +#endif + +// TODO(b/197552541) Move this namespace to errors.h. +namespace errors { + +void SetStackTrace(absl::Status& status, std::vector stack_trace); + +std::vector GetStackTrace(const absl::Status& status); +} // namespace errors + +// Helper class to manage multiple child status values. +class StatusGroup { + public: + StatusGroup(); + // Constructor to form a StatusGroup from any N set of Status arguments. + // Usage: StatusGroup({status_a, status_b, status_c}); + StatusGroup(std::initializer_list statuses); + + // Utility function to mark a Status as derived. By marking derived status, + // Derived status messages are ignored when reporting errors to end users. + static absl::Status MakeDerived(const absl::Status& s); + static bool IsDerived(const absl::Status& s); + + // Enable warning and error log collection for appending to the aggregated + // status. This function may be called more than once. + static void ConfigureLogHistory(); + + // Returns merged payloads of all statuses. In case multiple statuses have the + // same payload key, non-derived statuses have priority over derived ones, + // otherwise one payload value will be chosen in an unspecified but + // deterministic order. + // NOTE: The payload marking derived statuses as derived will not be returned. + std::unordered_map GetPayloads() const; + + // Return a merged status with combined child status messages with a summary. + absl::Status as_summary_status() const; + // Return a merged status with combined child status messages with + // concatenation. + absl::Status as_concatenated_status() const; + + bool ok() const { return ok_; } + + // Augment this group with the child status `status`. + void Update(const absl::Status& status); + + // Attach recent warning and error log messages + void AttachLogMessages(); + bool HasLogMessages() const { return !recent_logs_.empty(); } + + private: + bool ok_ = true; + size_t num_ok_ = 0; + + // Maintain a sorted collection of statuses. + struct CompareStatus { + bool operator()(const absl::Status& a, const absl::Status& b) const { + return a.ToString() > b.ToString(); + } + }; + // Using std::set instead of absl::btree_set to keep size for certain + // dependent libraries under the limit. + std::set derived_; + std::set non_derived_; + + std::vector recent_logs_; // recent warning and error logs +}; + +typedef std::function StatusCallback; + +extern ::tsl::string* TfCheckOpHelperOutOfLine(const absl::Status& v, + const char* msg); + +inline ::tsl::string* TfCheckOpHelper(absl::Status v, const char* msg) { + if (v.ok()) return nullptr; + return TfCheckOpHelperOutOfLine(v, msg); +} + +#define TF_DO_CHECK_OK(val, level) \ + while (auto* _result = ::tsl::TfCheckOpHelper(val, #val)) \ + LOG(level) << *(_result) + +#define TF_CHECK_OK(val) TF_DO_CHECK_OK(val, FATAL) +#define TF_QCHECK_OK(val) TF_DO_CHECK_OK(val, QFATAL) + +// DEBUG only version of TF_CHECK_OK. Compiler still parses 'val' even in opt +// mode. +#ifndef NDEBUG +#define TF_DCHECK_OK(val) TF_CHECK_OK(val) +#else +#define TF_DCHECK_OK(val) \ + while (false && (::tsl::OkStatus() == (val))) LOG(FATAL) +#endif + +} // namespace tsl + +#endif // XLA_TSL_PLATFORM_STATUS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/status_matchers.cc b/third_party/xla/xla/tsl/platform/status_matchers.cc similarity index 94% rename from third_party/xla/third_party/tsl/tsl/platform/status_matchers.cc rename to third_party/xla/xla/tsl/platform/status_matchers.cc index bcb04018dbc7f9..ee4c204798a15f 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status_matchers.cc +++ b/third_party/xla/xla/tsl/platform/status_matchers.cc @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/status_matchers.h" +#include "xla/tsl/platform/status_matchers.h" #include #include +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/test.h" #include "xla/tsl/protobuf/error_codes.pb.h" -#include "tsl/platform/status.h" -#include "tsl/platform/test.h" namespace tsl { namespace testing { diff --git a/third_party/xla/xla/tsl/platform/status_matchers.h b/third_party/xla/xla/tsl/platform/status_matchers.h new file mode 100644 index 00000000000000..9650ec28754c2a --- /dev/null +++ b/third_party/xla/xla/tsl/platform/status_matchers.h @@ -0,0 +1,343 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_TSL_PLATFORM_STATUS_MATCHERS_H_ +#define XLA_TSL_PLATFORM_STATUS_MATCHERS_H_ + +#include +#include +#include + +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/protobuf/error_codes.pb.h" + +// Defines the following utilities: +// +// =============== +// IsOkAndHolds(m) +// =============== +// +// This matcher matches a StatusOr value whose status is OK and whose inner +// value matches matcher m. Example: +// +// using ::tsl::testing::IsOkAndHolds; +// using ::testing::HasSubstr; +// ... +// StatusOr status_or_message("Hello, world"); +// EXPECT_THAT(status_or_message, IsOkAndHolds("Hello, world"))); +// EXPECT_THAT(status_or_message, IsOkAndHolds(HasSubstr("Hello,"))); +// +// =============================== +// StatusIs(status_code_matcher, +// error_message_matcher) +// =============================== +// +// This matcher matches a Status or StatusOr if the following are true: +// +// - the status's code() matches status_code_matcher, and +// - the status's error_message() matches error_message_matcher. +// +// Example: +// +// using ::tsl::testing::StatusIs; +// using ::testing::HasSubstr; +// using ::testing::MatchesRegex; +// using ::testing::Ne; +// using ::testing::_; +// StatusOr GetMessage(int id); +// ... +// +// // The status code must be CANCELLED; the error message can be anything. +// EXPECT_THAT(GetName(42), +// StatusIs(tsl::error::CANCELLED, _)); +// +// // The status code can be anything; the error message must match the regex. +// EXPECT_THAT(GetName(43), +// StatusIs(_, MatchesRegex("server.*time-out"))); +// +// // The status code should not be CANCELLED; the error message can be +// // anything with "Cancelled" in it. +// EXPECT_THAT(GetName(44), +// StatusIs(Ne(tsl::error::CANCELLED), +// HasSubstr("Cancelled")))); +// +// ============================= +// StatusIs(status_code_matcher) +// ============================= +// +// This is a shorthand for +// StatusIs(status_code_matcher, ::testing::_) +// +// In other words, it's like the two-argument StatusIs(), except that it ignores +// error messages. +// +// ====== +// IsOk() +// ====== +// +// Matches a Status or StatusOr whose status value is OK. +// Equivalent to 'StatusIs(error::OK)'. +// +// Example: +// ... +// StatusOr message("Hello, world"); +// EXPECT_THAT(message, IsOk()); +// Status status = OkStatus(); +// EXPECT_THAT(status, IsOk()); + +namespace tsl { + +inline void PrintTo(const tsl::error::Code code, std::ostream* os) { + *os << Code_Name(code); +} + +template +void PrintTo(const StatusOr& status_or, std::ostream* os) { + *os << ::testing::PrintToString(status_or.status()); + if (status_or.ok()) { + *os << ": " << ::testing::PrintToString(status_or.value()); + } +} + +namespace testing { +namespace internal_status { + +inline const absl::Status& GetStatus(const absl::Status& status) { + return status; +} + +template +inline const absl::Status& GetStatus(const StatusOr& status) { + return status.status(); +} + +//////////////////////////////////////////////////////////// +// Implementation of IsOkAndHolds(). +// +// Monomorphic implementation of matcher IsOkAndHolds(m). StatusOrType is a +// reference to StatusOr. +template +class IsOkAndHoldsMatcherImpl + : public ::testing::MatcherInterface { + public: + typedef + typename std::remove_reference::type::value_type value_type; + + template + explicit IsOkAndHoldsMatcherImpl(InnerMatcher&& inner_matcher) + : inner_matcher_(::testing::SafeMatcherCast( + std::forward(inner_matcher))) {} + + void DescribeTo(std::ostream* os) const override { + *os << "is OK and has a value that "; + inner_matcher_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "isn't OK or has a value that "; + inner_matcher_.DescribeNegationTo(os); + } + + bool MatchAndExplain( + StatusOrType actual_value, + ::testing::MatchResultListener* result_listener) const override { + if (!actual_value.ok()) { + *result_listener << "which has status " << actual_value.status(); + return false; + } + + ::testing::StringMatchResultListener inner_listener; + const bool matches = + inner_matcher_.MatchAndExplain(*actual_value, &inner_listener); + const std::string inner_explanation = inner_listener.str(); + if (!inner_explanation.empty()) { + *result_listener << "which contains value " + << ::testing::PrintToString(*actual_value) << ", " + << inner_explanation; + } + return matches; + } + + private: + const ::testing::Matcher inner_matcher_; +}; + +// Implements IsOkAndHolds(m) as a polymorphic matcher. +template +class IsOkAndHoldsMatcher { + public: + explicit IsOkAndHoldsMatcher(InnerMatcher inner_matcher) + : inner_matcher_(std::move(inner_matcher)) {} + + // Converts this polymorphic matcher to a monomorphic matcher of the given + // type. StatusOrType can be either StatusOr or a reference to StatusOr. + template + operator ::testing::Matcher() const { // NOLINT + return ::testing::Matcher( + new IsOkAndHoldsMatcherImpl(inner_matcher_)); + } + + private: + const InnerMatcher inner_matcher_; +}; + +//////////////////////////////////////////////////////////// +// Implementation of StatusIs(). +// +// StatusIs() is a polymorphic matcher. This class is the common +// implementation of it shared by all types T where StatusIs() can be used as +// a Matcher. + +class StatusIsMatcherCommonImpl { + public: + StatusIsMatcherCommonImpl( + ::testing::Matcher code_matcher, + ::testing::Matcher message_matcher) + : code_matcher_(std::move(code_matcher)), + message_matcher_(std::move(message_matcher)) {} + + void DescribeTo(std::ostream* os) const; + + void DescribeNegationTo(std::ostream* os) const; + + bool MatchAndExplain(const absl::Status& status, + ::testing::MatchResultListener* result_listener) const; + + private: + const ::testing::Matcher code_matcher_; + const ::testing::Matcher message_matcher_; +}; + +// Monomorphic implementation of matcher StatusIs() for a given type T. T can +// be Status, StatusOr<>, or a reference to either of them. +template +class MonoStatusIsMatcherImpl : public ::testing::MatcherInterface { + public: + explicit MonoStatusIsMatcherImpl(StatusIsMatcherCommonImpl common_impl) + : common_impl_(std::move(common_impl)) {} + + void DescribeTo(std::ostream* os) const override { + common_impl_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + common_impl_.DescribeNegationTo(os); + } + + bool MatchAndExplain( + T actual_value, + ::testing::MatchResultListener* result_listener) const override { + return common_impl_.MatchAndExplain(GetStatus(actual_value), + result_listener); + } + + private: + StatusIsMatcherCommonImpl common_impl_; +}; + +// Implements StatusIs() as a polymorphic matcher. +class StatusIsMatcher { + public: + StatusIsMatcher(::testing::Matcher code_matcher, + ::testing::Matcher message_matcher) + : common_impl_( + ::testing::MatcherCast(code_matcher), + ::testing::MatcherCast(message_matcher)) {} + + // Converts this polymorphic matcher to a monomorphic matcher of the given + // type. T can be StatusOr<>, Status, or a reference to either of them. + template + operator ::testing::Matcher() const { // NOLINT + return ::testing::MakeMatcher(new MonoStatusIsMatcherImpl(common_impl_)); + } + + private: + const StatusIsMatcherCommonImpl common_impl_; +}; + +// Monomorphic implementation of matcher IsOk() for a given type T. +// T can be Status, StatusOr<>, or a reference to either of them. +template +class MonoIsOkMatcherImpl : public ::testing::MatcherInterface { + public: + void DescribeTo(std::ostream* os) const override { *os << "is OK"; } + void DescribeNegationTo(std::ostream* os) const override { + *os << "is not OK"; + } + bool MatchAndExplain(T actual_value, + ::testing::MatchResultListener*) const override { + return GetStatus(actual_value).ok(); + } +}; + +// Implements IsOk() as a polymorphic matcher. +class IsOkMatcher { + public: + template + operator ::testing::Matcher() const { // NOLINT + return ::testing::Matcher(new MonoIsOkMatcherImpl()); + } +}; +} // namespace internal_status + +// Returns a matcher that matches a StatusOr<> whose status is OK and whose +// value matches the inner matcher. +template +internal_status::IsOkAndHoldsMatcher::type> +IsOkAndHolds(InnerMatcher&& inner_matcher) { + return internal_status::IsOkAndHoldsMatcher< + typename std::decay::type>( + std::forward(inner_matcher)); +} + +// Returns a matcher that matches a Status or StatusOr<> whose status code +// matches code_matcher, and whose error message matches message_matcher. +template +internal_status::StatusIsMatcher StatusIs(CodeMatcher code_matcher, + MessageMatcher message_matcher) { + return internal_status::StatusIsMatcher(std::move(code_matcher), + std::move(message_matcher)); +} +// Remove this specialization when tensorflow::Status is absl::Status +template +internal_status::StatusIsMatcher StatusIs(tensorflow::error::Code code_matcher, + MessageMatcher message_matcher) { + return internal_status::StatusIsMatcher( + static_cast(code_matcher), std::move(message_matcher)); +} + +// Returns a matcher that matches a Status or StatusOr<> whose status code +// matches code_matcher. +template +internal_status::StatusIsMatcher StatusIs(CodeMatcher code_matcher) { + return StatusIs(std::move(code_matcher), ::testing::_); +} +// Remove this specialization when tensorflow::Status is absl::Status +template <> +inline internal_status::StatusIsMatcher StatusIs( + tensorflow::error::Code code_matcher) { + return StatusIs(static_cast(code_matcher), ::testing::_); +} + +// Returns a matcher that matches a Status or StatusOr<> which is OK. +inline internal_status::IsOkMatcher IsOk() { + return internal_status::IsOkMatcher(); +} + +} // namespace testing +} // namespace tsl + +#endif // XLA_TSL_PLATFORM_STATUS_MATCHERS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/status_matchers_test.cc b/third_party/xla/xla/tsl/platform/status_matchers_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/platform/status_matchers_test.cc rename to third_party/xla/xla/tsl/platform/status_matchers_test.cc index 3a681f6f3aed31..e8e73d4c9bb5d2 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status_matchers_test.cc +++ b/third_party/xla/xla/tsl/platform/status_matchers_test.cc @@ -12,17 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/status_matchers.h" +#include "xla/tsl/platform/status_matchers.h" #include #include #include +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/tsl/protobuf/error_codes.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" namespace tsl { namespace testing { diff --git a/third_party/xla/third_party/tsl/tsl/platform/status_test.cc b/third_party/xla/xla/tsl/platform/status_test.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/platform/status_test.cc rename to third_party/xla/xla/tsl/platform/status_test.cc index e716a15b96e46e..1cba0f61046f91 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status_test.cc +++ b/third_party/xla/xla/tsl/platform/status_test.cc @@ -10,7 +10,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" #include #include @@ -19,13 +19,13 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/str_format.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status_matchers.h" +#include "xla/tsl/platform/status_to_from_proto.h" +#include "xla/tsl/platform/test.h" #include "xla/tsl/protobuf/error_codes.pb.h" #include "xla/tsl/protobuf/status.pb.h" -#include "tsl/platform/errors.h" #include "tsl/platform/stack_frame.h" -#include "tsl/platform/status_matchers.h" -#include "tsl/platform/status_to_from_proto.h" -#include "tsl/platform/test.h" namespace tsl { namespace { diff --git a/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.cc b/third_party/xla/xla/tsl/platform/status_to_from_proto.cc similarity index 96% rename from third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.cc rename to third_party/xla/xla/tsl/platform/status_to_from_proto.cc index 54e2b2ef3391ab..59bf21dcf53d7b 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.cc +++ b/third_party/xla/xla/tsl/platform/status_to_from_proto.cc @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/status_to_from_proto.h" +#include "xla/tsl/platform/status_to_from_proto.h" #include #include "absl/strings/cord.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status.h" #include "xla/tsl/protobuf/error_codes.pb.h" #include "xla/tsl/protobuf/status.pb.h" -#include "tsl/platform/status.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/status_to_from_proto.h b/third_party/xla/xla/tsl/platform/status_to_from_proto.h new file mode 100644 index 00000000000000..b26d824bd8aa61 --- /dev/null +++ b/third_party/xla/xla/tsl/platform/status_to_from_proto.h @@ -0,0 +1,43 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_TSL_PLATFORM_STATUS_TO_FROM_PROTO_H_ +#define XLA_TSL_PLATFORM_STATUS_TO_FROM_PROTO_H_ + +#include "xla/tsl/platform/status.h" +#include "xla/tsl/protobuf/status.pb.h" + +namespace tsl { + +// TODO(b/250921378): Merge this file with `status.h` once we figure out how to +// fix the following error with the MacOS build: +// +// ImportError: +// dlopen(/org_tensorflow/tensorflow/python/platform/_pywrap_tf2.so, 2): +// Symbol not found: tensorflow11StatusProtoC1EPN6protobuf5ArenaEb + +// Converts a `Status` to a `StatusProto`. +tensorflow::StatusProto StatusToProto(const absl::Status& s); + +#if defined(PLATFORM_GOOGLE) +// Constructs a `Status` from a `StatusProto`. +absl::Status StatusFromProto( + const tensorflow::StatusProto& proto, + absl::SourceLocation loc = absl::SourceLocation::current()); +#else +Status StatusFromProto(const tensorflow::StatusProto& proto); +#endif +} // namespace tsl + +#endif // XLA_TSL_PLATFORM_STATUS_TO_FROM_PROTO_H_ diff --git a/third_party/xla/xla/tsl/platform/statusor.h b/third_party/xla/xla/tsl/platform/statusor.h new file mode 100644 index 00000000000000..f638fe3f2cda32 --- /dev/null +++ b/third_party/xla/xla/tsl/platform/statusor.h @@ -0,0 +1,111 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// StatusOr is the union of a Status object and a T object. StatusOr models +// the concept of an object that is either a value, or an error Status +// explaining why such a value is not present. To this end, StatusOr does not +// allow its Status value to be Status::OK. +// +// The primary use-case for StatusOr is as the return value of a +// function which may fail. +// +// Example client usage for a StatusOr, where T is not a pointer: +// +// StatusOr result = DoBigCalculationThatCouldFail(); +// if (result.ok()) { +// float answer = result.value(); +// printf("Big calculation yielded: %f", answer); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example client usage for a StatusOr: +// +// StatusOr result = FooFactory::MakeNewFoo(arg); +// if (result.ok()) { +// std::unique_ptr foo(result.value()); +// foo->DoSomethingCool(); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example client usage for a StatusOr>: +// +// StatusOr> result = FooFactory::MakeNewFoo(arg); +// if (result.ok()) { +// std::unique_ptr foo = std::move(result.value()); +// foo->DoSomethingCool(); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example factory implementation returning StatusOr: +// +// StatusOr FooFactory::MakeNewFoo(int arg) { +// if (arg <= 0) { +// return tsl::InvalidArgument("Arg must be positive"); +// } else { +// return new Foo(arg); +// } +// } +// +// Note that the assignment operators require that destroying the currently +// stored value cannot invalidate the argument; in other words, the argument +// cannot be an alias for the current value, or anything owned by the current +// value. +#ifndef XLA_TSL_PLATFORM_STATUSOR_H_ +#define XLA_TSL_PLATFORM_STATUSOR_H_ + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/status/statusor.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "tsl/platform/platform.h" + +// Include appropriate platform-dependent `TF_ASSIGN_OR_RETURN`. +#if defined(PLATFORM_GOOGLE) +#include "xla/tsl/platform/google/statusor.h" // IWYU pragma: export +#else +#include "xla/tsl/platform/default/statusor.h" // IWYU pragma: export +#endif + +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + +namespace tsl { + +template +using StatusOr ABSL_DEPRECATE_AND_INLINE() = absl::StatusOr; + +} // namespace tsl + +#define TF_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ + TF_ASSERT_OK_AND_ASSIGN_IMPL( \ + TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \ + rexpr); + +#define TF_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + ASSERT_TRUE(statusor.status().ok()) << statusor.status(); \ + lhs = std::move(statusor).value() + +#define TF_STATUS_MACROS_CONCAT_NAME(x, y) TF_STATUS_MACROS_CONCAT_IMPL(x, y) +#define TF_STATUS_MACROS_CONCAT_IMPL(x, y) x##y + +#endif // XLA_TSL_PLATFORM_STATUSOR_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/statusor_test.cc b/third_party/xla/xla/tsl/platform/statusor_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/platform/statusor_test.cc rename to third_party/xla/xla/tsl/platform/statusor_test.cc index fd0ee7886073b4..41706938273124 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/statusor_test.cc +++ b/third_party/xla/xla/tsl/platform/statusor_test.cc @@ -15,7 +15,7 @@ limitations under the License. // Unit tests for StatusOr -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" #include #include @@ -23,10 +23,10 @@ limitations under the License. #include #include "absl/base/config.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" namespace tsl { namespace { @@ -731,8 +731,9 @@ TEST(Status, StackTracePropagation) { ASSERT_EQ(sources.size(), 3); for (int i = 0; i < 3; ++i) { - ASSERT_EQ(sources[i].file_name(), - "third_party/tensorflow/tsl/platform/statusor_test.cc"); + ASSERT_EQ( + sources[i].file_name(), + "third_party/tensorflow/compiler/xla/tsl/platform/statusor_test.cc"); } } diff --git a/third_party/xla/xla/tsl/platform/subprocess.h b/third_party/xla/xla/tsl/platform/subprocess.h index d43d70bd6a1f5e..8702b7795a8062 100644 --- a/third_party/xla/xla/tsl/platform/subprocess.h +++ b/third_party/xla/xla/tsl/platform/subprocess.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/subprocess_test.cc b/third_party/xla/xla/tsl/platform/subprocess_test.cc index 807de31bc3e907..5bcf7824177964 100644 --- a/third_party/xla/xla/tsl/platform/subprocess_test.cc +++ b/third_party/xla/xla/tsl/platform/subprocess_test.cc @@ -21,9 +21,9 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/test.h" #include "tsl/platform/path.h" #include "tsl/platform/strcat.h" -#include "tsl/platform/test.h" #ifdef PLATFORM_WINDOWS #define WIFEXITED(code) ((code) != 3) diff --git a/third_party/xla/third_party/tsl/tsl/platform/test.cc b/third_party/xla/xla/tsl/platform/test.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/platform/test.cc rename to third_party/xla/xla/tsl/platform/test.cc index b2b2a8936c81e9..25a697f85f25aa 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/test.cc +++ b/third_party/xla/xla/tsl/platform/test.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" #include #include #include -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/net.h" #include "tsl/platform/path.h" diff --git a/third_party/xla/xla/tsl/platform/test.h b/third_party/xla/xla/tsl/platform/test.h new file mode 100644 index 00000000000000..2569bc57d989e2 --- /dev/null +++ b/third_party/xla/xla/tsl/platform/test.h @@ -0,0 +1,86 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TSL_PLATFORM_TEST_H_ +#define XLA_TSL_PLATFORM_TEST_H_ + +#include +#include +#include + +#include // IWYU pragma: export +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" +#include "tsl/platform/platform.h" + +// Includes gmock.h and enables the use of gmock matchers in tensorflow tests. +// +// Test including this header can use the macros EXPECT_THAT(...) and +// ASSERT_THAT(...) in combination with gmock matchers. +// Example: +// std::vector vec = Foo(); +// EXPECT_THAT(vec, ::testing::ElementsAre(1,2,3)); +// EXPECT_THAT(vec, ::testing::UnorderedElementsAre(2,3,1)); +// +// For more details on gmock matchers see: +// https://github.com/google/googletest/blob/master/googlemock/docs/CheatSheet.md#matchers +// +// The advantages of using gmock matchers instead of self defined matchers are +// better error messages, more maintainable tests and more test coverage. +#if !defined(PLATFORM_GOOGLE) && !defined(PLATFORM_GOOGLE_ANDROID) && \ + !defined(PLATFORM_CHROMIUMOS) +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#endif +#include // IWYU pragma: export + +namespace tsl { +namespace testing { + +// Return a temporary directory suitable for temporary testing files. +// +// Where possible, consider using Env::LocalTempFilename over this function. +std::string TmpDir(); + +// Returns the path to TensorFlow in the directory containing data +// dependencies. +// +// A better alternative would be making use if +// tensorflow/tsl/platform/resource_loader.h:GetDataDependencyFilepath. That +// function should do the right thing both within and outside of tests allowing +// avoiding test specific APIs. +std::string TensorFlowSrcRoot(); + +// Returns the path to XLA in the directory containing data +// dependencies. +std::string XlaSrcRoot(); + +// Returns the path to TSL in the directory containing data +// dependencies. +std::string TslSrcRoot(); + +// Return a random number generator seed to use in randomized tests. +// Returns the same value for the lifetime of the process. +int RandomSeed(); + +// Returns an unused port number, for use in multi-process testing. +// NOTE: This function is not thread-safe. +int PickUnusedPortOrDie(); + +} // namespace testing +} // namespace tsl + +#endif // XLA_TSL_PLATFORM_TEST_H_ diff --git a/third_party/xla/xla/tsl/platform/test_benchmark.h b/third_party/xla/xla/tsl/platform/test_benchmark.h new file mode 100644 index 00000000000000..2d0c4435dc182f --- /dev/null +++ b/third_party/xla/xla/tsl/platform/test_benchmark.h @@ -0,0 +1,48 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Simple benchmarking facility. +#ifndef XLA_TSL_PLATFORM_TEST_BENCHMARK_H_ +#define XLA_TSL_PLATFORM_TEST_BENCHMARK_H_ + +#include "benchmark/benchmark.h" // IWYU pragma: export +#include "tsl/platform/platform.h" + +// FIXME(vyng): Remove this. +// Background: During the benchmark-migration projects, all benchmarks were made +// to use "testing::benchmark::" prefix because that is what the internal +// Google benchmark library use. +namespace testing { +namespace benchmark { +using ::benchmark::State; // NOLINT +} // namespace benchmark +} // namespace testing + +namespace tsl { +namespace testing { + +inline void RunBenchmarks() { benchmark::RunSpecifiedBenchmarks(); } +inline void InitializeBenchmarks(int* argc, char** argv) { + benchmark::Initialize(argc, argv); +} + +template +void DoNotOptimize(const T& var) { + ::benchmark::DoNotOptimize(var); +} +} // namespace testing +} // namespace tsl + +#endif // XLA_TSL_PLATFORM_TEST_BENCHMARK_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/test_main.cc b/third_party/xla/xla/tsl/platform/test_main.cc similarity index 96% rename from third_party/xla/third_party/tsl/tsl/platform/test_main.cc rename to third_party/xla/xla/tsl/platform/test_main.cc index fb9265618f2553..3a0660540a0ecc 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/test_main.cc +++ b/third_party/xla/xla/tsl/platform/test_main.cc @@ -21,10 +21,10 @@ limitations under the License. #include #include "absl/strings/match.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" #include "tsl/platform/platform.h" #include "tsl/platform/stacktrace_handler.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" GTEST_API_ int main(int argc, char** argv) { tsl::testing::InstallStacktraceHandler(); diff --git a/third_party/xla/third_party/tsl/tsl/platform/threadpool.cc b/third_party/xla/xla/tsl/platform/threadpool.cc similarity index 87% rename from third_party/xla/third_party/tsl/tsl/platform/threadpool.cc rename to third_party/xla/xla/tsl/platform/threadpool.cc index 8b2c850331e944..36757031ad4b94 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/threadpool.cc +++ b/third_party/xla/xla/tsl/platform/threadpool.cc @@ -13,17 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/threadpool.h" +#include "xla/tsl/platform/threadpool.h" + +#include // NOLINT +#include +#include +#include +#include +#include + +#include "absl/base/optimization.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/types.h" #define EIGEN_USE_THREADS #include "absl/types/optional.h" -#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "unsupported/Eigen/CXX11/Tensor" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/blocking_counter.h" #include "tsl/platform/context.h" #include "tsl/platform/denormal.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/mutex.h" #include "tsl/platform/numa.h" #include "tsl/platform/setround.h" #include "tsl/platform/tracing.h" @@ -45,32 +55,42 @@ namespace tsl { namespace thread { struct EigenEnvironment { - typedef Thread EnvThread; + using EnvThread = Thread; + struct TaskImpl { - std::function f; + std::function fn; Context context; uint64 trace_id; }; + struct Task { - std::unique_ptr f; + Task() = default; + + Task(std::function fn, Context context, uint64 trace_id) + : f(TaskImpl{std::move(fn), std::move(context), trace_id}) {} + + Task(Task&&) = default; + Task& operator=(Task&&) = default; + + std::optional f; }; - Env* const env_; - const ThreadOptions thread_options_; - const string name_; + Env* const env; + const ThreadOptions thread_options; + const std::string name; EigenEnvironment(Env* env, const ThreadOptions& thread_options, - const string& name) - : env_(env), thread_options_(thread_options), name_(name) {} + std::string name) + : env(env), thread_options(thread_options), name(std::move(name)) {} EnvThread* CreateThread(std::function f) { - return env_->StartThread(thread_options_, name_, [=]() { + return env->StartThread(thread_options, name, [this, f = std::move(f)]() { // Set the processor flag to flush denormals to zero. port::ScopedFlushDenormal flush; // Set the processor rounding mode to ROUND TO NEAREST. tsl::port::ScopedSetRound round(FE_TONEAREST); - if (thread_options_.numa_node != port::kNUMANoAffinity) { - port::NUMASetThreadNodeAffinity(thread_options_.numa_node); + if (thread_options.numa_node != port::kNUMANoAffinity) { + port::NUMASetThreadNodeAffinity(thread_options.numa_node); } f(); }); @@ -78,36 +98,30 @@ struct EigenEnvironment { Task CreateTask(std::function f) { uint64 id = 0; - if (tracing::EventCollector::IsEnabled()) { + if (ABSL_PREDICT_FALSE(tracing::EventCollector::IsEnabled())) { id = tracing::GetUniqueArg(); tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id); } - return Task{ - std::unique_ptr(new TaskImpl{ - std::move(f), - Context(ContextKind::kThread), - id, - }), - }; + return Task(std::move(f), Context(ContextKind::kThread), id); } void ExecuteTask(const Task& t) { WithContext wc(t.f->context); tracing::ScopedRegion region(tracing::EventCategory::kRunClosure, t.f->trace_id); - t.f->f(); + t.f->fn(); } }; -ThreadPool::ThreadPool(Env* env, const string& name, int num_threads) +ThreadPool::ThreadPool(Env* env, const std::string& name, int num_threads) : ThreadPool(env, ThreadOptions(), name, num_threads, true, nullptr) {} ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options, - const string& name, int num_threads) + const std::string& name, int num_threads) : ThreadPool(env, thread_options, name, num_threads, true, nullptr) {} ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options, - const string& name, int num_threads, + const std::string& name, int num_threads, bool low_latency_hint, Eigen::Allocator* allocator) { CHECK_GE(num_threads, 1); @@ -185,7 +199,7 @@ void ThreadPool::TransformRangeConcurrently( const std::function& fn) { ParallelFor(total, SchedulingParams(SchedulingStrategy::kFixedBlockSize, - absl::nullopt /* cost_per_unit */, block_size), + /*cost_per_unit=*/std::nullopt, block_size), fn); } diff --git a/third_party/xla/xla/tsl/platform/threadpool.h b/third_party/xla/xla/tsl/platform/threadpool.h new file mode 100644 index 00000000000000..ebd6ea596abb7a --- /dev/null +++ b/third_party/xla/xla/tsl/platform/threadpool.h @@ -0,0 +1,245 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TSL_PLATFORM_THREADPOOL_H_ +#define XLA_TSL_PLATFORM_THREADPOOL_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/threadpool_interface.h" +#include "xla/tsl/platform/types.h" + +namespace Eigen { +class Allocator; +class ThreadPoolInterface; +struct ThreadPoolDevice; + +template +class ThreadPoolTempl; +} // namespace Eigen + +namespace tsl { +namespace thread { + +struct EigenEnvironment; + +class ThreadPool { + public: + // Scheduling strategies for ParallelFor. The strategy governs how the given + // units of work are distributed among the available threads in the + // threadpool. + enum class SchedulingStrategy { + // The Adaptive scheduling strategy adaptively chooses the shard sizes based + // on the cost of each unit of work, and the cost model of the underlying + // threadpool device. + // + // The 'cost_per_unit' is an estimate of the number of CPU cycles (or + // nanoseconds if not CPU-bound) to complete a unit of work. Overestimating + // creates too many shards and CPU time will be dominated by per-shard + // overhead, such as Context creation. Underestimating may not fully make + // use of the specified parallelism, and may also cause inefficiencies due + // to load balancing issues and stragglers. + kAdaptive, + // The Fixed Block Size scheduling strategy shards the given units of work + // into shards of fixed size. In case the total number of units is not + // evenly divisible by 'block_size', at most one of the shards may be of + // smaller size. The exact number of shards may be found by a call to + // NumShardsUsedByFixedBlockSizeScheduling. + // + // Each shard may be executed on a different thread in parallel, depending + // on the number of threads available in the pool. Note that when there + // aren't enough threads in the pool to achieve full parallelism, function + // calls will be automatically queued. + kFixedBlockSize + }; + + // Contains additional parameters for either the Adaptive or the Fixed Block + // Size scheduling strategy. + class SchedulingParams { + public: + explicit SchedulingParams(SchedulingStrategy strategy, + absl::optional cost_per_unit, + absl::optional block_size) + : strategy_(strategy), + cost_per_unit_(cost_per_unit), + block_size_(block_size) {} + + SchedulingStrategy strategy() const { return strategy_; } + absl::optional cost_per_unit() const { return cost_per_unit_; } + absl::optional block_size() const { return block_size_; } + + private: + // The underlying Scheduling Strategy for which this instance contains + // additional parameters. + SchedulingStrategy strategy_; + + // The estimated cost per unit of work in number of CPU cycles (or + // nanoseconds if not CPU-bound). Only applicable for Adaptive scheduling + // strategy. + absl::optional cost_per_unit_; + + // The block size of each shard. Only applicable for Fixed Block Size + // scheduling strategy. + absl::optional block_size_; + }; + + // Constructs a pool that contains "num_threads" threads with specified + // "name". env->StartThread() is used to create individual threads with the + // given ThreadOptions. If "low_latency_hint" is true the thread pool + // implementation may use it as a hint that lower latency is preferred at the + // cost of higher CPU usage, e.g. by letting one or more idle threads spin + // wait. Conversely, if the threadpool is used to schedule high-latency + // operations like I/O the hint should be set to false. + // + // REQUIRES: num_threads > 0 + ThreadPool(Env* env, const ThreadOptions& thread_options, + const std::string& name, int num_threads, bool low_latency_hint, + Eigen::Allocator* allocator = nullptr); + + // Constructs a pool for low-latency ops that contains "num_threads" threads + // with specified "name". env->StartThread() is used to create individual + // threads. + // REQUIRES: num_threads > 0 + ThreadPool(Env* env, const std::string& name, int num_threads); + + // Constructs a pool for low-latency ops that contains "num_threads" threads + // with specified "name". env->StartThread() is used to create individual + // threads with the given ThreadOptions. + // REQUIRES: num_threads > 0 + ThreadPool(Env* env, const ThreadOptions& thread_options, + const std::string& name, int num_threads); + + // Constructs a pool that wraps around the thread::ThreadPoolInterface + // instance provided by the caller. Caller retains ownership of + // `user_threadpool` and must ensure its lifetime is longer than the + // ThreadPool instance. + explicit ThreadPool(thread::ThreadPoolInterface* user_threadpool); + + // Waits until all scheduled work has finished and then destroy the + // set of threads. + ~ThreadPool(); + + // Schedules fn() for execution in the pool of threads. + void Schedule(std::function fn); + + void SetStealPartitions( + const std::vector>& partitions); + + void ScheduleWithHint(std::function fn, int start, int limit); + + // Returns the number of shards used by ParallelForFixedBlockSizeScheduling + // with these parameters. + int NumShardsUsedByFixedBlockSizeScheduling(const int64_t total, + const int64_t block_size); + + // Returns the number of threads spawned by calling TransformRangeConcurrently + // with these parameters. + // Deprecated. Use NumShardsUsedByFixedBlockSizeScheduling. + int NumShardsUsedByTransformRangeConcurrently(const int64_t block_size, + const int64_t total); + + // ParallelFor shards the "total" units of work assuming each unit of work + // having roughly "cost_per_unit" cost, in cycles. Each unit of work is + // indexed 0, 1, ..., total - 1. Each shard contains 1 or more units of work + // and the total cost of each shard is roughly the same. + // + // "cost_per_unit" is an estimate of the number of CPU cycles (or nanoseconds + // if not CPU-bound) to complete a unit of work. Overestimating creates too + // many shards and CPU time will be dominated by per-shard overhead, such as + // Context creation. Underestimating may not fully make use of the specified + // parallelism, and may also cause inefficiencies due to load balancing + // issues and stragglers. + void ParallelFor(int64_t total, int64_t cost_per_unit, + const std::function& fn); + + // Similar to ParallelFor above, but takes the specified scheduling strategy + // into account. + void ParallelFor(int64_t total, const SchedulingParams& scheduling_params, + const std::function& fn); + + // Same as ParallelFor with Fixed Block Size scheduling strategy. + // Deprecated. Prefer ParallelFor with a SchedulingStrategy argument. + void TransformRangeConcurrently( + const int64_t block_size, const int64_t total, + const std::function& fn); + + // Shards the "total" units of work. For more details, see "ParallelFor". + // + // The function is passed a thread_id between 0 and NumThreads() *inclusive*. + // This is because some work can happen on the caller thread while the threads + // in the pool are also being used. + // + // The caller can allocate NumThreads() + 1 separate buffers for each thread. + // Each thread can safely write to the buffer given by its id without + // synchronization. However, the worker fn may be called multiple times + // sequentially with the same id. + // + // At most NumThreads() unique ids will actually be used, and only a few may + // be used for small workloads. If each buffer is expensive, the buffers + // should be stored in an array initially filled with null, and a buffer + // should be allocated by fn the first time that the id is used. + void ParallelForWithWorkerId( + int64_t total, int64_t cost_per_unit, + const std::function& fn); + + // Similar to ParallelForWithWorkerId above, but takes the specified + // scheduling strategy into account. + void ParallelForWithWorkerId( + int64_t total, const SchedulingParams& scheduling_params, + const std::function& fn); + + // Returns the number of threads in the pool. + int NumThreads() const; + + // Returns current thread id between 0 and NumThreads() - 1, if called from a + // thread in the pool. Returns -1 otherwise. + int CurrentThreadId() const; + + // If ThreadPool implementation is compatible with Eigen::ThreadPoolInterface, + // returns a non-null pointer. The caller does not own the object the returned + // pointer points to, and should not attempt to delete. + Eigen::ThreadPoolInterface* AsEigenThreadPool() const; + + private: + // Divides the work represented by the range [0, total) into k shards. + // Calls fn(i*block_size, (i+1)*block_size) from the ith shard (0 <= i < k). + // Each shard may be executed on a different thread in parallel, depending on + // the number of threads available in the pool. + // When (i+1)*block_size > total, fn(i*block_size, total) is called instead. + // Here, k = NumShardsUsedByFixedBlockSizeScheduling(total, block_size). + // Requires 0 < block_size <= total. + void ParallelForFixedBlockSizeScheduling( + const int64_t total, const int64_t block_size, + const std::function& fn); + + // underlying_threadpool_ is the user_threadpool if user_threadpool is + // provided in the constructor. Otherwise it is the eigen_threadpool_. + Eigen::ThreadPoolInterface* underlying_threadpool_; + // eigen_threadpool_ is instantiated and owned by thread::ThreadPool if + // user_threadpool is not in the constructor. + std::unique_ptr> eigen_threadpool_; + std::unique_ptr threadpool_device_; + ThreadPool(const ThreadPool&) = delete; + void operator=(const ThreadPool&) = delete; +}; + +} // namespace thread +} // namespace tsl + +#endif // XLA_TSL_PLATFORM_THREADPOOL_H_ diff --git a/third_party/xla/xla/tsl/platform/threadpool_async_executor.h b/third_party/xla/xla/tsl/platform/threadpool_async_executor.h new file mode 100644 index 00000000000000..3d35b5f57e6916 --- /dev/null +++ b/third_party/xla/xla/tsl/platform/threadpool_async_executor.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TSL_PLATFORM_THREADPOOL_ASYNC_EXECUTOR_H_ +#define XLA_TSL_PLATFORM_THREADPOOL_ASYNC_EXECUTOR_H_ + +#include + +#include "xla/tsl/concurrency/async_value.h" +#include "xla/tsl/platform/threadpool.h" + +namespace tsl::thread { + +// An adaptor for a ThreadPool that converts it into the AsyncValue:Executor. +// +// AsncValue::Executor task is a move-only absl::AnyInvocable, and ThreadPool +// expects a copyable std::function. This class adapts the two and makes sure +// that the task is deleted when it's done executing. +class ThreadPoolAsyncExecutor : public AsyncValue::Executor { + public: + explicit ThreadPoolAsyncExecutor(ThreadPool* thread_pool) + : thread_pool_(thread_pool) {} + + void Execute(Task task) final { + auto* task_ptr = new Task(std::move(task)); + thread_pool_->Schedule([task_ptr] { + (*task_ptr)(); + delete task_ptr; + }); + } + + private: + ThreadPool* thread_pool_; +}; + +} // namespace tsl::thread + +#endif // XLA_TSL_PLATFORM_THREADPOOL_ASYNC_EXECUTOR_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/threadpool_async_executor_test.cc b/third_party/xla/xla/tsl/platform/threadpool_async_executor_test.cc similarity index 86% rename from third_party/xla/third_party/tsl/tsl/platform/threadpool_async_executor_test.cc rename to third_party/xla/xla/tsl/platform/threadpool_async_executor_test.cc index acc00aa210b174..074b87fe58f1b2 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/threadpool_async_executor_test.cc +++ b/third_party/xla/xla/tsl/platform/threadpool_async_executor_test.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/threadpool_async_executor.h" +#include "xla/tsl/platform/threadpool_async_executor.h" #include "absl/synchronization/notification.h" -#include "tsl/platform/env.h" -#include "tsl/platform/test.h" -#include "tsl/platform/threadpool.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/threadpool.h" namespace tsl::thread { namespace { diff --git a/third_party/xla/xla/tsl/platform/threadpool_interface.h b/third_party/xla/xla/tsl/platform/threadpool_interface.h new file mode 100644 index 00000000000000..95ad088b90d347 --- /dev/null +++ b/third_party/xla/xla/tsl/platform/threadpool_interface.h @@ -0,0 +1,31 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TSL_PLATFORM_THREADPOOL_INTERFACE_H_ +#define XLA_TSL_PLATFORM_THREADPOOL_INTERFACE_H_ + +#include "unsupported/Eigen/CXX11/ThreadPool" +#include "xla/tsl/platform/types.h" +#include "tsl/platform/mutex.h" + +namespace tsl { +namespace thread { + +class ThreadPoolInterface : public Eigen::ThreadPoolInterface {}; + +} // namespace thread +} // namespace tsl + +#endif // XLA_TSL_PLATFORM_THREADPOOL_INTERFACE_H_ diff --git a/third_party/xla/xla/tsl/platform/threadpool_options.h b/third_party/xla/xla/tsl/platform/threadpool_options.h new file mode 100644 index 00000000000000..aa2ac294ebc771 --- /dev/null +++ b/third_party/xla/xla/tsl/platform/threadpool_options.h @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TSL_PLATFORM_THREADPOOL_OPTIONS_H_ +#define XLA_TSL_PLATFORM_THREADPOOL_OPTIONS_H_ + +#include "xla/tsl/platform/threadpool_interface.h" + +namespace tsl { +namespace thread { + +struct ThreadPoolOptions { + // If not null, use this threadpool to schedule inter-op operation + thread::ThreadPoolInterface* inter_op_threadpool = nullptr; + + // If not null, use this threadpool to schedule intra-op operation + thread::ThreadPoolInterface* intra_op_threadpool = nullptr; +}; + +} // namespace thread +} // namespace tsl + +#endif // XLA_TSL_PLATFORM_THREADPOOL_OPTIONS_H_ diff --git a/third_party/xla/xla/tsl/platform/types.h b/third_party/xla/xla/tsl/platform/types.h new file mode 100644 index 00000000000000..22131e33f7ca09 --- /dev/null +++ b/third_party/xla/xla/tsl/platform/types.h @@ -0,0 +1,74 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TSL_PLATFORM_TYPES_H_ +#define XLA_TSL_PLATFORM_TYPES_H_ + +#include + +#include "tsl/platform/bfloat16.h" +#include "tsl/platform/ml_dtypes.h" // IWYU pragma: export +#include "tsl/platform/platform.h" +#include "tsl/platform/tstring.h" + +// Include appropriate platform-dependent implementations +#if defined(PLATFORM_GOOGLE) || defined(GOOGLE_INTEGRAL_TYPES) +#include "xla/tsl/platform/google/integral_types.h" // IWYU pragma: export +#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \ + defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_POSIX_IOS) || \ + defined(PLATFORM_GOOGLE_IOS) || defined(PLATFORM_WINDOWS) +#include "xla/tsl/platform/default/integral_types.h" // IWYU pragma: export +#else +#error Define the appropriate PLATFORM_ macro for this platform +#endif + +namespace tsl { + +// Alias tsl::string to std::string. +using std::string; + +static const uint4 kuint4max = static_cast(0x0F); +static const uint8 kuint8max = static_cast(0xFF); +static const uint16 kuint16max = static_cast(0xFFFF); +static const uint32 kuint32max = static_cast(0xFFFFFFFF); +static const uint64 kuint64max = static_cast(0xFFFFFFFFFFFFFFFFull); +static const int8_t kint8min = static_cast(~0x7F); +static const int8_t kint8max = static_cast(0x7F); +static const int4 kint4min = static_cast(0x08); +static const int4 kint4max = static_cast(0x07); +static const int16_t kint16min = static_cast(~0x7FFF); +static const int16_t kint16max = static_cast(0x7FFF); +static const int32_t kint32min = static_cast(~0x7FFFFFFF); +static const int32_t kint32max = static_cast(0x7FFFFFFF); +static const int64_t kint64min = static_cast(~0x7FFFFFFFFFFFFFFFll); +static const int64_t kint64max = static_cast(0x7FFFFFFFFFFFFFFFll); + +// A typedef for a uint64 used as a short fingerprint. +using Fprint = uint64; + +} // namespace tsl + +// Alias namespace ::stream_executor as ::tensorflow::se. +namespace stream_executor {} +namespace tensorflow { +namespace se = ::stream_executor; +} // namespace tensorflow + +#if defined(PLATFORM_WINDOWS) +#include +typedef std::ptrdiff_t ssize_t; +#endif + +#endif // XLA_TSL_PLATFORM_TYPES_H_ diff --git a/third_party/xla/xla/tsl/platform/windows/BUILD b/third_party/xla/xla/tsl/platform/windows/BUILD index c5104f6176a77d..0fdd26c7612ddd 100644 --- a/third_party/xla/xla/tsl/platform/windows/BUILD +++ b/third_party/xla/xla/tsl/platform/windows/BUILD @@ -15,6 +15,7 @@ package( default_visibility = internal_visibility([ "//tensorflow/core/platform:__pkg__", "@local_tsl//tsl/platform:__pkg__", + "//xla/tsl/platform:__pkg__", ]), licenses = ["notice"], ) @@ -24,17 +25,17 @@ cc_library( srcs = [ "windows_file_system.cc", "windows_file_system.h", - "@local_tsl//tsl/platform:env.cc", - "@local_tsl//tsl/platform:file_system.cc", - "@local_tsl//tsl/platform:file_system_helper.cc", + "//xla/tsl/platform:env.cc", + "//xla/tsl/platform:file_system.cc", + "//xla/tsl/platform:file_system_helper.cc", + "//xla/tsl/platform:threadpool.cc", "@local_tsl//tsl/platform:ram_file_system.h", - "@local_tsl//tsl/platform:threadpool.cc", ], hdrs = [ - "@local_tsl//tsl/platform:env.h", - "@local_tsl//tsl/platform:file_system.h", - "@local_tsl//tsl/platform:file_system_helper.h", - "@local_tsl//tsl/platform:threadpool.h", + "//xla/tsl/platform:env.h", + "//xla/tsl/platform:file_system.h", + "//xla/tsl/platform:file_system_helper.h", + "//xla/tsl/platform:threadpool.h", ], tags = [ "manual", @@ -44,6 +45,15 @@ cc_library( deps = [ ":error_windows", ":wide_char", + "//xla/tsl/platform:env_time", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:file_statistics", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:threadpool_interface", + "//xla/tsl/platform:types", "//xla/tsl/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/time", @@ -54,27 +64,18 @@ cc_library( "@local_tsl//tsl/platform:context", "@local_tsl//tsl/platform:cord", "@local_tsl//tsl/platform:denormal", - "@local_tsl//tsl/platform:env_time", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:file_statistics", "@local_tsl//tsl/platform:load_library", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:regexp", "@local_tsl//tsl/platform:setround", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:str_util", "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:stringpiece", "@local_tsl//tsl/platform:stringprintf", - "@local_tsl//tsl/platform:threadpool_interface", "@local_tsl//tsl/platform:tracing", - "@local_tsl//tsl/platform:types", ], ) @@ -96,14 +97,14 @@ cc_library( cc_library( name = "env_time", srcs = ["env_time.cc"], - hdrs = ["@local_tsl//tsl/platform:env_time.h"], + hdrs = ["//xla/tsl/platform:env_time.h"], tags = [ "manual", "no_oss", "nobuilder", ], deps = [ - "@local_tsl//tsl/platform:types", + "//xla/tsl/platform:types", ], ) @@ -130,7 +131,7 @@ cc_library( "no_oss", "nobuilder", ], - deps = ["@local_tsl//tsl/platform:types"], + deps = ["//xla/tsl/platform:types"], ) cc_library( @@ -144,8 +145,8 @@ cc_library( ], deps = [ ":wide_char", + "//xla/tsl/platform:errors", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:errors", ], ) @@ -163,8 +164,8 @@ cc_library( ], deps = [ ":error_windows", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", ], ) @@ -191,11 +192,11 @@ cc_library( "nobuilder", ], deps = [ + "//xla/tsl/platform:logging", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:byte_order", "@local_tsl//tsl/platform:dynamic_annotations", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:types", "@snappy", ], ) @@ -222,9 +223,9 @@ cc_library( "nobuilder", ], deps = [ + "//xla/tsl/platform:types", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:stacktrace", - "@local_tsl//tsl/platform:types", ], ) @@ -239,12 +240,12 @@ cc_library( ], textual_hdrs = ["subprocess.h"], deps = [ + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:strcat", - "@local_tsl//tsl/platform:types", ], ) diff --git a/third_party/xla/xla/tsl/platform/windows/env.cc b/third_party/xla/xla/tsl/platform/windows/env.cc index 58382bafd240b3..414159dc3590fc 100644 --- a/third_party/xla/xla/tsl/platform/windows/env.cc +++ b/third_party/xla/xla/tsl/platform/windows/env.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" #include #include @@ -22,17 +22,19 @@ limitations under the License. #include #include #include + +#include #undef ERROR #include #include #include +#include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/windows/wide_char.h" #include "xla/tsl/platform/windows/windows_file_system.h" #include "xla/tsl/protobuf/error_codes.pb.h" #include "tsl/platform/load_library.h" -#include "tsl/platform/logging.h" #include "tsl/platform/ram_file_system.h" #pragma comment(lib, "shlwapi.lib") @@ -102,8 +104,8 @@ class WindowsEnv : public Env { return new StdThread(thread_options, name, std::move(fn)); } - int32 GetCurrentThreadId() override { - return static_cast(::GetCurrentThreadId()); + int64_t GetCurrentThreadId() override { + return static_cast(::GetCurrentThreadId()); } bool GetCurrentThreadName(string* name) override { diff --git a/third_party/xla/xla/tsl/platform/windows/env_time.cc b/third_party/xla/xla/tsl/platform/windows/env_time.cc index 19a58de6f6ac2e..bc73285cbc5995 100644 --- a/third_party/xla/xla/tsl/platform/windows/env_time.cc +++ b/third_party/xla/xla/tsl/platform/windows/env_time.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/env_time.h" +#include "xla/tsl/platform/env_time.h" #include #include diff --git a/third_party/xla/xla/tsl/platform/windows/intrinsics_port.h b/third_party/xla/xla/tsl/platform/windows/intrinsics_port.h index e8a64a4684a8a5..0f2fa1d8424757 100644 --- a/third_party/xla/xla/tsl/platform/windows/intrinsics_port.h +++ b/third_party/xla/xla/tsl/platform/windows/intrinsics_port.h @@ -20,7 +20,7 @@ limitations under the License. // the following avx intrinsics are not defined on windows // in immintrin.h so we define them here. // -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" #define _mm_load_pd1 _mm_load1_pd diff --git a/third_party/xla/xla/tsl/platform/windows/net.cc b/third_party/xla/xla/tsl/platform/windows/net.cc index 1823ef8f679fb9..63f00b4b95ffb9 100644 --- a/third_party/xla/xla/tsl/platform/windows/net.cc +++ b/third_party/xla/xla/tsl/platform/windows/net.cc @@ -21,9 +21,9 @@ limitations under the License. #include #include +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/windows/error_windows.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" #undef ERROR diff --git a/third_party/xla/xla/tsl/platform/windows/port.cc b/third_party/xla/xla/tsl/platform/windows/port.cc index 57600173577329..e4e122ddfcaac3 100644 --- a/third_party/xla/xla/tsl/platform/windows/port.cc +++ b/third_party/xla/xla/tsl/platform/windows/port.cc @@ -24,15 +24,15 @@ limitations under the License. #include #include +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/cpu_info.h" #include "tsl/platform/demangle.h" #include "tsl/platform/host_info.h" #include "tsl/platform/init_main.h" -#include "tsl/platform/logging.h" #include "tsl/platform/mem.h" #include "tsl/platform/numa.h" #include "tsl/platform/snappy.h" -#include "tsl/platform/types.h" namespace tsl { namespace port { diff --git a/third_party/xla/xla/tsl/platform/windows/stacktrace_handler.cc b/third_party/xla/xla/tsl/platform/windows/stacktrace_handler.cc index 76aa873b64ce13..7f00be5e3e43b9 100644 --- a/third_party/xla/xla/tsl/platform/windows/stacktrace_handler.cc +++ b/third_party/xla/xla/tsl/platform/windows/stacktrace_handler.cc @@ -28,9 +28,9 @@ limitations under the License. #include #include // NOLINT(build/c++11) +#include "xla/tsl/platform/types.h" #include "tsl/platform/mutex.h" #include "tsl/platform/stacktrace.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/windows/subprocess.cc b/third_party/xla/xla/tsl/platform/windows/subprocess.cc index 1dee6fccff6051..c44483e7b40ee0 100644 --- a/third_party/xla/xla/tsl/platform/windows/subprocess.cc +++ b/third_party/xla/xla/tsl/platform/windows/subprocess.cc @@ -24,7 +24,7 @@ limitations under the License. #include -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/strcat.h" #define PIPE_BUF_SIZE 4096 diff --git a/third_party/xla/xla/tsl/platform/windows/subprocess.h b/third_party/xla/xla/tsl/platform/windows/subprocess.h index 8c5909953784bc..f815355390d4b4 100644 --- a/third_party/xla/xla/tsl/platform/windows/subprocess.h +++ b/third_party/xla/xla/tsl/platform/windows/subprocess.h @@ -19,9 +19,9 @@ limitations under the License. #include #include -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/platform/windows/windows_file_system.cc b/third_party/xla/xla/tsl/platform/windows/windows_file_system.cc index c5de08a515c571..f4c47064204e3c 100644 --- a/third_party/xla/xla/tsl/platform/windows/windows_file_system.cc +++ b/third_party/xla/xla/tsl/platform/windows/windows_file_system.cc @@ -27,13 +27,13 @@ limitations under the License. #include #include +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/file_system_helper.h" +#include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/windows/error_windows.h" #include "xla/tsl/platform/windows/wide_char.h" #include "xla/tsl/protobuf/error_codes.pb.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/file_system_helper.h" -#include "tsl/platform/logging.h" #include "tsl/platform/strcat.h" // TODO(mrry): Prevent this Windows.h #define from leaking out of our headers. diff --git a/third_party/xla/xla/tsl/platform/windows/windows_file_system.h b/third_party/xla/xla/tsl/platform/windows/windows_file_system.h index c29294d33fa2f5..4dad78172ea441 100644 --- a/third_party/xla/xla/tsl/platform/windows/windows_file_system.h +++ b/third_party/xla/xla/tsl/platform/windows/windows_file_system.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_TSL_PLATFORM_WINDOWS_WINDOWS_FILE_SYSTEM_H_ #define XLA_TSL_PLATFORM_WINDOWS_WINDOWS_FILE_SYSTEM_H_ -#include "tsl/platform/file_system.h" +#include "xla/tsl/platform/file_system.h" #include "tsl/platform/path.h" #include "tsl/platform/platform.h" diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/BUILD b/third_party/xla/xla/tsl/profiler/backends/cpu/BUILD index 71601beb67e60b..3c2073289dc4ce 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/BUILD +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/BUILD @@ -16,11 +16,11 @@ cc_library( "//tensorflow/lite:__pkg__", ]), deps = [ + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "@com_google_absl//absl/container:flat_hash_set", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", ] + if_static([ ":traceme_recorder_impl", ]), @@ -42,14 +42,14 @@ cc_library( "//xla/tsl/profiler:xla_internal", ]), deps = [ + "//xla/tsl/platform:env", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "//xla/tsl/profiler/utils:lock_free_queue", "//xla/tsl/profiler/utils:per_thread", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -60,18 +60,18 @@ tsl_cc_test( deps = [ ":traceme_recorder", ":traceme_recorder_impl", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "//xla/tsl/platform:types", "//xla/tsl/profiler/utils:math_utils", "//xla/tsl/profiler/utils:time_utils", "//xla/tsl/profiler/utils:time_utils_impl", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:notification", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/platform:types", ], ) @@ -83,10 +83,10 @@ cc_library( "//xla/tsl/profiler:internal", ]), deps = [ + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:types", ] + if_static([ ":annotation_stack_impl", ]), @@ -104,10 +104,10 @@ cc_library( "//xla/tsl/profiler:internal", ]), deps = [ + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -123,12 +123,12 @@ cc_library( ]), deps = [ ":traceme_recorder", + "//xla/tsl/platform:types", "//xla/tsl/profiler/utils:parse_annotation", "//xla/tsl/profiler/utils:tf_op_utils", "//xla/tsl/profiler/utils:xplane_builder", "//xla/tsl/profiler/utils:xplane_utils", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -148,13 +148,13 @@ cc_library( deps = [ ":threadpool_listener_state", ":traceme_recorder", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:types", "//xla/tsl/profiler/utils:time_utils", "//xla/tsl/profiler/utils:xplane_schema", "@com_google_absl//absl/log", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:tracing", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/lib:context_types_hdrs", "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/lib:traceme_encode", diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/annotation_stack.cc b/third_party/xla/xla/tsl/profiler/backends/cpu/annotation_stack.cc index a7b35b8626de70..586410fc4eac8b 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/annotation_stack.cc +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/annotation_stack.cc @@ -26,8 +26,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/annotation_stack.h b/third_party/xla/xla/tsl/profiler/backends/cpu/annotation_stack.h index 18fe3a2a1f7e9c..0e3d1d0e16662b 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/annotation_stack.h +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/annotation_stack.h @@ -20,7 +20,7 @@ limitations under the License. #include #include "absl/types/span.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/host_tracer_utils.cc b/third_party/xla/xla/tsl/profiler/backends/cpu/host_tracer_utils.cc index 3ee8fae3f04883..d72984d2605335 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/host_tracer_utils.cc +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/host_tracer_utils.cc @@ -18,12 +18,12 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/backends/cpu/traceme_recorder.h" #include "xla/tsl/profiler/utils/parse_annotation.h" #include "xla/tsl/profiler/utils/tf_op_utils.h" #include "xla/tsl/profiler/utils/xplane_builder.h" #include "xla/tsl/profiler/utils/xplane_utils.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/host_tracer_utils.h b/third_party/xla/xla/tsl/profiler/backends/cpu/host_tracer_utils.h index 438cdbbe24c601..eb0d7dd4c08117 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/host_tracer_utils.h +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/host_tracer_utils.h @@ -15,8 +15,8 @@ limitations under the License. #ifndef XLA_TSL_PROFILER_BACKENDS_CPU_HOST_TRACER_UTILS_H_ #define XLA_TSL_PROFILER_BACKENDS_CPU_HOST_TRACER_UTILS_H_ +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/backends/cpu/traceme_recorder.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/threadpool_listener.cc b/third_party/xla/xla/tsl/profiler/backends/cpu/threadpool_listener.cc index af9fc451b2d238..e10e0e445183c5 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/threadpool_listener.cc +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/threadpool_listener.cc @@ -19,13 +19,13 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/backends/cpu/threadpool_listener_state.h" #include "xla/tsl/profiler/backends/cpu/traceme_recorder.h" #include "xla/tsl/profiler/utils/time_utils.h" #include "xla/tsl/profiler/utils/xplane_schema.h" -#include "tsl/platform/logging.h" #include "tsl/platform/tracing.h" -#include "tsl/platform/types.h" #include "tsl/profiler/lib/context_types.h" #include "tsl/profiler/lib/traceme_encode.h" #include "tsl/profiler/protobuf/xplane.pb.h" diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/threadpool_listener.h b/third_party/xla/xla/tsl/profiler/backends/cpu/threadpool_listener.h index c6376978d86a8e..5cef72cc83bd02 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/threadpool_listener.h +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/threadpool_listener.h @@ -17,9 +17,9 @@ limitations under the License. #define XLA_TSL_PROFILER_BACKENDS_CPU_THREADPOOL_LISTENER_H_ #include "absl/status/status.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/backends/cpu/threadpool_listener_state.h" #include "tsl/platform/tracing.h" -#include "tsl/platform/types.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder.cc b/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder.cc index df81cb4ba52b96..c047c531f64c66 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder.cc +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder.cc @@ -26,12 +26,12 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/lock_free_queue.h" #include "xla/tsl/profiler/utils/per_thread.h" -#include "tsl/platform/env.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/types.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder.h b/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder.h index 62f2f7d91c6005..ed8477bde6c0c7 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder.h +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder.h @@ -23,8 +23,8 @@ limitations under the License. #include #include -#include "tsl/platform/macros.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace profiler { @@ -72,7 +72,7 @@ class TraceMeRecorder { int64_t end_time; }; struct ThreadInfo { - uint32 tid; + int64_t tid; std::string name; }; struct ThreadEvents { diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder_test.cc b/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder_test.cc index 9fa89ed3d5e400..2d423148704e8d 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder_test.cc +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/tsl/profiler/backends/cpu/traceme_recorder.h" #include +#include #include #include #include @@ -23,14 +24,14 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/threadpool.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/math_utils.h" #include "xla/tsl/profiler/utils/time_utils.h" -#include "tsl/platform/env.h" -#include "tsl/platform/logging.h" #include "tsl/platform/notification.h" -#include "tsl/platform/test.h" -#include "tsl/platform/threadpool.h" -#include "tsl/platform/types.h" namespace tsl { namespace profiler { @@ -119,7 +120,7 @@ TEST(RecorderTest, Multithreaded) { bool overlapping_sessions = false; std::set events; }; - absl::flat_hash_map thread_state; + absl::flat_hash_map thread_state; // We expect each thread to eventually have multiple events, not all in a // contiguous range. auto done = [&thread_state] { diff --git a/third_party/xla/xla/tsl/profiler/convert/BUILD b/third_party/xla/xla/tsl/profiler/convert/BUILD index 8f56410a1e4a2b..fd5cb99dc37351 100644 --- a/third_party/xla/xla/tsl/profiler/convert/BUILD +++ b/third_party/xla/xla/tsl/profiler/convert/BUILD @@ -4,7 +4,7 @@ load( "//xla/tsl/platform:rules_cc.bzl", "cc_library", ) -load("//xla/tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") +load("//xla/tsl/profiler/builds:build_config.bzl", "tf_profiler_alias", "tf_profiler_copts") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -21,6 +21,7 @@ cc_library( "//xla/tsl/profiler:internal", ], deps = [ + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/profiler/protobuf:trace_events_proto_cc", ], @@ -28,13 +29,17 @@ cc_library( cc_library( name = "xla_op_utils", + srcs = [tf_profiler_alias("//xla/tsl/profiler/convert/", "xla_op_utils.cc")], hdrs = ["xla_op_utils.h"], visibility = internal_visibility([ "//xla/tsl/profiler:internal", "//xla/tsl/profiler:xla_profiler_backends", "//xla/python:__pkg__", ]), - deps = ["@com_google_absl//absl/strings"], + deps = [ + "//xla/tsl/platform:macros", + "@com_google_absl//absl/strings", + ], ) tsl_cc_test( @@ -43,8 +48,8 @@ tsl_cc_test( srcs = ["xla_op_utils_test.cc"], deps = [ ":xla_op_utils", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", ], ) @@ -55,10 +60,10 @@ cc_library( copts = tf_profiler_copts(), visibility = internal_visibility(["//xla/tsl/profiler:internal"]), deps = [ + "//xla/tsl/platform:types", "//xla/tsl/profiler/utils:timestamp_utils", "//xla/tsl/profiler/utils:xplane_schema", "//xla/tsl/profiler/utils:xplane_utils", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -73,13 +78,13 @@ cc_library( ]), deps = [ ":trace_container", + "//xla/tsl/platform:types", "//xla/tsl/profiler/utils:format_utils", "//xla/tsl/profiler/utils:math_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@jsoncpp_git//:jsoncpp", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:trace_events_proto_cc", ], ) @@ -89,10 +94,9 @@ tsl_cc_test( srcs = ["trace_container_test.cc"], deps = [ ":trace_container", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/profiler/protobuf:trace_events_proto_cc", ], ) @@ -102,10 +106,10 @@ tsl_cc_test( deps = [ ":trace_container", ":trace_events_to_json", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@jsoncpp_git//:jsoncpp", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/profiler/protobuf:trace_events_proto_cc", ], ) @@ -120,6 +124,7 @@ cc_library( ]), deps = [ ":trace_container", + "//xla/tsl/platform:types", "//xla/tsl/profiler/utils:tf_xplane_visitor", "//xla/tsl/profiler/utils:trace_utils", "//xla/tsl/profiler/utils:xplane_schema", @@ -127,7 +132,6 @@ cc_library( "//xla/tsl/profiler/utils:xplane_visitor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:trace_events_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], @@ -139,11 +143,11 @@ tsl_cc_test( srcs = ["xplane_to_trace_events_test.cc"], deps = [ ":xplane_to_trace_events", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "//xla/tsl/profiler/utils:trace_utils", "//xla/tsl/profiler/utils:xplane_builder", "//xla/tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/profiler/protobuf:trace_events_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], diff --git a/third_party/xla/xla/tsl/profiler/convert/oss/BUILD b/third_party/xla/xla/tsl/profiler/convert/oss/BUILD new file mode 100644 index 00000000000000..446e9973d9f445 --- /dev/null +++ b/third_party/xla/xla/tsl/profiler/convert/oss/BUILD @@ -0,0 +1,4 @@ +exports_files( + ["xla_op_utils.cc"], + visibility = ["//xla/tsl/profiler/convert:__pkg__"], +) diff --git a/third_party/xla/xla/tsl/profiler/convert/oss/xla_op_utils.cc b/third_party/xla/xla/tsl/profiler/convert/oss/xla_op_utils.cc new file mode 100644 index 00000000000000..cb19d36c8972ee --- /dev/null +++ b/third_party/xla/xla/tsl/profiler/convert/oss/xla_op_utils.cc @@ -0,0 +1,33 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tsl/profiler/convert/xla_op_utils.h" + +#include "absl/strings/string_view.h" + +namespace tsl { +namespace profiler { + +// LINT.IfChange +constexpr absl::string_view kHloSparseCoreV0Infeed = "sparsecorev0 infeed"; +constexpr absl::string_view kHloSparseCoreV0Outfeed = "sparsecorev0 outfeed"; +constexpr absl::string_view kHloSparseCoreV0InfeedWait = + "sparsecorev0 infeed wait"; +constexpr absl::string_view kHloSparseCoreV0InfeedTransform = + "sparsecorev0 infeed transform"; +// LINT.ThenChange(//tensorflow/compiler/xla/tsl/profiler/convert/google/xla_op_utils.cc) + +} // namespace profiler +} // namespace tsl diff --git a/third_party/xla/xla/tsl/profiler/convert/post_process_single_host_xplane.cc b/third_party/xla/xla/tsl/profiler/convert/post_process_single_host_xplane.cc index 864da925423d99..427fa1cdf0c3db 100644 --- a/third_party/xla/xla/tsl/profiler/convert/post_process_single_host_xplane.cc +++ b/third_party/xla/xla/tsl/profiler/convert/post_process_single_host_xplane.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/timestamp_utils.h" #include "xla/tsl/profiler/utils/xplane_schema.h" #include "xla/tsl/profiler/utils/xplane_utils.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/convert/post_process_single_host_xplane.h b/third_party/xla/xla/tsl/profiler/convert/post_process_single_host_xplane.h index 7aa5bf6a5db7b2..287e76586e2748 100644 --- a/third_party/xla/xla/tsl/profiler/convert/post_process_single_host_xplane.h +++ b/third_party/xla/xla/tsl/profiler/convert/post_process_single_host_xplane.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef XLA_TSL_PROFILER_CONVERT_POST_PROCESS_SINGLE_HOST_XPLANE_H_ #define XLA_TSL_PROFILER_CONVERT_POST_PROCESS_SINGLE_HOST_XPLANE_H_ -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/convert/trace_container_test.cc b/third_party/xla/xla/tsl/profiler/convert/trace_container_test.cc index ccd06d81590c97..fe3d4b39c2ee76 100644 --- a/third_party/xla/xla/tsl/profiler/convert/trace_container_test.cc +++ b/third_party/xla/xla/tsl/profiler/convert/trace_container_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "xla/tsl/platform/test.h" #include "tsl/platform/protobuf.h" -#include "tsl/platform/test.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json.cc b/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json.cc index d9bc3319fdbb5d..9796e29ec18702 100644 --- a/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json.cc +++ b/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json.cc @@ -22,10 +22,10 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "json/json.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/format_utils.h" #include "xla/tsl/profiler/utils/math_utils.h" #include "tsl/platform/protobuf.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/trace_events.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json.h b/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json.h index 8c3b690c795bfd..2f64ee222237b8 100644 --- a/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json.h +++ b/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/convert/trace_container.h" -#include "tsl/platform/types.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json_test.cc b/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json_test.cc index dbbc9b1272df6f..b96bd698dea3d5 100644 --- a/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json_test.cc +++ b/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "json/json.h" +#include "xla/tsl/platform/test.h" #include "xla/tsl/profiler/convert/trace_container.h" #include "tsl/platform/protobuf.h" -#include "tsl/platform/test.h" #include "tsl/profiler/protobuf/trace_events.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/convert/xla_op_utils.h b/third_party/xla/xla/tsl/profiler/convert/xla_op_utils.h index 7ea44e211ca09e..b743dc23f89467 100644 --- a/third_party/xla/xla/tsl/profiler/convert/xla_op_utils.h +++ b/third_party/xla/xla/tsl/profiler/convert/xla_op_utils.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/macros.h" namespace tsl { namespace profiler { @@ -75,6 +76,12 @@ inline constexpr absl::string_view kHloAsyncDone = "async-done"; inline constexpr absl::string_view kHloReshape = "reshape"; inline constexpr absl::string_view kHloTranspose = "transpose"; +// SparseCore V0 sub-categories. +TF_CONST_INIT extern const absl::string_view kHloSparseCoreV0Infeed; +TF_CONST_INIT extern const absl::string_view kHloSparseCoreV0Outfeed; +TF_CONST_INIT extern const absl::string_view kHloSparseCoreV0InfeedWait; +TF_CONST_INIT extern const absl::string_view kHloSparseCoreV0InfeedTransform; + // Return if a category is fusion. inline bool IsFusion(absl::string_view category) { return absl::EndsWith(category, " fusion"); @@ -111,11 +118,33 @@ inline bool IsInfeedOrOutfeed(absl::string_view category) { absl::StrContains(category, kHloInfeed) || absl::StrContains(category, kHloOutfeed); } + +inline bool IsHostOrSparseCoreV0Infeed(absl::string_view category) { + return category == tsl::profiler::kHloInfeed || + category == tsl::profiler::kHloSparseCoreV0Infeed; +} + inline bool MayHaveInnerOps(absl::string_view category) { return category == kHloCall || category == kHloConditional || category == kHloWhile || category == kHloMegacoreFusion; } +// File and line that the framework op corresponding to an HLO op is associated +// to in a user's program; e.g. it could be the file and line of user code that +// generated the op. +struct OpSourceInfo { + absl::string_view source_file; + int32_t source_line = -1; + std::string stack_frame; + + std::string GetSourceTopLine() const { + if (source_file.empty()) return ""; + return absl::StrCat(source_file, ":", source_line); + } + + std::string GetSourceStack() const { return stack_frame; } +}; + } // namespace profiler } // namespace tsl diff --git a/third_party/xla/xla/tsl/profiler/convert/xla_op_utils_test.cc b/third_party/xla/xla/tsl/profiler/convert/xla_op_utils_test.cc index f288d6d52344cb..9869688b6e9bbf 100644 --- a/third_party/xla/xla/tsl/profiler/convert/xla_op_utils_test.cc +++ b/third_party/xla/xla/tsl/profiler/convert/xla_op_utils_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "xla/tsl/profiler/convert/xla_op_utils.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace profiler { @@ -52,6 +52,13 @@ TEST(XlaOpUtilsTest, IsRematerialization) { "test_function_name/reshape/dot_general")); } +TEST(XlaOpUtilsTest, IsHostOrSparseCoreV0Infeed) { + EXPECT_TRUE(IsHostOrSparseCoreV0Infeed(kHloInfeed)); + EXPECT_TRUE(IsHostOrSparseCoreV0Infeed(kHloSparseCoreV0Infeed)); + EXPECT_FALSE(IsHostOrSparseCoreV0Infeed(kHloSparseCoreV0InfeedWait)); + EXPECT_FALSE(IsHostOrSparseCoreV0Infeed(kHloSparseCoreV0InfeedTransform)); +} + } // namespace } // namespace profiler } // namespace tsl diff --git a/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events.cc b/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events.cc index c37951436d7168..99c8d00e87a81d 100644 --- a/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events.cc +++ b/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events.cc @@ -24,12 +24,12 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "xla/tsl/profiler/utils/trace_utils.h" #include "xla/tsl/profiler/utils/xplane_schema.h" #include "xla/tsl/profiler/utils/xplane_utils.h" #include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/trace_events.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" diff --git a/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events.h b/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events.h index d1416395d1e08c..83d5fcf0b4dea7 100644 --- a/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events.h +++ b/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/convert/trace_container.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events_test.cc b/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events_test.cc index 6e0d3955c84cbf..f21f2d280d1660 100644 --- a/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events_test.cc +++ b/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events_test.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include +#include "xla/tsl/platform/test.h" #include "xla/tsl/profiler/utils/trace_utils.h" #include "xla/tsl/profiler/utils/xplane_builder.h" #include "xla/tsl/profiler/utils/xplane_schema.h" -#include "tsl/platform/test.h" #include "tsl/profiler/protobuf/trace_events.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" diff --git a/third_party/xla/xla/tsl/profiler/rpc/BUILD b/third_party/xla/xla/tsl/profiler/rpc/BUILD index f05a50ccb65417..69fa636cf1b8b9 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/BUILD +++ b/third_party/xla/xla/tsl/profiler/rpc/BUILD @@ -29,6 +29,12 @@ cc_library( "//tensorflow/python/profiler/internal:__pkg__", ]), deps = [ + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_time", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", "//xla/tsl/profiler/rpc/client:save_profile", "//xla/tsl/profiler/utils:file_system_utils", "//xla/tsl/profiler/utils:math_utils", @@ -37,13 +43,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_time", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/profiler/lib:profiler_session", "@local_tsl//tsl/profiler/protobuf:profiler_service_cc_grpc_proto", "@local_tsl//tsl/profiler/protobuf:profiler_service_proto_cc", @@ -72,9 +72,9 @@ cc_library( ]), deps = [ ":profiler_service_impl", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:types", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:profiler_service_cc_grpc_proto", ] + tsl_grpc_cc_dependencies(), alwayslink = True, diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/BUILD b/third_party/xla/xla/tsl/profiler/rpc/client/BUILD index f9dd5e0eeb0795..3310e438565c62 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/BUILD +++ b/third_party/xla/xla/tsl/profiler/rpc/client/BUILD @@ -34,16 +34,16 @@ cc_library( ":profiler_client_for_pybind", ":remote_profiler_session_manager", ":save_profile", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "//xla/tsl/profiler/convert:trace_events_to_json", "//xla/tsl/profiler/convert:xplane_to_trace_events", "//xla/tsl/profiler/utils:session_manager", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:profiler_analysis_proto_cc", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", "@local_tsl//tsl/profiler/protobuf:profiler_service_proto_cc", @@ -66,14 +66,14 @@ cc_library( deps = [ "//xla/tsl/lib/io:zlib_compression_options", "//xla/tsl/lib/io:zlib_outputbuffer", + "//xla/tsl/platform:env", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "//xla/tsl/profiler/utils:file_system_utils", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:profiler_service_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], @@ -98,9 +98,9 @@ cc_library( ]), deps = [ ":profiler_client_impl", + "//xla/tsl/platform:status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/profiler/protobuf:profiler_analysis_cc_grpc_proto", "@local_tsl//tsl/profiler/protobuf:profiler_service_cc_grpc_proto", ], @@ -121,14 +121,14 @@ cc_library( "//tensorflow/python/profiler/internal:__pkg__", ]), deps = [ + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "//xla/tsl/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:profiler_analysis_cc_grpc_proto", "@local_tsl//tsl/profiler/protobuf:profiler_service_cc_grpc_proto", ] + tsl_grpc_cc_dependencies(), @@ -140,13 +140,13 @@ cc_library( testonly = 1, hdrs = ["profiler_client_test_util.h"], deps = [ + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test", + "//xla/tsl/platform:types", "//xla/tsl/profiler/rpc:profiler_server_impl", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/lib:profiler_session", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", ] + tf_protos_profiler_service(), @@ -159,16 +159,16 @@ tsl_cc_test( ":profiler_client", ":profiler_client_impl", # for oss ":profiler_client_test_util", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "//xla/tsl/platform:types", "//xla/tsl/profiler/rpc:profiler_server_impl", "//xla/tsl/profiler/rpc:profiler_service_impl", "//xla/tsl/profiler/utils:time_utils_impl", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/lib:profiler_factory_impl", "@local_tsl//tsl/profiler/lib:profiler_session_impl", ] + tf_protos_profiler_service(), @@ -181,18 +181,18 @@ cc_library( copts = tf_profiler_copts(), deps = [ ":profiler_client_for_pybind", + "//xla/tsl/platform:env_time", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "//xla/tsl/profiler/utils:time_utils", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:env_time", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/platform:types", ], ) @@ -203,17 +203,17 @@ tsl_cc_test( ":profiler_client_impl", # for oss ":profiler_client_test_util", ":remote_profiler_session_manager", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "//xla/tsl/platform:types", "//xla/tsl/profiler/rpc:profiler_server_impl", "//xla/tsl/profiler/rpc:profiler_service_impl", "//xla/tsl/profiler/utils:time_utils_impl", "@com_google_absl//absl/status", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/lib:profiler_factory_impl", "@local_tsl//tsl/profiler/lib:profiler_session_impl", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.cc b/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.cc index 84dd66b6e2f118..939ea500af7014 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.cc @@ -26,16 +26,16 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/convert/trace_events_to_json.h" #include "xla/tsl/profiler/convert/xplane_to_trace_events.h" #include "xla/tsl/profiler/rpc/client/profiler_client.h" #include "xla/tsl/profiler/rpc/client/remote_profiler_session_manager.h" #include "xla/tsl/profiler/rpc/client/save_profile.h" #include "xla/tsl/profiler/utils/session_manager.h" -#include "tsl/platform/errors.h" #include "tsl/platform/host_info.h" -#include "tsl/platform/status.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/profiler_analysis.pb.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/profiler_service.pb.h" diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.h b/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.h index 42e27fd1934687..bf5b52a79a1ccd 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.h +++ b/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.h @@ -21,7 +21,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/profiler_service.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client.cc b/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client.cc index 892bef42bb0a82..e4fa849fab0e9a 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client.cc @@ -21,11 +21,11 @@ limitations under the License. #include "absl/time/clock.h" #include "absl/time/time.h" #include "grpcpp/grpcpp.h" // IWYU pragma: keep +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/protobuf/error_codes.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/status.h" -#include "tsl/platform/types.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client.h b/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client.h index b5020d0ba3f34a..37bf7fdd36e379 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client.h +++ b/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client.h @@ -22,7 +22,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/time/time.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" #include "tsl/profiler/protobuf/profiler_analysis.grpc.pb.h" #include "tsl/profiler/protobuf/profiler_service.grpc.pb.h" diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client_test.cc b/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client_test.cc index 58fa263794ccb8..fc7c6bde6134f8 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client_test.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client_test.cc @@ -19,11 +19,11 @@ limitations under the License. #include "absl/time/clock.h" #include "absl/time/time.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/rpc/client/profiler_client_test_util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/profiler_service.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client_test_util.h b/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client_test_util.h index e2bd41bb0d7335..d0a61f450b05b0 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client_test_util.h +++ b/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client_test_util.h @@ -24,10 +24,10 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/rpc/profiler_server.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" #include "tsl/profiler/lib/profiler_session.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/profiler_service.pb.h" diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.cc b/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.cc index 2eb7e0d6743180..ec34644fb62cd7 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.cc @@ -22,12 +22,12 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "xla/tsl/platform/env_time.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/rpc/client/profiler_client.h" #include "xla/tsl/profiler/utils/time_utils.h" -#include "tsl/platform/env_time.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/types.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.h b/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.h index 404e45187702c2..d75eac5794b731 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.h +++ b/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.h @@ -21,12 +21,12 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/rpc/client/profiler_client.h" -#include "tsl/platform/macros.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/status.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/platform/types.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager_test.cc b/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager_test.cc index 7386f065041b45..78d671601cc1b2 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager_test.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager_test.cc @@ -21,11 +21,11 @@ limitations under the License. #include "absl/status/status.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/rpc/client/profiler_client_test_util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/profiler_service.pb.h" diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc b/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc index bc8bf69f492bfd..8dc7dfadd6dcb1 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc @@ -30,12 +30,12 @@ limitations under the License. #include "absl/time/time.h" #include "xla/tsl/lib/io/zlib_compression_options.h" #include "xla/tsl/lib/io/zlib_outputbuffer.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/file_system.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" #include "xla/tsl/profiler/utils/file_system_utils.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/file_system.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/status.h" #include "tsl/profiler/protobuf/profiler_service.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.h b/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.h index 2b5b9ac7125483..c27942f3801cc7 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.h +++ b/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.h @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "tsl/platform/status.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "tsl/profiler/protobuf/profiler_service.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" diff --git a/third_party/xla/xla/tsl/profiler/rpc/profiler_server.cc b/third_party/xla/xla/tsl/profiler/rpc/profiler_server.cc index 7679534d875dc4..a19b9be37d5c4c 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/profiler_server.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/profiler_server.cc @@ -20,9 +20,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "grpcpp/grpcpp.h" // IWYU pragma: keep +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/rpc/profiler_service_impl.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/profiler_service.grpc.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/rpc/profiler_server.h b/third_party/xla/xla/tsl/profiler/rpc/profiler_server.h index 5ea10ec82c473c..8021de58481919 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/profiler_server.h +++ b/third_party/xla/xla/tsl/profiler/rpc/profiler_server.h @@ -18,7 +18,7 @@ limitations under the License. #include #include "grpcpp/grpcpp.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" #include "tsl/profiler/protobuf/profiler_service.grpc.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/rpc/profiler_service_impl.cc b/third_party/xla/xla/tsl/profiler/rpc/profiler_service_impl.cc index d359f0bdadb1fd..8501048944acd3 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/profiler_service_impl.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/profiler_service_impl.cc @@ -20,18 +20,18 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_replace.h" #include "grpcpp/support/status.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/env_time.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/status.h" #include "xla/tsl/profiler/rpc/client/save_profile.h" #include "xla/tsl/profiler/utils/file_system_utils.h" #include "xla/tsl/profiler/utils/math_utils.h" #include "xla/tsl/profiler/utils/time_utils.h" #include "xla/tsl/profiler/utils/xplane_utils.h" -#include "tsl/platform/env.h" -#include "tsl/platform/env_time.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/status.h" #include "tsl/profiler/lib/profiler_session.h" #include "tsl/profiler/protobuf/profiler_service.grpc.pb.h" #include "tsl/profiler/protobuf/profiler_service.pb.h" diff --git a/third_party/xla/xla/tsl/profiler/utils/BUILD b/third_party/xla/xla/tsl/profiler/utils/BUILD index 3230a7b2ba51f5..c0be7d2109c0f6 100644 --- a/third_party/xla/xla/tsl/profiler/utils/BUILD +++ b/third_party/xla/xla/tsl/profiler/utils/BUILD @@ -29,7 +29,7 @@ cc_library( name = "format_utils", hdrs = ["format_utils.h"], deps = [ - "@local_tsl//tsl/platform:logging", + "//xla/tsl/platform:logging", ], ) @@ -71,9 +71,9 @@ cc_library( visibility = internal_visibility([":friends"]), deps = [ ":math_utils", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:types", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:types", ], ) @@ -82,8 +82,8 @@ tsl_cc_test( srcs = ["timespan_test.cc"], deps = [ ":timespan", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", ], ) @@ -93,8 +93,8 @@ cc_library( hdrs = ["tf_op_utils.h"], copts = tf_profiler_copts(), deps = [ + "//xla/tsl/platform:macros", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:regexp", ], ) @@ -105,9 +105,9 @@ tsl_cc_test( srcs = ["tf_op_utils_test.cc"], deps = [ ":tf_op_utils", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -120,13 +120,13 @@ cc_library( deps = [ ":tf_op_utils", "//xla/tsl/lib/gtl:map_util", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/lib:context_types_hdrs", ], ) @@ -150,11 +150,11 @@ cc_library( visibility = internal_visibility([":friends"]), deps = [ ":timespan", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:types", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -169,13 +169,13 @@ cc_library( ":math_utils", ":timespan", ":xplane_schema", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -187,9 +187,9 @@ tsl_cc_test( deps = [ ":xplane_builder", ":xplane_visitor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -203,8 +203,8 @@ cc_library( "//xla/tsl/profiler:internal", ]), deps = [ + "//xla/tsl/platform:types", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:types", ], ) @@ -222,6 +222,7 @@ cc_library( ":xplane_builder", ":xplane_schema", ":xplane_visitor", + "//xla/tsl/platform:types", "//xla/tsl/util:stats_calculator_portable", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -230,7 +231,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:fingerprint", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/lib:context_types", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], @@ -241,17 +241,18 @@ tsl_cc_test( srcs = ["xplane_utils_test.cc"], deps = [ ":math_utils", + ":tf_xplane_visitor", ":xplane_builder", ":xplane_schema", ":xplane_utils", ":xplane_visitor", - "//xla/tsl/profiler/utils:tf_xplane_visitor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "//xla/tsl/platform:types", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -284,9 +285,9 @@ tsl_cc_test( srcs = ["parse_annotation_test.cc"], deps = [ ":parse_annotation", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -298,11 +299,15 @@ cc_library( visibility = internal_visibility([":friends"]), deps = [ ":tf_xplane_visitor", + ":timespan", ":xplane_builder", ":xplane_schema", ":xplane_utils", ":xplane_visitor", "//xla/tsl/lib/gtl:map_util", + "//xla/tsl/platform:env", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:types", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -310,9 +315,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@local_tsl//tsl/platform:dso_loader", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -328,10 +330,10 @@ cc_library( ":xplane_builder", ":xplane_schema", ":xplane_utils", + "//xla/tsl/platform:types", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -346,12 +348,12 @@ tsl_cc_test( ":xplane_schema", ":xplane_test_utils", ":xplane_visitor", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", + "//xla/tsl/platform:types", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -360,6 +362,7 @@ cc_library( name = "tpu_xplane_utils", srcs = ["tpu_xplane_utils.cc"], hdrs = ["tpu_xplane_utils.h"], + visibility = internal_visibility([":friends"]), deps = [ ":xplane_schema", ":xplane_utils", @@ -378,9 +381,9 @@ tsl_cc_test( ":xplane_schema", ":xplane_utils", ":xplane_visitor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -409,7 +412,7 @@ cc_library( "//xla/tsl/profiler:internal", ]), deps = [ - "@local_tsl//tsl/platform:logging", + "//xla/tsl/platform:logging", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:thread_annotations", @@ -421,8 +424,8 @@ tsl_cc_test( srcs = ["buffer_pool_test.cc"], deps = [ ":buffer_pool", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", ], ) @@ -459,10 +462,10 @@ tsl_cc_test( ":xplane_schema", ":xplane_test_utils", ":xplane_visitor", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/profiler/lib:connected_traceme", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], @@ -473,9 +476,9 @@ cc_library( srcs = ["session_manager.cc"], hdrs = ["session_manager.h"], deps = [ + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/profiler/lib:profiler_session", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", ], @@ -502,8 +505,8 @@ tsl_cc_test( ":xplane_schema", ":xplane_utils", ":xplane_visitor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", ], ) @@ -530,8 +533,8 @@ cc_library( hdrs = ["lock_free_queue.h"], deps = [ ":no_init", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:macros", ], ) @@ -542,11 +545,11 @@ tsl_cc_test( srcs = ["lock_free_queue_test.cc"], deps = [ ":lock_free_queue", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -573,13 +576,13 @@ tsl_cc_test( srcs = ["per_thread_test.cc"], deps = [ ":per_thread", + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) @@ -600,9 +603,9 @@ tsl_cc_test( deps = [ ":device_utils", ":xplane_schema", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/tsl/profiler/utils/buffer_pool.cc b/third_party/xla/xla/tsl/profiler/utils/buffer_pool.cc index f16fe91d573a8b..17bcb573b01cbf 100644 --- a/third_party/xla/xla/tsl/profiler/utils/buffer_pool.cc +++ b/third_party/xla/xla/tsl/profiler/utils/buffer_pool.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/mem.h" #include "tsl/platform/mutex.h" diff --git a/third_party/xla/xla/tsl/profiler/utils/buffer_pool_test.cc b/third_party/xla/xla/tsl/profiler/utils/buffer_pool_test.cc index 4e5dbab63085de..38af82e31359ae 100644 --- a/third_party/xla/xla/tsl/profiler/utils/buffer_pool_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/buffer_pool_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "xla/tsl/profiler/utils/buffer_pool.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/utils/device_utils_test.cc b/third_party/xla/xla/tsl/profiler/utils/device_utils_test.cc index 1698357e36b330..e2a64d5f396acf 100644 --- a/third_party/xla/xla/tsl/profiler/utils/device_utils_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/device_utils_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/test.h" #include "xla/tsl/profiler/utils/xplane_schema.h" -#include "tsl/platform/test.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/utils/format_utils.h b/third_party/xla/xla/tsl/profiler/utils/format_utils.h index d93d69e8592d70..583c68842e5bd8 100644 --- a/third_party/xla/xla/tsl/profiler/utils/format_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/format_utils.h @@ -20,7 +20,7 @@ limitations under the License. #include -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/utils/group_events.cc b/third_party/xla/xla/tsl/profiler/utils/group_events.cc index 393e170b839446..72619100eaddab 100644 --- a/third_party/xla/xla/tsl/profiler/utils/group_events.cc +++ b/third_party/xla/xla/tsl/profiler/utils/group_events.cc @@ -32,14 +32,16 @@ limitations under the License. #include "absl/functional/bind_front.h" #include "absl/strings/str_cat.h" #include "xla/tsl/lib/gtl/map_util.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "xla/tsl/profiler/utils/xplane_builder.h" #include "xla/tsl/profiler/utils/xplane_schema.h" #include "xla/tsl/profiler/utils/xplane_utils.h" #include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/dso_loader.h" -#include "tsl/platform/env.h" -#include "tsl/platform/types.h" +#include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { namespace profiler { @@ -905,6 +907,30 @@ void GroupXplaneEvents(tensorflow::profiler::XPlane* plane, group_line = nullptr; } else { // host loop if (group_line) { + // Determine whether the module line has been grouped. + bool is_grouped = false; + for (XEvent& event : *module_line->mutable_events()) { + XEventVisitor module_visitor(&plane_visitor, module_line, &event); + if (module_visitor.GetStat(StatType::kGroupId).has_value()) { + is_grouped = true; + break; + } + } + if (!is_grouped) { + // If the module line has not been grouped, then: + // (1) Assign group_id to each step event. + int32_t group_id = 0; + for (XEvent& event : *step_line->mutable_events()) { + XEventBuilder step_builder(step_line, &plane_builder, &event); + XEventVisitor step_visitor(&plane_visitor, step_line, &event); + if (!step_visitor.GetStat(StatType::kGroupId).has_value()) { + step_builder.AddStatValue(*group_id_stat_metadata, group_id++); + } + } + // (2) Group the module events nested by the step events. + GroupLine(*group_id_stat_metadata, plane_visitor, *step_line, + &plane_builder, module_line); + } // Host loop steps take the group_id from their module. GroupLine(*group_id_stat_metadata, plane_visitor, *group_line, &plane_builder, step_line); diff --git a/third_party/xla/xla/tsl/profiler/utils/group_events.h b/third_party/xla/xla/tsl/profiler/utils/group_events.h index 52a73529fb734c..cdacea2b8bd0cf 100644 --- a/third_party/xla/xla/tsl/profiler/utils/group_events.h +++ b/third_party/xla/xla/tsl/profiler/utils/group_events.h @@ -28,10 +28,10 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/xplane_builder.h" #include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/utils/group_events_test.cc b/third_party/xla/xla/tsl/profiler/utils/group_events_test.cc index e8c3306ee4ea3d..e65281bb59a048 100644 --- a/third_party/xla/xla/tsl/profiler/utils/group_events_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/group_events_test.cc @@ -18,14 +18,15 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "xla/tsl/profiler/utils/xplane_builder.h" #include "xla/tsl/profiler/utils/xplane_schema.h" #include "xla/tsl/profiler/utils/xplane_test_utils.h" #include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { @@ -717,6 +718,36 @@ TEST(GroupTPUEventsTest, TpuProgramCallbackTest) { }); } +TEST(GroupTPUEventsTest, ModuleRootEventTest) { + tensorflow::profiler::XSpace space; + tensorflow::profiler::XPlane* device_plane = space.add_planes(); + XPlaneBuilder device_plane_builder(device_plane); + device_plane_builder.ReserveLines(1); + auto step_line = device_plane_builder.GetOrCreateLine(0); + step_line.SetName("Steps"); + CreateXEvent(&device_plane_builder, &step_line, "1", 100, 200, + {{StatType::kStepNum, int64_t{1}}}); + auto module_line = device_plane_builder.GetOrCreateLine(1); + module_line.SetName("XLA Modules"); + CreateXEvent(&device_plane_builder, &module_line, "module", 105, 199, + {{StatType::kRunId, int64_t{123}}, + {StatType::kQueueId, int64_t{0}}, + {StatType::kDeviceOrdinal, int64_t{1}}}); + auto hlo_line = device_plane_builder.GetOrCreateLine(2); + hlo_line.SetName("XLA Ops"); + CreateXEvent(&device_plane_builder, &hlo_line, "matmul", 110, 190, {}); + EventForest event_forest; + GroupTpuEventsOSS(&space, {device_plane}, &event_forest); + XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(&space.planes(0)); + device_plane_visitor.ForEachLine([&](const XLineVisitor& line) { + line.ForEachEvent([&](const XEventVisitor& event) { + SCOPED_TRACE(absl::StrCat(line.Name(), " ", event.Name())); + // All events should be grouped and have `group_id` set. + EXPECT_TRUE(event.GetStat(StatType::kGroupId).has_value()); + }); + }); +} + } // namespace } // namespace profiler } // namespace tsl diff --git a/third_party/xla/xla/tsl/profiler/utils/lock_free_queue.h b/third_party/xla/xla/tsl/profiler/utils/lock_free_queue.h index 9f22aa8b8e5094..4b8b05a248bbe9 100644 --- a/third_party/xla/xla/tsl/profiler/utils/lock_free_queue.h +++ b/third_party/xla/xla/tsl/profiler/utils/lock_free_queue.h @@ -23,9 +23,9 @@ limitations under the License. #include #include +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" #include "xla/tsl/profiler/utils/no_init.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/utils/lock_free_queue_test.cc b/third_party/xla/xla/tsl/profiler/utils/lock_free_queue_test.cc index fd8ccdfb659207..df9c4f3cdf4b00 100644 --- a/third_party/xla/xla/tsl/profiler/utils/lock_free_queue_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/lock_free_queue_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include "absl/synchronization/notification.h" -#include "tsl/platform/env.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace profiler { @@ -53,14 +53,14 @@ void FillEvents2Stage(LockFreeQueue& queue, expected2.clear(); for (size_t i = 0; i < event_count1; ++i) { T event = gen(i); - expected1.emplace_back(event); + expected1.push_back(event); queue.Push(std::move(event)); } stage1_filled.Notify(); stage1_grabbed.WaitForNotification(); for (size_t i = 0; i < event_count2; ++i) { T event = gen(i + event_count1); - expected2.emplace_back(event); + expected2.push_back(event); queue.Push(std::move(event)); } stage2_filled.Notify(); diff --git a/third_party/xla/xla/tsl/profiler/utils/parse_annotation.cc b/third_party/xla/xla/tsl/profiler/utils/parse_annotation.cc index 67328c1ea6e9bc..f790bc0e5ff59b 100644 --- a/third_party/xla/xla/tsl/profiler/utils/parse_annotation.cc +++ b/third_party/xla/xla/tsl/profiler/utils/parse_annotation.cc @@ -31,7 +31,7 @@ std::vector SplitNameAndMetadata( absl::string_view annotation) { std::vector parts; if (!HasMetadata(annotation)) { - parts.emplace_back(annotation); + parts.push_back(annotation); } else { annotation.remove_suffix(1); parts = absl::StrSplit(annotation, '#'); diff --git a/third_party/xla/xla/tsl/profiler/utils/parse_annotation_test.cc b/third_party/xla/xla/tsl/profiler/utils/parse_annotation_test.cc index 6225916ef96cfc..a31afd7b796b65 100644 --- a/third_party/xla/xla/tsl/profiler/utils/parse_annotation_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/parse_annotation_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/utils/per_thread_test.cc b/third_party/xla/xla/tsl/profiler/utils/per_thread_test.cc index 9007319c4d0c74..684b0c4f22d8df 100644 --- a/third_party/xla/xla/tsl/profiler/utils/per_thread_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/per_thread_test.cc @@ -24,8 +24,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/synchronization/blocking_counter.h" #include "absl/synchronization/notification.h" -#include "tsl/platform/env.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane.h b/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane.h index c64a6d02417e48..46f94c166cb280 100644 --- a/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane.h +++ b/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane.h @@ -409,7 +409,7 @@ class TpuModuleLineMutatorFactory : public XplaneEventMutatorFactory { // consistent with other kTpuLaunch types. std::vector> required_stats; required_stats.reserve(4); - required_stats.emplace_back(device_ordinal_); + required_stats.push_back(device_ordinal_); required_stats.emplace_back(*queue_id); required_stats.emplace_back(*run_id); required_stats.emplace_back(static_cast(*core_type)); @@ -501,7 +501,7 @@ class ThreadpoolLineMutatorFactory : public XplaneEventMutatorFactory { metadata.start_region_timestamp_ps = start_region_timestamp_ps; metadata.region_id = region_id; metadata.end_region_timestamp_ps = event.TimestampPs(); - event_metadata.emplace_back(metadata); + event_metadata.push_back(metadata); } }); for (const auto& event_metadata : event_metadata) { diff --git a/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane_test.cc b/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane_test.cc index d18d6452a6a85d..3a52d032dfcd7b 100644 --- a/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane_test.cc @@ -21,12 +21,12 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/hash/hash.h" +#include "xla/tsl/platform/test.h" #include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "xla/tsl/profiler/utils/xplane_builder.h" #include "xla/tsl/profiler/utils/xplane_schema.h" #include "xla/tsl/profiler/utils/xplane_test_utils.h" #include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tsl/platform/test.h" #include "tsl/profiler/lib/connected_traceme.h" #include "tsl/profiler/protobuf/xplane.pb.h" diff --git a/third_party/xla/xla/tsl/profiler/utils/session_manager.cc b/third_party/xla/xla/tsl/profiler/utils/session_manager.cc index d45b6edd83efba..db512507d62be8 100644 --- a/third_party/xla/xla/tsl/profiler/utils/session_manager.cc +++ b/third_party/xla/xla/tsl/profiler/utils/session_manager.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" #include "tsl/profiler/lib/profiler_session.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" diff --git a/third_party/xla/xla/tsl/profiler/utils/session_manager.h b/third_party/xla/xla/tsl/profiler/utils/session_manager.h index fd8c60cbc63d13..557f708500c263 100644 --- a/third_party/xla/xla/tsl/profiler/utils/session_manager.h +++ b/third_party/xla/xla/tsl/profiler/utils/session_manager.h @@ -20,7 +20,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/utils/tf_op_utils.h b/third_party/xla/xla/tsl/profiler/utils/tf_op_utils.h index 078d4d7c3b6f9c..6ef73646dc4b64 100644 --- a/third_party/xla/xla/tsl/profiler/utils/tf_op_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/tf_op_utils.h @@ -21,7 +21,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/string_view.h" -#include "tsl/platform/macros.h" +#include "xla/tsl/platform/macros.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/utils/tf_op_utils_test.cc b/third_party/xla/xla/tsl/profiler/utils/tf_op_utils_test.cc index aef2bbc686f4d8..8379c3bddda9d4 100644 --- a/third_party/xla/xla/tsl/profiler/utils/tf_op_utils_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/tf_op_utils_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/utils/timespan.h b/third_party/xla/xla/tsl/profiler/utils/timespan.h index d1883b8566a6ae..ea913b1438e8e3 100644 --- a/third_party/xla/xla/tsl/profiler/utils/timespan.h +++ b/third_party/xla/xla/tsl/profiler/utils/timespan.h @@ -20,9 +20,9 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/math_utils.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/types.h" namespace tsl { namespace profiler { @@ -100,6 +100,12 @@ class Timespan { return begin_ps_ == other.begin_ps_ && duration_ps_ == other.duration_ps_; } + // The compiler can't synthesize <= from < and == until C++ 20's <=>, but we + // can't yet assume C++20 support. + bool operator<=(const Timespan& other) const { + return *this < other || *this == other; + } + // Returns a string that shows the begin and end times. std::string DebugString() const { return absl::StrCat("[", begin_ps(), ", ", end_ps(), "]"); diff --git a/third_party/xla/xla/tsl/profiler/utils/timespan_test.cc b/third_party/xla/xla/tsl/profiler/utils/timespan_test.cc index 57d7876365c904..5e68072a2621d8 100644 --- a/third_party/xla/xla/tsl/profiler/utils/timespan_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/timespan_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "xla/tsl/profiler/utils/timespan.h" -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace profiler { @@ -80,5 +80,25 @@ TEST(TimespanTests, InstantSpanNonInstantSpanOverlappedDuration) { EXPECT_EQ(0, Timespan(12, 0).OverlappedDurationPs(Timespan(8, 16))); } +TEST(TimespanTests, Operators) { + EXPECT_LT(Timespan(11, 0), Timespan(12, 0)); + EXPECT_LT(Timespan(12, 1), Timespan(12, 0)); + + EXPECT_FALSE(Timespan(12, 0) < Timespan(12, 1)); + EXPECT_FALSE(Timespan(12, 0) < Timespan(11, 0)); + EXPECT_FALSE(Timespan(12, 0) < Timespan(12, 0)); + + EXPECT_FALSE(Timespan(12, 0) == Timespan(12, 1)); + EXPECT_FALSE(Timespan(12, 0) == Timespan(11, 0)); + + EXPECT_EQ(Timespan(12, 0), Timespan(12, 0)); + + EXPECT_LE(Timespan(12, 0), Timespan(12, 0)); + EXPECT_LE(Timespan(12, 0), Timespan(13, 0)); + EXPECT_LE(Timespan(11, 0), Timespan(12, 0)); + + EXPECT_FALSE(Timespan(12, 0) <= Timespan(11, 0)); +} + } // namespace profiler } // namespace tsl diff --git a/third_party/xla/xla/tsl/profiler/utils/timestamp_utils_test.cc b/third_party/xla/xla/tsl/profiler/utils/timestamp_utils_test.cc index dd2e434adbc0f3..0c68572bc75927 100644 --- a/third_party/xla/xla/tsl/profiler/utils/timestamp_utils_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/timestamp_utils_test.cc @@ -15,10 +15,10 @@ limitations under the License. #include "xla/tsl/profiler/utils/timestamp_utils.h" +#include "xla/tsl/platform/test.h" #include "xla/tsl/profiler/utils/xplane_schema.h" #include "xla/tsl/profiler/utils/xplane_utils.h" #include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tsl/platform/test.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils_test.cc b/third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils_test.cc index fc341c98582cc9..c68084a3548f0b 100644 --- a/third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/test.h" #include "xla/tsl/profiler/utils/xplane_schema.h" #include "xla/tsl/profiler/utils/xplane_utils.h" #include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tsl/platform/test.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/utils/trace_utils.h b/third_party/xla/xla/tsl/profiler/utils/trace_utils.h index ef53e611ab95fa..090e9ae164c2a4 100644 --- a/third_party/xla/xla/tsl/profiler/utils/trace_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/trace_utils.h @@ -20,7 +20,7 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/string_view.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/utils/xplane_builder.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_builder.cc index 717650cc06bbdb..d4aba52c0317d4 100644 --- a/third_party/xla/xla/tsl/profiler/utils/xplane_builder.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_builder.cc @@ -24,10 +24,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/math_utils.h" #include "xla/tsl/profiler/utils/timespan.h" #include "xla/tsl/profiler/utils/xplane_schema.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/utils/xplane_builder.h b/third_party/xla/xla/tsl/profiler/utils/xplane_builder.h index a665cece663cb8..02be5da574b3cf 100644 --- a/third_party/xla/xla/tsl/profiler/utils/xplane_builder.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_builder.h @@ -28,11 +28,11 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/math_utils.h" #include "xla/tsl/profiler/utils/timespan.h" -#include "tsl/platform/macros.h" #include "tsl/platform/protobuf.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/utils/xplane_builder_test.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_builder_test.cc index ee2c8e4df0400b..6af0502acdd5d5 100644 --- a/third_party/xla/xla/tsl/profiler/utils/xplane_builder_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_builder_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/platform/test.h" #include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tsl/platform/test.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/utils/xplane_schema.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_schema.cc index edfc9639a5c18d..314b28c8a99bf4 100644 --- a/third_party/xla/xla/tsl/profiler/utils/xplane_schema.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_schema.cc @@ -276,6 +276,7 @@ const StatTypeMap& GetStatTypeMap() { {"flops", kFlops}, {"model_flops", kModelFlops}, {"bytes_accessed", kBytesAccessed}, + {"raw_bytes_accessed", kRawBytesAccessed}, {"memory_access_breakdown", kMemoryAccessBreakdown}, {"shape_with_layout", kShapeWithLayout}, {"source", kSourceInfo}, @@ -316,6 +317,8 @@ const StatTypeMap& GetStatTypeMap() { {"peak_sram_wr_bw_gigabytes_per_second", kDevCapPeakSramWrBwGigabytesPerSecond}, {"device_vendor", kDevVendor}, + {"has_megacore", kDevHasMegacore}, + {"has_merged_vmem", kDevHasMergedVmem}, // Batching related. {"batch_size_after_padding", kBatchSizeAfterPadding}, {"padding_amount", kPaddingAmount}, @@ -357,7 +360,8 @@ const StatTypeMap& GetStatTypeMap() { {"source_stack", kSourceStack}, {"device_offset_ps", kDeviceOffsetPs}, {"device_duration_ps", kDeviceDurationPs}, - {"scope_range_id", kScopeRangeId}}); + {"scope_range_id", kScopeRangeId}, + {"core_details", kCoreDetails}}); DCHECK_EQ(stat_type_map->size(), kNumStatTypes); return *stat_type_map; } @@ -377,6 +381,7 @@ const MegaScaleStatTypeMap& GetMegaScaleStatTypeMap() { {"action_inputs", kMegaScaleActionInputs}, {"transfer_source", kMegaScaleTransferSource}, {"transfer_destinations", kMegaScaleTransferDestinations}, + {"dcn_topology_level", kMegaScaleTransferDcnTopologyLevel}, {"buffer_sizes", kMegaScaleBufferSizes}, {"compute_operation", kMegaScaleComputeOperation}, {"chunk", kMegaScaleChunk}, diff --git a/third_party/xla/xla/tsl/profiler/utils/xplane_schema.h b/third_party/xla/xla/tsl/profiler/utils/xplane_schema.h index c92a79fa771895..ad0af06e21314f 100644 --- a/third_party/xla/xla/tsl/profiler/utils/xplane_schema.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_schema.h @@ -25,9 +25,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "tsl/profiler/lib/context_types.h" namespace tsl { @@ -264,6 +264,7 @@ enum StatType { kFlops, kModelFlops, kBytesAccessed, + kRawBytesAccessed, kMemoryAccessBreakdown, kShapeWithLayout, kSourceInfo, @@ -296,6 +297,8 @@ enum StatType { kDevCapPeakSramRdBwGigabytesPerSecond, kDevCapPeakSramWrBwGigabytesPerSecond, kDevVendor, + kDevHasMegacore, + kDevHasMergedVmem, // Batching related. kBatchSizeAfterPadding, kPaddingAmount, @@ -343,7 +346,8 @@ enum StatType { kDeviceOffsetPs, kDeviceDurationPs, kScopeRangeId, - kLastStatType = kScopeRangeId, + kCoreDetails, + kLastStatType = kCoreDetails, }; enum MegaScaleStatType : uint8_t { @@ -361,6 +365,7 @@ enum MegaScaleStatType : uint8_t { kMegaScaleActionInputs, kMegaScaleTransferSource, kMegaScaleTransferDestinations, + kMegaScaleTransferDcnTopologyLevel, kMegaScaleBufferSizes, kMegaScaleComputeOperation, kMegaScaleChunk, diff --git a/third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.cc index 548b444b912263..e6079d5d11c7a3 100644 --- a/third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.cc @@ -20,10 +20,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/xplane_builder.h" #include "xla/tsl/profiler/utils/xplane_schema.h" #include "xla/tsl/profiler/utils/xplane_utils.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.h b/third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.h index b2e5e58494c67a..ed78ed5a42b773 100644 --- a/third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.h @@ -21,9 +21,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/variant.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/xplane_builder.h" #include "xla/tsl/profiler/utils/xplane_schema.h" -#include "tsl/platform/types.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/utils/xplane_utils.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_utils.cc index 5ceb72059073d6..b83df1f676b97d 100644 --- a/third_party/xla/xla/tsl/profiler/utils/xplane_utils.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_utils.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/math_utils.h" #include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "xla/tsl/profiler/utils/timespan.h" @@ -36,7 +37,6 @@ limitations under the License. #include "xla/tsl/profiler/utils/xplane_visitor.h" #include "xla/tsl/util/stats_calculator.h" #include "tsl/platform/fingerprint.h" -#include "tsl/platform/types.h" #include "tsl/profiler/lib/context_types.h" #include "tsl/profiler/protobuf/xplane.pb.h" @@ -556,6 +556,7 @@ void AggregateXPlane(const XPlane& full_trace, XPlane& aggregated_trace) { const XPlaneVisitor& plane = CreateTfXPlaneVisitor(&full_trace); XPlaneBuilder aggregated_plane(&aggregated_trace); aggregated_plane.SetName(plane.Name()); + aggregated_plane.SetId(plane.Id()); uint64_t first_op_start_ps = kint64max; uint64_t last_op_end_ps = 0; @@ -619,14 +620,17 @@ void AggregateXPlane(const XPlane& full_trace, XPlane& aggregated_trace) { XStatMetadata* kGroupId = aggregated_plane.GetOrCreateStatMetadata( GetStatTypeStr(StatType::kGroupId)); + // TODO(b/384550563): Remove this offset once we have a better way to + // aggregate XPlanes. + int64_t metadata_id_offset = aggregated_plane.CreateEventMetadata()->id() - 1; for (const auto& [line_id, stats_by_group] : stats) { XLineBuilder aggregated_line = aggregated_plane.GetOrCreateLine(line_id); for (const auto& [group_id, stat_by_event] : stats_by_group) { for (const auto& [event_id, event_stat] : stat_by_event) { const auto& src_event_metadata = *plane.GetEventMetadata(event_id); XEventMetadata& event_metadata = - *aggregated_plane.GetOrCreateEventMetadata( - src_event_metadata.name()); + *aggregated_plane.GetOrCreateEventMetadata(src_event_metadata.id() + + metadata_id_offset); CopyEventMetadata(src_event_metadata, plane, event_metadata, aggregated_plane); XEventBuilder aggregated_event = diff --git a/third_party/xla/xla/tsl/profiler/utils/xplane_utils.h b/third_party/xla/xla/tsl/profiler/utils/xplane_utils.h index b2b3784267bac8..273804bbdc98fd 100644 --- a/third_party/xla/xla/tsl/profiler/utils/xplane_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_utils.h @@ -24,10 +24,10 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/timespan.h" #include "xla/tsl/profiler/utils/trace_utils.h" #include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/utils/xplane_utils_test.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_utils_test.cc index ec44a499f56ad9..3c012456af3b60 100644 --- a/third_party/xla/xla/tsl/profiler/utils/xplane_utils_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_utils_test.cc @@ -22,15 +22,16 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/math_utils.h" #include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "xla/tsl/profiler/utils/xplane_builder.h" #include "xla/tsl/profiler/utils/xplane_schema.h" #include "xla/tsl/profiler/utils/xplane_visitor.h" -#include "tsl/platform/test.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { @@ -396,6 +397,7 @@ TEST(XplaneUtilsTest, FindMutablePlanesWithPredicate) { TEST(XplaneUtilsTest, TestAggregateXPlanes) { XPlane xplane; XPlaneBuilder builder(&xplane); + builder.SetId(123); auto& event_metadata1 = *builder.GetOrCreateEventMetadata("EventMetadata1"); auto& event_metadata2 = *builder.GetOrCreateEventMetadata("EventMetadata2"); auto& event_metadata3 = *builder.GetOrCreateEventMetadata("EventMetadata3"); @@ -441,6 +443,7 @@ TEST(XplaneUtilsTest, TestAggregateXPlanes) { XPlane aggregated_xplane; AggregateXPlane(xplane, aggregated_xplane); + EXPECT_EQ(aggregated_xplane.id(), 123); // Protobuf matchers are unavailable in OSS (b/169705709) #if defined(PLATFORM_GOOGLE) // TODO(b/238349654): Proto matcher are ineffective for XPlanes. @@ -448,7 +451,8 @@ TEST(XplaneUtilsTest, TestAggregateXPlanes) { aggregated_xplane, IgnoringFields( {"tensorflow.profiler.XEvent.metadata_id", - "tensorflow.profiler.XPlane.event_metadata"}, + "tensorflow.profiler.XPlane.event_metadata", + "tensorflow.profiler.XPlane.id"}, IgnoringRepeatedFieldOrdering(EqualsProto( R"pb(lines { id: 1 @@ -518,6 +522,63 @@ TEST(XplaneUtilsTest, TestAggregateXPlanes) { #endif } +TEST(XplaneUtilsTest, TestAggregateXPlanesWithNonUniqueMetadataNames) { + XPlane xplane; + XPlaneBuilder builder(&xplane); + const XStatMetadata& program_id_stat = + *builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kProgramId)); + XEventMetadata& event_metadata1 = + *builder.GetOrCreateEventMetadata("EventMetadata1"); + XStatsBuilder event_metadata1_stats(&event_metadata1, + &builder); + event_metadata1_stats.AddStatValue(program_id_stat, 1); + XEventMetadata& event_metadata1p2 = *builder.CreateEventMetadata(); + event_metadata1p2.set_name("EventMetadata1"); + XStatsBuilder event_metadata1p2_stats(&event_metadata1p2, + &builder); + event_metadata1p2_stats.AddStatValue(program_id_stat, 2); + XEventMetadata& step_event_metadata1 = + *builder.GetOrCreateEventMetadata("StepEventMetadata1"); + XEventMetadata& step_event_metadata1p2 = + *builder.GetOrCreateEventMetadata("StepEventMetadata2"); + + XLineBuilder step_line = builder.GetOrCreateLine(1); + step_line.SetName(kStepLineName); + XEventBuilder step1 = step_line.AddEvent(step_event_metadata1); + step1.SetOffsetNs(0); + step1.SetDurationNs(10); + XEventBuilder step2 = step_line.AddEvent(step_event_metadata1p2); + step2.SetOffsetNs(10); + step2.SetDurationNs(10); + + XLineBuilder xla_line = builder.GetOrCreateLine(2); + xla_line.SetName(kXlaOpLineName); + XEventBuilder event1 = xla_line.AddEvent(event_metadata1); + event1.SetOffsetNs(0); + event1.SetDurationNs(5); + XEventBuilder event2 = xla_line.AddEvent(event_metadata1p2); + event2.SetOffsetNs(0); + event2.SetDurationNs(5); + XEventBuilder event3 = xla_line.AddEvent(event_metadata1); + event3.SetOffsetNs(5); + event3.SetDurationNs(5); + XEventBuilder event4 = xla_line.AddEvent(event_metadata1p2); + event4.SetOffsetNs(5); + event4.SetDurationNs(5); + + XPlane aggregated_xplane; + AggregateXPlane(xplane, aggregated_xplane); + + absl::flat_hash_set program_ids; + for (const auto& [id, event_metadata] : aggregated_xplane.event_metadata()) { + if (event_metadata.name() == "EventMetadata1") { + program_ids.insert(event_metadata.stats(0).int64_value()); + } + } + EXPECT_TRUE(program_ids.contains(1)); + EXPECT_TRUE(program_ids.contains(2)); +} + TEST(XPlaneUtilsTest, TestAggregateXPlaneWithCycleStats) { XPlane xplane; XPlaneBuilder builder(&xplane); diff --git a/third_party/xla/xla/tsl/profiler/utils/xplane_visitor.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_visitor.cc index b7bfad3f7211eb..2ea1fb86a803df 100644 --- a/third_party/xla/xla/tsl/profiler/utils/xplane_visitor.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_visitor.cc @@ -22,8 +22,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/types.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/profiler/utils/xplane_visitor.h b/third_party/xla/xla/tsl/profiler/utils/xplane_visitor.h index a9c8510355cde2..7dce2e1fbca2cf 100644 --- a/third_party/xla/xla/tsl/profiler/utils/xplane_visitor.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_visitor.h @@ -25,8 +25,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/profiler/utils/timespan.h" -#include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" namespace tsl { @@ -209,6 +209,14 @@ class XEventVisitor : public XStatsOwner { return GetTimespan() < other.GetTimespan(); } + bool operator==(const XEventVisitor& other) const { + return GetTimespan() == other.GetTimespan(); + } + + bool operator<=(const XEventVisitor& other) const { + return GetTimespan() <= other.GetTimespan(); + } + const XEventMetadata* metadata() const { return metadata_; } XEventMetadataVisitor Metadata() const { diff --git a/third_party/xla/xla/tsl/protobuf/coordination_service.proto b/third_party/xla/xla/tsl/protobuf/coordination_service.proto index f593feace8723d..2740f1c685660a 100644 --- a/third_party/xla/xla/tsl/protobuf/coordination_service.proto +++ b/third_party/xla/xla/tsl/protobuf/coordination_service.proto @@ -230,6 +230,22 @@ message BarrierResponse { int64 counter = 1; } +// Request and response messages for querying the set of alive tasks. +message GetAliveTasksRequest { + // The task that is making the GetAliveTasks request. + CoordinatedTask requesting_task = 1; + + // The tasks to check for aliveness. This list must include the requesting + // task. + repeated CoordinatedTask tasks = 2; +} + +message GetAliveTasksResponse { + // The set of alive tasks. This set is a (non-strict) subset of the tasks + // provided in the GetAliveTasksRequest. + repeated CoordinatedTask alive_tasks = 1; +} + // Request and response messages for cancelling generic sync barriers. message CancelBarrierRequest { // Barrier key. @@ -363,6 +379,36 @@ service CoordinationService { // - FailedPrecondition: Barrier has already been passed. rpc CancelBarrier(CancelBarrierRequest) returns (CancelBarrierResponse); + // Returns the set of currently alive tasks. More specifically, given a set of + // tasks T, GetAliveTasks(T) returns the subset T of alive tasks. + // + // # Barrier Semantics + // + // If multiple tasks call GetAliveTasks concurrently, it's important that they + // all agree on which tasks are alive. Otherwise, the tasks' behavior might + // diverge. For example, imagine a set of tasks trying to run an AllGather, + // but they all disagree on which tasks should be participating in the + // AllGather. This is buggy. + // + // To ensure that every task agrees on which tasks are alive, the + // GetAliveTasks RPC has barrier-like semantics. Consider an invocation + // GetAliveTasks(T) for a set of tasks T. The invocation acts as a barrier, + // waiting for every task in T to call GetAliveTasks(T). Afterwards, + // GetAliveTasks returns the same set of alive tasks A to all the tasks in T. + // This ensures that every task agrees which tasks are alive. + // + // One small correction. GetAliveTasks doesn't act as a barrier for *every* + // task in T. Some tasks in T might have failed, so we should not wait for + // them. Instead, the GetAliveTasks RPC waits only for the returned tasks A. + // + // # An Example + // + // Imagine we have four tasks: A, B, C, and D. Further imagine that task D + // has failed and that every task calls GetAliveTasks([A, B, C, D]). The + // invocation will return tasks [A, B, C]. The GetAliveTasks call acts as a + // barrier across tasks A, B, and C. Task D, which failed, is ignored. + rpc GetAliveTasks(GetAliveTasksRequest) returns (GetAliveTasksResponse); + // Polls the service for errors. // // This RPC is used by the coordination service agent to send long polling diff --git a/third_party/xla/xla/tsl/tsl.bzl b/third_party/xla/xla/tsl/tsl.bzl index 337ee2cb6208b5..a3e906e055bb93 100644 --- a/third_party/xla/xla/tsl/tsl.bzl +++ b/third_party/xla/xla/tsl/tsl.bzl @@ -35,7 +35,6 @@ load( load( "@local_tsl//third_party/py/rules_pywrap:pywrap.bzl", "use_pywrap_rules", - _pybind_extension = "pybind_extension", ) # Internally this loads a macro, but in OSS this is a function @@ -838,4 +837,4 @@ def tsl_extra_config_settings_targets(): return [] # TODO(b/356020232): remove after migration is done -tsl_pybind_extension = _pybind_extension if use_pywrap_rules() else tsl_pybind_extension_opensource +tsl_pybind_extension = tsl_pybind_extension_opensource diff --git a/third_party/xla/xla/tsl/util/BUILD b/third_party/xla/xla/tsl/util/BUILD index 50b07a331df0f2..d80f015b79bcd4 100644 --- a/third_party/xla/xla/tsl/util/BUILD +++ b/third_party/xla/xla/tsl/util/BUILD @@ -127,9 +127,9 @@ cc_library( srcs = ["byte_swap_array.cc"], hdrs = ["byte_swap_array.h"], deps = [ + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", "@local_tsl//tsl/platform:byte_order", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", ], ) @@ -200,14 +200,14 @@ cc_library( srcs = ["env_var.cc"], hdrs = ["env_var.h"], deps = [ - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:status", + "//xla/tsl/platform:types", "@local_tsl//tsl/platform:numbers", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:str_util", "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:stringpiece", - "@local_tsl//tsl/platform:types", ], ) @@ -220,14 +220,14 @@ cc_library( "@local_tsl//tsl:__subpackages__", ]), deps = [ + "//xla/tsl/platform:env", + "//xla/tsl/platform:env_impl", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:macros", + "//xla/tsl/platform:types", "//xla/tsl/protobuf:test_log_proto_cc", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:str_util", - "@local_tsl//tsl/platform:types", ], ) @@ -243,6 +243,7 @@ cc_library( copts = tsl_copts(), visibility = internal_visibility([ "//xla/tsl:internal", + "//xla/tsl/profiler:friends", ]), ) @@ -251,8 +252,8 @@ tsl_cc_test( srcs = ["stats_calculator_test.cc"], deps = [ ":stats_calculator_portable", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_main", ], ) @@ -261,8 +262,8 @@ cc_library( srcs = ["device_name_utils.cc"], hdrs = ["device_name_utils.h"], deps = [ - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:status", "@local_tsl//tsl/platform:stringpiece", ], ) @@ -274,11 +275,11 @@ tsl_cc_test( deps = [ ":device_name_utils", "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:errors", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", "@local_tsl//tsl/platform:strcat", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/platform:test_main", ], ) @@ -287,12 +288,12 @@ cc_library( srcs = ["command_line_flags.cc"], hdrs = ["command_line_flags.h"], deps = [ + "//xla/tsl/platform:logging", + "//xla/tsl/platform:types", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:str_util", "@local_tsl//tsl/platform:stringpiece", "@local_tsl//tsl/platform:stringprintf", - "@local_tsl//tsl/platform:types", ], ) diff --git a/third_party/xla/xla/tsl/util/byte_swap_array.cc b/third_party/xla/xla/tsl/util/byte_swap_array.cc index 2c80e8cb928d0d..53bc7d9124f6be 100644 --- a/third_party/xla/xla/tsl/util/byte_swap_array.cc +++ b/third_party/xla/xla/tsl/util/byte_swap_array.cc @@ -15,7 +15,7 @@ limitations under the License. #include "xla/tsl/util/byte_swap_array.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/util/byte_swap_array.h b/third_party/xla/xla/tsl/util/byte_swap_array.h index a2ff2a864ee2dd..d6eff172cea2f2 100644 --- a/third_party/xla/xla/tsl/util/byte_swap_array.h +++ b/third_party/xla/xla/tsl/util/byte_swap_array.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef XLA_TSL_UTIL_BYTE_SWAP_ARRAY_H_ #define XLA_TSL_UTIL_BYTE_SWAP_ARRAY_H_ +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/status.h" #include "tsl/platform/byte_order.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" // Define basic byte swapping operations. // These operations must be macros to use compiler intrinsics. diff --git a/third_party/xla/xla/tsl/util/command_line_flags.cc b/third_party/xla/xla/tsl/util/command_line_flags.cc index d61e88e744c994..226377ddca6047 100644 --- a/third_party/xla/xla/tsl/util/command_line_flags.cc +++ b/third_party/xla/xla/tsl/util/command_line_flags.cc @@ -22,7 +22,7 @@ limitations under the License. #include #include "absl/strings/match.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/str_util.h" #include "tsl/platform/stringpiece.h" #include "tsl/platform/stringprintf.h" diff --git a/third_party/xla/xla/tsl/util/command_line_flags.h b/third_party/xla/xla/tsl/util/command_line_flags.h index d4b3efd662a94d..50888879219f3c 100644 --- a/third_party/xla/xla/tsl/util/command_line_flags.h +++ b/third_party/xla/xla/tsl/util/command_line_flags.h @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tsl/platform/types.h" +#include "xla/tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/util/device_name_utils.cc b/third_party/xla/xla/tsl/util/device_name_utils.cc index 91003750e6ce23..c16b22fa9daad0 100644 --- a/third_party/xla/xla/tsl/util/device_name_utils.cc +++ b/third_party/xla/xla/tsl/util/device_name_utils.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/util/device_name_utils.h b/third_party/xla/xla/tsl/util/device_name_utils.h index 1fbe606aed1967..950387a6827023 100644 --- a/third_party/xla/xla/tsl/util/device_name_utils.h +++ b/third_party/xla/xla/tsl/util/device_name_utils.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" #include "tsl/platform/stringpiece.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/util/device_name_utils_test.cc b/third_party/xla/xla/tsl/util/device_name_utils_test.cc index e93ae83f6009ff..5651e1078a80a2 100644 --- a/third_party/xla/xla/tsl/util/device_name_utils_test.cc +++ b/third_party/xla/xla/tsl/util/device_name_utils_test.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" #include "tsl/platform/strcat.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/util/env_var.cc b/third_party/xla/xla/tsl/util/env_var.cc index 95b744cda43c99..9215c745e0fcfc 100644 --- a/third_party/xla/xla/tsl/util/env_var.cc +++ b/third_party/xla/xla/tsl/util/env_var.cc @@ -17,8 +17,8 @@ limitations under the License. #include -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" #include "tsl/platform/numbers.h" #include "tsl/platform/str_util.h" #include "tsl/platform/strcat.h" @@ -52,7 +52,7 @@ absl::Status ReadInt64FromEnvVar(absl::string_view env_var_name, if (tf_env_var_val == nullptr) { return absl::OkStatus(); } - if (strings::safe_strto64(tf_env_var_val, value)) { + if (absl::SimpleAtoi(tf_env_var_val, value)) { return absl::OkStatus(); } return errors::InvalidArgument(strings::StrCat( @@ -67,7 +67,7 @@ absl::Status ReadFloatFromEnvVar(absl::string_view env_var_name, if (tf_env_var_val == nullptr) { return absl::OkStatus(); } - if (strings::safe_strtof(tf_env_var_val, value)) { + if (absl::SimpleAtof(tf_env_var_val, value)) { return absl::OkStatus(); } return errors::InvalidArgument(strings::StrCat( diff --git a/third_party/xla/xla/tsl/util/env_var.h b/third_party/xla/xla/tsl/util/env_var.h index 87bf9ae78befdc..fdfb366dd8c7a9 100644 --- a/third_party/xla/xla/tsl/util/env_var.h +++ b/third_party/xla/xla/tsl/util/env_var.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef XLA_TSL_UTIL_ENV_VAR_H_ #define XLA_TSL_UTIL_ENV_VAR_H_ -#include "tsl/platform/status.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/types.h" #include "tsl/platform/stringpiece.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/util/onednn_threadpool.h b/third_party/xla/xla/tsl/util/onednn_threadpool.h index a191e566b0eea0..7bf988b57585a4 100644 --- a/third_party/xla/xla/tsl/util/onednn_threadpool.h +++ b/third_party/xla/xla/tsl/util/onednn_threadpool.h @@ -28,10 +28,10 @@ limitations under the License. #define EIGEN_USE_THREADS #include "dnnl_threadpool.hpp" +#include "absl/synchronization/blocking_counter.h" #include "dnnl.hpp" -#include "tsl/platform/blocking_counter.h" +#include "xla/tsl/platform/threadpool.h" #include "tsl/platform/cpu_info.h" -#include "tsl/platform/threadpool.h" namespace tsl { @@ -124,7 +124,7 @@ class OneDnnThreadPool : public threadpool_iface { } run_jobs(balance, njobs_to_schedule, n, njobs, fn); } else { - tsl::BlockingCounter counter(njobs); + absl::BlockingCounter counter(njobs); std::function handle_range = [=, &handle_range, &counter]( int first, int last) { while (last - first > 1) { diff --git a/third_party/xla/xla/tsl/util/reporter.cc b/third_party/xla/xla/tsl/util/reporter.cc index 1d08abf7b2e6c2..08bdcd6c8fb13f 100644 --- a/third_party/xla/xla/tsl/util/reporter.cc +++ b/third_party/xla/xla/tsl/util/reporter.cc @@ -15,7 +15,7 @@ limitations under the License. #include "xla/tsl/util/reporter.h" -#include "tsl/platform/errors.h" +#include "xla/tsl/platform/errors.h" #include "tsl/platform/mutex.h" #include "tsl/platform/str_util.h" diff --git a/third_party/xla/xla/tsl/util/reporter.h b/third_party/xla/xla/tsl/util/reporter.h index e270dd1e23085f..be504656c3e942 100644 --- a/third_party/xla/xla/tsl/util/reporter.h +++ b/third_party/xla/xla/tsl/util/reporter.h @@ -21,11 +21,11 @@ limitations under the License. #include #include +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" #include "xla/tsl/protobuf/test_log.pb.h" -#include "tsl/platform/env.h" -#include "tsl/platform/macros.h" #include "tsl/platform/mutex.h" -#include "tsl/platform/types.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/util/stats_calculator_test.cc b/third_party/xla/xla/tsl/util/stats_calculator_test.cc index bab88a0236fe7e..bbd75845f583d6 100644 --- a/third_party/xla/xla/tsl/util/stats_calculator_test.cc +++ b/third_party/xla/xla/tsl/util/stats_calculator_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tsl/platform/test.h" +#include "xla/tsl/platform/test.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/types.h b/third_party/xla/xla/types.h index 8d30a2b2500131..98e3d7c9331ffc 100644 --- a/third_party/xla/xla/types.h +++ b/third_party/xla/xla/types.h @@ -60,6 +60,8 @@ template inline constexpr bool is_specialized_integral_v = is_specialized_integral::value; +using u1 = tsl::uint1; +using s1 = tsl::int1; using u2 = tsl::uint2; using s2 = tsl::int2; using u4 = tsl::uint4; diff --git a/third_party/xla/xla/types_test.cc b/third_party/xla/xla/types_test.cc index 2d6d288bf9690a..7f16fb8a2056f8 100644 --- a/third_party/xla/xla/types_test.cc +++ b/third_party/xla/xla/types_test.cc @@ -18,7 +18,8 @@ limitations under the License. #include #include -#include "xla/test.h" +#include +#include "xla/hlo/testlib/test.h" namespace xla { namespace { diff --git a/third_party/xla/xla/util_test.cc b/third_party/xla/xla/util_test.cc index 2fe6317bfbb8ea..d15329872d911b 100644 --- a/third_party/xla/xla/util_test.cc +++ b/third_party/xla/xla/util_test.cc @@ -23,15 +23,18 @@ limitations under the License. #include #include #include -#include #include #include +#include +#include "absl/base/log_severity.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "ml_dtypes/include/float8.h" +#include "xla/hlo/testlib/test.h" #include "xla/maybe_owning.h" -#include "xla/test.h" #include "xla/types.h" #include "tsl/platform/logging.h" #include "tsl/platform/ml_dtypes.h" @@ -70,8 +73,8 @@ TEST(UtilTest, VectorString) { std::vector float_vector = {5.5}; EXPECT_EQ(VectorString(float_vector), "(5.5)"); - std::set string_set = {std::string_view("a"), - std::string_view("b")}; + std::set string_set = {absl::string_view("a"), + absl::string_view("b")}; EXPECT_EQ(VectorString(string_set), "(a, b)"); EXPECT_EQ(VectorString({}), "()"); diff --git a/third_party/xla/xla/window_util.cc b/third_party/xla/xla/window_util.cc index 66614613b98d02..affb4ae347d7aa 100644 --- a/third_party/xla/xla/window_util.cc +++ b/third_party/xla/xla/window_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/functional/function_ref.h" +#include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/window_util_test.cc b/third_party/xla/xla/window_util_test.cc index e1f6e13597a54e..0fcaa1e297d0f7 100644 --- a/third_party/xla/xla/window_util_test.cc +++ b/third_party/xla/xla/window_util_test.cc @@ -15,7 +15,8 @@ limitations under the License. #include "xla/window_util.h" -#include "xla/test.h" +#include +#include "xla/hlo/testlib/test.h" #include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 26dd55f3ea8fbf..c6e0e4a033242c 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -101,6 +101,9 @@ message DebugOptions { // When true, XLA:CPU uses the thunk runtime to execute compiled program. bool xla_cpu_use_thunk_runtime = 298; + // When true, XLA:CPU uses XNNPACK to execute supported operations. + bool xla_cpu_use_xnnpack = 359; + // Enabling this will enable optimizations that ignore the possibility of NaN. bool xla_enable_fast_math = 335; @@ -128,6 +131,13 @@ message DebugOptions { // Specifies the behavior of per kernel autotuning cache. AutotuneCacheMode xla_gpu_experimental_autotune_cache_mode = 324; + // Do not lock collective cliques for each XLA:GPU execution, and instead + // use per-process cliques that are never unlocked. This disables deadlock + // prevention mechanism in XLA:GPU and should be used at you own risk. If + // collective operations from concurrent executions are not correcctly ordered + // it may lead to deadlocks, crashes or will produce garbage. + bool xla_gpu_collectives_use_persistent_cliques = 354; + // Experimentally disables binary libraries in GPU compiler passes. bool xla_gpu_experimental_disable_binary_libraries = 329; @@ -153,9 +163,18 @@ message DebugOptions { // supported by XLA's Triton emitter. Tile sizes are assigned automatically. bool xla_gpu_experimental_enable_triton_heroless_priority_fusion = 340; - // Gates the experimental feature coupling the Triton Softmax pattern matcher - // with priority fusion. - bool xla_gpu_experimental_enable_triton_softmax_priority_fusion = 325; + // When enabled, the Triton emitter for dot will use int4 as native type and + // later the Triton IR will be rewritten by Triton IR rewriting pass to use + // int4 packed into int8. + bool xla_gpu_experimental_enable_triton_i4_rewrites = 361; + + // When possible, XLA will use Triton's experimental TMA feature. + bool xla_gpu_experimental_enable_triton_tma = 355; + + // If true, XLA will annotate instructions in the dumps with emitter code + // location (source:line) annotations. This helps to identify the source of + // the code that emits a particular instruction. + bool xla_gpu_unsupported_annotate_with_emitter_loc = 358; // Internal testing flag to switch RaggedAllToAllDecomposer on or off. bool xla_gpu_unsupported_enable_ragged_all_to_all_decomposer = 350; @@ -502,6 +521,7 @@ message DebugOptions { COLLECTIVEBROADCAST = 4; ALLTOALL = 5; COLLECTIVEPERMUTE = 6; + RAGGEDALLTOALL = 7; } repeated CollectiveOpType xla_gpu_disable_async_collectives = 289; @@ -515,6 +535,15 @@ message DebugOptions { // xla_gpu_enable_async_collectives reserved 152, 278, 183, 199, 200, 201, 238; + // Enables NCCL Speed-of-Light (SoL) analytical cost model + bool xla_gpu_enable_analytical_sol_latency_estimator = 356; + // Extra platform-specific options to improve analytical latency + // estimator precision; comma-separated list of 'key=val' strings (=val may be + // omitted); no whitespace around commas. Available options: + // --xla_gpu_analytical_latency_estimator_options= + //'nccl_op_launch_ms=55,nic_speed_gbps=40, + // chunk_prep_ms=1,rtt_ms=2,gpus_per_node=4,chunk_size_bytes=1024' + map xla_gpu_analytical_latency_estimator_options = 357; // Size threshold (in bytes) for the GPU collective combiners. int64 xla_gpu_all_reduce_combine_threshold_bytes = 157; int64 xla_gpu_all_gather_combine_threshold_bytes = 212; @@ -684,7 +713,7 @@ message DebugOptions { bool xla_gpu_enable_highest_priority_async_stream = 216; bool xla_gpu_enable_analytical_latency_estimator = 255; - bool xla_gpu_lhs_enable_gpu_async_tracker = 204; + reserved 204; // Was xla_gpu_lhs_enable_gpu_async_tracker. string xla_gpu_pgle_profile_file_or_directory_path = 210; int32 xla_gpu_memory_limit_slop_factor = 260; @@ -875,9 +904,8 @@ message DebugOptions { // Let GEMM fusion autotuning probe cuDNN as a backend. // Current levels: // 0: Disabled. - // 1: Fusions of GEMM, elementwise, transpose/reshape operations. - // 2: + Broadcasts, slicing. - // 3: + Nontrivial noncontracting dimension reshapes/transposes. + // 1: Enabled on Blackwell+ GPUs. + // 2: Enabled on all supported GPUs (Ampere+). int32 xla_gpu_cudnn_gemm_fusion_level = 285; // This instructs the runtime whether to use @@ -1028,7 +1056,7 @@ message DebugOptions { AUTOTUNE_CACHE_MODE_READ = 2; } - // Timeouts for RendezvousSingle stuck warning and termination. + // Timeouts for Rendezvous stuck warning and termination. int32 xla_gpu_executable_warn_stuck_timeout_seconds = 327; int32 xla_gpu_executable_terminate_timeout_seconds = 328; @@ -1077,7 +1105,12 @@ message DebugOptions { // be deterministic, although with additional overhead. bool xla_gpu_enable_scatter_determinism_expander = 345; - // Next id: 354 + // Enable windowed einsum(collective matmul) rewrite for all-to-all + gemm + // This feature is still experimental and effective only + // xla_gpu_multi_streamed_windowed_einsum is set to true. + bool xla_gpu_experimental_enable_alltoall_windowed_einsum = 360; + + // Next id: 362 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. @@ -1095,10 +1128,11 @@ message DebugOptions { // xla_gpu_enable_persistent_temp_buffers // xla_gpu_enable_triton_gemm_int4 // xla_gpu_enable_priority_fusion + // xla_gpu_experimental_enable_triton_softmax_priority_fusion // xla_gpu_pgle_accuracy_checker // xla_gpu_enable_heuristic_pass_configuration reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 221, 242, 206, 320, - 326, 332; + 325, 326, 332; } // Contains flags which affects the GPU compilation result. diff --git a/third_party/xla/xla/xla_data.proto b/third_party/xla/xla/xla_data.proto index 7d9563b11ab795..01a6415549b584 100644 --- a/third_party/xla/xla/xla_data.proto +++ b/third_party/xla/xla/xla_data.proto @@ -32,6 +32,7 @@ enum PrimitiveType { PRED = 1; // Signed integral values of fixed width. + S1 = 30; S2 = 26; S4 = 21; S8 = 2; @@ -40,6 +41,7 @@ enum PrimitiveType { S64 = 5; // Unsigned integral values of fixed width. + U1 = 31; U2 = 27; U4 = 22; U8 = 6; @@ -134,7 +136,7 @@ enum PrimitiveType { // primitive type will have empty dimensions and tuple_shapes fields. TOKEN = 17; - // Next = 30 + // Next = 32 } // LINT.ThenChange( // https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc @@ -431,6 +433,8 @@ message OpMetadata { // Profile information for the Op. ProfileInfo profile_info = 10; + reserved 11; + // Deduplicated HLO name for this op. In some cases, we can have multiple // instructions (e.g. fusions) that are considered duplicates. We want to // group them together under the same name so that we can group them together @@ -439,8 +443,7 @@ message OpMetadata { // fusion.2 and fusion.3 will have deduplicated_name = fusion.1 string deduplicated_name = 12; - // Whether to preserve the layout of the HLO op. - bool preserve_layout = 13; + reserved 13; // 1-based position of the frame in frames flat array. // Ids are 1-based to keep 0 value as representation of non-set property. @@ -559,9 +562,11 @@ message DeviceAssignmentProto { message LiteralProto { ShapeProto shape = 1; repeated bool preds = 2; + bytes s1s = 30; bytes s2s = 26; bytes s4s = 21; bytes s8s = 15; + bytes u1s = 31; bytes u2s = 27; bytes u4s = 22; bytes u8s = 3; @@ -587,7 +592,7 @@ message LiteralProto { bytes f8e4m3fnuzs = 25; bytes f8e3m4s = 29; repeated int64 sparse_indices = 14; - // Next = 30 + // Next = 32 } message WindowDimension {